Skip to content

Commit

Permalink
Fix mypy errors and add mypy check to CI (#914)
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar authored Jul 4, 2022
1 parent c0b1914 commit 9150afc
Show file tree
Hide file tree
Showing 54 changed files with 474 additions and 342 deletions.
27 changes: 0 additions & 27 deletions .github/workflows/lint.yml

This file was deleted.

13 changes: 11 additions & 2 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
[mypy]

[mypy-PIL.*]
ignore_missing_imports = True

[mypy-SimpleITK.*]
ignore_missing_imports = True

[mypy-duecredit.*]
ignore_missing_imports = True

[mypy-matplotlib.*]
ignore_missing_imports = True

[mypy-nibabel.*]
ignore_missing_imports = True

[mypy-tqdm.*]
[mypy-scipy.*]
ignore_missing_imports = True

[mypy-PIL.*]
[mypy-tqdm.*]
ignore_missing_imports = True
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ repos:
rev: v2.34.0
hooks:
- id: pyupgrade
args: ['--py37-plus']
args: ['--py37-plus', '--keep-runtime-typing']
8 changes: 5 additions & 3 deletions src/torchio/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import copy
from typing import Sequence, Optional, Callable, Iterable, Dict
from typing import Sequence, Optional, Callable, Iterable, List

from torch.utils.data import Dataset

Expand Down Expand Up @@ -94,14 +96,14 @@ def __getitem__(self, index: int) -> Subject:
return subject

@classmethod
def from_batch(cls: 'SubjectsDataset', batch: Dict) -> 'SubjectsDataset':
def from_batch(cls, batch: dict) -> SubjectsDataset:
"""Instantiate a dataset from a batch generated by a data loader.
Args:
batch: Dictionary generated by a data loader, containing data that
can be converted to instances of :class:`~.torchio.Subject`.
"""
subjects = get_subjects_from_batch(batch)
subjects: List[Subject] = get_subjects_from_batch(batch)
return cls(subjects)

def dry_iter(self):
Expand Down
39 changes: 27 additions & 12 deletions src/torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TypeData,
TypeDataAffine,
TypeTripletInt,
TypeQuartetInt,
TypeTripletFloat,
TypeDirection3D,
)
Expand Down Expand Up @@ -235,6 +236,8 @@ def affine(self) -> np.ndarray:
if self._loaded or self._is_dir() or self._is_multipath():
affine = self[AFFINE]
else:
assert self.path is not None
assert isinstance(self.path, (str, Path))
affine = read_affine(self.path)
return affine

Expand All @@ -247,13 +250,18 @@ def type(self) -> str: # noqa: A003
return self[TYPE]

@property
def shape(self) -> Tuple[int, int, int, int]:
def shape(self) -> TypeQuartetInt:
"""Tensor shape as :math:`(C, W, H, D)`."""
custom_reader = self.reader is not read_image
multipath = not isinstance(self.path, (str, Path))
if self._loaded or custom_reader or multipath or self.path.is_dir():
shape = tuple(self.data.shape)
multipath = self._is_multipath()
if isinstance(self.path, Path):
is_dir = self.path.is_dir()
shape: TypeQuartetInt
if self._loaded or custom_reader or multipath or is_dir:
channels, si, sj, sk = self.data.shape
shape = channels, si, sj, sk
else:
assert isinstance(self.path, (str, Path))
shape = read_shape(self.path)
return shape

Expand Down Expand Up @@ -289,18 +297,20 @@ def direction(self) -> TypeDirection3D:
_, _, direction = get_sitk_metadata_from_ras_affine(
self.affine, lps=False,
)
return direction
return direction # type: ignore[return-value]

@property
def spacing(self) -> Tuple[float, float, float]:
"""Voxel spacing in mm."""
_, spacing = get_rotation_and_spacing_from_affine(self.affine)
return tuple(spacing)
sx, sy, sz = spacing
return sx, sy, sz

@property
def origin(self) -> Tuple[float, float, float]:
"""Center of first voxel in array, in mm."""
return tuple(self.affine[:3, 3])
ox, oy, oz = self.affine[:3, 3]
return ox, oy, oz

