Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preproc hilbert + rectification + FIRWS tests #264

Merged
merged 19 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions syncopy/nwanalysis/wilson_sf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def wilson_sf(CSD, nIter=100, rtol=1e-9, direct_inversion=True):

# max relative error
CSDfac = psi @ psi.conj().transpose(0, 2, 1)
err = np.abs(CSD - CSDfac)
err = (err / np.abs(CSD)).max()
err = max_rel_err(CSD, CSDfac)
# converged
if err < rtol:
converged = True
Expand All @@ -129,8 +128,8 @@ def _psi0_initial(CSD):

nSamples = CSD.shape[1]

# perform ifft to obtain gammas.
gamma = np.fft.ifft(CSD, axis=0)
# perform (i)fft to obtain gammas.
gamma = np.fft.fft(CSD, axis=0)
gamma0 = gamma[0, ...]

# Remove any asymmetry due to rounding error.
Expand Down
177 changes: 174 additions & 3 deletions syncopy/preproc/compRoutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def process_metadata(self, data, out):
if data.selection is not None:
chanSec = data.selection.channel
trl = data.selection.trialdefinition
for row in range(trl.shape[0]):
trl[row, :2] = [row, row + 1]
else:
chanSec = slice(None)
trl = data.trialdefinition
Expand Down Expand Up @@ -270,7 +268,7 @@ def but_filtering_cF(dat,
class But_Filtering(ComputationalRoutine):

"""
Compute class that performs filtering with butterworth filters
Compute class that performs filtering with butterworth filters
of :class:`~syncopy.AnalogData` objects

Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`,
Expand Down Expand Up @@ -301,3 +299,176 @@ def process_metadata(self, data, out):

out.samplerate = data.samplerate
out.channel = np.array(data.channel[chanSec])


@unwrap_io
def rectify_cF(dat, noCompute=False, chunkShape=None):

"""
Provides straightforward rectification via `np.abs`.

dat : (N, K) :class:`numpy.ndarray`
Uniformly sampled multi-channel time-series data
noCompute : bool
If `True`, do not perform actual calculation but
instead return expected shape and :class:`numpy.dtype` of output
array.

Returns
-------
rectified : (N, K) :class:`~numpy.ndarray`
The rectified signals

Notes
-----
This method is intended to be used as
:meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`.
Thus, input parameters are presumed to be forwarded from a parent metafunction.
Consequently, this function does **not** perform any error checking and operates
under the assumption that all inputs have been externally validated and cross-checked.

"""

# operation does not change the shape
outShape = dat.shape
if noCompute:
return outShape, np.float32

return np.abs(dat)


class Rectify(ComputationalRoutine):

"""
Compute class that performs rectification
of :class:`~syncopy.AnalogData` objects

Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`,
see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute
classes and metafunctions.

See also
--------
syncopy.preprocessing : parent metafunction
"""

computeFunction = staticmethod(rectify_cF)

# 1st argument,the data, gets omitted
valid_kws = list(signature(rectify_cF).parameters.keys())[1:]

def process_metadata(self, data, out):

# Some index gymnastics to get trial begin/end "samples"
if data.selection is not None:
chanSec = data.selection.channel
trl = data.selection.trialdefinition
else:
chanSec = slice(None)
trl = data.trialdefinition

out.trialdefinition = trl

out.samplerate = data.samplerate
out.channel = np.array(data.channel[chanSec])


@unwrap_io
def hilbert_cF(dat, output='abs', timeAxis=0, noCompute=False, chunkShape=None):

"""
Provides Hilbert transformation with various outputs, band-pass filtering
beforehand highly recommended.

dat : (N, K) :class:`numpy.ndarray`
Uniformly sampled multi-channel time-series data
output : {'abs', 'complex', 'real', 'imag', 'absreal', 'absimag', 'angle'}
The transformation after performing the complex Hilbert transform. Choose
`'angle'` to get the phase.
timeAxis : int, optional
Index of running time axis in `dat` (0 or 1)
noCompute : bool
If `True`, do not perform actual calculation but
instead return expected shape and :class:`numpy.dtype` of output
array.

Returns
-------
rectified : (N, K) :class:`~numpy.ndarray`
The rectified signals

Notes
-----
This method is intended to be used as
:meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`.
Thus, input parameters are presumed to be forwarded from a parent metafunction.
Consequently, this function does **not** perform any error checking and operates
under the assumption that all inputs have been externally validated and cross-checked.

"""

out_trafo = {
'abs': lambda x: np.abs(x),
'complex': lambda x: x,
'real': lambda x: np.real(x),
'imag': lambda x: np.imag(x),
'absreal': lambda x: np.abs(np.real(x)),
'absimag': lambda x: np.abs(np.imag(x)),
'angle': lambda x: np.angle(x)
}

# Re-arrange array if necessary and get dimensional information
if timeAxis != 0:
dat = dat.T # does not copy but creates view of `dat`
else:
dat = dat

# operation does not change the shape
# but may change the number format
outShape = dat.shape
fmt = np.complex64 if output == 'complex' else np.float32
pantaray marked this conversation as resolved.
Show resolved Hide resolved
if noCompute:
return outShape, fmt

trafo = sci.hilbert(dat, axis=0)

return out_trafo[output](trafo)


class Hilbert(ComputationalRoutine):

"""
Compute class that performs Hilbert transforms
of :class:`~syncopy.AnalogData` objects

Sub-class of :class:`~syncopy.shared.computational_routine.ComputationalRoutine`,
see :doc:`/developer/compute_kernels` for technical details on Syncopy's compute
classes and metafunctions.

See also
--------
syncopy.preprocessing : parent metafunction
"""

computeFunction = staticmethod(hilbert_cF)

# 1st argument,the data, gets omitted
valid_kws = list(signature(hilbert_cF).parameters.keys())[1:]

def process_metadata(self, data, out):

# Some index gymnastics to get trial begin/end "samples"
if data.selection is not None:
chanSec = data.selection.channel
trl = data.selection.trialdefinition

else:
chanSec = slice(None)
trl = data.trialdefinition

out.trialdefinition = trl

out.samplerate = data.samplerate
out.channel = np.array(data.channel[chanSec])
80 changes: 68 additions & 12 deletions syncopy/preproc/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
check_passed_kwargs
)

from .compRoutines import But_Filtering, Sinc_Filtering
from .compRoutines import But_Filtering, Sinc_Filtering, Rectify, Hilbert

availableFilters = ('but', 'firws')
availableFilterTypes = ('lp', 'hp', 'bp', 'bs')
availableDirections = ('twopass', 'onepass', 'onepass-minphase')
availableWindows = ("hamming", "hann", "blackman")

hilbert_outputs = {'abs', 'complex', 'real', 'imag', 'absreal', 'absimag', 'angle'}


@unwrap_cfg
@unwrap_select
Expand All @@ -37,10 +39,12 @@ def preprocessing(data,
direction=None,
window="hamming",
polyremoval=None,
rectify=False,
hilbert=False,
**kwargs
):
"""
Filtering of time continuous raw data with IIR and FIR filters
Preprocessing of time continuous raw data with IIR and FIR filters

data : `~syncopy.AnalogData`
A non-empty Syncopy :class:`~syncopy.AnalogData` object
Expand Down Expand Up @@ -68,6 +72,11 @@ def preprocessing(data,
to filtering. A value of 0 corresponds to subtracting the mean
("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the
least squares fit of a linear polynomial).
rectify : bool, optional
Set to `True` to rectify (after filtering)
hilbert : None or one of {'abs', 'complex', 'real', 'imag', 'absreal', 'absimag', 'angle'}
Choose one of the supported output types to perform
Hilbert transformation after filtering. Set to `'angle'` to return the phase.

Returns
-------
Expand Down Expand Up @@ -121,6 +130,9 @@ def preprocessing(data,
if polyremoval is not None:
scalar_parser(polyremoval, varname="polyremoval", ntype="int_like", lims=[0, 1])

if not isinstance(rectify, bool):
SPYValueError("either `True` or `False`", varname='rectify', actual=rectify)

# -- get trial info

# if a subset selection is present
Expand All @@ -143,6 +155,17 @@ def preprocessing(data,
# act = "non-equidistant sampling"
# raise SPYValueError(lgl, varname="data", actual=act)

# -- post processing
if rectify and hilbert:
lgl = "either rectification or Hilbert transform"
raise SPYValueError(lgl, varname="rectify/hilbert", actual=(rectify, hilbert))

# `hilbert` acts both as a switch and a parameter to set the output (like in FT)
if hilbert:
if hilbert not in hilbert_outputs:
lgl = f"one of {hilbert_outputs}"
raise SPYValueError(lgl, varname="hilbert", actual=hilbert)

# -- Method calls

# Prepare keyword dict for logging (use `lcls` to get actually provided
Expand Down Expand Up @@ -177,7 +200,8 @@ def preprocessing(data,
log_dict["order"] = order
log_dict["direction"] = direction

check_effective_parameters(But_Filtering, defaults, lcls)
check_effective_parameters(But_Filtering, defaults, lcls,
besides=('hilbert', 'rectify'))

filterMethod = But_Filtering(samplerate=data.samplerate,
filter_type=filter_type,
Expand Down Expand Up @@ -211,7 +235,7 @@ def preprocessing(data,
log_dict["direction"] = direction

check_effective_parameters(Sinc_Filtering, defaults, lcls,
besides=['filter_class'])
besides=['filter_class', 'hilbert', 'rectify'])

filterMethod = Sinc_Filtering(samplerate=data.samplerate,
filter_type=filter_type,
Expand All @@ -222,16 +246,48 @@ def preprocessing(data,
polyremoval=polyremoval,
timeAxis=timeAxis)

# ------------------------------------
# Call the chosen ComputationalRoutine
# ------------------------------------
# -------------------------------------------
# Call the chosen filter ComputationalRoutine
# -------------------------------------------

out = AnalogData(dimord=data.dimord)
filtered = AnalogData(dimord=data.dimord)
# Perform actual computation
filterMethod.initialize(data,
out._stackingDim,
data._stackingDim,
chan_per_worker=kwargs.get("chan_per_worker"),
keeptrials=True)
filterMethod.compute(data, out, parallel=kwargs.get("parallel"), log_dict=log_dict)

return out
filterMethod.compute(data, filtered, parallel=kwargs.get("parallel"), log_dict=log_dict)

# -- check for post processing flags --

if rectify:
log_dict['rectify'] = rectify
rectified = AnalogData(dimord=data.dimord)
rectCR = Rectify()
rectCR.initialize(filtered,
data._stackingDim,
chan_per_worker=kwargs.get("chan_per_worker"),
keeptrials=True)
rectCR.compute(filtered, rectified,
parallel=kwargs.get("parallel"),
log_dict=log_dict)
pantaray marked this conversation as resolved.
Show resolved Hide resolved
del filtered
return rectified
pantaray marked this conversation as resolved.
Show resolved Hide resolved

elif hilbert:
log_dict['hilbert'] = hilbert
htrafo = AnalogData(dimord=data.dimord)
hilbertCR = Hilbert(output=hilbert,
timeAxis=timeAxis)
hilbertCR.initialize(filtered, data._stackingDim,
chan_per_worker=kwargs.get("chan_per_worker"),
keeptrials=True)
hilbertCR.compute(filtered, htrafo,
parallel=kwargs.get("parallel"),
log_dict=log_dict)
del filtered
return htrafo

pantaray marked this conversation as resolved.
Show resolved Hide resolved
# no post-processing
else:
return filtered
Loading