# coding=utf-8 # Copyright 2025 the Cohere Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch AyaVision model.""" from typing import Optional, Union import torch from torch import nn from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaModel, LlavaModelOutputWithPast, LlavaPreTrainedModel, TransformersKwargs, ) from ...activations import ACT2FN from ...cache_utils import Cache from ...processing_utils import Unpack from ...utils import auto_docstring, logging from ...utils.generic import check_model_inputs from .configuration_aya_vision import AyaVisionConfig logger = logging.get_logger(__name__) class AyaVisionMultiModalProjector(nn.Module): def __init__(self, config: AyaVisionConfig): super().__init__() self.config = config self.downsample_factor = config.downsample_factor self.alignment_intermediate_size = getattr( config, "alignment_intermediate_size", config.text_config.hidden_size ) self.layernorm = nn.LayerNorm( config.vision_config.hidden_size * (config.downsample_factor**2), eps=config.adapter_layer_norm_eps ) self.linear_1 = nn.Linear( config.vision_config.hidden_size * (config.downsample_factor**2), self.alignment_intermediate_size, bias=True, ) self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation # For SwiGLU, project down to half size since we split intermediate dim self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, config.text_config.hidden_size, bias=True) def forward(self, image_features): image_features = self.pixel_shuffle(image_features) image_features = self.layernorm(image_features) hidden_states = self.linear_1(image_features) # Split along last dimension and apply SwiGLU x, gate = hidden_states.chunk(2, dim=-1) hidden_states = self.act(gate) * x hidden_states = self.linear_2(hidden_states) return hidden_states def pixel_shuffle(self, image_features): # B, S, D batch_size, seq_length, feature_dim = image_features.shape height = width = int(seq_length**0.5) image_features = image_features.reshape(image_features.shape[0], width, height, -1) channels = image_features.shape[-1] image_features = image_features.reshape( batch_size, width, int(height / self.downsample_factor), int(channels * self.downsample_factor) ) image_features = image_features.permute(0, 2, 1, 3) image_features = image_features.reshape( batch_size, int(height / self.downsample_factor), int(width / self.downsample_factor), -1 ) image_features = image_features.permute(0, 2, 1, 3) return image_features class AyaVisionPreTrainedModel(LlavaPreTrainedModel): _can_compile_fullgraph = False _can_record_outputs = { "hidden_states": "DecoderLayer", "attentions": "Attention", } class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass class AyaVisionModelOutputWithPast(LlavaModelOutputWithPast): pass class AyaVisionModel(LlavaModel): # Unlike LLaVA, the model doesn't have to deal with Pixtral-style image states def get_image_features( self, pixel_values: torch.FloatTensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, **kwargs, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, list[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) if vision_feature_select_strategy not in ["default", "full"]: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) # If we have one vision feature layer, return the corresponding hidden states, # otherwise, select the hidden states of each feature layer and concatenate them if isinstance(vision_feature_layer, int): selected_image_feature = image_outputs.hidden_states[vision_feature_layer] if vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] else: hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] # For default; crop CLS from each hidden state in the hidden state pool if vision_feature_select_strategy == "default": hs_pool = [hs[:, 1:] for hs in hs_pool] selected_image_feature = torch.cat(hs_pool, dim=-1) image_features = self.multi_modal_projector(selected_image_feature) return image_features @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, AyaVisionModelOutputWithPast]: vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) return AyaVisionModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration): def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, image_sizes: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, AyaVisionCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoProcessor, AyaVisionForConditionalGeneration >>> import torch >>> torch_device = "cuda:0" >>> processor = AutoProcessor.from_pretrained("CohereForAI/aya-vision-8b", use_fast=True) >>> model = AyaVisionForConditionalGeneration.from_pretrained("CohereForAI/aya-vision-8b", device_map=torch_device) >>> messages = [ ... { ... "role": "user", ... "content": [ ... { ... "type": "image", ... "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium", ... }, ... {"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"}, ... ], ... } ... ] >>> inputs = processor.apply_chat_template( ... messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device ... ).to(model.device) >>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3) >>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) ```""" super().forward( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, labels=labels, cache_position=cache_position, logits_to_keep=logits_to_keep, image_sizes=image_sizes, **kwargs, ) __all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]