import pathlib from typing import Optional, Union import torch from torchvision.transforms.v2 import functional as F from transformers.models.detr.image_processing_detr_fast import DetrFastImageProcessorKwargs, DetrImageProcessorFast from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, get_max_height_width from ...image_transforms import center_to_corners_format from ...image_utils import ( IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, AnnotationFormat, AnnotationType, ChannelDimension, ImageInput, PILImageResampling, get_image_size, validate_annotations, ) from ...processing_utils import Unpack from ...utils import ( TensorType, logging, requires_backends, ) logger = logging.get_logger(__name__) SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,) def prepare_coco_detection_annotation( image, target, return_segmentation_masks: bool = False, input_data_format: Optional[Union[ChannelDimension, str]] = None, ): """ Convert the target in COCO format into the format expected by RT-DETR. """ image_height, image_width = image.size()[-2:] image_id = target["image_id"] image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device) # Get all COCO annotations for the given image. annotations = target["annotations"] classes = [] area = [] boxes = [] keypoints = [] for obj in annotations: if "iscrowd" not in obj or obj["iscrowd"] == 0: classes.append(obj["category_id"]) area.append(obj["area"]) boxes.append(obj["bbox"]) if "keypoints" in obj: keypoints.append(obj["keypoints"]) classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device) area = torch.as_tensor(area, dtype=torch.float32, device=image.device) iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device) # guard against no boxes via resizing boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) new_target = { "image_id": image_id, "class_labels": classes[keep], "boxes": boxes[keep], "area": area[keep], "iscrowd": iscrowd[keep], "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device), } if keypoints: keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device) # Apply the keep mask here to filter the relevant annotations keypoints = keypoints[keep] num_keypoints = keypoints.shape[0] keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints new_target["keypoints"] = keypoints return new_target class RTDetrFastImageProcessorKwargs(DetrFastImageProcessorKwargs): pass class RTDetrImageProcessorFast(DetrImageProcessorFast): resample = PILImageResampling.BILINEAR image_mean = IMAGENET_DEFAULT_MEAN image_std = IMAGENET_DEFAULT_STD format = AnnotationFormat.COCO_DETECTION do_convert_annotations = True do_resize = True do_rescale = True do_normalize = False do_pad = False size = {"height": 640, "width": 640} default_to_square = False model_input_names = ["pixel_values", "pixel_mask"] valid_kwargs = RTDetrFastImageProcessorKwargs def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None: # Backwards compatibility do_convert_annotations = kwargs.get("do_convert_annotations") do_normalize = kwargs.get("do_normalize") if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None: self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize BaseImageProcessorFast.__init__(self, **kwargs) def preprocess( self, images: ImageInput, annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, masks_path: Optional[Union[str, pathlib.Path]] = None, **kwargs: Unpack[RTDetrFastImageProcessorKwargs], ) -> BatchFeature: return BaseImageProcessorFast.preprocess(self, images, annotations, masks_path, **kwargs) def prepare_annotation( self, image: torch.Tensor, target: dict, format: Optional[AnnotationFormat] = None, return_segmentation_masks: Optional[bool] = None, masks_path: Optional[Union[str, pathlib.Path]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> dict: format = format if format is not None else self.format if format == AnnotationFormat.COCO_DETECTION: return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks target = prepare_coco_detection_annotation( image, target, return_segmentation_masks, input_data_format=input_data_format ) else: raise ValueError(f"Format {format} is not supported.") return target def _preprocess( self, images: list["torch.Tensor"], annotations: Optional[Union[AnnotationType, list[AnnotationType]]], masks_path: Optional[Union[str, pathlib.Path]], return_segmentation_masks: bool, do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], do_rescale: bool, rescale_factor: float, do_normalize: bool, do_convert_annotations: bool, image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], do_pad: bool, pad_size: Optional[SizeDict], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], **kwargs, ) -> BatchFeature: """ Preprocess an image or a batch of images so that it can be used by the model. """ if annotations is not None and isinstance(annotations, dict): annotations = [annotations] if annotations is not None and len(images) != len(annotations): raise ValueError( f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." ) format = AnnotationFormat(format) if annotations is not None: validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) data = {} processed_images = [] processed_annotations = [] pixel_masks = [] # Initialize pixel_masks here for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) if annotations is not None: annotation = self.prepare_annotation( image, annotation, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path, input_data_format=ChannelDimension.FIRST, ) if do_resize: resized_image = self.resize(image, size=size, interpolation=interpolation) if annotations is not None: annotation = self.resize_annotation( annotation, orig_size=image.size()[-2:], target_size=resized_image.size()[-2:], ) image = resized_image # Fused rescale and normalize image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) if do_convert_annotations and annotations is not None: annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) processed_images.append(image) processed_annotations.append(annotation) images = processed_images annotations = processed_annotations if annotations is not None else None if do_pad: # depends on all resized image shapes so we need another loop if pad_size is not None: padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(images) padded_images = [] padded_annotations = [] for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} if padded_size == image.size()[-2:]: padded_images.append(image) pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device)) padded_annotations.append(annotation) continue image, pixel_mask, annotation = self.pad( image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations ) padded_images.append(image) padded_annotations.append(annotation) pixel_masks.append(pixel_mask) images = padded_images annotations = padded_annotations if annotations is not None else None data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)}) data.update({"pixel_values": torch.stack(images, dim=0)}) encoded_inputs = BatchFeature(data, tensor_type=return_tensors) if annotations is not None: encoded_inputs["labels"] = [ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations ] return encoded_inputs def post_process_object_detection( self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, use_focal_loss: bool = True, ): """ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch. Args: outputs ([`DetrObjectDetectionOutput`]): Raw outputs of the model. threshold (`float`, *optional*, defaults to 0.5): Score threshold to keep object detection predictions. target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*): Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size `(height, width)` of each image in the batch. If unset, predictions will not be resized. use_focal_loss (`bool` defaults to `True`): Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied to compute the scores of each detection, otherwise, a softmax function is used. Returns: `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model. """ requires_backends(self, ["torch"]) out_logits, out_bbox = outputs.logits, outputs.pred_boxes # convert from relative cxcywh to absolute xyxy boxes = center_to_corners_format(out_bbox) if target_sizes is not None: if len(out_logits) != len(target_sizes): raise ValueError( "Make sure that you pass in as many target sizes as the batch dimension of the logits" ) if isinstance(target_sizes, list): img_h, img_w = torch.as_tensor(target_sizes).unbind(1) else: img_h, img_w = target_sizes.unbind(1) scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) boxes = boxes * scale_fct[:, None, :] num_top_queries = out_logits.shape[1] num_classes = out_logits.shape[2] if use_focal_loss: scores = torch.nn.functional.sigmoid(out_logits) scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1) labels = index % num_classes index = index // num_classes boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) else: scores = torch.nn.functional.softmax(out_logits)[:, :, :-1] scores, labels = scores.max(dim=-1) if scores.shape[1] > num_top_queries: scores, index = torch.topk(scores, num_top_queries, dim=-1) labels = torch.gather(labels, dim=1, index=index) boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) results = [] for score, label, box in zip(scores, labels, boxes): results.append( { "scores": score[score > threshold], "labels": label[score > threshold], "boxes": box[score > threshold], } ) return results def from_dict(self): raise NotImplementedError("No need to override this method for RT-DETR yet.") def post_process(self): raise NotImplementedError("Post-processing is not implemented for RT-DETR yet.") def post_process_segmentation(self): raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.") def post_process_instance(self): raise NotImplementedError("Instance post-processing is not implemented for RT-DETR yet.") def post_process_panoptic(self): raise NotImplementedError("Panoptic post-processing is not implemented for RT-DETR yet.") def post_process_instance_segmentation(self): raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.") def post_process_semantic_segmentation(self): raise NotImplementedError("Semantic segmentation post-processing is not implemented for RT-DETR yet.") def post_process_panoptic_segmentation(self): raise NotImplementedError("Panoptic segmentation post-processing is not implemented for RT-DETR yet.") __all__ = ["RTDetrImageProcessorFast"]