diff --git a/herdingspikes/detection_lightning/detect.pyx b/herdingspikes/detection_lightning/detect.pyx index b4865ed..7ce0052 100644 --- a/herdingspikes/detection_lightning/detect.pyx +++ b/herdingspikes/detection_lightning/detect.pyx @@ -433,8 +433,8 @@ class HSDetectionLightning(object): ) result: dict[str, RealArray] = { - "sample_ind": sample_ind, - "channel_ind": channel_ind, + "sample_index": sample_ind, + "channel_index": channel_ind, "amplitude": amplitude, } if self.localize: diff --git a/herdingspikes/hs2.py b/herdingspikes/hs2.py index 9f32d99..b63cebc 100644 --- a/herdingspikes/hs2.py +++ b/herdingspikes/hs2.py @@ -517,8 +517,8 @@ def DetectFromRaw(self): sp[0]["spike_shape"] = np.zeros(len(sp[0]["sample_ind"])) self.spikes = pd.DataFrame( { - "ch": sp[0]["channel_ind"], - "t": sp[0]["sample_ind"], + "ch": sp[0]["channel_index"], + "t": sp[0]["sample_index"], "Amplitude": sp[0]["amplitude"], "x": sp[0]["location"][:, 0], "y": sp[0]["location"][:, 1], @@ -540,6 +540,7 @@ def DetectFromRaw(self): else: h["cutout_length"] = 0 h.close() + return sp def PlotTracesChannels( self, @@ -1432,3 +1433,49 @@ def PlotNeighbourhood( ax[1].plot(self.spikes.Shape[i], color=(0.4, 0.4, 0.4)) ax[1].plot(np.mean(self.spikes.Shape[spInds].values, axis=0), color="k") return ax + + +def detect_peaks_lightning(recording, params=None): + """ + Detect spikes in a recording using the lightning framework. This function is compatible + with the SpikeInterface sorting components framework. Note it does not return spike locations. + + Parameters + ---------- + recording : RecordingExtractor + The recording extractor object + params : dict + The parameters for the spike detection. If None, default parameters are used. + + Returns + ------- + peaks : np.array + Structured array with the detected peaks. Fields are: + * 'sample_index' : int + The index of the peak sample + * 'channel_index' : int + The index of the channel + * 'amplitude' : float + The amplitude of the peak + + """ + + det = HSDetectionLightning(recording, params=params) + peaks = det.DetectFromRaw() + peaks_array = np.array( + list( + tuple( + map( + tuple, + np.array( + [ + (peaks[0][k]) + for k in ["sample_index", "channel_index", "amplitude"] + ] + ).T, + ) + ) + ), + dtype=[("sample_index", "