-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use patch_size instead of chunk_size as base shape for sampling #4
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
8ca4afc
Changed PatchSampler to take as base the patche size instead of the i…
fercer 4720c76
Reverted change in the computation when masks elements are relative s…
fercer 9e1e985
Fixed spatial chunk size computation when patch sizes are grater than…
fercer d7fb5f5
Fixed missing patches from chunks smaller than the input image chunk …
fercer 1826378
Padding and stride added to PatchSampler and ImageBase classes to all…
fercer b894985
Added tests for stride and pad parameters of PatchSampler class
fercer aaa51f6
Fixed patch slices generation in PatchSampler to always retrieve patc…
fercer 7885749
Standardized patch sampling method to handle smaller and bigger mask …
fercer a5a4097
Added example notebook to documentation
fercer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ coverage.xml | |
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
example.py | ||
|
||
# Translations | ||
*.mo | ||
|
194 changes: 194 additions & 0 deletions
194
docs/source/examples/advanced_example_pytorch_inference.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
--- | ||
jupytext: | ||
text_representation: | ||
extension: .md | ||
format_name: myst | ||
format_version: 0.13 | ||
jupytext_version: 1.15.1 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
language: python | ||
name: python3 | ||
execution: | ||
timeout: 120 | ||
--- | ||
|
||
# Integration of ZarrDataset with PyTorch's DataLoader for inference (Advanced) | ||
|
||
```python | ||
import zarrdataset as zds | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
``` | ||
|
||
|
||
```python | ||
# These are images from the Image Data Resource (IDR) | ||
# https://idr.openmicroscopy.org/ that are publicly available and were | ||
# converted to the OME-NGFF (Zarr) format by the OME group. More examples | ||
# can be found at Public OME-Zarr data (Nov. 2020) | ||
# https://www.openmicroscopy.org/2020/11/04/zarr-data.html | ||
|
||
filenames = [ | ||
"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr" | ||
] | ||
``` | ||
|
||
|
||
```python | ||
import random | ||
import numpy as np | ||
|
||
# For reproducibility | ||
np.random.seed(478963) | ||
torch.manual_seed(478964) | ||
random.seed(478965) | ||
``` | ||
|
||
## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI) | ||
|
||
Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result. | ||
Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size by setting `allow_incomplete_patches=True`. | ||
|
||
|
||
```python | ||
patch_size = dict(Y=128, X=128) | ||
pad = dict(Y=16, X=16) | ||
patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True) | ||
``` | ||
|
||
Create a dataset from the list of filenames. All those files should be stored within their respective group "0". | ||
|
||
Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly | ||
|
||
|
||
```python | ||
image_specs = zds.ImagesDatasetSpecs( | ||
filenames=filenames, | ||
data_group="4", | ||
source_axes="TCZYX", | ||
axes="YXC", | ||
roi="0,0,0,0,0:1,-1,1,-1,-1" | ||
) | ||
|
||
my_dataset = zds.ZarrDataset(image_specs, | ||
patch_sampler=patch_sampler, | ||
return_positions=True) | ||
``` | ||
|
||
|
||
```python | ||
my_dataset | ||
``` | ||
|
||
|
||
|
||
|
||
ZarrDataset (PyTorch support:True, tqdm support :True) | ||
Modalities: images | ||
Transforms order: [] | ||
Using images modality as reference. | ||
Using <class 'zarrdataset._samplers.PatchSampler'> for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. | ||
|
||
|
||
|
||
Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32. | ||
|
||
|
||
```python | ||
import torchvision | ||
|
||
img_preprocessing = torchvision.transforms.Compose([ | ||
zds.ToDtype(dtype=np.float32), | ||
torchvision.transforms.ToTensor(), | ||
torchvision.transforms.Normalize(127, 255) | ||
]) | ||
|
||
my_dataset.add_transform("images", img_preprocessing) | ||
``` | ||
|
||
|
||
```python | ||
my_dataset | ||
``` | ||
|
||
|
||
|
||
|
||
ZarrDataset (PyTorch support:True, tqdm support :True) | ||
Modalities: images | ||
Transforms order: [('images',)] | ||
Using images modality as reference. | ||
Using <class 'zarrdataset._samplers.PatchSampler'> for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. | ||
|
||
|
||
|
||
## Create a DataLoader from the dataset object | ||
|
||
ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module. | ||
|
||
|
||
```python | ||
my_dataloader = DataLoader(my_dataset, num_workers=0) | ||
``` | ||
|
||
|
||
```python | ||
import dask.array as da | ||
import numpy as np | ||
import zarr | ||
|
||
z_arr = zarr.open("https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr/4", mode="r") | ||
|
||
H = z_arr.shape[-2] | ||
W = z_arr.shape[-1] | ||
|
||
pad_H = (128 - H) % 128 | ||
pad_W = (128 - W) % 128 | ||
z_prediction = zarr.zeros((H + pad_H, W + pad_W), dtype=np.float32, chunks=(128, 128)) | ||
z_prediction | ||
``` | ||
|
||
|
||
|
||
|
||
<zarr.core.Array (1152, 1408) float32> | ||
|
||
|
||
|
||
Set up a simple model for illustration purpose | ||
|
||
|
||
```python | ||
model = torch.nn.Sequential( | ||
torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1), | ||
torch.nn.ReLU() | ||
) | ||
``` | ||
|
||
|
||
```python | ||
for i, (pos, sample) in enumerate(my_dataloader): | ||
pred_pos = ( | ||
slice(pos[0, 0, 0].item() + 16, | ||
pos[0, 0, 1].item() - 16), | ||
slice(pos[0, 1, 0].item() + 16, | ||
pos[0, 1, 1].item() - 16) | ||
) | ||
pred = model(sample) | ||
z_prediction[pred_pos] = pred.detach().cpu().numpy()[0, 0, 16:-16, 16:-16] | ||
``` | ||
|
||
## Visualize the result | ||
|
||
|
||
```python | ||
import matplotlib.pyplot as plt | ||
|
||
plt.subplot(2, 1, 1) | ||
plt.imshow(np.moveaxis(z_arr[0, :, 0, ...], 0, -1)) | ||
plt.subplot(2, 1, 2) | ||
plt.imshow(z_prediction) | ||
plt.show() | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a typo here for
group "0"
, should it be"4"
in this example?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for noticing this @ClementCaporal! I considered this change and added it to a recent PR #8 that addresses an incorrect sampling of masked regions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh Nice!
I was starting to use masked regions on friday and started noticing strange behavior so I just have to pull now thanks to you!
Have a good week,
Clément