From 4a50788e3e0924525f1eb7bc4d64917a59f5acfd Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Wed, 25 Dec 2024 16:45:28 +0100 Subject: [PATCH] Update __init__.py --- torchvision/tv_tensors/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 1ba47f60a36..82f57c74aac 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,3 +1,5 @@ +from typing import TypeVar + import torch from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat @@ -7,12 +9,14 @@ from ._tv_tensor import TVTensor from ._video import Video +TVTensorLike = TypeVar("TVTensorLike", TVTensor, BoundingBoxes, Image, Mask, Video) + # TODO: Fix this. We skip this method as it leads to # RecursionError: maximum recursion depth exceeded while calling a Python object # Until `disable` is removed, there will be graph breaks after all calls to functional transforms @torch.compiler.disable -def wrap(wrappee, *, like, **kwargs): +def wrap(wrappee: torch.Tensor, *, like: TVTensorLike, **kwargs) -> TVTensorLike: """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of