# coding=utf-8 # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch SAM 2 model.""" from dataclasses import dataclass from typing import Callable, Optional, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ...activations import ACT2FN from ...image_processing_utils import BatchFeature, get_size_dict from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs from ...image_utils import ( IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, ChannelDimension, ImageInput, PILImageResampling, SizeDict, pil_torch_interpolation_mapping, ) from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( ModelOutput, TensorType, auto_docstring, logging, ) from ...utils.generic import TransformersKwargs, check_model_inputs from ..auto import AutoModel from ..maskformer.modeling_maskformer import MaskFormerSinePositionEmbedding from ..sam.image_processing_sam_fast import SamImageProcessorFast from ..sam.modeling_sam import ( SamLayerNorm, SamMaskDecoder, SamMaskEmbedding, SamModel, SamPromptEncoder, SamTwoWayAttentionBlock, SamTwoWayTransformer, eager_attention_forward, ) from ..vitdet.modeling_vitdet import window_partition, window_unpartition from .configuration_sam2 import ( Sam2Config, Sam2HieraDetConfig, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig, ) logger = logging.get_logger(__name__) class Sam2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): r""" mask_size (`dict[str, int]`, *optional*): The size `{"height": int, "width": int}` to resize the segmentation maps to. """ mask_size: Optional[dict[str, int]] @auto_docstring class Sam2ImageProcessorFast(SamImageProcessorFast): resample = PILImageResampling.BILINEAR image_mean = IMAGENET_DEFAULT_MEAN image_std = IMAGENET_DEFAULT_STD size = {"height": 1024, "width": 1024} mask_size = {"height": 256, "width": 256} do_resize = True do_rescale = True do_normalize = True do_convert_rgb = True valid_kwargs = Sam2FastImageProcessorKwargs # modular artefacts do_pad = None pad_size = None mask_pad_size = None def __init__(self, **kwargs: Unpack[Sam2FastImageProcessorKwargs]): BaseImageProcessorFast.__init__(self, **kwargs) def pad_image(self): raise NotImplementedError("No pad_image for SAM 2.") def _get_preprocess_shape(self): raise NotImplementedError("No _get_preprocess_shape for SAM 2.") def resize(self): raise NotImplementedError("No need to override resize for SAM 2.") def _preprocess( self, images: list["torch.Tensor"], return_tensors: Optional[Union[str, TensorType]], **kwargs, ) -> "torch.Tensor": return BaseImageProcessorFast._preprocess(self, images, return_tensors=return_tensors, **kwargs).pixel_values def _preprocess_image_like_inputs( self, images: ImageInput, segmentation_maps: Optional[ImageInput], do_convert_rgb: bool, input_data_format: ChannelDimension, device: Optional[Union[str, "torch.device"]] = None, **kwargs: Unpack[Sam2FastImageProcessorKwargs], ) -> BatchFeature: """ Preprocess image-like inputs. """ images = self._prepare_image_like_inputs( images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device ) original_sizes = [image.shape[-2:] for image in images] images_kwargs = kwargs.copy() pixel_values = self._preprocess(images, **images_kwargs) reshaped_input_sizes = [image.shape[-2:] for image in images] data = { "pixel_values": pixel_values, "original_sizes": original_sizes, "reshaped_input_sizes": reshaped_input_sizes, } if segmentation_maps is not None: processed_segmentation_maps = self._prepare_image_like_inputs( images=segmentation_maps, expected_ndims=2, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST, ) segmentation_maps_kwargs = kwargs.copy() segmentation_maps_kwargs.update( { "do_normalize": False, "do_rescale": False, "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST], "size": segmentation_maps_kwargs.pop("mask_size"), } ) processed_segmentation_maps = self._preprocess( images=processed_segmentation_maps, **segmentation_maps_kwargs ) data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) return BatchFeature(data=data, tensor_type=kwargs["return_tensors"]) def _further_process_kwargs( self, size: Optional[SizeDict] = None, mask_size: Optional[SizeDict] = None, default_to_square: Optional[bool] = None, image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, data_format: Optional[ChannelDimension] = None, **kwargs, ) -> dict: """ Update kwargs that need further processing before being validated Can be overridden by subclasses to customize the processing of kwargs. """ if kwargs is None: kwargs = {} if size is not None: size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if mask_size is not None: mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size")) if isinstance(image_mean, list): image_mean = tuple(image_mean) if isinstance(image_std, list): image_std = tuple(image_std) if data_format is None: data_format = ChannelDimension.FIRST kwargs["size"] = size kwargs["mask_size"] = mask_size kwargs["image_mean"] = image_mean kwargs["image_std"] = image_std kwargs["data_format"] = data_format # torch resize uses interpolation instead of resample # Check if resample is an int before checking if it's an instance of PILImageResampling # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module. # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`. resample = kwargs.pop("resample") kwargs["interpolation"] = ( pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample ) return kwargs def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor: """ Apply non-overlapping constraints to the object scores in pred_masks. Here we keep only the highest scoring object at each spatial location in pred_masks. """ batch_size = pred_masks.size(0) if batch_size == 1: return pred_masks device = pred_masks.device # "max_obj_inds": object index of the object with the highest score at each location max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] keep = max_obj_inds == batch_obj_inds # suppress overlapping regions' scores below -10.0 so that the foreground regions # don't overlap (here sigmoid(-10.0)=4.5398e-05) pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) return pred_masks def post_process_masks( self, masks, original_sizes, mask_threshold=0.0, binarize=True, max_hole_area=0.0, max_sprinkle_area=0.0, apply_non_overlapping_constraints=False, **kwargs, ): """ Remove padding and upscale masks to the original image size. Args: masks (`Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]`): Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format. mask_threshold (`float`, *optional*, defaults to 0.0): Threshold for binarization and post-processing operations. binarize (`bool`, *optional*, defaults to `True`): Whether to binarize the masks. max_hole_area (`float`, *optional*, defaults to 0.0): The maximum area of a hole to fill. max_sprinkle_area (`float`, *optional*, defaults to 0.0): The maximum area of a sprinkle to fill. apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): Whether to apply non-overlapping constraints to the masks. Returns: (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. """ if isinstance(original_sizes, (torch.Tensor, np.ndarray)): original_sizes = original_sizes.tolist() # TODO: add connected components kernel for postprocessing output_masks = [] for i, original_size in enumerate(original_sizes): if isinstance(masks[i], np.ndarray): masks[i] = torch.from_numpy(masks[i]) elif not isinstance(masks[i], torch.Tensor): raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`") interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False) if apply_non_overlapping_constraints: interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask) if binarize: interpolated_mask = interpolated_mask > mask_threshold output_masks.append(interpolated_mask) return output_masks @dataclass @auto_docstring(custom_intro="Base class for the vision encoder's outputs.") class Sam2VisionEncoderOutput(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. fpn_hidden_states (`tuple(torch.FloatTensor)`): Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck. fpn_position_encoding (`tuple(torch.FloatTensor)`): Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each stage. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: Optional[torch.FloatTensor] = None fpn_hidden_states: Optional[torch.FloatTensor] = None fpn_position_encoding: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None @dataclass @auto_docstring(custom_intro="Base class for the Sam2 model's output.") class Sam2ImageSegmentationOutput(ModelOutput): r""" iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`): The Intersection over Union (IoU) scores of the predicted masks. pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`): The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed by the processor to be brought to the original image size. object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`): Logits for the object score, indicating if an object is present. image_embeddings (`tuple(torch.FloatTensor)`): The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each tensor has shape `(batch_size, channels, height, width)`. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the vision model at the output of each stage. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the vision model. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the mask decoder. """ iou_scores: Optional[torch.FloatTensor] = None pred_masks: Optional[torch.FloatTensor] = None object_score_logits: Optional[torch.FloatTensor] = None image_embeddings: tuple[torch.FloatTensor, ...] = None vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None class Sam2PatchEmbeddings(nn.Module): r""" Turns pixel values into patch embeddings for transformer consumption. Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details. Returns: embeddings (`torch.FloatTensor`): Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding """ def __init__(self, config: Sam2HieraDetConfig): super().__init__() num_channels = config.num_channels hidden_size = config.hidden_size self.projection = nn.Conv2d( num_channels, hidden_size, kernel_size=config.patch_kernel_size, stride=config.patch_stride, padding=config.patch_padding, ) def forward(self, pixel_values): _, num_channels, height, width = pixel_values.shape embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) return embeddings class Sam2SinePositionEmbedding(MaskFormerSinePositionEmbedding): pass class Sam2VisionNeck(nn.Module): def __init__(self, config: Sam2VisionConfig): super().__init__() self.config = config self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True) self.convs = nn.ModuleList() for in_channels in config.backbone_channel_list: self.convs.append( nn.Conv2d( in_channels=in_channels, out_channels=config.fpn_hidden_size, kernel_size=config.fpn_kernel_size, stride=config.fpn_stride, padding=config.fpn_padding, ), ) self.fpn_top_down_levels = config.fpn_top_down_levels def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: fpn_hidden_states = () fpn_position_encoding = () # forward in top-down order (from low to high resolution) n = len(self.convs) - 1 for i in range(n, -1, -1): lateral_features = hidden_states[i].permute(0, 3, 1, 2) lateral_features = self.convs[n - i](lateral_features) if i not in self.fpn_top_down_levels or i == n: prev_features = lateral_features else: top_down_features = F.interpolate( prev_features.to(dtype=torch.float32), scale_factor=2.0, mode="nearest", align_corners=None, antialias=False, ).to(lateral_features.dtype) prev_features = lateral_features + top_down_features prev_position_encoding = self.position_encoding( prev_features.shape, prev_features.device, prev_features.dtype ).to(prev_features.dtype) fpn_hidden_states += (prev_features,) fpn_position_encoding += (prev_position_encoding,) return fpn_hidden_states, fpn_position_encoding def do_pool(x: torch.Tensor, query_stride: Optional[int] = None) -> torch.Tensor: if query_stride is None: return x # (B, H, W, C) -> (B, C, H, W) x = x.permute(0, 3, 1, 2) x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False) # (B, C, H', W') -> (B, H', W', C) x = x.permute(0, 2, 3, 1) return x class Sam2MultiScaleAttention(nn.Module): def __init__( self, config: Sam2HieraDetConfig, dim: int, dim_out: int, num_attention_heads: int, query_stride: Optional[tuple[int, int]] = None, ): super().__init__() self.config = config self.dim = dim self.dim_out = dim_out self.query_stride = query_stride self.num_attention_heads = num_attention_heads head_dim = dim_out // num_attention_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim_out * 3) self.proj = nn.Linear(dim_out, dim_out) self.is_causal = False def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: batch_size, height, width, _ = hidden_states.shape # qkv with shape (B, H * W, 3, nHead, C) qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1) # q, k, v with shape (B, H * W, nheads, C) query, key, value = torch.unbind(qkv, 2) attn_weights = (query * self.scale) @ key.transpose(-2, -1) attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) # Q pooling (for downsample at stage changes) if self.query_stride: query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride) height, width = query.shape[1:3] # downsampled shape query = query.reshape(batch_size, height * width, self.num_attention_heads, -1) # transpose query, key, value to (B, nHead, H * W, C) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, _ = attention_interface( self, query, key, value, attention_mask=None, is_causal=self.is_causal, scaling=self.scale, **kwargs, ) attn_output = attn_output.reshape(batch_size, height, width, -1) attn_output = self.proj(attn_output) return attn_output class Sam2FeedForward(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, activation: str = "relu", sigmoid_output: bool = False, ): super().__init__() self.num_layers = num_layers self.activation = ACT2FN[activation] self.proj_in = nn.Linear(input_dim, hidden_dim) self.proj_out = nn.Linear(hidden_dim, output_dim) self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) self.sigmoid_output = sigmoid_output def forward(self, hidden_states): hidden_states = self.proj_in(hidden_states) hidden_states = self.activation(hidden_states) for layer in self.layers: hidden_states = self.activation(layer(hidden_states)) hidden_states = self.proj_out(hidden_states) if self.sigmoid_output: hidden_states = F.sigmoid(hidden_states) return hidden_states class Sam2MultiScaleBlock(GradientCheckpointingLayer): def __init__( self, config: Sam2HieraDetConfig, stage_idx: int, block_idx: int, total_block_idx: int, ): super().__init__() # take embed dim from previous stage if first block of stage self.dim = ( config.embed_dim_per_stage[stage_idx - 1] if stage_idx > 0 and block_idx == 0 else config.embed_dim_per_stage[stage_idx] ) self.dim_out = config.embed_dim_per_stage[stage_idx] self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) # take window size from previous stage if first block of stage self.window_size = ( config.window_size_per_stage[stage_idx - 1] if stage_idx > 0 and block_idx == 0 else config.window_size_per_stage[stage_idx] ) self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size # use query stride for first block of stage if stage is a query pool stage self.query_stride = ( config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None ) self.attn = Sam2MultiScaleAttention( config, self.dim, self.dim_out, num_attention_heads=config.num_attention_heads_per_stage[stage_idx], query_stride=self.query_stride, ) self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps) self.mlp = Sam2FeedForward( self.dim_out, int(self.dim_out * config.mlp_ratio), self.dim_out, num_layers=2, activation=config.hidden_act, ) if self.dim != self.dim_out: self.proj = nn.Linear(self.dim, self.dim_out) def forward( self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: residual = hidden_states # batch_size, height, width, channel hidden_states = self.layer_norm1(hidden_states) # Skip connection if self.dim != self.dim_out: residual = do_pool(self.proj(hidden_states), self.query_stride) # Window partition window_size = self.window_size if self.window_size > 0: H, W = hidden_states.shape[1], hidden_states.shape[2] hidden_states, pad_hw = window_partition(hidden_states, window_size) # Window Attention + Q Pooling (if stage change) attn_output = self.attn( hidden_states=hidden_states, **kwargs, ) hidden_states = attn_output if self.query_stride: # Shapes have changed due to Q pooling window_size = self.window_size // self.query_stride[0] H, W = residual.shape[1:3] pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size pad_hw = (H + pad_h, W + pad_w) # Reverse window partition if self.window_size > 0: hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W)) hidden_states = residual + hidden_states layernorm_output = self.layer_norm2(hidden_states) hidden_states = hidden_states + self.mlp(layernorm_output) return hidden_states @dataclass @auto_docstring( custom_intro=""" Hiera model's outputs that also contains a pooling of the last hidden states. """ ) class Sam2HieraDetModelOutput(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): hidden-states at the output of the last layer of the model. intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`): Sequence of hidden-states at the output of the intermediate layers of the model. """ last_hidden_state: Optional[torch.FloatTensor] = None intermediate_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None @auto_docstring class Sam2PreTrainedModel(PreTrainedModel): config_class = Sam2Config base_model_prefix = "sam2" main_input_name = "pixel_values" _supports_sdpa = True _supports_flash_attn_2 = True _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)): module.weight.data.fill_(1.0) module.bias.data.zero_() if isinstance(module, Sam2HieraDetModel): if module.pos_embed is not None: module.pos_embed.data.zero_() if module.pos_embed_window is not None: module.pos_embed_window.data.zero_() if isinstance(module, Sam2Model): if module.no_memory_embedding is not None: module.no_memory_embedding.data.zero_() class Sam2HieraDetModel(Sam2PreTrainedModel): config_class = Sam2HieraDetConfig main_input_name = "pixel_values" _can_record_outputs = { "hidden_states": Sam2MultiScaleBlock, "attentions": Sam2MultiScaleAttention, } def __init__(self, config: Sam2HieraDetConfig): super().__init__(config) self.patch_embed = Sam2PatchEmbeddings(config) # Windowed positional embedding (https://huggingface.co/papers/2311.05613) self.pos_embed = nn.Parameter( torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size) ) self.pos_embed_window = nn.Parameter( torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0]) ) self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist() self.blocks = nn.ModuleList() total_block_idx = 0 for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage): for block_idx in range(blocks_per_stage): block = Sam2MultiScaleBlock( config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx ) self.blocks.append(block) total_block_idx += 1 def get_input_embeddings(self): return self.patch_embed def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor: h, w = hw window_embed = self.pos_embed_window pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) pos_embed = pos_embed.permute(0, 2, 3, 1) return pos_embed @check_model_inputs def forward( self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Sam2HieraDetModelOutput]: if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.patch_embed(pixel_values) hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3]) intermediate_hidden_states = () for i, block_module in enumerate(self.blocks): hidden_states = block_module(hidden_states, **kwargs) if i in self.stage_ends: intermediate_hidden_states = intermediate_hidden_states + (hidden_states,) return Sam2HieraDetModelOutput( last_hidden_state=hidden_states, intermediate_hidden_states=intermediate_hidden_states, ) @auto_docstring( custom_intro=""" The vision model from Sam without any head or projection on top. """ ) class Sam2VisionModel(Sam2PreTrainedModel): config_class = Sam2VisionConfig main_input_name = "pixel_values" _can_record_outputs = { "hidden_states": Sam2MultiScaleBlock, "attentions": Sam2MultiScaleAttention, } def __init__(self, config: Sam2VisionConfig): super().__init__(config) self.config = config self.backbone = AutoModel.from_config(config.backbone_config) self.neck = Sam2VisionNeck(config) self.num_feature_levels = config.num_feature_levels self.post_init() def get_input_embeddings(self): return self.backbone.get_input_embeddings() @check_model_inputs def forward( self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Sam2VisionEncoderOutput]: if pixel_values is None: raise ValueError("You have to specify pixel_values") # Forward through backbone backbone_output = self.backbone(pixel_values, **kwargs) hidden_states = backbone_output.last_hidden_state intermediate_hidden_states = backbone_output.intermediate_hidden_states fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1] fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1] return Sam2VisionEncoderOutput( last_hidden_state=hidden_states, fpn_hidden_states=fpn_hidden_states, fpn_position_encoding=fpn_position_encoding, ) class Sam2PositionalEmbedding(nn.Module): def __init__(self, config: Sam2PromptEncoderConfig): super().__init__() self.scale = config.scale positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2)) self.register_buffer("positional_embedding", positional_embedding) def forward(self, input_coords, input_shape=None): """Positionally encode points that are normalized to [0,1].""" coordinates = input_coords.clone() if input_shape is not None: coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] coordinates.to(torch.float32) # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coordinates = 2 * coordinates - 1 coordinates = coordinates.to(self.positional_embedding.dtype) coordinates = coordinates @ self.positional_embedding coordinates = 2 * np.pi * coordinates # outputs d_1 x ... x d_n x channel shape return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) class Sam2MaskEmbedding(SamMaskEmbedding): pass class Sam2PromptEncoder(SamPromptEncoder): def __init__(self, config: Sam2PromptEncoderConfig): nn.Module.__init__(self) self.shared_embedding = Sam2PositionalEmbedding(config) self.mask_embed = Sam2MaskEmbedding(config) self.no_mask_embed = nn.Embedding(1, config.hidden_size) self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size) self.input_image_size = config.image_size self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size) self.hidden_size = config.hidden_size self.not_a_point_embed = nn.Embedding(1, config.hidden_size) def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0) labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1) input_shape = (self.input_image_size, self.input_image_size) point_embedding = self.shared_embedding(points, input_shape) # torch.where and expanding the labels tensor is required by the ONNX export point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) # This is required for the ONNX export. The dtype, device need to be explicitly # specified as otherwise torch.onnx.export interprets as double point_embedding = torch.where( labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding), ) # Add point embeddings for labels >= 0 point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1) return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes += 0.5 # Shift to center of pixel coords = boxes.view(*boxes.shape[:2], 2, 2) # add padding point for consistency with the original implementation coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0) corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size)) corner_embedding[:, :, 0, :] += self.point_embed.weight[2] corner_embedding[:, :, 1, :] += self.point_embed.weight[3] corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :]) return corner_embedding class Sam2Attention(nn.Module): """ SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__(self, config, downsample_rate=None): super().__init__() downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate self.config = config self.hidden_size = config.hidden_size self.internal_dim = config.hidden_size // downsample_rate self.num_attention_heads = config.num_attention_heads self.head_dim = self.internal_dim // config.num_attention_heads self.scaling = self.head_dim**-0.5 self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_similarity: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: # Input projections batch_size, point_batch_size = query.shape[:2] new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) query = self.q_proj(query).view(*new_shape).transpose(1, 2) key = self.k_proj(key).view(*new_shape).transpose(1, 2) value = self.v_proj(value).view(*new_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query, key, value, attention_mask=attention_similarity, dropout=0.0, scaling=self.scaling, is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.reshape( batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Sam2TwoWayAttentionBlock(SamTwoWayAttentionBlock, GradientCheckpointingLayer): def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False): nn.Module.__init__(self) self.self_attn = Sam2Attention(config, downsample_rate=1) self.layer_norm1 = nn.LayerNorm(config.hidden_size) self.cross_attn_token_to_image = Sam2Attention(config) self.layer_norm2 = nn.LayerNorm(config.hidden_size) self.mlp = Sam2FeedForward( config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers ) self.layer_norm3 = nn.LayerNorm(config.hidden_size) self.layer_norm4 = nn.LayerNorm(config.hidden_size) self.cross_attn_image_to_token = Sam2Attention(config) self.skip_first_layer_pe = skip_first_layer_pe class Sam2TwoWayTransformer(SamTwoWayTransformer): pass class Sam2LayerNorm(SamLayerNorm): pass class Sam2MaskDecoder(SamMaskDecoder): def __init__(self, config: Sam2MaskDecoderConfig): super().__init__(config) del self.iou_prediction_head self.iou_prediction_head = Sam2FeedForward( self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth, sigmoid_output=True, ) self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1) self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1) self.obj_score_token = nn.Embedding(1, self.hidden_size) self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3) self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh def _get_stability_scores(self, mask_logits): """ Compute stability scores of the mask logits based on the IoU between upper and lower thresholds. """ mask_logits = mask_logits.flatten(-2) stability_delta = self.dynamic_multimask_stability_delta area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) return stability_scores def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): """ When outputting a single mask, if the stability score from the current single-mask output (based on output token 0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score. This is intended to ensure a valid mask for both clicking and tracking. """ # The best mask from multimask output tokens (1~3) multimask_logits = all_mask_logits[:, :, 1:, :, :] multimask_iou_scores = all_iou_scores[:, :, 1:] best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P] best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) best_scores_inds_expanded = best_scores_inds_expanded.expand( -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1) ) best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W] best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1] # The mask from singlemask output token 0 and its stability score singlemask_logits = all_mask_logits[:, :, 0:1, :, :] singlemask_iou_scores = all_iou_scores[:, :, 0:1] stability_scores = self._get_stability_scores(singlemask_logits) is_stable = stability_scores >= self.dynamic_multimask_stability_thresh # Dynamically fall back to best multimask output upon low stability scores. mask_logits_out = torch.where( is_stable[..., None, None].expand_as(singlemask_logits), singlemask_logits, best_multimask_logits, ) iou_scores_out = torch.where( is_stable.expand_as(singlemask_iou_scores), singlemask_iou_scores, best_multimask_iou_scores, ) return mask_logits_out, iou_scores_out def forward( self, image_embeddings: torch.Tensor, image_positional_embeddings: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool, high_resolution_features: list[torch.Tensor], attention_similarity: Optional[torch.Tensor] = None, target_embedding: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Args: image_embeddings (`torch.Tensor`): The embeddings from the image encoder. image_positional_embeddings (`torch.Tensor`): Positional encoding with the shape of image_embeddings. sparse_prompt_embeddings (`torch.Tensor`): The embeddings of the points and boxes. dense_prompt_embeddings (`torch.Tensor`): The embeddings of the mask inputs. multimask_output (`bool`): Whether to return multiple masks or a single mask. high_resolution_features (`list[torch.Tensor]`, *optional*): The high-resolution features from the vision encoder. attention_similarity (`torch.Tensor`, *optional*): The attention similarity tensor. target_embedding (`torch.Tensor`, *optional*): The target embedding. """ batch_size, num_channels, height, width = image_embeddings.shape point_batch_size = sparse_prompt_embeddings.shape[1] # Concatenate output tokens output_tokens = torch.cat( [ self.obj_score_token.weight, self.iou_token.weight, self.mask_tokens.weight, ], dim=0, ) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) if sparse_prompt_embeddings.shape[0] != 0: tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) else: tokens = output_tokens point_embeddings = tokens.to(self.iou_token.weight.dtype) # Expand per-image data in batch direction to be per-mask image_embeddings = image_embeddings + dense_prompt_embeddings image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0) image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) # Run the transformer point_embeddings, image_embeddings = self.transformer( point_embeddings=point_embeddings, image_embeddings=image_embeddings, image_positional_embeddings=image_positional_embeddings, attention_similarity=attention_similarity, target_embedding=target_embedding, **kwargs, ) iou_token_out = point_embeddings[:, :, 1, :] mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens image_embeddings = image_embeddings.transpose(2, 3).view( batch_size * point_batch_size, num_channels, height, width ) feat_s0, feat_s1 = high_resolution_features feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0) feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0) upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1 upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0) hyper_in_list: list[torch.Tensor] = [] for i in range(self.num_mask_tokens): current_mlp = self.output_hypernetworks_mlps[i] hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] hyper_in = torch.stack(hyper_in_list, dim=2) _, num_channels, height, width = upscaled_embedding.shape upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width) masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :]) # Select the correct mask or masks for output if multimask_output: mask_slice = slice(1, None) masks = masks[:, :, mask_slice, :, :] iou_pred = iou_pred[:, :, mask_slice] elif self.dynamic_multimask_via_stability and not self.training: mask_slice = slice(0, 1) masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) else: mask_slice = slice(0, 1) masks = masks[:, :, mask_slice, :, :] iou_pred = iou_pred[:, :, mask_slice] sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape return masks, iou_pred, sam_tokens_out, object_score_logits @auto_docstring( custom_intro=""" Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and input points and labels, boxes, or masks. """ ) class Sam2Model(SamModel): _keys_to_ignore_on_load_unexpected = [ r"^memory_.*", r"^mask_downsample.*", r"^object_pointer_proj.*", r"^temporal_positional_encoding_projection_layer.*", "no_memory_positional_encoding", "no_object_pointer", "occlusion_spatial_embedding_parameter", ] def __init__(self, config: Sam2Config): PreTrainedModel.__init__(self, config) self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config) self.vision_encoder = AutoModel.from_config(config.vision_config) self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config) # The module using it is not a PreTrainedModel subclass so we need this config.mask_decoder_config._attn_implementation = config._attn_implementation self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config) self.num_feature_levels = config.vision_config.num_feature_levels self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes # a single token to indicate no memory embedding from previous frames self.hidden_dim = config.vision_config.fpn_hidden_size self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) self.post_init() def get_image_wide_positional_embeddings(self) -> torch.Tensor: size = self.prompt_encoder.image_embedding_size target_device = self.shared_image_embedding.positional_embedding.device target_dtype = self.shared_image_embedding.positional_embedding.dtype grid = torch.ones(size, device=target_device, dtype=target_dtype) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 y_embed = y_embed / size[0] x_embed = x_embed / size[1] positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width @torch.no_grad() def get_image_embeddings( self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs], ) -> list[torch.Tensor]: r""" Returns the image embeddings by passing the pixel values through the vision encoder. Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Input pixel values """ batch_size = pixel_values.shape[0] feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs) # add no memory embedding to the last feature map feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding # reshape feature maps to the same shape as the backbone feature sizes image_embeddings = [ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) ] return image_embeddings def get_image_features( self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs], ) -> tuple[ list[torch.Tensor], list[torch.Tensor], Optional[tuple[torch.FloatTensor, ...]], Optional[tuple[torch.FloatTensor, ...]], ]: r""" Extract and preprocess image features using the vision encoder. Args: pixel_values (`torch.FloatTensor`): Input pixel values of shape `(batch_size, num_channels, height, width)`. Returns: `tuple`: A tuple containing: - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels. - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level. - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder. - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder. """ vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder( pixel_values, **kwargs, ) feature_maps = vision_outputs.fpn_hidden_states feature_maps_position_embeddings = vision_outputs.fpn_position_encoding # precompute projected level 0 and level 1 features in SAM decoder # to avoid running it again on every SAM click feature_maps = list(feature_maps) feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0]) feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1]) # flatten NxCxHxW to HWxNxC feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps] feature_maps_position_embeddings = [ feature_map_position_embedding.flatten(2).permute(2, 0, 1) for feature_map_position_embedding in feature_maps_position_embeddings ] return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions @check_model_inputs @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, input_points: Optional[torch.FloatTensor] = None, input_labels: Optional[torch.LongTensor] = None, input_boxes: Optional[torch.FloatTensor] = None, input_masks: Optional[torch.LongTensor] = None, image_embeddings: Optional[torch.FloatTensor] = None, multimask_output: bool = True, attention_similarity: Optional[torch.FloatTensor] = None, target_embedding: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Sam2ImageSegmentationOutput: r""" input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much better results. The points can be obtained by passing a list of list of list to the processor that will create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per input point), the third dimension is the number of points per segmentation mask (it is possible to pass multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) coordinates of the point. If a different number of points is passed either for each image, or for each mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the computation of the embedding will be skipped for these points using the labels. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the official implementation, there are 3 types of labels - `1`: the point is a point that contains the object of interest - `0`: the point is a point that does not contain the object of interest - `-1`: the point corresponds to the background We added the label: - `-10`: the point is a padding point, thus should be ignored by the prompt encoder The padding labels should be automatically done by the processor. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. In the order (`x1`, `y1`, `x2`, `y2`): - `x1`: the x coordinate of the top left point of the input box - `y1`: the y coordinate of the top left point of the input box - `x2`: the x coordinate of the bottom right point of the input box - `y2`: the y coordinate of the bottom right point of the input box input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` method, and then feed them to the `forward` method instead of feeding the `pixel_values`. multimask_output (`bool`, *optional*): In the original implementation and paper, the model always outputs 3 masks per image (or per point / per bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the "best" mask, by specifying `multimask_output=False`. attention_similarity (`torch.FloatTensor`, *optional*): Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). target_embedding (`torch.FloatTensor`, *optional*): Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoModel, AutoProcessor >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny") >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny") >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") >>> input_points = [[[400, 650]]] # 2D location of a window on the car >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") >>> # Get segmentation mask >>> outputs = model(**inputs) >>> # Postprocess masks >>> masks = processor.post_process_masks( ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] ... ) ``` """ if not ((pixel_values is None) ^ (image_embeddings is None)): raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.") if input_points is not None and input_boxes is not None: if input_points.shape[1] != input_boxes.shape[1]: raise ValueError( f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}." ) image_positional_embeddings = self.get_image_wide_positional_embeddings() # repeat with batch size batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0] image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) vision_attentions = None vision_hidden_states = None if pixel_values is not None: feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features( pixel_values, **kwargs, ) # add no memory embedding to the last feature map feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding # reshape feature maps to the same shape as the backbone feature sizes image_embeddings = [ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) ] if input_points is not None and input_labels is None: input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) if input_points is None and input_boxes is None: # If no points are provide, pad with an empty point (with label -1) input_points = torch.zeros( batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device ) input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device) if input_masks is not None: # If mask_inputs is provided, downsize it into low-res mask input if needed # and feed it as a dense mask prompt into the SAM mask encoder if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size: input_masks = F.interpolate( input_masks.float(), size=self.prompt_encoder.mask_input_size, align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ).to(input_masks.dtype) sparse_embeddings, dense_embeddings = self.prompt_encoder( input_points=input_points, input_labels=input_labels, input_boxes=input_boxes, input_masks=input_masks, ) low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder( image_embeddings=image_embeddings[-1], image_positional_embeddings=image_positional_embeddings, sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, high_resolution_features=image_embeddings[:-1], attention_similarity=attention_similarity, target_embedding=target_embedding, **kwargs, ) return Sam2ImageSegmentationOutput( iou_scores=iou_scores, pred_masks=low_res_multimasks, object_score_logits=object_score_logits, image_embeddings=image_embeddings, vision_hidden_states=vision_hidden_states, vision_attentions=vision_attentions, ) __all__ = [ "Sam2Model", "Sam2VisionModel", "Sam2PreTrainedModel", "Sam2ImageProcessorFast", "Sam2HieraDetModel", ]