diff --git a/PAPERS.md b/PAPERS.md
new file mode 100644
index 0000000..bef1f89
--- /dev/null
+++ b/PAPERS.md
@@ -0,0 +1,9 @@
+## Papers that use SOLT
+The aim of building SOLT was to create a tool for reproducible research.
+At MIPT-Oulu, we use SOLT in our projects:
+
+1. https://arxiv.org/abs/1907.05089
+2. https://arxiv.org/abs/1904.06236
+3. https://arxiv.org/abs/1907.08020
+4. https://arxiv.org/abs/1907.12237
+5. https://arxiv.org/abs/2003.01944
diff --git a/README.md b/README.md
index 549f548..669f1c1 100644
--- a/README.md
+++ b/README.md
@@ -45,7 +45,9 @@ pip install git+https://github.com/MIPT-Oulu/solt
## Benchmark
We propose a fair benchmark based on the refactored version of the one proposed by albumentations
-team (number of images per second):
+team, but here, we also convert the results into a PyTorch tensor and do the ImageNet normalization. The
+following numbers support a realistic and honest comparison between
+the libraries (number of images per second, the higher - the better):
| |albumentations
0.4.3|torchvision (Pillow-SIMD backend)
0.5.0|augmentor
0.2.8|solt
0.1.9|
|----------------|:------------------------------------:|:-------------------------------------------------------:|:-------------------------------:|:--------------------------:|
@@ -61,17 +63,7 @@ team (number of images per second):
|HFlipCrop | 2460 | 2902 | 2862 | **3514** |
Python and library versions: Python 3.7.0 (default, Oct 9 2018, 10:31:47) [GCC 7.3.0], numpy 1.18.1, pillow-simd 7.0.0.post3, opencv-python 4.2.0.32, scikit-image 0.16.2, scipy 1.4.1.
-Please find the details about the benchmark [here](BENCHMARK.md).
-
-## Papers that use SOLT
-The aim of building SOLT was to create a tool for reproducible research. At MIPT, we use SOLT in our projects:
-
-1. https://arxiv.org/abs/1907.05089
-2. https://arxiv.org/abs/1904.06236
-3. https://arxiv.org/abs/1907.08020
-4. https://arxiv.org/abs/1907.12237
-
-If you use SOLT and cite it in your research, please, don't hesitate to sent an email to Aleksei Tiulpin. It will be added here.
+The code was run on AMD Threadripper 1900. Please find the details about the benchmark [here](BENCHMARK.md).
## How to contribute
Follow the guidelines described [here](CONTRIBUTING.md).
@@ -83,6 +75,10 @@ Physics and Technology,
University of Oulu, Finalnd.
## How to cite
+If you use SOLT and cite it in your research, please,
+don't hesitate to sent an email to Aleksei Tiulpin.
+All the papers that use SOLT are listed [here](PAPERS.md).
+
```
@misc{aleksei_tiulpin_2019_3351977,
author = {Aleksei Tiulpin},
diff --git a/solt/transforms/_transforms.py b/solt/transforms/_transforms.py
index 9159717..45be9eb 100644
--- a/solt/transforms/_transforms.py
+++ b/solt/transforms/_transforms.py
@@ -838,7 +838,7 @@ class CutOut(ImageTransform):
Parameters
----------
- cutout_size : tuple or int or None
+ cutout_size : tuple or int or float or None
The size of the cutout. If None, then it is equal to 2.
data_indices : tuple or None
Indices of the images within the data container to which this transform needs to be applied.
@@ -853,34 +853,47 @@ class CutOut(ImageTransform):
def __init__(self, cutout_size=2, data_indices=None, p=0.5):
super(CutOut, self).__init__(p=p, data_indices=data_indices)
- if not isinstance(cutout_size, (int, tuple, list)):
+ if not isinstance(cutout_size, (int, tuple, list, float)):
raise TypeError("Cutout size is of an incorrect type!")
if isinstance(cutout_size, list):
cutout_size = tuple(cutout_size)
if isinstance(cutout_size, tuple):
- if not isinstance(cutout_size[0], int) or not isinstance(cutout_size[1], int):
+ if not isinstance(cutout_size[0], (int, float)) or not isinstance(cutout_size[1], (int, float)):
raise TypeError
- if isinstance(cutout_size, int):
+ if isinstance(cutout_size, (int, float)):
cutout_size = (cutout_size, cutout_size)
+ if not isinstance(cutout_size[0], type(cutout_size[1])):
+ raise TypeError("CutOut sizes must be of the same type")
self.cutout_size = cutout_size
def sample_transform(self, data: DataContainer):
h, w = super(CutOut, self).sample_transform(data)
+ if isinstance(self.cutout_size[0], float):
+ cut_size_x = int(self.cutout_size[0] * w)
+ else:
+ cut_size_x = self.cutout_size[0]
+
+ if isinstance(self.cutout_size[1], float):
+ cut_size_y = int(self.cutout_size[1] * h)
+ else:
+ cut_size_y = self.cutout_size[1]
- if self.cutout_size[0] > w or self.cutout_size[1] > h:
+ if cut_size_x > w or cut_size_y > h:
raise ValueError("Cutout size is too large!")
- self.state_dict["x"] = int(random.random() * (w - self.cutout_size[0]))
- self.state_dict["y"] = int(random.random() * (h - self.cutout_size[1]))
+ self.state_dict["x"] = int(random.random() * (w - cut_size_x))
+ self.state_dict["y"] = int(random.random() * (h - cut_size_y))
+ self.state_dict["cut_size_x"] = cut_size_x
+ self.state_dict["cut_size_y"] = cut_size_y
def __cutout_img(self, img):
img[
- self.state_dict["y"] : self.state_dict["y"] + self.cutout_size[1],
- self.state_dict["x"] : self.state_dict["x"] + self.cutout_size[0],
+ self.state_dict["y"] : self.state_dict["y"] + self.state_dict["cut_size_y"],
+ self.state_dict["x"] : self.state_dict["x"] + self.state_dict["cut_size_x"],
] = 0
return img
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index 46206db..31ab88d 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -792,7 +792,7 @@ def test_crop_or_cutout_size_are_too_big(img_2x2, cutout_crop_size):
trf(dc)
-@pytest.mark.parametrize("cutout_crop_size", ["123", 2.5, (2.5, 2), (2, 2.2)])
+@pytest.mark.parametrize("cutout_crop_size", ["123", ("23", 2), (2.5, 2), (2, 2.2)])
def test_wrong_crop_size_types(cutout_crop_size):
with pytest.raises(TypeError):
slt.Crop(crop_to=cutout_crop_size)
@@ -1029,11 +1029,7 @@ def test_intensity_remap_values():
@pytest.mark.parametrize(
- "img, expected",
- [
- (img_3x3(), does_not_raise()),
- (img_3x3_rgb(), pytest.raises(ValueError)),
- ],
+ "img, expected", [(img_3x3(), does_not_raise()), (img_3x3_rgb(), pytest.raises(ValueError)),],
)
def test_intensity_remap_channels(img, expected):
trf = slt.IntensityRemap(p=1)
@@ -1146,25 +1142,29 @@ def test_different_interpolations_per_item_per_transform(img_6x6, transform_sett
@pytest.mark.parametrize(
- "img, expected",
+ "img, expected, cut_size",
[
- (img_7x7(), np.zeros((7, 7, 1), dtype=np.uint8)),
- (img_6x6(), np.zeros((6, 6, 1), dtype=np.uint8)),
- (img_6x6_rgb(), np.zeros((6, 6, 3), dtype=np.uint8)),
+ (img_7x7(), np.zeros((7, 7, 1), dtype=np.uint8), 7),
+ (img_6x6(), np.zeros((6, 6, 1), dtype=np.uint8), 6),
+ (img_6x6_rgb(), np.zeros((6, 6, 3), dtype=np.uint8), 6),
+ (img_7x7(), np.zeros((7, 7, 1), dtype=np.uint8), 1.0),
+ (img_6x6(), np.zeros((6, 6, 1), dtype=np.uint8), 1.0),
+ (img_6x6_rgb(), np.zeros((6, 6, 3), dtype=np.uint8), 1.0),
],
)
-def test_cutout_blacks_out_image(img, expected):
+def test_cutout_blacks_out_image(img, expected, cut_size):
dc = slc.DataContainer((img,), "I")
- trf = slc.Stream([slt.CutOut(p=1, cutout_size=6)])
+ trf = slc.Stream([slt.CutOut(p=1, cutout_size=cut_size)])
dc_res = trf(dc, return_torch=False)
assert np.array_equal(expected, dc_res.data[0])
-def test_cutout_1x1_blacks_corner_pixels_2x2_img(img_2x2):
+@pytest.mark.parametrize("cut_size", [1, (1, 1), 0.5, (0.5, 0.5)])
+def test_cutout_1x1_blacks_corner_pixels_2x2_img(img_2x2, cut_size):
dc = slc.DataContainer((img_2x2.copy(),), "I")
- trf = slc.Stream([slt.CutOut(p=1, cutout_size=1)])
+ trf = slc.Stream([slt.CutOut(p=1, cutout_size=cut_size)])
dc_res = trf(dc, return_torch=False)
equal = 0