# coding=utf-8 # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch LeViT model.""" import itertools from dataclasses import dataclass from typing import Optional, Union import torch from torch import nn from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ModelOutput, ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_levit import LevitConfig logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Output type of [`LevitForImageClassificationWithTeacher`]. """ ) class LevitForImageClassificationWithTeacherOutput(ModelOutput): r""" logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Prediction scores as the average of the `cls_logits` and `distillation_logits`. cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the class token). distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the distillation token). """ logits: Optional[torch.FloatTensor] = None cls_logits: Optional[torch.FloatTensor] = None distillation_logits: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None class LevitConvEmbeddings(nn.Module): """ LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer. """ def __init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1 ): super().__init__() self.convolution = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False ) self.batch_norm = nn.BatchNorm2d(out_channels) def forward(self, embeddings): embeddings = self.convolution(embeddings) embeddings = self.batch_norm(embeddings) return embeddings class LevitPatchEmbeddings(nn.Module): """ LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple `LevitConvEmbeddings`. """ def __init__(self, config): super().__init__() self.embedding_layer_1 = LevitConvEmbeddings( config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding ) self.activation_layer_1 = nn.Hardswish() self.embedding_layer_2 = LevitConvEmbeddings( config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding ) self.activation_layer_2 = nn.Hardswish() self.embedding_layer_3 = LevitConvEmbeddings( config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding ) self.activation_layer_3 = nn.Hardswish() self.embedding_layer_4 = LevitConvEmbeddings( config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding ) self.num_channels = config.num_channels def forward(self, pixel_values): num_channels = pixel_values.shape[1] if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) embeddings = self.embedding_layer_1(pixel_values) embeddings = self.activation_layer_1(embeddings) embeddings = self.embedding_layer_2(embeddings) embeddings = self.activation_layer_2(embeddings) embeddings = self.embedding_layer_3(embeddings) embeddings = self.activation_layer_3(embeddings) embeddings = self.embedding_layer_4(embeddings) return embeddings.flatten(2).transpose(1, 2) class MLPLayerWithBN(nn.Module): def __init__(self, input_dim, output_dim, bn_weight_init=1): super().__init__() self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False) self.batch_norm = nn.BatchNorm1d(output_dim) def forward(self, hidden_state): hidden_state = self.linear(hidden_state) hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state) return hidden_state class LevitSubsample(nn.Module): def __init__(self, stride, resolution): super().__init__() self.stride = stride self.resolution = resolution def forward(self, hidden_state): batch_size, _, channels = hidden_state.shape hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[ :, :: self.stride, :: self.stride ].reshape(batch_size, -1, channels) return hidden_state class LevitAttention(nn.Module): def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution): super().__init__() self.num_attention_heads = num_attention_heads self.scale = key_dim**-0.5 self.key_dim = key_dim self.attention_ratio = attention_ratio self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2 self.out_dim_projection = attention_ratio * key_dim * num_attention_heads self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values) self.activation = nn.Hardswish() self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0) points = list(itertools.product(range(resolution), range(resolution))) len_points = len(points) attention_offsets, indices = {}, [] for p1 in points: for p2 in points: offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) indices.append(attention_offsets[offset]) self.attention_bias_cache = {} self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets))) self.register_buffer( "attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False ) @torch.no_grad() def train(self, mode=True): super().train(mode) if mode and self.attention_bias_cache: self.attention_bias_cache = {} # clear ab cache def get_attention_biases(self, device): if self.training: return self.attention_biases[:, self.attention_bias_idxs] else: device_key = str(device) if device_key not in self.attention_bias_cache: self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] return self.attention_bias_cache[device_key] def forward(self, hidden_state): batch_size, seq_length, _ = hidden_state.shape queries_keys_values = self.queries_keys_values(hidden_state) query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split( [self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3 ) query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device) attention = attention.softmax(dim=-1) hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection) hidden_state = self.projection(self.activation(hidden_state)) return hidden_state class LevitAttentionSubsample(nn.Module): def __init__( self, input_dim, output_dim, key_dim, num_attention_heads, attention_ratio, stride, resolution_in, resolution_out, ): super().__init__() self.num_attention_heads = num_attention_heads self.scale = key_dim**-0.5 self.key_dim = key_dim self.attention_ratio = attention_ratio self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads self.out_dim_projection = attention_ratio * key_dim * num_attention_heads self.resolution_out = resolution_out # resolution_in is the initial resolution, resolution_out is final resolution after downsampling self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values) self.queries_subsample = LevitSubsample(stride, resolution_in) self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads) self.activation = nn.Hardswish() self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim) self.attention_bias_cache = {} points = list(itertools.product(range(resolution_in), range(resolution_in))) points_ = list(itertools.product(range(resolution_out), range(resolution_out))) len_points, len_points_ = len(points), len(points_) attention_offsets, indices = {}, [] for p1 in points_: for p2 in points: size = 1 offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2)) if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) indices.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets))) self.register_buffer( "attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False ) @torch.no_grad() def train(self, mode=True): super().train(mode) if mode and self.attention_bias_cache: self.attention_bias_cache = {} # clear ab cache def get_attention_biases(self, device): if self.training: return self.attention_biases[:, self.attention_bias_idxs] else: device_key = str(device) if device_key not in self.attention_bias_cache: self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] return self.attention_bias_cache[device_key] def forward(self, hidden_state): batch_size, seq_length, _ = hidden_state.shape key, value = ( self.keys_values(hidden_state) .view(batch_size, seq_length, self.num_attention_heads, -1) .split([self.key_dim, self.attention_ratio * self.key_dim], dim=3) ) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) query = self.queries(self.queries_subsample(hidden_state)) query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute( 0, 2, 1, 3 ) attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device) attention = attention.softmax(dim=-1) hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection) hidden_state = self.projection(self.activation(hidden_state)) return hidden_state class LevitMLPLayer(nn.Module): """ MLP Layer with `2X` expansion in contrast to ViT with `4X`. """ def __init__(self, input_dim, hidden_dim): super().__init__() self.linear_up = MLPLayerWithBN(input_dim, hidden_dim) self.activation = nn.Hardswish() self.linear_down = MLPLayerWithBN(hidden_dim, input_dim) def forward(self, hidden_state): hidden_state = self.linear_up(hidden_state) hidden_state = self.activation(hidden_state) hidden_state = self.linear_down(hidden_state) return hidden_state class LevitResidualLayer(nn.Module): """ Residual Block for LeViT """ def __init__(self, module, drop_rate): super().__init__() self.module = module self.drop_rate = drop_rate def forward(self, hidden_state): if self.training and self.drop_rate > 0: rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device) rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach() hidden_state = hidden_state + self.module(hidden_state) * rnd return hidden_state else: hidden_state = hidden_state + self.module(hidden_state) return hidden_state class LevitStage(nn.Module): """ LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers. """ def __init__( self, config, idx, hidden_sizes, key_dim, depths, num_attention_heads, attention_ratio, mlp_ratio, down_ops, resolution_in, ): super().__init__() self.layers = [] self.config = config self.resolution_in = resolution_in # resolution_in is the initial resolution, resolution_out is final resolution after downsampling for _ in range(depths): self.layers.append( LevitResidualLayer( LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in), self.config.drop_path_rate, ) ) if mlp_ratio > 0: hidden_dim = hidden_sizes * mlp_ratio self.layers.append( LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate) ) if down_ops[0] == "Subsample": self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1 self.layers.append( LevitAttentionSubsample( *self.config.hidden_sizes[idx : idx + 2], key_dim=down_ops[1], num_attention_heads=down_ops[2], attention_ratio=down_ops[3], stride=down_ops[5], resolution_in=resolution_in, resolution_out=self.resolution_out, ) ) self.resolution_in = self.resolution_out if down_ops[4] > 0: hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4] self.layers.append( LevitResidualLayer( LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate ) ) self.layers = nn.ModuleList(self.layers) def get_resolution(self): return self.resolution_in def forward(self, hidden_state): for layer in self.layers: hidden_state = layer(hidden_state) return hidden_state class LevitEncoder(nn.Module): """ LeViT Encoder consisting of multiple `LevitStage` stages. """ def __init__(self, config): super().__init__() self.config = config resolution = self.config.image_size // self.config.patch_size self.stages = [] self.config.down_ops.append([""]) for stage_idx in range(len(config.depths)): stage = LevitStage( config, stage_idx, config.hidden_sizes[stage_idx], config.key_dim[stage_idx], config.depths[stage_idx], config.num_attention_heads[stage_idx], config.attention_ratio[stage_idx], config.mlp_ratio[stage_idx], config.down_ops[stage_idx], resolution, ) resolution = stage.get_resolution() self.stages.append(stage) self.stages = nn.ModuleList(self.stages) def forward(self, hidden_state, output_hidden_states=False, return_dict=True): all_hidden_states = () if output_hidden_states else None for stage in self.stages: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) hidden_state = stage(hidden_state) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) if not return_dict: return tuple(v for v in [hidden_state, all_hidden_states] if v is not None) return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states) class LevitClassificationLayer(nn.Module): """ LeViT Classification Layer """ def __init__(self, input_dim, output_dim): super().__init__() self.batch_norm = nn.BatchNorm1d(input_dim) self.linear = nn.Linear(input_dim, output_dim) def forward(self, hidden_state): hidden_state = self.batch_norm(hidden_state) logits = self.linear(hidden_state) return logits @auto_docstring class LevitPreTrainedModel(PreTrainedModel): config: LevitConfig base_model_prefix = "levit" main_input_name = "pixel_values" _no_split_modules = ["LevitResidualLayer"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)): module.bias.data.zero_() module.weight.data.fill_(1.0) @auto_docstring class LevitModel(LevitPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.patch_embeddings = LevitPatchEmbeddings(config) self.encoder = LevitEncoder(config) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") embeddings = self.patch_embeddings(pixel_values) encoder_outputs = self.encoder( embeddings, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes) pooled_output = last_hidden_state.mean(dim=1) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndNoAttention( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, ) @auto_docstring( custom_intro=""" Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for ImageNet. """ ) class LevitForImageClassification(LevitPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.num_labels = config.num_labels self.levit = LevitModel(config) # Classifier head self.classifier = ( LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else torch.nn.Identity() ) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) sequence_output = outputs[0] sequence_output = sequence_output.mean(1) logits = self.classifier(sequence_output) loss = None if labels is not None: loss = self.loss_function(labels, logits, self.config) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutputWithNoAttention( loss=loss, logits=logits, hidden_states=outputs.hidden_states, ) @auto_docstring( custom_intro=""" LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning:: This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet supported. """ ) class LevitForImageClassificationWithTeacher(LevitPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.num_labels = config.num_labels self.levit = LevitModel(config) # Classifier head self.classifier = ( LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else torch.nn.Identity() ) self.classifier_distill = ( LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else torch.nn.Identity() ) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, LevitForImageClassificationWithTeacherOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) sequence_output = outputs[0] sequence_output = sequence_output.mean(1) cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output) logits = (cls_logits + distill_logits) / 2 if not return_dict: output = (logits, cls_logits, distill_logits) + outputs[2:] return output return LevitForImageClassificationWithTeacherOutput( logits=logits, cls_logits=cls_logits, distillation_logits=distill_logits, hidden_states=outputs.hidden_states, ) __all__ = [ "LevitForImageClassification", "LevitForImageClassificationWithTeacher", "LevitModel", "LevitPreTrainedModel", ]