Skip to content

Commit

Permalink
Merge branch 'dev' into add-gpu-load-for-pydicom
Browse files Browse the repository at this point in the history
  • Loading branch information
yiheng-wang-nv authored Jan 23, 2025
2 parents 35d1f91 + df1ba5d commit 43162aa
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 8 deletions.
25 changes: 20 additions & 5 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam


def _get_ngc_bundle_url(model_name: str, version: str) -> str:
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files"


def _get_ngc_private_base_url(repo: str) -> str:
Expand Down Expand Up @@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
return name


def _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]:
if not has_requests:
raise ValueError("requests package is required, please install it.")
headers = {} if headers is None else headers
response = requests_get(request_url, headers=headers)
response.raise_for_status()
model_info = json.loads(response.text)

if not isinstance(model_info, dict) or "modelFiles" not in model_info:
raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.")

model_files = model_info["modelFiles"]
return [f["path"] for f in model_files]


def _download_from_ngc(
download_path: Path,
filename: str,
Expand All @@ -229,12 +244,12 @@ def _download_from_ngc(
# ensure prefix is contained
filename = _add_ngc_prefix(filename, prefix=prefix)
url = _get_ngc_bundle_url(model_name=filename, version=version)
filepath = download_path / f"{filename}_v{version}.zip"
if remove_prefix:
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
extract_path = download_path / f"{filename}"
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
filepath = download_path / filename
filepath.mkdir(parents=True, exist_ok=True)
for file in _get_all_download_files(url):
download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress)


def _download_from_ngc_private(
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ class NormalizeIntensity(Transform):
mean and std on each channel separately.
When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
be the number of image channels if they are not None.
If the input is not of floating point type, it will be converted to float32
Args:
subtrahend: the amount to subtract by (usually the mean).
Expand Down Expand Up @@ -907,6 +908,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
if self.divisor is not None and len(self.divisor) != len(img):
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")

if not img.dtype.is_floating_point:
img, *_ = convert_data_type(img, dtype=torch.float32)

for i, d in enumerate(img):
img[i] = self._normalize( # type: ignore
d,
Expand Down
7 changes: 6 additions & 1 deletion tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,12 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape):
@SkipIfNoModule("kvikio")
@parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4])
def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape):
test_image = np.random.rand(128, 128, 128)
if torch.__version__.endswith("nv24.8"):
# related issue: https://github.com/Project-MONAI/MONAI/issues/8274
# for this version, use randint test case to avoid the issue
test_image = torch.randint(0, 256, (128, 128, 128), dtype=torch.uint8).numpy()
else:
test_image = np.random.rand(128, 128, 128)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_normalize_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,27 @@ def test_channel_wise(self, im_type):
normalized = normalizer(input_data)
assert_allclose(normalized, im_type(expected), type_test="tensor")

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise_int(self, im_type):
normalizer = NormalizeIntensity(nonzero=True, channel_wise=True)
input_data = im_type(torch.arange(1, 25).reshape(2, 3, 4))
expected = np.array(
[
[
[-1.593255, -1.3035723, -1.0138896, -0.7242068],
[-0.4345241, -0.1448414, 0.1448414, 0.4345241],
[0.7242068, 1.0138896, 1.3035723, 1.593255],
],
[
[-1.593255, -1.3035723, -1.0138896, -0.7242068],
[-0.4345241, -0.1448414, 0.1448414, 0.4345241],
[0.7242068, 1.0138896, 1.3035723, 1.593255],
],
]
)
normalized = normalizer(input_data)
assert_allclose(normalized, im_type(expected), type_test="tensor", rtol=1e-7, atol=1e-7) # tolerance

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_value_errors(self, im_type):
input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))
Expand Down
11 changes: 9 additions & 2 deletions tests/test_zarr_avg_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@
from torch.nn.functional import pad

from monai.inferers import ZarrAvgMerger
from monai.utils import optional_import
from monai.utils import get_package_version, optional_import, version_geq
from tests.utils import assert_allclose

np.seterr(divide="ignore", invalid="ignore")
zarr, has_zarr = optional_import("zarr")
if has_zarr:
if version_geq(get_package_version("zarr"), "3.0.0"):
directory_store = zarr.storage.LocalStore("test.zarr")
else:
directory_store = zarr.storage.DirectoryStore("test.zarr")
else:
directory_store = None
numcodecs, has_numcodecs = optional_import("numcodecs")

TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32)
Expand Down Expand Up @@ -154,7 +161,7 @@

# explicit directory store
TEST_CASE_10_DIRECTORY_STORE = [
dict(merged_shape=TENSOR_4x4.shape, store=zarr.storage.DirectoryStore("test.zarr")),
dict(merged_shape=TENSOR_4x4.shape, store=directory_store),
[
(TENSOR_4x4[..., :2, :2], (0, 0)),
(TENSOR_4x4[..., :2, 2:], (0, 2)),
Expand Down

0 comments on commit 43162aa

Please sign in to comment.