From 5bb460459e8d844d019f29dd18a326e0f9d4234b Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Wed, 18 Dec 2024 13:50:21 -0500 Subject: [PATCH 1/2] WIP on several fixes --- kilosort/clustering_qr.py | 9 +++++++-- kilosort/postprocessing.py | 6 +++++- kilosort/template_matching.py | 16 +++++++++++++--- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/kilosort/clustering_qr.py b/kilosort/clustering_qr.py index ef659b66..93d4ea67 100644 --- a/kilosort/clustering_qr.py +++ b/kilosort/clustering_qr.py @@ -366,6 +366,7 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), clu = np.zeros(nsp, 'int32') Wall = torch.zeros((0, ops['Nchan'], ops['settings']['n_pcs'])) + Nfilt = None nearby_chans_empty = 0 nmax = 0 prog = tqdm(np.arange(len(ycent)), miniters=20 if progress_bar else None, @@ -433,9 +434,13 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), except: logger.exception(f'Error in clustering_qr.run on center {ii}') logger.debug(f'Xd shape: {Xd.shape}') - logger.debug(f'iclust shape: {iclust.shape}') - logger.debug(f'clu shape: {clu.shape}') logger.debug(f'Nfilt: {Nfilt}') + logger.debug(f'num spikes: {nsp}') + try: + logger.debug(f'iclust shape: {iclust.shape}') + except UnboundLocalError: + logger.debug('iclust not yet assigned') + pass raise if nearby_chans_empty == len(ycent): diff --git a/kilosort/postprocessing.py b/kilosort/postprocessing.py index a7286aec..f662f7b7 100644 --- a/kilosort/postprocessing.py +++ b/kilosort/postprocessing.py @@ -33,9 +33,13 @@ def remove_duplicates(spike_times, spike_clusters, dt=15): def compute_spike_positions(st, tF, ops): '''Get x,y positions of spikes relative to probe.''' tmass = (tF**2).sum(-1) - tmass = tmass / tmass.sum(1, keepdim=True) xc = torch.from_numpy(ops['xc']).to(tmass.device) yc = torch.from_numpy(ops['yc']).to(tmass.device) + # TODO: also store distance to each of these channels, and multiply by + # tmass before summing so that far away channels are ~0 + tmask = ops['iCC_mask'][:, ops['iU'][st[:,1]]] # 1 if close enough, 0 if too far away (tbd, maybe 100ish um) + tmass = tmass * tmask + tmass = tmass / tmass.sum(1, keepdim=True) chs = ops['iCC'][:, ops['iU'][st[:,1]]].cpu() xc0 = xc[chs.T] yc0 = yc[chs.T] diff --git a/kilosort/template_matching.py b/kilosort/template_matching.py index 3089ff33..7a4113e4 100644 --- a/kilosort/template_matching.py +++ b/kilosort/template_matching.py @@ -14,15 +14,25 @@ def prepare_extract(ops, U, nC, device=torch.device('cuda')): ds = (ops['xc'] - ops['xc'][:, np.newaxis])**2 + (ops['yc'] - ops['yc'][:, np.newaxis])**2 iCC = np.argsort(ds, 0)[:nC] - iCC = torch.from_numpy(iCC).to(device) + iCC = torch.from_numpy(iCC, device=device) + iCC_mask = np.sorg(ds, 0)[:nC] + iCC_mask = iCC_mask < 10000 # 100um squared + iCC_mask = torch.from_numpy(iCC_mask, device=device) iU = torch.argmax((U**2).sum(1), -1) Ucc = U[torch.arange(U.shape[0]),:,iCC[:,iU]] - return iCC, iU, Ucc + + # iCC: nC nearest channels to each channel + # iCC_mask: 1 if above is within 100um of channel, 0 otherwise + # iU: index of max channel for each template + # Ucc: spatial PC features corresponding to iCC for each template + + return iCC, iCC_mask, iU, Ucc def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None): nC = ops['settings']['nearest_chans'] - iCC, iU, Ucc = prepare_extract(ops, U, nC, device=device) + iCC, iCC_mask, iU, Ucc = prepare_extract(ops, U, nC, device=device) ops['iCC'] = iCC + ops['iCC_mask'] = iCC_mask ops['iU'] = iU nt = ops['nt'] From 92332706f88b11f1d66f60cfe1c25c86f0b23f01 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Mon, 30 Dec 2024 12:53:31 -0500 Subject: [PATCH 2/2] Added fix for spike smearing between shanks --- kilosort/parameters.py | 12 ++++++++ kilosort/postprocessing.py | 14 +++++---- kilosort/template_matching.py | 55 +++++++++++++++++++++++++++-------- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/kilosort/parameters.py b/kilosort/parameters.py index c0896e3b..51445ed0 100644 --- a/kilosort/parameters.py +++ b/kilosort/parameters.py @@ -398,6 +398,18 @@ default of 7 bins for a 30kHz sampling rate. """ }, + + 'position_limit': { + 'gui_name': 'position limit', 'type': float, 'min': 0, 'max': np.inf, + 'exclude': [], 'default': 100, 'step': 'postprocessing', + 'description': + """ + Maximum distance (in microns) between channels that can be used + to estimate spike positions in `postprocessing.compute_spike_positions`. + This does not affect spike sorting, only how positions are estimated + after sorting is complete. + """ + }, } # Add default values to descriptions diff --git a/kilosort/postprocessing.py b/kilosort/postprocessing.py index f662f7b7..b4fc3560 100644 --- a/kilosort/postprocessing.py +++ b/kilosort/postprocessing.py @@ -32,18 +32,22 @@ def remove_duplicates(spike_times, spike_clusters, dt=15): def compute_spike_positions(st, tF, ops): '''Get x,y positions of spikes relative to probe.''' + # Determine channel weightings for nearest channels + # based on norm of PC features. Channels that are far away have 0 weight, + # determined by `ops['settings']['position_limit']`. tmass = (tF**2).sum(-1) - xc = torch.from_numpy(ops['xc']).to(tmass.device) - yc = torch.from_numpy(ops['yc']).to(tmass.device) - # TODO: also store distance to each of these channels, and multiply by - # tmass before summing so that far away channels are ~0 - tmask = ops['iCC_mask'][:, ops['iU'][st[:,1]]] # 1 if close enough, 0 if too far away (tbd, maybe 100ish um) + tmask = ops['iCC_mask'][:, ops['iU'][st[:,1]]].T.to(tmass.device) tmass = tmass * tmask tmass = tmass / tmass.sum(1, keepdim=True) + + # Get x,y coordinates of nearest channels. + xc = torch.from_numpy(ops['xc']).to(tmass.device) + yc = torch.from_numpy(ops['yc']).to(tmass.device) chs = ops['iCC'][:, ops['iU'][st[:,1]]].cpu() xc0 = xc[chs.T] yc0 = yc[chs.T] + # Estimate spike positions as weighted sum of coordinates of nearby channels. xs = (xc0 * tmass).sum(1).cpu().numpy() ys = (yc0 * tmass).sum(1).cpu().numpy() diff --git a/kilosort/template_matching.py b/kilosort/template_matching.py index 7a4113e4..69f8eaae 100644 --- a/kilosort/template_matching.py +++ b/kilosort/template_matching.py @@ -11,26 +11,54 @@ logger = logging.getLogger(__name__) -def prepare_extract(ops, U, nC, device=torch.device('cuda')): - ds = (ops['xc'] - ops['xc'][:, np.newaxis])**2 + (ops['yc'] - ops['yc'][:, np.newaxis])**2 +def prepare_extract(xc, yc, U, nC, position_limit, device=torch.device('cuda')): + """Identify desired channels based on distances and template norms. + + Parameters + ---------- + xc : np.ndarray + X-coordinates of contact positions on probe. + yc : np.ndarray + Y-coordinates of contact positions on probe. + U : torch.Tensor + TODO + nC : int + Number of nearest channels to use. + position_limit : float + Max distance (in microns) between channels that are used to estimate + spike positions in `postprocessing.compute_spike_positions`. + + Returns + ------- + iCC : np.ndarray + For each channel, indices of nC nearest channels. + iCC_mask : np.ndarray + For each channel, a 1 if the channel is within 100um and a 0 otherwise. + Used to control spike position estimate in post-processing. + iU : torch.Tensor + For each template, index of channel with greatest norm. + Ucc : torch.Tensor + For each template, spatial PC features corresponding to iCC. + + """ + ds = (xc - xc[:, np.newaxis])**2 + (yc - yc[:, np.newaxis])**2 iCC = np.argsort(ds, 0)[:nC] - iCC = torch.from_numpy(iCC, device=device) - iCC_mask = np.sorg(ds, 0)[:nC] - iCC_mask = iCC_mask < 10000 # 100um squared - iCC_mask = torch.from_numpy(iCC_mask, device=device) + iCC = torch.from_numpy(iCC).to(device) + iCC_mask = np.sort(ds, 0)[:nC] + iCC_mask = iCC_mask < position_limit**2 + iCC_mask = torch.from_numpy(iCC_mask).to(device) iU = torch.argmax((U**2).sum(1), -1) Ucc = U[torch.arange(U.shape[0]),:,iCC[:,iU]] - # iCC: nC nearest channels to each channel - # iCC_mask: 1 if above is within 100um of channel, 0 otherwise - # iU: index of max channel for each template - # Ucc: spatial PC features corresponding to iCC for each template - return iCC, iCC_mask, iU, Ucc + def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None): nC = ops['settings']['nearest_chans'] - iCC, iCC_mask, iU, Ucc = prepare_extract(ops, U, nC, device=device) + position_limit = ops['settings']['position_limit'] + iCC, iCC_mask, iU, Ucc = prepare_extract( + ops['xc'], ops['yc'], U, nC, position_limit, device=device + ) ops['iCC'] = iCC ops['iCC_mask'] = iCC_mask ops['iU'] = iU @@ -95,6 +123,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None): return st, tF, ops + def align_U(U, ops, device=torch.device('cuda')): Uex = torch.einsum('xyz, zt -> xty', U.to(device), ops['wPCA']) X = Uex.reshape(-1, ops['Nchan']).T @@ -118,6 +147,7 @@ def postprocess_templates(Wall, ops, clu, st, device=torch.device('cuda')): Wall3 = Wall3.transpose(1,2).to(device) return Wall3 + def prepare_matching(ops, U): nt = ops['nt'] W = ops['wPCA'].contiguous() @@ -132,6 +162,7 @@ def prepare_matching(ops, U): return ctc + def run_matching(ops, X, U, ctc, device=torch.device('cuda')): Th = ops['Th_learned'] nt = ops['nt']