Skip to content

Commit

Permalink
fixes RandSpatialCropSamples random states (#1086)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Oct 4, 2020
1 parent 06fb955 commit 9f51893
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 21 deletions.
7 changes: 7 additions & 0 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,13 @@ def __init__(
self.num_samples = num_samples
self.cropper = RandSpatialCrop(roi_size, random_center, random_size)

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "Randomizable":
super().set_random_state(seed=seed, state=state)
self.cropper.set_random_state(state=self.R)
return self

def randomize(self, data: Optional[Any] = None) -> None:
pass

Expand Down
7 changes: 7 additions & 0 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ def __init__(
self.num_samples = num_samples
self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size)

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "Randomizable":
super().set_random_state(seed=seed, state=state)
self.cropper.set_random_state(state=self.R)
return self

def randomize(self, data: Optional[Any] = None) -> None:
pass

Expand Down
63 changes: 53 additions & 10 deletions tests/test_rand_spatial_crop_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,67 @@
from monai.transforms import RandSpatialCropSamples

TEST_CASE_1 = [
{"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True},
np.random.randint(0, 2, size=[3, 3, 3, 3]),
(3, 3, 3, 3),
{"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False},
np.arange(192).reshape(3, 4, 4, 4),
[(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)],
np.array(
[
[
[[21, 22, 23], [25, 26, 27], [29, 30, 31]],
[[37, 38, 39], [41, 42, 43], [45, 46, 47]],
[[53, 54, 55], [57, 58, 59], [61, 62, 63]],
],
[
[[85, 86, 87], [89, 90, 91], [93, 94, 95]],
[[101, 102, 103], [105, 106, 107], [109, 110, 111]],
[[117, 118, 119], [121, 122, 123], [125, 126, 127]],
],
[
[[149, 150, 151], [153, 154, 155], [157, 158, 159]],
[[165, 166, 167], [169, 170, 171], [173, 174, 175]],
[[181, 182, 183], [185, 186, 187], [189, 190, 191]],
],
]
),
]

TEST_CASE_2 = [
{"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False},
np.random.randint(0, 2, size=[3, 3, 3, 3]),
(3, 3, 3, 3),
{"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False, "random_size": True},
np.arange(192).reshape(3, 4, 4, 4),
[(3, 4, 4, 3), (3, 4, 3, 3), (3, 3, 4, 4), (3, 4, 4, 4), (3, 3, 3, 4), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)],
np.array(
[
[
[[21, 22, 23], [25, 26, 27], [29, 30, 31]],
[[37, 38, 39], [41, 42, 43], [45, 46, 47]],
[[53, 54, 55], [57, 58, 59], [61, 62, 63]],
],
[
[[85, 86, 87], [89, 90, 91], [93, 94, 95]],
[[101, 102, 103], [105, 106, 107], [109, 110, 111]],
[[117, 118, 119], [121, 122, 123], [125, 126, 127]],
],
[
[[149, 150, 151], [153, 154, 155], [157, 158, 159]],
[[165, 166, 167], [169, 170, 171], [173, 174, 175]],
[[181, 182, 183], [185, 186, 187], [189, 190, 191]],
],
]
),
]


class TestRandSpatialCropSamples(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, input_param, input_data, expected_shape):
result = RandSpatialCropSamples(**input_param)(input_data)
for item in result:
self.assertTupleEqual(item.shape, expected_shape)
def test_shape(self, input_param, input_data, expected_shape, expected_last_item):
xform = RandSpatialCropSamples(**input_param)
xform.set_random_state(1234)
result = xform(input_data)

np.testing.assert_equal(len(result), input_param["num_samples"])
for item, expected in zip(result, expected_shape):
self.assertTupleEqual(item.shape, expected)
np.testing.assert_allclose(result[-1], expected_last_item)


if __name__ == "__main__":
Expand Down
58 changes: 47 additions & 11 deletions tests/test_rand_spatial_crop_samplesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,61 @@
from monai.transforms import RandSpatialCropSamplesd

TEST_CASE_1 = [
{"keys": ["img", "seg"], "num_samples": 4, "roi_size": [3, 3, 3], "random_center": True},
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3]), "seg": np.random.randint(0, 2, size=[3, 3, 3, 3])},
(3, 3, 3, 3),
{"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True},
{"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)},
[(3, 3, 3, 2), (3, 2, 2, 2), (3, 3, 3, 2), (3, 3, 2, 2)],
{
"img": np.array(
[
[[[0, 1], [3, 4]], [[9, 10], [12, 13]], [[18, 19], [21, 22]]],
[[[27, 28], [30, 31]], [[36, 37], [39, 40]], [[45, 46], [48, 49]]],
[[[54, 55], [57, 58]], [[63, 64], [66, 67]], [[72, 73], [75, 76]]],
]
),
"seg": np.array(
[
[[[81, 80], [78, 77]], [[72, 71], [69, 68]], [[63, 62], [60, 59]]],
[[[54, 53], [51, 50]], [[45, 44], [42, 41]], [[36, 35], [33, 32]]],
[[[27, 26], [24, 23]], [[18, 17], [15, 14]], [[9, 8], [6, 5]]],
]
),
},
]

TEST_CASE_2 = [
{"keys": ["img", "seg"], "num_samples": 8, "roi_size": [3, 3, 3], "random_center": False},
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3]), "seg": np.random.randint(0, 2, size=[3, 3, 3, 3])},
(3, 3, 3, 3),
{"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False},
{"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)},
[(3, 3, 3, 3), (3, 2, 3, 3), (3, 2, 2, 3), (3, 2, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 2, 2, 3), (3, 3, 2, 3)],
{
"img": np.array(
[
[[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]],
[[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]],
[[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]],
]
),
"seg": np.array(
[
[[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]],
[[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]],
[[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]],
]
),
},
]


class TestRandSpatialCropSamplesd(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, input_param, input_data, expected_shape):
result = RandSpatialCropSamplesd(**input_param)(input_data)
for item in result:
self.assertTupleEqual(item["img"].shape, expected_shape)
self.assertTupleEqual(item["seg"].shape, expected_shape)
def test_shape(self, input_param, input_data, expected_shape, expected_last):
xform = RandSpatialCropSamplesd(**input_param)
xform.set_random_state(1234)
result = xform(input_data)
for item, expected in zip(result, expected_shape):
self.assertTupleEqual(item["img"].shape, expected)
self.assertTupleEqual(item["seg"].shape, expected)
np.testing.assert_allclose(item["img"], expected_last["img"])
np.testing.assert_allclose(item["seg"], expected_last["seg"])


if __name__ == "__main__":
Expand Down

0 comments on commit 9f51893

Please sign in to comment.