Skip to content

Commit

Permalink
made cutout take percentage if the input (closes #35)
Browse files Browse the repository at this point in the history
  • Loading branch information
lext committed Mar 10, 2020
1 parent 06b9a56 commit 7af0c9a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 35 deletions.
9 changes: 9 additions & 0 deletions PAPERS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## Papers that use SOLT
The aim of building SOLT was to create a tool for reproducible research.
At MIPT-Oulu, we use SOLT in our projects:

1. https://arxiv.org/abs/1907.05089
2. https://arxiv.org/abs/1904.06236
3. https://arxiv.org/abs/1907.08020
4. https://arxiv.org/abs/1907.12237
5. https://arxiv.org/abs/2003.01944
20 changes: 8 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ pip install git+https://github.com/MIPT-Oulu/solt
## Benchmark

We propose a fair benchmark based on the refactored version of the one proposed by albumentations
team (number of images per second):
team, but here, we also convert the results into a PyTorch tensor and do the ImageNet normalization. The
following numbers support a realistic and honest comparison between
the libraries (number of images per second, the higher - the better):

| |albumentations<br><small>0.4.3</small>|torchvision (Pillow-SIMD backend)<br><small>0.5.0</small>|augmentor<br><small>0.2.8</small>|solt<br><small>0.1.9</small>|
|----------------|:------------------------------------:|:-------------------------------------------------------:|:-------------------------------:|:--------------------------:|
Expand All @@ -61,17 +63,7 @@ team (number of images per second):
|HFlipCrop | 2460 | 2902 | 2862 | **3514** |

Python and library versions: Python 3.7.0 (default, Oct 9 2018, 10:31:47) [GCC 7.3.0], numpy 1.18.1, pillow-simd 7.0.0.post3, opencv-python 4.2.0.32, scikit-image 0.16.2, scipy 1.4.1.
Please find the details about the benchmark [here](BENCHMARK.md).

## Papers that use SOLT
The aim of building SOLT was to create a tool for reproducible research. At MIPT, we use SOLT in our projects:

1. https://arxiv.org/abs/1907.05089
2. https://arxiv.org/abs/1904.06236
3. https://arxiv.org/abs/1907.08020
4. https://arxiv.org/abs/1907.12237

If you use SOLT and cite it in your research, please, don't hesitate to sent an email to Aleksei Tiulpin. It will be added here.
The code was run on AMD Threadripper 1900. Please find the details about the benchmark [here](BENCHMARK.md).

## How to contribute
Follow the guidelines described [here](CONTRIBUTING.md).
Expand All @@ -83,6 +75,10 @@ Physics and Technology,
University of Oulu, Finalnd.

## How to cite
If you use SOLT and cite it in your research, please,
don't hesitate to sent an email to Aleksei Tiulpin.
All the papers that use SOLT are listed [here](PAPERS.md).

```
@misc{aleksei_tiulpin_2019_3351977,
author = {Aleksei Tiulpin},
Expand Down
31 changes: 22 additions & 9 deletions solt/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ class CutOut(ImageTransform):
Parameters
----------
cutout_size : tuple or int or None
cutout_size : tuple or int or float or None
The size of the cutout. If None, then it is equal to 2.
data_indices : tuple or None
Indices of the images within the data container to which this transform needs to be applied.
Expand All @@ -853,34 +853,47 @@ class CutOut(ImageTransform):

def __init__(self, cutout_size=2, data_indices=None, p=0.5):
super(CutOut, self).__init__(p=p, data_indices=data_indices)
if not isinstance(cutout_size, (int, tuple, list)):
if not isinstance(cutout_size, (int, tuple, list, float)):
raise TypeError("Cutout size is of an incorrect type!")

if isinstance(cutout_size, list):
cutout_size = tuple(cutout_size)

if isinstance(cutout_size, tuple):
if not isinstance(cutout_size[0], int) or not isinstance(cutout_size[1], int):
if not isinstance(cutout_size[0], (int, float)) or not isinstance(cutout_size[1], (int, float)):
raise TypeError

if isinstance(cutout_size, int):
if isinstance(cutout_size, (int, float)):
cutout_size = (cutout_size, cutout_size)
if not isinstance(cutout_size[0], type(cutout_size[1])):
raise TypeError("CutOut sizes must be of the same type")

self.cutout_size = cutout_size

def sample_transform(self, data: DataContainer):
h, w = super(CutOut, self).sample_transform(data)
if isinstance(self.cutout_size[0], float):
cut_size_x = int(self.cutout_size[0] * w)
else:
cut_size_x = self.cutout_size[0]

if isinstance(self.cutout_size[1], float):
cut_size_y = int(self.cutout_size[1] * h)
else:
cut_size_y = self.cutout_size[1]

if self.cutout_size[0] > w or self.cutout_size[1] > h:
if cut_size_x > w or cut_size_y > h:
raise ValueError("Cutout size is too large!")

self.state_dict["x"] = int(random.random() * (w - self.cutout_size[0]))
self.state_dict["y"] = int(random.random() * (h - self.cutout_size[1]))
self.state_dict["x"] = int(random.random() * (w - cut_size_x))
self.state_dict["y"] = int(random.random() * (h - cut_size_y))
self.state_dict["cut_size_x"] = cut_size_x
self.state_dict["cut_size_y"] = cut_size_y

def __cutout_img(self, img):
img[
self.state_dict["y"] : self.state_dict["y"] + self.cutout_size[1],
self.state_dict["x"] : self.state_dict["x"] + self.cutout_size[0],
self.state_dict["y"] : self.state_dict["y"] + self.state_dict["cut_size_y"],
self.state_dict["x"] : self.state_dict["x"] + self.state_dict["cut_size_x"],
] = 0
return img

Expand Down
28 changes: 14 additions & 14 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def test_crop_or_cutout_size_are_too_big(img_2x2, cutout_crop_size):
trf(dc)


@pytest.mark.parametrize("cutout_crop_size", ["123", 2.5, (2.5, 2), (2, 2.2)])
@pytest.mark.parametrize("cutout_crop_size", ["123", ("23", 2), (2.5, 2), (2, 2.2)])
def test_wrong_crop_size_types(cutout_crop_size):
with pytest.raises(TypeError):
slt.Crop(crop_to=cutout_crop_size)
Expand Down Expand Up @@ -1029,11 +1029,7 @@ def test_intensity_remap_values():


@pytest.mark.parametrize(
"img, expected",
[
(img_3x3(), does_not_raise()),
(img_3x3_rgb(), pytest.raises(ValueError)),
],
"img, expected", [(img_3x3(), does_not_raise()), (img_3x3_rgb(), pytest.raises(ValueError)),],
)
def test_intensity_remap_channels(img, expected):
trf = slt.IntensityRemap(p=1)
Expand Down Expand Up @@ -1146,25 +1142,29 @@ def test_different_interpolations_per_item_per_transform(img_6x6, transform_sett


@pytest.mark.parametrize(
"img, expected",
"img, expected, cut_size",
[
(img_7x7(), np.zeros((7, 7, 1), dtype=np.uint8)),
(img_6x6(), np.zeros((6, 6, 1), dtype=np.uint8)),
(img_6x6_rgb(), np.zeros((6, 6, 3), dtype=np.uint8)),
(img_7x7(), np.zeros((7, 7, 1), dtype=np.uint8), 7),
(img_6x6(), np.zeros((6, 6, 1), dtype=np.uint8), 6),
(img_6x6_rgb(), np.zeros((6, 6, 3), dtype=np.uint8), 6),
(img_7x7(), np.zeros((7, 7, 1), dtype=np.uint8), 1.0),
(img_6x6(), np.zeros((6, 6, 1), dtype=np.uint8), 1.0),
(img_6x6_rgb(), np.zeros((6, 6, 3), dtype=np.uint8), 1.0),
],
)
def test_cutout_blacks_out_image(img, expected):
def test_cutout_blacks_out_image(img, expected, cut_size):
dc = slc.DataContainer((img,), "I")
trf = slc.Stream([slt.CutOut(p=1, cutout_size=6)])
trf = slc.Stream([slt.CutOut(p=1, cutout_size=cut_size)])

dc_res = trf(dc, return_torch=False)

assert np.array_equal(expected, dc_res.data[0])


def test_cutout_1x1_blacks_corner_pixels_2x2_img(img_2x2):
@pytest.mark.parametrize("cut_size", [1, (1, 1), 0.5, (0.5, 0.5)])
def test_cutout_1x1_blacks_corner_pixels_2x2_img(img_2x2, cut_size):
dc = slc.DataContainer((img_2x2.copy(),), "I")
trf = slc.Stream([slt.CutOut(p=1, cutout_size=1)])
trf = slc.Stream([slt.CutOut(p=1, cutout_size=cut_size)])
dc_res = trf(dc, return_torch=False)

equal = 0
Expand Down

0 comments on commit 7af0c9a

Please sign in to comment.