# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/florence2/modular_florence2.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_florence2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 Microsoft and the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Any, Callable, Optional, Union import torch.nn as nn import torch.nn.functional as F from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, logging, ) from ..auto import AutoModel from .configuration_florence2 import Florence2Config, Florence2VisionConfig if is_torch_available(): import torch logger = logging.get_logger(__name__) def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize output = input.div(keep_prob) * random_tensor return output class Florence2VisionDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: return f"p={self.drop_prob}" class Florence2VisionLearnedAbsolutePositionEmbedding2D(nn.Module): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, config: Florence2Config): super().__init__() num_pos = config.vision_config.max_position_embeddings embedding_dim = config.vision_config.embed_dim[-1] self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) def forward(self, pixel_values, pixel_mask=None): height, width = pixel_values.shape[-2:] width_values = torch.arange(width, device=pixel_values.device) height_values = torch.arange(height, device=pixel_values.device) x_emb = self.column_embeddings(width_values) y_emb = self.row_embeddings(height_values) pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) pos = pos.permute(2, 0, 1) pos = pos.unsqueeze(0) pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) return pos class Florence2VisionPositionalEmbeddingCosine1D(nn.Module): """ This module generates 1D cosine positional embeddings using precomputed sinusoidal functions. """ def __init__(self, config: Florence2Config): super().__init__() self.embed_dim = config.vision_config.embed_dim[-1] self.max_seq_len = config.vision_config.max_temporal_embeddings pos_idx_to_embed = torch.empty((self.max_seq_len, self.embed_dim)) sine, cosine = self.get_sinusoid_embeddings( max_positions=self.max_seq_len, embed_dim=self.embed_dim, ) pos_idx_to_embed[:, 0::2] = sine pos_idx_to_embed[:, 1::2] = cosine # Save the positional embeddings in a constant buffer. self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) @staticmethod def get_sinusoid_embeddings(max_positions: int, embed_dim: int): half_dim = embed_dim // 2 emb = math.log(10000) / half_dim emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) emb = torch.arange(max_positions, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) return torch.sin(emb), torch.cos(emb) def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: len_seq = seq_embeds.size(1) if len_seq > self.max_seq_len: raise ValueError(f"Maximum sequence length {self.max_seq_len}, got {len_seq}") pos_embeds = self.pos_idx_to_embed[0:len_seq, :] return pos_embeds class Florence2VisionMLP(nn.Module): def __init__(self, config: Florence2VisionConfig, stage_idx: int): super().__init__() self.config = config self.activation_fn = ACT2FN[config.activation_function] self.fc1 = nn.Linear(config.embed_dim[stage_idx], int(config.embed_dim[stage_idx] * config.mlp_ratio)) self.fc2 = nn.Linear(int(config.embed_dim[stage_idx] * config.mlp_ratio), config.embed_dim[stage_idx]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class Florence2VisionConvEmbed(nn.Module): """Image to Patch Embedding""" def __init__(self, config: Florence2VisionConfig, stage_idx: int): super().__init__() self.config = config self.stage_idx = stage_idx self.patch_size = config.patch_size[stage_idx] self.in_channels = config.in_channels if stage_idx == 0 else config.embed_dim[stage_idx - 1] self.embed_dim = config.embed_dim[stage_idx] self.stride = config.patch_stride[stage_idx] self.padding = config.patch_padding[stage_idx] self.pre_norm = config.patch_prenorm[stage_idx] self.conv = nn.Conv2d( self.in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.stride, padding=self.padding, ) dim_norm = self.in_channels if self.pre_norm else self.embed_dim self.norm = nn.LayerNorm(dim_norm) def forward(self, hidden_states: torch.Tensor): if self.norm and self.pre_norm: hidden_states = hidden_states.permute(0, 2, 3, 1) hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 1, 2) hidden_states = self.conv(hidden_states) if self.norm and not self.pre_norm: hidden_states = hidden_states.permute(0, 2, 3, 1) hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 1, 2) return hidden_states def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, head_mask: Optional[torch.Tensor] = None, **kwargs, ): if scaling is None: scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) if head_mask is not None: attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Florence2VisionChannelAttention(nn.Module): def __init__(self, config: Florence2VisionConfig, stage_idx: int): super().__init__() self.config = config self.dim = config.embed_dim[stage_idx] self.groups = config.num_groups[stage_idx] self.qkv = nn.Linear(self.dim, self.dim * 3, bias=config.qkv_bias) self.proj = nn.Linear(self.dim, self.dim) self.is_causal = False def forward(self, hidden_states: torch.Tensor): batch_size, num_tokens, hidden_size = hidden_states.shape # Reshape for grouped channel attention qkv = self.qkv(hidden_states).reshape(batch_size, num_tokens, 3, self.groups, hidden_size // self.groups) qkv = qkv.permute(2, 0, 3, 4, 1) query, key, value = qkv.unbind(0) scale = num_tokens**-0.5 # Channel-to-channel attention within groups: attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] hidden_states, _ = attention_interface( self, query, key, value, attention_mask=None, scaling=scale, ) hidden_states = hidden_states.permute(0, 3, 2, 1) hidden_states = hidden_states.reshape(batch_size, num_tokens, hidden_size) # Final projection hidden_states = self.proj(hidden_states) return hidden_states class Florence2VisionChannelBlock(nn.Module): def __init__( self, config: Florence2VisionConfig, stage_idx: int, drop_path_rate: float, ): super().__init__() self.config = config dim_in = config.embed_dim[stage_idx] self.conv1 = nn.Conv2d( dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in, ) self.norm1 = nn.LayerNorm(config.embed_dim[stage_idx]) self.channel_attn = Florence2VisionChannelAttention(config=config, stage_idx=stage_idx) self.drop_path1 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.conv2 = nn.Conv2d( dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in, ) self.norm2 = nn.LayerNorm(config.embed_dim[stage_idx]) self.ffn = Florence2VisionMLP(config=config, stage_idx=stage_idx) self.drop_path2 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() def forward(self, hidden_states: torch.Tensor): batch_size, embed_dim, height, width = hidden_states.shape # First channel block: Depthwise Conv + Channel Attention hidden_states = self.conv1(hidden_states) + hidden_states hidden_states = hidden_states.flatten(2).transpose(1, 2) residual = hidden_states # Channel group attention self-attention mechanism hidden_states = self.norm1(hidden_states) hidden_states = self.channel_attn(hidden_states) hidden_states = residual + self.drop_path1(hidden_states) hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) # Second channel block: Depthwise Conv + FFN hidden_states = self.conv2(hidden_states) + hidden_states hidden_states = hidden_states.flatten(2).transpose(1, 2) residual = hidden_states # FFN hidden_states = self.norm2(hidden_states) hidden_states = self.ffn(hidden_states) hidden_states = residual + self.drop_path2(hidden_states) hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) return hidden_states class Florence2VisionWindowAttention(nn.Module): def __init__(self, config: Florence2VisionConfig, stage_idx: int): super().__init__() self.config = config self.dim = config.embed_dim[stage_idx] self.window_size = config.window_size self.num_heads = config.num_heads[stage_idx] head_dim = self.dim // self.num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(self.dim, self.dim * 3, bias=config.qkv_bias) self.proj = nn.Linear(self.dim, self.dim) self.is_causal = False def forward(self, hidden_states: torch.Tensor): batch_size, height, width, embed_dim = hidden_states.shape # Pad the input if necessary pad_left = pad_top = 0 pad_right = (self.window_size - width % self.window_size) % self.window_size pad_bottom = (self.window_size - height % self.window_size) % self.window_size hidden_states = F.pad(hidden_states, (0, 0, pad_left, pad_right, pad_top, pad_bottom)) _, padded_height, padded_width, _ = hidden_states.shape # Partition input into non-overlapping windows (for local spatial attention in DaViT) hidden_states = hidden_states.view( batch_size, padded_height // self.window_size, self.window_size, padded_width // self.window_size, self.window_size, embed_dim, ) windowed_hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous() windowed_hidden_states = windowed_hidden_states.view(-1, self.window_size * self.window_size, embed_dim) # Generate Q, K, V for each window num_windows_per_batch, num_tokens_per_window, embed_dim = windowed_hidden_states.shape qkv = self.qkv(windowed_hidden_states).reshape( num_windows_per_batch, num_tokens_per_window, 3, self.num_heads, embed_dim // self.num_heads ) qkv = qkv.permute(2, 0, 3, 1, 4) query, key, value = qkv.unbind(0) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] windowed_hidden_states, _ = attention_interface( self, query, key, value, attention_mask=None, scaling=self.scale, ) windowed_hidden_states = windowed_hidden_states.view(num_windows_per_batch, num_tokens_per_window, embed_dim) windowed_hidden_states = self.proj(windowed_hidden_states) # Merge windows back to original spatial layout windowed_hidden_states = windowed_hidden_states.view(-1, self.window_size, self.window_size, embed_dim) hidden_states = windowed_hidden_states.view( -1, padded_height // self.window_size, padded_width // self.window_size, self.window_size, self.window_size, embed_dim, ) hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous() hidden_states = hidden_states.view(-1, padded_height, padded_width, embed_dim) hidden_states = hidden_states[:, :height, :width, :].contiguous() hidden_states = hidden_states.view(batch_size, height * width, embed_dim) return hidden_states class Florence2VisionSpatialBlock(nn.Module): def __init__( self, config: Florence2VisionConfig, stage_idx: int, drop_path_rate: float, ): super().__init__() self.conv1 = nn.Conv2d( config.embed_dim[stage_idx], config.embed_dim[stage_idx], kernel_size=3, padding=1, groups=config.embed_dim[stage_idx], ) self.norm1 = nn.LayerNorm(config.embed_dim[stage_idx]) self.window_attn = Florence2VisionWindowAttention(config=config, stage_idx=stage_idx) self.drop_path1 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.conv2 = nn.Conv2d( config.embed_dim[stage_idx], config.embed_dim[stage_idx], kernel_size=3, padding=1, groups=config.embed_dim[stage_idx], ) self.norm2 = nn.LayerNorm(config.embed_dim[stage_idx]) self.ffn = Florence2VisionMLP(config=config, stage_idx=stage_idx) self.drop_path2 = Florence2VisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() def forward(self, hidden_states: torch.Tensor): batch_size, embed_dim, height, width = hidden_states.shape # First spatial mixing block: Conv + Window Attention hidden_states = self.conv1(hidden_states) + hidden_states hidden_states = hidden_states.flatten(2).transpose(1, 2) residual = hidden_states # Spatial Window-based self-attention mechanism hidden_states = self.norm1(hidden_states) hidden_states = hidden_states.view(batch_size, height, width, embed_dim) hidden_states = self.window_attn(hidden_states) hidden_states = residual + self.drop_path1(hidden_states) hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) # Second spatial mixing block: Conv + FFN hidden_states = self.conv2(hidden_states) + hidden_states hidden_states = hidden_states.flatten(2).transpose(1, 2) residual = hidden_states # FFN hidden_states = self.norm2(hidden_states) hidden_states = self.ffn(hidden_states) hidden_states = residual + self.drop_path2(hidden_states) hidden_states = hidden_states.transpose(1, 2).view(batch_size, embed_dim, height, width) return hidden_states class Florence2VisionBlock(nn.Module): def __init__( self, config: Florence2VisionConfig, stage_idx: int, spatial_drop_path_rate: float, channel_drop_path_rate: float, ): super().__init__() self.spatial_block = Florence2VisionSpatialBlock( config=config, stage_idx=stage_idx, drop_path_rate=spatial_drop_path_rate, ) self.channel_block = Florence2VisionChannelBlock( config=config, stage_idx=stage_idx, drop_path_rate=channel_drop_path_rate, ) def forward(self, hidden_states: torch.Tensor): hidden_states = self.spatial_block(hidden_states) hidden_states = self.channel_block(hidden_states) return hidden_states @auto_docstring class Florence2VisionPreTrainedModel(PreTrainedModel): config_class = Florence2VisionConfig main_input_name = "pixel_values" _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True _can_compile_fullgraph = True @auto_docstring class Florence2VisionBackbone(Florence2VisionPreTrainedModel): def __init__(self, config: Florence2VisionConfig): super().__init__(config) self.config = config self.embed_dim = config.embed_dim self.num_heads = config.num_heads self.num_groups = config.num_groups self.num_stages = len(self.embed_dim) if not (self.num_stages == len(self.num_heads) == len(self.num_groups)): raise ValueError( f"Expected self.num_stages ({self.num_stages}) == " f"len(self.num_heads) ({len(self.num_heads)}) == " f"len(self.num_groups) ({len(self.num_groups)})" ) dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths) * 2, device="cpu")] depth_offset = 0 convs = [] blocks = [] for stage_idx in range(self.num_stages): conv_embed = Florence2VisionConvEmbed( config=config, stage_idx=stage_idx, ) convs.append(conv_embed) block = nn.ModuleList( Florence2VisionBlock( config=config, stage_idx=stage_idx, spatial_drop_path_rate=dpr[depth_offset + block_idx * 2], channel_drop_path_rate=dpr[depth_offset + block_idx * 2 + 1], ) for block_idx in range(config.depths[stage_idx]) ) blocks.append(block) depth_offset += config.depths[stage_idx] * 2 self.convs = nn.ModuleList(convs) self.blocks = nn.ModuleList(blocks) # Initialize weights and apply final processing self.post_init() def forward(self, hidden_states: torch.Tensor): for conv, block in zip(self.convs, self.blocks): hidden_states = conv(hidden_states) for layer in block: hidden_states = layer(hidden_states) return hidden_states class Florence2MultiModalProjector(nn.Module): def __init__(self, config: Florence2Config): super().__init__() self.vision_embedding_dim = config.vision_config.embed_dim[-1] self.vision_projection_dim = config.vision_config.projection_dim self.image_projection = nn.Linear(self.vision_embedding_dim, self.vision_projection_dim, bias=False) self.image_proj_norm = nn.LayerNorm(self.vision_projection_dim) self.image_position_embed = Florence2VisionLearnedAbsolutePositionEmbedding2D(config=config) self.visual_temporal_embed = Florence2VisionPositionalEmbeddingCosine1D(config=config) def forward(self, image_features): position_features = image_features + self.image_position_embed(image_features) position_features = position_features.flatten(2).transpose(1, 2) temporal_features = self.visual_temporal_embed(position_features[:, :1, :]) temporal_features = temporal_features.unsqueeze(1) visual_token_features = position_features + temporal_features visual_token_features = visual_token_features.unsqueeze(1) spatial_image_features = visual_token_features.mean(dim=2) temporal_image_features = visual_token_features.mean(dim=1) image_features = torch.cat([spatial_image_features, temporal_image_features], dim=1) image_features = self.image_projection(image_features) image_features = self.image_proj_norm(image_features) return image_features @dataclass @auto_docstring( custom_intro=""" Base class for Florence-2 base model's outputs that also contains : pre-computed hidden states that can speed up sequential decoding. """ ) class Florence2Seq2SeqModelOutput(Seq2SeqModelOutput): r""" image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: Optional[torch.FloatTensor] = None @dataclass @auto_docstring( custom_intro=""" Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential decoding. """ ) class Florence2Seq2SeqLMOutput(Seq2SeqLMOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_image_tokens, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None @auto_docstring class Florence2PreTrainedModel(PreTrainedModel): config: Florence2Config base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = False config_class = Florence2Config @auto_docstring( custom_intro=""" Florence-2 is a vision model for captioning, detection, and segmentation. """ ) class Florence2Model(Florence2PreTrainedModel): _checkpoint_conversion_mapping = {} _tied_weights_keys = [ "language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", ] def __init__(self, config: Florence2Config): super().__init__(config) self.vision_tower = Florence2VisionBackbone(config=config.vision_config) self.multi_modal_projector = Florence2MultiModalProjector(config) self.language_model = AutoModel.from_config(config.text_config) self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_decoder(self, decoder): self.language_model = decoder def get_decoder(self): return self.language_model.get_decoder() def get_image_features(self, pixel_values: torch.Tensor, **kwargs): """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ image_features = self.vision_tower(pixel_values, **kwargs) image_embeds = self.multi_modal_projector(image_features) return image_embeds def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is equal to the length of multimodal features. If the lengths are different, an error is raised. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_features = image_features.shape[0] * image_features.shape[1] if inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) return special_image_mask @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[tuple, Florence2Seq2SeqModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_features = self.get_image_features(pixel_values) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) encoder_outputs = self.language_model.encoder( attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) if decoder_input_ids is None: decoder_start_token_id = self.config.text_config.decoder_start_token_id decoder_input_ids = torch.ones((inputs_embeds.size()[0], 1), dtype=torch.long, device=inputs_embeds.device) decoder_input_ids *= decoder_start_token_id decoder_outputs = self.language_model.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=True, ) return Florence2Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) def get_encoder(self): return self.language_model.get_encoder() def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids @auto_docstring( custom_intro=""" Florence-2 is a vision model for captioning, detection, and segmentation. """ ) class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} _tied_weights_keys = [ "model.language_model.encoder.embed_tokens.weight", "model.language_model.decoder.embed_tokens.weight", "lm_head.weight", ] def __init__(self, config: Florence2Config): super().__init__(config) self.model = Florence2Model(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_decoder(self, decoder): self.model.set_decoder(decoder) def get_decoder(self): return self.model.get_decoder() def get_image_features(self, pixel_values: torch.Tensor, **kwargs): return self.model.get_image_features(pixel_values=pixel_values, **kwargs) # Make modules available through conditional class for BC @property def language_model(self): return self.model.language_model @property def vision_tower(self): return self.model.vision_tower @property def multi_modal_projector(self): return self.model.multi_modal_projector @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Florence2Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large") >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large") >>> prompt = "" >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_length=100) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "A green car parked in front of a yellow building." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.text_config.pad_token_id, self.config.text_config.decoder_start_token_id ) outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, # **kwargs, ## TODO: add back when Bart attention is refactored and takes kwargs ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return Florence2Seq2SeqLMOutput( loss=loss, logits=logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, cache_position=None, logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values return model_inputs def get_encoder(self): return self.model.get_encoder() def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): return self.model.get_placeholder_mask( input_ids=input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str], generation_config, ) -> dict[str, Any]: # override to handle merging image and text embeddings before passing to language encoder inputs_embeds = model_kwargs.pop("inputs_embeds", None) pixel_values = model_kwargs.pop("pixel_values", None) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(inputs_tensor) if pixel_values is not None: image_features = self.get_image_features(pixel_values) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask = self.get_placeholder_mask( inputs_tensor, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) model_kwargs["inputs_embeds"] = inputs_embeds model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation( None, model_kwargs, model_input_name, generation_config ) model_kwargs.pop("inputs_embeds", None) return model_kwargs __all__ = [ "Florence2Model", "Florence2ForConditionalGeneration", "Florence2PreTrainedModel", "Florence2VisionBackbone", "Florence2VisionPreTrainedModel", ]