From 3e5d37c982a0447c2cd92e44c574be370a5fd98b Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 7 Jan 2025 23:28:21 +0000 Subject: [PATCH 1/9] move make_flat_list_of_images and make_batched_videos to image_utils --- src/transformers/image_utils.py | 58 +++++++++++++++++++ .../models/aria/image_processing_aria.py | 27 +-------- src/transformers/models/aria/modular_aria.py | 5 +- .../models/blip/image_processing_blip.py | 5 +- .../chameleon/image_processing_chameleon.py | 27 +-------- .../models/clip/image_processing_clip.py | 4 +- .../models/colpali/modular_colpali.py | 5 +- .../models/colpali/processing_colpali.py | 27 +-------- .../image_processing_instructblipvideo.py | 29 +--------- .../llava_next/image_processing_llava_next.py | 27 +-------- .../image_processing_llava_next_video.py | 22 +------ .../image_processing_llava_onevision.py | 28 +-------- .../video_processing_llava_onevision.py | 28 ++------- .../models/paligemma/processing_paligemma.py | 28 +-------- .../qwen2_vl/image_processing_qwen2_vl.py | 51 ++-------------- .../models/siglip/image_processing_siglip.py | 4 +- .../image_processing_video_llava.py | 26 +-------- 17 files changed, 99 insertions(+), 302 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 90b5f44c563..5736918183e 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -226,6 +226,64 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: ) +def make_flat_list_of_images( + images: Union[List[ImageInput], ImageInput], +) -> ImageInput: + """ + Ensure that the input is a flat list of images. If the input is a single image, it is converted to a list of length 1. + If the input is a nested list of images, it is converted to a flat list of images. + Args: + images (`Union[List[ImageInput], ImageInput]`): + The input image. + Returns: + list: A list of images or a 4d array of images. + """ + # If the input is a nested list of images, we flatten it + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_pil_image(images): + return [images] + + elif is_valid_image(images): + if len(images.shape) == 4: + return images + elif len(images.shape) == 3: + return [images] + + raise ValueError(f"Could not make a flat list of images from {images}") + + +def make_batched_videos(videos) -> VideoInput: + """ + Ensure that the input is a list of videos. + Args: + videos (`VideoInput`): + Video or videos to turn into a list of videos. + Returns: + list: A list of videos. + """ + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + return videos + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + if is_pil_image(videos[0]): + return [videos] + elif len(videos[0].shape) == 4: + return [list(video) for video in videos] + + elif is_valid_image(videos): + if is_pil_image(videos): + return [[videos]] + elif len(videos.shape) == 4: + return [list(videos)] + + raise ValueError(f"Could not make batched video from {videos}") + + def to_numpy_array(img) -> np.ndarray: if not is_valid_image(img): raise ValueError(f"Invalid image type: {type(img)}") diff --git a/src/transformers/models/aria/image_processing_aria.py b/src/transformers/models/aria/image_processing_aria.py index 7b00665aa28..de8637eb28f 100644 --- a/src/transformers/models/aria/image_processing_aria.py +++ b/src/transformers/models/aria/image_processing_aria.py @@ -31,7 +31,7 @@ PILImageResampling, get_image_size, infer_channel_dimension_format, - is_valid_image, + make_flat_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -39,29 +39,6 @@ from ...utils import TensorType -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: """ Divides an image into patches of a specified size. @@ -244,7 +221,7 @@ def preprocess( if max_image_size not in [490, 980]: raise ValueError("max_image_size must be either 490 or 980") - images = make_batched_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 78c6e08bdfd..f799452bc3c 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -28,6 +28,7 @@ PILImageResampling, get_image_size, infer_channel_dimension_format, + make_flat_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -57,7 +58,7 @@ LlamaRMSNorm, ) from ..llava.modeling_llava import LlavaCausalLMOutputWithPast -from ..llava_next.image_processing_llava_next import divide_to_patches, make_batched_images +from ..llava_next.image_processing_llava_next import divide_to_patches logger = logging.get_logger(__name__) @@ -608,7 +609,7 @@ def preprocess( if max_image_size not in [490, 980]: raise ValueError("max_image_size must be either 490 or 980") - images = make_batched_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/blip/image_processing_blip.py b/src/transformers/models/blip/image_processing_blip.py index 0f7683d08d1..df2aee157dc 100644 --- a/src/transformers/models/blip/image_processing_blip.py +++ b/src/transformers/models/blip/image_processing_blip.py @@ -28,7 +28,7 @@ PILImageResampling, infer_channel_dimension_format, is_scaled_image, - make_list_of_images, + make_flat_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -231,8 +231,7 @@ def preprocess( size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) - - images = make_list_of_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/chameleon/image_processing_chameleon.py b/src/transformers/models/chameleon/image_processing_chameleon.py index 4ef305c511e..c9d110ad229 100644 --- a/src/transformers/models/chameleon/image_processing_chameleon.py +++ b/src/transformers/models/chameleon/image_processing_chameleon.py @@ -30,7 +30,7 @@ PILImageResampling, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_flat_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -44,29 +44,6 @@ import PIL -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - class ChameleonImageProcessor(BaseImageProcessor): r""" Constructs a Chameleon image processor. @@ -275,7 +252,7 @@ def preprocess( image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - images = make_batched_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py index c81451b195c..2155b306bc0 100644 --- a/src/transformers/models/clip/image_processing_clip.py +++ b/src/transformers/models/clip/image_processing_clip.py @@ -33,7 +33,7 @@ PILImageResampling, infer_channel_dimension_format, is_scaled_image, - make_list_of_images, + make_flat_list_of_images, to_numpy_array, valid_images, validate_kwargs, @@ -283,7 +283,7 @@ def preprocess( validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) - images = make_list_of_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/colpali/modular_colpali.py b/src/transformers/models/colpali/modular_colpali.py index ceb43e2d66f..2cc6dded858 100644 --- a/src/transformers/models/colpali/modular_colpali.py +++ b/src/transformers/models/colpali/modular_colpali.py @@ -20,11 +20,10 @@ IMAGE_TOKEN, PaliGemmaProcessor, build_string_from_input, - make_batched_images, ) from ...feature_extraction_utils import BatchFeature -from ...image_utils import ImageInput, is_valid_image +from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images from ...processing_utils import ( ProcessingKwargs, Unpack, @@ -168,7 +167,7 @@ def __call__( ) for prompt, image_list in zip(texts_doc, images) ] - images = make_batched_images(images) + images = make_flat_list_of_images(images) pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] # max_length has to account for the image tokens diff --git a/src/transformers/models/colpali/processing_colpali.py b/src/transformers/models/colpali/processing_colpali.py index f8d68675798..342cd0cd3d6 100644 --- a/src/transformers/models/colpali/processing_colpali.py +++ b/src/transformers/models/colpali/processing_colpali.py @@ -23,7 +23,7 @@ from typing import ClassVar, List, Optional, Union from ...feature_extraction_utils import BatchFeature -from ...image_utils import ImageInput, is_valid_image +from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput from ...utils import is_torch_available @@ -72,29 +72,6 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_i return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - class ColPaliProcessor(ProcessorMixin): r""" Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as @@ -230,7 +207,7 @@ def __call__( ) for prompt, image_list in zip(texts_doc, images) ] - images = make_batched_images(images) + images = make_flat_list_of_images(images) pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] # max_length has to account for the image tokens diff --git a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py index 75e07317b05..37cec22a9b3 100644 --- a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py @@ -32,40 +32,17 @@ VideoInput, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_batched_videos, to_numpy_array, valid_images, validate_preprocess_arguments, ) -from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging - - -if is_vision_available(): - import PIL +from ...utils import TensorType, filter_out_non_signature_kwargs, logging logger = logging.get_logger(__name__) -def make_batched_videos(videos) -> List[VideoInput]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], PIL.Image.Image): - return [videos] - elif len(videos[0].shape) == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos): - if isinstance(videos, PIL.Image.Image): - return [[videos]] - elif len(videos.shape) == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - # Copied from transformers.models.blip.image_processing_blip.BlipImageProcessor with Blip->InstructBlipVideo, BLIP->InstructBLIPVideo class InstructBlipVideoImageProcessor(BaseImageProcessor): r""" @@ -198,7 +175,7 @@ def preprocess( do_convert_rgb: bool = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - ) -> PIL.Image.Image: + ) -> BatchFeature: """ Preprocess a video or batch of images/videos. diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index 8e2a4f4644f..742ed4cbabd 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -37,7 +37,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_flat_list_of_images, make_list_of_images, to_numpy_array, valid_images, @@ -53,29 +53,6 @@ from PIL import Image -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: """ Divides an image into patches of a specified size. @@ -670,7 +647,7 @@ def preprocess( do_pad = do_pad if do_pad is not None else self.do_pad do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - images = make_batched_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py index 81f55f9373b..af6ad7a4bb7 100644 --- a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py @@ -34,7 +34,7 @@ VideoInput, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_batched_videos, make_list_of_images, to_numpy_array, validate_preprocess_arguments, @@ -46,23 +46,7 @@ if is_vision_available(): - from PIL import Image - - -def make_batched_videos(videos) -> List[VideoInput]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], Image.Image): - return [videos] - elif len(videos[0].shape) == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") + pass class LlavaNextVideoImageProcessor(BaseImageProcessor): @@ -212,7 +196,7 @@ def _preprocess( do_convert_rgb: bool = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - ) -> Image.Image: + ) -> list[np.ndarray]: """ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index 75581d25aef..22435175045 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -36,7 +36,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_flat_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -51,30 +51,6 @@ from PIL import Image -# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - # Copied from transformers.models.llava_next.image_processing_llava_next.divide_to_patches def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: """ @@ -632,7 +608,7 @@ def preprocess( do_pad = do_pad if do_pad is not None else self.do_pad do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - images = make_batched_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/llava_onevision/video_processing_llava_onevision.py b/src/transformers/models/llava_onevision/video_processing_llava_onevision.py index a5aa42688e6..743e9f2df68 100644 --- a/src/transformers/models/llava_onevision/video_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/video_processing_llava_onevision.py @@ -16,6 +16,8 @@ from typing import Dict, List, Optional, Union +import numpy as np + from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( convert_to_rgb, @@ -31,37 +33,17 @@ VideoInput, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_batched_videos, to_numpy_array, valid_images, validate_preprocess_arguments, ) -from ...utils import TensorType, is_vision_available, logging +from ...utils import TensorType, logging logger = logging.get_logger(__name__) -if is_vision_available(): - from PIL import Image - - -def make_batched_videos(videos) -> List[VideoInput]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], Image.Image) or len(videos[0].shape) == 3: - return [videos] - elif len(videos[0].shape) == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - class LlavaOnevisionVideoProcessor(BaseImageProcessor): r""" Constructs a LLaVa-Onevisino-Video video processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame. @@ -138,7 +120,7 @@ def _preprocess( do_convert_rgb: bool = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - ) -> Image.Image: + ) -> list[np.ndarray]: """ Args: images (`ImageInput`): diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index f2d0afed946..ac4b98e70b0 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -19,7 +19,7 @@ from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature -from ...image_utils import ImageInput, is_valid_image +from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images from ...processing_utils import ( ImagesKwargs, ProcessingKwargs, @@ -99,30 +99,6 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_i return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" -# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched video from {images}") - - class PaliGemmaProcessor(ProcessorMixin): r""" Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor. @@ -297,7 +273,7 @@ def __call__( ) for prompt, image_list in zip(text, images) ] - images = make_batched_images(images) + images = make_flat_list_of_images(images) else: expanded_samples = [] for sample in text: diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py index b8656a91031..51b657327c3 100644 --- a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py @@ -40,62 +40,19 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_batched_videos, + make_flat_list_of_images, make_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, ) -from ...utils import TensorType, is_vision_available, logging +from ...utils import TensorType, logging logger = logging.get_logger(__name__) -if is_vision_available(): - from PIL import Image - - -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched images from {images}") - - -# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos -def make_batched_videos(videos) -> List[VideoInput]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], Image.Image): - return [videos] - elif len(videos[0].shape) == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 ): @@ -392,7 +349,7 @@ def preprocess( do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb if images is not None: - images = make_batched_images(images) + images = make_flat_list_of_images(images) if videos is not None: videos = make_batched_videos(videos) diff --git a/src/transformers/models/siglip/image_processing_siglip.py b/src/transformers/models/siglip/image_processing_siglip.py index b87adb7492d..d5826878065 100644 --- a/src/transformers/models/siglip/image_processing_siglip.py +++ b/src/transformers/models/siglip/image_processing_siglip.py @@ -30,7 +30,7 @@ PILImageResampling, infer_channel_dimension_format, is_scaled_image, - make_list_of_images, + make_flat_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -181,7 +181,7 @@ def preprocess( image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - images = make_list_of_images(images) + images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/video_llava/image_processing_video_llava.py b/src/transformers/models/video_llava/image_processing_video_llava.py index 4e978346176..dbb10548577 100644 --- a/src/transformers/models/video_llava/image_processing_video_llava.py +++ b/src/transformers/models/video_llava/image_processing_video_llava.py @@ -34,38 +34,18 @@ VideoInput, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_batched_videos, make_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, ) -from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging +from ...utils import TensorType, filter_out_non_signature_kwargs, logging logger = logging.get_logger(__name__) -if is_vision_available(): - import PIL - - -def make_batched_videos(videos) -> List[VideoInput]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], PIL.Image.Image): - return [videos] - elif len(videos[0].shape) == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - class VideoLlavaImageProcessor(BaseImageProcessor): r""" Constructs a CLIP image processor. @@ -208,7 +188,7 @@ def preprocess( return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, - ) -> PIL.Image.Image: + ) -> BatchFeature: """ Preprocess an image or batch of images. From a95e445acd5be38a713278cd5e9f37453137307f Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 7 Jan 2025 23:49:39 +0000 Subject: [PATCH 2/9] remove unnecessary is_vision_available --- .../llava_next_video/image_processing_llava_next_video.py | 6 +----- tests/models/colpali/test_modeling_colpali.py | 4 ---- .../instructblipvideo/test_modeling_instructblipvideo.py | 6 +----- tests/models/pixtral/test_modeling_pixtral.py | 5 ----- 4 files changed, 2 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py index af6ad7a4bb7..3ec8d9db069 100644 --- a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py @@ -39,16 +39,12 @@ to_numpy_array, validate_preprocess_arguments, ) -from ...utils import TensorType, is_vision_available, logging +from ...utils import TensorType, logging logger = logging.get_logger(__name__) -if is_vision_available(): - pass - - class LlavaNextVideoImageProcessor(BaseImageProcessor): r""" Constructs a LLaVa-NeXT-Video video processor. Based on [`CLIPImageProcessor`] with incorporation of processing each video frame. diff --git a/tests/models/colpali/test_modeling_colpali.py b/tests/models/colpali/test_modeling_colpali.py index 646726ac700..4dcbf94be40 100644 --- a/tests/models/colpali/test_modeling_colpali.py +++ b/tests/models/colpali/test_modeling_colpali.py @@ -26,7 +26,6 @@ from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from transformers import ( is_torch_available, - is_vision_available, ) from transformers.models.colpali.configuration_colpali import ColPaliConfig from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput @@ -43,9 +42,6 @@ if is_torch_available(): import torch -if is_vision_available(): - pass - class ColPaliForRetrievalModelTester: def __init__( diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 3be5f89325c..ef79f152f9e 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -39,7 +39,7 @@ slow, torch_device, ) -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -58,10 +58,6 @@ from transformers import InstructBlipVideoForConditionalGeneration, InstructBlipVideoVisionModel -if is_vision_available(): - pass - - class InstructBlipVideoVisionModelTester: def __init__( self, diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py index 3e5667caf45..a30454f6562 100644 --- a/tests/models/pixtral/test_modeling_pixtral.py +++ b/tests/models/pixtral/test_modeling_pixtral.py @@ -20,7 +20,6 @@ PixtralVisionConfig, PixtralVisionModel, is_torch_available, - is_vision_available, ) from transformers.testing_utils import ( require_torch, @@ -35,10 +34,6 @@ import torch -if is_vision_available(): - pass - - class PixtralVisionModelTester: def __init__( self, From 423e9d4191d82a0d3ef8f88f933ce3e93801a77a Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 8 Jan 2025 00:21:24 +0000 Subject: [PATCH 3/9] move make_nested_list_of_images to image_utils --- src/transformers/image_utils.py | 50 +++++++++++++++++-- .../idefics2/image_processing_idefics2.py | 37 +------------- .../idefics3/image_processing_idefics3.py | 38 +------------- .../models/mllama/image_processing_mllama.py | 40 +-------------- .../pixtral/image_processing_pixtral.py | 38 +------------- 5 files changed, 54 insertions(+), 149 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 5736918183e..7b10a023fc9 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -158,6 +158,10 @@ def is_valid_image(img): return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img) +def is_valid_list_of_images(images: List): + return images and all(is_valid_image(image) for image in images) + + def valid_images(imgs): # If we have an list of images, make sure every image is valid if isinstance(imgs, (list, tuple)): @@ -189,7 +193,7 @@ def is_scaled_image(image: np.ndarray) -> bool: def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: """ - Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1. + Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1. If the input is a batch of images, it is converted to a list of images. Args: @@ -230,7 +234,7 @@ def make_flat_list_of_images( images: Union[List[ImageInput], ImageInput], ) -> ImageInput: """ - Ensure that the input is a flat list of images. If the input is a single image, it is converted to a list of length 1. + Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1. If the input is a nested list of images, it is converted to a flat list of images. Args: images (`Union[List[ImageInput], ImageInput]`): @@ -239,10 +243,48 @@ def make_flat_list_of_images( list: A list of images or a 4d array of images. """ # If the input is a nested list of images, we flatten it - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + if ( + isinstance(images, (list, tuple)) + and all(isinstance(images_i, (list, tuple)) for images_i in images) + and all(is_valid_list_of_images(images_i) for images_i in images) + ): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_list_of_images(images): + return images + + elif is_pil_image(images): + return [images] + + elif is_valid_image(images): + if len(images.shape) == 4: + return images + elif len(images.shape) == 3: + return [images] + + raise ValueError(f"Could not make a flat list of images from {images}") + + +def make_nested_list_of_images( + images: Union[List[ImageInput], ImageInput], +) -> ImageInput: + """ + Ensure that the output is a nested list of images. + Args: + images (`Union[List[ImageInput], ImageInput]`): + The input image. + Returns: + list: A list of images or a 4d array of images. + """ + # If the input is a nested list of images, we flatten it + if ( + isinstance(images, (list, tuple)) + and all(isinstance(images_i, (list, tuple)) for images_i in images) + and all(is_valid_list_of_images(images_i) for images_i in images) + ): return [img for img_list in images for img in img_list] - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + elif isinstance(images, (list, tuple)) and is_valid_list_of_images(images): return images elif is_pil_image(images): diff --git a/src/transformers/models/idefics2/image_processing_idefics2.py b/src/transformers/models/idefics2/image_processing_idefics2.py index 65d5a828541..927aba761c4 100644 --- a/src/transformers/models/idefics2/image_processing_idefics2.py +++ b/src/transformers/models/idefics2/image_processing_idefics2.py @@ -29,7 +29,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_nested_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -77,39 +77,6 @@ def get_resize_output_image_size(image, size, input_data_format) -> Tuple[int, i return height, width -def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: - """ - Convert a single image or a list of images to a list of numpy arrays. - - Args: - images (`ImageInput`): - A single image or a list of images. - - Returns: - A list of numpy arrays. - """ - # If it's a single image, convert it to a list of lists - if is_valid_image(images): - images = [[images]] - # If it's a list of images, it's a single batch, so convert it to a list of lists - elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]): - images = [images] - # If it's a list of batches, it's already in the right format - elif ( - isinstance(images, (list, tuple)) - and len(images) > 0 - and isinstance(images[0], (list, tuple)) - and len(images[0]) > 0 - and is_valid_image(images[0][0]) - ): - pass - else: - raise ValueError( - "Invalid input type. Must be a single image, a list of images, or a list of batches of images." - ) - return images - - # Copied from transformers.models.detr.image_processing_detr.max_across_indices def max_across_indices(values: Iterable[Any]) -> List[Any]: """ @@ -504,7 +471,7 @@ def preprocess( do_pad = do_pad if do_pad is not None else self.do_pad do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting - images_list = make_list_of_images(images) + images_list = make_nested_list_of_images(images) if not valid_images(images_list[0]): raise ValueError( diff --git a/src/transformers/models/idefics3/image_processing_idefics3.py b/src/transformers/models/idefics3/image_processing_idefics3.py index df71a8bf0e8..b8b30609b84 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3.py +++ b/src/transformers/models/idefics3/image_processing_idefics3.py @@ -29,7 +29,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_nested_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -141,40 +141,6 @@ def get_resize_output_image_size( return height, width -# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images -def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: - """ - Convert a single image or a list of images to a list of numpy arrays. - - Args: - images (`ImageInput`): - A single image or a list of images. - - Returns: - A list of numpy arrays. - """ - # If it's a single image, convert it to a list of lists - if is_valid_image(images): - images = [[images]] - # If it's a list of images, it's a single batch, so convert it to a list of lists - elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]): - images = [images] - # If it's a list of batches, it's already in the right format - elif ( - isinstance(images, (list, tuple)) - and len(images) > 0 - and isinstance(images[0], (list, tuple)) - and len(images[0]) > 0 - and is_valid_image(images[0][0]) - ): - pass - else: - raise ValueError( - "Invalid input type. Must be a single image, a list of images, or a list of batches of images." - ) - return images - - # Copied from transformers.models.detr.image_processing_detr.max_across_indices def max_across_indices(values: Iterable[Any]) -> List[Any]: """ @@ -720,7 +686,7 @@ def preprocess( do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb do_pad = do_pad if do_pad is not None else self.do_pad - images_list = make_list_of_images(images) + images_list = make_nested_list_of_images(images) if not valid_images(images_list[0]): raise ValueError( diff --git a/src/transformers/models/mllama/image_processing_mllama.py b/src/transformers/models/mllama/image_processing_mllama.py index 3c852589672..9ff077f1501 100644 --- a/src/transformers/models/mllama/image_processing_mllama.py +++ b/src/transformers/models/mllama/image_processing_mllama.py @@ -33,8 +33,8 @@ ImageInput, PILImageResampling, infer_channel_dimension_format, - is_valid_image, is_vision_available, + make_nested_list_of_images, to_numpy_array, validate_preprocess_arguments, ) @@ -514,42 +514,6 @@ def convert_to_rgb(image: ImageInput) -> ImageInput: return alpha_composite -# Modified from transformers.models.idefics2.image_processing_idefics2.make_list_of_images -def make_list_of_images(images: ImageInput) -> List[List[Optional[np.ndarray]]]: - """ - Convert a single image or a list of images to a list of numpy arrays. - - Args: - images (`ImageInput`): - A single image or a list of images. - - Returns: - A list of numpy arrays. - """ - # If it's a single image, convert it to a list of lists - if is_valid_image(images): - output_images = [[images]] - # If it's a list of images, it's a single batch, so convert it to a list of lists - elif isinstance(images, (list, tuple)) and is_valid_list_of_images(images): - output_images = [images] - # If it's a list of batches, it's already in the right format - elif ( - isinstance(images, (list, tuple)) - and all(isinstance(images_i, (list, tuple)) for images_i in images) - and any(is_valid_list_of_images(images_i) for images_i in images) - ): - output_images = images - else: - raise ValueError( - "Invalid input type. Must be a single image, a list of images, or a list of batches of images." - ) - return output_images - - -def is_valid_list_of_images(images: List): - return images and all(is_valid_image(image) for image in images) - - def _validate_size(size: Dict[str, int]) -> None: if not ("height" in size and "width" in size): raise ValueError(f"Argument `size` must be a dictionary with keys 'height' and 'width'. Got: {size}") @@ -726,7 +690,7 @@ def preprocess( # extra validation _validate_mllama_preprocess_arguments(do_resize, size, do_pad, max_image_tiles) - images_list = make_list_of_images(images) + images_list = make_nested_list_of_images(images) if self.do_convert_rgb: images_list = [[convert_to_rgb(image) for image in images] for images in images_list] diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index 6d83e0c4647..5ea750945c8 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -31,7 +31,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, - is_valid_image, + make_nested_list_of_images, to_numpy_array, valid_images, validate_kwargs, @@ -99,40 +99,6 @@ def _recursive_to(obj, device, *args, **kwargs): return self -# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images -def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: - """ - Convert a single image or a list of images to a list of numpy arrays. - - Args: - images (`ImageInput`): - A single image or a list of images. - - Returns: - A list of numpy arrays. - """ - # If it's a single image, convert it to a list of lists - if is_valid_image(images): - images = [[images]] - # If it's a list of images, it's a single batch, so convert it to a list of lists - elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]): - images = [images] - # If it's a list of batches, it's already in the right format - elif ( - isinstance(images, (list, tuple)) - and len(images) > 0 - and isinstance(images[0], (list, tuple)) - and len(images[0]) > 0 - and is_valid_image(images[0][0]) - ): - pass - else: - raise ValueError( - "Invalid input type. Must be a single image, a list of images, or a list of batches of images." - ) - return images - - # Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white. def convert_to_rgb(image: ImageInput) -> ImageInput: """ @@ -449,7 +415,7 @@ def preprocess( validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) - images_list = make_list_of_images(images) + images_list = make_nested_list_of_images(images) if not valid_images(images_list[0]): raise ValueError( From 948f93beb91932a2ab0b287276bd83bf0377ef79 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 8 Jan 2025 00:25:50 +0000 Subject: [PATCH 4/9] fix fast pixtral image processor --- .../models/pixtral/image_processing_pixtral_fast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/pixtral/image_processing_pixtral_fast.py b/src/transformers/models/pixtral/image_processing_pixtral_fast.py index 082e255c843..5dab12f9d2c 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral_fast.py +++ b/src/transformers/models/pixtral/image_processing_pixtral_fast.py @@ -26,6 +26,7 @@ get_image_size, get_image_type, infer_channel_dimension_format, + make_nested_list_of_images, validate_fast_preprocess_arguments, validate_kwargs, ) @@ -41,7 +42,6 @@ BatchMixFeature, convert_to_rgb, get_resize_output_image_size, - make_list_of_images, ) @@ -271,7 +271,7 @@ def preprocess( validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) - images_list = make_list_of_images(images) + images_list = make_nested_list_of_images(images) image_type = get_image_type(images_list[0][0]) if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: From 30a2d541deaa471cf81699b88af299811ecee359 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 8 Jan 2025 00:36:57 +0000 Subject: [PATCH 5/9] fix import mllama --- src/transformers/models/mllama/processing_mllama.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mllama/processing_mllama.py b/src/transformers/models/mllama/processing_mllama.py index 5905f3313f7..4e8f788cf70 100644 --- a/src/transformers/models/mllama/processing_mllama.py +++ b/src/transformers/models/mllama/processing_mllama.py @@ -20,16 +20,13 @@ import numpy as np from ...feature_extraction_utils import BatchFeature -from ...image_utils import ImageInput +from ...image_utils import ImageInput, make_nested_list_of_images from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import ( PreTokenizedInput, TextInput, ) -# TODO: Can we do it that way or its better include as "Copied from ..." -from .image_processing_mllama import make_list_of_images - class MllamaImagesKwargs(ImagesKwargs, total=False): max_image_tiles: Optional[int] @@ -292,7 +289,7 @@ def __call__( n_images_in_images = [0] if images is not None: - images = make_list_of_images(images) + images = make_nested_list_of_images(images) n_images_in_images = [len(sample) for sample in images] if text is not None: From a4a90aa81e04419d958b7a1d59fd9500cec6986f Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 8 Jan 2025 00:44:21 +0000 Subject: [PATCH 6/9] fix make_nested_list_of_images --- src/transformers/image_utils.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 7b10a023fc9..ed248be3072 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -276,27 +276,24 @@ def make_nested_list_of_images( Returns: list: A list of images or a 4d array of images. """ - # If the input is a nested list of images, we flatten it - if ( + # If it's a single image, convert it to a list of lists + if is_valid_image(images): + output_images = [[images]] + # If it's a list of images, it's a single batch, so convert it to a list of lists + elif isinstance(images, (list, tuple)) and is_valid_list_of_images(images): + output_images = [images] + # If it's a list of batches, it's already in the right format + elif ( isinstance(images, (list, tuple)) and all(isinstance(images_i, (list, tuple)) for images_i in images) and all(is_valid_list_of_images(images_i) for images_i in images) ): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_list_of_images(images): - return images - - elif is_pil_image(images): - return [images] - - elif is_valid_image(images): - if len(images.shape) == 4: - return images - elif len(images.shape) == 3: - return [images] - - raise ValueError(f"Could not make a flat list of images from {images}") + output_images = images + else: + raise ValueError( + "Invalid input type. Must be a single image, a list of images, or a list of batches of images." + ) + return output_images def make_batched_videos(videos) -> VideoInput: From 75834180e5b6e49c38b1be31c62627f405c109fe Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 9 Jan 2025 17:05:08 +0000 Subject: [PATCH 7/9] add tests --- src/transformers/image_utils.py | 56 +++++---- tests/utils/test_image_utils.py | 211 +++++++++++++++++++++++++++++++- 2 files changed, 242 insertions(+), 25 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index ed248be3072..b623201360e 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -207,7 +207,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: return images # Either the input is a single image, in which case we create a list of length 1 - if isinstance(images, PIL.Image.Image): + if is_pil_image(images): # PIL images are never batched return [images] @@ -250,17 +250,17 @@ def make_flat_list_of_images( ): return [img for img_list in images for img in img_list] - elif isinstance(images, (list, tuple)) and is_valid_list_of_images(images): - return images - - elif is_pil_image(images): - return [images] - - elif is_valid_image(images): - if len(images.shape) == 4: + if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): + if is_pil_image(images[0]) or images[0].ndim == 3: return images - elif len(images.shape) == 3: + if images[0].ndim == 4: + return [img for img_list in images for img in img_list] + + if is_valid_image(images): + if is_pil_image(images) or images.ndim == 3: return [images] + if images.ndim == 4: + return images raise ValueError(f"Could not make a flat list of images from {images}") @@ -274,26 +274,34 @@ def make_nested_list_of_images( images (`Union[List[ImageInput], ImageInput]`): The input image. Returns: - list: A list of images or a 4d array of images. + list: A list of list of images or a list of 4d array of images. """ - # If it's a single image, convert it to a list of lists - if is_valid_image(images): - output_images = [[images]] - # If it's a list of images, it's a single batch, so convert it to a list of lists - elif isinstance(images, (list, tuple)) and is_valid_list_of_images(images): - output_images = [images] # If it's a list of batches, it's already in the right format - elif ( + if ( isinstance(images, (list, tuple)) and all(isinstance(images_i, (list, tuple)) for images_i in images) and all(is_valid_list_of_images(images_i) for images_i in images) ): - output_images = images - else: - raise ValueError( - "Invalid input type. Must be a single image, a list of images, or a list of batches of images." - ) - return output_images + return images + + # If it's a list of images, it's a single batch, so convert it to a list of lists + if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): + if is_pil_image(images[0]) or images[0].ndim == 3: + return [images] + if images[0].ndim == 4: + return images + + # If it's a single image, convert it to a list of lists + if is_pil_image(images): + return [[images]] + + if is_valid_image(images): + if is_pil_image(images) or images.ndim == 3: + return [[images]] + if images.ndim == 4: + return [images] + + raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.") def make_batched_videos(videos) -> VideoInput: diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index 1fa84aa5db2..ded6236a475 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -28,7 +28,13 @@ from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL from transformers import is_torch_available, is_vision_available -from transformers.image_utils import ChannelDimension, get_channel_dimension_axis, make_list_of_images +from transformers.image_utils import ( + ChannelDimension, + get_channel_dimension_axis, + make_flat_list_of_images, + make_list_of_images, + make_nested_list_of_images, +) from transformers.testing_utils import is_flaky, require_torch, require_vision @@ -115,6 +121,21 @@ def test_conversion_array_to_array(self): self.assertEqual(array5.shape, (3, 16, 32)) self.assertTrue(np.array_equal(array5, array1)) + def test_make_list_of_images_pil(self): + # Test a single image is converted to a list of 1 image + pil_image = get_random_image(16, 32) + images_list = make_list_of_images(pil_image) + self.assertIsInstance(images_list, list) + self.assertEqual(len(images_list), 1) + self.assertIsInstance(images_list[0], PIL.Image.Image) + + # Test a list of images is not modified + images = [get_random_image(16, 32) for _ in range(4)] + images_list = make_list_of_images(images) + self.assertIsInstance(images_list, list) + self.assertEqual(len(images_list), 4) + self.assertIsInstance(images_list[0], PIL.Image.Image) + def test_make_list_of_images_numpy(self): # Test a single image is converted to a list of 1 image images = np.random.randint(0, 256, (16, 32, 3)) @@ -167,6 +188,194 @@ def test_make_list_of_images_torch(self): self.assertTrue(np.array_equal(images_list[0], images[0])) self.assertIsInstance(images_list, list) + def test_make_flat_list_of_images_pil(self): + # Test a single image is converted to a list of 1 image + pil_image = get_random_image(16, 32) + images_list = make_flat_list_of_images(pil_image) + self.assertIsInstance(images_list, list) + self.assertEqual(len(images_list), 1) + self.assertIsInstance(images_list[0], PIL.Image.Image) + + # Test a list of images is not modified + images = [get_random_image(16, 32) for _ in range(4)] + images_list = make_flat_list_of_images(images) + self.assertIsInstance(images_list, list) + self.assertEqual(len(images_list), 4) + self.assertIsInstance(images_list[0], PIL.Image.Image) + + # Test a nested list of images is flattened + images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)] + images_list = make_flat_list_of_images(images) + self.assertIsInstance(images_list, list) + self.assertEqual(len(images_list), 4) + self.assertIsInstance(images_list[0], PIL.Image.Image) + + def test_make_flat_list_of_images_numpy(self): + # Test a single image is converted to a list of 1 image + images = np.random.randint(0, 256, (16, 32, 3)) + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 1) + self.assertTrue(np.array_equal(images_list[0], images)) + self.assertIsInstance(images_list, list) + + # Test a batch of images is unchanged + images = np.random.randint(0, 256, (4, 16, 32, 3)) + images_array = make_flat_list_of_images(images) + self.assertEqual(len(images_array), 4) + self.assertIsInstance(images_array, np.ndarray) + self.assertTrue(np.array_equal(images_array, images)) + + # Test a list of images is not modified + images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)] + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertTrue(np.array_equal(images_list[0], images[0])) + self.assertIsInstance(images_list, list) + + # Test list of batched images is flattened + images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 8) + self.assertTrue(np.array_equal(images_list[0], images[0][0])) + self.assertIsInstance(images_list, list) + self.assertIsInstance(images_list[0], np.ndarray) + + # Test nested list of images is flattened + images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)] + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertTrue(np.array_equal(images_list[0], images[0][0])) + self.assertIsInstance(images_list, list) + + @require_torch + def test_make_flat_list_of_images_torch(self): + # Test a single image is converted to a list of 1 image + images = torch.randint(0, 256, (16, 32, 3)) + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 1) + self.assertTrue(np.array_equal(images_list[0], images)) + self.assertIsInstance(images_list, list) + + # Test a batch of images is unchanged + images = torch.randint(0, 256, (4, 16, 32, 3)) + images_array = make_flat_list_of_images(images) + self.assertEqual(len(images_array), 4) + self.assertIsInstance(images_array, torch.Tensor) + self.assertTrue(np.array_equal(images_array, images)) + + # Test a list of images is not modified + images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)] + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertTrue(np.array_equal(images_list[0], images[0])) + self.assertIsInstance(images_list, list) + + # Test list of batched images is flattened + images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 8) + self.assertTrue(np.array_equal(images_list[0], images[0][0])) + self.assertIsInstance(images_list, list) + self.assertIsInstance(images_list[0], torch.Tensor) + + # Test nested list of images is flattened + images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)] + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertTrue(np.array_equal(images_list[0], images[0][0])) + self.assertIsInstance(images_list, list) + + def test_make_nested_list_of_images_pil(self): + # Test a single image is converted to a nested list of 1 image + pil_image = get_random_image(16, 32) + images_list = make_nested_list_of_images(pil_image) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list[0]), 1) + self.assertIsInstance(images_list[0][0], PIL.Image.Image) + + # Test a list of images is converted to a nested list of images + images = [get_random_image(16, 32) for _ in range(4)] + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list), 1) + self.assertEqual(len(images_list[0]), 4) + self.assertIsInstance(images_list[0][0], PIL.Image.Image) + + # Test a nested list of images is not modified + images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)] + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list), 2) + self.assertEqual(len(images_list[0]), 2) + self.assertIsInstance(images_list[0][0], PIL.Image.Image) + + def test_make_nested_list_of_images_numpy(self): + # Test a single image is converted to a nested list of 1 image + images = np.random.randint(0, 256, (16, 32, 3)) + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list), 1) + self.assertTrue(np.array_equal(images_list[0][0], images)) + + # Test a batch of images is converted to a nested list of images + images = np.random.randint(0, 256, (4, 16, 32, 3)) + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list, list) + self.assertIsInstance(images_list[0], np.ndarray) + self.assertEqual(len(images_list), 1) + self.assertEqual(len(images_list[0]), 4) + self.assertTrue(np.array_equal(images_list[0][0], images[0])) + + # Test a list of images is converted to a nested list of images + images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)] + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list), 1) + self.assertEqual(len(images_list[0]), 4) + self.assertTrue(np.array_equal(images_list[0][0], images[0])) + + # Test a nested list of images is left unchanged + images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)] + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list), 2) + self.assertEqual(len(images_list[0]), 2) + self.assertTrue(np.array_equal(images_list[0][0], images[0][0])) + + @require_torch + def test_make_nested_list_of_images_torch(self): + # Test a single image is converted to a nested list of 1 image + images = torch.randint(0, 256, (16, 32, 3)) + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list[0]), 1) + self.assertTrue(np.array_equal(images_list[0][0], images)) + + # Test a batch of images is converted to a nested list of images + images = torch.randint(0, 256, (4, 16, 32, 3)) + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list, list) + self.assertIsInstance(images_list[0], torch.Tensor) + self.assertEqual(len(images_list), 1) + self.assertEqual(len(images_list[0]), 4) + self.assertTrue(np.array_equal(images_list[0][0], images[0])) + + # Test a list of images is converted to a nested list of images + images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)] + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list), 1) + self.assertEqual(len(images_list[0]), 4) + self.assertTrue(np.array_equal(images_list[0][0], images[0])) + + # Test a nested list of images is left unchanged + images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)] + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertEqual(len(images_list), 2) + self.assertEqual(len(images_list[0]), 2) + self.assertTrue(np.array_equal(images_list[0][0], images[0][0])) + @require_torch def test_conversion_torch_to_array(self): feature_extractor = ImageFeatureExtractionMixin() From f0da40bbafce3a065cbfb3b639dab0cccede9e5c Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 10 Jan 2025 15:31:18 +0000 Subject: [PATCH 8/9] convert 4d arrays/tensors to list --- src/transformers/image_utils.py | 6 ++-- tests/utils/test_image_utils.py | 56 ++++++++++++++++++++++----------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index b623201360e..7acfb6edbce 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -260,7 +260,7 @@ def make_flat_list_of_images( if is_pil_image(images) or images.ndim == 3: return [images] if images.ndim == 4: - return images + return list(images) raise ValueError(f"Could not make a flat list of images from {images}") @@ -289,7 +289,7 @@ def make_nested_list_of_images( if is_pil_image(images[0]) or images[0].ndim == 3: return [images] if images[0].ndim == 4: - return images + return [list(image) for image in images] # If it's a single image, convert it to a list of lists if is_pil_image(images): @@ -299,7 +299,7 @@ def make_nested_list_of_images( if is_pil_image(images) or images.ndim == 3: return [[images]] if images.ndim == 4: - return [images] + return [list(images)] raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.") diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index ded6236a475..29cd413d8d8 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -218,12 +218,13 @@ def test_make_flat_list_of_images_numpy(self): self.assertTrue(np.array_equal(images_list[0], images)) self.assertIsInstance(images_list, list) - # Test a batch of images is unchanged + # Test a 4d array of images is changed to a list of images images = np.random.randint(0, 256, (4, 16, 32, 3)) - images_array = make_flat_list_of_images(images) - self.assertEqual(len(images_array), 4) - self.assertIsInstance(images_array, np.ndarray) - self.assertTrue(np.array_equal(images_array, images)) + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertIsInstance(images_list, list) + self.assertIsInstance(images_list[0], np.ndarray) + self.assertTrue(np.array_equal(images_list[0], images[0])) # Test a list of images is not modified images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)] @@ -232,7 +233,7 @@ def test_make_flat_list_of_images_numpy(self): self.assertTrue(np.array_equal(images_list[0], images[0])) self.assertIsInstance(images_list, list) - # Test list of batched images is flattened + # Test list of 4d array images is flattened images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] images_list = make_flat_list_of_images(images) self.assertEqual(len(images_list), 8) @@ -256,12 +257,13 @@ def test_make_flat_list_of_images_torch(self): self.assertTrue(np.array_equal(images_list[0], images)) self.assertIsInstance(images_list, list) - # Test a batch of images is unchanged + # Test a 4d tensors of images is changed to a list of images images = torch.randint(0, 256, (4, 16, 32, 3)) - images_array = make_flat_list_of_images(images) - self.assertEqual(len(images_array), 4) - self.assertIsInstance(images_array, torch.Tensor) - self.assertTrue(np.array_equal(images_array, images)) + images_list = make_flat_list_of_images(images) + self.assertEqual(len(images_list), 4) + self.assertIsInstance(images_list, list) + self.assertIsInstance(images_list[0], torch.Tensor) + self.assertTrue(np.array_equal(images_list[0], images[0])) # Test a list of images is not modified images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)] @@ -270,7 +272,7 @@ def test_make_flat_list_of_images_torch(self): self.assertTrue(np.array_equal(images_list[0], images[0])) self.assertIsInstance(images_list, list) - # Test list of batched images is flattened + # Test list of 4d tensors of imagess is flattened images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] images_list = make_flat_list_of_images(images) self.assertEqual(len(images_list), 8) @@ -317,11 +319,11 @@ def test_make_nested_list_of_images_numpy(self): self.assertEqual(len(images_list), 1) self.assertTrue(np.array_equal(images_list[0][0], images)) - # Test a batch of images is converted to a nested list of images + # Test a 4d array of images is converted to a nested list of images images = np.random.randint(0, 256, (4, 16, 32, 3)) images_list = make_nested_list_of_images(images) - self.assertIsInstance(images_list, list) - self.assertIsInstance(images_list[0], np.ndarray) + self.assertIsInstance(images_list[0], list) + self.assertIsInstance(images_list[0][0], np.ndarray) self.assertEqual(len(images_list), 1) self.assertEqual(len(images_list[0]), 4) self.assertTrue(np.array_equal(images_list[0][0], images[0])) @@ -342,6 +344,15 @@ def test_make_nested_list_of_images_numpy(self): self.assertEqual(len(images_list[0]), 2) self.assertTrue(np.array_equal(images_list[0][0], images[0][0])) + # Test a list of 4d array images is converted to a nested list of images + images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertIsInstance(images_list[0][0], np.ndarray) + self.assertEqual(len(images_list), 2) + self.assertEqual(len(images_list[0]), 4) + self.assertTrue(np.array_equal(images_list[0][0], images[0][0])) + @require_torch def test_make_nested_list_of_images_torch(self): # Test a single image is converted to a nested list of 1 image @@ -351,11 +362,11 @@ def test_make_nested_list_of_images_torch(self): self.assertEqual(len(images_list[0]), 1) self.assertTrue(np.array_equal(images_list[0][0], images)) - # Test a batch of images is converted to a nested list of images + # Test a 4d tensor of images is converted to a nested list of images images = torch.randint(0, 256, (4, 16, 32, 3)) images_list = make_nested_list_of_images(images) - self.assertIsInstance(images_list, list) - self.assertIsInstance(images_list[0], torch.Tensor) + self.assertIsInstance(images_list[0], list) + self.assertIsInstance(images_list[0][0], torch.Tensor) self.assertEqual(len(images_list), 1) self.assertEqual(len(images_list[0]), 4) self.assertTrue(np.array_equal(images_list[0][0], images[0])) @@ -376,6 +387,15 @@ def test_make_nested_list_of_images_torch(self): self.assertEqual(len(images_list[0]), 2) self.assertTrue(np.array_equal(images_list[0][0], images[0][0])) + # Test a list of 4d tensor images is converted to a nested list of images + images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] + images_list = make_nested_list_of_images(images) + self.assertIsInstance(images_list[0], list) + self.assertIsInstance(images_list[0][0], torch.Tensor) + self.assertEqual(len(images_list), 2) + self.assertEqual(len(images_list[0]), 4) + self.assertTrue(np.array_equal(images_list[0][0], images[0][0])) + @require_torch def test_conversion_torch_to_array(self): feature_extractor = ImageFeatureExtractionMixin() From 77ed530fc3d1eb87b00d698c0addb0d7905c60fd Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 14 Jan 2025 20:17:03 +0000 Subject: [PATCH 9/9] add test_make_batched_videos --- src/transformers/image_utils.py | 11 ++-- tests/utils/test_image_utils.py | 110 ++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 7 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 7acfb6edbce..105b7ef26fb 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -292,9 +292,6 @@ def make_nested_list_of_images( return [list(image) for image in images] # If it's a single image, convert it to a list of lists - if is_pil_image(images): - return [[images]] - if is_valid_image(images): if is_pil_image(images) or images.ndim == 3: return [[images]] @@ -317,15 +314,15 @@ def make_batched_videos(videos) -> VideoInput: return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if is_pil_image(videos[0]): + if is_pil_image(videos[0]) or videos[0].ndim == 3: return [videos] - elif len(videos[0].shape) == 4: + elif videos[0].ndim == 4: return [list(video) for video in videos] elif is_valid_image(videos): - if is_pil_image(videos): + if is_pil_image(videos) or videos.ndim == 3: return [[videos]] - elif len(videos.shape) == 4: + elif videos.ndim == 4: return [list(videos)] raise ValueError(f"Could not make batched video from {videos}") diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index 29cd413d8d8..d4ce1435a14 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -31,6 +31,7 @@ from transformers.image_utils import ( ChannelDimension, get_channel_dimension_axis, + make_batched_videos, make_flat_list_of_images, make_list_of_images, make_nested_list_of_images, @@ -396,6 +397,115 @@ def test_make_nested_list_of_images_torch(self): self.assertEqual(len(images_list[0]), 4) self.assertTrue(np.array_equal(images_list[0][0], images[0][0])) + def test_make_batched_videos_pil(self): + # Test a single image is converted to a list of 1 video with 1 frame + pil_image = get_random_image(16, 32) + videos_list = make_batched_videos(pil_image) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list[0]), 1) + self.assertIsInstance(videos_list[0][0], PIL.Image.Image) + + # Test a list of images is converted to a list of 1 video + images = [get_random_image(16, 32) for _ in range(4)] + videos_list = make_batched_videos(images) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list), 1) + self.assertEqual(len(videos_list[0]), 4) + self.assertIsInstance(videos_list[0][0], PIL.Image.Image) + + # Test a nested list of images is not modified + images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)] + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list), 2) + self.assertEqual(len(videos_list[0]), 2) + self.assertIsInstance(videos_list[0][0], PIL.Image.Image) + + def test_make_batched_videos_numpy(self): + # Test a single image is converted to a list of 1 video with 1 frame + images = np.random.randint(0, 256, (16, 32, 3)) + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list), 1) + self.assertTrue(np.array_equal(videos_list[0][0], images)) + + # Test a 4d array of images is converted to a a list of 1 video + images = np.random.randint(0, 256, (4, 16, 32, 3)) + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertIsInstance(videos_list[0][0], np.ndarray) + self.assertEqual(len(videos_list), 1) + self.assertEqual(len(videos_list[0]), 4) + self.assertTrue(np.array_equal(videos_list[0][0], images[0])) + + # Test a list of images is converted to a list of videos + images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)] + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list), 1) + self.assertEqual(len(videos_list[0]), 4) + self.assertTrue(np.array_equal(videos_list[0][0], images[0])) + + # Test a nested list of images is left unchanged + images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)] + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list), 2) + self.assertEqual(len(videos_list[0]), 2) + self.assertTrue(np.array_equal(videos_list[0][0], images[0][0])) + + # Test a list of 4d array images is converted to a list of videos + images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertIsInstance(videos_list[0][0], np.ndarray) + self.assertEqual(len(videos_list), 2) + self.assertEqual(len(videos_list[0]), 4) + self.assertTrue(np.array_equal(videos_list[0][0], images[0][0])) + + @require_torch + def test_make_batched_videos_torch(self): + # Test a single image is converted to a list of 1 video with 1 frame + images = torch.randint(0, 256, (16, 32, 3)) + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list[0]), 1) + self.assertTrue(np.array_equal(videos_list[0][0], images)) + + # Test a 4d tensor of images is converted to a list of 1 video + images = torch.randint(0, 256, (4, 16, 32, 3)) + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertIsInstance(videos_list[0][0], torch.Tensor) + self.assertEqual(len(videos_list), 1) + self.assertEqual(len(videos_list[0]), 4) + self.assertTrue(np.array_equal(videos_list[0][0], images[0])) + + # Test a list of images is converted to a list of videos + images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)] + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list), 1) + self.assertEqual(len(videos_list[0]), 4) + self.assertTrue(np.array_equal(videos_list[0][0], images[0])) + + # Test a nested list of images is left unchanged + images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)] + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertEqual(len(videos_list), 2) + self.assertEqual(len(videos_list[0]), 2) + self.assertTrue(np.array_equal(videos_list[0][0], images[0][0])) + + # Test a list of 4d tensor images is converted to a list of videos + images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] + videos_list = make_nested_list_of_images(images) + self.assertIsInstance(videos_list[0], list) + self.assertIsInstance(videos_list[0][0], torch.Tensor) + self.assertEqual(len(videos_list), 2) + self.assertEqual(len(videos_list[0]), 4) + self.assertTrue(np.array_equal(videos_list[0][0], images[0][0])) + @require_torch def test_conversion_torch_to_array(self): feature_extractor = ImageFeatureExtractionMixin()