# coding=utf-8 # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch DINOv2 model.""" import collections.abc from typing import Callable, Optional, Union import torch from torch import nn from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import TransformersKwargs, auto_docstring, logging, torch_int from ...utils.backbone_utils import BackboneMixin from ...utils.generic import can_return_tuple, check_model_inputs from .configuration_dinov2 import Dinov2Config logger = logging.get_logger(__name__) class Dinov2Embeddings(nn.Module): """ Construct the CLS token, mask token, position and patch embeddings. """ def __init__(self, config: Dinov2Config) -> None: super().__init__() self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) if config.use_mask_token: self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) self.patch_embeddings = Dinov2PatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.patch_size = config.patch_size self.use_mask_token = config.use_mask_token self.config = config def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) target_dtype = patch_pos_embed.dtype patch_pos_embed = nn.functional.interpolate( patch_pos_embed.to(torch.float32), size=(new_height, new_width), mode="bicubic", align_corners=False, ).to(dtype=target_dtype) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embeddings.projection.weight.dtype embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) if bool_masked_pos is not None and self.use_mask_token: embeddings = torch.where( bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings ) # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) embeddings = self.dropout(embeddings) return embeddings class Dinov2PatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: num_channels = pixel_values.shape[1] if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings # Copied from transformers.models.vit.modeling_vit.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): # Take the dot product between "query" and "key" to get the raw attention scores. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling # Normalize the attention scores to probabilities. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) # Mask heads if we want to if attention_mask is not None: attn_weights = attn_weights * attention_mask attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 class Dinov2SelfAttention(nn.Module): def __init__(self, config: Dinov2Config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size {config.hidden_size} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.dropout_prob = config.attention_probs_dropout_prob self.scaling = self.attention_head_size**-0.5 self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: batch_size = hidden_states.shape[0] new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2) value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] context_layer, attention_probs = attention_interface( self, query_layer, key_layer, value_layer, head_mask, is_causal=self.is_causal, scaling=self.scaling, dropout=0.0 if not self.training else self.dropout_prob, ) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.reshape(new_context_layer_shape) return context_layer, attention_probs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 class Dinov2SelfOutput(nn.Module): """ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the layernorm applied before each block. """ def __init__(self, config: Dinov2Config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 class Dinov2Attention(nn.Module): def __init__(self, config: Dinov2Config): super().__init__() self.attention = Dinov2SelfAttention(config) self.output = Dinov2SelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads: set[int]): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads ) # Prune linear layers self.attention.query = prune_linear_layer(self.attention.query, index) self.attention.key = prune_linear_layer(self.attention.key, index) self.attention.value = prune_linear_layer(self.attention.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor: self_attn_output, _ = self.attention(hidden_states, head_mask) output = self.output(self_attn_output, hidden_states) return output class Dinov2LayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return hidden_state * self.lambda1 # Copied from transformers.models.beit.modeling_beit.drop_path def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize output = input.div(keep_prob) * random_tensor return output # Copied from transformers.models.beit.modeling_beit.BeitDropPath class Dinov2DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: return f"p={self.drop_prob}" class Dinov2MLP(nn.Module): def __init__(self, config) -> None: super().__init__() in_features = out_features = config.hidden_size hidden_features = int(config.hidden_size * config.mlp_ratio) self.fc1 = nn.Linear(in_features, hidden_features, bias=True) if isinstance(config.hidden_act, str): self.activation = ACT2FN[config.hidden_act] else: self.activation = config.hidden_act self.fc2 = nn.Linear(hidden_features, out_features, bias=True) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.fc1(hidden_state) hidden_state = self.activation(hidden_state) hidden_state = self.fc2(hidden_state) return hidden_state class Dinov2SwiGLUFFN(nn.Module): def __init__(self, config) -> None: super().__init__() in_features = out_features = config.hidden_size hidden_features = int(config.hidden_size * config.mlp_ratio) hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) self.weights_out = nn.Linear(hidden_features, out_features, bias=True) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.weights_in(hidden_state) x1, x2 = hidden_state.chunk(2, dim=-1) hidden = nn.functional.silu(x1) * x2 return self.weights_out(hidden) class Dinov2Layer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: Dinov2Config) -> None: super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = Dinov2Attention(config) self.layer_scale1 = Dinov2LayerScale(config) self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_swiglu_ffn: self.mlp = Dinov2SwiGLUFFN(config) else: self.mlp = Dinov2MLP(config) self.layer_scale2 = Dinov2LayerScale(config) def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states_norm = self.norm1(hidden_states) self_attention_output = self.attention(hidden_states_norm, head_mask) self_attention_output = self.layer_scale1(self_attention_output) # first residual connection hidden_states = self.drop_path(self_attention_output) + hidden_states # in Dinov2, layernorm is also applied after self-attention layer_output = self.norm2(hidden_states) layer_output = self.mlp(layer_output) layer_output = self.layer_scale2(layer_output) # second residual connection layer_output = self.drop_path(layer_output) + hidden_states return layer_output class Dinov2Encoder(nn.Module): def __init__(self, config: Dinov2Config): super().__init__() self.config = config self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False ) -> BaseModelOutput: all_hidden_states = [hidden_states] if output_hidden_states else None for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] if head_mask is not None else None hidden_states = layer_module(hidden_states, layer_head_mask) if all_hidden_states: all_hidden_states.append(hidden_states) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=tuple(all_hidden_states) if all_hidden_states else None, ) @auto_docstring class Dinov2PreTrainedModel(PreTrainedModel): config: Dinov2Config base_model_prefix = "dinov2" main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["Dinov2Layer"] _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { "attentions": Dinov2SelfAttention, } def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues module.weight.data = nn.init.trunc_normal_( module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, Dinov2Embeddings): module.position_embeddings.data = nn.init.trunc_normal_( module.position_embeddings.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.position_embeddings.dtype) module.cls_token.data = nn.init.trunc_normal_( module.cls_token.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.cls_token.dtype) if self.config.use_mask_token: module.mask_token.data.zero_() elif isinstance(module, Dinov2LayerScale): module.lambda1.data.fill_(self.config.layerscale_value) @auto_docstring class Dinov2Model(Dinov2PreTrainedModel): def __init__(self, config: Dinov2Config): super().__init__(config) self.config = config self.embeddings = Dinov2Embeddings(config) self.encoder = Dinov2Encoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> Dinov2PatchEmbeddings: return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) @check_model_inputs @auto_docstring def forward( self, pixel_values: Optional[torch.Tensor] = None, bool_masked_pos: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, **kwargs, ) -> BaseModelOutputWithPooling: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for pre-training. """ if output_hidden_states is None: output_hidden_states = self.config.output_hidden_states if pixel_values is None: raise ValueError("You have to specify pixel_values") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) encoder_outputs: BaseModelOutput = self.encoder( embedding_output, head_mask=head_mask, output_hidden_states=output_hidden_states ) sequence_output = encoder_outputs.last_hidden_state sequence_output = self.layernorm(sequence_output) pooled_output = sequence_output[:, 0, :] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, ) @auto_docstring( custom_intro=""" Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. """ ) class Dinov2ForImageClassification(Dinov2PreTrainedModel): def __init__(self, config: Dinov2Config) -> None: super().__init__(config) self.num_labels = config.num_labels self.dinov2 = Dinov2Model(config) # Classifier head self.classifier = ( nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() ) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, pixel_values: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> ImageClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ outputs: BaseModelOutputWithPooling = self.dinov2(pixel_values, head_mask=head_mask, **kwargs) sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size cls_token = sequence_output[:, 0] patch_tokens = sequence_output[:, 1:] linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) logits = self.classifier(linear_input) loss = None if labels is not None: loss = self.loss_function(labels, logits, self.config, **kwargs) return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @auto_docstring( custom_intro=""" Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. """ ) class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config) super()._init_backbone(config) self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] self.embeddings = Dinov2Embeddings(config) self.encoder = Dinov2Encoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> Dinov2PatchEmbeddings: return self.embeddings.patch_embeddings @check_model_inputs @auto_docstring def forward( self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, **kwargs ) -> BackboneOutput: r""" Examples: ```python >>> from transformers import AutoImageProcessor, AutoBackbone >>> import torch >>> from PIL import Image >>> import requests >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") >>> model = AutoBackbone.from_pretrained( ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] ... ) >>> inputs = processor(image, return_tensors="pt") >>> outputs = model(**inputs) >>> feature_maps = outputs.feature_maps >>> list(feature_maps[-1].shape) [1, 768, 16, 16] ```""" if output_hidden_states is None: output_hidden_states = self.config.output_hidden_states embedding_output = self.embeddings(pixel_values) output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True) hidden_states = output.hidden_states feature_maps = [] for stage, hidden_state in zip(self.stage_names, hidden_states): if stage in self.out_features: if self.config.apply_layernorm: hidden_state = self.layernorm(hidden_state) if self.config.reshape_hidden_states: hidden_state = hidden_state[:, 1:] # this was actually a bug in the original implementation that we copied here, # cause normally the order is height, width batch_size, _, height, width = pixel_values.shape patch_size = self.config.patch_size hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() feature_maps.append(hidden_state) return BackboneOutput( feature_maps=tuple(feature_maps), hidden_states=hidden_states if output_hidden_states else None, ) __all__ = ["Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", "Dinov2Backbone"]