# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch VideoLlava model.""" from dataclasses import dataclass from typing import Optional, Union import torch from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from .configuration_video_llava import VideoLlavaConfig logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Base class for VideoLlava base model outputs. """ ) class VideoLlavaModelOutputWithPast(ModelOutput): r""" past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. video_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ last_hidden_state: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None video_hidden_states: Optional[torch.FloatTensor] = None @dataclass @auto_docstring( custom_intro=""" Base class for VideoLlava causal language model (or autoregressive) outputs. """ ) class VideoLlavaCausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. video_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None video_hidden_states: Optional[torch.FloatTensor] = None # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->VideoLlava class VideoLlavaMultiModalProjector(nn.Module): def __init__(self, config: VideoLlavaConfig): super().__init__() # We have hidden_size * the number of vision feature layers num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) self.linear_1 = nn.Linear( config.vision_config.hidden_size * num_feature_layers, config.text_config.hidden_size, bias=config.multimodal_projector_bias, ) self.act = ACT2FN[config.projector_hidden_act] self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias ) def forward(self, image_features): hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states @auto_docstring class VideoLlavaPreTrainedModel(PreTrainedModel): config: VideoLlavaConfig base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True _can_compile_fullgraph = True _supports_attention_backend = True def _init_weights(self, module): std = ( self.config.initializer_range if hasattr(self.config, "initializer_range") else self.config.text_config.initializer_range ) if hasattr(module, "class_embedding"): module.class_embedding.data.normal_(mean=0.0, std=std) if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() @auto_docstring( custom_intro=""" The VideoLlava model which consists of a vision backbone and a language model without language modeling head. """, ) class VideoLlavaModel(VideoLlavaPreTrainedModel): _checkpoint_conversion_mapping = {"language_model.model": "language_model"} def __init__(self, config: VideoLlavaConfig): super().__init__(config) self.video_tower = AutoModel.from_config(config.vision_config) self.image_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = VideoLlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_decoder(self, decoder): self.language_model = decoder def get_decoder(self): return self.language_model def get_image_features( self, pixel_values_images: torch.FloatTensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) The tensors corresponding to the input images. vision_feature_layer (`Union[int, list[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) if vision_feature_select_strategy not in ["default", "full"]: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True) # If we have one vision feature layer, return the corresponding hidden states, # otherwise, select the hidden states of each feature layer and concatenate them if isinstance(vision_feature_layer, int): image_outputs = image_outputs.hidden_states[vision_feature_layer] if vision_feature_select_strategy == "default": image_outputs = image_outputs[:, 1:] else: hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] # For default; crop CLS from each hidden state in the hidden state pool if vision_feature_select_strategy == "default": hs_pool = [hs[:, 1:] for hs in hs_pool] image_outputs = torch.cat(hs_pool, dim=-1) image_features = self.multi_modal_projector(image_outputs) return image_features def get_video_features( self, pixel_values_videos: torch.FloatTensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, ): """ Obtains video last hidden states from the vision tower and apply multimodal projection. Args: pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) The tensors corresponding to the input videos. vision_feature_layer (`Union[int, list[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. Returns: video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`). frames (`int`): Number of frames the videos have. """ vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width) video_outputs = self.video_tower(pixel_values, output_hidden_states=True) # If we have one vision feature layer, return the corresponding hidden states, # otherwise, select the hidden states of each feature layer and concatenate them if isinstance(vision_feature_layer, int): video_features = video_outputs.hidden_states[vision_feature_layer] else: hs_pool = [video_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] video_features = torch.cat(hs_pool, dim=-1) video_features = self.multi_modal_projector(video_features) return video_features, num_frames def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: Optional[torch.FloatTensor] = None, video_features: Optional[torch.FloatTensor] = None, ): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is equal to the length of multimodal features. If the lengths are different, an error is raised. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) special_video_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_video_mask = special_video_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id special_video_mask = input_ids == self.config.video_token_id n_image_tokens = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}" ) n_video_tokens = special_video_mask.sum() special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): raise ValueError( f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0] * video_features.shape[1]}" ) return special_image_mask, special_video_mask @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values_images: Optional[torch.FloatTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, VideoLlavaModelOutputWithPast]: r""" pixel_values_images (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`]. See [`VideoLlavaImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses [`VideoLlavaImageProcessor`] for processing images). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values_images is not None: image_features = self.get_image_features( pixel_values_images, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if pixel_values_videos is not None: video_features, num_frames = self.get_video_features( pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer ) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) _, special_video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_features ) inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **kwargs, ) return VideoLlavaModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values_images is not None else None, video_hidden_states=video_features if pixel_values_videos is not None else None, ) @auto_docstring( custom_intro=""" The VideoLlava model which consists of a vision backbone and a language model. """ ) class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = { "^language_model.model": "model.language_model", "^image_tower": "model.image_tower", "^video_tower": "model.video_tower", "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: VideoLlavaConfig): super().__init__(config) self.model = VideoLlavaModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_decoder(self, decoder): self.model.set_decoder(decoder) def get_decoder(self): return self.model.get_decoder() def get_image_features( self, pixel_values_images: torch.FloatTensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, ): return self.model.get_image_features( pixel_values_images=pixel_values_images, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) # Make modules available through conditional class for BC @property def language_model(self): return self.model.language_model @property def video_tower(self): return self.model.video_tower @property def image_tower(self): return self.model.image_tower @property def multi_modal_projector(self): return self.model.multi_modal_projector @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values_images: Optional[torch.FloatTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, VideoLlavaCausalLMOutputWithPast]: r""" pixel_values_images (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`]. See [`VideoLlavaImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses [`VideoLlavaImageProcessor`] for processing images). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from PIL import Image >>> import requests >>> import numpy as np >>> import av >>> from huggingface_hub import hf_hub_download >>> from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration >>> def read_video_pyav(container, indices): ... ''' ... Decode the video with PyAV decoder. ... Args: ... container (`av.container.input.InputContainer`): PyAV container. ... indices (`list[int]`): List of frame indices to decode. ... Returns: ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). ... ''' ... frames = [] ... container.seek(0) ... start_index = indices[0] ... end_index = indices[-1] ... for i, frame in enumerate(container.decode(video=0)): ... if i > end_index: ... break ... if i >= start_index and i in indices: ... frames.append(frame) ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) >>> model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf") >>> processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf") >>> prompt = "USER: