# -*- coding: utf-8 -*-
# Load data from NWB file

# Builtin/3rd party package imports
import os
import sys
import h5py
import subprocess
import numpy as np
from tqdm import tqdm

# Local imports
from syncopy.datatype.continuous_data import AnalogData
from syncopy.datatype.discrete_data import EventData, SpikeData
from syncopy.shared.errors import (
from syncopy.shared.parsers import io_parser, scalar_parser, filename_parser
from syncopy import __pynwb__

__all__ = ["load_nwb"]

if __pynwb__:
    import pynwb

def _is_valid_nwb_file(filename):
        this_python = os.path.join(os.path.dirname(sys.executable), "python")[this_python, "-m", "pynwb.validate", filename], check=True)
        return True, None
    except subprocess.CalledProcessError as exc:
        err = f"NWB file validation failed. Original error message: {str(exc)}"
        return False, err

[docs]def load_nwb( filename, memuse=3000, container=None, validate=False, default_spike_data_samplerate=None, ): """ Read contents of NWB files Parameters ---------- filename : str Name of (may include full path to) NWB file (e.g., `"/path/to/mydata.nwb"`). memuse : scalar Approximate in-memory cache size (in MB) for reading data from disk container : str Name of syncopy container folder to create the syncopy data in default_spike_data_samplerate : float, optional The samplerate for spike data, in Hz. If not provided, the samplerate is read from the NWB file, but this is not guaranteed to work as some NWB files which contain only spike data do not store a samplerate. If this is `None` and no samplerate is found in the file, this function will raise an error, and you will have to provide the samplerate manually. Returns ------- objdict : dict Any NWB `TimeSeries`-like data is imported into an :class:`~syncopy.AnalogData` object. If the NWB file contains TTL pulse data, an additional :class:`~syncopy.EventData` object is instantiated. The syncopy objects are returned as a dictionary whose keys are the base-names (sans path) of the corresponding files. """ if not __pynwb__: raise SPYError("NWB support is not available. Please install the 'pynwb' package.") # Check if file exists nwbPath, nwbBaseName = io_parser(filename, varname="filename", isfile=True, exists=True) nwbFullName = os.path.join(nwbPath, nwbBaseName) # Ensure `memuse` makes sense` scalar_parser(memuse, varname="memuse", lims=[0, np.inf]) # First, perform some basal validation w/NWB if requested. if validate: is_valid, err = _is_valid_nwb_file(nwbFullName) if not is_valid: raise SPYError(err) # Load NWB meta data from disk nwbio = pynwb.NWBHDF5IO(nwbFullName, "r", load_namespaces=True) nwbfile = # Allocate lists for storing temporary NWB info: IMPORTANT use lists to preserve # order of data chunks/channels nSamples = 0 tStarts = [] sRates = [] dTypes = [] angSeries = [] ttlVals = [] ttlChanStates = [] ttlChans = [] ttlDtypes = [] # If the file contains `epochs`, use it to infer trial information hasEpochs = "epochs" in nwbfile.fields.keys() hasTrials = "trials" in nwbfile.fields.keys() hasSpikedata = "units" in nwbfile.fields.keys() hasAcquisitions = "acquisition" in nwbfile.fields.keys() # Access LFPs in ecephys processing module, if any. try: lfp = nwbfile.processing["ecephys"]["LFP"]["ElectricalSeries"] if isinstance(lfp, pynwb.ecephys.ElectricalSeries): channel_names = lfp.electrodes[:].location if channel_names.unique().size == 1: SPYWarning("No unique channel names found for LFP.") dTypes.append( if lfp.channel_conversion is not None: dTypes.append(lfp.channel_conversion.dtype) tStarts.append(lfp.starting_time) sRates.append(lfp.rate) nSamples = max(nSamples,[0]) angSeries.append(lfp) except KeyError: pass # Access all (supported) `acquisition` fields in the file for acqName, acqValue in nwbfile.acquisition.items(): # Actual extracellular analog time-series data if isinstance(acqValue, pynwb.ecephys.ElectricalSeries): channel_names = acqValue.electrodes[:].location if channel_names.unique().size == 1: SPYWarning("No unique channel names found for {}".format(acqName)) dTypes.append( if acqValue.channel_conversion is not None: dTypes.append(acqValue.channel_conversion.dtype) tStarts.append(acqValue.starting_time) sRates.append(acqValue.rate) nSamples = max(nSamples,[0]) angSeries.append(acqValue) # TTL event pulse data elif ".TTLs" in str(acqValue.__class__): if == "TTL_PulseValues": ttlVals.append(acqValue) elif == "TTL_ChannelStates": ttlChanStates.append(acqValue) elif == "TTL_Channels": ttlChans.append(acqValue) else: lgl = "TTL data exported via `esi-oephys2nwb`" act = "unformatted TTL data '{}'" raise SPYValueError(lgl, varname=acqName, actual=act.format(acqValue.description)) ttlDtypes.append( ttlDtypes.append(acqValue.timestamps.dtype) # Unsupported else: lgl = "supported NWB Acquisition data class" raise SPYValueError(lgl, varname=acqName, actual=str(acqValue.__class__)) # Load Spike Data from units. The data gets turned into a SpikeData object later. spikes_by_unit = None units = None if hasSpikedata: units = nwbfile.units.to_dataframe() spikes_by_unit = {n: units.loc[n, "spike_times"] for n in units.index} # If the NWB data is split up in "trials" (or epochs), ensure things don't # get too wild (uniform sampling rates and timing offsets). if hasTrials or hasEpochs: if len(tStarts) < 1 or len(sRates) < 1: if ( hasSpikedata and not hasAcquisitions ): # There may be no samplerate read from acquisitions because there are no acquisitions, but only spike data. samplerate = default_spike_data_samplerate if samplerate is None: if "samplerate" in units.columns: samplerate = units.loc[:, "samplerate"].unique()[0] sRates.append(samplerate) tStarts.append(0.0) else: raise SPYError( "Could not read samplerate for spike data from NWB file. Please provide a samplerate manually via parameter 'default_spike_data_samplerate'." ) else: raise SPYError( "Found acquisitions and trials but no valid timing/samplerate data in NWB file. Data in file not supported." ) if all(tStarts) is None or all(sRates) is None: lgl = "acquisition timings defined by `starting_time` and `rate`" act = "`starting_time` or `rate` not set" raise SPYValueError(lgl, varname="starting_time/rate", actual=act) if np.unique(tStarts).size > 1 or np.unique(sRates).size > 1: lgl = "acquisitions with unique `starting_time` and `rate`" act = "`starting_time` or `rate` different across acquisitions" raise SPYValueError(lgl, varname="starting_time/rate", actual=act) if hasTrials: time_intervals = nwbfile.trials[:] else: time_intervals = nwbfile.epochs[:] if not type(time_intervals) is np.ndarray: time_intervals = time_intervals.to_numpy() trl = np.zeros((time_intervals.shape[0], 3), dtype=np.intp) trial_start_stop = (time_intervals - tStarts[0]) * sRates[ 0 ] # use offset relative to first acquisition trl[:, 0:2] = trial_start_stop[:, 0:2] # If we found trials, we may be able to load the offset field from the trials # table. This is not guaranteed to work, though, as the offset field is only present if the # file was exported by Syncopy. If the field is not present, we do not do anything here, we just # proceed with the default zero offset. if hasTrials and "offset" in nwbfile.trials.colnames: df = nwbfile.trials.to_dataframe() trl[:, 2] = df["offset"] * sRates[0] msg = "Found {} trials".format(trl.shape[0]) else: trl = np.array([[0, nSamples, 0]]) msg = "No trial information found. Proceeding with single all-encompassing trial" # Print status update to inform user log_msg = "Read data from NWB file {}".format(nwbFullName) # Check for filename if container is not None: if not isinstance(container, str): raise SPYTypeError(container, varname="container", expected="str") if not os.path.splitext(container)[1] == ".spy": container += ".spy" if not os.path.isdir(container): os.makedirs(container) fileInfo = filename_parser(container) filebase = os.path.join(fileInfo["folder"], fileInfo["container"], fileInfo["basename"]) # If TTL data was found, ensure we have exactly one set of values and associated # channel markers if max(len(ttlVals), len(ttlChans)) > min(len(ttlVals), len(ttlChans)): lgl = "TTL pulse values and channel markers" act = "pulses: {}, channels: {}".format(str(ttlVals), str(ttlChans)) raise SPYValueError(lgl, varname=ttlVals[0].name, actual=act) if len(ttlVals) > 1: lgl = "one set of TTL pulses" act = "{} TTL data sets".format(len(ttlVals)) raise SPYValueError(lgl, varname=ttlVals[0].name, actual=act) # Use provided TTL data to initialize `EventData` object evtData = None objectDict = {} if len(ttlVals) > 0: msg = "Creating separate EventData object for embedded TTL pulse data..." SPYInfo(msg) if container is not None: filename = filebase + ".event" else: filename = None evtData = EventData(dimord=["sample", "eventid", "chans"], filename=filename) h5evt = h5py.File(evtData.filename, mode="w") evtDset = h5evt.create_dataset("data", dtype=int, shape=(ttlVals[0].data.size, 3)) # Column 1: sample indices # Column 2: TTL pulse values # Column 3: TTL channel markers if "resolution" in ttlChans[0].__nwbfields__: ts_resolution = ttlChans[0].resolution else: ts_resolution = ttlChans[0].timestamps__resolution evtDset[:, 0] = ((ttlChans[0].timestamps[()] - tStarts[0]) / ts_resolution).astype(np.intp) evtDset[:, 1] = ttlVals[0].data[()].astype(int) evtDset[:, 2] = ttlChans[0].data[()].astype(int) = evtDset evtData.samplerate = float(1 / ts_resolution) if hasTrials: evtData.trialdefinition = trl else: evtData.trialdefinition = np.array([[np.nanmin(evtDset[:, 0]), np.nanmax(evtDset[:, 0]), 0]]) msg = "No trial information found. Proceeding with single all-encompassing trial" # Write logs evtData.log = log_msg objectDict[os.path.basename(evtData.filename)] = evtData # Compute actually available memory pbarDesc = "Reading data in blocks of {} GB".format(round(memuse / 1000, 2)) memuse *= 1024**2 # Process analog time series data and convert stuff block by block (if necessary) pbar = tqdm(angSeries, position=0, disable=None) for acqValue in pbar: # Show dataset name in progress bar label pbar.set_description("Loading {} from disk".format( # Allocate `AnalogData` object and use generated HDF5 file-name to manually # allocate a target dataset for reading the NWB data if container is not None: filename = filebase + "_" + + ".analog" else: filename = None angData = AnalogData(dimord=AnalogData._defaultDimord, filename=filename) angShape = [None, None] angShape[angData.dimord.index("time")] =[0] numDataChannels =[1] if > 1 else 1 angShape[angData.dimord.index("channel")] = numDataChannels h5ang = h5py.File(angData.filename, mode="w") angDset = h5ang.create_dataset("data", dtype=np.result_type(*dTypes), shape=angShape) # If channel-specific gains are set, load them now if acqValue.channel_conversion is not None: gains = acqValue.channel_conversion[()] if np.all(gains == gains[0]): gains = gains[0] # Given memory cap, compute how many data blocks can be grabbed per swipe: # `nSamp` is the no. of samples that can be loaded into memory without exceeding `memuse` # `rem` is the no. of remaining samples, s. t. ``nSamp + rem = angDset.shape[0]` # `blockList` is a list of samples to load per swipe, i.e., `[nSamp, nSamp, ..., rem]` nSamp = int(memuse / (numDataChannels * angDset.dtype.itemsize)) rem = int(angDset.shape[0] % nSamp) blockList = [nSamp] * int(angDset.shape[0] // nSamp) + [rem] * int(rem > 0) for m, M in enumerate(tqdm(blockList, desc=pbarDesc, position=1, leave=False, disable=None)): st_samp, end_samp = m * nSamp, m * nSamp + M angDset[st_samp:end_samp, :] =[st_samp:end_samp, :] if acqValue.channel_conversion is not None: angDset[st_samp:end_samp, :] *= gains # Finalize angData = angDset channel_names = acqValue.electrodes[:].location if channel_names.size != numDataChannels: SPYWarning( f"Found {channel_names.size} channel names for data with {numDataChannels} channels in NWB file. Discarding channel names." ) = None if channel_names.unique().size == 1 and channel_names.size > 1: SPYWarning( "No unique channel names found for acquisition {}. Discarding channel names.".format(acqName) ) = None else: = channel_names.to_list() angData.samplerate = sRates[0] angData.trialdefinition = trl = {"starting_time": tStarts[0]} angData.log = log_msg objectDict[os.path.basename(angData.filename)] = angData if hasSpikedata and spikes_by_unit is not None: dsetname = "nwbspike" # TODO: Can we get a name for this somwhere in the NWB file? if container is not None: filename = filebase + "_" + dsetname + ".spike" else: filename = None spData = SpikeData(dimord=SpikeData._defaultDimord, filename=filename) # Convert spike times to Syncopy format: load one vector for time, unit, and channel, repectively. spike_times = np.sort(np.concatenate([np.array(i) for i in spikes_by_unit.values()])) spike_units = np.concatenate([np.array([i] * len(spikes_by_unit[i])) for i in spikes_by_unit.keys()]) spike_channels = np.array([0] * len(spike_times)) # single channel, map all to channel 0. # Try to get the samplerate from the NWB file samplerate = sRates[0] spike_data_sampleidx = np.column_stack( (np.rint(spike_times * samplerate), spike_channels, spike_units) ) hdf5_file = h5py.File(spData.filename, mode="w") spDset = hdf5_file.create_dataset("data", data=spike_data_sampleidx, dtype=np.int64) # Finally, assign the dataset to the SpikeData object. = spDset # Fill other fields = [ "channel0" ] # No channel information is saved in NWB files for spike data, only unit information. spData.samplerate = samplerate spData.trialdefinition = trl = {"starting_time": tStarts[0]} spData.log = log_msg # Add loaded Syncopy data object to list of objects to return objectDict[os.path.basename(spData.filename)] = spData # Close NWB file nwbio.close() return objectDict