# coding=utf-8 # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch ViT MSN (masked siamese network) model.""" import collections.abc from typing import Callable, Optional, Union import torch from torch import nn from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import TransformersKwargs, auto_docstring, logging, torch_int from ...utils.generic import can_return_tuple, check_model_inputs from .configuration_vit_msn import ViTMSNConfig logger = logging.get_logger(__name__) class ViTMSNEmbeddings(nn.Module): """ Construct the CLS token, position and patch embeddings. Optionally, also the mask token. """ def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None: super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = ViTMSNPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.patch_size = config.patch_size self.config = config # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, interpolate_pos_encoding: bool = False, ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) if bool_masked_pos is not None: seq_length = embeddings.shape[1] mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) # replace the masked visual tokens by mask_tokens mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings # Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->ViTMSN class ViTMSNPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config: ViTMSNConfig): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) if not interpolate_pos_encoding: if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size[0]}*{self.image_size[1]})." ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings # Copied from transformers.models.vit.modeling_vit.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): # Take the dot product between "query" and "key" to get the raw attention scores. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling # Normalize the attention scores to probabilities. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) # Mask heads if we want to if attention_mask is not None: attn_weights = attn_weights * attention_mask attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN class ViTMSNSelfAttention(nn.Module): def __init__(self, config: ViTMSNConfig): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size {config.hidden_size} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.dropout_prob = config.attention_probs_dropout_prob self.scaling = self.attention_head_size**-0.5 self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: batch_size = hidden_states.shape[0] new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2) value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] context_layer, attention_probs = attention_interface( self, query_layer, key_layer, value_layer, head_mask, is_causal=self.is_causal, scaling=self.scaling, dropout=0.0 if not self.training else self.dropout_prob, ) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.reshape(new_context_layer_shape) return context_layer, attention_probs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN class ViTMSNSelfOutput(nn.Module): """ The residual connection is defined in ViTMSNLayer instead of here (as is the case with other models), due to the layernorm applied before each block. """ def __init__(self, config: ViTMSNConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMSN class ViTMSNAttention(nn.Module): def __init__(self, config: ViTMSNConfig): super().__init__() self.attention = ViTMSNSelfAttention(config) self.output = ViTMSNSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads: set[int]): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads ) # Prune linear layers self.attention.query = prune_linear_layer(self.attention.query, index) self.attention.key = prune_linear_layer(self.attention.key, index) self.attention.value = prune_linear_layer(self.attention.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor: self_attn_output, _ = self.attention(hidden_states, head_mask) output = self.output(self_attn_output, hidden_states) return output # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN class ViTMSNIntermediate(nn.Module): def __init__(self, config: ViTMSNConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->ViTMSN class ViTMSNOutput(nn.Module): def __init__(self, config: ViTMSNConfig): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = hidden_states + input_tensor return hidden_states # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN class ViTMSNLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTMSNConfig): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = ViTMSNAttention(config) self.intermediate = ViTMSNIntermediate(config) self.output = ViTMSNOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states_norm = self.layernorm_before(hidden_states) attention_output = self.attention(hidden_states_norm, head_mask) # first residual connection hidden_states = attention_output + hidden_states # in ViTMSN, layernorm is also applied after self-attention layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) # second residual connection is done here layer_output = self.output(layer_output, hidden_states) return layer_output # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMSN class ViTMSNEncoder(nn.Module): def __init__(self, config: ViTMSNConfig): super().__init__() self.config = config self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput: for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] if head_mask is not None else None hidden_states = layer_module(hidden_states, layer_head_mask) return BaseModelOutput(last_hidden_state=hidden_states) @auto_docstring class ViTMSNPreTrainedModel(PreTrainedModel): config: ViTMSNConfig base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"] _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": ViTMSNLayer, "attentions": ViTMSNSelfAttention, } # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, ViTMSNEmbeddings): module.cls_token.data.zero_() module.position_embeddings.data.zero_() if module.mask_token is not None: module.mask_token.data.zero_() @auto_docstring class ViTMSNModel(ViTMSNPreTrainedModel): def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False): r""" use_mask_token (`bool`, *optional*, defaults to `False`): Whether to use a mask token for masked image modeling. """ super().__init__(config) self.config = config self.embeddings = ViTMSNEmbeddings(config, use_mask_token=use_mask_token) self.encoder = ViTMSNEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> ViTMSNPatchEmbeddings: return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) @check_model_inputs @auto_docstring def forward( self, pixel_values: Optional[torch.Tensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.Tensor] = None, interpolate_pos_encoding: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Examples: ```python >>> from transformers import AutoImageProcessor, ViTMSNModel >>> import torch >>> from PIL import Image >>> import requests >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-small") >>> model = ViTMSNModel.from_pretrained("facebook/vit-msn-small") >>> inputs = image_processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state ```""" if pixel_values is None: raise ValueError("You have to specify pixel_values") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask) sequence_output = encoder_outputs.last_hidden_state sequence_output = self.layernorm(sequence_output) return BaseModelOutput(last_hidden_state=sequence_output) # Caution: We don't have the weights for the classification head yet. This class # is here for the users that are interested to fine-tune the base model (ViTMSNModel). @auto_docstring class ViTMSNForImageClassification(ViTMSNPreTrainedModel): def __init__(self, config: ViTMSNConfig) -> None: super().__init__(config) self.num_labels = config.num_labels self.vit = ViTMSNModel(config) # Classifier head self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, pixel_values: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, interpolate_pos_encoding: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> ImageClassifierOutput: r""" Examples: ```python >>> from transformers import AutoImageProcessor, ViTMSNForImageClassification >>> import torch >>> from PIL import Image >>> import requests >>> torch.manual_seed(2) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-small") >>> model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-small") >>> inputs = image_processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_label = logits.argmax(-1).item() >>> print(model.config.id2label[predicted_label]) tusker ``` """ outputs: BaseModelOutput = self.vit( pixel_values, head_mask=head_mask, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs ) sequence_output = outputs.last_hidden_state logits = self.classifier(sequence_output[:, 0, :]) loss = None if labels is not None: loss = self.loss_function(labels, logits, self.config, **kwargs) return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = ["ViTMSNModel", "ViTMSNForImageClassification", "ViTMSNPreTrainedModel"]