# coding=utf-8 # Copyright 2025 the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_outputs import MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring from ...utils.generic import check_model_inputs from ..mixtral.modeling_mixtral import MixtralModel, MixtralPreTrainedModel from ..olmo2.modeling_olmo2 import Olmo2Attention, Olmo2RMSNorm, Olmo2RotaryEmbedding from ..olmoe.configuration_olmoe import OlmoeConfig from ..olmoe.modeling_olmoe import ( OlmoeDecoderLayer, OlmoeForCausalLM, OlmoeMLP, OlmoeSparseMoeBlock, ) class FlexOlmoConfig(OlmoeConfig): r""" This is the configuration class to store the configuration of a [`FlexOlmoModel`]. It is used to instantiate an FlexOlmo model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the [allenai/FlexOlmo-7x7B-1T](https://huggingface.co/allenai/FlexOlmo-7x7B-1T). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 100352): Vocabulary size of the FlexOlmo model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`FlexOlmoModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details, check out [this paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 4096): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*, defaults to 100277): Padding token id. bos_token_id (`int`, *optional*): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 100257): End of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 500000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. num_experts_per_tok (`int`, *optional*, defaults to 5): Number of selected experts. num_experts (`int`, *optional*, defaults to 7): Number of routed experts. output_router_logits (`bool`, *optional*, defaults to `False`): Whether or not the router logits should be returned by the model. Enabling this will also allow the model to output the auxiliary loss, including load balancing loss and router z-loss. router_aux_loss_coef (`float`, *optional*, defaults to 0.01): The aux loss factor for the total loss. norm_topk_prob (`bool`, *optional*, defaults to `False`): Whether to normalize the topk probabilities. ```python >>> from transformers import FlexOlmoModel, FlexOlmoConfig >>> # Initializing a FlexOlmo style configuration >>> configuration = FlexOlmoConfig() >>> # Initializing a model from the FlexOlmo style configuration >>> model = FlexOlmoModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "flex_olmo" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=100352, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=100277, bos_token_id=None, eos_token_id=100257, tie_word_embeddings=False, rope_theta=500000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, num_experts_per_tok=5, num_experts=7, output_router_logits=False, router_aux_loss_coef=0.01, norm_topk_prob=False, **kwargs, ): super().__init__( vocab_size=vocab_size, max_position_embeddings=max_position_embeddings, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, initializer_range=initializer_range, rms_norm_eps=rms_norm_eps, use_cache=use_cache, rope_theta=rope_theta, rope_scaling=rope_scaling, attention_bias=attention_bias, attention_dropout=attention_dropout, num_experts_per_tok=num_experts_per_tok, num_experts=num_experts, output_router_logits=output_router_logits, router_aux_loss_coef=router_aux_loss_coef, norm_topk_prob=norm_topk_prob, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) del self.clip_qkv # FlexOlmo RMS norm reuses Olmo2 RMS norm, which handles low precision slightly differently than the original Olmoe. class FlexOlmoRMSNorm(Olmo2RMSNorm): pass # FlexOlmo RMS norm reuses Olmo2 RMS norm, so that the output cos and sin are returned # as float32 rather than the input type. class FlexOlmoRotaryEmbedding(Olmo2RotaryEmbedding): pass class FlexOlmoMLP(OlmoeMLP): pass # FlexOlmo uses Olmo2 attention instead of OlmoE Attention since its `apply_rotary_pos_emb` # implementation handles lower precision more faithfully to the Olmo codebase. class FlexOlmoAttention(Olmo2Attention): pass class FlexOlmoSparseMoeBlock(OlmoeSparseMoeBlock): pass # FlexOlmo decoder layer is identical to OlmoE decoder layer except: # - Norm is applied after attention/feedforward rather than before. class FlexOlmoDecoderLayer(OlmoeDecoderLayer): def __init__(self, config: FlexOlmoConfig, layer_idx: int): super().__init__(config, layer_idx=layer_idx) self.post_attention_layernorm = FlexOlmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = FlexOlmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = FlexOlmoAttention(config=config, layer_idx=layer_idx) del self.input_layernorm def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.FloatTensor: residual = hidden_states # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states, _ = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states return hidden_states # FlexOlmo uses Mixtral model as its base instead of OlmoE model since Mixtral is more up-to-date with the rest # of the transformers library. For example, it uses the newer mechanisms of recording submodule outputs. class FlexOlmoPreTrainedModel(MixtralPreTrainedModel): pass # FlexOlmo uses Mixtral model as its base instead of OlmoE model since Mixtral is more up-to-date with the rest # of the transformers library. For example, it uses the newer mechanisms of recording submodule outputs. # FlexOlmo model is identical to Mixtral model except: # - FlexOlmo does not use sliding window attention. class FlexOlmoModel(MixtralModel): @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = self.norm(hidden_states) return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE last_hidden_state=hidden_states, past_key_values=past_key_values, ) class FlexOlmoForCausalLM(OlmoeForCausalLM): pass __all__ = [ "FlexOlmoConfig", "FlexOlmoForCausalLM", "FlexOlmoModel", "FlexOlmoPreTrainedModel", ]