Skip to content

Commit

Permalink
Merge branch 'main' into read_image_exif_support
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Mar 5, 2024
2 parents 6abde93 + fa5b844 commit 855d5de
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 118 deletions.
3 changes: 2 additions & 1 deletion .github/scripts/setup-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ case $GPU_ARCH_TYPE in
;;
esac
PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${CHANNEL}/${GPU_ARCH_ID}"
pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}"
# TODO: remove pinning of mpmath when https://github.com/pytorch/vision/issues/8292 is properly fixed.
pip install --progress-bar=off "mpmath<1.4" --pre torch --index-url="${PYTORCH_WHEEL_INDEX}"

if [[ $GPU_ARCH_TYPE == 'cuda' ]]; then
python -c "import torch; exit(not torch.cuda.is_available())"
Expand Down
53 changes: 0 additions & 53 deletions .github/workflows/build-conda-macos.yml

This file was deleted.

52 changes: 0 additions & 52 deletions .github/workflows/build-wheels-macos.yml

This file was deleted.

2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def main(args):

num_classes = len(dataset.classes)
mixup_cutmix = get_mixup_cutmix(
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_classes=num_classes, use_v2=args.use_v2
)
if mixup_cutmix is not None:

Expand Down
10 changes: 5 additions & 5 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
from torchvision.transforms import functional as F


def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
transforms_module = get_module(use_v2)

mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.MixUp(alpha=mixup_alpha, num_categories=num_categories)
transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes)
if use_v2
else RandomMixUp(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.CutMix(alpha=mixup_alpha, num_categories=num_categories)
transforms_module.CutMix(alpha=mixup_alpha, num_classes=num_classes)
if use_v2
else RandomCutMix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
else RandomCutMix(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
)
if not mixup_cutmix:
return None
Expand Down
16 changes: 16 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,22 @@ def test_draw_keypoints_visibility_default():
assert_equal(result, expected)


def test_draw_keypoints_dtypes():
image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
image_float = to_dtype(image_uint8, torch.float, scale=True)

out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
out_float = utils.draw_keypoints(image_float, keypoints)

assert out_uint8.dtype == torch.uint8
assert out_uint8 is not image_uint8

assert out_float.is_floating_point()
assert out_float is not image_float

torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
Expand Down
21 changes: 15 additions & 6 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,13 @@ def draw_keypoints(

"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
The image values should be uint8 in [0, 255] or float in [0, 1].
Keypoints can be drawn for multiple instances at a time.
This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
Expand All @@ -363,16 +363,16 @@ def draw_keypoints(
For more details, see :ref:`draw_keypoints_with_visibility`.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_keypoints)
# validate image
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
Expand All @@ -397,6 +397,12 @@ def draw_keypoints(
f"Got {visibility.shape = } and {keypoints.shape = }"
)

original_dtype = image.dtype
if original_dtype.is_floating_point:
from torchvision.transforms.v2.functional import to_dtype # noqa

image = to_dtype(image, dtype=torch.uint8, scale=True)

ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
Expand Down Expand Up @@ -428,7 +434,10 @@ def draw_keypoints(
width=width,
)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
if original_dtype.is_floating_point:
out = to_dtype(out, dtype=original_dtype, scale=True)
return out


# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
Expand Down

0 comments on commit 855d5de

Please sign in to comment.