Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Principal Curves #338

Merged
merged 8 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions brainlit/algorithms/generate_fragments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from brainlit.algorithms.generate_fragments.tube_seg import *
from brainlit.algorithms.generate_fragments.adaptive_thresh import *
from brainlit.algorithms.generate_fragments.state_generation import *
from brainlit.algorithms.generate_fragments.pcurve import *
147 changes: 147 additions & 0 deletions brainlit/algorithms/generate_fragments/pcurve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2020, zsteve
# Written entirely by https://github.com/zsteve
# https://github.com/zsteve/pcurvepy
# This code was slightly modified by https://github.com/CaseyWeiner, 2022
# CaseyWeiner added an "s_factor" parameter to init to allow smoothing factor
# customization in the univariate spline interpolation step.

import sklearn
import numpy as np
from sklearn.decomposition import PCA
from scipy.interpolate import UnivariateSpline
from typing import Tuple


class PrincipalCurve:
def __init__(self, k: int = 3, s_factor: int = 1) -> None:
self.k = k
self.p = None
self.s = None
self.p_interp = None
self.s_interp = None
self.s_factor = s_factor

def project(
self, X: np.ndarray, p: np.ndarray, s: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get interpolating s values for projection of X onto the curve defined by (p, s)

Args:
X (np.ndarray): data
p (np.ndarray): curve points
s (np.ndarray): curve parameterisation

Returns:
np.ndarray: interpolating parameter values
np.ndarray: projected points on curve
np.ndarray: sum of square distances
"""
s_interp = np.zeros(X.shape[0])
p_interp = np.zeros(X.shape)
d_sq = 0

for i in range(0, X.shape[0]):
z = X[i, :]
seg_proj = (
((p[1:] - p[0:-1]).T)
* np.einsum("ij,ij->i", z - p[0:-1], p[1:] - p[0:-1])
/ np.power(np.linalg.norm(p[1:] - p[0:-1], axis=1), 2)
).T # compute parallel component
proj_dist = (z - p[0:-1]) - seg_proj # compute perpendicular component
dist_endpts = np.minimum(
np.linalg.norm(z - p[0:-1], axis=1), np.linalg.norm(z - p[1:], axis=1)
)
dist_seg = np.maximum(np.linalg.norm(proj_dist, axis=1), dist_endpts)

idx_min = np.argmin(dist_seg)
q = seg_proj[idx_min]
s_interp[i] = (
np.linalg.norm(q) / np.linalg.norm(p[idx_min + 1, :] - p[idx_min, :])
) * (s[idx_min + 1] - s[idx_min]) + s[idx_min]
p_interp[i] = (s_interp[i] - s[idx_min]) * (
p[idx_min + 1, :] - p[idx_min, :]
) + p[idx_min, :]
d_sq = d_sq + np.linalg.norm(proj_dist[idx_min]) ** 2

return (s_interp, p_interp, d_sq)

def renorm_parameterisation(self, p: int) -> np.ndarray:
"""Renormalise curve to unit speed

Args:
p (np.ndarray): curve points

Returns:
np.ndarray: new parameterisation
"""
seg_lens = np.linalg.norm(p[1:] - p[0:-1], axis=1)
s = np.zeros(p.shape[0])
s[1:] = np.cumsum(seg_lens)
s = s / sum(seg_lens)
return s

def fit(
self,
X: np.ndarray,
p: np.ndarray = None,
w: np.ndarray = None,
max_iter: int = 10,
tol: float = 1e-3,
) -> None:
"""Fit principal curve to data

Args:
X (np.ndarray): data
p (np.ndarray): starting curve (optional)
w (np.ndarray): data weights (optional)
max_iter (int): maximum number of iterations
tol (float): tolerance for stopping condition
"""
pca = sklearn.decomposition.PCA(n_components=X.shape[1])
pca.fit(X)
pc1 = pca.components_[:, 0]
if p is None:
p = np.kron(np.dot(X, pc1) / np.dot(pc1, pc1), pc1).reshape(
X.shape
) # starting point for iteration
order = np.argsort(
[np.linalg.norm(p[0, :] - p[i, :]) for i in range(0, p.shape[0])]
)
p = p[order]
s = self.renorm_parameterisation(p)

p_interp = np.zeros(X.shape)
s_interp = np.zeros(X.shape[0])
d_sq_old = np.Inf

for i in range(0, max_iter):
s_interp, p_interp, d_sq = self.project(X, p, s)

if np.abs(d_sq - d_sq_old) < tol:
break
d_sq_old = d_sq

order = np.argsort(s_interp)
# s_interp = s_interp[order]
# X = X[order, :]

s_in = len(s_interp) * self.s_factor

spline = [
UnivariateSpline(s_interp[order], X[order, j], s=s_in, k=self.k, w=w)
for j in range(0, X.shape[1])
] # Alter k, s

p = np.zeros((len(s_interp), X.shape[1]))
for j in range(0, X.shape[1]):
p[:, j] = spline[j](s_interp[order])

idx = [i for i in range(0, p.shape[0] - 1) if (p[i] != p[i + 1]).any()]
p = p[idx, :]
s = self.renorm_parameterisation(p)

self.s = s
self.p = p
self.p_interp = p_interp
self.s_interp = s_interp
return
40 changes: 40 additions & 0 deletions brainlit/algorithms/generate_fragments/state_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import networkx as nx
from typing import List, Tuple

from brainlit.algorithms.generate_fragments import pcurve


class state_generation:
def __init__(
Expand Down Expand Up @@ -448,6 +450,44 @@ def _endpoints_from_coords_neighbors(self, coords: np.ndarray) -> List[list]:

return ends

def _pc_endpoints_from_coords_neighbors(self, coords: np.ndarray) -> List[list]:
"""Compute endpoints of fragment with Principal Curves.

Args:
coords (np.array): coordinates of voxels in the fragment

Returns:
list: endpoints of the fragment

References
----------
.. [1] Hastie, Trevor, and Werner Stuetzle. “Principal Curves.”
Journal of the American Statistical Association, vol. 84, no. 406,
[American Statistical Association, Taylor & Francis, Ltd.], 1989,
pp. 502–16, https://doi.org/10.2307/2289936.
.. [2] Principal Curves Code written by zsteve,
https://github.com/zsteve, https://github.com/zsteve/pcurvepy
"""

ends = []

# Make sure x, y, z ascending & don't repeat
sorter = np.lexsort((coords[:, 2], coords[:, 1], coords[:, 0]))
coords = coords[sorter]
coords = np.unique(coords, axis=0)

p_curve = pcurve.PrincipalCurve(k=1, s_factor=5)
p_curve.fit(coords, max_iter=50)
pc = p_curve.p

pc = np.floor(pc + 0.5)
pc_frag_list = [i for i in pc if i in coords]

ends.append(pc_frag_list[0])
CaseyWeiner marked this conversation as resolved.
Show resolved Hide resolved
ends.append(pc_frag_list[-1])

return ends

def _compute_states_thread(
self, corner1: List[int], corner2: List[int]
) -> List[tuple]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@
### functionality checks ###
############################

test_coords = np.hstack(
(
np.arange(100).reshape(100, 1),
np.arange(100).reshape(100, 1),
np.arange(100).reshape(100, 1),
)
)


def test_state_generation():
sg = state_generation(
Expand All @@ -65,5 +73,7 @@ def test_state_generation():
print(G.nodes[node])
assert len(G.nodes) == 9 # 2 states per fragment plus one soma state

sg._pc_endpoints_from_coords_neighbors(test_coords)

sg.compute_edge_weights()
sg.compute_bfs()