Source code for syncopy.io.load_tdt

# -*- coding: utf-8 -*-
# @Author: Diljit Singh Kajal
# @Date:   2022-04-08 15:00:00
#
# load_tdt.py Merge separate TDT SEV files into one HDF5 file

import os
from datetime import datetime
import re
import numpy as np
from tqdm.auto import tqdm
import h5py

# Local imports
from syncopy.shared.parsers import io_parser, scalar_parser
from syncopy.shared.errors import SPYWarning, SPYValueError
from syncopy.shared.tools import StructDict
import syncopy as spy


# --- The user exposed function ---


[docs]def load_tdt(data_path, start_code=None, end_code=None, subtract_median=False): """ Imports TDT time series data and meta-information into a single :class:`~syncopy.AnalogData` object. An ad-hoc trialdefinition will be attached if both `start_code` and `end_code` are given. Otherwise a single all-to-all trialdefinition is used. Custom trialdefinitions can be done afterwards with :func:`~syncopy.definetrial`. All meta-information is stored within the `.info` dict of the :class:`~syncopy.AnalogData` object. PDio related keys: ("PDio_onset", "PDio_offset", "PDio_data") Trigger related keys: ("Trigger_code", "Trigger_timestamp", "Trigger_sample") Parameters ---------- data_path : str Path to the directory containing the `.sev` files start_code : int or None, optional Trigger code defining the beginning of a trial end_code : int or None, optional Trigger code defining the end of a trial subtract_median : bool Set to `True` to subtract the median from all individual time series Returns ------- adata : :class:`~syncopy.AnalogData` The tdt data in syncopy format Examples -------- Load all channels from a source directory `/data/session3/`: >>> adata = load_tdt('/data/session3') Access the trigger codes and samples: >>> trg_code = adata.info['Trigger_code'] >>> trg_sample = adata.info['Trigger_sample'] Load the same data and construct a trialdefinition on the fly: >>> adata = load_tdt('/data/session3', start_code=23000, end_code=30020) Access the 3rd trial: >>> trl_dat = adata.trials[2] """ io_parser(data_path, isfile=False) if start_code is not None and end_code is None: lgl = "trigger codes for both trial start and end" raise SPYValueError(lgl, "end_code", end_code) if end_code is not None and start_code is None: lgl = "trigger codes for both trial start and end" raise SPYValueError(lgl, "start_code", start_code) elif end_code is None and start_code is None: pass # both are given else: scalar_parser(start_code, "start_code", ntype="int_like") scalar_parser(end_code, "end_code", ntype="int_like") # initialize tdt info loader class TDT_Load_Info = ESI_TDTinfo(data_path) # this is a StructDict tdt_info = TDT_Load_Info.load_tdt_info() # nicely sorted by channel names file_paths = _get_source_paths(data_path, ".sev") tdt_data_handler = ESI_TDTdata(data_path, subtract_median=subtract_median, channels=None) adata = tdt_data_handler.data_aranging(file_paths, tdt_info) # we have to open for reading again adata.data = h5py.File(adata.filename, "r")["data"] # Write log-entry msg = f"loaded TDT data from {len(file_paths)} files\n" msg += f"\tsource folder: {data_path}\n" msg += f"\tsubtract median: {subtract_median}" adata.log = msg if start_code is not None: adata.trialdefinition = _mk_trialdef(adata, start_code, end_code) msg = f"created trialdefinition from start code `{start_code}` and end code `{end_code}`" adata.log = msg return adata
class ESI_TDTinfo: def __init__(self, block_path): self.block_path = block_path self.nodata = False self.t1 = 0 self.t2 = 0 self.UNKNOWN = int("00000000", 16) self.STRON = int("00000101", 16) self.STROFF = int("00000102", 16) self.SCALAR = int("00000201", 16) self.STREAM = int("00008101", 16) self.SNIP = int("00008201", 16) self.MARK = int("00008801", 16) self.HASDATA = int("00008000", 16) self.UCF = int("00000010", 16) self.PHANTOM = int("00000020", 16) self.MASK = int("0000FF0F", 16) self.INVALID_MASK = int("FFFF0000", 16) self.STARTBLOCK = int("0001", 16) self.STOPBLOCK = int("0002", 16) self.ALLOWED_FORMATS = [ np.float32, np.int32, np.int16, np.int8, np.float64, np.int64, ] self.ALLOWED_EVTYPES = ["all", "epocs", "snips", "streams", "scalars"] def code_to_type(self, code): # given event code, return string 'epocs', 'snips', 'streams', or 'scalars' strobe_types = [self.STRON, self.STROFF, self.MARK] scalar_types = [self.SCALAR] snip_types = [self.SNIP] if code in strobe_types: s = "epocs" elif code in snip_types: s = "snips" elif code & self.MASK == self.STREAM: s = "streams" elif code in scalar_types: s = "scalars" else: s = "unknown" return s def time2sample(self, ts, fs=195312.5, t1=False, t2=False, to_time=False): sample = ts * fs if t2: # drop precision beyond 1e-9 exact = np.round(sample * 1e9) / 1e9 sample = np.floor(sample) if exact == sample: sample -= 1 else: sample = np.ceil(sample) if t1 else np.round(sample) sample = np.uint64(sample) if to_time: return np.float64(sample) / fs return sample def check_ucf(self, code): # given event code, check if it has unique channel files return code & self.UCF == self.UCF def epoc_to_type(self, code): # given epoc event code, return if it is 'onset' or 'offset' event strobe_on_types = [self.STRON, self.MARK] strobe_off_types = [self.STROFF] if code in strobe_on_types: return "onset" elif code in strobe_off_types: return "offset" return "unknown" def code_to_name(self, code): return int(code).to_bytes(4, byteorder="little").decode("cp437") def load_tdt_info(self): header = StructDict() data = StructDict() data.epocs = StructDict() data.streams = StructDict() data.scalars = StructDict() data.info = StructDict() epocs = StructDict() epocs.name = [] epocs.buddies = [] epocs.ts = [] epocs.code = [] epocs.type = [] epocs.type_str = [] epocs.data = [] epocs.dform = [] tsq_list = _get_source_paths(self.block_path, ".tsq") if len(tsq_list) > 1: raise Exception("multiple TSQ files found\n{0}".format(", ".join(tsq_list))) tsq = open(tsq_list[0], "rb") tsq.seek(0, os.SEEK_SET) tsq.seek(48, os.SEEK_SET) code1 = np.fromfile(tsq, dtype=np.int32, count=1) assert code1 == self.STARTBLOCK, "Block start marker not found" tsq.seek(56, os.SEEK_SET) header.start_time = np.fromfile(tsq, dtype=np.float64, count=1) # read stop time tsq.seek(-32, os.SEEK_END) code2 = np.fromfile(tsq, dtype=np.int32, count=1) if code2 != self.STOPBLOCK: SPYWarning( "Block end marker not found, block did not end cleanly. Try setting T2 smaller if errors occur" ) header.stop_time = np.nan else: tsq.seek(-24, os.SEEK_END) header.stop_time = np.fromfile(tsq, dtype=np.float64, count=1) [data.info.tankpath, data.info.blockname] = os.path.split(os.path.normpath(self.block_path)) data.info.start_date = datetime.fromtimestamp(header.start_time[0]) if not np.isnan(header.start_time): data.info.utc_start_time = data.info.start_date.strftime("%H:%M:%S") else: data.info.utc_start_time = np.nan if not np.isnan(header.stop_time): data.info.stop_date = datetime.fromtimestamp(header.stop_time[0]) data.info.utc_stop_time = data.info.stop_date.strftime("%H:%M:%S") else: data.info.stop_date = np.nan data.info.utc_stop_time = np.nan if header.stop_time > 0: data.info.duration = data.info.stop_date - data.info.start_date # datestr(s2-s1, 'HH:MM:SS') tsq.seek(40, os.SEEK_SET) read_size = 10000000 if self.t2 > 0 else 50000000 header.stores = StructDict() while True: heads = np.frombuffer(tsq.read(read_size * 4), dtype=np.uint32) rem = len(heads) % 10 if rem != 0: SPYWarning("Block did not end cleanly, removing last {0} headers".format(rem)) heads = heads[:-rem] # reshape so each column is one header heads = heads.reshape((-1, 10)).T # check the codes first and build store maps and note arrays codes = heads[2, :] good_codes = codes > 0 bad_codes = np.logical_not(good_codes) if np.sum(bad_codes) > 0: SPYWarning( "Bad TSQ headers were written, removing {0}, keeping {1} headers".format( sum(bad_codes), sum(good_codes) ) ) heads = heads[:, good_codes] codes = heads[2, :] # get set of codes but preserve order in the block store_codes = [] unique_codes, unique_ind = np.unique(codes, return_index=True) for counter, x in enumerate(unique_codes): store_codes.append( { "code": x, "type": heads[1, unique_ind[counter]], "type_str": self.code_to_type(heads[1, unique_ind[counter]]), "ucf": self.check_ucf(heads[1, unique_ind[counter]]), "epoc_type": self.epoc_to_type(heads[1, unique_ind[counter]]), "dform": heads[8, unique_ind[counter]], "size": heads[0, unique_ind[counter]], "buddy": heads[3, unique_ind[counter]], "temp": heads[:, unique_ind[counter]], } ) # Looking for only Mark, PDi\ and PDio looking_for = ["Mark", "PDio", "LFPs", "LFP1", "PDi\\"] # targets = StructDict() for chk, content in enumerate(store_codes): if self.code_to_name(content["code"]) in looking_for: targets[self.code_to_name(content["code"])] = chk for tar in targets.items(): store_code = store_codes[tar[1]] store_code["name"] = self.code_to_name(store_code["code"]) # print(code_to_name(store_code['code']), store_code['code']) store_code["var_name"] = store_code["name"] var_name = store_code["var_name"] if store_code["type_str"] == "epocs": if not store_code["name"] in epocs.name: buddy = "".join([str(chr(c)) for c in np.array([store_code["buddy"]]).view(np.uint8)]) buddy = buddy.replace("\x00", " ") # if skip_by_name: # if store: # if isinstance(store, str): # if buddy == store: # skip_by_name = False # elif isinstance(store, list): # if buddy in store: # skip_by_name = False # if skip_by_name: # continue epocs.name.append(store_code["name"]) epocs.buddies.append(buddy) epocs.code.append(store_code["code"]) epocs.ts.append([]) epocs.type.append(store_code["epoc_type"]) epocs.type_str.append(store_code["type_str"]) epocs.data.append([]) epocs.dform.append(store_code["dform"]) if var_name not in header.stores.keys(): if store_code["type_str"] != "epocs": header.stores[var_name] = StructDict( name=store_code["name"], code=store_code["code"], size=store_code["size"], type=store_code["type"], type_str=store_code["type_str"], ) if header.stores[var_name].type_str == "streams": header.stores[var_name].ucf = store_code["ucf"] if header.stores[var_name].type_str != "scalars": # Finding the sampling rate header.stores[var_name].fs = np.double( np.array([store_code["temp"][9]]).view(np.float32) ) header.stores[var_name].dform = store_code["dform"] valid_ind = np.where(codes == store_code["code"])[0] temp = heads[3, valid_ind].view(np.uint16) if store_code["type_str"] != "epocs": if not hasattr(header.stores[var_name], "ts"): header.stores[var_name].ts = [] vvv = ( np.reshape(heads[[[4], [5]], valid_ind].T, (-1, 1)).T.view(np.float64) - header.start_time ) # round timestamps to the nearest sample vvv = self.time2sample(vvv, to_time=True) header.stores[var_name].ts.append(vvv) if (not self.nodata) or (store_code["type_str"] == "streams"): if not hasattr(header.stores[var_name], "data"): header.stores[var_name].data = [] header.stores[var_name].data.append( np.reshape(heads[[[6], [7]], valid_ind].T, (-1, 1)).T.view(np.float64) ) if not hasattr(header.stores[var_name], "chan"): header.stores[var_name].chan = [] header.stores[var_name].chan.append(temp[::2]) else: loc = epocs.name.index(store_code["name"]) # round timestamps to the nearest sample vvv = ( np.reshape(heads[[[4], [5]], valid_ind].T, (-1, 1)).T.view(np.float64) - header.start_time ) # round timestamps to the nearest sample vvv = self.time2sample(vvv, to_time=True) epocs.ts[loc] = np.append(epocs.ts[loc], vvv) epocs.data[loc] = np.append( epocs.data[loc], np.reshape(heads[[[6], [7]], valid_ind].T, (-1, 1)).T.view(np.float64), ) last_ts = heads[[4, 5], -1].view(np.float64) - header.start_time last_ts = last_ts[0] if self.t2 > 0 and last_ts > self.t2: break # eof reached if heads.size < read_size: break print( "Reading data from t = {0}s to t = {1}s".format( np.round(self.t1, 2), np.round(np.maximum(last_ts, self.t2), 2) ) ) for ii in range(len(epocs.name)): # find all non-buddies first if epocs.type[ii] == "onset": var_name = epocs.name[ii] header.stores[var_name] = StructDict() header.stores[var_name].name = epocs.name[ii] ts = epocs.ts[ii] header.stores[var_name].onset = ts header.stores[var_name].offset = np.append(ts[1:], np.inf) header.stores[var_name].type = epocs.type[ii] header.stores[var_name].type_str = epocs.type_str[ii] header.stores[var_name].data = epocs.data[ii] header.stores[var_name].dform = epocs.dform[ii] header.stores[var_name].size = 10 for ii in range(len(epocs.name)): if epocs.type[ii] == "offset": var_name = epocs.buddies[ii] if var_name not in header.stores.keys(): SPYWarning(epocs.buddies[ii] + " buddy epoc not found, skipping") continue header.stores[var_name].offset = epocs.ts[ii] # handle odd case where there is a single offset event and no onset events if "onset" not in header.stores[var_name].keys(): header.stores[var_name].name = epocs.buddies[ii] header.stores[var_name].onset = 0 header.stores[var_name].type_str = "epocs" header.stores[var_name].type = "onset" header.stores[var_name].data = 0 header.stores[var_name].dform = 4 header.stores[var_name].size = 10 # fix time ranges if header.stores[var_name].offset[0] < header.stores[var_name].onset[0]: header.stores[var_name].onset = np.append(0, header.stores[var_name].onset) header.stores[var_name].data = np.append( header.stores[var_name].data[0], header.stores[var_name].data ) if header.stores[var_name].onset[-1] > header.stores[var_name].offset[-1]: header.stores[var_name].offset = np.append(header.stores[var_name].offset, np.inf) for var_name in header.stores.keys(): # convert cell arrays to regular arrays if "ts" in header.stores[var_name].keys(): header.stores[var_name].ts = np.concatenate(header.stores[var_name].ts, axis=1)[0] if "chan" in header.stores[var_name].keys(): header.stores[var_name].chan = np.concatenate(header.stores[var_name].chan) if "sortcode" in header.stores[var_name].keys(): header.stores[var_name].sortcode = np.concatenate(header.stores[var_name].sortcode) if "data" in header.stores[var_name].keys(): if header.stores[var_name].type_str != "epocs": header.stores[var_name].data = np.concatenate(header.stores[var_name].data, axis=1)[0] # if it's a data type, cast as a file offset pointer instead of data if header.stores[var_name].type_str in ["streams", "snips"]: if "data" in header.stores[var_name].keys(): header.stores[var_name].data = header.stores[var_name].data.view(np.uint64) if "chan" in header.stores[var_name].keys(): if np.max(header.stores[var_name].chan) == 1: header.stores[var_name].chan = [1] valid_time_range = ( np.array([[self.t1], [self.t2]]) if self.t2 > 0 else np.array([[self.t1], [np.inf]]) ) ranges = None if hasattr(ranges, "__len__"): valid_time_range = ranges num_ranges = valid_time_range.shape[1] if num_ranges > 0: data.time_ranges = valid_time_range for var_name in header.stores.keys(): current_type_str = header.stores[var_name].type_str data[current_type_str][var_name] = header.stores[var_name] firstStart = valid_time_range[0, 0] last_stop = valid_time_range[1, -1] if "ts" in header.stores[var_name].keys(): if current_type_str == "streams": data[current_type_str][var_name].start_time = [0 for jj in range(num_ranges)] else: this_dtype = data[current_type_str][var_name].ts.dtype data[current_type_str][var_name].filtered_ts = [ np.array([], dtype=this_dtype) for jj in range(num_ranges) ] if hasattr(data[current_type_str][var_name], "chan"): data[current_type_str][var_name].filtered_chan = [[] for jj in range(num_ranges)] if hasattr(data[current_type_str][var_name], "sortcode"): this_dtype = data[current_type_str][var_name].sortcode.dtype data[current_type_str][var_name].filtered_sort_code = [ np.array([], dtype=this_dtype) for jj in range(num_ranges) ] if hasattr(data[current_type_str][var_name], "data"): this_dtype = data[current_type_str][var_name].data.dtype data[current_type_str][var_name].filtered_data = [ np.array([], dtype=this_dtype) for jj in range(num_ranges) ] filter_ind = [[] for i in range(num_ranges)] for jj in range(num_ranges): start = valid_time_range[0, jj] stop = valid_time_range[1, jj] ind1 = data[current_type_str][var_name].ts >= start ind2 = data[current_type_str][var_name].ts < stop filter_ind[jj] = np.where(ind1 & ind2)[0] bSkip = 0 if len(filter_ind[jj]) == 0: # if it's a stream and a short window, we might have missed it if current_type_str == "streams": ind2 = np.where(ind2)[0] if len(ind2) > 0: ind2 = ind2[-1] # keep one prior for streams (for all channels) nchan = max(data[current_type_str][var_name].chan) if ind2 - nchan >= -1: filter_ind[jj] = ind2 - np.arange(nchan - 1, -1, -1) temp = data[current_type_str][var_name].ts[filter_ind[jj]] data[current_type_str][var_name].start_time[jj] = temp[0] bSkip = 1 if len(filter_ind[jj]) > 0: # parse out the information we need if current_type_str == "streams": # keep one prior for streams (for all channels) if not bSkip: nchan = max(data[current_type_str][var_name].chan) temp = filter_ind[jj] if temp[0] - nchan > -1: filter_ind[jj] = np.concatenate( [ -np.arange(nchan, 0, -1) + temp[0], filter_ind[jj], ] ) temp = data[current_type_str][var_name].ts[filter_ind[jj]] data[current_type_str][var_name].start_time[jj] = temp[0] else: data[current_type_str][var_name].filtered_ts[jj] = data[current_type_str][ var_name ].ts[filter_ind[jj]] if hasattr(data[current_type_str][var_name], "chan"): if len(data[current_type_str][var_name].chan) > 1: data[current_type_str][var_name].filtered_chan[jj] = data[current_type_str][ var_name ].chan[filter_ind[jj]] else: data[current_type_str][var_name].filtered_chan[jj] = data[current_type_str][ var_name ].chan if hasattr(data[current_type_str][var_name], "sortcode"): data[current_type_str][var_name].filtered_sort_code[jj] = data[current_type_str][ var_name ].sortcode[filter_ind[jj]] if hasattr(data[current_type_str][var_name], "data"): data[current_type_str][var_name].filtered_data[jj] = data[current_type_str][ var_name ].data[filter_ind[jj]] if current_type_str == "streams": delattr(data[current_type_str][var_name], "ts") delattr(data[current_type_str][var_name], "data") delattr(data[current_type_str][var_name], "chan") if not hasattr(data[current_type_str][var_name], "filtered_chan"): data[current_type_str][var_name].filtered_chan = [[] for i in range(num_ranges)] if not hasattr(data[current_type_str][var_name], "filtered_data"): data[current_type_str][var_name].filtered_data = [[] for i in range(num_ranges)] if not hasattr(data[current_type_str][var_name], "start_time"): data[current_type_str][var_name].start_time = -1 else: # consolidate other fields if hasattr(data[current_type_str][var_name], "filtered_ts"): data[current_type_str][var_name].ts = np.concatenate( data[current_type_str][var_name].filtered_ts ) delattr(data[current_type_str][var_name], "filtered_ts") else: data[current_type_str][var_name].ts = [] if hasattr(data[current_type_str][var_name], "chan"): if hasattr(data[current_type_str][var_name], "filtered_chan"): data[current_type_str][var_name].chan = np.concatenate( data[current_type_str][var_name].filtered_chan ) delattr(data[current_type_str][var_name], "filtered_chan") if current_type_str == "snips": if len(set(data[current_type_str][var_name].chan)) == 1: data[current_type_str][var_name].chan = [ data[current_type_str][var_name].chan[0] ] else: data[current_type_str][var_name].chan = [] if hasattr(data[current_type_str][var_name], "sortcode"): if hasattr(data[current_type_str][var_name], "filtered_sort_code"): data[current_type_str][var_name].sortcode = np.concatenate( data[current_type_str][var_name].filtered_sort_code ) delattr(data[current_type_str][var_name], "filtered_sort_code") else: data[current_type_str][var_name].sortcode = [] if hasattr(data[current_type_str][var_name], "data"): if hasattr(data[current_type_str][var_name], "filtered_data"): data[current_type_str][var_name].data = np.concatenate( data[current_type_str][var_name].filtered_data ) delattr(data[current_type_str][var_name], "filtered_data") else: data[current_type_str][var_name].data = [] else: # handle epoc events filter_ind = [] for jj in range(num_ranges): start = valid_time_range[0, jj] stop = valid_time_range[1, jj] ind1 = data[current_type_str][var_name].onset >= start ind2 = data[current_type_str][var_name].onset < stop filter_ind.append(np.where(ind1 & ind2)[0]) filter_ind = np.concatenate(filter_ind) if len(filter_ind) > 0: data[current_type_str][var_name].onset = data[current_type_str][var_name].onset[ filter_ind ] data[current_type_str][var_name].data = data[current_type_str][var_name].data[filter_ind] data[current_type_str][var_name].offset = data[current_type_str][var_name].offset[ filter_ind ] if data[current_type_str][var_name].offset[0] < data[current_type_str][var_name].onset[0]: if data[current_type_str][var_name].onset[0] > firstStart: data[current_type_str][var_name].onset = np.concatenate( [[firstStart], data[current_type_str][var_name].onset] ) if data[current_type_str][var_name].offset[-1] > last_stop: data[current_type_str][var_name].offset[-1] = last_stop else: # default case is no valid events for this store data[current_type_str][var_name].onset = [] data[current_type_str][var_name].data = [] data[current_type_str][var_name].offset = [] if var_name == "Note": data[current_type_str][var_name].notes = [] for current_name in header.stores.keys(): current_type_str = header.stores[current_name].type_str # if current_type_str not in evtype: # continue current_type_str = data[current_type_str][current_name].type_str current_data_format = self.ALLOWED_FORMATS[data[current_type_str][current_name].dform] if current_type_str == "scalars": if len(data[current_type_str][current_name].chan) > 0: nchan = int(np.max(data[current_type_str][current_name].chan)) else: nchan = 0 if nchan > 1: # organize data by sample # find channels with most and least amount of data ind = [] min_length = np.inf max_length = 0 for xx in range(nchan): ind.append(np.where(data[current_type_str][current_name].chan == xx + 1)[0]) min_length = min(len(ind[-1]), min_length) max_length = max(len(ind[-1]), max_length) if min_length != max_length: SPYWarning( "Truncating store {0} to {1} values (from {2})".format( current_name, min_length, max_length ) ) ind = [ind[xx][:min_length] for xx in range(nchan)] if not self.nodata: data[current_type_str][current_name].data = ( data[current_type_str][current_name].data[np.concatenate(ind)].reshape(nchan, -1) ) # only use timestamps from first channel data[current_type_str][current_name].ts = data[current_type_str][current_name].ts[ind[0]] # remove channels field delattr(data[current_type_str][current_name], "chan") tsq.close() del epocs del header Data = StructDict() Data.PDio = data.epocs.PDio try: Data.LFPs = data.streams.LFPs except Exception: Data.LFPs = data.streams.LFP1 Data.Mark = data.scalars.Mark Data.info = data.info return Data class ESI_TDTdata: def __init__( self, inputdir, subtract_median=False, channels=None, ): self.inputdir = inputdir self.chan_in_chunks = 16 self.subtract_median = subtract_median self.channels = "all" if channels is None else channels def arrange_header(self, DataInfo_loaded, Files): header = StructDict() header["fs"] = DataInfo_loaded.LFPs.fs header["total_num_channel"] = len(Files) return header def read_data(self, filename): HEADERSIZE = 40 """Read data from a TDT SEV file created by the RS4 streamer""" with open(filename, "rb") as f: f.seek(HEADERSIZE) data = np.fromfile(f, dtype="single") return data def md5sum(self, filename): from hashlib import md5 hash = md5() with open(filename, "rb") as f: for chunk in iter(lambda: f.read(128 * hash.block_size), b""): hash.update(chunk) return hash.hexdigest() def data_aranging(self, Files, DataInfo_loaded): AData = spy.AnalogData(dimord=["time", "channel"]) hdf_out_path = AData.filename LenOfData = self.read_data(Files[0]).shape[ 0 ] # Lenght of the data is always set to the length of the first channel with h5py.File(hdf_out_path, "w") as combined_data_file: idxStartStop = [ np.clip( np.array((jj, jj + self.chan_in_chunks)), a_min=None, a_max=len(Files), ) for jj in range(0, len(Files), self.chan_in_chunks) ] print( "Merging {0} files in {1} chunks each with {2} channels into \n {3}".format( len(Files), len(idxStartStop), self.chan_in_chunks, hdf_out_path ) ) for (start, stop) in tqdm(iterable=idxStartStop, desc="chunk", unit="chunk", disable=None): data = [self.read_data(Files[jj])[:LenOfData] for jj in range(start, stop)] data = np.vstack(data).T if start == 0: # this is the actual dataset for the AnalogData target = combined_data_file.create_dataset( "data", shape=(data.shape[0], len(Files)), dtype="single" ) if self.subtract_median: data -= np.median(data, keepdims=True).astype(data.dtype) target[:, start:stop] = data # link dataset to AnalogData instance AData.data = target # temporary fix to get at least all-to-all AData.trialdefinition = None chanlist = ( None if self.channels == "all" else ["channel" + str(trch + 1).zfill(3) for trch in self.channels] ) AData.samplerate = DataInfo_loaded.LFPs.fs AData.channel = chanlist # helper to make serializable def serial(arr): return arr.tolist() # write info file AData.info["originalFiles"] = (Files,) # AData.info["md5sum"] = self.md5sum(hdf_out_path) AData.info["blockname"] = DataInfo_loaded.info.blockname AData.info["start_date"] = str(DataInfo_loaded.info.start_date) AData.info["utc_start_time"] = DataInfo_loaded.info.utc_start_time AData.info["stop_date"] = str(DataInfo_loaded.info.stop_date) AData.info["utc_stop_time"] = DataInfo_loaded.info.utc_stop_time AData.info["duration"] = str(DataInfo_loaded.info.duration) AData.info["PDio_onset"] = serial(DataInfo_loaded.PDio.onset) AData.info["PDio_offset"] = serial(DataInfo_loaded.PDio.offset) AData.info["PDio_data"] = serial(DataInfo_loaded.PDio.data) AData.info["Trigger_timestamp"] = serial(DataInfo_loaded.Mark.ts) AData.info["Trigger_sample"] = serial(np.round(DataInfo_loaded.Mark.ts * DataInfo_loaded.LFPs.fs)) AData.info["Trigger_code"] = serial(DataInfo_loaded.Mark.data[0]) return AData # --- Helpers --- def _mk_trialdef(adata, start_code, end_code): """ Create a basic trialdefinition from the trial start and end trigger codes """ # trigger codes and samples trg_codes = np.array(adata.info["Trigger_code"], dtype=int) trg_sample = np.array(adata.info["Trigger_sample"], dtype=int) # boolean indexing trl_starts = trg_sample[trg_codes == start_code] trl_ends = trg_sample[trg_codes == end_code] if trl_starts.size == 0: lgl = "at least one occurence of trial start code" raise SPYValueError(lgl, "start_code", start_code) if trl_ends.size == 0: lgl = "at least one occurence of trial end code" raise SPYValueError(lgl, "end_code", end_code) if trl_starts.size > trl_ends.size: msg = f"Found {trl_starts.size} trial starts and {trl_ends.size} trial end codes!\n" msg += "truncating to number of trial ends.." SPYWarning(msg) N = trl_ends.size elif trl_ends.size > trl_starts.size: msg = f"Found {trl_starts.size} trial starts and {trl_ends.size} trial end codes!\n" msg += "truncating to number of trial starts.." SPYWarning(msg) N = trl_starts.size # both are equal else: N = trl_starts.size trldef = np.zeros((N, 3)) trldef[:, 0] = trl_starts[:N] trldef[:, 1] = trl_ends[:N] return trldef def _get_source_paths(directory, ext=".sev"): """ Returns all abs. paths in `directory` for files which end with `ext` """ # get all data source file names f_names = [f for f in os.listdir(directory) if f.endswith(ext)] # parse absolute paths to source files tdtPaths = [] for fname in f_names: fname = os.path.join(directory, fname) f_path, f_name = io_parser(fname, varname="tdt source file name", isfile=True, exists=True) tdtPaths.append(os.path.join(f_path, f_name)) tdtPaths = _natural_sort(tdtPaths) return tdtPaths def _natural_sort(file_names): """Sort a list of strings using numbers Ch1 will be followed by Ch2 and not Ch11. """ def convert(text): return int(text) if text.isdigit() else text.lower() def alphanum_key(key): return [convert(c) for c in re.split("([0-9]+)", key)] return sorted(file_names, key=alphanum_key)