# coding=utf-8 # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen3 model.""" from typing import Optional, Union import torch import torch.nn.functional as F from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaRMSNorm, ) from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel, load_balancing_loss_func from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer from ..qwen3.modeling_qwen3 import Qwen3Attention from .configuration_qwen3_moe import Qwen3MoeConfig logger = logging.get_logger(__name__) class Qwen3MoeAttention(Qwen3Attention): # This is the main diff with qwen2Moe! def __init__(self, config: Qwen3MoeConfig, layer_idx: int): super().__init__(config, layer_idx) self.sliding_window = getattr(config, "sliding_window", None) class Qwen3MoeMLP(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class Qwen3MoeSparseMoeBlock(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob # gating self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = nn.ModuleList( [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: # only diff with mixtral sparse moe block! routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class Qwen3MoeRMSNorm(LlamaRMSNorm): pass class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer): def __init__(self, config: Qwen3MoeConfig, layer_idx: int): nn.Module.__init__(self) self.hidden_size = config.hidden_size self.self_attn = Qwen3MoeAttention(config, layer_idx) if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoeSparseMoeBlock(config) else: self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> torch.FloatTensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) # For the MoE layers, we need to unpack if isinstance(hidden_states, tuple): hidden_states, _ = hidden_states hidden_states = residual + hidden_states return hidden_states class Qwen3MoeModel(MixtralModel): pass class Qwen3MoeForCausalLM(MixtralForCausalLM): def __init__(self, config): super().__init__(config) self.model = Qwen3MoeModel(config) self.num_experts = config.num_experts def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_router_logits=output_router_logits, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) class Qwen3MoeForSequenceClassification(LlamaForSequenceClassification): pass class Qwen3MoeForTokenClassification(LlamaForTokenClassification): pass class Qwen3MoeForQuestionAnswering(LlamaForQuestionAnswering): pass __all__ = [ "Qwen3MoeForCausalLM", "Qwen3MoeForQuestionAnswering", "Qwen3MoeModel", "Qwen3MoePreTrainedModel", # noqa: F822 "Qwen3MoeForSequenceClassification", "Qwen3MoeForTokenClassification", ]