# coding=utf-8
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ViT model."""
import collections.abc
import math
from typing import Callable, Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageModelingOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
from ...utils.generic import can_return_tuple, check_model_inputs
from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__)
class ViTEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
def __init__(self, config: ViTConfig, use_mask_token: bool = False):
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class ViTPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config: ViTConfig):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class ViTSelfAttention(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
def forward(
self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, attention_probs
class ViTSelfOutput(nn.Module):
"""
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViTAttention(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: set[int]):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
self_attn_output, _ = self.attention(hidden_states, head_mask)
output = self.output(self_attn_output, hidden_states)
return output
class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTOutput(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTLayer(GradientCheckpointingLayer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTConfig):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTAttention(config)
self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states_norm = self.layernorm_before(hidden_states)
attention_output = self.attention(hidden_states_norm, head_mask)
# first residual connection
hidden_states = attention_output + hidden_states
# in ViT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
return layer_output
class ViTEncoder(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput:
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module(hidden_states, layer_head_mask)
return BaseModelOutput(last_hidden_state=hidden_states)
@auto_docstring
class ViTPreTrainedModel(PreTrainedModel):
config: ViTConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": ViTLayer,
"attentions": ViTSelfAttention,
}
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.cls_token.dtype)
if module.mask_token is not None:
module.mask_token.data.zero_()
@auto_docstring
class ViTModel(ViTPreTrainedModel):
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
r"""
add_pooling_layer (bool, *optional*, defaults to `True`):
Whether to add a pooling layer
use_mask_token (`bool`, *optional*, defaults to `False`):
Whether to use a mask token for masked image modeling.
"""
super().__init__(config)
self.config = config
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> ViTPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: dict[int, list[int]]):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@check_model_inputs
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
interpolate_pos_encoding: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask)
sequence_output = encoder_outputs.last_hidden_state
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output)
class ViTPooler(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
self.activation = ACT2FN[config.pooler_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
@auto_docstring(
custom_intro="""
ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
Note that we provide a script to pre-train this model on custom data in our [examples
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
"""
)
class ViTForMaskedImageModeling(ViTPreTrainedModel):
def __init__(self, config: ViTConfig):
super().__init__(config)
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=config.hidden_size,
out_channels=config.encoder_stride**2 * config.num_channels,
kernel_size=1,
),
nn.PixelShuffle(config.encoder_stride),
)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
interpolate_pos_encoding: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> MaskedImageModelingOutput:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Examples:
```python
>>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
>>> # create random boolean mask of shape (batch_size, num_patches)
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]
```"""
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
raise ValueError(
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
"the reconstructed image has the same dimensions as the input. "
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}."
)
outputs: BaseModelOutputWithPooling = self.vit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
**kwargs,
)
sequence_output = outputs.last_hidden_state
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
reconstructed_pixel_values = self.decoder(sequence_output)
masked_im_loss = None
if bool_masked_pos is not None:
size = self.config.image_size // self.config.patch_size
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
mask = (
bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
.repeat_interleave(self.config.patch_size, 2)
.unsqueeze(1)
.contiguous()
)
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
return MaskedImageModelingOutput(
loss=masked_im_loss,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@auto_docstring(
custom_intro="""
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
"""
)
class ViTForImageClassification(ViTPreTrainedModel):
def __init__(self, config: ViTConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.vit = ViTModel(config, add_pooling_layer=False)
# Classifier head
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
interpolate_pos_encoding: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> ImageClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs: BaseModelOutputWithPooling = self.vit(
pixel_values,
head_mask=head_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
**kwargs,
)
sequence_output = outputs.last_hidden_state
pooled_output = sequence_output[:, 0, :]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss = self.loss_function(labels, logits, self.config, **kwargs)
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["ViTForImageClassification", "ViTForMaskedImageModeling", "ViTModel", "ViTPreTrainedModel"]