Skip to content

Commit

Permalink
initial implementation to explicitly specify segmentation masks to be…
Browse files Browse the repository at this point in the history
… used for extraction
  • Loading branch information
sophiamaedler committed Jan 9, 2025
1 parent d53eab4 commit 1a2a1b2
Showing 1 changed file with 45 additions and 23 deletions.
68 changes: 45 additions & 23 deletions src/scportrait/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,24 +267,53 @@ def _get_segmentation_info(self):
f"Found no segmentation masks with key {self.segmentation_key}. Cannot proceed with extraction."
)

# get relevant segmentation masks to perform extraction on
nucleus_key = f"{self.segmentation_key}_nucleus"
#intialize default values to track what should be extracted
self.nucleus_key = None
self.cytosol_key = None
self.extract_nucleus_mask = False
self.extract_cytosol_mask = False

if "segmentation_mask" in self.config:
allowed_mask_values = ["nucleus", "cytosol"]
allowed_mask_values = [f"{self.segmentation_key}_{x}" for x in allowed_mask_values]

if isinstance(self.config["segmentation_mask"], str):
assert (self.config["segmentation_mask"] in allowed_mask_values)

if "nucleus" in self.main_segmenation_mask:
self.nucleus_key = self.main_segmenation_mask
self.extract_nucleus_mask = True

elif "cytosol" in self.main_segmenation_mask:
self.cytosol_key = self.main_segmenation_mask
self.extract_cytosol_mask = True
else:
raise ValueError(f"Segmentation mask {self.main_segmenation_mask} is not a valid mask to extract from.")

if nucleus_key in relevant_masks:
self.extract_nucleus_mask = True
self.nucleus_key = nucleus_key
else:
self.extract_nucleus_mask = False
self.nucleus_key = None
elif isinstance(self.config["segmentation_mask"], list):
assert all(x in allowed_mask_values for x in self.config["segmentation_mask"])

cytosol_key = f"{self.segmentation_key}_cytosol"
for x in self.config["segmentation_mask"]:
if "nucleus" in x:
self.nucleus_key = x
self.extract_nucleus_mask = True
if "cytosol" in x:
self.cytosol_key = x
self.extract_cytosol_mask = True

if cytosol_key in relevant_masks:
self.extract_cytosol_mask = True
self.cytosol_key = cytosol_key
else:
self.extract_cytosol_mask = False
self.cytosol_key = None
# get relevant segmentation masks to perform extraction on
nucleus_key = f"{self.segmentation_key}_nucleus"

if nucleus_key in relevant_masks:
self.extract_nucleus_mask = True
self.nucleus_key = nucleus_key

cytosol_key = f"{self.segmentation_key}_cytosol"

if cytosol_key in relevant_masks:
self.extract_cytosol_mask = True
self.cytosol_key = cytosol_key

self.n_masks = np.sum([self.extract_nucleus_mask, self.extract_cytosol_mask])
self.masks = [x for x in [self.nucleus_key, self.cytosol_key] if x is not None]
Expand Down Expand Up @@ -661,25 +690,18 @@ def _transfer_tempmmap_to_hdf5(self):
# self._clear_cache(vars_to_delete=[cell_ids]) # this is not working as expected so we will just delete the variable directly

_, c, x, y = _tmp_single_cell_data.shape
print(_tmp_single_cell_data.shape)
print(self.image_size)
print(keep_index.shape)
single_cell_data = hf.create_dataset(
"single_cell_data",
shape=(len(keep_index), c, x, y),
chunks=(1, 1, self.image_size, self.image_size),
# compression=self.compression_type,
compression='gzip', #was lzf, gzip works
compression=self.compression_type,
dtype=np.float16,
# rdcc_nbytes=5242880000, # 5gb 1024 * 1024 * 5000
# rdcc_w0=1,
# rdcc_nslots=50000,
)

# populate dataset in loop to prevent loading of entire dataset into memory
# this is required to process large datasets to not run into memory issues
for ix, i in enumerate(keep_index):
single_cell_data[ix] = _tmp_single_cell_data[i]
single_cell_data[ix] = _tmp_single_cell_data[i]

self.log("single-cell data created")
del single_cell_data
Expand Down

0 comments on commit 1a2a1b2

Please sign in to comment.