Source code for syncopy.statistics.spike_psth

# -*- coding: utf-8 -*-
#
# Syncopy PSTH frontend
#

import numpy as np
from copy import deepcopy
import logging
import platform

# Syncopy imports
import syncopy as spy
from syncopy.shared.parsers import data_parser, scalar_parser, array_parser
from syncopy.shared.tools import get_defaults, get_frontend_cfg
from syncopy.datatype import TimeLockData

from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYInfo
from syncopy.shared.kwarg_decorators import (
    unwrap_cfg,
    unwrap_select,
    detect_parallel_client,
)
from syncopy.shared.input_processors import check_passed_kwargs
from syncopy.shared.latency import get_analysis_window, create_trial_selection

# Local imports
from syncopy.statistics.compRoutines import PSTH
from syncopy.statistics.psth import Rice_rule, sqrt_rule, get_chan_unit_combs

available_binsizes = {"rice": Rice_rule, "sqrt": sqrt_rule}
available_outputs = ["rate", "spikecount", "proportion"]


[docs]@unwrap_cfg @unwrap_select @detect_parallel_client def spike_psth( data, binsize="rice", output="rate", latency="maxperiod", vartriallen=True, keeptrials=True, **kwargs, ): """ Peristimulus time histogram Parameters ---------- data : :class:`~syncopy.SpikeData` A non-empty Syncopy :class:`~syncopy.datatype.SpikeData` object binsize : float or one of {'rice', 'sqrt'}, optional Binsize in seconds or get optimal bin width via Rice rule (`'rice'`) or square root of number of observations (`'sqrt'`) output : {'rate', 'spikecount', 'proportion'}, optional Set to `'rate'` to convert the output to firing rates (spikes/sec), 'spikecount' to count the number spikes per trial or 'proportion' to normalize the area under the PSTH to 1 Defaults to `'rate'` vartriallen : bool, optional `True` (default): accept variable trial lengths and use all available trials and the samples in every trial. Missing values (empty bins) will be ignored in the computation and results stored as NaNs `False` : only select those trials that fully cover the window as specified by `latency` and discard those trials that do not. latency : array_like or {'maxperiod', 'minperiod', 'prestim', 'poststim'} Either set desired time interval (`[begin, end]`) for spike counting in seconds, 'maxperiod' (default) for the maximum period available or `'minperiod' for the minimal time interval all trials share, or `'prestim'` (all t < 0) or `'poststim'` (all t > 0) keeptrials : bool, optional If `True` the psth's of individual trials are returned, otherwise results are averaged across trials. Returns ------- out : :class:`~syncopy.TimeLockData` Time locked data object, with additional datasets: ``out.avg`` and ``out.var`` Examples -------- `spd` is a :class:`~syncopy.SpikeData` object. Computing the rate histogram with a 0.2 seconds bin size and on the maximally available time interval, covering all events: >>> spy.spike_psth(spd, binsize=0.2, latency='maxperiod') Get the spike counts for the max. common time interval where there is activity for all trials and use the square root bin size selection rule: >>> spy.spike_psth(spd, binsize='sqrt', latency='minperiod', output='spikecount') Firing rates between 0.1 and 0.5 seconds in 50ms bins, discarding trials which do not have events in every bin of the selected latency interval: >>> spy.spike_psth(spd, binsize=0.05, latency=[0.1, 0.5], vartriallen=False) """ # Make sure our one mandatory input object can be processed try: data_parser( data, varname="data", dataclass="SpikeData", writable=None, empty=False, dimord=["sample", "channel", "unit"], ) except Exception as exc: raise exc if not isinstance(vartriallen, bool): raise SPYTypeError(vartriallen, varname="vartriallen", expected="Bool") defaults = get_defaults(spike_psth) lcls = locals() # check for ineffective additional kwargs check_passed_kwargs(lcls, defaults, frontend_name="spike_psth") # save frontend call in cfg new_cfg = get_frontend_cfg(defaults, lcls, kwargs) # digest selections if data.selection is not None: trl_def = data.selection.trialdefinition sinfo = data.selection.trialdefinition[:, :2] trials = data.selection.trials else: trl_def = data.trialdefinition sinfo = data.sampleinfo trials = data.trials trl_starts, trl_ends = data.trialintervals[:, 0], data.trialintervals[:, 1] # validate output parameter if output not in available_outputs: lgl = f"one of {available_outputs}" act = output raise SPYValueError(lgl, "output", act) if isinstance(binsize, str): if binsize not in available_binsizes: lgl = f"one of {available_binsizes}" act = output raise SPYValueError(lgl, "output", act) # --- parse and digest `latency` (time window of analysis) --- window = get_analysis_window(data, latency) # to restore later select_backup = None if data.selection is None else deepcopy(data.selection.select) if not vartriallen: # this will create/ammend the selection, respecting the latency window select, numDiscard = create_trial_selection(data, window) msg = f"Discarded {numDiscard} trials which did not fit into latency window" SPYInfo(msg) # apply the updated selection data.selectdata(select, inplace=True) # now redefine local variables trl_def = data.selection.trialdefinition sinfo = data.selection.trialdefinition[:, :2] trials = data.selection.trials else: numDiscard = 0 # --- determine overall (all selected trials) histogram shape --- # get average trial size for auto-binning av_trl_size = np.diff(sinfo).sum() / len(trials) if binsize in available_binsizes: nBins = available_binsizes[binsize](av_trl_size) bins = np.linspace(*window, nBins) else: # make sure we have at least 2 bins scalar_parser(binsize, varname="binsize", lims=[0, np.diff(window).squeeze()]) # include rightmost bin edge bins = np.arange(window[0], window[1] + binsize, binsize) nBins = len(bins) # it's a sequential loop to get an array of [chan, unit] indices combs = get_chan_unit_combs(trials) # --- populate the log log_dict = { "bins": bins, "binsize": binsize, "latency": latency, "output": output, "vartriallen": vartriallen, "numDiscard": numDiscard, } # --- set up CR --- # trl_start` and `onset` for distributing positional args to psth_cF trl_starts = trl_def[:, 0] trl_ends = trl_def[:, 1] trigger_onsets = trl_def[:, 2] psth_cR = PSTH( trl_starts, trigger_onsets, trl_ends, chan_unit_combs=combs, tbins=bins, output=output, samplerate=data.samplerate, ) # only available dimord labels ['time', 'channel']) psth_results = TimeLockData() psth_cR.initialize( data, chan_per_worker=None, out_stackingdim=psth_results._stackingDim, keeptrials=keeptrials, ) psth_cR.compute(data, psth_results, parallel=kwargs.get("parallel"), log_dict=log_dict) # calculate trial average and variance avg = spy.mean(psth_results, dim="trials", parallel=False) var = spy.var(psth_results, dim="trials", parallel=False) # attach data to TimeLockData psth_results._update_dataset("avg", avg.data) psth_results._update_dataset("var", var.data) # unregister datasets to detach from objects avg._unregister_dataset("data", del_from_file=False) var._unregister_dataset("data", del_from_file=False) # scramble filenames and delete unneeded objects avg.filename, var.filename = "", "" del avg, var # -- propagate old cfg and attach this one -- psth_results.cfg.update(data.cfg) psth_results.cfg.update({"spike_psth": new_cfg}) # finally revert possible in-place selections if select_backup is None: data.selection = None else: data.selectdata(select_backup, inplace=True) return psth_results