# coding=utf-8 # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mask2Former model.""" import math import warnings from dataclasses import dataclass from typing import Optional, Union import numpy as np import torch from torch import Tensor, nn from ...activations import ACT2FN from ...file_utils import ModelOutput, is_scipy_available, requires_backends from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import auto_docstring, is_accelerate_available, logging from ...utils.backbone_utils import load_backbone from .configuration_mask2former import Mask2FormerConfig if is_scipy_available(): from scipy.optimize import linear_sum_assignment if is_accelerate_available(): from accelerate import PartialState from accelerate.utils import reduce logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Mask2Former's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns the mask features and the multiscale features. """ ) class Mask2FormerPixelDecoderOutput(ModelOutput): r""" multi_scale_features (`tuple(torch.FloatTensor)`): Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height, width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder. mask_features (`torch.FloatTensor`): Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder Layer. attentions (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed or when `config.output_attentions=True` """ multi_scale_features: Optional[tuple[torch.FloatTensor]] = None mask_features: Optional[torch.FloatTensor] = None attentions: Optional[tuple[torch.FloatTensor]] = None @dataclass @auto_docstring( custom_intro=""" Base class for outputs of the Transformer decoder. This class adds two attributes to BaseModelOutputWithCrossAttentions for mask predictions logits and a tuple of intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a layernorm. """ ) class Mask2FormerMaskedAttentionDecoderOutput(BaseModelOutputWithCrossAttentions): r""" hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. Returned when `output_hidden_states=True`. attentions (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Returned when `output_attentions=True`. masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`): Tuple of mask predictions from all layers of the transformer decoder. intermediate_hidden_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`): Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a layernorm. """ last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[torch.FloatTensor] = None masks_queries_logits: Optional[tuple[torch.FloatTensor]] = None intermediate_hidden_states: Optional[tuple[torch.FloatTensor]] = None @dataclass @auto_docstring( custom_intro=""" Mask2Former's pixel level module output. It returns the output of the encoder (optional) and all hidden states (multi-scale features) from the `decoder`. By default, the `encoder` is a Swin Backbone and the `decoder` is a Multi-Scale Deformable Attention based decoder. The `decoder_last_hidden_state` are the **per-pixel embeddings** while `decoder_hidden_states` refer to multi-scale feature maps produced using **multi-scaling strategy** defined in the paper. """ ) class Mask2FormerPixelLevelModuleOutput(ModelOutput): r""" encoder_last_hidden_state (`torch.FloatTensor`): Last hidden states (final feature map of shape `(batch_size, num_channels, height, width)`) of the last stage of the encoder. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden states (also called feature maps) of the model at the output of each stage. Returned if output_hidden_states is set to True. decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)): 1/4 scale features from the last Pixel Decoder Layer. decoder_hidden_states (`tuple(torch.FloatTensor)`): Tuple of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden states (also called feature maps) of the model at the output of each stage. """ encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None decoder_last_hidden_state: Optional[torch.FloatTensor] = None decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None @dataclass @auto_docstring( custom_intro=""" Class for outputs of [`Mask2FormerModel`]. This class returns all the needed hidden states to compute the logits. """ ) class Mask2FormerModelOutput(ModelOutput): r""" encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): Last hidden states (final feature map) of the last stage of the encoder model (backbone). Returned when `output_hidden_states=True` is passed. pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): Last hidden states (final feature map) of the last stage of the pixel decoder model. transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`): Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder model at the output of each stage. Returned when `output_hidden_states=True` is passed. pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, , *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel decoder model at the output of each stage. Returned when `output_hidden_states=True` is passed. transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the transformer decoder at the output of each stage. Returned when `output_hidden_states=True` is passed. transformer_decoder_intermediate_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`): Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a layernorm. masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`) Mask Predictions from each layer in the transformer decoder. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed): Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Self attentions weights from transformer decoder. """ encoder_last_hidden_state: Optional[torch.FloatTensor] = None pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None pixel_decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None transformer_decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None transformer_decoder_intermediate_states: Optional[tuple[torch.FloatTensor]] = None masks_queries_logits: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None @dataclass @auto_docstring( custom_intro=""" Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`]. This output can be directly passed to [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or [`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or [`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see [`~Mask2FormerImageProcessor] for details regarding usage. """ ) class Mask2FormerForUniversalSegmentationOutput(ModelOutput): r""" loss (`torch.Tensor`, *optional*): The computed loss, returned when labels are present. class_queries_logits (`torch.FloatTensor`): A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each query. Note the `+ 1` is needed because we incorporate the null class. masks_queries_logits (`torch.FloatTensor`): A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each query. auxiliary_logits (`list[Dict(str, torch.FloatTensor)]`, *optional*): List of class and mask predictions from each layer of the transformer decoder. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Last hidden states (final feature map) of the last stage of the encoder model (backbone). pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Last hidden states (final feature map) of the last stage of the pixel decoder model. transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`): Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder model at the output of each stage. pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel decoder model at the output of each stage. transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the transformer decoder at the output of each stage. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Self and Cross Attentions weights from transformer decoder. """ loss: Optional[torch.FloatTensor] = None class_queries_logits: Optional[torch.FloatTensor] = None masks_queries_logits: Optional[torch.FloatTensor] = None auxiliary_logits: Optional[list[dict[str, torch.FloatTensor]]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None pixel_decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None attentions: Optional[tuple[torch.FloatTensor]] = None # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py def sample_point( input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs ) -> torch.Tensor: """ A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. Args: input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): A tensor that contains features map on a height * width grid point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: 2)): A tensor that contains [0, 1] * [0, 1] normalized point coordinates add_dim (`bool`): boolean value to keep track of added dimension Returns: point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, height_grid, width_grid): A tensor that contains features for points in `point_coordinates`. """ if point_coordinates.dim() == 3: add_dim = True point_coordinates = point_coordinates.unsqueeze(2) # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs) if add_dim: point_features = point_features.squeeze(3) return point_features # Copied from transformers.models.maskformer.modeling_maskformer.dice_loss def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: r""" Compute the DICE loss, similar to generalized IOU for masks as follows: $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ Args: inputs (`torch.Tensor`): A tensor representing a mask. labels (`torch.Tensor`): A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class). num_masks (`int`): The number of masks present in the current batch, used for normalization. Returns: `torch.Tensor`: The computed loss. """ probs = inputs.sigmoid().flatten(1) numerator = 2 * (probs * labels).sum(-1) denominator = probs.sum(-1) + labels.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) loss = loss.sum() / num_masks return loss def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor: r""" Args: inputs (`torch.Tensor`): A float tensor of arbitrary shape. labels (`torch.Tensor`): A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class). Returns: loss (`torch.Tensor`): The computed loss. """ criterion = nn.BCEWithLogitsLoss(reduction="none") cross_entropy_loss = criterion(inputs, labels) loss = cross_entropy_loss.mean(1).sum() / num_masks return loss # Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: """ A pair wise version of the dice loss, see `dice_loss` for usage. Args: inputs (`torch.Tensor`): A tensor representing a mask labels (`torch.Tensor`): A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class). Returns: `torch.Tensor`: The computed loss between each pairs. """ inputs = inputs.sigmoid().flatten(1) numerator = 2 * torch.matmul(inputs, labels.T) # using broadcasting to get a [num_queries, NUM_CLASSES] matrix denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: r""" A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. Args: inputs (`torch.Tensor`): A tensor representing a mask. labels (`torch.Tensor`): A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class). Returns: loss (`torch.Tensor`): The computed loss between each pairs. """ height_and_width = inputs.shape[1] criterion = nn.BCEWithLogitsLoss(reduction="none") cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T) loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T) loss = loss_pos + loss_neg return loss # Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py class Mask2FormerHungarianMatcher(nn.Module): """This class computes an assignment between the labels and the predictions of the network. For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__( self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544 ): """Creates the matcher Params: cost_class (`float`, *optional*, defaults to 1.0): Relative weight of the classification error in the matching cost. cost_mask (`float`, *optional*, defaults to 1.0): This is the relative weight of the focal loss of the binary mask in the matching cost. cost_dice (`float`, *optional*, defaults to 1.0): This is the relative weight of the dice loss of the binary mask in the matching cost. num_points (`int`, *optional*, defaults to 12544): No. of points to sample on which the mask loss will be calculated. The same set of K points are uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite matching. """ super().__init__() if cost_class == 0 and cost_mask == 0 and cost_dice == 0: raise ValueError("All costs can't be 0") self.num_points = num_points self.cost_class = cost_class self.cost_mask = cost_mask self.cost_dice = cost_dice @torch.no_grad() def forward( self, masks_queries_logits: torch.Tensor, class_queries_logits: torch.Tensor, mask_labels: torch.Tensor, class_labels: torch.Tensor, ) -> list[tuple[Tensor]]: """ Params: masks_queries_logits (`torch.Tensor`): A tensor of dim `batch_size, num_queries, num_labels` with the classification logits. class_queries_logits (`torch.Tensor`): A tensor of dim `batch_size, num_queries, height, width` with the predicted masks. class_labels (`torch.Tensor`): A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels. mask_labels (`torch.Tensor`): A tensor of dim `num_target_boxes, height, width` containing the target masks. Returns: matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected labels (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes). """ indices: list[tuple[np.array]] = [] # iterate through batch size batch_size = masks_queries_logits.shape[0] for i in range(batch_size): pred_probs = class_queries_logits[i].softmax(-1) pred_mask = masks_queries_logits[i] # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted. cost_class = -pred_probs[:, class_labels[i]] target_mask = mask_labels[i].to(pred_mask) target_mask = target_mask[:, None] pred_mask = pred_mask[:, None] # Sample ground truth and predicted masks point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device) target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1) target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1) pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels) cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels) cost_dice = pair_wise_dice_loss(pred_mask, target_mask) # final cost matrix cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible`` cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10)) cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10)) cost_matrix = torch.nan_to_num(cost_matrix, 0) # do the assignment using the hungarian algorithm in scipy assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) indices.append(assigned_indices) # It could be stacked in one tensor matched_indices = [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] return matched_indices # Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py class Mask2FormerLoss(nn.Module): def __init__(self, config: Mask2FormerConfig, weight_dict: dict[str, float]): """ The Mask2Former Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and mask) Args: config (`Mask2FormerConfig`): The configuration for Mask2Former model also containing loss calculation specific parameters. weight_dict (`dict[str, float]`): A dictionary of weights to be applied to the different losses. """ super().__init__() requires_backends(self, ["scipy"]) self.num_labels = config.num_labels self.weight_dict = weight_dict # Weight to apply to the null class self.eos_coef = config.no_object_weight empty_weight = torch.ones(self.num_labels + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) # pointwise mask loss parameters self.num_points = config.train_num_points self.oversample_ratio = config.oversample_ratio self.importance_sample_ratio = config.importance_sample_ratio self.matcher = Mask2FormerHungarianMatcher( cost_class=config.class_weight, cost_dice=config.dice_weight, cost_mask=config.mask_weight, num_points=self.num_points, ) def _max_by_axis(self, sizes: list[list[int]]) -> list[int]: maxes = sizes[0] for sublist in sizes[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes # Adapted from nested_tensor_from_tensor_list() in original implementation def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]: # get the maximum size in the batch max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) # compute final size batch_shape = [len(tensors)] + max_size batch_size, _, height, width = batch_shape dtype = tensors[0].dtype device = tensors[0].device padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) # pad the tensors to the size of the biggest one for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) padding_mask[: tensor.shape[1], : tensor.shape[2]] = False return padded_tensors, padding_masks def loss_labels( self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array] ) -> dict[str, Tensor]: """Compute the losses related to the labels using cross entropy. Args: class_queries_logits (`torch.Tensor`): A tensor of shape `batch_size, num_queries, num_labels` class_labels (`list[torch.Tensor]`): List of class labels of shape `(labels)`. indices (`tuple[np.array])`: The indices computed by the Hungarian matcher. Returns: `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. """ pred_logits = class_queries_logits batch_size, num_queries, _ = pred_logits.shape criterion = nn.CrossEntropyLoss(weight=self.empty_weight) idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries) target_classes_o = torch.cat( [target[j] for target, (_, j) in zip(class_labels, indices)] ) # shape of (batch_size, num_queries) target_classes = torch.full( (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device ) target_classes[idx] = target_classes_o # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) pred_logits_transposed = pred_logits.transpose(1, 2) loss_ce = criterion(pred_logits_transposed, target_classes) losses = {"loss_cross_entropy": loss_ce} return losses def loss_masks( self, masks_queries_logits: torch.Tensor, mask_labels: list[torch.Tensor], indices: tuple[np.array], num_masks: int, ) -> dict[str, torch.Tensor]: """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss. Args: masks_queries_logits (`torch.Tensor`): A tensor of shape `(batch_size, num_queries, height, width)`. mask_labels (`torch.Tensor`): List of mask labels of shape `(labels, height, width)`. indices (`tuple[np.array])`: The indices computed by the Hungarian matcher. num_masks (`int)`: The number of masks, used for normalization. Returns: losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys: - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth. masks. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth, masks. """ src_idx = self._get_predictions_permutation_indices(indices) tgt_idx = self._get_targets_permutation_indices(indices) # shape (batch_size * num_queries, height, width) pred_masks = masks_queries_logits[src_idx] # shape (batch_size, num_queries, height, width) # pad all and stack the targets to the num_labels dimension target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) target_masks = target_masks[tgt_idx] # No need to upsample predictions as we are using normalized coordinates pred_masks = pred_masks[:, None] target_masks = target_masks[:, None] # Sample point coordinates with torch.no_grad(): point_coordinates = self.sample_points_using_uncertainty( pred_masks, lambda logits: self.calculate_uncertainty(logits), self.num_points, self.oversample_ratio, self.importance_sample_ratio, ) point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1) point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1) losses = { "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), "loss_dice": dice_loss(point_logits, point_labels, num_masks), } del pred_masks del target_masks return losses def _get_predictions_permutation_indices(self, indices): # Permute predictions following indices batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) predictions_indices = torch.cat([src for (src, _) in indices]) return batch_indices, predictions_indices def _get_targets_permutation_indices(self, indices): # Permute labels following indices batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) target_indices = torch.cat([tgt for (_, tgt) in indices]) return batch_indices, target_indices def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor: """ In Mask2Former paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (`torch.Tensor`): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is: the number of foreground classes. The values are logits. Returns: scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ uncertainty_scores = -(torch.abs(logits)) return uncertainty_scores def sample_points_using_uncertainty( self, logits: torch.Tensor, uncertainty_function, num_points: int, oversample_ratio: int, importance_sample_ratio: float, ) -> torch.Tensor: """ This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit prediction as input. Args: logits (`float`): Logit predictions for P points. uncertainty_function: A function that takes logit predictions for P points and returns their uncertainties. num_points (`int`): The number of points P to sample. oversample_ratio (`int`): Oversampling parameter. importance_sample_ratio (`float`): Ratio of points that are sampled via importance sampling. Returns: point_coordinates (`torch.Tensor`): Coordinates for P sampled points. """ num_boxes = logits.shape[0] num_points_sampled = int(num_points * oversample_ratio) # Get random point coordinates point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) # Get sampled prediction value for the point coordinates point_logits = sample_point(logits, point_coordinates, align_corners=False) # Calculate the uncertainties based on the sampled prediction values of the points point_uncertainties = uncertainty_function(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) idx += shift[:, None] point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) if num_random_points > 0: point_coordinates = torch.cat( [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], dim=1, ) return point_coordinates def forward( self, masks_queries_logits: torch.Tensor, class_queries_logits: torch.Tensor, mask_labels: list[torch.Tensor], class_labels: list[torch.Tensor], auxiliary_predictions: Optional[dict[str, torch.Tensor]] = None, ) -> dict[str, torch.Tensor]: """ This performs the loss computation. Args: masks_queries_logits (`torch.Tensor`): A tensor of shape `(batch_size, num_queries, height, width)`. class_queries_logits (`torch.Tensor`): A tensor of shape `(batch_size, num_queries, num_labels)`. mask_labels (`torch.Tensor`): List of mask labels of shape `(labels, height, width)`. class_labels (`list[torch.Tensor]`): List of class labels of shape `(labels)`. auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*): if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], then it contains the logits from the inner layers of the Mask2FormerMaskedAttentionDecoder. Returns: losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys: - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth masks. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth masks. if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], the dictionary contains additional losses for each auxiliary predictions. """ # retrieve the matching between the outputs of the last layer and the labels indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) # compute the average number of target masks for normalization purposes num_masks = self.get_num_masks(class_labels, device=class_labels[0].device) # get all the losses losses: dict[str, Tensor] = { **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), **self.loss_labels(class_queries_logits, class_labels, indices), } # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. if auxiliary_predictions is not None: for idx, aux_outputs in enumerate(auxiliary_predictions): masks_queries_logits = aux_outputs["masks_queries_logits"] class_queries_logits = aux_outputs["class_queries_logits"] loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels) loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} losses.update(loss_dict) return losses def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: """ Computes the average number of target masks across the batch, for normalization purposes. """ num_masks = sum(len(classes) for classes in class_labels) num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device) world_size = 1 if is_accelerate_available(): if PartialState._shared_state != {}: num_masks = reduce(num_masks) world_size = PartialState().num_processes num_masks = torch.clamp(num_masks / world_size, min=1) return num_masks # Copied from transformers.models.oneformer.modeling_oneformer.multi_scale_deformable_attention def multi_scale_deformable_attention( value: Tensor, value_spatial_shapes: Union[Tensor, list[tuple]], sampling_locations: Tensor, attention_weights: Tensor, ) -> Tensor: batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes): # batch_size, height*width, num_heads, hidden_dim # -> batch_size, height*width, num_heads*hidden_dim # -> batch_size, num_heads*hidden_dim, height*width # -> batch_size*num_heads, hidden_dim, height, width value_l_ = ( value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) ) # batch_size, num_queries, num_heads, num_points, 2 # -> batch_size, num_heads, num_queries, num_points, 2 # -> batch_size*num_heads, num_queries, num_points, 2 sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) # batch_size*num_heads, hidden_dim, num_queries, num_points sampling_value_l_ = nn.functional.grid_sample( value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False ) sampling_value_list.append(sampling_value_l_) # (batch_size, num_queries, num_heads, num_levels, num_points) # -> (batch_size, num_heads, num_queries, num_levels, num_points) # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) attention_weights = attention_weights.transpose(1, 2).reshape( batch_size * num_heads, 1, num_queries, num_levels * num_points ) output = ( (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) .sum(-1) .view(batch_size, num_heads * hidden_dim, num_queries) ) return output.transpose(1, 2).contiguous() # Copied from transformers.models.maskformer.modeling_maskformer.MaskFormerSinePositionEmbedding with MaskFormer->Mask2Former class Mask2FormerSinePositionEmbedding(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__( self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None ): super().__init__() if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize self.scale = 2 * math.pi if scale is None else scale @compile_compatible_method_lru_cache(maxsize=1) def forward( self, shape: torch.Size, device: Union[torch.device, str], dtype: torch.dtype, mask: Optional[Tensor] = None, ) -> Tensor: if mask is None: mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) not_mask = (~mask).to(dtype) y_embed = not_mask.cumsum(1) x_embed = not_mask.cumsum(2) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos # Modified from transformers.models.detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module): """ Multiscale deformable attention as proposed in Deformable DETR. """ def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): super().__init__() if embed_dim % num_heads != 0: raise ValueError( f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}" ) dim_per_head = embed_dim // num_heads # check if dim_per_head is power of 2 if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): warnings.warn( "You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the" " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" " implementation." ) self.im2col_step = 128 self.d_model = embed_dim self.n_levels = n_levels self.n_heads = num_heads self.n_points = n_points self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points) self.value_proj = nn.Linear(embed_dim, embed_dim) self.output_proj = nn.Linear(embed_dim, embed_dim) def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): return tensor if position_embeddings is None else tensor + position_embeddings def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states=None, encoder_attention_mask=None, position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): # add position embeddings to the hidden states before projecting to queries and keys if position_embeddings is not None: hidden_states = self.with_pos_embed(hidden_states, position_embeddings) batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape total_elements = sum(height * width for height, width in spatial_shapes_list) if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" ) value = self.value_proj(encoder_hidden_states) if attention_mask is not None: # we invert the attention_mask value = value.masked_fill(attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 ) attention_weights = self.attention_weights(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels * self.n_points ) attention_weights = nn.functional.softmax(attention_weights, -1).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.tensor( [[shape[1], shape[0]] for shape in spatial_shapes_list], dtype=torch.long, device=reference_points.device, ) sampling_locations = ( reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] ) elif reference_points.shape[-1] == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") output = multi_scale_deformable_attention(value, spatial_shapes_list, sampling_locations, attention_weights) output = self.output_proj(output) return output, attention_weights class Mask2FormerPixelDecoderEncoderLayer(nn.Module): def __init__(self, config: Mask2FormerConfig): super().__init__() self.embed_dim = config.feature_size self.self_attn = Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, n_levels=3, n_points=4, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = nn.functional.relu self.activation_dropout = config.dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim) self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): """ Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Input to the layer. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Attention mask. position_embeddings (`torch.FloatTensor`, *optional*): Position embeddings, to be added to `hidden_states`. reference_points (`torch.FloatTensor`, *optional*): Reference points. spatial_shapes_list (`list` of `tuple`): Spatial shapes of the backbone feature maps as a list of tuples. level_start_index (`torch.LongTensor`, *optional*): Level start index. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps. hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) if self.training: if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights.transpose(1, 0),) return outputs # Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetrEncoder->Mask2FormerPixelDecoderEncoderOnly class Mask2FormerPixelDecoderEncoderOnly(nn.Module): """ Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a [`Mask2FormerPixelDecoderEncoderLayer`]. The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers. Args: config: Mask2FormerConfig """ def __init__(self, config: Mask2FormerConfig): super().__init__() self.config = config self.dropout = config.dropout self.layers = nn.ModuleList( [Mask2FormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)] ) @staticmethod def get_reference_points(spatial_shapes_list, valid_ratios, device): """ Get reference points for each feature map. Used in decoder. Args: spatial_shapes_list (`list` of `tuple`): Spatial shapes of the backbone feature maps as a list of tuples. valid_ratios (`torch.FloatTensor`): Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`. device (`torch.device`): Device on which to create the tensors. Returns: `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` """ reference_points_list = [] for lvl, (height, width) in enumerate(spatial_shapes_list): ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), indexing="ij", ) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward( self, inputs_embeds=None, attention_mask=None, position_embeddings=None, spatial_shapes_list=None, level_start_index=None, valid_ratios=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: - 1 for pixel features that are real (i.e. **not masked**), - 0 for pixel features that are padding (i.e. **masked**). [What are attention masks?](../glossary#attention-mask) position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Position embeddings that are added to the queries and keys in each self-attention layer. spatial_shapes_list (`list` of `tuple`): Spatial shapes of each feature map as a list of tuples. level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): Starting index of each feature map. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): Ratio of valid area in each feature level. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = inputs_embeds reference_points = self.get_reference_points(spatial_shapes_list, valid_ratios, device=inputs_embeds.device) all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, encoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states.transpose(1, 0),) layer_outputs = encoder_layer( hidden_states, attention_mask, position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states += (hidden_states.transpose(1, 0),) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions ) # Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrModel with DeformableDetrModel->Mask2FormerPixelDecoder class Mask2FormerPixelDecoder(nn.Module): def __init__(self, config: Mask2FormerConfig, feature_channels): super().__init__() self.config = config feature_dim = config.feature_size mask_dim = config.mask_feature_size num_pos_features = feature_dim // 2 self.position_embedding = Mask2FormerSinePositionEmbedding(num_pos_feats=num_pos_features, normalize=True) self.num_feature_levels = 3 transformer_in_channels = feature_channels[-self.num_feature_levels :] self.transformer_feature_strides = config.feature_strides[-self.num_feature_levels :] self.feature_channels = feature_channels self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, feature_dim)) # Create input projection layers if self.num_feature_levels > 1: input_projections_list = [] for in_channels in transformer_in_channels[::-1]: input_projections_list.append( nn.Sequential( nn.Conv2d(in_channels, feature_dim, kernel_size=1), nn.GroupNorm(32, feature_dim), ) ) self.input_projections = nn.ModuleList(input_projections_list) else: self.input_projections = nn.ModuleList( [ nn.Sequential( nn.Conv2d(transformer_in_channels[-1], feature_dim, kernel_size=1), nn.GroupNorm(32, feature_dim), ) ] ) self.encoder = Mask2FormerPixelDecoderEncoderOnly(config) self.mask_projection = nn.Conv2d(feature_dim, mask_dim, kernel_size=1, stride=1, padding=0) # Extra FPN levels stride = min(self.transformer_feature_strides) self.common_stride = config.common_stride self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) lateral_convs = [] output_convs = [] for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]): lateral_conv = nn.Sequential( nn.Conv2d(in_channels, feature_dim, kernel_size=1, bias=False), nn.GroupNorm(32, feature_dim), ) output_conv = nn.Sequential( nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False), nn.GroupNorm(32, feature_dim), nn.ReLU(), ) self.add_module(f"adapter_{idx + 1}", lateral_conv) self.add_module(f"layer_{idx + 1}", output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Order convolutional layers from low to high resolution self.lateral_convolutions = lateral_convs[::-1] self.output_convolutions = output_convs[::-1] def get_valid_ratio(self, mask, dtype=torch.float32): """Get the valid ratio of all feature maps.""" _, height, width = mask.shape valid_height = torch.sum(~mask[:, :, 0], 1) valid_width = torch.sum(~mask[:, 0, :], 1) valid_ratio_height = valid_height.to(dtype) / height valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) return valid_ratio def forward( self, features, encoder_outputs=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # Apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) input_embeds = [] position_embeddings = [] for level, x in enumerate(features[::-1][: self.num_feature_levels]): input_embeds.append(self.input_projections[level](x)) position_embeddings.append(self.position_embedding(x.shape, x.device, x.dtype)) masks = [ torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds ] # Prepare encoder inputs (by flattening) spatial_shapes_list = [(embed.shape[2], embed.shape[3]) for embed in input_embeds] input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1) spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device) masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1) position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings] level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)] level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1) # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder if encoder_outputs is None: encoder_outputs = self.encoder( inputs_embeds=input_embeds_flat, attention_mask=masks_flat, position_embeddings=level_pos_embed_flat, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs.last_hidden_state batch_size = last_hidden_state.shape[0] # We compute level_start_index_list separately from the tensor version level_start_index # to avoid iterating over a tensor which breaks torch.compile/export. level_start_index_list = [0] for height, width in spatial_shapes_list[:-1]: level_start_index_list.append(level_start_index_list[-1] + height * width) split_sizes = [None] * self.num_feature_levels for i in range(self.num_feature_levels): if i < self.num_feature_levels - 1: split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i] else: split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i] encoder_output = torch.split(last_hidden_state, split_sizes, dim=1) # Compute final features outputs = [ x.transpose(1, 2).view(batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1]) for i, x in enumerate(encoder_output) ] # Append extra FPN levels to outputs, ordered from low to high resolution for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]): lateral_conv = self.lateral_convolutions[idx] output_conv = self.output_convolutions[idx] current_fpn = lateral_conv(feature) # Following FPN implementation, we use nearest upsampling here out = current_fpn + nn.functional.interpolate( outputs[-1], size=current_fpn.shape[-2:], mode="bilinear", align_corners=False ) out = output_conv(out) outputs.append(out) num_cur_levels = 0 multi_scale_features = [] for out in outputs: if num_cur_levels < self.num_feature_levels: multi_scale_features.append(out) num_cur_levels += 1 return Mask2FormerPixelDecoderOutput( mask_features=self.mask_projection(outputs[-1]), multi_scale_features=tuple(multi_scale_features), attentions=encoder_outputs.attentions, ) class Mask2FormerPixelLevelModule(nn.Module): def __init__(self, config: Mask2FormerConfig): """ Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image Segmentation](https://huggingface.co/papers/2112.01527). It runs the input image through a backbone and a pixel decoder, generating multi-scale feature maps and pixel embeddings. Args: config ([`Mask2FormerConfig`]): The configuration used to instantiate this model. """ super().__init__() self.encoder = load_backbone(config) self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: backbone_features = self.encoder(pixel_values).feature_maps decoder_output = self.decoder(backbone_features, output_hidden_states=output_hidden_states) return Mask2FormerPixelLevelModuleOutput( encoder_last_hidden_state=backbone_features[-1], encoder_hidden_states=tuple(backbone_features) if output_hidden_states else None, decoder_last_hidden_state=decoder_output.mask_features, decoder_hidden_states=decoder_output.multi_scale_features, ) # Modified from transformers.models.detr.modeling_detr.DetrAttention with Detr->Mask2Former class Mask2FormerAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and keys (as explained in the DETR paper). """ def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads if self.head_dim * num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {num_heads})." ) self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): return tensor if position_embeddings is None else tensor + position_embeddings def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None, key_value_position_embeddings: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" hidden_states = hidden_states.permute(1, 0, 2) if hidden_states is not None else None position_embeddings = position_embeddings.permute(1, 0, 2) if position_embeddings is not None else None key_value_states = key_value_states.permute(1, 0, 2) if key_value_states is not None else None key_value_position_embeddings = ( key_value_position_embeddings.permute(1, 0, 2) if key_value_position_embeddings is not None else None ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size, target_len, embed_dim = hidden_states.size() # add position embeddings to the hidden states before projecting to queries and keys if position_embeddings is not None: hidden_states_original = hidden_states hidden_states = self.with_pos_embed(hidden_states, position_embeddings) # add key-value position embeddings to the key value states if key_value_position_embeddings is not None: key_value_states_original = key_value_states key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) proj_shape = (batch_size * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) source_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): raise ValueError( f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (batch_size * self.num_heads, target_len, source_len): raise ValueError( f"Attention mask should be of size {(target_len, batch_size * self.num_heads, source_len)}, but is" f" {attention_mask.size()}" ) attn_weights += attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(batch_size, target_len, embed_dim) attn_output = self.out_proj(attn_output).permute(1, 0, 2) return attn_output, attn_weights_reshaped class Mask2FormerMaskedAttentionDecoderLayer(GradientCheckpointingLayer): """ The Mask2FormerMaskedAttentionDecoderLayer is made up of self-attention, cross (masked) attention as well as FFN blocks. The cross attention block used as part of `Mask2FormerMaskedAttentionDecoderLayer` is actually a `masked attention` block that restricts the attention to localized features centered around predicted segments which leads to faster convergence and improved performance. The order of self and cross (i.e. masked) attention blocks have also been swapped in Mask2FormerMaskedAttentionDecoder compared to a standard DetrDecoder as an optimization improvement. Args: config (`Mask2FormerConfig`): The configuration used to initialize the Mask2FormerMaskedAttentionDecoder. """ def __init__(self, config: Mask2FormerConfig): super().__init__() self.config = config self.embed_dim = self.config.hidden_dim self.pre_norm = self.config.pre_norm self.self_attn = Mask2FormerAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.dropout, is_decoder=True, ) self.dropout = self.config.dropout self.activation_fn = ACT2FN[self.config.activation_function] self.activation_dropout = self.config.dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.cross_attn = nn.MultiheadAttention(self.embed_dim, self.config.num_attention_heads, self.config.dropout) self.cross_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, self.config.dim_feedforward) self.fc2 = nn.Linear(self.config.dim_feedforward, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, hidden_states: torch.Tensor, level_index: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[torch.Tensor] = None, query_position_embeddings: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ): # Masked(Cross)-Attention Block cross_attn_weights = None self_attn_weights = None residual = hidden_states hidden_states, cross_attn_weights = self.cross_attn( query=self.with_pos_embed(hidden_states, query_position_embeddings), key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]), value=encoder_hidden_states[level_index], attn_mask=encoder_attention_mask, key_padding_mask=None, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.cross_attn_layer_norm(hidden_states) # Self Attention Block residual = hidden_states hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=query_position_embeddings, attention_mask=None, output_attentions=True, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) return outputs def forward_pre( self, hidden_states: torch.Tensor, level_index: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[torch.Tensor] = None, query_position_embeddings: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ): # Masked(Cross)-Attention Block cross_attn_weights = None self_attn_weights = None residual = hidden_states hidden_states = self.cross_attn_layer_norm(hidden_states) hidden_states, cross_attn_weights = self.cross_attn( query=self.with_pos_embed(hidden_states, query_position_embeddings), key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]), value=encoder_hidden_states[level_index], attn_mask=encoder_attention_mask, key_padding_mask=None, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Self Attention Block residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=query_position_embeddings, attention_mask=None, output_attentions=True, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) return outputs def forward( self, hidden_states: torch.Tensor, level_index: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[torch.Tensor] = None, query_position_embeddings: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ): """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(seq_len, batch, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(1, seq_len, tgt_len, src_len)`. position_embeddings (`torch.FloatTensor`, *optional*): Position embeddings that are added to the keys in the masked-attention layer. query_position_embeddings (`torch.FloatTensor`, *optional*): Position embeddings that are added to the queries and keys in the self-attention layer. encoder_hidden_states (`torch.FloatTensor`): Cross attention input to the layer of shape `(seq_len, batch, embed_dim)`. encoder_attention_mask (`torch.FloatTensor`): Encoder attention mask of size`(1, seq_len, tgt_len, src_len)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ if self.pre_norm: outputs = self.forward_pre( hidden_states=hidden_states, level_index=level_index, position_embeddings=position_embeddings, query_position_embeddings=query_position_embeddings, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) else: outputs = self.forward_post( hidden_states=hidden_states, level_index=level_index, position_embeddings=position_embeddings, query_position_embeddings=query_position_embeddings, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) return outputs class Mask2FormerMaskedAttentionDecoder(nn.Module): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Mask2FormerMaskedAttentionDecoderLayer`]. The decoder updates the query embeddings through multiple cross (masked) and self-attention layers. The decoder uses a new **masked attention** mechanism instead of the standard cross-attention, which extracts localized features by constraining cross-attention to within the foreground region of the predicted mask for each query, instead of attending to the full feature map. Args: config (`Mask2FormerConfig`): Configuration used to instantiate Mask2FormerMaskedAttentionDecoder. """ def __init__(self, config: Mask2FormerConfig): super().__init__() self.config = config self.mask_feature_size = config.mask_feature_size self.dropout = config.dropout self.layerdrop = config.dropout self.num_feature_levels = 3 # level embedding (3 scales) self.decoder_layers = config.decoder_layers - 1 self.layers = nn.ModuleList( [Mask2FormerMaskedAttentionDecoderLayer(self.config) for _ in range(self.decoder_layers)] ) self.layernorm = nn.LayerNorm(config.hidden_dim) self.mask_predictor = Mask2FormerMaskPredictor( hidden_size=config.hidden_dim, num_heads=config.num_attention_heads, mask_feature_size=self.mask_feature_size, ) self.gradient_checkpointing = False def forward( self, inputs_embeds: Optional[torch.Tensor] = None, multi_stage_positional_embeddings: Optional[torch.Tensor] = None, pixel_embeddings: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, query_position_embeddings: Optional[torch.Tensor] = None, feature_size_list: Optional[list] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): The query embeddings that are passed into the decoder. multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): Position embeddings that are added to the keys in each cross(masked)-attention layer. pixel_embeddings (`torch.FloatTensor`): Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder. query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross(masked)-attention of the decoder. feature_size_list (`list[torch.Size]`): This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is not None: hidden_states = inputs_embeds # intermediate hidden states with layernorm applied - required for predicting class logits intermediate = () # decoder layers all_hidden_states = () if output_hidden_states else None attentions = () if output_attentions else None # intermediate mask predictions from transformer decoder layers intermediate_mask_predictions = () intermediate_hidden_states = self.layernorm(inputs_embeds) intermediate += (intermediate_hidden_states,) predicted_mask, attention_mask = self.mask_predictor( intermediate_hidden_states, pixel_embeddings, feature_size_list[0] ) intermediate_mask_predictions += (predicted_mask,) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = torch.rand([]) if self.training and (dropout_probability < self.layerdrop): continue level_index = idx % self.num_feature_levels where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) # Multiply the attention mask instead of indexing to avoid issue in torch.export. attention_mask = attention_mask * where.unsqueeze(-1) layer_outputs = decoder_layer( hidden_states, level_index, None, # attention_mask multi_stage_positional_embeddings, query_position_embeddings, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=attention_mask, output_attentions=output_attentions, ) intermediate_hidden_states = self.layernorm(layer_outputs[0]) predicted_mask, attention_mask = self.mask_predictor( intermediate_hidden_states, pixel_embeddings, feature_size_list[(idx + 1) % self.num_feature_levels], ) intermediate_mask_predictions += (predicted_mask,) # add intermediate hidden states with layer norm applied which will be used for predicting class logits intermediate += (intermediate_hidden_states,) hidden_states = layer_outputs[0] if output_attentions: attentions += (layer_outputs[1],) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) hidden_states = hidden_states.transpose(1, 0) if not return_dict: outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] return tuple(v for v in outputs if v is not None) return Mask2FormerMaskedAttentionDecoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=attentions, intermediate_hidden_states=intermediate, masks_queries_logits=intermediate_mask_predictions, ) # Copied from transformers.models.maskformer.modeling_maskformer.PredictionBlock with MaskFormer->Mask2Former class Mask2FormerPredictionBlock(nn.Module): def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None: super().__init__() self.layers = [nn.Linear(in_dim, out_dim), activation] # Maintain submodule indexing as if part of a Sequential block for i, layer in enumerate(self.layers): self.add_module(str(i), layer) def forward(self, input: Tensor) -> Tensor: hidden_state = input for layer in self.layers: hidden_state = layer(hidden_state) return hidden_state class Mask2FormerMLPPredictionHead(nn.Module): def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3): """ A classic Multi Layer Perceptron (MLP). Args: input_dim (`int`): The input dimensions. hidden_dim (`int`): The hidden dimensions. output_dim (`int`): The output dimensions. num_layers (int, *optional*, defaults to 3): The number of layers. """ super().__init__() in_dims = [input_dim] + [hidden_dim] * (num_layers - 1) out_dims = [hidden_dim] * (num_layers - 1) + [output_dim] self.layers = [] for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): activation = nn.ReLU() if i < num_layers - 1 else nn.Identity() layer = Mask2FormerPredictionBlock(in_dim, out_dim, activation=activation) self.layers.append(layer) # Provide backwards compatibility from when the class inherited from nn.Sequential # In nn.Sequential subclasses, the name given to the layer is its index in the sequence. # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g. # self.my_layer_name = Layer() # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register # explicitly self.add_module(str(i), layer) def forward(self, input: Tensor) -> Tensor: hidden_state = input for layer in self.layers: hidden_state = layer(hidden_state) return hidden_state class Mask2FormerMaskPredictor(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mask_feature_size: torch.Tensor): """ This class is used to get the predicted mask for a given Mask2FormerMaskedAttentionDecoder layer. It also generates the binarized attention mask associated with the given predicted mask. The attention mask obtained using predicted mask of the (l-1)th decoder layer is fed to the cross(masked)-attention block of the next decoder layer as input. Args: hidden_size (`int`): The feature dimension of the Mask2FormerMaskedAttentionDecoder num_heads (`int`): The number of heads used in the Mask2FormerMaskedAttentionDecoder mask_feature_size (`torch.Tensor`): one of the output dimensions of the predicted masks for each query """ super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.mask_embedder = Mask2FormerMLPPredictionHead(self.hidden_size, self.hidden_size, mask_feature_size) def forward( self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: Optional[int] = None ): mask_embeddings = self.mask_embedder(outputs.transpose(0, 1)) # Sum up over the channels outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings) attention_mask = nn.functional.interpolate( outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False ) attention_mask = attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1) attention_mask = (attention_mask.flatten(0, 1) < 0.5).bool() attention_mask = attention_mask.detach() return outputs_mask, attention_mask class Mask2FormerTransformerModule(nn.Module): """ The Mask2Former's transformer module. """ def __init__(self, in_features: int, config: Mask2FormerConfig): super().__init__() hidden_dim = config.hidden_dim self.num_feature_levels = 3 self.position_embedder = Mask2FormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True) self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim) self.queries_features = nn.Embedding(config.num_queries, hidden_dim) self.input_projections = [] for _ in range(self.num_feature_levels): if in_features != hidden_dim or config.enforce_input_projection: self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1)) else: self.input_projections.append(nn.Sequential()) self.decoder = Mask2FormerMaskedAttentionDecoder(config=config) self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) def forward( self, multi_scale_features: list[Tensor], mask_features: Tensor, output_hidden_states: bool = False, output_attentions: bool = False, ) -> Mask2FormerMaskedAttentionDecoderOutput: multi_stage_features = [] multi_stage_positional_embeddings = [] size_list = [] for i in range(self.num_feature_levels): size_list.append(multi_scale_features[i].shape[-2:]) multi_stage_positional_embeddings.append( self.position_embedder( multi_scale_features[i].shape, multi_scale_features[i].device, multi_scale_features[i].dtype, None ).flatten(2) ) multi_stage_features.append( self.input_projections[i](multi_scale_features[i]).flatten(2) + self.level_embed.weight[i][None, :, None] ) # Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels) multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) _, batch_size, _ = multi_stage_features[0].shape # [num_queries, batch_size, num_channels] query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) decoder_output = self.decoder( inputs_embeds=query_features, multi_stage_positional_embeddings=multi_stage_positional_embeddings, pixel_embeddings=mask_features, encoder_hidden_states=multi_stage_features, query_position_embeddings=query_embeddings, feature_size_list=size_list, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=True, ) return decoder_output @auto_docstring class Mask2FormerPreTrainedModel(PreTrainedModel): config: Mask2FormerConfig base_model_prefix = "model" main_input_name = "pixel_values" def _init_weights(self, module: nn.Module): xavier_std = self.config.init_xavier_std std = self.config.init_std if isinstance(module, Mask2FormerTransformerModule): if module.input_projections is not None: for input_projection in module.input_projections: if not isinstance(input_projection, nn.Sequential): nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std) nn.init.constant_(input_projection.bias, 0) elif isinstance(module, Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention): nn.init.constant_(module.sampling_offsets.weight.data, 0.0) thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) .view(module.n_heads, 1, 1, 2) .repeat(1, module.n_levels, module.n_points, 1) ) for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) nn.init.constant_(module.attention_weights.weight.data, 0.0) nn.init.constant_(module.attention_weights.bias.data, 0.0) nn.init.xavier_uniform_(module.value_proj.weight.data) nn.init.constant_(module.value_proj.bias.data, 0.0) nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.constant_(module.output_proj.bias.data, 0.0) elif isinstance(module, Mask2FormerMaskedAttentionDecoderLayer): for p in module.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p, gain=xavier_std) module.cross_attn.in_proj_bias.data.zero_() elif isinstance(module, Mask2FormerPixelDecoder): nn.init.normal_(module.level_embed, std=0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.weight.data.fill_(1.0) module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() if hasattr(module, "reference_points"): nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) nn.init.constant_(module.reference_points.bias.data, 0.0) @auto_docstring class Mask2FormerModel(Mask2FormerPreTrainedModel): main_input_name = "pixel_values" def __init__(self, config: Mask2FormerConfig): super().__init__(config) self.pixel_level_module = Mask2FormerPixelLevelModule(config) self.transformer_module = Mask2FormerTransformerModule(in_features=config.feature_size, config=config) self.post_init() @auto_docstring def forward( self, pixel_values: Tensor, pixel_mask: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Mask2FormerModelOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, _, height, width = pixel_values.shape if pixel_mask is None: pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) pixel_level_module_output = self.pixel_level_module( pixel_values=pixel_values, output_hidden_states=output_hidden_states ) transformer_module_output = self.transformer_module( multi_scale_features=pixel_level_module_output.decoder_hidden_states, mask_features=pixel_level_module_output.decoder_last_hidden_state, output_hidden_states=True, output_attentions=output_attentions, ) encoder_hidden_states = None pixel_decoder_hidden_states = None transformer_decoder_hidden_states = None transformer_decoder_intermediate_states = None if output_hidden_states: encoder_hidden_states = pixel_level_module_output.encoder_hidden_states pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states transformer_decoder_hidden_states = transformer_module_output.hidden_states transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states output = Mask2FormerModelOutput( encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state, pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state, transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state, encoder_hidden_states=encoder_hidden_states, pixel_decoder_hidden_states=pixel_decoder_hidden_states, transformer_decoder_hidden_states=transformer_decoder_hidden_states, transformer_decoder_intermediate_states=transformer_decoder_intermediate_states, attentions=transformer_module_output.attentions, masks_queries_logits=transformer_module_output.masks_queries_logits, ) if not return_dict: output = tuple(v for v in output.values() if v is not None) return output @auto_docstring( custom_intro=""" The Mask2Former Model with heads on top for instance/semantic/panoptic segmentation. """ ) class Mask2FormerForUniversalSegmentation(Mask2FormerPreTrainedModel): main_input_name = "pixel_values" def __init__(self, config: Mask2FormerConfig): super().__init__(config) self.model = Mask2FormerModel(config) self.weight_dict: dict[str, float] = { "loss_cross_entropy": config.class_weight, "loss_mask": config.mask_weight, "loss_dice": config.dice_weight, } self.class_predictor = nn.Linear(config.hidden_dim, config.num_labels + 1) self.criterion = Mask2FormerLoss(config=config, weight_dict=self.weight_dict) self.post_init() def get_loss_dict( self, masks_queries_logits: Tensor, class_queries_logits: Tensor, mask_labels: Tensor, class_labels: Tensor, auxiliary_predictions: dict[str, Tensor], ) -> dict[str, Tensor]: loss_dict: dict[str, Tensor] = self.criterion( masks_queries_logits=masks_queries_logits, class_queries_logits=class_queries_logits, mask_labels=mask_labels, class_labels=class_labels, auxiliary_predictions=auxiliary_predictions, ) # weight each loss by `self.weight_dict[]` including auxiliary losses for key, weight in self.weight_dict.items(): for loss_key, loss in loss_dict.items(): if key in loss_key: loss *= weight return loss_dict def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor: return sum(loss_dict.values()) def get_auxiliary_logits(self, classes: torch.Tensor, output_masks: torch.Tensor): auxiliary_logits: list[dict[str, Tensor]] = [] for aux_binary_masks, aux_classes in zip(output_masks[:-1], classes[:-1]): auxiliary_logits.append({"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes}) return auxiliary_logits @auto_docstring def forward( self, pixel_values: Tensor, mask_labels: Optional[list[Tensor]] = None, class_labels: Optional[list[Tensor]] = None, pixel_mask: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_auxiliary_logits: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Mask2FormerForUniversalSegmentationOutput: r""" mask_labels (`list[torch.Tensor]`, *optional*): List of mask labels of shape `(num_labels, height, width)` to be fed to a model class_labels (`list[torch.LongTensor]`, *optional*): list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. output_auxiliary_logits (`bool`, *optional*): Whether or not to output auxiliary logits. Examples: Instance segmentation example: ```python >>> from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation >>> from PIL import Image >>> import requests >>> import torch >>> # Load Mask2Former trained on COCO instance segmentation dataset >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") >>> model = Mask2FormerForUniversalSegmentation.from_pretrained( ... "facebook/mask2former-swin-small-coco-instance" ... ) >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = image_processor(image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> # Model predicts class_queries_logits of shape `(batch_size, num_queries)` >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` >>> class_queries_logits = outputs.class_queries_logits >>> masks_queries_logits = outputs.masks_queries_logits >>> # Perform post-processing to get instance segmentation map >>> pred_instance_map = image_processor.post_process_instance_segmentation( ... outputs, target_sizes=[(image.height, image.width)] ... )[0] >>> print(pred_instance_map.shape) torch.Size([480, 640]) ``` Semantic segmentation example: ```python >>> from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation >>> from PIL import Image >>> import requests >>> import torch >>> # Load Mask2Former trained on ADE20k semantic segmentation dataset >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic") >>> model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-ade-semantic") >>> url = ( ... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg" ... ) >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = image_processor(image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> # Model predicts class_queries_logits of shape `(batch_size, num_queries)` >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` >>> class_queries_logits = outputs.class_queries_logits >>> masks_queries_logits = outputs.masks_queries_logits >>> # Perform post-processing to get semantic segmentation map >>> pred_semantic_map = image_processor.post_process_semantic_segmentation( ... outputs, target_sizes=[(image.height, image.width)] ... )[0] >>> print(pred_semantic_map.shape) torch.Size([512, 683]) ``` Panoptic segmentation example: ```python >>> from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation >>> from PIL import Image >>> import requests >>> import torch >>> # Load Mask2Former trained on CityScapes panoptic segmentation dataset >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-cityscapes-panoptic") >>> model = Mask2FormerForUniversalSegmentation.from_pretrained( ... "facebook/mask2former-swin-small-cityscapes-panoptic" ... ) >>> url = "https://cdn-media.huggingface.co/Inference-API/Sample-results-on-the-Cityscapes-dataset-The-above-images-show-how-our-method-can-handle.png" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = image_processor(image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> # Model predicts class_queries_logits of shape `(batch_size, num_queries)` >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` >>> class_queries_logits = outputs.class_queries_logits >>> masks_queries_logits = outputs.masks_queries_logits >>> # Perform post-processing to get panoptic segmentation map >>> pred_panoptic_map = image_processor.post_process_panoptic_segmentation( ... outputs, target_sizes=[(image.height, image.width)] ... )[0]["segmentation"] >>> print(pred_panoptic_map.shape) torch.Size([338, 676]) ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( pixel_values=pixel_values, pixel_mask=pixel_mask, output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, output_attentions=output_attentions, return_dict=True, ) loss, loss_dict, auxiliary_logits = None, None, None class_queries_logits = () for decoder_output in outputs.transformer_decoder_intermediate_states: class_prediction = self.class_predictor(decoder_output.transpose(0, 1)) class_queries_logits += (class_prediction,) masks_queries_logits = outputs.masks_queries_logits auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits) if mask_labels is not None and class_labels is not None: loss_dict = self.get_loss_dict( masks_queries_logits=masks_queries_logits[-1], class_queries_logits=class_queries_logits[-1], mask_labels=mask_labels, class_labels=class_labels, auxiliary_predictions=auxiliary_logits, ) loss = self.get_loss(loss_dict) encoder_hidden_states = None pixel_decoder_hidden_states = None transformer_decoder_hidden_states = None if output_hidden_states: encoder_hidden_states = outputs.encoder_hidden_states pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states output_auxiliary_logits = ( self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits ) if not output_auxiliary_logits: auxiliary_logits = None output = Mask2FormerForUniversalSegmentationOutput( loss=loss, class_queries_logits=class_queries_logits[-1], masks_queries_logits=masks_queries_logits[-1], auxiliary_logits=auxiliary_logits, encoder_last_hidden_state=outputs.encoder_last_hidden_state, pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state, transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state, encoder_hidden_states=encoder_hidden_states, pixel_decoder_hidden_states=pixel_decoder_hidden_states, transformer_decoder_hidden_states=transformer_decoder_hidden_states, attentions=outputs.attentions, ) if not return_dict: output = tuple(v for v in output.values() if v is not None) if loss is not None: output = (loss) + output return output __all__ = ["Mask2FormerForUniversalSegmentation", "Mask2FormerModel", "Mask2FormerPreTrainedModel"]