# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_phi4_multimodal.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from typing import Callable, Optional, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, torch_int from ...utils.deprecation import deprecate_kwarg from ...utils.generic import TransformersKwargs, check_model_inputs from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig class Phi4MultimodalVisionMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states def simple_eager_attention_forward( module: nn.Module, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Phi4MultimodalVisionAttention(nn.Module): def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.scaling = self.head_dim**-0.5 self.is_causal = True self.attention_dropout = config.attention_dropout self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) attention_interface: Callable = simple_eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class Phi4MultimodalVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = Phi4MultimodalVisionAttention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Phi4MultimodalVisionMLP(config) @auto_docstring def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class Phi4MultimodalVisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Phi4MultimodalVisionEncoderLayer`]. Args: config: Phi4MultimodalVisionConfig """ def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList( [Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False # Ignore copy @auto_docstring def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask, **kwargs, ) return BaseModelOutput(last_hidden_state=hidden_states) def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsequently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") @auto_docstring class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): config: Phi4MultimodalVisionConfig base_model_prefix = "phi4_vision" supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Phi4MultimodalVisionEncoderLayer, "attentions": Phi4MultimodalVisionAttention, } def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Phi4MultimodalVisionEmbeddings): width = ( self.config.hidden_size if isinstance(self.config, Phi4MultimodalVisionConfig) else self.config.hidden_size ) nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, Phi4MultimodalVisionAttention): nn.init.normal_(module.q_proj.weight) nn.init.normal_(module.k_proj.weight) nn.init.normal_(module.v_proj.weight) nn.init.normal_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, Phi4MultimodalVisionMLP): nn.init.normal_(module.fc1.weight) nn.init.normal_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): nn.init.normal_(module.probe.data) nn.init.normal_(module.attention.in_proj_weight.data) nn.init.zeros_(module.attention.in_proj_bias.data) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class Phi4MultimodalVisionEmbeddings(nn.Module): def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.config = config self.patch_size = config.patch_size self.num_patches_per_side = config.image_size // self.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and no class embeddings. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] num_positions = self.position_embedding.weight.shape[0] # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding(self.position_ids) patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: batch_size = pixel_values.size(0) patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings class Phi4MultimodalVisionMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = Phi4MultimodalVisionMLP(config) def forward(self, hidden_state, attention_mask): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) hidden_state = self.attention( query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask )[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): config: Phi4MultimodalVisionConfig main_input_name = "pixel_values" def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__(config) self.config = config self.embeddings = Phi4MultimodalVisionEmbeddings(config) self.encoder = Phi4MultimodalVisionEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.embeddings.patch_embedding @check_model_inputs def forward( self, pixel_values, patch_attention_mask: Optional[torch.BoolTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: batch_size = pixel_values.size(0) if patch_attention_mask is None: patch_attention_mask = torch.ones( size=( batch_size, pixel_values.size(2) // self.config.patch_size, pixel_values.size(3) // self.config.patch_size, ), dtype=torch.bool, device=pixel_values.device, ) hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) patch_attention_mask = patch_attention_mask.view(batch_size, -1) # The call to `_upad_input` in `_flash_attention_forward` is expensive # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): attention_mask = None else: attention_mask = ( _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if self.config._attn_implementation != "flash_attention_2" else patch_attention_mask ) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, **kwargs, ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = self.head( hidden_state=last_hidden_state, attention_mask=patch_attention_mask, ) return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, ) class Phi4MultimodalImageEmbedding(nn.Module): """Image embedding.""" def __init__(self, config: Phi4MultimodalConfig): super().__init__() self.config = config self.layer_idx = config.vision_config.feature_layer self.crop_size = config.vision_config.crop_size self.image_dim_out = config.vision_config.hidden_size n_patches = config.vision_config.image_size // config.vision_config.patch_size if n_patches % 2 != 0: self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) n_patches += 1 self.num_img_tokens = (n_patches // 2) ** 2 self.drop = nn.Dropout(config.embd_pdrop) self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config) self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size) self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size) self.global_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, self.image_dim_out])) self.sub_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out])) def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: img_processor_output = self.img_processor( img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True ) img_feature = img_processor_output.hidden_states[self.layer_idx] patch_feature = img_feature # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) if getattr(self, "img_processor_padding", None) is not None: patch_feature = self.img_processor_padding(patch_feature) patch_feature = self.image_token_compression(patch_feature) # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) return patch_feature def forward( self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_pixel_values: torch.FloatTensor, image_sizes: Optional[torch.Tensor] = None, image_attention_mask: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype) target_device = self.img_projection_up.bias.device target_dtype = self.img_projection_up.bias.dtype batch_size = image_pixel_values.shape[0] img_features = self.get_img_features( image_pixel_values.flatten(0, 1), attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool, device=target_device), ) base_feat_size = int(np.sqrt(img_features.shape[1])) img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out) image_sizes = image_sizes.view(-1, 2) output_imgs = [] for idx in range(batch_size): height, width = image_sizes[idx] height_ratio = height // self.crop_size width_ratio = width // self.crop_size area_ratio = height_ratio * width_ratio global_img = img_features[idx, :1] global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous() temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1) global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) sub_img = img_features[idx, 1:] sub_img = sub_img[:area_ratio] sub_img = ( sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out) .transpose(1, 2) .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out) .contiguous() ) if image_attention_mask is not None: reshaped_image_attention_mask = ( image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) .transpose(1, 2) .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size) ) useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) sub_img = sub_img[:, :useful_height, :useful_width] temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1) else: temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1) sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) # Merge global and sub output_imgs.append(torch.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1)) img_set_tensor = [] for output_img in output_imgs: output_img = output_img.to(device=target_device, dtype=target_dtype) img_feature_proj = self.img_projection_up(output_img) img_feature_proj = nn.functional.gelu(img_feature_proj) img_feature_proj = self.img_projection_down(img_feature_proj) img_set_tensor.append(img_feature_proj) merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) merged_img_set_tensor = merged_img_set_tensor.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) with torch.no_grad(): positions_tuple = torch.nonzero(input_ids == self.config.vision_config.image_token_id, as_tuple=True) # Temporarily disable autocast to avoid issue on bf16 tensors # Ref: https://github.com/pytorch/pytorch/issues/132715 with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): image_embeds = inputs_embeds.index_put( indices=positions_tuple, values=merged_img_set_tensor, accumulate=False ) image_embeds = self.drop(image_embeds) return image_embeds ########################################################## AUDIO ############################################# class Phi4MultimodalAudioMLP(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.layer_norm = nn.LayerNorm(config.hidden_size) self.act_fn = ACT2FN[config.activation] self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): hidden_states = self.layer_norm(hidden_states) up_states = self.gate_up_proj(hidden_states) up_states, gate = up_states.chunk(2, dim=-1) up_states = up_states * self.act_fn(gate) up_states = self.dropout(up_states) hidden_states = self.down_proj(up_states) out = self.dropout(hidden_states) return out class Phi4MultimodalAudioAttention(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.dropout_rate self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs, ): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) attention_interface: Callable = simple_eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, _ = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output class Phi4MultimodalAudioDepthWiseSeparableConv1d(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): super().__init__() self.dw_conv = nn.Conv1d( config.hidden_size, config.hidden_size * config.depthwise_multiplier, config.kernel_size, 1, padding=padding, groups=config.hidden_size, ) self.pw_conv = nn.Conv1d( config.hidden_size * config.depthwise_multiplier, config.depthwise_separable_out_channel, 1, 1, 0 ) def forward(self, hidden_states): return self.pw_conv(self.dw_conv(hidden_states)) class Phi4MultimodalAudioGluPointWiseConv(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.output_dim = config.ext_pw_out_channel self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) self.glu_act = ACT2FN[config.conv_glu_type] self.b1 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) self.b2 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) def forward(self, hidden_states): # we assume the input always has the #channel (#dim) in the last dimension of the # tensor, so need to switch the dimension first for 1D-Conv case hidden_states = hidden_states.permute([0, 2, 1]) hidden_states = self.ext_pw_conv_1d(hidden_states) out = hidden_states[:, 0 : self.output_dim, :] + self.b1 out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2) return out.permute([0, 2, 1]) class Phi4MultimodalAudioConvModule(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.kernel_size = config.kernel_size self.layer_norm = nn.LayerNorm(config.hidden_size) self.glu = Phi4MultimodalAudioGluPointWiseConv(config) self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeparableConv1d(config, padding=config.kernel_size - 1) self.act = ACT2FN[config.conv_activation] self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states: torch.Tensor): hidden_states = self.glu(self.layer_norm(hidden_states)) hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1])) if self.kernel_size > 1: hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)] hidden_states = self.act(hidden_states) hidden_states = self.ext_pw_conv_1d(hidden_states) out = self.dropout(hidden_states.permute([0, 2, 1])) return out class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.feed_forward_in = Phi4MultimodalAudioMLP(config) self.self_attn = Phi4MultimodalAudioAttention(config) self.conv = Phi4MultimodalAudioConvModule(config) self.feed_forward_out = Phi4MultimodalAudioMLP(config) self.layer_norm_att = nn.LayerNorm(config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, ): residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) hidden_states = self.layer_norm_att(residual) hidden_states = residual + self.self_attn(hidden_states, attention_mask) hidden_states = hidden_states + self.conv(hidden_states) hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) out = self.layer_norm(hidden_states) return out class Phi4MultimodalAudioNemoConvSubsampling(torch.nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.subsampling_factor = config.time_reduction self.sampling_num = int(math.log2(self.subsampling_factor)) self.act_fn = ACT2FN[config.nemo_activation] conv_channels = config.nemo_conv_channels layers = [ nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), self.act_fn, ] for _ in range(self.sampling_num - 1): layers.extend( [ nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels), nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), self.act_fn, ] ) # Aggregate the layers self.conv = torch.nn.Sequential(*layers) self.out = torch.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): # Unsqueeze Channel Axis hidden_states = hidden_states.unsqueeze(1) hidden_states = self.conv(hidden_states) # Flatten Channel and Frequency Axes b, _, t, _ = hidden_states.size() hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1)) if mask is None: return hidden_states, None max_audio_length = hidden_states.shape[1] feature_lens = mask.sum(1) padding_length = torch.ceil(feature_lens / self.subsampling_factor) arange_ = torch.arange(0, max_audio_length, device=hidden_states.device) pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) return hidden_states, pad_mask.unsqueeze(1) class Phi4MultimodalAudioRelativeAttentionBias(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.max_distance = config.bias_max_distance self.symmetric = config.bias_symmetric self.num_buckets = self.max_distance if not config.bias_symmetric: self.num_buckets *= 2 self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads) def forward(self, x): # instantiate bias compatible with shape of x max_pos = x.size(1) context_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[:, None] memory_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[None, :] relative_position = memory_position - context_position # clipping to a maximum distance using ops that play well with ONNX export relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) relative_position = relative_position.masked_fill( relative_position > self.max_distance - 1, self.max_distance - 1 ) # mapping from relative position to index in the bias parameter bias_idx = relative_position bias_idx = bias_idx.abs() if self.symmetric else bias_idx + self.num_buckets // 2 att_bias = self.bias_values(bias_idx) att_bias = att_bias.permute(2, 0, 1).unsqueeze(0) return att_bias class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.register_buffer("global_mean", torch.zeros(config.input_size)) self.register_buffer("global_invstd", torch.ones(config.input_size)) def forward(self, x): return (x - self.global_mean) * self.global_invstd @auto_docstring class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): config: Phi4MultimodalAudioConfig supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Phi4MultimodalAudioGluPointWiseConv): module.b1.data.zero_() module.b2.data.zero_() def unfold_tensor(tensor, max_seq_len): """ For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. Args: tensor: N, T, D """ _, _, D = tensor.shape tensor = tensor.transpose(-1, -2) # N x D x 1 x T => N x (D x max_seq_len) x T' tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len)) new_bsz, _, slen = tensor.shape tensor = tensor.view(new_bsz, -1, max_seq_len, slen) tensor = tensor.permute(0, 3, 2, 1) tensor = tensor.view(-1, max_seq_len, D).contiguous() return tensor def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): """ The function is very important for Transformer Transducer Streaming mode Args: xs_len (int): sequence length chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] left_window (int): how many left chunks can be seen right_window (int): how many right chunks can be seen. It is used for chunk overlap model. Returns: mask (torch.Tensor): a mask tensor for streaming model """ chunk_start_idx = torch.Tensor(chunk_start_idx).long() start_pad = torch.nn.functional.pad( chunk_start_idx, (1, 0) ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] end_pad = torch.nn.functional.pad( chunk_start_idx, (0, 1), value=x_len ) # append x_len to the end, so it becomes [0,18,36,48, x_len] seq_range = torch.arange(0, x_len).unsqueeze(-1) idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) idx_left = idx - left_window idx_left[idx_left < 0] = 0 boundary_left = start_pad[idx_left] mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) idx_right = idx + right_window idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) boundary_right = end_pad[idx_right] mask_right = seq_range_expand < boundary_right.unsqueeze(-1) return mask_left & mask_right class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__(config) self.config = config self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) self.encoders = nn.ModuleList( [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] ) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): # Create mask matrix for streaming # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size) # avoid randomness when run evaluation or decoding if self.training and np.random.rand() > 0.5: # Either first or last chunk is not complete. # If only the last one is not complete, EOS is not effective chunk_start_idx = seq_len - chunk_start_idx chunk_start_idx = chunk_start_idx[::-1] chunk_start_idx = chunk_start_idx[:-1] chunk_start_idx = np.insert(chunk_start_idx, 0, 0) enc_streaming_mask = ( adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) .unsqueeze(0) .expand([batch_size, -1, -1]) ) return enc_streaming_mask def forward_embeddings(self, hidden_states, masks): """Forwarding the inputs through the top embedding layers""" seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) if seq_len <= 0: raise ValueError( f"The sequence length after time reduction is invalid: {seq_len}. Your input feature is too short." ) batch_size = hidden_states.shape[0] enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk) enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) hidden_states, masks = self.embed(hidden_states, masks) streaming_mask = enc_streaming_mask if streaming_mask is not None and masks is not None: hs_mask = masks & streaming_mask elif masks is not None: hs_mask = masks else: hs_mask = streaming_mask return hidden_states, hs_mask, masks def calculate_hs_mask(self, hidden_states, device, mask): max_audio_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] enc_streaming_mask = self._streaming_mask( max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk ) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask feature_lens = mask.sum(1) padding_length = feature_lens pad_mask = torch.arange(0, max_audio_length, device=device).expand( padding_length.size(0), -1 ) < padding_length.unsqueeze(1) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): hidden_states = self.encoder_embedding(hidden_states) hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) unfolded = False bs, seq_len, _ = hidden_states.shape max_seq_len = 500 # maximum position for absolute positional encoding if seq_len > max_seq_len: # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len unfolded = True # the unfold op will drop residual frames, pad it to the multiple of max_seq_len if seq_len % max_seq_len > 0: chunk_pad_size = max_seq_len - (seq_len % max_seq_len) else: chunk_pad_size = 0 if chunk_pad_size > 0: hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0) hidden_states = hidden_states_pad.to(hidden_states.device) hidden_states = unfold_tensor(hidden_states, max_seq_len) masks_unfold = None if mask is not None: # revise hs_mask here because the previous calculated hs_mask did not consider extra pad subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] extra_padded_subsamlped_pad_mask = F.pad( subsampled_pad_mask, (0, chunk_pad_size), "constant", False ) # extra padding to the pad mask extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() masks_unfold = unfold_tensor( extra_padded_subsamlped_pad_mask, max_seq_len ) # unfold the pad mask like we did to the input tensor masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor hs_mask = self.calculate_hs_mask( hidden_states, hidden_states.device, masks_unfold ) # calculate hs_mask based on the unfolded pad mask relative_attention_bias = self.relative_attention_bias_layer(hidden_states) attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias for layer in self.encoders: hidden_states = layer(hidden_states, attention_mask) if unfolded: embed_dim = hidden_states.shape[-1] hidden_states = hidden_states.reshape(bs, -1, embed_dim) # if we ever padded before unfolding, we need to remove the padding if chunk_pad_size > 0: hidden_states = hidden_states[:, :-chunk_pad_size, :] return hidden_states class Phi4MultimodalAudioEmbedding(nn.Module): def __init__(self, config: Phi4MultimodalConfig): super().__init__() self.config = config self.layer_idx = config.audio_config.feature_layer self.drop = nn.Dropout(config.embd_pdrop) self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config) self.up_proj_for_speech = nn.Linear( config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size ) self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size) self.up_proj_for_vision_speech = nn.Linear( config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size ) self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size) def forward( self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, audio_input_features: torch.FloatTensor, audio_embed_sizes=None, audio_attention_mask=None, audio_projection_mode="speech", ) -> torch.FloatTensor: with torch.no_grad(): positions_tuple = torch.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True) up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech down_proj = ( self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech ) target_device = up_proj.bias.device target_dtype = up_proj.bias.dtype audio_input_features = audio_input_features.to(device=target_device, dtype=target_dtype) audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) audio_encoder_hidden_states = nn.functional.gelu(audio_encoder_hidden_states) audio_embeds = down_proj(audio_encoder_hidden_states) merged_audio_embeds = torch.cat( [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 ) merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) # Temporarily disable autocast to avoid issue on bf16 tensors # Ref: https://github.com/pytorch/pytorch/issues/132715 with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): audio_embeds = inputs_embeds.index_put( indices=positions_tuple, values=merged_audio_embeds, accumulate=False ) audio_embeds = self.drop(audio_embeds) return audio_embeds @use_kernel_forward_from_hub("RMSNorm") class Phi4MultimodalRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Phi4MultimodalRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Phi4MultimodalMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: up_states = self.gate_up_proj(hidden_states) gate, up_states = up_states.chunk(2, dim=-1) up_states = up_states * self.activation_fn(gate) return self.down_proj(up_states) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) return q_embed, k_embed class Phi4MultimodalAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) qkv = self.qkv_proj(hidden_states) query_pos = self.config.num_attention_heads * self.head_dim query_states = qkv[..., :query_pos] key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Phi4MultimodalConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx) self.mlp = Phi4MultimodalMLP(config) self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.config = config self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama return hidden_states class Phi4MultimodalFeatureEmbedding(nn.Module): """Image-audio embedding.""" def __init__(self, config: Phi4MultimodalConfig) -> None: super().__init__() self.config = config self.image_token_id = config.vision_config.image_token_id self.audio_token_id = config.audio_config.audio_token_id self.image_embed = Phi4MultimodalImageEmbedding(config) self.audio_embed = Phi4MultimodalAudioEmbedding(config) def forward( self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_pixel_values: Optional[torch.FloatTensor] = None, audio_input_features: Optional[torch.FloatTensor] = None, image_sizes=None, image_attention_mask=None, audio_embed_sizes=None, audio_attention_mask=None, ) -> torch.FloatTensor: with torch.no_grad(): image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) non_image_position_mask = ~image_position_mask image_embeds = None audio_embeds = None if image_pixel_values is not None and (input_ids == self.image_token_id).any(): image_embeds = self.image_embed( input_ids, inputs_embeds, image_pixel_values=image_pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, ) if audio_input_features is not None and (input_ids == self.audio_token_id).any(): audio_projection_mode = "vision" if image_pixel_values is not None else "speech" audio_embeds = self.audio_embed( input_ids, inputs_embeds, audio_input_features=audio_input_features, audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, audio_projection_mode=audio_projection_mode, ) # merge image and audio if image_embeds is not None and audio_embeds is not None: inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask elif image_embeds is not None: inputs_embeds = image_embeds elif audio_embeds is not None: inputs_embeds = audio_embeds return inputs_embeds class Phi4MultimodalRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: Phi4MultimodalConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @auto_docstring class Phi4MultimodalPreTrainedModel(PreTrainedModel): config: Phi4MultimodalConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Phi4MultimodalDecoderLayer, "attentions": Phi4MultimodalAttention, } _version = "0.0.5" def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Phi4MultimodalImageEmbedding): module.global_img_feature_extensor.data.zero_() module.sub_img_feature_extensor.data.zero_() @auto_docstring class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): def __init__(self, config: Phi4MultimodalConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Phi4MultimodalRotaryEmbedding(config=config) self.gradient_checkpointing = False self.embed_dropout = nn.Dropout(config.embd_pdrop) self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) # Initialize weights and apply final processing self.post_init() @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, image_pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, image_attention_mask=None, audio_input_features: Optional[torch.FloatTensor] = None, audio_embed_sizes=None, audio_attention_mask=None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> BaseModelOutputWithPast: r""" image_pixel_values (`torch.FloatTensor`, *optional*): If the input contains images, these correspond to the pixel values after transformations (as returned by the Processor) image_sizes (`torch.LongTensor`, *optional*): If the input contains images, these correspond to size of each image. image_attention_mask (`torch.LongTensor`, *optional*): Attention mask for the images. audio_input_features (`torch.FloatTensor`, *optional*): If the input contains audio samples, these correspond to the values after transformation (as returned by the Processor). audio_embed_sizes (`torch.Tensor`, *optional*): Size of the audio inputs. audio_attention_mask (`torch.Tensor, *optional*): Attention mask for the audio inputs. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens_extend( input_ids, inputs_embeds, image_pixel_values=image_pixel_values, audio_input_features=audio_input_features, image_sizes=image_sizes, image_attention_mask=image_attention_mask, audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask causal_mask = mask_function( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) @auto_docstring class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Phi4MultimodalModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, image_pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, image_attention_mask=None, audio_input_features: Optional[torch.FloatTensor] = None, audio_embed_sizes=None, audio_attention_mask=None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> CausalLMOutputWithPast: r""" image_pixel_values (`torch.FloatTensor`, *optional*): If the input contains images, these correspond to the pixel values after transformations (as returned by the Processor) image_sizes (`torch.LongTensor`, *optional*): If the input contains images, these correspond to size of each image. image_attention_mask (`torch.LongTensor`, *optional*): Attention mask for the images. audio_input_features (`torch.FloatTensor`, *optional*): If the input contains audio samples, these correspond to the values after transformation (as returned by the Processor). audio_embed_sizes (`torch.Tensor`, *optional*): Size of the audio inputs. audio_attention_mask (`torch.Tensor, *optional*): Attention mask for the audio inputs. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM >>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA") >>> tokenizer = AutoTokenizer.from_pretrained("TBA") >>> prompt = "This is an example script ." >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, image_pixel_values=image_pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, audio_input_features=audio_input_features, audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, image_pixel_values=None, image_sizes=None, image_attention_mask=None, audio_input_features=None, audio_embed_sizes=None, audio_attention_mask=None, cache_position=None, position_ids=None, use_cache=True, logits_to_keep=0, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the # process # When the first time input length reached long and short factor switching point, enforce re-compute cache # It will cause downside of slower at this single token position, however, better than current failure. if ( past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 ): past_length = cache_position[0] if past_length <= self.config.original_max_position_embeddings: past_key_values = None model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, image_pixel_values=image_pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, audio_input_features=audio_input_features, audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs __all__ = [ "Phi4MultimodalAudioPreTrainedModel", "Phi4MultimodalAudioModel", "Phi4MultimodalVisionPreTrainedModel", "Phi4MultimodalVisionModel", "Phi4MultimodalPreTrainedModel", "Phi4MultimodalModel", "Phi4MultimodalForCausalLM", ]