diff --git a/tissues/femoral_cartilage.py b/tissues/femoral_cartilage.py index c5ed81f..fdcff32 100644 --- a/tissues/femoral_cartilage.py +++ b/tissues/femoral_cartilage.py @@ -208,16 +208,17 @@ def split_regions(self, unrolled_quantitative_map): assert ml_mask.shape == acp_mask.shape + # convert backgorund label to NaN + ml_mask[ml_mask == self.BACKGROUND_KEY] = np.nan + acp_mask[acp_mask == self.BACKGROUND_KEY] = np.nan + ml_mask = ml_mask[..., np.newaxis] acp_mask = acp_mask[..., np.newaxis] self.regions_mask = np.concatenate((ml_mask, acp_mask), axis=2) - # convert backgorund label to NaN - self.regions_mask[self.regions_mask == self.BACKGROUND_KEY] = np.nan - - assert (self.regions_mask[..., 0] == ml_mask[..., 0]).all() - assert (self.regions_mask[..., 1] == acp_mask[..., 0]).all() + assert np.allclose(self.regions_mask[..., 0], ml_mask[..., 0], equal_nan=True) + assert np.allclose(self.regions_mask[..., 1], acp_mask[..., 0], equal_nan=True) def calc_quant_vals(self, quant_map, map_type): """