# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/llava_onevision/modular_llava_onevision.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_llava_onevision.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Union import numpy as np import torch from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ..auto import AutoModel from .configuration_llava_onevision import LlavaOnevisionConfig @dataclass @auto_docstring( custom_intro=""" Base class for Llava outputs, with hidden states and attentions. """ ) class LlavaOnevisionModelOutputWithPast(BaseModelOutputWithPast): r""" past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. video_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: Optional[torch.FloatTensor] = None video_hidden_states: Optional[torch.FloatTensor] = None @dataclass @auto_docstring( custom_intro=""" Base class for LlavaOnevision causal language model (or autoregressive) outputs. """ ) class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. video_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None video_hidden_states: Optional[torch.FloatTensor] = None @auto_docstring class LlavaOnevisionPreTrainedModel(PreTrainedModel): config: LlavaOnevisionConfig base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, LlavaOnevisionModel): embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) module.image_newline.data.normal_(mean=0.0, std=embed_std) class LlavaOnevisionMultiModalProjector(nn.Module): def __init__(self, config: LlavaOnevisionConfig): super().__init__() # We have hidden_size * the number of vision feature layers num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) self.linear_1 = nn.Linear( config.vision_config.hidden_size * num_feature_layers, config.text_config.hidden_size, bias=config.multimodal_projector_bias, ) self.act = ACT2FN[config.projector_hidden_act] self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias ) def forward(self, image_features): hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (`tuple`): The size of the input image in the format (width, height). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. patch_size (`int`): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if not isinstance(grid_pinpoints, list): raise TypeError("grid_pinpoints should be a list of tuples or lists") # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate if not isinstance(image_size, (list, tuple)): if not isinstance(image_size, (torch.Tensor, np.ndarray)): raise TypeError( f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" ) image_size = image_size.tolist() height, width = select_best_resolution(image_size, grid_pinpoints) return height // patch_size, width // patch_size def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): """ Calculate the number of patches after the preprocessing for images of any resolution. Args: image_size (`torch.LongTensor` or `np.ndarray` or `tuple[int, int]`): The size of the input image in the format (height, width). ? grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. patch_size (`int`): The size of each image patch. Returns: int: the number of patches """ if not isinstance(grid_pinpoints, list): raise TypeError("grid_pinpoints should be a list of tuples or lists") # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate if not isinstance(image_size, (list, tuple)): if not isinstance(image_size, (torch.Tensor, np.ndarray)): raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") image_size = image_size.tolist() best_resolution = select_best_resolution(image_size, grid_pinpoints) height, width = best_resolution num_patches = 0 # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 for i in range(0, height, patch_size): for j in range(0, width, patch_size): num_patches += 1 # add the base patch num_patches += 1 return num_patches def unpad_image(tensor, original_size): """ Unpads a PyTorch tensor of a padded and resized image. Args: tensor (`torch.Tensor`): The image tensor, assumed to be of shape (num_channels, height, width). original_size (`tuple`): The original size of the image (height, width). Returns: `torch.Tensor`: The unpadded image tensor. """ if not isinstance(original_size, (list, tuple)): if not isinstance(original_size, (torch.Tensor, np.ndarray)): raise TypeError( f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" ) original_size = original_size.tolist() original_height, original_width = original_size current_height, current_width = tensor.shape[1:] original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width new_height = int(round(original_height * scale_factor, 7)) padding = (current_height - new_height) // 2 unpadded_tensor = tensor[:, padding : current_height - padding, :] else: scale_factor = current_height / original_height new_width = int(round(original_width * scale_factor, 7)) padding = (current_width - new_width) // 2 unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor @auto_docstring( custom_intro=""" The Llava-Next model which consists of a vision backbone and a language model without language modeling head. """ ) class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): _checkpoint_conversion_mapping = {"language_model.model": "language_model"} def __init__(self, config): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) embed_std = 1 / math.sqrt(config.text_config.hidden_size) self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_decoder(self, decoder): self.language_model = decoder def get_decoder(self): return self.language_model def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. Args: image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) List of image feature tensor, each contains all the visual feature of all patches. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). image_newline (`torch.Tensor` of shape `(embed_dim)`) New line embedding vector. vision_aspect_ratio (`str`, *optional*, "anyres_max_9"): Aspect ratio used when processong image features. The default value is "anyres_max_9". Returns: image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) feature_lens (`list[int]`) token length of each image in image_features """ new_image_features = [] feature_lens = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size if height * width != base_image_feature.shape[0]: raise ValueError("The number of patches is not consistent with the image size.") num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) max_num_patches = int(vision_aspect_ratio.strip("anyres_max_")) channels, curr_height, curr_width = image_feature.shape ratio = math.sqrt(curr_height * curr_width / (max_num_patches * height**2)) if ratio > 1.1: image_feature = image_feature[None] image_feature = nn.functional.interpolate( image_feature, [int(curr_height // ratio), int(curr_width // ratio)], mode="bilinear" )[0] if image_newline is not None: image_feature = torch.cat( ( image_feature, image_newline[:, None, None] .expand(*image_feature.shape[:-1], 1) .to(image_feature.device, image_feature.dtype), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) return new_image_features, feature_lens def get_image_features( self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, vision_aspect_ratio: Optional[str] = None, batch_num_images: Optional[torch.LongTensor] = None, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) The tensors corresponding to the input images. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). vision_feature_layer (`Union[int, list[int]]`): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. vision_feature_select_strategy (`str`): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` batch_num_images (`torch.LongTensor`, *optional*): Number of images in each sample. Returns: image_features (list[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) vision_aspect_ratio = ( vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio ) # ! infer image_num_patches from image_sizes if batch_num_images is None: # treat this as a single-image case for backward compatibility need_patching = [True] * len(image_sizes) else: need_patching = [n == 1 for n in batch_num_images for _ in range(n)] image_num_patches = [ image_size_to_num_patches( image_size=imsize, grid_pinpoints=self.config.image_grid_pinpoints, patch_size=self.config.vision_config.image_size, ) if should_patch else 1 for imsize, should_patch in zip(image_sizes, need_patching) ] if pixel_values.dim() == 5: # stacked if input is (batch_size, num_patches, num_channels, height, width) _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] pixel_values = torch.cat(_pixel_values_list, dim=0) elif pixel_values.dim() != 4: # otherwise has to be stacked from list of (num_patches, num_channels, height, width) raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") image_features = self.vision_tower(pixel_values, output_hidden_states=True) # If we have one vision feature layer, return the corresponding hidden states, # otherwise, select the hidden states of each feature layer and concatenate them if isinstance(vision_feature_layer, int): selected_image_feature = image_features.hidden_states[vision_feature_layer] else: hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] selected_image_feature = torch.cat(hs_pool, dim=-1) if vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] image_features = self.multi_modal_projector(selected_image_feature) image_features = torch.split(image_features, image_num_patches, dim=0) image_features, feature_lens = self.pack_image_features( image_features, image_sizes, image_newline=self.image_newline, vision_aspect_ratio=vision_aspect_ratio, ) return image_features def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: Optional[torch.FloatTensor] = None, video_features: Optional[torch.FloatTensor] = None, ): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is equal to the length of multimodal features. If the lengths are different, an error is raised. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) special_video_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_video_mask = special_video_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id special_video_mask = input_ids == self.config.video_token_id n_image_tokens = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" ) n_video_tokens = special_video_mask.sum() special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): raise ValueError( f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" ) return special_image_mask, special_video_mask @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_sizes_videos: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, vision_aspect_ratio: Optional[str] = None, batch_num_images: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, LlavaOnevisionModelOutputWithPast]: r""" image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*): The sizes of the videos in the batch, being (height, width) for each frame in the video. vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`): Aspect ratio used when processong image features. The default value is "anyres_max_9". batch_num_images (`torch.LongTensor`, *optional*): Number of images in each sample. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) vision_aspect_ratio = ( vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) # Images are processed with Anyres if pixel_values is not None: image_features = self.get_image_features( pixel_values, image_sizes, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, batch_num_images=batch_num_images, ) image_features = torch.cat(image_features, dim=0) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Video are simply embedded and further pooled to decrease seq len if pixel_values_videos is not None: video_features = self.get_video_features( pixel_values_videos, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) image_newline = ( self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device) ) video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1).to(inputs_embeds.device, inputs_embeds.dtype) _, special_video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_features ) inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **kwargs, ) return LlavaOnevisionModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, video_hidden_states=video_features if pixel_values_videos is not None else None, ) def get_video_features( self, pixel_values: torch.FloatTensor, vision_feature_layer: Union[int, list[int]], vision_feature_select_strategy: str, ): """ Obtains video last hidden states from the vision tower, apply multimodal projection and pooling. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) The tensors corresponding to the input video. vision_feature_layer (`Union[int, list[int]], *optional*, defaults to -2`): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. vision_feature_select_strategy (`str`): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: video_features (list[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches and are of shape `(num_videos, video_length, embed_dim)`). """ batch_size, frames, channels, height, width = pixel_values.shape pixel_values = pixel_values.view(batch_size * frames, channels, height, width) video_features = self.vision_tower(pixel_values, output_hidden_states=True) # If we have one vision feature layer, return the corresponding hidden states, # otherwise, select the hidden states of each feature layer and concatenate them if isinstance(vision_feature_layer, int): selected_video_feature = video_features.hidden_states[vision_feature_layer] else: hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] selected_video_feature = torch.cat(hs_pool, dim=-1) if vision_feature_select_strategy == "default": selected_video_feature = selected_video_feature[:, 1:] video_features = self.multi_modal_projector(selected_video_feature) video_features = self.apply_pooling(video_features) video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1) return video_features def apply_pooling(self, image_features): height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size batch_frames, seq_len, dim = image_features.shape image_features = image_features.view(batch_frames, height, width, -1) image_features = image_features.permute(0, 3, 1, 2).contiguous() height, width = image_features.shape[2:] scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)] image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear") image_features = image_features.permute(0, 2, 3, 1) image_features = image_features.view(batch_frames, -1, dim) return image_features @auto_docstring( custom_intro=""" The LLAVA-NeXT model which consists of a vision backbone and a language model. """ ) class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = { "^language_model.model": "model.language_model", "^vision_tower": "model.vision_tower", "^multi_modal_projector": "model.multi_modal_projector", "^image_newline": "model.image_newline", "^language_model.lm_head": "lm_head", } _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: LlavaOnevisionConfig): super().__init__(config) self.model = LlavaOnevisionModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_decoder(self, decoder): self.model.set_decoder(decoder) def get_decoder(self): return self.model.get_decoder() def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): return self.model.pack_image_features( image_features=image_features, image_sizes=image_sizes, vision_feature_select_strategy=vision_feature_select_strategy, image_newline=image_newline, ) def get_image_features( self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, ): return self.model.get_image_features( pixel_values=pixel_values, image_sizes=image_sizes, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) # Make modules available through conditional class for BC @property def language_model(self): return self.model.language_model @property def vision_tower(self): return self.model.vision_tower @property def multi_modal_projector(self): return self.model.multi_modal_projector @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_sizes_videos: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, vision_aspect_ratio: Optional[str] = None, batch_num_images: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, LlavaOnevisionCausalLMOutputWithPast]: r""" image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*): The sizes of the videos in the batch, being (height, width) for each frame in the video. vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`): Aspect ratio used when processong image features. The default value is "anyres_max_9". batch_num_images (`torch.LongTensor`, *optional*): Number of images in each sample. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from PIL import Image >>> import requests >>> import torch >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", dtype="float16", device_map="cuda:0") >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf") >>> conversation = [ ... { ... "role": "user", ... "content": [ ... {"type": "text", "text": "What is shown in this image?"}, ... {"type": "image"}, ... ], ... }, ... ] >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> raw_image = Image.open(requests.get(image_file, stream=True).raw) >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16) >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False) >>> processor.batch_decode(output, skip_special_tokens=True)[0] "user\n\nWhat is shown in this image?\nassistant\ncat" ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) vision_aspect_ratio = ( vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio ) outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_sizes=image_sizes, image_sizes_videos=image_sizes_videos, vision_aspect_ratio=vision_aspect_ratio, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, batch_num_images=batch_num_images, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return LlavaOnevisionCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, video_hidden_states=outputs.video_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, image_sizes=None, pixel_values_videos=None, image_sizes_videos=None, attention_mask=None, cache_position=None, logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values model_inputs["image_sizes"] = image_sizes model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["image_sizes_videos"] = image_sizes_videos return model_inputs @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, cache_position: torch.Tensor, batch_size: int, **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask def get_video_features( self, pixel_values: torch.FloatTensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, ): return self.model.get_video_features( pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) __all__ = ["LlavaOnevisionModel", "LlavaOnevisionForConditionalGeneration", "LlavaOnevisionPreTrainedModel"]