from typing import Optional import torch from torch import nn from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_outputs import BaseModelOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring from ...utils.generic import check_model_inputs from ..mistral.configuration_mistral import MistralConfig from ..qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2DecoderLayer, Qwen2ForCausalLM, Qwen2ForQuestionAnswering, Qwen2ForSequenceClassification, Qwen2ForTokenClassification, Qwen2MLP, Qwen2Model, Qwen2PreTrainedModel, Qwen2RMSNorm, Qwen2RotaryEmbedding, ) class MinistralConfig(MistralConfig, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`MinistralModel`]. It is used to instantiate an Ministral model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Ministral-8B-Instruct-2410. [mistralai/Ministral-8B-Instruct-2410](https://huggingface.co/mistralai/Ministral-8B-Instruct-2410) [mistralai/Ministral-8B-Instruct-2410](https://huggingface.co/mistralai/Ministral-8B-Instruct-2410) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Ministral model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`MinistralModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 14336): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. num_key_value_heads (`int`, *optional*, defaults to 8): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details, check out [this paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): The attention head dimension. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to `4096*32`): The maximum sequence length that this model might ever be used with. Ministral's sliding window attention allows sequence of up to 4096*32 tokens. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*): The id of the padding token. bos_token_id (`int`, *optional*, defaults to 1): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 2): The id of the "end-of-sequence" token. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. sliding_window (`int`, *optional*, defaults to 4096): Sliding window attention window size. If not specified, will default to `4096`. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. layer_types (`list`, *optional*): Attention pattern for each layer. ```python >>> from transformers import MinistralModel, MinistralConfig >>> # Initializing a Ministral 8B style configuration >>> configuration = MinistralConfig() >>> # Initializing a model from the Ministral 8B style configuration >>> model = MinistralModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "ministral" def __init__( self, vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, head_dim=None, hidden_act="silu", max_position_embeddings=4096 * 32, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=10000.0, sliding_window=4096, attention_dropout=0.0, layer_types=None, **kwargs, ): PretrainedConfig.__init__( self, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window self.head_dim = head_dim # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_dropout = attention_dropout self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ "sliding_attention" if self.sliding_window is not None else "full_attention" ] * num_hidden_layers class MinistralMLP(Qwen2MLP): pass class MinistralAttention(Qwen2Attention): def __init__(self, config, layer_idx: int): super().__init__(config, layer_idx) # Match Mistral: q/k/v do not have bias self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) class MinistralRMSNorm(Qwen2RMSNorm): pass class MinistralDecoderLayer(Qwen2DecoderLayer): pass class MinistralPreTrainedModel(Qwen2PreTrainedModel): pass class MinistralRotaryEmbedding(Qwen2RotaryEmbedding): pass class MinistralModel(Qwen2Model): def __init__(self, config: MinistralConfig): super().__init__(config) del self.has_sliding_layers @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) class MinistralForCausalLM(Qwen2ForCausalLM): pass class MinistralForSequenceClassification(Qwen2ForSequenceClassification): pass class MinistralForTokenClassification(Qwen2ForTokenClassification): pass class MinistralForQuestionAnswering(Qwen2ForQuestionAnswering): pass __all__ = [ "MinistralConfig", "MinistralPreTrainedModel", "MinistralModel", "MinistralForCausalLM", "MinistralForSequenceClassification", "MinistralForTokenClassification", "MinistralForQuestionAnswering", ]