# -*- coding: utf-8 -*-
#
# Set/update trial settings of Syncopy data objects
#
# Builtin/3rd party package imports
import sys
import numpy as np
# Local imports
from syncopy.shared.parsers import data_parser, array_parser, scalar_parser
from syncopy.shared.errors import SPYTypeError, SPYValueError
from ...datatype.util import TimeIndexer
__all__ = ["definetrial"]
[docs]def definetrial(
obj,
trialdefinition=None,
pre=None,
post=None,
start=None,
trigger=None,
stop=None,
clip_edges=False,
):
"""(Re-)define trials of a Syncopy data object
Data can be structured into trials based on timestamps of a start, trigger
and end events::
start trigger stop
|---- pre ----|--------|---------|--- post----|
**Note**: To define a trial encompassing the whole dataset simply invoke this
routine with no arguments, i.e., ``definetrial(obj)`` or equivalently
``obj.definetrial()``
Parameters
----------
obj : Syncopy data object (:class:`BaseData`-like)
trialdefinition : :class:`EventData` object or Mx3 array
[start, stop, trigger_offset] sample indices for `M` trials
pre : float
offset time (s) before start event
post : float
offset time (s) after end event
start : int
event code (id) to be used for start of trial
stop : int
event code (id) to be used for end of trial
trigger :
event code (id) to be used center (t=0) of trial
clip_edges : bool
trim trials to actual data-boundaries.
Returns
-------
Syncopy data object (:class:`BaseData`-like))
Notes
-----
:func:`definetrial` supports the following argument combinations:
>>> # define M trials based on [start, end, offset] indices
>>> definetrial(obj, trialdefinition=[M x 3] array)
>>> # define trials based on event codes stored in <:class:`EventData` object>
>>> definetrial(obj, trialdefinition=<EventData object>,
pre=0, post=0, start=startCode, stop=stopCode,
trigger=triggerCode)
>>> # apply same trial definition as defined in <:class:`EventData` object>
>>> definetrial(<AnalogData object>,
trialdefinition=<EventData object w/sampleinfo/t0/trialinfo>)
>>> # define whole recording as single trial
>>> definetrial(obj, trialdefinition=None)
"""
# Start by vetting input object
data_parser(obj, varname="obj")
if obj.data is None:
lgl = "non-empty Syncopy data object"
act = "empty Syncopy data object"
raise SPYValueError(legal=lgl, varname="obj", actual=act)
# Check array/object holding trial specifications
if trialdefinition is not None:
if trialdefinition.__class__.__name__ == "EventData":
data_parser(trialdefinition, varname="trialdefinition", writable=None, empty=False)
evt = True
else:
array_parser(trialdefinition, varname="trialdefinition", dims=2)
if any(["ContinuousData" in str(base) for base in obj.__class__.__mro__]):
scount = obj.data.shape[obj.dimord.index("time")]
else:
scount = np.inf
array_parser(
trialdefinition[:, :2],
varname="sampleinfo",
dims=(None, 2),
hasnan=False,
hasinf=False,
ntype="int_like",
lims=[0, scount],
)
trl = np.array(trialdefinition, dtype="float")
ref = obj
tgt = obj
evt = False
else:
# Construct object-class-specific `trl` arrays treating data-set as single trial
if any(["ContinuousData" in str(base) for base in obj.__class__.__mro__]):
trl = np.array([[0, obj.data.shape[obj.dimord.index("time")], 0]])
else:
sidx = obj.dimord.index("sample")
trl = np.array([[np.nanmin(obj.data[:, sidx]), np.nanmax(obj.data[:, sidx]), 0]])
ref = obj
tgt = obj
evt = False
# AnalogData + EventData w/sampleinfo
if obj.__class__.__name__ == "AnalogData" and evt and trialdefinition.sampleinfo is not None:
if obj.samplerate is None or trialdefinition.samplerate is None:
lgl = "non-`None` value - make sure `samplerate` is set before defining trials"
act = "None"
raise SPYValueError(legal=lgl, varname="samplerate", actual=act)
ref = trialdefinition
tgt = obj
trl = np.array(ref.trialinfo)
t0 = np.array(ref._t0).reshape((ref._t0.size, 1))
trl = np.hstack([ref.sampleinfo, t0, trl])
trl = np.round((trl / ref.samplerate) * tgt.samplerate).astype(int)
# AnalogData + EventData w/keywords or just EventData w/keywords
if any([kw is not None for kw in [pre, post, start, trigger, stop]]):
# Make sure we actually have valid data objects to work with
if obj.__class__.__name__ == "EventData" and evt is False:
ref = obj
tgt = obj
elif obj.__class__.__name__ == "AnalogData" and evt is True:
ref = trialdefinition
tgt = obj
else:
lgl = "AnalogData with associated EventData object"
act = "{} and {}".format(obj.__class__.__name__, trialdefinition.__class__.__name__)
raise SPYValueError(legal=lgl, actual=act, varname="input")
# The only case we might actually need it: ensure `clip_edges` is valid
if not isinstance(clip_edges, bool):
raise SPYTypeError(clip_edges, varname="clip_edges", expected="Boolean")
# Ensure that objects have their sampling-rates set, otherwise break
if ref.samplerate is None or tgt.samplerate is None:
lgl = "non-`None` value - make sure `samplerate` is set before defining trials"
act = "None"
raise SPYValueError(legal=lgl, varname="samplerate", actual=act)
# Get input dimensions
szin = []
for var in [pre, post, start, trigger, stop]:
if isinstance(var, (np.ndarray, list)):
szin.append(len(var))
if np.unique(szin).size > 1:
lgl = "all trial-related arrays to have the same length"
act = "arrays with sizes {}".format(str(np.unique(szin)).replace("[", "").replace("]", ""))
raise SPYValueError(legal=lgl, varname="trial-keywords", actual=act)
if len(szin):
ntrials = szin[0]
ninc = 1
else:
ntrials = 1
ninc = 0
# If both `pre` and `start` or `post` and `stop` are `None`, abort
if (pre is None and start is None) or (post is None and stop is None):
lgl = "`pre` or `start` and `post` or `stop` to be not `None`"
act = "both `pre` and `start` and/or `post` and `stop` are simultaneously `None`"
raise SPYValueError(legal=lgl, actual=act)
if (trigger is None) and (pre is not None or post is not None):
lgl = "non-None `trigger` with `pre`/`post` timing information"
act = "`trigger` = `None`"
raise SPYValueError(legal=lgl, actual=act)
# If provided, ensure keywords make sense, otherwise allocate defaults
kwrds = {}
vdict = {
"pre": {"var": pre, "hasnan": False, "ntype": None, "fillvalue": 0},
"post": {"var": post, "hasnan": False, "ntype": None, "fillvalue": 0},
"start": {
"var": start,
"hasnan": None,
"ntype": "int_like",
"fillvalue": np.nan,
},
"trigger": {
"var": trigger,
"hasnan": None,
"ntype": "int_like",
"fillvalue": np.nan,
},
"stop": {
"var": stop,
"hasnan": None,
"ntype": "int_like",
"fillvalue": np.nan,
},
}
for vname, opts in vdict.items():
if opts["var"] is not None:
if np.issubdtype(type(opts["var"]), np.number):
try:
scalar_parser(
opts["var"],
varname=vname,
ntype=opts["ntype"],
lims=[-np.inf, np.inf],
)
except Exception as exc:
raise exc
opts["var"] = np.full((ntrials,), opts["var"])
else:
try:
array_parser(
opts["var"],
varname=vname,
hasinf=False,
hasnan=opts["hasnan"],
ntype=opts["ntype"],
dims=(ntrials,),
)
except Exception as exc:
raise exc
kwrds[vname] = opts["var"]
else:
kwrds[vname] = np.full((ntrials,), opts["fillvalue"])
# Prepare `trl` and convert event-codes + sample-numbers to lists
trl = []
evtid = list(ref.data[:, ref.dimord.index("eventid")])
evtsp = list(ref.data[:, ref.dimord.index("sample")])
nevents = len(evtid)
searching = True
trialno = 0
cnt = 0
act = ""
# Do this line-by-line: halt on error (if event-id is not found in `ref`)
while searching:
# Allocate begin and end of trial
begin = None
end = None
t0 = 0
idxl = []
# First, try to assign `start`, then `t0`
if not np.isnan(kwrds["start"][trialno]):
try:
sidx = evtid.index(kwrds["start"][trialno])
except:
act = str(kwrds["start"][trialno])
vname = "start"
break
begin = evtsp[sidx] / ref.samplerate
evtid[sidx] = -np.pi
idxl.append(sidx)
if not np.isnan(kwrds["trigger"][trialno]):
try:
idx = evtid.index(kwrds["trigger"][trialno])
except:
act = str(kwrds["trigger"][trialno])
vname = "trigger"
break
t0 = evtsp[idx] / ref.samplerate
evtid[idx] = -np.pi
idxl.append(idx)
# Trial-begin is either `trigger - pre` or `start - pre`
if begin is not None:
begin -= kwrds["pre"][trialno]
else:
begin = t0 - kwrds["pre"][trialno]
# Try to assign `stop`, if we got nothing, use `t0 + post`
if not np.isnan(kwrds["stop"][trialno]):
evtid[:sidx] = [np.pi] * sidx
try:
idx = evtid.index(kwrds["stop"][trialno])
except:
act = str(kwrds["stop"][trialno])
vname = "stop"
break
end = evtsp[idx] / ref.samplerate + kwrds["post"][trialno]
evtid[idx] = -np.pi
idxl.append(idx)
else:
end = t0 + kwrds["post"][trialno]
# Off-set `t0`
t0 -= begin
# Make sure current trial setup makes (some) sense
if begin >= end:
lgl = "non-overlapping trial begin-/end-samples"
act = "trial-begin at {}, trial-end at {}".format(str(begin), str(end))
raise SPYValueError(legal=lgl, actual=act)
# Finally, write line of `trl`
trl.append([begin, end, t0])
# Update counters and end this mess when we're done
trialno += ninc
cnt += 1
evtsp = evtsp[max(idxl, default=-1) + 1 :]
evtid = evtid[max(idxl, default=-1) + 1 :]
if trialno == ntrials or cnt == nevents:
searching = False
# Abort if the above loop ran into troubles
if len(trl) < ntrials:
if len(act) > 0:
raise SPYValueError(legal="existing event-id", varname=vname, actual=act)
# Make `trl` a NumPy array
trl = np.round(np.array(trl) * tgt.samplerate).astype(int)
# If appropriate, clip `trl` to AnalogData object's bounds (if wanted)
if clip_edges and evt:
msk = trl[:, 0] < 0
trl[msk, 0] = 0
dmax = tgt.data.shape[tgt.dimord.index("time")]
msk = trl[:, 1] > dmax
trl[msk, 1] = dmax
if np.any(trl[:, 0] >= trl[:, 1]):
lgl = "non-overlapping trials"
act = "some trials are overlapping after clipping to AnalogData object range"
raise SPYValueError(legal=lgl, actual=act)
# The triplet `sampleinfo`, `t0` and `trialinfo` works identically for
# all data genres
if trl.shape[1] < 3:
raise SPYValueError(
"array of shape (no. of trials, 3+)",
varname="trialdefinition",
actual="shape = {shp:s}".format(shp=str(trl.shape)),
)
# Finally: assign `sampleinfo`, `t0` and `trialinfo` (and potentially `trialid`)
# use target class setter
tgt.trialdefinition = trl
# In the discrete case, we have some additinal work to do
if any(["DiscreteData" in str(base) for base in tgt.__class__.__mro__]):
# Compute trial-IDs by matching data samples with provided trial-bounds
samples = tgt.data[:, tgt.dimord.index("sample")]
idx = np.searchsorted(samples, tgt.sampleinfo.ravel())
idx = idx.reshape(tgt.sampleinfo.shape)
tgt._trialslice = [slice(st, end) for st, end in idx]
tgt.trialid = np.full((samples.shape), -1, dtype=int)
for itrl, itrl_slice in enumerate(tgt._trialslice):
tgt.trialid[itrl_slice] = itrl
# Write log entry
if ref == tgt:
ref.log = (
"updated trial-definition with ["
+ " x ".join([str(numel) for numel in trl.shape])
+ "] element array"
)
else:
ref_log = ref._log.replace("\n\n", "\n\t")
tgt.log = "trial-definition extracted from EventData object: "
tgt._log += ref_log
tgt.cfg = {
"method": sys._getframe().f_code.co_name,
"EventData object": ref.cfg,
}
ref.log = "updated trial-defnition of {} object".format(tgt.__class__.__name__)
return