# coding=utf-8 # Copyright 2025 the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional import torch import torch.nn as nn from transformers.utils.generic import TransformersKwargs from ...cache_utils import Cache, DynamicCache from ...configuration_utils import layer_type_validation from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ..olmo2.configuration_olmo2 import Olmo2Config from ..olmo2.modeling_olmo2 import ( Olmo2Attention, Olmo2DecoderLayer, Olmo2ForCausalLM, Olmo2Model, Olmo2PreTrainedModel, Olmo2RMSNorm, Olmo2RotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, ) class Olmo3Config(Olmo2Config): r""" This is the configuration class to store the configuration of a [`Olmo3Model`]. It is used to instantiate an OLMo3 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the [allenai/OLMo-3-0725-1B](https://huggingface.co/allenai/OLMo-3-0725-1B). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 50304): Vocabulary size of the Olmo3 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Olmo3Model`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details, check out [this paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*, defaults to 1): Padding token id. bos_token_id (`int`, *optional*): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 50279): End of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`list[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`list[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. sliding_window (`int`, *optional*, defaults to 4096): Size of the sliding window for sliding window attention. layer_types (`list`, *optional*): Attention pattern for each layer. Defaults to sliding window attention for 3 out of 4 layers, and full attention for every 4th layer. ```python >>> from transformers import Olmo3Model, Olmo3Config >>> # Initializing a Olmo3 7B style configuration >>> configuration = Olmo3Config() >>> # Initializing a model from the Olmo3 7B style configuration >>> model = Olmo3Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config ``` """ model_type = "olmo3" base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=50304, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, use_cache=True, pad_token_id=1, bos_token_id=None, eos_token_id=50279, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, rms_norm_eps=1e-5, sliding_window=4096, layer_types=None, **kwargs, ): super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, use_cache=use_cache, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, attention_bias=attention_bias, attention_dropout=attention_dropout, rms_norm_eps=rms_norm_eps, **kwargs, ) self.sliding_window = sliding_window self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" for i in range(self.num_hidden_layers) ] layer_type_validation(self.layer_types) def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. """ rope_config_validation(self) class Olmo3RMSNorm(Olmo2RMSNorm): pass # Olmo3 attention is identical to OLMo 2 attention except: # - Sliding window attention is used for 3 out of 4 layers. class Olmo3Attention(Olmo2Attention): def __init__(self, config: Olmo3Config, layer_idx: int): super().__init__(config, layer_idx=layer_idx) assert config.layer_types is not None self.attention_type = config.layer_types[layer_idx] self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states)) key_states = self.k_norm(self.k_proj(hidden_states)) value_states = self.v_proj(hidden_states) query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Olmo3DecoderLayer(Olmo2DecoderLayer): pass # OLMo 3 RoPE is identical to OLMo 2 RoPE, except: # - RoPE scaling is not applied to sliding window attention layers. class Olmo3RotaryEmbedding(Olmo2RotaryEmbedding): def __init__(self, config: Olmo3Config, device=None, rope_type: Optional[str] = None): nn.Module.__init__(self) if rope_type is not None: self.rope_type = rope_type elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): # BC: "rope_type" was originally "type" self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" assert self.rope_type is not None self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq class Olmo3PreTrainedModel(Olmo2PreTrainedModel): pass # The OLMo 3 model is identical to the OLMo 2 model, except: # - Sliding window attention is used for 3 out of 4 layers. # - RoPE scaling is not applied to sliding window attention layers. class Olmo3Model(Olmo2Model): def __init__(self, config: Olmo3Config): super().__init__(config) self.norm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layers = nn.ModuleList( [Olmo3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.rotary_embs = nn.ModuleDict( { "sliding_attention": Olmo3RotaryEmbedding(config=config, rope_type="default"), "full_attention": Olmo3RotaryEmbedding(config=config), } ) del self.rotary_emb def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position: torch.Tensor = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } hidden_states = inputs_embeds position_embeddings_mapping = { "sliding_attention": self.rotary_embs["sliding_attention"](hidden_states, position_ids), "full_attention": self.rotary_embs["full_attention"](hidden_states, position_ids), } for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type], position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings_mapping[decoder_layer.self_attn.attention_type], **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) class Olmo3ForCausalLM(Olmo2ForCausalLM): pass __all__ = [ "Olmo3Config", "Olmo3ForCausalLM", "Olmo3Model", "Olmo3PreTrainedModel", ]