Source code for syncopy.synthdata.utils

# -*- coding: utf-8 -*-
#
# Utilities for syncopy's synthetic data generators
#

# Builtin/3rd party package imports
from inspect import signature
import numpy as np
import functools

from syncopy import AnalogData
from syncopy.shared.parsers import scalar_parser
from syncopy.shared.kwarg_decorators import (
    unwrap_cfg,
    _append_docstring,
    _append_signature,
)


[docs]def collect_trials(trial_func): """ Decorator to wrap around a single trial (nSamples x nChannels shaped np.ndarray) synthetic data function. Creates a generator expression to arrive memory safely at a multi-trial :class:``~syncopy.AnalogData`` object. All single trial producing functions (the ``trial_func``) should accept `nChannels` and `nSamples` as keyword arguments, OR provide other means to define those numbers, e.g. `AdjMat` for :func:`~syncopy.synth_data.ar2_network` If the single trial function also accepts a `samplerate` parameter, forward it directly. If the underlying trial generating function also accepts a `seed`, forward this directly. One can set `seed_per_trial=False` to use the same seed for all trials, or leave `seed_per_trial=True` (the default), to have this function internally generate a list of seeds with len equal to `nTrials` from the given seed, with one seed per trial. One can set the `seed` to `None`, which will select a random seed each time, (and it will differ between trials). The default `nTrials=None` is the identity wrapper and just returns the output of the trial generating function directly, so a single trial :class:`numpy.ndarray`. """ @unwrap_cfg @functools.wraps(trial_func) def wrapper_synth(*args, nTrials=100, samplerate=1000, seed=None, seed_per_trial=True, **tf_kwargs): seed_array = None # One seed per trial. # Use the single seed to create one seed per trial. if nTrials is not None and seed is not None and seed_per_trial: rng = np.random.default_rng(seed) seed_array = rng.integers(1_000_000, size=nTrials) # append samplerate parameter if also needed by the generator if "samplerate" in signature(trial_func).parameters.keys(): tf_kwargs["samplerate"] = samplerate # bypass: directly return a single trial (may pass on the scalar seed if the function supports it) if nTrials is None: if "seed" in signature(trial_func).parameters.keys(): tf_kwargs["seed"] = seed return trial_func(**tf_kwargs) # collect trials else: scalar_parser(nTrials, "nTrials", ntype="int_like", lims=[1, np.inf]) # create the trial generator def mk_trl_generator(): for trial_idx in range(nTrials): if "seed" in signature(trial_func).parameters.keys(): if seed_array is not None: tf_kwargs["seed"] = seed_array[trial_idx] else: tf_kwargs["seed"] = seed yield trial_func(*args, **tf_kwargs) trl_generator = mk_trl_generator() data = AnalogData(trl_generator, samplerate=samplerate) return data # Append `nTrials` and `seed` keyword entry to wrapped function's docstring and signature nTrialsDocEntry = ( " nTrials : int or None\n" " Number of trials for the returned :class:`~syncopy.AnalogData` object.\n" " When set to `None` a single-trial :class:`~numpy.ndarray`\n" " is returned." ) wrapper_synth.__doc__ = _append_docstring(trial_func, nTrialsDocEntry) wrapper_synth.__signature__ = _append_signature(trial_func, "nTrials", kwdefault=100) return wrapper_synth