# coding=utf-8 # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen3-VL model.""" from typing import Callable, Optional, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, is_torchdynamo_compiling, logging from ...utils.generic import check_model_inputs from ...video_utils import VideoInput from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLCausalLMOutputWithPast, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel, Qwen2_5_VLVisionBlock, ) from ..qwen2_vl.modeling_qwen2_vl import ( PatchEmbed, Qwen2VLModelOutputWithPast, Qwen2VLPreTrainedModel, TransformersKwargs, VisionAttention, VisionRotaryEmbedding, ) from ..qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor from ..qwen3.modeling_qwen3 import ( Qwen3Attention, Qwen3DecoderLayer, Qwen3Model, apply_rotary_pos_emb, eager_attention_forward, ) logger = logging.get_logger(__name__) class Qwen3VLVisionConfig(PretrainedConfig): model_type = "qwen3_vl" base_config_key = "vision_config" def __init__( self, depth=27, hidden_size=1152, hidden_act="gelu_pytorch_tanh", intermediate_size=4304, num_heads=16, in_channels=3, patch_size=16, spatial_merge_size=2, temporal_patch_size=2, out_hidden_size=3584, num_position_embeddings=2304, deepstack_visual_indexes=[8, 16, 24], initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) self.depth = depth self.hidden_size = hidden_size self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.num_heads = num_heads self.in_channels = in_channels self.patch_size = patch_size self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size self.out_hidden_size = out_hidden_size self.num_position_embeddings = num_position_embeddings self.initializer_range = initializer_range self.deepstack_visual_indexes = deepstack_visual_indexes class Qwen3VLTextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 151936): Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Qwen3VLModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 22016): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. num_key_value_heads (`int`, *optional*, defaults to 32): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details, check out [this paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. head_dim (`int`, *optional*, defaults to 128): The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 128000): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 5000000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`list[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`list[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. ```python >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig >>> # Initializing a Qwen3VL style configuration >>> configuration = Qwen3VLTextConfig() >>> # Initializing a model from the Qwen3-VL-7B style configuration >>> model = Qwen3VLTextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "qwen3_vl_text" base_config_key = "text_config" def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, head_dim=128, hidden_act="silu", max_position_embeddings=128000, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=5000000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.head_dim = head_dim self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) class Qwen3VLConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`): The config object or dictionary of the text backbone. vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`): The config object or dictionary of the vision backbone. image_token_id (`int`, *optional*, defaults to 151655): The image token index to encode the image prompt. video_token_id (`int`, *optional*, defaults to 151656): The video token index to encode the image prompt. vision_start_token_id (`int`, *optional*, defaults to 151652): The start token index to encode the image prompt. vision_end_token_id (`int`, *optional*, defaults to 151653): The end token index to encode the image prompt. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie the word embeddings. ```python >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig >>> # Initializing a Qwen3-VL style configuration >>> configuration = Qwen3VLConfig() >>> # Initializing a model from the Qwen3-VL-4B style configuration >>> model = Qwen3VLForConditionalGeneration(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "qwen3_vl" sub_configs = {"vision_config": Qwen3VLVisionConfig, "text_config": Qwen3VLTextConfig} keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, text_config=None, vision_config=None, image_token_id=151655, video_token_id=151656, vision_start_token_id=151652, vision_end_token_id=151653, tie_word_embeddings=False, **kwargs, ): if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) elif text_config is None: self.text_config = self.sub_configs["text_config"]() self.image_token_id = image_token_id self.video_token_id = video_token_id self.vision_start_token_id = vision_start_token_id self.vision_end_token_id = vision_end_token_id super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) class Qwen3VLVisionMLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) class Qwen3VLVisionPatchEmbed(PatchEmbed): def __init__(self, config) -> None: super().__init__() self.patch_size = config.patch_size self.temporal_patch_size = config.temporal_patch_size self.in_channels = config.in_channels self.embed_dim = config.hidden_size kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) class Qwen3VLVisionRotaryEmbedding(VisionRotaryEmbedding): pass class Qwen3VLVisionPatchMerger(nn.Module): def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: super().__init__() self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) self.use_postshuffle_norm = use_postshuffle_norm self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) self.act_fn = nn.GELU() self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) return x class Qwen3VLVisionAttention(VisionAttention): def __init__(self, config: Qwen3VLVisionConfig) -> None: super().__init__() self.dim = config.hidden_size class Qwen3VLVisionBlock(Qwen2_5_VLVisionBlock): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) self.attn = Qwen3VLVisionAttention(config=config) self.mlp = Qwen3VLVisionMLP(config=config) class Qwen3VLTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: Qwen3VLTextConfig, device=None): super().__init__() if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", "default") else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) def apply_interleaved_mrope(self, freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to interleaved [THTHWHTHW...TT], preserving frequency continuity. args: x: (3, bs, seq_len, head_dim // 2) mrope_section: (3,) returns: x_t: (bs, seq_len, head_dim // 2) """ freqs_t = freqs[0] # just overwrite the first dimension T for dim, offset in enumerate((1, 2), start=1): # H, W length = mrope_section[dim] * 3 idx = slice(offset, length, 3) freqs_t[..., idx] = freqs[dim, ..., idx] return freqs_t @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): # In contrast to other models, Qwen3VL has different position ids for the grids # So we expand the inv_freq to shape (3, ...) if position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class Qwen3VLTextAttention(Qwen3Attention): def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): super().__init__(config, layer_idx) del self.sliding_window def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen3VLTextDecoderLayer(Qwen3DecoderLayer): def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): super().__init__(config, layer_idx) del self.attention_type def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: return super().forward( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, ) class Qwen3VLModelOutputWithPast(Qwen2VLModelOutputWithPast): pass class Qwen3VLPreTrainedModel(Qwen2VLPreTrainedModel): config: Qwen3VLConfig _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] _can_record_outputs = { "hidden_states": Qwen3VLTextDecoderLayer, "attentions": Qwen3VLTextAttention, } class Qwen3VLVisionModel(Qwen3VLPreTrainedModel): config: Qwen3VLVisionConfig _no_split_modules = ["Qwen3VLVisionBlock"] def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) self.spatial_merge_size = config.spatial_merge_size self.patch_size = config.patch_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size self.patch_embed = Qwen3VLVisionPatchEmbed( config=config, ) self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) self.num_grid_per_side = int(config.num_position_embeddings**0.5) head_dim = config.hidden_size // config.num_heads self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)]) self.merger = Qwen3VLVisionPatchMerger( config=config, use_postshuffle_norm=False, ) self.deepstack_visual_indexes = config.deepstack_visual_indexes self.deepstack_merger_list = nn.ModuleList( [ Qwen3VLVisionPatchMerger( config=config, use_postshuffle_norm=True, ) for _ in range(len(config.deepstack_visual_indexes)) ] ) self.gradient_checkpointing = False def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: merge_size = self.spatial_merge_size max_hw = int(grid_thw[:, 1:].max().item()) freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) device = freq_table.device total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) offset = 0 for num_frames, height, width in grid_thw: merged_h, merged_w = height // merge_size, width // merge_size block_rows = torch.arange(merged_h, device=device) # block row indices block_cols = torch.arange(merged_w, device=device) # block col indices intra_row = torch.arange(merge_size, device=device) # intra-block row offsets intra_col = torch.arange(merge_size, device=device) # intra-block col offsets # Compute full-resolution positions row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) coords = torch.stack((row_idx, col_idx), dim=-1) if num_frames > 1: coords = coords.repeat(num_frames, 1) num_tokens = coords.shape[0] pos_ids[offset : offset + num_tokens] = coords offset += num_tokens embeddings = freq_table[pos_ids] # lookup rotary embeddings embeddings = embeddings.flatten(1) return embeddings def fast_pos_embed_interpolate(self, grid_thw): grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] idx_list = [[] for _ in range(4)] weight_list = [[] for _ in range(4)] for t, h, w in zip(grid_ts, grid_hs, grid_ws): h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) h_idxs_floor = h_idxs.int() w_idxs_floor = w_idxs.int() h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) dh = h_idxs - h_idxs_floor dw = w_idxs - w_idxs_floor base_h = h_idxs_floor * self.num_grid_per_side base_h_ceil = h_idxs_ceil * self.num_grid_per_side indices = [ (base_h[None].T + w_idxs_floor[None]).flatten(), (base_h[None].T + w_idxs_ceil[None]).flatten(), (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ] weights = [ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ((1 - dh)[None].T * dw[None]).flatten(), (dh[None].T * (1 - dw)[None]).flatten(), (dh[None].T * dw[None]).flatten(), ] for i in range(4): idx_list[i].extend(indices[i].tolist()) weight_list[i].extend(weights[i].tolist()) idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) weight_tensor = torch.tensor( weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device ) pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) patch_pos_embeds_permute = [] merge_size = self.config.spatial_merge_size for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): pos_embed = pos_embed.repeat(t, 1) pos_embed = ( pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) .permute(0, 1, 3, 2, 4, 5) .flatten(0, 4) ) patch_pos_embeds_permute.append(pos_embed) patch_pos_embeds = torch.cat(patch_pos_embeds_permute) return patch_pos_embeds def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) hidden_states = hidden_states + pos_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw # See https://github.com/huggingface/transformers/pull/34852 for more information dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs, ) if layer_num in self.deepstack_visual_indexes: deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( hidden_states ) deepstack_feature_lists.append(deepstack_feature) hidden_states = self.merger(hidden_states) return hidden_states, deepstack_feature_lists @auto_docstring( custom_intro=( "Text part of Qwen3VL, " "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." ) ) class Qwen3VLTextModel(Qwen3VLPreTrainedModel, Qwen3Model): config: Qwen3VLTextConfig _no_split_modules = ["Qwen3VLTextDecoderLayer"] def __init__(self, config: Qwen3VLTextConfig): super().__init__(config) del self.has_sliding_layers def _deepstack_process( self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor ): visual_pos_masks = visual_pos_masks.to(hidden_states.device) visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds hidden_states[visual_pos_masks, :] = local_this return hidden_states @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, # args for deepstack visual_pos_masks: Optional[torch.Tensor] = None, deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: r""" visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): The mask of the visual positions. deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). The feature is extracted from the different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") # torch.jit.trace() doesn't support cache objects in the output if use_cache and past_key_values is None and not torch.jit.is_tracing(): past_key_values = DynamicCache(config=self.config) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) # the hard coded `3` is for temporal, height and width. if position_ids is None: position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) elif position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) if position_ids.ndim == 3 and position_ids.shape[0] == 4: text_position_ids = position_ids[0] position_ids = position_ids[1:] else: text_position_ids = position_ids[0] attention_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=text_position_ids, ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers for layer_idx, decoder_layer in enumerate(self.layers): layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=text_position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = layer_outputs # add visual features to the hidden states of first several layers if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): hidden_states = self._deepstack_process( hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx], ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) @auto_docstring class Qwen3VLModel(Qwen2_5_VLModel): config: Qwen3VLConfig _checkpoint_conversion_mapping = {} _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] def __init__(self, config): super().__init__(config) self.visual = Qwen3VLVisionModel._from_config(config.vision_config) self.language_model = Qwen3VLTextModel._from_config(config.text_config) def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split if video_grid_thw is not None: video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) video_grid_thw[:, 0] = 1 spatial_merge_size = self.config.vision_config.spatial_merge_size image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) position_ids = torch.ones( 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device, ) image_index, video_index = 0, 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums for _ in range(image_nums + video_nums): if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: ed_image = len(input_tokens) + 1 if video_token_id in input_tokens and remain_videos > 0: ed_video = input_tokens.index(video_token_id, st) else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: t, h, w = ( image_grid_thw[image_index][0], image_grid_thw[image_index][1], image_grid_thw[image_index][2], ) image_index += 1 remain_images -= 1 ed = ed_image else: t, h, w = ( video_grid_thw[video_index][0], video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) video_index += 1 remain_videos -= 1 ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t.item(), h.item() // spatial_merge_size, w.item() // spatial_merge_size, ) text_len = ed - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( torch.arange(input_ids.shape[1], device=input_ids.device) .view(1, 1, -1) .expand(3, input_ids.shape[0], -1) ) mrope_position_deltas = torch.zeros( [input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype, ) return position_ids, mrope_position_deltas def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): """ Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) return image_embeds, deepstack_image_embeds def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None ): """ Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. Args: pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images return self.get_image_features(pixel_values_videos, video_grid_thw) @auto_docstring @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Qwen3VLModelOutputWithPast]: r""" image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) image_mask = None video_mask = None if pixel_values is not None: image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) visual_pos_masks = None deepstack_visual_embeds = None if image_mask is not None and video_mask is not None: # aggregate visual_pos_masks and deepstack_visual_embeds image_mask = image_mask[..., 0] video_mask = video_mask[..., 0] visual_pos_masks = image_mask | video_mask deepstack_visual_embeds = [] image_mask_joint = image_mask[visual_pos_masks] video_mask_joint = video_mask[visual_pos_masks] for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) embed_joint[image_mask_joint, :] = img_embed embed_joint[video_mask_joint, :] = vid_embed deepstack_visual_embeds.append(embed_joint) elif image_mask is not None: image_mask = image_mask[..., 0] visual_pos_masks = image_mask deepstack_visual_embeds = deepstack_image_embeds elif video_mask is not None: video_mask = video_mask[..., 0] visual_pos_masks = video_mask deepstack_visual_embeds = deepstack_video_embeds if position_ids is None: attention_mask_tensor = ( attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] ) if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) # Only apply conversion for floating point tensors (inverted masks) if attention_mask_tensor.dtype.is_floating_point: attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min attention_mask_tensor = (1.0 - attention_mask_tensor).int() # Calculate RoPE index once per generation in the pre-fill stage only. # When compiling, we can't check tensor values thus we check only input length # It is safe to assume that `length!=1` means we're in pre-fill because compiled # models currently cannot do asssisted decoding prefill_compiled_stage = is_torchdynamo_compiling() and ( (input_ids is not None and input_ids.shape[1] != 1) or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) ) prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( (cache_position is not None and cache_position[0] == 0) or (past_key_values is None or past_key_values.get_seq_length() == 0) ) if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, attention_mask=attention_mask_tensor, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.language_model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, **kwargs, ) return Qwen3VLModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, rope_deltas=self.rope_deltas, ) class Qwen3VLCausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast): pass class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): config: Qwen3VLConfig _checkpoint_conversion_mapping = {} @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. Example: TODO: Add example """ outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) return Qwen3VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, rope_deltas=outputs.rope_deltas, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, use_cache=use_cache, **kwargs, ) # Qwen3VL position_ids are prepareed with rope_deltas in forward model_inputs["position_ids"] = None if cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_values_videos"] = None return model_inputs class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False): pass class Qwen3VLImagesKwargs(Qwen2VLImagesKwargs): pass class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False): images_kwargs: Qwen3VLImagesKwargs videos_kwargs: Qwen3VLVideosProcessorKwargs _defaults = { "text_kwargs": { "padding": False, "return_token_type_ids": False, "return_mm_token_type_ids": False, }, "videos_kwargs": {"return_metadata": True}, } class Qwen3VLProcessor(Qwen2VLProcessor): r""" Constructs a Qwen3VL processor which wraps a Qwen3VL image processor and a Qwen2 tokenizer into a single processor. [`Qwen3VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the [`~Qwen3VLProcessor.__call__`] and [`~Qwen3VLProcessor.decode`] for more information. Args: image_processor ([`Qwen2VLImageProcessor`], *optional*): The image processor is a required input. tokenizer ([`Qwen2TokenizerFast`], *optional*): The tokenizer is a required input. video_processor ([`Qwen3VLVideoProcessor`], *optional*): The video processor is a required input. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. """ def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs) self.vision_start_token = ( "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token ) self.vision_end_token = ( "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token ) self.vision_start_token_id = ( tokenizer.vision_start_token_id if getattr(tokenizer, "vision_start_token_id", None) else tokenizer.convert_tokens_to_ids(self.vision_start_token) ) self.vision_end_token_id = ( tokenizer.vision_end_token_id if getattr(tokenizer, "vision_end_token_id", None) else tokenizer.convert_tokens_to_ids(self.vision_end_token) ) def __call__( self, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, videos: VideoInput = None, **kwargs: Unpack[Qwen3VLProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. text (`str`, `list[str]`, `list[list[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ output_kwargs = self._merge_kwargs( Qwen3VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] else: image_inputs = {} image_grid_thw = None if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] # If user has not requested video metadata, pop it if "return_metadata" not in kwargs: video_metadata = videos_inputs.pop("video_metadata") else: video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] else: videos_inputs = {} video_grid_thw = None if not isinstance(text, list): text = [text] text = text.copy() # below lines change text in-place if image_grid_thw is not None: merge_length = self.image_processor.merge_size**2 index = 0 for i in range(len(text)): while self.image_token in text[i]: num_image_tokens = image_grid_thw[index].prod() // merge_length text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", self.image_token) if video_grid_thw is not None: merge_length = self.video_processor.merge_size**2 index = 0 for i in range(len(text)): while self.video_token in text[i]: metadata = video_metadata[index] if metadata.fps is None: logger.warning_once( "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. " "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. " "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." ) metadata.fps = 24 if metadata.fps is None else metadata.fps # if timestamps are not provided, calculate them curr_timestamp = self._calculate_timestamps( metadata.frames_indices, metadata.fps, self.video_processor.merge_size, ) video_placeholder = "" frame_seqlen = video_grid_thw[index][1:].prod() // merge_length for frame_idx in range(video_grid_thw[index][0]): curr_time = curr_timestamp[frame_idx] video_placeholder += f"<{curr_time:.1f} seconds>" video_placeholder += ( self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token ) if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]: text[i] = text[i].replace( f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1 ) else: # vllm may input video token directly text[i] = text[i].replace(self.video_token, video_placeholder, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", self.video_token) return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) if return_mm_token_type_ids: array_ids = np.array(text_inputs["input_ids"]) mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) mm_token_type_ids[array_ids == self.image_token_id] = 1 text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2): if not isinstance(indices, list): indices = indices.tolist() if len(indices) % merge_size != 0: indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size)) timestamps = [idx / video_fps for idx in indices] # @JJJYmmm frames are merged by self.merge_size, \ # so we need to average the timestamps between the first/last frame within the temporal patch timestamps = [ (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size) ] return timestamps __all__ = [ "Qwen3VLConfig", "Qwen3VLTextConfig", "Qwen3VLVisionModel", "Qwen3VLForConditionalGeneration", "Qwen3VLModel", "Qwen3VLPreTrainedModel", "Qwen3VLProcessor", "Qwen3VLTextModel", ]