# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_deepseek_vl.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Union import torch import torch.nn as nn from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ..auto import AutoModel from .configuration_deepseek_vl import DeepseekVLConfig @dataclass @auto_docstring( custom_intro=""" Base class for DeepseekVL model's outputs that may also contain a past key/values (to speed up sequential decoding). """ ) class DeepseekVLBaseModelOutputWithPast(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver """ last_hidden_state: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[tuple[torch.FloatTensor]] = None @dataclass @auto_docstring( custom_intro=""" Base class for DeepseekVL causal language model (or autoregressive) outputs. """ ) class DeepseekVLCausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[tuple[torch.FloatTensor]] = None class DeepseekVLAligner(nn.Module): def __init__(self, config): super().__init__() self.config = config in_features = config.vision_config.hidden_size out_features = config.text_config.hidden_size self.linear1 = nn.Linear(in_features, out_features) self.activation = nn.GELU() self.linear2 = nn.Linear(out_features, out_features) def forward(self, vision_encodings: torch.Tensor) -> torch.Tensor: x = self.linear1(vision_encodings) x = self.activation(x) x = self.linear2(x) return x @auto_docstring class DeepseekVLPreTrainedModel(PreTrainedModel): config: DeepseekVLConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn = True _supports_sdpa = True _can_compile_fullgraph = True _supports_param_buffer_assignment = False def _init_weights(self, module): """Initialize the weights""" # Required only for Linear layer in DeepseekVLAligner if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range) if module.bias is not None: module.bias.data.zero_() @auto_docstring class DeepseekVLModel(DeepseekVLPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.vision_model = AutoModel.from_config(config.vision_config) self.aligner = DeepseekVLAligner(config) self.language_model = AutoModel.from_config(config=config.text_config) self.gradient_checkpointing = False # Initialize weights and apply final processing. self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_image_features(self, pixel_values): image_embeds = self.vision_model(pixel_values) image_embeds = self.aligner(image_embeds.last_hidden_state) return image_embeds def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is equal to the length of multimodal features. If the lengths are different, an error is raised. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if inputs_embeds[special_image_mask].numel() != image_features.numel(): n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) return special_image_mask @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values) image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_attention_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) lm_output = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) return DeepseekVLBaseModelOutputWithPast( last_hidden_state=lm_output.last_hidden_state, past_key_values=lm_output.past_key_values, hidden_states=lm_output.hidden_states, attentions=lm_output.attentions, image_hidden_states=image_embeds if pixel_values is not None else None, ) class DeepseekVLForConditionalGeneration(DeepseekVLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] _can_compile_fullgraph = True def __init__(self, config: DeepseekVLConfig): super().__init__(config) self.config = config self.model = DeepseekVLModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) # Initialize weights and apply final processing. self.post_init() def get_input_embeddings(self): return self.model.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.model.language_model.set_input_embeddings(value) def prepare_embeddings_for_image_generation(self) -> torch.Tensor: raise AttributeError("Not needed for DeepseekVL") @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return DeepseekVLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, pixel_values=None, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, logits_to_keep=None, **kwargs, ): # Overwritten -- extra custom processing model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values return model_inputs __all__ = ["DeepseekVLPreTrainedModel", "DeepseekVLModel", "DeepseekVLForConditionalGeneration"]