Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for nested images to LLava and VipLLava #35558

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
106 changes: 104 additions & 2 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def is_valid_image(img):
return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)


def is_valid_list_of_images(images: List):
return images and all(is_valid_image(image) for image in images)


def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
Expand Down Expand Up @@ -189,7 +193,7 @@ def is_scaled_image(image: np.ndarray) -> bool:

def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
"""
Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1.
If the input is a batch of images, it is converted to a list of images.

Args:
Expand All @@ -203,7 +207,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
return images

# Either the input is a single image, in which case we create a list of length 1
if isinstance(images, PIL.Image.Image):
if is_pil_image(images):
# PIL images are never batched
return [images]

Expand All @@ -226,6 +230,104 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
)


def make_flat_list_of_images(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also return a 4d array/tensor, as that's how it was originally implemented in processors that use this function, so the name might be a bit misleading?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if a 4D array is really necessary in processors where it is called? AFAIK we always iterate over each image, which mean in the end we'll anyway process one 3D image

In that case, we can only return an actual list of images

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point, it will also be more aligned with make_list_of_images. I will make the change and check that it doesn't break anything. Thanks!

images: Union[List[ImageInput], ImageInput],
) -> ImageInput:
"""
Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
If the input is a nested list of images, it is converted to a flat list of images.
Args:
images (`Union[List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images or a 4d array of images.
"""
# If the input is a nested list of images, we flatten it
if (
isinstance(images, (list, tuple))
and all(isinstance(images_i, (list, tuple)) for images_i in images)
and all(is_valid_list_of_images(images_i) for images_i in images)
):
return [img for img_list in images for img in img_list]

if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == 3:
return images
if images[0].ndim == 4:
return [img for img_list in images for img in img_list]

if is_valid_image(images):
if is_pil_image(images) or images.ndim == 3:
return [images]
if images.ndim == 4:
return list(images)

raise ValueError(f"Could not make a flat list of images from {images}")


def make_nested_list_of_images(
images: Union[List[ImageInput], ImageInput],
) -> ImageInput:
"""
Ensure that the output is a nested list of images.
Args:
images (`Union[List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of list of images or a list of 4d array of images.
"""
# If it's a list of batches, it's already in the right format
if (
isinstance(images, (list, tuple))
and all(isinstance(images_i, (list, tuple)) for images_i in images)
and all(is_valid_list_of_images(images_i) for images_i in images)
):
return images

# If it's a list of images, it's a single batch, so convert it to a list of lists
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == 3:
return [images]
if images[0].ndim == 4:
return [list(image) for image in images]

# If it's a single image, convert it to a list of lists
if is_valid_image(images):
if is_pil_image(images) or images.ndim == 3:
return [[images]]
if images.ndim == 4:
return [list(images)]

raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")


def make_batched_videos(videos) -> VideoInput:
"""
Ensure that the input is a list of videos.
Args:
videos (`VideoInput`):
Video or videos to turn into a list of videos.
Returns:
list: A list of videos.
"""
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos

elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
if is_pil_image(videos[0]) or videos[0].ndim == 3:
return [videos]
elif videos[0].ndim == 4:
return [list(video) for video in videos]

elif is_valid_image(videos):
if is_pil_image(videos) or videos.ndim == 3:
return [[videos]]
elif videos.ndim == 4:
return [list(videos)]

raise ValueError(f"Could not make batched video from {videos}")


def to_numpy_array(img) -> np.ndarray:
if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}")
Expand Down
27 changes: 2 additions & 25 deletions src/transformers/models/aria/image_processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,14 @@
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_valid_image,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType


def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.

Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.

Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]

elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images

elif is_valid_image(images):
return [images]

raise ValueError(f"Could not make batched video from {images}")


def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
"""
Divides an image into patches of a specified size.
Expand Down Expand Up @@ -244,7 +221,7 @@ def preprocess(
if max_image_size not in [490, 980]:
raise ValueError("max_image_size must be either 490 or 980")

images = make_batched_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
Expand Down Expand Up @@ -57,7 +58,7 @@
LlamaRMSNorm,
)
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast
from ..llava_next.image_processing_llava_next import divide_to_patches, make_batched_images
from ..llava_next.image_processing_llava_next import divide_to_patches


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -608,7 +609,7 @@ def preprocess(
if max_image_size not in [490, 980]:
raise ValueError("max_image_size must be either 490 or 980")

images = make_batched_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
raise ValueError(
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/blip/image_processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
Expand Down Expand Up @@ -231,8 +231,7 @@ def preprocess(

size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)

images = make_list_of_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
raise ValueError(
Expand Down
27 changes: 2 additions & 25 deletions src/transformers/models/chameleon/image_processing_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
Expand All @@ -44,29 +44,6 @@
import PIL


def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.

Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.

Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]

elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images

elif is_valid_image(images):
return [images]

raise ValueError(f"Could not make batched video from {images}")


class ChameleonImageProcessor(BaseImageProcessor):
r"""
Constructs a Chameleon image processor.
Expand Down Expand Up @@ -275,7 +252,7 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb

images = make_batched_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/clip/image_processing_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
Expand Down Expand Up @@ -283,7 +283,7 @@ def preprocess(

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

images = make_list_of_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
raise ValueError(
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/colpali/modular_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
IMAGE_TOKEN,
PaliGemmaProcessor,
build_string_from_input,
make_batched_images,
)

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
from ...processing_utils import (
ProcessingKwargs,
Unpack,
Expand Down Expand Up @@ -168,7 +167,7 @@ def __call__(
)
for prompt, image_list in zip(texts_doc, images)
]
images = make_batched_images(images)
images = make_flat_list_of_images(images)
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]

# max_length has to account for the image tokens
Expand Down
27 changes: 2 additions & 25 deletions src/transformers/models/colpali/processing_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import ClassVar, List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
from ...utils import is_torch_available
Expand Down Expand Up @@ -72,29 +72,6 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_i
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"


def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.

Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.

Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]

elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images

elif is_valid_image(images):
return [images]

raise ValueError(f"Could not make batched video from {images}")


class ColPaliProcessor(ProcessorMixin):
r"""
Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as
Expand Down Expand Up @@ -230,7 +207,7 @@ def __call__(
)
for prompt, image_list in zip(texts_doc, images)
]
images = make_batched_images(images)
images = make_flat_list_of_images(images)
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]

# max_length has to account for the image tokens
Expand Down
Loading