diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..5bc000434 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,16 @@ +[run] +branch = True +source = syncopy + +[report] +exclude_lines = + if self.debug: + if debug: + raise NotImplementedError + if __name__ == .__main__.: +ignore_errors = True +omit = + syncopy/tests/* + *conda2pip.py + *setup.py + test_* diff --git a/.gitignore b/.gitignore index a7d9512d8..4fe07edc2 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ requirements-test.txt # Editor-related stuff .vscode + +# Mac OS related stuff +.DS_Store \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 1e52b6f42..89addbbc7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,6 @@ python: cache: pip -# safelist branches: only: - master @@ -16,6 +15,9 @@ install: - pip install -r requirements.txt - pip install -r requirements-test.txt - python setup.py -q install -# command to run tests + script: - - pytest -v + - pytest -v --cov=./ + +after_success: + - bash <(curl -s https://codecov.io/bash) || echo "Codecov did not collect coverage reports" diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000..ca4291b1f --- /dev/null +++ b/codecov.yml @@ -0,0 +1,14 @@ +coverage: + status: + project: + default: + # Allow coverage to drop `threshold` percent in PRs to master/dev + target: auto + threshold: 5% + base: auto + branches: + - master + - dev + if_ci_failed: error #success, failure, error, ignore + informational: false + only_pulls: true diff --git a/dev_frontend.py b/dev_frontend.py new file mode 100644 index 000000000..d292abab2 --- /dev/null +++ b/dev_frontend.py @@ -0,0 +1,62 @@ +import numpy as np +import scipy.signal as sci + +from syncopy.datatype import CrossSpectralData, padding, SpectralData, AnalogData +from syncopy.connectivity.ST_compRoutines import cross_spectra_cF, ST_CrossSpectra +from syncopy.connectivity.ST_compRoutines import cross_covariance_cF +from syncopy.connectivity import connectivity +from syncopy.specest import freqanalysis +import matplotlib.pyplot as ppl + +from syncopy.shared.parsers import data_parser, scalar_parser, array_parser +from syncopy.shared.tools import get_defaults +from syncopy.datatype import SpectralData, padding + +from syncopy.tests.misc import generate_artificial_data +tdat = generate_artificial_data(inmemory=True, seed=1230, nTrials=50, nChannels=5) + +foilim = [1, 30] +# this still gives type(tsel) = slice :) +sdict1 = {"trials": [0], 'channels' : ['channel1'], 'toi': np.arange(-1, 1, 0.001)} + +coherence = connectivity(data=tdat, + foilim=None, + output='pow', + taper='dpss', + tapsmofrq=5, + foo = 3, # non-sensical + keeptrials=False) + +granger = connectivity(data=tdat, + method='granger', + foilim=[0, 50], + output='pow', + taper='dpss', + tapsmofrq=5, + keeptrials=False) + +# D = SpectralData(dimord=['freq','test1','test2','taper']) +# D2 = AnalogData(dimord=['freq','test1']) + +# a lot of problems here.. +# correlation = connectivity(data=tdat, method='corr', keeptrials=False, taper='df') + + +# the hard wired dimord of the cF + +res = freqanalysis(data=tdat, + method='mtmfft', + samplerate=tdat.samplerate, +# order_max=20, +# foilim=foilim, +# foi=np.arange(502), + output='pow', +# polyremoval=1, + t_ftimwin=0.5, + keeptrials=True, + taper='dpss', + nTaper = 19, + tapsmofrq=5, + keeptapers=True, + parallel=False, # try this!!!!!! + select={"trials" : [0,1]}) diff --git a/doc/source/user/data_handling.rst b/doc/source/user/data_handling.rst index ce85eb996..efe789edf 100644 --- a/doc/source/user/data_handling.rst +++ b/doc/source/user/data_handling.rst @@ -22,19 +22,20 @@ Reading and writing data with Syncopy syncopy.load syncopy.save -Functions for Editing Syncopy Data Objects -------------------------------------------- -Defining trials, data selection and padding. +Functions for Inspecting/Editing Syncopy Data Objects +----------------------------------------------------- +Defining trials, data selection and padding. .. autosummary:: syncopy.definetrial + syncopy.show syncopy.selectdata syncopy.padding Advanced Topics --------------- -More information about Syncopy's data class structure and file format. +More information about Syncopy's data class structure and file format. .. toctree:: diff --git a/syncopy.yml b/syncopy.yml index ae1fec83c..f6d05ea95 100644 --- a/syncopy.yml +++ b/syncopy.yml @@ -7,7 +7,7 @@ dependencies: - python >= 3.8, < 3.9 - pip - numpy >= 1.10, < 2.0 - - scipy >= 1.5, < 1.6 + - scipy >= 1.5 - h5py >= 2.9, < 3 - matplotlib >= 3.3, < 3.5 - tqdm >= 4.31 @@ -18,7 +18,7 @@ dependencies: - memory_profiler - numpydoc - sphinx_bootstrap_theme - - pytest + - pytest-cov - pylint - ipdb - tox diff --git a/syncopy/connectivity/AV_compRoutines.py b/syncopy/connectivity/AV_compRoutines.py new file mode 100644 index 000000000..f23228b13 --- /dev/null +++ b/syncopy/connectivity/AV_compRoutines.py @@ -0,0 +1,507 @@ +# -*- coding: utf-8 -*- +# +# computeFunctions and -Routines to post-process +# the parallel single trial computations to be found in ST_compRoutines.py +# The standard use case involves computations on the +# trial average, meaning that the SyNCoPy input to these routines +# consists of only '1 trial' and parallelising over channels +# is non trivial and atm also not supported. Pre-processing +# like padding or detrending already happened in the single trial +# compute functions. +# + +# Builtin/3rd party package imports +import numpy as np +from inspect import signature + +# syncopy imports +from syncopy.shared.const_def import spectralDTypes, spectralConversions +from syncopy.shared.computational_routine import ComputationalRoutine +from syncopy.shared.kwarg_decorators import unwrap_io +from syncopy.shared.errors import ( + SPYValueError, +) +from syncopy.connectivity.wilson_sf import wilson_sf, regularize_csd +from syncopy.connectivity.granger import granger + + +@unwrap_io +def normalize_csd_cF(csd_av_dat, + output='abs', + chunkShape=None, + noCompute=False): + + """ + Given the trial averaged cross spectral densities, + calculates the normalizations to arrive at the + channel x channel coherencies. If ``S_ij(f)`` is the + averaged cross-spectrum between channel `i` and `j`, the + coherency [1]_ is defined as: + + .. math:: + + C_{ij} = S_{ij}(f) / (|S_{ii}| |S_{jj}|) + + The coherence is now defined as either ``|C_ij|`` + or ``|C_ij|^2``, this can be controlled with the `output` + parameter. + + Parameters + ---------- + csd_av_dat : (1, nFreq, N, N) :class:`numpy.ndarray` + Cross-spectral densities for `N` x `N` channels + and `nFreq` frequencies averaged over trials. + output : {'abs', 'pow', 'fourier'}, default: 'abs' + Also after normalization the coherency is still complex (`'fourier'`), + to get the real valued coherence ``0 < C_ij(f) < 1`` one can either take the + absolute (`'abs'`) or the absolute squared (`'pow'`) values of the + coherencies. The definitions are not uniform in the literature, + hence multiple output types are supported. + noCompute : bool + Preprocessing flag. If `True`, do not perform actual calculation but + instead return expected shape and :class:`numpy.dtype` of output + array. + + Returns + ------- + CS_ij : (1, nFreq, N, N) :class:`numpy.ndarray` + Coherence for all channel combinations ``i,j``. + `N` corresponds to number of input channels. + + Notes + ----- + + This method is intended to be used as + :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. + Thus, input parameters are presumed to be forwarded from a parent metafunction. + Consequently, this function does **not** perform any error checking and operates + under the assumption that all inputs have been externally validated and cross-checked. + + .. [1] Nolte, Guido, et al. "Identifying true brain interaction from EEG + data using the imaginary part of coherency." + Clinical neurophysiology 115.10 (2004): 2292-2307. + + + See also + -------- + cross_spectra_cF : :func:`~syncopy.connectivity.ST_compRoutines.cross_spectra_cF` + Single trial (Multi-)tapered cross spectral densities. + + """ + + # it's the same as the input shape! + outShape = csd_av_dat.shape + + # For initialization of computational routine, + # just return output shape and dtype + if noCompute: + return outShape, spectralDTypes[output] + + # re-shape to (nChannels x nChannels x nFreq) + CS_ij = csd_av_dat.transpose(0, 2, 3, 1)[0, ...] + + # main diagonal has shape (nFreq x nChannels): the auto spectra + diag = CS_ij.diagonal() + + # get the needed product pairs of the autospectra + Ciijj = np.sqrt(diag[:, :, None] * diag[:, None, :]).T + CS_ij = CS_ij / Ciijj + + CS_ij = spectralConversions[output](CS_ij) + + # re-shape to original form and re-attach dummy time axis + return CS_ij[None, ...].transpose(0, 3, 1, 2) + + +class NormalizeCrossSpectra(ComputationalRoutine): + + """ + Compute class that normalizes trial averaged csd's + of :class:`~syncopy.CrossSpectralData` objects + to arrive at the respective coherencies. + + Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`, + see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute + classes and metafunctions. + + See also + -------- + syncopy.connectivityanalysis : parent metafunction + """ + + # the hard wired dimord of the cF + dimord = ['time', 'freq', 'channel_i', 'channel_j'] + + computeFunction = staticmethod(normalize_csd_cF) + + method = "" # there is no backend + # 1st argument,the data, gets omitted + valid_kws = list(signature(normalize_csd_cF).parameters.keys())[1:] + + def pre_check(self): + ''' + Make sure we have a trial average, + so the input data only consists of `1 trial`. + Can only be performed after initialization! + ''' + + if self.numTrials is None: + lgl = 'Initialize the computational Routine first!' + act = 'ComputationalRoutine not initialized!' + raise SPYValueError(legal=lgl, varname=self.__class__.__name__, actual=act) + + if self.numTrials != 1: + lgl = "1 trial: normalizations can only be done on averaged quantities!" + act = f"DataSet contains {self.numTrials} trials" + raise SPYValueError(legal=lgl, varname="data", actual=act) + + def process_metadata(self, data, out): + + # Some index gymnastics to get trial begin/end "samples" + if data._selection is not None: + chanSec_i = data._selection.channel_i + chanSec_j = data._selection.channel_j + trl = data._selection.trialdefinition + for row in range(trl.shape[0]): + trl[row, :2] = [row, row + 1] + else: + chanSec_i = slice(None) + chanSec_j = slice(None) + time = np.arange(len(data.trials)) + time = time.reshape((time.size, 1)) + trl = np.hstack((time, time + 1, + np.zeros((len(data.trials), 1)), + np.array(data.trialinfo))) + + # Attach constructed trialdef-array (if even necessary) + if self.keeptrials: + out.trialdefinition = trl + else: + out.trialdefinition = np.array([[0, 1, 0]]) + + # Attach remaining meta-data + out.samplerate = data.samplerate + out.channel_i = np.array(data.channel_i[chanSec_i]) + out.channel_j = np.array(data.channel_j[chanSec_j]) + out.freq = data.freq + + +@unwrap_io +def normalize_ccov_cF(trl_av_dat, + chunkShape=None, + noCompute=False): + + """ + Given the trial averaged cross-covariances, + we normalize with the 0-lag auto-covariances + (~averaged single trial variances) + to arrive at the cross-correlations. + + Parameters + ---------- + trl_av_dat : (nLag, 1, N, N) :class:`numpy.ndarray` + Cross-covariances for `N` x `N` channels + and `nLag` epochs averaged over trials. + noCompute : bool + Preprocessing flag. If `True`, do not perform actual calculation but + instead return expected shape and :class:`numpy.dtype` of output + array. + + Returns + ------- + Corr_ij : (nLag, 1, N, N) :class:`numpy.ndarray` + Cross-correlations for all channel combinations ``i,j``. + `N` corresponds to number of input channels. + + Notes + ----- + + This method is intended to be used as + :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. + Thus, input parameters are presumed to be forwarded from a parent metafunction. + Consequently, this function does **not** perform any error checking and operates + under the assumption that all inputs have been externally validated and cross-checked. + + See also + -------- + cross_covariance_cF : :func:`~syncopy.connectivity.ST_compRoutines.cross_covariance_cF` + Single trial cross covariances. + + """ + + # it's the same as the input shape! + outShape = trl_av_dat.shape + + # For initialization of computational routine, + # just return output shape and dtype + # cross spectra are complex! + if noCompute: + return outShape, spectralDTypes['abs'] + + # re-shape to (nLag x nChannels x nChannels) + CCov_ij = trl_av_dat[:, 0, ...] + + # main diagonal has shape (nChannels x nChannels): + # the auto-covariances at 0-lag (~stds) + diag = trl_av_dat[0, 0, ...].diagonal() + + # get the needed product pairs + Ciijj = np.sqrt(diag[:, None] * diag[None, :]).T + CCov_ij = CCov_ij / Ciijj + + # re-attach dummy freq axis + return CCov_ij[:, None, ...] + + +class NormalizeCrossCov(ComputationalRoutine): + + """ + Compute class that normalizes trial averaged + cross-covariances of :class:`~syncopy.CrossSpectralData` objects + to arrive at the respective correlations + + Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`, + see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute + classes and metafunctions. + + See also + -------- + syncopy.connectivityanalysis : parent metafunction + """ + + # the hard wired dimord of the cF + dimord = ['time', 'freq', 'channel_i', 'channel_j'] + + computeFunction = staticmethod(normalize_ccov_cF) + + method = "" # there is no backend + # 1st argument,the data, gets omitted + valid_kws = list(signature(normalize_ccov_cF).parameters.keys())[1:] + + def pre_check(self): + ''' + Make sure we have a trial average, + so the input data only consists of `1 trial`. + Can only be performed after initialization! + ''' + + if self.numTrials is None: + lgl = 'Initialize the computational Routine first!' + act = 'ComputationalRoutine not initialized!' + raise SPYValueError(legal=lgl, varname=self.__class__.__name__, actual=act) + + if self.numTrials != 1: + lgl = "1 trial: normalizations can only be done on averaged quantities!" + act = f"DataSet contains {self.numTrials} trials" + raise SPYValueError(legal=lgl, varname="data", actual=act) + + def process_metadata(self, data, out): + + # Get trialdef array + channels from source + if data._selection is not None: + chanSec_i = data._selection.channel_i + chanSec_j = data._selection.channel_j + trl = data._selection.trialdefinition + else: + chanSec_i = slice(None) + chanSec_j = slice(None) + trl = data.trialdefinition + + out.trialdefinition = trl + # Attach remaining meta-data + out.samplerate = data.samplerate + out.channel_i = np.array(data.channel_i[chanSec_i]) + out.channel_j = np.array(data.channel_j[chanSec_j]) + + +@unwrap_io +def granger_cF(csd_av_dat, + rtol=1e-8, + nIter=100, + cond_max=1e6, + chunkShape=None, + noCompute=False): + + """ + Given the trial averaged cross spectral densities, + calculates the pairwise Granger-Geweke causalities + for all (non-symmetric!) channel combinations + following the algorithm proposed in [1]_. + + First the CSD matrix is factorized using Wilson's + algorithm, the resulting transfer functions and + noise covariance matrix is then used to calculate + Granger causality according to Eq. 8 in [1]_. + + Selection of channels and frequencies of interest + can and should be done beforehand when calculating the CSDs. + + Critical numerical parameters for Wilson's algorithm + (`rtol`, `nIter`, `cond_max`) have sensitive defaults, + which were tested for datasets with up to + 5000 samples and 256 channels. Changing them is + recommended for expert users only. + + Parameters + ---------- + csd_av_dat : (1, nFreq, N, N) :class:`numpy.ndarray` + Cross-spectral densities for `N` x `N` channels + and `nFreq` frequencies averaged over trials. + rtol : float + Relative error tolerance for Wilson's algorithm + for spectral matrix factorization. Default should + be fine for most cases, handle with care! + nIter : int + Maximum number of iterations for CSD factorization. A result + is returned if exhausted also if error tolerance was not met. + cond_max : float + The maximal condition number of the spectral matrix. + The CSD matrix can be almost singular in cases of many channels and + low sample number. In these cases Wilson's factorization fails + to converge, as it relies on positive definiteness of the CSD matrix. + If the condition number is above `cond_max`, a brute force + regularization is performed until the regularized CSD matrix has a + condition number below `cond_max`. + noCompute : bool + Preprocessing flag. If `True`, do not perform actual calculation but + instead return expected shape and :class:`numpy.dtype` of output + array. + + Returns + ------- + Granger : (1, nFreq, N, N) :class:`numpy.ndarray` + Spectral Granger-Geweke causality between all channel + combinations. Directionality follows array + notation: causality from ``i -> j`` is ``Granger[0,:,i,j]``, + causality from ``j -> i`` is ``Granger[0,:,j,i]`` + + Notes + ----- + + This method is intended to be used as + :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. + Thus, input parameters are presumed to be forwarded from a parent metafunction. + Consequently, this function does **not** perform any error checking and operates + under the assumption that all inputs have been externally validated and cross-checked. + + .. [1] Dhamala, Mukeshwar, Govindan Rangarajan, and Mingzhou Ding. + "Estimating Granger causality from Fourier and wavelet transforms + of time series data." Physical review letters 100.1 (2008): 018701. + + See also + -------- + cross_spectra_cF : :func:`~syncopy.connectivity.ST_compRoutines.cross_spectra_cF` + Single trial (Multi-)tapered cross spectral densities. Trial averages + can be obtained by calling the respective computational routine + with `keeptrials=False`. + wilson_sf : :func:`~syncopy.connectivity.wilson_sf.wilson_sf + Spectral matrix factorization that yields the + transfer functions and noise covariances + from a cross spectral density. + regularize_csd : :func:`~syncopy.connectivity.wilson_sf.regularize_csd + Brute force regularization scheme for the CSD matrix + granger : :func:`~syncopy.connectivity.granger.granger + Given the results of the spectral matrix + factorization, calculates the granger causalities + """ + + # it's the same as the input shape! + outShape = csd_av_dat.shape + + # For initialization of computational routine, + # just return output shape and dtype + # Granger causalities are real + if noCompute: + return outShape, spectralDTypes['abs'] + + # strip off singleton time dimension + # for the backend calls + CSD = csd_av_dat[0] + + # auto-regularize to `cond_max` condition number + # maximal regularization factor is 1e-3, raises a ValueError + # if this is not enough! + CSDreg, factor = regularize_csd(CSD, cond_max=cond_max, eps_max=1e-3) + # call Wilson + H, Sigma, conv = wilson_sf(CSDreg, nIter=nIter, rtol=rtol) + + # calculate G-causality + Granger = granger(CSDreg, H, Sigma) + + # reattach dummy time axis + return Granger[None, ...] + + +class GrangerCausality(ComputationalRoutine): + + """ + Compute class that computes pairwise Granger causalities + of :class:`~syncopy.CrossSpectralData` objects. + + Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`, + see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute + classes and metafunctions. + + See also + -------- + syncopy.connectivityanalysis : parent metafunction + """ + + # the hard wired dimord of the cF + dimord = ['time', 'freq', 'channel_i', 'channel_j'] + + computeFunction = staticmethod(granger_cF) + + method = "" # there is no backend + # 1st argument,the data, gets omitted + valid_kws = list(signature(granger_cF).parameters.keys())[1:] + + def pre_check(self): + ''' + Make sure we have a trial average, + so the input data only consists of `1 trial`. + Can only be performed after initialization! + ''' + + if self.numTrials is None: + lgl = 'Initialize the computational Routine first!' + act = 'ComputationalRoutine not initialized!' + raise SPYValueError(legal=lgl, varname=self.__class__.__name__, actual=act) + + if self.numTrials != 1: + lgl = "1 trial: Granger causality can only be computed on trial averages!" + act = f"DataSet contains {self.numTrials} trials" + raise SPYValueError(legal=lgl, varname="data", actual=act) + + def process_metadata(self, data, out): + + # Some index gymnastics to get trial begin/end "samples" + if data._selection is not None: + chanSec_i = data._selection.channel_i + chanSec_j = data._selection.channel_j + trl = data._selection.trialdefinition + for row in range(trl.shape[0]): + trl[row, :2] = [row, row + 1] + else: + chanSec_i = slice(None) + chanSec_j = slice(None) + time = np.arange(len(data.trials)) + time = time.reshape((time.size, 1)) + trl = np.hstack((time, time + 1, + np.zeros((len(data.trials), 1)), + np.array(data.trialinfo))) + + # Attach constructed trialdef-array (if even necessary) + if self.keeptrials: + out.trialdefinition = trl + else: + out.trialdefinition = np.array([[0, 1, 0]]) + + # Attach remaining meta-data + out.samplerate = data.samplerate + out.channel_i = np.array(data.channel_i[chanSec_i]) + out.channel_j = np.array(data.channel_j[chanSec_j]) + out.freq = data.freq diff --git a/syncopy/connectivity/ST_compRoutines.py b/syncopy/connectivity/ST_compRoutines.py new file mode 100644 index 000000000..54968b863 --- /dev/null +++ b/syncopy/connectivity/ST_compRoutines.py @@ -0,0 +1,427 @@ +# -*- coding: utf-8 -*- +# +# computeFunctions and -Routines for parallel calculation +# of single trial measures needed for the averaged +# measures like cross spectral densities +# + +# Builtin/3rd party package imports +import numpy as np +from scipy.signal import fftconvolve, detrend +from inspect import signature + +# syncopy imports +from syncopy.specest.mtmfft import mtmfft +from syncopy.shared.const_def import spectralDTypes +from syncopy.shared.errors import SPYValueError +from syncopy.datatype import padding +from syncopy.shared.tools import best_match +from syncopy.shared.computational_routine import ComputationalRoutine +from syncopy.shared.kwarg_decorators import unwrap_io + + +@unwrap_io +def cross_spectra_cF(trl_dat, + samplerate=1, + foi=None, + padding_opt={}, + taper="hann", + taper_opt=None, + polyremoval=False, + timeAxis=0, + norm=False, + chunkShape=None, + noCompute=False, + fullOutput=False): + + """ + Single trial Fourier cross spectral estimates between all channels + of the input data. First all the individual Fourier transforms + are calculated via a (multi-)tapered FFT, then the pairwise + cross-spectra are computed. + + Averaging over tapers is done implicitly + for multi-taper analysis with `taper="dpss"`. + + Output consists of all (nChannels x nChannels+1)/2 different complex + estimates arranged in a symmetric fashion (``CS_ij == CS_ji*``). The + elements on the main diagonal (`CS_ii`) are the (real) auto-spectra. + + This is NOT the same as what is commonly referred to as + "cross spectral density" as there is no (time) averaging!! + Multi-tapering alone is not necessarily sufficient to get enough + statitstical power for a robust csd estimate. Yet for completeness + and testing the option `norm=True` will output a single-trial + coherence estimate. + + Parameters + ---------- + trl_dat : (K, N) :class:`numpy.ndarray` + Uniformly sampled multi-channel time-series data + The 1st dimension is interpreted as the time axis, + columns represent individual channels. + Dimensions can be transposed to `(N, K)` with the `timeAxis` parameter. + samplerate : float + Samplerate in Hz + foi : 1D :class:`numpy.ndarray` or None, optional + Frequencies of interest (Hz) for output. If desired frequencies + cannot be matched exactly the closest possible frequencies (respecting + data length and padding) are used. + padding_opt : dict + Parameters to be used for padding. See :func:`syncopy.padding` for + more details. + taper : str or None + Taper function to use, one of scipy.signal.windows + Set to `None` for no tapering. + taper_opt : dict, optional + Additional keyword arguments passed to the `taper` function. + For multi-tapering with `taper='dpss'` set the keys + `'Kmax'` and `'NW'`. + For further details, please refer to the + `SciPy docs `_ + polyremoval : int or None + Order of polynomial used for de-trending data in the time domain prior + to spectral analysis. A value of 0 corresponds to subtracting the mean + ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the + least squares fit of a linear polynomial). + If `polyremoval` is `None`, no de-trending is performed. + timeAxis : int, optional + Index of running time axis in `trl_dat` (0 or 1) + norm : bool, optional + Set to `True` to normalize for a single-trial coherence measure. + Only meaningful in a multi-taper (`taper="dpss"`) setup and if no + additional (trial-)averaging is perfomed afterwards. + noCompute : bool + Preprocessing flag. If `True`, do not perform actual calculation but + instead return expected shape and :class:`numpy.dtype` of output + array. + fullOutput : bool + For backend testing or stand-alone applications, set to `True` + to return also the `freqs` array. + + Returns + ------- + CS_ij : (1, nFreq, N, N) :class:`numpy.ndarray` + Complex cross spectra for all channel combinations ``i,j``. + `N` corresponds to number of input channels. + + freqs : (nFreq,) :class:`numpy.ndarray` + The Fourier frequencies if `fullOutput=True` + + Notes + ----- + This method is intended to be used as + :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. + Thus, input parameters are presumed to be forwarded from a parent metafunction. + Consequently, this function does **not** perform any error checking and operates + under the assumption that all inputs have been externally validated and cross-checked. + + See also + -------- + mtmfft : :func:`~syncopy.specest.mtmfft.mtmfft` + (Multi-)tapered Fourier analysis + + """ + + # Re-arrange array if necessary and get dimensional information + if timeAxis != 0: + dat = trl_dat.T # does not copy but creates view of `trl_dat` + else: + dat = trl_dat + + # Symmetric Padding (updates no. of samples) + if padding_opt: + dat = padding(dat, **padding_opt) + + nChannels = dat.shape[1] + + freqs = np.fft.rfftfreq(dat.shape[0], 1 / samplerate) + + if foi is not None: + _, freq_idx = best_match(freqs, foi, squash_duplicates=True) + nFreq = freq_idx.size + else: + freq_idx = slice(None) + nFreq = freqs.size + + # we always average over tapers here + outShape = (1, nFreq, nChannels, nChannels) + + # For initialization of computational routine, + # just return output shape and dtype + # cross spectra are complex! + if noCompute: + return outShape, spectralDTypes["fourier"] + + # detrend + if polyremoval == 0: + # SciPy's overwrite_data not working for type='constant' :/ + dat = detrend(dat, type='constant', axis=0, overwrite_data=True) + elif polyremoval == 1: + dat = detrend(dat, type='linear', axis=0, overwrite_data=True) + + # compute the individual spectra + # specs have shape (nTapers x nFreq x nChannels) + specs, freqs = mtmfft(dat, samplerate, taper, taper_opt) + + # outer product along channel axes + # has shape (nTapers x nFreq x nChannels x nChannels) + CS_ij = specs[:, :, np.newaxis, :] * specs[:, :, :, np.newaxis].conj() + + # average tapers and transpose: + # now has shape (nChannels x nChannels x nFreq) + CS_ij = CS_ij.mean(axis=0).T + + if norm: + # only meaningful for multi-tapering + if taper != 'dpss': + msg = "Normalization of single trial csd only possible with taper='dpss'" + raise SPYValueError(legal=msg, varname="taper", actual=taper) + # main diagonal has shape (nChannels x nFreq): the auto spectra + diag = CS_ij.diagonal() + # get the needed product pairs of the autospectra + Ciijj = np.sqrt(diag[:, :, None] * diag[:, None, :]).T + CS_ij = CS_ij / Ciijj + + # where does freqs go/come from - + # we will eventually allow tuples as return values yeah! + if not fullOutput: + return CS_ij[None, ..., freq_idx].transpose(0, 3, 1, 2) + else: + return CS_ij[None, ..., freq_idx].transpose(0, 3, 1, 2), freqs[freq_idx] + + +class ST_CrossSpectra(ComputationalRoutine): + + """ + Compute class that calculates single-trial (multi-)tapered cross spectra + of :class:`~syncopy.AnalogData` objects + + Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`, + see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute + classes and metafunctions. + + See also + -------- + syncopy.connectivityanalysis : parent metafunction + """ + + # the hard wired dimord of the cF + dimord = ['time', 'freq', 'channel_i', 'channel_j'] + + computeFunction = staticmethod(cross_spectra_cF) + + backends = [mtmfft] + # 1st argument,the data, gets omitted + valid_kws = list(signature(mtmfft).parameters.keys())[1:] + valid_kws += list(signature(cross_spectra_cF).parameters.keys())[1:] + # hardcode some parameter names which got digested from the frontend + valid_kws += ['tapsmofrq', 'nTaper', 'pad_to_length'] + + def process_metadata(self, data, out): + + # Some index gymnastics to get trial begin/end "samples" + if data._selection is not None: + chanSec = data._selection.channel + trl = data._selection.trialdefinition + for row in range(trl.shape[0]): + trl[row, :2] = [row, row + 1] + else: + chanSec = slice(None) + time = np.arange(len(data.trials)) + time = time.reshape((time.size, 1)) + trl = np.hstack((time, time + 1, + np.zeros((len(data.trials), 1)), + np.array(data.trialinfo))) + + # Attach constructed trialdef-array (if even necessary) + if self.keeptrials: + out.trialdefinition = trl + else: + out.trialdefinition = np.array([[0, 1, 0]]) + + # Attach remaining meta-data + out.samplerate = data.samplerate + out.channel_i = np.array(data.channel[chanSec]) + out.channel_j = np.array(data.channel[chanSec]) + out.freq = self.cfg['foi'] + + +@unwrap_io +def cross_covariance_cF(trl_dat, + samplerate=1, + padding_opt={}, + polyremoval=0, + timeAxis=0, + norm=False, + chunkShape=None, + noCompute=False, + fullOutput=False): + + """ + Single trial covariance estimates between all channels + of the input data. Output consists of all ``(nChannels x nChannels+1)/2`` + different estimates arranged in a symmetric fashion + (``COV_ij == COV_ji``). The elements on the + main diagonal (`CS_ii`) are the channel variances. + + Parameters + ---------- + trl_dat : (K, N) :class:`numpy.ndarray` + Uniformly sampled multi-channel time-series data + The 1st dimension is interpreted as the time axis, + columns represent individual channels. + Dimensions can be transposed to `(N, K)` with the `timeAxis` parameter. + samplerate : float + Samplerate in Hz + padding_opt : dict + Parameters to be used for padding. See :func:`syncopy.padding` for + more details. + polyremoval : int or None + Order of polynomial used for de-trending data in the time domain prior + to spectral analysis. A value of 0 corresponds to subtracting the mean + ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the + least squares fit of a linear polynomial). + If `polyremoval` is `None`, no de-trending is performed. + timeAxis : int, optional + Index of running time axis in `trl_dat` (0 or 1) + norm : bool, optional + Set to `True` to normalize for single-trial cross-correlation. + noCompute : bool + Preprocessing flag. If `True`, do not perform actual calculation but + instead return expected shape and :class:`numpy.dtype` of output + array. + fullOutput : bool + For backend testing or stand-alone applications, set to `True` + to return also the `lags` array. + + Returns + ------- + CC_ij : (K, 1, N, N) :class:`numpy.ndarray` + Cross covariance for all channel combinations ``i,j``. + `N` corresponds to number of input channels. + + lags : (M,) :class:`numpy.ndarray` + The lag times if `fullOutput=True` + + Notes + ----- + This method is intended to be used as + :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. + Thus, input parameters are presumed to be forwarded from a parent metafunction. + Consequently, this function does **not** perform any error checking and operates + under the assumption that all inputs have been externally validated and cross-checked. + + """ + + # Re-arrange array if necessary and get dimensional information + if timeAxis != 0: + dat = trl_dat.T # does not copy but creates view of `trl_dat` + else: + dat = trl_dat + + # detrend + if polyremoval == 0: + # SciPy's overwrite_data not working for type='constant' :/ + dat = detrend(dat, type='constant', axis=0, overwrite_data=True) + elif polyremoval == 1: + detrend(dat, type='linear', axis=0, overwrite_data=True) + + # Symmetric Padding (updates no. of samples) + if padding_opt: + dat = padding(dat, **padding_opt) + + nSamples = dat.shape[0] + nChannels = dat.shape[1] + + # positive lags in time units + if nSamples % 2 == 0: + lags = np.arange(0, nSamples // 2) + else: + lags = np.arange(0, nSamples // 2 + 1) + lags = lags * 1 / samplerate + + outShape = (len(lags), 1, nChannels, nChannels) + + # For initialization of computational routine, + # just return output shape and dtype + # cross covariances are real! + if noCompute: + return outShape, spectralDTypes["abs"] + + # re-normalize output for different effective overlaps + norm_overlap = np.arange(nSamples, nSamples // 2, step = -1) + + CC = np.empty(outShape) + for i in range(nChannels): + for j in range(i + 1): + cc12 = fftconvolve(dat[:, i], dat[::-1, j], mode='same') + CC[:, 0, i, j] = cc12[nSamples // 2:] / norm_overlap + if i != j: + # cross-correlation is symmetric with C(tau) = C(-tau)^T + cc21 = cc12[::-1] + CC[:, 0, j, i] = cc21[nSamples // 2:] / norm_overlap + + # normalize with products of std + if norm: + STDs = np.std(dat, axis=0) + N = STDs[:, None] * STDs[None, :] + CC = CC / N + + if not fullOutput: + return CC + else: + return CC, lags + + +class ST_CrossCovariance(ComputationalRoutine): + + """ + Compute class that calculates single-trial cross-covariances + of :class:`~syncopy.AnalogData` objects + + Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`, + see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute + classes and metafunctions. + + See also + -------- + syncopy.connectivityanalysis : parent metafunction + """ + + # the hard wired dimord of the cF + dimord = ['time', 'freq', 'channel_i', 'channel_j'] + + computeFunction = staticmethod(cross_covariance_cF) + + # 1st argument,the data, gets omitted + valid_kws = list(signature(cross_covariance_cF).parameters.keys())[1:] + + def process_metadata(self, data, out): + + # Get trialdef array + channels from source: note, since lags are encoded + # in time-axis, trial offsets etc. are bogus anyway: simply take max-sample + # counts / 2 to fit lags + if data._selection is not None: + chanSec = data._selection.channel + trl = np.ceil(data._selection.trialdefinition / 2) + else: + chanSec = slice(None) + trl = np.ceil(data.trialdefinition / 2) + + # If trial-averaging was requested, use the first trial as reference + # (all trials had to have identical lengths), and average onset timings + if not self.keeptrials: + t0 = trl[:, 2].mean() + trl = trl[[0], :] + trl[:, 2] = t0 + + out.trialdefinition = trl + # Attach remaining meta-data + out.samplerate = data.samplerate + out.channel_i = np.array(data.channel[chanSec]) + out.channel_j = np.array(data.channel[chanSec]) + + diff --git a/syncopy/connectivity/__init__.py b/syncopy/connectivity/__init__.py new file mode 100644 index 000000000..b77eb5f05 --- /dev/null +++ b/syncopy/connectivity/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# +# Populate namespace with user exposed +# connectivity methods +# + +from .connectivity_analysis import connectivity +from .connectivity_analysis import __all__ as _all_ + +# Populate local __all__ namespace +__all__ = [] +__all__.extend(_all_) diff --git a/syncopy/connectivity/connectivity_analysis.py b/syncopy/connectivity/connectivity_analysis.py new file mode 100644 index 000000000..e74c57065 --- /dev/null +++ b/syncopy/connectivity/connectivity_analysis.py @@ -0,0 +1,310 @@ +# -*- coding: utf-8 -*- +# +# Syncopy connectivity analysis methods +# + +# Builtin/3rd party package imports +import numpy as np +from numbers import Number + +# Syncopy imports +from syncopy.shared.parsers import data_parser, scalar_parser +from syncopy.shared.tools import get_defaults +from syncopy.datatype import CrossSpectralData +from syncopy.datatype.methods.padding import _nextpow2 +from syncopy.shared.errors import ( + SPYValueError, + SPYWarning, + SPYInfo) +from syncopy.shared.kwarg_decorators import (unwrap_cfg, unwrap_select, + detect_parallel_client) +from syncopy.shared.input_validators import ( + validate_taper, + validate_foi, + check_effective_parameters, + check_passed_kwargs +) + +from .ST_compRoutines import ST_CrossSpectra, ST_CrossCovariance +from .AV_compRoutines import NormalizeCrossSpectra, NormalizeCrossCov, GrangerCausality + +__all__ = ["connectivity"] +availableMethods = ("coh", "corr", "granger") + + +@unwrap_cfg +@unwrap_select +@detect_parallel_client +def connectivity(data, method="coh", keeptrials=False, output="abs", + foi=None, foilim=None, pad_to_length=None, + polyremoval=None, taper="hann", tapsmofrq=None, + nTaper=None, out=None, **kwargs): + + """ + Perform connectivity analysis of Syncopy :class:`~syncopy.AnalogData` objects + + **Usage Summary** + + Options available in all analysis methods: + + * **foi**/**foilim** : frequencies of interest; either array of frequencies or + frequency window (not both) + * **polyremoval** : de-trending method to use (0 = mean, 1 = linear) + + List of available analysis methods and respective distinct options: + + "coh" : (Multi-) tapered coherency estimate + Compute the normalized cross spectral densities + between all channel combinations + + * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers` + * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz) + * **nTaper** : (optional) number of orthogonal tapers for slepian tapers + * **pad_to_length**: either pad to an absolute length or set to `'nextpow2'` + + "corr" : Cross-correlations + Computes the one sided (positive lags) cross-correlations + between all channel combinations. The maximal lag is half + the trial lenghts. + + * **keeptrials** : set to `True` for single trial cross-correlations + + "granger" : Spectral Granger-Geweke causality + Computes linear causality estimates between + all channel combinations. The needed cross-spectral + densities can be computed via multi-tapering. + + * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers` + * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz) + * **nTaper** : (optional) number of orthogonal tapers for slepian tapers + * **pad_to_length**: either pad to an absolute length or set to `'nextpow2'` + + **Full documentation below** + + """ + + # Make sure our one mandatory input object can be processed + try: + data_parser(data, varname="data", dataclass="AnalogData", + writable=None, empty=False) + except Exception as exc: + raise exc + timeAxis = data.dimord.index("time") + + # Get everything of interest in local namespace + defaults = get_defaults(connectivity) + lcls = locals() + # check for ineffective additional kwargs + check_passed_kwargs(lcls, defaults, frontend_name="connectivity") + + # Ensure a valid computational method was selected + if method not in availableMethods: + lgl = "'" + "or '".join(opt + "' " for opt in availableMethods) + raise SPYValueError(legal=lgl, varname="method", actual=method) + + # if a subset selection is present + # get sampleinfo and check for equidistancy + if data._selection is not None: + sinfo = data._selection.trialdefinition[:, :2] + trialList = data._selection.trials + # user picked discrete set of time points + if isinstance(data._selection.time[0], list): + lgl = "equidistant time points (toi) or time slice (toilim)" + actual = "non-equidistant set of time points" + raise SPYValueError(legal=lgl, varname="select", actual=actual) + else: + trialList = list(range(len(data.trials))) + sinfo = data.sampleinfo + lenTrials = np.diff(sinfo).squeeze() + numTrials = len(trialList) + + # check polyremoval + if polyremoval is not None: + scalar_parser(polyremoval, varname="polyremoval", ntype="int_like", lims=[0, 1]) + + # --- Padding --- + + if method == "corr" and pad_to_length: + lgl = "`None`, no padding needed/allowed for cross-correlations" + actual = f"{pad_to_length}" + raise SPYValueError(legal=lgl, varname="pad_to_length", actual=actual) + + # here we check for equal lengths trials as is required for + # trial averaging, in case of no user specified absolute padding length + # we do a rough 'maxlen' padding, nextpow2 will be overruled in this case + if lenTrials.min() != lenTrials.max() and not isinstance(pad_to_length, Number): + pad_to_length = int(lenTrials.max()) + msg = f"Unequal trial lengths present, automatic padding to {pad_to_length} samples" + SPYWarning(msg) + + # symmetric zero padding of ALL trials the same way + if isinstance(pad_to_length, Number): + + scalar_parser(pad_to_length, + varname='pad_to_length', + ntype='int_like', + lims=[lenTrials.max(), np.inf]) + padding_opt = { + 'padtype' : 'zero', + 'pad' : 'absolute', + 'padlength' : pad_to_length + } + # after padding! + nSamples = pad_to_length + + # or pad to optimal FFT lengths + # (not possible for unequal lengths trials) + elif pad_to_length == 'nextpow2': + padding_opt = { + 'padtype' : 'zero', + 'pad' : 'nextpow2' + } + # after padding + nSamples = _nextpow2(int(lenTrials.min())) + # no padding + else: + padding_opt = None + nSamples = int(lenTrials.min()) + + # --- Basic foi sanitization --- + + foi, foilim = validate_foi(foi, foilim, data.samplerate) + + # only now set foi array for foilim in 1Hz steps + if foilim: + foi = np.arange(foilim[0], foilim[1] + 1) + + # Prepare keyword dict for logging (use `lcls` to get actually provided + # keyword values, not defaults set above) + log_dict = {"method": method, + "output": output, + "keeptrials": keeptrials, + "polyremoval": polyremoval, + "pad_to_length": pad_to_length} + + # --- Setting up specific Methods --- + + if method in ['coh', 'granger']: + + # --- set up computation of the single trial CSDs --- + + if keeptrials is not False: + lgl = "False, trial averaging needed!" + act = keeptrials + raise SPYValueError(lgl, varname="keeptrials", actual=act) + + if foi is None and foilim is None: + # Construct array of maximally attainable frequencies + freqs = np.fft.rfftfreq(nSamples, 1 / data.samplerate) + msg = (f"Automatic FFT frequency selection from {freqs[0]:.1f}Hz to " + f"{freqs[-1]:.1f}Hz") + SPYInfo(msg) + foi = freqs + + # sanitize taper selection and retrieve dpss settings + taper_opt = validate_taper(taper, + tapsmofrq, + nTaper, + keeptapers=False, # ST_CSD's always average tapers + foimax=foi.max(), + samplerate=data.samplerate, + nSamples=nSamples, + output="pow") # ST_CSD's always have this unit/norm + + log_dict["foi"] = foi + log_dict["taper"] = taper + # only dpss returns non-empty taper_opt dict + if taper_opt: + log_dict["nTaper"] = taper_opt["Kmax"] + log_dict["tapsmofrq"] = tapsmofrq + + check_effective_parameters(ST_CrossSpectra, defaults, lcls) + # parallel computation over trials + st_compRoutine = ST_CrossSpectra(samplerate=data.samplerate, + padding_opt=padding_opt, + taper=taper, + taper_opt=taper_opt, + polyremoval=polyremoval, + timeAxis=timeAxis, + foi=foi) + # hard coded as class attribute + st_dimord = ST_CrossSpectra.dimord + + if method == 'coh': + # final normalization after trial averaging + av_compRoutine = NormalizeCrossSpectra(output=output) + + if method == 'granger': + # after trial averaging + # hardcoded numerical parameters + av_compRoutine = GrangerCausality(rtol=1e-8, + nIter=100, + cond_max=1e5 + ) + + if method == 'corr': + check_effective_parameters(ST_CrossCovariance, defaults, lcls) + + # single trial cross-correlations + if keeptrials: + av_compRoutine = None # no trial average + norm = True # normalize individual trials within the ST CR + else: + av_compRoutine = NormalizeCrossCov() + norm = False + + # parallel computation over trials + st_compRoutine = ST_CrossCovariance(samplerate=data.samplerate, + padding_opt=padding_opt, + polyremoval=polyremoval, + timeAxis=timeAxis, + norm=norm) + # hard coded as class attribute + st_dimord = ST_CrossCovariance.dimord + + # ------------------------------------------------- + # Call the chosen single trial ComputationalRoutine + # ------------------------------------------------- + + # the single trial results need a new DataSet + st_out = CrossSpectralData(dimord=st_dimord) + + # Perform the trial-parallelized computation of the matrix quantity + st_compRoutine.initialize(data, + st_out._stackingDim, + chan_per_worker=None, # no parallelisation over channels possible + keeptrials=keeptrials) # we most likely need trial averaging! + st_compRoutine.compute(data, st_out, parallel=kwargs.get("parallel"), log_dict=log_dict) + + # if ever needed.. + # for single trial cross-corr results <-> keeptrials is True + if keeptrials and av_compRoutine is None: + if out is not None: + msg = "Single trial processing does not support `out` argument but directly returns the results" + SPYWarning(msg) + return st_out + + # ---------------------------------------------------------------------------------- + # Sanitize output and call the chosen ComputationalRoutine on the averaged ST output + # ---------------------------------------------------------------------------------- + + # If provided, make sure output object is appropriate + if out is not None: + try: + data_parser(out, varname="out", writable=True, empty=True, + dataclass="CrossSpectralData", + dimord=st_dimord) + except Exception as exc: + raise exc + new_out = False + else: + out = CrossSpectralData(dimord=st_dimord) + new_out = True + + # now take the trial average from the single trial CR as input + av_compRoutine.initialize(st_out, out._stackingDim, chan_per_worker=None) + av_compRoutine.pre_check() # make sure we got a trial_average + av_compRoutine.compute(st_out, out, parallel=False, log_dict=log_dict) + + # Either return newly created output object or simply quit + return out if new_out else None diff --git a/syncopy/connectivity/granger.py b/syncopy/connectivity/granger.py new file mode 100644 index 000000000..a755a5253 --- /dev/null +++ b/syncopy/connectivity/granger.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# +# Implementation of Granger-Geweke causality +# +# +# Builtin/3rd party package imports +import numpy as np + + +def granger(CSD, Hfunc, Sigma): + """ + Computes the pairwise Granger-Geweke causalities + for all (non-symmetric!) channel combinations + according to Equation 8 in [1]_. + + The transfer functions `Hfunc` and noise covariance + `Sigma` are expected to have been already computed. + + Parameters + ---------- + CSD : (nFreq, N, N) :class:`numpy.ndarray` + Complex cross spectra for all channel combinations ``i,j`` + `N` corresponds to number of input channels. + Hfunc : (nFreq, N, N) :class:`numpy.ndarray` + Spectral transfer functions for all channel combinations ``i,j`` + Sigma : (N, N) :class:`numpy.ndarray` + The noise covariances + + Returns + ------- + Granger : (nFreq, N, N) :class:`numpy.ndarray` + Spectral Granger-Geweke causality between all channel + combinations. Directionality follows array + notation: causality from ``i -> j`` is ``Granger[:,i,j]``, + causality from ``j -> i`` is ``Granger[:,j,i]`` + + See also + -------- + wilson_sf : :func:`~syncopy.connectivity.wilson_sf.wilson_sf + Spectral matrix factorization that yields the + transfer functions and noise covariances + from a cross spectral density. + + Notes + ----- + .. [1] Dhamala, Mukeshwar, Govindan Rangarajan, and Mingzhou Ding. + "Estimating Granger causality from Fourier and wavelet transforms + of time series data." Physical review letters 100.1 (2008): 018701. + + """ + + nChannels = CSD.shape[1] + auto_spectra = CSD.transpose(1, 2, 0).diagonal() + auto_spectra = np.abs(auto_spectra) # auto-spectra are real + + # we need the stacked auto-spectra of the form (nChannel=3): + # S_11 S_22 S_33 + # Smat(f) = S_11 S_22 S_33 + # S_11 S_22 S_33 + Smat = auto_spectra[:, None, :] * np.ones(nChannels)[:, None] + + # Granger i->j needs H_ji entry + Hmat = np.abs(Hfunc.transpose(0, 2, 1))**2 + # Granger i->j needs Sigma_ji entry + SigmaIJ = np.abs(Sigma.T)**2 + + # imag part should be 0 + auto_cov = np.abs(Sigma.diagonal()) + # same stacking as for the auto spectra (without freq axis) + SigmaII = auto_cov[:, None] * np.ones(nChannels)[:, None] + + # the denominator + denom = SigmaII.transpose() - SigmaIJ / SigmaII + denom = Smat - denom * Hmat + + # linear causality i -> j + Granger = np.log(Smat / denom) + + return Granger diff --git a/syncopy/connectivity/wilson_sf.py b/syncopy/connectivity/wilson_sf.py new file mode 100644 index 000000000..de80a1851 --- /dev/null +++ b/syncopy/connectivity/wilson_sf.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +# +# Performs the numerical inner-outer factorization of a spectral matrix, using +# Wilsons method. This implementation here is a Python version of the original +# Matlab implementation by M. Dhamala (mdhamala@bme.ufl.edu) & G. Rangarajan +# (rangaraj@math.iisc.ernet.in), UF, Aug 3-4, 2006. +# +# The algorithm itself was first presented in: +# The Factorization of Matricial Spectral Densities, SIAM J. Appl. Math, +# Vol. 23, No. 4, pgs 420-426 December 1972 by G T Wilson). + +# Builtin/3rd party package imports +import numpy as np + + +def wilson_sf(CSD, nIter=100, rtol=1e-9): + """ + Wilsons spectral matrix factorization ("analytic method") + + Converges extremely fast, so the default number of + iterations should be more than enough in practical situations. + + This is a pure backend function and hence no input argument + checking is performed. + + Parameters + ---------- + CSD : (nFreq, N, N) :class:`numpy.ndarray` + Complex cross spectra for all channel combinations ``i,j``. + `N` corresponds to number of input channels. Has to be + positive definite and well conditioned. + nIter : int + Maximum number of iterations, factorization result + is returned also if error tolerance wasn't met. + rtol : float + Tolerance of the relative maximal + error of the factorization. + + Returns + ------- + Hfunc : (nFreq, N, N) :class:`numpy.ndarray` + The transfer function + Sigma : (N, N) :class:`numpy.ndarray` + Noise covariance + converged : bool + Indicates wether the algorithm converged. + If `False` result was returned after `nIter` + iterations. + """ + + nFreq = CSD.shape[0] + + Ident = np.eye(*CSD.shape[1:]) + + # nChannel x nChannel + psi0 = _psi0_initial(CSD) + + # initial choice of psi, constant for all z(~f) + psi = np.tile(psi0, (nFreq, 1, 1)) + assert psi.shape == CSD.shape + + converged = False + for _ in range(nIter): + + psi_inv = np.linalg.inv(psi) + # the bracket of equation 3.1 + g = psi_inv @ CSD @ psi_inv.conj().transpose(0, 2, 1) + gplus, gplus_0 = _plusOperator(g + Ident) + + # the 'any' matrix + S = np.triu(gplus_0) + S = S - S.conj().T # S + S* = 0 + + # the next step psi_{tau+1} + psi = psi @ (gplus + S) + psi0 = psi0 @ (gplus_0 + S) + + # max relative error + CSDfac = psi @ psi.conj().transpose(0, 2, 1) + err = np.abs(CSD - CSDfac) + err = (err / np.abs(CSD)).max() + + # converged + if err < rtol: + converged = True + break + + # Noise Covariance + Sigma = psi0 @ psi0.conj().T + + # Transfer function + psi0_inv = np.linalg.inv(psi0) + Hfunc = psi @ psi0_inv.conj().T + + return Hfunc, Sigma, converged + + +def _psi0_initial(CSD): + + """ + Initialize Wilson's algorithm with the Cholesky + decomposition of the 1st Fourier series component + of the cross spectral density matrix (CSD). This is + explicitly proposed in section 4. of the original paper. + """ + + nSamples = CSD.shape[1] + + # perform ifft to obtain gammas. + gamma = np.fft.ifft(CSD, axis=0) + gamma0 = gamma[0, ...] + + # Remove any asymmetry due to rounding error. + # This also will zero out any imaginary values + # on the diagonal - real diagonals are required for cholesky. + gamma0 = np.real((gamma0 + gamma0.conj()) / 2) + + # check for positive definiteness + eivals = np.linalg.eigvals(gamma0) + if np.all(np.imag(eivals) == 0): + psi0 = np.linalg.cholesky(gamma0) + # otherwise initialize with 1's as a fallback + else: + psi0 = np.ones((nSamples, nSamples)) + + return psi0.T + + +def _plusOperator(g): + + """ + The []+ operator from definition 1.2, + given by explicit Fourier transformations + + The nFreq x nChannel x nChannel matrix `g` is given + in the frequency domain. + """ + + # 'negative lags' from the ifft + nLag = g.shape[0] // 2 + # the series expansion in beta_k + beta = np.fft.ifft(g, axis=0) + + # take half of the zero lag + beta[0, ...] = 0.5 * beta[0, ...] + g0 = beta[0, ...].copy() + + # Zero out negative lags + beta[nLag + 1:, ...] = 0 + + gp = np.fft.fft(beta, axis=0) + + return gp, g0 + + +# --- End of Wilson's Algorithm --- + + +def regularize_csd(CSD, cond_max=1e6, eps_max=1e-3, nSteps=15): + + """ + Brute force regularization of CSD matrix + by inspecting the maximal condition number + along the frequency axis. + Multiply with different ``epsilon * I``, + starting with ``epsilon = 1e-10`` until the + condition number is smaller than `cond_max`. + Raises a `ValueError` if the maximal regularization + factor `epx_max` was reached but `cond_max` still not met. + + + Parameters + ---------- + CSD : 3D :class:`numpy.ndarray` + The cross spectral density matrix + with shape ``(nFreq, nChannel, nChannel)`` + cond_max : float + The maximal condition number after regularization + eps_max : float + The largest regularization factor to be used. If + also this value does not regularize the CSD up + to `cond_max` a `ValueError` is raised. + nSteps : int + Number of steps between 1e-10 and `eps_max`. + + Returns + ------- + CSDreg : 3D :class:`numpy.ndarray` + The regularized CSD matrix with a maximal + condition number of `cond_max` + eps : float + The regularization factor used + + """ + + epsilons = np.logspace(-10, np.log10(eps_max), nSteps) + I = np.eye(CSD.shape[1]) + + CondNum = np.linalg.cond(CSD).max() + + # nothing to be done + if CondNum < cond_max: + return CSD, 0 + + for eps in epsilons: + CSDreg = CSD + eps * I + CondNum = np.linalg.cond(CSDreg).max() + + if CondNum < cond_max: + return CSDreg, eps + + msg = f"CSD matrix not regularizable with a max epsilon of {eps_max}!" + raise ValueError(msg) diff --git a/syncopy/datatype/__init__.py b/syncopy/datatype/__init__.py index ae1d8b6af..03e33882a 100644 --- a/syncopy/datatype/__init__.py +++ b/syncopy/datatype/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -# +# # Populate namespace with datatype routines and classes -# +# # Import __all__ routines from local modules from . import base_data, continuous_data, discrete_data, methods, statistical_data @@ -12,6 +12,7 @@ from .methods.definetrial import * from .methods.padding import * from .methods.selectdata import * +from .methods.show import * # Populate local __all__ namespace __all__ = [] @@ -22,3 +23,4 @@ __all__.extend(methods.definetrial.__all__) __all__.extend(methods.padding.__all__) __all__.extend(methods.selectdata.__all__) +__all__.extend(methods.show.__all__) diff --git a/syncopy/datatype/base_data.py b/syncopy/datatype/base_data.py index 12769e876..32a0c325d 100644 --- a/syncopy/datatype/base_data.py +++ b/syncopy/datatype/base_data.py @@ -18,16 +18,20 @@ from functools import reduce import shutil import numpy as np +from numpy.lib.arraysetops import isin from numpy.lib.format import open_memmap, read_magic import h5py import scipy as sp # Local imports import syncopy as spy +from .methods.arithmetic import _process_operator +from .methods.selectdata import selectdata +from .methods.show import show from syncopy.shared.tools import StructDict from syncopy.shared.parsers import (scalar_parser, array_parser, io_parser, filename_parser, data_parser) -from syncopy.shared.errors import SPYTypeError, SPYValueError, SPYError, SPYWarning +from syncopy.shared.errors import SPYInfo, SPYTypeError, SPYValueError, SPYError, SPYWarning from syncopy.datatype.methods.definetrial import definetrial as _definetrial from syncopy import __version__, __storage__, __acme__, __sessionid__, __storagelimit__ if __acme__: @@ -67,16 +71,53 @@ class BaseData(ABC): # Dummy allocations of class attributes that are actually initialized in subclasses _mode = None + _stackingDimLabel = None # Set caller for `SPYWarning` to not have it show up as '' _spwCaller = "BaseData.{}" + # Attach data selection and output routines to make them available as class methods + selectdata = selectdata + show = show + + # Initialize hidden attributes used by all children + _cfg = {} + _filename = None + _trialdefinition = None + _dimord = None + _mode = None + _lhd = "\n\t\t>>> SyNCopy v. {ver:s} <<< \n\n" +\ + "Created: {timestamp:s} \n\n" +\ + "System Profile: \n" +\ + "{sysver:s} \n" +\ + "ACME: {acver:s}\n" +\ + "Dask: {daver:s}\n" +\ + "NumPy: {npver:s}\n" +\ + "SciPy: {spver:s}\n\n" +\ + "--- LOG ---" + _log_header = _lhd.format(ver=__version__, + timestamp=time.asctime(), + sysver=sys.version, + acver=acme.__version__ if __acme__ else "--", + daver=dask.__version__ if __acme__ else "--", + npver=np.__version__, + spver=sp.__version__) + _log = "" + @property @classmethod @abstractmethod def _defaultDimord(cls): return NotImplementedError + @property + def _stackingDim(self): + if any(["DiscreteData" in str(base) for base in self.__class__.__mro__]): + return 0 + else: + if self._stackingDimLabel is not None and self.dimord is not None: + return self.dimord.index(self._stackingDimLabel) + @property def cfg(self): """Dictionary of previous operations on data""" @@ -121,6 +162,7 @@ def _set_dataset_property(self, dataIn, propertyName, ndim=None): ndim = len(self._defaultDimord) supportedSetters = { + list : self._set_dataset_property_with_list, str : self._set_dataset_property_with_str, np.ndarray : self._set_dataset_property_with_ndarray, np.core.memmap : self._set_dataset_property_with_memmap, @@ -294,6 +336,85 @@ def _set_dataset_property_with_dataset(self, inData, propertyName, ndim): setattr(self, "_" + propertyName, inData) + def _set_dataset_property_with_list(self, inData, propertyName, ndim): + """Set a dataset property with list of NumPy arrays + + Parameters + ---------- + inData : list + list of :class:`numpy.ndarray`s. Each array corresponds to + a trial. Arrays are stacked together to fill dataset. + propertyName : str + Name of the property to be filled with the concatenated array + ndim : int + Number of expected array dimensions. + """ + + # Check list entries: must be numeric, finite NumPy arrays + for val in inData: + try: + array_parser(val, varname="data", hasinf=False, dims=ndim) + except Exception as exc: + raise exc + + # Ensure we don't have a mix of real/complex arrays + if np.unique([np.iscomplexobj(val) for val in inData]).size > 1: + lgl = "list of numeric NumPy arrays of same numeric type (real/complex)" + act = "real and complex NumPy arrays" + raise SPYValueError(legal=lgl, varname="data", actual=act) + + # Requirements for input arrays differ wrt data-class (`DiscreteData` always 2D) + if any(["ContinuousData" in str(base) for base in self.__class__.__mro__]): + + # Ensure shapes match up + if any(val.shape != inData[0].shape for val in inData): + lgl = "NumPy arrays of identical shape" + act = "NumPy arrays with differing shapes" + raise SPYValueError(legal=lgl, varname="data", actual=act) + trialLens = [val.shape[self.dimord.index("time")] for val in inData] + + else: + + # Ensure all arrays have shape `(N, nCol)`` + if self.__class__.__name__ == "SpikeData": + nCol = 3 + else: # EventData + nCol = 2 + if any(val.shape[1] != nCol for val in inData): + lgl = "NumPy 2d-arrays with 3 columns" + act = "NumPy arrays of different shape" + raise SPYValueError(legal=lgl, varname="data", actual=act) + trialLens = [np.nanmax(val[:, self.dimord.index("sample")]) for val in inData] + + # Now the shaky stuff: if not provided, use determined trial lengths to + # cook up a (completely fictional) samplerate: we aim for `smax` Hz and + # round down to `sround` Hz + nTrials = len(trialLens) + msg2 = "" + if self.samplerate is None: + sround = 50 + smax = 1000 + srate = min(max(min(smax, tlen / 2) // sround * sround, 1) for tlen in trialLens) + self.samplerate = srate + msg2 = ", samplerate = {srate} Hz (rounded to {sround} Hz with max of {smax} Hz)" + msg2 = msg2.format(srate=srate, sround=sround, smax=smax) + t0 = -self.samplerate + msg = "Artificially generated trial-layout: trigger offset = {t0} sec" + msg2 + SPYWarning(msg.format(t0=t0/self.samplerate), caller="data") + + # Use constructed quantities to set up trial layout matrix + accumSamples = np.cumsum(trialLens) + trialdefinition = np.zeros((nTrials, 3)) + trialdefinition[1:, 0] = accumSamples[:-1] + trialdefinition[:, 1] = accumSamples + trialdefinition[:, 2] = t0 + + # Finally, concatenate provided arrays and let corresponding setting method + # perform the actual HDF magic + data = np.concatenate(inData, axis=self._stackingDim) + self._set_dataset_property_with_ndarray(data, propertyName, ndim) + self.trialdefinition = trialdefinition + def _is_empty(self): return all([getattr(self, attr) is None for attr in self._hdfFileDatasetProperties]) @@ -321,12 +442,19 @@ def dimord(self, dims): self._dimord = None return + # this enforces the _defaultDimord if set(dims) != set(self._defaultDimord): base = "dimensional labels {}" lgl = base.format("'" + "' x '".join(str(dim) for dim in self._defaultDimord) + "'") act = base.format("'" + "' x '".join(str(dim) for dim in dims) + "'") raise SPYValueError(legal=lgl, varname="dimord", actual=act) + # this enforces that custom dimords are set for every axis + if len(dims) != len(self._defaultDimord): + lgl = f"Custom dimord has length {len(self._defaultDimord)}" + act = f"Custom dimord has length {len(dims)}" + raise SPYValueError(legal=lgl, varname="dimord", actual=act) + # Canonical way to perform initial allocation of dimensional properties # (`self._channel = None`, `self._freq = None` etc.) self._dimord = list(dims) @@ -497,18 +625,16 @@ def trialinfo(self): def trialinfo(self, trl): raise SPYError("Cannot set trialinfo. Use `BaseData._trialdefinition` or `syncopy.definetrial` instead.") - # Selector method - @abstractmethod - def selectdata(self, trials=None, deepcopy=False, **kwargs): - """ - Docstring mostly pointing to ``selectdata`` - """ - # Helper function that grabs a single trial @abstractmethod def _get_trial(self, trialno): pass + # Helper function that creates a `FauxTrial` object given actual trial information + @abstractmethod + def _preview_trial(self, trialno): + pass + # Convenience function, wiping contents of backing device from memory def clear(self): """Clear loaded data from memory @@ -692,6 +818,107 @@ def __del__(self): shutil.rmtree(os.path.splitext(self.filename)[0], ignore_errors=True) + # Support for basic arithmetic operations (no in-place computations supported yet) + def __add__(self, other): + return _process_operator(self, other, "+") + + def __radd__(self, other): + return _process_operator(self, other, "+") + + def __sub__(self, other): + return _process_operator(self, other, "-") + + def __rsub__(self, other): + return _process_operator(self, other, "-") + + def __mul__(self, other): + return _process_operator(self, other, "*") + + def __rmul__(self, other): + return _process_operator(self, other, "*") + + def __truediv__(self, other): + return _process_operator(self, other, "/") + + def __rtruediv__(self, other): + return _process_operator(self, other, "/") + + def __pow__(self, other): + return _process_operator(self, other, "**") + + def __eq__(self, other): + + # If other object is not a Syncopy data-class, get out + if not "BaseData" in str(other.__class__.__mro__): + SPYInfo("Not a Syncopy object") + return False + + # Check if two Syncopy objects of same type/dimord are present + try: + data_parser(other, dimord=self.dimord, dataclass=self.__class__.__name__) + except Exception as exc: + SPYInfo("Syncopy object of different type/dimord") + return False + + # First, ensure we have something to compare here + if self._is_empty(): + if not other._is_empty(): + SPYInfo("Empty and non-empty Syncopy object") + return False + return True + + # If in-place selections are present, abort + if self._selection is not None or other._selection is not None: + err = "Cannot perform object comparison with existing in-place selection" + raise SPYError(err) + + # Use `_infoFileProperties` to fetch dimensional object props: remove `dimord` + # (has already been checked by `data_parser` above) and remove `cfg` (two + # objects might be identical even if their history deviates) + dimProps = [prop for prop in self._infoFileProperties if not prop.startswith("_")] + dimProps = list(set(dimProps).difference(["dimord", "cfg"])) + for prop in dimProps: + val = getattr(self, prop) + if isinstance(val, np.ndarray): + isEqual = val.tolist() == getattr(other, prop).tolist() + else: + isEqual = val == getattr(other, prop) + if not isEqual: + SPYInfo("Mismatch in {}".format(prop)) + return False + + # Check if trial setup is identical + if not np.array_equal(self.trialdefinition, other.trialdefinition): + SPYInfo("Mismatch in trial layouts") + return False + + # If an object is compared to itself (or its shallow copy), don't bother + # juggling NumPy arrays but simply perform a quick dataset/filename comparison + isEqual = True + if self.filename == other.filename: + for dsetName in self._hdfFileDatasetProperties: + val = getattr(self, dsetName) + if isinstance(val, h5py.Dataset): + isEqual = val == getattr(other, dsetName) + else: + isEqual = np.allclose(val, getattr(other, dsetName)) + if not isEqual: + SPYInfo("HDF dataset mismatch") + return False + return True + + # The other object really is a standalone Syncopy class instance and + # everything but the data itself aligns; now the most expensive part: + # trial by trial data comparison + for tk in range(len(self.trials)): + if not np.allclose(self.trials[tk], other.trials[tk]): + SPYInfo("Mismatch in trial #{}".format(tk)) + return False + + # If we made it this far, `self` and `other` really seem to be identical + return True + + # Class "constructor" def __init__(self, filename=None, dimord=None, mode="r+", **kwargs): """ @@ -703,11 +930,6 @@ def __init__(self, filename=None, dimord=None, mode="r+", **kwargs): """ # Initialize hidden attributes - self._cfg = {} - self._filename = None - self._trialdefinition = None - self._dimord = None - self._mode = None for propertyName in self._hdfFileDatasetProperties: setattr(self, "_" + propertyName, None) @@ -738,24 +960,7 @@ def __init__(self, filename=None, dimord=None, mode="r+", **kwargs): for propertyName in self._hdfFileDatasetProperties: setattr(self, propertyName, kwargs[propertyName]) - # Prepare log + header and write first entry - lhd = "\n\t\t>>> SyNCopy v. {ver:s} <<< \n\n" +\ - "Created: {timestamp:s} \n\n" +\ - "System Profile: \n" +\ - "{sysver:s} \n" +\ - "ACME: {acver:s}\n" +\ - "Dask: {daver:s}\n" +\ - "NumPy: {npver:s}\n" +\ - "SciPy: {spver:s}\n\n" +\ - "--- LOG ---" - self._log_header = lhd.format(ver=__version__, - timestamp=time.asctime(), - sysver=sys.version, - acver=acme.__version__ if __acme__ else "--", - daver=dask.__version__ if __acme__ else "--", - npver=np.__version__, - spver=sp.__version__) - self._log = "" + # Write initial log entry self.log = "created {clname:s} object".format(clname=self.__class__.__name__) # Write version @@ -1262,8 +1467,8 @@ def __init__(self, data, select): varname="select", actual=select) if not isinstance(select, dict): raise SPYTypeError(select, "select", expected="dict") - supported = ["trials", "channels", "toi", "toilim", "foi", "foilim", - "tapers", "units", "eventids"] + supported = ["trials", "channels", "channels_i", "channels_j", "toi", + "toilim", "foi", "foilim", "tapers", "units", "eventids"] if not set(select.keys()).issubset(supported): lgl = "dict with one or all of the following keys: '" +\ "'".join(opt + "', " for opt in supported)[:-2] @@ -1276,7 +1481,7 @@ def __init__(self, data, select): # Set up lists of (a) all selectable properties (b) trial-dependent ones # and (c) selectors independent from trials - self._allProps = ["channel", "time", "freq", "taper", "unit", "eventid"] + self._allProps = ["channel", "channel_i", "channel_j", "time", "freq", "taper", "unit", "eventid"] self._byTrialProps = ["time", "unit", "eventid"] self._dimProps = list(self._allProps) for prop in self._byTrialProps: @@ -1347,8 +1552,35 @@ def channel(self): @channel.setter def channel(self, dataselect): data, select = dataselect + chanSpec = select.get("channels") + if self._dataClass == "CrossSpectralData": + if chanSpec is not None: + lgl = "`channel_i` and/or `channel_j` selectors for `CrossSpectralData`" + raise SPYValueError(legal=lgl, varname="select: channels", actual=data.__class__.__name__) + else: + return self._selection_setter(data, select, "channel", "channels") + @property + def channel_i(self): + """List or slice encoding principal channel-pair selection""" + return self._channel_i + + @channel_i.setter + def channel_i(self, dataselect): + data, select = dataselect + self._selection_setter(data, select, "channel_i", "channels_i") + + @property + def channel_j(self): + """List or slice encoding principal channel-pair selection""" + return self._channel_j + + @channel_j.setter + def channel_j(self, dataselect): + data, select = dataselect + self._selection_setter(data, select, "channel_j", "channels_j") + @property def time(self): """len(self.trials) list of lists/slices of by-trial time-selections""" diff --git a/syncopy/datatype/continuous_data.py b/syncopy/datatype/continuous_data.py index 96cd07ff8..485db60c6 100644 --- a/syncopy/datatype/continuous_data.py +++ b/syncopy/datatype/continuous_data.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -# +# # Syncopy's abstract base class for continuous data + regular children -# +# """Uniformly sampled (continuous data). @@ -9,25 +9,21 @@ """ # Builtin/3rd party package imports -import h5py -import os import inspect import numpy as np from abc import ABC from collections.abc import Iterator -from numpy.lib.format import open_memmap # Local imports from .base_data import BaseData, FauxTrial from .methods.definetrial import definetrial -from .methods.selectdata import selectdata from syncopy.shared.parsers import scalar_parser, array_parser -from syncopy.shared.errors import SPYValueError, SPYIOError +from syncopy.shared.errors import SPYValueError from syncopy.shared.tools import best_match from syncopy.plotting import _plot_analog from syncopy.plotting import _plot_spectral -__all__ = ["AnalogData", "SpectralData"] +__all__ = ["AnalogData", "SpectralData", "CrossSpectralData"] class ContinuousData(BaseData, ABC): @@ -38,15 +34,15 @@ class ContinuousData(BaseData, ABC): This class cannot be instantiated. Use one of the children instead. """ - + _infoFileProperties = BaseData._infoFileProperties + ("samplerate", "channel",) _hdfFileAttributeProperties = BaseData._hdfFileAttributeProperties + ("samplerate", "channel",) _hdfFileDatasetProperties = BaseData._hdfFileDatasetProperties + ("data",) - + @property def data(self): """array-like object representing data without trials - + Trials are concatenated along the time axis. """ @@ -57,7 +53,7 @@ def data(self): raise SPYValueError(legal=lgl, actual=act.format(self.filename), varname="data") return self._data - + @data.setter def data(self, inData): @@ -73,16 +69,14 @@ def __str__(self): ppattrs = [attr for attr in ppattrs if not (inspect.ismethod(getattr(self, attr)) or isinstance(getattr(self, attr), Iterator))] - + if self.__class__.__name__ == "CrossSpectralData": + ppattrs.remove("channel") ppattrs.sort() # Construct string for pretty-printing class attributes - dsep = "' x '" - dinfo = "" + dsep = " by " hdstr = "Syncopy {clname:s} object with fields\n\n" - ppstr = hdstr.format(diminfo=dinfo + "'" + \ - dsep.join(dim for dim in self.dimord) + "' " if self.dimord is not None else "Empty ", - clname=self.__class__.__name__) + ppstr = hdstr.format(clname=self.__class__.__name__) maxKeyLength = max([len(k) for k in ppattrs]) printString = "{0:>" + str(maxKeyLength + 5) + "} : {1:}\n" for attr in ppattrs: @@ -110,7 +104,10 @@ def __str__(self): valueString = "[" + " x ".join([str(numel) for numel in value.shape]) \ + "] element " + str(type(value)) elif isinstance(value, list): - valueString = "{0} element list".format(len(value)) + if attr == "dimord" and value is not None: + valueString = dsep.join(dim for dim in self.dimord) + else: + valueString = "{0} element list".format(len(value)) elif isinstance(value, dict): msg = "dictionary with {nk:s}keys{ks:s}" keylist = value.keys() @@ -122,14 +119,13 @@ def __str__(self): ppstr += printString.format(attr, valueString) ppstr += "\nUse `.log` to see object history" return ppstr - + @property def _shapes(self): if self.sampleinfo is not None: - sid = self.dimord.index("time") shp = [list(self.data.shape) for k in range(self.sampleinfo.shape[0])] for k, sg in enumerate(self.sampleinfo): - shp[k][sid] = sg[1] - sg[0] + shp[k][self._stackingDim] = sg[1] - sg[0] return [tuple(sp) for sp in shp] @property @@ -137,28 +133,28 @@ def channel(self): """ :class:`numpy.ndarray` : list of recording channel names """ # if data exists but no user-defined channel labels, create them on the fly if self._channel is None and self._data is not None: - nChannel = self.data.shape[self.dimord.index("channel")] + nChannel = self.data.shape[self.dimord.index("channel")] return np.array(["channel" + str(i + 1).zfill(len(str(nChannel))) - for i in range(nChannel)]) + for i in range(nChannel)]) return self._channel @channel.setter - def channel(self, channel): - + def channel(self, channel): + if channel is None: self._channel = None return - + if self.data is None: raise SPYValueError("Syncopy: Cannot assign `channels` without data. " + - "Please assign data first") - + "Please assign data first") + try: - array_parser(channel, varname="channel", ntype="str", + array_parser(channel, varname="channel", ntype="str", dims=(self.data.shape[self.dimord.index("channel")],)) except Exception as exc: raise exc - + self._channel = np.array(channel) @property @@ -171,7 +167,7 @@ def samplerate(self, sr): if sr is None: self._samplerate = None return - + try: scalar_parser(sr, varname="samplerate", lims=[np.finfo('float').eps, np.inf]) except Exception as exc: @@ -223,31 +219,30 @@ def time(self): # Helper function that grabs a single trial def _get_trial(self, trialno): idx = [slice(None)] * len(self.dimord) - sid = self.dimord.index("time") - idx[sid] = slice(int(self.sampleinfo[trialno, 0]), int(self.sampleinfo[trialno, 1])) + idx[self._stackingDim] = slice(int(self.sampleinfo[trialno, 0]), int(self.sampleinfo[trialno, 1])) return self._data[tuple(idx)] - + def _is_empty(self): return super()._is_empty() or self.samplerate is None - - # Helper function that spawns a `FauxTrial` object given actual trial information + + # Helper function that spawns a `FauxTrial` object given actual trial information def _preview_trial(self, trialno): """ Generate a `FauxTrial` instance of a trial - + Parameters ---------- trialno : int Number of trial the `FauxTrial` object is intended to mimic - + Returns ------- faux_trl : :class:`syncopy.datatype.base_data.FauxTrial` An instance of :class:`syncopy.datatype.base_data.FauxTrial` mainly - intended to be used in `noCompute` runs of + intended to be used in `noCompute` runs of :meth:`syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` - to avoid loading actual trial-data into memory. - + to avoid loading actual trial-data into memory. + See also -------- syncopy.datatype.base_data.FauxTrial : class definition and further details @@ -255,20 +250,19 @@ def _preview_trial(self, trialno): """ shp = list(self.data.shape) idx = [slice(None)] * len(self.dimord) - tidx = self.dimord.index("time") stop = int(self.sampleinfo[trialno, 1]) start = int(self.sampleinfo[trialno, 0]) - shp[tidx] = stop - start - idx[tidx] = slice(start, stop) - + shp[self._stackingDim] = stop - start + idx[self._stackingDim] = slice(start, stop) + # process existing data selections if self._selection is not None: - + # time-selection is most delicate due to trial-offset tsel = self._selection.time[self._selection.trials.index(trialno)] if isinstance(tsel, slice): if tsel.start is not None: - tstart = tsel.start + tstart = tsel.start else: tstart = 0 if tsel.stop is not None: @@ -279,15 +273,17 @@ def _preview_trial(self, trialno): # account for trial offsets an compute slicing index + shape start = start + tstart stop = start + (tstop - tstart) - idx[tidx] = slice(start, stop) - shp[tidx] = stop - start - + idx[self._stackingDim] = slice(start, stop) + shp[self._stackingDim] = stop - start + else: - idx[tidx] = [tp + start for tp in tsel] - shp[tidx] = len(tsel) + idx[self._stackingDim] = [tp + start for tp in tsel] + shp[self._stackingDim] = len(tsel) - # process the rest - for dim in ["channel", "freq", "taper"]: + # process the rest + dims = list(self.dimord) + dims.pop(self._stackingDim) + for dim in dims: sel = getattr(self._selection, dim) if sel: dimIdx = self.dimord.index(dim) @@ -308,38 +304,38 @@ def _preview_trial(self, trialno): idx[dimIdx] = slice(begin, end, delta) else: shp[dimIdx] = len(sel) - + return FauxTrial(shp, tuple(idx), self.data.dtype, self.dimord) - + # Helper function that extracts timing-related indices def _get_time(self, trials, toi=None, toilim=None): """ Get relative by-trial indices of time-selections - + Parameters ---------- trials : list List of trial-indices to perform selection on toi : None or list - Time-points to be selected (in seconds) on a by-trial scale. + Time-points to be selected (in seconds) on a by-trial scale. toilim : None or list Time-window to be selected (in seconds) on a by-trial scale - + Returns ------- timing : list of lists - List of by-trial sample-indices corresponding to provided + List of by-trial sample-indices corresponding to provided time-selection. If both `toi` and `toilim` are `None`, `timing` - is a list of universal (i.e., ``slice(None)``) selectors. - + is a list of universal (i.e., ``slice(None)``) selectors. + Notes ----- - This class method is intended to be solely used by - :class:`syncopy.datatype.base_data.Selector` objects and thus has purely + This class method is intended to be solely used by + :class:`syncopy.datatype.base_data.Selector` objects and thus has purely auxiliary character. Therefore, all input sanitization and error checking - is left to :class:`syncopy.datatype.base_data.Selector` and not - performed here. - + is left to :class:`syncopy.datatype.base_data.Selector` and not + performed here. + See also -------- syncopy.datatype.base_data.Selector : Syncopy data selectors @@ -353,7 +349,7 @@ def _get_time(self, trials, toi=None, toilim=None): timing.append(slice(selTime[0], selTime[-1] + 1, 1)) else: timing.append(selTime) - + elif toi is not None: for trlno in trials: _, selTime = best_match(self.time[trlno], toi) @@ -363,32 +359,32 @@ def _get_time(self, trials, toi=None, toilim=None): if timeSteps.min() == timeSteps.max() == 1: selTime = slice(selTime[0], selTime[-1] + 1, 1) timing.append(selTime) - + else: timing = [slice(None)] * len(trials) - + return timing # Make instantiation persistent in all subclasses - def __init__(self, data=None, channel=None, samplerate=None, **kwargs): - + def __init__(self, data=None, channel=None, samplerate=None, **kwargs): + self._channel = None self._samplerate = None self._data = None - + + self.samplerate = samplerate # use setter for error-checking + # Call initializer super().__init__(data=data, **kwargs) - + self.channel = channel - self.samplerate = samplerate # use setter for error-checking - self.data = data - + if self.data is not None: # In case of manual data allocation (reading routine would leave a # mark in `cfg`), fill in missing info - if len(self.cfg) == 0: - + if self.sampleinfo is None: + # First, fill in dimensional info definetrial(self, kwargs.get("trialdefinition")) @@ -401,44 +397,28 @@ class AnalogData(ContinuousData): position etc. The data is always stored as a two-dimensional array on disk. On disk, Trials are - concatenated along the time axis. + concatenated along the time axis. Data is only read from disk on demand, similar to memory maps and HDF5 files. """ - + _infoFileProperties = ContinuousData._infoFileProperties + ("_hdr",) _defaultDimord = ["time", "channel"] - + _stackingDimLabel = "time" + # Attach plotting routines to not clutter the core module code singlepanelplot = _plot_analog.singlepanelplot multipanelplot = _plot_analog.multipanelplot - + @property def hdr(self): """dict with information about raw data - + This property is empty for data created by Syncopy. """ return self._hdr - # Selector method FIXME: use plotting-routine-like patching? - def selectdata(self, trials=None, channels=None, toi=None, toilim=None): - """ - Create new `AnalogData` object from selection - - Please refer to :func:`syncopy.selectdata` for detailed usage information. - - Examples - -------- - >>> ang2chan = ang.selectdata(channels=["channel01", "channel02"]) - - See also - -------- - syncopy.selectdata : create new objects via deep-copy selections - """ - return selectdata(self, trials=trials, channels=channels, toi=toi, toilim=toilim) - # "Constructor" def __init__(self, data=None, @@ -448,14 +428,14 @@ def __init__(self, channel=None, dimord=None): """Initialize an :class:`AnalogData` object. - + Parameters ---------- - data : 2D :class:numpy.ndarray or HDF5 dataset - multi-channel time series data with uniform sampling + data : 2D :class:numpy.ndarray or HDF5 dataset + multi-channel time series data with uniform sampling filename : str path to target filename that should be used for writing - trialdefinition : :class:`EventData` object or Mx3 array + trialdefinition : :class:`EventData` object or Mx3 array [start, stop, trigger_offset] sample indices for `M` trials samplerate : float sampling rate in Hz @@ -465,17 +445,17 @@ def __init__(self, 1. `filename` + `data` : create hdf dataset incl. sampleinfo @filename 2. just `data` : try to attach data (error checking done by :meth:`AnalogData.data.setter`) - + See also -------- :func:`syncopy.definetrial` - + """ - # FIXME: I think escalating `dimord` to `BaseData` should be sufficient so that + # FIXME: I think escalating `dimord` to `BaseData` should be sufficient so that # the `if any(key...) loop in `BaseData.__init__()` takes care of assigning a default dimord if data is not None and dimord is None: - dimord = self._defaultDimord + dimord = self._defaultDimord # Assign default (blank) values self._hdr = None @@ -488,57 +468,23 @@ def __init__(self, channel=channel, dimord=dimord) - # # Overload ``copy`` method to account for `VirtualData` memmaps - # def copy(self, deep=False): - # """Create a copy of the data object in memory. - - # Parameters - # ---------- - # deep : bool - # If `True`, a copy of the underlying data file is created in the temporary Syncopy folder - - - # Returns - # ------- - # AnalogData - # in-memory copy of AnalogData object - - # See also - # -------- - # save_spy - - # """ - - # cpy = copy(self) - - # if deep: - # if isinstance(self.data, VirtualData): - # print("SyNCoPy core - copy: Deep copy not possible for " + - # "VirtualData objects. Please use `save_spy` instead. ") - # return - # elif isinstance(self.data, (np.memmap, h5py.Dataset)): - # self.data.flush() - # filename = self._gen_filename() - # shutil.copyfile(self._filename, filename) - # cpy.data = filename - # return cpy - class SpectralData(ContinuousData): - """Multi-channel, real or complex spectral data + """ + Multi-channel, real or complex spectral data This class can be used for representing any data with a frequency, channel, and optionally a time axis. The datatype can be complex or float. - """ - + _infoFileProperties = ContinuousData._infoFileProperties + ("taper", "freq",) _defaultDimord = ["time", "taper", "freq", "channel"] + _stackingDimLabel = "time" # Attach plotting routines to not clutter the core module code singlepanelplot = _plot_spectral.singlepanelplot multipanelplot = _plot_spectral.multipanelplot - + @property def taper(self): """ :class:`numpy.ndarray` : list of window functions used """ @@ -550,22 +496,21 @@ def taper(self): @taper.setter def taper(self, tpr): - + if tpr is None: self._taper = None return - + if self.data is None: print("Syncopy core - taper: Cannot assign `taper` without data. "+\ "Please assing data first") - return - + try: array_parser(tpr, dims=(self.data.shape[self.dimord.index("taper")],), varname="taper", ntype="str", ) except Exception as exc: raise exc - + self._taper = np.array(tpr) @property @@ -578,47 +523,28 @@ def freq(self): @freq.setter def freq(self, freq): - + if freq is None: self._freq = None return - + if self.data is None: print("Syncopy core - freq: Cannot assign `freq` without data. "+\ "Please assing data first") return try: - + array_parser(freq, varname="freq", hasnan=False, hasinf=False, dims=(self.data.shape[self.dimord.index("freq")],)) except Exception as exc: raise exc - + self._freq = np.array(freq) - # Selector method - def selectdata(self, trials=None, channels=None, toi=None, toilim=None, - foi=None, foilim=None, tapers=None): - """ - Create new `SpectralData` object from selection - - Please refer to :func:`syncopy.selectdata` for detailed usage information. - - Examples - -------- - >>> spcBand = spc.selectdata(foilim=[10, 40]) - - See also - -------- - syncopy.selectdata : create new objects via deep-copy selections - """ - return selectdata(self, trials=trials, channels=channels, toi=toi, - toilim=toilim, foi=foi, foilim=foilim, tapers=tapers) - # Helper function that extracts frequency-related indices def _get_freq(self, foi=None, foilim=None): """ - Coming soon... + Coming soon... Error checking is performed by `Selector` class """ if foilim is not None: @@ -626,7 +552,7 @@ def _get_freq(self, foi=None, foilim=None): selFreq = selFreq.tolist() if len(selFreq) > 1: selFreq = slice(selFreq[0], selFreq[-1] + 1, 1) - + elif foi is not None: _, selFreq = best_match(self.freq, foi) selFreq = selFreq.tolist() @@ -634,12 +560,12 @@ def _get_freq(self, foi=None, foilim=None): freqSteps = np.diff(selFreq) if freqSteps.min() == freqSteps.max() == 1: selFreq = slice(selFreq[0], selFreq[-1] + 1, 1) - + else: selFreq = slice(None) - + return selFreq - + # "Constructor" def __init__(self, data=None, @@ -653,11 +579,11 @@ def __init__(self, self._taper = None self._freq = None - + # FIXME: See similar comment above in `AnalogData.__init__()` if data is not None and dimord is None: dimord = self._defaultDimord - + # Call parent initializer super().__init__(data=data, filename=filename, @@ -684,3 +610,159 @@ def __init__(self, self.freq = [1] if taper is not None: self.taper = ['taper'] + + +class CrossSpectralData(ContinuousData): + """ + Multi-channel real or complex spectral connectivity data + + This class can be used for representing channel-channel interactions involving + frequency and optionally time or lag. The datatype can be complex or float. + """ + + # Adapt `infoFileProperties` and `hdfFileAttributeProperties` from `ContinuousData` + _infoFileProperties = BaseData._infoFileProperties +\ + ("samplerate", "channel_i", "channel_j", "freq", ) + _hdfFileAttributeProperties = BaseData._hdfFileAttributeProperties +\ + ("samplerate", "channel_i", "channel_j", "freq", ) + _defaultDimord = ["time", "freq", "channel_i", "channel_j"] + _stackingDimLabel = "time" + _channel_i = None + _channel_j = None + _samplerate = None + _data = None + + # Steal frequency-related stuff from `SpectralData` + _get_freq = SpectralData._get_freq + freq = SpectralData.freq + + # override channel property to avoid accidental access + @property + def channel(self): + return "see channel_i and channel_j" + + @channel.setter + def channel(self, channel): + if channel is None: + pass + else: + msg = f"CrossSpectralData has no 'channel' to set but dimord: {self._dimord}" + raise NotImplementedError(msg) + + @property + def channel_i(self): + """ :class:`numpy.ndarray` : list of recording channel names """ + # if data exists but no user-defined channel labels, create them on the fly + if self._channel_i is None and self._data is not None: + nChannel = self.data.shape[self.dimord.index("channel_i")] + return np.array(["channel_i-" + str(i + 1).zfill(len(str(nChannel))) + for i in range(nChannel)]) + + return self._channel_i + + @channel_i.setter + def channel_i(self, channel_i): + """ :class:`numpy.ndarray` : list of channel labels """ + if channel_i is None: + self._channel_i = None + return + + if self.data is None: + raise SPYValueError("Syncopy: Cannot assign `channels` without data. " + + "Please assign data first") + + try: + array_parser(channel_i, varname="channel_i", ntype="str", + dims=(self.data.shape[self.dimord.index("channel_i")],)) + except Exception as exc: + raise exc + + self._channel_i = np.array(channel_i) + + @property + def channel_j(self): + """ :class:`numpy.ndarray` : list of recording channel names """ + # if data exists but no user-defined channel labels, create them on the fly + if self._channel_j is None and self._data is not None: + nChannel = self.data.shape[self.dimord.index("channel_j")] + return np.array(["channel_j-" + str(i + 1).zfill(len(str(nChannel))) + for i in range(nChannel)]) + + return self._channel_j + + @channel_j.setter + def channel_j(self, channel_j): + """ :class:`numpy.ndarray` : list of channel labels """ + if channel_j is None: + self._channel_j = None + return + + if self.data is None: + raise SPYValueError("Syncopy: Cannot assign `channels` without data. " + + "Please assign data first") + + try: + array_parser(channel_j, varname="channel_j", ntype="str", + dims=(self.data.shape[self.dimord.index("channel_j")],)) + except Exception as exc: + raise exc + + self._channel_j = np.array(channel_j) + + # # Local 2d -> 1d channel index converter + # def _ind2sub(self, channel1, channel2): + # """Convert 2d channel tuple to linear 1d index""" + + # chanIdx = [] + # for ck, channel in enumerate((channel1, channel2)): + # target = getattr(self, "_channel{}".format(ck + 1)) + # if isinstance(channel, str): + # if channel == "all": + # channel = None + # else: + # raise SPYValueError(legal="'all' or `None` or list/array", + # varname="channels", actual=channel) + # if channel is None: + # channel = target + # if isinstance(channel, range): + # channel = list(channel) + # elif isinstance(channel, slice): + # channel = target[channel] + + # # Use set comparison to ensure (a) no mixed-type selections (['a', 2, 'c']) + # # and (b) no invalid selections ([-99, 0.01]) + # if not set(channel).issubset(target): + # lgl = "list/array of existing channel names or indices" + # raise SPYValueError(legal=lgl, varname="channel") + # if not all(isinstance(c, str) for c in channel): + # target = np.arange(target.size) + + # # Preserve order and duplicates of selection - don't use `np.isin` here! + # chanIdx.append([np.where(target == c)[0] for c in channel]) + + # # Almost: `ravel_multi_index` expects a tuple of arrays, so perform some zipping + # linearIndex = [(c1, c2) for c1 in chanIdx[0] for c2 in chanIdx[1]] + # return np.ravel_multi_index(tuple(zip(*linearIndex)), + # dims=(self._channel1.size, self._channel2.size)) + + def __init__(self, + data=None, + filename=None, + channel_i=None, + channel_j=None, + samplerate=None, + freq=None, + dimord=None): + + # Set dimensional labels + self.dimord = dimord + # set frequencies + self.freq = freq + + # Call parent initializer + super().__init__(data=data, + filename=filename, + samplerate=samplerate, + freq=freq, + dimord=dimord) + diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index bdb1bd53a..c8ae89521 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -# +# # Syncopy's abstract base class for discrete data + regular children -# +# # Builtin/3rd party package imports import numpy as np @@ -36,7 +36,7 @@ class DiscreteData(BaseData, ABC): @property def data(self): """array-like object representing data without trials - + Trials are concatenated along the time axis. """ @@ -47,7 +47,7 @@ def data(self): raise SPYValueError(legal=lgl, actual=act.format(self.filename), varname="data") return self._data - + @data.setter def data(self, inData): @@ -56,24 +56,20 @@ def data(self, inData): if inData is None: return - def __str__(self): + def __str__(self): # Get list of print-worthy attributes ppattrs = [attr for attr in self.__dir__() if not (attr.startswith("_") or attr in ["log", "trialdefinition", "hdr"])] ppattrs = [attr for attr in ppattrs if not (inspect.ismethod(getattr(self, attr)) or isinstance(getattr(self, attr), Iterator))] - + ppattrs.sort() # Construct string for pretty-printing class attributes - dinfo = " '" + self._classname_to_extension()[1:] + "' x " - dsep = "'-'" - + dsep = " by " hdstr = "Syncopy {clname:s} object with fields\n\n" - ppstr = hdstr.format(diminfo=dinfo + "'" + \ - dsep.join(dim for dim in self.dimord) + "' " if self.dimord is not None else "Empty ", - clname=self.__class__.__name__) + ppstr = hdstr.format(clname=self.__class__.__name__) maxKeyLength = max([len(k) for k in ppattrs]) printString = "{0:>" + str(maxKeyLength + 5) + "} : {1:}\n" for attr in ppattrs: @@ -101,7 +97,10 @@ def __str__(self): valueString = "[" + " x ".join([str(numel) for numel in value.shape]) \ + "] element " + str(type(value)) elif isinstance(value, list): - valueString = "{0} element list".format(len(value)) + if attr == "dimord" and value is not None: + valueString = dsep.join(dim for dim in self.dimord) + else: + valueString = "{0} element list".format(len(value)) elif isinstance(value, dict): msg = "dictionary with {nk:s}keys{ks:s}" keylist = value.keys() @@ -112,7 +111,7 @@ def __str__(self): valueString = str(value) ppstr += printString.format(attr, valueString) ppstr += "\nUse `.log` to see object history" - return ppstr + return ppstr @property def hdr(self): @@ -139,7 +138,7 @@ def samplerate(self, sr): if sr is None: self._samplerate = None return - + try: scalar_parser(sr, varname="samplerate", lims=[1, np.inf]) except Exception as exc: @@ -156,7 +155,7 @@ def trialid(self, trlid): if trlid is None: self._trialid = None return - + if self.data is None: print("SyNCoPy core - trialid: Cannot assign `trialid` without data. " + "Please assing data first") @@ -190,69 +189,69 @@ def trialtime(self): # Helper function that grabs a single trial def _get_trial(self, trialno): return self._data[self.trialid == trialno, :] - - # Helper function that spawns a `FauxTrial` object given actual trial information + + # Helper function that spawns a `FauxTrial` object given actual trial information def _preview_trial(self, trialno): """ Generate a `FauxTrial` instance of a trial - + Parameters ---------- trialno : int Number of trial the `FauxTrial` object is intended to mimic - + Returns ------- faux_trl : :class:`syncopy.datatype.base_data.FauxTrial` An instance of :class:`syncopy.datatype.base_data.FauxTrial` mainly - intended to be used in `noCompute` runs of + intended to be used in `noCompute` runs of :meth:`syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` - to avoid loading actual trial-data into memory. - + to avoid loading actual trial-data into memory. + See also -------- syncopy.datatype.base_data.FauxTrial : class definition and further details syncopy.shared.computational_routine.ComputationalRoutine : Syncopy compute engine """ - + trialIdx = np.where(self.trialid == trialno)[0] nCol = len(self.dimord) idx = [trialIdx.tolist(), slice(0, nCol)] if self._selection is not None: # selections are harmonized, just take `.time` idx[0] = trialIdx[self._selection.time[self._selection.trials.index(trialno)]].tolist() shp = [len(idx[0]), nCol] - + return FauxTrial(shp, tuple(idx), self.data.dtype, self.dimord) - + # Helper function that extracts by-trial timing-related indices def _get_time(self, trials, toi=None, toilim=None): """ Get relative by-trial indices of time-selections - + Parameters ---------- trials : list List of trial-indices to perform selection on toi : None or list - Time-points to be selected (in seconds) on a by-trial scale. + Time-points to be selected (in seconds) on a by-trial scale. toilim : None or list Time-window to be selected (in seconds) on a by-trial scale - + Returns ------- timing : list of lists - List of by-trial sample-indices corresponding to provided + List of by-trial sample-indices corresponding to provided time-selection. If both `toi` and `toilim` are `None`, `timing` - is a list of universal (i.e., ``slice(None)``) selectors. - + is a list of universal (i.e., ``slice(None)``) selectors. + Notes ----- - This class method is intended to be solely used by - :class:`syncopy.datatype.base_data.Selector` objects and thus has purely + This class method is intended to be solely used by + :class:`syncopy.datatype.base_data.Selector` objects and thus has purely auxiliary character. Therefore, all input sanitization and error checking - is left to :class:`syncopy.datatype.base_data.Selector` and not - performed here. - + is left to :class:`syncopy.datatype.base_data.Selector` and not + performed here. + See also -------- syncopy.datatype.base_data.Selector : Syncopy data selectors @@ -275,7 +274,7 @@ def _get_time(self, trials, toi=None, toilim=None): if sampSteps.min() == sampSteps.max() == 1: idxList = slice(idxList[0], idxList[-1] + 1, 1) timing.append(idxList) - + elif toi is not None: allTrials = self.trialtime for trlno in trials: @@ -296,32 +295,31 @@ def _get_time(self, trials, toi=None, toilim=None): if sampSteps.min() == sampSteps.max() == 1: idxList = slice(idxList[0], idxList[-1] + 1, 1) timing.append(idxList) - + else: timing = [slice(None)] * len(trials) - + return timing def __init__(self, data=None, samplerate=None, trialid=None, **kwargs): # Assign (default) values self._trialid = None - self._samplerate = None + self._samplerate = None self._hdr = None self._data = None + self.samplerate = samplerate + self.trialid = trialid + # Call initializer super().__init__(data=data, **kwargs) - self.samplerate = samplerate - self.trialid = trialid - self.data = data - if self.data is not None: # In case of manual data allocation (reading routine would leave a # mark in `cfg`), fill in missing info - if len(self.cfg) == 0: + if self.sampleinfo is None: # Fill in dimensional info definetrial(self, kwargs.get("trialdefinition")) @@ -332,7 +330,7 @@ class SpikeData(DiscreteData): This class can be used for representing spike trains. The data is always stored as a two-dimensional [nSpikes x 3] array on disk with the columns - being ``["sample", "channel", "unit"]``. + being ``["sample", "channel", "unit"]``. Data is only read from disk on demand, similar to memory maps and HDF5 files. @@ -342,16 +340,17 @@ class SpikeData(DiscreteData): _infoFileProperties = DiscreteData._infoFileProperties + ("channel", "unit",) _hdfFileAttributeProperties = DiscreteData._hdfFileAttributeProperties + ("channel",) _defaultDimord = ["sample", "channel", "unit"] - + _stackingDimLabel = "sample" + @property def channel(self): - """ :class:`numpy.ndarray` : list of original channel names for each unit""" + """ :class:`numpy.ndarray` : list of original channel names for each unit""" # if data exists but no user-defined channel labels, create them on the fly if self._channel is None and self._data is not None: channelNumbers = np.unique(self.data[:, self.dimord.index("channel")]) return np.array(["channel" + str(int(i + 1)).zfill(len(str(channelNumbers.max() + 1))) for i in channelNumbers]) - + return self._channel @channel.setter @@ -361,22 +360,22 @@ def channel(self, chan): return if self.data is None: raise SPYValueError("Syncopy: Cannot assign `channels` without data. " + - "Please assign data first") + "Please assign data first") try: array_parser(chan, varname="channel", ntype="str") except Exception as exc: raise exc - + # Remove duplicate entries from channel array but preserve original order # (e.g., `[2, 0, 0, 1]` -> `[2, 0, 1`); allows for complex subset-selections _, idx = np.unique(chan, return_index=True) - chan = np.array(chan)[idx] + chan = np.array(chan)[np.sort(idx)] nchan = np.unique(self.data[:, self.dimord.index("channel")]).size if chan.size != nchan: lgl = "channel label array of length {0:d}".format(nchan) act = "array of length {0:d}".format(chan.size) raise SPYValueError(legal=lgl, varname="channel", actual=act) - + self._channel = chan @property @@ -393,11 +392,11 @@ def unit(self, unit): if unit is None: self._unit = None return - + if self.data is None: raise SPYValueError("Syncopy - SpikeData - unit: Cannot assign `unit` without data. " + "Please assign data first") - + nunit = np.unique(self.data[:, self.dimord.index("unit")]).size try: array_parser(unit, varname="unit", ntype="str", dims=(nunit,)) @@ -405,51 +404,33 @@ def unit(self, unit): raise exc self._unit = np.array(unit) - # Selector method - def selectdata(self, trials=None, toi=None, toilim=None, units=None, channels=None): - """ - Create new `SpikeData` object from selection - - Please refer to :func:`syncopy.selectdata` for detailed usage information. - - Examples - -------- - >>> spkUnit01 = spk.selectdata(units=[0, 1]) - - See also - -------- - syncopy.selectdata : create new objects via deep-copy selections - """ - return selectdata(self, trials=trials, channels=channels, toi=toi, - toilim=toilim, units=units) - # Helper function that extracts by-trial unit-indices def _get_unit(self, trials, units=None): """ Get relative by-trial indices of unit selections - + Parameters ---------- trials : list List of trial-indices to perform selection on units : None or list List of unit-indices to be selected - + Returns ------- indices : list of lists - List of by-trial sample-indices corresponding to provided - unit-selection. If `units` is `None`, `indices` is a list of universal - (i.e., ``slice(None)``) selectors. - + List of by-trial sample-indices corresponding to provided + unit-selection. If `units` is `None`, `indices` is a list of universal + (i.e., ``slice(None)``) selectors. + Notes ----- - This class method is intended to be solely used by - :class:`syncopy.datatype.base_data.Selector` objects and thus has purely + This class method is intended to be solely used by + :class:`syncopy.datatype.base_data.Selector` objects and thus has purely auxiliary character. Therefore, all input sanitization and error checking - is left to :class:`syncopy.datatype.base_data.Selector` and not - performed here. - + is left to :class:`syncopy.datatype.base_data.Selector` and not + performed here. + See also -------- syncopy.datatype.base_data.Selector : Syncopy data selectors @@ -469,7 +450,7 @@ def _get_unit(self, trials, units=None): indices.append(trialUnits) else: indices = [slice(None)] * len(trials) - + return indices # "Constructor" @@ -489,13 +470,13 @@ def __init__(self, filename : str path to filename or folder (spy container) - trialdefinition : :class:`EventData` object or nTrials x 3 array + trialdefinition : :class:`EventData` object or nTrials x 3 array [start, stop, trigger_offset] sample indices for `M` trials samplerate : float sampling rate in Hz channel : str or list/array(str) original channel names - unit : str or list/array(str) + unit : str or list/array(str) names of all units dimord : list(str) ordered list of dimension labels @@ -514,7 +495,7 @@ def __init__(self, self._unit = None self._channel = None - + # Call parent initializer super().__init__(data=data, filename=filename, @@ -536,61 +517,45 @@ class EventData(DiscreteData): Data is only read from disk on demand, similar to memory maps and HDF5 files. - """ - + """ + _defaultDimord = ["sample", "eventid"] - + _stackingDimLabel = "sample" + @property def eventid(self): """numpy.ndarray(int): integer event code assocated with each event""" if self.data is None: return None return np.unique(self.data[:, self.dimord.index("eventid")]) - - # Selector method - def selectdata(self, trials=None, toi=None, toilim=None, eventids=None): - """ - Create new `EventData` object from selection - - Please refer to :func:`syncopy.selectdata` for detailed usage information. - - Examples - -------- - >>> evtStimOn = evt.selectdata(eventids=[1]) - - See also - -------- - syncopy.selectdata : create new objects via deep-copy selections - """ - return selectdata(self, trials=trials, toi=toi, toilim=toilim, eventids=eventids) # Helper function that extracts by-trial eventid-indices def _get_eventid(self, trials, eventids=None): """ Get relative by-trial indices of event-id selections - + Parameters ---------- trials : list List of trial-indices to perform selection on eventids : None or list List of event-id-indices to be selected - + Returns ------- indices : list of lists - List of by-trial sample-indices corresponding to provided - event-id-selection. If `eventids` is `None`, `indices` is a list of - universal (i.e., ``slice(None)``) selectors. - + List of by-trial sample-indices corresponding to provided + event-id-selection. If `eventids` is `None`, `indices` is a list of + universal (i.e., ``slice(None)``) selectors. + Notes ----- - This class method is intended to be solely used by - :class:`syncopy.datatype.base_data.Selector` objects and thus has purely + This class method is intended to be solely used by + :class:`syncopy.datatype.base_data.Selector` objects and thus has purely auxiliary character. Therefore, all input sanitization and error checking - is left to :class:`syncopy.datatype.base_data.Selector` and not - performed here. - + is left to :class:`syncopy.datatype.base_data.Selector` and not + performed here. + See also -------- syncopy.datatype.base_data.Selector : Syncopy data selectors @@ -610,9 +575,9 @@ def _get_eventid(self, trials, eventids=None): indices.append(trialEvents) else: indices = [slice(None)] * len(trials) - + return indices - + # "Constructor" def __init__(self, data=None, @@ -628,10 +593,10 @@ def __init__(self, filename : str path to filename or folder (spy container) - trialdefinition : :class:`EventData` object or nTrials x 3 array + trialdefinition : :class:`EventData` object or nTrials x 3 array [start, stop, trigger_offset] sample indices for `M` trials samplerate : float - sampling rate in Hz + sampling rate in Hz dimord : list(str) ordered list of dimension labels diff --git a/syncopy/datatype/methods/arithmetic.py b/syncopy/datatype/methods/arithmetic.py new file mode 100644 index 000000000..f05178094 --- /dev/null +++ b/syncopy/datatype/methods/arithmetic.py @@ -0,0 +1,503 @@ +# -*- coding: utf-8 -*- +# +# Syncopy object arithmetics +# + +# Builtin/3rd party package imports +import numbers +import numpy as np +import h5py + +# Local imports +from syncopy import __acme__ +from syncopy.shared.parsers import data_parser +from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYWarning +from syncopy.shared.computational_routine import ComputationalRoutine +from syncopy.shared.kwarg_decorators import unwrap_io +from syncopy.shared.computational_routine import ComputationalRoutine +if __acme__: + import dask.distributed as dd + +__all__ = [] + + +# Main entry point for overloaded operators +def _process_operator(obj1, obj2, operator): + """ + Perform binary arithmetic operation on Syncopy data object + + Parameters + ---------- + obj1 : Syncopy data class or Python object + Depending on left/right application of arithmetic operator, `obj1` may be + either a Syncopy class or any Python object + obj2 : Syncopy data class or Python object + Depending on left/right application of arithmetic operator, `obj2` may be + either a Syncopy class or any Python object + operator : str + Operation to be performed encoded as string. Currently supported operators + are `'+'`, `'-'`, `'*'`, `'/'` and `'**'` (i.e., `'pow'`). + + Returns + ------- + res : Syncopy object + Result of arithmetic operation + + Notes + ----- + All arithmetic operations are performed on a per-trial basis. This means, + any data not covered by a Syncopy object's `trialdefinition` will not be + affected by the arithmetic operation. + Note further, that error checking is only performed on a very basic level, i.e., + the code ensures that instances of different classes are not mashed together + (e.g., ``AnalogData + SpectralData``) and that objects have compatible trial + counts and dtypes (no mixing of complex/real data). However, as long as trial + shapes align, it is possible to process objects w/diverging `samplerate`, + `channels`, `freqs` etc. The reason for this object parsing leniency is that + it might be interesting/necessary to manipulate objects arising from different + configurations (e.g., subtract channel `x` in `obj1` from channel `y` in `obj2`). + + See also + -------- + _parse_input : prepare objects for arithmetic operations + _perform_computation : execute arithmetic operation + """ + baseObj, operand, operand_dat, opres_type, operand_idxs = _parse_input(obj1, obj2, operator) + return _perform_computation(baseObj, operand, operand_dat, operand_idxs, opres_type, operator) + + +# Error checking and input preparation +def _parse_input(obj1, obj2, operator): + """ + Prepare objects for performing binary arithmetics + + Parameters + ---------- + obj1 : Syncopy data class or Python object + See :func:`_process_operator` for details. + obj2 : Syncopy data class or Python object + See :func:`_process_operator` for details. + operator : str + See :func:`_process_operator` for details. + + Returns + ------- + baseObj : Syncopy data object + The "base" object to perform arithmetics on. By default, the left object + is considered as base (if possible), i.e., in the expression ``data1 + data2``, + `data1` is defined as base object + operand : Syncopy data object, scalar or array-like + Term to perform arithmetic operation with. + operand_dat : dict or scalar or array-like + If `operand` is a scalar, list or NumPy ndarray then ``operand_dat == operand``. + If `operand` is a Syncopy object, then `operand_dat` is a dictionary with + keys `"filename"` (pointing to the HDF5 backing device of `operand`) and + `"dsetname"`(name of the corresponding dataset(s) of `operand`). + opres_type : dtype + Numerical type of the Syncopy object resulting from applying the arithmetic + operation. + operand_idxs : None or list + If `operand` is a scalar, list or NumPy ndarray then `operand_idxs` is + `None`. If `operand` is a Syncopy object, then `operand_idxs` is a list + containing the array indices of `operands` data(subset) for each (selected) + trial. + + Note + ---- + The distinction between `baseObj` and `operand` is not only syntactic sugar + but has consequences if both `baseObj` and `operand` are Syncopy objects: + the `baseObj` is allowed to come with any valid subset selection (may require + advanced indexing involving multiple slice/list combinations, might include + repetitions and be unordered). Conversely, the `operand` object can only + contain `simple` selections (no fancy indexing allowed, no repetitions or + unordered selections). This restriction simplifies the required HDF dataset + indexing considerably. + """ + + # Determine which input is a Syncopy object (depending on lef/right application of + # operator, i.e., `data + 1` or `1 + data`). Can be both as well, but we just need + # one `baseObj` to get going + if "BaseData" in str(obj1.__class__.__mro__): + baseObj = obj1 + operand = obj2 + elif "BaseData" in str(obj2.__class__.__mro__): + baseObj = obj2 + operand = obj1 + + # Ensure base object is not discrete + if "DiscreteData" in str(baseObj.__class__.__mro__): + lgl = "`AnalogData`, `SpectralData` or `CrossSpectralData`" + raise SPYTypeError(baseObj, varname="base", expected=lgl) + + # Ensure our base object is not empty + try: + data_parser(baseObj, varname="base", empty=False) + except Exception as exc: + raise exc + + # If no active selection is present, create a "fake" all-to-all selection + # to harmonize processing down the road (and attach `_cleanup` attribute for later removal) + if baseObj._selection is None: + baseObj.selectdata(inplace=True) + baseObj._selection._cleanup = True + baseTrialList = baseObj._selection.trials + + # Use the `_preview_trial` functionality of Syncopy objects to get each trial's + # shape and dtype (existing selections are taken care of automatically) + baseTrials = [baseObj._preview_trial(trlno) for trlno in baseTrialList] + + # Depending on the what is thrown at `baseObj` perform more or less extensive parsing + # First up: operand is a scalar + if isinstance(operand, numbers.Number): + + # Don't allow `np.inf` manipulations and catch zero-divisions + if np.isinf(operand): + raise SPYValueError("finite scalar", varname="operand", actual=str(operand)) + if operator == "/" and operand == 0: + raise SPYValueError("non-zero scalar for division", varname="operand", actual=str(operand)) + + # Ensure complex and real values are not mashed up + _check_complex_operand(baseTrials, operand, "scalar") + + # Determine exact numeric type of operation's result + opres_type = np.result_type(*(trl.dtype for trl in baseTrials), operand) + + # That's it set output vars + operand_dat = operand + operand_idxs = None + + # Operand is array-like + elif isinstance(operand, (np.ndarray, list)): + + # First, ensure operand is a NumPy array to make things easier + operand = np.array(operand) + + # Ensure complex and real values are not mashed up + _check_complex_operand(baseTrials, operand, "array") + + # Determine exact numeric type of the operation's result + opres_type = np.result_type(*(trl.dtype for trl in baseTrials), operand.dtype) + + # Ensure shapes match up + if not all(trl.shape == operand.shape for trl in baseTrials): + lgl = "array of compatible shape" + act = "array with shape {}" + raise SPYValueError(lgl, varname="operand", actual=act.format(operand.shape)) + + # No more info needed, the array is the only quantity we need + operand_dat = operand + operand_idxs = None + + # All good, nevertheless warn of potential havoc this operation may cause... + msg = "Performing arithmetic with NumPy arrays may cause inconsistency " +\ + "in Syncopy objects (channels, samplerate, trialdefintions etc.)" + SPYWarning(msg, caller=operator) + + # Operand is another Syncopy object + elif "BaseData" in str(operand.__class__.__mro__): + + # Ensure operand object class, and `dimord` match up (and it's non-empty) + try: + data_parser(operand, varname="operand", dimord=baseObj.dimord, + dataclass=baseObj.__class__.__name__, empty=False) + except Exception as exc: + raise exc + + # Make sure samplerates are identical (if present) + baseSr = getattr(baseObj, "samplerate") + opndSr = getattr(operand, "samplerate") + if baseSr != opndSr: + lgl = "Syncopy objects with identical samplerate" + act = "Syncopy object with samplerates {} and {}, respectively" + raise SPYValueError(lgl, varname="operand", + actual=act.format(baseSr, opndSr)) + + # If only a subset of `operand` is selected, adjust for this + if operand._selection is not None: + opndTrialList = operand._selection.trials + else: + opndTrialList = list(range(len(operand.trials))) + + # Ensure the same number of trials is about to be processed + opndTrials = [operand._preview_trial(trlno) for trlno in opndTrialList] + if len(opndTrials) != len(baseTrials): + lgl = "Syncopy object with same number of trials (selected)" + act = "Syncopy object with {} trials (selected)" + raise SPYValueError(lgl, varname="operand", actual=act.format(len(opndTrials))) + + # Ensure complex and real values are not mashed up + baseIsComplex = ["complex" in trl.dtype.name for trl in baseTrials] + opndIsComplex = ["complex" in trl.dtype.name for trl in opndTrials] + if baseIsComplex != opndIsComplex: + lgl = "Syncopy data object of same numerical type (real/complex)" + raise SPYTypeError(operand, varname="operand", expected=lgl) + + # Determine the numeric type of the operation's result + opres_type = np.result_type(*(trl.dtype for trl in baseTrials), + *(trl.dtype for trl in opndTrials)) + + # Ensure shapes align + if not all(baseTrials[k].shape == opndTrials[k].shape for k in range(len(baseTrials))): + lgl = "Syncopy object (selection) of compatible shapes {}" + act = "Syncopy object (selection) with shapes {}" + baseShapes = [trl.shape for trl in baseTrials] + opndShapes = [trl.shape for trl in opndTrials] + raise SPYValueError(lgl.format(baseShapes), varname="operand", + actual=act.format(opndShapes)) + + # Avoid things becoming too nasty: if `operand`` contains wild selections + # (unordered lists or index repetitions) or selections requiring advanced + # (aka fancy) indexing (multiple slices mixed with lists), abort + for trl in opndTrials: + if any(np.diff(sel).min() <= 0 if isinstance(sel, list) and len(sel) > 1 \ + else False for sel in trl.idx): + lgl = "Syncopy object with ordered unreverberated subset selection" + act = "Syncopy object with selection {}" + raise SPYValueError(lgl, varname="operand", actual=act.format(operand._selection)) + if sum(isinstance(sel, slice) for sel in trl.idx) > 1 and \ + sum(isinstance(sel, list) for sel in trl.idx) > 1: + lgl = "Syncopy object without selections requiring advanced indexing" + act = "Syncopy object with selection {}" + raise SPYValueError(lgl, varname="operand", actual=act.format(operand._selection)) + + # Propagate indices for fetching data from operand + operand_idxs = [trl.idx for trl in opndTrials] + + # Assemble dict with relevant info for performing operation + operand_dat = {"filename" : operand.filename, + "dsetname" : operand._hdfFileDatasetProperties[0]} + + # If `operand` is anything else it's invalid for performing arithmetic on + else: + lgl = "Syncopy object, scalar or array-like" + raise SPYTypeError(operand, varname="operand", expected=lgl) + + return baseObj, operand, operand_dat, opres_type, operand_idxs + +# Check for complexity in `operand` vs. `baseObj` +def _check_complex_operand(baseTrials, operand, opDimType): + """ + Local helper to determine if provided scalar/array and `baseObj` are both real/complex + """ + + # Ensure complex and real values are not mashed up + if np.iscomplexobj(operand): + sameType = lambda dt : "complex" in dt.name + else: + sameType = lambda dt : "complex" not in dt.name + if not all(sameType(trl.dtype) for trl in baseTrials): + lgl = "{} of same mathematical type (real/complex)" + raise SPYTypeError(operand, varname="operand", expected=lgl.format(opDimType)) + + return + + +# Invoke `ComputationalRoutine` to compute arithmetic operation +def _perform_computation(baseObj, + operand, + operand_dat, + operand_idxs, + opres_type, + operator): + """ + Leverage `ComputationalRoutine` to process arithmetic operation + + Parameters + ---------- + baseObj : Syncopy data object + See :func:`_parse_input` for details. + operand : Syncopy data object, scalar or array-like + See :func:`_parse_input` for details. + operand_dat : dict or scalar or array-like + See :func:`_parse_input` for details. + opres_type : dtype + See :func:`_parse_input` for details. + operator : str + See :func:`_process_operator` for details. + + Returns + ------- + out : Syncopy data object + Result of performing arithmetic operation on `baseObj` and `operand` + + Note + ---- + This method instantiates a subclass of + :class:`~syncopy.shared.computational_routine.ComputationalRoutine` + to perform arithmetic operations on Syncopy objects either sequentially or + in parallel. Note that due to this code being only invoked via operator + overloading the `@detect_parallel_client` decorator is *not* invoked, since + the user cannot supply any keyword arguments. Instead, the code scans for + running dask distributed computing clients (if ACME is available) and uses + concurrent processing if a client is found. + + See also + -------- + arithmetic_cF : `computeFunction` performing arithmetics + SpyArithmetic : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` subclass + """ + + # Prepare logging info in dictionary: we know that `baseObj` is definitely + # a Syncopy data object, operand may or may not be; account for this + if "BaseData" in str(operand.__class__.__mro__): + opSel = operand._selection + else: + opSel = None + log_dct = {"operator": operator, + "base": baseObj.__class__.__name__, + "base selection": baseObj._selection, + "operand": operand.__class__.__name__, + "operand selection": opSel} + + # Create output object + out = baseObj.__class__(dimord=baseObj.dimord) + + # Now create actual functional operations: wrap operator in lambda + if operator == "+": + operation = lambda x, y : x + y + elif operator == "-": + operation = lambda x, y : x - y + elif operator == "*": + operation = lambda x, y : x * y + elif operator == "/": + operation = lambda x, y : x / y + elif operator == "**": + operation = lambda x, y : x ** y + else: + raise SPYValueError("supported arithmetic operator", actual=operator) + + # If ACME is available, try to attach (already running) parallel computing client + parallel = False + if __acme__: + try: + dd.get_client() + parallel = True + except ValueError: + parallel = False + + # Perform actual computation: instantiate `ComputationalRoutine` w/extracted info + opMethod = SpyArithmetic(operand_dat, operand_idxs, operation=operation, + opres_type=opres_type) + opMethod.initialize(baseObj, + out._stackingDim, + chan_per_worker=None, + keeptrials=True) + + # In case of parallel execution, be careful: use a distributed lock to prevent + # ACME from performing chained operations (`x + y + 3``) simultaneously (thereby + # wrecking the underlying HDF5 datasets). Similarly, if `operand` is a Syncopy + # object, close its corresponding dataset(s) before starting to concurrently read + # from them (triggering locking errors) + if parallel: + lock = dd.lock.Lock(name='arithmetic_ops') + lock.acquire() + if "BaseData" in str(operand.__class__.__mro__): + for dsetName in operand._hdfFileDatasetProperties: + dset = getattr(operand, dsetName) + dset.file.close() + + opMethod.compute(baseObj, out, parallel=parallel, log_dict=log_dct) + + # Re-open `operand`'s dataset(s) and release distributed lock + if parallel: + if "BaseData" in str(operand.__class__.__mro__): + for dsetName in operand._hdfFileDatasetProperties: + setattr(operand, dsetName, operand.filename) + lock.release() + + # Delete any created subset selections + if hasattr(baseObj._selection, "_cleanup"): + baseObj._selection = None + + return out + + +@unwrap_io +def arithmetic_cF(base_dat, operand_dat, operand_idx, operation=None, opres_type=None, + noCompute=False, chunkShape=None): + """ + Perform arithmetic operation + + Parameters + ---------- + base_dat : :class:`numpy.ndarray` + Trial data + operand_dat : dict or scalar or array-like + If two Syncopy objects are processed, then `operand_dat` is a dictionary + containing information about the operand's HDF5 backing device (see + :func:`_parse_input` for details). Otherwise, `operand_dat` is either a + scalar or array-like quantity. + operand_idx : tuple + If `operand` is a scalar, list or NumPy ndarray then `operand_idx` is + `None`. If `operand` is a Syncopy object, then `operand_idx` is an indexing + tuple. + operation : lambda object + A lambda expression encapsulating the requested arithmetic operation. + opres_type : dtype + Numerical type of applying ``operation(base_dat, operand)`` + noCompute : bool + Preprocessing flag. If `True`, do not perform actual calculation but + instead return expected shape and :class:`numpy.dtype` of output + array. + chunkShape : None or tuple + If not `None`, represents shape of output + + Returns + ------- + res : :class:`numpy.ndarray` + Result of ``operation(base_dat, operand)`` + + Notes + ----- + This method is intended to be used as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. + Thus, input parameters are presumed to be forwarded from a parent metafunction. + Consequently, this function does **not** perform any error checking and operates + under the assumption that all inputs have been externally validated and cross-checked. + + See also + -------- + _perform_computation : execute arithmetic operation + SpyArithmetic : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` subclass + """ + + if noCompute: + return base_dat.shape, opres_type + + if isinstance(operand_dat, dict): + with h5py.File(operand_dat["filename"], "r") as h5f: + operand = h5f[operand_dat["dsetname"]][operand_idx] + else: + operand = operand_dat + + return operation(base_dat, operand) + +class SpyArithmetic(ComputationalRoutine): + """ + Compute class for performing arithmetic operations with Syncopy objects + + Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`, + see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute + classes and metafunctions. + + See also + -------- + _perform_computation : execute arithmetic operation + """ + + computeFunction = staticmethod(arithmetic_cF) + + def process_metadata(self, baseObj, out): + + # Get/set timing-related selection modifiers + out.trialdefinition = baseObj._selection.trialdefinition + # if baseObj._selection._timeShuffle: # FIXME: should be implemented done the road + # out.time = baseObj._selection.timepoints + if baseObj._selection._samplerate: + out.samplerate = baseObj.samplerate + + # Get/set dimensional attributes changed by selection + for prop in baseObj._selection._dimProps: + selection = getattr(baseObj._selection, prop) + if selection is not None: + setattr(out, prop, getattr(baseObj, prop)[selection]) diff --git a/syncopy/datatype/methods/definetrial.py b/syncopy/datatype/methods/definetrial.py index c196d8064..6e0108263 100644 --- a/syncopy/datatype/methods/definetrial.py +++ b/syncopy/datatype/methods/definetrial.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -# +# # Set/update trial settings of Syncopy data objects -# +# # Builtin/3rd party package imports import numbers @@ -18,7 +18,7 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, trigger=None, stop=None, clip_edges=False): """(Re-)define trials of a Syncopy data object - + Data can be structured into trials based on timestamps of a start, trigger and end events:: @@ -29,46 +29,46 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, Parameters ---------- obj : Syncopy data object (:class:`BaseData`-like) - trialdefinition : :class:`EventData` object or Mx3 array + trialdefinition : :class:`EventData` object or Mx3 array [start, stop, trigger_offset] sample indices for `M` trials pre : float offset time (s) before start event - post : float + post : float offset time (s) after end event start : int event code (id) to be used for start of trial stop : int event code (id) to be used for end of trial - trigger : - event code (id) to be used center (t=0) of trial + trigger : + event code (id) to be used center (t=0) of trial clip_edges : bool - trim trials to actual data-boundaries. + trim trials to actual data-boundaries. Returns ------- Syncopy data object (:class:`BaseData`-like)) - - + + Notes ----- :func:`definetrial` supports the following argument combinations: - + >>> # define M trials based on [start, end, offset] indices - >>> definetrial(obj, trialdefinition=[M x 3] array) + >>> definetrial(obj, trialdefinition=[M x 3] array) >>> # define trials based on event codes stored in <:class:`EventData` object> - >>> definetrial(obj, trialdefinition=, - pre=0, post=0, start=startCode, stop=stopCode, + >>> definetrial(obj, trialdefinition=, + pre=0, post=0, start=startCode, stop=stopCode, trigger=triggerCode) >>> # apply same trial definition as defined in <:class:`EventData` object> - >>> definetrial(, + >>> definetrial(, trialdefinition=) - >>> # define whole recording as single trial + >>> # define whole recording as single trial >>> definetrial(obj, trialdefinition=None) - + """ # Start by vetting input object @@ -95,17 +95,17 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, array_parser(trialdefinition, varname="trialdefinition", dims=2) except Exception as exc: raise exc - + if any(["ContinuousData" in str(base) for base in obj.__class__.__mro__]): scount = obj.data.shape[obj.dimord.index("time")] else: scount = np.inf try: - array_parser(trialdefinition[:, :2], varname="sampleinfo", dims=(None, 2), hasnan=False, + array_parser(trialdefinition[:, :2], varname="sampleinfo", dims=(None, 2), hasnan=False, hasinf=False, ntype="int_like", lims=[0, scount]) except Exception as exc: - raise exc - + raise exc + trl = np.array(trialdefinition, dtype="float") ref = obj tgt = obj @@ -139,7 +139,7 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, if any([kw is not None for kw in [pre, post, start, trigger, stop]]): # Make sure we actually have valid data objects to work with - if obj.__class__.__name__ == "EventData" and evt is False: + if obj.__class__.__name__ == "EventData" and evt is False: ref = obj tgt = obj elif obj.__class__.__name__ == "AnalogData" and evt is True: @@ -191,7 +191,7 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, kwrds = {} vdict = {"pre": {"var": pre, "hasnan": False, "ntype": None, "fillvalue": 0}, "post": {"var": post, "hasnan": False, "ntype": None, "fillvalue": 0}, - "start": {"var": start, "hasnan": None, "ntype": "int_like", "fillvalue": np.nan}, + "start": {"var": start, "hasnan": None, "ntype": "int_like", "fillvalue": np.nan}, "trigger": {"var": trigger, "hasnan": None, "ntype": "int_like", "fillvalue": np.nan}, "stop": {"var": stop, "hasnan": None, "ntype": "int_like", "fillvalue": np.nan}} for vname, opts in vdict.items(): @@ -244,7 +244,7 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, begin = evtsp[sidx]/ref.samplerate evtid[sidx] = -np.pi idxl.append(sidx) - + if not np.isnan(kwrds["trigger"][trialno]): try: idx = evtid.index(kwrds["trigger"][trialno]) @@ -285,7 +285,7 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, lgl = "non-overlapping trial begin-/end-samples" act = "trial-begin at {}, trial-end at {}".format(str(begin), str(end)) raise SPYValueError(legal=lgl, actual=act) - + # Finally, write line of `trl` trl.append([begin, end, t0]) @@ -317,7 +317,7 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, lgl = "non-overlapping trials" act = "some trials are overlapping after clipping to AnalogData object range" raise SPYValueError(legal=lgl, actual=act) - + # The triplet `sampleinfo`, `t0` and `trialinfo` works identically for # all data genres if trl.shape[1] < 3: @@ -355,5 +355,5 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, tgt.cfg = {"method" : sys._getframe().f_code.co_name, "EventData object": ref.cfg} ref.log = "updated trial-defnition of {} object".format(tgt.__class__.__name__) - + return diff --git a/syncopy/datatype/methods/padding.py b/syncopy/datatype/methods/padding.py index d8cf97480..0984a33c3 100644 --- a/syncopy/datatype/methods/padding.py +++ b/syncopy/datatype/methods/padding.py @@ -7,14 +7,21 @@ import numpy as np # Local imports +from syncopy.datatype.continuous_data import AnalogData +from syncopy.shared.computational_routine import ComputationalRoutine +from syncopy.shared.kwarg_decorators import unwrap_io from syncopy.shared.parsers import data_parser, array_parser, scalar_parser from syncopy.shared.errors import SPYTypeError, SPYValueError, SPYWarning +from syncopy.shared.kwarg_decorators import unwrap_cfg, unwrap_select, detect_parallel_client __all__ = ["padding"] +@unwrap_cfg +@unwrap_select +@detect_parallel_client def padding(data, padtype, pad="absolute", padlength=None, prepadlength=None, - postpadlength=None, unit="samples", create_new=True): + postpadlength=None, unit="samples", create_new=True, **kwargs): """ Perform data padding on Syncopy object or :class:`numpy.ndarray` @@ -285,11 +292,33 @@ def padding(data, padtype, pad="absolute", padlength=None, prepadlength=None, timeAxis = 0 spydata = False - # FIXME: Creation of new spy-object currently not supported + # If input is a syncopy object, fetch trial list and `sampleinfo` (thereby + # accounting for in-place selections); to not repeat this later, save relevant + # quantities in tmp attributes (all prefixed by `'_pad'`) + if spydata: + if data._selection is not None: + trialList = data._selection.trials + data._pad_sinfo = np.zeros((len(trialList), 2)) + data._pad_t0 = np.zeros((len(trialList),)) + for tk, trlno in enumerate(trialList): + trl = data._preview_trial(trlno) + tsel = trl.idx[timeAxis] + if isinstance(tsel, list): + lgl = "Syncopy AnalogData object with no or channe/trial selection" + raise SPYValueError(lgl, varname="data", actual=data._selection) + else: + data._pad_sinfo[tk, :] = [trl.idx[timeAxis].start, trl.idx[timeAxis].stop] + data._pad_t0[tk] = data._t0[trlno] + data._pad_channel = data.channel[data._selection.channel] + else: + trialList = list(range(len(data.trials))) + data._pad_sinfo = data.sampleinfo + data._pad_t0 = data._t0 + data._pad_channel = data.channel + + # Ensure `create_new` is not weird if not isinstance(create_new, bool): raise SPYTypeError(create_new, varname="create_new", expected="bool") - if spydata and create_new: - raise NotImplementedError("Creation of padded spy objects currently not supported. ") # Use FT-compatible options (sans FT option 'remove') if not isinstance(padtype, str): @@ -329,7 +358,7 @@ def padding(data, padtype, pad="absolute", padlength=None, prepadlength=None, # trials, compute lower bound for padding (in samples or seconds) if pad in ["absolute", "maxlen"]: if spydata: - maxTrialLen = np.diff(data.sampleinfo).max() + maxTrialLen = np.diff(data._pad_sinfo).max() else: maxTrialLen = data.shape[timeAxis] # if `pad="absolute" and data is array else: @@ -441,14 +470,14 @@ def padding(data, padtype, pad="absolute", padlength=None, prepadlength=None, "edge": {"mode": "edge"}, "mirror": {"mode": "reflect"}} - # If in put was syncopy data object, padding is done on a per-trial basis + # If input was syncopy data object, padding is done on a per-trial basis if spydata: # A list of input keywords for ``np.pad`` is constructed, no matter if # we actually want to build a new object or not pad_opts = [] - for trl in data.trials: - nSamples = trl.shape[timeAxis] + for tk in trialList: + nSamples = data._preview_trial(tk).shape[timeAxis] if pad == "absolute": padding = (padlength - nSamples)/(prepadlength + postpadlength) elif pad == "relative": @@ -463,8 +492,25 @@ def padding(data, padtype, pad="absolute", padlength=None, prepadlength=None, if padtype == "localmean": pad_opts[-1]["stat_length"] = pw[timeAxis, :] + # If a new object is requested, use the legwork performed above to fire + # up the corresponding ComputationalRoutine if create_new: - pass + out = AnalogData(dimord=data.dimord) + log_dct = {"padtype": padtype, + "pad": pad, + "padlength": padlength, + "prepadlength": prepadlength, + "postpadlength": postpadlength, + "unit": unit} + + chanAxis = list(set([0, 1]).difference([timeAxis]))[0] + padMethod = PaddingRoutine(timeAxis, chanAxis, pad_opts) + padMethod.initialize(data, + out._stackingDim, + chan_per_worker=kwargs.get("chan_per_worker"), + keeptrials=True) + padMethod.compute(data, out, parallel=kwargs.get("parallel"), log_dict=log_dct) + return out else: return pad_opts @@ -496,7 +542,7 @@ def padding(data, padtype, pad="absolute", padlength=None, prepadlength=None, if create_new: if isinstance(data, np.ndarray): return np.pad(data, **pad_opts) - else: # FIXME: currently only supports FauxTrial + else: shp = list(data.shape) shp[timeAxis] += pw[timeAxis, :].sum() idx = list(data.idx) @@ -516,3 +562,99 @@ def _nextpow2(number): while n < number: n *= 2 return n + + +@unwrap_io +def padding_cF(trl_dat, timeAxis, chanAxis, pad_opt, noCompute=False, chunkShape=None): + """ + Perform trial data padding + + Parameters + ---------- + trl_dat : :class:`numpy.ndarray` + Trial data + timeAxis : int + Index of running time axis in `trl_dat` (0 or 1) + chanAxis : int + Index of channel axis in `trl_dat` (0 or 1) + pad_opt : dict + Dictionary of options for :func:`numpy.pad` + noCompute : bool + Preprocessing flag. If `True`, do not perform actual padding but + instead return expected shape and :class:`numpy.dtype` of output + array. + chunkShape : None or tuple + If not `None`, represents shape of output + + Returns + ------- + res : :class:`numpy.ndarray` + Padded array + + Notes + ----- + This method is intended to be used as + :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. + Thus, input parameters are presumed to be forwarded from a parent metafunction. + Consequently, this function does **not** perform any error checking and operates + under the assumption that all inputs have been externally validated and cross-checked. + + See also + -------- + syncopy.padding : pad :class:`syncopy.AnalogData` objects + PaddingRoutine : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` subclass + """ + + nSamples = trl_dat.shape[timeAxis] + nChannels = trl_dat.shape[chanAxis] + + if noCompute: + outShape = [None] * 2 + outShape[timeAxis] = pad_opt['pad_width'].sum() + nSamples + outShape[chanAxis] = nChannels + return outShape, trl_dat.dtype + + # Symmetric Padding (updates no. of samples) + return np.pad(trl_dat, **pad_opt) + +class PaddingRoutine(ComputationalRoutine): + """ + Compute class for performing data padding on Syncopy AnalogData objects + + Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`, + see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute + classes and metafunctions. + + See also + -------- + syncopy.padding : pad :class:`syncopy.AnalogData` objects + """ + + computeFunction = staticmethod(padding_cF) + + def process_metadata(self, data, out): + + # Fetch index of running time and used padding options from provided + # positional args and use them to compute new start/stop/trigger onset samples + timeAxis = self.argv[0] + pad_opts = self.argv[2] + prePadded = [pad_opt["pad_width"][timeAxis, 0] for pad_opt in pad_opts] + totalPadded = [pad_opt["pad_width"].sum() for pad_opt in pad_opts] + accumSamples = np.cumsum(np.diff(data._pad_sinfo).squeeze() + totalPadded) + + # Construct trialdefinition array (columns: start/stop/t0/etc) + trialdefinition = np.zeros((len(totalPadded), data.trialdefinition.shape[1])) + trialdefinition[1:, 0] = accumSamples[:-1] + trialdefinition[:, 1] = accumSamples + trialdefinition[:, 2] = data._pad_t0 - prePadded + + # Set relevant properties in output object + out.samplerate = data.samplerate + out.trialdefinition = trialdefinition + out.channel = data._pad_channel + + # Remove inpromptu attributes generated above + delattr(data, "_pad_sinfo") + delattr(data, "_pad_t0") + delattr(data, "_pad_channel") diff --git a/syncopy/datatype/methods/selectdata.py b/syncopy/datatype/methods/selectdata.py index 6e90dd6af..f81f7d1d6 100644 --- a/syncopy/datatype/methods/selectdata.py +++ b/syncopy/datatype/methods/selectdata.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- -# +# # Syncopy data selection methods -# +# # Local imports from syncopy.shared.parsers import data_parser -from syncopy.shared.tools import get_defaults +from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYInfo, SPYWarning from syncopy.shared.kwarg_decorators import unwrap_cfg, unwrap_io, detect_parallel_client from syncopy.shared.computational_routine import ComputationalRoutine @@ -14,182 +14,187 @@ @unwrap_cfg @detect_parallel_client -def selectdata(data, trials=None, channels=None, toi=None, toilim=None, foi=None, - foilim=None, tapers=None, units=None, eventids=None, - out=None, **kwargs): +def selectdata(data, trials=None, channels=None, channels_i=None, channels_j=None, + toi=None, toilim=None, foi=None, foilim=None, tapers=None, units=None, + eventids=None, out=None, inplace=False, clear=False, **kwargs): """ Create a new Syncopy object from a selection - **Usage Notice** - - Syncopy offers two modes for selecting data: - - * **in-place** selections mark subsets of a Syncopy data object for processing + **Usage Notice** + + Syncopy offers two modes for selecting data: + + * **in-place** selections mark subsets of a Syncopy data object for processing via a ``select`` dictionary *without* creating a new object - * **deep-copy** selections copy subsets of a Syncopy data object to keep and + * **deep-copy** selections copy subsets of a Syncopy data object to keep and preserve in a new object created by :func:`~syncopy.selectdata` - - All Syncopy metafunctions, such as :func:`~syncopy.freqanalysis`, support + + All Syncopy metafunctions, such as :func:`~syncopy.freqanalysis`, support **in-place** data selection via a ``select`` keyword, effectively avoiding - potentially slow copy operations and saving disk space. The keys accepted - by the `select` dictionary are identical to the keyword arguments discussed - below. In addition, ``select = "all"`` can be used to select entire object + potentially slow copy operations and saving disk space. The keys accepted + by the `select` dictionary are identical to the keyword arguments discussed + below. In addition, ``select = "all"`` can be used to select entire object contents. Examples - + >>> select = {"toilim" : [-0.25, 0]} >>> spy.freqanalysis(data, select=select) - >>> # or equivalently + >>> # or equivalently >>> cfg = spy.get_defaults(spy.freqanalysis) >>> cfg.select = select >>> spy.freqanalysis(cfg, data) - + **Usage Summary** - + List of Syncopy data objects and respective valid data selectors: - + :class:`~syncopy.AnalogData` : trials, channels, toi/toilim Examples - + >>> spy.selectdata(data, trials=[0, 3, 5], channels=["channel01", "channel02"]) - >>> cfg = spy.StructDict() + >>> cfg = spy.StructDict() >>> cfg.trials = [5, 3, 0]; cfg.toilim = [0.25, 0.5] >>> spy.selectdata(cfg, data) - + :class:`~syncopy.SpectralData` : trials, channels, toi/toilim, foi/foilim, tapers Examples - + >>> spy.selectdata(data, trials=[0, 3, 5], channels=["channel01", "channel02"]) >>> cfg = spy.StructDict() >>> cfg.foi = [30, 40, 50]; cfg.tapers = slice(2, 4) >>> spy.selectdata(cfg, data) - + :class:`~syncopy.EventData` : trials, toi/toilim, eventids Examples - + >>> spy.selectdata(data, toilim=[-1, 2.5], eventids=[0, 1]) >>> cfg = spy.StructDict() >>> cfg.trials = [0, 0, 1, 0]; cfg.eventids = slice(2, None) >>> spy.selectdata(cfg, data) - + :class:`~syncopy.SpikeData` : trials, toi/toilim, units, channels Examples - + >>> spy.selectdata(data, toilim=[-1, 2.5], units=range(0, 10)) >>> cfg = spy.StructDict() >>> cfg.toi = [1.25, 3.2]; cfg.trials = [0, 1, 2, 3] >>> spy.selectdata(cfg, data) - + **Note** Any property that is not specifically accessed via one of the provided selectors is taken as is, e.g., ``spy.selectdata(data, trials=[1, 2])`` - selects the entire contents of trials no. 2 and 3, while + selects the entire contents of trials no. 2 and 3, while ``spy.selectdata(data, channels=range(0, 50))`` selects the first 50 channels of `data` across all defined trials. Consequently, if no keywords are specified, - the entire contents of `data` is selected. - - **Full documentation below** - + the entire contents of `data` is selected. + + **Full documentation below** + Parameters ---------- data : Syncopy data object A non-empty Syncopy data object. **Note** the type of `data` determines - which keywords can be used. Some keywords are only valid for certain - types of Syncopy objects, e.g., "freqs" is not a valid selector for an - :class:`~syncopy.AnalogData` object. + which keywords can be used. Some keywords are only valid for certain + types of Syncopy objects, e.g., "freqs" is not a valid selector for an + :class:`~syncopy.AnalogData` object. trials : list (integers) or None or "all" - List of integers representing trial numbers to be selected; can include - repetitions and need not be sorted (e.g., ``trials = [0, 1, 0, 0, 2]`` - is valid) but must be finite and not NaN. If `trials` is `None`, or - ``trials = "all"`` all trials are selected. + List of integers representing trial numbers to be selected; can include + repetitions and need not be sorted (e.g., ``trials = [0, 1, 0, 0, 2]`` + is valid) but must be finite and not NaN. If `trials` is `None`, or + ``trials = "all"`` all trials are selected. channels : list (integers or strings), slice, range or None or "all" - Channel-selection; can be a list of channel names (``['channel3', 'channel1']``), - a list of channel indices (``[3, 5]``), a slice (``slice(3, 10)``) or - range (``range(3, 10)``). Note that following Python conventions, channels - are counted starting at zero, and range and slice selections are half-open - intervals of the form `[low, high)`, i.e., low is included , high is - excluded. Thus, ``channels = [0, 1, 2]`` or ``channels = slice(0, 3)`` - selects the first up to (and including) the third channel. Selections can - be unsorted and may include repetitions but must match exactly, be finite - and not NaN. If `channels` is `None`, or ``channels = "all"`` all channels - are selected. + Channel-selection; can be a list of channel names (``['channel3', 'channel1']``), + a list of channel indices (``[3, 5]``), a slice (``slice(3, 10)``) or + range (``range(3, 10)``). Note that following Python conventions, channels + are counted starting at zero, and range and slice selections are half-open + intervals of the form `[low, high)`, i.e., low is included , high is + excluded. Thus, ``channels = [0, 1, 2]`` or ``channels = slice(0, 3)`` + selects the first up to (and including) the third channel. Selections can + be unsorted and may include repetitions but must match exactly, be finite + and not NaN. If `channels` is `None`, or ``channels = "all"`` all channels + are selected. toi : list (floats) or None or "all" - Time-points to be selected (in seconds) in each trial. Timing is expected - to be on a by-trial basis (e.g., relative to trigger onsets). Selections - can be approximate, unsorted and may include repetitions but must be - finite and not NaN. Fuzzy matching is performed for approximate selections - (i.e., selected time-points are close but not identical to timing information - found in `data`) using a nearest-neighbor search for elements of `toi`. - If `toi` is `None` or ``toi = "all"``, the entire time-span in each trial - is selected. + Time-points to be selected (in seconds) in each trial. Timing is expected + to be on a by-trial basis (e.g., relative to trigger onsets). Selections + can be approximate, unsorted and may include repetitions but must be + finite and not NaN. Fuzzy matching is performed for approximate selections + (i.e., selected time-points are close but not identical to timing information + found in `data`) using a nearest-neighbor search for elements of `toi`. + If `toi` is `None` or ``toi = "all"``, the entire time-span in each trial + is selected. toilim : list (floats [tmin, tmax]) or None or "all" - Time-window ``[tmin, tmax]`` (in seconds) to be extracted from each trial. - Window specifications must be sorted (e.g., ``[2.2, 1.1]`` is invalid) - and not NaN but may be unbounded (e.g., ``[1.1, np.inf]`` is valid). Edges - `tmin` and `tmax` are included in the selection. - If `toilim` is `None` or ``toilim = "all"``, the entire time-span in each - trial is selected. + Time-window ``[tmin, tmax]`` (in seconds) to be extracted from each trial. + Window specifications must be sorted (e.g., ``[2.2, 1.1]`` is invalid) + and not NaN but may be unbounded (e.g., ``[1.1, np.inf]`` is valid). Edges + `tmin` and `tmax` are included in the selection. + If `toilim` is `None` or ``toilim = "all"``, the entire time-span in each + trial is selected. foi : list (floats) or None or "all" - Frequencies to be selected (in Hz). Selections can be approximate, unsorted - and may include repetitions but must be finite and not NaN. Fuzzy matching - is performed for approximate selections (i.e., selected frequencies are + Frequencies to be selected (in Hz). Selections can be approximate, unsorted + and may include repetitions but must be finite and not NaN. Fuzzy matching + is performed for approximate selections (i.e., selected frequencies are close but not identical to frequencies found in `data`) using a nearest- neighbor search for elements of `foi` in `data.freq`. If `foi` is `None` - or ``foi = "all"``, all frequencies are selected. + or ``foi = "all"``, all frequencies are selected. foilim : list (floats [fmin, fmax]) or None or "all" - Frequency-window ``[fmin, fmax]`` (in Hz) to be extracted. Window - specifications must be sorted (e.g., ``[90, 70]`` is invalid) and not NaN - but may be unbounded (e.g., ``[-np.inf, 60.5]`` is valid). Edges `fmin` - and `fmax` are included in the selection. If `foilim` is `None` or - ``foilim = "all"``, all frequencies are selected. + Frequency-window ``[fmin, fmax]`` (in Hz) to be extracted. Window + specifications must be sorted (e.g., ``[90, 70]`` is invalid) and not NaN + but may be unbounded (e.g., ``[-np.inf, 60.5]`` is valid). Edges `fmin` + and `fmax` are included in the selection. If `foilim` is `None` or + ``foilim = "all"``, all frequencies are selected. tapers : list (integers or strings), slice, range or None or "all" - Taper-selection; can be a list of taper names (``['dpss-win-1', 'dpss-win-3']``), - a list of taper indices (``[3, 5]``), a slice (``slice(3, 10)``) or range - (``range(3, 10)``). Note that following Python conventions, tapers are - counted starting at zero, and range and slice selections are half-open - intervals of the form `[low, high)`, i.e., low is included , high is - excluded. Thus, ``tapers = [0, 1, 2]`` or ``tapers = slice(0, 3)`` selects - the first up to (and including) the third taper. Selections can be unsorted - and may include repetitions but must match exactly, be finite and not NaN. - If `tapers` is `None` or ``tapers = "all"``, all tapers are selected. + Taper-selection; can be a list of taper names (``['dpss-win-1', 'dpss-win-3']``), + a list of taper indices (``[3, 5]``), a slice (``slice(3, 10)``) or range + (``range(3, 10)``). Note that following Python conventions, tapers are + counted starting at zero, and range and slice selections are half-open + intervals of the form `[low, high)`, i.e., low is included , high is + excluded. Thus, ``tapers = [0, 1, 2]`` or ``tapers = slice(0, 3)`` selects + the first up to (and including) the third taper. Selections can be unsorted + and may include repetitions but must match exactly, be finite and not NaN. + If `tapers` is `None` or ``tapers = "all"``, all tapers are selected. units : list (integers or strings), slice, range or None or "all" - Unit-selection; can be a list of unit names (``['unit10', 'unit3']``), a - list of unit indices (``[3, 5]``), a slice (``slice(3, 10)``) or range - (``range(3, 10)``). Note that following Python conventions, units are - counted starting at zero, and range and slice selections are half-open - intervals of the form `[low, high)`, i.e., low is included , high is - excluded. Thus, ``units = [0, 1, 2]`` or ``units = slice(0, 3)`` selects - the first up to (and including) the third unit. Selections can be unsorted + Unit-selection; can be a list of unit names (``['unit10', 'unit3']``), a + list of unit indices (``[3, 5]``), a slice (``slice(3, 10)``) or range + (``range(3, 10)``). Note that following Python conventions, units are + counted starting at zero, and range and slice selections are half-open + intervals of the form `[low, high)`, i.e., low is included , high is + excluded. Thus, ``units = [0, 1, 2]`` or ``units = slice(0, 3)`` selects + the first up to (and including) the third unit. Selections can be unsorted and may include repetitions but must match exactly, be finite and not NaN. - If `units` is `None` or ``units = "all"``, all units are selected. + If `units` is `None` or ``units = "all"``, all units are selected. eventids : list (integers), slice, range or None or "all" - Event-ID-selection; can be a list of event-id codes (``[2, 0, 1]``), slice - (``slice(0, 2)``) or range (``range(0, 2)``). Note that following Python - conventions, range and slice selections are half-open intervals of the - form `[low, high)`, i.e., low is included , high is excluded. Selections + Event-ID-selection; can be a list of event-id codes (``[2, 0, 1]``), slice + (``slice(0, 2)``) or range (``range(0, 2)``). Note that following Python + conventions, range and slice selections are half-open intervals of the + form `[low, high)`, i.e., low is included , high is excluded. Selections can be unsorted and may include repetitions but must match exactly, be - finite and not NaN. If `eventids` is `None` or ``eventids = "all"``, all - events are selected. - + finite and not NaN. If `eventids` is `None` or ``eventids = "all"``, all + events are selected. + inplace : bool + If `inplace` is `True` **no** new object is created. Instead the provided + selection is stored in the input object's `_selection` attribute for later + use. By default `inplace` is `False` and all calls to `selectdata` create + a new Syncopy data object. + Returns ------- dataselection : Syncopy data object - Syncopy data object of the same type as `data` but containing only the - subset specified by provided selectors. - + Syncopy data object of the same type as `data` but containing only the + subset specified by provided selectors. + Notes ----- This routine represents a convenience function for creating new Syncopy objects - based on existing data entities. However, in many situations, the creation - of a new object (and thus the allocation of additional disk-space) might not + based on existing data entities. However, in many situations, the creation + of a new object (and thus the allocation of additional disk-space) might not be necessary: all Syncopy metafunctions, such as :func:`~syncopy.freqanalysis`, - support **in-place** data selection. - - Consider the following example: assume `data` is an :class:`~syncopy.AnalogData` - object representing 220 trials of LFP recordings containing baseline (between - second -0.25 and 0) and stimulus-on data (on the interval [0.25, 0.5]). + support **in-place** data selection. + + Consider the following example: assume `data` is an :class:`~syncopy.AnalogData` + object representing 220 trials of LFP recordings containing baseline (between + second -0.25 and 0) and stimulus-on data (on the interval [0.25, 0.5]). To compute the baseline spectrum, data-selection does **not** have to be performed before calling :func:`~syncopy.freqanalysis` but instead can be done in-place: - + >>> import syncopy as spy >>> cfg = spy.get_defaults(spy.freqanalysis) >>> cfg.method = 'mtmfft' @@ -205,35 +210,32 @@ def selectdata(data, trials=None, channels=None, toi=None, toilim=None, foi=None >>> # in-place selection of stimulus-on time-frame performed by `freqanalysis` >>> cfg.select = stimSelect >>> stimonSpectrum = spy.freqanalysis(cfg, data) - + Especially for large data-sets, in-place data selection performed by Syncopy's - metafunctions does not only save disk-space but can significantly increase - performance. - + metafunctions does not only save disk-space but can significantly increase + performance. + Examples -------- - Use :func:`~syncopy.tests.misc.generate_artificial_data` to create a synthetic - :class:`syncopy.AnalogData` object. - + Use :func:`~syncopy.tests.misc.generate_artificial_data` to create a synthetic + :class:`syncopy.AnalogData` object. + >>> from syncopy.tests.misc import generate_artificial_data - >>> adata = generate_artificial_data(nTrials=10, nChannels=32) - + >>> adata = generate_artificial_data(nTrials=10, nChannels=32) + Assume a hypothetical trial onset at second 2.0 with the first second of each trial representing baseline recordings. To extract only the stimulus-on period from `adata`, one could use - + >>> stimon = spy.selectdata(adata, toilim=[2.0, np.inf]) - + Note that this is equivalent to - + >>> stimon = adata.selectdata(toilim=[2.0, np.inf]) - + See also -------- - :meth:`syncopy.AnalogData.selectdata` : corresponding class method - :meth:`syncopy.SpectralData.selectdata` : corresponding class method - :meth:`syncopy.EventData.selectdata` : corresponding class method - :meth:`syncopy.SpikeData.selectdata` : corresponding class method + :func:`syncopy.show` : Show (subsets) of Syncopy objects """ # Ensure our one mandatory input is usable @@ -242,44 +244,79 @@ def selectdata(data, trials=None, channels=None, toi=None, toilim=None, foi=None except Exception as exc: raise exc + # Vet the only inputs not checked by `Selector` + if not isinstance(inplace, bool): + raise SPYTypeError(inplace, varname="inplace", expected="Boolean") + if not isinstance(inplace, bool): + raise SPYTypeError(clear, varname="clear", expected="Boolean") + # If provided, make sure output object is appropriate - if out is not None: - try: - data_parser(out, varname="out", writable=True, empty=True, - dataclass=data.__class__.__name__, - dimord=data.dimord) - except Exception as exc: - raise exc - new_out = False + if not inplace: + if out is not None: + try: + data_parser(out, varname="out", writable=True, empty=True, + dataclass=data.__class__.__name__, + dimord=data.dimord) + except Exception as exc: + raise exc + new_out = False + else: + out = data.__class__(dimord=data.dimord) + new_out = True else: - out = data.__class__(dimord=data.dimord) - new_out = True + if out is not None: + lgl = "no output object for in-place selection" + raise SPYValueError(lgl, varname="out", actual=out.__class__.__name__) + + # FIXME: remove once tests are in place (cf #165) + if channels_i is not None or channels_j is not None: + SPYWarning("CrossSpectralData channel selection currently untested and experimental!") + + # Collect provided keywords in dict + selectDict = {"trials": trials, + "channels": channels, + "channels_i": channels_i, + "channels_j": channels_j, + "toi": toi, + "toilim": toilim, + "foi": foi, + "foilim": foilim, + "tapers": tapers, + "units": units, + "eventids": eventids} + + # First simplest case: determine whether we just need to clear an existing selection + if clear: + if any(value is not None for value in selectDict.values()): + lgl = "no data selectors if `clear = True`" + raise SPYValueError(lgl, varname="select", actual=selectDict) + if data._selection is None: + SPYInfo("No in-place selection found. ") + else: + data._selection = None + SPYInfo("In-place selection cleared") + return # Pass provided selections on to `Selector` class which performs error checking - data._selection = {"trials": trials, - "channels": channels, - "toi": toi, - "toilim": toilim, - "foi": foi, - "foilim": foilim, - "tapers": tapers, - "units": units, - "eventids": eventids} - - # Create inventory of all available selectors and actually provided values + data._selection = selectDict + + # If an in-place selection was requested we're done + if inplace: + SPYInfo("In-place selection attached to data object: {}".format(data._selection)) + return + + # Create inventory of all available selectors and actually provided values # to create a bookkeeping dict for logging - provided = locals() - available = get_defaults(data.selectdata) - actualSelection = {} - for key in available: - actualSelection[key] = provided[key] - + log_dct = {"inplace": inplace, "clear": clear} + log_dct.update(selectDict) + log_dct.update(**kwargs) + # Fire up `ComputationalRoutine`-subclass to do the actual selecting/copying selectMethod = DataSelection() - selectMethod.initialize(data, chan_per_worker=kwargs.get("chan_per_worker")) - selectMethod.compute(data, out, parallel=kwargs.get("parallel"), - log_dict=actualSelection) - + selectMethod.initialize(data, out._stackingDim, chan_per_worker=kwargs.get("chan_per_worker")) + selectMethod.compute(data, out, parallel=kwargs.get("parallel"), + log_dict=log_dct) + # Wipe data-selection slot to not alter input object data._selection = None @@ -299,14 +336,14 @@ class DataSelection(ComputationalRoutine): computeFunction = staticmethod(_selectdata) def process_metadata(self, data, out): - + # Get/set timing-related selection modifiers out.trialdefinition = data._selection.trialdefinition # if data._selection._timeShuffle: # FIXME: should be implemented done the road - # out.time = data._selection.timepoints + # out.time = data._selection.timepoints if data._selection._samplerate: out.samplerate = data.samplerate - + # Get/set dimensional attributes changed by selection for prop in data._selection._dimProps: selection = getattr(data._selection, prop) diff --git a/syncopy/datatype/methods/show.py b/syncopy/datatype/methods/show.py new file mode 100644 index 000000000..a5ac125d7 --- /dev/null +++ b/syncopy/datatype/methods/show.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# +# Syncopy data slicing methods +# + +# Builtin/3rd party package imports +import numpy as np + +# Local imports +from syncopy.shared.errors import SPYInfo +from syncopy.shared.kwarg_decorators import unwrap_cfg + +__all__ = ["show"] + + +@unwrap_cfg +def show(data, **kwargs): + """ + Show (partial) contents of Syncopy object + + **Usage Notice** + + Syncopy uses HDF5 files as on-disk backing device for data storage. This + allows working with larger-than-memory data-sets by streaming only relevant + subsets of data from disk on demand without excessive RAM use. However, using + :func:`~syncopy.show` this mechanism is bypassed and the requested data subset + is loaded into memory at once. Thus, inadvertent usage of :func:`~syncopy.show` + on a large data object can lead to memory overflow or even out-of-memory errors. + + **Usage Summary** + + Data selectors for showing subsets of Syncopy data objects follow the syntax + of :func:`~syncopy.selectdata`. Please refer to :func:`~syncopy.selectdata` + for a list of valid data selectors for respective Syncopy data objects. + + Parameters + ---------- + data : Syncopy data object + As for subset-selection via :func:`~syncopy.selectdata`, the type of `data` + determines which keywords can be used. Some keywords are only valid for + certain types of Syncopy objects, e.g., "freqs" is not a valid selector + for an :class:`~syncopy.AnalogData` object. + **kwargs : keywords + Valid data selectors (e.g., `trials`, `channels`, `toi` etc.). Please + refer to :func:`~syncopy.selectdata` for a full list of available data + selectors. + + Returns + ------- + arr : NumPy nd-array + A (selection) of data retrieved from the `data` input object. + + Notes + ----- + This routine represents a convenience function for quickly inspecting the + contents of Syncopy objects. It is always possible to manually access an object's + numerical data by indexing the underlying HDF5 dataset: `data.data[idx]`. + The dimension labels of the dataset are encoded in `data.dimord`, e.g., if + `data` is a :class:`~syncopy.AnalogData` with `data.dimord` being `['time', 'channel']` + and `data.data.shape` is `(15000, 16)`, then `data.data[:, 3]` returns the + contents of the fourth channel across all time points. + + Examples + -------- + Use :func:`~syncopy.tests.misc.generate_artificial_data` to create a synthetic + :class:`syncopy.AnalogData` object. + + >>> from syncopy.tests.misc import generate_artificial_data + >>> adata = generate_artificial_data(nTrials=10, nChannels=32) + + Show the contents of `'channel02'` across all trials: + + >>> spy.show(adata, channels=['channel02']) + Syncopy INFO: In-place selection attached to data object: Syncopy AnalogData selector with 1 channels, all times, 10 trials + Syncopy INFO: Showing 1 channels, all times, 10 trials + Out[11]: + array([[1.627 ], + [1.7906], + [1.1757], + ..., + [1.1498], + [0.7753], + [1.0457]], dtype=float32) + + Note that this is equivalent to + + >>> adata.show(channels=['channel02']) + + See also + -------- + :func:`syncopy.selectdata` : Create a new Syncopy object from a selection + """ + + # Account for pathological cases + if data.data is None: + SPYInfo("Empty object, nothing to show") + return + + # Leverage `selectdata` to sanitize input and perform subset picking + data.selectdata(inplace=True, **kwargs) + + # Use an object's `_preview_trial` method fetch required indexing tuples + SPYInfo("Showing{}".format(data._selection.__str__().partition("with")[-1])) + idxList = [] + for trlno in data._selection.trials: + idxList.append(data._preview_trial(trlno).idx) + + # Perform some slicing/list-selection gymnastics: ensure that selections + # that result in contiguous slices are actually returned as such (e.g., + # `idxList = [(slice(1,2), [2]), (slice(2,3), [2])` -> `returnIdx = [slice(1,3), [2]]`) + singleIdx = [False] * len(idxList[0]) + returnIdx = list(idxList[0]) + for sk, selectors in enumerate(zip(*idxList)): + if np.unique(selectors).size == 1: + singleIdx[sk] = True + else: + if all(isinstance(sel, slice) for sel in selectors): + gaps = [selectors[k + 1].start - selectors[k].stop for k in range(len(selectors) - 1)] + if all(gap == 0 for gap in gaps): + singleIdx[sk] = True + returnIdx[sk] = slice(selectors[0].start, selectors[-1].stop) + + # Reset in-place subset selection + data._selection = None + + # If possible slice underlying dataset only once, otherwise return a list + # of arrays corresponding to selected trials + if all(si == True for si in singleIdx): + return data.data[tuple(returnIdx)] + else: + return [data.data[idx] for idx in idxList] diff --git a/syncopy/shared/computational_routine.py b/syncopy/shared/computational_routine.py index 63f2e2b52..6d66f1854 100644 --- a/syncopy/shared/computational_routine.py +++ b/syncopy/shared/computational_routine.py @@ -8,13 +8,10 @@ import sys import psutil import h5py -import time import numpy as np from itertools import chain from abc import ABC, abstractmethod -from collections.abc import Sized from copy import copy -from glob import glob from numpy.lib.format import open_memmap from tqdm.auto import tqdm if sys.platform == "win32": @@ -26,7 +23,7 @@ # Local imports from .tools import get_defaults from syncopy import __storage__, __acme__, __path__ -from syncopy.shared.errors import SPYValueError, SPYWarning, SPYParallelError +from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYParallelError, SPYWarning if __acme__: from acme import ParallelMap import dask.distributed as dd @@ -226,7 +223,7 @@ def __init__(self, *argv, **kwargs): self._callMax = 10000 self._callCount = 0 - def initialize(self, data, chan_per_worker=None, keeptrials=True): + def initialize(self, data, out_stackingdim, chan_per_worker=None, keeptrials=True): """ Perform dry-run of calculation to determine output shape @@ -235,6 +232,8 @@ def initialize(self, data, chan_per_worker=None, keeptrials=True): data : syncopy data object Syncopy data object to be processed (has to be the same object that is passed to :meth:`compute` for the actual calculation). + out_stackingdim : int + Index of data dimension for stacking trials in output object chan_per_worker : None or int Number of channels to be processed by each worker (only relevant in case of concurrent processing). If `chan_per_worker` is `None` (default) @@ -282,7 +281,7 @@ def initialize(self, data, chan_per_worker=None, keeptrials=True): trials = [] for tk, trialno in enumerate(self.trialList): trial = data._preview_trial(trialno) - trlArg = tuple(arg[tk] if isinstance(arg, Sized) and len(arg) == self.numTrials \ + trlArg = tuple(arg[tk] if isinstance(arg, (list, tuple, np.ndarray)) and len(arg) == self.numTrials \ else arg for arg in self.argv) chunkShape, dtype = self.computeFunction(trial, *trlArg, @@ -291,17 +290,28 @@ def initialize(self, data, chan_per_worker=None, keeptrials=True): dtp_list.append(dtype) trials.append(trial) + # Determine trial stacking dimension and compute aggregate shape of output + stackingDim = out_stackingdim + totalSize = sum(cShape[stackingDim] for cShape in chk_list) + outputShape = list(chunkShape) + if stackingDim < 0 or stackingDim >= len(outputShape): + msg = "valid trial stacking dimension" + raise SPYTypeError(out_stackingdim, varname="out_stackingdim", expected=msg) + outputShape[stackingDim] = totalSize + # The aggregate shape is computed as max across all chunks chk_arr = np.array(chk_list) - if np.unique(chk_arr[:, 0]).size > 1 and not self.keeptrials: + chunkShape = tuple(chk_arr.max(axis=0)) + if np.unique(chk_arr[:, stackingDim]).size > 1 and not self.keeptrials: err = "Averaging trials of unequal lengths in output currently not supported!" raise NotImplementedError(err) if np.any([dtp_list[0] != dtp for dtp in dtp_list]): lgl = "unique output dtype" act = "{} different output dtypes".format(np.unique(dtp_list).size) raise SPYValueError(legal=lgl, varname="dtype", actual=act) - chunkShape = tuple(chk_arr.max(axis=0)) - self.outputShape = (chk_arr[:, 0].sum(),) + chunkShape[1:] + + # Save determined shapes and data type + self.outputShape = tuple(outputShape) self.cfg["chunkShape"] = chunkShape self.dtype = np.dtype(dtp_list[0]) @@ -323,7 +333,7 @@ def initialize(self, data, chan_per_worker=None, keeptrials=True): # Allocate control variables trial = trials[0] - trlArg0 = tuple(arg[0] if isinstance(arg, Sized) and len(arg) == self.numTrials \ + trlArg0 = tuple(arg[0] if isinstance(arg, (list, tuple, np.ndarray)) and len(arg) == self.numTrials \ else arg for arg in self.argv) chunkShape0 = chk_arr[0, :] lyt = [slice(0, stop) for stop in chunkShape0] @@ -383,18 +393,15 @@ def initialize(self, data, chan_per_worker=None, keeptrials=True): sourceLayout.append(trial.idx) # Construct dimensional layout of output - # FIXME: should be targetLayout[0][stackingDim].stop - # FIXME: should be lyt[stackingDim] = slice(stacking, stacking + chkshp[stackingDim]) - # FIXME: should be stacking += chkshp[stackingDim] - stacking = targetLayout[0][0].stop + stacking = targetLayout[0][stackingDim].stop for tk in range(1, self.numTrials): trial = trials[tk] - trlArg = tuple(arg[tk] if isinstance(arg, Sized) and len(arg) == self.numTrials \ + trlArg = tuple(arg[tk] if isinstance(arg, (list, tuple, np.ndarray)) and len(arg) == self.numTrials \ else arg for arg in self.argv) chkshp = chk_list[tk] lyt = [slice(0, stop) for stop in chkshp] - lyt[0] = slice(stacking, stacking + chkshp[0]) - stacking += chkshp[0] + lyt[stackingDim] = slice(stacking, stacking + chkshp[stackingDim]) + stacking += chkshp[stackingDim] if chan_per_worker is None: targetLayout.append(tuple(lyt)) targetShapes.append(tuple([slc.stop - slc.start for slc in lyt])) @@ -751,7 +758,9 @@ def preallocate_output(self, out, parallel_store=False): layout = h5py.VirtualLayout(shape=self.outputShape, dtype=self.dtype) for k, idx in enumerate(self.targetLayout): fname = os.path.join(self.virtualDatasetDir, "{0:d}.h5".format(k)) - layout[idx] = h5py.VirtualSource(fname, self.virtualDatasetNames, shape=self.targetShapes[k]) + # Catch empty selections: don't map empty sources into the layout of the VDS + if all([sel for sel in self.sourceLayout[k]]): + layout[idx] = h5py.VirtualSource(fname, self.virtualDatasetNames, shape=self.targetShapes[k]) self.VirtualDatasetLayout = layout self.outFileName = os.path.join(self.virtualDatasetDir, "{0:d}.h5") self.tmpDsetName = self.virtualDatasetNames @@ -865,7 +874,7 @@ def compute_sequential(self, data, out): sigrid = self.sourceSelectors[nblock] outgrid = self.targetLayout[nblock] argv = tuple(arg[nblock] \ - if isinstance(arg, Sized) and len(arg) == self.numTrials \ + if isinstance(arg, (list, tuple, np.ndarray)) and len(arg) == self.numTrials \ else arg for arg in self.argv) # Catch empty source-array selections; this workaround is not diff --git a/syncopy/shared/const_def.py b/syncopy/shared/const_def.py new file mode 100644 index 000000000..468d45e79 --- /dev/null +++ b/syncopy/shared/const_def.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# +# Constant definitions used throughout SyNCoPy +# + +# Builtin/3rd party package imports +import numpy as np +from scipy.signal import windows + +# Module-wide output specs +spectralDTypes = {"pow": np.float32, + "fourier": np.complex64, + "abs": np.float32} + +#: output conversion of complex fourier coefficients +spectralConversions = {"pow": lambda x: (x * np.conj(x)).real.astype(np.float32), + "fourier": lambda x: x.astype(np.complex64), + "abs": lambda x: (np.absolute(x)).real.astype(np.float32)} + + +#: available tapers of :func:`~syncopy.freqanalysis` and :func:`~syncopy.connectivity` +all_windows = windows.__all__ +all_windows.remove("exponential") # not symmetric +all_windows.remove("hanning") # deprecated +availableTapers = all_windows + +#: general, method agnostic, parameters for our CRs +generalParameters = ("method", "output", "keeptrials","samplerate", + "foi", "foilim", "polyremoval", "out") diff --git a/syncopy/shared/filetypes.py b/syncopy/shared/filetypes.py index f3717ff76..874a2d8f4 100644 --- a/syncopy/shared/filetypes.py +++ b/syncopy/shared/filetypes.py @@ -1,13 +1,13 @@ # -*- coding: utf-8 -*- -# +# # Supported Syncopy classes and file extensions -# +# def _data_classname_to_extension(name): return "." + name.split('Data')[0].lower() # data file extensions are first word of data class name in lower-case -supportedClasses = ('AnalogData', 'SpectralData', # ContinousData +supportedClasses = ('AnalogData', 'SpectralData', 'CrossSpectralData', # ContinousData 'SpikeData', 'EventData', # DiscreteData 'TimelockData', ) # StatisticalData diff --git a/syncopy/shared/input_validators.py b/syncopy/shared/input_validators.py new file mode 100644 index 000000000..be40ce21e --- /dev/null +++ b/syncopy/shared/input_validators.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- +# +# Validators for user submitted frontend arguments like foi, taper, etc. +# Input args are the parameters to check for validity + auxiliary parameters +# needed for the checks. +# + +# Builtin/3rd party package imports +import numpy as np + +from syncopy.shared.errors import SPYValueError, SPYWarning, SPYInfo +from syncopy.shared.parsers import scalar_parser, array_parser +from syncopy.shared.const_def import availableTapers, generalParameters + + +def validate_foi(foi, foilim, samplerate): + + """ + Parameters + ---------- + foi : 'all' or array like or None + frequencies of interest + foilim : 2-element sequence or None + foi limits + + Other Parameters + ---------------- + samplerate : float + the samplerate in Hz + + Returns + ------- + foi, foilim : tuple + Either both are `None` or the + user submitted one is parsed and returned + + Notes + ----- + Setting both `foi` and `foilim` to `None` is valid, the + subsequent analysis methods should all have a default way to + select a standard set of frequencies (e.g. np.fft.fftfreq). + """ + + if foi is not None and foilim is not None: + lgl = "either `foi` or `foilim` specification" + act = "both" + raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act) + + if foi is not None: + if isinstance(foi, str): + if foi == "all": + foi = None + else: + raise SPYValueError(legal="'all' or `None` or list/array", + varname="foi", actual=foi) + else: + try: + array_parser(foi, varname="foi", hasinf=False, hasnan=False, + lims=[0, samplerate/2], dims=(None,)) + except Exception as exc: + raise exc + foi = np.array(foi, dtype="float") + + if foilim is not None: + if isinstance(foilim, str): + if foilim == "all": + foilim = None + else: + raise SPYValueError(legal="'all' or `None` or `[fmin, fmax]`", + varname="foilim", actual=foilim) + else: + try: + array_parser(foilim, varname="foilim", hasinf=False, hasnan=False, + lims=[0, samplerate/2], dims=(2,)) + except Exception as exc: + raise exc + # foilim is of shape (2,) + if foilim[0] > foilim[1]: + msg = "Sorting foilim low to high.." + SPYInfo(msg) + foilim = np.sort(foilim) + + return foi, foilim + + +def validate_taper(taper, + tapsmofrq, + nTaper, + keeptapers, + foimax, + samplerate, + nSamples, + output): + + """ + General taper validation and Slepian/dpss input sanitization. + The default is to max out `nTaper` to achieve the desired frequency + smoothing bandwidth. For details about the Slepion settings see + + "The Effective Bandwidth of a Multitaper Spectral Estimator, + A. T. Walden, E. J. McCoy and D. B. Percival" + + Parameters + ---------- + taper : str + Windowing function, one of :data:`~syncopy.shared.const_def.availableTapers` + tapsmofrq : float or None + Taper smoothing bandwidth for `taper='dpss'` + nTaper : int_like or None + Number of tapers to use for multi-tapering (not recommended) + + Other Parameters + ---------------- + keeptapers : bool + foimax : float + Maximum frequency for the analysis + samplerate : float + the samplerate in Hz + nSamples : int + Number of samples + output : str, one of {'abs', 'pow', 'fourier'} + Fourier transformation output type + + Returns + ------- + dpss_opt : dict + For multi-tapering (`taper='dpss'`) contains the + parameters `NW` and `Kmax` for `scipy.signal.windows.dpss`. + For all other tapers this is an empty dictionary. + """ + + # See if taper choice is supported + if taper not in availableTapers: + lgl = "'" + "or '".join(opt + "' " for opt in availableTapers) + raise SPYValueError(legal=lgl, varname="taper", actual=taper) + + # Warn user about DPSS only settings + if taper != "dpss": + if tapsmofrq is not None: + msg = "`tapsmofrq` is only used if `taper` is `dpss`!" + SPYWarning(msg) + if nTaper is not None: + msg = "`nTaper` is only used if `taper` is `dpss`!" + SPYWarning(msg) + if keeptapers: + msg = "`keeptapers` is only used if `taper` is `dpss`!" + SPYWarning(msg) + + # empty dpss_opt, only Slepians have options + return {} + + # direct mtm estimate (averaging) only valid for spectral power + if taper == "dpss" and not keeptapers and output != "pow": + lgl = "'pow', the only valid option for taper averaging" + raise SPYValueError(legal=lgl, varname="output", actual=output) + + # Set/get `tapsmofrq` if we're working w/Slepian tapers + elif taper == "dpss": + + # --- minimal smoothing bandwidth --- + # --- such that Kmax/nTaper is at least 1 + minBw = 2 * samplerate / nSamples + # ----------------------------------- + + # user set tapsmofrq directly + if tapsmofrq is not None: + try: + scalar_parser(tapsmofrq, varname="tapsmofrq", lims=[0, np.inf]) + except Exception as exc: + raise exc + + if tapsmofrq < minBw: + msg = f'Setting tapsmofrq to the minimal attainable bandwidth of {minBw:.2f}Hz' + SPYInfo(msg) + tapsmofrq = minBw + + # we now enforce a user submitted smoothing bw + else: + lgl = "smoothing bandwidth in Hz, typical values are in the range 1-10Hz" + raise SPYValueError(legal=lgl, varname="tapsmofrq", actual=tapsmofrq) + + # Try to derive "sane" settings by using 3/4 octave + # smoothing of highest `foi` + # following Hill et al. "Oscillatory Synchronization in Large-Scale + # Cortical Networks Predicts Perception", Neuron, 2011 + # FIX ME: This "sane setting" seems quite excessive (huuuge bwidths) + + # tapsmofrq = (foimax * 2**(3 / 4 / 2) - foimax * 2**(-3 / 4 / 2)) / 2 + # msg = f'Automatic setting of `tapsmofrq` to {tapsmofrq:.2f}' + # SPYInfo(msg) + + # -------------------------------------------- + # set parameters for scipy.signal.windows.dpss + NW = tapsmofrq * nSamples / (2 * samplerate) + # from the minBw setting NW always is at least 1 + Kmax = int(2 * NW - 1) # optimal number of tapers + # -------------------------------------------- + + # the recommended way: + # set nTaper automatically to achieve exact effective smoothing bandwidth + if nTaper is None: + msg = f'Using {Kmax} taper(s) for multi-tapering' + SPYInfo(msg) + dpss_opt = {'NW' : NW, 'Kmax' : Kmax} + return dpss_opt + + elif nTaper is not None: + try: + scalar_parser(nTaper, + varname="nTaper", + ntype="int_like", lims=[1, np.inf]) + except Exception as exc: + raise exc + + if nTaper != Kmax: + msg = f''' + Manually setting the number of tapers is not recommended + and may (strongly) distort the effective smoothing bandwidth!\n + The optimal number of tapers is {Kmax}, you have chosen to use {nTaper}. + ''' + SPYWarning(msg) + + dpss_opt = {'NW' : NW, 'Kmax' : nTaper} + return dpss_opt + + +def check_effective_parameters(CR, defaults, lcls): + + """ + For a given ComputationalRoutine, compare set parameters + (*lcls*) with the accepted parameters and the frontend + meta function *defaults* to warn if any ineffective parameters are set. + + Parameters + ---------- + CR : :class:`~syncopy.shared.computational_routine.ComputationalRoutine + Needs to have a `valid_kws` attribute + defaults : dict + Result of :func:`~syncopy.shared.tools.get_defaults`, the frontend + parameter names plus values with default values + lcls : dict + Result of `locals()`, all names and values of the local (frontend-)name space + """ + # list of possible parameter names of the CR + expected = CR.valid_kws + ["parallel", "select"] + relevant = [name for name in defaults if name not in generalParameters] + for name in relevant: + if name not in expected and (lcls[name] != defaults[name]): + msg = f"option `{name}` has no effect in method `{CR.__name__}`!" + SPYWarning(msg, caller=__name__.split('.')[-1]) + + +def check_passed_kwargs(lcls, defaults, frontend_name): + + ''' + Catch additional kwargs passed to the frontends + which have no effect + ''' + + # unpack **kwargs of frontend call which + # might contain arbitrary kws passed from the user + kw_dict = lcls.get("kwargs") + + # nothing to do.. + if not kw_dict: + return + + relevant = list(kw_dict.keys()) + expected = [name for name in defaults] + + for name in relevant: + if name not in expected: + msg = f"option `{name}` has no effect in `{frontend_name}`!" + SPYWarning(msg, caller=__name__.split('.')[-1]) + diff --git a/syncopy/shared/kwarg_decorators.py b/syncopy/shared/kwarg_decorators.py index 41235c25a..a4f588399 100644 --- a/syncopy/shared/kwarg_decorators.py +++ b/syncopy/shared/kwarg_decorators.py @@ -117,12 +117,16 @@ def unwrap_cfg(func): """ # Perform a little introspection gymnastics to get the name of the first - # positional and keyword argument of `func` + # positional and keyword argument of `func` (if we only find anonymous `**kwargs`, + # come up with an exemplary keyword - `kwarg0` is only used in the generated docstring) funcParams = inspect.signature(func).parameters paramList = list(funcParams) kwargList = [pName for pName, pVal in funcParams.items() if pVal.default != pVal.empty] arg0 = paramList[0] - kwarg0 = kwargList[0] + if len(kwargList) > 0: + kwarg0 = kwargList[0] + else: + kwarg0 = "some_parameter" @functools.wraps(func) def wrapper_cfg(*args, **kwargs): @@ -237,12 +241,6 @@ def wrapper_cfg(*args, **kwargs): else: posargs.append(arg) - # At this point, `data` is a list: if it's empty, not a single Syncopy data object - # was provided (neither via `cfg`, `kwargs`, or `args`) and the call is invalid - if len(data) == 0: - err = "{0} missing mandatory argument: `{1}`" - raise SPYError(err.format(func.__name__, arg0)) - # Call function with unfolded `data` + modified positional/keyword args return func(*data, *posargs, **cfg) @@ -477,7 +475,7 @@ def parallel_client_detector(*args, **kwargs): kwargs["parallel"] = parallel # Process provided object(s) - if nObs == 1: + if nObs <= 1: results = func(*args, **kwargs) else: results = [] diff --git a/syncopy/specest/compRoutines.py b/syncopy/specest/compRoutines.py index 0e5ec9ba6..f4a366258 100644 --- a/syncopy/specest/compRoutines.py +++ b/syncopy/specest/compRoutines.py @@ -15,7 +15,7 @@ # method_keys : list of names of the backend method parameters # cF_keys : list of names of the parameters of the middleware computeFunctions # -# the backend method name als gets explictly attached as a class constant: +# the backend method name als gets explicitly attached as a class constant: # method: backend method name # Builtin/3rd party package imports @@ -36,7 +36,7 @@ from syncopy.shared.tools import best_match from syncopy.shared.computational_routine import ComputationalRoutine from syncopy.shared.kwarg_decorators import unwrap_io -from syncopy.specest.const_def import ( +from syncopy.shared.const_def import ( spectralConversions, spectralDTypes, ) @@ -47,12 +47,10 @@ # ----------------------- @unwrap_io -def mtmfft_cF(trl_dat, foi=None, timeAxis=0, - keeptapers=True, nTaper=None, tapsmofrq=None, +def mtmfft_cF(trl_dat, foi=None, timeAxis=0, keeptapers=True, pad="nextpow2", padtype="zero", padlength=None, polyremoval=None, output_fmt="pow", - noCompute=False, chunkShape=None, - method_kwargs=None): + noCompute=False, chunkShape=None, method_kwargs=None): """ Compute (multi-)tapered Fourier transform of multi-channel time series data @@ -67,19 +65,10 @@ def mtmfft_cF(trl_dat, foi=None, timeAxis=0, data length and padding) are used. timeAxis : int Index of running time axis in `trl_dat` (0 or 1) - tapsmofrq : float - The amount of spectral smoothing through multi-tapering (Hz) for Slepian - tapers (`taper`="dpss"). keeptapers : bool If `True`, return spectral estimates for each taper. Otherwise power spectrum is averaged across tapers, only valid spectral estimate if `output_fmt` is `pow`. - nTaper : int - Only effective if ``taper='dpss'``. Number of orthogonal tapers to use. - tapsmofrq : float - Only effective if ``taper='dpss'``. The amount of spectral smoothing through - multi-tapering (Hz). Note that smoothing frequency specifications are one-sided, - i.e., 4 Hz smoothing means plus-minus 4 Hz, i.e., a 8 Hz smoothing box. pad : str Padding mode; one of `'absolute'`, `'relative'`, `'maxlen'`, or `'nextpow2'`. See :func:`syncopy.padding` for more information. @@ -91,13 +80,11 @@ def mtmfft_cF(trl_dat, foi=None, timeAxis=0, Number of samples to pad to data (if `pad` is 'absolute' or 'relative'). See :func:`syncopy.padding` for more information. polyremoval : int or None - **FIXME: Not implemented yet** Order of polynomial used for de-trending data in the time domain prior to spectral analysis. A value of 0 corresponds to subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the - least squares fit of a linear polynomial), ``polyremoval = N`` for `N > 1` - subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic - etc.). If `polyremoval` is `None`, no de-trending is performed. + least squares fit of a linear polynomial). + If `polyremoval` is `None`, no de-trending is performed. output_fmt : str Output of spectral estimation; one of :data:`~syncopy.specest.const_def.availableOutputs` noCompute : bool @@ -111,7 +98,6 @@ def mtmfft_cF(trl_dat, foi=None, timeAxis=0, Keyword arguments passed to :func:`~syncopy.specest.mtmfft.mtmfft` controlling the spectral estimation method - Returns ------- spec : :class:`numpy.ndarray` @@ -119,7 +105,8 @@ def mtmfft_cF(trl_dat, foi=None, timeAxis=0, Notes ----- - This method is intended to be used as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + This method is intended to be used as + :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. Thus, input parameters are presumed to be forwarded from a parent metafunction. Consequently, this function does **not** perform any error checking and operates @@ -136,13 +123,6 @@ def mtmfft_cF(trl_dat, foi=None, timeAxis=0, numpy.fft.rfft : NumPy's FFT implementation """ - # Slepian window parameters - if method_kwargs['taper'] == "dpss": - taperopt = {"Kmax" : nTaper, "NW" : tapsmofrq} - else: - taperopt = {} - - method_kwargs['taperopt'] = taperopt # Re-arrange array if necessary and get dimensional information if timeAxis != 0: @@ -161,6 +141,7 @@ def mtmfft_cF(trl_dat, foi=None, timeAxis=0, freqs = np.fft.rfftfreq(nSamples, 1 / method_kwargs["samplerate"]) _, freq_idx = best_match(freqs, foi, squash_duplicates=True) nFreq = freq_idx.size + nTaper = method_kwargs["taper_opt"].get('Kmax', 1) outShape = (1, max(1, nTaper * keeptapers), nFreq, nChannels) # For initialization of computational routine, @@ -168,6 +149,12 @@ def mtmfft_cF(trl_dat, foi=None, timeAxis=0, if noCompute: return outShape, spectralDTypes[output_fmt] + # detrend, does not work with 'FauxTrial' data.. + if polyremoval == 0: + dat = signal.detrend(dat, type='constant', axis=0, overwrite_data=True) + elif polyremoval == 1: + dat = signal.detrend(dat, type='linear', axis=0, overwrite_data=True) + # call actual specest method res, _ = mtmfft(dat, **method_kwargs) @@ -197,11 +184,11 @@ class MultiTaperFFT(ComputationalRoutine): computeFunction = staticmethod(mtmfft_cF) - method = "mtmfft" # 1st argument,the data, gets omitted - method_keys = list(signature(mtmfft).parameters.keys())[1:] - # here also last argument, the method_kwargs, are omitted - cF_keys = list(signature(mtmfft_cF).parameters.keys())[1:-1] + valid_kws = list(signature(mtmfft).parameters.keys())[1:] + valid_kws += list(signature(mtmfft_cF).parameters.keys())[1:] + # hardcode some parameter names which got digested from the frontend + valid_kws += ['tapsmofrq', 'nTaper'] def process_metadata(self, data, out): @@ -249,7 +236,7 @@ def mtmconvol_cF( toi=None, foi=None, nTaper=1, tapsmofrq=None, timeAxis=0, - keeptapers=True, polyremoval=None, output_fmt="pow", + keeptapers=True, polyremoval=0, output_fmt="pow", noCompute=False, chunkShape=None, method_kwargs=None): """ Perform time-frequency analysis on multi-channel time series data using a sliding window FFT @@ -292,7 +279,7 @@ def mtmconvol_cF( Index of running time axis in `trl_dat` (0 or 1) taper : callable Taper function to use, one of :data:`~syncopy.specest.const_def.availableTapers` - taperopt : dict + taper_opt : dict Additional keyword arguments passed to `taper` (see above). For further details, please refer to the `SciPy docs `_ @@ -300,13 +287,11 @@ def mtmconvol_cF( If `True`, results of Fourier transform are preserved for each taper, otherwise spectrum is averaged across tapers. polyremoval : int - **FIXME: Not implemented yet** - Order of polynomial used for de-trending. A value of 0 corresponds to - subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear - trends (subtracting the least squares fit of a linear function), - ``polyremoval = N`` for `N > 1` subtracts a polynomial of order `N` (``N = 2`` - quadratic, ``N = 3`` cubic etc.). If `polyremoval` is `None`, no de-trending - is performed. + Order of polynomial used for de-trending data in the time domain prior + to spectral analysis. A value of 0 corresponds to subtracting the mean + ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the + least squares fit of a linear polynomial). Detrending is done on each segment! + If `polyremoval` is `None`, no de-trending is performed. output_fmt : str Output of spectral estimation; one of :data:`~syncopy.specest.const_def.availableOutputs` noCompute : bool @@ -316,6 +301,9 @@ def mtmconvol_cF( chunkShape : None or tuple If not `None`, represents shape of output object `spec` (respecting provided values of `nTaper`, `keeptapers` etc.) + method_kwargs : dict + Keyword arguments passed to :func:`~syncopy.specest.mtmconvol.mtmconvol` + controlling the spectral estimation method Returns ------- @@ -354,14 +342,6 @@ def mtmconvol_cF( dat = padding(dat, "zero", pad="relative", padlength=None, prepadlength=padbegin, postpadlength=padend) - # Slepian window parameters - if method_kwargs['taper'] == "dpss": - taperopt = {"Kmax" : nTaper, "NW" : tapsmofrq} - else: - taperopt = {} - - method_kwargs['taperopt'] = taperopt - # Get shape of output for dry-run phase nChannels = dat.shape[1] if isinstance(toi, np.ndarray): # `toi` is an array of time-points @@ -377,9 +357,18 @@ def mtmconvol_cF( if noCompute: return outShape, spectralDTypes[output_fmt] + # detrending options for each segment + if polyremoval == 0: + detrend = 'constant' + elif polyremoval == 1: + detrend = 'linear' + else: + detrend = False + # additional keyword args for `stft` in dictionary method_kwargs.update({"boundary": stftBdry, - "padded": stftPad}) + "padded": stftPad, + "detrend" : detrend}) if equidistant: ftr, freqs = mtmconvol(dat[soi, :], **method_kwargs) @@ -392,18 +381,18 @@ def mtmconvol_cF( # every individual soi, so we can use mtmfft! samplerate = method_kwargs['samplerate'] taper = method_kwargs['taper'] - taperopt = method_kwargs['taperopt'] + taper_opt = method_kwargs['taper_opt'] # In case tapers aren't preserved allocate `spec` "too big" # and average afterwards spec = np.full((nTime, nTaper, nFreq, nChannels), np.nan, dtype=spectralDTypes[output_fmt]) - ftr, freqs = mtmfft(dat[soi[0], :], samplerate, taper, taperopt) + ftr, freqs = mtmfft(dat[soi[0], :], samplerate, taper, taper_opt) _, fIdx = best_match(freqs, foi, squash_duplicates=True) spec[0, ...] = spectralConversions[output_fmt](ftr[:, fIdx, :]) # loop over remaining soi to center windows on for tk in range(1, len(soi)): - ftr, freqs = mtmfft(dat[soi[tk], :], samplerate, taper, taperopt) + ftr, freqs = mtmfft(dat[soi[tk], :], samplerate, taper, taper_opt) spec[tk, ...] = spectralConversions[output_fmt](ftr[:, fIdx, :]) # Average across tapers if wanted @@ -428,6 +417,12 @@ class MultiTaperFFTConvol(ComputationalRoutine): computeFunction = staticmethod(mtmconvol_cF) + # 1st argument,the data, gets omitted + valid_kws = list(signature(mtmconvol).parameters.keys())[1:] + valid_kws += list(signature(mtmconvol_cF).parameters.keys())[1:] + # hardcode some parameter names which got digested from the frontend + valid_kws += ['tapsmofrq', 't_ftimwin', 'nTaper'] + def process_metadata(self, data, out): # Get trialdef array + channels from source @@ -468,7 +463,7 @@ def wavelet_cF( postselect, toi=None, timeAxis=0, - polyremoval=None, + polyremoval=0, output_fmt="pow", noCompute=False, chunkShape=None, @@ -495,13 +490,11 @@ def wavelet_cF( timeAxis : int Index of running time axis in `trl_dat` (0 or 1) polyremoval : int - **FIXME: Not implemented yet** - Order of polynomial used for de-trending. A value of 0 corresponds to - subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear - trends (subtracting the least squares fit of a linear function), - ``polyremoval = N`` for `N > 1` subtracts a polynomial of order `N` (``N = 2`` - quadratic, ``N = 3`` cubic etc.). If `polyremoval` is `None`, no de-trending - is performed. + Order of polynomial used for de-trending data in the time domain prior + to spectral analysis. A value of 0 corresponds to subtracting the mean + ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the + least squares fit of a linear polynomial). + If `polyremoval` is `None`, no de-trending is performed. output_fmt : str Output of spectral estimation; one of :data:`~syncopy.specest.const_def.availableOutputs` noCompute : bool @@ -562,6 +555,12 @@ def wavelet_cF( if noCompute: return outShape, spectralDTypes[output_fmt] + # detrend, does not work with 'FauxTrial' data.. + if polyremoval == 0: + dat = signal.detrend(dat, type='constant', axis=0, overwrite_data=True) + elif polyremoval == 1: + dat = signal.detrend(dat, type='linear', axis=0, overwrite_data=True) + # ------------------ # actual method call # ------------------ @@ -588,11 +587,10 @@ class WaveletTransform(ComputationalRoutine): computeFunction = staticmethod(wavelet_cF) - method = "wavelet" # 1st argument,the data, gets omitted - method_keys = list(signature(wavelet).parameters.keys())[1:] + valid_kws = list(signature(wavelet).parameters.keys())[1:] # here also last argument, the method_kwargs, are omitted - cF_keys = list(signature(wavelet_cF).parameters.keys())[1:-1] + valid_kws += list(signature(wavelet_cF).parameters.keys())[1:-1] def process_metadata(self, data, out): @@ -633,10 +631,9 @@ def superlet_cF( trl_dat, preselect, postselect, - # padbegin, # were always 0! - # padend, toi=None, timeAxis=0, + polyremoval=0, output_fmt="pow", noCompute=False, chunkShape=None, @@ -661,6 +658,14 @@ def superlet_cF( Either array of equidistant time-points or `"all"` to perform analysis on all samples in `trl_dat`. Please refer to :func:`~syncopy.freqanalysis` for further details. + timeAxis : int + Index of running time axis in `trl_dat` (0 or 1) + polyremoval : int or None + Order of polynomial used for de-trending data in the time domain prior + to spectral analysis. A value of 0 corresponds to subtracting the mean + ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the + least squares fit of a linear polynomial). + If `polyremoval` is `None`, no de-trending is performed. output_fmt : str Output of spectral estimation; one of :data:`~syncopy.specest.const_def.availableOutputs` @@ -679,7 +684,7 @@ def superlet_cF( ------- gmean_spec : :class:`numpy.ndarray` Complex time-frequency representation of the input data. - Shape is (nTime, 1, nScales, nChannels). + Shape is ``(nTime, 1, nScales, nChannels)``. Notes ----- @@ -694,8 +699,8 @@ def superlet_cF( -------- syncopy.freqanalysis : parent metafunction SuperletTransform : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` - instance that calls this method as - :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` + instance that calls this method as + :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` """ @@ -716,6 +721,12 @@ def superlet_cF( if noCompute: return outShape, spectralDTypes[output_fmt] + # detrend, does not work with 'FauxTrial' data.. + if polyremoval == 0: + dat = signal.detrend(dat, type='constant', axis=0, overwrite_data=True) + elif polyremoval == 1: + dat = signal.detrend(dat, type='linear', axis=0, overwrite_data=True) + # ------------------ # actual method call # ------------------ @@ -741,11 +752,9 @@ class SuperletTransform(ComputationalRoutine): computeFunction = staticmethod(superlet_cF) - method = "superlet" # 1st argument,the data, gets omitted - method_keys = list(signature(superlet).parameters.keys())[1:] - # here also last argument, the method_kwargs, are omitted - cF_keys = list(signature(superlet_cF).parameters.keys())[1:-1] + valid_kws = list(signature(superlet).parameters.keys())[1:] + valid_kws += list(signature(superlet_cF).parameters.keys())[1:-1] def process_metadata(self, data, out): @@ -760,9 +769,6 @@ def process_metadata(self, data, out): # Construct trialdef array and compute new sampling rate trl, srate = _make_trialdef(self.cfg, trl, data.samplerate) - # Construct trialdef array and compute new sampling rate - trl, srate = _make_trialdef(self.cfg, trl, data.samplerate) - # If trial-averaging was requested, use the first trial as reference # (all trials had to have identical lengths), and average onset timings if not self.keeptrials: diff --git a/syncopy/specest/const_def.py b/syncopy/specest/const_def.py index 2a20069a6..a15cf744f 100644 --- a/syncopy/specest/const_def.py +++ b/syncopy/specest/const_def.py @@ -1,33 +1,16 @@ # -*- coding: utf-8 -*- # -# Constant definitions and helper functions for spectral estimations +# Constant definitions specific for spectral estimations # -# Builtin/3rd party package imports -import numpy as np - -# Module-wide output specs -spectralDTypes = {"pow": np.float32, - "fourier": np.complex128, - "abs": np.float32} - -#: output conversion of complex fourier coefficients -spectralConversions = {"pow": lambda x: (x * np.conj(x)).real.astype(np.float32), - "fourier": lambda x: x.astype(np.complex128), - "abs": lambda x: (np.absolute(x)).real.astype(np.float32)} +from syncopy.shared.const_def import spectralConversions #: available outputs of :func:`~syncopy.freqanalysis` availableOutputs = tuple(spectralConversions.keys()) -#: available tapers of :func:`~syncopy.freqanalysis` -availableTapers = ("hann", "dpss") - #: available wavelet functions of :func:`~syncopy.freqanalysis` availableWavelets = ("Morlet", "Paul", "DOG", "Ricker", "Marr", "Mexican_hat") #: available spectral estimation methods of :func:`~syncopy.freqanalysis` availableMethods = ("mtmfft", "mtmconvol", "wavelet", "superlet") -#: general, method agnostic, parameters of :func:`~syncopy.freqanalysis` -generalParameters = ("method", "output", "keeptrials", - "foi", "foilim", "polyremoval", "out") diff --git a/syncopy/specest/freqanalysis.py b/syncopy/specest/freqanalysis.py index 8441d4182..131bc7f4c 100644 --- a/syncopy/specest/freqanalysis.py +++ b/syncopy/specest/freqanalysis.py @@ -16,6 +16,14 @@ from syncopy.shared.kwarg_decorators import (unwrap_cfg, unwrap_select, detect_parallel_client) from syncopy.shared.tools import best_match +from syncopy.shared.const_def import spectralConversions + +from syncopy.shared.input_validators import ( + validate_taper, + validate_foi, + check_effective_parameters, + check_passed_kwargs +) # method specific imports - they should go! import syncopy.specest.wavelets as spywave @@ -24,11 +32,8 @@ # Local imports from .const_def import ( - spectralConversions, - availableTapers, availableWavelets, availableMethods, - generalParameters ) from .compRoutines import ( @@ -46,7 +51,7 @@ @detect_parallel_client def freqanalysis(data, method='mtmfft', output='fourier', keeptrials=True, foi=None, foilim=None, pad=None, padtype='zero', - padlength=None, polyremoval=None, + padlength=None, polyremoval=0, taper="hann", tapsmofrq=None, nTaper=None, keeptapers=False, toi="all", t_ftimwin=None, wavelet="Morlet", width=6, order=None, order_max=None, order_min=1, c_1=3, adaptive=False, @@ -63,18 +68,17 @@ def freqanalysis(data, method='mtmfft', output='fourier', * **foi**/**foilim** : frequencies of interest; either array of frequencies or frequency window (not both) * **keeptrials** : return individual trials or grand average - * **polyremoval** : de-trending method to use (0 = mean, 1 = linear, 2 = quadratic, - 3 = cubic, etc.) + * **polyremoval** : de-trending method to use (0 = mean, 1 = linear) List of available analysis methods and respective distinct options: - :func:`~syncopy.specest.mtmfft.mtmfft` : (Multi-)tapered Fourier transform + "mtmfft" : (Multi-)tapered Fourier transform Perform frequency analysis on time-series trial data using either a single taper window (Hanning) or many tapers based on the discrete prolate spheroidal sequence (DPSS) that maximize energy concentration in the main lobe. - * **taper** : one of :data:`~syncopy.specest.const_def.availableTapers` + * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers` * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz) * **nTaper** : number of orthogonal tapers for slepian tapers * **keeptapers** : return individual tapers or average @@ -87,7 +91,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', * **prepadlength** : number of samples to pre-pend to each trial * **postpadlength** : number of samples to append to each trial - :func:`~syncopy.specest.mtmconvol.mtmconvol` : (Multi-)tapered sliding window Fourier transform + "mtmconvol" : (Multi-)tapered sliding window Fourier transform Perform time-frequency analysis on time-series trial data based on a sliding window short-time Fourier transform using either a single Hanning taper or multiple DPSS tapers. @@ -105,7 +109,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', a window on every sample in the data. * **t_ftimwin** : sliding window length (in sec) - :func:`~syncopy.specest.wavelet.wavelet` : (Continuous non-orthogonal) wavelet transform + "wavelet" : (Continuous non-orthogonal) wavelet transform Perform time-frequency analysis on time-series trial data using a non-orthogonal continuous wavelet transform. @@ -118,7 +122,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', * **order** : Order of Paul wavelet function (>= 4) or derivative order of real-valued DOG wavelets (2 = mexican hat) - :func:`~syncopy.specest.superlet.superlet` : Superlet transform + "superlet" : Superlet transform Perform time-frequency analysis on time-series trial data using the super-resolution superlet transform (SLT) from [Moca2021]_. @@ -143,7 +147,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', output : str Output of spectral estimation. One of :data:`~syncopy.specest.const_def.availableOutputs` (see below); use `'pow'` for power spectrum (:obj:`numpy.float32`), `'fourier'` for complex - Fourier coefficients (:obj:`numpy.complex128`) or `'abs'` for absolute + Fourier coefficients (:obj:`numpy.complex64`) or `'abs'` for absolute values (:obj:`numpy.float32`). keeptrials : bool If `True` spectral estimates of individual trials are returned, otherwise @@ -192,13 +196,13 @@ def freqanalysis(data, method='mtmfft', output='fourier', samples to append to each trial. See :func:`syncopy.padding` for more information. polyremoval : int or None - **FIXME: Not implemented yet** Order of polynomial used for de-trending data in the time domain prior to spectral analysis. A value of 0 corresponds to subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the - least squares fit of a linear polynomial), ``polyremoval = N`` for `N > 1` - subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic - etc.). If `polyremoval` is `None`, no de-trending is performed. + least squares fit of a linear polynomial). + If `polyremoval` is `None`, no de-trending is performed. Note that + for spectral estimation de-meaning is very advisable and hence also the + default. taper : str Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. Windowing function, one of :data:`~syncopy.specest.const_def.availableTapers` (see below). @@ -322,6 +326,8 @@ def freqanalysis(data, method='mtmfft', output='fourier', # Get everything of interest in local namespace defaults = get_defaults(freqanalysis) lcls = locals() + # check for ineffective additional kwargs + check_passed_kwargs(lcls, defaults, frontend_name="freqanalysis") # Ensure a valid computational method was selected if method not in availableMethods: @@ -339,23 +345,22 @@ def freqanalysis(data, method='mtmfft', output='fourier', raise SPYTypeError(lcls[vname], varname=vname, expected="Bool") # If only a subset of `data` is to be processed, make some necessary adjustments - # and compute minimal sample-count across (selected) trials + # of the sampleinfo and trial lengths if data._selection is not None: + sinfo = data._selection.trialdefinition[:, :2] trialList = data._selection.trials - sinfo = np.zeros((len(trialList), 2)) - for tk, trlno in enumerate(trialList): - trl = data._preview_trial(trlno) - tsel = trl.idx[timeAxis] - if isinstance(tsel, list): - sinfo[tk, :] = [0, len(tsel)] - else: - sinfo[tk, :] = [trl.idx[timeAxis].start, trl.idx[timeAxis].stop] else: trialList = list(range(len(data.trials))) sinfo = data.sampleinfo lenTrials = np.diff(sinfo).squeeze() + if not lenTrials.shape: + lenTrials = lenTrials[None] numTrials = len(trialList) + # check polyremoval + if polyremoval is not None: + scalar_parser(polyremoval, varname="polyremoval", ntype="int_like", lims=[0, 1]) + # Sliding window FFT does not support "fancy" padding if method == "mtmconvol" and isinstance(pad, str): msg = "method 'mtmconvol' only supports in-place padding for windows " +\ @@ -417,51 +422,12 @@ def freqanalysis(data, method='mtmfft', output='fourier', # Shortcut to data sampling interval dt = 1 / data.samplerate - # Basic sanitization of frequency specifications - if foi is not None: - if isinstance(foi, str): - if foi == "all": - foi = None - else: - raise SPYValueError(legal="'all' or `None` or list/array", - varname="foi", actual=foi) - else: - try: - array_parser(foi, varname="foi", hasinf=False, hasnan=False, - lims=[0, data.samplerate/2], dims=(None,)) - except Exception as exc: - raise exc - foi = np.array(foi, dtype="float") - if foilim is not None: - if isinstance(foilim, str): - if foilim == "all": - foilim = None - else: - raise SPYValueError(legal="'all' or `None` or `[fmin, fmax]`", - varname="foilim", actual=foilim) - else: - try: - array_parser(foilim, varname="foilim", hasinf=False, hasnan=False, - lims=[0, data.samplerate/2], dims=(2,)) - except Exception as exc: - raise exc - # foilim is of shape (2,) - if foilim[0] > foilim[1]: - msg = "Sorting foilim low to high.." - SPYInfo(msg) - foilim = np.sort(foilim) - - if foi is not None and foilim is not None: - lgl = "either `foi` or `foilim` specification" - act = "both" - raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act) - - # FIXME: implement detrending + foi, foilim = validate_foi(foi, foilim, data.samplerate) + # see also https://docs.obspy.org/_modules/obspy/signal/detrend.html#polynomial if polyremoval is not None: - raise NotImplementedError("Detrending has not been implemented yet.") try: - scalar_parser(polyremoval, varname="polyremoval", lims=[0, 8], ntype="int_like") + scalar_parser(polyremoval, varname="polyremoval", lims=[0, 1], ntype="int_like") except Exception as exc: raise exc @@ -474,8 +440,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', "polyremoval": polyremoval, "pad": lcls["pad"], "padtype": lcls["padtype"], - "padlength": lcls["padlength"], - "foi": lcls["foi"]} + "padlength": lcls["padlength"]} # -------------------------------- # 1st: Check time-frequency inputs @@ -559,6 +524,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', scalar_parser(t_ftimwin, varname="t_ftimwin", lims=[dt, minTrialLength]) except Exception as exc: + SPYInfo("Please specify 't_ftimwin' parameter.. exiting!") raise exc # this is the effective sliding window FFT sample size @@ -579,75 +545,30 @@ def freqanalysis(data, method='mtmfft', output='fourier', f"{freqs[-1]:.1f}Hz") SPYInfo(msg) foi = freqs - + log_dct["foi"] = foi + # Abort if desired frequency selection is empty if foi.size == 0: lgl = "non-empty frequency specification" act = "empty frequency selection" raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act) - # See if taper choice is supported - if taper not in availableTapers: - lgl = "'" + "or '".join(opt + "' " for opt in availableTapers) - raise SPYValueError(legal=lgl, varname="taper", actual=taper) - - # Warn user about DPSS only settings - if taper != "dpss": - if tapsmofrq is not None: - msg = "`tapsmofrq` is only used if `taper` is `dpss`!" - SPYWarning(msg) - if nTaper is not None: - msg = "`nTaper` is only used if `taper` is `dpss`!" - SPYWarning(msg) - if keeptapers: - msg = "`keeptapers` is only used if `taper` is `dpss`!" - SPYWarning(msg) - - # Set/get `tapsmofrq` if we're working w/Slepian tapers - if taper == "dpss": - - # direct mtm estimate (averaging) only valid for spectral power - if not keeptapers and output != "pow": - lgl = "'pow', the only valid option for taper averaging" - raise SPYValueError(legal=lgl, varname="output", actual=output) - - # Try to derive "sane" settings by using 3/4 octave - # smoothing of highest `foi` - # following Hill et al. "Oscillatory Synchronization in Large-Scale - # Cortical Networks Predicts Perception", Neuron, 2011 - if tapsmofrq is None: - foimax = foi.max() - tapsmofrq = (foimax * 2**(3/4/2) - foimax * 2**(-3/4/2)) / 2 - msg = f'Automatic setting of `tapsmofrq` to {tapsmofrq:.2f}' - SPYInfo(msg) - - else: - try: - scalar_parser(tapsmofrq, varname="tapsmofrq", lims=[1, np.inf]) - except Exception as exc: - raise exc - - # Get/compute number of tapers to use (at least 1 and max. 50) - if not nTaper: - nTaper = int(max(2, min(50, np.floor(tapsmofrq * minSampleNum * dt)))) - msg = f'Automatic setting of `nTaper` to {nTaper}' - SPYInfo(msg) - else: - try: - scalar_parser(nTaper, - varname="nTaper", - ntype="int_like", lims=[1, np.inf]) - except Exception as exc: - raise exc - - # only taper with frontend supported options is DPSS - else: - nTaper = 1 - - # Update `log_dct` w/method-specific options (use `lcls` to get actually - # provided keyword values, not defaults set in here) - log_dct["taper"] = lcls["taper"] - log_dct["tapsmofrq"] = lcls["tapsmofrq"] + # sanitize taper selection and retrieve dpss settings + taper_opt = validate_taper(taper, + tapsmofrq, + nTaper, + keeptapers, + foimax=foi.max(), + samplerate=data.samplerate, + nSamples=minSampleNum, + output=output) + + # Update `log_dct` w/method-specific options + log_dct["taper"] = taper + # only dpss returns non-empty taper_opt dict + if taper_opt: + log_dct["nTaper"] = taper_opt["Kmax"] + log_dct["tapsmofrq"] = tapsmofrq # ------------------------------------------------------- # Now, prepare explicit compute-classes for chosen method @@ -655,31 +576,31 @@ def freqanalysis(data, method='mtmfft', output='fourier', if method == "mtmfft": - _check_effective_parameters(MultiTaperFFT, defaults, lcls) + check_effective_parameters(MultiTaperFFT, defaults, lcls) # method specific parameters method_kwargs = { 'samplerate' : data.samplerate, - 'taper' : taper + 'taper' : taper, + 'taper_opt' : taper_opt } # Set up compute-class specestMethod = MultiTaperFFT( - samplerate=data.samplerate, foi=foi, timeAxis=timeAxis, pad=pad, padtype=padtype, padlength=padlength, keeptapers=keeptapers, - nTaper = nTaper, - tapsmofrq = tapsmofrq, polyremoval=polyremoval, output_fmt=output, method_kwargs=method_kwargs) elif method == "mtmconvol": + check_effective_parameters(MultiTaperFFTConvol, defaults, lcls) + # Process `toi` for sliding window multi taper fft, # we have to account for three scenarios: (1) center sliding # windows on all samples in (selected) trials (2) `toi` was provided as @@ -810,7 +731,8 @@ def freqanalysis(data, method='mtmfft', output='fourier', method_kwargs = {"samplerate": data.samplerate, "nperseg": nperseg, "noverlap": noverlap, - "taper" : taper} + "taper" : taper, + "taper_opt" : taper_opt} # Set up compute-class specestMethod = MultiTaperFFTConvol( @@ -821,9 +743,6 @@ def freqanalysis(data, method='mtmfft', output='fourier', equidistant=equidistant, toi=toi, foi=foi, - taper=taper, - nTaper=nTaper, - tapsmofrq=tapsmofrq, timeAxis=timeAxis, keeptapers=keeptapers, polyremoval=polyremoval, @@ -832,7 +751,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', elif method == "wavelet": - _check_effective_parameters(WaveletTransform, defaults, lcls) + check_effective_parameters(WaveletTransform, defaults, lcls) # Check wavelet selection if wavelet not in availableWavelets: @@ -894,6 +813,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', # Update `log_dct` w/method-specific options (use `lcls` to get actually # provided keyword values, not defaults set in here) + log_dct["foi"] = foi log_dct["wavelet"] = lcls["wavelet"] log_dct["width"] = lcls["width"] log_dct["order"] = lcls["order"] @@ -917,7 +837,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', elif method == "superlet": - _check_effective_parameters(SuperletTransform, defaults, lcls) + check_effective_parameters(SuperletTransform, defaults, lcls) # check and parse superlet specific arguments if order_max is None: @@ -969,6 +889,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', SPYWarning(msg) scales = np.sort(scales)[::-1] + log_dct["foi"] = foi log_dct["c_1"] = lcls["c_1"] log_dct["order_max"] = lcls["order_max"] log_dct["order_min"] = lcls["order_min"] @@ -989,6 +910,7 @@ def freqanalysis(data, method='mtmfft', output='fourier', postSelect, toi=toi, timeAxis=timeAxis, + polyremoval=polyremoval, output_fmt=output, method_kwargs=method_kwargs) @@ -1011,39 +933,10 @@ def freqanalysis(data, method='mtmfft', output='fourier', # Perform actual computation specestMethod.initialize(data, + out._stackingDim, chan_per_worker=kwargs.get("chan_per_worker"), keeptrials=keeptrials) specestMethod.compute(data, out, parallel=kwargs.get("parallel"), log_dict=log_dct) # Either return newly created output object or simply quit return out if new_out else None - - -def _check_effective_parameters(CR, defaults, lcls): - - ''' - For a given ComputationalRoutine, compare set parameters - (*lcls*) with the accepted parameters and the *defaults* - to warn if any ineffective parameters are set. - - #FIXME: If general structure of this function proofs - useful for all CRs/syncopy in general, - probably best to move this to syncopy.shared.tools - - Parameters - ---------- - - CR : :class:`~syncopy.shared.computational_routine.ComputationalRoutine - defaults : dict - Result of :func:`~syncopy.shared.tools.get_defaults`, the function - parameter names plus values with default values - lcls : dict - Result of `locals()`, all names and values of the local name space - ''' - # list of possible parameter names of the CR - expected = CR.method_keys + CR.cF_keys + ["parallel", "select"] - relevant = [name for name in defaults if name not in generalParameters] - for name in relevant: - if name not in expected and (lcls[name] != defaults[name]): - msg = f"option `{name}` has no effect in method `{CR.method}`!" - SPYWarning(msg, caller=__name__.split('.')[-1]) diff --git a/syncopy/specest/mtmconvol.py b/syncopy/specest/mtmconvol.py index 880ff1bda..fe273b2e0 100644 --- a/syncopy/specest/mtmconvol.py +++ b/syncopy/specest/mtmconvol.py @@ -9,9 +9,9 @@ def mtmconvol(data_arr, samplerate, nperseg, noverlap=None, taper="hann", - taperopt={}, boundary='zeros', padded=True): + taper_opt={}, boundary='zeros', padded=True, detrend=False): - ''' + """ (Multi-)tapered short time fast Fourier transform. Returns full complex Fourier transform for each taper. Multi-tapering only supported with Slepian windwows (`taper="dpss"`). @@ -26,14 +26,14 @@ def mtmconvol(data_arr, samplerate, nperseg, noverlap=None, taper="hann", nperseg : int Sliding window size in sample units noverlap : int - Overlap between consecutive windows, set to nperseg -1 + Overlap between consecutive windows, set to ``nperseg - 1`` to cover the whole signal taper : str or None - Taper function to use, one of scipy.signal.windows + Taper function to use, one of `scipy.signal.windows` Set to `None` for no tapering. - taperopt : dict + taper_opt : dict Additional keyword arguments passed to the `taper` function. - For multi-tapering with `taper='dpss'` set the keys + For multi-tapering with ``taper='dpss'`` set the keys `'Kmax'` and `'NW'`. For further details, please refer to the `SciPy docs `_ @@ -42,28 +42,27 @@ def mtmconvol(data_arr, samplerate, nperseg, noverlap=None, taper="hann", sample. If set to `False` half the window size (`nperseg`) will be lost on each side of the signal. padded : bool - Additional padding in case `noverlap != nperseg - 1` to fit an integer number + Additional padding in case ``noverlap != nperseg - 1`` to fit an integer number of windows. Returns ------- ftr : 4D :class:`numpy.ndarray` The Fourier transforms, complex output has shape: - (nTime, nTapers x nFreq x nChannels) + ``(nTime, nTapers x nFreq x nChannels)`` freqs : 1D :class:`numpy.ndarray` Array of Fourier frequencies Notes ----- - For a (MTM) power spectral estimate average the absolute squared transforms across tapers: - Sxx = np.real(ftr * ftr.conj()).mean(axis=0) + ``Sxx = np.real(ftr * ftr.conj()).mean(axis=0)`` The short time FFT result is normalized such that this yields the squared harmonic amplitudes. - ''' + """ # attach dummy channel axis in case only a # single signal/channel is the input @@ -84,14 +83,14 @@ def mtmconvol(data_arr, samplerate, nperseg, noverlap=None, taper="hann", # -> normalizes with win.sum() :/ # see also https://github.com/scipy/scipy/issues/14740 if taper == 'dpss': - taperopt['sym'] = False + taper_opt['sym'] = False # only truly 2d for multi-taper "dpss" - windows = np.atleast_2d(taper_func(nperseg, **taperopt)) + windows = np.atleast_2d(taper_func(nperseg, **taper_opt)) # Slepian normalization if taper == 'dpss': - windows = windows * np.sqrt(taperopt.get('Kmax', 1)) / np.sqrt(nperseg) + windows = windows * np.sqrt(taper_opt.get('Kmax', 1)) / np.sqrt(nperseg) # number of time points in the output if boundary is None: @@ -108,8 +107,8 @@ def mtmconvol(data_arr, samplerate, nperseg, noverlap=None, taper="hann", for taperIdx, win in enumerate(windows): # pxx has shape (nFreq, nChannels, nTime) _, _, pxx = signal.stft(data_arr, samplerate, win, - nperseg, noverlap, boundary=boundary, - padded=padded, axis=0) + nperseg, noverlap, boundary=boundary, + padded=padded, axis=0, detrend=detrend) if taper == 'dpss': # reverse scipy window normalization @@ -119,8 +118,3 @@ def mtmconvol(data_arr, samplerate, nperseg, noverlap=None, taper="hann", ftr[:, taperIdx, ...] = 2 * pxx.transpose(2, 0, 1)[:nTime, ...] return ftr, freqs - - - - - diff --git a/syncopy/specest/mtmfft.py b/syncopy/specest/mtmfft.py index faecc9cff..26bba7e98 100644 --- a/syncopy/specest/mtmfft.py +++ b/syncopy/specest/mtmfft.py @@ -1,20 +1,19 @@ # -*- coding: utf-8 -*- -# +# # Spectral estimation with (multi-)tapered FFT -# +# # Builtin/3rd party package imports import numpy as np from scipy import signal -def mtmfft(data_arr, samplerate, taper="hann", taperopt={}): - - ''' +def mtmfft(data_arr, samplerate, taper="hann", taper_opt=None): + """ (Multi-)tapered fast Fourier transform. Returns full complex Fourier transform for each taper. Multi-tapering only supported with Slepian windwows (`taper="dpss"`). - + Parameters ---------- data_arr : (N,) :class:`numpy.ndarray` @@ -23,33 +22,31 @@ def mtmfft(data_arr, samplerate, taper="hann", taperopt={}): samplerate : float Samplerate in Hz taper : str or None - Taper function to use, one of scipy.signal.windows + Taper function to use, one of `scipy.signal.windows` Set to `None` for no tapering. - taperopt : dict - Additional keyword arguments passed to the `taper` function. - For multi-tapering with `taper='dpss'` set the keys + taper_opt : dict or None + Additional keyword arguments passed to the `taper` function. + For multi-tapering with ``taper='dpss'`` set the keys `'Kmax'` and `'NW'`. - For further details, please refer to the + For further details, please refer to the `SciPy docs `_ Returns ------- - ftr : 3D :class:`numpy.ndarray` - Complex output has shape (nTapers x nFreq x nChannels). + Complex output has shape ``(nTapers x nFreq x nChannels)``. freqs : 1D :class:`numpy.ndarray` Array of Fourier frequencies - + Notes ----- - For a (MTM) power spectral estimate average the absolute squared transforms across tapers: - Sxx = np.real(ftr * ftr.conj()).mean(axis=0) + ``Sxx = np.real(ftr * ftr.conj()).mean(axis=0)`` - The FFT result is normalized such that this yields the squared amplitudes. - ''' + The FFT result is normalized such that this yields the squared amplitudes. + """ # attach dummy channel axis in case only a # single signal/channel is the input @@ -66,29 +63,31 @@ def mtmfft(data_arr, samplerate, taper="hann", taperopt={}): if taper is None: taper = 'boxcar' - taper_func = getattr(signal.windows, taper) + if taper_opt is None: + taper_opt = {} + + taper_func = getattr(signal.windows, taper) # only really 2d if taper='dpss' with Kmax > 1 - windows = np.atleast_2d(taper_func(nSamples, **taperopt)) - + windows = np.atleast_2d(taper_func(nSamples, **taper_opt)) + # only(!!) slepian windows are already normalized # still have to normalize by number of tapers # such that taper-averaging yields correct amplitudes if taper == 'dpss': - windows = windows * np.sqrt(taperopt.get('Kmax', 1)) - # per pedes L2 normalisation for all other tapers + windows = windows * np.sqrt(taper_opt.get('Kmax', 1)) + # per pedes L2 normalisation for all other tapers else: windows = windows * np.sqrt(nSamples) / np.sum(windows) - + # Fourier transforms (nTapers x nFreq x nChannels) - ftr = np.zeros((windows.shape[0], nFreq, nChannels), dtype='complex128') + ftr = np.zeros((windows.shape[0], nFreq, nChannels), dtype='complex64') for taperIdx, win in enumerate(windows): win = np.tile(win, (nChannels, 1)).T # real fft takes only 'half the energy'/positive frequencies, # multiply by 2 to correct for this ftr[taperIdx] = 2 * np.fft.rfft(data_arr * win, axis=0) - # normalization + # normalization ftr[taperIdx] /= np.sqrt(nSamples) return ftr, freqs - diff --git a/syncopy/tests/backend/test_connectivity.py b/syncopy/tests/backend/test_connectivity.py new file mode 100644 index 000000000..d041421ce --- /dev/null +++ b/syncopy/tests/backend/test_connectivity.py @@ -0,0 +1,323 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import matplotlib.pyplot as ppl +from syncopy.connectivity import ST_compRoutines as stCR +from syncopy.connectivity import AV_compRoutines as avCR +from syncopy.connectivity.wilson_sf import wilson_sf, regularize_csd +from syncopy.connectivity.granger import granger + + +def test_coherence(): + + """ + Tests the normalization cF to + arrive at the coherence given + a trial averaged csd + """ + + nSamples = 1001 + fs = 1000 + tvec = np.arange(nSamples) / fs + harm_freq = 40 + phase_shifts = np.array([0, np.pi / 2, np.pi]) + + nTrials = 100 + + # shape is (1, nFreq, nChannel, nChannel) + nFreq = nSamples // 2 + 1 + nChannel = len(phase_shifts) + avCSD = np.zeros((1, nFreq, nChannel, nChannel), dtype=np.complex64) + + for i in range(nTrials): + + # 1 phase phase shifted harmonics + white noise + constant, SNR = 1 + trl_dat = [10 + np.cos(harm_freq * 2 * np. pi * tvec + ps) + for ps in phase_shifts] + trl_dat = np.array(trl_dat).T + trl_dat = np.array(trl_dat) + np.random.randn(nSamples, len(phase_shifts)) + + # process every trial individually + CSD, freqs = stCR.cross_spectra_cF(trl_dat, fs, + polyremoval=1, + taper='hann', + norm=False, # this is important! + fullOutput=True) + + assert avCSD.shape == CSD.shape + avCSD += CSD + + # this is the trial average + avCSD /= nTrials + + # perform the normalisation on the trial averaged csd's + Cij = avCR.normalize_csd_cF(avCSD) + + # output has shape (1, nFreq, nChannels, nChannels) + assert Cij.shape == avCSD.shape + + # coherence between channel 0 and 1 + coh = Cij[0, :, 0, 1] + + fig, ax = ppl.subplots(figsize=(6,4), num=None) + ax.set_xlabel('frequency (Hz)') + ax.set_ylabel('coherence') + ax.set_ylim((-.02,1.05)) + ax.set_title('Trial average coherence, SNR=1') + + ax.plot(freqs, coh, lw=1.5, alpha=0.8, c='cornflowerblue') + + # we test for the highest peak sitting at + # the vicinity (± 5Hz) of the harmonic + peak_val = np.max(coh) + peak_idx = np.argmax(coh) + peak_freq = freqs[peak_idx] + + assert harm_freq - 5 < peak_freq < harm_freq + 5 + + # we test that the peak value + # is at least 0.9 and max 1 + assert 0.9 < peak_val < 1 + + # trial averaging should suppress the noise + # we test that away from the harmonic the coherence is low + level = 0.4 + assert np.all(coh[:peak_idx - 2] < level) + assert np.all(coh[peak_idx + 2:] < level) + + +def test_csd(): + + """ + Tests multi-tapered single trial cross spectral + densities + """ + + nSamples = 1001 + fs = 1000 + tvec = np.arange(nSamples) / fs + harm_freq = 40 + phase_shifts = np.array([0, np.pi / 2, np.pi]) + + # 1 phase phase shifted harmonics + white noise + constant, SNR = 1 + data = [10 + np.cos(harm_freq * 2 * np. pi * tvec + ps) + for ps in phase_shifts] + data = np.array(data).T + data = np.array(data) + np.random.randn(nSamples, len(phase_shifts)) + + bw = 5 #Hz + NW = nSamples * bw / (2 * fs) + Kmax = int(2 * NW - 1) # multiple tapers for single trial coherence + CSD, freqs = stCR.cross_spectra_cF(data, fs, + polyremoval=1, + taper='dpss', + taper_opt={'Kmax' : Kmax, 'NW' : NW}, + norm=True, + fullOutput=True) + + # output has shape (1, nFreq, nChannels, nChannels) + assert CSD.shape == (1, len(freqs), data.shape[1], data.shape[1]) + + # single trial coherence between channel 0 and 1 + coh = np.abs(CSD[0, :, 0, 1]) + + fig, ax = ppl.subplots(figsize=(6,4), num=None) + ax.set_xlabel('frequency (Hz)') + ax.set_ylabel('coherence') + ax.set_ylim((-.02,1.05)) + ax.set_title(f'MTM coherence, {Kmax} tapers, SNR=1') + + ax.plot(freqs, coh, lw=1.5, alpha=0.8, c='cornflowerblue') + + # we test for the highest peak sitting at + # the vicinity (± 5Hz) of one the harmonic + peak_val = np.max(coh) + peak_idx = np.argmax(coh) + peak_freq = freqs[peak_idx] + assert harm_freq - 5 < peak_freq < harm_freq + 5 + + # we test that the peak value + # is at least 0.9 and max 1 + assert 0.9 < peak_val < 1 + + +def test_cross_cov(): + + nSamples = 1001 + fs = 1000 + tvec = np.arange(nSamples) / fs + + cosine = np.cos(2 * np.pi * 30 * tvec) + sine = np.sin(2 * np.pi * 30 * tvec) + data = np.c_[cosine, sine] + + # output shape is (nLags x 1 x nChannels x nChannels) + CC, lags = stCR.cross_covariance_cF(data, samplerate=fs, norm=True, fullOutput=True) + + # test for result is returned in the [0, np.ceil(nSamples / 2)] lag interval + nLags = int(np.ceil(nSamples / 2)) + + # output has shape (nLags, 1, nChannels, nChannels) + assert CC.shape == (nLags, 1, data.shape[1], data.shape[1]) + + # cross-correlation (normalized cross-covariance) between + # cosine and sine analytically equals minus sine + assert np.all(CC[:, 0, 0, 1] + sine[:nLags] < 1e-5) + + +def test_wilson(): + """ + Test Wilson's spectral matrix factorization. + + As the routine has relative error-checking + inbuild, we just need to check for convergence. + """ + + # --- create test data --- + fs = 1000 + nChannels = 10 + nSamples = 1000 + f1, f2 = [30 , 40] # 30Hz and 60Hz + data = np.zeros((nSamples, nChannels)) + for i in range(nChannels): + # more phase diffusion in the 60Hz band + p1 = phase_evo(f1 * 2 * np.pi, eps=0.1, fs=fs, N=nSamples) + p2 = phase_evo(f2 * 2 * np.pi, eps=0.35, fs=fs, N=nSamples) + + data[:, i] = np.cos(p1) + 2 * np.sin(p2) + .5 * np.random.randn(nSamples) + + # --- get the (single trial) CSD --- + + bw = 5 # 5Hz smoothing + NW = bw * nSamples / (2 * fs) + Kmax = int(2 * NW - 1) # optimal number of tapers + + CSD, freqs = stCR.cross_spectra_cF(data, fs, + taper='dpss', + taper_opt={'Kmax' : Kmax, 'NW' : NW}, + norm=False, + fullOutput=True) + # strip off singleton time axis + CSD = CSD[0] + + # get CSD condition number, which is way too large! + CN = np.linalg.cond(CSD).max() + assert CN > 1e6 + + # --- regularize CSD --- + + CSDreg, fac = regularize_csd(CSD, cond_max=1e6, nSteps=25) + CNreg = np.linalg.cond(CSDreg).max() + assert CNreg < 1e6 + # check that 'small' regularization factor is enough + assert fac < 1e-5 + + # --- factorize CSD with Wilson's algorithm --- + + H, Sigma, conv = wilson_sf(CSDreg, rtol=1e-9) + + # converged - \Psi \Psi^* \approx CSD, + # with relative error <= rtol? + assert conv + + # reconstitute + CSDfac = H @ Sigma @ H.conj().transpose(0, 2, 1) + + fig, ax = ppl.subplots(figsize=(6, 4)) + ax.set_xlabel('frequency (Hz)') + ax.set_ylabel(r'$|CSD_{ij}(f)|$') + chan = nChannels // 2 + # show (real) auto-spectra + ax.plot(freqs, np.abs(CSD[:, chan, chan]), + '-o', label='original CSD', ms=3) + ax.plot(freqs, np.abs(CSDreg[:, chan, chan]), + '-o', label='regularized CSD', ms=3) + ax.plot(freqs, np.abs(CSDfac[:, chan, chan]), + '-o', label='factorized CSD', ms=3) + ax.set_xlim((f1 - 5, f2 + 5)) + ax.legend() + + +def test_granger(): + + """ + Test the granger causality measure + with uni-directionally coupled AR(2) + processes akin to the source publication: + + Dhamala, Mukeshwar, Govindan Rangarajan, and Mingzhou Ding. + "Estimating Granger causality from Fourier and wavelet transforms + of time series data." Physical review letters 100.1 (2008): 018701. + """ + + fs = 200 # Hz + nSamples = 2500 + nTrials = 50 + + # both AR(2) processes have same parameters + # and yield a spectral peak at 40Hz + alpha1, alpha2 = 0.55, -0.8 + coupling = 0.25 + + CSDav = np.zeros((nSamples // 2 + 1, 2, 2), dtype=np.complex64) + for _ in range(nTrials): + + # -- simulate 2 AR(2) processes -- + + sol = np.zeros((nSamples, 2)) + # pick the 1st values at random + xs_ini = np.random.randn(2, 2) + sol[:2, :] = xs_ini + for i in range(1, nSamples): + sol[i, 1] = alpha1 * sol[i - 1, 1] + alpha2 * sol[i - 2, 1] + sol[i, 1] += np.random.randn() + # X2 drives X1 + sol[i, 0] = alpha1 * sol[i - 1, 0] + alpha2 * sol[i - 2, 0] + sol[i, 0] += sol[i - 1, 1] * coupling + sol[i, 0] += np.random.randn() + + # --- get CSD --- + bw = 5 + NW = bw * nSamples / (2 * 1000) + Kmax = int(2 * NW - 1) # optimal number of tapers + CS2, freqs = stCR.cross_spectra_cF(sol, fs, + taper='dpss', + taper_opt={'Kmax' : Kmax, 'NW' : NW}, + fullOutput=True) + + CSD = CS2[0, ...] + CSDav += CSD + + CSDav /= nTrials + # with only 2 channels this CSD is well conditioned + assert np.linalg.cond(CSDav).max() < 1e2 + H, Sigma, conv = wilson_sf(CSDav) + + G = granger(CSDav, H, Sigma) + assert G.shape == CSDav.shape + + # check for directional causality at 40Hz + freq_idx = np.argmin(freqs < 40) + assert 39 < freqs[freq_idx] < 41 + + # check low to no causality for 1->2 + assert G[freq_idx, 0, 1] < 0.1 + # check high causality for 2->1 + assert G[freq_idx, 1, 0] > 0.8 + + fig, ax = ppl.subplots(figsize=(6, 4)) + ax.set_xlabel('frequency (Hz)') + ax.set_ylabel(r'Granger causality(f)') + ax.plot(freqs, G[:, 0, 1], label=r'Granger $1\rightarrow2$') + ax.plot(freqs, G[:, 1, 0], label=r'Granger $2\rightarrow1$') + ax.legend() + +# --- Helper routines --- + + +# noisy phase evolution -> phase diffusion +def phase_evo(omega0, eps, fs=1000, N=1000): + wn = np.random.randn(N) + delta_ts = np.ones(N) * 1 / fs + phase = np.cumsum(omega0 * delta_ts + eps * wn) + return phase diff --git a/syncopy/tests/backend/test_timefreq.py b/syncopy/tests/backend/test_timefreq.py index 9e1866479..929f293a6 100644 --- a/syncopy/tests/backend/test_timefreq.py +++ b/syncopy/tests/backend/test_timefreq.py @@ -12,42 +12,42 @@ def gen_testdata(freqs=[20, 40, 60], cycles=11, fs=1000, eps = 0): - ''' + """ Harmonic superposition of multiple few-cycle oscillations akin to the example of Figure 3 in Moca et al. 2021 NatComm Each harmonic has a frequency neighbor with +10Hz and a time neighbor after 2 cycles(periods). - ''' + """ signal = [] for freq in freqs: - + # 10 cycles of f1 tvec = np.arange(cycles / freq, step=1 / fs) harmonic = np.cos(2 * np.pi * freq * tvec) # frequency neighbor - f_neighbor = np.cos(2 * np.pi * (freq + 10) * tvec) + f_neighbor = np.cos(2 * np.pi * (freq + 10) * tvec) packet = harmonic + f_neighbor # 2 cycles time neighbor delta_t = np.zeros(int(2 / freq * fs)) - + # 5 cycles break pad = np.zeros(int(5 / freq * fs)) signal.extend([pad, packet, delta_t, harmonic]) - # stack the packets together with some padding + # stack the packets together with some padding signal.append(pad) signal = np.concatenate(signal) # additive white noise if eps > 0: signal = np.random.randn(len(signal)) * eps + signal - + return signal @@ -60,7 +60,7 @@ def gen_testdata(freqs=[20, 40, 60], # signal_freqs = np.array([20, 70]) cycles = 12 A = 5 # signal amplitude -signal = A * gen_testdata(freqs=signal_freqs, cycles=cycles, fs=fs, eps=0.) +signal = A * gen_testdata(freqs=signal_freqs, cycles=cycles, fs=fs, eps=0.) # define frequencies of interest for wavelet methods foi = np.arange(1, 101, step=1) @@ -69,18 +69,18 @@ def gen_testdata(freqs=[20, 40, 60], freq_idx = [] for frequency in signal_freqs: freq_idx.append(np.argmax(foi >= frequency)) - + def test_mtmconvol(): - # 10 cycles of 40Hz are 250 samples + # 10 cycles of 40Hz are 250 samples window_size = 750 # default - stft pads with 0's to make windows fit # we choose N-1 overlap to retrieve a time-freq estimate # for each epoch in the signal - # the transforms have shape (nTime, nTaper, nFreq, nChannel) + # the transforms have shape (nTime, nTaper, nFreq, nChannel) ftr, freqs = mtmconvol.mtmconvol(signal, samplerate=fs, taper='cosine', nperseg=window_size, @@ -96,11 +96,11 @@ def test_mtmconvol(): gridspec_kw={"height_ratios": [1, 3]}, figsize=(6, 6)) - ax1.set_title("Short Time Fourier Transform") + ax1.set_title("Short Time Fourier Transform") ax1.plot(np.arange(signal.size) / fs, signal, c='cornflowerblue') ax1.set_ylabel('signal (a.u.)') - ax2.set_xlabel("time (s)") + ax2.set_xlabel("time (s)") ax2.set_ylabel("frequency (Hz)") df = freqs[1] - freqs[0] @@ -108,13 +108,13 @@ def test_mtmconvol(): extent = [0, len(signal) / fs, freqs[0] - df / 2, freqs[-1] - df / 2] # test also the plotting # scale with amplitude - assert ax2.imshow(ampls.T, - cmap='magma', - aspect='auto', - origin='lower', - extent=extent, - vmin=0, - vmax=1.2 * A) + ax2.imshow(ampls.T, + cmap='magma', + aspect='auto', + origin='lower', + extent=extent, + vmin=0, + vmax=1.2 * A) # zoom into foi region ax2.set_ylim((foi[0], foi[-1])) @@ -129,7 +129,7 @@ def test_mtmconvol(): for frequency in signal_freqs: freq_idx.append(np.argmax(freqs >= frequency)) - # test amplitude normalization + # test amplitude normalization for idx, frequency in zip(freq_idx, signal_freqs): ax2.plot([0, len(signal) / fs], @@ -142,10 +142,10 @@ def test_mtmconvol(): cycle_num = (ampls[:, idx] > A / np.e).sum() / fs * frequency print(f'{cycle_num} cycles for the {frequency} band') # we have 2 times the cycles for each frequency (temporal neighbor) - # assert cycle_num > 2 * cycles + assert cycle_num > 2 * cycles # power should decay fast, so we don't detect more cycles - # assert cycle_num < 3 * cycles - + assert cycle_num < 3 * cycles + fig.tight_layout() # ------------------------- @@ -153,10 +153,10 @@ def test_mtmconvol(): # ------------------------- taper = 'dpss' - taperopt = {'Kmax' : 10, 'NW' : 2} + taper_opt = {'Kmax' : 10, 'NW' : 2} # the transforms have shape (nTime, nTaper, nFreq, nChannel) ftr2, freqs2 = mtmconvol.mtmconvol(signal, - samplerate=fs, taper=taper, taperopt=taperopt, + samplerate=fs, taper=taper, taper_opt=taper_opt, nperseg=window_size, noverlap=window_size - 1) @@ -169,22 +169,22 @@ def test_mtmconvol(): gridspec_kw={"height_ratios": [1, 3]}, figsize=(6, 6)) - ax1.set_title("Multi-Taper STFT") + ax1.set_title("Multi-Taper STFT") ax1.plot(np.arange(signal.size) / fs, signal, c='cornflowerblue') ax1.set_ylabel('signal (a.u.)') - ax2.set_xlabel("time (s)") + ax2.set_xlabel("time (s)") ax2.set_ylabel("frequency (Hz)") # test also the plotting # scale with amplitude - assert ax2.imshow(ampls2.T, - cmap='magma', - aspect='auto', - origin='lower', - extent=extent, - vmin=0, - vmax=1.2 * A) + ax2.imshow(ampls2.T, + cmap='magma', + aspect='auto', + origin='lower', + extent=extent, + vmin=0, + vmax=1.2 * A) # zoom into foi region ax2.set_ylim((foi[0], foi[-1])) @@ -205,7 +205,7 @@ def test_mtmconvol(): # for multi-taper stft we can't # check for the whole time domain - # due to too much spectral broadening + # due to too much spectral broadening/smearing # so we just check that the maximum estimated # amplitude is within 10% boundsof the real amplitude @@ -213,10 +213,10 @@ def test_mtmconvol(): def test_superlet(): - + scalesSL = superlet.scale_from_period(1 / foi) - # spec shape is nScales x nTime (x nChannels) + # spec shape is nScales x nTime (x nChannels) spec = superlet.superlet(signal, samplerate=fs, scales=scalesSL, @@ -231,24 +231,24 @@ def test_superlet(): sharex=True, gridspec_kw={"height_ratios": [1, 3]}, figsize=(6, 6)) - - ax1.set_title("Superlet Transform") + + ax1.set_title("Superlet Transform") ax1.plot(np.arange(signal.size) / fs, signal, c='cornflowerblue') ax1.set_ylabel('signal (a.u.)') - - ax2.set_xlabel("time (s)") - ax2.set_ylabel("frequency (Hz)") + + ax2.set_xlabel("time (s)") + ax2.set_ylabel("frequency (Hz)") extent = [0, len(signal) / fs, foi[0], foi[-1]] # test also the plotting # scale with amplitude - assert ax2.imshow(ampls, - cmap='magma', - aspect='auto', - extent=extent, - origin='lower', - vmin=0, - vmax=1.2 * A) - + ax2.imshow(ampls, + cmap='magma', + aspect='auto', + extent=extent, + origin='lower', + vmin=0, + vmax=1.2 * A) + # get the 'mappable' im = ax2.images[0] fig.colorbar(im, ax = ax2, orientation='horizontal', @@ -272,7 +272,7 @@ def test_superlet(): fig.tight_layout() - + def test_wavelet(): # get a wavelet function @@ -294,20 +294,20 @@ def test_wavelet(): ax1.set_title("Wavelet Transform") ax1.plot(np.arange(signal.size) / fs, signal, c='cornflowerblue') ax1.set_ylabel('signal (a.u.)') - - ax2.set_xlabel("time (s)") + + ax2.set_xlabel("time (s)") ax2.set_ylabel("frequency (Hz)") extent = [0, len(signal) / fs, foi[0], foi[-1]] # test also the plotting # scale with amplitude - assert ax2.imshow(ampls, - cmap='magma', - aspect='auto', - extent=extent, - origin='lower', - vmin=0, - vmax=1.2 * A) + ax2.imshow(ampls, + cmap='magma', + aspect='auto', + extent=extent, + origin='lower', + vmin=0, + vmax=1.2 * A) # get the 'mappable' im = ax2.images[0] @@ -332,7 +332,7 @@ def test_wavelet(): fig.tight_layout() - + def test_mtmfft(): # superposition 40Hz and 100Hz oscillations A1:A2 for 1s @@ -340,14 +340,14 @@ def test_mtmfft(): A1, A2 = 5, 3 tvec = np.arange(0, 1, 1 / 1000) - signal = A1 * np.cos(2 * np.pi * 40 * tvec) - signal += A2 * np.cos(2 * np.pi * 100 * tvec) + signal = A1 * np.cos(2 * np.pi * f1 * tvec) + signal += A2 * np.cos(2 * np.pi * f2 * tvec) # -------------------- # -- test untapered -- # -------------------- - - # the transforms have shape (nTaper, nFreq, nChannel) + + # the transforms have shape (nTaper, nFreq, nChannel) ftr, freqs = mtmfft.mtmfft(signal, fs, taper=None) # with 1000Hz sampling frequency and 1000 samples this gives @@ -359,7 +359,7 @@ def test_mtmfft(): spec = np.real(ftr * ftr.conj()).mean(axis=0) amplitudes = np.sqrt(spec)[:, 0] # only 1 channel # our FFT normalisation recovers the signal amplitudes: - assert np.allclose([A1, A2], amplitudes[[f1, f2]]) + assert np.allclose([A1, A2], amplitudes[[f1, f2]]) fig, ax = ppl.subplots() ax.set_title(f"Amplitude spectrum {A1} x 40Hz + {A2} x 100Hz") @@ -370,10 +370,10 @@ def test_mtmfft(): # ------------------------- # test multi-taper analysis # ------------------------- - - taperopt = {'Kmax' : 8, 'NW' : 1} - ftr, freqs = mtmfft.mtmfft(signal, fs, taper="dpss", taperopt=taperopt) - # average over tapers + + taper_opt = {'Kmax' : 8, 'NW' : 1} + ftr, freqs = mtmfft.mtmfft(signal, fs, taper="dpss", taper_opt=taper_opt) + # average over tapers dpss_spec = np.real(ftr * ftr.conj()).mean(axis=0) dpss_amplitudes = np.sqrt(dpss_spec)[:, 0] # only 1 channel # check for amplitudes (and taper normalisation) @@ -385,34 +385,34 @@ def test_mtmfft(): # ----------------- # test kaiser taper (is boxcar for beta -> inf) # ----------------- - - taperopt = {'beta' : 2} - ftr, freqs = mtmfft.mtmfft(signal, fs, taper="kaiser", taperopt=taperopt) + + taper_opt = {'beta' : 2} + ftr, freqs = mtmfft.mtmfft(signal, fs, taper="kaiser", taper_opt=taper_opt) # average over tapers (only 1 here) kaiser_spec = np.real(ftr * ftr.conj()).mean(axis=0) kaiser_amplitudes = np.sqrt(kaiser_spec)[:, 0] # only 1 channel # check for amplitudes (and taper normalisation) - assert np.allclose(kaiser_amplitudes[[f1, f2]], [A1, A2], atol=1e-2) + assert np.allclose(kaiser_amplitudes[[f1, f2]], [A1, A2], atol=1e-2) # ------------------------------- # test all other window functions (which don't need a parameter) # ------------------------------- - + for win in windows.__all__: - taperopt = {} + taper_opt = {} # that guy isn't symmetric if win == 'exponential': continue # that guy is deprecated if win == 'hanning': - continue + continue try: - ftr, freqs = mtmfft.mtmfft(signal, fs, taper=win, taperopt=taperopt) + ftr, freqs = mtmfft.mtmfft(signal, fs, taper=win, taper_opt=taper_opt) # average over tapers (only 1 here) spec = np.real(ftr * ftr.conj()).mean(axis=0) amplitudes = np.sqrt(spec)[:, 0] # only 1 channel - # print(win, amplitudes[[f1, f2]]) - assert np.allclose(amplitudes[[f1, f2]], [A1, A2], atol=1e-3) + # print(win, amplitudes[[f1, f2]]) + assert np.allclose(amplitudes[[f1, f2]], [A1, A2], atol=1e-3) except TypeError: # we didn't provide default parameters.. pass diff --git a/syncopy/tests/run_tests.sh b/syncopy/tests/run_tests.sh index f561ed01c..d041c1506 100755 --- a/syncopy/tests/run_tests.sh +++ b/syncopy/tests/run_tests.sh @@ -34,7 +34,7 @@ if [ "$1" == "" ]; then usage fi -# Set up "global" pytest options for running test-suite +# Set up "global" pytest options for running test-suite (coverage is only done in local pytest runs) export PYTEST_ADDOPTS="--color=yes --tb=short --verbose" # The while construction allows parsing of multiple positional/optional args (future-proofing...) @@ -46,6 +46,8 @@ while [ "$1" != "" ]; do if [ $_useSLURM ]; then srun -p DEV --mem=8000m -c 4 pytest else + PYTEST_ADDOPTS="$PYTEST_ADDOPTS --cov=../../syncopy --cov-config=../../.coveragerc" + export PYTEST_ADDOPTS pytest fi ;; diff --git a/syncopy/tests/spy_setup.py b/syncopy/tests/spy_setup.py index b75334c24..4ddd4ce0c 100644 --- a/syncopy/tests/spy_setup.py +++ b/syncopy/tests/spy_setup.py @@ -23,4 +23,22 @@ if __name__ == "__main__": # Test stuff within here... - pass + data1 = generate_artificial_data(nTrials=5, nChannels=16, equidistant=False, inmemory=False) + data2 = generate_artificial_data(nTrials=5, nChannels=16, equidistant=True, inmemory=False) + + + nSamples = 1000 + nChannels = 50 + my_noise = np.random.randn(nSamples, nChannels) + + trl_dat = [my_noise, 5 * my_noise + 10, np.random.randn(nSamples, nChannels)] + + aa = spy.AnalogData(trl_dat) + + + # client = spy.esi_cluster_setup(interactive=False) + # data1 + data2 + + sys.exit() + spec = spy.freqanalysis(artdata, method="mtmfft", taper="dpss", output="pow") + diff --git a/syncopy/tests/test_basedata.py b/syncopy/tests/test_basedata.py index fa3ffeb7f..909ed8795 100644 --- a/syncopy/tests/test_basedata.py +++ b/syncopy/tests/test_basedata.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- -# +# # Test proper functionality of Syncopy's `BaseData` class + helpers -# +# # Builtin/3rd party package imports import os import tempfile +from attr import has import h5py import time import pytest @@ -17,13 +18,20 @@ from syncopy.datatype import AnalogData import syncopy.datatype as spd from syncopy.datatype.base_data import VirtualData -from syncopy.shared.errors import SPYValueError, SPYTypeError +from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYError from syncopy.tests.misc import is_win_vm, is_slurm_node # Construct decorators for skipping certain tests skip_in_vm = pytest.mark.skipif(is_win_vm(), reason="running in Win VM") skip_in_slurm = pytest.mark.skipif(is_slurm_node(), reason="running on cluster node") +# Collect all supported binary arithmetic operators +arithmetics = [lambda x, y : x + y, + lambda x, y : x - y, + lambda x, y : x * y, + lambda x, y : x / y, + lambda x, y : x ** y] + class TestVirtualData(): @@ -141,6 +149,7 @@ class TestBaseData(): nSpikes = 50 data = {} trl = {} + samplerate = 1.0 # Generate 2D array simulating an AnalogData array data["AnalogData"] = np.arange(1, nChannels * nSamples + 1).reshape(nSamples, nChannels) @@ -176,35 +185,18 @@ def test_data_alloc(self): hname = os.path.join(tdir, "dummy.h5") for dclass in self.classes: - # attempt allocation with random file - with open(fname, "w") as f: - f.write("dummy") - # with pytest.raises(SPYValueError): - # getattr(spd, dclass)(fname) # allocation with HDF5 file h5f = h5py.File(hname, mode="w") h5f.create_dataset("dummy", data=self.data[dclass]) h5f.close() - - # dummy = getattr(spd, dclass)(filename=hname) - # assert np.array_equal(dummy.data, self.data[dclass]) - # assert dummy.filename == hname - # del dummy # allocation using HDF5 dataset directly dset = h5py.File(hname, mode="r+")["dummy"] dummy = getattr(spd, dclass)(data=dset) assert np.array_equal(dummy.data, self.data[dclass]) assert dummy.mode == "r+", dummy.data.file.mode - del dummy - - # # allocation with memmaped npy file - # np.save(fname, self.data[dclass]) - # dummy = getattr(spd, dclass)(filename=fname) - # assert np.array_equal(dummy.data, self.data[dclass]) - # assert dummy.filename == fname - # del dummy + del dummy # allocation using memmap directly np.save(fname, self.data[dclass]) @@ -231,15 +223,10 @@ def test_data_alloc(self): with pytest.raises(SPYValueError): getattr(spd, dclass)(data=dset) - # # attempt allocation using illegal HDF5 file + # allocate with valid dataset of "illegal" file del h5f["dummy"] h5f.create_dataset("dummy1", data=self.data[dclass]) - # FIXME: unused: h5f.create_dataset("dummy2", data=self.data[dclass]) h5f.close() - # with pytest.raises(SPYValueError): - # getattr(spd, dclass)(hname) - - # allocate with valid dataset of "illegal" file dset = h5py.File(hname, mode="r")["dummy1"] dummy = getattr(spd, dclass)(data=dset, filename=fname) @@ -256,7 +243,34 @@ def test_data_alloc(self): np.save(fname, np.ones((self.nChannels,))) with pytest.raises(SPYValueError): getattr(spd, dclass)(data=open_memmap(fname)) - + + # ensure synthetic data allocation via list of arrays works + dummy = getattr(spd, dclass)(data=[self.data[dclass], self.data[dclass]]) + assert len(dummy.trials) == 2 + + dummy = getattr(spd, dclass)(data=[self.data[dclass], self.data[dclass]], + samplerate=10.0) + assert len(dummy.trials) == 2 + assert dummy.samplerate == 10 + + if any(["ContinuousData" in str(base) for base in self.__class__.__mro__]): + nChan = self.data[dclass].shape[dummy.dimord.index("channel")] + dummy = getattr(spd, dclass)(data=[self.data[dclass], self.data[dclass]], + channel=['label']*nChan) + assert len(dummy.trials) == 2 + assert np.array_equal(dummy.channel, np.array(['label']*nChan)) + + # the most egregious input errors are caught by `array_parser`; only + # test list-routine-specific stuff: complex/real mismatch + with pytest.raises(SPYValueError) as spyval: + getattr(spd, dclass)(data=[self.data[dclass], np.complex64(self.data[dclass])]) + assert "same numeric type (real/complex)" in str(spyval.value) + + # shape mismatch + with pytest.raises(SPYValueError): + getattr(spd, dclass)(data=[self.data[dclass], self.data[dclass].T]) + + time.sleep(0.01) del dummy @@ -264,7 +278,8 @@ def test_data_alloc(self): def test_trialdef(self): for dclass in self.classes: dummy = getattr(spd, dclass)(self.data[dclass], - trialdefinition=self.trl[dclass]) + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) assert np.array_equal(dummy.sampleinfo, self.trl[dclass][:, :2]) assert np.array_equal(dummy._t0, self.trl[dclass][:, 2]) assert np.array_equal(dummy.trialinfo.flatten(), self.trl[dclass][:, 3]) @@ -296,7 +311,7 @@ def test_clear(self): def test_filename(self): # ensure we're salting sufficiently to create at least `numf` # distinct pseudo-random filenames in `__storage__` - numf = 1000 + numf = 10000 dummy = AnalogData() fnames = [] for k in range(numf): @@ -310,13 +325,15 @@ def test_copy(self): # shallow copies are views in memory) for dclass in self.classes: dummy = getattr(spd, dclass)(self.data[dclass], - trialdefinition=self.trl[dclass]) + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) dummy2 = dummy.copy() assert dummy.filename == dummy2.filename assert hash(str(dummy.data)) == hash(str(dummy2.data)) assert hash(str(dummy.sampleinfo)) == hash(str(dummy2.sampleinfo)) assert hash(str(dummy._t0)) == hash(str(dummy2._t0)) assert hash(str(dummy.trialinfo)) == hash(str(dummy2.trialinfo)) + assert hash(str(dummy.samplerate)) == hash(str(dummy2.samplerate)) # test shallow + deep copies of memmaps + HDF5 files with tempfile.TemporaryDirectory() as tdir: @@ -330,13 +347,16 @@ def test_copy(self): mm = open_memmap(fname, mode="r") # hash-matching of shallow-copied memmap - dummy = getattr(spd, dclass)(data=mm, trialdefinition=self.trl[dclass]) + dummy = getattr(spd, dclass)(data=mm, + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) dummy2 = dummy.copy() assert dummy.filename == dummy2.filename assert hash(str(dummy.data)) == hash(str(dummy2.data)) assert hash(str(dummy.sampleinfo)) == hash(str(dummy2.sampleinfo)) assert hash(str(dummy._t0)) == hash(str(dummy2._t0)) assert hash(str(dummy.trialinfo)) == hash(str(dummy2.trialinfo)) + assert hash(str(dummy.samplerate)) == hash(str(dummy2.samplerate)) # test integrity of deep-copy dummy3 = dummy.copy(deep=True) @@ -346,16 +366,19 @@ def test_copy(self): assert np.array_equal(dummy._t0, dummy3._t0) assert np.array_equal(dummy.trialinfo, dummy3.trialinfo) assert np.array_equal(dummy.sampleinfo, dummy3.sampleinfo) + assert dummy.samplerate == dummy3.samplerate # hash-matching of shallow-copied HDF5 dataset dummy = getattr(spd, dclass)(data=h5py.File(hname)["dummy"], - trialdefinition=self.trl[dclass]) + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) dummy2 = dummy.copy() assert dummy.filename == dummy2.filename assert hash(str(dummy.data)) == hash(str(dummy2.data)) assert hash(str(dummy.sampleinfo)) == hash(str(dummy2.sampleinfo)) assert hash(str(dummy._t0)) == hash(str(dummy2._t0)) assert hash(str(dummy.trialinfo)) == hash(str(dummy2.trialinfo)) + assert hash(str(dummy.samplerate)) == hash(str(dummy2.samplerate)) # test integrity of deep-copy dummy3 = dummy.copy(deep=True) @@ -364,6 +387,7 @@ def test_copy(self): assert np.array_equal(dummy._t0, dummy3._t0) assert np.array_equal(dummy.trialinfo, dummy3.trialinfo) assert np.array_equal(dummy.data, dummy3.data) + assert dummy.samplerate == dummy3.samplerate # Delete all open references to file objects b4 closing tmp dir del mm, dummy, dummy2, dummy3 @@ -371,3 +395,175 @@ def test_copy(self): # remove file for next round os.unlink(hname) + + # Test basic error handling of arithmetic ops + def test_arithmetic(self): + + # Define list of classes arithmetic ops should and should not work with + # FIXME: include `CrossSpectralData` here and use something like + # if any(["ContinuousData" in str(base) for base in self.__class__.__mro__]) + continuousClasses = ["AnalogData", "SpectralData"] + discreteClasses = ["SpikeData", "EventData"] + + # Illegal classes for arithmetics + for dclass in discreteClasses: + dummy = getattr(spd, dclass)(self.data[dclass], + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) + for operation in arithmetics: + with pytest.raises(SPYTypeError) as spytyp: + operation(dummy, 2) + assert "Wrong type of base: expected `AnalogData`, `SpectralData`" in str(spytyp.value) + + # Now, test basic error handling for allowed classes + for dclass in continuousClasses: + dummy = getattr(spd, dclass)(self.data[dclass], + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) + otherClass = list(set(self.classes).difference([dclass]))[0] + other = getattr(spd, otherClass)(self.data[otherClass], + trialdefinition=self.trl[otherClass], + samplerate=self.samplerate) + complexArr = np.complex64(dummy.trials[0]) + complexNum = 3+4j + + # Start w/the one operator that does not handle zeros well... + with pytest.raises(SPYValueError) as spyval: + dummy / 0 + assert "expected non-zero scalar for division" in str(spyval.value) + + # Go through all supported operators and try to sabotage them + for operation in arithmetics: + + # Completely wrong operand + with pytest.raises(SPYTypeError) as spytyp: + operation(dummy, np.sin) + assert "expected Syncopy object, scalar or array-like found ufunc" in str(spytyp.value) + + # Empty object + with pytest.raises(SPYValueError) as spyval: + operation(getattr(spd, dclass)(), np.sin) + assert "expected non-empty Syncopy data object" in str(spyval.value) + + # Unbounded scalar + with pytest.raises(SPYValueError) as spyval: + operation(dummy, np.inf) + assert "'inf'; expected finite scalar" in str(spyval.value) + + # Complex scalar (all test data are real) + with pytest.raises(SPYTypeError) as spytyp: + operation(dummy, complexNum) + assert "expected scalar of same mathematical type (real/complex)" in str(spytyp.value) + + # Array w/wrong numeric type + with pytest.raises(SPYTypeError) as spytyp: + operation(dummy, complexArr) + assert "array of same numerical type (real/complex) found ndarray" in str(spytyp.value) + + # Syncopy object of different type + with pytest.raises(SPYTypeError) as spytyp: + operation(dummy, other) + err = "expected Syncopy {} object found {}" + assert err.format(dclass, otherClass) in str(spytyp.value) + + # Next, validate proper functionality of `==` operator for Syncopy objects + for dclass in self.classes: + + # Start simple compare obj to itself, to empty object and compare two empties + dummy = getattr(spd, dclass)(self.data[dclass], + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) + assert dummy == dummy + assert dummy != getattr(spd, dclass)() + assert getattr(spd, dclass)() == getattr(spd, dclass)() + + # Basic type mismatch + assert dummy != complexArr + assert dummy != complexNum + + # Two differing Syncopy object classes + otherClass = list(set(self.classes).difference([dclass]))[0] + other = getattr(spd, otherClass)(self.data[otherClass], + trialdefinition=self.trl[otherClass], + samplerate=self.samplerate) + assert dummy != other + + # Ensure shallow and deep copies are "==" to their origin + dummy2 = dummy.copy() + assert dummy2 == dummy + dummy3 = dummy.copy(deep=True) + assert dummy3 == dummy + + # Ensure differing samplerate evaluates to `False` + dummy3.samplerate = 2*dummy.samplerate + assert dummy3 != dummy + dummy3.samplerate = dummy.samplerate + + # In-place selections are invalid for `==` comparisons + dummy3.selectdata(inplace=True) + with pytest.raises(SPYError) as spe: + dummy3 == dummy + assert "Cannot perform object comparison" in str(spe.value) + + # Abuse existing in-place selection to alter dimensional props of dummy3 + # and ensure inequality + dimProps = dummy3._selector._dimProps + dummy3.selectdata(clear=True) + for prop in dimProps: + if hasattr(dummy3, prop): + setattr(dummy3, prop, getattr(dummy, prop)[::-1]) + assert dummy3 != dummy + setattr(dummy3, prop, getattr(dummy, prop)) + + # Different trials + dummy3 = dummy.selectdata(trials=list(range(len(dummy.trials) - 1))) + assert dummy3 != dummy + + # Different trial offsets + trl = self.trl[dclass] + trl[:, 1] -= 1 + dummy3 = getattr(spd, dclass)(self.data[dclass], + trialdefinition=trl, + samplerate=self.samplerate) + assert dummy3 != dummy + + # Different trial annotations + trl = self.trl[dclass] + trl[:, -1] = np.sqrt(2) + dummy3 = getattr(spd, dclass)(self.data[dclass], + trialdefinition=trl, + samplerate=self.samplerate) + assert dummy3 != dummy + + # Difference in actual numerical data + dummy3 = dummy.copy(deep=True) + for dsetName in dummy3._hdfFileDatasetProperties: + getattr(dummy3, dsetName)[0] = np.pi + assert dummy3 != dummy + + del dummy, dummy2, dummy3, other + + # Same objects but different dimords: `ContinuousData`` children + for dclass in continuousClasses: + dummy = getattr(spd, dclass)(self.data[dclass], + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) + ymmud = getattr(spd, dclass)(self.data[dclass].T, + dimord=dummy.dimord[::-1], + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) + assert dummy != ymmud + + # Same objects but different dimords: `DiscreteData` children + for dclass in discreteClasses: + dummy = getattr(spd, dclass)(self.data[dclass], + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) + ymmud = getattr(spd, dclass)(self.data[dclass], + dimord=dummy.dimord[::-1], + trialdefinition=self.trl[dclass], + samplerate=self.samplerate) + assert dummy != ymmud + + + diff --git a/syncopy/tests/test_computationalroutine.py b/syncopy/tests/test_computationalroutine.py index c5b06e93a..df2f0ed7b 100644 --- a/syncopy/tests/test_computationalroutine.py +++ b/syncopy/tests/test_computationalroutine.py @@ -64,12 +64,12 @@ def process_metadata(self, data, out): def filter_manager(data, b=None, a=None, out=None, select=None, chan_per_worker=None, keeptrials=True, parallel=False, parallel_store=None, log_dict=None): - myfilter = LowPassFilter(b, a=a) - myfilter.initialize(data, chan_per_worker=chan_per_worker, keeptrials=keeptrials) newOut = False if out is None: newOut = True out = AnalogData(dimord=AnalogData._defaultDimord) + myfilter = LowPassFilter(b, a=a) + myfilter.initialize(data, out._stackingDim, chan_per_worker=chan_per_worker, keeptrials=keeptrials) myfilter.compute(data, out, parallel=parallel, parallel_store=parallel_store, diff --git a/syncopy/tests/test_continuousdata.py b/syncopy/tests/test_continuousdata.py index 6d1bf4db0..1b594c1d9 100644 --- a/syncopy/tests/test_continuousdata.py +++ b/syncopy/tests/test_continuousdata.py @@ -27,6 +27,158 @@ skip_without_acme = pytest.mark.skipif( not __acme__, reason="acme not available") +# Collect all supported binary arithmetic operators +arithmetics = [lambda x, y : x + y, + lambda x, y : x - y, + lambda x, y : x * y, + lambda x, y : x / y, + lambda x, y : x ** y] + +# Module-wide set of testing selections +trialSelections = [ + "all", # enforce below selections in all trials of `dummy` + [3, 1, 2] # minimally unordered +] +chanSelections = [ + ["channel03", "channel01", "channel01", "channel02"], # string selection w/repetition + unordered + [4, 2, 2, 5, 5], # repetition + unorderd + range(5, 8), # narrow range + slice(-2, None) # negative-start slice + ] +toiSelections = [ + "all", # non-type-conform string + [0.6], # single inexact match + [-0.2, 0.6, 0.9, 1.1, 1.3, 1.6, 1.8, 2.2, 2.45, 3.] # unordered, inexact, repetions + ] +toilimSelections = [ + [0.5, 1.5], # regular range + [1.5, 2.0], # minimal range (just two-time points) + [1.0, np.inf] # unbounded from above + ] +foiSelections = [ + "all", # non-type-conform string + [2.6], # single inexact match + [1.1, 1.9, 2.1, 3.9, 9.2, 11.8, 12.9, 5.1, 13.8] # unordered, inexact, repetions + ] +foilimSelections = [ + [2, 11], # regular range + [1, 2.0], # minimal range (just two-time points) + [1.0, np.inf] # unbounded from above + ] +taperSelections = [ + ["TestTaper_03", "TestTaper_01", "TestTaper_01", "TestTaper_02"], # string selection w/repetition + unordered + [0, 1, 1, 2, 3], # preserve repetition, don't convert to slice + range(2, 5), # narrow range + slice(0, 5, 2), # slice w/non-unitary step-size + ] +timeSelections = list(zip(["toi"] * len(toiSelections), toiSelections)) \ + + list(zip(["toilim"] * len(toilimSelections), toilimSelections)) +freqSelections = list(zip(["foi"] * len(foiSelections), foiSelections)) \ + + list(zip(["foilim"] * len(foilimSelections), foilimSelections)) + + +# Local helper function for performing basic arithmetic tests +def _base_op_tests(dummy, ymmud, dummy2, ymmud2, dummyC, operation): + + dummyArr = 2 * np.ones((dummy.trials[0].shape)) + ymmudArr = 2 * np.ones((ymmud.trials[0].shape)) + scalarOperands = [2, np.pi] + dummyOperands = [dummyArr, dummyArr.tolist()] + ymmudOperands = [ymmudArr, ymmudArr.tolist()] + + # Ensure trial counts are properly vetted + dummy2.selectdata(trials=[0], inplace=True) + with pytest.raises(SPYValueError) as spyval: + operation(dummy, dummy2) + assert "Syncopy object with same number of trials (selected)" in str (spyval.value) + dummy2._selection = None + + # Scalar algebra must be commutative (except for pow) + for operand in scalarOperands: + result = operation(dummy, operand) # perform operation from right + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(dummy.trials[tk], operand)) + # Don't try to compute `2 ** data`` + if operation(2,3) != 8: + result2 = operation(operand, dummy) # perform operation from left + assert np.array_equal(result2.data, result.data) + + # Same as above, but swapped `dimord` + result = operation(ymmud, operand) + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(ymmud.trials[tk], operand)) + if operation(2,3) != 8: + result2 = operation(operand, ymmud) + assert np.array_equal(result2.data, result.data) + + # Careful: NumPy tries to avoid failure by broadcasting; instead of relying + # on an existing `__radd__` method, it performs arithmetic component-wise, i.e., + # ``np.ones((3,3)) + data`` performs ``1 + data`` nine times, so don't + # test for left/right arithmetics... + for operand in dummyOperands: + result = operation(dummy, operand) + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(dummy.trials[tk], operand)) + for operand in ymmudOperands: + result = operation(ymmud, operand) + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(ymmud.trials[tk], operand)) + + # Ensure erroneous object type-casting is prevented + if dummyC is not None: + with pytest.raises(SPYTypeError) as spytyp: + operation(dummy, dummyC) + assert "Syncopy data object of same numerical type (real/complex)" in str(spytyp.value) + + # Most severe safety hazard: throw two objects at each other (with regular and + # swapped dimord) + result = operation(dummy, dummy2) + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(dummy.trials[tk], dummy2.trials[tk])) + result = operation(ymmud, ymmud2) + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(ymmud.trials[tk], ymmud2.trials[tk])) + +def _selection_op_tests(dummy, ymmud, dummy2, ymmud2, kwdict, operation): + + # Perform in-place selection and construct array based on new subset + selected = dummy.selectdata(**kwdict) + dummy.selectdata(inplace=True, **kwdict) + arr = 2 * np.ones((selected.trials[0].shape), dtype=np.intp) + for operand in [np.pi, arr]: + result = operation(dummy, operand) + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(selected.trials[tk], operand)) + + # Most most complicated: subset selection present in base object + # and operand thrown at it: only attempt to do this if the selection + # is "well-behaved", i.e., is ordered and does not contain repetitions + # The operator code checks for this, so catch the corresponding + # `SpyValueError` and only attempt to test if coast is clear + dummy2.selectdata(inplace=True, **kwdict) + try: + result = operation(dummy, dummy2) + cleanSelection = True + except SPYValueError: + cleanSelection = False + if cleanSelection: + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(selected.trials[tk], + selected.trials[tk])) + selected = ymmud.selectdata(**kwdict) + ymmud.selectdata(inplace=True, **kwdict) + ymmud2.selectdata(inplace=True, **kwdict) + result = operation(ymmud, ymmud2) + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, operation(selected.trials[tk], + selected.trials[tk])) + + # Very important: clear manually set selections for next iteration + dummy._selection = None + dummy2._selection = None + ymmud._selection = None + ymmud2._selection = None + class TestAnalogData(): @@ -331,6 +483,11 @@ def test_object_padding(self): adata = generate_artificial_data(nTrials=7, nChannels=16, equidistant=False, inmemory=False) timeAxis = adata.dimord.index("time") + chanAxis = adata.dimord.index("channel") + + # Define trial/channel selections for tests + trialSel = [0, 2, 1] + chanSel = range(4) # test dictionary generation for `create_new = False`: ensure all trials # have padded length of `total_time` seconds (1 sample tolerance) @@ -341,17 +498,58 @@ def test_object_padding(self): assert "pad_width" in pad_list[tk].keys() assert "constant_values" in pad_list[tk].keys() trl_time = (pad_list[tk]["pad_width"][timeAxis, :].sum() + trl.shape[timeAxis]) / adata.samplerate - assert trl_time - total_time < 1/adata.samplerate + assert trl_time - total_time < 1 / adata.samplerate + + # real thing: pad object with standing channel selection + res = padding(adata, "zero", pad="absolute", padlength=total_time,unit="time", + create_new=True, select={"trials": trialSel, "channels": chanSel}) + for tk, trl in enumerate(res.trials): + adataTrl = adata.trials[trialSel[tk]] + nSamples = pad_list[trialSel[tk]]["pad_width"][timeAxis, :].sum() + adataTrl.shape[timeAxis] + assert trl.shape[timeAxis] == nSamples + assert trl.shape[chanAxis] == len(list(chanSel)) + + # test correct update of trigger onset w/pre-padding + adataTimes = adata.time + prepadTime = 5 + res = padding(adata, "zero", pad="relative", prepadlength=prepadTime, + unit="time", create_new=True) + resTimes = res.time + adataTimes = adata.time + for tk, timeArr in enumerate(resTimes): + assert timeArr[0] == adataTimes[tk][0] - prepadTime + assert np.array_equal(timeArr[timeArr >= 0], adataTimes[tk][adataTimes[tk] >= 0]) + + # postpadding must not change trigger onset timing + postpadTime = 5 + res = padding(adata, "zero", pad="relative", postpadlength=postpadTime, + unit="time", create_new=True) + resTimes = res.time + for tk, timeArr in enumerate(resTimes): + assert timeArr[0] == adataTimes[tk][0] + assert np.array_equal(timeArr[timeArr <= 0], adataTimes[tk][adataTimes[tk] <= 0]) # jumble axes of `AnalogData` object and compute max. trial length adata2 = generate_artificial_data(nTrials=7, nChannels=16, - equidistant=False, inmemory=False, - dimord=adata.dimord[::-1]) + equidistant=False, inmemory=False, + dimord=adata.dimord[::-1]) timeAxis2 = adata2.dimord.index("time") + chanAxis2 = adata2.dimord.index("channel") maxtrllen = 0 for trl in adata2.trials: maxtrllen = max(maxtrllen, trl.shape[timeAxis2]) + # same as above, but this time w/swapped dimensions + res2 = padding(adata2, "zero", pad="absolute", padlength=total_time, unit="time", + create_new=True, select={"trials": trialSel, "channels": chanSel}) + pad_list2 = padding(adata2, "zero", pad="absolute", padlength=total_time, + unit="time", create_new=False) + for tk, trl in enumerate(res2.trials): + adataTrl = adata2.trials[trialSel[tk]] + nSamples = pad_list2[trialSel[tk]]["pad_width"][timeAxis2, :].sum() + adataTrl.shape[timeAxis2] + assert trl.shape[timeAxis2] == nSamples + assert trl.shape[chanAxis2] == len(list(chanSel)) + # symmetric `maxlen` padding: 1 sample tolerance pad_list2 = padding(adata2, "zero", pad="maxlen", create_new=False) for tk, trl in enumerate(adata2.trials): @@ -375,6 +573,21 @@ def test_object_padding(self): trl_len = pad_list2[tk]["pad_width"][timeAxis2, :].sum() + trl.shape[timeAxis2] assert trl_len == maxtrllen + # make things maximally intersting: relative + time + non-equidistant + + # overlapping + selection + nonstandard dimord + adata3 = generate_artificial_data(nTrials=7, nChannels=16, + equidistant=False, overlapping=True, + inmemory=False, dimord=adata2.dimord) + res3 = padding(adata3, "zero", pad="absolute", padlength=total_time, unit="time", + create_new=True, select={"trials": trialSel, "channels": chanSel}) + pad_list3 = padding(adata3, "zero", pad="absolute", padlength=total_time, + unit="time", create_new=False) + for tk, trl in enumerate(res3.trials): + adataTrl = adata3.trials[trialSel[tk]] + nSamples = pad_list3[trialSel[tk]]["pad_width"][timeAxis2, :].sum() + adataTrl.shape[timeAxis2] + assert trl.shape[timeAxis2] == nSamples + assert trl.shape[chanAxis2] == len(list(chanSel)) + # `maxlen'-specific errors: `padlength` wrong type, wrong combo with `prepadlength` with pytest.raises(SPYTypeError): padding(adata, "zero", pad="maxlen", padlength=self.ns, create_new=False) @@ -384,66 +597,98 @@ def test_object_padding(self): padding(adata, "zero", pad="maxlen", padlength=self.ns, prepadlength=True, create_new=False) - # FIXME: implement as soon as object padding is supported: - # test absolute + time + non-equidistant! - # test relative + time + non-equidistant + overlapping! - # test data-selection via class method def test_dataselection(self): + + # Create testing objects (regular and swapped dimords) + dummy = AnalogData(data=self.data, + trialdefinition=self.trl, + samplerate=self.samplerate) + ymmud = AnalogData(data=self.data.T, + trialdefinition=self.trl, + samplerate=self.samplerate, + dimord=AnalogData._defaultDimord[::-1]) + + for obj in [dummy, ymmud]: + idx = [slice(None)] * len(obj.dimord) + timeIdx = obj.dimord.index("time") + chanIdx = obj.dimord.index("channel") + for trialSel in trialSelections: + for chanSel in chanSelections: + for timeSel in timeSelections: + kwdict = {} + kwdict["trials"] = trialSel + kwdict["channels"] = chanSel + kwdict[timeSel[0]] = timeSel[1] + cfg = StructDict(kwdict) + # data selection via class-method + `Selector` instance for indexing + selected = obj.selectdata(**kwdict) + time.sleep(0.05) + selector = Selector(obj, kwdict) + idx[chanIdx] = selector.channel + for tk, trialno in enumerate(selector.trials): + idx[timeIdx] = selector.time[tk] + assert np.array_equal(selected.trials[tk].squeeze(), + obj.trials[trialno][idx[0], :][:, idx[1]].squeeze()) + cfg.data = obj + cfg.out = AnalogData(dimord=obj.dimord) + # data selection via package function and `cfg`: ensure equality + selectdata(cfg) + assert np.array_equal(cfg.out.channel, selected.channel) + assert np.array_equal(cfg.out.data, selected.data) + time.sleep(0.05) + + # test arithmetic operations + def test_ang_arithmetic(self): + + # Create testing objects and corresponding arrays to perform arithmetics with dummy = AnalogData(data=self.data, trialdefinition=self.trl, samplerate=self.samplerate) - trialSelections = [ - "all", # enforce below selections in all trials of `dummy` - [3, 1] # minimally unordered - ] - chanSelections = [ - ["channel03", "channel01", "channel01", "channel02"], # string selection w/repetition + unordered - [4, 2, 2, 5, 5], # repetition + unorderd - range(5, 8), # narrow range - slice(-2, None) # negative-start slice - ] - toiSelections = [ - "all", # non-type-conform string - [0.6], # single inexact match - [-0.2, 0.6, 0.9, 1.1, 1.3, 1.6, 1.8, 2.2, 2.45, 3.] # unordered, inexact, repetions - ] - toilimSelections = [ - [0.5, 1.5], # regular range - [1.5, 2.0], # minimal range (just two-time points) - [1.0, np.inf] # unbounded from above - ] - timeSelections = list(zip(["toi"] * len(toiSelections), toiSelections)) \ - + list(zip(["toilim"] * len(toilimSelections), toilimSelections)) - - idx = [slice(None)] * len(dummy.dimord) - timeIdx = dummy.dimord.index("time") - chanIdx = dummy.dimord.index("channel") - - for trialSel in trialSelections: - for chanSel in chanSelections: - for timeSel in timeSelections: - kwdict = {} - kwdict["trials"] = trialSel - kwdict["channels"] = chanSel - kwdict[timeSel[0]] = timeSel[1] - cfg = StructDict(kwdict) - # data selection via class-method + `Selector` instance for indexing - selected = dummy.selectdata(**kwdict) - time.sleep(0.05) - selector = Selector(dummy, kwdict) - idx[chanIdx] = selector.channel - for tk, trialno in enumerate(selector.trials): - idx[timeIdx] = selector.time[tk] - assert np.array_equal(selected.trials[tk].squeeze(), - dummy.trials[trialno][idx[0], :][:, idx[1]].squeeze()) - cfg.data = dummy - cfg.out = AnalogData(dimord=AnalogData._defaultDimord) - # data selection via package function and `cfg`: ensure equality - selectdata(cfg) - assert np.array_equal(cfg.out.channel, selected.channel) - assert np.array_equal(cfg.out.data, selected.data) - time.sleep(0.05) + ymmud = AnalogData(data=self.data.T, + trialdefinition=self.trl, + samplerate=self.samplerate, + dimord=AnalogData._defaultDimord[::-1]) + dummy2 = AnalogData(data=self.data, + trialdefinition=self.trl, + samplerate=self.samplerate) + ymmud2 = AnalogData(data=self.data.T, + trialdefinition=self.trl, + samplerate=self.samplerate, + dimord=AnalogData._defaultDimord[::-1]) + + # Perform basic arithmetic with +, -, *, / and ** (pow) + for operation in arithmetics: + + # First, ensure `dimord` is respected + with pytest.raises(SPYValueError) as spyval: + operation(dummy, ymmud) + assert "expected Syncopy 'time' x 'channel' data object" in str (spyval.value) + + _base_op_tests(dummy, ymmud, dummy2, ymmud2, None, operation) + + # Now the most complicated case: user-defined subset selections are present + kwdict = {} + kwdict["trials"] = trialSelections[1] + kwdict["channels"] = chanSelections[3] + kwdict[timeSelections[4][0]] = timeSelections[4][1] + _selection_op_tests(dummy, ymmud, dummy2, ymmud2, kwdict, operation) + + # # Go through full selection stack - WARNING: this takes > 15 minutes + # for trialSel in trialSelections: + # for chanSel in chanSelections: + # for timeSel in timeSelections: + # kwdict = {} + # kwdict["trials"] = trialSel + # kwdict["channels"] = chanSel + # kwdict[timeSel[0]] = timeSel[1] + # _selection_op_tests(dummy, ymmud, dummy2, ymmud2, kwdict, operation) + + # Finally, perform a representative chained operation to ensure chaining works + result = (dummy + dummy2) / dummy ** 3 + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, + (dummy.trials[tk] + dummy2.trials[tk]) / dummy.trials[tk] ** 3) @skip_without_acme def test_parallel(self, testcluster): @@ -452,7 +697,8 @@ def test_parallel(self, testcluster): par_tests = ["test_relative_array_padding", "test_absolute_nextpow2_array_padding", "test_object_padding", - "test_dataselection"] + "test_dataselection", + "test_ang_arithmetic"] for test in par_tests: getattr(self, test)() flush_local_cluster(testcluster) @@ -466,7 +712,7 @@ class TestSpectralData(): ns = 30 nt = 5 nf = 15 - data = np.arange(1, nc * ns * nt * nf + 1).reshape(ns, nt, nf, nc) + data = np.arange(1, nc * ns * nt * nf + 1, dtype="float").reshape(ns, nt, nf, nc) trl = np.vstack([np.arange(0, ns, 5), np.arange(5, ns + 5, 5), np.ones((int(ns / 5), )), @@ -511,21 +757,6 @@ def test_sd_trialretrieval(self): trl_ref = self.data2[..., start:start + 5] assert np.array_equal(dummy._get_trial(trlno), trl_ref) - # # test ``_copy_trial`` with memmap'ed data - # with tempfile.TemporaryDirectory() as tdir: - # fname = os.path.join(tdir, "dummy.npy") - # np.save(fname, self.data) - # mm = open_memmap(fname, mode="r") - # dummy = SpectralData(mm, trialdefinition=self.trl) - # for trlno, start in enumerate(range(0, self.ns, 5)): - # trl_ref = self.data[start:start + 5, ...] - # trl_tmp = dummy._copy_trial(trlno, - # dummy.filename, - # dummy.dimord, - # dummy.sampleinfo, - # None) - # assert np.array_equal(trl_tmp, trl_ref) - # del mm, dummy del dummy def test_sd_saveload(self): @@ -577,96 +808,129 @@ def test_sd_saveload(self): # test data-selection via class method def test_sd_dataselection(self): + + # Create testing objects (regular and swapped dimords) dummy = SpectralData(data=self.data, trialdefinition=self.trl, samplerate=self.samplerate, taper=["TestTaper_0{}".format(k) for k in range(1, self.nt + 1)]) - trialSelections = [ - "all", # enforce below selections in all trials of `dummy` - [3, 1] # minimally unordered - ] - chanSelections = [ - ["channel03", "channel01", "channel01", "channel02"], # string selection w/repetition + unordered - [4, 2, 2, 5, 5], # repetition + unorderd - range(5, 8), # narrow range - slice(-2, None) # negative-start slice - ] - toiSelections = [ - "all", # non-type-conform string - [0.6], # single inexact match - [-0.2, 0.6, 0.9, 1.1, 1.3, 1.6, 1.8, 2.2, 2.45, 3.] # unordered, inexact, repetions - ] - toilimSelections = [ - [0.5, 1.5], # regular range - [1.5, 2.0], # minimal range (just two-time points) - [1.0, np.inf] # unbounded from above - ] - foiSelections = [ - "all", # non-type-conform string - [2.6], # single inexact match - [1.1, 1.9, 2.1, 3.9, 9.2, 11.8, 12.9, 5.1, 13.8] # unordered, inexact, repetions - ] - foilimSelections = [ - [2, 11], # regular range - [1, 2.0], # minimal range (just two-time points) - [1.0, np.inf] # unbounded from above - ] - taperSelections = [ - ["TestTaper_03", "TestTaper_01", "TestTaper_01", "TestTaper_02"], # string selection w/repetition + unordered - [0, 1, 1, 2, 3], # preserve repetition, don't convert to slice - range(2, 5), # narrow range - slice(0, 5, 2), # slice w/non-unitary step-size - ] - timeSelections = list(zip(["toi"] * len(toiSelections), toiSelections)) \ - + list(zip(["toilim"] * len(toilimSelections), toilimSelections)) - freqSelections = list(zip(["foi"] * len(foiSelections), foiSelections)) \ - + list(zip(["foilim"] * len(foilimSelections), foilimSelections)) - - idx = [slice(None)] * len(dummy.dimord) - timeIdx = dummy.dimord.index("time") - chanIdx = dummy.dimord.index("channel") - freqIdx = dummy.dimord.index("freq") - taperIdx = dummy.dimord.index("taper") - - for trialSel in trialSelections: - for chanSel in chanSelections: - for timeSel in timeSelections: - for freqSel in freqSelections: - for taperSel in taperSelections: - kwdict = {} - kwdict["trials"] = trialSel - kwdict["channels"] = chanSel - kwdict[timeSel[0]] = timeSel[1] - kwdict[freqSel[0]] = freqSel[1] - kwdict["tapers"] = taperSel - cfg = StructDict(kwdict) - # data selection via class-method + `Selector` instance for indexing - selected = dummy.selectdata(**kwdict) - time.sleep(0.05) - selector = Selector(dummy, kwdict) - idx[chanIdx] = selector.channel - idx[freqIdx] = selector.freq - idx[taperIdx] = selector.taper - for tk, trialno in enumerate(selector.trials): - idx[timeIdx] = selector.time[tk] - indexed = dummy.trials[trialno][idx[0], ...][:, idx[1], ...][:, :, idx[2], :][..., idx[3]] - assert np.array_equal(selected.trials[tk].squeeze(), - indexed.squeeze()) - cfg.data = dummy - cfg.out = SpectralData(dimord=SpectralData._defaultDimord) - # data selection via package function and `cfg`: ensure equality - selectdata(cfg) - assert np.array_equal(cfg.out.channel, selected.channel) - assert np.array_equal(cfg.out.freq, selected.freq) - assert np.array_equal(cfg.out.taper, selected.taper) - assert np.array_equal(cfg.out.data, selected.data) - time.sleep(0.05) + ymmud = SpectralData(data=np.transpose(self.data, [3, 2, 1, 0]), + trialdefinition=self.trl, + samplerate=self.samplerate, + taper=["TestTaper_0{}".format(k) for k in range(1, self.nt + 1)], + dimord=SpectralData._defaultDimord[::-1]) + + for obj in [dummy, ymmud]: + idx = [slice(None)] * len(obj.dimord) + timeIdx = obj.dimord.index("time") + chanIdx = obj.dimord.index("channel") + freqIdx = obj.dimord.index("freq") + taperIdx = obj.dimord.index("taper") + for trialSel in trialSelections: + for chanSel in chanSelections: + for timeSel in timeSelections: + for freqSel in freqSelections: + for taperSel in taperSelections: + kwdict = {} + kwdict["trials"] = trialSel + kwdict["channels"] = chanSel + kwdict[timeSel[0]] = timeSel[1] + kwdict[freqSel[0]] = freqSel[1] + kwdict["tapers"] = taperSel + cfg = StructDict(kwdict) + # data selection via class-method + `Selector` instance for indexing + selected = obj.selectdata(**kwdict) + time.sleep(0.05) + selector = Selector(obj, kwdict) + idx[chanIdx] = selector.channel + idx[freqIdx] = selector.freq + idx[taperIdx] = selector.taper + for tk, trialno in enumerate(selector.trials): + idx[timeIdx] = selector.time[tk] + indexed = obj.trials[trialno][idx[0], ...][:, idx[1], ...][:, :, idx[2], :][..., idx[3]] + assert np.array_equal(selected.trials[tk].squeeze(), + indexed.squeeze()) + cfg.data = obj + cfg.out = SpectralData(dimord=obj.dimord) + # data selection via package function and `cfg`: ensure equality + selectdata(cfg) + assert np.array_equal(cfg.out.channel, selected.channel) + assert np.array_equal(cfg.out.freq, selected.freq) + assert np.array_equal(cfg.out.taper, selected.taper) + assert np.array_equal(cfg.out.data, selected.data) + time.sleep(0.05) + + # test arithmetic operations + def test_sd_arithmetic(self): + + # Create testing objects and corresponding arrays to perform arithmetics with + dummy = SpectralData(data=self.data, + trialdefinition=self.trl, + samplerate=self.samplerate, + taper=["TestTaper_0{}".format(k) for k in range(1, self.nt + 1)]) + dummyC = SpectralData(data=np.complex64(self.data), + trialdefinition=self.trl, + samplerate=self.samplerate, + taper=["TestTaper_0{}".format(k) for k in range(1, self.nt + 1)]) + ymmud = SpectralData(data=np.transpose(self.data, [3, 2, 1, 0]), + trialdefinition=self.trl, + samplerate=self.samplerate, + taper=["TestTaper_0{}".format(k) for k in range(1, self.nt + 1)], + dimord=SpectralData._defaultDimord[::-1]) + dummy2 = SpectralData(data=self.data, + trialdefinition=self.trl, + samplerate=self.samplerate, + taper=["TestTaper_0{}".format(k) for k in range(1, self.nt + 1)]) + ymmud2 = SpectralData(data=np.transpose(self.data, [3, 2, 1, 0]), + trialdefinition=self.trl, + samplerate=self.samplerate, + taper=["TestTaper_0{}".format(k) for k in range(1, self.nt + 1)], + dimord=SpectralData._defaultDimord[::-1]) + + # Perform basic arithmetic with +, -, *, / and ** (pow) + for operation in arithmetics: + + # First, ensure `dimord` is respected + with pytest.raises(SPYValueError) as spyval: + operation(dummy, ymmud) + assert "expected Syncopy 'time' x 'channel' data object" in str(spyval.value) + + _base_op_tests(dummy, ymmud, dummy2, ymmud2, dummyC, operation) + + # Now the most complicated case: user-defined subset selections are present + kwdict = {} + kwdict["trials"] = trialSelections[1] + kwdict["channels"] = chanSelections[3] + kwdict[timeSelections[4][0]] = timeSelections[4][1] + kwdict[freqSelections[4][0]] = freqSelections[4][1] + kwdict["tapers"] = taperSelections[2] + _selection_op_tests(dummy, ymmud, dummy2, ymmud2, kwdict, operation) + + # # Go through full selection stack - WARNING: this takes > 1 hour + # for trialSel in trialSelections: + # for chanSel in chanSelections: + # for timeSel in timeSelections: + # for freqSel in freqSelections: + # for taperSel in taperSelections: + # kwdict = {} + # kwdict["trials"] = trialSel + # kwdict["channels"] = chanSel + # kwdict[timeSel[0]] = timeSel[1] + # kwdict[freqSel[0]] = freqSel[1] + # kwdict["tapers"] = taperSel + # _selection_op_tests(dummy, ymmud, dummy2, ymmud2, kwdict, operation) + + # Finally, perform a representative chained operation to ensure chaining works + result = (dummy + dummy2) / dummy ** 3 + for tk, trl in enumerate(result.trials): + assert np.array_equal(trl, + (dummy.trials[tk] + dummy2.trials[tk]) / dummy.trials[tk] ** 3) @skip_without_acme def test_sd_parallel(self, testcluster): # repeat selected test w/parallel processing engine client = dd.Client(testcluster) - par_tests = ["test_sd_dataselection"] + par_tests = ["test_sd_dataselection", "test_sd_arithmetic"] for test in par_tests: getattr(self, test)() flush_local_cluster(testcluster) diff --git a/syncopy/tests/test_decorators.py b/syncopy/tests/test_decorators.py index fd2d3be79..6bf81cba9 100644 --- a/syncopy/tests/test_decorators.py +++ b/syncopy/tests/test_decorators.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -# +# # Test proper functionality of Syncopy's decorator mechanics -# +# # Builtin/3rd party package imports import string @@ -17,7 +17,7 @@ def group_objects(*data, groupbychan=None, select=None): """ Dummy function that collects the `filename` property of all - input objects that contain a specific channel given by + input objects that contain a specific channel given by `groupbychan` """ group = [] @@ -31,13 +31,13 @@ def group_objects(*data, groupbychan=None, select=None): class TestSpyCalls(): - + nChan = 13 nObjs = nChan - # Generate `nChan` objects whose channel-labeling scheme obeys: + # Generate `nChan` objects whose channel-labeling scheme obeys: # ob1.channel = ["A", "B", "C", ..., "M"] - # ob2.channel = [ "B", "C", ..., "M", "N"] + # ob2.channel = [ "B", "C", ..., "M", "N"] # ob3.channel = [ "C", ..., "M", "N", "O"] # ... # ob13.channel = [ "M", "N", "O", ..., "Z"] @@ -48,23 +48,23 @@ class TestSpyCalls(): obj.channel = list(string.ascii_uppercase[n : nChan + n]) dataObjs.append(obj) data = dataObjs[0] - + def test_validcallstyles(self): - + # data positional fname, = group_objects(self.data) assert fname == self.data.filename - + # data as keyword fname, = group_objects(data=self.data) assert fname == self.data.filename - + # data in cfg cfg = StructDict() cfg.data = self.data fname, = group_objects(cfg) assert fname == self.data.filename - + # 1. data positional, 2. cfg positional cfg = StructDict() cfg.groupbychan = None @@ -74,38 +74,38 @@ def test_validcallstyles(self): # 1. cfg positional, 2. data positional fname, = group_objects(cfg, self.data) assert fname == self.data.filename - + # data positional, cfg as keyword fname, = group_objects(self.data, cfg=cfg) assert fname == self.data.filename - + # cfg positional, data as keyword fname, = group_objects(cfg, data=self.data) assert fname == self.data.filename - + # both keywords fname, = group_objects(cfg=cfg, data=self.data) assert fname == self.data.filename - + def test_invalidcallstyles(self): - + # expected error messages errmsg1 = "expected Syncopy data object(s) provided either via " +\ "`cfg`/keyword or positional arguments, not both" errmsg2 = "expected Syncopy data object(s) provided either via `cfg` " +\ "or as keyword argument, not both" errmsg3 = "expected either 'data' or 'dataset' in `cfg`/keywords, not both" - + # ensure things break reliably for 'data' as well as 'dataset' for key in ["data", "dataset"]: - + # data + cfg w/data cfg = StructDict() cfg[key] = self.data with pytest.raises(SPYValueError) as exc: group_objects(self.data, cfg) assert errmsg1 in str(exc.value) - + # data as positional + kwarg with pytest.raises(SPYValueError) as exc: group_objects(self.data, data=self.data) @@ -121,7 +121,7 @@ def test_invalidcallstyles(self): with pytest.raises(SPYValueError) as exc: group_objects(self.data, cfg, dataset=self.data) assert errmsg1 in str(exc.value) - + # cfg w/data + kwarg with pytest.raises(SPYValueError) as exc: group_objects(cfg, data=self.data) @@ -136,12 +136,12 @@ def test_invalidcallstyles(self): with pytest.raises(SPYValueError)as exc: group_objects(self.data, cfg, cfg=cfg) assert "expected `cfg` either as positional or keyword argument, not both" in str(exc.value) - + # keyword set via cfg and kwarg with pytest.raises(SPYValueError) as exc: group_objects(self.data, cfg, groupbychan="invalid") assert "'non-default value for groupbychan'; expected no keyword arguments" in str(exc.value) - + # both data and dataset in cfg/keywords cfg = StructDict() cfg.data = self.data @@ -157,19 +157,14 @@ def test_invalidcallstyles(self): with pytest.raises(SPYError)as exc: group_objects(data="invalid") assert "`data` must be Syncopy data object(s)!" in str(exc.value) - + # cfg is not dict/StructDict with pytest.raises(SPYTypeError)as exc: group_objects(cfg="invalid") assert "Wrong type of cfg: expected dictionary-like" in str(exc.value) - # no data input whatsoever - with pytest.raises(SPYError)as exc: - group_objects("invalid") - assert "missing mandatory argument: `data`" in str(exc.value) - def test_varargin(self): - + # data positional allFnames = group_objects(*self.dataObjs) assert allFnames == [obj.filename for obj in self.dataObjs] @@ -179,7 +174,7 @@ def test_varargin(self): cfg.data = self.dataObjs fnameList = group_objects(cfg) assert allFnames == fnameList - + # group objects by single-letter "channels" in various ways for letter in ["L", "E", "I", "A"]: letterIdx = string.ascii_uppercase.index(letter) @@ -188,7 +183,7 @@ def test_varargin(self): # data positional + keyword to get "reference" groupList = group_objects(*self.dataObjs, groupbychan=letter) assert len(groupList) == nOccurences - + # 1. data positional, 2. cfg positional cfg = StructDict() cfg.groupbychan = letter @@ -198,11 +193,11 @@ def test_varargin(self): # 1. cfg positional, 2. data positional fnameList = group_objects(cfg, *self.dataObjs) assert groupList == fnameList - + # data positional, cfg as keyword fnameList = group_objects(*self.dataObjs, cfg=cfg) assert groupList == fnameList - + # cfg w/data + keyword cfg = StructDict() cfg.dataset = self.dataObjs @@ -211,7 +206,7 @@ def test_varargin(self): assert groupList == fnameList # data positional + select keyword - fnameList = group_objects(*self.dataObjs[:letterIdx + 1], + fnameList = group_objects(*self.dataObjs[:letterIdx + 1], select={"channels": [letter]}) assert groupList == fnameList @@ -227,17 +222,16 @@ def test_varargin(self): cfg.select = {"channels": [letter]} fnameList = group_objects(cfg) assert groupList == fnameList - + # invalid selection with pytest.raises(SPYValueError) as exc: group_objects(*self.dataObjs, select={"channels": ["Z"]}) assert "expected list/array of channel existing names or indices" in str(exc.value) - # data does not only contain Syncopy objects + # data does not only contain Syncopy objects cfg = StructDict() cfg.data = self.dataObjs + ["invalid"] with pytest.raises(SPYError)as exc: group_objects(cfg) assert "`data` must be Syncopy data object(s)!" in str(exc.value) - - \ No newline at end of file + diff --git a/syncopy/tests/test_discretedata.py b/syncopy/tests/test_discretedata.py index eb37714b2..6d745a8f6 100644 --- a/syncopy/tests/test_discretedata.py +++ b/syncopy/tests/test_discretedata.py @@ -149,9 +149,16 @@ def test_saveload(self): # test data-selection via class method def test_dataselection(self): + + # Create testing objects (regular and swapped dimords) dummy = SpikeData(data=self.data, trialdefinition=self.trl, samplerate=2.0) + ymmud = SpikeData(data=self.data[:, ::-1], + trialdefinition=self.trl, + samplerate=2.0, + dimord=dummy.dimord[::-1]) + # selections are chosen so that result is not empty trialSelections = [ "all", # enforce below selections in all trials of `dummy` @@ -180,40 +187,40 @@ def test_dataselection(self): timeSelections = list(zip(["toi"] * len(toiSelections), toiSelections)) \ + list(zip(["toilim"] * len(toilimSelections), toilimSelections)) - chanIdx = dummy.dimord.index("channel") - unitIdx = dummy.dimord.index("unit") - chanArr = np.arange(dummy.channel.size) - - for trialSel in trialSelections: - for chanSel in chanSelections: - for unitSel in unitSelections: - for timeSel in timeSelections: - kwdict = {} - kwdict["trials"] = trialSel - kwdict["channels"] = chanSel - kwdict["units"] = unitSel - kwdict[timeSel[0]] = timeSel[1] - cfg = StructDict(kwdict) - # data selection via class-method + `Selector` instance for indexing - selected = dummy.selectdata(**kwdict) - selector = Selector(dummy, kwdict) - tk = 0 - for trialno in selector.trials: - if selector.time[tk]: - assert np.array_equal(dummy.trials[trialno][selector.time[tk], :], - selected.trials[tk]) - tk += 1 - assert set(selected.data[:, chanIdx]).issubset(chanArr[selector.channel]) - assert set(selected.channel) == set(dummy.channel[selector.channel]) - assert np.array_equal(selected.unit, - dummy.unit[np.unique(selected.data[:, unitIdx])]) - cfg.data = dummy - cfg.out = SpikeData(dimord=SpikeData._defaultDimord) - # data selection via package function and `cfg`: ensure equality - selectdata(cfg) - assert np.array_equal(cfg.out.channel, selected.channel) - assert np.array_equal(cfg.out.unit, selected.unit) - assert np.array_equal(cfg.out.data, selected.data) + for obj in [dummy, ymmud]: + chanIdx = obj.dimord.index("channel") + unitIdx = obj.dimord.index("unit") + chanArr = np.arange(obj.channel.size) + for trialSel in trialSelections: + for chanSel in chanSelections: + for unitSel in unitSelections: + for timeSel in timeSelections: + kwdict = {} + kwdict["trials"] = trialSel + kwdict["channels"] = chanSel + kwdict["units"] = unitSel + kwdict[timeSel[0]] = timeSel[1] + cfg = StructDict(kwdict) + # data selection via class-method + `Selector` instance for indexing + selected = obj.selectdata(**kwdict) + selector = Selector(obj, kwdict) + tk = 0 + for trialno in selector.trials: + if selector.time[tk]: + assert np.array_equal(obj.trials[trialno][selector.time[tk], :], + selected.trials[tk]) + tk += 1 + assert set(selected.data[:, chanIdx]).issubset(chanArr[selector.channel]) + assert set(selected.channel) == set(obj.channel[selector.channel]) + assert np.array_equal(selected.unit, + obj.unit[np.unique(selected.data[:, unitIdx])]) + cfg.data = obj + cfg.out = SpikeData(dimord=obj.dimord) + # data selection via package function and `cfg`: ensure equality + selectdata(cfg) + assert np.array_equal(cfg.out.channel, selected.channel) + assert np.array_equal(cfg.out.unit, selected.unit) + assert np.array_equal(cfg.out.data, selected.data) @skip_without_acme def test_parallel(self, testcluster): @@ -473,9 +480,16 @@ def test_ed_trialsetting(self): # test data-selection via class method def test_ed_dataselection(self): + + # Create testing objects (regular and swapped dimords) dummy = EventData(data=self.data, trialdefinition=self.trl, samplerate=2.0) + ymmud = EventData(data=self.data[:, ::-1], + trialdefinition=self.trl, + samplerate=2.0, + dimord=dummy.dimord[::-1]) + # selections are chosen so that result is not empty trialSelections = [ "all", # enforce below selections in all trials of `dummy` @@ -497,33 +511,33 @@ def test_ed_dataselection(self): timeSelections = list(zip(["toi"] * len(toiSelections), toiSelections)) \ + list(zip(["toilim"] * len(toilimSelections), toilimSelections)) - eventidIdx = dummy.dimord.index("eventid") - - for trialSel in trialSelections: - for eventidSel in eventidSelections: - for timeSel in timeSelections: - kwdict = {} - kwdict["trials"] = trialSel - kwdict["eventids"] = eventidSel - kwdict[timeSel[0]] = timeSel[1] - cfg = StructDict(kwdict) - # data selection via class-method + `Selector` instance for indexing - selected = dummy.selectdata(**kwdict) - selector = Selector(dummy, kwdict) - tk = 0 - for trialno in selector.trials: - if selector.time[tk]: - assert np.array_equal(dummy.trials[trialno][selector.time[tk], :], - selected.trials[tk]) - tk += 1 - assert np.array_equal(selected.eventid, - dummy.eventid[np.unique(selected.data[:, eventidIdx]).astype(np.intp)]) - cfg.data = dummy - cfg.out = EventData(dimord=EventData._defaultDimord) - # data selection via package function and `cfg`: ensure equality - selectdata(cfg) - assert np.array_equal(cfg.out.eventid, selected.eventid) - assert np.array_equal(cfg.out.data, selected.data) + for obj in [dummy, ymmud]: + eventidIdx = obj.dimord.index("eventid") + for trialSel in trialSelections: + for eventidSel in eventidSelections: + for timeSel in timeSelections: + kwdict = {} + kwdict["trials"] = trialSel + kwdict["eventids"] = eventidSel + kwdict[timeSel[0]] = timeSel[1] + cfg = StructDict(kwdict) + # data selection via class-method + `Selector` instance for indexing + selected = obj.selectdata(**kwdict) + selector = Selector(obj, kwdict) + tk = 0 + for trialno in selector.trials: + if selector.time[tk]: + assert np.array_equal(obj.trials[trialno][selector.time[tk], :], + selected.trials[tk]) + tk += 1 + assert np.array_equal(selected.eventid, + obj.eventid[np.unique(selected.data[:, eventidIdx]).astype(np.intp)]) + cfg.data = obj + cfg.out = EventData(dimord=obj.dimord) + # data selection via package function and `cfg`: ensure equality + selectdata(cfg) + assert np.array_equal(cfg.out.eventid, selected.eventid) + assert np.array_equal(cfg.out.data, selected.data) @skip_without_acme def test_ed_parallel(self, testcluster): diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 365ea61f1..5571c845c 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -387,6 +387,16 @@ def test_general(self): with pytest.raises(SPYValueError): Selector(ang, {"wrongkey": [1]}) + # set/clear in-place data selection (both setting and clearing are idempotent, + # i.e., repeated execution must work, hence the double whammy) + ang.selectdata(trials=[3, 1]) + ang.selectdata(trials=[3, 1]) + ang.selectdata(clear=True) + ang.selectdata(clear=True) + with pytest.raises(SPYValueError) as spyval: + ang.selectdata(trials=[3, 1], clear=True) + assert "no data selectors if `clear = True`" in str(spyval.value) + # go through all data-classes defined above for dclass in self.classes: dummy = getattr(spd, dclass)(data=self.data[dclass], diff --git a/syncopy/tests/test_specest.py b/syncopy/tests/test_specest.py index f4db5fff0..3a48839cd 100644 --- a/syncopy/tests/test_specest.py +++ b/syncopy/tests/test_specest.py @@ -4,13 +4,12 @@ # # Builtin/3rd party package imports -from multiprocessing import Value import os import tempfile import inspect +import psutil import gc import pytest -import time import numpy as np import scipy.signal as scisig from numpy.lib.format import open_memmap @@ -30,6 +29,10 @@ # Decorator to decide whether or not to run dask-related tests skip_without_acme = pytest.mark.skipif(not __acme__, reason="acme not available") +# Decorator to decide whether or not to run memory-intensive tests +availMem = psutil.virtual_memory().total +skip_low_mem = pytest.mark.skipif(availMem < 10 * 1024**3, reason="less than 10GB RAM available") + # Local helper for constructing TF testing signals def _make_tf_signal(nChannels, nTrials, seed, fadeIn=None, fadeOut=None): @@ -188,7 +191,7 @@ def test_allocout(self): # keep trials but throw away tapers out = SpectralData(dimord=SpectralData._defaultDimord) freqanalysis(self.adata, method="mtmfft", taper="dpss", - keeptapers=False, output="pow", out=out) + tapsmofrq=3, keeptapers=False, output="pow", out=out) assert out.sampleinfo.shape == (self.nTrials, 2) assert out.taper.size == 1 @@ -196,6 +199,7 @@ def test_allocout(self): cfg.dataset = self.adata cfg.out = SpectralData(dimord=SpectralData._defaultDimord) cfg.taper = "dpss" + cfg.tapsmofrq = 3 cfg.output = "pow" cfg.keeptapers = False freqanalysis(cfg) @@ -255,14 +259,13 @@ def test_dpss(self): # ensure default setting results in single taper spec = freqanalysis(self.adata, method="mtmfft", - taper="dpss", output="pow", select=select) + taper="dpss", tapsmofrq=3, output="pow", select=select) assert spec.taper.size == 1 assert spec.channel.size == len(chanList) # specify tapers spec = freqanalysis(self.adata, method="mtmfft", taper="dpss", tapsmofrq=7, keeptapers=True, select=select) - assert spec.taper.size == 7 assert spec.channel.size == len(chanList) # non-equidistant data w/multiple tapers @@ -392,13 +395,14 @@ def test_vdata(self): avdata = AnalogData(vdata, samplerate=self.fs, trialdefinition=self.trialdefinition) spec = freqanalysis(avdata, method="mtmfft", taper="dpss", - keeptapers=False, output="abs", pad="relative", + tapsmofrq=3, keeptapers=False, output="abs", pad="relative", padlength=npad) assert (np.diff(avdata.sampleinfo)[0][0] + npad) / 2 + 1 == spec.freq.size del avdata, vdata, dmap, spec gc.collect() # force-garbage-collect object so that tempdir can be closed @skip_without_acme + @skip_low_mem def test_parallel(self, testcluster): # collect all tests of current class and repeat them using dask # (skip VirtualData tests since ``wrapper_io`` expects valid headers) @@ -555,7 +559,7 @@ def test_tf_allocout(self): # keep trials but throw away tapers out = SpectralData(dimord=SpectralData._defaultDimord) - freqanalysis(self.tfData, method="mtmconvol", taper="dpss", + freqanalysis(self.tfData, method="mtmconvol", taper="dpss", tapsmofrq=3, keeptapers=False, output="pow", toi=0.0, t_ftimwin=1.0, out=out) assert out.sampleinfo.shape == (self.nTrials, 2) @@ -565,6 +569,7 @@ def test_tf_allocout(self): cfg.dataset = self.tfData cfg.out = SpectralData(dimord=SpectralData._defaultDimord) cfg.taper = "dpss" + cfg.tapsmofrq = 3 cfg.keeptapers = False cfg.output = "pow" freqanalysis(cfg) @@ -722,8 +727,8 @@ def test_tf_toi(self): cfg.toi = "all" cfg.t_ftimwin = 0.05 tfSpec = freqanalysis(cfg, self.tfData) - assert tfSpec.taper.size > 1 - dt = 1/self.tfData.samplerate + assert tfSpec.taper.size >= 1 + dt = 1 / self.tfData.samplerate timeArr = np.arange(cfg.select["toilim"][0], cfg.select["toilim"][1] + dt, dt) assert np.allclose(tfSpec.time[0], timeArr) cfg.toi = 1.0 @@ -767,7 +772,7 @@ def test_tf_irregular_trials(self): artdata = generate_artificial_data(nTrials=5, nChannels=16, equidistant=True, inmemory=False) tfSpec = freqanalysis(artdata, **cfg) - assert tfSpec.taper.size > 1 + assert tfSpec.taper.size >= 1 for tk, origTime in enumerate(artdata.time): assert np.array_equal(np.unique(np.floor(origTime)), tfSpec.time[tk]) @@ -784,7 +789,7 @@ def test_tf_irregular_trials(self): artdata = generate_artificial_data(nTrials=5, nChannels=8, equidistant=False, inmemory=False) tfSpec = freqanalysis(artdata, **cfg) - assert tfSpec.taper.size > 1 + assert tfSpec.taper.size >= 1 for tk, origTime in enumerate(artdata.time): assert np.array_equal(np.unique(np.floor(origTime)), tfSpec.time[tk]) cfg.toi = "all" @@ -798,7 +803,7 @@ def test_tf_irregular_trials(self): equidistant=False, inmemory=False, dimord=AnalogData._defaultDimord[::-1]) tfSpec = freqanalysis(cfg) - assert tfSpec.taper.size > 1 + assert tfSpec.taper.size >= 1 for tk, origTime in enumerate(cfg.data.time): assert np.array_equal(np.unique(np.floor(origTime)), tfSpec.time[tk]) cfg.toi = "all" @@ -809,11 +814,11 @@ def test_tf_irregular_trials(self): # same + overlapping trials cfg.toi = 0.0 cfg.data = generate_artificial_data(nTrials=5, nChannels=4, - equidistant=False, inmemory=False, - dimord=AnalogData._defaultDimord[::-1], - overlapping=True) + equidistant=False, inmemory=False, + dimord=AnalogData._defaultDimord[::-1], + overlapping=True) tfSpec = freqanalysis(cfg) - assert tfSpec.taper.size > 1 + assert tfSpec.taper.size >= 1 for tk, origTime in enumerate(cfg.data.time): assert np.array_equal(np.unique(np.floor(origTime)), tfSpec.time[tk]) cfg.toi = "all" @@ -822,6 +827,7 @@ def test_tf_irregular_trials(self): assert np.array_equal(origTime, tfSpec.time[tk]) @skip_without_acme + @skip_low_mem def test_tf_parallel(self, testcluster): # collect all tests of current class and repeat them running concurrently client = dd.Client(testcluster) @@ -873,7 +879,7 @@ def test_tf_parallel(self, testcluster): inmemory=False) for chan_per_worker in enumerate([None, chanPerWrkr]): tfSpec = freqanalysis(artdata, cfg) - assert tfSpec.taper.size > 1 + assert tfSpec.taper.size >= 1 # overlapping trial spacing, throw away trials and tapers cfg.keeptapers = False @@ -912,6 +918,7 @@ class TestWavelet(): "channels": range(0, int(nChannels / 2)), "toilim": [-20, 60.8]}] + @skip_low_mem def test_wav_solution(self): # Compute TF specturm across entire time-interval (use integer-valued @@ -1112,6 +1119,7 @@ def test_wav_irregular_trials(self): assert np.array_equal(origTime, tfSpec.time[tk]) @skip_without_acme + @skip_low_mem def test_wav_parallel(self, testcluster): # collect all tests of current class and repeat them running concurrently client = dd.Client(testcluster) diff --git a/testdev_backend.py b/testdev_backend.py deleted file mode 100644 index f202ade61..000000000 --- a/testdev_backend.py +++ /dev/null @@ -1,205 +0,0 @@ -''' This is a temporary development file ''' - -import numpy as np -import matplotlib.pyplot as ppl - -from syncopy.shared.parsers import data_parser, scalar_parser, array_parser -from syncopy.specest.wavelet import get_optimal_wavelet_scales, wavelet -from syncopy.specest.superlet import SuperletTransform, MorletSL, cwtSL, _get_superlet_support, superlet, compute_adaptive_order, scale_from_period -from syncopy.specest.wavelets import Morlet -from scipy.signal import fftconvolve - - -def gen_superlet_testdata(freqs=[20, 40, 60], - cycles=11, fs=1000, - eps = 0): - - ''' - Harmonic superposition of multiple - few-cycle oscillations akin to the - example of Figure 3 in Moca et al. 2021 NatComm - ''' - - signal = [] - for freq in freqs: - - # 10 cycles of f1 - tvec = np.arange(cycles / freq, step=1 / fs) - - harmonic = np.cos(2 * np.pi * freq * tvec) - f_neighbor = np.cos(2 * np.pi * (freq + 10) * tvec) - packet = harmonic + f_neighbor - - # 2 cycles time neighbor - delta_t = np.zeros(int(2 / freq * fs)) - - # 5 cycles break - pad = np.zeros(int(5 / freq * fs)) - - signal.extend([pad, packet, delta_t, harmonic]) - - # stack the packets together with some padding - signal.append(pad) - signal = np.concatenate(signal) - - # additive white noise - if eps > 0: - signal = np.random.randn(len(signal)) * eps + signal - - return signal - - -# test the Wavelet transform -fs = 1000 -s1 = 1 * gen_superlet_testdata(fs=fs, eps=0) # 20Hz, 40Hz and 60Hz -data = np.c_[3*s1, 50*s1] -preselect = np.ones(len(s1), dtype=bool) -preselect2 = np.ones((len(s1), 2), dtype=bool) -pads = 0 - -ts = np.arange(-50,50) -morletTC = Morlet() -morletSL = MorletSL(c_i=30) - -# frequencies to look at, 10th freq is around 20Hz -freqs = np.linspace(1, 100, 50) # up to 100Hz -scalesTC = morletTC.scale_from_period(1 / freqs) -# scales are cycle independent! -scalesSL = scale_from_period(1 / freqs) - -# automatic diadic scales -ssTC = get_optimal_wavelet_scales(Morlet().scale_from_period, len(s1), 1/fs) -ssSL = get_optimal_wavelet_scales(scale_from_period, len(s1), 1/fs) -# a multiplicative Superlet - a set of Morlets, order 1 - 30 -c_1 = 1 -cycles = c_1 * np.arange(1, 31) -sl = [MorletSL(c) for c in cycles] - -res = wavelet(data, - preselect, - preselect, - pads, - pads, - samplerate=fs, - # toi='some', - output_fmt="pow", - scales=scalesTC, - wav=Morlet(), - noCompute=False) - - -# unit impulse -# data = np.zeros(500) -# data[248:252] = 1 -spec = superlet(s1, samplerate=fs, scales=scalesSL, - order_max=10, - order_min=5, - adaptive=False) -spec2 = superlet(data, samplerate=fs, scales=scalesSL, order_max=20, adaptive=False) - -# nc = superlet(data, samplerate=fs, scales=scalesSL, order_max=30) - - -def do_slt(data, scales=scalesSL, **slkwargs): - - if scales is None: - scales = get_optimal_wavelet_scales(scale_from_period, - len(data[:, 0]), - 1 / fs) - - spec = superlet(data, samplerate=fs, - scales=scales, - **slkwargs) - - print(spec.max(),spec.shape) - ppl.figure() - extent = [0, len(s1) / fs, freqs[-1], freqs[0]] - ppl.imshow(np.abs(spec[...,0]), cmap='plasma', aspect='auto', extent=extent) - ppl.plot([0, len(s1) / fs], [20, 20], 'k--') - ppl.plot([0, len(s1) / fs], [40, 40], 'k--') - ppl.plot([0, len(s1) / fs], [60, 60], 'k--') - - return spec - - -def show_MorletSL(morletSL, scale): - - cycle = morletSL.c_i - ts = _get_superlet_support(scale, 1/fs, cycle) - ppl.plot(ts, MorletSL(cycle)(ts, scale)) - - -def show_MorletTC(morletTC, scale): - - M = 10 * scale * fs - # times to use, centred at zero - ts = np.arange((-M + 1) / 2.0, (M + 1) / 2.0) / fs - ppl.plot(ts, morletTC(ts, scale)) - - -def do_superlet_cwt(data, wav, scales=None): - - if scales is None: - scales = get_optimal_wavelet_scales(scale_from_period, len(data[:,0]), 1/fs) - - res = cwtSL(data, - wav, - scales=scales, - dt=1 / fs) - - ppl.figure() - extent = [0, len(s1) / fs, freqs[-1], freqs[0]] - channel=0 - ppl.imshow(np.abs(res[:,:, channel]), cmap='plasma', aspect='auto', extent=extent) - ppl.plot([0, len(s1) / fs], [20, 20], 'k--') - ppl.plot([0, len(s1) / fs], [40, 40], 'k--') - ppl.plot([0, len(s1) / fs], [60, 60], 'k--') - - return res.T - - -def do_normal_cwt(data, wav, scales=None): - - if scales is None: - scales = get_optimal_wavelet_scales(wav.scale_from_period, - len(data[:,0]), - 1/fs) - res = wavelet(data, - preselect, - preselect, - pads, - pads, - samplerate=fs, - # toi='some', - output_fmt="pow", - scales=scales, - wav=wav) - - ppl.figure() - extent = [0, len(s1) / fs, freqs[-1], freqs[0]] - ppl.imshow(res[:, 0, :, 0].T, cmap='plasma', aspect='auto', extent=extent) - ppl.plot([0, len(s1) / fs], [20, 20], 'k--') - ppl.plot([0, len(s1) / fs], [40, 40], 'k--') - ppl.plot([0, len(s1) / fs], [60, 60], 'k--') - - return res[:, 0, :, :].T - -# do_cwt(morletSL) - - -def screen_CWT(w0s= [5, 8, 12]): - for w0 in w0s: - morletTC = Morlet(w0) - scales = _get_optimal_wavelet_scales(morletTC, len(s1), 1/fs) - res = wavelet(s1[:, np.newaxis], - preselect, - preselect, - pads, - pads, - samplerate=fs, - toi=np.array([1,2]), - scales=scales, - wav=morletTC) - - ppl.figure() - ppl.imshow(res[:, 0, :, 0].T, cmap='plasma', aspect='auto') diff --git a/testdev_frontend.py b/testdev_frontend.py deleted file mode 100644 index 8e7df5563..000000000 --- a/testdev_frontend.py +++ /dev/null @@ -1,32 +0,0 @@ -''' This is a temporary development file ''' - -import numpy as np -import matplotlib.pyplot as ppl -from syncopy.specest import freqanalysis -from syncopy.tests.misc import generate_artificial_data - -tdat = generate_artificial_data() - -# test mtmfft analysis -r_mtm = freqanalysis(tdat) - -toi_ival = np.linspace(-0.5, 1, 100) - -#toi_ival = [0,0.2,0.5,1] -toi_ival = 'all' -foi = np.logspace(-1, 2.6, 25) -# test classical wavelet analysis -r_wav = freqanalysis(tdat, method="wavelet", - toi=toi_ival, - output='abs', - foi=None) #, foilim=[5, 500]) - -# test superlet analysis -r_sup = freqanalysis(tdat, method="superlet", toi=toi_ival, - order_max=20, output='abs', - order_min=1, - c_1 = 5, - adaptive=True) - -r_sup = freqanalysis(tdat, method="superlet", toi='all', order_max=30, foi=foi, output='abs',order_min=5, adaptive=True) -#res_strials = [t for t in r_sup.trials] diff --git a/tox.ini b/tox.ini index 065399184..86277a0b5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py38-scipy{14,15}-{noacme, acme} +envlist = py38-scipy15-{noacme, acme} requires = tox-conda isolated_build = True @@ -13,8 +13,7 @@ deps = tqdm >= 4.31 memory_profiler conda_deps= - scipy14: scipy >= 1.4, < 1.5 - scipy15: scipy >= 1.5, < 1.6 + scipy15: scipy >= 1.5 acme: esi-acme conda_channels= defaults