@property
def itemsize(self):
Expand Down Expand Up @@ -421,17 +431,17 @@ def _parse_single_path(

def _parse_path(
self,
path: Union[TypePath, Sequence[TypePath], None],
path: Optional[Union[TypePath, Sequence[TypePath]]],
) -> Optional[Union[Path, List[Path]]]:
if path is None:
return None
elif isinstance(path, dict):
# https://github.com/fepegar/torchio/pull/838
raise TypeError('The path argument cannot be a dictionary')
elif self._is_paths_sequence(path):
return [self._parse_single_path(p) for p in path]
return [self._parse_single_path(p) for p in path] # type: ignore[union-attr] # noqa: E501
else:
return self._parse_single_path(path)
return self._parse_single_path(path) # type: ignore[arg-type]

def _parse_tensor(
self,
Expand Down Expand Up @@ -510,7 +520,12 @@ def load(self) -> None:
"""
if self._loaded:
return
paths = self.path if self._is_multipath() else [self.path]

paths: List[Path]
if self._is_multipath():
paths = self.path # type: ignore[assignment]
else:
paths = [self.path] # type: ignore[list-item]
tensor, affine = self.read_and_check(paths[0])
tensors = [tensor]
for path in paths[1:]:
Expand Down Expand Up @@ -786,7 +801,7 @@ def __init__(self, *args, **kwargs):

def count_nonzero(self) -> int:
"""Get the number of voxels that are not 0."""
return self.data.count_nonzero().item()
return int(self.data.count_nonzero().item())

def count_labels(self) -> Dict[int, int]:
"""Get the number of voxels in each label."""
Expand Down
14 changes: 9 additions & 5 deletions src/torchio/data/inference/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Tuple
from typing import Optional, Tuple

import torch
import numpy as np
Expand Down Expand Up @@ -32,11 +32,11 @@ def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'):
subject = sampler.subject
self.volume_padded = sampler.padding_mode is not None
self.spatial_shape = subject.spatial_shape
self._output_tensor = None
self._output_tensor: Optional[torch.Tensor] = None
self.patch_overlap = sampler.patch_overlap
self._parse_overlap_mode(overlap_mode)
self.overlap_mode = overlap_mode
self._avgmask_tensor = None
self._avgmask_tensor: Optional[torch.Tensor] = None

@staticmethod
def _parse_overlap_mode(overlap_mode):
Expand Down Expand Up @@ -127,6 +127,7 @@ def add_batch(
)
raise RuntimeError(message)
self._initialize_output_tensor(batch)
assert isinstance(self._output_tensor, torch.Tensor)
if self.overlap_mode == 'crop':
for patch, location in zip(batch, locations):
cropped_patch, new_location = self._crop_patch(
Expand All @@ -143,6 +144,7 @@ def add_batch(
] = cropped_patch
elif self.overlap_mode == 'average':
self._initialize_avgmask_tensor(batch)
assert isinstance(self._avgmask_tensor, torch.Tensor)
for patch, location in zip(batch, locations):
i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = location
self._output_tensor[
Expand All @@ -160,6 +162,7 @@ def add_batch(

def get_output_tensor(self) -> torch.Tensor:
"""Get the aggregated volume after dense inference."""
assert isinstance(self._output_tensor, torch.Tensor)
if self._output_tensor.dtype == torch.int64:
message = (
'Medical image frameworks such as ITK do not support int64.'
Expand All @@ -168,6 +171,7 @@ def get_output_tensor(self) -> torch.Tensor:
warnings.warn(message, RuntimeWarning)
self._output_tensor = self._output_tensor.type(torch.int32)
if self.overlap_mode == 'average':
assert isinstance(self._avgmask_tensor, torch.Tensor)
# true_divide is used instead of / in case the PyTorch version is
# old and one the operands is int:
# https://github.com/fepegar/torchio/issues/526
Expand All @@ -180,7 +184,7 @@ def get_output_tensor(self) -> torch.Tensor:
from ...transforms import Crop
border = self.patch_overlap // 2
cropping = border.repeat(2)
crop = Crop(cropping)
return crop(output)
crop = Crop(cropping) # type: ignore[arg-type]
return crop(output) # type: ignore[return-value]
else:
return output
32 changes: 23 additions & 9 deletions src/torchio/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
TypeDataAffine,
TypeDirection,
TypeTripletFloat,
TypeDoubletInt,
TypeTripletInt,
TypeQuartetInt,
)


Expand Down Expand Up @@ -87,18 +90,28 @@ def _read_dicom(directory: TypePath):
return image


def read_shape(path: TypePath) -> Tuple[int, int, int, int]:
def read_shape(path: TypePath) -> TypeQuartetInt:
reader = sitk.ImageFileReader()
reader.SetFileName(str(path))
reader.ReadImageInformation()
num_channels = reader.GetNumberOfComponents()
spatial_shape = reader.GetSize()
num_dimensions = reader.GetDimension()
assert 2 <= num_dimensions <= 4
if num_dimensions == 2:
spatial_shape = *spatial_shape, 1
elif num_dimensions == 4: # assume bad NIfTI
*spatial_shape, num_channels = spatial_shape
shape = (num_channels,) + tuple(spatial_shape)
spatial_shape_2d: TypeDoubletInt = reader.GetSize()
assert len(spatial_shape_2d) == 2
si, sj = spatial_shape_2d
sk = 1
elif num_dimensions == 4:
# We assume bad NIfTI file (channels encoded as spatial dimension)
spatial_shape_4d: TypeQuartetInt = reader.GetSize()
assert len(spatial_shape_4d) == 4
si, sj, sk, num_channels = spatial_shape_4d
elif num_dimensions == 3:
spatial_shape_3d: TypeTripletInt = reader.GetSize()
assert len(spatial_shape_3d) == 3
si, sj, sk = spatial_shape_3d
shape = num_channels, si, sj, sk
return shape


Expand Down Expand Up @@ -130,7 +143,7 @@ def write_image(


def _write_nibabel(
tensor: TypeData,
tensor: torch.Tensor,
affine: TypeData,
path: TypePath,
) -> None:
Expand Down Expand Up @@ -384,10 +397,11 @@ def get_sitk_metadata_from_ras_affine(
origin_array = origin_lps if lps else origin_ras
direction_array = direction_lps if lps else direction_ras
direction_array = direction_array.flatten()
# The following are to comply with typing hints
# (there must be prettier ways to do this)
# The following are to comply with mypy
# (although there must be prettier ways to do this)
ox, oy, oz = origin_array
sx, sy, sz = spacing_array
direction: TypeDirection
if is_2d:
d1, d2, d3, d4 = direction_array
direction = d1, d2, d3, d4
Expand Down
Loading

0 comments on commit 9150afc

Please sign in to comment.