# coding=utf-8 # Copyright 2021 NVIDIA The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch SegFormer model.""" import math from typing import Optional, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging from .configuration_segformer import SegformerConfig logger = logging.get_logger(__name__) class SegFormerImageClassifierOutput(ImageClassifierOutput): """ Base class for outputs of image classification models. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Classification (or regression if config.num_labels==1) loss. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None # Copied from transformers.models.beit.modeling_beit.drop_path def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize output = input.div(keep_prob) * random_tensor return output # Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Segformer class SegformerDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: return f"p={self.drop_prob}" class SegformerOverlapPatchEmbeddings(nn.Module): """Construct the overlapping patch embeddings.""" def __init__(self, patch_size, stride, num_channels, hidden_size): super().__init__() self.proj = nn.Conv2d( num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=patch_size // 2, ) self.layer_norm = nn.LayerNorm(hidden_size) def forward(self, pixel_values): embeddings = self.proj(pixel_values) _, _, height, width = embeddings.shape # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels) # this can be fed to a Transformer layer embeddings = embeddings.flatten(2).transpose(1, 2) embeddings = self.layer_norm(embeddings) return embeddings, height, width class SegformerEfficientSelfAttention(nn.Module): """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT paper](https://huggingface.co/papers/2102.12122).""" def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): super().__init__() self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads if self.hidden_size % self.num_attention_heads != 0: raise ValueError( f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " f"heads ({self.num_attention_heads})" ) self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(self.hidden_size, self.all_head_size) self.key = nn.Linear(self.hidden_size, self.all_head_size) self.value = nn.Linear(self.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.sr_ratio = sequence_reduction_ratio if sequence_reduction_ratio > 1: self.sr = nn.Conv2d( hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio ) self.layer_norm = nn.LayerNorm(hidden_size) def forward( self, hidden_states, height, width, output_attentions=False, ): batch_size, seq_length, _ = hidden_states.shape query_layer = ( self.query(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) if self.sr_ratio > 1: batch_size, seq_len, num_channels = hidden_states.shape # Reshape to (batch_size, num_channels, height, width) hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) # Apply sequence reduction hidden_states = self.sr(hidden_states) # Reshape back to (batch_size, seq_len, num_channels) hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) key_layer = ( self.key(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) value_layer = ( self.value(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs class SegformerSelfOutput(nn.Module): def __init__(self, config, hidden_size): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class SegformerAttention(nn.Module): def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): super().__init__() self.self = SegformerEfficientSelfAttention( config=config, hidden_size=hidden_size, num_attention_heads=num_attention_heads, sequence_reduction_ratio=sequence_reduction_ratio, ) self.output = SegformerSelfOutput(config, hidden_size=hidden_size) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward(self, hidden_states, height, width, output_attentions=False): self_outputs = self.self(hidden_states, height, width, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class SegformerDWConv(nn.Module): def __init__(self, dim=768): super().__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, hidden_states, height, width): batch_size, seq_len, num_channels = hidden_states.shape hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width) hidden_states = self.dwconv(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) return hidden_states class SegformerMixFFN(nn.Module): def __init__(self, config, in_features, hidden_features=None, out_features=None): super().__init__() out_features = out_features or in_features self.dense1 = nn.Linear(in_features, hidden_features) self.dwconv = SegformerDWConv(hidden_features) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act self.dense2 = nn.Linear(hidden_features, out_features) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, height, width): hidden_states = self.dense1(hidden_states) hidden_states = self.dwconv(hidden_states, height, width) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.dense2(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class SegformerLayer(nn.Module): """This corresponds to the Block class in the original implementation.""" def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio): super().__init__() self.layer_norm_1 = nn.LayerNorm(hidden_size) self.attention = SegformerAttention( config, hidden_size=hidden_size, num_attention_heads=num_attention_heads, sequence_reduction_ratio=sequence_reduction_ratio, ) self.drop_path = SegformerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.layer_norm_2 = nn.LayerNorm(hidden_size) mlp_hidden_size = int(hidden_size * mlp_ratio) self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size) def forward(self, hidden_states, height, width, output_attentions=False): self_attention_outputs = self.attention( self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention height, width, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection (with stochastic depth) attention_output = self.drop_path(attention_output) hidden_states = attention_output + hidden_states mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) # second residual connection (with stochastic depth) mlp_output = self.drop_path(mlp_output) layer_output = mlp_output + hidden_states outputs = (layer_output,) + outputs return outputs class SegformerEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config # stochastic depth decay rule drop_path_decays = [ x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu") ] # patch embeddings embeddings = [] for i in range(config.num_encoder_blocks): embeddings.append( SegformerOverlapPatchEmbeddings( patch_size=config.patch_sizes[i], stride=config.strides[i], num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], hidden_size=config.hidden_sizes[i], ) ) self.patch_embeddings = nn.ModuleList(embeddings) # Transformer blocks blocks = [] cur = 0 for i in range(config.num_encoder_blocks): # each block consists of layers layers = [] if i != 0: cur += config.depths[i - 1] for j in range(config.depths[i]): layers.append( SegformerLayer( config, hidden_size=config.hidden_sizes[i], num_attention_heads=config.num_attention_heads[i], drop_path=drop_path_decays[cur + j], sequence_reduction_ratio=config.sr_ratios[i], mlp_ratio=config.mlp_ratios[i], ) ) blocks.append(nn.ModuleList(layers)) self.block = nn.ModuleList(blocks) # Layer norms self.layer_norm = nn.ModuleList( [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)] ) def forward( self, pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None batch_size = pixel_values.shape[0] hidden_states = pixel_values for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)): embedding_layer, block_layer, norm_layer = x # first, obtain patch embeddings hidden_states, height, width = embedding_layer(hidden_states) # second, send embeddings through blocks for i, blk in enumerate(block_layer): layer_outputs = blk(hidden_states, height, width, output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) # third, apply layer norm hidden_states = norm_layer(hidden_states) # fourth, optionally reshape back to (batch_size, num_channels, height, width) if idx != len(self.patch_embeddings) - 1 or ( idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage ): hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @auto_docstring class SegformerPreTrainedModel(PreTrainedModel): config: SegformerConfig base_model_prefix = "segformer" main_input_name = "pixel_values" def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): module.bias.data.zero_() module.weight.data.fill_(1.0) @auto_docstring class SegformerModel(SegformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config # hierarchical Transformer encoder self.encoder = SegformerEncoder(config) # Initialize weights and apply final processing self.post_init() def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) @auto_docstring def forward( self, pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_outputs = self.encoder( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] if not return_dict: return (sequence_output,) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @auto_docstring( custom_intro=""" SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden states) e.g. for ImageNet. """ ) class SegformerForImageClassification(SegformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.segformer = SegformerModel(config) # Classifier head self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, SegFormerImageClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.segformer( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] # convert last hidden states to (batch_size, height*width, hidden_size) batch_size = sequence_output.shape[0] if self.config.reshape_last_stage: # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) sequence_output = sequence_output.permute(0, 2, 3, 1) sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1]) # global average pooling sequence_output = sequence_output.mean(dim=1) logits = self.classifier(sequence_output) loss = None if labels is not None: loss = self.loss_function(labels, logits, self.config) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return SegFormerImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class SegformerMLP(nn.Module): """ Linear Embedding. """ def __init__(self, config: SegformerConfig, input_dim): super().__init__() self.proj = nn.Linear(input_dim, config.decoder_hidden_size) def forward(self, hidden_states: torch.Tensor): hidden_states = hidden_states.flatten(2).transpose(1, 2) hidden_states = self.proj(hidden_states) return hidden_states class SegformerDecodeHead(SegformerPreTrainedModel): def __init__(self, config): super().__init__(config) # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size mlps = [] for i in range(config.num_encoder_blocks): mlp = SegformerMLP(config, input_dim=config.hidden_sizes[i]) mlps.append(mlp) self.linear_c = nn.ModuleList(mlps) # the following 3 layers implement the ConvModule of the original implementation self.linear_fuse = nn.Conv2d( in_channels=config.decoder_hidden_size * config.num_encoder_blocks, out_channels=config.decoder_hidden_size, kernel_size=1, bias=False, ) self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) self.activation = nn.ReLU() self.dropout = nn.Dropout(config.classifier_dropout_prob) self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) self.config = config def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: batch_size = encoder_hidden_states[-1].shape[0] all_hidden_states = () for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3: height = width = int(math.sqrt(encoder_hidden_state.shape[-1])) encoder_hidden_state = ( encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() ) # unify channel dimension height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] encoder_hidden_state = mlp(encoder_hidden_state) encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) # upsample encoder_hidden_state = nn.functional.interpolate( encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False ) all_hidden_states += (encoder_hidden_state,) hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) hidden_states = self.batch_norm(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.dropout(hidden_states) # logits are of shape (batch_size, num_labels, height/4, width/4) logits = self.classifier(hidden_states) return logits @auto_docstring( custom_intro=""" SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes. """ ) class SegformerForSemanticSegmentation(SegformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.segformer = SegformerModel(config) self.decode_head = SegformerDecodeHead(config) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, SemanticSegmenterOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). Examples: ```python >>> from transformers import AutoImageProcessor, SegformerForSemanticSegmentation >>> from PIL import Image >>> import requests >>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") >>> model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = image_processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4) >>> list(logits.shape) [1, 150, 128, 128] ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if labels is not None and self.config.num_labels < 1: raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}") outputs = self.segformer( pixel_values, output_attentions=output_attentions, output_hidden_states=True, # we need the intermediate hidden states return_dict=return_dict, ) encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] logits = self.decode_head(encoder_hidden_states) loss = None if labels is not None: # upsample logits to the images' original size upsampled_logits = nn.functional.interpolate( logits, size=labels.shape[-2:], mode="bilinear", align_corners=False ) if self.config.num_labels > 1: loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) loss = loss_fct(upsampled_logits, labels) elif self.config.num_labels == 1: valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float() loss_fct = BCEWithLogitsLoss(reduction="none") loss = loss_fct(upsampled_logits.squeeze(1), labels.float()) loss = (loss * valid_mask).mean() if not return_dict: if output_hidden_states: output = (logits,) + outputs[1:] else: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SemanticSegmenterOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions, ) __all__ = [ "SegformerDecodeHead", "SegformerForImageClassification", "SegformerForSemanticSegmentation", "SegformerLayer", "SegformerModel", "SegformerPreTrainedModel", ]