Source code for syncopy.datatype.methods.concat

# -*- coding: utf-8 -*-
#
# Syncopy object concatenation
#

# Builtin/3rd party package imports
import numpy as np
import h5py
import dask.distributed as dd

# Local imports
from syncopy import __acme__
from syncopy.datatype.continuous_data import ContinuousData
from syncopy.shared.parsers import data_parser
from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYWarning, SPYInfo
from syncopy.shared.computational_routine import ComputationalRoutine
from syncopy.shared.kwarg_decorators import process_io, detect_parallel_client

import dask.distributed as dd

__all__ = ["concat"]


[docs]def concat(spy_obj1, spy_obj2, dim="channel"): """ Concatenate two Syncopy data objects trial-by-trial along `dim` axis. Only `channel` axis supported at the moment! Parameters ---------- spy_obj1 : Syncopy data object spy_obj1 : Syncopy data object dim : str The concatenation dimension """ # -- sanity checks -- if not issubclass(spy_obj1.__class__, ContinuousData): raise SPYTypeError(spy_obj1, "spy_obj1", "continuous data type") data_parser(spy_obj1, empty=False) data_parser(spy_obj2, empty=False) # catches also differing classes if spy_obj1.dimord != spy_obj2.dimord: raise SPYValueError( "objects with equal dimensional layout", "spy object dimensions", f"{spy_obj1.dimord} and {spy_obj2.dimord}", ) if dim not in spy_obj1.dimord: raise SPYValueError( f"object which has a `{dim}` dimension", "spy_obj1.dimord", f"{spy_obj1.dimord}", ) if dim != "channel": raise NotImplementedError("Only `channel` concatenation supported atm") concat_axis = spy_obj1.dimord.index(dim) # check shapes shape_zip = zip(spy_obj1.data.shape, spy_obj2.data.shape) for axis, (s1, s2) in enumerate(shape_zip): if axis != concat_axis and s1 != s2: raise SPYValueError( "objects with matching shapes", "spy objects", f"{spy_obj1.data.shape} and {spy_obj2.data.shape}", ) # new size of the concatenation axis/dimension # this works for `time`, `channel`, `freq` and even `trials` new_size = len(getattr(spy_obj1, dim)) + len(getattr(spy_obj2, dim)) # these get distributed over the cF calls (one for each) to index the trials of the 2nd object obj2_idxs = [spy_obj2._preview_trial(trlno).idx for trlno in range(len(spy_obj2.trials))] # Create output object result = spy_obj1.__class__(dimord=spy_obj1.dimord) parallel = False try: dd.get_client() parallel = True except ValueError: parallel = False CR = SpyConcat( obj2_idxs, # get unpacked to `trl2_idx` for each concat_cF hdf5_path=spy_obj2.filename, axis=concat_axis, new_size=new_size, ) CR.initialize(spy_obj1, result._stackingDim, chan_per_worker=None, keeptrials=True) log_dict = {"obj1": spy_obj1.filename, "obj2": spy_obj2.filename, "dim": dim} CR.compute(spy_obj1, result, parallel=parallel, log_dict=log_dict) return result
@process_io def concat_cF( trl1_dat, trl2_idx, hdf5_path=None, axis=-1, new_size=None, noCompute=False, chunkShape=None, ): """ Concatenate 2 trials along `axis` Parameters ---------- trl1_dat : :class:`numpy.ndarray` Trial data trl2_idx : tuple A tuple of indices pointing to the 2nd trial hdf5_path : str Path to the backing hdf5 dataset, providing the 2nd trial via `trl2_idx` new_size : int New size along the concatenation axis noCompute : bool Preprocessing flag. If `True`, do not perform actual calculation but instead return expected shape and :class:`numpy.dtype` of output array. chunkShape : None or tuple If not `None`, represents shape of output Returns ------- res : :class:`numpy.ndarray` Concatenate result Notes ----- This method is intended to be used as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. Thus, input parameters are presumed to be forwarded from a parent metafunction. Consequently, this function does **not** perform any error checking and operates under the assumption that all inputs have been externally validated and cross-checked. See also -------- _perform_computation : execute arithmetic operation SpyArithmetic : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` subclass """ if noCompute: new_shape = np.array(trl1_dat.shape) new_shape[axis] = new_size return tuple(new_shape), trl1_dat.dtype # get the 2nd trial as array with h5py.File(hdf5_path, "r") as h5file: trl2_dat = h5file["data"][trl2_idx] return np.concatenate([trl1_dat, trl2_dat], axis=axis) class SpyConcat(ComputationalRoutine): """ Compute class for performing concatenation of Syncopy objects Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`, see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute classes and metafunctions. """ computeFunction = staticmethod(concat_cF) def process_metadata(self, baseObj, out): # -- only channel concat atm -- # Get/set dimensional attributes with the help of a selection if baseObj.selection is None: revert = True baseObj.selectdata(inplace=True) else: revert = False concat_dim = baseObj.dimord[self.cfg["axis"]] # Get/set timing-related selection modifiers out.trialdefinition = baseObj.selection.trialdefinition if baseObj.selection._samplerate: out.samplerate = baseObj.samplerate for prop in [prop for prop in baseObj.selection._dimProps if prop != concat_dim]: selection = getattr(baseObj.selection, prop) if selection is not None: if np.issubdtype(type(selection), np.number): selection = [selection] setattr(out, prop, getattr(baseObj, prop)[selection]) if revert: baseObj.selection = None