From 037a11ed9fa555b951483653bce7ed73bdf777a1 Mon Sep 17 00:00:00 2001 From: Arjun Desai Date: Fri, 5 Oct 2018 14:22:54 -0700 Subject: [PATCH] adding normalized map support --- defaults.py | 4 +++- msk/knee.py | 2 +- pipeline.py | 2 +- scan_sequences/cones.py | 10 ++++------ scan_sequences/cube_quant.py | 9 ++++----- scan_sequences/dess.py | 6 +++--- scripts/test-script | 2 +- test_pipeline.py | 4 ++-- tissues/femoral_cartilage.py | 34 ++++++++++++++++++++++++++++------ tissues/tissue.py | 2 +- utils/quant_vals.py | 25 +++++++++++++------------ 11 files changed, 61 insertions(+), 39 deletions(-) diff --git a/defaults.py b/defaults.py index a867043..678157a 100644 --- a/defaults.py +++ b/defaults.py @@ -1 +1,3 @@ -DEFAULT_BATCH_SIZE = 32 \ No newline at end of file +DEFAULT_BATCH_SIZE = 32 + +FIX_VISUALIZATION_BOUNDS = True \ No newline at end of file diff --git a/msk/knee.py b/msk/knee.py index 1c7aa55..11900e7 100644 --- a/msk/knee.py +++ b/msk/knee.py @@ -1,5 +1,5 @@ from tissues.femoral_cartilage import FemoralCartilage -from utils.quant_vals import QuantitativeValue as QV +from utils.quant_vals import QuantitativeValues as QV from utils import io_utils KNEE_KEY = 'knee' diff --git a/pipeline.py b/pipeline.py index 1b8a705..a67c485 100644 --- a/pipeline.py +++ b/pipeline.py @@ -11,7 +11,7 @@ from models.get_model import get_model from tissues.femoral_cartilage import FemoralCartilage -from utils.quant_vals import QuantitativeValue as QV, get_qv +from utils.quant_vals import QuantitativeValues as QV, get_qv import file_constants as fc from msk import knee diff --git a/scan_sequences/cones.py b/scan_sequences/cones.py index 2b7b479..25f71c9 100644 --- a/scan_sequences/cones.py +++ b/scan_sequences/cones.py @@ -13,7 +13,7 @@ __EXPECTED_NUM_ECHO_TIMES__ = 4 __R_SQUARED_THRESHOLD__ = 0.9 -__INITIAL_P0_VALS__ = (1.0, 1/30.0) +__INITIAL_T2_STAR_VAL__ = 30.0 # ms __T2_STAR_LOWER_BOUND__ = 0 __T2_STAR_UPPER_BOUND__ = np.inf @@ -56,7 +56,6 @@ def interregister(self, target_path, mask_path=None): raw_filepaths = dict() echo_time_inds = natsorted(list(subvolumes.keys())) - print(echo_time_inds) for i in range(len(echo_time_inds)): raw_filepath = os.path.join(temp_raw_dirpath, '%03d.nii.gz' % i) @@ -107,7 +106,6 @@ def interregister(self, target_path, mask_path=None): transformation_files = reg_output.transform warped_files = [(base_echo_time, reg_output.warped_file)] - print(raw_filepaths) files = [] for echo_time_ind in raw_filepaths.keys(): filepath = raw_filepaths[echo_time_ind] @@ -146,8 +144,8 @@ def save_data(self, save_dirpath): if self.t2star_map is not None: data = {'data': self.t2star_map} - io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % qv.QuantitativeValue.T2_STAR.name.lower()), data) - io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % qv.QuantitativeValue.T2_STAR.name.lower()), + io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % qv.QuantitativeValues.T2_STAR.name.lower()), data) + io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % qv.QuantitativeValues.T2_STAR.name.lower()), self.t2star_map, self.pixel_spacing) # Save interregistered files @@ -183,7 +181,7 @@ def generate_t2_star_map(self): svs.append(svr) svs = np.concatenate(svs) - vals, r_squared = qv.fit_monoexp_tc(spin_lock_times, svs, __INITIAL_P0_VALS__) + vals, r_squared = qv.fit_monoexp_tc(spin_lock_times, svs, __INITIAL_T2_STAR_VAL__) map_unfiltered = vals.reshape(original_shape) r_squared = r_squared.reshape(original_shape) diff --git a/scan_sequences/cube_quant.py b/scan_sequences/cube_quant.py index 99c7c7b..5b9cc7b 100644 --- a/scan_sequences/cube_quant.py +++ b/scan_sequences/cube_quant.py @@ -12,7 +12,7 @@ __EXPECTED_NUM_SPIN_LOCK_TIMES__ = 4 __R_SQUARED_THRESHOLD__ = 0.9 -__INITIAL_P0_VALS__ = (1.0, 1/30.0) +__INITIAL_T1_RHO_VAL__ = 70.0 __T1_RHO_LOWER_BOUND__ = 0 __T1_RHO_UPPER_BOUND__ = np.inf @@ -136,9 +136,8 @@ def generate_t1_rho_map(self): svs.append(svr) svs = np.concatenate(svs) - print(svs.shape) - vals, r_squared = qv.fit_monoexp_tc(spin_lock_times, svs, __INITIAL_P0_VALS__) + vals, r_squared = qv.fit_monoexp_tc(spin_lock_times, svs, __INITIAL_T1_RHO_VAL__) map_unfiltered = vals.reshape(original_shape) r_squared = r_squared.reshape(original_shape) @@ -213,8 +212,8 @@ def save_data(self, save_dirpath): if self.t1rho_map is not None: data = {'data': self.t1rho_map} - io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % qv.QuantitativeValue.T1_RHO.name.lower()), data) - io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % qv.QuantitativeValue.T1_RHO.name.lower()), self.t1rho_map, + io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % qv.QuantitativeValues.T1_RHO.name.lower()), data) + io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % qv.QuantitativeValues.T1_RHO.name.lower()), self.t1rho_map, self.pixel_spacing) # Save interregistered files diff --git a/scan_sequences/dess.py b/scan_sequences/dess.py index 71e358a..dadb751 100644 --- a/scan_sequences/dess.py +++ b/scan_sequences/dess.py @@ -4,7 +4,7 @@ from scan_sequences.scans import TargetSequence from utils import dicom_utils, im_utils, io_utils -from utils.quant_vals import QuantitativeValue +from utils.quant_vals import QuantitativeValues class Dess(TargetSequence): NAME = 'dess' @@ -162,8 +162,8 @@ def save_data(self, save_dirpath): save_dirpath = self.__save_dir__(save_dirpath) data = {'data': self.t2map} - io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % QuantitativeValue.T2.name.lower()), data) - io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % QuantitativeValue.T2.name.lower()), self.t2map, self.pixel_spacing) + io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % QuantitativeValues.T2.name.lower()), data) + io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % QuantitativeValues.T2.name.lower()), self.t2map, self.pixel_spacing) # write echos for i in range(len(self.subvolumes)): diff --git a/scripts/test-script b/scripts/test-script index 959af2c..0715fd6 100755 --- a/scripts/test-script +++ b/scripts/test-script @@ -1,5 +1,5 @@ #!/bin/bash -WEIGHTS_DIRECTORY="" +WEIGHTS_DIRECTORY="/Users/arjundesai/Documents/stanford/research/msk_pipeline_raw/weights" if [ -z "$WEIGHTS_DIRECTORY" ]; then echo "Please define WEIGHTS_DIRECTORY in script. Use the absolute path" exit 125 diff --git a/test_pipeline.py b/test_pipeline.py index a5d27de..f0d73e6 100644 --- a/test_pipeline.py +++ b/test_pipeline.py @@ -9,7 +9,7 @@ from scan_sequences.cube_quant import CubeQuant from utils import io_utils, dicom_utils -from utils.quant_vals import QuantitativeValue +from utils.quant_vals import QuantitativeValues import file_constants as fc @@ -117,7 +117,7 @@ def test_t2_map_load(self): scan = pipeline.handle_dess(vargin) - scan.tissues[0].calc_quant_vals(scan.t2map, QuantitativeValue.T2) + scan.tissues[0].calc_quant_vals(scan.t2map, QuantitativeValues.T2) print(scan.t2map.shape) diff --git a/tissues/femoral_cartilage.py b/tissues/femoral_cartilage.py index 3e71191..f593225 100644 --- a/tissues/femoral_cartilage.py +++ b/tissues/femoral_cartilage.py @@ -9,9 +9,15 @@ import nipy.labs.mask as nlm -from utils.quant_vals import QuantitativeValue +from utils.quant_vals import QuantitativeValues import matplotlib.pyplot as plt +import defaults +import warnings + +BOUNDS = {QuantitativeValues.T2: 100.0, + QuantitativeValues.T1_RHO: 150.0, + QuantitativeValues.T2_STAR: 100.0} class FemoralCartilage(Tissue): ID = 1 @@ -246,9 +252,10 @@ def calc_quant_vals(self, quant_map, map_type): sagital_keys = ['anterior', 'central', 'posterior'] df = pd.DataFrame(data=np.transpose(tissue_values), index=sagital_keys, columns=pd.MultiIndex.from_tuples(zip(depth_keys, coronal_keys))) - maps = [{'title': 'T2 deep', 'data': deep, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': 't2deep.png'}, - {'title': 'T2 superficial', 'data': superficial, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': 't2superficial.png'}, - {'title': 'T2 total', 'data': total, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': 't2total.png'}] + qv_name = map_type.name + maps = [{'title': '%s deep' % qv_name, 'data': deep, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': '%s_deep.png' % qv_name}, + {'title': '%s superficial' % qv_name, 'data': superficial, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': '%s_superficial.png' % qv_name}, + {'title': '%s total' % qv_name, 'data': total, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': '%s_total.png' % qv_name}] self.__store_quant_vals__(maps, df, map_type) @@ -261,7 +268,7 @@ def __save_quant_data__(self, dirpath): q_names = [] dfs = [] - for quant_val in QuantitativeValue: + for quant_val in QuantitativeValues: if quant_val.name not in self.quant_vals.keys(): continue @@ -276,12 +283,27 @@ def __save_quant_data__(self, dirpath): ylabel = 'Angle (binned)' title = q_map_data['title'] data_map = q_map_data['data'] + plt.clf() - plt.imshow(data_map, cmap='jet') + + upper_bound = BOUNDS[quant_val] + is_picture_written = False + if defaults.FIX_VISUALIZATION_BOUNDS: + if np.sum(data_map <= upper_bound) == 0: + plt.imshow(data_map, cmap='jet', vmin=0.0, vmax=BOUNDS[quant_val]) + is_picture_written = True + else: + warnings.warn('%s: Pixel value exceeded upper bound (%0.1f). Using normalized scale.' + % (quant_val.name, upper_bound)) + + if not is_picture_written: + plt.imshow(data_map, cmap='jet') + plt.xlabel(xlabel) plt.ylabel(ylabel) plt.title(title) plt.colorbar() + plt.savefig(filepath) if len(dfs) > 0: diff --git a/tissues/tissue.py b/tissues/tissue.py index 3234258..42341f1 100644 --- a/tissues/tissue.py +++ b/tissues/tissue.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import os from utils import io_utils -from utils.quant_vals import QuantitativeValue +from utils.quant_vals import QuantitativeValues import cv2 import numpy as np diff --git a/utils/quant_vals.py b/utils/quant_vals.py index 2ab14fc..d7bee80 100644 --- a/utils/quant_vals.py +++ b/utils/quant_vals.py @@ -9,19 +9,19 @@ __R_SQUARED_THRESHOLD__ = 0.9 -class QuantitativeValue(Enum): +class QuantitativeValues(Enum): T1_RHO = 1 T2 = 2 T2_STAR = 3 def get_qv(id): - for qv in QuantitativeValue: + for qv in QuantitativeValues: if qv.name.lower() == id or qv.value == id: return qv -def fit_mono_exp(x, y, p0=None): +def __fit_mono_exp__(x, y, p0=None): def func(t, a, b): exp = np.exp(b * t) return a * exp @@ -40,8 +40,9 @@ def func(t, a, b): return popt, r_squared -def fit_monoexp_tc(x, ys, p0): - vals = np.zeros([1, ys.shape[-1]]) +def fit_monoexp_tc(x, ys, tc0): + p0 = (1.0, -1/tc0) + time_constants = np.zeros([1, ys.shape[-1]]) r_squared = np.zeros([1, ys.shape[-1]]) warned_negative = False @@ -56,17 +57,17 @@ def fit_monoexp_tc(x, ys, p0): continue try: - params, r2 = fit_mono_exp(x, y, p0=p0) - t1_rho = abs(params[-1]) + params, r2 = __fit_mono_exp__(x, y, p0=p0) + tc = 1 / abs(params[-1]) except RuntimeError: - t1_rho, r2 = (np.nan, 0.0) + tc, r2 = (np.nan, 0.0) - vals[..., i] = t1_rho + time_constants[..., i] = tc r_squared[..., i] = r2 - return vals, r_squared + return time_constants, r_squared if __name__ == '__main__': - print(type(QuantitativeValue.T1_RHO.name)) - print(QuantitativeValue.T1_RHO.value== 1) \ No newline at end of file + print(type(QuantitativeValues.T1_RHO.name)) + print(QuantitativeValues.T1_RHO.value == 1) \ No newline at end of file