Skip to content

Commit

Permalink
Merge branch 'master' into cythonization
Browse files Browse the repository at this point in the history
  • Loading branch information
CHrlS98 committed Dec 18, 2024
2 parents f0b630c + c352faf commit 45493c3
Show file tree
Hide file tree
Showing 32 changed files with 729 additions and 461 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:
pull_request:
branches:
- master
merge_group:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand Down Expand Up @@ -70,7 +71,7 @@ jobs:
.test_reports/
coverage:
runs-on: scilus-runners
runs-on: ubuntu-latest
if: github.repository == 'scilus/scilpy'
needs: test

Expand Down
4 changes: 2 additions & 2 deletions docs/source/documentation/construct_participants_tsv_file.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Instructions to write the tsv files "participants.tsv" for the script scil_group_comparison.py
Instructions to write the tsv files "participants.tsv" for the script scil_stats_group_comparison.py
===============================================================================================

The TSV file should follow the BIDS `specification <https://bids-specification.readthedocs.io/en/stable/03-modality-agnostic-files.html#participants-file>`_.
Expand All @@ -12,7 +12,7 @@ participant_id categorical_var_1 categorical_var_2 ...

(ex: participant_id sex nb_children)

The categorical variable name are the "group_by" variable that can be called by scil_group_comparison.py
The categorical variable name are the "group_by" variable that can be called by scil_stats_group_comparison.py

Specific row
------------
Expand Down
43 changes: 20 additions & 23 deletions scilpy/connectivity/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
import h5py
import nibabel as nib
import numpy as np
from scipy.ndimage import map_coordinates

from scilpy.image.labels import get_data_as_labels
from scilpy.io.hdf5 import reconstruct_streamlines_from_hdf5
from scilpy.tractanalysis.reproducibility_measures import \
compute_bundle_adjacency_voxel
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
from scilpy.tractograms.streamline_operations import \
resample_streamlines_num_points
from scilpy.utils.metrics_tools import compute_lesion_stats


d = threading.local()


