diff --git a/tissues/femoral_cartilage.py b/tissues/femoral_cartilage.py index 88ae14c..c5ed81f 100644 --- a/tissues/femoral_cartilage.py +++ b/tissues/femoral_cartilage.py @@ -39,6 +39,9 @@ class FemoralCartilage(Tissue): LATERAL_KEY = 1 SAGGITAL_KEYS = [MEDIAL_KEY, LATERAL_KEY] + # Background Key + BACKGROUND_KEY = 3 + def __init__(self, weights_dir=None, medial_to_lateral=None): """ :param weights_dir: Directory to weights files @@ -171,6 +174,7 @@ def split_regions(self, unrolled_quantitative_map): """ # create unrolled mask from unrolled map + unrolled_quantitative_map = np.nan_to_num(unrolled_quantitative_map) unrolled_mask_indexes = np.nonzero(unrolled_quantitative_map) unrolled_mask = np.zeros((unrolled_quantitative_map.shape[0], unrolled_quantitative_map.shape[1])) unrolled_mask[unrolled_mask_indexes] = 1 @@ -178,13 +182,13 @@ def split_regions(self, unrolled_quantitative_map): # find the center of mass of the unrolled mask center_of_mass = sni.measurements.center_of_mass(unrolled_mask) - unrolled_mask[np.where(unrolled_mask < 1)] = 3 + unrolled_mask[np.where(unrolled_mask < 1)] = self.BACKGROUND_KEY lateral_mask = np.copy(unrolled_mask)[:, 0:np.int(np.around(center_of_mass[1]))] medial_mask = np.copy(unrolled_mask)[:, np.int(np.around(center_of_mass[1])):] - lateral_mask[np.where(lateral_mask < 3)] = self.LATERAL_KEY - medial_mask[np.where(medial_mask < 3)] = self.MEDIAL_KEY + lateral_mask[np.where(lateral_mask < self.BACKGROUND_KEY)] = self.LATERAL_KEY + medial_mask[np.where(medial_mask < self.BACKGROUND_KEY)] = self.MEDIAL_KEY if self.medial_to_lateral: ml_mask = np.concatenate((medial_mask, lateral_mask), axis=1) @@ -196,9 +200,9 @@ def split_regions(self, unrolled_quantitative_map): central_mask = np.copy(unrolled_mask)[np.int(center_of_mass[0]):np.int(center_of_mass[0]) + 10, :] posterior_mask = np.copy(unrolled_mask)[np.int(center_of_mass[0]) + 10:, :] - anterior_mask[np.where(anterior_mask < 3)] = self.ANTERIOR_KEY - posterior_mask[np.where(posterior_mask < 3)] = self.POSTERIOR_KEY - central_mask[np.where(central_mask < 3)] = self.CENTRAL_KEY + anterior_mask[np.where(anterior_mask < self.BACKGROUND_KEY)] = self.ANTERIOR_KEY + posterior_mask[np.where(posterior_mask < self.BACKGROUND_KEY)] = self.POSTERIOR_KEY + central_mask[np.where(central_mask < self.BACKGROUND_KEY)] = self.CENTRAL_KEY acp_mask = np.concatenate((anterior_mask, central_mask, posterior_mask), axis=0) @@ -209,6 +213,9 @@ def split_regions(self, unrolled_quantitative_map): self.regions_mask = np.concatenate((ml_mask, acp_mask), axis=2) + # convert backgorund label to NaN + self.regions_mask[self.regions_mask == self.BACKGROUND_KEY] = np.nan + assert (self.regions_mask[..., 0] == ml_mask[..., 0]).all() assert (self.regions_mask[..., 1] == acp_mask[..., 0]).all()