diff --git a/docs/source/documentation/fibertube_tracking.rst b/docs/source/documentation/fibertube_tracking.rst new file mode 100644 index 000000000..4dbcc9921 --- /dev/null +++ b/docs/source/documentation/fibertube_tracking.rst @@ -0,0 +1,318 @@ +Introduction to the Fibertube Tracking environment through an interactive demo. +==== + +In this demo, you will be introduced to the main scripts of this project +as you apply them on simple data. Our main objective is to better +understand and quantify the fundamental limitations of tractography +algorithms, and how they might evolve as we approach microscopy +resolution where individual axons can be seen. To do so, we will be +evaluating tractography's ability to reconstruct individual white matter +fiber strands at simulated extreme resolutions (mimicking "infinite" +resolution). + +Terminology +----------- + +Here is a list of terms and definitions used in this project. + +General: + +- Axon: Bio-physical object. Portion of the nerve cell that carries out + the electrical impulse to other neurons. (On the order of 0.1 to 1um) +- Streamline: Virtual object. Series of 3D coordinates approximating an + underlying fiber structure. + +Fibertube Tracking: + +- Fibertube: Virtual representation of an axon. Tube obtained from + combining a diameter to a streamline. +- Centerline: Virtual object. Streamline passing through the center of + a tubular structure. +- Fibertube segment: Cylindrical segment of a fibertube that comes as a + result of the discretization of its centerline. +- Fibertube Tractography: The computational tractography method that + reconstructs fibertubes. Contrary to traditional white matter fiber + tractography, fibertube tractography does not rely on a discretized + grid of fODFs or peaks. It directly tracks and reconstructs + fibertubes, i.e. streamlines that have an associated diameter. + +.. image:: https://github.com/user-attachments/assets/0286ec53-5bca-4133-93dd-22f360dfcb45 + :alt: Fibertube visualized in 3D + +Methodology +----------- + +This project can be split into 3 major steps: + +- Preparing ground-truth data: We will be using the ground-truth of + simulated phantoms of streamlines, along with a diameter (giving us + fibertubes) and ensuring that they are void of any collision, i.e. + fibertubes in the simulated phantom should not intersect one another. + This is physically impossible to respect the geometry of axons. +- Tracking and experimentation: We will perform fibertube tracking on + our newly formed set of fibertubes with a variety of parameter + combinations. +- Evaluation metrics computation: By passing the resulting tractogram + through different evaluation scripts (like Tractometer), we will + acquire connectivity and fiber reconstruction scores for each of the + parameter combinations. + +Preparing the data +------------------ + +To download the data required for this demo, open a terminal, move to any +location you see fit for this demo and execute the following command: +:: + + wget https://scil.usherbrooke.ca/scil_test_data/dvc-store/files/md5/82/248b4888a63b0aeffc8070cc206995 -O others.zip && unzip others.zip -d Data && mv others.zip Data/others.zip && chmod -R 755 Data && cp ./Data/others/fibercup_bundles.trk ./centerlines.trk && echo 0.001 >diameters.txt + +This will fetch a tractogram to act as our set of centerlines, and then +generate diameters to form our fibertubes. + +``centerlines.trk`` is a subset of the FiberCup phantom ground truth: + +.. image:: https://github.com/user-attachments/assets/3be43cc9-60ec-4e97-95ef-a436c32bba83 + :alt: Fibercup subset visualized in 3D + +The first thing to do with our data is to resample ``centerlines.trk`` +so that each centerline is formed of segments no longer than 0.2 mm. + +Note: This is because the next script will rely on a KDTree to find +all neighboring fibertube segments of any given point. Because the +search radius is set at the length of the longest fibertube segment, +the performance drops significantly if they are not shortened to +~0.2mm. + +To resample a tractogram, we can use this script from scilpy. Don't +forget to activate your scilpy environment first. + +:: + + scil_tractogram_resample_nb_points.py centerlines.trk centerlines_resampled.trk --step_size 0.2 -f + +Next, we want to filter out intersecting fibertubes (collisions), to +make the data anatomically plausible and ensure that there exists a +resolution at which there is no unit of space containing partial +volume. + +.. image:: https://github.com/user-attachments/assets/d9b0519b-c1e3-4de0-8529-92aa92041ce2 + :alt: Fibertube intersection visualized in 3D + +This is accomplished using ``scil_tractogram_filter_collisions.py``. + +:: + + scil_tractogram_filter_collisions.py centerlines_resampled.trk diameters.txt fibertubes.trk --save_colliding --out_metrics metrics.json -v -f + +After 3-5 minutes, you should get something like: + +:: + + ... + ├── centerlines_resampled_obstacle.trk + ├── centerlines_resampled_invalid.trk + ├── fibertubes.trk + ├── metrics.json + ... + +As you may have guessed from the output name, this script automatically +combines the diameter to the centerlines as data_per_streamline in the +output tractogram. This is why we named it "fibertubes.trk". + +If you wish to know how many fibertubes are left after filtering, you +can run the following command: + +``scil_tractogram_print_info.py fibertubes.trk`` + +Visualising collisions +---------------------- + +By calling: + +:: + + scil_viz_tractogram_collisions.py centerlines_resampled_invalid.trk --in_tractogram_obstacle centerlines_resampled_obstacle.trk --ref_tractogram centerlines.trk + +You are able to see exactly which streamline has been filtered +("invalid" - In red) as well as the streamlines they collided with +("obstacle" - In green). In white and lower opacity is the original +tractogram passed as ``--ref_tractogram``. + +.. image:: https://github.com/user-attachments/assets/7ab864f5-e4a3-421b-8431-ef4a5b3150c8 + :alt: Filtered intersections visualized in 3D + +Fibertube metrics +----------------- + +Before we get into tracking. Here is an overview of the metrics that we +saved in ``metrics.json``. (Values expressed in mm): + +- ``fibertube_density``: + Estimate of the following ratio: volume of fibertubes / total volume + where the total volume is the combined volume of all voxels containing + at least one fibertube. +- ``min_external_distance``: Smallest distance separating two + fibertubes, outside their diameter. +- ``max_voxel_anisotropic``: Diagonal vector of the largest possible + anisotropic voxel that would not intersect two fibertubes. +- ``max_voxel_isotropic``: Isotropic version of max_voxel_anisotropic + made by using the smallest component. Ex: max_voxel_anisotropic: (3, + 5, 5) => max_voxel_isotropic: (3, 3, 3) +- ``max_voxel_rotated``: Largest possible isotropic voxel obtainable with + a different coordinate system. It is only usable if the entire tractogram + is rotated according to [rotation_matrix]. Ex: max_voxel_anisotropic: + (1, 0, 0) => max_voxel_rotated: (0.5774, 0.5774, 0.5774) + +If the option is provided. The following matrix would be saved in a +different file: + +- ``rotation_matrix``: 4D transformation matrix containing the rotation to be + applied on the tractogram to align max_voxel_rotated with the coordinate + system. (see scil_tractogram_apply_transform.py). + + +.. image:: https://github.com/user-attachments/assets/43cebcbe-e3b1-4ca0-999e-e042db8aa937 + :alt: Metrics (without max_voxel_rotated) visualized in 3D + +.. image:: https://github.com/user-attachments/assets/924ab3f9-33da-458f-a98b-b4e88b051ae8 + :alt: max_voxel_rotated visualized in 3D + +Note: This information can be useful for analyzing the +reconstruction obtained through tracking, as well as for performing +track density imaging at extreme resolutions. + +Performing fibertube tracking +----------------------------- + +We're finally at the tracking phase! Using the script +``scil_fibertube_tracking.py``, you are able to track without relying on +a discretized grid of directions or fODFs. Instead, you will be +propagating a streamline through fibertubes and controlling the +resolution by using a ``blur_radius``. The way it works is as follows: + +Seeding +~~~~~~~ + +A number of seeds is set randomly within the first segment of +every fibertube. We can however change the number of fibertubes that +will be tracked, as well as the amount of seeds within each. (See +Seeding options in the help menu). + +Tracking +~~~~~~~~ + +When the tracking algorithm is about to select a new direction to +propagate the current streamline, it will build a sphere of radius +``blur_radius`` and pick randomly from all the fibertube segments +intersecting with it. The larger the intersection volume, the more +likely a fibertube segment is to be picked and used as a tracking +direction. + + +.. image:: https://github.com/user-attachments/assets/0308c206-c396-41c5-a0e1-bb69b692c101 + :alt: Visualization of the blurring sphere intersecting with segments + + +For more information and better visualization, watch the following +presentation: https://docs.google.com/presentation/d/1nRV2j_A8bHOcjGSHtNmD8MsA9n5pHvR8/edit#slide=id.p19 + + +This makes fibertube tracking inherently probabilistic. +Theoretically, with a ``blur_radius`` of 0, any given set of coordinates +has either a single tracking direction because it is within a fibertube, +or no direction at all from being out of one. In fact, this behavior +won't change until the diameter of the sphere is larger than the +smallest distance separating two fibertubes. When this happens, more +than one fibertubes will intersect the ``blur_radius`` sphere and +introduce partial volume effect. + +The interface of the script is very similar to +``scil_tracking_local_dev.py``, but simplified and with a ``blur_radius`` +option. Let us do: + +:: + + scil_fibertube_tracking.py fibertubes.trk tracking.trk --blur_radius 0.1 --step_size 0.1 --nb_fibertubes 3 --out_config tracking_config.json --processes 4 -v -f + +This should take a minute or two and will produce 15 streamlines. The loading +bar of each thread will only update every 100 streamlines. It may look +like it's frozen, but rest assured. it's still going! + +Reconstruction analysis +~~~~~~~~~~~~~~~~~~~~~~~ + +By using the ``scil_fibertube_score_tractogram.py`` script, you are able +to obtain measures on the quality of the fibertube tracking that was +performed. + +Each streamline is associated with an "Arrival fibertube segment", which is +the closest fibertube segment to its before-last coordinate. We then define +the following terms: + +VC: "Valid Connection": A streamline whose arrival fibertube segment is +the final segment of the fibertube in which is was originally seeded. + +IC: "Invalid Connection": A streamline whose arrival fibertube segment is +the start or final segment of a fibertube in which is was not seeded. + +NC: "No Connection": A streamline whose arrival fibertube segment is +not the start or final segment of any fibertube. + +.. image:: https://github.com/user-attachments/assets/ac36d847-2363-4b23-a69b-43c9d4d40b9a + :alt: Visualization of VC, IC and NC + +The "absolute error" of a coordinate is the distance in mm between that +coordinate and the closest point on its corresponding fibertube. The +average of all coordinate absolute errors of a streamline is called the +"Mean absolute error" or "mae". + +Here is a visual representation of streamlines (Green) tracked along a fibertube +(Only the centerline is shown in blue) with their coordinate absolute error (Red). + + +.. image:: https://github.com/user-attachments/assets/62324b66-f66b-43ae-a772-086560ef713a + :alt: Visualization of the coordinate absolute error + +Computed metrics: + +- vc_ratio: Number of VC divided by the number of streamlines. +- ic_ratio: Number of IC divided by the number of streamlines. +- nc_ratio: Number of NC divided by the number of streamlines. +- mae_min: Minimum MAE for the tractogram. +- mae_max: Maximum MAE for the tractogram. +- mae_mean: Average MAE for the tractogram. +- mae_med: Median MAE for the tractogram. + +To score the produced tractogram, we run: + +:: + + scil_fibertube_score_tractogram.py fibertubes.trk tracking.trk tracking_config.json reconstruction_metrics.json -f + +giving us the following output in ``reconstruction_metrics.json``: + +:: + + { + "vc_ratio": 0.3333333333333333, + "ic_ratio": 0.4, + "nc_ratio": 0.26666666666666666, + "mae_min": 0.004093314514974615, + "mae_max": 10.028780087103556, + "mae_mean": 3.055598084631571, + "mae_med": 0.9429987731800447 + } + +This data tells us that 1/3 of streamlines had the end of their own fibertube as +their arrival fibertube segment (``"vc_ratio": 0.3333333333333333``). +For 40% of streamlines, their arrival fibertube segment was the start or end of +another fibertube (``"ic_ratio": 0.4``). 26% of streamlines had an arrival fibertube +segment that was not a start or end segment (``"nc_ratio": 0.26666666666666666``). +Lastly, we notice that the streamline with the "worst" trajectory was on average +~10.03mm away from its fibertube (``"mae_max": 10.028780087103556``). + +This is not very good, but it's to be expected with a --blur_radius and +--step_size of 0.1. If you have a few minutes, try again with 0.01! + +End of Demo +----------- diff --git a/docs/source/index.rst b/docs/source/index.rst index c9dafc2d4..f22a0bd09 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,4 +17,5 @@ Welcome to the scilpy documentation! documentation/construct_participants_tsv_file documentation/create_overlapping_slice_mosaics documentation/devcontainer + documentation/fibertube_tracking documentation/tractogram_registration diff --git a/requirements.txt b/requirements.txt index 23cb23241..d946292e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,8 @@ matplotlib==3.6.* PyMCubes==0.1.* nibabel==5.2.* nilearn==0.9.* +numba==0.59.1 +numba-kdtree==0.4.0 nltk==3.8.* numpy==1.25.* openpyxl==3.0.* diff --git a/scilpy/connectivity/connectivity.py b/scilpy/connectivity/connectivity.py new file mode 100644 index 000000000..dd6b75901 --- /dev/null +++ b/scilpy/connectivity/connectivity.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +import logging +import os +import threading + +from dipy.io.stateful_tractogram import StatefulTractogram +from dipy.io.utils import is_header_compatible, get_reference_info +from dipy.tracking.streamlinespeed import length +from dipy.tracking.vox2track import _streamlines_in_mask +import h5py +import nibabel as nib +import numpy as np + +from scilpy.image.labels import get_data_as_labels +from scilpy.io.hdf5 import reconstruct_streamlines_from_hdf5 +from scilpy.tractanalysis.reproducibility_measures import \ + compute_bundle_adjacency_voxel +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from scilpy.utils.metrics_tools import compute_lesion_stats + +d = threading.local() + + +def compute_triu_connectivity_from_labels(tractogram, data_labels, + keep_background=False, + hide_labels=None): + """ + Compute a connectivity matrix. + + Parameters + ---------- + tractogram: StatefulTractogram, or list[np.ndarray] + Streamlines. A StatefulTractogram input is recommanded. + When using directly with a list of streamlines, streamlinee must be in + vox space, corner origin. + data_labels: np.ndarray + The loaded nifti image. + keep_background: Bool + By default, the background (label 0) is not included in the matrix. + If True, label 0 is kept. + hide_labels: Optional[List[int]] + If not None, streamlines ending in a voxel with a given label are + ignored (i.e. matrix is set to 0 for that label). + + Returns + ------- + matrix: np.ndarray + With use_scilpy: shape (nb_labels + 1, nb_labels + 1) + Else, shape (nb_labels, nb_labels) + ordered_labels: List + The list of labels. Name of each row / column. + start_labels: List + For each streamline, the label at starting point. + end_labels: List + For each streamline, the label at ending point. + """ + if isinstance(tractogram, StatefulTractogram): + # Vox space, corner origin + # = we can get the nearest neighbor easily. + # Coord 0 = voxel 0. Coord 0.9 = voxel 0. Coord 1 = voxel 1. + tractogram.to_vox() + tractogram.to_corner() + streamlines = tractogram.streamlines + else: + streamlines = tractogram + + ordered_labels = list(np.sort(np.unique(data_labels))) + assert ordered_labels[0] >= 0, "Only accepting positive labels." + nb_labels = len(ordered_labels) + logging.debug("Computing connectivity matrix for {} labels." + .format(nb_labels)) + + matrix = np.zeros((nb_labels, nb_labels), dtype=int) + start_labels = [] + end_labels = [] + + for s in streamlines: + start = ordered_labels.index( + data_labels[tuple(np.floor(s[0, :]).astype(int))]) + end = ordered_labels.index( + data_labels[tuple(np.floor(s[-1, :]).astype(int))]) + + start_labels.append(start) + end_labels.append(end) + + matrix[start, end] += 1 + if start != end: + matrix[end, start] += 1 + + matrix = np.triu(matrix) + assert matrix.sum() == len(streamlines) + + # Rejecting background + if not keep_background and ordered_labels[0] == 0: + logging.debug("Rejecting background.") + ordered_labels = ordered_labels[1:] + matrix = matrix[1:, 1:] + + # Hiding labels + if hide_labels is not None: + for label in hide_labels: + if label not in ordered_labels: + logging.warning("Cannot hide label {} because it was not in " + "the data.".format(label)) + continue + idx = ordered_labels.index(label) + nb_hidden = np.sum(matrix[idx, :]) + np.sum(matrix[:, idx]) - \ + matrix[idx, idx] + if nb_hidden > 0: + logging.warning("{} streamlines had one or both endpoints " + "in hidden label {} (line/column {})" + .format(nb_hidden, label, idx)) + matrix[idx, :] = 0 + matrix[:, idx] = 0 + else: + logging.info("No streamlines with endpoints in hidden label " + "{} (line/column {}) :)".format(label, idx)) + ordered_labels[idx] = ("Hidden label ({})".format(label)) + + return matrix, ordered_labels, start_labels, end_labels + + +def load_node_nifti(directory, in_label, out_label, ref_img): + in_filename = os.path.join(directory, + '{}_{}.nii.gz'.format(in_label, out_label)) + + if os.path.isfile(in_filename): + if not is_header_compatible(in_filename, ref_img): + raise IOError('{} do not have a compatible header'.format( + in_filename)) + return nib.load(in_filename).get_fdata(dtype=np.float32) + + return None + + +def multi_proc_compute_connectivity_matrices_from_hdf5(args): + (hdf5_filename, labels_img, comb, + compute_volume, compute_streamline_count, compute_length, + similarity_directory, metrics_data, metrics_names, lesion_data, + include_dps, weighted, min_lesion_vol) = args + return compute_connectivity_matrices_from_hdf5( + hdf5_filename, labels_img, comb[0], comb[1], + compute_volume, compute_streamline_count, compute_length, + similarity_directory, metrics_data, metrics_names, lesion_data, + include_dps, weighted, min_lesion_vol) + + +def compute_connectivity_matrices_from_hdf5( + hdf5_filename, labels_img, in_label, out_label, + compute_volume=True, compute_streamline_count=True, + compute_length=True, similarity_directory=None, metrics_data=None, + metrics_names=None, lesion_data=None, include_dps=False, + weighted=False, min_lesion_vol=0): + """ + Parameters + ---------- + hdf5_filename: str + Name of the hdf5 file containing the precomputed connections (bundles) + labels_img: np.ndarray + Data as labels + in_label: str + Name of one extremity of the bundle to analyse. + out_label: str + Name of the other extremity. Current node is {in_label}_{out_label}. + compute_volume: bool + If true, return 'volume_mm3' in the returned dictionary with the volume + of the bundle. + compute_streamline_count: bool + If true, return 'streamline_count' in the returned dictionary, with + the number of streamlines in the bundle. + compute_length: bool + If true, return 'length_mm' in the returned dictionary, with the mean + length of streamlines in the bundle. + similarity_directory: str + If not None, should be a directory containing nifti files that + represent density maps for each connection, using the + _.nii.gz conventions. + Typically computed from a template (must be in the same space). + metrics_data: list[np.ndarray] + List of 3D data with metrics to use, with the list of associated metric + names. If set, the returned dictionary will contain an entry for each + name, with the mean value of each metric. + metrics_names: list[str] + The metrics names. + lesion_data: Tuple[list, np.ndarray] + The (lesion_labels, lesion_data) for lesion load analysis. If set, the + returned dictionary will contain the three entries 'lesion_volume': + the total lesion volume, 'lesion_streamline_count': the number of + streamlines passing through lesions, 'lesion_count': the number of + lesions. + include_dps: bool + If true, return an entry for each dps with the mean dps value. + weighted: bool + If true, weight the results with the density map. + min_lesion_vol: float + Minimum lesion volume for a lesion to be considered. + + Returns + ------- + final_dict: Tuple[dict, list[str]] or None + dict: {(in_label, out_label): measures_dict} + A dictionary with the node as key and as the dictionary as + described above. + dps_keys: The list of keys included from dps. + """ + if len(metrics_data) > 0: + assert len(metrics_data) == len(metrics_names) + + affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img) + + measures_to_return = {} + + # Getting the bundle from the hdf5 + with h5py.File(hdf5_filename, 'r') as hdf5_file: + key = '{}_{}'.format(in_label, out_label) + if key not in hdf5_file: + logging.debug("Connection {} not found in the hdf5".format(key)) + return None + streamlines = reconstruct_streamlines_from_hdf5(hdf5_file[key]) + if len(streamlines) == 0: + logging.debug("Connection {} contained no streamline".format(key)) + return None + logging.debug("Found {} streamlines for connection {}" + .format(len(streamlines), key)) + + # Getting dps info from the hdf5 + dps_keys = [] + if include_dps: + for dps_key in hdf5_file[key].keys(): + if dps_key not in ['data', 'offsets', 'lengths']: + if 'commit' in dps_key: + dps_values = np.sum(hdf5_file[key][dps_key]) + else: + dps_values = np.average(hdf5_file[key][dps_key]) + measures_to_return[dps_key] = dps_values + dps_keys.append(dps_key) + + # If density is not required, do not compute it + # Only required for volume, similarity and any metrics + if (compute_volume or similarity_directory is not None or + len(metrics_data) > 0): + density = compute_tract_counts_map(streamlines, dimensions) + + if compute_length: + # scil_tractogram_segment_connections_from_labels.py requires + # isotropic voxels + mean_length = np.average(length(list(streamlines))) * voxel_sizes[0] + measures_to_return['length_mm'] = mean_length + + if compute_volume: + measures_to_return['volume_mm3'] = np.count_nonzero(density) * \ + np.prod(voxel_sizes) + + if compute_streamline_count: + measures_to_return['streamline_count'] = len(streamlines) + + if similarity_directory is not None: + density_sim = load_node_nifti(similarity_directory, + in_label, out_label, labels_img) + if density_sim is None: + ba_vox = 0 + else: + ba_vox = compute_bundle_adjacency_voxel(density, density_sim) + + measures_to_return['similarity'] = ba_vox + + for metric_data, metric_name in zip(metrics_data, metrics_names): + if weighted: + avg_value = np.average(metric_data, weights=density) + else: + avg_value = np.average(metric_data[density > 0]) + measures_to_return[metric_name] = avg_value + + if lesion_data is not None: + lesion_labels, lesion_img = lesion_data + voxel_sizes = lesion_img.header.get_zooms()[0:3] + lesion_img.set_filename('tmp.nii.gz') + lesion_atlas = get_data_as_labels(lesion_img) + tmp_dict = compute_lesion_stats( + density.astype(bool), lesion_atlas, + voxel_sizes=voxel_sizes, single_label=True, + min_lesion_vol=min_lesion_vol, + precomputed_lesion_labels=lesion_labels) + + tmp_ind = _streamlines_in_mask(list(streamlines), + lesion_atlas.astype(np.uint8), + np.eye(3), [0, 0, 0]) + streamlines_count = len( + np.where(tmp_ind == [0, 1][True])[0].tolist()) + + if tmp_dict: + measures_to_return['lesion_vol'] = tmp_dict['lesion_total_volume'] + measures_to_return['lesion_count'] = tmp_dict['lesion_count'] + measures_to_return['lesion_streamline_count'] = streamlines_count + else: + measures_to_return['lesion_vol'] = 0 + measures_to_return['lesion_count'] = 0 + measures_to_return['lesion_streamline_count'] = 0 + + return {(in_label, out_label): measures_to_return}, dps_keys diff --git a/scilpy/connectivity/connectivity_tools.py b/scilpy/connectivity/matrix_tools.py similarity index 100% rename from scilpy/connectivity/connectivity_tools.py rename to scilpy/connectivity/matrix_tools.py diff --git a/scilpy/image/tests/test_volume_operations.py b/scilpy/image/tests/test_volume_operations.py index f48176e1f..b4949c64d 100644 --- a/scilpy/image/tests/test_volume_operations.py +++ b/scilpy/image/tests/test_volume_operations.py @@ -204,6 +204,16 @@ def test_resample_volume(): assert_equal(resampled_img.get_fdata(), ref3d) assert resampled_img.affine[0, 0] == 3 + # 4) Same test, with a fake 4th dimension + moving3d = np.stack((moving3d, moving3d), axis=-1) + moving3d_img = nib.Nifti1Image(moving3d, np.eye(4)) + resampled_img = resample_volume(moving3d_img, voxel_res=(3, 3, 3), + interp='nn') + result = resampled_img.get_fdata() + assert_equal(result[:, :, :, 0], ref3d) + assert_equal(result[:, :, :, 1], ref3d) + assert resampled_img.affine[0, 0] == 3 + def test_reshape_volume_pad(): # 3D img diff --git a/scilpy/image/volume_space_management.py b/scilpy/image/volume_space_management.py index b952090f1..a21e9998a 100644 --- a/scilpy/image/volume_space_management.py +++ b/scilpy/image/volume_space_management.py @@ -1,6 +1,13 @@ # -*- coding: utf-8 -*- import numpy as np +from numba_kdtree import KDTree +from numba import njit +from scilpy.tracking.fibertube_utils import (streamlines_to_segments, + point_in_cylinder, + sphere_cylinder_intersection) +from scilpy.tractograms.streamline_operations import \ + get_streamlines_as_fixed_array from dipy.core.interpolation import trilinear_interpolate4d, \ nearestneighbor_interpolate from dipy.io.stateful_tractogram import Origin, Space @@ -357,3 +364,198 @@ def _is_voxmm_in_bound(self, x, y, z, origin): True if position is in dataset range and false otherwise. """ return self.is_idx_in_bound(*self.voxmm_to_idx(x, y, z, origin)) + + +class FibertubeDataVolume(DataVolume): + """ + Adaptation of the scilpy.image.volume_space_management.AbstractDataVolume + interface for fibertube tracking. Instead of a spherical function, + provides direction and intersection volume of close-by fiber segments. + + Data given at initialization must have "center" origin. Additionally, + FibertubeDataVolume enforces this origin at every function call. This is + because the origin must stay coherent with the data given at + initialization and cannot change afterwards. + """ + + VALID_ORIGIN = Origin.NIFTI + + def __init__(self, centerlines, diameters, reference, blur_radius, + random_generator): + """ + Parameters + ---------- + centerlines: list + Tractogram containing the fibertube centerlines + diameters: list + Diameters of each fibertube + reference: StatefulTractogram + Spatial reference used to obtain the dimensions and pixel + resolution of the data. Should be a stateful tractogram. + blur_radius: float + Radius of the blurring sphere to be used for degrading resolution. + random_generator: numpy random generator + """ + # Prepare data + if centerlines is None: + self.data = [] + self.tree = None + self.segments_indices = None + self.max_seg_length = None + return + + segments_centers, segments_indices, max_seg_length = ( + streamlines_to_segments(centerlines, False)) + self.tree = KDTree(segments_centers) + self.segments_indices = segments_indices + self.max_seg_length = max_seg_length + self.dim = reference.dimensions[:3] + self.data, _ = get_streamlines_as_fixed_array(centerlines) + self.diameters = diameters + self.max_diameter = max(diameters) + + # Rest of init + self.voxres = reference.voxel_sizes + self.blur_radius = blur_radius + self.random_generator = random_generator + + @staticmethod + def _validate_origin(origin): + if FibertubeDataVolume.VALID_ORIGIN is not origin: + raise ValueError("FibertubeDataVolume only supports origin: " + + FibertubeDataVolume.VALID_ORIGIN.value + ". " + "Given origin is: " + origin.value + ".") + + def get_value_at_idx(self, i, j, k): + i, j, k = self._clip_idx_to_bound(i, j, k) + return self._voxmm_to_value(i, j, k) + + def get_value_at_coordinate(self, x, y, z, space, origin): + FibertubeDataVolume._validate_origin(origin) + + if space == Space.VOX: + return self._voxmm_to_value(*self.vox_to_voxmm(x, y, z), origin) + elif space == Space.VOXMM: + return self._voxmm_to_value(x, y, z, origin) + else: + raise NotImplementedError("We have not prepared the DataVolume " + "to work in RASMM space yet.") + + def is_idx_in_bound(self, i, j, k): + return super().is_idx_in_bound(i, j, k) + + def is_coordinate_in_bound(self, x, y, z, space, origin): + FibertubeDataVolume._validate_origin(origin) + return super().is_coordinate_in_bound(x, y, z, space, origin) + + @staticmethod + def vox_to_idx(x, y, z, origin): + FibertubeDataVolume._validate_origin(origin) + return super(FibertubeDataVolume, + FibertubeDataVolume).vox_to_idx(x, y, z, origin) + + def voxmm_to_idx(self, x, y, z, origin): + FibertubeDataVolume._validate_origin(origin) + return super().voxmm_to_idx(x, y, z, origin) + + def vox_to_voxmm(self, x, y, z): + """ + Get mm space coordinates at position x, y, z (vox). + + Parameters + ---------- + x, y, z: floats + Position coordinate (vox) along x, y, z axis. + + Return + ------ + x, y, z: floats + mm space coordinates for position x, y, z. + """ + + # Does not depend on origin! + # In each dimension: + # In corner: 0 to 1 will become 0 to voxres. + # In center: -0.5 to 0.5 will become -0.5*voxres to 0.5*voxres. + return [x * self.voxres[0], + y * self.voxres[1], + z * self.voxres[2]] + + def _clip_voxmm_to_bound(self, x, y, z, origin): + return self.vox_to_voxmm(*self._clip_vox_to_bound( + *self.voxmm_to_vox(x, y, z), origin)) + + def _vox_to_value(self, x, y, z, origin): + return self._voxmm_to_value(*self.vox_to_voxmm(x, y, z), origin) + + def _voxmm_to_value(self, x, y, z, origin): + x, y, z = self._clip_voxmm_to_bound(x, y, z, origin) + + pos = np.array([x, y, z], dtype=np.float64) + + neighbors = self.tree.query_radius( + pos, + self.blur_radius + self.max_seg_length / 2 + self.max_diameter)[0] + + return self.extract_directions(pos, neighbors, self.blur_radius, + self.segments_indices, self.data, + self.diameters, self.random_generator) + + def get_absolute_direction(self, x, y, z): + pos = np.array([x, y, z], np.float64) + + neighbors = self.tree.query_radius( + pos, + self.blur_radius + self.max_seg_length / 2 + self.max_diameter)[0] + + for segi in neighbors: + fi, pi = self.segments_indices[segi] + fiber = self.data[fi] + radius = self.diameters[fi] / 2 + + if point_in_cylinder(fiber[pi], fiber[pi+1], radius, pos): + return fiber[pi+1] - fiber[pi] + + return None + + @staticmethod + @njit + def extract_directions(pos, neighbors, blur_radius, segments_indices, + centerlines, diameters, random_generator, + volume_nb_samples=1000, + volume_nb_samples_backup=10000): + directions = [] + volumes = [] + + for segi in neighbors: + fi, pi = segments_indices[segi] + fiber = centerlines[fi] + fib_pt1 = fiber[pi] + fib_pt2 = fiber[pi+1] + dir = fib_pt2 - fib_pt1 + radius = diameters[fi] / 2 + + volume, is_estimated = sphere_cylinder_intersection( + pos, blur_radius, fib_pt1, + fib_pt2, radius, + volume_nb_samples, + random_generator) + + # Catch estimation error when using very small blur_radius. + if volume == 0 and is_estimated: + volume, _ = sphere_cylinder_intersection( + pos, blur_radius, fib_pt1, + fib_pt2, radius, + volume_nb_samples_backup, + random_generator) + + if volume > 0: + directions.append(dir / np.linalg.norm(dir)) + volumes.append(volume) + + if len(volumes) > 0: + max_vol = max(volumes) + for vol in volumes: + vol /= max_vol + + return (directions, volumes) diff --git a/scilpy/io/fetcher.py b/scilpy/io/fetcher.py index 127c4004e..6876d6e6e 100644 --- a/scilpy/io/fetcher.py +++ b/scilpy/io/fetcher.py @@ -59,7 +59,8 @@ def get_testing_files_dict(): "fodf_filtering.zip": "5985c0644321ecf81fd694fb91e2c898", "processing.zip": "eece5cdbf437b8e4b5cb89c797872e28", "surface_vtk_fib.zip": "241f3afd6344c967d7176b43e4a99a41", - "tractograms.zip": "1eb29085db974b5e58d32b13eb76fbe6" + "tractograms.zip": "1eb29085db974b5e58d32b13eb76fbe6", + "mrds.zip": "5abe6092400e11e9bb2423e2c387e774" } diff --git a/scilpy/io/streamlines.py b/scilpy/io/streamlines.py index cf26c85cc..0fc196f6a 100644 --- a/scilpy/io/streamlines.py +++ b/scilpy/io/streamlines.py @@ -75,8 +75,8 @@ def load_tractogram_with_reference(parser, args, filepath, arg_name=None): filepath: str Path of the tractogram file. arg_name: str, optional - Name of the reference argument. By default the args.ref is used. If - arg_name is given, then args.arg_name_ref will be used instead. + Name of the reference argument. By default the args.reference is used. + If arg_name is given, then args.arg_name_ref will be used instead. """ if is_argument_set(args, 'bbox_check'): bbox_check = args.bbox_check diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index 6db0e19cd..9fd6ed351 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -22,6 +22,8 @@ from scilpy.utils.spatial import RAS_AXES_NAMES +FLOATING_POINTS_PRECISION = 12 + eddy_options = ["mb", "mb_offs", "slspec", "mporder", "s2v_lambda", "field", "field_mat", "flm", "slm", "fwhm", "niter", "s2v_niter", "cnr_maps", "residuals", "fep", "interp", "s2v_interp", @@ -276,6 +278,14 @@ def add_skip_b0_check_arg(parser, will_overwrite_with_min, '--skip_b0_check', action='store_true', help=msg) +def add_precision_arg(parser): + parser.add_argument('--precision', type=ranged_type(int, 1), + default=FLOATING_POINTS_PRECISION, + help='Precision for floating point values. Numbers ' + 'are rounded up to \nthe number of decimals ' + 'provided. [Default: %(default)s]') + + def add_verbose_arg(parser): parser.add_argument('-v', default="WARNING", const='INFO', nargs='?', choices=['DEBUG', 'INFO', 'WARNING'], dest='verbose', @@ -729,6 +739,37 @@ def check(path): check(optional_file) +def assert_inputs_dirs_exist(parser, required, optional=None): + """ + Assert that all inputs directories exist. If not, print parser's usage and + exit. + + Parameters + ---------- + parser: argparse.ArgumentParser object + Parser. + required: string or list of paths + Required paths to be checked. + optional: string or list of paths + Optional paths to be checked. + """ + def check(path): + if not os.path.isdir(path): + parser.error('Input directory {} does not exist'.format(path)) + + if isinstance(required, str): + required = [required] + + if isinstance(optional, str): + optional = [optional] + + for required_file in required: + check(required_file) + for optional_file in optional or []: + if optional_file is not None: + check(optional_file) + + def assert_outputs_exist(parser, args, required, optional=None, check_dir_exists=True): """ diff --git a/scilpy/reconst/divide.py b/scilpy/reconst/divide.py index 7ba882c91..6253a3215 100644 --- a/scilpy/reconst/divide.py +++ b/scilpy/reconst/divide.py @@ -100,7 +100,7 @@ def _gamma_data2fit(signal, gtab_infos, fit_iters=1, random_iters=50, Returns ------- best_params : np.array - Array containing the parameters of the fit. + Array containing the parameters of the fit. Shape: (4,) """ if np.sum(gtab_infos[3]) > 0 and do_multiple_s0 is True: ns = len(np.unique(gtab_infos[3])) - 1 @@ -263,25 +263,33 @@ def gamma_fit2metrics(params): def _fit_gamma_parallel(args): - data = args[0] - gtab_infos = args[1] - fit_iters = args[2] - random_iters = args[3] - do_weight_bvals = args[4] - do_weight_pa = args[5] - do_multiple_s0 = args[6] - chunk_id = args[7] - - sub_fit_array = np.zeros((data.shape[0], 4)) - for i in range(data.shape[0]): - if data[i].any(): - sub_fit_array[i] = _gamma_data2fit(data[i], gtab_infos, fit_iters, - random_iters, do_weight_bvals, - do_weight_pa, do_multiple_s0) + # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. + (data, gtab_infos, fit_iters, random_iters, + do_weight_bvals, do_weight_pa, do_multiple_s0, chunk_id) = args + + sub_fit_array = _fit_gamma_loop(data, gtab_infos, fit_iters, + random_iters, do_weight_bvals, + do_weight_pa, do_multiple_s0) return chunk_id, sub_fit_array +def _fit_gamma_loop(data, gtab_infos, fit_iters, random_iters, + do_weight_bvals, do_weight_pa, do_multiple_s0): + """ + Loops on 2D data and fits each voxel separately + See _gamma_data2fit for a complete description. + """ + # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. + tmp_fit_array = np.zeros((data.shape[0], 4)) + for i in range(data.shape[0]): + if data[i].any(): + tmp_fit_array[i] = _gamma_data2fit( + data[i], gtab_infos, fit_iters, random_iters, + do_weight_bvals, do_weight_pa, do_multiple_s0) + return tmp_fit_array + + def fit_gamma(data, gtab_infos, mask=None, fit_iters=1, random_iters=50, do_weight_bvals=False, do_weight_pa=False, do_multiple_s0=False, nbr_processes=None): @@ -328,30 +336,40 @@ def fit_gamma(data, gtab_infos, mask=None, fit_iters=1, random_iters=50, or nbr_processes <= 0 else nbr_processes # Ravel the first 3 dimensions while keeping the 4th intact, like a list of - # 1D time series voxels. Then separate it in chunks of len(nbr_processes). + # 1D time series voxels. data = data[mask].reshape((np.count_nonzero(mask), data_shape[3])) - chunks = np.array_split(data, nbr_processes) - - chunk_len = np.cumsum([0] + [len(c) for c in chunks]) - pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_fit_gamma_parallel, - zip(chunks, - itertools.repeat(gtab_infos), - itertools.repeat(fit_iters), - itertools.repeat(random_iters), - itertools.repeat(do_weight_bvals), - itertools.repeat(do_weight_pa), - itertools.repeat(do_multiple_s0), - np.arange(len(chunks)))) - pool.close() - pool.join() - - # Re-assemble the chunk together in the original shape. - fit_array = np.zeros((data_shape[0:3])+(4,)) - tmp_fit_array = np.zeros((np.count_nonzero(mask), 4)) - for i, fit in results: - tmp_fit_array[chunk_len[i]:chunk_len[i+1]] = fit + # Separating the case nbr_processes=1 to help get good coverage metrics + # (codecov does not deal well with multiprocessing) + if nbr_processes == 1: + tmp_fit_array = _fit_gamma_loop(data, gtab_infos, fit_iters, + random_iters, do_weight_bvals, + do_weight_pa, do_multiple_s0) + else: + # Separate the data in chunks of len(nbr_processes). + chunks = np.array_split(data, nbr_processes) + + pool = multiprocessing.Pool(nbr_processes) + results = pool.map(_fit_gamma_parallel, + zip(chunks, + itertools.repeat(gtab_infos), + itertools.repeat(fit_iters), + itertools.repeat(random_iters), + itertools.repeat(do_weight_bvals), + itertools.repeat(do_weight_pa), + itertools.repeat(do_multiple_s0), + np.arange(len(chunks)))) + pool.close() + pool.join() + + # Re-assemble the chunks together. + chunk_len = np.cumsum([0] + [len(c) for c in chunks]) + tmp_fit_array = np.zeros((np.count_nonzero(mask), 4)) + for chunk_id, fit in results: + tmp_fit_array[chunk_len[chunk_id]:chunk_len[chunk_id + 1]] = fit + + # Bring back to the original shape + fit_array = np.zeros((data_shape[0:3]) + (4,)) fit_array[mask] = tmp_fit_array return fit_array diff --git a/scilpy/reconst/fodf.py b/scilpy/reconst/fodf.py index af73a581d..652f35132 100644 --- a/scilpy/reconst/fodf.py +++ b/scilpy/reconst/fodf.py @@ -126,20 +126,27 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, def _fit_from_model_parallel(args): - model = args[0] - data = args[1] - chunk_id = args[2] + (model, data, chunk_id) = args + sub_fit_array = _fit_from_model_loop(data, model) - sub_fit_array = np.zeros((data.shape[0],), dtype='object') + return chunk_id, sub_fit_array + + +def _fit_from_model_loop(data, model): + """ + Loops on 2D data and fits each voxel separately. + See fit_from_model for more information. + """ + # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. + tmp_fit_array = np.zeros((data.shape[0],), dtype='object') for i in range(data.shape[0]): if data[i].any(): try: - sub_fit_array[i] = model.fit(data[i]) + tmp_fit_array[i] = model.fit(data[i]) except cvx.error.SolverError: coeff = np.full((len(model.n)), np.NaN) - sub_fit_array[i] = MSDeconvFit(model, coeff, None) - - return chunk_id, sub_fit_array + tmp_fit_array[i] = MSDeconvFit(model, coeff, None) + return tmp_fit_array def fit_from_model(model, data, mask=None, nbr_processes=None): @@ -181,23 +188,31 @@ def fit_from_model(model, data, mask=None, nbr_processes=None): # Ravel the first 3 dimensions while keeping the 4th intact, like a list of # 1D time series voxels. Then separate it in chunks of len(nbr_processes). data = data[mask].reshape((np.count_nonzero(mask), data_shape[3])) - chunks = np.array_split(data, nbr_processes) - - chunk_len = np.cumsum([0] + [len(c) for c in chunks]) - pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_fit_from_model_parallel, - zip(itertools.repeat(model), - chunks, - np.arange(len(chunks)))) - pool.close() - pool.join() - - # Re-assemble the chunk together in the original shape. - fit_array = np.zeros(data_shape[0:3], dtype='object') - tmp_fit_array = np.zeros((np.count_nonzero(mask)), dtype='object') - for i, fit in results: - tmp_fit_array[chunk_len[i]:chunk_len[i+1]] = fit + # Separating the case nbr_processes=1 to help get good coverage metrics + # (codecov does not deal well with multiprocessing) + if nbr_processes == 1: + tmp_fit_array = _fit_from_model_loop(data, model) + else: + # Separate the data in chunks of len(nbr_processes). + chunks = np.array_split(data, nbr_processes) + + pool = multiprocessing.Pool(nbr_processes) + results = pool.map(_fit_from_model_parallel, + zip(itertools.repeat(model), + chunks, + np.arange(len(chunks)))) + pool.close() + pool.join() + + # Re-assemble the chunks together. + chunk_len = np.cumsum([0] + [len(c) for c in chunks]) + tmp_fit_array = np.zeros((np.count_nonzero(mask)), dtype='object') + for i, fit in results: + tmp_fit_array[chunk_len[i]:chunk_len[i+1]] = fit + + # Bring back to the original shape + fit_array = np.zeros(data_shape[0:3], dtype='object') fit_array[mask] = tmp_fit_array fit_array = MultiVoxelFit(model, fit_array, mask) diff --git a/scilpy/reconst/sh.py b/scilpy/reconst/sh.py index f6ff22bcc..45f8dea67 100644 --- a/scilpy/reconst/sh.py +++ b/scilpy/reconst/sh.py @@ -181,17 +181,25 @@ def compute_rish(sh, mask=None, full_basis=False): def _peaks_from_sh_parallel(args): - shm_coeff = args[0] - B = args[1] - sphere = args[2] - relative_peak_threshold = args[3] - absolute_threshold = args[4] - min_separation_angle = args[5] - npeaks = args[6] - normalize_peaks = args[7] - chunk_id = args[8] - is_symmetric = args[9] + (shm_coeff, B, sphere, relative_peak_threshold, + absolute_threshold, min_separation_angle, + npeaks, normalize_peaks, chunk_id, is_symmetric) = args + + peak_dirs, peak_values, peak_indices = _peaks_from_sh_loop( + shm_coeff, B, sphere, relative_peak_threshold, + absolute_threshold, min_separation_angle, npeaks, + normalize_peaks, is_symmetric) + return chunk_id, peak_dirs, peak_values, peak_indices + +def _peaks_from_sh_loop(shm_coeff, B, sphere, relative_peak_threshold, + absolute_threshold, min_separation_angle, npeaks, + normalize_peaks, is_symmetric): + """ + Loops on 2D (ravelled) data and fits each voxel separately. + See peaks_from_sh for a complete description of parameters. + """ + # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. data_shape = shm_coeff.shape[0] peak_dirs = np.zeros((data_shape, npeaks, 3)) peak_values = np.zeros((data_shape, npeaks)) @@ -218,8 +226,7 @@ def _peaks_from_sh_parallel(args): if normalize_peaks: peak_values[idx][:n] /= peaks[0] peak_dirs[idx] *= peak_values[idx][:, None] - - return chunk_id, peak_dirs, peak_values, peak_indices + return peak_dirs, peak_values, peak_indices def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, @@ -227,7 +234,7 @@ def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, normalize_peaks=False, npeaks=5, sh_basis_type='descoteaux07', is_legacy=True, nbr_processes=None, full_basis=False, is_symmetric=True): - """Computes peaks from given spherical harmonic coefficients + """Computes peaks from given spherical harmonic coefficients. Parameters ---------- @@ -281,53 +288,65 @@ def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, peak_dirs, peak_values, peak_indices """ sh_order = order_from_ncoef(shm_coeff.shape[-1], full_basis) - B, _ = sh_to_sf_matrix(sphere, sh_order, sh_basis_type, full_basis, - legacy=is_legacy) + B, _ = sh_to_sf_matrix(sphere, sh_order, sh_basis_type, + full_basis, legacy=is_legacy) data_shape = shm_coeff.shape if mask is None: mask = np.sum(shm_coeff, axis=3).astype(bool) nbr_processes = multiprocessing.cpu_count() if nbr_processes is None \ - or nbr_processes < 0 else nbr_processes + or nbr_processes <= 0 else nbr_processes # Ravel the first 3 dimensions while keeping the 4th intact, like a list of - # 1D time series voxels. Then separate it in chunks of len(nbr_processes). + # 1D time series voxels. shm_coeff = shm_coeff[mask].reshape( (np.count_nonzero(mask), data_shape[3])) - chunks = np.array_split(shm_coeff, nbr_processes) - chunk_len = np.cumsum([0] + [len(c) for c in chunks]) - - pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_peaks_from_sh_parallel, - zip(chunks, - itertools.repeat(B), - itertools.repeat(sphere), - itertools.repeat(relative_peak_threshold), - itertools.repeat(absolute_threshold), - itertools.repeat(min_separation_angle), - itertools.repeat(npeaks), - itertools.repeat(normalize_peaks), - np.arange(len(chunks)), - itertools.repeat(is_symmetric))) - pool.close() - pool.join() - - # Re-assemble the chunk together in the original shape. + + # Separating the case nbr_processes=1 to help get good coverage metrics + # (codecov does not deal well with multiprocessing) + if nbr_processes == 1: + (tmp_peak_dirs_array, tmp_peak_values_array, + tmp_peak_indices_array) = _peaks_from_sh_loop( + shm_coeff, B, sphere, relative_peak_threshold, + absolute_threshold, min_separation_angle, npeaks, + normalize_peaks, is_symmetric) + else: + # Separate the data in chunks of len(nbr_processes). + chunks = np.array_split(shm_coeff, nbr_processes) + + pool = multiprocessing.Pool(nbr_processes) + results = pool.map(_peaks_from_sh_parallel, + zip(chunks, + itertools.repeat(B), + itertools.repeat(sphere), + itertools.repeat(relative_peak_threshold), + itertools.repeat(absolute_threshold), + itertools.repeat(min_separation_angle), + itertools.repeat(npeaks), + itertools.repeat(normalize_peaks), + np.arange(len(chunks)), + itertools.repeat(is_symmetric))) + pool.close() + pool.join() + + # Re-assemble the chunk together.. + chunk_len = np.cumsum([0] + [len(c) for c in chunks]) + + # tmp arrays are necessary to avoid inserting data in returned variable + # rather than the original array + tmp_peak_dirs_array = np.zeros((np.count_nonzero(mask), npeaks, 3)) + tmp_peak_values_array = np.zeros((np.count_nonzero(mask), npeaks)) + tmp_peak_indices_array = np.zeros((np.count_nonzero(mask), npeaks)) + for i, peak_dirs, peak_values, peak_indices in results: + tmp_peak_dirs_array[chunk_len[i]:chunk_len[i+1], :, :] = peak_dirs + tmp_peak_values_array[chunk_len[i]:chunk_len[i+1], :] = peak_values + tmp_peak_indices_array[chunk_len[i]:chunk_len[i+1], :] = peak_indices + + # Bring back to the original shape peak_dirs_array = np.zeros(data_shape[0:3] + (npeaks, 3)) peak_values_array = np.zeros(data_shape[0:3] + (npeaks,)) peak_indices_array = np.zeros(data_shape[0:3] + (npeaks,)) - - # tmp arrays are neccesary to avoid inserting data in returned variable - # rather than the original array - tmp_peak_dirs_array = np.zeros((np.count_nonzero(mask), npeaks, 3)) - tmp_peak_values_array = np.zeros((np.count_nonzero(mask), npeaks)) - tmp_peak_indices_array = np.zeros((np.count_nonzero(mask), npeaks)) - for i, peak_dirs, peak_values, peak_indices in results: - tmp_peak_dirs_array[chunk_len[i]:chunk_len[i+1], :, :] = peak_dirs - tmp_peak_values_array[chunk_len[i]:chunk_len[i+1], :] = peak_values - tmp_peak_indices_array[chunk_len[i]:chunk_len[i+1], :] = peak_indices - peak_dirs_array[mask] = tmp_peak_dirs_array peak_values_array[mask] = tmp_peak_values_array peak_indices_array[mask] = tmp_peak_indices_array @@ -336,15 +355,21 @@ def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, def _maps_from_sh_parallel(args): - shm_coeff = args[0] - _ = args[1] - peak_values = args[2] - peak_indices = args[3] - B = args[4] - sphere = args[5] - gfa_thr = args[6] - chunk_id = args[7] + (shm_coeff, peak_values, peak_indices, B, sphere, + gfa_thr, chunk_id) = args + res = _maps_from_sh_loop(shm_coeff, peak_values, peak_indices, B, + sphere, gfa_thr) + return chunk_id, *res + + +def _maps_from_sh_loop(shm_coeff, peak_values, peak_indices, B, sphere, + gfa_thr): + """ + Loops on 2D (ravelled) data and fits each voxel separately. + For a more complete description of parameters, see maps_from_sh. + """ + # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. data_shape = shm_coeff.shape[0] nufo_map = np.zeros(data_shape) afd_max = np.zeros(data_shape) @@ -375,11 +400,11 @@ def _maps_from_sh_parallel(args): qa_map = peak_values[idx] - odf.min() global_max = max(global_max, peak_values[idx][0]) - return chunk_id, nufo_map, afd_max, afd_sum, rgb_map, \ - gfa_map, qa_map, max_odf, global_max + return (nufo_map, afd_max, afd_sum, rgb_map, + gfa_map, qa_map, max_odf, global_max) -def maps_from_sh(shm_coeff, peak_dirs, peak_values, peak_indices, sphere, +def maps_from_sh(shm_coeff, peak_values, peak_indices, sphere, mask=None, gfa_thr=0, sh_basis_type='descoteaux07', nbr_processes=None): """Computes maps from given SH coefficients and peaks @@ -388,8 +413,6 @@ def maps_from_sh(shm_coeff, peak_dirs, peak_values, peak_indices, sphere, ---------- shm_coeff : np.ndarray Spherical harmonic coefficients - peak_dirs : np.ndarray - Peak directions peak_values : np.ndarray Peak values peak_indices : np.ndarray @@ -428,33 +451,65 @@ def maps_from_sh(shm_coeff, peak_dirs, peak_values, peak_indices, sphere, else nbr_processes npeaks = peak_values.shape[3] + # Ravel the first 3 dimensions while keeping the 4th intact, like a list of - # 1D time series voxels. Then separate it in chunks of len(nbr_processes). + # 1D time series voxels. shm_coeff = shm_coeff[mask].reshape( (np.count_nonzero(mask), data_shape[3])) - peak_dirs = peak_dirs[mask].reshape((np.count_nonzero(mask), npeaks, 3)) peak_values = peak_values[mask].reshape((np.count_nonzero(mask), npeaks)) peak_indices = peak_indices[mask].reshape((np.count_nonzero(mask), npeaks)) - shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) - peak_dirs_chunks = np.array_split(peak_dirs, nbr_processes) - peak_values_chunks = np.array_split(peak_values, nbr_processes) - peak_indices_chunks = np.array_split(peak_indices, nbr_processes) - chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) - - pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_maps_from_sh_parallel, - zip(shm_coeff_chunks, - peak_dirs_chunks, - peak_values_chunks, - peak_indices_chunks, - itertools.repeat(B), - itertools.repeat(sphere), - itertools.repeat(gfa_thr), - np.arange(len(shm_coeff_chunks)))) - pool.close() - pool.join() - - # Re-assemble the chunk together in the original shape. + + if nbr_processes == 1: + (tmp_nufo_map_array, tmp_afd_max_array, tmp_afd_sum_array, + tmp_rgb_map_array, tmp_gfa_map_array, tmp_qa_map_array, + all_time_max_odf, all_time_global_max) = _maps_from_sh_loop( + shm_coeff, peak_values, peak_indices, + B, sphere, gfa_thr) + else: + # Separate the data in chunks of len(nbr_processes). + shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) + peak_values_chunks = np.array_split(peak_values, nbr_processes) + peak_indices_chunks = np.array_split(peak_indices, nbr_processes) + + pool = multiprocessing.Pool(nbr_processes) + results = pool.map(_maps_from_sh_parallel, + zip(shm_coeff_chunks, + peak_values_chunks, + peak_indices_chunks, + itertools.repeat(B), + itertools.repeat(sphere), + itertools.repeat(gfa_thr), + np.arange(len(shm_coeff_chunks)))) + pool.close() + pool.join() + + # Re-assemble the chunk together. + chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) + + # tmp arrays are necessary to avoid inserting data in returned variable + # rather than the original array + tmp_nufo_map_array = np.zeros((np.count_nonzero(mask))) + tmp_afd_max_array = np.zeros((np.count_nonzero(mask))) + tmp_afd_sum_array = np.zeros((np.count_nonzero(mask))) + tmp_rgb_map_array = np.zeros((np.count_nonzero(mask), 3)) + tmp_gfa_map_array = np.zeros((np.count_nonzero(mask))) + tmp_qa_map_array = np.zeros((np.count_nonzero(mask), npeaks)) + + all_time_max_odf = -np.inf + all_time_global_max = -np.inf + for (i, nufo_map, afd_max, afd_sum, rgb_map, + gfa_map, qa_map, max_odf, global_max) in results: + all_time_max_odf = max(all_time_global_max, max_odf) + all_time_global_max = max(all_time_global_max, global_max) + + tmp_nufo_map_array[chunk_len[i]:chunk_len[i+1]] = nufo_map + tmp_afd_max_array[chunk_len[i]:chunk_len[i+1]] = afd_max + tmp_afd_sum_array[chunk_len[i]:chunk_len[i+1]] = afd_sum + tmp_rgb_map_array[chunk_len[i]:chunk_len[i+1], :] = rgb_map + tmp_gfa_map_array[chunk_len[i]:chunk_len[i+1]] = gfa_map + tmp_qa_map_array[chunk_len[i]:chunk_len[i+1], :] = qa_map + + # Bring back to the original shape nufo_map_array = np.zeros(data_shape[0:3]) afd_max_array = np.zeros(data_shape[0:3]) afd_sum_array = np.zeros(data_shape[0:3]) @@ -462,29 +517,6 @@ def maps_from_sh(shm_coeff, peak_dirs, peak_values, peak_indices, sphere, gfa_map_array = np.zeros(data_shape[0:3]) qa_map_array = np.zeros(data_shape[0:3] + (npeaks,)) - # tmp arrays are neccesary to avoid inserting data in returned variable - # rather than the original array - tmp_nufo_map_array = np.zeros((np.count_nonzero(mask))) - tmp_afd_max_array = np.zeros((np.count_nonzero(mask))) - tmp_afd_sum_array = np.zeros((np.count_nonzero(mask))) - tmp_rgb_map_array = np.zeros((np.count_nonzero(mask), 3)) - tmp_gfa_map_array = np.zeros((np.count_nonzero(mask))) - tmp_qa_map_array = np.zeros((np.count_nonzero(mask), npeaks)) - - all_time_max_odf = -np.inf - all_time_global_max = -np.inf - for (i, nufo_map, afd_max, afd_sum, rgb_map, - gfa_map, qa_map, max_odf, global_max) in results: - all_time_max_odf = max(all_time_global_max, max_odf) - all_time_global_max = max(all_time_global_max, global_max) - - tmp_nufo_map_array[chunk_len[i]:chunk_len[i+1]] = nufo_map - tmp_afd_max_array[chunk_len[i]:chunk_len[i+1]] = afd_max - tmp_afd_sum_array[chunk_len[i]:chunk_len[i+1]] = afd_sum - tmp_rgb_map_array[chunk_len[i]:chunk_len[i+1], :] = rgb_map - tmp_gfa_map_array[chunk_len[i]:chunk_len[i+1]] = gfa_map - tmp_qa_map_array[chunk_len[i]:chunk_len[i+1], :] = qa_map - nufo_map_array[mask] = tmp_nufo_map_array afd_max_array[mask] = tmp_afd_max_array afd_sum_array[mask] = tmp_afd_sum_array @@ -501,22 +533,28 @@ def maps_from_sh(shm_coeff, peak_dirs, peak_values, peak_indices, sphere, or np.array_equal(np.array([1]), afd_unique): logging.warning('All AFD_max values are 1. The peaks seem normalized.') - return(nufo_map_array, afd_max_array, afd_sum_array, - rgb_map_array, gfa_map_array, qa_map_array) + return (nufo_map_array, afd_max_array, afd_sum_array, + rgb_map_array, gfa_map_array, qa_map_array) def _convert_sh_basis_parallel(args): - sh = args[0] - B_in = args[1] - invB_out = args[2] - chunk_id = args[3] + (sh, B_in, invB_out, chunk_id) = args + sh = _convert_sh_basis_loop(sh, B_in, invB_out) + + return chunk_id, sh + +def _convert_sh_basis_loop(sh, B_in, invB_out): + """ + Loops on 2D (ravelled) data and fits each voxel separately. + For a more complete description of parameters, see convert_sh_basis. + """ + # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. for idx in range(sh.shape[0]): if sh[idx].any(): sf = np.dot(sh[idx], B_in) sh[idx] = np.dot(sf, invB_out) - - return chunk_id, sh + return sh def convert_sh_basis(shm_coeff, sphere, mask=None, @@ -580,44 +618,59 @@ def convert_sh_basis(shm_coeff, sphere, mask=None, if nbr_processes is None or nbr_processes < 0 else nbr_processes # Ravel the first 3 dimensions while keeping the 4th intact, like a list of - # 1D time series voxels. Then separate it in chunks of len(nbr_processes). + # 1D time series voxels. shm_coeff = shm_coeff[mask].reshape( (np.count_nonzero(mask), data_shape[3])) - shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) - chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) - - pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_convert_sh_basis_parallel, - zip(shm_coeff_chunks, - itertools.repeat(B_in), - itertools.repeat(invB_out), - np.arange(len(shm_coeff_chunks)))) - pool.close() - pool.join() - - # Re-assemble the chunk together in the original shape. - shm_coeff_array = np.zeros(data_shape) - tmp_shm_coeff_array = np.zeros((np.count_nonzero(mask), data_shape[3])) - for i, new_shm_coeff in results: - tmp_shm_coeff_array[chunk_len[i]:chunk_len[i+1], :] = new_shm_coeff + # Separating the case nbr_processes=1 to help get good coverage metrics + # (codecov does not deal well with multiprocessing) + if nbr_processes == 1: + tmp_shm_coeff_array = _convert_sh_basis_loop(shm_coeff, B_in, invB_out) + else: + # Separate the data in chunks of len(nbr_processes). + shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) + + pool = multiprocessing.Pool(nbr_processes) + results = pool.map(_convert_sh_basis_parallel, + zip(shm_coeff_chunks, + itertools.repeat(B_in), + itertools.repeat(invB_out), + np.arange(len(shm_coeff_chunks)))) + pool.close() + pool.join() + + # Re-assemble the chunk together. + chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) + tmp_shm_coeff_array = np.zeros((np.count_nonzero(mask), data_shape[3])) + for i, new_shm_coeff in results: + tmp_shm_coeff_array[chunk_len[i]:chunk_len[i+1], :] = new_shm_coeff + + # Bring back to the original shape + shm_coeff_array = np.zeros(data_shape) shm_coeff_array[mask] = tmp_shm_coeff_array return shm_coeff_array def _convert_sh_to_sf_parallel(args): - sh = args[0] - B_in = args[1] - new_output_dim = args[2] - chunk_id = args[3] + (sh, B_in, new_output_dim, chunk_id) = args + sf = _convert_sh_to_sf_loop(sh, new_output_dim, B_in) + return chunk_id, sf + + +def _convert_sh_to_sf_loop(sh, new_output_dim, B_in): + """ + Loops on 2D data and fits each voxel separately. + See convert_sh_to_sf for more information. + """ + # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. sf = np.zeros((sh.shape[0], new_output_dim), dtype=np.float32) for idx in range(sh.shape[0]): if sh[idx].any(): sf[idx] = np.dot(sh[idx], B_in) - return chunk_id, sf + return sf def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", @@ -670,30 +723,40 @@ def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", if mask is None: mask = np.sum(shm_coeff, axis=3).astype(bool) + output_dim = len(sphere.vertices) + new_shape = data_shape[:3] + (output_dim,) + # Ravel the first 3 dimensions while keeping the 4th intact, like a list of - # 1D time series voxels. Then separate it in chunks of len(nbr_processes). + # 1D time series voxels. shm_coeff = shm_coeff[mask].reshape( (np.count_nonzero(mask), data_shape[3])) - shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) - chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) - - pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_convert_sh_to_sf_parallel, - zip(shm_coeff_chunks, - itertools.repeat(B_in), - itertools.repeat(len(sphere.vertices)), - np.arange(len(shm_coeff_chunks)))) - pool.close() - pool.join() - - # Re-assemble the chunk together in the original shape. - new_shape = data_shape[:3] + (len(sphere.vertices),) - sf_array = np.zeros(new_shape, dtype=dtype) - tmp_sf_array = np.zeros((np.count_nonzero(mask), new_shape[3]), - dtype=dtype) - for i, new_sf in results: - tmp_sf_array[chunk_len[i]:chunk_len[i + 1], :] = new_sf + # Separating the case nbr_processes=1 to help get good coverage metrics + # (codecov does not deal well with multiprocessing) + if nbr_processes == 1: + tmp_sf_array = _convert_sh_to_sf_loop(shm_coeff, output_dim, B_in) + else: + # Separate the data in chunks of len(nbr_processes). + shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) + + pool = multiprocessing.Pool(nbr_processes) + results = pool.map(_convert_sh_to_sf_parallel, + zip(shm_coeff_chunks, + itertools.repeat(B_in), + itertools.repeat(output_dim), + np.arange(len(shm_coeff_chunks)))) + pool.close() + pool.join() + + # Re-assemble the chunk together. + chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) + tmp_sf_array = np.zeros((np.count_nonzero(mask), new_shape[3]), + dtype=dtype) + for i, new_sf in results: + tmp_sf_array[chunk_len[i]:chunk_len[i + 1], :] = new_sf + + # Bring back to the original shape + sf_array = np.zeros(new_shape, dtype=dtype) sf_array[mask] = tmp_sf_array return sf_array diff --git a/scilpy/segment/streamlines.py b/scilpy/segment/streamlines.py index 3d5b599a0..5dd461e65 100644 --- a/scilpy/segment/streamlines.py +++ b/scilpy/segment/streamlines.py @@ -137,6 +137,13 @@ def filter_grid_roi(sft, mask, filter_type, is_exclude, filter_distance=0, sft of rejected streamlines (if return_rejected_sft) """ + if len(sft.streamlines) == 0: + if return_sft: + if return_rejected_sft: + return np.array([]), sft, sft + return np.array([]), sft + else: + return np.array([]) if filter_distance != 0: bin_struct = generate_binary_structure(3, 2) @@ -204,7 +211,7 @@ def pre_filtering_for_geometrical_shape(sft, size, center, filter_type, center: numpy.ndarray (3) Center x/y/z of the ROI. filter_type: str - One of the 3 following choices, 'any', 'all', 'either_end', 'both_ends'. + One of the 4 following choices, 'any', 'all', 'either_end', 'both_ends'. is_in_vox: bool Value to indicate if the ROI is in voxel space. @@ -256,7 +263,7 @@ def filter_ellipsoid(sft, ellipsoid_radius, ellipsoid_center, ellipsoid_center: numpy.ndarray (3) Center x/y/z of the ellipsoid. filter_type: str - One of the 3 following choices, 'any', 'all', 'either_end', 'both_ends'. + One of the 4 following choices, 'any', 'all', 'either_end', 'both_ends'. is_exclude: bool Value to indicate if the ROI is an AND (false) or a NOT (true). is_in_vox: bool @@ -269,6 +276,9 @@ def filter_ellipsoid(sft, ellipsoid_radius, ellipsoid_center, sft: StatefulTractogram Filtered sft """ + if len(sft.streamlines) == 0: + return np.array([]), sft + pre_filtered_indices, pre_filtered_sft = \ pre_filtering_for_geometrical_shape(sft, ellipsoid_radius, ellipsoid_center, filter_type, @@ -347,7 +357,7 @@ def filter_ellipsoid(sft, ellipsoid_radius, ellipsoid_center, data_per_streamline=data_per_streamline, data_per_point=data_per_point) - return new_sft, line_based_indices + return line_based_indices, new_sft def filter_cuboid(sft, cuboid_radius, cuboid_center, @@ -362,17 +372,21 @@ def filter_cuboid(sft, cuboid_radius, cuboid_center, cuboid_center: numpy.ndarray (3) Center x/y/z of the cuboid. filter_type: str - One of the 3 following choices, 'any', 'all', 'either_end', 'both_ends'. + One of the 4 following choices: 'any', 'all', 'either_end', 'both_ends'. is_exclude: bool Value to indicate if the ROI is an AND (false) or a NOT (true). is_in_vox: bool Value to indicate if the ROI is in voxel space. Returns ------- - ids : tuple - Filtered sft. + ids : list Ids of the streamlines passing through the mask. + sft: StatefulTractogram + Filtered sft """ + if len(sft.streamlines) == 0: + return np.array([]), sft + pre_filtered_indices, pre_filtered_sft = \ pre_filtering_for_geometrical_shape(sft, cuboid_radius, cuboid_center, filter_type, @@ -444,4 +458,4 @@ def filter_cuboid(sft, cuboid_radius, cuboid_center, data_per_streamline=data_per_streamline, data_per_point=data_per_point) - return new_sft, line_based_indices + return line_based_indices, new_sft diff --git a/scilpy/segment/voting_scheme.py b/scilpy/segment/voting_scheme.py index a01d5ae0f..04f823848 100644 --- a/scilpy/segment/voting_scheme.py +++ b/scilpy/segment/voting_scheme.py @@ -96,7 +96,8 @@ def _load_bundles_dictionary(self): return model_bundles_dict, bundle_names, bundle_counts - def _find_max_in_sparse_matrix(self, bundle_id, min_vote, bundles_wise_vote): + def _find_max_in_sparse_matrix(self, bundle_id, min_vote, + bundles_wise_vote): """ Will find the maximum values of a specific row (bundle_id), make sure they are the maximum values across bundles (argmax) and above the @@ -147,22 +148,24 @@ def _save_recognized_bundles(self, sft, bundle_names, bundles_wise_vote) if not streamlines_id.size: - logging.error('{0} final recognition got {1} streamlines'.format( - bundle_names[bundle_id], len(streamlines_id))) + logging.error('{0} final recognition got {1} streamlines' + .format(bundle_names[bundle_id], + len(streamlines_id))) continue else: - logging.info('{0} final recognition got {1} streamlines'.format( - bundle_names[bundle_id], len(streamlines_id))) + logging.info('{0} final recognition got {1} streamlines' + .format(bundle_names[bundle_id], + len(streamlines_id))) # All models of the same bundle have the same basename - basename = os.path.join(self.output_directory, - os.path.splitext(bundle_names[bundle_id])[0]) + basename = os.path.join( + self.output_directory, + os.path.splitext(bundle_names[bundle_id])[0]) new_sft = sft[streamlines_id.T] new_sft.remove_invalid_streamlines() save_tractogram(new_sft, basename + extension) - curr_results_dict = {} - curr_results_dict['indices'] = streamlines_id.tolist() + curr_results_dict = {'indices': streamlines_id.tolist()} results_dict[basename] = curr_results_dict out_logfile = os.path.join(self.output_directory, 'results.json') @@ -238,16 +241,27 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, rbx = RecobundlesX(tmp_memmap_filenames, clusters_indices, centroids) - # Update all RecobundlesX initialisation into a single dictionnary - pool = multiprocessing.Pool(nbr_processes) - all_recognized_dict = pool.map(single_recognize, - zip(repeat(rbx), - model_bundles_dict.keys(), - model_bundles_dict.values(), - repeat(bundle_names), - repeat([seed]))) - pool.close() - pool.join() + # Separating the case nbr_processes=1 to help get good coverage metrics + # (codecov does not deal well with multiprocessing) + if nbr_processes == 1: + all_recognized = [] + for model_filepath, val in zip(model_bundles_dict.keys(), + model_bundles_dict.values()): + bundle_pruning_thr, model_bundle = val + all_recognized.append( + single_recognize(rbx, model_filepath, model_bundle, + bundle_pruning_thr, bundle_names, seed)) + else: + # Update all RecobundlesX initialisation into a single list + pool = multiprocessing.Pool(nbr_processes) + all_recognized = pool.map(single_recognize_parallel, + zip(repeat(rbx), + model_bundles_dict.keys(), + model_bundles_dict.values(), + repeat(bundle_names), + repeat(seed))) + pool.close() + pool.join() tmp_dir.cleanup() logging.info('RBx took {0} sec. for {1} bundles from {2} atlas'.format( @@ -260,11 +274,11 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, len_wb_streamlines), dtype=np.uint8) - for bundle_id, recognized_indices in all_recognized_dict: - if recognized_indices is not None: - tmp_values = bundles_wise_vote[bundle_id, recognized_indices.T] + for bundle_id, recognized_ind in all_recognized: + if recognized_ind is not None: + tmp_values = bundles_wise_vote[bundle_id, recognized_ind.T] bundles_wise_vote[bundle_id, - recognized_indices.T] = tmp_values.toarray() + 1 + recognized_ind.T] = tmp_values.toarray() + 1 bundles_wise_vote = bundles_wise_vote.tocsr() # Once everything was run, save the results using a voting system @@ -283,9 +297,21 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, round(time() - save_timer, 2))) -def single_recognize(args): +def single_recognize_parallel(args): + """Wrapper function to multiprocess recobundles execution.""" + rbx = args[0] + model_filepath = args[1] + bundle_pruning_thr, model_bundle = args[2] + bundle_names = args[3] + seed = args[4] + return single_recognize(rbx, model_filepath, model_bundle, + bundle_pruning_thr, bundle_names, seed) + + +def single_recognize(rbx, model_filepath, model_bundle, bundle_pruning_thr, + bundle_names, seed): """ - Wrapper function to multiprocess recobundles execution. + Recobundle for a single bundle. Parameters ---------- @@ -293,11 +319,10 @@ def single_recognize(args): Initialize RBx object with QBx ClusterMap as values model_filepath : str Path to the model bundle file - params : tuple - bundle_pruning_thr : float - Threshold for pruning the model bundle - streamlines: ArraySequence - Streamlines of the model bundle + model_bundle: ArraySequence + Model bundle. + bundle_pruning_thr : float + Threshold for pruning the model bundle bundle_names : list List of string with bundle names for models (to get bundle_id) seed : int @@ -305,18 +330,12 @@ def single_recognize(args): Returns ------- - transf_neighbor : tuple - bundle_id : (int) - Unique value to each bundle to identify them. - recognized_indices : (numpy.ndarray) - Streamlines indices from the original tractogram. + bundle_id : int + Unique value to each bundle to identify them. + recognized_indices : numpy.ndarray + Streamlines indices from the original tractogram. """ - rbx = args[0] - model_filepath = args[1] - bundle_pruning_thr = args[2][0] - model_bundle = args[2][1] - bundle_names = args[3] - np.random.seed(args[4][0]) + np.random.seed(seed) # Use for logging and finding the bundle_id shorter_tag, ext = os.path.splitext(os.path.basename(model_filepath)) diff --git a/scilpy/tracking/fibertube_utils.py b/scilpy/tracking/fibertube_utils.py new file mode 100644 index 000000000..9395ebca9 --- /dev/null +++ b/scilpy/tracking/fibertube_utils.py @@ -0,0 +1,491 @@ +import numpy as np + +from math import sqrt +from numba import njit +from scilpy.tracking.utils import tqdm_if_verbose + + +def streamlines_to_segments(streamlines, verbose=False): + """ + Separates all streamlines of a tractogram into segments that connect + each position. Then, flattens the resulting 2D array and returns it + + Parameters + ---------- + streamlines : list + Streamlines to segment. This function is compatible with streamlines + as a fixed array, as long as the padding value is a number. Padding + will also be present in the result value. + + Returns + ------- + centers : ndarray[float] + A flattened array of all the tractogram's segment centers + indices : ndarray[Tuple[int, int]] + A flattened array of all the tractogram's segment indices + max_length: float + Length of the longest segment of the whole tractogram + """ + centers = [] + indices = [] + max_length = 0. + for si, s in tqdm_if_verbose(enumerate(streamlines), verbose, + total=len(streamlines)): + centers.append((s[1:] + s[:-1]) / 2) + indices.append([(si, pi) for pi in range(len(s)-1)]) + + max_length_candidate = np.amax(np.linalg.norm(s[1:] - s[:-1], axis=-1)) + + if max_length_candidate > max_length: + max_length = float(max_length_candidate) + + centers = np.vstack(centers) + indices = np.vstack(indices) + + return (centers, indices, max_length) + + +@njit +def rotation_between_vectors_matrix(vec1, vec2): + """ + Produces a rotation matrix that aligns a 3D vector 'vec1' with another 3D + vector 'vec2'. Numba compatible. + + https://math.stackexchange.com/questions/180418/calculate- + rotation-matrix-to-align-vector-a-to-vector-b-in-3d + + Parameters + ---------- + vec1: ndarray + Vector to be rotated + vec2: ndarray + Targeted orientation + + Returns + ------- + rotation_matrix: ndarray + A transform matrix (3x3) which when applied to vec1, aligns it with + vec2. + """ + a, b = ((vec1 / np.linalg.norm(vec1)).reshape(3), + (vec2 / np.linalg.norm(vec2)).reshape(3)) + v = np.cross(a, b) + c = np.dot(a, b) + s = np.linalg.norm(v) + if s != 0: + kmat = np.array([[0, -v[2], v[1]], + [v[2], 0, -v[0]], + [-v[1], v[0], 0]]) + rotation_matrix = np.eye(3) + (kmat + + kmat.dot(kmat) * ((1 - c) / (s ** 2))) + else: + rotation_matrix = np.eye(3) + return rotation_matrix + + +@njit +def sample_sphere(center, radius: float, amount: int, + rand_gen: np.random.Generator): + """ + Samples a sphere uniformly given its dimensions and the amount of samples. + + Parameters + ---------- + center: ndarray + Center coordinates of the sphere. Can be [0, 0, 0] if only the + relative displacement interests you. + radius: float + Radius of the sphere. + amount: int + Amount of samples to be produced. + rand_gen: numpy random generator + Numpy random generator used for producing samples within the sphere. + + Returns + ------- + samples: list + Array containing the coordinates of each sample. + """ + samples = [] + while (len(samples) < amount): + sample = np.array([rand_gen.uniform(-radius, radius), + rand_gen.uniform(-radius, radius), + rand_gen.uniform(-radius, radius)]) + if np.linalg.norm(sample) <= radius: + samples.append(sample + center) + return samples + + +@njit +def sample_cylinder(pt1, pt2, radius: float, sample_count: int, + random_generator: np.random.Generator): + """ + Samples a cylinder uniformly given its dimensions and the amount of + samples. + + Parameters + ---------- + pt1: ndarray + First extremity of the cylinder axis + pt2: ndarray + Second extremity of the cylinder axis + radius: float + Radius of the cylinder. + sample_count: int + Amount of samples to be produced. + rand_gen: numpy random generator + Numpy random generator used for producing samples within the sphere. + + Returns + ------- + samples: list + Array containing the coordinates of each sample. + """ + samples = [] + while (len(samples) < sample_count): + axis = pt2 - pt1 + center = (pt1 + pt2) / 2 + half_length = np.linalg.norm(axis) / 2 + axis /= np.linalg.norm(axis) + reference = np.array([0., 0., 1.], dtype=axis.dtype) + + # Generation + x = random_generator.uniform(-radius, radius) + y = random_generator.uniform(-radius, radius) + z = random_generator.uniform(-half_length, half_length) + sample = np.array([x, y, z], dtype=np.float64) + + # Rotation + rotation_matrix = np.eye(4, dtype=np.float64) + rotation_matrix[:3, :3] = rotation_between_vectors_matrix( + reference, + axis).astype(np.float32) + sample = np.dot(rotation_matrix, np.append(sample, 1.))[:3] + + # Translation + sample += center + sample = sample.astype(np.float32) + + if (point_in_cylinder(pt1, pt2, radius, sample)): + samples.append(sample) + return samples + + +@njit +def point_in_cylinder(pt1, pt2, r, q): + vec = pt2 - pt1 + cond_1 = np.dot(q - pt1, vec) >= 0 + cond_2 = np.dot(q - pt2, vec) <= 0 + cond_3 = (np.linalg.norm(np.cross(q - pt1, vec)) / + np.linalg.norm(vec)) <= r + return cond_1 and cond_2 and cond_3 + + +@njit +def sphere_cylinder_intersection(sph_p, sph_r: float, cyl_p1, cyl_p2, + cyl_r: float, sample_count: int, + random_generator: np.random.Generator): + """ + Estimates the volume of intersection between a cylinder and a sphere by + sampling the cylinder. Most efficient when the cylinder is smaller than + the sphere. + + Parameters + ---------- + sph_p: ndarray + Center coordinate of the sphere. + sph_r: float + Radius of the sphere. + cyl_p1: ndarray + First point of the cylinder's center segment. + cyl_p2: ndarray + Second point of the cylinder's center segment. + cyl_r: float + Radius of the cylinder. + sample_count: int + Amount of samples to use for the estimation. + + Returns + ------- + inter_volume: float + Approximate volume of the sphere-cylinder intersection. + is_estimated: boolean + Indicates whether or not the resulting volume has been estimated. + If true, inter_volume has a probability of error due to sample count. + """ + cyl_axis = cyl_p2 - cyl_p1 + cyl_length = np.linalg.norm(cyl_axis) + + # If cylinder is completely inside the sphere. + if (np.linalg.norm(sph_p - cyl_p1) + cyl_r <= sph_r and + np.linalg.norm(sph_p - cyl_p2) + cyl_r <= sph_r): + cyl_volume = np.pi * (cyl_r ** 2) * cyl_length + return cyl_volume, False + + # If cylinder is completely outside the sphere. + _, vector, _ = dist_point_segment(cyl_p1, cyl_p2, sph_p) + if np.linalg.norm(vector) >= sph_r + cyl_r: + return 0, False + + # High probability of intersection. + samples = sample_cylinder(cyl_p1, cyl_p2, cyl_r, sample_count, + random_generator) + + inter_samples = 0 + for sample in samples: + if np.linalg.norm(sph_p - sample) < sph_r: + inter_samples += 1 + + # Proportion of cylinder samples common to both shapes * cylinder volume. + cyl_volume = np.pi * (cyl_r ** 2) * cyl_length + inter_volume = (inter_samples / sample_count) * cyl_volume + + return inter_volume, True + + +@njit +def create_perpendicular(v: np.ndarray): + """ + Generates a vector perpendicular to v. + + Source: https://math.stackexchange.com/questions/133177/finding-a-unit- + vector-perpendicular-to-another-vector + Answer by Ahmed Fasih + + Parameters + ---------- + v: ndarray + Vector from which a perpendicular vector will be generated. + + Returns + ------- + vp: ndarray + Vector perpendicular to v. + """ + vp = np.array([0., 0., 0.]) + if v.all() == vp.all(): + return vp + for m in range(3): + if v[m] == 0.: + continue + n = (m + 1) % 3 + vp[n] = -v[m] + vp[m] = -v[n] + + return vp / np.linalg.norm(vp) + + +@njit +def dist_point_segment(p0, p1, q): + """ + Calculates the shortest distance between a 3D point q and a segment p0-p1. + + Parameters + ---------- + p0: ndarray + Point forming the first end of the segment. + p1: ndarray + Point forming the second end of the segment. + q: ndarray + Point coordinates. + + Returns + ------- + distance: float + Shortest distance between the two segments + v: ndarray + Vector representing the distance between the two segments. + v = Ps - q and |v| = distance + Ps: ndarray + Point coordinates on segment P that is closest to point q + """ + return dist_segment_segment(p0, p1, q, q)[:3] + + +@njit +def dist_segment_segment(P0, P1, Q0, Q1): + """ + Calculates the shortest distance between two 3D segments P0-P1 and Q0-Q1. + + Parameters + ---------- + P0: ndarray + Point forming the first end of the P segment. + P1: ndarray + Point forming the second end of the P segment. + Q0: ndarray + Point forming the first end of the Q segment. + Q1: ndarray + Point forming the second end of the Q segment. + + Returns + ------- + distance: float + Shortest distance between the two segments + v: ndarray + Vector representing the distance between the two segments. + v = Ps - Qt and |v| = distance + Ps: ndarray + Point coordinates on segment P that is closest to segment Q + Qt: ndarray + Point coordinates on segment Q that is closest to segment P + + This function is a python version of the following code: + https://www.geometrictools.com/GTE/Mathematics/DistSegmentSegment.h + + Scientific source: + https://www.geometrictools.com/Documentation/DistanceLine3Line3.pdf + + """ + P1mP0 = np.subtract(P1, P0) + Q1mQ0 = np.subtract(Q1, Q0) + P0mQ0 = np.subtract(P0, Q0) + + a = np.dot(P1mP0, P1mP0) + b = np.dot(P1mP0, Q1mQ0) + c = np.dot(Q1mQ0, Q1mQ0) + d = np.dot(P1mP0, P0mQ0) + e = np.dot(Q1mQ0, P0mQ0) + det = a * c - b * b + s = t = nd = bmd = bte = ctd = bpe = ate = btd = None + + if det > 0: + bte = b * e + ctd = c * d + if bte <= ctd: # s <= 0 + s = 0 + if e <= 0: # t <= 0 + # section 6 + t = 0 + nd = -d + if nd >= a: + s = 1 + elif nd > 0: + s = nd / a + # else: s is already 0 + elif e < c: # 0 < t < 1 + # section 5 + t = e / c + else: # t >= 1 + # section 4 + t = 1 + bmd = b - d + if bmd >= a: + s = 0 + elif bmd > 0: + s = bmd / a + # else: s is already 0 + else: # s > 0 + s = bte - ctd + if s >= det: # s >= 1 + # s = 1 + s = 1 + bpe = b + e + if bpe <= 0: # t <= 0 + # section 8 + t = 0 + nd = -d + if nd <= 0: + s = 0 + elif nd < a: + s = nd / a + # else: s is already 1 + elif bpe < c: # 0 < t < 1 + # section 1 + t = bpe / c + else: # t >= 1 + # section 2 + t = 1 + bmd = b - d + if bmd <= 0: + s = 0 + elif bmd < a: + s = bmd / a + # else: s is already 1 + else: # 0 < s < 1 + ate = a * e + btd = b * d + if ate <= btd: # t <= 0 + # section 7 + t = 0 + nd = -d + if nd <= 0: + s = 0 + elif nd >= a: + s = 1 + else: + s = nd / a + else: # t > 0 + t = ate - btd + if t >= det: # t >= 1 + # section 3 + t = 1 + bmd = b - d + if bmd <= 0: + s = 0 + elif (bmd >= a): + s = 1 + else: + s = bmd / a + else: # 0 < t < 1 + # section 0 + s /= det + t /= det + + else: + # The segments are parallel. The quadratic factors to + # R(s,t) = a*(s-(b/a)*t)^2 + 2*d*(s - (b/a)*t) + f + # where a*c = b^2, e = b*d/a, f = |P0-Q0|^2, and b is not + # 0. R is constant along lines of the form s-(b/a)*t = k + # and its occurs on the line a*s - b*t + d = 0. This line + # must intersect both the s-axis and the t-axis because 'a' + # and 'b' are not 0. Because of parallelism, the line is + # also represented by -b*s + c*t - e = 0. + # + # The code determines an edge of the domain [0,1]^2 that + # intersects the minimum line, or if n1 of the edges + # intersect, it determines the closest corner to the minimum + # line. The conditionals are designed to test first for + # intersection with the t-axis (s = 0) using + # -b*s + c*t - e = 0 and then with the s-axis (t = 0) using + # a*s - b*t + d = 0. + + # When s = 0, solve c*t - e = 0 (t = e/c). + if e <= 0: # t <= 0 + # Now solve a*s - b*t + d = 0 for t = 0 (s = -d/a). + t = 0 + nd = -d + if nd <= 0: # s <= 0 + # section 6 + s = 0 + elif nd >= a: # s >= 1 + # section 8 + s = 1 + else: # 0 < s < 1 + # section 7 + s = nd / a + elif e >= c: # t >= 1 + # Now solve a*s - b*t + d = 0 for t = 1 (s = (b-d)/a). + t = 1 + bmd = b - d + if bmd <= 0: # s <= 0 + # section 4 + s = 0 + elif bmd >= a: # s >= 1 + # section 2 + s = 1 + else: # 0 < s < 1 + # section 3 + s = bmd / a + else: # 0 < t < 1 + # The point (0,e/c) is on the line and domain, so we have + # 1 point at which R is a minimum. + s = 0 + t = e / c + + Ps = P0 + s * P1mP0 + Qt = Q0 + t * Q1mQ0 + v = Ps - Qt + sqr_distance = np.dot(v, v) + distance = sqrt(sqr_distance) + return (distance, v, Ps, Qt) diff --git a/scilpy/tracking/propagator.py b/scilpy/tracking/propagator.py index ce44099d8..d347095f9 100644 --- a/scilpy/tracking/propagator.py +++ b/scilpy/tracking/propagator.py @@ -11,6 +11,7 @@ from scilpy.reconst.utils import (get_sphere_neighbours, get_sh_order_and_fullness) from scilpy.tracking.utils import sample_distribution, TrackingDirection +from scilpy.image.volume_space_management import FibertubeDataVolume class PropagationStatus(Enum): @@ -498,8 +499,8 @@ def _sample_next_direction(self, pos, v_in): # Sampling one. if np.sum(sf) > 0: - v_out = directions[sample_distribution(sf, - self.line_rng_generator)] + v_out = directions[ + sample_distribution(sf, self.line_rng_generator)] else: return None elif self.algo == 'det': @@ -573,3 +574,116 @@ def _get_possible_next_dirs_det(self, pos, previous_direction): if 0 < sf[i] == np.max(sf[self.maxima_neighbours[i]]): maxima.append(self.dirs[i]) return maxima + + +class FibertubePropagator(AbstractPropagator): + """ + Simplified propagator for using fibertube data. It is probabilistic and + uses the volume of intersection between fibertube segments and the + blurring sphere as a random distribution for picking a segment. This + segment is then used as the propagation direction. + """ + def __init__(self, datavolume: FibertubeDataVolume, step_size, rk_order, + theta, space, origin): + """" + Parameters + ---------- + datavolume: FibertubeDataVolume + Trackable fibertube dataset object. + step_size: float + The step size for tracking. Important: step size should be in the + same units as the space of the tracking! + rk_order: int + Order for the Runge Kutta integration. + theta: float + Maximum angle (radians) between two steps. + space: dipy Space + Space of the streamlines during tracking. value. + origin: dipy Origin + Origin of the streamlines during tracking. All coordinates + received in the propagator's methods will be expected to respect + that origin. + + A note on space and origin: All coordinates received in the + propagator's methods will be expected to respect those values. Tracker + will verify that the propagator has the same internal values as itself. + """ + + if not (rk_order == 1 or rk_order == 2 or rk_order == 4): + raise ValueError("Invalid runge-kutta order. Is " + + str(rk_order) + ". Choices : 1, 2, 4") + + self.datavolume = datavolume + self.step_size = step_size + self.rk_order = rk_order + self.theta = theta + self.space = space + self.origin = origin + self.normalize_directions = True + # Will be reset at each new streamline. + self.line_rng_generator = None + + def reset_data(self, new_data=None): + return super().reset_data(new_data) + + def prepare_forward(self, seeding_pos, random_generator): + direction = self.datavolume.get_absolute_direction(*seeding_pos) + + # Validate seeding within a fibertube. + if direction is None: + return PropagationStatus.ERROR + + self.line_rng_generator = random_generator + + return TrackingDirection(direction) + + def prepare_backward(self, line, forward_dir): + return super().prepare_backward(line, forward_dir) + + def finalize_streamline(self, last_pos, v_in): + return super().finalize_streamline(last_pos, v_in) + + def propagate(self, line, v_in): + return super().propagate(line, v_in) + + def _sample_next_direction(self, pos, v_in): + directions, volumes = self._get_possible_next_dirs(pos, v_in) + + # Sampling one. + if np.sum(volumes) > 0: + v_out = directions[ + sample_distribution(volumes, self.line_rng_generator)] + return v_out + return None + + def _get_possible_next_dirs(self, pos, v_in): + directions, volumes = ( + self.datavolume.get_value_at_coordinate(*pos, self.space, + self.origin)) + + # Angle threshold + valid_dirs = [] + valid_volumes = [] + + for i, dir in enumerate(directions): + num = np.dot(v_in, dir) + cosine = num / (np.linalg.norm(v_in) * + np.linalg.norm(dir)) + + # Flip direction if facing the wrong way + if cosine < 0: + cosine = abs(cosine) + dir = -dir + + cosine = np.clip(cosine, -1, 1) + + if (np.arccos(cosine) > self.theta): + continue + + valid_dirs.append(dir) + valid_volumes.append(volumes[i]) + + valid_dirs = np.array(valid_dirs) + valid_volumes = np.array(valid_volumes) + + return valid_dirs, valid_volumes diff --git a/scilpy/tracking/seed.py b/scilpy/tracking/seed.py index 4278897fe..3ba27697f 100644 --- a/scilpy/tracking/seed.py +++ b/scilpy/tracking/seed.py @@ -2,6 +2,7 @@ import numpy as np from dipy.io.stateful_tractogram import Space, Origin +from scilpy.tracking.fibertube_utils import sample_cylinder class SeedGenerator: @@ -264,3 +265,102 @@ def init_generator(self, rng_seed, numbers_to_skip): random_generator.random_sample(random_numbers_to_skip) return random_generator, indices + + +class FibertubeSeedGenerator(SeedGenerator): + """ + Adaptation of the scilpy.tracking.seed.SeedGenerator interface for + fibertube tracking. Generates a given number of seed within the first + segment of a given number of fibertubes. + """ + def __init__(self, centerlines, diameters, nb_seeds_per_fibertube): + """ + Parameters + ---------- + centerlines: list + Tractogram containing the fibertube centerlines + diameters: list + Diameters of each fibertube + nb_seeds_per_fibertube: int + """ + self.space = Space.VOXMM + self.origin = Origin.NIFTI + + self.centerlines = centerlines + self.diameters = diameters + self.nb_seeds_per_fibertube = nb_seeds_per_fibertube + + def init_generator(self, rng_seed, numbers_to_skip): + """ + Initialize a numpy number generator according to user's parameters. + Returns also the shuffled index of all fibertubes. + + The values are not stored in this classed, but are returned to the + user, who should add them as arguments in the methods + self.get_next_pos() + self.get_next_n_pos() + The use of this is that with multiprocessing, each process may have its + own generator, with less risk of using the wrong one when they are + managed by the user. + + Parameters + ---------- + rng_seed : int + The "seed" for the random generator. + numbers_to_skip : int + Number of seeds (i.e. voxels) to skip. Useful if you want to + continue sampling from the same generator as in a first experiment + (with a fixed rng_seed). + + Return + ------ + random_generator : numpy random generator + Initialized numpy number generator. + indices : ndarray + Shuffled indices of current seeding map, shuffled with current + generator. + """ + self.generator = np.random.RandomState(rng_seed) + + # 1. Initializing seeding maps indices (shuffling in-place) + indices = np.arange(len(self.centerlines)) + self.generator.shuffle(indices) + + # 2. Generating the seed for the random sampling. + # Because FibertubeSeedGenerator uses rejection sampling to seed + # within a cylinder, we can't predict how many generator calls will + # be done in each thread to avoid duplicates. We instead generate a + # single, predictable number used as a seed for the rejection + # sampling. + while numbers_to_skip > 100000: + self.generator.random_sample(100000) + numbers_to_skip -= 100000 + self.generator.random_sample(numbers_to_skip) + sampling_rng_seed = self.generator.randint(0, 2**32-1) + self.sampling_generator = np.random.default_rng(sampling_rng_seed) + + return self.sampling_generator, indices + + def get_next_pos(self, random_generator: np.random.Generator, + shuffled_indices, which_seed): + which_fi = which_seed // self.nb_seeds_per_fibertube + + fiber = self.centerlines[shuffled_indices[which_fi]] + radius = self.diameters[shuffled_indices[which_fi]] / 2 + + seed = sample_cylinder(fiber[0], fiber[1], radius, 1, + random_generator)[0] + + return seed[0], seed[1], seed[2] + + def get_next_n_pos(self, random_generator, shuffled_indices, + which_seed_start, n): + which_fi = which_seed_start // self.nb_seeds_per_fibertube + + fiber = self.centerlines[shuffled_indices[which_fi]] + radius = self.diameters[shuffled_indices[which_fi]] / 2 + + seeds = sample_cylinder(fiber[0], fiber[1], radius, n, + random_generator) + + return seeds diff --git a/scilpy/tracking/tests/test_propagator.py b/scilpy/tracking/tests/test_propagator.py index 433c876ae..48ba877bf 100644 --- a/scilpy/tracking/tests/test_propagator.py +++ b/scilpy/tracking/tests/test_propagator.py @@ -4,7 +4,7 @@ def test_class_propagator(): """ We will not test the tracker / propagator : too big to be tested, and only - used in scil_tracking_local_dev, which is intented for developping and - testing new parameters. + used in scil_tracking_local_dev and scil_fibertube_tracking, which are + intented for developping and testing new parameters. """ pass diff --git a/scilpy/tracking/tests/test_tracker.py b/scilpy/tracking/tests/test_tracker.py index 6b5e50d1c..e949f79c1 100644 --- a/scilpy/tracking/tests/test_tracker.py +++ b/scilpy/tracking/tests/test_tracker.py @@ -4,7 +4,7 @@ def test_class_tracker(): """ We will not test the tracker / propagator : too big to be tested, and only - used in scil_tracking_local_dev, which is intented for developping and - testing new parameters. + used in scil_tracking_local_dev and scil_fibertube_tracking, which are + intented for developping and testing new parameters. """ pass diff --git a/scilpy/tractanalysis/fibertube_scoring.py b/scilpy/tractanalysis/fibertube_scoring.py new file mode 100644 index 000000000..2ce4d365b --- /dev/null +++ b/scilpy/tractanalysis/fibertube_scoring.py @@ -0,0 +1,487 @@ +import numpy as np +from numba import objmode + +from math import sqrt, acos +from numba import njit +from numba_kdtree import KDTree as nbKDTree +from scipy.spatial import KDTree +from scipy.spatial.transform import Rotation +from scilpy.tracking.fibertube_utils import (streamlines_to_segments, + point_in_cylinder, + dist_segment_segment, + dist_point_segment) +from scilpy.tracking.utils import tqdm_if_verbose +from scilpy.tractanalysis.todi import TrackOrientationDensityImaging + + +def mean_fibertube_density(sft): + """ + Estimates the average per-voxel spatial density of a set of fibertubes. + This is obtained by dividing the volume of fibertube segments present + each voxel by the the total volume of the voxel. + + Parameters + ---------- + sft: StatefulTractogram + Stateful Tractogram object containing the fibertubes. + + Returns + ------- + mean_density: float + Per-voxel spatial density, averaged for the whole tractogram. + """ + diameters = np.reshape(sft.data_per_streamline['diameters'], + len(sft.streamlines)) + mean_diameter = np.mean(diameters) + + mean_segment_lengths = [] + for streamline in sft.streamlines: + mean_segment_lengths.append( + np.mean(np.linalg.norm(streamline[1:] - streamline[:-1], axis=-1))) + mean_segment_length = np.mean(mean_segment_lengths) + # Computing mean tube density per voxel. + sft.to_vox() + # Because compute_todi expects streamline points (in voxel coordinates) + # to be in the range [0, size] rather than [-0.5, size - 0.5], we shift + # the voxel origin to corner. + sft.to_corner() + + # Computing TDI + _, data_shape, _, _ = sft.space_attributes + todi_obj = TrackOrientationDensityImaging(tuple(data_shape)) + todi_obj.compute_todi(sft.streamlines) + img = todi_obj.get_tdi() + img = todi_obj.reshape_to_3d(img) + + nb_voxels_nonzero = np.count_nonzero(img) + sum = np.sum(img, axis=-1) + sum = np.sum(sum, axis=-1) + sum = np.sum(sum, axis=-1) + + mean_seg_volume = np.pi * ((mean_diameter/2) ** 2) * mean_segment_length + + mean_seg_count = sum / nb_voxels_nonzero + mean_volume = mean_seg_count * mean_seg_volume + mean_density = mean_volume / (sft.voxel_sizes[0] * + sft.voxel_sizes[1] * + sft.voxel_sizes[2]) + + return mean_density + + +def min_external_distance(sft, verbose): + """" + Calculates the minimal distance in between two fibertubes. A RuntimeError + is thrown if a collision is detected (i.e. a negative distance is found). + Use IntersectionFinder to remove intersections from fibertubes. + + Parameters + ---------- + sft: StatefulTractogram + Stateful Tractogram object containing the fibertubes + verbose: bool + Whether to make the function verbose. + + Returns + ------- + min_external_distance: float + Minimal distance found between two fibertubes. + min_external_distance_vec: ndarray + Vector representation of min_external_distance. + """ + centerlines = sft.streamlines + diameters = np.reshape(sft.data_per_streamline['diameters'], + len(centerlines)) + max_diameter = np.max(diameters) + + if len(centerlines) <= 1: + ValueError("Cannot compute metrics of a tractogram with a single" + + "streamline or less") + seg_centers, seg_indices, max_seg_length = streamlines_to_segments( + centerlines, verbose) + tree = KDTree(seg_centers) + min_external_distance = np.inf + min_external_distance_vec = np.zeros(0, dtype=np.float32) + + for segi, center in tqdm_if_verbose(enumerate(seg_centers), verbose, + total=len(seg_centers)): + si = seg_indices[segi][0] + + neighbors = tree.query_ball_point(center, + max_seg_length + max_diameter, + workers=-1) + + for neighbor_segi in neighbors: + neighbor_si = seg_indices[neighbor_segi][0] + + # Skip if neighbor is our streamline + if neighbor_si == si: + continue + + p0 = centerlines[si][seg_indices[segi][1]] + p1 = centerlines[si][seg_indices[segi][1] + 1] + q0 = centerlines[neighbor_si][seg_indices[neighbor_segi][1]] + q1 = centerlines[neighbor_si][seg_indices[neighbor_segi][1] + 1] + + rp = diameters[si] / 2 + rq = diameters[neighbor_si] / 2 + + distance, vector, *_ = dist_segment_segment(p0, p1, q0, q1) + external_distance = distance - rp - rq + + if external_distance < 0: + raise RuntimeError( + 'The input streamlines contained a collision after \n' + 'filtering. This is unlikely to be an error of this \n' + 'script, and instead may be due to your original data \n' + 'using very high float precision. For more info on \n' + 'this issue, please see the documentation for' + 'scil_tractogram_filter_collisions.py.') + + if (external_distance < min_external_distance): + min_external_distance = external_distance + min_external_distance_vec = ( + get_external_vector_from_centerline_vector(vector, rp, rq) + ) + + return min_external_distance, min_external_distance_vec + + +def max_voxels(diagonal): + """ + Given the vector representing the smallest distance between two + fibertubes, calculates the maximum sized voxels (anisotropic & isotropic) + without causing any partial-volume effect. + + These voxel are expressed in the current 3D referential and are + often rendered meaningless by it. See function max_voxel_rotated for an + alternative. + + Parameters + ---------- + diagonal: ndarray + Vector representing the smallest distance between two + fibertubes. + + Returns + ------- + max_voxel_anisotropic: ndarray + Maximum sized anisotropic voxel. + max_voxel_isotropic: ndarray + Maximum sized isotropic voxel. + """ + max_voxel_anisotropic = np.abs(diagonal).astype(np.float32) + + # Find an isotropic voxel within the anisotropic one + min_edge = min(max_voxel_anisotropic) + max_voxel_isotropic = np.array([min_edge, min_edge, min_edge], + dtype=np.float32) + + return (max_voxel_anisotropic, max_voxel_isotropic) + + +def max_voxel_rotated(diagonal): + """ + Given the vector representing the smallest distance between two + fibertubes, calculates the maximum sized voxel without causing any + partial-volume effect. This voxel is isotropic. + + This voxel is not expressed in the current 3D referential. It will require + the tractogram to be rotated according to rotation_matrix for this voxel + to be applicable. + + Parameters + ---------- + diagonal: ndarray + Vector representing the smallest distance between two + fibertubes. + + Returns + ------- + rotation_matrix: ndarray + 3x3 rotation matrix to be applied to the tractogram to align it with + the voxel + edge: float + Edge size of the max_voxel_rotated. + """ + hyp = np.linalg.norm(diagonal) + edge = hyp / 3*sqrt(3) + + # The rotation should be such that the diagonal becomes aligned + # with [1, 1, 1] + diag = diagonal / np.linalg.norm(diagonal) + dest = [1, 1, 1] / np.linalg.norm([1, 1, 1]) + + v = np.cross(diag, dest) + v /= np.linalg.norm(v) + theta = acos(np.dot(diag, dest)) + rotation_matrix = Rotation.from_rotvec(v * theta).as_matrix() + + return (rotation_matrix, edge) + + +@njit +def get_external_vector_from_centerline_vector(vector, r1, r2): + """ + Given a vector separating two fibertube centerlines, finds a + vector that separates them from outside their diameter. + + Parameters + ---------- + vector: ndarray + Vector between two fibertube centerlines. + rp: ndarray + Radius of one of the fibertubes. + rq: ndarray + Radius of the other fibertube. + + Results + ------- + external_vector: ndarray + Vector between the two fibertubes, outside their diameter. + """ + unit_vector = vector / np.linalg.norm(vector) + external_vector = (vector - r1 * unit_vector - r2 * + unit_vector) + + return external_vector + + +@njit +def resolve_origin_seeding(seeds, centerlines, diameters): + """ + Associates given seeds to segment 0 of the fibertube in which they have + been generated. This pairing only works with fiber origin seeding. + + Parameters + ---------- + seeds: ndarray + centerlines: ndarray + Fibertube centerlines given as a fixed array + (see streamlines_as_fixed_array). + diameters: ndarray + + Return + ------ + seeds_fiber: ndarray + Array containing the fiber index of each seed. If the seed is not in a + fiber, its value will be -1. + """ + seeds_fiber = [-1] * len(seeds) + + for si, seed in enumerate(seeds): + for fi, fiber in enumerate(centerlines): + if point_in_cylinder(fiber[0], fiber[1], diameters[fi]/2, seed): + seeds_fiber[si] = fi + break + + return np.array(seeds_fiber) + + +@njit +def mean_reconstruction_error(centerlines, centerlines_length, diameters, + streamlines, streamlines_length, seeds_fiber, + return_error_tractogram=False): + """ + For each provided streamline, finds the mean distance between its + coordinates and the fibertube it has been seeded in. + + Parameters + ---------- + centerlines: ndarray + Fixed array containing ground-truth fibertube centerlines. + centerlines_length: ndarray + Fixed array containing the number of coordinates of each fibertube + centerlines. + diameters: list, + Diameters of the fibertubes + streamlines: ndarray + Fixed array containing streamlines resulting from the tracking + process. + streamlines_length: ndarray, + Fixed array containing the number of coordinates of each streamline + seeds_fiber: list + Array of the same length as there are streamlines. For every + streamline, contains the index of the fiber in which it has been + seeded. + return_error_tractogram: bool = False + + Return + ------ + mean_errors: list + Array containing the mean error for every streamline. + error_tractogram: list + Empty when return_error_tractogram is set to False. Otherwise, + contains a visual representation of the error between every streamline + and the fiber in which it has been seeded. + """ + mean_errors = [] + error_tractogram = [] + + # objmode allows the execution of non numba-compatible code within a numba + # function + with objmode(centers='float64[:, :]', indices='int64[:, :]'): + centers, indices, _ = streamlines_to_segments(centerlines, False) + centers_fixed_length = len(centerlines[0])-1 + + # Building a tree for first fibertube + tree = nbKDTree(centers[:centerlines_length[0]-1]) + tree_fi = 0 + + for si, streamline_fixed in enumerate(streamlines): + streamline = streamline_fixed[:streamlines_length[si]-1] + errors = [] + + seeded_fi = seeds_fiber[si] + fiber = centerlines[seeded_fi] + radius = diameters[seeded_fi] / 2 + + # Rebuild tree for current fiber. + if tree_fi != seeded_fi: + tree = nbKDTree( + centers[centers_fixed_length * seeded_fi: + (centers_fixed_length * seeded_fi + + centerlines_length[seeded_fi] - 1)]) + + # Querying nearest neighbor for each coordinate of the streamline. + neighbors = tree.query_parallel(streamline)[1] + + for pi, point in enumerate(streamline): + nearest_index = neighbors[pi][0] + + # Retrieving the closest cylinder segment. + _, pi = indices[nearest_index] + pt1 = fiber[pi] + pt2 = fiber[pi + 1] + + # If we're within the fiber, error = 0 + if (np.linalg.norm(point - pt1) < radius or + np.linalg.norm(point - pt2) < radius or + point_in_cylinder(pt1, pt2, radius, point)): + errors.append(0.) + else: + distance, vector, segment_collision_point = dist_point_segment( + pt1, pt2, point) + errors.append(distance - radius) + + if return_error_tractogram: + fiber_collision_point = segment_collision_point - ( + vector / np.linalg.norm(vector)) * radius + error_tractogram.append([fiber_collision_point, point]) + + mean_errors.append(np.array(errors).mean()) + + return mean_errors, error_tractogram + + +@njit +def endpoint_connectivity(blur_radius, centerlines, centerlines_length, + diameters, streamlines, seeds_fiber): + """ + For every streamline, find whether or not it has reached the end segment + of its fibertube. Each streamline is associated with an "Arrival fibertube + segment", which is the closest fibertube segment to its before-last + coordinate. + + IMPORTANT: Streamlines given as input to be scored should be forward-only, + which means they are saved so that [0] is the seeding position and [-1] is + the end. + + VC: "Valid Connection": A streamline whose arrival fibertube segment is + the final segment of the fibertube in which is was originally seeded. + + IC: "Invalid Connection": A streamline whose arrival fibertube segment is + the start or final segment of a fibertube in which is was not seeded. + + NC: "No Connection": A streamline whose arrival fibertube segment is + not the start or final segment of any fibertube. + + Parameters + ---------- + blur_radius: float + Blur radius used during fibertube tracking. + centerlines: ndarray + Fixed array containing ground-truth fibertube centerlines. + centerlines_length: ndarray + Fixed array containing the number of coordinates of each fibertube + centerlines. + diameters: list, + Diameters of the fibertubes. + streamlines: ndarray + Fixed array containing streamlines resulting from the tracking + process. + streamlines_length: ndarray, + Fixed array containing the number of coordinates of each streamline + seeds_fiber: list + Array of the same length as there are streamlines. For every + streamline, contains the index of the fibertube in which it has been + seeded. + + Return + ------ + vc: list + List containing the indices of all streamlines that are valid + connections. + ic: list + List containing the indices of all streamlines that are invalid + connections. + nc: list + List containing the indices of all streamlines that are no + connections. + """ + max_diameter = np.max(diameters) + + # objmode allows the execution of non numba-compatible code within a numba + # function + with objmode(centers='float64[:, :]', indices='int64[:, :]', + max_seg_length='float64'): + centers, indices, max_seg_length = streamlines_to_segments( + centerlines, False) + + tree = nbKDTree(centers) + + vc = set() + ic = set() + nc = set() + + # streamline[-2] is the last point with a valid direction + all_neighbors = tree.query_radius( + streamlines[:, -2], blur_radius + max_seg_length / 2 + max_diameter) + + for streamline_index, streamline in enumerate(streamlines): + seed_fi = seeds_fiber[streamline_index] + neighbors = all_neighbors[streamline_index] + + closest_dist = np.inf + closest_seg = 0 + + # Finding closest segment + # There will always be a neighbor to override np.inf + for segment_index in neighbors: + fibertube_index = indices[segment_index][0] + point_index = indices[segment_index][1] + + dist, _, _ = dist_point_segment( + centerlines[fibertube_index][point_index], + centerlines[fibertube_index][point_index+1], + streamline[-2]) + + if dist < closest_dist: + closest_dist = dist + closest_seg = segment_index + + fibertube_index = indices[closest_seg][0] + point_index = indices[closest_seg][1] + + # If the closest segment is the last of its centerlines + if point_index == centerlines_length[fibertube_index]-1: + if fibertube_index == seed_fi: + vc.add(streamline_index) + else: + ic.add(streamline_index) + elif point_index == 0: + ic.add(streamline_index) + else: + nc.add(streamline_index) + + return list(vc), list(ic), list(nc) diff --git a/scilpy/tractanalysis/fixel_density.py b/scilpy/tractanalysis/fixel_density.py index 55216fb5f..3773ace84 100644 --- a/scilpy/tractanalysis/fixel_density.py +++ b/scilpy/tractanalysis/fixel_density.py @@ -7,11 +7,12 @@ def _fixel_density_parallel(args): - peaks = args[0] - max_theta = args[1] - dps_key = args[2] - bundle = args[3] + (peaks, max_theta, dps_key, bundle) = args + return _fixel_density_single_bundle(bundle, peaks, max_theta, dps_key) + + +def _fixel_density_single_bundle(bundle, peaks, max_theta, dps_key): sft = load_tractogram(bundle, 'same') sft.to_vox() sft.to_corner() @@ -83,14 +84,22 @@ def fixel_density(peaks, bundles, dps_key=None, max_theta=45, if nbr_processes is None or nbr_processes <= 0 \ else nbr_processes - pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_fixel_density_parallel, - zip(itertools.repeat(peaks), - itertools.repeat(max_theta), - itertools.repeat(dps_key), - bundles)) - pool.close() - pool.join() + # Separating the case nbr_processes=1 to help get good coverage metrics + # (codecov does not deal well with multiprocessing) + if nbr_processes == 1: + results = [] + for b in bundles: + results.append( + _fixel_density_single_bundle(b, peaks, max_theta, dps_key)) + else: + pool = multiprocessing.Pool(nbr_processes) + results = pool.map(_fixel_density_parallel, + zip(itertools.repeat(peaks), + itertools.repeat(max_theta), + itertools.repeat(dps_key), + bundles)) + pool.close() + pool.join() fixel_density = np.moveaxis(np.asarray(results), 0, -1) diff --git a/scilpy/tractanalysis/mrds_along_streamlines.py b/scilpy/tractanalysis/mrds_along_streamlines.py new file mode 100644 index 000000000..91bdeca14 --- /dev/null +++ b/scilpy/tractanalysis/mrds_along_streamlines.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- + +import numpy as np + +from scilpy.tractanalysis.grid_intersections import grid_intersections + + +def mrds_metrics_along_streamlines(sft, mrds_pdds, + metrics, max_theta, + length_weighting): + """ + Compute mean map for a given fixel-specific metric along streamlines. + + Parameters + ---------- + sft : StatefulTractogram + StatefulTractogram containing the streamlines needed. + mrds_pdds : ndarray (X, Y, Z, 3*N_TENSORS) + MRDS principal diffusion directions of the tensors + metrics : list of ndarray + Array of shape (X, Y, Z, N_TENSORS) containing the fixel-specific + metric of interest. + max_theta : float + Maximum angle in degrees between the fiber direction and the + MRDS principal diffusion direction. + length_weighting : bool + If True, will weigh the metric values according to segment lengths. + """ + + mrds_sum, weights = \ + mrds_metric_sums_along_streamlines(sft, mrds_pdds, + metrics, max_theta, + length_weighting) + + all_metric = mrds_sum[0] + for curr_metric in mrds_sum[1:]: + all_metric += np.abs(curr_metric) + + non_zeros = np.nonzero(all_metric) + weights_nz = weights[non_zeros] + for metric_idx in range(len(metrics)): + mrds_sum[metric_idx][non_zeros] /= weights_nz + + return mrds_sum + + +def mrds_metric_sums_along_streamlines(sft, mrds_pdds, metrics, + max_theta, length_weighting): + """ + Compute a sum map along a bundle for a given fixel-specific metric. + + Parameters + ---------- + sft : StatefulTractogram + StatefulTractogram containing the streamlines needed. + mrds_pdds : ndarray (X, Y, Z, 3*N_TENSORS) + MRDS principal diffusion directions (PDDs) of the tensors + metrics : list of ndarray (X, Y, Z, N_TENSORS) + Fixel-specific metrics. + max_theta : float + Maximum angle in degrees between the fiber direction and the + MRDS principal diffusion direction. + length_weighting : bool + If True, will weight the metric values according to segment lengths. + + Returns + ------- + metric_sum_map : np.array + fixel-specific metrics sum map. + weight_map : np.array + Segment lengths. + """ + + sft.to_vox() + sft.to_corner() + + X, Y, Z = metrics[0].shape[0:3] + metrics_sum_map = np.zeros((len(metrics), X, Y, Z)) + weight_map = np.zeros(metrics[0].shape[:-1]) + min_cos_theta = np.cos(np.radians(max_theta)) + + all_crossed_indices = grid_intersections(sft.streamlines) + for crossed_indices in all_crossed_indices: + segments = crossed_indices[1:] - crossed_indices[:-1] + seg_lengths = np.linalg.norm(segments, axis=1) + + # Remove points where the segment is zero. + # This removes numpy warnings of division by zero. + non_zero_lengths = np.nonzero(seg_lengths)[0] + segments = segments[non_zero_lengths] + seg_lengths = seg_lengths[non_zero_lengths] + + # Those starting points are used for the segment vox_idx computations + seg_start = crossed_indices[non_zero_lengths] + vox_indices = (seg_start + (0.5 * segments)).astype(int) + + normalization_weights = np.ones_like(seg_lengths) + if length_weighting: + normalization_weights = seg_lengths + + normalized_seg = np.reshape(segments / seg_lengths[..., None], (-1, 3)) + + # Reshape MRDS PDDs + mrds_pdds = mrds_pdds.reshape(mrds_pdds.shape[0], + mrds_pdds.shape[1], + mrds_pdds.shape[2], -1, 3) + + for vox_idx, seg_dir, norm_weight in zip(vox_indices, + normalized_seg, + normalization_weights): + vox_idx = tuple(vox_idx) + + mrds_peak_dir = mrds_pdds[vox_idx] + + cos_theta = np.abs(np.dot(seg_dir.reshape((-1, 3)), + mrds_peak_dir.T)) + + metric_val = [0.0]*len(metrics) + if (cos_theta > min_cos_theta).any(): + fixel_idx = np.argmax(np.squeeze(cos_theta), + axis=0) # (n_segs) + + for metric_idx, curr_metric in enumerate(metrics): + metric_val[metric_idx] = curr_metric[vox_idx][fixel_idx] + + for metric_idx, curr_metric in enumerate(metrics): + metrics_sum_map[metric_idx][vox_idx] += metric_val[metric_idx] * norm_weight + weight_map[vox_idx] += norm_weight + + return metrics_sum_map, weight_map diff --git a/scilpy/tractograms/intersection_finder.py b/scilpy/tractograms/intersection_finder.py new file mode 100644 index 000000000..72d805ee6 --- /dev/null +++ b/scilpy/tractograms/intersection_finder.py @@ -0,0 +1,232 @@ +import time +import math +import logging +import numpy as np + +from scipy.spatial import KDTree +from scilpy.tracking.fibertube_utils import (streamlines_to_segments, + dist_segment_segment) +from dipy.io.stateful_tractogram import StatefulTractogram +from scilpy.tracking.utils import tqdm_if_verbose + + +class IntersectionFinder: + """ + Utility class for finding intersections in a given StatefulTractogram with + a diameter for each streamline. + """ + + FLOAT_EPSILON = 1e-7 + + def __init__(self, in_sft: StatefulTractogram, diameters: list, + verbose=False): + """ + Builds a KDTree from all the tractogram's segments + and stores data required later for filtering. + + Parameters + ---------- + in_sft : StatefulTractogram + Stateful Tractogram object containing streamlines to filter. + diameters : list + Diameters of each streamline of the tractogram. + verbose : bool + Should produce verbose output. + """ + self.diameters = diameters + self.max_diameter = np.max(diameters) + self.verbose = verbose + self.in_sft = in_sft + self.streamlines = in_sft.streamlines + self.seg_centers, self.seg_indices, self.max_seg_length = ( + streamlines_to_segments(self.streamlines, verbose)) + self.tree = KDTree(self.seg_centers) + + self._invalid = [] + self._collisions = [] + self._obstacle = [] + self._excluded = [] + + if self.max_seg_length >= 0.3: + logging.warning("The longest streamline segment is over 0.3mm. " + + "Performance may drop significantly. " + + "Resampling to ~0.2mm is recommended. " + "(See scil_tractogram_resample_nb_points.py)") + + @property + def invalid(self): + """Streamlines that hit another streamline and should be + filtered out.""" + return self._invalid + + @property + def collisions(self): + """Collision point of each invalid streamline.""" + return self._collisions + + @property + def obstacle(self): + """Streamlines hit by an invalid streamline. They should not + be filtered and are saved separately merely for visualization.""" + return self._obstacle + + @property + def excluded(self): + """Streamlines that don't collide, but should be excluded for + other reasons.""" + return self._excluded + + def find_intersections(self, min_distance=0): + """ + Finds intersections within the initialized data of the object + + Produces and stores: + invalid : ndarray[bool] + Bit map identifying streamlines that hit another streamline + and should be filtered out. + collisions : ndarray[float32] + Collision point of each collider. + obstacle : ndarray[bool] + Streamlines hit by invalid. They should not be filtered and + are flagged simply for visualization. + excluded : ndarray[bool] + Streamlines that don't collide, but should be excluded for + other reasons. (ex: distance does not respect min_distance) + + Parameters + ---------- + min_distance: float + If set, streamlines will be filtered more + aggressively so that even if they don\'t collide, + being below [min_distance] apart (external to their + diameter) will be interpreted as a collision. This + option is the same as filtering with a large diameter + but only saving a small diameter in out_tractogram. + (Value in mm) + """ + start_time = time.time() + streamlines = self.streamlines + + invalid = np.full((len(streamlines)), False, dtype=np.bool_) + collisions = np.zeros((len(streamlines), 3), dtype=np.float32) + obstacle = np.full((len(streamlines)), False, dtype=np.bool_) + excluded = np.full((len(streamlines)), False, dtype=np.bool_) + + # si : Streamline Index | index of streamline within the tractogram. + # pi : Point Index | index of point coordinate within a + # streamline. + # segi : Segment Index | index of streamline segment within the + # entire tractogram. + for segi, center in tqdm_if_verbose(enumerate(self.seg_centers), + self.verbose, + total=len(self.seg_centers)): + si = self.seg_indices[segi][0] + + # [Pruning 1] If current streamline has already collided or been + # excluded, skip. + if invalid[si] or excluded[si]: + continue + + neighbors = self.tree.query_ball_point( + center, + self.max_seg_length + self.max_diameter + min_distance, + workers=-1) + + for neighbor_segi in neighbors: + neighbor_si = self.seg_indices[neighbor_segi][0] + + # [Pruning 2] Skip if neighbor is our streamline + if neighbor_si == si: + continue + + # [Pruning 3] If neighbor has already collided or been + # excluded, skip. + if invalid[neighbor_si] or excluded[neighbor_si]: + continue + + p0 = streamlines[si][self.seg_indices[segi][1]] + p1 = streamlines[si][self.seg_indices[segi][1] + 1] + q0 = streamlines[neighbor_si][ + self.seg_indices[neighbor_segi][1]] + q1 = streamlines[neighbor_si][ + self.seg_indices[neighbor_segi][1] + 1] + + rp = self.diameters[si] / 2 + rq = self.diameters[neighbor_si] / 2 + + distance, _, p_coll, q_coll = dist_segment_segment(p0, p1, + q0, q1) + external_distance = distance - rp - rq + + if external_distance < 0: + invalid[si] = True + # Rough estimate of collision point + collisions[si] = (p_coll + q_coll) / 2 + obstacle[neighbor_si] = True + break + if min_distance != 0 and external_distance < min_distance: + excluded[si] = True + break + + logging.debug("Finished finding intersections in " + + str(round(time.time() - start_time, 2)) + " seconds.") + + self._invalid = invalid + self._collisions = collisions + self._obstacle = obstacle + self._excluded = excluded + + def build_tractograms(self, save_colliding): + """ + Builds and saves the various tractograms obtained from + find_intersections(). + + Parameters + ---------- + save_colliding: bool + If set, will return invalid_sft and obstacle_sft in addition to + out_sft. + + Return + ------ + out_sft: StatefulTractogram + Tractogram containing final streamlines void of collision. + invalid_sft: StatefulTractogram | None + Tractogram containing the invalid streamlines that have been + removed. + obstacle_sft: StatefulTractogram | None + Tractogram containing the streamlines that the invalid + streamlines collided with. May or may not have been removed + afterwards during filtering. + """ + out_streamlines = [] + out_diameters = [] + out_collisions = [] + out_invalid = [] + out_obstacle = [] + + for si, s in tqdm_if_verbose(enumerate(self.streamlines), self.verbose, + total=len(self.streamlines)): + if self._invalid[si]: + out_invalid.append(s) + out_collisions.append(self._collisions[si]) + elif not self._excluded[si]: + out_streamlines.append(s) + out_diameters.append(self.diameters[si]) + if self._obstacle[si]: + out_obstacle.append(s) + + out_sft = StatefulTractogram.from_sft( + out_streamlines, self.in_sft, + data_per_streamline={'diameters': out_diameters}) + if save_colliding: + invalid_sft = StatefulTractogram.from_sft( + out_invalid, + self.in_sft, + data_per_streamline={'collisions': out_collisions}) + obstacle_sft = StatefulTractogram.from_sft( + out_obstacle, + self.in_sft) + return out_sft, invalid_sft, obstacle_sft + + return out_sft, None, None diff --git a/scilpy/tractograms/streamline_operations.py b/scilpy/tractograms/streamline_operations.py index 2036fd98d..5249b80a8 100644 --- a/scilpy/tractograms/streamline_operations.py +++ b/scilpy/tractograms/streamline_operations.py @@ -971,3 +971,30 @@ def get_streamlines_bounding_box(streamlines): box_max = np.maximum(box_max, np.max(s, axis=0)) return box_min, box_max + + +def get_streamlines_as_fixed_array(streamlines): + """ + Obtain streamlines as a fixed array of shape + (nbr of streamlines, max streamline length, 3). + + Useful for optimization with code precompiling. (See Numba) + + Parameters + ---------- + streamlines: list + + Return + ------ + streamlines_fixed: ndarray + Streamlines as a fixed length array, padded with 0. + lengths: ndarray + Single dimensional array of all the streamline lengths. + """ + lengths = [len(streamline) for streamline in streamlines] + streamlines_fixed = np.ndarray((len(streamlines), max(lengths), 3)) + for i, f in enumerate(streamlines_fixed): + for j, c in enumerate(streamlines[i]): + f[j] = c + + return streamlines_fixed, np.array(lengths) diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index 60b67c394..ef71b29be 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -154,6 +154,10 @@ def flip_sft(sft, flip_axes): ------- flipped_sft: StatefulTractogram """ + old_space = sft.space + old_origin = sft.origin + sft.to_vox() + sft.to_corner() if len(flip_axes) == 0: # Could return sft. But creating new SFT (or deep copy). flipped_streamlines = sft.streamlines @@ -174,6 +178,11 @@ def flip_sft(sft, flip_axes): data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline) + sft.to_space(old_space) + sft.to_origin(old_origin) + new_sft.to_space(old_space) + new_sft.to_origin(old_origin) + return new_sft diff --git a/scripts/scil_bundle_fixel_analysis.py b/scripts/scil_bundle_fixel_analysis.py old mode 100644 new mode 100755 diff --git a/scripts/scil_bundle_mean_fixel_mrds_metric.py b/scripts/scil_bundle_mean_fixel_mrds_metric.py new file mode 100755 index 000000000..8874a2e6b --- /dev/null +++ b/scripts/scil_bundle_mean_fixel_mrds_metric.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Given a bundle and MRDS metrics, compute the fixel-specific +metrics at each voxel intersected by the bundle. Intersected voxels are +found by computing the intersection between the voxel grid and each streamline +in the input tractogram. + +This script behaves like scil_bundle_mean_fixel_afd.py for fODFs, +but here for MRDS metrics. These latest distributions add the unique +possibility to capture fixel-based fractional anisotropy (fixel-FA), mean +diffusivity (fixel-MD), radial diffusivity (fixel-RD) and +axial diffusivity (fixel-AD). + +Fixel-specific metrics are metrics extracted from +Multi-Resolution Discrete-Search (MRDS) solutions. +There are as many values per voxel as there are fixels extracted. The +values chosen for a given voxel is the one belonging to the lobe better aligned +with the current streamline segment. + +Input files come from scil_mrds_metrics.py command. + +Output metrics will be named: [prefix]_mrds_[metric_name].nii.gz + +Please use a bundle file rather than a whole tractogram. +""" + +import argparse + +import nibabel as nib +import numpy as np + +from scilpy.io.streamlines import load_tractogram_with_reference +from scilpy.io.utils import (add_overwrite_arg, + add_reference_arg, + assert_headers_compatible, + assert_inputs_exist, assert_outputs_exist) +from scilpy.tractanalysis.mrds_along_streamlines \ + import mrds_metrics_along_streamlines + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('in_bundle', + help='Path of the bundle file.') + p.add_argument('in_pdds', + help='Path of the MRDS PDDs volume.') + + g = p.add_argument_group(title='MRDS metrics input') + g.add_argument('--fa', + help='Path of the fixel-specific metric FA volume.') + g.add_argument('--md', + help='Path of the fixel-specific metric MD volume.') + g.add_argument('--rd', + help='Path of the fixel-specific metric RD volume.') + g.add_argument('--ad', + help='Path of the fixel-specific metric AD volume.') + + p.add_argument('--prefix', default='result', + help='Prefix of the MRDS fixel results.') + + p.add_argument('--length_weighting', action='store_true', + help='If set, will weight the values according to ' + 'segment lengths. [%(default)s]') + + p.add_argument('--max_theta', default=60, type=float, + help='Maximum angle (in degrees) condition on fixel ' + 'alignment. [%(default)s]') + + add_reference_arg(p) + add_overwrite_arg(p) + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + in_metrics = [] + out_metrics = [] + if args.fa is not None: + in_metrics.append(args.fa) + out_metrics.append('{}_mrds_fFA.nii.gz'.format(args.prefix)) + if args.ad is not None: + in_metrics.append(args.ad) + out_metrics.append('{}_mrds_fAD.nii.gz'.format(args.prefix)) + if args.rd is not None: + in_metrics.append(args.rd) + out_metrics.append('{}_mrds_fRD.nii.gz'.format(args.prefix)) + if args.md is not None: + in_metrics.append(args.md) + out_metrics.append('{}_mrds_fMD.nii.gz'.format(args.prefix)) + + if in_metrics == []: + parser.error('At least one metric is required.') + + assert_inputs_exist(parser, [args.in_bundle, + args.in_pdds], in_metrics) + assert_headers_compatible(parser, [args.in_bundle, args.in_pdds], + in_metrics) + + assert_outputs_exist(parser, args, out_metrics) + + sft = load_tractogram_with_reference(parser, args, args.in_bundle) + pdds_img = nib.load(args.in_pdds) + affine = pdds_img.affine + header = pdds_img.header + + in_metrics_data = [nib.load(metric).get_fdata(dtype=np.float32) for metric in in_metrics] + fixel_metrics =\ + mrds_metrics_along_streamlines(sft, + pdds_img.get_fdata(dtype=np.float32), + in_metrics_data, + args.max_theta, + args.length_weighting) + + for metric_id, curr_metric in enumerate(fixel_metrics): + nib.Nifti1Image(curr_metric.astype(np.float32), + affine=affine, + header=header, + dtype=np.float32).to_filename(out_metrics[metric_id]) + + +if __name__ == '__main__': + main() diff --git a/scripts/scil_connectivity_compute_matrices.py b/scripts/scil_connectivity_compute_matrices.py index 10a200769..f28a1730b 100755 --- a/scripts/scil_connectivity_compute_matrices.py +++ b/scripts/scil_connectivity_compute_matrices.py @@ -3,273 +3,135 @@ """ This script computes a variety of measures in the form of connectivity -matrices. This script is made to follow -scil_tractogram_segment_connections_from_labels.py and -uses the same labels list as input. - -The script expects a folder containing all relevants bundles following the -naming convention LABEL1_LABEL2.trk and a text file containing the list of -labels that should be part of the matrices. The ordering of labels in the -matrices will follow the same order as the list. -This script only generates matrices in the form of array, does not visualize -or reorder the labels (node). - -The parameter --similarity expects a folder with density maps -(LABEL1_LABEL2.nii.gz) following the same naming convention as the input -directory. -The bundles should be averaged version in the same space. This will -compute the weighted-dice between each node and their homologuous average -version. - -The parameters --metrics can be used more than once and expect a map (t1, fa, -etc.) in the same space and each will generate a matrix. The average value in -the volume occupied by the bundle will be reported in the matrices nodes. - -The parameters --maps can be used more than once and expect a folder with -pre-computed maps (LABEL1_LABEL2.nii.gz) following the same naming convention -as the input directory. Each will generate a matrix. The average non-zeros -value in the map will be reported in the matrices nodes. - -The parameters --lesion_load will compute 3 lesion(s) related matrices: -lesion_count.npy, lesion_vol.npy, lesion_sc.npy and put it inside of a -specified folder. They represent the number of lesion, the total volume of -lesion(s) and the total of streamlines going through the lesion(s) for of each -connection. Each connection can be seen as a 'bundle' and then something -similar to scil_analyse_lesion_load.py is run for each 'bundle'. +matrices. This script only generates matrices in the form of array, it does not +visualize or reorder the labels (node). + +See also +>> scil_connectivity_compute_simple_matrix.py +which simply computes the connectivity matrix (either binary or with the +streamline count), directly from the endpoints. + +In comparison, the current script A) uses more complex segmentation, and +B) outputs more matrices, using various metrics, + +A) Connections segmentations +---------------------------- +Segmenting a tractogram based on its endpoints is not as straighforward as one +could imagine. The endpoints could be outside any labelled region. This script +is made to follow +>> scil_tractogram_segment_connections_from_labels.py, +which already carefully segmented the connections. + +The current script uses 1) the same labels list as input, 2) the resulting +pre-segmented tractogram in the hdf5 format, and 3) a text file containing the +list of labels that should be part of the matrices. The ordering of labels in +the matrices will follow the same order as the list. + +B) Outputs +---------- +Each connection can be seen as a 'bundle'. + + - Streamline count. + - Length: mean streamline length (mm). + - Volume-weighted: Volume of the bundle. + - Similarity: mean density. + Uses pre-computed density maps, which can be obtained with + >> scil_connectivity_hdf5_average_density_map.py + The bundles should be averaged version in the same space. This will + compute the weighted-dice between each node and their homologuous average + version. + - Any metric: You can provide your own maps through --metrics. The average + non-zero value in the volume occupied by the bundle will be reported in + the matrices nodes. + Ex: --metrics FA.niigz fa.npy --metrics T1.nii.gz t1.npy + - Lesions-related metrics: The option --lesion_load will compute 3 + lesion(s)-related matrices (saved in the chosen output directory): + lesion_count.npy, lesion_vol.npy, and lesion_sc.npy. They represent the + number of lesion, the total volume of lesion(s) and the total number of + streamlines going through the lesion(s) for each bundle. See also: + >> scil_analyse_lesion_load.py + >> scil_lesions_info.py + - Mean DPS: Mean values in the data_per_streamline of each streamline in the + bundles. Formerly: scil_compute_connectivity.py """ import argparse -import copy import itertools import logging import multiprocessing import os import coloredlogs -from dipy.io.utils import is_header_compatible, get_reference_info -from dipy.tracking.streamlinespeed import length -from dipy.tracking.vox2track import _streamlines_in_mask import h5py import nibabel as nib import numpy as np import scipy.ndimage as ndi +from scilpy.connectivity.connectivity import \ + compute_connectivity_matrices_from_hdf5, \ + multi_proc_compute_connectivity_matrices_from_hdf5 from scilpy.image.labels import get_data_as_labels -from scilpy.io.hdf5 import (assert_header_compatible_hdf5, - reconstruct_streamlines_from_hdf5) +from scilpy.io.hdf5 import assert_header_compatible_hdf5 from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - validate_nbr_processes) -from scilpy.tractanalysis.reproducibility_measures import \ - compute_bundle_adjacency_voxel -from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map -from scilpy.utils.metrics_tools import compute_lesion_stats - - -def load_node_nifti(directory, in_label, out_label, ref_img): - in_filename = os.path.join(directory, - '{}_{}.nii.gz'.format(in_label, out_label)) - - if os.path.isfile(in_filename): - if not is_header_compatible(in_filename, ref_img): - raise IOError('{} do not have a compatible header'.format( - in_filename)) - return nib.load(in_filename).get_fdata(dtype=np.float64) - - return None - - -def _processing_wrapper(args): - hdf5_filename = args[0] - labels_img = args[1] - in_label, out_label = args[2] - measures_to_compute = copy.copy(args[3]) - if args[4] is not None: - similarity_directory = args[4][0] - weighted = args[5] - include_dps = args[6] - min_lesion_vol = args[7] - - hdf5_file = h5py.File(hdf5_filename, 'r') - key = '{}_{}'.format(in_label, out_label) - if key not in hdf5_file: - return - streamlines = reconstruct_streamlines_from_hdf5(hdf5_file[key]) - if len(streamlines) == 0: - return - - affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img) - measures_to_return = {} - assert_header_compatible_hdf5(hdf5_file, (affine, dimensions)) - - # Precompute to save one transformation, insert later - if 'length' in measures_to_compute: - streamlines_copy = list(streamlines) - # scil_tractogram_segment_connections_from_labels.py requires - # isotropic voxels - mean_length = np.average(length(streamlines_copy))*voxel_sizes[0] - - # If density is not required, do not compute it - # Only required for volume, similarity and any metrics - if not ((len(measures_to_compute) == 1 and - ('length' in measures_to_compute or - 'streamline_count' in measures_to_compute)) or - (len(measures_to_compute) == 2 and - ('length' in measures_to_compute and - 'streamline_count' in measures_to_compute))): - - density = compute_tract_counts_map(streamlines, - dimensions) - - if 'volume' in measures_to_compute: - measures_to_return['volume'] = np.count_nonzero(density) * \ - np.prod(voxel_sizes) - measures_to_compute.remove('volume') - if 'streamline_count' in measures_to_compute: - measures_to_return['streamline_count'] = len(streamlines) - measures_to_compute.remove('streamline_count') - if 'length' in measures_to_compute: - measures_to_return['length'] = mean_length - measures_to_compute.remove('length') - if 'similarity' in measures_to_compute and similarity_directory: - density_sim = load_node_nifti(similarity_directory, - in_label, out_label, - labels_img) - if density_sim is None: - ba_vox = 0 - else: - ba_vox = compute_bundle_adjacency_voxel(density, density_sim) - - measures_to_return['similarity'] = ba_vox - measures_to_compute.remove('similarity') - - for measure in measures_to_compute: - # Maps - if isinstance(measure, str) and os.path.isdir(measure): - map_dirname = measure - map_data = load_node_nifti(map_dirname, - in_label, out_label, - labels_img) - measures_to_return[map_dirname] = np.average( - map_data[map_data > 0]) - elif isinstance(measure, tuple): - if not isinstance(measure[0], tuple) \ - and os.path.isfile(measure[0]): - metric_filename = measure[0] - metric_img = measure[1] - if not is_header_compatible(metric_img, labels_img): - logging.error('{} do not have a compatible header'.format( - metric_filename)) - raise IOError - - metric_data = metric_img.get_fdata(dtype=np.float64) - if weighted: - avg_value = np.average(metric_data, weights=density) - else: - avg_value = np.average(metric_data[density > 0]) - measures_to_return[metric_filename] = avg_value - # lesion - else: - lesion_filename = measure[0][0] - computed_lesion_labels = measure[0][1] - lesion_img = measure[1] - if not is_header_compatible(lesion_img, labels_img): - logging.error('{} do not have a compatible header'.format( - lesion_filename)) - raise IOError - - voxel_sizes = lesion_img.header.get_zooms()[0:3] - lesion_img.set_filename('tmp.nii.gz') - lesion_atlas = get_data_as_labels(lesion_img) - tmp_dict = compute_lesion_stats( - density.astype(bool), lesion_atlas, - voxel_sizes=voxel_sizes, single_label=True, - min_lesion_vol=min_lesion_vol, - precomputed_lesion_labels=computed_lesion_labels) - - tmp_ind = _streamlines_in_mask(list(streamlines), - lesion_atlas.astype(np.uint8), - np.eye(3), [0, 0, 0]) - streamlines_count = len( - np.where(tmp_ind == [0, 1][True])[0].tolist()) - - if tmp_dict: - measures_to_return[lesion_filename+'vol'] = \ - tmp_dict['lesion_total_volume'] - measures_to_return[lesion_filename+'count'] = \ - tmp_dict['lesion_count'] - measures_to_return[lesion_filename+'sc'] = \ - streamlines_count - else: - measures_to_return[lesion_filename+'vol'] = 0 - measures_to_return[lesion_filename+'count'] = 0 - measures_to_return[lesion_filename+'sc'] = 0 - - if include_dps: - for dps_key in hdf5_file[key].keys(): - if dps_key not in ['data', 'offsets', 'lengths']: - out_file = os.path.join(include_dps, dps_key) - if 'commit' in dps_key: - measures_to_return[out_file] = np.sum( - hdf5_file[key][dps_key]) - else: - measures_to_return[out_file] = np.average( - hdf5_file[key][dps_key]) - - return {(in_label, out_label): measures_to_return} + validate_nbr_processes, assert_inputs_dirs_exist, + assert_headers_compatible, + assert_output_dirs_exist_and_empty) def _build_arg_parser(): p = argparse.ArgumentParser( description=__doc__, - formatter_class=argparse.RawTextHelpFormatter,) + formatter_class=argparse.RawTextHelpFormatter) p.add_argument('in_hdf5', - help='Input filename for the hdf5 container (.h5).\n' - 'Obtained from ' - 'scil_tractogram_segment_connections_from_labels.py.') + help='Input filename for the hdf5 container (.h5).') p.add_argument('in_labels', help='Labels file name (nifti).\n' - 'This generates a NxN connectivity matrix.') - p.add_argument('--volume', metavar='OUT_FILE', - help='Output file for the volume weighted matrix (.npy).') - p.add_argument('--streamline_count', metavar='OUT_FILE', + 'This generates a NxN connectivity matrix, where N \n' + 'is the number of values in in_labels.') + + g = p.add_argument_group("Output matrices options") + g.add_argument('--volume', metavar='OUT_FILE', + help='Output file for the volume weighted matrix (.npy), ' + 'computed in mm3.') + g.add_argument('--streamline_count', metavar='OUT_FILE', help='Output file for the streamline count weighted matrix ' '(.npy).') - p.add_argument('--length', metavar='OUT_FILE', - help='Output file for the length weighted matrix (.npy).') - p.add_argument('--similarity', nargs=2, + g.add_argument('--length', metavar='OUT_FILE', + help='Output file for the length weighted matrix (.npy), ' + 'weighted in mm.') + g.add_argument('--similarity', nargs=2, metavar=('IN_FOLDER', 'OUT_FILE'), help='Input folder containing the averaged bundle density\n' 'maps (.nii.gz) and output file for the similarity ' - 'weighted matrix (.npy).') - p.add_argument('--maps', nargs=2, action='append', - metavar=('IN_FOLDER', 'OUT_FILE'), - help='Input folder containing pre-computed maps (.nii.gz)\n' - 'and output file for the weighted matrix (.npy).') - p.add_argument('--metrics', nargs=2, action='append', + 'weighted matrix (.npy).\n' + 'The density maps should be named using the same ' + 'labels as in the hdf5 (LABEL1_LABEL2.nii.gz).') + g.add_argument('--metrics', nargs=2, action='append', default=[], metavar=('IN_FILE', 'OUT_FILE'), help='Input (.nii.gz). and output file (.npy) for a metric ' 'weighted matrix.') - p.add_argument('--lesion_load', nargs=2, metavar=('IN_FILE', 'OUT_DIR'), + g.add_argument('--lesion_load', nargs=2, metavar=('IN_FILE', 'OUT_DIR'), help='Input binary mask (.nii.gz) and output directory ' 'for all lesion-related matrices.') - p.add_argument('--min_lesion_vol', type=float, default=7, - help='Minimum lesion volume in mm3 [%(default)s].') - - p.add_argument('--density_weighting', action="store_true", - help='Use density-weighting for the metric weighted' - 'matrix.') - p.add_argument('--no_self_connection', action="store_true", - help='Eliminate the diagonal from the matrices.') - p.add_argument('--include_dps', metavar='OUT_DIR', + g.add_argument('--include_dps', metavar='OUT_DIR', help='Save matrices from data_per_streamline in the output ' 'directory.\nCOMMIT-related values will be summed ' 'instead of averaged.\nWill always overwrite files.') - p.add_argument('--force_labels_list', + + g = p.add_argument_group("Processing options") + g.add_argument('--min_lesion_vol', type=float, default=7, + help='Minimum lesion volume in mm3 [%(default)s].') + g.add_argument('--density_weighting', action="store_true", + help='Use density-weighting for the metric weighted ' + 'matrix.') + g.add_argument('--no_self_connection', action="store_true", + help='Eliminate the diagonal from the matrices.') + g.add_argument('--force_labels_list', help='Path to a labels list (.txt) in case of missing ' 'labels in the atlas.') @@ -280,86 +142,80 @@ def _build_arg_parser(): return p -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - logging.getLogger().setLevel(logging.getLevelName(args.verbose)) - coloredlogs.install(level=logging.getLevelName(args.verbose)) - - assert_inputs_exist(parser, [args.in_hdf5, args.in_labels], - args.force_labels_list) - - # Summarizing all options chosen by user in measures_to_compute. - measures_to_compute = [] - measures_output_filename = [] - if args.volume: - measures_to_compute.append('volume') - measures_output_filename.append(args.volume) - if args.streamline_count: - measures_to_compute.append('streamline_count') - measures_output_filename.append(args.streamline_count) - if args.length: - measures_to_compute.append('length') - measures_output_filename.append(args.length) - if args.similarity: - measures_to_compute.append('similarity') - measures_output_filename.append(args.similarity[1]) - - # Adding measures from pre-computed maps. - dict_maps_out_name = {} - if args.maps is not None: - for in_folder, out_name in args.maps: - measures_to_compute.append(in_folder) - dict_maps_out_name[in_folder] = out_name - measures_output_filename.append(out_name) - - # Adding measures from pre-computed metrics. - dict_metrics_out_name = {} +def check_inputs_outputs(parser, args): + optional_input_volumes = [] + optional_output_matrices = [args.volume, args.streamline_count, + args.length] + out_dirs = [args.include_dps] if args.metrics is not None: - for in_name, out_name in args.metrics: - # Verify that all metrics are compatible with each other - if not is_header_compatible(args.metrics[0][0], in_name): - raise IOError('Metrics {} and {} do not share a compatible ' - 'header'.format(args.metrics[0][0], in_name)) - - # This is necessary to support more than one map for weighting - measures_to_compute.append((in_name, nib.load(in_name))) - dict_metrics_out_name[in_name] = out_name - measures_output_filename.append(out_name) - - # Adding measures from lesions. - dict_lesion_out_name = {} + optional_input_volumes.extend([m[0] for m in args.metrics]) + optional_output_matrices.extend([m[1] for m in args.metrics]) if args.lesion_load is not None: - in_name = args.lesion_load[0] - lesion_img = nib.load(in_name) - lesion_data = get_data_as_mask(lesion_img, dtype=bool) - lesion_atlas, _ = ndi.label(lesion_data) - measures_to_compute.append(((in_name, np.unique(lesion_atlas)[1:]), - nib.Nifti1Image(lesion_atlas, - lesion_img.affine))) + optional_input_volumes.append(args.lesion_load[0]) + optional_output_matrices.append(args.lesion_load[1]) + if args.similarity is not None: + optional_output_matrices.append(args.similarity[1]) + # Note. Inputs in the --similarity folder are not checked yet! + # But at least checking that the folder exists + assert_inputs_dirs_exist(parser, [], args.similarity[0]) + if args.lesion_load is not None: + out_dirs.append(args.lesion_load[1]) + + # Inputs + assert_inputs_exist(parser, [args.in_hdf5, args.in_labels], + [args.force_labels_list] + optional_input_volumes) + + # Headers + assert_headers_compatible(parser, args.in_labels, optional_input_volumes) + with h5py.File(args.in_hdf5, 'r') as hdf5: + vol = nib.load(args.in_labels) + assert_header_compatible_hdf5(hdf5, vol) + + # Outputs + assert_outputs_exist(parser, args, [], optional_output_matrices) + assert_output_dirs_exist_and_empty(parser, args, [], out_dirs) + for m in optional_output_matrices: + if m is not None and m[-4:] != '.npy': + parser.error("Expecting .npy for the output matrix, got: {}" + .format(m)) + + +def fill_matrix_and_save(measures_dict, labels_list, measure_keys, filenames): + matrix = np.zeros((len(labels_list), len(labels_list), len(measure_keys))) + + # Run one loop on node. Fill all matrices at once. + for label_key, node_values in measures_dict.items(): + in_label, out_label = label_key + for i, measure_key in enumerate(measure_keys): + in_pos = labels_list.index(in_label) + out_pos = labels_list.index(out_label) + matrix[in_pos, out_pos, i] = node_values[measure_key] + matrix[out_pos, in_pos, i] = node_values[measure_key] - out_name_1 = os.path.join(args.lesion_load[1], 'lesion_vol.npy') - out_name_2 = os.path.join(args.lesion_load[1], 'lesion_count.npy') - out_name_3 = os.path.join(args.lesion_load[1], 'lesion_sc.npy') + for i, f in enumerate(filenames): + logging.info("Saving resulting {} in file {}" + .format(measure_keys[i], f)) + np.save(f, matrix[:, :, i]) - dict_lesion_out_name[in_name+'vol'] = out_name_1 - dict_lesion_out_name[in_name+'count'] = out_name_2 - dict_lesion_out_name[in_name+'sc'] = out_name_3 - measures_output_filename.extend([out_name_1, out_name_2, out_name_3]) - # Verifying all outputs that will be used for all measures. - assert_outputs_exist(parser, args, measures_output_filename) - if not measures_to_compute: - parser.error('No connectivity measures were selected, nothing ' - 'to compute.') +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + coloredlogs.install(level=logging.getLevelName(args.verbose)) - logging.info('The following measures will be computed and save: {}'.format( - measures_output_filename)) + # Verifications + check_inputs_outputs(parser, args) - if args.include_dps: - if not os.path.isdir(args.include_dps): - os.makedirs(args.include_dps) - logging.info('data_per_streamline weighting is activated.') + # Verifying that at least one option is selected + compute_volume = args.volume is not None + compute_streamline_count = args.streamline_count is not None + compute_length = args.length is not None + similarity_directory = args.similarity[0] if args.similarity else None + if not (compute_volume or compute_streamline_count or compute_length or + similarity_directory is not None or len(args.metrics) > 0 or + args.lesion_load is not None or args.include_dps): + parser.error("Please select at least one output matrix to compute.") # Loading the data img_labels = nib.load(args.in_labels) @@ -369,6 +225,29 @@ def main(): else: labels_list = np.loadtxt( args.force_labels_list, dtype=np.int16).tolist() + logging.info("Found {} labels.".format(len(labels_list))) + + # Not preloading the similarity (density) files, as there are many + # (one per node). Can be loaded and discarded when treating each node. + + # Preloading the metrics here (FA, T1) to avoid reloading for each + # node! But if there are many metrics, this could be heavy to keep in + # memory, especially if multiprocessing is used. Still probably better. + metrics_data = [] + metrics_names = [] + for m in args.metrics: + metrics_names.append(m[1]) + metrics_data.append(nib.load(m[0]).get_fdata(dtype=np.float64)) + + # Preloading the lesion file + lesion_data = None + if args.lesion_load is not None: + lesion_img = nib.load(args.lesion_load[0]) + lesion_data = get_data_as_mask(lesion_img, dtype=bool) + lesion_atlas, _ = ndi.label(lesion_data) + lesion_labels = np.unique(lesion_atlas)[1:] + atlas_img = nib.Nifti1Image(lesion_atlas, lesion_img.affine) + lesion_data = (lesion_labels, atlas_img) # Finding all connectivity combo (start-finish) comb_list = list(itertools.combinations(labels_list, r=2)) @@ -377,87 +256,88 @@ def main(): # Running everything! nbr_cpu = validate_nbr_processes(parser, args) - measures_dict_list = [] + outputs = [] if nbr_cpu == 1: for comb in comb_list: - measures_dict_list.append(_processing_wrapper( - [args.in_hdf5, - img_labels, comb, - measures_to_compute, - args.similarity, - args.density_weighting, - args.include_dps, - args.min_lesion_vol])) + outputs.append(compute_connectivity_matrices_from_hdf5( + args.in_hdf5, img_labels, comb[0], comb[1], + compute_volume, compute_streamline_count, compute_length, + similarity_directory, metrics_data, metrics_names, + lesion_data, args.include_dps, args.density_weighting, + args.min_lesion_vol)) else: pool = multiprocessing.Pool(nbr_cpu) - measures_dict_list = pool.map(_processing_wrapper, - zip(itertools.repeat(args.in_hdf5), - itertools.repeat(img_labels), - comb_list, - itertools.repeat( - measures_to_compute), - itertools.repeat(args.similarity), - itertools.repeat( - args.density_weighting), - itertools.repeat(args.include_dps), - itertools.repeat(args.min_lesion_vol) - )) + + # Dividing the process bundle by bundle + outputs = pool.map( + multi_proc_compute_connectivity_matrices_from_hdf5, + zip(itertools.repeat(args.in_hdf5), + itertools.repeat(img_labels), + comb_list, + itertools.repeat(compute_volume), + itertools.repeat(compute_streamline_count), + itertools.repeat(compute_length), + itertools.repeat(similarity_directory), + itertools.repeat(metrics_data), + itertools.repeat(metrics_names), + itertools.repeat(lesion_data), + itertools.repeat(args.include_dps), + itertools.repeat(args.density_weighting), + itertools.repeat(args.min_lesion_vol) + )) pool.close() pool.join() # Removing None entries (combinaisons that do not exist) - # Fusing the multiprocessing output into a single dictionary - measures_dict_list = [it for it in measures_dict_list if it is not None] - if not measures_dict_list: - raise ValueError('Empty matrix, no entries to save.') - measures_dict = measures_dict_list[0] - for dix in measures_dict_list[1:]: - measures_dict.update(dix) - - if args.no_self_connection: - total_elem = len(labels_list)**2 - len(labels_list) - results_elem = len(measures_dict.keys())*2 - len(labels_list) - else: - total_elem = len(labels_list)**2 - results_elem = len(measures_dict.keys())*2 + outputs = [it for it in outputs if it is not None] + if len(outputs) == 0: + raise ValueError('No connection found at all! Matrices would be ' + 'all-zeros. Exiting.') - logging.info('Out of {} possible nodes, {} contain value'.format( - total_elem, results_elem)) + measures_dict_list = [it[0] for it in outputs] + dps_keys = [it[1] for it in outputs] - # Filling out all the matrices (symmetric) in the order of labels_list - nbr_of_measures = len(list(measures_dict.values())[0]) - matrix = np.zeros((len(labels_list), len(labels_list), nbr_of_measures)) + # Verify that all bundles had the same dps_keys + if len(dps_keys) > 1 and not dps_keys[1:] == dps_keys[:-1]: + raise ValueError("DPS keys not consistant throughout the hdf5 " + "connections. Verify your tractograms, or do not " + "use --include_dps.") + dps_keys = dps_keys[0] - for in_label, out_label in measures_dict: - curr_node_dict = measures_dict[(in_label, out_label)] - measures_ordering = list(curr_node_dict.keys()) + # Fusing the multiprocessing output into a single dictionary + measures_dict = {} + for node in measures_dict_list: + measures_dict.update(node) - for i, measure in enumerate(curr_node_dict): - in_pos = labels_list.index(in_label) - out_pos = labels_list.index(out_label) - matrix[in_pos, out_pos, i] = curr_node_dict[measure] - matrix[out_pos, in_pos, i] = curr_node_dict[measure] - - # Saving the matrices separatly with the specified name or dps - for i, measure in enumerate(measures_ordering): - if measure == 'volume': - matrix_basename = args.volume - elif measure == 'streamline_count': - matrix_basename = args.streamline_count - elif measure == 'length': - matrix_basename = args.length - elif measure == 'similarity': - matrix_basename = args.similarity[1] - elif measure in dict_metrics_out_name: - matrix_basename = dict_metrics_out_name[measure] - elif measure in dict_maps_out_name: - matrix_basename = dict_maps_out_name[measure] - elif measure in dict_lesion_out_name: - matrix_basename = dict_lesion_out_name[measure] - else: - matrix_basename = measure - - np.save(matrix_basename, matrix[:, :, i]) + # Filling out all the matrices (symmetric) in the order of labels_list + keys = [] + filenames = [] + if compute_volume: + keys.append('volume_mm3') + filenames.append(args.volume) + if compute_length: + keys.append('length_mm') + filenames.append(args.length) + if compute_streamline_count: + keys.append('streamline_count') + filenames.append(args.streamline_count) + if similarity_directory is not None: + keys.append('similarity') + filenames.append(args.similarity[1]) + if len(args.metrics) > 0: + keys.extend(metrics_names) + filenames.extend([m[1] for m in args.metrics]) + if args.lesion_load is not None: + keys.extend(['lesion_vol', 'lesion_count', 'lesion_streamline_count']) + filenames.extend( + [os.path.join(args.lesion_load[1], 'lesion_vol.npy'), + os.path.join(args.lesion_load[1], 'lesion_count.npy'), + os.path.join(args.lesion_load[1], 'lesion_sc.npy')]) + if args.include_dps: + keys.extend(dps_keys) + filenames.extend([os.path.join(args.include_dps, "{}.npy".format(k)) + for k in dps_keys]) + fill_matrix_and_save(measures_dict, labels_list, keys, filenames) if __name__ == "__main__": diff --git a/scripts/scil_connectivity_compute_simple_matrix.py b/scripts/scil_connectivity_compute_simple_matrix.py new file mode 100644 index 000000000..2b954568b --- /dev/null +++ b/scripts/scil_connectivity_compute_simple_matrix.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Computes a very simple connectivity matrix, using the streamline count and the +position of the streamlines' endpoints. + +This script is intented for exploration of your data. For a more thorough +computation (using the longest streamline segment), and for more options about +the weights of the matrix, see: +>> scil_connectivity_compute_matrices.py + +Contrary to scil_connectivity_compute_matrices.py, works with an incomplete +parcellation (i.e. with streamlines ending in the background). + +In the output figure, 4 matrices are shown, all using the streamline count: + - Raw count + - Raw count (log view) + - Binary matrix (if at least 1 streamline connects the two regions) + - Percentage of the total streamline count. + +You may select which matrix to save to disk (as .npy) using options --binary or +--percentage. Default ouput matrix is the raw count. +""" + +import argparse +import logging +import os.path + +import matplotlib.pyplot as plt +from matplotlib.colors import LogNorm +from mpl_toolkits.axes_grid1 import make_axes_locatable +import nibabel as nib +import numpy as np + +from scilpy.connectivity.connectivity import \ + compute_triu_connectivity_from_labels +from scilpy.image.labels import get_data_as_labels + +from scilpy.io.streamlines import load_tractogram_with_reference +from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, \ + add_verbose_arg, add_overwrite_arg, assert_headers_compatible, \ + add_reference_arg + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('in_tractogram', + help='Tractogram (trk or tck).') + p.add_argument('in_labels', + help='Input nifti volume.') + p.add_argument('out_matrix', + help="Out .npy file.") + p.add_argument('out_labels', + help="Out .txt file. Will show the ordered labels (i.e. " + "the columns and lines' tags).") + + g = p.add_argument_group("Label management options") + g.add_argument('--keep_background', action='store_true', + help="By default, the background (label 0) is not included " + "in the matrix. \nUse this option to keep it.") + g.add_argument('--hide_labels', metavar='label', nargs='+', + help="Set given labels' weights to 0 in the matrix. \n" + "Their row and columns wil be kept but set to 0.") + + g = p.add_argument_group("Figure options") + g.add_argument('--hide_fig', action='store_true', + help="If set, does not show the matrices with matplotlib " + "(you can still use --out_fig)") + g.add_argument('--out_fig', metavar='file.png', + help="If set, saves the figure to file. \nExtension can be " + "any format understood by matplotlib (ex, .png).") + + g = p.add_argument_group("Output matrix (.npy) options") + g = g.add_mutually_exclusive_group() + g.add_argument('--binary', action='store_true', + help="If set, saves the result as binary. Else, the " + "streamline count is saved.") + g.add_argument('--percentage', action='store_true') + + add_verbose_arg(p) + add_reference_arg(p) + add_overwrite_arg(p) + + return p + + +def prepare_figure_connectivity(matrix): + matrix = np.copy(matrix) + + fig, axs = plt.subplots(2, 2) + im = axs[0, 0].imshow(matrix) + divider = make_axes_locatable(axs[0, 0]) + cax = divider.append_axes('right', size='5%', pad=0.05) + fig.colorbar(im, cax=cax, orientation='vertical') + axs[0, 0].set_title("Raw streamline count") + + im = axs[0, 1].imshow(matrix + np.min(matrix[matrix > 0]), norm=LogNorm()) + divider = make_axes_locatable(axs[0, 1]) + cax = divider.append_axes('right', size='5%', pad=0.05) + fig.colorbar(im, cax=cax, orientation='vertical') + axs[0, 1].set_title("Raw streamline count (log view)") + + matrix = matrix / matrix.sum() * 100 + im = axs[1, 0].imshow(matrix) + divider = make_axes_locatable(axs[1, 0]) + cax = divider.append_axes('right', size='5%', pad=0.05) + fig.colorbar(im, cax=cax, orientation='vertical') + axs[1, 0].set_title("Percentage of the total streamline count") + + matrix = matrix > 0 + axs[1, 1].imshow(matrix) + axs[1, 1].set_title("Binary matrix: 1 if at least 1 streamline") + + plt.suptitle("Connectivity matrix: streamline count") + + +def main(): + p = _build_arg_parser() + args = p.parse_args() + logging.getLogger().setLevel(args.verbose) + if args.verbose == 'DEBUG': + # Currently, with debug, matplotlib prints a lot of stuff. Why?? + logging.getLogger().setLevel(logging.INFO) + + # Verifications + tmp, ext = os.path.splitext(args.out_matrix) + if ext != '.npy': + p.error("out_matrix should have a .npy extension.") + + assert_inputs_exist(p, [args.in_labels, args.in_tractogram], + args.reference) + assert_headers_compatible(p, [args.in_labels, args.in_tractogram], [], + args.reference) + assert_outputs_exist(p, args, args.out_matrix, args.out_fig) + + # Loading + in_sft = load_tractogram_with_reference(p, args, args.in_tractogram) + in_img = nib.load(args.in_labels) + data_labels = get_data_as_labels(in_img) + + # Computing + matrix, ordered_labels, _, _ = \ + compute_triu_connectivity_from_labels( + in_sft, data_labels, keep_background=args.keep_background, + hide_labels=args.hide_labels) + + # Save figure will all versions of the matrix. + if (not args.hide_fig) or args.out_fig is not None: + prepare_figure_connectivity(matrix) + + if args.out_fig is not None: + plt.savefig(args.out_fig) + + # Save matrix + if args.binary: + matrix = matrix > 0 + elif args.percentage: + matrix = matrix / matrix.sum() * 100 + np.save(args.out_matrix, matrix) + + # Save labels + with open(args.out_labels, "w") as text_file: + for i, label in enumerate(ordered_labels): + text_file.write("{} = {}\n".format(i, label)) + + # Showing as last step. Everything else is done, so if user closes figure + # it's fine. + if not args.hide_fig: + plt.show() + + +if __name__ == '__main__': + main() diff --git a/scripts/scil_connectivity_graph_measures.py b/scripts/scil_connectivity_graph_measures.py index de3e691ec..8c8a37e2d 100755 --- a/scripts/scil_connectivity_graph_measures.py +++ b/scripts/scil_connectivity_graph_measures.py @@ -36,7 +36,7 @@ import logging import os -from scilpy.connectivity.connectivity_tools import evaluate_graph_measures +from scilpy.connectivity.matrix_tools import evaluate_graph_measures from scilpy.io.utils import (add_json_args, add_overwrite_arg, add_verbose_arg, diff --git a/scripts/scil_connectivity_normalize.py b/scripts/scil_connectivity_normalize.py index 518f1e635..8a3fd8a06 100755 --- a/scripts/scil_connectivity_normalize.py +++ b/scripts/scil_connectivity_normalize.py @@ -50,7 +50,7 @@ import nibabel as nib import numpy as np -from scilpy.connectivity.connectivity_tools import \ +from scilpy.connectivity.matrix_tools import \ normalize_matrix_from_values, normalize_matrix_from_parcel from scilpy.image.volume_math import normalize_max, normalize_sum, base_10_log from scilpy.io.utils import (add_overwrite_arg, diff --git a/scripts/scil_connectivity_reorder_rois.py b/scripts/scil_connectivity_reorder_rois.py index 0585dff43..9ed78ab5d 100755 --- a/scripts/scil_connectivity_reorder_rois.py +++ b/scripts/scil_connectivity_reorder_rois.py @@ -29,8 +29,8 @@ import numpy as np -from scilpy.connectivity.connectivity_tools import (compute_olo, - apply_reordering) +from scilpy.connectivity.matrix_tools import (compute_olo, + apply_reordering) from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, load_matrix_in_any_format, diff --git a/scripts/scil_fibertube_score_tractogram.py b/scripts/scil_fibertube_score_tractogram.py new file mode 100644 index 000000000..b8272c148 --- /dev/null +++ b/scripts/scil_fibertube_score_tractogram.py @@ -0,0 +1,247 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Given ground-truth fibertubes and a tractogram obtained through fibertube +tracking, computes metrics about the quality of individual fiber +reconstruction. + +IMPORTANT: Streamlines given as input to be scored should be forward-only, +which means they are saved so that [0] is the seeding position and [-1] is +the end. +TODO: Add the seed's segment index as dps, to allow different seeding methods +and forward_only=False. + +Each streamline is associated with an "Arrival fibertube segment", which is +the closest fibertube segment to its before-last coordinate. We then define +the following terms: + +VC: "Valid Connection": A streamline whose arrival fibertube segment is +the final segment of the fibertube in which is was originally seeded. + +IC: "Invalid Connection": A streamline whose arrival fibertube segment is +the start or final segment of a fibertube in which is was not seeded. + +NC: "No Connection": A streamline whose arrival fibertube segment is +not the start or final segment of any fibertube. + +The "absolute error" of a coordinate is the distance in mm between that +coordinate and the closest point on its corresponding fibertube. The average +of all coordinate absolute errors of a streamline is called the "Mean absolute +error" or "mae". + +Computed metrics: + - vc_ratio + Number of VC divided by the number of streamlines. + - ic_ratio + Number of IC divided by the number of streamlines. + - nc_ratio + Number of NC divided by the number of streamlines. + - mae_min + Minimum MAE for the tractogram. + - mae_max + Maximum MAE for the tractogram. + - mae_mean + Average MAE for the tractogram. + - mae_med + Median MAE for the tractogram. + +See also: + - scil_tractogram_filter_collisions.py to prepare data for fibertube + tracking + - scil_fibertube_tracking.py to perform a fibertube tracking + - docs/source/documentation/fibertube_tracking.rst +""" + +import os +import json +import numba +import argparse +import logging +import numpy as np +import nibabel as nib + +from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin +from dipy.io.streamline import save_tractogram, load_tractogram +from scilpy.tractanalysis.fibertube_scoring import \ + resolve_origin_seeding, endpoint_connectivity, mean_reconstruction_error +from scilpy.tractograms.streamline_operations import \ + get_streamlines_as_fixed_array +from scilpy.io.utils import (assert_inputs_exist, + assert_outputs_exist, + add_overwrite_arg, + add_verbose_arg, + add_json_args) + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + p.add_argument('in_fibertubes', + help='Path to the tractogram file (must be .trk) \n' + 'containing ground-truth fibertubes. They must be: \n' + '1- Void of any collision. \n' + '2- With their respective diameter saved \n' + 'as data_per_streamline. \n' + 'For both of these requirements, see \n' + 'scil_tractogram_filter_collisions.') + + p.add_argument('in_tracking', + help='Path to the tractogram file (must be .trk) \n' + 'containing the reconstruction of ground-truth \n' + 'fibertubes made from fibertube tracking. Seeds \n' + 'used for tracking must be saved as \n' + 'data_per_streamline.') + + p.add_argument('in_config', + help='Path to a json file containing the fibertube \n' + 'parameters used for the tracking process.') + + p.add_argument('out_metrics', + help='Output file containing the computed measures and \n' + 'metrics (must be .json).') + + p.add_argument('--save_error_tractogram', action='store_true', + help='If set, a .trk file will be saved, containing a \n' + 'visual representation of all the coordinate absolute \n' + 'errors of the entire tractogram. The file name is \n' + 'derived from the out_metrics parameter.') + + p.add_argument( + '--out_tracked_fibertubes', type=str, default=None, + help='If set, the fibertubes that were used for seeding will be \n' + 'saved separately at the specified location (must be .trk or \n' + '.tck). This parameter is not required for scoring the tracking \n' + 'result, as the seeding information of each streamline is always \n' + 'saved as data_per_streamline.') + + add_json_args(p) + add_verbose_arg(p) + add_overwrite_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + logging.getLogger('numba').setLevel(logging.WARNING) + + if os.path.splitext(args.in_fibertubes)[1] != '.trk': + parser.error('Invalid input streamline file format (must be trk):' + + '{0}'.format(args.in_fibertubes)) + + if not nib.streamlines.is_supported(args.in_tracking): + parser.error('Invalid output streamline file format (must be trk ' + + 'or tck): {0}'.format(args.in_tracking)) + + if os.path.splitext(args.in_config)[1] != '.json': + parser.error('Invalid input streamline file format (must be json):' + + '{0}'.format(args.in_config)) + + out_metrics_no_ext, ext = os.path.splitext(args.out_metrics) + if ext != '.json': + parser.error('Invalid output file format (must be json): {0}' + .format(args.out_metrics)) + + if args.out_tracked_fibertubes: + if not nib.streamlines.is_supported(args.out_tracked_fibertubes): + parser.error('Invalid output streamline file format (must be ' + + 'trk or tck):' + + '{0}'.format(args.out_tracked_fibertubes)) + + assert_inputs_exist(parser, [args.in_fibertubes, args.in_config, + args.in_tracking]) + assert_outputs_exist(parser, args, [args.out_metrics], + [args.out_tracked_fibertubes]) + + our_space = Space.VOXMM + our_origin = Origin('center') + + logging.debug('Loading centerline tractogram & diameters') + truth_sft = load_tractogram(args.in_fibertubes, 'same', our_space, + our_origin) + centerlines = truth_sft.get_streamlines_copy() + centerlines, centerlines_length = get_streamlines_as_fixed_array( + centerlines) + + if "diameters" not in truth_sft.data_per_streamline: + parser.error('No diameters found as data per streamline on ' + + args.in_fibertubes) + diameters = np.reshape(truth_sft.data_per_streamline['diameters'], + len(centerlines)) + + logging.debug('Loading reconstructed tractogram') + in_sft = load_tractogram(args.in_tracking, 'same', our_space, + our_origin) + streamlines = in_sft.get_streamlines_copy() + streamlines, streamlines_length = get_streamlines_as_fixed_array( + streamlines) + + logging.debug("Loading seeds") + if "seeds" not in in_sft.data_per_streamline: + parser.error('No seeds found as data per streamline on ' + + args.in_tracking) + + seeds = in_sft.data_per_streamline['seeds'] + seeds_fiber = resolve_origin_seeding(seeds, centerlines, diameters) + + logging.debug("Loading config") + with open(args.in_config, 'r') as f: + config = json.load(f) + blur_radius = float(config['blur_radius']) + + if len(seeds_fiber) != len(streamlines): + raise ValueError('Could not resolve origin seeding regions') + for num in seeds_fiber: + if num == -1: + raise ValueError('Could not resolve origin seeding regions') + + if args.out_tracked_fibertubes: + # Set for removing doubles + tracked_fibertubes_indices = set(seeds_fiber) + tracked_fibertubes = [] + + for fi in tracked_fibertubes_indices: + tracked_fibertubes.append(centerlines[fi][:centerlines_length[fi]]) + + tracked_sft = StatefulTractogram.from_sft(tracked_fibertubes, + truth_sft) + save_tractogram(tracked_sft, args.out_tracked_fibertubes, + bbox_valid_check=False) + + logging.debug("Computing endpoint connectivity") + vc, ic, nc = endpoint_connectivity(blur_radius, centerlines, + centerlines_length, diameters, + streamlines, seeds_fiber) + + logging.debug("Computing reconstruction error") + (mean_errors, error_tractogram) = mean_reconstruction_error( + centerlines, centerlines_length, diameters, streamlines, + streamlines_length, seeds_fiber, args.save_error_tractogram) + + metrics = { + 'vc_ratio': len(vc)/len(streamlines), + 'ic_ratio': len(ic)/len(streamlines), + 'nc_ratio': len(nc)/len(streamlines), + 'mae_min': np.min(mean_errors), + 'mae_max': np.max(mean_errors), + 'mae_mean': np.mean(mean_errors), + 'mae_med': np.median(mean_errors), + } + with open(args.out_metrics, 'w') as outfile: + json.dump(metrics, outfile, + indent=args.indent, sort_keys=args.sort_keys) + + if args.save_error_tractogram: + sft = StatefulTractogram.from_sft(error_tractogram, truth_sft) + save_tractogram(sft, out_metrics_no_ext + '.trk', + bbox_valid_check=False) + + +if __name__ == '__main__': + main() diff --git a/scripts/scil_fibertube_tracking.py b/scripts/scil_fibertube_tracking.py new file mode 100644 index 000000000..6b5ca536e --- /dev/null +++ b/scripts/scil_fibertube_tracking.py @@ -0,0 +1,281 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Implementation of the fibertube tracking environment using the +architecture of scil_local_tracking_dev.py. + +Contrary to traditional white matter fiber tractography, fibertube +tractography does not rely on a discretized grid of fODFs or peaks. It +directly tracks and reconstructs fibertubes, i.e. streamlines that have an +associated diameter. + +When the tracking algorithm is about to select a new direction to propagate +the current streamline, it will build a sphere of radius blur_radius and pick +randomly from all the fibertube segments intersecting with it. The larger the +intersection volume, the more likely a fibertube segment is to be picked and +used as a tracking direction. This makes fibertube tracking inherently +probabilistic. + +Possible tracking directions are filtered to respect the aperture cone defined +by the previous tracking direction and the angular constraint. + +Seeding is done within the first segment of each fibertube. + +For a better understanding of Fibertube Tracking please see: + - docs/source/documentation/fibertube_tracking.rst +""" + +import os +import json +import time +import argparse +import logging +import numpy as np +import nibabel as nib +import dipy.core.geometry as gm + +from scilpy.tracking.seed import FibertubeSeedGenerator +from scilpy.tracking.propagator import FibertubePropagator +from scilpy.image.volume_space_management import FibertubeDataVolume +from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin +from dipy.io.streamline import load_tractogram, save_tractogram +from scilpy.tracking.tracker import Tracker +from scilpy.image.volume_space_management import DataVolume +from scilpy.io.utils import (assert_inputs_exist, + assert_outputs_exist, + add_processes_arg, + add_verbose_arg, + add_json_args, + add_overwrite_arg) + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + + p.add_argument('in_fibertubes', + help='Path to the tractogram (must be .trk) file \n' + 'containing fibertubes. They must be: \n' + '1- Void of any collision. \n' + '2- With their respective diameter saved \n' + 'as data_per_streamline. \n' + 'For both of these requirements, see \n' + 'scil_tractogram_filter_collisions.py.') + + p.add_argument('out_tractogram', + help='Tractogram output file (must be .trk or .tck).') + + track_g = p.add_argument_group('Tracking options') + track_g.add_argument( + '--blur_radius', type=float, default=0.1, + help='Radius of the spherical region from which the \n' + 'algorithm will determine the next direction. \n' + 'A blur_radius within [0.001, 0.5] is recommended. \n' + '[%(default)s]') + track_g.add_argument( + '--step_size', type=float, default=0.1, + help='Step size of the tracking algorithm, in mm. \n' + 'It is recommended to use the same value as the \n' + 'blur_radius, in the interval [0.001, 0.5] \n' + 'The step_size should never exceed twice the \n' + 'blur_radius. [%(default)s]') + track_g.add_argument( + '--min_length', type=float, default=10., + metavar='m', + help='Minimum length of a streamline in mm. ' + '[%(default)s]') + track_g.add_argument( + '--max_length', type=float, default=300., + metavar='M', + help='Maximum length of a streamline in mm. ' + '[%(default)s]') + track_g.add_argument( + '--theta', type=float, default=60., + help='Maximum angle between 2 steps. If the angle is ' + 'too big, streamline is \nstopped and the ' + 'following point is NOT included.\n' + '[%(default)s]') + track_g.add_argument( + '--rk_order', metavar="K", type=int, default=1, + choices=[1, 2, 4], + help="The order of the Runge-Kutta integration used \n" + 'for the step function. \n' + 'For more information, refer to the note in the \n' + 'script description. [%(default)s]') + track_g.add_argument( + '--max_invalid_nb_points', metavar='MAX', type=int, + default=0, + help='Maximum number of steps without valid \n' + 'direction, \nex: No fibertube intersecting the \n' + 'tracking sphere or max angle is reached.\n' + 'Default: 0, i.e. do not add points following ' + 'an invalid direction.') + track_g.add_argument( + '--keep_last_out_point', action='store_true', + help='If set, keep the last point (once out of the \n' + 'tracking mask) of the streamline. Default: discard \n' + 'them. This is the default in Dipy too. \n' + 'Note that points obtained after an invalid direction \n' + '(based on the propagator\'s definition of invalid) \n' + 'are never added.') + + seed_group = p.add_argument_group( + 'Seeding options', + 'When no option is provided, uses --nb_seeds_per_fibertube 5.') + seed_group.add_argument( + '--nb_seeds_per_fibertube', type=int, default=5, + help='The number of seeds planted in the first segment \n' + 'of each fibertube. The total amount of streamlines will \n' + 'be [nb_seeds_per_fibertube] * [nb_fibertubes]. [%(default)s]') + seed_group.add_argument( + '--nb_fibertubes', type=int, + help='If set, the script will only track a specified \n' + 'amount of fibers. Otherwise, the entire tractogram \n' + 'will be tracked. The total amount of streamlines \n' + 'will be [nb_seeds_per_fibertube] * [nb_fibertubes].') + + rand_g = p.add_argument_group('Random options') + rand_g.add_argument( + '--rng_seed', type=int, default=0, + help='If set, all random values will be generated \n' + 'using the specified seed. [%(default)s]') + rand_g.add_argument( + '--skip', type=int, default=0, + help="Skip the first N seeds. \n" + "Useful if you want to create new streamlines to " + "add to \na previously created tractogram with a " + "fixed --rng_seed.\nEx: If tractogram_1 was created " + "with -nt 1,000,000, \nyou can create tractogram_2 " + "with \n--skip 1,000,000.") + + out_g = p.add_argument_group('Output options') + out_g.add_argument( + '--out_config', default=None, type=str, + help='If set, the parameter configuration used for tracking will \n' + 'be saved at the specified location (must be .json). If not given, \n' + 'the config will be printed in the console.') + + add_json_args(out_g) + add_overwrite_arg(out_g) + add_processes_arg(p) + add_verbose_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + logging.getLogger('numba').setLevel(logging.WARNING) + + if os.path.splitext(args.in_fibertubes)[1] != '.trk': + parser.error('Invalid input streamline file format (must be trk):' + + '{0}'.format(args.in_fibertubes)) + + if not nib.streamlines.is_supported(args.out_tractogram): + parser.error('Invalid output streamline file format (must be trk ' + + 'or tck): {0}'.format(args.out_tractogram)) + + if args.out_config: + if os.path.splitext(args.out_config)[1] != '.json': + parser.error('Invalid output file format (must be json): {0}' + .format(args.out_config)) + + assert_inputs_exist(parser, [args.in_fibertubes]) + assert_outputs_exist(parser, args, [args.out_tractogram], + [args.out_config]) + + theta = gm.math.radians(args.theta) + + max_nbr_pts = int(args.max_length / args.step_size) + min_nbr_pts = max(int(args.min_length / args.step_size), 1) + + our_space = Space.VOXMM + our_origin = Origin('center') + + logging.debug('Loading tractogram & diameters') + in_sft = load_tractogram(args.in_fibertubes, 'same', our_space, our_origin) + centerlines = list(in_sft.get_streamlines_copy()) + diameters = np.reshape(in_sft.data_per_streamline['diameters'], + len(centerlines)) + + logging.debug("Instantiating datavolumes") + # The scilpy Tracker requires a mask for tracking, but fibertube tracking + # aims to eliminate grids (or masks) in tractography. Instead, the tracking + # stops when no more fibertubes are detected by the Tracker. + + # Since the scilpy Tracker requires a mask, we provide a fake one that will + # never interfere. + fake_mask_data = np.ones(in_sft.dimensions) + fake_mask = DataVolume(fake_mask_data, in_sft.voxel_sizes, 'nearest') + datavolume = FibertubeDataVolume(centerlines, diameters, in_sft, + args.blur_radius, + np.random.default_rng(args.rng_seed)) + + logging.debug("Instantiating seed generator") + seed_generator = FibertubeSeedGenerator(centerlines, diameters, + args.nb_seeds_per_fibertube) + + logging.debug("Instantiating propagator") + propagator = FibertubePropagator(datavolume, args.step_size, + args.rk_order, theta, our_space, + our_origin) + + logging.debug("Instantiating tracker") + max_nbr_seeds = args.nb_seeds_per_fibertube * len(centerlines) + if args.nb_fibertubes: + if args.nb_fibertubes > len(centerlines): + raise ValueError("The provided number of seeded fibers exceeds" + + "the number of available fibertubes.") + else: + nbr_seeds = args.nb_seeds_per_fibertube * args.nb_fibertubes + else: + nbr_seeds = max_nbr_seeds + + if args.skip and nbr_seeds + args.skip > max_nbr_seeds: + raise ValueError("The number of seeds plus the number of skipped " + + "seeds requires more fibertubes than there are " + + "available.") + tracker = Tracker(propagator, fake_mask, seed_generator, nbr_seeds, + min_nbr_pts, max_nbr_pts, + args.max_invalid_nb_points, 0, + args.nbr_processes, True, 'r+', + rng_seed=args.rng_seed, + track_forward_only=True, + skip=args.skip, + verbose=args.verbose, + append_last_point=args.keep_last_out_point) + + start_time = time.time() + logging.debug("Tracking...") + streamlines, seeds = tracker.track() + str_time = "%.2f" % (time.time() - start_time) + logging.debug('Finished tracking in: ' + str_time + ' seconds') + + out_sft = StatefulTractogram.from_sft(streamlines, in_sft) + out_sft.data_per_streamline['seeds'] = seeds + save_tractogram(out_sft, args.out_tractogram) + + config = { + 'step_size': args.step_size, + 'blur_radius': args.blur_radius, + 'nb_fibertubes': (args.nb_fibertubes if args.nb_fibertubes + else len(centerlines)), + 'nb_seeds_per_fibertube': args.nb_seeds_per_fibertube + } + if args.out_config: + with open(args.out_config, 'w') as outfile: + json.dump(config, outfile, + indent=args.indent, sort_keys=args.sort_keys) + else: + print('Config:\n', + json.dumps(config, indent=args.indent, + sort_keys=args.sort_keys)) + + +if __name__ == "__main__": + main() diff --git a/scripts/scil_fodf_metrics.py b/scripts/scil_fodf_metrics.py index 7586898fe..4112a1086 100755 --- a/scripts/scil_fodf_metrics.py +++ b/scripts/scil_fodf_metrics.py @@ -163,7 +163,7 @@ def main(): # Computing maps if args.nufo or args.afd_max or args.afd_total or args.afd_sum or args.rgb: nufo_map, afd_max, afd_sum, rgb_map, \ - _, _ = maps_from_sh(data, peak_dirs, peak_values, peak_indices, + _, _ = maps_from_sh(data, peak_values, peak_indices, sphere, nbr_processes=args.nbr_processes) # Save result diff --git a/scripts/scil_frf_mean.py b/scripts/scil_frf_mean.py index 0005a665c..06838e40d 100755 --- a/scripts/scil_frf_mean.py +++ b/scripts/scil_frf_mean.py @@ -18,6 +18,7 @@ import numpy as np from scilpy.io.utils import (add_overwrite_arg, + add_precision_arg, assert_inputs_exist, add_verbose_arg, assert_outputs_exist) @@ -33,6 +34,7 @@ def _build_arg_parser(): p.add_argument('mean_frf', metavar='file', help='Path of the output mean FRF file.') + add_precision_arg(p) add_verbose_arg(p) add_overwrite_arg(p) @@ -66,7 +68,7 @@ def main(): final_frf = np.mean(all_frfs, axis=0) - np.savetxt(args.mean_frf, final_frf) + np.savetxt(args.mean_frf, final_frf, fmt=f"%.{args.precision}f") if __name__ == "__main__": diff --git a/scripts/scil_frf_memsmt.py b/scripts/scil_frf_memsmt.py index 5865ddd99..9f1a0c764 100755 --- a/scripts/scil_frf_memsmt.py +++ b/scripts/scil_frf_memsmt.py @@ -50,7 +50,8 @@ from scilpy.image.utils import extract_affine from scilpy.io.btensor import generate_btensor_input from scilpy.io.image import get_data_as_mask -from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, +from scilpy.io.utils import (add_overwrite_arg, add_precision_arg, + add_verbose_arg, assert_inputs_exist, assert_outputs_exist, assert_roi_radii_format, add_skip_b0_check_arg, add_tolerance_arg, @@ -169,6 +170,7 @@ def _build_arg_parser(): help='Path to the output CSF frf mask file, the voxels ' 'used to compute the CSF frf.') + add_precision_arg(p) add_verbose_arg(p) add_overwrite_arg(p) @@ -264,7 +266,7 @@ def main(): frf_out = [args.out_wm_frf, args.out_gm_frf, args.out_csf_frf] for frf, response in zip(frf_out, responses): - np.savetxt(frf, response) + np.savetxt(frf, response, fmt=f"%.{args.precision}f") if __name__ == "__main__": diff --git a/scripts/scil_frf_msmt.py b/scripts/scil_frf_msmt.py index d6a622026..ce9bc3fcd 100755 --- a/scripts/scil_frf_msmt.py +++ b/scripts/scil_frf_msmt.py @@ -36,7 +36,8 @@ from scilpy.dwi.utils import extract_dwi_shell from scilpy.gradients.bvec_bval_tools import check_b0_threshold from scilpy.io.image import get_data_as_mask -from scilpy.io.utils import (add_overwrite_arg, add_skip_b0_check_arg, +from scilpy.io.utils import (add_overwrite_arg, add_precision_arg, + add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, assert_roi_radii_format, assert_headers_compatible) @@ -136,6 +137,7 @@ def _build_arg_parser(): help='Path to the output CSF frf mask file, the voxels ' 'used to compute the CSF frf.') + add_precision_arg(p) add_verbose_arg(p) add_overwrite_arg(p) @@ -216,7 +218,7 @@ def main(): frf_out = [args.out_wm_frf, args.out_gm_frf, args.out_csf_frf] for frf, response in zip(frf_out, responses): - np.savetxt(frf, response) + np.savetxt(frf, response, fmt=f"%.{args.precision}f") if __name__ == "__main__": diff --git a/scripts/scil_frf_set_diffusivities.py b/scripts/scil_frf_set_diffusivities.py index 2424087f2..428cbce74 100755 --- a/scripts/scil_frf_set_diffusivities.py +++ b/scripts/scil_frf_set_diffusivities.py @@ -17,6 +17,7 @@ import numpy as np from scilpy.io.utils import (add_overwrite_arg, + add_precision_arg, assert_inputs_exist, add_verbose_arg, assert_outputs_exist) @@ -43,6 +44,7 @@ def _build_arg_parser(): 'evaluated without the x 10**-4 factor. [%(default)s].' ) + add_precision_arg(p) add_verbose_arg(p) add_overwrite_arg(p) @@ -59,7 +61,7 @@ def main(): frf_file = np.loadtxt(args.frf_file) response = replace_frf(frf_file, args.new_frf, args.no_factor) - np.savetxt(args.output_frf_file, response) + np.savetxt(args.output_frf_file, response, fmt=f"%.{args.precision}f") if __name__ == "__main__": diff --git a/scripts/scil_frf_ssst.py b/scripts/scil_frf_ssst.py index 925f3d584..d1f12c3aa 100755 --- a/scripts/scil_frf_ssst.py +++ b/scripts/scil_frf_ssst.py @@ -20,6 +20,7 @@ from scilpy.gradients.bvec_bval_tools import check_b0_threshold from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, + add_precision_arg, add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, assert_roi_radii_format, @@ -80,6 +81,7 @@ def _build_arg_parser(): add_b0_thresh_arg(p) add_skip_b0_check_arg(p, will_overwrite_with_min=True) + add_precision_arg(p) add_verbose_arg(p) add_overwrite_arg(p) @@ -117,7 +119,7 @@ def main(): min_fa_thresh=args.min_fa_thresh, min_nvox=args.min_nvox, roi_radii=roi_radii, roi_center=args.roi_center) - np.savetxt(args.frf_file, full_response) + np.savetxt(args.frf_file, full_response, fmt=f"%.{args.precision}f") if __name__ == "__main__": diff --git a/scripts/scil_mrds_metrics.py b/scripts/scil_mrds_metrics.py new file mode 100755 index 000000000..5d6230cd3 --- /dev/null +++ b/scripts/scil_mrds_metrics.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Script to compute FA/MD/RD/AD for +each Multi-ResolutionDiscrete-Search (MRDS) solution. +It will output the results in 4 different 4D files. +Each 4th dimension will correspond to each tensor in the MRDS solution. +e.g. FA of tensor D_1 will be in index 0 of the 4th dimension, + FA of tensor D_2 will be in index 1 of the 4th dimension, + FA of tensor D_3 will be in index 2 of the 4th dimension. +""" + +import logging +import numpy as np +import nibabel as nib +import argparse + +from dipy.reconst.dti import fractional_anisotropy + +from scilpy.io.image import get_data_as_mask +from scilpy.io.utils import (add_overwrite_arg, + add_verbose_arg, + assert_inputs_exist, assert_outputs_exist, + assert_headers_compatible) + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('in_evals', + help='MRDS eigenvalues file (Shape: [X, Y, Z, 9]).\n' + 'The last dimensions, values 1-3 are associated with ' + 'the first tensor (D_1), 4-6 with the second tensor ' + '(D_2), 7-9 with the third tensor (D_3).\n' + 'This file is one of the outputs of ' + 'scil_mrds_select_number_of_tensors.py ' + '(*_MRDS_evals.nii.gz).') + + p.add_argument('--mask', + help='Path to a binary mask.\nOnly data inside ' + 'the mask will be used for computations and ' + 'reconstruction. (Default: %(default)s)') + + p.add_argument( + '--not_all', action='store_true', + help='If set, will only save the metrics explicitly specified using ' + 'the other metrics flags. (Default: not set).') + + g = p.add_argument_group(title='MRDS-Metrics files flags') + g.add_argument('--fa', metavar='file', default='', + help='Output filename for the FA.') + g.add_argument('--ad', metavar='file', default='', + help='Output filename for the AD.') + g.add_argument('--rd', metavar='file', default='', + help='Output filename for the RD.') + g.add_argument('--md', metavar='file', default='', + help='Output filename for the MD.') + + add_verbose_arg(p) + add_overwrite_arg(p) + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + # Verifications + if not args.not_all: + args.fa = args.fa or 'mrds_fa.nii.gz' + args.ad = args.ad or 'mrds_ad.nii.gz' + args.rd = args.rd or 'mrds_rd.nii.gz' + args.md = args.md or 'mrds_md.nii.gz' + + assert_inputs_exist(parser, args.in_evals, args.mask) + assert_headers_compatible(parser, args.in_evals, args.mask) + assert_outputs_exist(parser, args, [], + optional=[args.fa, args.ad, args.rd, args.md]) + + evals_img = nib.load(args.in_evals) + lambdas = evals_img.get_fdata(dtype=np.float32) + + header = evals_img.header + affine = evals_img.affine + + X, Y, Z = lambdas.shape[0:3] + + # load mask + if args.mask: + mask = get_data_as_mask(nib.load(args.mask)) + else: + mask = np.ones((X, Y, Z), dtype=np.uint8) + + fa = np.zeros((X, Y, Z, 3)) + ad = np.zeros((X, Y, Z, 3)) + rd = np.zeros((X, Y, Z, 3)) + md = np.zeros((X, Y, Z, 3)) + + if args.fa: + fa = np.stack((fractional_anisotropy(lambdas[:, :, :, 0:3]), + fractional_anisotropy(lambdas[:, :, :, 3:6]), + fractional_anisotropy(lambdas[:, :, :, 6:9])), + axis=3) + nib.save(nib.Nifti1Image(fa * mask[..., None], + affine=affine, + header=header, + dtype=np.float32), + args.fa) + + if args.ad: + ad = np.stack((lambdas[:, :, :, 0], + lambdas[:, :, :, 3], + lambdas[:, :, :, 6]), + axis=3) + nib.save(nib.Nifti1Image(ad * mask[..., None], + affine=affine, + header=header, + dtype=np.float32), + args.ad) + + if args.rd: + rd = np.stack(((lambdas[:, :, :, 1] + lambdas[:, :, :, 2])/2, + (lambdas[:, :, :, 4] + lambdas[:, :, :, 5])/2, + (lambdas[:, :, :, 7] + lambdas[:, :, :, 8])/2), + axis=3) + nib.save(nib.Nifti1Image(rd * mask[..., None], + affine=affine, + header=header, + dtype=np.float32), + args.rd) + + if args.md: + md = np.stack((np.average(lambdas[:, :, :, 0:3], axis=3), + np.average(lambdas[:, :, :, 3:6], axis=3), + np.average(lambdas[:, :, :, 6:9], axis=3)), + axis=3) + nib.save(nib.Nifti1Image(md * mask[..., None], + affine=affine, + header=header, + dtype=np.float32), + args.md) + + +if __name__ == '__main__': + main() diff --git a/scripts/scil_mrds_select_number_of_tensors.py b/scripts/scil_mrds_select_number_of_tensors.py new file mode 100755 index 000000000..e2d9e98d7 --- /dev/null +++ b/scripts/scil_mrds_select_number_of_tensors.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Use the NUFO map information to select the plausible number of tensors +in the Multi-Resolution Discrete Search (MRDS). +https://link.springer.com/chapter/10.1007/978-3-031-47292-3_4 + +scil_mrds_select_number_of_tensors.py uses the output from mdtmrds command. +Some mdtmrds output files will be named differently from the expected input: + COMP_SIZE becomes signal_fraction + NUM_COMP becomes num_tensors + PDDs_CARTESIAN becomes evecs + Eigenvalues becomes evals + +mdtmrds: information available soon (not part of scilpy). + +Input: + Inputs are a list of 5 files for each MRDS solution (D1, D2, D3). + - Signal fraction of each tensor ([in_prefix]_D[1,2,3]_signal_fraction.nii.gz) + - Eigenvalues ($in_prefix]_D[1,2,3]_evals.nii.gz) + - Isotropic ([in_prefix]_D[1,2,3]_isotropic.nii.gz) + - Number of tensors ([in_prefix]_D[1,2,3]_num_tensors.nii.gz) + - Eigenvectors ([in_prefix]_D[1,2,3]_evecs.nii.gz) + + + Example: + scil_mrds_select_number_of_tensors.py sub-01 nufo.nii.gz +""" + +import argparse +import itertools +import logging + +import nibabel as nib +import numpy as np + +from scilpy.image.labels import get_data_as_labels +from scilpy.io.image import get_data_as_mask +from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, + add_sh_basis_args, add_verbose_arg, + assert_headers_compatible, + assert_inputs_exist, assert_outputs_exist) + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('in_prefix', + help='Prefix used for all MRDS solutions.') + p.add_argument('in_volume', + help='Volume with the number of expected tensors.' + ' (Example: NUFO volume)') + + p.add_argument('--out_prefix', default='results', + help='Prefix of the MRDS results [%(default)s].') + p.add_argument('--mask', + help='Optional mask filename.') + + add_processes_arg(p) + add_sh_basis_args(p) + add_verbose_arg(p) + add_overwrite_arg(p) + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(args.verbose.upper()) + + mrds_files = [] + for i in range(1, 4): + mrds_files.append([args.in_prefix + '_D{}_signal_fraction.nii.gz'.format(i), + args.in_prefix + '_D{}_evals.nii.gz'.format(i), + args.in_prefix + '_D{}_isotropic.nii.gz'.format(i), + args.in_prefix + '_D{}_num_tensors.nii.gz'.format(i), + args.in_prefix + '_D{}_evecs.nii.gz'.format(i)]) + + assert_inputs_exist(parser, [args.in_volume] + [x for xs in mrds_files for x in xs], + optional=args.mask) + + output_files = ["{}_MRDS_signal_fraction.nii.gz".format(args.out_prefix), + "{}_MRDS_evals.nii.gz".format(args.out_prefix), + "{}_MRDS_isotropic.nii.gz".format(args.out_prefix), + "{}_MRDS_num_tensors.nii.gz".format(args.out_prefix), + "{}_MRDS_evecs.nii.gz".format(args.out_prefix)] + assert_outputs_exist(parser, args, output_files) + assert_headers_compatible(parser, [args.in_volume] + [x for xs in mrds_files for x in xs]) + + signal_fraction = [] + evals = [] + iso = [] + num_tensors = [] + evecs = [] + for N in range(3): + signal_fraction.append(nib.load(mrds_files[N][0]).get_fdata(dtype=np.float32)) + evals.append(nib.load(mrds_files[N][1]).get_fdata(dtype=np.float32)) + iso.append(nib.load(mrds_files[N][2]).get_fdata(dtype=np.float32)) + num_tensors.append(nib.load(mrds_files[N][3]).get_fdata(dtype=np.float32)) + evecs.append(nib.load(mrds_files[N][4]).get_fdata(dtype=np.float32)) + + # MOdel SElector MAP + mosemap_img = nib.load(args.in_volume) + mosemap = get_data_as_labels(mosemap_img) + header = mosemap_img.header + + affine = mosemap_img.affine + X, Y, Z = mosemap.shape[0:3] + + # load mask + if args.mask: + mask = get_data_as_mask(nib.load(args.mask), dtype=bool) + else: + mask = np.ones((X, Y, Z), dtype=np.uint8) + + # select data using mosemap + voxels = itertools.product(range(X), range(Y), range(Z)) + filtered_voxels = ((x, y, z) for (x, y, z) in voxels if mask[x, y, z]) + + signal_fraction_out = np.zeros((X, Y, Z, 3)) + evals_out = np.zeros((X, Y, Z, 9)) + iso_out = np.zeros((X, Y, Z, 2)) + num_tensors_out = np.zeros((X, Y, Z), dtype=np.uint8) + evecs_out = np.zeros((X, Y, Z, 9)) + + # select data using mosemap + for (X, Y, Z) in filtered_voxels: + N = mosemap[X, Y, Z]-1 + + # Maximum number of tensors is 3 + if N > 2: + N = 2 + + if N > -1: + signal_fraction_out[X, Y, Z, :] = signal_fraction[N][X, Y, Z, :] + evals_out[X, Y, Z, :] = evals[N][X, Y, Z, :] + iso_out[X, Y, Z, :] = iso[N][X, Y, Z, :] + num_tensors_out[X, Y, Z] = int(num_tensors[N][X, Y, Z]) + evecs_out[X, Y, Z, :] = evecs[N][X, Y, Z, :] + + # write output files + nib.save(nib.Nifti1Image(signal_fraction_out, + affine=affine, + header=header, + dtype=np.float32), output_files[0]) + nib.save(nib.Nifti1Image(evals_out, + affine=affine, + header=header, + dtype=np.float32), output_files[1]) + nib.save(nib.Nifti1Image(iso_out, + affine=affine, + header=header, + dtype=np.float32), output_files[2]) + nib.save(nib.Nifti1Image(num_tensors_out, + affine=affine, + header=header, + dtype=np.uint8), output_files[3]) + nib.save(nib.Nifti1Image(evecs_out, + affine=affine, + header=header, + dtype=np.float32), output_files[4]) + + +if __name__ == '__main__': + main() diff --git a/scripts/scil_tracking_local_dev.py b/scripts/scil_tracking_local_dev.py index 82d4d1e4d..fbc55a3d5 100755 --- a/scripts/scil_tracking_local_dev.py +++ b/scripts/scil_tracking_local_dev.py @@ -233,8 +233,9 @@ def main(): logging.info("Instantiating propagator.") # Converting step size to vox space - # We only support iso vox for now. - assert odf_sh_res[0] == odf_sh_res[1] == odf_sh_res[2] + # We only support iso vox for now but allow slightly different vox 1e-3. + assert np.allclose(np.mean(odf_sh_res[:3]), + odf_sh_res, atol=1e-03) voxel_size = odf_sh_img.header.get_zooms()[0] vox_step_size = args.step_size / voxel_size diff --git a/scripts/scil_tractogram_add_dps.py b/scripts/scil_tractogram_add_dps.py deleted file mode 100755 index b688efe0b..000000000 --- a/scripts/scil_tractogram_add_dps.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" Add information to each streamline from a file. Can be for example -SIFT2 weights, processing information, bundle IDs, etc. - -Output must be a .trk otherwise the data will be lost. -""" - -import argparse -import logging - -from dipy.io.streamline import save_tractogram -import numpy as np - -from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_overwrite_arg, - add_reference_arg, - add_verbose_arg, - assert_inputs_exist, - assert_outputs_exist, - check_tract_trk, - load_matrix_in_any_format) - - -def _build_arg_parser(): - - p = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_tractogram', - help='Input tractogram (.trk or .tck).') - p.add_argument('in_dps_file', - help='File containing the data to add to streamlines.') - p.add_argument('dps_key', - help='Where to store the data in the tractogram.') - p.add_argument('out_tractogram', - help='Output tractogram (.trk).') - - add_reference_arg(p) - add_verbose_arg(p) - add_overwrite_arg(p) - - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - logging.getLogger().setLevel(logging.getLevelName(args.verbose)) - - # I/O assertions - assert_inputs_exist(parser, [args.in_tractogram, args.in_dps_file], - args.reference) - assert_outputs_exist(parser, args, args.out_tractogram) - check_tract_trk(parser, args.out_tractogram) - - # Load tractogram - sft = load_tractogram_with_reference(parser, args, args.in_tractogram) - - # Make sure the user is not unwillingly overwritting dps - if (args.dps_key in sft.get_data_per_streamline_keys() and - not args.overwrite): - parser.error('"{}" already in data per streamline. Use -f to force ' - 'overwriting.'.format(args.dps_key)) - - # Load data and remove extraneous dimensions - data = np.squeeze(load_matrix_in_any_format(args.in_dps_file)) - - # Quick check as the built-in error from sft is not too explicit - if len(sft) != data.shape[0]: - raise ValueError('Data must have as many entries ({}) as there are' - ' streamlines ({}).'.format(data.shape[0], len(sft))) - # Add data to tractogram - sft.data_per_streamline[args.dps_key] = data - - # Save the new sft - save_tractogram(sft, args.out_tractogram) - - -if __name__ == '__main__': - main() diff --git a/scripts/scil_tractogram_dps_math.py b/scripts/scil_tractogram_dps_math.py new file mode 100755 index 000000000..5b96a5827 --- /dev/null +++ b/scripts/scil_tractogram_dps_math.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Import, extract or delete dps (data_per_streamline) information to a tractogram +file. Can be for example SIFT2 weights, processing information, bundle IDs, +tracking seeds, etc. + +Input and output tractograms must be .trk, unless you are using the 'import' +operation, in which case a .tck input tractogram is accepted. + +Usage examples: + > scil_tractogram_dps_math.py tractogram.trk import "bundle_ids" + --in_dps_file my_bundle_ids.txt + > scil_tractogram_dps_math.py tractogram.trk export "seeds" + --out_dps_file seeds.npy +""" + +import nibabel as nib +import argparse +import logging + +from dipy.io.streamline import save_tractogram, load_tractogram +from scilpy.io.streamlines import load_tractogram_with_reference +import numpy as np + +from scilpy.io.utils import (add_overwrite_arg, + add_verbose_arg, + assert_inputs_exist, + assert_outputs_exist, + check_tract_trk, + load_matrix_in_any_format, + save_matrix_in_any_format) + + +def _build_arg_parser(): + + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + p.add_argument('in_tractogram', + help='Input tractogram (.trk for all operations,' + '.tck accepted for import).') + p.add_argument('operation', metavar='OPERATION', + choices=['import', 'delete', 'export'], + help='The type of operation to be performed on the\n' + 'tractogram\'s data_per_streamline at the given\n' + 'key. Must be one of the following: [%(choices)s].\n' + 'The additional arguments required for each\n' + 'operation are specified under each group below.') + p.add_argument('dps_key', type=str, + help='Key name used for the operation.') + + p.add_argument('--out_tractogram', + help='Output tractogram (.trk). Required for "import" and\n' + '"delete" operations.') + + import_args = p.add_argument_group('Operation "import" mandatory options') + import_args.add_argument('--in_dps_file', + help='File containing the data to import to\n' + 'streamlines (.txt, .npy or .mat).') + + export_args = p.add_argument_group('Operation "export" mandatory options') + export_args.add_argument('--out_dps_file', + help='File in which the extracted data will be\n' + 'saved (.txt or .npy).') + + add_verbose_arg(p) + add_overwrite_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + if args.operation == 'import': + if not nib.streamlines.is_supported(args.in_tractogram): + parser.error('Invalid input streamline file format (must be trk ' + + 'or tck): {0}'.format(args.in_tractogram)) + else: + check_tract_trk(parser, args.in_tractogram) + + if args.out_tractogram: + check_tract_trk(parser, args.out_tractogram) + + assert_inputs_exist(parser, args.in_tractogram, args.in_dps_file) + assert_outputs_exist(parser, args, [], optional=[args.out_dps_file, + args.out_tractogram]) + + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) + + if args.operation == 'import': + if args.in_dps_file is None: + parser.error('The --in_dps_file option is required for ' + + 'the "import" operation.') + + if args.out_tractogram is None: + parser.error('The --out_tractogram option is required for ' + + 'the "import" operation.') + + # Make sure the user is not unwillingly overwritting dps + if (args.dps_key in sft.get_data_per_streamline_keys() and + not args.overwrite): + parser.error('"{}" already in data per streamline. Use -f to force' + ' overwriting.'.format(args.dps_key)) + + # Load data and remove extraneous dimensions + data = np.squeeze(load_matrix_in_any_format(args.in_dps_file)) + + # Quick check as the built-in error from sft is not too explicit + if len(sft) != data.shape[0]: + raise ValueError('Data must have as many entries ({}) as there are' + ' streamlines ({}).'.format(data.shape[0], + len(sft))) + + sft.data_per_streamline[args.dps_key] = data + + save_tractogram(sft, args.out_tractogram) + + if args.operation == 'delete': + if args.out_tractogram is None: + parser.error('The --out_tractogram option is required for ' + + 'the "delete" operation.') + + del sft.data_per_streamline[args.dps_key] + + save_tractogram(sft, args.out_tractogram) + + if args.operation == 'export': + if args.out_dps_file is None: + parser.error('The --out_dps_file option is required for ' + + 'the "export" operation.') + + # Extract data and reshape + if args.dps_key not in sft.data_per_streamline.keys(): + raise ValueError('Data does not have any data_per_streamline' + ' entry stored at this key: {}' + .format(args.dps_key)) + + data = np.squeeze(sft.data_per_streamline[args.dps_key]) + save_matrix_in_any_format(args.out_dps_file, data) + + +if __name__ == '__main__': + main() diff --git a/scripts/scil_tractogram_filter_by_roi.py b/scripts/scil_tractogram_filter_by_roi.py index 3ee4d9d2b..8bb203893 100755 --- a/scripts/scil_tractogram_filter_by_roi.py +++ b/scripts/scil_tractogram_filter_by_roi.py @@ -446,17 +446,21 @@ def main(): radius += distance * sft.space_attributes[2] if geometry == 'Ellipsoid': - filtered_sft, kept_ids = filter_ellipsoid( + kept_ids, filtered_sft = filter_ellipsoid( sft, radius, center, mode, is_exclude) else: # geometry == 'Cuboid': - filtered_sft, kept_ids = filter_cuboid( + kept_ids, filtered_sft = filter_cuboid( sft, radius, center, mode, is_exclude) logging.info('The filtering options {} resulted in {} included ' 'streamlines'.format(roi_opt, len(filtered_sft))) sft = filtered_sft - total_kept_ids = total_kept_ids[kept_ids] + if kept_ids.size == 0: + total_kept_ids = 0 + else: + total_kept_ids = total_kept_ids[kept_ids] + o_dict['streamline_count_after_criteria{}'.format(i)] = \ len(sft.streamlines) diff --git a/scripts/scil_tractogram_filter_collisions.py b/scripts/scil_tractogram_filter_collisions.py new file mode 100644 index 000000000..72134c67d --- /dev/null +++ b/scripts/scil_tractogram_filter_collisions.py @@ -0,0 +1,304 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Given an input tractogram and a text file containing a diameter for each +streamline, filters all intersecting streamlines and saves the resulting +tractogram and diameters. + +IMPORTANT: The input tractogram needs to have been resampled to segments of +at most 0.2mm. Otherwise performance will drop significantly. This is because +this script relies on a KDTree to find all neighboring streamline segments of +any given point. Because the search radius is set at the length of the longest +fibertube segment, the performance drops significantly if they are not +shortened to ~0.2mm. +(see scil_tractogram_resample_nb_points.py) + +IMPORTANT: Some tractograms, especially if old, were created with a very high +float precision. scil_tractogram_filter_collisions.py does not save its output +with such precision. This means that after filtering once and saving the +result, new collisions may be created from saving at a lower float precision. +It will require a second filtering to be truly collision-free. + +If you are using the --out_metrics parameter on high float precision data, the +script may even throw an error saying that not all collisions were filtered +prior to metrics computation. + +Solution: If you encounter such behaviour, we recommend you load and save your +tractogram to be filtered with up-to-date tools such as MI-Brain or the +Nibabel python library. (Which scilpy scripts use) + +---------- + +The filtering is deterministic and follows this approach: + - Pick next streamline + - Iterate over its segments + - If current segment collides with any other streamline segment given + their diameters + - Current streamline is deemed invalid and is filtered out + - Other streamline is left untouched + - Repeat + +This means that the order of the streamlines within the tractogram has a +direct impact on which streamline gets filtered out. To counter the resulting +bias, streamlines are shuffled first unless --disable_shuffling is set. + +If the --out_metrics parameter is given, several metrics about the data will +be computed (all expressed in mm): + - fibertube_density + Estimate of the following ratio: volume of fibertubes / total volume + where the total volume is the combined volume of all voxels containing + at least one fibertube. + - min_external_distance + Smallest distance separating two streamlines, outside their diameter. + - max_voxel_anisotropic + Diagonal vector of the largest possible anisotropic voxel that + would not intersect two streamlines, given their diameter. + - max_voxel_isotropic + Isotropic version of max_voxel_anisotropic made by using the smallest + component. + Ex: max_voxel_anisotropic: (3, 5, 5) => max_voxel_isotropic: (3, 3, 3) + - max_voxel_rotated + Largest possible isotropic voxel obtainable if the tractogram is + rotated. + It is only usable if the entire tractogram is rotated according to + [rotation_matrix]. + Ex: max_voxel_anisotropic: (1, 0, 0) => max_voxel_isotropic: (0, 0, 0) + => max_voxel_rotated: (0.5774, 0.5774, 0.5774) + +If the --out_rotation_matrix option is provided, the following will be saved: + - rotation_matrix + 4D transformation matrix representing the rotation to be applied on + the tractogram to align max_voxel_rotated with the coordinate system + (see scil_tractogram_apply_transform.py). + +See also: + - docs/source/documentation/fibertube_tracking.rst +""" + +import os +import json +import argparse +import logging +import numpy as np + +from scilpy.tractograms.intersection_finder import IntersectionFinder +from dipy.io.stateful_tractogram import StatefulTractogram +from dipy.io.streamline import save_tractogram +from scilpy.io.streamlines import load_tractogram_with_reference +from scilpy.tractanalysis.fibertube_scoring import (mean_fibertube_density, + min_external_distance, + max_voxels, + max_voxel_rotated) +from scilpy.io.utils import (assert_inputs_exist, + assert_outputs_exist, + add_overwrite_arg, + add_verbose_arg, + add_json_args) + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + + p.add_argument('in_tractogram', + help='Path to the tractogram file containing the \n' + 'streamlines (must be .trk).') + + p.add_argument('in_diameters', + help='Path to a text file containing a list of \n' + 'diameters in mm. Each line corresponds \n' + 'to the identically numbered streamline. \n' + 'If unsure, refer to the diameters text file of the \n' + 'DiSCo dataset. If a single diameter is provided, all \n' + 'streamlines will be given this diameter.') + + p.add_argument('out_tractogram', + help='Tractogram output file free of collision (must \n' + 'be .trk). By default, the diameters will be \n' + 'saved as data_per_streamline.') + + p.add_argument('--save_colliding', action='store_true', + help='Useful for visualization. If set, the script will \n' + 'produce two other tractograms (.trk) containing \n' + 'colliding streamlines. The first one contains invalid \n' + 'streamlines that have been filtered out, along with \n' + 'their collision point as data per streamline. The \n' + 'second one contains the potentially valid streamlines \n' + 'that the first tractogram collided with. Note that the \n' + 'streamlines in the second tractogram may or may not \n' + 'have been filtered afterwards. \n' + 'Filenames are derived from [in_tractogram] with \n' + '"_invalid" appended for the first tractogram, and \n' + '"_obstacle" appended for the second tractogram.') + + p.add_argument('--out_metrics', default=None, type=str, + help='If set, metrics about the streamlines and their \n' + 'diameter will be computed after filtering and saved at \n' + 'the given location (must be .json).') + + p.add_argument('--out_rotation_matrix', default=None, type=str, + help='If set, the transformation required to align the \n' + '"max_voxel_rotated" metric with the coordinate system \n' + 'will be saved at the given location (must be .mat). \n' + 'This option requires computing all the metrics, even \n' + 'if --out_metrics is not provided. If it is provided, ' + 'metrics are not computed twice.') + + p.add_argument('--min_distance', default=0, type=float, + help='If set, streamlines will be filtered more \n' + 'aggressively so that even if they don\'t collide, \n' + 'being below [min_distance] apart (external to their \n' + 'diameter) will be interpreted as a collision. This \n' + 'option is the same as filtering with a large diameter \n' + 'but only saving a small diameter in out_tractogram. \n' + '(Value in mm) [%(default)s]') + + p.add_argument('--disable_shuffling', action='store_true', + help='If set, no shuffling will be performed before \n' + 'the filtering process. Streamlines will be picked in \n' + 'order.') + + p.add_argument('--rng_seed', type=int, default=0, + help='If set, all random values will be generated \n' + 'using the specified seed. [%(default)s]') + + add_json_args(p) + add_overwrite_arg(p) + add_verbose_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + logging.getLogger('numba').setLevel(logging.WARNING) + + in_tractogram_no_ext, in_tractogram_ext = os.path.splitext( + args.in_tractogram) + if in_tractogram_ext != '.trk': + raise ValueError("Invalid output streamline file format " + + "(must be trk): {0}".format(args.in_tractogram)) + + if os.path.splitext(args.out_tractogram)[1] != '.trk': + raise ValueError("Invalid output streamline file format " + + "(must be trk): {0}".format(args.out_tractogram)) + + if args.out_metrics: + if os.path.splitext(args.out_metrics)[1] != '.json': + raise ValueError("Invalid metrics output file format " + + "(must be json): {0}".format(args.out_metrics)) + + if args.out_rotation_matrix: + if os.path.splitext(args.out_rotation_matrix)[1] != '.mat': + raise ValueError("Invalid out_rotation_matrix output file" + + "format (must be mat): " + + "{0}".format(args.out_rotation_matrix)) + + outputs = [args.out_tractogram] + if args.save_colliding: + outputs.append(in_tractogram_no_ext + '_invalid.trk') + outputs.append(in_tractogram_no_ext + '_obstacle.trk') + + assert_inputs_exist(parser, args.in_tractogram) + assert_outputs_exist(parser, args, outputs, + [args.out_metrics, args.out_rotation_matrix]) + + logging.debug('Loading tractogram & diameters') + in_sft = load_tractogram_with_reference(parser, args, args.in_tractogram) + in_sft.to_voxmm() + in_sft.to_center() + + streamlines = in_sft.get_streamlines_copy() + diameters = np.loadtxt(args.in_diameters, dtype=np.float64) + + # Test single diameter + if np.ndim(diameters) == 0: + diameters = np.full(len(streamlines), diameters) + elif diameters.shape[0] != (len(streamlines)): + raise ValueError('Number of diameters does not match the number ' + + 'of streamlines.') + + if not args.disable_shuffling: + logging.debug('Shuffling streamlines') + indexes = list(range(len(streamlines))) + gen = np.random.default_rng(args.rng_seed) + gen.shuffle(indexes) + + streamlines = streamlines[indexes] + diameters = diameters[indexes] + in_sft = StatefulTractogram.from_sft(streamlines, in_sft) + + # Casting ArraySequence as a list to improve speed + streamlines = list(streamlines) + + logging.debug('Building IntersectionFinder') + inter_finder = IntersectionFinder( + in_sft, diameters, args.verbose != 'WARNING') + + logging.debug('Finding intersections') + inter_finder.find_intersections(args.min_distance) + + logging.debug('Building new tractogram(s)') + out_sft, invalid_sft, obstacle_sft = inter_finder.build_tractograms( + args.save_colliding) + + logging.debug('Saving new tractogram(s)') + save_tractogram(out_sft, args.out_tractogram) + + if args.save_colliding: + save_tractogram( + invalid_sft, + in_tractogram_no_ext + '_invalid.trk') + + save_tractogram( + obstacle_sft, + in_tractogram_no_ext + '_obstacle.trk') + + logging.debug('Input streamline count: ' + str(len(streamlines)) + + ' | Output streamline count: ' + + str(len(out_sft.streamlines))) + + logging.debug( + str(len(streamlines) - len(out_sft.streamlines)) + + ' streamlines have been filtered') + + if args.out_metrics is not None or args.out_rotation_matrix is not None: + logging.info('Computing metrics') + + min_ext_dist, min_ext_dist_vect = ( + min_external_distance( + out_sft, + args.verbose != 'WARNING')) + max_voxel_ani, max_voxel_iso = max_voxels(min_ext_dist_vect) + mvr_rot, mvr_edge = max_voxel_rotated(min_ext_dist_vect) + + # Fibertube density comes last, because it changes space and origin. + mean_density = mean_fibertube_density(out_sft) + + if args.out_metrics: + metrics = { + 'mean_density': mean_density, + 'min_external_distance': min_ext_dist.tolist(), + 'max_voxel_anisotropic': max_voxel_ani.tolist(), + 'max_voxel_isotropic': max_voxel_iso.tolist(), + 'max_voxel_rotated': [mvr_edge]*3 + } + + with open(args.out_metrics, 'w') as outfile: + json.dump(metrics, outfile, + indent=args.indent, sort_keys=args.sort_keys) + + if args.out_rotation_matrix is not None: + max_voxel_rotated_transform = np.r_[np.c_[ + mvr_rot, [0, 0, 0]], [[0, 0, 0, 1]]] + with open(args.out_rotation_matrix, 'w') as outfile: + np.savetxt(outfile, max_voxel_rotated_transform) + + +if __name__ == "__main__": + main() diff --git a/scripts/scil_tractogram_segment_connections_from_labels.py b/scripts/scil_tractogram_segment_connections_from_labels.py index b2bd29351..407d69323 100755 --- a/scripts/scil_tractogram_segment_connections_from_labels.py +++ b/scripts/scil_tractogram_segment_connections_from_labels.py @@ -29,7 +29,8 @@ The segmentation process ------------------------ Segmenting a tractogram based on its endpoints is not as straighforward as one -could imagine. [EXPLAIN THE ISSUES] +could imagine. The endpoints could be outside any labelled region. + The current strategy is to keep the longest streamline segment connecting 2 regions. If the streamline crosses other gray matter regions before reaching its final connected region, the kept connection is still the longest. This is @@ -137,10 +138,11 @@ def _build_arg_parser(): help='Tractogram filename (s). Format must be one of \n' 'trk, tck, vtk, fib, dpy.\n' 'If you have many tractograms for a single subject ' - '(ex, coming \nfrom Ensemble Tracking)') + '(ex, coming \nfrom Ensemble Tracking), we will ' + 'merge them together.') p.add_argument('in_labels', help='Labels file name (nifti). Labels must have 0 as ' - 'background.') + 'background. Volumes must have isotropic voxels.') p.add_argument('out_hdf5', help='Output hdf5 file (.h5).') diff --git a/scripts/scil_viz_tractogram_collisions.py b/scripts/scil_viz_tractogram_collisions.py new file mode 100644 index 000000000..ded9d0df8 --- /dev/null +++ b/scripts/scil_viz_tractogram_collisions.py @@ -0,0 +1,126 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Visualize collisions given by scil_tractogram_filter_collisions with +the --save_colliding parameter. +""" + +import argparse + +from dipy.io.streamline import load_tractogram +from scilpy.io.streamlines import load_tractogram_with_reference +from fury import window, actor +from nibabel.streamlines import detect_format, TrkFile + +from scilpy.io.utils import (add_overwrite_arg, + add_reference_arg, + assert_inputs_exist, + assert_outputs_exist) + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + p.add_argument('in_tractogram_invalid', + help='Tractogram file containing the colliding \n' + 'streamlines that have been filtered, along their \n' + 'collision point as data_per_streamline (must be \n' + '.trk). This file is obtained from the \n' + 'scil_tractogram_filter_collisions.py script.') + + p.add_argument('--in_tractogram_obstacle', + help='Tractogram file containing the streamlines that \n' + 'that [in_tractogram_invalid] has collided with. Will \n' + 'be overlaid in the viewing window. This file is \n' + 'obtained from the scil_tractogram_filter_collisions.py \n' + 'script.') + + p.add_argument('--ref_tractogram', + help='Tractogram file containing the full tractogram \n' + 'as visual reference (must be .trk or .tck). It will be \n' + 'overlaid in white and very low opacity.') + + p.add_argument('--out_screenshot', default='', + help='If set, save a screenshot of the result in the \n' + 'specified filename (.png, .bmp, .jpeg or .jpg).') + + p.add_argument('--win_size', nargs=2, type=int, default=(1000, 1000)) + + add_overwrite_arg(p) + add_reference_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + assert_inputs_exist(parser, args.in_tractogram_invalid, + [args.in_tractogram_obstacle, args.ref_tractogram]) + assert_outputs_exist(parser, args, [], [args.out_screenshot]) + + tracts_format = detect_format(args.in_tractogram_invalid) + if tracts_format is not TrkFile: + raise ValueError("Invalid input streamline file format " + + "(must be trk):" + + "{0}".format(args.in_tractogram_invalid)) + + if args.in_tractogram_obstacle: + tracts_format = detect_format(args.in_tractogram_obstacle) + if tracts_format is not TrkFile: + raise ValueError("Invalid input streamline file format " + + "(must be trk):" + + "{0}".format(args.in_tractogram_invalid)) + + invalid_sft = load_tractogram(args.in_tractogram_invalid, 'same', + bbox_valid_check=False) + invalid_sft.to_voxmm() + invalid_sft.to_center() + + if 'collisions' not in invalid_sft.data_per_streamline: + parser.error('Tractogram does not contain collisions') + collisions = invalid_sft.data_per_streamline['collisions'] + + if args.in_tractogram_obstacle: + obstacle_sft = load_tractogram(args.in_tractogram_obstacle, 'same', + bbox_valid_check=False) + obstacle_sft.to_voxmm() + obstacle_sft.to_center() + if args.ref_tractogram: + full_sft = load_tractogram_with_reference(parser, args, + args.ref_tractogram) + full_sft.to_voxmm() + full_sft.to_center() + + # Make display objects and add them to canvas + s = window.Scene() + invalid_actor = actor.line(invalid_sft.streamlines, + colors=[1., 0., 0.]) + s.add(invalid_actor) + + if args.in_tractogram_obstacle: + obstacle_actor = actor.line(obstacle_sft.streamlines, + colors=[0., 1., 0.]) + s.add(obstacle_actor) + + if args.ref_tractogram: + full_actor = actor.line(full_sft.streamlines, opacity=0.03, + colors=[1., 1., 1.]) + + s.add(full_actor) + + points = actor.dot(collisions, colors=(1., 1., 1.)) + s.add(points) + + # Show and record if needed + if args.out_screenshot: + window.record(s, out_path=args.out_screenshot, size=args.win_size) + window.show(s) + + +if __name__ == '__main__': + main() diff --git a/scripts/scil_viz_tractogram_seeds.py b/scripts/scil_viz_tractogram_seeds.py index a86a0ade4..568739506 100755 --- a/scripts/scil_viz_tractogram_seeds.py +++ b/scripts/scil_viz_tractogram_seeds.py @@ -31,7 +31,7 @@ def _build_arg_parser(): help='Tractogram file (must be trk)') p.add_argument('--save', help='If set, save a screenshot of the result in the ' - 'specified filename') + 'specified filename (.png, .bmp, .jpeg or .jpg).') add_verbose_arg(p) add_overwrite_arg(p) @@ -65,7 +65,7 @@ def main(): # Make display objects streamlines_actor = actor.line(streamlines) - points = actor.dot(seeds, color=(1., 1., 1.)) + points = actor.dot(seeds, colors=(1., 1., 1.)) # Add display objects to canvas s = window.Scene() diff --git a/scripts/scil_volume_apply_transform.py b/scripts/scil_volume_apply_transform.py index 7ec3c058b..4d9ed555e 100755 --- a/scripts/scil_volume_apply_transform.py +++ b/scripts/scil_volume_apply_transform.py @@ -41,6 +41,9 @@ def _build_arg_parser(): p.add_argument('--keep_dtype', action='store_true', help='If True, keeps the data_type of the input image ' '(in_file) when saving the output image (out_name).') + p.add_argument('--interpolation', default='linear', + choices=['linear', 'nearest'], + help='Interpolation: "linear" or "nearest". [%(default)s]') add_verbose_arg(p) add_overwrite_arg(p) @@ -75,7 +78,8 @@ def main(): # Processing, saving warped_img = apply_transform( - transfo, reference, moving, keep_dtype=args.keep_dtype) + transfo, reference, moving, keep_dtype=args.keep_dtype, + interp=args.interpolation) nib.save(warped_img, args.out_name) diff --git a/scripts/scil_volume_pairwise_comparison.py b/scripts/scil_volume_pairwise_comparison.py index cc5672f0f..8db1cf550 100755 --- a/scripts/scil_volume_pairwise_comparison.py +++ b/scripts/scil_volume_pairwise_comparison.py @@ -103,7 +103,7 @@ def compute_all_measures(args): voxel_size = np.product(voxel_size) logging.info(f"Comparing {filename_1} and {filename_2}") dict_measures = compare_volume_wrapper(data_1, data_2, voxel_size, - adjency_no_overlap, ratio) + ratio, adjency_no_overlap) return dict_measures diff --git a/scripts/scil_volume_resample.py b/scripts/scil_volume_resample.py index cf1f74d64..0a659e709 100755 --- a/scripts/scil_volume_resample.py +++ b/scripts/scil_volume_resample.py @@ -23,8 +23,7 @@ import numpy as np from scilpy.io.utils import (add_verbose_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - assert_headers_compatible) + assert_inputs_exist, assert_outputs_exist) from scilpy.image.volume_operations import resample_volume @@ -59,6 +58,9 @@ def _build_arg_parser(): choices=['nn', 'lin', 'quad', 'cubic'], help="Interpolation mode.\nnn: nearest neighbour\nlin: linear\n" "quad: quadratic\ncubic: cubic\nDefaults to linear") + p.add_argument('--enforce_voxel_size', action='store_true', + help='Enforce --voxel_size even if there is a numerical' + ' difference after resampling.') p.add_argument('--enforce_dimensions', action='store_true', help='Enforce the reference volume dimension.') @@ -78,7 +80,10 @@ def main(): assert_outputs_exist(parser, args, args.out_image) if args.enforce_dimensions and not args.ref: - parser.error("Cannot enforce dimensions without a reference image") + parser.error("Cannot enforce dimensions without a reference image.") + + if args.enforce_voxel_size and not args.voxel_size: + parser.error("Cannot enforce voxel size without a voxel size.") if args.volume_size and (not len(args.volume_size) == 1 and not len(args.volume_size) == 3): @@ -110,6 +115,23 @@ def main(): enforce_dimensions=args.enforce_dimensions) # Saving results + zooms = list(resampled_img.header.get_zooms()) + if args.voxel_size: + if len(args.voxel_size) == 1: + args.voxel_size = args.voxel_size * 3 + + if not np.array_equal(zooms[:3], args.voxel_size): + logging.warning('Voxel size is different from expected.' + ' Got: %s, expected: %s', + tuple(zooms), tuple(args.voxel_size)) + if args.enforce_voxel_size: + logging.warning('Enforcing voxel size to %s', + tuple(args.voxel_size)) + zooms[0] = args.voxel_size[0] + zooms[1] = args.voxel_size[1] + zooms[2] = args.voxel_size[2] + resampled_img.header.set_zooms(tuple(zooms)) + logging.info('Saving resampled data to %s', args.out_image) nib.save(resampled_img, args.out_image) diff --git a/scripts/scil_volume_reshape.py b/scripts/scil_volume_reshape.py old mode 100644 new mode 100755 diff --git a/scripts/tests/test_aodf_metrics.py b/scripts/tests/test_aodf_metrics.py index 29e4c63ee..a242e060b 100644 --- a/scripts/tests/test_aodf_metrics.py +++ b/scripts/tests/test_aodf_metrics.py @@ -25,7 +25,7 @@ def test_execution(script_runner, monkeypatch): # Using a low resolution sphere for peak extraction reduces process time ret = script_runner.run('scil_aodf_metrics.py', in_fodf, - '--sphere', 'repulsion100') + '--sphere', 'repulsion100', '--processes', '1') assert ret.success @@ -35,7 +35,7 @@ def test_assert_not_all(script_runner, monkeypatch): f"{test_data_root}/fodf_descoteaux07_sub_unified_asym.nii.gz") ret = script_runner.run('scil_aodf_metrics.py', in_fodf, - '--not_all') + '--not_all', '--processes', '1') assert not ret.success @@ -46,7 +46,8 @@ def test_execution_not_all(script_runner, monkeypatch): ret = script_runner.run('scil_aodf_metrics.py', in_fodf, '--not_all', '--asi_map', - 'asi_map.nii.gz', '-f') + 'asi_map.nii.gz', '-f', + '--processes', '1') assert ret.success @@ -57,7 +58,8 @@ def test_assert_symmetric_input(script_runner, monkeypatch): # Using a low resolution sphere for peak extraction reduces process time ret = script_runner.run('scil_aodf_metrics.py', in_fodf, - '--sphere', 'repulsion100') + '--sphere', 'repulsion100', + '--processes', '1') assert not ret.success @@ -67,7 +69,9 @@ def test_execution_symmetric_input(script_runner, monkeypatch): f"{test_data_root}/fodf_descoteaux07_sub.nii.gz") # Using a low resolution sphere for peak extraction reduces process time + # Using multiprocessing to test this option. ret = script_runner.run('scil_aodf_metrics.py', in_fodf, '--sphere', 'repulsion100', '--not_all', - '--nufid', 'nufid.nii.gz') + '--nufid', 'nufid.nii.gz', + '--processes', '4') assert not ret.success diff --git a/scripts/tests/test_bundle_fixel_analysis.py b/scripts/tests/test_bundle_fixel_analysis.py index 614f4149d..21101de10 100644 --- a/scripts/tests/test_bundle_fixel_analysis.py +++ b/scripts/tests/test_bundle_fixel_analysis.py @@ -21,9 +21,10 @@ def test_default_parameters(script_runner, monkeypatch): in_peaks = os.path.join(SCILPY_HOME, 'commit_amico', 'peaks.nii.gz') in_bundle = os.path.join(SCILPY_HOME, 'commit_amico', 'tracking.trk') + # Using multiprocessing in this test, single in following tests. ret = script_runner.run('scil_bundle_fixel_analysis.py', in_peaks, '--in_bundles', in_bundle, - '--processes', '1', '-f') + '--processes', '4', '-f') assert ret.success diff --git a/scripts/tests/test_bundle_mean_fixel_mrds_metric.py b/scripts/tests/test_bundle_mean_fixel_mrds_metric.py new file mode 100644 index 000000000..efe916c2e --- /dev/null +++ b/scripts/tests/test_bundle_mean_fixel_mrds_metric.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +def test_help_option(script_runner): + ret = script_runner.run( + 'scil_bundle_mean_fixel_mrds_metric.py', '--help') + + assert ret.success diff --git a/scripts/tests/test_connectivity_compute_simple_matrix.py b/scripts/tests/test_connectivity_compute_simple_matrix.py new file mode 100644 index 000000000..33411c60e --- /dev/null +++ b/scripts/tests/test_connectivity_compute_simple_matrix.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['tractometry.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run( + 'scil_connectivity_compute_simple_matrix.py', '--help') + assert ret.success + + +def test_script(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_labels = os.path.join(SCILPY_HOME, 'tractometry', + 'IFGWM_labels_map.nii.gz') + in_sft = os.path.join(SCILPY_HOME, 'tractometry', 'IFGWM.trk') + + ret = script_runner.run( + 'scil_connectivity_compute_simple_matrix.py', in_sft, in_labels, + 'out_matrix.npy', 'out_labels.txt', '--hide_labels', '10', + '--percentage', '--hide_fig', '--out_fig', 'matrices.png') + assert ret.success diff --git a/scripts/tests/test_fibertube_score_tractogram.py b/scripts/tests/test_fibertube_score_tractogram.py new file mode 100644 index 000000000..a7cea0f8e --- /dev/null +++ b/scripts/tests/test_fibertube_score_tractogram.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import json +import tempfile +import numpy as np +import nibabel as nib + +from scilpy.io.streamlines import save_tractogram +from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin + +tmp_dir = tempfile.TemporaryDirectory() + + +def init_data(): + streamlines = [[[5., 1., 5.], [5., 5., 9.], [7., 9., 9.], [13., 11., 9.], + [5., 7., 7.]], [[7., 7., 7.], [9., 9., 9.]]] + + mask = np.ones((15, 15, 15)) + affine = np.eye(4) + header = nib.nifti2.Nifti2Header() + extra = { + 'affine': affine, + 'dimensions': (15, 15, 15), + 'voxel_size': 1., + 'voxel_order': "RAS" + } + mask_img = nib.nifti2.Nifti2Image(mask, affine, header, extra) + + config = { + 'step_size': 0.001, + 'blur_radius': 0.001, + 'nb_fibertubes': 2, + 'nb_seeds_per_fibertube': 1, + } + + sft_fibertubes = StatefulTractogram(streamlines, mask_img, Space.VOX, + Origin.NIFTI) + sft_fibertubes.data_per_streamline = { + "diameters": [0.002, 0.001] + } + sft_tracking = StatefulTractogram(streamlines, mask_img, Space.VOX, + Origin.NIFTI) + sft_tracking.data_per_streamline = { + "seeds": [streamlines[0][0], streamlines[1][0]] + } + + save_tractogram(sft_fibertubes, 'fibertubes.trk', True) + save_tractogram(sft_tracking, 'tracking.trk', True) + + with open('config.json', 'w') as file: + json.dump(config, file, indent=True) + + +def test_help_option(script_runner): + ret = script_runner.run('scil_fibertube_score_tractogram.py', '--help') + assert ret.success + + +def test_execution(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + ret = script_runner.run('scil_fibertube_score_tractogram.py', + 'fibertubes.trk', 'tracking.trk', 'config.json', + 'metrics.json', '--save_error_tractogram', '-f') + assert ret.success diff --git a/scripts/tests/test_fibertube_tracking.py b/scripts/tests/test_fibertube_tracking.py new file mode 100644 index 000000000..ba1a407d6 --- /dev/null +++ b/scripts/tests/test_fibertube_tracking.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile +import numpy as np +import nibabel as nib + +from scilpy.io.streamlines import save_tractogram +from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin + +tmp_dir = tempfile.TemporaryDirectory() + + +def init_data(): + streamlines = [[[5., 1., 5.], [5., 5., 9.], [7., 9., 9.], [13., 11., 9.], + [5., 7., 7.]], [[7., 7., 7.], [9., 9., 9.]]] + + mask = np.ones((15, 15, 15)) + affine = np.eye(4) + header = nib.nifti2.Nifti2Header() + extra = { + 'affine': affine, + 'dimensions': (15, 15, 15), + 'voxel_size': 1., + 'voxel_order': "RAS" + } + mask_img = nib.nifti2.Nifti2Image(mask, affine, header, extra) + + sft = StatefulTractogram(streamlines, mask_img, Space.VOX, Origin.NIFTI) + sft.data_per_streamline = { + "diameters": [0.002, 0.001] + } + + save_tractogram(sft, 'tractogram.trk', True) + + +def test_help_option(script_runner): + ret = script_runner.run('scil_fibertube_tracking.py', '--help') + assert ret.success + + +def test_execution(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + ret = script_runner.run('scil_fibertube_tracking.py', + 'tractogram.trk', 'tracking.trk', + '--min_length', '0', '-f') + assert ret.success + + +def test_execution_tracking_rk(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + ret = script_runner.run('scil_fibertube_tracking.py', + 'tractogram.trk', 'tracking.trk', + '--blur_radius', '0.3', + '--step_size', '0.1', + '--rk_order', '2', '--min_length', '0', '-f') + assert ret.success + + +def test_execution_config(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + ret = script_runner.run('scil_fibertube_tracking.py', + 'tractogram.trk', 'tracking.trk', + '--blur_radius', '0.3', + '--step_size', '0.1', + '--out_config', 'config.json', + '--min_length', '0', '-f') + assert ret.success + + +def test_execution_seeding(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + ret = script_runner.run('scil_fibertube_tracking.py', + 'tractogram.trk', 'tracking.trk', + '--blur_radius', '0.3', + '--step_size', '0.1', + '--nb_fibertubes', '1', + '--nb_seeds_per_fibertube', '3', '--skip', '3', + '--min_length', '0', '-f') + assert ret.success diff --git a/scripts/tests/test_frf_mean.py b/scripts/tests/test_frf_mean.py index d04088e39..436d7c144 100644 --- a/scripts/tests/test_frf_mean.py +++ b/scripts/tests/test_frf_mean.py @@ -32,6 +32,23 @@ def test_execution_processing_msmt(script_runner, monkeypatch): assert ret.success +def test_outputs_precision(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_frf = os.path.join(SCILPY_HOME, 'commit_amico', 'wm_frf.txt') + ret = script_runner.run('scil_frf_mean.py', in_frf, in_frf, 'mfrfp.txt', + '--precision', '4') + assert ret.success + + expected = [ + "0.0016 0.0004 0.0004 3076.7249", + "0.0012 0.0003 0.0003 3076.7249", + "0.0009 0.0003 0.0003 3076.7249" + ] + with open('mfrfp.txt', 'r') as result: + for i, line in enumerate(result.readlines()): + assert line.strip("\n") == expected[i] + + def test_execution_processing_bad_input(script_runner, monkeypatch): monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) in_wm_frf = os.path.join(SCILPY_HOME, 'commit_amico', 'wm_frf.txt') diff --git a/scripts/tests/test_frf_memsmt.py b/scripts/tests/test_frf_memsmt.py index 5ca12b3d7..e60de7607 100644 --- a/scripts/tests/test_frf_memsmt.py +++ b/scripts/tests/test_frf_memsmt.py @@ -130,6 +130,42 @@ def test_inputs_check(script_runner, monkeypatch): assert (not ret.success) +def test_outputs_precision(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi_lin = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'dwi_linear.nii.gz') + in_bval_lin = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'linear.bvals') + in_bvec_lin = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'linear.bvecs') + in_dwi_plan = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'dwi_planar.nii.gz') + in_bval_plan = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'planar.bvals') + in_bvec_plan = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'planar.bvecs') + in_dwi_sph = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'dwi_spherical.nii.gz') + in_bval_sph = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'spherical.bvals') + in_bvec_sph = os.path.join(SCILPY_HOME, 'btensor_testdata', + 'spherical.bvecs') + ret = script_runner.run('scil_frf_memsmt.py', 'wm_frf.txt', + 'gm_frf.txt', 'csf_frf.txt', '--in_dwis', + in_dwi_lin, in_dwi_plan, in_dwi_sph, '--in_bvals', + in_bval_lin, in_bval_plan, in_bval_sph, + '--in_bvecs', in_bvec_lin, in_bvec_plan, + in_bvec_sph, '--in_bdeltas', '1', '-0.5', '0', + '--min_nvox', '1', '--precision', '4', '-f') + + assert ret.success + + for frf_file in ['wm_frf.txt', 'gm_frf.txt', 'csf_frf.txt']: + with open(frf_file, "r") as f: + for item in f.readline().strip("\n").split(" "): + assert len(item.split(".")[1]) == 4 + + def test_execution_processing(script_runner, monkeypatch): monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) in_dwi_lin = os.path.join(SCILPY_HOME, 'btensor_testdata', diff --git a/scripts/tests/test_frf_msmt.py b/scripts/tests/test_frf_msmt.py index 5eddf67fe..a919e5afc 100644 --- a/scripts/tests/test_frf_msmt.py +++ b/scripts/tests/test_frf_msmt.py @@ -75,6 +75,27 @@ def test_roi_radii_shape_parameter2(script_runner, monkeypatch): assert (not ret.success) +def test_outputs_precision(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'commit_amico', + 'dwi.nii.gz') + in_bval = os.path.join(SCILPY_HOME, 'commit_amico', + 'dwi.bval') + in_bvec = os.path.join(SCILPY_HOME, 'commit_amico', + 'dwi.bvec') + mask = os.path.join(SCILPY_HOME, 'commit_amico', 'mask.nii.gz') + ret = script_runner.run('scil_frf_msmt.py', in_dwi, + in_bval, in_bvec, 'wm_frf.txt', 'gm_frf.txt', + 'csf_frf.txt', '--mask', mask, '--min_nvox', '20', + '--precision', '4', '-f') + assert ret.success + + for frf_file in ['wm_frf.txt', 'gm_frf.txt', 'csf_frf.txt']: + with open(frf_file, "r") as f: + for item in f.readline().strip("\n").split(" "): + assert len(item.split(".")[1]) == 4 + + def test_execution_processing(script_runner, monkeypatch): monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) in_dwi = os.path.join(SCILPY_HOME, 'commit_amico', diff --git a/scripts/tests/test_frf_set_diffusivities.py b/scripts/tests/test_frf_set_diffusivities.py index 5783433bc..de40559fa 100644 --- a/scripts/tests/test_frf_set_diffusivities.py +++ b/scripts/tests/test_frf_set_diffusivities.py @@ -3,6 +3,7 @@ import os import tempfile +import numpy as np from scilpy import SCILPY_HOME from scilpy.io.fetcher import fetch_data, get_testing_files_dict @@ -34,6 +35,24 @@ def test_execution_processing_msmt(script_runner, monkeypatch): assert ret.success +def test_outputs_precision(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_frf = os.path.join(SCILPY_HOME, 'commit_amico', 'wm_frf.txt') + ret = script_runner.run('scil_frf_set_diffusivities.py', in_frf, + '15,4,4,13,4,4,12,5,5', 'new_frf.txt', + '--precision', '4', '-f') + assert ret.success + + expected = [ + "0.0015 0.0004 0.0004 3076.7249", + "0.0013 0.0004 0.0004 3076.7249", + "0.0012 0.0005 0.0005 3076.7249" + ] + with open('new_frf.txt', 'r') as result: + for i, line in enumerate(result.readlines()): + assert line.strip("\n") == expected[i] + + def test_execution_processing__wrong_input(script_runner, monkeypatch): monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) in_frf = os.path.join(SCILPY_HOME, 'commit_amico', 'wm_frf.txt') diff --git a/scripts/tests/test_frf_ssst.py b/scripts/tests/test_frf_ssst.py index 243657fa9..5f02d018e 100644 --- a/scripts/tests/test_frf_ssst.py +++ b/scripts/tests/test_frf_ssst.py @@ -60,6 +60,18 @@ def test_roi_radii_shape_parameter(script_runner, monkeypatch): assert (not ret.success) +def test_outputs_precision(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + ret = script_runner.run('scil_frf_ssst.py', in_dwi, + in_bval, in_bvec, 'frf.txt', + '--precision', '4', '-f') + assert ret.success + + with open("frf.txt", "r") as f: + for item in f.readline().strip("\n").split(" "): + assert len(item.split(".")[1]) == 4 + + def test_execution_processing(script_runner, monkeypatch): monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) ret = script_runner.run('scil_frf_ssst.py', in_dwi, diff --git a/scripts/tests/test_gradients_apply_transform.py b/scripts/tests/test_gradients_apply_transform.py old mode 100755 new mode 100644 diff --git a/scripts/tests/test_mrds_metrics.py b/scripts/tests/test_mrds_metrics.py new file mode 100644 index 000000000..62b4a8cfc --- /dev/null +++ b/scripts/tests/test_mrds_metrics.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['mrds.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run('scil_mrds_metrics.py', '--help') + assert ret.success + + +def test_execution_mrds_all_metrics(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + in_evals = os.path.join(SCILPY_HOME, + 'mrds', 'sub-01_MRDS_eigenvalues.nii.gz') + + # no option + ret = script_runner.run('scil_mrds_metrics.py', + in_evals, + '-f') + assert ret.success + + +def test_execution_mrds_not_all_metrics(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + in_evals = os.path.join(SCILPY_HOME, + 'mrds', 'sub-01_MRDS_eigenvalues.nii.gz') + in_mask = os.path.join(SCILPY_HOME, + 'mrds', 'sub-01_mask.nii.gz') + # no option + ret = script_runner.run('scil_mrds_metrics.py', + in_evals, + '--mask', in_mask, + '--not_all', + '--fa', 'sub-01_MRDS_FA.nii.gz', + '--ad', 'sub-01_MRDS_AD.nii.gz', + '--rd', 'sub-01_MRDS_RD.nii.gz', + '--md', 'sub-01_MRDS_MD.nii.gz', + '-f') + assert ret.success diff --git a/scripts/tests/test_mrds_select_number_of_tensors.py b/scripts/tests/test_mrds_select_number_of_tensors.py new file mode 100644 index 000000000..7602c3104 --- /dev/null +++ b/scripts/tests/test_mrds_select_number_of_tensors.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['mrds.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run('scil_mrds_select_number_of_tensors.py', '--help') + assert ret.success + + +def test_execution_mrds(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + in_nufo = os.path.join(SCILPY_HOME, + 'mrds', 'sub-01_nufo.nii.gz') + # no option + ret = script_runner.run('scil_mrds_select_number_of_tensors.py', + SCILPY_HOME + '/mrds/sub-01', + in_nufo, + '-f') + assert ret.success + + +def test_execution_mrds_w_mask(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + in_nufo = os.path.join(SCILPY_HOME, + 'mrds', 'sub-01_nufo.nii.gz') + in_mask = os.path.join(SCILPY_HOME, 'mrds', + 'sub-01_mask.nii.gz') + + ret = script_runner.run('scil_mrds_select_number_of_tensors.py', + SCILPY_HOME + '/mrds/sub-01', + in_nufo, + '--mask', in_mask, + '-f') + assert ret.success diff --git a/scripts/tests/test_sh_to_sf.py b/scripts/tests/test_sh_to_sf.py old mode 100755 new mode 100644 index b0dedfbcc..8703876c3 --- a/scripts/tests/test_sh_to_sf.py +++ b/scripts/tests/test_sh_to_sf.py @@ -28,7 +28,8 @@ def test_execution_in_sphere(script_runner, monkeypatch): 'sf_724.nii.gz', '--in_bval', in_bval, '--in_b0', in_b0, '--out_bval', 'sf_724.bval', '--out_bvec', 'sf_724.bvec', - '--sphere', 'symmetric724', '--dtype', 'float32') + '--sphere', 'symmetric724', '--dtype', 'float32', + '--processes', '1') assert ret.success @@ -43,14 +44,16 @@ def test_execution_in_bvec(script_runner, monkeypatch): 'sf_724.nii.gz', '--in_bval', in_bval, '--out_bval', 'sf_724.bval', '--out_bvec', 'sf_724.bvec', - '--in_bvec', in_bvec, '--dtype', 'float32', '-f') + '--in_bvec', in_bvec, '--dtype', 'float32', '-f', + '--processes', '1') assert ret.success # Test that fails if no bvals is given. ret = script_runner.run('scil_sh_to_sf.py', in_sh, 'sf_724.nii.gz', '--out_bvec', 'sf_724.bvec', - '--in_bvec', in_bvec, '--dtype', 'float32', '-f') + '--in_bvec', in_bvec, '--dtype', 'float32', '-f', + '--processes', '1') assert not ret.success @@ -60,9 +63,10 @@ def test_execution_no_bval(script_runner, monkeypatch): in_b0 = os.path.join(SCILPY_HOME, 'processing', 'fa.nii.gz') # --sphere but no --bval + # Testing multiprocessing option ret = script_runner.run('scil_sh_to_sf.py', in_sh, 'sf_724.nii.gz', '--in_b0', in_b0, '--out_bvec', 'sf_724.bvec', '--b0_scaling', '--sphere', 'symmetric724', '--dtype', 'float32', - '-f') + '-f', '--processes', '4') assert ret.success diff --git a/scripts/tests/test_tractogram_add_dps.py b/scripts/tests/test_tractogram_add_dps.py deleted file mode 100644 index b5e7f05d2..000000000 --- a/scripts/tests/test_tractogram_add_dps.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import numpy as np -import os -import tempfile - -from dipy.io.streamline import load_tractogram - -from scilpy import SCILPY_HOME -from scilpy.io.fetcher import fetch_data, get_testing_files_dict - -# If they already exist, this only takes 5 seconds (check md5sum) -fetch_data(get_testing_files_dict(), keys=['filtering.zip']) -tmp_dir = tempfile.TemporaryDirectory() - - -def test_help_option(script_runner): - ret = script_runner.run('scil_tractogram_add_dps.py', - '--help') - assert ret.success - - -def test_execution_add_dps(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'filtering', - 'bundle_4.trk') - sft = load_tractogram(in_bundle, 'same') - filename = 'vals.npy' - outname = 'out.trk' - np.save(filename, np.arange(len(sft))) - ret = script_runner.run('scil_tractogram_add_dps.py', - in_bundle, filename, 'key', outname, '-f') - assert ret.success - - -def test_execution_add_dps_missing_vals(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'filtering', - 'bundle_4.trk') - sft = load_tractogram(in_bundle, 'same') - filename = 'vals.npy' - outname = 'out.trk' - np.save(filename, np.arange(len(sft) - 10)) - ret = script_runner.run('scil_tractogram_add_dps.py', - in_bundle, filename, 'key', outname, '-f') - assert ret.stderr - - -def test_execution_add_dps_existing_key(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'filtering', - 'bundle_4.trk') - sft = load_tractogram(in_bundle, 'same') - filename = 'vals.npy' - outname = 'out.trk' - outname2 = 'out_2.trk' - np.save(filename, np.arange(len(sft))) - ret = script_runner.run('scil_tractogram_add_dps.py', - in_bundle, filename, 'key', outname, '-f') - assert ret.success - ret = script_runner.run('scil_tractogram_add_dps.py', - outname, filename, 'key', outname2) - assert not ret.success - - -def test_execution_add_dps_tck(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'filtering', - 'bundle_4.trk') - sft = load_tractogram(in_bundle, 'same') - filename = 'vals.npy' - outname = 'out.tck' - np.save(filename, np.arange(len(sft))) - ret = script_runner.run('scil_tractogram_add_dps.py', - in_bundle, filename, 'key', outname, '-f') - assert not ret.success diff --git a/scripts/tests/test_tractogram_dps_math.py b/scripts/tests/test_tractogram_dps_math.py new file mode 100644 index 000000000..03b758de2 --- /dev/null +++ b/scripts/tests/test_tractogram_dps_math.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import os +import tempfile + +from dipy.io.streamline import load_tractogram, save_tractogram + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['filtering.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run('scil_tractogram_dps_math.py', + '--help') + assert ret.success + + +def test_execution_dps_math_import(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'filtering', + 'bundle_4.trk') + sft = load_tractogram(in_bundle, 'same') + filename = 'vals.npy' + outname = 'out.trk' + np.save(filename, np.arange(len(sft))) + ret = script_runner.run('scil_tractogram_dps_math.py', + in_bundle, 'import', 'key', + '--in_dps_file', filename, + '--out_tractogram', outname, + '-f') + assert ret.success + + +def test_execution_dps_math_import_with_missing_vals(script_runner, + monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'filtering', + 'bundle_4.trk') + sft = load_tractogram(in_bundle, 'same') + filename = 'vals.npy' + outname = 'out.trk' + np.save(filename, np.arange(len(sft) - 10)) + ret = script_runner.run('scil_tractogram_dps_math.py', + in_bundle, 'import', 'key', + '--in_dps_file', filename, + '--out_tractogram', outname, + '-f') + assert ret.stderr + + +def test_execution_dps_math_import_with_existing_key(script_runner, + monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'filtering', + 'bundle_4.trk') + sft = load_tractogram(in_bundle, 'same') + filename = 'vals.npy' + outname = 'out.trk' + outname2 = 'out_2.trk' + np.save(filename, np.arange(len(sft))) + ret = script_runner.run('scil_tractogram_dps_math.py', + in_bundle, 'import', 'key', + '--in_dps_file', filename, + '--out_tractogram', outname, + '-f') + assert ret.success + ret = script_runner.run('scil_tractogram_dps_math.py', + outname, 'import', 'key', + '--in_dps_file', filename, + '--out_tractogram', outname2,) + assert not ret.success + + +def test_execution_dps_math_tck_output(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'filtering', + 'bundle_4.trk') + sft = load_tractogram(in_bundle, 'same') + filename = 'vals.npy' + outname = 'out.tck' + np.save(filename, np.arange(len(sft))) + ret = script_runner.run('scil_tractogram_dps_math.py', + in_bundle, 'import', 'key', + '--in_dps_file', filename, + '--out_tractogram', outname, + '-f') + assert not ret.success + + +def test_execution_dps_math_delete(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle_no_key = os.path.join(SCILPY_HOME, 'filtering', + 'bundle_4.trk') + in_bundle = 'bundle_4.trk' + sft = load_tractogram(in_bundle_no_key, 'same') + sft.data_per_streamline = { + "key": [0] * len(sft) + } + save_tractogram(sft, in_bundle) + outname = 'out.trk' + ret = script_runner.run('scil_tractogram_dps_math.py', + in_bundle, 'delete', 'key', + '--out_tractogram', outname, + '-f') + assert ret.success + + +def test_execution_dps_math_delete_no_key(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'filtering', + 'bundle_4.trk') + outname = 'out.trk' + ret = script_runner.run('scil_tractogram_dps_math.py', + in_bundle, 'delete', 'key', + '--out_tractogram', outname, + '-f') + assert not ret.success + + +def test_execution_dps_math_export(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle_no_key = os.path.join(SCILPY_HOME, 'filtering', + 'bundle_4.trk') + in_bundle = 'bundle_4.trk' + sft = load_tractogram(in_bundle_no_key, 'same') + sft.data_per_streamline = { + "key": [0] * len(sft) + } + save_tractogram(sft, in_bundle) + filename = 'out.txt' + ret = script_runner.run('scil_tractogram_dps_math.py', + in_bundle, 'export', 'key', + '--out_dps_file', filename, + '-f') + assert ret.success + + +def test_execution_dps_math_export_no_key(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'filtering', + 'bundle_4.trk') + filename = 'out.txt' + ret = script_runner.run('scil_tractogram_dps_math.py', + in_bundle, 'export', 'key', + '--out_dps_file', filename, + '-f') + assert not ret.success diff --git a/scripts/tests/test_tractogram_filter_collisions.py b/scripts/tests/test_tractogram_filter_collisions.py new file mode 100644 index 000000000..033563b1e --- /dev/null +++ b/scripts/tests/test_tractogram_filter_collisions.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile +import numpy as np +import nibabel as nib + +from scilpy.io.streamlines import save_tractogram +from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin + +tmp_dir = tempfile.TemporaryDirectory() + + +def init_data(): + streamlines = [[[5., 1., 5.], [5., 5., 9.], [7., 9., 9.], [13., 11., 9.], + [5., 7., 7.]], [[7., 7., 7.], [9., 9., 9.]]] + + mask = np.ones((15, 15, 15)) + affine = np.eye(4) + header = nib.nifti2.Nifti2Header() + extra = { + 'affine': affine, + 'dimensions': (15, 15, 15), + 'voxel_size': 1., + 'voxel_order': "RAS" + } + mask_img = nib.nifti2.Nifti2Image(mask, affine, header, extra) + + sft = StatefulTractogram(streamlines, mask_img, Space.VOX, Origin.NIFTI) + save_tractogram(sft, 'tractogram.trk', True) + + +def test_help_option(script_runner): + ret = script_runner.run('scil_tractogram_filter_collisions.py', '--help') + assert ret.success + + +def test_execution_filtering(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + + diameters = [5, 1] + np.savetxt('diameters.txt', diameters) + + ret = script_runner.run('scil_tractogram_filter_collisions.py', + 'tractogram.trk', 'diameters.txt', 'clean.trk', + '-f') + assert ret.success + + +def test_execution_filtering_save_colliding(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + + diameters = [5, 1] + np.savetxt('diameters.txt', diameters) + + ret = script_runner.run('scil_tractogram_filter_collisions.py', + 'tractogram.trk', 'diameters.txt', 'clean.trk', + '--save_colliding', '-f') + assert ret.success + + +def test_execution_filtering_single_diameter(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + + diameters = [5] + np.savetxt('diameters.txt', diameters) + + ret = script_runner.run('scil_tractogram_filter_collisions.py', + 'tractogram.trk', 'diameters.txt', 'clean.trk', + '-f') + assert ret.success + + +def test_execution_filtering_no_shuffle(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + + diameters = [5, 1] + np.savetxt('diameters.txt', diameters) + + ret = script_runner.run('scil_tractogram_filter_collisions.py', + 'tractogram.trk', 'diameters.txt', 'clean.trk', + '--disable_shuffling', '-f') + assert ret.success + + +def test_execution_filtering_min_distance(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + + diameters = [0.001, 0.001] + np.savetxt('diameters.txt', diameters) + + ret = script_runner.run('scil_tractogram_filter_collisions.py', + 'tractogram.trk', 'diameters.txt', 'clean.trk', + '--min_distance', '5', '-f') + assert ret.success + + +def test_execution_filtering_metrics(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + + # No collision, as we want to keep two streamlines for this test. + diameters = [0.001, 0.001] + np.savetxt('diameters.txt', diameters) + + ret = script_runner.run('scil_tractogram_filter_collisions.py', + 'tractogram.trk', 'diameters.txt', 'clean.trk', + '--out_metrics', 'metrics.json', '-f') + assert ret.success + + +def test_execution_rotation_matrix(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + init_data() + + # No collision, as we want to keep two streamlines for this test. + diameters = [0.001, 0.001] + np.savetxt('diameters.txt', diameters) + + ret = script_runner.run('scil_tractogram_filter_collisions.py', + 'tractogram.trk', 'diameters.txt', 'clean.trk', + '--out_rotation_matrix', 'rotation.mat', '-f') + assert ret.success diff --git a/scripts/tests/test_tractogram_segment_bundles_with_bundleseg.py b/scripts/tests/test_tractogram_segment_with_bundleseg.py similarity index 100% rename from scripts/tests/test_tractogram_segment_bundles_with_bundleseg.py rename to scripts/tests/test_tractogram_segment_with_bundleseg.py diff --git a/scripts/tests/test_viz_tractogram_collisions.py b/scripts/tests/test_viz_tractogram_collisions.py new file mode 100644 index 000000000..2d8819039 --- /dev/null +++ b/scripts/tests/test_viz_tractogram_collisions.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +def test_help_option(script_runner): + ret = script_runner.run('scil_viz_tractogram_collisions.py', '--help') + assert ret.success diff --git a/scripts/tests/test_volume_apply_transform.py b/scripts/tests/test_volume_apply_transform.py index b5cfd749c..1b9ba4407 100644 --- a/scripts/tests/test_volume_apply_transform.py +++ b/scripts/tests/test_volume_apply_transform.py @@ -27,5 +27,36 @@ def test_execution_bst(script_runner, monkeypatch): 'output0GenericAffine.mat') ret = script_runner.run('scil_volume_apply_transform.py', in_model, in_fa, in_aff, - 'template_lin.nii.gz', '--inverse') + 'template_lin.nii.gz', '--inverse', + '-f') + assert ret.success + + +def test_execution_interp_nearest(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_model = os.path.join(SCILPY_HOME, 'bst', 'template', + 'template0.nii.gz') + in_fa = os.path.join(SCILPY_HOME, 'bst', + 'fa.nii.gz') + in_aff = os.path.join(SCILPY_HOME, 'bst', + 'output0GenericAffine.mat') + ret = script_runner.run('scil_volume_apply_transform.py', + in_model, in_fa, in_aff, + 'template_lin.nii.gz', '--inverse', + '--interp', 'nearest', '-f') + assert ret.success + + +def test_execution_interp_lin(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_model = os.path.join(SCILPY_HOME, 'bst', 'template', + 'template0.nii.gz') + in_fa = os.path.join(SCILPY_HOME, 'bst', + 'fa.nii.gz') + in_aff = os.path.join(SCILPY_HOME, 'bst', + 'output0GenericAffine.mat') + ret = script_runner.run('scil_volume_apply_transform.py', + in_model, in_fa, in_aff, + 'template_lin.nii.gz', '--inverse', + '--interp', 'linear', '-f') assert ret.success diff --git a/scripts/tests/test_volume_resample.py b/scripts/tests/test_volume_resample.py index 884f2cb7d..9133e5dbe 100644 --- a/scripts/tests/test_volume_resample.py +++ b/scripts/tests/test_volume_resample.py @@ -21,7 +21,15 @@ def test_help_option(script_runner): def test_execution_given_size(script_runner, monkeypatch): monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) ret = script_runner.run('scil_volume_resample.py', in_img, - 'fa_resample.nii.gz', '--voxel_size', '2') + 'fa_resample_2.nii.gz', '--voxel_size', '2') + assert ret.success + + +def test_execution_force_voxel(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + ret = script_runner.run('scil_volume_resample.py', in_img, + 'fa_resample_4.nii.gz', '--voxel_size', '4', + '--enforce_voxel_size') assert ret.success @@ -31,3 +39,12 @@ def test_execution_ref(script_runner, monkeypatch): ret = script_runner.run('scil_volume_resample.py', in_img, 'fa_resample2.nii.gz', '--ref', ref) assert ret.success + + +def test_execution_ref_force(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + ref = os.path.join(SCILPY_HOME, 'others', 'fa_resample.nii.gz') + ret = script_runner.run('scil_volume_resample.py', in_img, + 'fa_resample_ref.nii.gz', '--ref', ref, + '--enforce_dimensions') + assert ret.success