Expand All @@ -31,8 +35,8 @@ def compute_triu_connectivity_from_labels(tractogram, data_labels,
----------
tractogram: StatefulTractogram, or list[np.ndarray]
Streamlines. A StatefulTractogram input is recommanded.
When using directly with a list of streamlines, streamlinee must be in
vox space, corner origin.
When using directly with a list of streamlines, streamlines must be in
vox space, center origin.
data_labels: np.ndarray
The loaded nifti image.
keep_background: Bool
Expand All @@ -55,12 +59,12 @@ def compute_triu_connectivity_from_labels(tractogram, data_labels,
For each streamline, the label at ending point.
"""
if isinstance(tractogram, StatefulTractogram):
# Vox space, corner origin
# = we can get the nearest neighbor easily.
# Coord 0 = voxel 0. Coord 0.9 = voxel 0. Coord 1 = voxel 1.
tractogram.to_vox()
tractogram.to_corner()
streamlines = tractogram.streamlines
# vox space, center origin: compatible with map_coordinates
sfs_2_pts = resample_streamlines_num_points(tractogram, 2)
sfs_2_pts.to_vox()
sfs_2_pts.to_center()
streamlines = sfs_2_pts.streamlines

else:
streamlines = tractogram

Expand All @@ -71,23 +75,16 @@ def compute_triu_connectivity_from_labels(tractogram, data_labels,
.format(nb_labels))

matrix = np.zeros((nb_labels, nb_labels), dtype=int)
start_labels = []
end_labels = []

for s in streamlines:
start = ordered_labels.index(
data_labels[tuple(np.floor(s[0, :]).astype(int))])
end = ordered_labels.index(
data_labels[tuple(np.floor(s[-1, :]).astype(int))])

start_labels.append(start)
end_labels.append(end)
labels = map_coordinates(data_labels, streamlines._data.T, order=0)
start_labels = labels[0::2]
end_labels = labels[1::2]

matrix[start, end] += 1
if start != end:
matrix[end, start] += 1
# sort each pair of labels for start to be smaller than end
start_labels, end_labels = zip(*[sorted(pair) for pair in
zip(start_labels, end_labels)])

matrix = np.triu(matrix)
np.add.at(matrix, (start_labels, end_labels), 1)
assert matrix.sum() == len(streamlines)

# Rejecting background
Expand Down Expand Up @@ -249,7 +246,7 @@ def compute_connectivity_matrices_from_hdf5(

if compute_volume:
measures_to_return['volume_mm3'] = np.count_nonzero(density) * \
np.prod(voxel_sizes)
np.prod(voxel_sizes)

if compute_streamline_count:
measures_to_return['streamline_count'] = len(streamlines)
Expand Down
3 changes: 1 addition & 2 deletions scilpy/io/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import logging
import os
import pathlib
import zipfile

import requests
import zipfile

from scilpy import SCILPY_HOME

Expand Down
2 changes: 1 addition & 1 deletion scilpy/io/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def reconstruct_sft_from_hdf5(hdf5_handle, group_keys, space=Space.VOX,
for sub_key in hdf5_handle[group_key].keys():
if sub_key not in ['data', 'offsets', 'lengths']:
data = hdf5_handle[group_key][sub_key]
if data.shape == hdf5_handle[group_key]['offsets']:
if data.shape == hdf5_handle[group_key]['offsets'].shape:
# Discovered dps
if load_dps:
if i == 0 or not merge_groups:
Expand Down
67 changes: 67 additions & 0 deletions scilpy/tracking/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,70 @@ def get_next_n_pos(self, random_generator, shuffled_indices,
random_generator)

return seeds


class CustomSeedsDispenser(SeedGenerator):
"""
Adaptation of the scilpy.tracking.seed.SeedGenerator interface for
using already generated, custom seeds.
"""
def __init__(self, custom_seeds, space=Space('vox'),
origin=Origin('center')):
"""
Custom seeds need to be in the same space and origin as the ODFs used
for tracking.
Parameters
----------
custom_seeds: list
Custom seeding coordinates.
space: Space (optional)
The Dipy space in which the seeds were saved.
Default: Space.Vox or 'vox'
origin: Origin (optional)
The Dipy origin in which the seeds were saved.
Default: Origin.NIFTI or 'center'
"""
self.origin = origin
self.space = space
self.seeds = custom_seeds
self.i = 0

def init_generator(self, rng_seed, numbers_to_skip):
"""
Does not do anything. Simulates SeedGenerator's implementation for
retro-compatibility.
Parameters
----------
rng_seed : int
The "seed" for the random generator.
numbers_to_skip : int
Number of seeds (i.e. voxels) to skip. Useful if you want to
continue sampling from the same generator as in a first experiment
(with a fixed rng_seed).
Return
------
random_generator : numpy random generator
Initialized numpy number generator.
indices : ndarray
Empty list for interface retro-compatibility
"""
self.i = numbers_to_skip

return np.random.default_rng(rng_seed), []

def get_next_pos(self, random_generator: np.random.Generator,
shuffled_indices, which_seed):
seed = self.seeds[self.i]
self.i += 1

return seed[0], seed[1], seed[2]

def get_next_n_pos(self, random_generator, shuffled_indices,
which_seed_start, n):
seeds = self.seeds[self.i:self.i+n]
self.i += n

return seeds
6 changes: 6 additions & 0 deletions scilpy/tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def add_seeding_options(p):
help='Number of seeds per voxel.')
seed_sub_exclusive.add_argument('--nt', type=int,
help='Total number of seeds to use.')
seed_sub_exclusive.add_argument(
'--in_custom_seeds', type=str,
help='Path to a file containing a list of custom seeding \n'
'coordinates (.txt, .mat or .npy). They should be in \n'
'voxel space. In the case of a text file, each line should \n'
'contain a single seed, written in the format: [x, y, z].')


def add_out_options(p):
Expand Down
7 changes: 2 additions & 5 deletions scilpy/tractograms/streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,19 +408,16 @@ def filter_streamlines_by_length(sft, min_length=0., max_length=np.inf,
valid_length_ids = np.logical_and(lengths >= min_length,
lengths <= max_length)
filtered_sft = sft[valid_length_ids]

if return_rejected:
rejected_sft = sft[~valid_length_ids]
else:
valid_length_ids = []
valid_length_ids = np.array([], dtype=bool)
filtered_sft = sft

# Return to original space
sft.to_space(orig_space)
filtered_sft.to_space(orig_space)

if return_rejected:
rejected_sft.to_space(orig_space)
rejected_sft = sft[~valid_length_ids]
return filtered_sft, valid_length_ids, rejected_sft
else:
return filtered_sft, valid_length_ids
Expand Down
12 changes: 12 additions & 0 deletions scilpy/tractograms/tests/test_streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from dipy.io.streamline import load_tractogram
from dipy.tracking.streamlinespeed import length
from dipy.io.stateful_tractogram import StatefulTractogram

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict
Expand Down Expand Up @@ -174,6 +175,17 @@ def test_filter_streamlines_by_length():
# Test that streamlines shorter than 100 and longer than 120 were removed.
assert np.all(lengths >= min_length) and np.all(lengths <= max_length)

# === 4. Return rejected streamlines with empty sft ===
empty_sft = short_sft[[]] # Empty sft from short_sft (chosen arbitrarily)
filtered_sft, _, rejected = \
filter_streamlines_by_length(empty_sft, min_length=min_length,
max_length=max_length,
return_rejected=True)
assert isinstance(filtered_sft, StatefulTractogram)
assert isinstance(rejected, StatefulTractogram)
assert len(filtered_sft) == 0
assert len(rejected) == 0


def test_filter_streamlines_by_total_length_per_dim():
long_sft = load_tractogram(in_long_sft, in_ref)
Expand Down
77 changes: 0 additions & 77 deletions scilpy/utils/metrics_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import logging
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
Expand Down Expand Up @@ -302,78 +300,3 @@ def get_bundle_metrics_mean_std_per_point(streamlines, bundle_name,
label_stats['mean'] = float(label_mean)
label_stats['std'] = float(label_std)
return stats


def plot_metrics_stats(means, stds, title=None, xlabel=None,
ylabel=None, figlabel=None, fill_color=None,
display_means=False):
"""
Plots the mean of a metric along n points with the standard deviation.
Parameters
----------
means: Numpy 1D (or 2D) array of size n
Mean of the metric along n points.
stds: Numpy 1D (or 2D) array of size n
Standard deviation of the metric along n points.
title: string
Title of the figure.
xlabel: string
Label of the X axis.
ylabel: string
Label of the Y axis (suggestion: the metric name).
figlabel: string
Label of the figure (only metadata in the figure object returned).
fill_color: string
Hexadecimal RGB color filling the region between mean ± std. The
hexadecimal RGB color should be formatted as #RRGGBB
display_means: bool
Display the subjects means as semi-transparent line
Return
------
The figure object.
"""
matplotlib.style.use('ggplot')

fig, ax = plt.subplots()

# Set optional information to the figure, if required.
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if figlabel is not None:
fig.set_label(figlabel)

if means.ndim > 1:
mean = np.average(means, axis=1)
std = np.average(stds, axis=1)
alpha = 0.5
else:
mean = np.array(means).ravel()
std = np.array(stds).ravel()
alpha = 0.9

dim = np.arange(1, len(mean)+1, 1)

if len(mean) <= 20:
ax.xaxis.set_ticks(dim)

ax.set_xlim(0, len(mean)+1)

if means.ndim > 1 and display_means:
for i in range(means.shape[-1]):
ax.plot(dim, means[:, i], color="k", linewidth=1,
solid_capstyle='round', alpha=0.1)

# Plot the mean line.
ax.plot(dim, mean, color="k", linewidth=5, solid_capstyle='round')

# Plot the std
plt.fill_between(dim, mean - std, mean + std,
facecolor=fill_color, alpha=alpha)

plt.close(fig)
return fig
8 changes: 4 additions & 4 deletions scilpy/viz/backends/fury.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def set_viewport(scene, orientation, slice_index, volume_shape, aspect_ratio):
Ratio between viewport's width and height.
"""

scene.projection('parallel')
scene.projection(proj_type='parallel')
camera = initialize_camera(
orientation, slice_index, volume_shape, aspect_ratio)
scene.set_camera(position=camera[CamParams.VIEW_POS],
Expand All @@ -162,7 +162,7 @@ def set_viewport(scene, orientation, slice_index, volume_shape, aspect_ratio):


def create_scene(actors, orientation, slice_index, volume_shape, aspect_ratio,
bg_color=(0, 0, 0)):
*, bg_color=(0, 0, 0)):
"""
Create a 3D scene containing actors fitting inside a grid. The camera is
placed based on the orientation supplied by the user. The projection mode
Expand Down Expand Up @@ -201,7 +201,7 @@ def create_scene(actors, orientation, slice_index, volume_shape, aspect_ratio,
return scene


def create_interactive_window(scene, window_size, interactor,
def create_interactive_window(scene, window_size, interactor, *,
title="Viewer", open_window=True):
"""
Create a 3D window with the content of scene, equiped with an interactor.
Expand All @@ -226,7 +226,7 @@ def create_interactive_window(scene, window_size, interactor,
Object from Fury containing the 3D scene interactor.
"""

showm = window.ShowManager(scene, title=title,
showm = window.ShowManager(scene=scene, title=title,
size=window_size,
reset_camera=False,
interactor_style=interactor)
Expand Down
Loading

0 comments on commit 45493c3

Please sign in to comment.