# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/sam2_video/modular_sam2_video.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_sam2_video.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from typing import Optional, Union import numpy as np import torch from ...image_utils import ImageInput from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding from ...utils import TensorType from ...utils.import_utils import requires from ...video_utils import VideoInput from .modeling_sam2_video import Sam2VideoInferenceSession @requires(backends=("torch",)) class Sam2VideoProcessor(ProcessorMixin): r""" Constructs a SAM2 processor which wraps a SAM2 image processor and an 2D points & Bounding boxes processor into a single processor. [`Sam2VideoProcessor`] offers all the functionalities of [`Sam2ImageProcessorFast`] and [`Sam2VideoProcessor`]. See the docstring of [`~Sam2ImageProcessorFast.__call__`] and [`~Sam2VideoProcessor.__call__`] for more information. Args: image_processor (`Sam2ImageProcessorFast`): An instance of [`Sam2ImageProcessorFast`]. video_processor (`Sam2VideoVideoProcessor`): An instance of [`Sam2VideoVideoProcessor`]. target_size (`int`, *optional*): The target size (target_size, target_size) to which the image will be resized. point_pad_value (`int`, *optional*, defaults to -10): The value used for padding input points. """ attributes = ["image_processor", "video_processor"] image_processor_class = "Sam2ImageProcessorFast" video_processor_class = "Sam2VideoVideoProcessor" def __init__( self, image_processor, video_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs ): super().__init__(image_processor, video_processor, **kwargs) self.point_pad_value = point_pad_value self.target_size = target_size if target_size is not None else self.image_processor.size["height"] def __call__( self, images: Optional[ImageInput] = None, segmentation_maps: Optional[ImageInput] = None, input_points: Optional[Union[list[list[list[list[float]]]], torch.Tensor]] = None, input_labels: Optional[Union[list[list[list[int]]], torch.Tensor]] = None, input_boxes: Optional[Union[list[list[list[float]]], torch.Tensor]] = None, original_sizes: Optional[Union[list[list[float]], torch.Tensor]] = None, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> BatchEncoding: r""" This method uses [`Sam2VideoImageProcessorFast.__call__`] method to prepare image(s) for the model. It also prepares 2D points and bounding boxes for the model if they are provided. Args: images (`ImageInput`, *optional*): The image(s) to process. segmentation_maps (`ImageInput`, *optional*): The segmentation maps to process. input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*): The points to add to the frame. input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*): The labels for the points. input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*): The bounding boxes to add to the frame. original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*): The original sizes of the images. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. **kwargs: Additional keyword arguments to pass to the image processor. Returns: A [`BatchEncoding`] with the following fields: - `pixel_values` (`torch.Tensor`): The processed image(s). - `original_sizes` (`list[list[float]]`): The original sizes of the images. - `reshaped_input_sizes` (`torch.Tensor`): The reshaped input sizes of the images. - `labels` (`torch.Tensor`): The processed segmentation maps (if provided). - `input_points` (`torch.Tensor`): The processed points. - `input_labels` (`torch.Tensor`): The processed labels. - `input_boxes` (`torch.Tensor`): The processed bounding boxes. """ if images is not None: encoding_image_processor = self.image_processor( images, segmentation_maps=segmentation_maps, return_tensors=return_tensors, **kwargs, ) elif original_sizes is not None: if isinstance(original_sizes, torch.Tensor): original_sizes = original_sizes.cpu().tolist() encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors) else: raise ValueError("Either images or original_sizes must be provided") # pop arguments that are not used in the forward but used nevertheless original_sizes = encoding_image_processor["original_sizes"] # Check original_sizes is of length 1 or len(images) if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images): raise ValueError( "original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size." ) # Process input points, labels, and boxes if provided if input_points is not None or input_labels is not None or input_boxes is not None: # Validate and convert inputs to standardized format processed_points = self._validate_single_input( input_points, expected_depth=4, input_name="points", expected_format="[image level, object level, point level, point coordinates]", expected_coord_size=2, ) processed_labels = self._validate_single_input( input_labels, expected_depth=3, input_name="labels", expected_format="[image level, object level, point level]", ) processed_boxes = self._validate_single_input( input_boxes, expected_depth=3, input_name="boxes", expected_format="[image level, box level, box coordinates]", expected_coord_size=4, ) # Get padding requirements for all inputs if processed_points is not None: points_max_dims = self._get_nested_dimensions(processed_points)[:3] if processed_labels is not None: labels_max_dims = self._get_nested_dimensions(processed_labels)[:3] if processed_boxes is not None: boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2] # Ensure points and labels have consistent dimensions if processed_points is not None and processed_labels is not None: if points_max_dims != labels_max_dims: raise ValueError( "Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions." ) # Check that boxes don't need padding (model limitation) if processed_boxes is not None and len(processed_boxes) >= 2: if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes): raise ValueError( "Input boxes have inconsistent dimensions that would require padding, " "but boxes cannot be padded due to model limitations. " "Please ensure all images have the same number of boxes." ) # Pad and normalize all inputs to final tensor format if processed_points is not None: padded_points = self._pad_nested_list(processed_points, points_max_dims + [2]) final_points = torch.tensor(padded_points, dtype=torch.float32) self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True) encoding_image_processor.update({"input_points": final_points}) if processed_labels is not None: padded_labels = self._pad_nested_list(processed_labels, labels_max_dims) final_labels = torch.tensor(padded_labels, dtype=torch.int64) encoding_image_processor.update({"input_labels": final_labels}) if processed_boxes is not None: final_boxes = torch.tensor(processed_boxes, dtype=torch.float32) self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True) encoding_image_processor.update({"input_boxes": final_boxes}) return encoding_image_processor def _normalize_coordinates( self, target_size: int, coords: "torch.Tensor", original_size, is_bounding_box=False ) -> "torch.Tensor": """ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. Args: target_size (`int`): The target size of the image. coords (`torch.Tensor`): The coordinates to be normalized. original_size (`tuple`): The original size of the image. is_bounding_box (`bool`, *optional*, defaults to `False`): Whether the coordinates are bounding boxes. """ old_h, old_w = original_size new_h, new_w = target_size, target_size coords = deepcopy(coords).float() if is_bounding_box: coords = coords.reshape(-1, 2, 2) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) if is_bounding_box: coords = coords.reshape(-1, 4) return coords def _convert_to_nested_list(self, data, expected_depth, current_depth=0): """ Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists. Args: data: Input data in any format expected_depth: Expected nesting depth current_depth: Current depth in recursion Returns: Nested list representation of the data """ if data is None: return None # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array if isinstance(data, torch.Tensor): # PyTorch tensor if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor return data.numpy().tolist() else: return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] elif isinstance(data, np.ndarray): # NumPy array if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array return data.tolist() else: return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] elif isinstance(data, list): if current_depth == expected_depth: # We've reached the expected depth, return as is return data else: # Continue recursion return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] elif isinstance(data, (int, float)): return data else: raise ValueError(f"Unsupported data type: {type(data)}") def _get_nested_dimensions(self, nested_list, max_dims=None): """ Get the maximum dimensions at each level of nesting. Args: nested_list (`list`): Nested list structure. max_dims (`list`, *optional*): Current maximum dimensions (for recursion). Returns: `list`: A list of maximum dimensions for each nesting level. """ if max_dims is None: max_dims = [] if not isinstance(nested_list, list): return max_dims if len(max_dims) == 0: max_dims.append(len(nested_list)) else: max_dims[0] = max(max_dims[0], len(nested_list)) if len(nested_list) > 0: for item in nested_list: if isinstance(item, list): sub_dims = self._get_nested_dimensions(item) # Merge sub_dims into max_dims for i, dim in enumerate(sub_dims): if i + 1 >= len(max_dims): max_dims.append(dim) else: max_dims[i + 1] = max(max_dims[i + 1], dim) return max_dims def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None): """ Recursively pad a nested list to match target dimensions. Args: nested_list (`list`): Nested list to pad. target_dims (`list`): Target dimensions for each level. current_level (`int`, *optional*, defaults to 0): Current nesting level. pad_value (`int`, *optional*): Value to use for padding. Returns: `list`: The padded nested list. """ if pad_value is None: pad_value = self.point_pad_value if current_level >= len(target_dims): return nested_list # Ensure we have a list if not isinstance(nested_list, list): nested_list = [nested_list] # Pad current level current_size = len(nested_list) target_size = target_dims[current_level] # Pad with appropriate values if current_level == len(target_dims) - 1: # At the coordinate level, pad with pad_value nested_list.extend([pad_value] * (target_size - current_size)) else: # At higher levels, pad with nested structures if current_size > 0: # Create appropriately sized template if current_level < len(target_dims) - 2: # For non-coordinate levels, create empty nested structure template_dims = target_dims[current_level + 1 :] template = self._create_empty_nested_structure(template_dims, pad_value) else: # For coordinate level, create list of pad_values template = [pad_value] * target_dims[current_level + 1] nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)]) else: # Create from scratch template_dims = target_dims[current_level + 1 :] template = self._create_empty_nested_structure(template_dims, pad_value) nested_list.extend([deepcopy(template) for _ in range(target_size)]) # Recursively pad sublists if current_level < len(target_dims) - 1: for i in range(len(nested_list)): if isinstance(nested_list[i], list): nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value) return nested_list def _create_empty_nested_structure(self, dims, pad_value): """ Create an empty nested structure with given dimensions filled with pad_value. Args: dims (`list`): The dimensions of the nested structure. pad_value (`int`): The value to fill the structure with. """ if len(dims) == 1: return [pad_value] * dims[0] else: return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])] def _get_nesting_level(self, input_list): """ Get the nesting level of a list structure. Args: input_list (`list`): The list to get the nesting level of. """ if isinstance(input_list, list): if len(input_list) == 0: return 1 return 1 + self._get_nesting_level(input_list[0]) elif isinstance(input_list, (np.ndarray, torch.Tensor)): # For arrays/tensors, the nesting level is the number of dimensions return len(input_list.shape) return 0 def _validate_single_input( self, data: Union[torch.Tensor, np.ndarray, list], expected_depth: int, input_name: str, expected_format: str, expected_coord_size: Optional[int] = None, ) -> list: """ Validate a single input by ensuring proper nesting and raising an error if the input is not valid. Args: data (`torch.Tensor`, `np.ndarray`, or `list`): Input data to process. expected_depth (`int`): Expected nesting depth. input_name (`str`): Name of the input for error messages. expected_format (`str`): The expected format of the input. expected_coord_size (`int`, *optional*): Expected coordinate size (2 for points, 4 for boxes, None for labels). . """ if data is None: return None # Handle tensors and numpy arrays first if isinstance(data, (torch.Tensor, np.ndarray)): # For tensors/arrays, we can directly check the number of dimensions if data.ndim != expected_depth: raise ValueError( f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions." ) elif expected_coord_size is not None: if data.shape[-1] != expected_coord_size: raise ValueError( f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}." ) return self._convert_to_nested_list(data, expected_depth) # Handle nested lists if isinstance(data, list): current_depth = self._get_nesting_level(data) if current_depth != expected_depth: raise ValueError( f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels." ) return self._convert_to_nested_list(data, expected_depth) def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False): """ Helper method to normalize coordinates in a tensor across multiple images. Args: tensor (`torch.Tensor`): Input tensor with coordinates. original_sizes (`list`): Original image sizes. is_bounding_box (`bool`, *optional*, defaults to `False`): Whether coordinates are bounding boxes. preserve_padding (`bool`, *optional*, defaults to `False`): Whether to preserve padding values (for points). """ if preserve_padding: # For points: avoid normalizing pad values mask = tensor != self.point_pad_value coord_mask = mask.all(dim=-1, keepdim=True) for img_idx in range(len(original_sizes)): if img_idx < tensor.shape[0]: original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0] normalized_coords = self._normalize_coordinates( self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box ) if preserve_padding: # Only update non-padded values img_mask = coord_mask[img_idx] tensor[img_idx] = torch.where( img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx] ) else: tensor[img_idx] = normalized_coords def post_process_masks( self, masks, original_sizes, mask_threshold=0.0, binarize=True, max_hole_area=0.0, max_sprinkle_area=0.0, apply_non_overlapping_constraints=False, **kwargs, ): """ Remove padding and upscale masks to the original image size. Args: masks (`Union[List[torch.Tensor], List[np.ndarray]]`): Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format. mask_threshold (`float`, *optional*, defaults to 0.0): Threshold for binarization and post-processing operations. binarize (`bool`, *optional*, defaults to `True`): Whether to binarize the masks. max_hole_area (`float`, *optional*, defaults to 0.0): The maximum area of a hole to fill. max_sprinkle_area (`float`, *optional*, defaults to 0.0): The maximum area of a sprinkle to fill. apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): Whether to apply non-overlapping constraints to the masks. Returns: (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. """ return self.image_processor.post_process_masks( masks, original_sizes, mask_threshold, binarize, max_hole_area, max_sprinkle_area, apply_non_overlapping_constraints, **kwargs, ) def init_video_session( self, video: Optional[VideoInput] = None, inference_device: Union[str, "torch.device"] = "cpu", inference_state_device: Union[str, "torch.device"] = None, processing_device: Union[str, "torch.device"] = None, video_storage_device: Union[str, "torch.device"] = None, max_vision_features_cache_size: int = 1, dtype: torch.dtype = torch.float32, ): """ Initializes a video session for inference. If a video is provided (async inference), the video will be processed and stored on the `video_storage_device`. Args: video (`VideoInput`, *optional*): The video to process. No need to provide when streaming. inference_device (`str` or `torch.device`, *optional*, defaults to "cpu"): The device to use for inference. inference_state_device (`str` or `torch.device`, *optional*): The device to store the inference state on. processing_device (`str` or `torch.device`, *optional*): The device to use for video processing. video_storage_device (`str` or `torch.device`, *optional*): The device to store the processed video frames on. max_vision_features_cache_size (`int`, *optional*, defaults to 1): The maximum number of vision features to cache. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The torch dtype to use for the whole session. """ video_storage_device = video_storage_device if video_storage_device is not None else inference_device inference_state_device = inference_state_device if inference_state_device is not None else inference_device processing_device = processing_device if processing_device is not None else inference_device pixel_values_video = None video_height = None video_width = None if video is not None: processed_video = self.video_processor(videos=video, device=processing_device, return_tensors="pt") pixel_values_video = processed_video.pixel_values_videos[0] video_height = processed_video.original_sizes[0][0] video_width = processed_video.original_sizes[0][1] inference_session = Sam2VideoInferenceSession( video=pixel_values_video, video_height=video_height, video_width=video_width, inference_device=inference_device, video_storage_device=video_storage_device, inference_state_device=inference_state_device, dtype=dtype, max_vision_features_cache_size=max_vision_features_cache_size, ) return inference_session def add_inputs_to_inference_session( self, inference_session: Sam2VideoInferenceSession, frame_idx: int, obj_ids: Union[list[int], int], input_points: Optional[Union[list[list[list[list[float]]]], torch.Tensor]] = None, input_labels: Optional[Union[list[list[list[int]]], torch.Tensor]] = None, input_boxes: Optional[Union[list[list[list[float]]], torch.Tensor]] = None, input_masks: Optional[Union[np.ndarray, torch.Tensor, list[np.ndarray], list[torch.Tensor]]] = None, original_size: Optional[tuple[int, int]] = None, clear_old_inputs: bool = True, ) -> Sam2VideoInferenceSession: """ Process new points, boxes, or masks for a video frame and add them to the inference session. Args: inference_session (`Sam2VideoInferenceSession`): The inference session for the video. frame_idx (`int`): The index of the frame to process. obj_ids (`list[int]` or `int`): The object ID(s) to associate with the points or box. These can be any integers and can be reused later on to specify an object. input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*): The points to add to the frame. input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*): The labels for the points. input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*): The bounding boxes to add to the frame. input_masks (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, or `list[torch.Tensor]`, *optional*): The mask(s) to add to the frame. original_size (`tuple[int, int]`, *optional*): The original size of the video. Provide when streaming. clear_old_inputs (`bool`, *optional*, defaults to `True`): Whether to clear old inputs for the object. """ if isinstance(obj_ids, int): obj_ids = [obj_ids] # Validate inputs if (input_points is not None) != (input_labels is not None): raise ValueError("points and labels must be provided together") if input_points is None and input_boxes is None and input_masks is None: raise ValueError("at least one of points, boxes, or masks must be provided as input") if input_masks is not None and (input_points is not None or input_boxes is not None): raise ValueError("masks cannot be provided together with points or boxes") if input_masks is not None: return self.process_new_mask_for_video_frame(inference_session, frame_idx, obj_ids, input_masks) else: return self.process_new_points_or_boxes_for_video_frame( inference_session, frame_idx, obj_ids, input_points, input_labels, input_boxes, original_size, clear_old_inputs, ) def process_new_points_or_boxes_for_video_frame( self, inference_session: Sam2VideoInferenceSession, frame_idx: int, obj_ids: list[int], input_points: Optional[Union[list[list[list[list[float]]]], torch.Tensor]] = None, input_labels: Optional[Union[list[list[list[int]]], torch.Tensor]] = None, input_boxes: Optional[Union[list[list[list[float]]], torch.Tensor]] = None, original_size: Optional[tuple[int, int]] = None, clear_old_inputs: bool = True, ) -> Sam2VideoInferenceSession: """ Process new points or boxes for a video frame and add them to the inference session. Args: inference_session (`Sam2VideoInferenceSession`): The inference session for the video. frame_idx (`int`): The index of the frame to process. obj_ids (`list[int]`): The object ID(s) to associate with the points or box. These can be any integers and can be reused later on to specify an object. input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*): The points to add to the frame. input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*): The labels for the points. input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*): The bounding boxes to add to the frame. original_size (`tuple[int, int]`, *optional*): The original size of the video. Provide when streaming. clear_old_inputs (`bool`, *optional*, defaults to `True`): Whether to clear old inputs for the object. """ if original_size is not None: inference_session.video_height = original_size[0] inference_session.video_width = original_size[1] elif inference_session.video_height is None or inference_session.video_width is None: raise ValueError("original_size must be provided when adding points or boxes on a first streamed frame") original_sizes = [[inference_session.video_height, inference_session.video_width]] encoded_inputs = self( input_points=input_points, input_labels=input_labels, input_boxes=input_boxes, original_sizes=original_sizes, return_tensors="pt", ) input_points = encoded_inputs.get("input_points", None) input_labels = encoded_inputs.get("input_labels", None) input_boxes = encoded_inputs.get("input_boxes", None) if input_points is not None: if input_points.shape[1] != len(obj_ids): raise ValueError( f"Number of object ids ({len(obj_ids)}) does not match number of points ({input_points.shape[1]})" ) else: input_points = torch.zeros(1, len(obj_ids), 0, 2, dtype=torch.float32) if input_labels is not None: if input_labels.shape[1] != len(obj_ids): raise ValueError( f"Number of object ids ({len(obj_ids)}) does not match number of labels ({input_labels.shape[1]})" ) else: input_labels = torch.zeros(1, len(obj_ids), 0, dtype=torch.int32) if input_boxes is not None: if input_boxes.shape[1] != len(obj_ids): raise ValueError( f"Number of object ids ({len(obj_ids)}) does not match number of boxes ({input_boxes.shape[1]})" ) if input_boxes is not None: if not clear_old_inputs: raise ValueError( "cannot add box without clearing old points, since " "box prompt must be provided before any point prompt " "(please use clear_old_points=True instead)" ) box_coords = input_boxes.reshape(1, -1, 2, 2) box_labels = torch.tensor([2, 3], dtype=torch.int32).repeat(1, box_coords.shape[1], 1) input_points = torch.cat([box_coords, input_points], dim=2) input_labels = torch.cat([box_labels, input_labels], dim=2) for obj_id, idx in zip(obj_ids, range(len(obj_ids))): obj_idx = inference_session.obj_id_to_idx(obj_id) input_points_for_obj = input_points[:, idx, :, :].unsqueeze(1) input_labels_for_obj = input_labels[:, idx, :].unsqueeze(1) # Handle existing points if not clear_old_inputs: existing_points = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None) if existing_points is not None: # Concatenate with existing points input_points_for_obj = torch.cat( [existing_points["point_coords"].to(input_points_for_obj.device), input_points_for_obj], dim=2 ) input_labels_for_obj = torch.cat( [existing_points["point_labels"].to(input_labels_for_obj.device), input_labels_for_obj], dim=2 ) point_inputs = { "point_coords": input_points_for_obj, "point_labels": input_labels_for_obj, } inference_session.add_point_inputs(obj_idx, frame_idx, point_inputs) inference_session.remove_mask_inputs(obj_idx, frame_idx) # Clear any mask inputs inference_session.obj_with_new_inputs = obj_ids def process_new_mask_for_video_frame( self, inference_session: Sam2VideoInferenceSession, frame_idx: int, obj_ids: list[int], input_masks: Union[np.ndarray, torch.Tensor, list[np.ndarray], list[torch.Tensor]], ): """ Add new mask to a frame and add them to the inference session. Args: inference_session (`Sam2VideoInferenceSession`): The inference session for the video. frame_idx (`int`): The index of the frame to process. obj_ids (`list[int]`): The object ID(s) to associate with the mask. These can be any integers and can be reused later on to specify an object. input_masks (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, or `list[torch.Tensor]`): The mask(s) to add to the frame. """ if not isinstance(input_masks, list): input_masks = [input_masks] if len(input_masks) != len(obj_ids): raise ValueError( f"Number of object ids ({len(obj_ids)}) does not match number of masks ({len(input_masks)})" ) for obj_id, mask in zip(obj_ids, input_masks): obj_idx = inference_session.obj_id_to_idx(obj_id) device = inference_session.inference_device # Process mask if not isinstance(mask, torch.Tensor): mask = torch.tensor(mask, dtype=torch.bool) nb_dim = mask.dim() if nb_dim > 4 or nb_dim < 2: raise ValueError(f"Mask has an unsupported number of dimensions: {nb_dim}") for i in range(4 - nb_dim): mask = mask.unsqueeze(0) mask_H, mask_W = mask.shape[-2:] mask_inputs_orig = mask.to(device) mask_inputs_orig = mask_inputs_orig.float().to(device) # Resize mask if needed if mask_H != self.target_size or mask_W != self.target_size: mask_inputs = torch.nn.functional.interpolate( mask_inputs_orig, size=(self.target_size, self.target_size), align_corners=False, mode="bilinear", antialias=True, ) mask_inputs = (mask_inputs >= 0.5).float() else: mask_inputs = mask_inputs_orig inference_session.add_mask_inputs(obj_idx, frame_idx, mask_inputs) inference_session.remove_point_inputs(obj_idx, frame_idx) # Clear any point inputs inference_session.obj_with_new_inputs = obj_ids __all__ = ["Sam2VideoProcessor"]