# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Image processor class for Pixtral.""" from typing import Optional, Union import torch from torchvision.transforms.v2 import functional as F from ...image_processing_utils import BatchFeature, get_size_dict from ...image_processing_utils_fast import ( BaseImageProcessorFast, DefaultFastImageProcessorKwargs, group_images_by_shape, reorder_images, ) from ...image_utils import ImageInput, PILImageResampling, SizeDict from ...processing_utils import Unpack from ...utils import ( TensorType, auto_docstring, logging, ) from .image_processing_pixtral import get_resize_output_image_size logger = logging.get_logger(__name__) class PixtralFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ patch_size (`dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. """ patch_size: Optional[dict[str, int]] @auto_docstring class PixtralImageProcessorFast(BaseImageProcessorFast): resample = PILImageResampling.BICUBIC image_mean = [0.48145466, 0.4578275, 0.40821073] image_std = [0.26862954, 0.26130258, 0.27577711] patch_size = {"height": 16, "width": 16} size = {"longest_edge": 1024} default_to_square = True do_resize = True do_rescale = True do_normalize = True do_convert_rgb = True valid_kwargs = PixtralFastImageProcessorKwargs model_input_names = ["pixel_values", "image_sizes"] def __init__(self, **kwargs: Unpack[PixtralFastImageProcessorKwargs]): super().__init__(**kwargs) @auto_docstring def preprocess(self, images: ImageInput, **kwargs: Unpack[PixtralFastImageProcessorKwargs]) -> BatchFeature: return super().preprocess(images, **kwargs) def resize( self, image: torch.Tensor, size: SizeDict, patch_size: SizeDict, interpolation: Optional["F.InterpolationMode"] = None, **kwargs, ) -> torch.Tensor: """ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge resized to keep the input aspect ratio. Args: image (`torch.Tensor`): Image to resize. size (`SizeDict`): Dict containing the longest possible edge of the image. patch_size (`SizeDict`): Patch size used to calculate the size of the output image. interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): Resampling filter to use when resiizing the image. """ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR if size.longest_edge: size = (size.longest_edge, size.longest_edge) elif size.height and size.width: size = (size.height, size.width) else: raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.") if patch_size.height and patch_size.width: patch_size = (patch_size.height, patch_size.width) else: raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.") output_size = get_resize_output_image_size(image, size=size, patch_size=patch_size) return F.resize(image, size=output_size, interpolation=interpolation, **kwargs) # Adapted from transformers.models.pixtral.image_processing_pixtral.PixtralImageProcessor._pad_for_batching def _pad_for_batching( self, pixel_values: list[torch.Tensor], image_sizes: list[list[int]], ): """ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. Args: pixel_values (`list[torch.Tensor]`): An array of pixel values of each images of shape (`batch_size`, `channels`, `height`, `width`) image_sizes (`list[list[int]]`): A list of sizes for each image in `pixel_values` in (height, width) format. Returns: list[`torch.Tensor`]: The padded images. """ max_shape = (max(size[0] for size in image_sizes), max(size[1] for size in image_sizes)) pixel_values = [ torch.nn.functional.pad(image, pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0])) for image, size in zip(pixel_values, image_sizes) ] return torch.stack(pixel_values) def _preprocess( self, images: list["torch.Tensor"], do_resize: bool, size: SizeDict, patch_size: dict[str, int], interpolation: Optional["F.InterpolationMode"], do_center_crop: bool, crop_size: dict[str, int], do_rescale: bool, rescale_factor: float, do_normalize: bool, image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], **kwargs, ) -> BatchFeature: patch_size = get_size_dict(patch_size, default_to_square=True) patch_size = SizeDict(**patch_size) # Group images by size for batched resizing grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) resized_images_grouped = {} for shape, stacked_images in grouped_images.items(): if do_resize: stacked_images = self.resize( image=stacked_images, size=size, patch_size=patch_size, interpolation=interpolation ) resized_images_grouped[shape] = stacked_images resized_images = reorder_images(resized_images_grouped, grouped_images_index) # Group images by size for further processing # Needed in case do_resize is False, or resize returns images with different sizes grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) batch_image_sizes = [grouped_images_index[i][0] for i in range(len(grouped_images_index))] processed_images_grouped = {} for shape, stacked_images in grouped_images.items(): if do_center_crop: stacked_images = self.center_crop(stacked_images, crop_size) # Fused rescale and normalize stacked_images = self.rescale_and_normalize( stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std ) processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) padded_images = self._pad_for_batching( pixel_values=processed_images, image_sizes=batch_image_sizes, ) return BatchFeature( data={"pixel_values": padded_images, "image_sizes": batch_image_sizes}, tensor_type=return_tensors ) __all__ = ["PixtralImageProcessorFast"]