# coding=utf-8 # Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BLIP-2 model.""" import math import warnings from dataclasses import dataclass from typing import Any, Callable, Optional, Union import torch from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithPast, Seq2SeqLMOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int, ) from ...utils.generic import OutputRecorder, check_model_inputs from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Class defining the outputs of [`Blip2ForConditionalGeneration`]. """ ) class Blip2ForConditionalGenerationModelOutput(ModelOutput): r""" loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): Language modeling loss from the language model. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head of the language model. vision_outputs (`BaseModelOutputWithPooling`): Outputs of the vision encoder. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`): Outputs of the Q-Former (Querying Transformer). language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`): Outputs of the language model. """ loss: Optional[tuple[torch.FloatTensor]] = None logits: Optional[tuple[torch.FloatTensor]] = None vision_outputs: Optional[torch.FloatTensor] = None qformer_outputs: Optional[tuple[torch.FloatTensor]] = None language_model_outputs: Optional[tuple[torch.FloatTensor]] = None def to_tuple(self) -> tuple[Any]: return tuple( self[k] if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"] else getattr(self, k).to_tuple() for k in self.keys() ) @dataclass @auto_docstring class Blip2ImageTextMatchingModelOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): Contrastive loss for image-text similarity. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text similarity scores. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image similarity scores. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output. text_model_output (`BaseModelOutputWithPooling`): The output of the [`Blip2QFormerModel`]. vision_model_output (`BaseModelOutputWithPooling`): The output of the [`Blip2VisionModel`]. """ loss: Optional[torch.FloatTensor] = None logits_per_image: Optional[torch.FloatTensor] = None logits_per_text: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None image_embeds: Optional[torch.FloatTensor] = None text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) @dataclass @auto_docstring( custom_intro=""" Base class for text model's outputs that also contains a pooling of the last hidden states. """ ) # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Blip2 class Blip2TextModelOutput(ModelOutput): r""" text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. """ text_embeds: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None @dataclass @auto_docstring( custom_intro=""" Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. """ ) # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Blip2 class Blip2VisionModelOutput(ModelOutput): r""" image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. """ image_embeds: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 class Blip2VisionEmbeddings(nn.Module): def __init__(self, config: Blip2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) self.patch_embedding = nn.Conv2d( in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embedding.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding class_pos_embed = self.position_embedding[:, :1] patch_pos_embed = self.position_embedding[:, 1:] dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) if interpolate_pos_encoding: position_embedding = self.interpolate_pos_encoding(embeddings, height, width) else: position_embedding = self.position_embedding embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) return embeddings # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> BLIP doesn't cast attn weights to fp32 def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Blip2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.is_causal = False self.attention_dropout = config.attention_dropout # small tweak here compared to CLIP, no bias here self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) if config.qkv_bias: q_bias = nn.Parameter(torch.zeros(self.embed_dim)) v_bias = nn.Parameter(torch.zeros(self.embed_dim)) else: q_bias = None v_bias = None if q_bias is not None: qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias)) self.qkv.bias = nn.Parameter(qkv_bias) self.projection = nn.Linear(self.embed_dim, self.embed_dim) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, tgt_len, embed_dim = hidden_states.size() mixed_qkv = self.qkv(hidden_states) mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute( 2, 0, 3, 1, 4 ) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask=None, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scale, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.projection(attn_output) return attn_output, attn_weights # Copied from transformers.models.blip.modeling_blip.BlipMLP class Blip2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2 class Blip2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Blip2Config): super().__init__() self.embed_dim = config.hidden_size self.self_attn = Blip2Attention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Blip2MLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @auto_docstring def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, head_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + residual residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = hidden_states + residual return hidden_states @auto_docstring class Blip2PreTrainedModel(PreTrainedModel): config: Blip2Config base_model_prefix = "blip" supports_gradient_checkpointing = True _supports_attention_backend = True _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _no_split_modules = [ "Blip2Attention", "Blip2QFormerMultiHeadAttention", "Blip2EncoderLayer", "Blip2TextEmbeddings", "T5Block", "OPTDecoderLayer", ] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=factor) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=factor) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, Blip2VisionEmbeddings): nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) elif isinstance( module, ( Blip2Model, Blip2TextModelWithProjection, Blip2VisionModelWithProjection, Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, ), ): module.query_tokens.data.zero_() # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2 class Blip2Encoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Blip2EncoderLayer`]. Args: config (`Blip2Config`): The corresponding vision configuration for the `Blip2Encoder`. """ def __init__(self, config: Blip2Config): super().__init__() self.config = config self.layers = nn.ModuleList([Blip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False @auto_docstring def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, BaseModelOutput]: hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask=attention_mask, **kwargs, ) return BaseModelOutput(last_hidden_state=hidden_states) @auto_docstring # Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2 class Blip2VisionModel(Blip2PreTrainedModel): main_input_name = "pixel_values" config: Blip2VisionConfig _can_record_outputs = { "hidden_states": Blip2EncoderLayer, "attentions": Blip2Attention, } def __init__(self, config: Blip2VisionConfig): super().__init__(config) self.config = config embed_dim = config.hidden_size self.embeddings = Blip2VisionEmbeddings(config) self.encoder = Blip2Encoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_init() @check_model_inputs @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, interpolate_pos_encoding: bool = False, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, BaseModelOutputWithPooling]: if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, **kwargs, ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, ) def get_input_embeddings(self): return self.embeddings class Blip2QFormerMultiHeadAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( "The hidden size (%d) is not a multiple of the number of attention heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) else: self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.save_attention = False def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, **kwargs: Unpack[TransformersKwargs], ): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": seq_length = hidden_states.size()[1] position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if is_cross_attention and self.save_attention: self.save_attention_map(attention_probs) attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return ( context_layer, attention_probs, ) # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Blip2QFormer class Blip2QFormerSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class Blip2QFormerAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() self.attention = Blip2QFormerMultiHeadAttention(config, is_cross_attention) self.output = Blip2QFormerSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads ) # Prune linear layers self.attention.query = prune_linear_layer(self.attention.query, index) self.attention.key = prune_linear_layer(self.attention.key, index) self.attention.value = prune_linear_layer(self.attention.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: attn_output, _ = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, **kwargs, ) attention_output = self.output(attn_output, hidden_states) return attention_output # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Blip2QFormer class Blip2QFormerIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Blip2QFormer class Blip2QFormerOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class Blip2QFormerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = Blip2QFormerAttention(config) self.layer_idx = layer_idx if layer_idx % config.cross_attention_frequency == 0: self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True) self.has_cross_attention = True else: self.has_cross_attention = False if config.use_qformer_text_input: self.intermediate = Blip2QFormerIntermediate(config) self.output = Blip2QFormerOutput(config) self.intermediate_query = Blip2QFormerIntermediate(config) self.output_query = Blip2QFormerOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, query_length=0, **kwargs: Unpack[TransformersKwargs], ): attention_output = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask, **kwargs, ) if query_length > 0: query_attention_output = attention_output[:, :query_length, :] if self.has_cross_attention: if encoder_hidden_states is None: raise ValueError("encoder_hidden_states must be given for cross-attention layers") query_attention_output = self.crossattention( hidden_states=query_attention_output, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, **kwargs, ) layer_output = apply_chunking_to_forward( self.feed_forward_chunk_query, self.chunk_size_feed_forward, self.seq_len_dim, query_attention_output, ) if attention_output.shape[1] > query_length: layer_output_text = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[:, query_length:, :], ) layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) return layer_output def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output def feed_forward_chunk_query(self, attention_output): intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output class Blip2QFormerEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList( [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False @can_return_tuple def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, query_length=0, **kwargs: Unpack[TransformersKwargs], ): for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] layer_head_mask = head_mask[i] if head_mask is not None else None hidden_states = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, query_length=query_length, **kwargs, ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, ) class Blip2TextEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") def forward( self, input_ids: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, query_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: if input_ids is not None: seq_length = input_ids.size()[1] else: seq_length = 0 if position_ids is None: position_ids = self.position_ids[:, :seq_length] if input_ids is not None: input_ids = input_ids.to(self.word_embeddings.weight.device) embeddings = self.word_embeddings(input_ids) if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings if query_embeds is not None: # `query_embeds` are kept in fp32 when we use it with Qformer if query_embeds.dtype != embeddings.dtype: query_embeds = query_embeds.to(embeddings.dtype) embeddings = torch.cat((query_embeds, embeddings), dim=1) else: embeddings = query_embeds return embeddings @auto_docstring( custom_intro=""" BLIP-2 Querying Transformer (Q-Former). """ ) class Blip2QFormerModel(Blip2PreTrainedModel): _supports_attention_backend = False # adds position on attn weights before last matmul _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False _can_record_outputs = { "hidden_states": Blip2QFormerLayer, "attentions": [ OutputRecorder(Blip2QFormerMultiHeadAttention, index=1, layer_name=".attention"), ], "cross_attentions": [ OutputRecorder(Blip2QFormerMultiHeadAttention, index=1, layer_name=".crossattention"), ], } def __init__(self, config: Blip2QFormerConfig): super().__init__(config) self.config = config self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.encoder = Blip2QFormerEncoder(config) self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: torch.Tensor, input_shape: tuple[int], device: torch.device, has_query: bool = False, ) -> torch.Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (`tuple[int]`): The shape of the input to the model. device (`torch.device`): The device of the input to the model. Returns: `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask @check_model_inputs @auto_docstring def forward( self, query_embeds: torch.FloatTensor, query_length: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Hidden states to be used in the attention computation. If cross-attention, will be used for the query (i.e., key and value will use the encoder_hidden_states). query_length (`int`, *optional*): Length of the query, usually based on the number of query tokens. If no value is provided, query_length will be inferred by the query_embeds. """ query_length = ( query_length if query_length is not None else query_embeds.shape[1] if query_embeds is not None else 0 ) # `Blip2QFormerModel` is kept as fp32 query_embeds = query_embeds.to(self.layernorm.weight.dtype) embedding_output = self.layernorm(query_embeds) embedding_output = self.dropout(embedding_output) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length)), device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: # Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already if encoder_hidden_states.dtype != query_embeds.dtype: encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype) if isinstance(encoder_hidden_states, list): encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() else: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if isinstance(encoder_attention_mask, list): encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) encoder_outputs: BaseModelOutput = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, query_length=query_length, **kwargs, ) sequence_output = encoder_outputs.last_hidden_state pooled_output = sequence_output[:, 0, :] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, ) @auto_docstring( custom_intro=""" BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer (Q-Former) and a language model. """ ) class Blip2Model(Blip2PreTrainedModel): config: Blip2Config main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) self.vision_model = Blip2VisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel._from_config(config.qformer_config) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: language_model = AutoModelForCausalLM.from_config(config.text_config) else: language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) # Update _tied_weights_keys using the base model used. if language_model._tied_weights_keys is not None: self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] self.language_model = language_model # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() def get_encoder(self): return self.language_model.get_encoder() def get_decoder(self): return self.language_model.get_decoder() def _tie_weights(self): if not self.config.use_decoder_only_language_model: self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared @filter_out_non_signature_kwargs() @auto_docstring def get_text_features( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.Tensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, legacy_output: bool = True, ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training). decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. legacy_output (`bool`, *optional*, defaults to `True`): Whether to return a model output object or a tensor of features. Returns: text_outputs (`CausalLMOutputWithPast` or `torch.FloatTensor`): The language model outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`. Examples: ```python >>> import torch >>> from transformers import AutoTokenizer, Blip2Model >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b") >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b") >>> inputs = tokenizer(["a photo of a cat"], padding=True, return_tensors="pt") >>> with torch.inference_mode(): ... text_features = model.get_text_features(**inputs) ```""" if legacy_output: warnings.warn( "Deprecation notice: In Transformers v4.59, the default return value of `get_text_features` will change. " "Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. " "To opt in to the new behavior now, set `legacy_output=False`." ) if self.config.use_decoder_only_language_model: text_outputs: CausalLMOutputWithPast = self.language_model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) else: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) text_outputs: Seq2SeqLMOutput = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, return_dict=True, ) return text_outputs if legacy_output else text_outputs.logits @filter_out_non_signature_kwargs() @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, legacy_output: bool = True, ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]: r""" legacy_output (`bool`, *optional*, defaults to `True`): Whether to return a model output object or a tensor of features. Returns: vision_outputs (`BaseModelOutputWithPooling` or `torch.FloatTensor`): The vision model outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`. Examples: ```python >>> import torch >>> from transformers import AutoProcessor, Blip2Model >>> from transformers.image_utils import load_image >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b") >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = load_image(url) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.inference_mode(): ... image_outputs = model.get_image_features(**inputs) ```""" if legacy_output: warnings.warn( "Deprecation notice: In Transformers v4.59, the default return value of `get_text_features` will change. " "Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. " "To opt in to the new behavior now, set `legacy_output=False`." ) vision_outputs = self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True, ) return vision_outputs if legacy_output else vision_outputs.pooler_output @filter_out_non_signature_kwargs() @auto_docstring def get_qformer_features( self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, legacy_output: bool = True, ) -> Union[torch.FloatTensor, BaseModelOutputWithPooling]: r""" legacy_output (`bool`, *optional*, defaults to `True`): Whether to return a model output object or a tensor of features. Returns: qformer_outputs (`BaseModelOutputWithPooling` or `torch.FloatTensor`): The Q-Former outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`. Examples: ```python >>> import torch >>> from transformers import AutoProcessor, Blip2Model >>> from transformers.image_utils import load_image >>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = load_image(url) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.inference_mode(): ... qformer_outputs = model.get_qformer_features(**inputs) ```""" if legacy_output: warnings.warn( "Deprecation notice: In Transformers v4.59, the default return value of `get_qformer_features` will change. " "Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. " "To opt in to the new behavior now, set `legacy_output=False`." ) vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True, ) image_embeds = vision_outputs.last_hidden_state # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=True, ) return query_outputs if legacy_output else query_outputs.last_hidden_state def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) return special_image_mask @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, interpolate_pos_encoding: bool = False, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Blip2ForConditionalGenerationModelOutput]: r""" decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. Only relevant in case an encoder-decoder language model (like T5) is used. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import Blip2Processor, Blip2Model >>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") >>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", dtype=torch.float16) >>> model.to(device) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> prompt = "Question: how many cats are there? Answer:" >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) >>> outputs = model(**inputs) ```""" # step 1: forward the images through the vision encoder, # to get image embeddings of shape (batch_size, seq_len, hidden_size) vision_outputs = self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs, ) image_embeds = vision_outputs[0] # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, **kwargs, ) query_output = query_outputs[0] # Qformer is kept in fp32, we downcast the output back if needed if query_output.dtype != image_embeds.dtype: query_output = query_output.to(image_embeds.dtype) # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter( special_image_mask, language_model_inputs ) if self.config.use_decoder_only_language_model: outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs, ) logits = outputs[0] loss = None # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: labels = labels.to(logits.device) logits = logits[:, -labels.size(1) :, :] # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous().to(logits.device) # Flatten the tokens loss_fct = CrossEntropyLoss(reduction="mean") loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) else: outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, return_dict=True, **kwargs, ) loss = outputs.loss logits = outputs.logits return Blip2ForConditionalGenerationModelOutput( loss=loss, logits=logits, vision_outputs=vision_outputs, qformer_outputs=query_outputs, language_model_outputs=outputs, ) @auto_docstring class Blip2TextModelWithProjection(Blip2PreTrainedModel): supports_gradient_checkpointing = False _keep_in_fp32_modules = ["query_tokens", "qformer"] _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.embeddings = Blip2TextEmbeddings(config.qformer_config) self.qformer = Blip2QFormerModel(config.qformer_config) # text projection layer self.text_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Blip2TextModelOutput]: r""" Examples: ```python >>> import torch >>> from transformers import AutoProcessor, Blip2TextModelWithProjection >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = Blip2TextModelWithProjection.from_pretrained( ... "Salesforce/blip2-itm-vit-g", dtype=torch.float16 ... ) >>> model.to(device) # doctest: +IGNORE_RESULT >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g") >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], return_tensors="pt").to(device) >>> outputs = model(**inputs) >>> text_embeds = outputs.text_embeds >>> print(text_embeds.shape) torch.Size([2, 7, 256]) ```""" query_embeds = self.embeddings( input_ids=input_ids, position_ids=position_ids, ) text_outputs = self.qformer( query_embeds=query_embeds, query_length=0, attention_mask=attention_mask, **kwargs, ) pooled_output = text_outputs[0] pooled_output = pooled_output.to(dtype=self.text_projection.weight.dtype) text_embeds = self.text_projection(pooled_output) text_embeds = nn.functional.normalize(text_embeds, dim=-1) return Blip2TextModelOutput( text_embeds=text_embeds, last_hidden_state=text_outputs.last_hidden_state, hidden_states=text_outputs.hidden_states, attentions=text_outputs.attentions, ) @auto_docstring class Blip2VisionModelWithProjection(Blip2PreTrainedModel): main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) self.vision_model = Blip2VisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel._from_config(config.qformer_config) # vision projection layer self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @can_return_tuple @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Blip2VisionModelOutput]: r""" Examples: ```python >>> import torch >>> from transformers import AutoProcessor, Blip2VisionModelWithProjection >>> from transformers.image_utils import load_image >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g") >>> model = Blip2VisionModelWithProjection.from_pretrained( ... "Salesforce/blip2-itm-vit-g", dtype=torch.float16 ... ) >>> model.to(device) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = load_image(url) >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) >>> with torch.inference_mode(): ... outputs = model(**inputs) >>> image_embeds = outputs.image_embeds >>> print(image_embeds.shape) torch.Size([1, 32, 256]) ```""" vision_outputs = self.vision_model( pixel_values=pixel_values, **kwargs, ) pooled_output = vision_outputs[0] image_attention_mask = torch.ones(pooled_output.size()[:-1], dtype=torch.long, device=pooled_output.device) query_tokens = self.query_tokens.expand(pooled_output.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=pooled_output, encoder_attention_mask=image_attention_mask, **kwargs, ) embeds = query_outputs[0] embeds = embeds.to(dtype=self.vision_projection.weight.dtype) image_embeds = self.vision_projection(embeds) image_embeds = nn.functional.normalize(image_embeds, dim=-1) return Blip2VisionModelOutput( image_embeds=image_embeds, last_hidden_state=vision_outputs.last_hidden_state, hidden_states=vision_outputs.hidden_states, attentions=vision_outputs.attentions, ) @auto_docstring( custom_intro=""" BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision encoder, Querying Transformer (Q-Former) and a language model. One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token. Note that Flan-T5 checkpoints cannot be cast to float16. They are pre-trained using bfloat16. """ ) class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): config: Blip2Config main_input_name = "pixel_values" _can_compile_fullgraph = True _keep_in_fp32_modules = ["query_tokens", "qformer"] _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) self.vision_model = Blip2VisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel._from_config(config.qformer_config) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: language_model = AutoModelForCausalLM.from_config(config.text_config) else: language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) # Update _tied_weights_keys using the base model used. if language_model._tied_weights_keys is not None: self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] self.language_model = language_model # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() def get_encoder(self): return self.language_model.get_encoder() def get_decoder(self): return self.language_model.get_decoder() def _tie_weights(self): if not self.config.use_decoder_only_language_model: self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check https://github.com/huggingface/transformers/pull/21707 for more details. """ hf_device_map = self.hf_device_map if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`. logger.warning( "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." " Please pass a `device_map` that contains `language_model` to remove this warning." " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" " more details on creating a `device_map` for large models.", ) if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility def get_image_features( self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = False, ): """ Encodes images into continuous embeddings that can be forwarded to the language model. Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): The tensors corresponding to the input images. """ # step 1: forward the images through the vision encoder, # to get image embeddings of shape (batch_size, seq_len, hidden_size) vision_outputs = self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True, ) image_embeds = vision_outputs[0] # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=True, ) query_output = query_outputs[0] # Qformer is kept in fp32, we downcast the output back if needed if query_output.dtype != image_embeds.dtype: query_output = query_output.to(image_embeds.dtype) # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) if return_dict: return language_model_inputs, vision_outputs, query_outputs return language_model_inputs def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) return special_image_mask @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, interpolate_pos_encoding: bool = False, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Blip2ForConditionalGenerationModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be provided to serve as text prompt, which the language model can continue. Indices can be obtained using [`Blip2Processor`]. See [`Blip2Processor.__call__`] for details. [What are input IDs?](../glossary#input-ids) decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. Only relevant in case an encoder-decoder language model (like T5) is used. Examples: Prepare processor, model and image input ```python >>> from PIL import Image >>> import requests >>> from transformers import Blip2Processor, Blip2ForConditionalGeneration >>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") >>> model = Blip2ForConditionalGeneration.from_pretrained( ... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, dtype=torch.float16 ... ) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) ``` Image captioning (without providing a text prompt): ```python >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) >>> generated_ids = model.generate(**inputs) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() >>> print(generated_text) two cats laying on a couch ``` Visual question answering (prompt = question): ```python >>> prompt = "Question: how many cats are there? Answer:" >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16) >>> generated_ids = model.generate(**inputs) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() >>> print(generated_text) two ``` Note that int8 inference is also supported through [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). This greatly reduces the amount of memory used by the model while maintaining the same performance. ```python >>> model = Blip2ForConditionalGeneration.from_pretrained( ... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, dtype=torch.bfloat16 ... ) # doctest: +IGNORE_RESULT >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16) >>> generated_ids = model.generate(**inputs) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() >>> print(generated_text) two ```""" language_model_inputs, vision_outputs, query_outputs = self.get_image_features( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True ) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter( special_image_mask, language_model_inputs ) if self.config.use_decoder_only_language_model: outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs, ) logits = outputs[0] loss = None # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: labels = labels.to(logits.device) logits = logits[:, -labels.size(1) :, :] # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous().to(logits.device) # Flatten the tokens loss_fct = CrossEntropyLoss(reduction="mean") loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) else: kwargs["return_dict"] = True outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, **kwargs, ) loss = outputs.loss logits = outputs.logits return Blip2ForConditionalGenerationModelOutput( loss=loss, logits=logits, vision_outputs=vision_outputs, qformer_outputs=query_outputs, language_model_outputs=outputs, ) @torch.no_grad() def generate( self, pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: """ Overrides `generate` function to be able to use the model as a conditional generator. Args: pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): Input images to be processed. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Embedded representation of the inputs. Should be float, not int tokens. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the positional encoding of the image embeddings. Returns: captions (list): A list of strings of length batch_size * num_captions. """ if hasattr(self, "hf_device_map"): # preprocess for `accelerate` self._preprocess_accelerate() batch_size = pixel_values.shape[0] image_embeds = self.vision_model( pixel_values, return_dict=True, interpolate_pos_encoding=interpolate_pos_encoding, ).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=True, ) query_output = query_outputs.last_hidden_state # Qformer is kept in fp32, we downcast the output back if needed if query_output.dtype != image_embeds.dtype: query_output = query_output.to(image_embeds.dtype) language_model_inputs = self.language_projection(query_output) if inputs_embeds is None: if input_ids is None: image_tokens = [self.config.image_token_index] * self.config.num_query_tokens start_tokens = image_tokens + [self.config.text_config.bos_token_id] input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) input_ids = input_ids.repeat(batch_size, 1) inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} if not self.language_model.config.is_encoder_decoder: inputs["input_ids"] = input_ids outputs = self.language_model.generate(**inputs, **generate_kwargs) return outputs @auto_docstring( custom_intro=""" BLIP-2 Model with a vision and text projector, and a classification head on top. The model is used in the context of image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to the image. """ ) class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) self.vision_model = Blip2VisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.embeddings = Blip2TextEmbeddings(config.qformer_config) self.qformer = Blip2QFormerModel._from_config(config.qformer_config) # vision projection layer self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) # text projection layer self.text_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) # image text matching head self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value @auto_docstring def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, use_image_text_matching_head: Optional[bool] = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, Blip2ImageTextMatchingModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be provided to serve as text prompt, which the language model can continue. Indices can be obtained using [`Blip2Processor`]. See [`Blip2Processor.__call__`] for details. [What are input IDs?](../glossary#input-ids) use_image_text_matching_head (`bool`, *optional*): Whether to return the Image-Text Matching or Contrastive scores. Examples: ```python >>> import torch >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Blip2ForImageTextRetrieval >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = Blip2ForImageTextRetrieval.from_pretrained("Salesforce/blip2-itm-vit-g", dtype=torch.float16) >>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g") >>> model.to(device) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> text = "two cats laying on a pink blanket" >>> inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16) >>> itm_out = model(**inputs, use_image_text_matching_head=True) >>> logits_per_image = torch.nn.functional.softmax(itm_out.logits_per_image, dim=1) >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities >>> print(f"{probs[0][0]:.1%} that image 0 is not '{text}'") 26.9% that image 0 is not 'two cats laying on a pink blanket' >>> print(f"{probs[0][1]:.1%} that image 0 is '{text}'") 73.0% that image 0 is 'two cats laying on a pink blanket' >>> texts = ["a photo of a cat", "a photo of a dog"] >>> inputs = processor(images=image, text=texts, return_tensors="pt").to(device, torch.float16) >>> itc_out = model(**inputs, use_image_text_matching_head=False) >>> logits_per_image = itc_out.logits_per_image # this is the image-text similarity score >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 55.3% that image 0 is 'a photo of a cat' >>> print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'") 44.7% that image 0 is 'a photo of a dog' ``` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[0] image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) if use_image_text_matching_head: query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) if self.config.image_token_index is not None: input_ids = input_ids[:, self.config.num_query_tokens :] else: query_attention_mask = torch.ones( query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device ) attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1) query_embeds = self.embeddings( input_ids=input_ids, query_embeds=query_tokens, ) text_outputs = self.qformer( query_embeds=query_embeds, query_length=query_tokens.shape[1], attention_mask=attention_mask, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=return_dict, ) text_embeds = text_outputs[0] if not return_dict else text_outputs.last_hidden_state text_embeds = text_embeds.to(dtype=self.itm_head.weight.dtype) output = self.itm_head(text_embeds[:, : query_tokens.size(1), :]) logits_per_image = output.mean(dim=1) logits_per_text = logits_per_image.t() else: query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=return_dict, ) image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state image_embeds = image_embeds.to(dtype=self.vision_projection.weight.dtype) if self.config.image_token_index is not None: input_ids = input_ids[:, self.config.num_query_tokens :] attention_mask = attention_mask[:, self.config.num_query_tokens :] query_embeds = self.embeddings( input_ids=input_ids, ) text_outputs = self.qformer( query_embeds=query_embeds, query_length=0, attention_mask=attention_mask, return_dict=return_dict, ) question_embeds = text_outputs[0] if not return_dict else text_outputs.last_hidden_state question_embeds = question_embeds.to(dtype=self.text_projection.weight.dtype) # normalized features image_embeds = nn.functional.normalize(self.vision_projection(image_embeds), dim=-1) text_embeds = nn.functional.normalize(self.text_projection(question_embeds[:, 0, :]), dim=-1) # cosine similarity as logits logits_per_image = torch.matmul(image_embeds, text_embeds.t()) logits_per_image, _ = logits_per_image.max(dim=1) logits_per_text = logits_per_image.t() if not return_dict: output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) return output return Blip2ImageTextMatchingModelOutput( logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, ) __all__ = [ "Blip2Model", "Blip2VisionModelWithProjection", "Blip2QFormerModel", "Blip2PreTrainedModel", "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", "Blip2VisionModel", "Blip2TextModelWithProjection", ]