Skip to content

Commit

Permalink
Merge pull request #45 from soupault/master
Browse files Browse the repository at this point in the history
Add IntensityRemap transform
  • Loading branch information
lext authored Feb 26, 2020
2 parents 0a75fec + 5c41bea commit 9ca08af
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from setuptools import find_packages, setup

requirements = ("numpy", "opencv-python-headless", "torch", "torchvision", "pyyaml")
requirements = ("numpy", "scipy", "opencv-python-headless", "torch", "torchvision", "pyyaml")

setup_requirements = ()

Expand Down
2 changes: 2 additions & 0 deletions solt/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SaltAndPepper,
Blur,
HSV,
IntensityRemap,
CvtColor,
Resize,
Contrast,
Expand All @@ -36,6 +37,7 @@
"SaltAndPepper",
"Blur",
"HSV",
"IntensityRemap",
"CvtColor",
"Resize",
"Contrast",
Expand Down
51 changes: 51 additions & 0 deletions solt/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import cv2
import numpy as np
import scipy
import scipy.signal

from ..core import (
BaseTransform,
Expand Down Expand Up @@ -1210,6 +1212,55 @@ def _apply_img(self, img: np.ndarray, settings: dict):
return cv2.LUT(img, self.state_dict["LUT"])


class IntensityRemap(ImageTransform):
"""Performs random intensity remapping.
Parameters
----------
p : float
Probability of applying this transform,
kernel_size: int
Size of medial filter kernel used during the generation of intensity mapping.
Higher value yield more monotonic mapping.
data_indices : tuple or None
Indices of the images within the data container to which this transform needs to be applied.
Every element within the tuple must be integer numbers.
If None, then the transform will be applied to all the images withing the DataContainer.
References
----------
.. [1] Hesse, L. S., Kuling, G., Veta, M., & Martel, A. L. (2019).
Intensity augmentation for domain transfer of whole breast
segmentation in MRI. https://arxiv.org/abs/1909.02642
"""

serializable_name = "intensity_remap"
"""How the class should be stored in the registry"""

def __init__(self, kernel_size=9, data_indices=None, p=0.5):
super(IntensityRemap, self).__init__(p=p, data_indices=data_indices)
self.kernel_size = kernel_size

def sample_transform(self, data):
m = random.sample(range(256), k=256)
m = scipy.signal.medfilt(m, kernel_size=self.kernel_size)
m = m + np.linspace(0, 255, 256)

m = m - min(m)
m = m / max(m) * 255
m = np.floor(m).astype(np.uint8)

self.state_dict = {"LUT": m}

@img_shape_checker
def _apply_img(self, img: np.ndarray, settings: dict):
if img.dtype != np.uint8:
raise ValueError("IntensityRemap supports uint8 ndarrays only")
if img.ndim == 3 and img.shape[-1] != 1:
raise ValueError("Only grayscale 2D images are supported")
return cv2.LUT(img, self.state_dict["LUT"])


class CvtColor(ImageTransform):
"""RGB to grayscale or grayscale to RGB image conversion.
Expand Down
56 changes: 56 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import random
from contextlib import ExitStack as does_not_raise

import cv2
import numpy as np
Expand Down Expand Up @@ -1004,6 +1005,61 @@ def test_hsv_doesnt_work_for_1_channel(img_6x6):
trf(dc)


def test_intensity_remap_values():
trf = slt.IntensityRemap(p=1)
img = np.arange(0, 256, 1, dtype=np.uint8).reshape((16, 16, 1))
dc = slc.DataContainer(img, "I")
out = trf(dc).data[0]

# Mapping is applied correctly
img_expected = trf.state_dict["LUT"].reshape((16, 16, 1))
np.testing.assert_array_equal(out, img_expected)

# Mapping has a positive trendline
assert np.sum(np.diff(out.astype(np.float).ravel())) > 0

# Higher kernel size yields more monotonic mapping
trf_noisy = slt.IntensityRemap(p=1, kernel_size=1)
trf_low_pass = slt.IntensityRemap(p=1, kernel_size=5)
out_noisy = trf_noisy(dc).data[0].astype(np.float)
out_low_pass = trf_low_pass(dc).data[0].astype(np.float)
std_noisy = np.std(np.diff(out_noisy.ravel()))
std_low_pass = np.std(np.diff(out_low_pass.ravel()))
assert std_low_pass < std_noisy


@pytest.mark.parametrize(
"img, expected",
[
(img_3x3(), does_not_raise()),
(img_3x3_rgb(), pytest.raises(ValueError)),
],
)
def test_intensity_remap_channels(img, expected):
trf = slt.IntensityRemap(p=1)
dc = slc.DataContainer(img, "I")

with expected:
trf(dc)


@pytest.mark.parametrize(
"dtype, expected",
[
(np.int8, pytest.raises(ValueError)),
(np.uint16, pytest.raises(ValueError)),
(np.float, pytest.raises(ValueError)),
(np.bool_, pytest.raises(ValueError)),
],
)
def test_intensity_remap_dtypes(dtype, expected):
trf = slt.IntensityRemap(p=1)
dc = slc.DataContainer(img_3x3().astype(dtype), "I")

with expected:
trf(dc)


@pytest.mark.parametrize(
"mode, img, expected",
[
Expand Down

0 comments on commit 9ca08af

Please sign in to comment.