Source code for syncopy.shared.tools

# -*- coding: utf-8 -*-
#
# Auxiliaries used across all of Syncopy
#

# Builtin/3rd party package imports
import numpy as np
from numbers import Number
from copy import deepcopy
import inspect
import json

# Local imports
from syncopy.shared.errors import SPYValueError, SPYWarning, SPYTypeError, SPYError
from syncopy.shared.parsers import sequence_parser

__all__ = ["StructDict", "get_defaults"]


[docs]class StructDict(dict): """Child-class of dict for emulating MATLAB structs Examples -------- cfg = StructDict() cfg.a = [0, 25] """
[docs] def __init__(self, *args, **kwargs): """ Create a child-class of dict whose attributes are its keys (thus ensuring that attributes and items are always in sync) """ super().__init__(*args, **kwargs) self.__dict__ = self
def __repr__(self): return self.__str__() def __str__(self): if self.keys(): ppStr = "Syncopy StructDict\n\n" maxKeyLength = max([len(val) for val in self.keys()]) printString = "{0:>" + str(maxKeyLength + 5) + "} : {1:}\n" for key, value in self.items(): ppStr += printString.format(key, str(value)) ppStr += "\nUse `dict(cfg)` for copy-paste-friendly format" else: ppStr = "{}" return ppStr
[docs] def copy(self, deep=True): """ Create a copy of this StructDict instance. Note: Overwrites the `.copy` method of the parent `dict` class, otherwise `copy()` will return a `dict` instead of a `StructDict`. Parameters --------- deep: bool Whether to produce a deep copy. Defaults to `True`. Returns ------- Copy of StructDict. """ if deep: return self.deepcopy() else: obj = type(self).__new__(self.__class__) obj.__dict__.update(self.__dict__) return obj
[docs] def deepcopy(self): """ Return a deep copy of this StructDict. Notes ----- Call the `.copy()` method instead to get a shallow copy, though that seems rather uncommon. """ return deepcopy(self)
def __deepcopy__(self, memo): result = type(self).__new__(self.__class__) memo[id(self)] = result for k, v in self.__dict__.items(): setattr(result, k, deepcopy(v, memo)) return result
class SerializableDict(dict): """ It's a dict which checks newly inserted values for serializability, keys should always be serializable """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # check also initial entries for key, value in self.items(): self.is_json(key, value) def __setitem__(self, key, value): # try simple serialization 1st value = _serialize_value(value) self.is_json(key, value) dict.__setitem__(self, key, value) def is_json(self, key, value): try: json.dumps(value) except TypeError as err: lgl = "expected serializable data type, e.g. floats, lists, tuples, ... " raise SPYError(f"Wrong type of value of {key}: {err}, {lgl}") try: json.dumps(key) except TypeError as err: lgl = "expected serializable data type, e.g. floats, lists, tuples, ... " raise SPYError(f"Wrong type of key of {key}: {err}, {lgl}") def _serialize_value(value): """ Helper to serialize 1-level deep sequences (lists, arrays, ranges) or single numbers/strings as ``value``s. Main task is to get rid of numpy data types which are not serializable (e.i. np.int64). For trivial types like str just passes ``value`` through """ if isinstance(value, np.ndarray): # converts to Python scalars value = value.tolist() return value if isinstance(value, range): value = list(value) # unpack the list, if ppl mix types this will go wrong if isinstance(value, list) and len(value) != 0: if hasattr(value[0], "is_integer"): if value[0].is_integer(): value = [int(v) for v in value] else: value = [float(v) for v in value] # should only be the integers elif isinstance(value[0], Number) and not isinstance(value[0], bool): value = [int(v) for v in value] # singleton/non-sequence type entries if isinstance(value, Number) and not isinstance(value, bool): # all floating types have this method if hasattr(value, "is_integer"): # get rid of np.int64 or np.float32 value = int(value) if value.is_integer() else float(value) else: value = int(value) return value def get_frontend_cfg(defaults, lcls, kwargs): """ Assemble serializable cfg dict to allow direct replay of frontend calls Most parsing is done in the respective frontends, the config values should be straightforward to serialize. Parameters ---------- defaults : dict The result of :func:`~get_defaults`, holding all frontend specific parameter names and default values lcls : dict The `locals()` within a frontend call, contains passed parameter names and values kwargs : dict The `kwargs` attached to every frontend signature, holding additional arguments, e.g. `parallel` and `select` Returns ------- new_cfg : :class:`~StructDict` Holds all (default and non-default) parameter key-value pairs passed to the frontend """ # create new cfg dict new_cfg = StructDict() for par_name in defaults: # check only set parameters if par_name in lcls: value = _serialize_value(lcls[par_name]) new_cfg[par_name] = value # 'select' only allowed dictionary parameter within kwargs # we can 'pop' here as selection got digested beforehand by @unwrap_select sdict = kwargs.pop("select", None) if sdict is not None: # serialized selection dict ser_sdict = dict() for sel_key in sdict: ser_sdict[sel_key] = _serialize_value(sdict[sel_key]) new_cfg["select"] = ser_sdict # should only be 'parallel' and 'chan_per_worker' for key in kwargs: new_cfg[key] = _serialize_value(kwargs[key]) # use instantiation for a final check SerializableDict(new_cfg) return new_cfg
[docs]def best_match(source, selection, span=False, tol=None, squash_duplicates=False): """ Find matching elements in a given 1d-array/list Parameters ---------- source : NumPy 1d-array/list Reference array whose elements are to be matched by `selection` selection: NumPy 1d-array/list Array of query-values whose closest matches are to be found in `source`. Note that `source` and `selection` need not be the same length. span : bool If `True`, `selection` is interpreted as (closed) interval ``[lo, hi]`` and `source` is queried for all elements contained in the interval, i.e., ``lo <= src <= hi for src in source`` (typically used for `toilim`/`foilim`-like selections). tol : None or float If `None` for each component of `selection` the closest value in `source` is selected, e.g., for ``source = [10, 20]`` and ``selection = [-50, 0, 50]`` the closest values are `[10, 10, 20]`. If not `None`, ensures values in `selection` do not deviate further than `tol` from `source`. If any element `sel` of `selection` is outside a `tol`-neighborhood around `source`, i.e., ``np.abs(sel - source).max() >= tol``, a :class:`~syncopy.shared.errors.SPYValueError` is raised. squash_duplicates : bool If `True`, identical matches are removed from the result. Returns ------- values : NumPy 1darray Values of `source` that most closely match given elements in `selection` idx : NumPy 1darray Indices of `values` with respect to `source`, such that, ``source[idx] == values`` Notes ----- This is an auxiliary method that is intended purely for internal use. Thus, no error checking is performed. Examples -------- Exact matching, ordered `source` and `selection`: >>> best_match(np.arange(10), [2,5]) (array([2, 5]), array([2, 5])) Inexact matching, ordered `source` and `selection`: >>> source = np.arange(10) >>> selection = np.array([1.5, 1.5, 2.2, 6.2, 8.8]) >>> best_match(source, selection) (array([2, 2, 2, 6, 9]), array([2, 2, 2, 6, 9])) Inexact matching, unordered `source` and `selection`: >>> source = np.array([2.2, 1.5, 1.5, 6.2, 8.8]) >>> selection = np.array([1.9, 9., 1., -0.4, 1.2, 0.2, 9.3]) >>> best_match(source, selection) (array([2.2, 8.8, 1.5, 1.5, 1.5, 1.5, 8.8]), array([0, 4, 1, 1, 1, 1, 4])) Same as above, but ignore duplicate matches >>> best_match(source, selection, squash_duplicates=True) (array([2.2, 8.8, 1.5]), array([0, 4, 1])) Interval-matching: >>> best_match(np.arange(10), [2.9, 6.1], span=True) (array([3, 4, 5, 6]), array([3, 4, 5, 6])) """ # Make `source` a NumPy array if necessary if isinstance(source, list): source = np.array(source) # If `selection` is a scalar, convert it to 1-element list if np.issubdtype(type(selection), np.number): selection = [selection] # Ensure selection is within `tol` bounds from `source` if tol is not None: if not np.all([np.all((np.abs(source - value)) < tol) for value in selection]): lgl = "all elements of `selection` to be within a {0:2.4f}-band around `source`" act = "values in `selection` deviating further than given tolerance " + "of {0:2.4f} from source" raise SPYValueError(legal=lgl.format(tol), varname="selection", actual=act.format(tol)) # Do not perform O(n) potentially unnecessary sort operations... issorted = True # Interval-selections are a lot easier than discrete time-points... if span: idx = np.intersect1d(np.where(source >= selection[0])[0], np.where(source <= selection[1])[0]) else: issorted = True if source.size > 1 and np.diff(source).min() < 0: issorted = False orig = np.array(source, copy=True) idx_orig = np.argsort(orig) source = orig[idx_orig] idx = np.searchsorted(source, selection, side="left") leftNbrs = np.abs(selection - source[np.maximum(idx - 1, np.zeros(idx.shape, dtype=np.intp))]) rightNbrs = np.abs( selection - source[np.minimum(idx, np.full(idx.shape, source.size - 1, dtype=np.intp))] ) shiftLeft = (idx == source.size) | (leftNbrs < rightNbrs) idx[shiftLeft] -= 1 # Account for potentially unsorted selections (and thus unordered `idx`) if squash_duplicates: _, xdi = np.unique(idx.astype(np.intp), return_index=True) idx = idx[np.sort(xdi)] # Re-order discrete-selection index arrays in case `source` was unsorted if not issorted and not span: idx_sort = idx_orig[idx] return orig[idx_sort], idx_sort else: return source[idx], idx
[docs]def get_defaults(obj): """ Parse input arguments of `obj` and return dictionary Parameters ---------- obj : function or class Object whose input arguments to parse. Can be either a class or function. Returns ------- argdict : dictionary Dictionary of `argument : default value` pairs constructed from `obj`'s call-signature/instantiation. Examples -------- To see the default input arguments of :meth:`syncopy.freqanalysis` use >>> spy.get_defaults(spy.freqanalysis) """ if not callable(obj): raise SPYTypeError(obj, varname="obj", expected="SyNCoPy function or class") dct = { k: v.default for k, v in inspect.signature(obj).parameters.items() if v.default != v.empty and v.name != "cfg" } return StructDict(dct)