# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Ernie 4.5 model""" import torch from torch import nn from ...modeling_rope_utils import dynamic_rope_update from ...utils import auto_docstring, can_return_tuple from ..glm.modeling_glm import rotate_half from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, LlamaMLP, LlamaRotaryEmbedding, ) from .configuration_ernie4_5 import Ernie4_5Config class Ernie4_5RotaryEmbedding(LlamaRotaryEmbedding): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling # keeping it in full precision return cos, sin def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ # glm rope style (with full dim) and full precision original_dtype = q.dtype cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) # Interleave them instead of usual shape cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) q_embed = (q.float() * cos) + (rotate_half(q).float() * sin) k_embed = (k.float() * cos) + (rotate_half(k).float() * sin) return q_embed.to(original_dtype), k_embed.to(original_dtype) class Ernie4_5MLP(LlamaMLP): def __init__(self, config: Ernie4_5Config): super().__init__(config) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) class Ernie4_5Attention(LlamaAttention): def __init__(self, config: Ernie4_5Config, layer_idx: int): super().__init__(config, layer_idx) self.attention_dropout = 0.0 self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) class Ernie4_5ForCausalLM(LlamaForCausalLM): @can_return_tuple @auto_docstring def forward(self, **super_kwargs): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ super().forward(**super_kwargs) __all__ = [ "Ernie4_5ForCausalLM", "Ernie4_5Model", # noqa: F822 "Ernie4_5PreTrainedModel", # noqa: F822 ]