# -*- coding: utf-8 -*-
# @Author: Diljit Singh Kajal
# @Date: 2022-04-08 15:00:00
#
# load_tdt.py Merge separate TDT SEV files into one HDF5 file
import os
from datetime import datetime
import re
import numpy as np
from getpass import getuser
from tqdm.auto import tqdm
import h5py
# Local imports
from syncopy.shared.parsers import io_parser, scalar_parser
from syncopy.shared.errors import SPYWarning, SPYValueError
from syncopy.shared.tools import StructDict
import syncopy as spy
# --- The user exposed function ---
[docs]def load_tdt(data_path, start_code=None, end_code=None,
subtract_median=False):
"""
Imports TDT time series data and meta-information
into a single :class:`~syncopy.AnalogData` object.
An ad-hoc trialdefinition will be attached if
both `start_code` and `end_code` are given.
Otherwise a single all-to-all trialdefinition
is used. Custom trialdefinitions can be done
afterwards with :func:`~syncopy.definetrial`.
All meta-information is stored within the `.info`
dict of the :class:`~syncopy.AnalogData` object.
PDio related keys:
("PDio_onset", "PDio_offset", "PDio_data")
Trigger related keys:
("Trigger_code", "Trigger_timestamp", "Trigger_sample")
Parameters
----------
data_path : str
Path to the directory containing the `.sev` files
start_code : int or None, optional
Trigger code defining the beginning of a trial
end_code : int or None, optional
Trigger code defining the end of a trial
subtract_median : bool
Set to `True` to subtract the median from all
individual time series
Returns
-------
adata : :class:`~syncopy.AnalogData`
The tdt data in syncopy format
Examples
--------
Load all channels from a source directory `/data/session3/`:
>>> adata = load_tdt('/data/session3')
Access the trigger codes and samples:
>>> trg_code = adata.info['Trigger_code']
>>> trg_sample = adata.info['Trigger_sample']
Load the same data and construct a trialdefinition on the fly:
>>> adata = load_tdt('/data/session3', start_code=23000, end_code=30020)
Access the 3rd trial:
>>> trl_dat = adata.trials[2]
"""
io_parser(data_path, isfile=False)
if start_code is not None and end_code is None:
lgl = "trigger codes for both trial start and end"
raise SPYValueError(lgl, "end_code", end_code)
if end_code is not None and start_code is None:
lgl = "trigger codes for both trial start and end"
raise SPYValueError(lgl, "start_code", start_code)
elif end_code is None and start_code is None:
pass
# both are given
else:
scalar_parser(start_code, "start_code", ntype='int_like')
scalar_parser(end_code, "end_code", ntype='int_like')
# initialize tdt info loader class
TDT_Load_Info = ESI_TDTinfo(data_path)
# this is a StructDict
tdt_info = TDT_Load_Info.load_tdt_info()
# nicely sorted by channel names
file_paths = _get_source_paths(data_path, ".sev")
tdt_data_handler = ESI_TDTdata(
data_path, subtract_median=subtract_median, channels=None
)
adata = tdt_data_handler.data_aranging(file_paths, tdt_info)
# we have to open for reading again
adata.data = h5py.File(adata.filename, "r")['data']
# Write log-entry
msg = f"loaded TDT data from {len(file_paths)} files\n"
msg += f"\tsource folder: {data_path}\n"
msg += f"\tsubtract median: {subtract_median}"
adata.log = msg
if start_code is not None:
adata.trialdefinition = _mk_trialdef(adata, start_code, end_code)
msg = f"created trialdefinition from start code `{start_code}` and end code `{end_code}`"
adata.log = msg
return adata
class ESI_TDTinfo:
def __init__(self, block_path):
self.block_path = block_path
self.nodata = False
self.t1 = 0
self.t2 = 0
self.UNKNOWN = int("00000000", 16)
self.STRON = int("00000101", 16)
self.STROFF = int("00000102", 16)
self.SCALAR = int("00000201", 16)
self.STREAM = int("00008101", 16)
self.SNIP = int("00008201", 16)
self.MARK = int("00008801", 16)
self.HASDATA = int("00008000", 16)
self.UCF = int("00000010", 16)
self.PHANTOM = int("00000020", 16)
self.MASK = int("0000FF0F", 16)
self.INVALID_MASK = int("FFFF0000", 16)
self.STARTBLOCK = int("0001", 16)
self.STOPBLOCK = int("0002", 16)
self.ALLOWED_FORMATS = [np.float32, np.int32, np.int16, np.int8, np.float64, np.int64]
self.ALLOWED_EVTYPES = ["all", "epocs", "snips", "streams", "scalars"]
def code_to_type(self, code):
# given event code, return string 'epocs', 'snips', 'streams', or 'scalars'
strobe_types = [self.STRON, self.STROFF, self.MARK]
scalar_types = [self.SCALAR]
snip_types = [self.SNIP]
if code in strobe_types:
s = "epocs"
elif code in snip_types:
s = "snips"
elif code & self.MASK == self.STREAM:
s = "streams"
elif code in scalar_types:
s = "scalars"
else:
s = "unknown"
return s
def time2sample(self, ts, fs=195312.5, t1=False, t2=False, to_time=False):
sample = ts * fs
if t2:
# drop precision beyond 1e-9
exact = np.round(sample * 1e9) / 1e9
sample = np.floor(sample)
if exact == sample:
sample -= 1
else:
sample = np.ceil(sample) if t1 else np.round(sample)
sample = np.uint64(sample)
if to_time:
return np.float64(sample) / fs
return sample
def check_ucf(self, code):
# given event code, check if it has unique channel files
return code & self.UCF == self.UCF
def epoc_to_type(self, code):
# given epoc event code, return if it is 'onset' or 'offset' event
strobe_on_types = [self.STRON, self.MARK]
strobe_off_types = [self.STROFF]
if code in strobe_on_types:
return "onset"
elif code in strobe_off_types:
return "offset"
return "unknown"
def code_to_name(self, code):
return int(code).to_bytes(4, byteorder="little").decode("cp437")
def load_tdt_info(self):
header = StructDict()
data = StructDict()
data.epocs = StructDict()
data.streams = StructDict()
data.scalars = StructDict()
data.info = StructDict()
epocs = StructDict()
epocs.name = []
epocs.buddies = []
epocs.ts = []
epocs.code = []
epocs.type = []
epocs.type_str = []
epocs.data = []
epocs.dform = []
tsq_list = _get_source_paths(self.block_path, ".tsq")
if len(tsq_list) > 1:
raise Exception("multiple TSQ files found\n{0}".format(", ".join(tsq_list)))
tsq = open(tsq_list[0], "rb")
tsq.seek(0, os.SEEK_SET)
tsq.seek(48, os.SEEK_SET)
code1 = np.fromfile(tsq, dtype=np.int32, count=1)
assert code1 == self.STARTBLOCK, "Block start marker not found"
tsq.seek(56, os.SEEK_SET)
header.start_time = np.fromfile(tsq, dtype=np.float64, count=1)
# read stop time
tsq.seek(-32, os.SEEK_END)
code2 = np.fromfile(tsq, dtype=np.int32, count=1)
if code2 != self.STOPBLOCK:
SPYWarning(
"Block end marker not found, block did not end cleanly. Try setting T2 smaller if errors occur"
)
header.stop_time = np.nan
else:
tsq.seek(-24, os.SEEK_END)
header.stop_time = np.fromfile(tsq, dtype=np.float64, count=1)
[data.info.tankpath, data.info.blockname] = os.path.split(os.path.normpath(self.block_path))
data.info.start_date = datetime.fromtimestamp(header.start_time[0])
if not np.isnan(header.start_time):
data.info.utc_start_time = data.info.start_date.strftime("%H:%M:%S")
else:
data.info.utc_start_time = np.nan
if not np.isnan(header.stop_time):
data.info.stop_date = datetime.fromtimestamp(header.stop_time[0])
data.info.utc_stop_time = data.info.stop_date.strftime("%H:%M:%S")
else:
data.info.stop_date = np.nan
data.info.utc_stop_time = np.nan
if header.stop_time > 0:
data.info.duration = data.info.stop_date - data.info.start_date # datestr(s2-s1, 'HH:MM:SS')
tsq.seek(40, os.SEEK_SET)
read_size = 10000000 if self.t2 > 0 else 50000000
header.stores = StructDict()
while True:
heads = np.frombuffer(tsq.read(read_size * 4), dtype=np.uint32)
rem = len(heads) % 10
if rem != 0:
SPYWarning("Block did not end cleanly, removing last {0} headers".format(rem))
heads = heads[:-rem]
# reshape so each column is one header
heads = heads.reshape((-1, 10)).T
# check the codes first and build store maps and note arrays
codes = heads[2, :]
good_codes = codes > 0
bad_codes = np.logical_not(good_codes)
if np.sum(bad_codes) > 0:
SPYWarning(
"Bad TSQ headers were written, removing {0}, keeping {1} headers".format(
sum(bad_codes), sum(good_codes)
)
)
heads = heads[:, good_codes]
codes = heads[2, :]
# get set of codes but preserve order in the block
store_codes = []
unique_codes, unique_ind = np.unique(codes, return_index=True)
for counter, x in enumerate(unique_codes):
store_codes.append(
{
"code": x,
"type": heads[1, unique_ind[counter]],
"type_str": self.code_to_type(heads[1, unique_ind[counter]]),
"ucf": self.check_ucf(heads[1, unique_ind[counter]]),
"epoc_type": self.epoc_to_type(heads[1, unique_ind[counter]]),
"dform": heads[8, unique_ind[counter]],
"size": heads[0, unique_ind[counter]],
"buddy": heads[3, unique_ind[counter]],
"temp": heads[:, unique_ind[counter]],
}
)
# Looking for only Mark, PDi\ and PDio
looking_for = ["Mark", "PDio", "LFPs", "LFP1", "PDi\\"] #
targets = StructDict()
for chk, content in enumerate(store_codes):
if self.code_to_name(content["code"]) in looking_for:
targets[self.code_to_name(content["code"])] = chk
for tar in targets.items():
store_code = store_codes[tar[1]]
store_code["name"] = self.code_to_name(store_code["code"])
# print(code_to_name(store_code['code']), store_code['code'])
store_code["var_name"] = store_code["name"]
var_name = store_code["var_name"]
if store_code["type_str"] == "epocs":
if not store_code["name"] in epocs.name:
buddy = "".join([str(chr(c)) for c in np.array([store_code["buddy"]]).view(np.uint8)])
buddy = buddy.replace("\x00", " ")
# if skip_by_name:
# if store:
# if isinstance(store, str):
# if buddy == store:
# skip_by_name = False
# elif isinstance(store, list):
# if buddy in store:
# skip_by_name = False
# if skip_by_name:
# continue
epocs.name.append(store_code["name"])
epocs.buddies.append(buddy)
epocs.code.append(store_code["code"])
epocs.ts.append([])
epocs.type.append(store_code["epoc_type"])
epocs.type_str.append(store_code["type_str"])
epocs.data.append([])
epocs.dform.append(store_code["dform"])
if var_name not in header.stores.keys():
if store_code["type_str"] != "epocs":
header.stores[var_name] = StructDict(
name=store_code["name"],
code=store_code["code"],
size=store_code["size"],
type=store_code["type"],
type_str=store_code["type_str"],
)
if header.stores[var_name].type_str == "streams":
header.stores[var_name].ucf = store_code["ucf"]
if header.stores[var_name].type_str != "scalars":
# Finding the sampling rate
header.stores[var_name].fs = np.double(
np.array([store_code["temp"][9]]).view(np.float32)
)
header.stores[var_name].dform = store_code["dform"]
valid_ind = np.where(codes == store_code["code"])[0]
temp = heads[3, valid_ind].view(np.uint16)
if store_code["type_str"] != "epocs":
if not hasattr(header.stores[var_name], "ts"):
header.stores[var_name].ts = []
vvv = (
np.reshape(
heads[[[4], [5]], valid_ind].T,
(-1, 1)).T.view(np.float64) - header.start_time
)
# round timestamps to the nearest sample
vvv = self.time2sample(vvv, to_time=True)
header.stores[var_name].ts.append(vvv)
if (not self.nodata) or (store_code["type_str"] == "streams"):
if not hasattr(header.stores[var_name], "data"):
header.stores[var_name].data = []
header.stores[var_name].data.append(
np.reshape(heads[[[6], [7]], valid_ind].T, (-1, 1)).T.view(np.float64)
)
if not hasattr(header.stores[var_name], "chan"):
header.stores[var_name].chan = []
header.stores[var_name].chan.append(temp[::2])
else:
loc = epocs.name.index(store_code["name"])
# round timestamps to the nearest sample
vvv = (
np.reshape(
heads[[[4], [5]], valid_ind].T,
(-1, 1)).T.view(np.float64) - header.start_time
)
# round timestamps to the nearest sample
vvv = self.time2sample(vvv, to_time=True)
epocs.ts[loc] = np.append(epocs.ts[loc], vvv)
epocs.data[loc] = np.append(
epocs.data[loc], np.reshape(heads[[[6], [7]], valid_ind].T, (-1, 1)).T.view(np.float64)
)
last_ts = heads[[4, 5], -1].view(np.float64) - header.start_time
last_ts = last_ts[0]
if self.t2 > 0 and last_ts > self.t2:
break
# eof reached
if heads.size < read_size:
break
print(
"Reading data from t = {0}s to t = {1}s".format(
np.round(self.t1, 2), np.round(np.maximum(last_ts, self.t2), 2)
)
)
for ii in range(len(epocs.name)):
# find all non-buddies first
if epocs.type[ii] == "onset":
var_name = epocs.name[ii]
header.stores[var_name] = StructDict()
header.stores[var_name].name = epocs.name[ii]
ts = epocs.ts[ii]
header.stores[var_name].onset = ts
header.stores[var_name].offset = np.append(ts[1:], np.inf)
header.stores[var_name].type = epocs.type[ii]
header.stores[var_name].type_str = epocs.type_str[ii]
header.stores[var_name].data = epocs.data[ii]
header.stores[var_name].dform = epocs.dform[ii]
header.stores[var_name].size = 10
for ii in range(len(epocs.name)):
if epocs.type[ii] == "offset":
var_name = epocs.buddies[ii]
if var_name not in header.stores.keys():
SPYWarning(epocs.buddies[ii] + " buddy epoc not found, skipping")
continue
header.stores[var_name].offset = epocs.ts[ii]
# handle odd case where there is a single offset event and no onset events
if "onset" not in header.stores[var_name].keys():
header.stores[var_name].name = epocs.buddies[ii]
header.stores[var_name].onset = 0
header.stores[var_name].type_str = "epocs"
header.stores[var_name].type = "onset"
header.stores[var_name].data = 0
header.stores[var_name].dform = 4
header.stores[var_name].size = 10
# fix time ranges
if header.stores[var_name].offset[0] < header.stores[var_name].onset[0]:
header.stores[var_name].onset = np.append(0, header.stores[var_name].onset)
header.stores[var_name].data = np.append(
header.stores[var_name].data[0], header.stores[var_name].data
)
if header.stores[var_name].onset[-1] > header.stores[var_name].offset[-1]:
header.stores[var_name].offset = np.append(header.stores[var_name].offset, np.inf)
for var_name in header.stores.keys():
# convert cell arrays to regular arrays
if "ts" in header.stores[var_name].keys():
header.stores[var_name].ts = np.concatenate(header.stores[var_name].ts, axis=1)[0]
if "chan" in header.stores[var_name].keys():
header.stores[var_name].chan = np.concatenate(header.stores[var_name].chan)
if "sortcode" in header.stores[var_name].keys():
header.stores[var_name].sortcode = np.concatenate(header.stores[var_name].sortcode)
if "data" in header.stores[var_name].keys():
if header.stores[var_name].type_str != "epocs":
header.stores[var_name].data = np.concatenate(header.stores[var_name].data, axis=1)[0]
# if it's a data type, cast as a file offset pointer instead of data
if header.stores[var_name].type_str in ["streams", "snips"]:
if "data" in header.stores[var_name].keys():
header.stores[var_name].data = header.stores[var_name].data.view(np.uint64)
if "chan" in header.stores[var_name].keys():
if np.max(header.stores[var_name].chan) == 1:
header.stores[var_name].chan = [1]
valid_time_range = np.array([[self.t1], [self.t2]]) if self.t2 > 0 else np.array([[self.t1], [np.inf]])
ranges = None
if hasattr(ranges, "__len__"):
valid_time_range = ranges
num_ranges = valid_time_range.shape[1]
if num_ranges > 0:
data.time_ranges = valid_time_range
for var_name in header.stores.keys():
current_type_str = header.stores[var_name].type_str
data[current_type_str][var_name] = header.stores[var_name]
firstStart = valid_time_range[0, 0]
last_stop = valid_time_range[1, -1]
if "ts" in header.stores[var_name].keys():
if current_type_str == "streams":
data[current_type_str][var_name].start_time = [0 for jj in range(num_ranges)]
else:
this_dtype = data[current_type_str][var_name].ts.dtype
data[current_type_str][var_name].filtered_ts = [
np.array([], dtype=this_dtype) for jj in range(num_ranges)
]
if hasattr(data[current_type_str][var_name], "chan"):
data[current_type_str][var_name].filtered_chan = [[] for jj in range(num_ranges)]
if hasattr(data[current_type_str][var_name], "sortcode"):
this_dtype = data[current_type_str][var_name].sortcode.dtype
data[current_type_str][var_name].filtered_sort_code = [
np.array([], dtype=this_dtype) for jj in range(num_ranges)
]
if hasattr(data[current_type_str][var_name], "data"):
this_dtype = data[current_type_str][var_name].data.dtype
data[current_type_str][var_name].filtered_data = [
np.array([], dtype=this_dtype) for jj in range(num_ranges)
]
filter_ind = [[] for i in range(num_ranges)]
for jj in range(num_ranges):
start = valid_time_range[0, jj]
stop = valid_time_range[1, jj]
ind1 = data[current_type_str][var_name].ts >= start
ind2 = data[current_type_str][var_name].ts < stop
filter_ind[jj] = np.where(ind1 & ind2)[0]
bSkip = 0
if len(filter_ind[jj]) == 0:
# if it's a stream and a short window, we might have missed it
if current_type_str == "streams":
ind2 = np.where(ind2)[0]
if len(ind2) > 0:
ind2 = ind2[-1]
# keep one prior for streams (for all channels)
nchan = max(data[current_type_str][var_name].chan)
if ind2 - nchan >= -1:
filter_ind[jj] = ind2 - np.arange(nchan - 1, -1, -1)
temp = data[current_type_str][var_name].ts[filter_ind[jj]]
data[current_type_str][var_name].start_time[jj] = temp[0]
bSkip = 1
if len(filter_ind[jj]) > 0:
# parse out the information we need
if current_type_str == "streams":
# keep one prior for streams (for all channels)
if not bSkip:
nchan = max(data[current_type_str][var_name].chan)
temp = filter_ind[jj]
if temp[0] - nchan > -1:
filter_ind[jj] = np.concatenate(
[-np.arange(nchan, 0, -1) + temp[0], filter_ind[jj]]
)
temp = data[current_type_str][var_name].ts[filter_ind[jj]]
data[current_type_str][var_name].start_time[jj] = temp[0]
else:
data[current_type_str][var_name].filtered_ts[jj] = data[current_type_str][
var_name
].ts[filter_ind[jj]]
if hasattr(data[current_type_str][var_name], "chan"):
if len(data[current_type_str][var_name].chan) > 1:
data[current_type_str][var_name].filtered_chan[jj] = data[current_type_str][
var_name
].chan[filter_ind[jj]]
else:
data[current_type_str][var_name].filtered_chan[jj] = data[current_type_str][
var_name
].chan
if hasattr(data[current_type_str][var_name], "sortcode"):
data[current_type_str][var_name].filtered_sort_code[jj] = data[current_type_str][
var_name
].sortcode[filter_ind[jj]]
if hasattr(data[current_type_str][var_name], "data"):
data[current_type_str][var_name].filtered_data[jj] = data[current_type_str][
var_name
].data[filter_ind[jj]]
if current_type_str == "streams":
delattr(data[current_type_str][var_name], "ts")
delattr(data[current_type_str][var_name], "data")
delattr(data[current_type_str][var_name], "chan")
if not hasattr(data[current_type_str][var_name], "filtered_chan"):
data[current_type_str][var_name].filtered_chan = [[] for i in range(num_ranges)]
if not hasattr(data[current_type_str][var_name], "filtered_data"):
data[current_type_str][var_name].filtered_data = [[] for i in range(num_ranges)]
if not hasattr(data[current_type_str][var_name], "start_time"):
data[current_type_str][var_name].start_time = -1
else:
# consolidate other fields
if hasattr(data[current_type_str][var_name], "filtered_ts"):
data[current_type_str][var_name].ts = np.concatenate(
data[current_type_str][var_name].filtered_ts
)
delattr(data[current_type_str][var_name], "filtered_ts")
else:
data[current_type_str][var_name].ts = []
if hasattr(data[current_type_str][var_name], "chan"):
if hasattr(data[current_type_str][var_name], "filtered_chan"):
data[current_type_str][var_name].chan = np.concatenate(
data[current_type_str][var_name].filtered_chan
)
delattr(data[current_type_str][var_name], "filtered_chan")
if current_type_str == "snips":
if len(set(data[current_type_str][var_name].chan)) == 1:
data[current_type_str][var_name].chan = [
data[current_type_str][var_name].chan[0]
]
else:
data[current_type_str][var_name].chan = []
if hasattr(data[current_type_str][var_name], "sortcode"):
if hasattr(data[current_type_str][var_name], "filtered_sort_code"):
data[current_type_str][var_name].sortcode = np.concatenate(
data[current_type_str][var_name].filtered_sort_code
)
delattr(data[current_type_str][var_name], "filtered_sort_code")
else:
data[current_type_str][var_name].sortcode = []
if hasattr(data[current_type_str][var_name], "data"):
if hasattr(data[current_type_str][var_name], "filtered_data"):
data[current_type_str][var_name].data = np.concatenate(
data[current_type_str][var_name].filtered_data
)
delattr(data[current_type_str][var_name], "filtered_data")
else:
data[current_type_str][var_name].data = []
else:
# handle epoc events
filter_ind = []
for jj in range(num_ranges):
start = valid_time_range[0, jj]
stop = valid_time_range[1, jj]
ind1 = data[current_type_str][var_name].onset >= start
ind2 = data[current_type_str][var_name].onset < stop
filter_ind.append(np.where(ind1 & ind2)[0])
filter_ind = np.concatenate(filter_ind)
if len(filter_ind) > 0:
data[current_type_str][var_name].onset = data[current_type_str][var_name].onset[filter_ind]
data[current_type_str][var_name].data = data[current_type_str][var_name].data[filter_ind]
data[current_type_str][var_name].offset = data[current_type_str][var_name].offset[
filter_ind
]
if data[current_type_str][var_name].offset[0] < data[current_type_str][var_name].onset[0]:
if data[current_type_str][var_name].onset[0] > firstStart:
data[current_type_str][var_name].onset = np.concatenate(
[[firstStart], data[current_type_str][var_name].onset]
)
if data[current_type_str][var_name].offset[-1] > last_stop:
data[current_type_str][var_name].offset[-1] = last_stop
else:
# default case is no valid events for this store
data[current_type_str][var_name].onset = []
data[current_type_str][var_name].data = []
data[current_type_str][var_name].offset = []
if var_name == "Note":
data[current_type_str][var_name].notes = []
for current_name in header.stores.keys():
current_type_str = header.stores[current_name].type_str
# if current_type_str not in evtype:
# continue
current_type_str = data[current_type_str][current_name].type_str
current_data_format = self.ALLOWED_FORMATS[data[current_type_str][current_name].dform]
if current_type_str == "scalars":
if len(data[current_type_str][current_name].chan) > 0:
nchan = int(np.max(data[current_type_str][current_name].chan))
else:
nchan = 0
if nchan > 1:
# organize data by sample
# find channels with most and least amount of data
ind = []
min_length = np.inf
max_length = 0
for xx in range(nchan):
ind.append(np.where(data[current_type_str][current_name].chan == xx + 1)[0])
min_length = min(len(ind[-1]), min_length)
max_length = max(len(ind[-1]), max_length)
if min_length != max_length:
SPYWarning(
"Truncating store {0} to {1} values (from {2})".format(
current_name, min_length, max_length
))
ind = [ind[xx][:min_length] for xx in range(nchan)]
if not self.nodata:
data[current_type_str][current_name].data = (
data[current_type_str][current_name].data[np.concatenate(ind)].reshape(nchan, -1)
)
# only use timestamps from first channel
data[current_type_str][current_name].ts = data[current_type_str][current_name].ts[ind[0]]
# remove channels field
delattr(data[current_type_str][current_name], "chan")
tsq.close()
del epocs
del header
Data = StructDict()
Data.PDio = data.epocs.PDio
try:
Data.LFPs = data.streams.LFPs
except Exception:
Data.LFPs = data.streams.LFP1
Data.Mark = data.scalars.Mark
Data.info = data.info
return Data
class ESI_TDTdata:
def __init__(
self,
inputdir,
subtract_median=False,
channels=None,
):
self.inputdir = inputdir
self.chan_in_chunks = 16
self.subtract_median = subtract_median
self.channels = "all" if channels is None else channels
def arrange_header(self, DataInfo_loaded, Files):
header = StructDict()
header["fs"] = DataInfo_loaded.LFPs.fs
header["total_num_channel"] = len(Files)
return header
def read_data(self, filename):
HEADERSIZE = 40
"""Read data from a TDT SEV file created by the RS4 streamer"""
with open(filename, "rb") as f:
f.seek(HEADERSIZE)
data = np.fromfile(f, dtype="single")
return data
def md5sum(self, filename):
from hashlib import md5
hash = md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(128 * hash.block_size), b""):
hash.update(chunk)
return hash.hexdigest()
def data_aranging(self, Files, DataInfo_loaded):
AData = spy.AnalogData(dimord=['time', 'channel'])
hdf_out_path = AData.filename
LenOfData = self.read_data(Files[0]).shape[0] # Lenght of the data is always set to the length of the first channel
with h5py.File(hdf_out_path, "w") as combined_data_file:
idxStartStop = [
np.clip(np.array((jj, jj + self.chan_in_chunks)), a_min=None, a_max=len(Files))
for jj in range(0, len(Files), self.chan_in_chunks)
]
print(
"Merging {0} files in {1} chunks each with {2} channels into \n {3}".format(
len(Files), len(idxStartStop), self.chan_in_chunks, hdf_out_path
)
)
for (start, stop) in tqdm(iterable=idxStartStop, desc="chunk", unit="chunk", disable=None):
data = [self.read_data(Files[jj])[:LenOfData] for jj in range(start, stop)]
data = np.vstack(data).T
if start == 0:
# this is the actual dataset for the AnalogData
target = combined_data_file.create_dataset(
"data", shape=(data.shape[0], len(Files)), dtype="single"
)
if self.subtract_median:
data -= np.median(data, keepdims=True).astype(data.dtype)
target[:, start:stop] = data
# link dataset to AnalogData instance
AData.data = target
# temporary fix to get at least all-to-all
AData.trialdefinition = None
chanlist = (
None if self.channels == "all" else ["channel" + str(trch + 1).zfill(3) for trch in self.channels]
)
AData.samplerate = DataInfo_loaded.LFPs.fs
AData.channel = chanlist
# helper to make serializable
def serial(arr):
return arr.tolist()
# write info file
AData.info["originalFiles"] = (Files,)
# AData.info["md5sum"] = self.md5sum(hdf_out_path)
AData.info["blockname"] = DataInfo_loaded.info.blockname
AData.info["start_date"] = str(DataInfo_loaded.info.start_date)
AData.info["utc_start_time"] = DataInfo_loaded.info.utc_start_time
AData.info["stop_date"] = str(DataInfo_loaded.info.stop_date)
AData.info["utc_stop_time"] = DataInfo_loaded.info.utc_stop_time
AData.info["duration"] = str(DataInfo_loaded.info.duration)
AData.info["PDio_onset"] = serial(DataInfo_loaded.PDio.onset)
AData.info["PDio_offset"] = serial(DataInfo_loaded.PDio.offset)
AData.info["PDio_data"] = serial(DataInfo_loaded.PDio.data)
AData.info["Trigger_timestamp"] = serial(DataInfo_loaded.Mark.ts)
AData.info["Trigger_sample"] = serial(np.round(DataInfo_loaded.Mark.ts * DataInfo_loaded.LFPs.fs))
AData.info["Trigger_code"] = serial(DataInfo_loaded.Mark.data[0])
return AData
# --- Helpers ---
def _mk_trialdef(adata, start_code, end_code):
"""
Create a basic trialdefinition from the trial
start and end trigger codes
"""
# trigger codes and samples
trg_codes = np.array(adata.info['Trigger_code'], dtype=int)
trg_sample = np.array(adata.info['Trigger_sample'], dtype=int)
# boolean indexing
trl_starts = trg_sample[trg_codes == start_code]
trl_ends = trg_sample[trg_codes == end_code]
if trl_starts.size == 0:
lgl = "at least one occurence of trial start code"
raise SPYValueError(lgl, "start_code", start_code)
if trl_ends.size == 0:
lgl = "at least one occurence of trial end code"
raise SPYValueError(lgl, "end_code", end_code)
if trl_starts.size > trl_ends.size:
msg = f"Found {trl_starts.size} trial starts and {trl_ends.size} trial end codes!\n"
msg += "truncating to number of trial ends.."
SPYWarning(msg)
N = trl_ends.size
elif trl_ends.size > trl_starts.size:
msg = f"Found {trl_starts.size} trial starts and {trl_ends.size} trial end codes!\n"
msg += "truncating to number of trial starts.."
SPYWarning(msg)
N = trl_starts.size
# both are equal
else:
N = trl_starts.size
trldef = np.zeros((N, 3))
trldef[:, 0] = trl_starts[:N]
trldef[:, 1] = trl_ends[:N]
return trldef
def _get_source_paths(directory, ext=".sev"):
"""
Returns all abs. paths in `directory`
for files which end with `ext`
"""
# get all data source file names
f_names = [f for f in os.listdir(directory) if f.endswith(ext)]
# parse absolute paths to source files
tdtPaths = []
for fname in f_names:
fname = os.path.join(directory, fname)
f_path, f_name = io_parser(fname, varname="tdt source file name", isfile=True, exists=True)
tdtPaths.append(os.path.join(f_path, f_name))
tdtPaths = _natural_sort(tdtPaths)
return tdtPaths
def _natural_sort(file_names):
"""Sort a list of strings using numbers
Ch1 will be followed by Ch2 and not Ch11.
"""
def convert(text):
return int(text) if text.isdigit() else text.lower()
def alphanum_key(key):
return [convert(c) for c in re.split("([0-9]+)", key)]
return sorted(file_names, key=alphanum_key)