# coding=utf-8 # Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TF 2.0 GroupViT model.""" from __future__ import annotations import collections.abc import math from dataclasses import dataclass from typing import Any import numpy as np import tensorflow as tf from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling from ...modeling_tf_utils import ( TFModelInputType, TFPreTrainedModel, get_initializer, keras, keras_serializable, unpack_inputs, ) from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_tensorflow_probability_available, logging, replace_return_docstrings, ) from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig logger = logging.get_logger(__name__) # soft dependency if is_tensorflow_probability_available(): try: import tensorflow_probability as tfp # On the first call, check whether a compatible version of TensorFlow is installed # TensorFlow Probability depends on a recent stable release of TensorFlow _ = tfp.distributions.Normal(loc=0.0, scale=1.0) except ImportError: logger.error( "GroupViT models are not usable since `tensorflow_probability` can't be loaded. " "It seems you have `tensorflow_probability` installed with the wrong tensorflow version." "Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability." ) else: try: import tensorflow_probability as tfp # On the first call, check whether a compatible version of TensorFlow is installed # TensorFlow Probability depends on a recent stable release of TensorFlow _ = tfp.distributions.Normal(loc=0.0, scale=1.0) except ImportError: pass _CHECKPOINT_FOR_DOC = "nvidia/groupvit-gcc-yfcc" LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart._expand_mask def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len one_cst = tf.constant(1.0) mask = tf.cast(mask, dtype=one_cst.dtype) expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) return (one_cst - expanded_mask) * LARGE_NEGATIVE # contrastive loss function, adapted from # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html def contrastive_loss(logits: tf.Tensor) -> tf.Tensor: return tf.math.reduce_mean( keras.metrics.sparse_categorical_crossentropy( y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True ) ) # Copied from transformers.models.clip.modeling_tf_clip.clip_loss with clip->groupvit def groupvit_loss(similarity: tf.Tensor) -> tf.Tensor: caption_loss = contrastive_loss(similarity) image_loss = contrastive_loss(tf.transpose(similarity)) return (caption_loss + image_loss) / 2.0 def hard_softmax(logits: tf.Tensor, dim: int) -> tf.Tensor: y_soft = stable_softmax(logits, dim) # Straight through. index = tf.argmax(y_soft, dim) y_hard = tf.one_hot( index, depth=shape_list(logits)[dim], # TensorFlow expects axis to be -1 or between [0, 3). But received: -2 # This is why the following code snippet is used. axis=range(len(shape_list(logits)))[dim], dtype=y_soft.dtype, ) ret = y_hard - tf.stop_gradient(y_soft) + y_soft return ret def gumbel_softmax(logits: tf.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> tf.Tensor: gumbel_dist = tfp.distributions.Gumbel(0.0, 1.0) gumbels = gumbel_dist.sample(tf.shape(logits), dtype=logits.dtype) gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) y_soft = stable_softmax(gumbels, dim) if hard: # Straight through. index = tf.argmax(y_soft, dim) y_hard = tf.one_hot( index, depth=shape_list(logits)[dim], # TensorFlow expects axis to be -1 or between [0, 3). But received: -2 # This is why the following code snippet is used. axis=range(len(shape_list(logits)))[dim], dtype=y_soft.dtype, ) ret = y_hard - tf.stop_gradient(y_soft) + y_soft else: # Reparametrization trick. ret = y_soft return ret def resize_attention_map(attentions: tf.Tensor, height: int, width: int, align_corners: bool = False) -> tf.Tensor: """ Args: attentions (`tf.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width] height (`int`): height of the output attention map width (`int`): width of the output attention map align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`. Returns: `tf.Tensor`: resized attention map of shape [batch_size, groups, height, width] """ scale = (height * width // attentions.shape[2]) ** 0.5 if height > width: feat_width = int(np.round(width / scale)) feat_height = shape_list(attentions)[2] // feat_width else: feat_height = int(np.round(height / scale)) feat_width = shape_list(attentions)[2] // feat_height batch_size = shape_list(attentions)[0] groups = shape_list(attentions)[1] # number of group token # [batch_size, groups, height x width, groups] -> [batch_size, groups, height, width] attentions = tf.reshape(attentions, (batch_size, groups, feat_height, feat_width)) attentions = tf.transpose(attentions, perm=(0, 2, 3, 1)) if align_corners: attentions = tf.compat.v1.image.resize( attentions, size=(height, width), method="bilinear", align_corners=align_corners, ) else: attentions = tf.image.resize(attentions, size=(height, width), method="bilinear") attentions = tf.transpose(attentions, perm=(0, 3, 1, 2)) return attentions def get_grouping_from_attentions(attentions: tuple[tf.Tensor], hw_shape: tuple[int]) -> tf.Tensor: """ Args: attentions (`tuple(tf.Tensor)`: tuple of attention maps returned by `TFGroupViTVisionTransformer` hw_shape (`tuple(int)`): height and width of the output attention map Returns: `tf.Tensor`: the attention map of shape [batch_size, groups, height, width] """ attn_maps = [] prev_attn_masks = None for attn_masks in attentions: # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups] attn_masks = tf.transpose(attn_masks, perm=(0, 2, 1)) if prev_attn_masks is None: prev_attn_masks = attn_masks else: prev_attn_masks = tf.matmul(prev_attn_masks, attn_masks) # [batch_size, height x width, num_groups] -> [batch_size, num_groups, height x width] -> [batch_size, num_groups, height, width] cur_attn_map = resize_attention_map(tf.transpose(prev_attn_masks, perm=(0, 2, 1)), *hw_shape) attn_maps.append(cur_attn_map) # [batch_size, num_groups, height, width] final_grouping = attn_maps[-1] return tf.stop_gradient(final_grouping) @dataclass class TFGroupViTModelOutput(ModelOutput): """ Args: loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): Contrastive loss for image-text similarity. logits_per_image (`tf.Tensor` of shape `(image_batch_size, text_batch_size)`): The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text similarity scores. logits_per_text (`tf.Tensor` of shape `(text_batch_size, image_batch_size)`): The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image similarity scores. segmentation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): Classification scores for each pixel. The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the original image size as post-processing. You should always check your logits shape and resize as needed. text_embeds (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`TFGroupViTTextModel`]. image_embeds (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`TFGroupViTVisionModel`]. text_model_output (`TFBaseModelOutputWithPooling`): The output of the [`TFGroupViTTextModel`]. vision_model_output (`TFBaseModelOutputWithPooling`): The output of the [`TFGroupViTVisionModel`]. """ loss: tf.Tensor | None = None logits_per_image: tf.Tensor | None = None logits_per_text: tf.Tensor | None = None segmentation_logits: tf.Tensor | None = None text_embeds: tf.Tensor | None = None image_embeds: tf.Tensor | None = None text_model_output: TFBaseModelOutputWithPooling = None vision_model_output: TFBaseModelOutputWithPooling = None def to_tuple(self) -> tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) class TFGroupViTCrossAttentionLayer(keras.layers.Layer): def __init__(self, config: GroupViTVisionConfig, **kwargs): super().__init__(**kwargs) self.attn = TFGroupViTAttention(config, name="attn") self.norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm2") self.mlp = TFGroupViTMLP(config, name="mlp") self.norm_post = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post") self.config = config def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False) -> tf.Tensor: x = query x = x + self.attn(query, encoder_hidden_states=key)[0] x = x + self.mlp(self.norm2(x)) x = self.norm_post(x) return x def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "attn", None) is not None: with tf.name_scope(self.attn.name): self.attn.build(None) if getattr(self, "norm2", None) is not None: with tf.name_scope(self.norm2.name): self.norm2.build([None, None, self.config.hidden_size]) if getattr(self, "mlp", None) is not None: with tf.name_scope(self.mlp.name): self.mlp.build(None) if getattr(self, "norm_post", None) is not None: with tf.name_scope(self.norm_post.name): self.norm_post.build([None, None, self.config.hidden_size]) class TFGroupViTAssignAttention(keras.layers.Layer): def __init__(self, config: GroupViTVisionConfig, **kwargs): super().__init__(**kwargs) self.scale = config.hidden_size**-0.5 self.q_proj = keras.layers.Dense(config.hidden_size, name="q_proj") self.k_proj = keras.layers.Dense(config.hidden_size, name="k_proj") self.v_proj = keras.layers.Dense(config.hidden_size, name="v_proj") self.proj = keras.layers.Dense(config.hidden_size, name="proj") self.assign_eps = config.assign_eps self.config = config def get_attn(self, attn: tf.Tensor, gumbel: bool = True, hard: bool = True, training: bool = False) -> tf.Tensor: if gumbel and training: attn = gumbel_softmax(attn, dim=-2, hard=hard) else: if hard: attn = hard_softmax(attn, dim=-2) else: attn = stable_softmax(attn, axis=-2) return attn def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False): value = key # [batch_size, query_length, channels] query = self.q_proj(query) # [batch_size, key_length, channels] key = self.k_proj(key) # [batch_size, key_length, channels] value = self.v_proj(value) # [batch_size, query_length, key_length] raw_attn = tf.matmul(query, key, transpose_b=True) * self.scale attn = self.get_attn(raw_attn, training=training) soft_attn = self.get_attn(raw_attn, training=training, gumbel=False, hard=False) attn = attn / (tf.math.reduce_sum(attn, axis=-1, keepdims=True) + self.assign_eps) out = tf.matmul(attn, value) out = self.proj(out) return out, soft_attn def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "q_proj", None) is not None: with tf.name_scope(self.q_proj.name): self.q_proj.build([None, None, self.config.hidden_size]) if getattr(self, "k_proj", None) is not None: with tf.name_scope(self.k_proj.name): self.k_proj.build([None, None, self.config.hidden_size]) if getattr(self, "v_proj", None) is not None: with tf.name_scope(self.v_proj.name): self.v_proj.build([None, None, self.config.hidden_size]) if getattr(self, "proj", None) is not None: with tf.name_scope(self.proj.name): self.proj.build([None, None, self.config.hidden_size]) class TFGroupViTTokenAssign(keras.layers.Layer): def __init__(self, config: GroupViTVisionConfig, num_group_token: int, num_output_group: int, **kwargs): super().__init__(**kwargs) self.num_output_group = num_output_group # norm on group_tokens self.norm_tokens = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_tokens") assign_mlp_ratio = ( config.assign_mlp_ratio if isinstance(config.assign_mlp_ratio, collections.abc.Iterable) else (config.assign_mlp_ratio, config.assign_mlp_ratio) ) tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio] self.mlp_inter = TFGroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group, name="mlp_inter") self.norm_post_tokens = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post_tokens") # norm on x self.norm_x = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_x") self.pre_assign_attn = TFGroupViTCrossAttentionLayer(config, name="pre_assign_attn") self.assign = TFGroupViTAssignAttention(config, name="assign") self.norm_new_x = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_new_x") self.mlp_channels = TFGroupViTMLP( config, config.hidden_size, channels_dim, config.hidden_size, name="mlp_channels" ) self.config = config def project_group_token(self, group_tokens: tf.Tensor) -> tf.Tensor: """ Args: group_tokens (tf.Tensor): group tokens, [batch_size, num_group_tokens, channels] Returns: projected_group_tokens (tf.Tensor): [batch_size, num_output_groups, channels] """ # [B, num_output_groups, C] <- [B, num_group_tokens, C] projected_group_tokens = self.mlp_inter(group_tokens) projected_group_tokens = self.norm_post_tokens(projected_group_tokens) return projected_group_tokens def call(self, image_tokens: tf.Tensor, group_tokens: tf.Tensor, training: bool = False): """ Args: image_tokens (`tf.Tensor`): image tokens, of shape [batch_size, input_length, channels] group_tokens (`tf.Tensor`): group tokens, [batch_size, num_group_tokens, channels] """ group_tokens = self.norm_tokens(group_tokens) image_tokens = self.norm_x(image_tokens) # [batch_size, num_output_groups, channels] projected_group_tokens = self.project_group_token(group_tokens) projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens) new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens) new_image_tokens += projected_group_tokens new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens)) return new_image_tokens, attention def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "norm_tokens", None) is not None: with tf.name_scope(self.norm_tokens.name): self.norm_tokens.build([None, None, self.config.hidden_size]) if getattr(self, "mlp_inter", None) is not None: with tf.name_scope(self.mlp_inter.name): self.mlp_inter.build(None) if getattr(self, "norm_post_tokens", None) is not None: with tf.name_scope(self.norm_post_tokens.name): self.norm_post_tokens.build([None, None, self.config.hidden_size]) if getattr(self, "norm_x", None) is not None: with tf.name_scope(self.norm_x.name): self.norm_x.build([None, None, self.config.hidden_size]) if getattr(self, "pre_assign_attn", None) is not None: with tf.name_scope(self.pre_assign_attn.name): self.pre_assign_attn.build(None) if getattr(self, "assign", None) is not None: with tf.name_scope(self.assign.name): self.assign.build(None) if getattr(self, "norm_new_x", None) is not None: with tf.name_scope(self.norm_new_x.name): self.norm_new_x.build([None, None, self.config.hidden_size]) if getattr(self, "mlp_channels", None) is not None: with tf.name_scope(self.mlp_channels.name): self.mlp_channels.build(None) # Adapted from transformers.models.vit.modeling_tf_vit.TFViTPatchEmbeddings with ViT->GroupViT class TFGroupViTPatchEmbeddings(keras.layers.Layer): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config: GroupViTConfig, **kwargs): super().__init__(**kwargs) image_size, patch_size = config.image_size, config.patch_size num_channels = config.num_channels # hidden_size is a member as it will be required in the call method self.hidden_size = config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_patches = num_patches self.num_channels = num_channels self.config = config self.projection = keras.layers.Conv2D( filters=self.hidden_size, kernel_size=patch_size, strides=patch_size, padding="valid", data_format="channels_last", use_bias=True, kernel_initializer=get_initializer(self.config.initializer_range), bias_initializer="zeros", name="projection", ) def call( self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False ) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) if tf.executing_eagerly() and num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) if ( not interpolate_pos_encoding and tf.executing_eagerly() and (height != self.image_size[0] or width != self.image_size[1]) ): raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. # So change the input format from `NCHW` to `NHWC`. # shape = (batch_size, in_height, in_width, in_channels=num_channels) pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) projection = self.projection(pixel_values) # Change the 2D spatial dimensions to a single temporal dimension. # shape = (batch_size, num_patches, out_channels=embed_dim) num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) # In the TFGroupViTVisionEmbeddings the embeddings from this layer will be layer normalized # LayerNormalization layer needs to have static last dimension (otherwise the test_keras_save_load fails with symbolic tensors) # This is why we have used the hidden_size in the reshape method embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, self.hidden_size)) return embeddings def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "projection", None) is not None: with tf.name_scope(self.projection.name): self.projection.build([None, None, None, self.num_channels]) # Adapted from transformers.vit.modeling_tf_vit.TFViTEmbeddings class TFGroupViTVisionEmbeddings(keras.layers.Layer): """ Construct the position and patch embeddings. """ def __init__(self, config: GroupViTVisionConfig, **kwargs): super().__init__(**kwargs) self.patch_embeddings = TFGroupViTPatchEmbeddings(config, name="patch_embeddings") self.dropout = keras.layers.Dropout(rate=config.dropout, name="dropout") self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") self.config = config def build(self, input_shape=None): num_patches = self.patch_embeddings.num_patches self.position_embeddings = self.add_weight( shape=(1, num_patches, self.config.hidden_size), initializer="zeros", trainable=True, name="position_embeddings", ) if self.built: return self.built = True if getattr(self, "patch_embeddings", None) is not None: with tf.name_scope(self.patch_embeddings.name): self.patch_embeddings.build(None) if getattr(self, "dropout", None) is not None: with tf.name_scope(self.dropout.name): self.dropout.build(None) if getattr(self, "layernorm", None) is not None: with tf.name_scope(self.layernorm.name): self.layernorm.build([None, None, self.config.hidden_size]) def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. Source: https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ batch_size, num_patches, dim = shape_list(embeddings) num_positions = shape_list(self.position_embeddings)[1] if num_patches == num_positions and height == width: return self.position_embeddings patch_pos_embed = self.position_embeddings h0 = height // self.config.patch_size w0 = width // self.config.patch_size patch_pos_embed = tf.image.resize( images=tf.reshape( patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) ), size=(h0, w0), method="bicubic", ) patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) return patch_pos_embed def call( self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False ) -> tf.Tensor: _, _, height, width = shape_list(pixel_values) embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) embeddings = self.layernorm(embeddings) # add positional encoding to each token if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextEmbeddings with CLIP->GroupViT class TFGroupViTTextEmbeddings(keras.layers.Layer): def __init__(self, config: GroupViTTextConfig, **kwargs): super().__init__(**kwargs) self.embed_dim = config.hidden_size self.config = config def build(self, input_shape: tf.TensorShape = None): with tf.name_scope("token_embedding"): self.weight = self.add_weight( shape=(self.config.vocab_size, self.embed_dim), initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), trainable=True, name="weight", ) with tf.name_scope("position_embedding"): self.position_embedding = self.add_weight( shape=(self.config.max_position_embeddings, self.embed_dim), initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range), trainable=True, name="embeddings", ) super().build(input_shape) def call( self, input_ids: tf.Tensor | None = None, position_ids: tf.Tensor | None = None, inputs_embeds: tf.Tensor | None = None, ) -> tf.Tensor: """ Applies embedding based on inputs tensor. Returns: final_embeddings (`tf.Tensor`): output embedding tensor. """ if input_ids is None and inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: check_embeddings_within_bounds(input_ids, self.config.vocab_size) inputs_embeds = tf.gather(params=self.weight, indices=input_ids) input_shape = shape_list(inputs_embeds)[:-1] if position_ids is None: position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) position_embeds = tf.gather(params=self.position_embedding, indices=position_ids) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) final_embeddings = inputs_embeds + position_embeds return final_embeddings class TFGroupViTStage(keras.layers.Layer): """This corresponds to the `GroupingLayer` class in the GroupViT implementation.""" def __init__( self, config: GroupViTVisionConfig, depth: int, num_prev_group_token: int, num_group_token: int, num_output_group: int, **kwargs, ): super().__init__(**kwargs) self.config = config self.depth = depth self.num_group_token = num_group_token self.layers = [TFGroupViTEncoderLayer(config, name=f"layers_._{i}") for i in range(depth)] if num_group_token > 0: self.downsample = TFGroupViTTokenAssign( config=config, num_group_token=num_group_token, num_output_group=num_output_group, name="downsample", ) else: self.downsample = None if num_prev_group_token > 0 and num_group_token > 0: self.group_projector = [ keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="group_projector.0"), TFGroupViTMixerMLP( config, num_prev_group_token, config.hidden_size // 2, num_group_token, name="group_projector.1" ), ] else: self.group_projector = None def build(self, input_shape=None): if self.num_group_token > 0: self.group_token = self.add_weight( shape=(1, self.num_group_token, self.config.hidden_size), initializer="zeros", trainable=True, name="group_token", ) else: self.group_token = None if self.built: return self.built = True if getattr(self, "downsample", None) is not None: with tf.name_scope(self.downsample.name): self.downsample.build(None) if getattr(self, "layers", None) is not None: for layer in self.layers: with tf.name_scope(layer.name): layer.build(None) if getattr(self, "group_projector", None) is not None: with tf.name_scope(self.group_projector[0].name): self.group_projector[0].build([None, None, self.config.hidden_size]) with tf.name_scope(self.group_projector[1].name): self.group_projector[1].build(None) @property def with_group_token(self): return self.group_token is not None def split_x(self, x: tf.Tensor) -> tf.Tensor: if self.with_group_token: return x[:, : -self.num_group_token], x[:, -self.num_group_token :] else: return x, None def concat_x(self, x: tf.Tensor, group_token: tf.Tensor | None = None) -> tf.Tensor: if group_token is None: return x return tf.concat([x, group_token], axis=1) def call( self, hidden_states: tf.Tensor, prev_group_token: tf.Tensor | None = None, output_attentions: bool = False, training: bool = False, ) -> tuple[tf.Tensor]: """ Args: hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(config.encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the grouping tensors of Grouping block. """ if self.with_group_token: group_token = tf.tile(self.group_token, multiples=(shape_list(hidden_states)[0], 1, 1)) if self.group_projector is not None: for layer in self.group_projector: prev_group_token = layer(prev_group_token) group_token = group_token + prev_group_token else: group_token = None x = hidden_states cat_x = self.concat_x(x, group_token) for layer in self.layers: layer_out = layer( cat_x, attention_mask=None, causal_attention_mask=None, output_attentions=None, ) cat_x = layer_out[0] x, group_token = self.split_x(cat_x) attention = None if self.downsample is not None: x, attention = self.downsample(x, group_token) outputs = (x, group_token) if output_attentions: outputs = outputs + (attention,) return outputs class TFGroupViTMLP(keras.layers.Layer): def __init__( self, config: GroupViTVisionConfig, hidden_size: int | None = None, intermediate_size: int | None = None, output_size: int | None = None, **kwargs, ): super().__init__(**kwargs) self.config = config self.activation_fn = get_tf_activation(config.hidden_act) hidden_size = hidden_size if hidden_size is not None else config.hidden_size intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size output_size = output_size if output_size is not None else hidden_size self.fc1 = keras.layers.Dense(intermediate_size, name="fc1") self.fc2 = keras.layers.Dense(output_size, name="fc2") self.intermediate_size = intermediate_size self.hidden_size = hidden_size def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "fc1", None) is not None: with tf.name_scope(self.fc1.name): self.fc1.build([None, None, self.hidden_size]) if getattr(self, "fc2", None) is not None: with tf.name_scope(self.fc2.name): self.fc2.build([None, None, self.intermediate_size]) class TFGroupViTMixerMLP(TFGroupViTMLP): def call(self, x, training: bool = False): x = super().call(hidden_states=tf.transpose(x, perm=(0, 2, 1))) return tf.transpose(x, perm=(0, 2, 1)) # Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPAttention class TFGroupViTAttention(keras.layers.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: GroupViTConfig, **kwargs): super().__init__(**kwargs) self.embed_dim = config.hidden_size self.num_attention_heads = config.num_attention_heads self.attention_head_size = self.embed_dim // self.num_attention_heads if self.attention_head_size * self.num_attention_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_attention_heads})." ) factor = config.initializer_factor in_proj_std = (self.embed_dim**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor out_proj_std = (self.embed_dim**-0.5) * factor self.sqrt_att_head_size = math.sqrt(self.attention_head_size) self.q_proj = keras.layers.Dense( units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="q_proj" ) self.k_proj = keras.layers.Dense( units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="k_proj" ) self.v_proj = keras.layers.Dense( units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="v_proj" ) self.dropout = keras.layers.Dropout(rate=config.attention_dropout) self.out_proj = keras.layers.Dense( units=self.embed_dim, kernel_initializer=get_initializer(out_proj_std), name="out_proj" ) # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention.transpose_for_scores def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] return tf.transpose(tensor, perm=[0, 2, 1, 3]) def call( self, hidden_states: tf.Tensor, attention_mask: tf.Tensor | None = None, causal_attention_mask: tf.Tensor | None = None, output_attentions: bool | None = None, encoder_hidden_states: tf.Tensor | None = None, training: bool = False, ) -> tuple[tf.Tensor]: """Input shape: Batch x Time x Channel""" batch_size = shape_list(hidden_states)[0] is_cross_attention = encoder_hidden_states is not None mixed_query_layer = self.q_proj(inputs=hidden_states) if is_cross_attention: mixed_key_layer = self.k_proj(inputs=encoder_hidden_states) mixed_value_layer = self.v_proj(inputs=encoder_hidden_states) else: mixed_key_layer = self.k_proj(inputs=hidden_states) mixed_value_layer = self.v_proj(inputs=hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch size, num_heads, seq_len_q, seq_len_k) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) attention_scores = tf.divide(attention_scores, dk) # apply the causal_attention_mask first if causal_attention_mask is not None: # Apply the causal attention mask (precomputed for all layers in TFCLIPModel call() function) attention_scores = tf.add(attention_scores, causal_attention_mask) if attention_mask is not None: # Apply the attention mask (precomputed for all layers in TFCLIPModel call() function) attention_scores = tf.add(attention_scores, attention_mask) # Normalize the attention scores to probabilities. _attention_probs = stable_softmax(logits=attention_scores, axis=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(inputs=_attention_probs) attention_output = tf.matmul(attention_probs, value_layer) attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, embed_dim) attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.embed_dim)) attention_output = self.out_proj(attention_output) # In TFBert, attention weights are returned after dropout. # However, in CLIP, they are returned before dropout. outputs = (attention_output, _attention_probs) if output_attentions else (attention_output,) return outputs def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "q_proj", None) is not None: with tf.name_scope(self.q_proj.name): self.q_proj.build([None, None, self.embed_dim]) if getattr(self, "k_proj", None) is not None: with tf.name_scope(self.k_proj.name): self.k_proj.build([None, None, self.embed_dim]) if getattr(self, "v_proj", None) is not None: with tf.name_scope(self.v_proj.name): self.v_proj.build([None, None, self.embed_dim]) if getattr(self, "out_proj", None) is not None: with tf.name_scope(self.out_proj.name): self.out_proj.build([None, None, self.embed_dim]) # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPEncoderLayer with CLIP->GroupViT class TFGroupViTEncoderLayer(keras.layers.Layer): def __init__(self, config: GroupViTConfig, **kwargs): super().__init__(**kwargs) self.embed_dim = config.hidden_size self.self_attn = TFGroupViTAttention(config, name="self_attn") self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") self.mlp = TFGroupViTMLP(config, name="mlp") self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") def call( self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, causal_attention_mask: tf.Tensor, output_attentions: bool, training: bool = False, ) -> tuple[tf.Tensor]: """ Args: hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. causal_attention_mask (`tf.Tensor`): causal attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`): Whether or not to return the attentions tensors of all attention layers. See `outputs` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(inputs=hidden_states) attention_outputs = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, training=training, ) hidden_states = attention_outputs[0] hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(inputs=hidden_states) hidden_states = self.mlp(hidden_states=hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) + attention_outputs[1:] # add attentions if we output them return outputs def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "self_attn", None) is not None: with tf.name_scope(self.self_attn.name): self.self_attn.build(None) if getattr(self, "layer_norm1", None) is not None: with tf.name_scope(self.layer_norm1.name): self.layer_norm1.build([None, None, self.embed_dim]) if getattr(self, "mlp", None) is not None: with tf.name_scope(self.mlp.name): self.mlp.build(None) if getattr(self, "layer_norm2", None) is not None: with tf.name_scope(self.layer_norm2.name): self.layer_norm2.build([None, None, self.embed_dim]) # Adapted from transformers.models.clip.modeling_tf_clip.TFGroupViTTextEncoder class TFGroupViTTextEncoder(keras.layers.Layer): def __init__(self, config: GroupViTTextConfig, **kwargs): super().__init__(**kwargs) self.layers = [TFGroupViTEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)] def call( self, hidden_states, attention_mask: tf.Tensor, causal_attention_mask: tf.Tensor, output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, ) -> tuple | TFBaseModelOutput: encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) layer_outputs = encoder_layer( hidden_states, attention_mask, causal_attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return TFBaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "layers", None) is not None: for layer in self.layers: with tf.name_scope(layer.name): layer.build(None) class TFGroupViTVisionEncoder(keras.layers.Layer): def __init__(self, config: GroupViTVisionConfig, **kwargs) -> None: super().__init__(**kwargs) self.stages = [ TFGroupViTStage( config=config, depth=config.depths[i], num_group_token=config.num_group_tokens[i], num_output_group=config.num_output_groups[i], num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0, name=f"stages_._{i}", ) for i in range(len(config.depths)) ] def call( self, hidden_states: tf.Tensor, output_hidden_states: bool, output_attentions: bool, return_dict: bool, training: bool = False, ) -> tuple | TFBaseModelOutput: all_hidden_states = () if output_hidden_states else None all_groupings = () if output_attentions else None group_tokens = None for stage in self.stages: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = stage(hidden_states, group_tokens, output_attentions) hidden_states = layer_outputs[0] group_tokens = layer_outputs[1] if output_attentions and layer_outputs[2] is not None: all_groupings = all_groupings + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None) return TFBaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings ) def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "stages", None) is not None: for layer in self.stages: with tf.name_scope(layer.name): layer.build(None) # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder class TFGroupViTTextTransformer(keras.layers.Layer): def __init__(self, config: GroupViTTextConfig, **kwargs): super().__init__(**kwargs) self.embeddings = TFGroupViTTextEmbeddings(config, name="embeddings") self.encoder = TFGroupViTTextEncoder(config, name="encoder") self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") # For `pooled_output` computation self.eos_token_id = config.eos_token_id self.embed_dim = config.hidden_size def call( self, input_ids: TFModelInputType, attention_mask: tf.Tensor, position_ids: tf.Tensor, output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: input_shape = shape_list(input_ids) embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids) batch_size, seq_length = input_shape # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 causal_attention_mask = self._build_causal_attention_mask(batch_size, seq_length, dtype=embedding_output.dtype) # check attention mask and invert # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask) encoder_outputs = self.encoder( hidden_states=embedding_output, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) sequence_output = encoder_outputs[0] sequence_output = self.final_layer_norm(inputs=sequence_output) if self.eos_token_id == 2: # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added # ------------------------------------------------------------ # text_embeds.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) pooled_output = tf.gather_nd( params=sequence_output, indices=tf.stack( values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 ), ) else: # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible) pooled_output = tf.gather_nd( params=sequence_output, indices=tf.stack( values=( tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1), ), axis=1, ), ) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return TFBaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32): # It is possible with an unspecified sequence length for seq_length to be # a runtime value, which is unsupported by tf.constant. Per the TensorFlow # docs, tf.fill can handle runtime dynamic shapes: # https://www.tensorflow.org/api_docs/python/tf/fill diag = tf.cast(tf.fill((seq_length,), 0.0), dtype) # set an additive 2D attention mask with all places being masked to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype) # set diagonal & lower triangular parts to 0 (i.e. the places not to be masked) # TIP: think the 2D matrix as the space of (query_seq, key_seq) to_mask = tf.linalg.band_part(to_mask, 0, -1) # to_mask = tf.linalg.band_part(to_mask, -1, 0) to_mask = tf.linalg.set_diag(to_mask, diagonal=diag) return tf.broadcast_to(input=to_mask, shape=(batch_size, 1, seq_length, seq_length)) def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "embeddings", None) is not None: with tf.name_scope(self.embeddings.name): self.embeddings.build(None) if getattr(self, "encoder", None) is not None: with tf.name_scope(self.encoder.name): self.encoder.build(None) if getattr(self, "final_layer_norm", None) is not None: with tf.name_scope(self.final_layer_norm.name): self.final_layer_norm.build([None, None, self.embed_dim]) # Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPVisionTransformer class TFGroupViTVisionTransformer(keras.layers.Layer): def __init__(self, config: GroupViTVisionConfig, **kwargs): super().__init__(**kwargs) self.embeddings = TFGroupViTVisionEmbeddings(config, name="embeddings") self.encoder = TFGroupViTVisionEncoder(config, name="encoder") self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") self.embed_dim = config.hidden_size def call( self, pixel_values: TFModelInputType, output_attentions: bool, output_hidden_states: bool, return_dict: bool, training: bool = False, ) -> tuple | TFBaseModelOutputWithPooling: embedding_output = self.embeddings(pixel_values) encoder_outputs = self.encoder( hidden_states=embedding_output, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] # normalize the last hidden state last_hidden_state = self.layernorm(last_hidden_state) pooled_output = tf.math.reduce_mean(last_hidden_state, axis=1) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return TFBaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "embeddings", None) is not None: with tf.name_scope(self.embeddings.name): self.embeddings.build(None) if getattr(self, "encoder", None) is not None: with tf.name_scope(self.encoder.name): self.encoder.build(None) if getattr(self, "layernorm", None) is not None: with tf.name_scope(self.layernorm.name): self.layernorm.build([None, None, self.embed_dim]) @keras_serializable # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextMainLayer with CLIP->GroupViT class TFGroupViTTextMainLayer(keras.layers.Layer): config_class = GroupViTTextConfig def __init__(self, config: GroupViTTextConfig, **kwargs): super().__init__(**kwargs) self.config = config self.text_model = TFGroupViTTextTransformer(config, name="text_model") def get_input_embeddings(self) -> keras.layers.Layer: return self.text_model.embeddings def set_input_embeddings(self, value: tf.Variable): self.text_model.embeddings.weight = value self.text_model.embeddings.vocab_size = shape_list(value)[0] @unpack_inputs def call( self, input_ids: TFModelInputType | None = None, attention_mask: np.ndarray | tf.Tensor | None = None, position_ids: np.ndarray | tf.Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: if input_ids is None: raise ValueError("You have to specify input_ids") input_shape = shape_list(input_ids) if attention_mask is None: attention_mask = tf.fill(dims=input_shape, value=1) text_model_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) return text_model_outputs def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "text_model", None) is not None: with tf.name_scope(self.text_model.name): self.text_model.build(None) @keras_serializable # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPVisionMainLayer with CLIP->GroupViT class TFGroupViTVisionMainLayer(keras.layers.Layer): config_class = GroupViTVisionConfig def __init__(self, config: GroupViTVisionConfig, **kwargs): super().__init__(**kwargs) self.config = config self.vision_model = TFGroupViTVisionTransformer(config, name="vision_model") def get_input_embeddings(self) -> keras.layers.Layer: return self.vision_model.embeddings @unpack_inputs def call( self, pixel_values: TFModelInputType | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: if pixel_values is None: raise ValueError("You have to specify pixel_values") vision_model_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) return vision_model_outputs def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "vision_model", None) is not None: with tf.name_scope(self.vision_model.name): self.vision_model.build(None) @keras_serializable # Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPMainLayer class TFGroupViTMainLayer(keras.layers.Layer): config_class = GroupViTConfig def __init__(self, config: GroupViTConfig, **kwargs): super().__init__(**kwargs) if not isinstance(config.text_config, GroupViTTextConfig): raise TypeError( "config.text_config is expected to be of type GroupViTTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, GroupViTVisionConfig): raise TypeError( "config.vision_config is expected to be of type GroupViTVisionConfig but is of type" f" {type(config.vision_config)}." ) self.config = config text_config = config.text_config vision_config = config.vision_config self.projection_dim = config.projection_dim self.projection_intermediate_dim = config.projection_intermediate_dim self.text_embed_dim = text_config.hidden_size self.vision_embed_dim = vision_config.hidden_size self.text_model = TFGroupViTTextTransformer(text_config, name="text_model") self.vision_model = TFGroupViTVisionTransformer(vision_config, name="vision_model") self.visual_projection = [ keras.layers.Dense(self.projection_intermediate_dim, name="visual_projection.0"), keras.layers.BatchNormalization(name="visual_projection.1", momentum=0.9, epsilon=1e-5), keras.layers.ReLU(name="visual_projection.2"), keras.layers.Dense(self.projection_dim, name="visual_projection.3"), ] self.text_projection = [ keras.layers.Dense(self.projection_intermediate_dim, name="text_projection.0"), keras.layers.BatchNormalization(name="text_projection.1", momentum=0.9, epsilon=1e-5), keras.layers.ReLU(name="text_projection.2"), keras.layers.Dense(self.projection_dim, name="text_projection.3"), ] def build(self, input_shape=None): self.logit_scale = self.add_weight( shape=(1,), initializer=keras.initializers.Constant(self.config.logit_scale_init_value), trainable=True, name="logit_scale", ) if self.built: return self.built = True if getattr(self, "text_model", None) is not None: with tf.name_scope(self.text_model.name): self.text_model.build(None) if getattr(self, "vision_model", None) is not None: with tf.name_scope(self.vision_model.name): self.vision_model.build(None) if getattr(self, "visual_projection", None) is not None: with tf.name_scope(self.visual_projection[0].name): self.visual_projection[0].build([None, None, None, self.vision_embed_dim]) with tf.name_scope(self.visual_projection[1].name): self.visual_projection[1].build((None, self.projection_intermediate_dim)) with tf.name_scope(self.visual_projection[3].name): self.visual_projection[3].build([None, None, None, self.projection_intermediate_dim]) if getattr(self, "text_projection", None) is not None: with tf.name_scope(self.text_projection[0].name): self.text_projection[0].build([None, None, None, self.text_embed_dim]) with tf.name_scope(self.text_projection[1].name): self.text_projection[1].build((None, self.projection_intermediate_dim)) with tf.name_scope(self.text_projection[3].name): self.text_projection[3].build([None, None, None, self.projection_intermediate_dim]) @unpack_inputs def get_text_features( self, input_ids: TFModelInputType | None = None, attention_mask: np.ndarray | tf.Tensor | None = None, position_ids: np.ndarray | tf.Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> tf.Tensor: if input_ids is None: raise ValueError("You have to specify either input_ids") input_shape = shape_list(input_ids) if attention_mask is None: attention_mask = tf.fill(dims=input_shape, value=1) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) pooled_output = text_outputs[1] for layer in self.text_projection: pooled_output = layer(pooled_output) text_features = pooled_output return text_features @unpack_inputs def get_image_features( self, pixel_values: TFModelInputType | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> tf.Tensor: if pixel_values is None: raise ValueError("You have to specify pixel_values") vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) pooled_output = vision_outputs[1] for layer in self.visual_projection: pooled_output = layer(pooled_output) image_features = pooled_output return image_features @unpack_inputs def call( self, input_ids: TFModelInputType | None = None, pixel_values: TFModelInputType | None = None, attention_mask: np.ndarray | tf.Tensor | None = None, position_ids: np.ndarray | tf.Tensor | None = None, return_loss: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_segmentation: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> TFGroupViTModelOutput | tuple[tf.Tensor]: if input_ids is None: raise ValueError("You have to specify either input_ids") if pixel_values is None: raise ValueError("You have to specify pixel_values") input_shape = shape_list(input_ids) if attention_mask is None: attention_mask = tf.fill(dims=input_shape, value=1) if output_segmentation: output_attentions = True vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) image_embeds = vision_outputs[1] for layer in self.visual_projection: image_embeds = layer(image_embeds) text_embeds = text_outputs[1] for layer in self.text_projection: text_embeds = layer(text_embeds) # normalized features image_embeds = image_embeds / tf.norm(image_embeds, axis=-1, keepdims=True) text_embeds = text_embeds / tf.norm(text_embeds, axis=-1, keepdims=True) # cosine similarity as logits logit_scale = tf.math.exp(self.logit_scale) logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale logits_per_image = tf.transpose(logits_per_text) seg_logits = None if output_segmentation: # grouped features # [batch_size_image, num_group, hidden_size] image_group_embeds = vision_outputs[0] # [batch_size_image*num_group, hidden_size] image_group_embeds = tf.reshape(image_group_embeds, shape=(-1, shape_list(image_group_embeds)[-1])) for layer in self.visual_projection: image_group_embeds = layer(image_group_embeds) if output_hidden_states: attentions = vision_outputs[3] else: attentions = vision_outputs[2] # [batch_size_image, num_group, height, width] grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:]) # normalized features image_group_embeds = image_group_embeds / tf.norm( tensor=image_group_embeds, ord="euclidean", axis=-1, keepdims=True ) # [batch_size_image x num_group, batch_size_text] logits_per_image_group = tf.matmul(image_group_embeds, text_embeds, transpose_b=True) * logit_scale # [batch_size_image, batch_size_text, num_group] logits_per_image_group = tf.reshape( logits_per_image_group, shape=(image_embeds.shape[0], -1, text_embeds.shape[0]) ) logits_per_image_group = tf.transpose(logits_per_image_group, perm=(0, 2, 1)) # [batch_size_image, batch_size_text, height x width] flatten_grouping = tf.reshape(grouping, shape=(shape_list(grouping)[0], shape_list(grouping)[1], -1)) # [batch_size_image, batch_size_text, height, width] seg_logits = tf.matmul(logits_per_image_group, flatten_grouping) * logit_scale seg_logits = tf.reshape( seg_logits, shape=(seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3]) ) loss = None if return_loss: loss = groupvit_loss(logits_per_text)[None, ...] if not return_dict: if seg_logits is not None: output = ( logits_per_image, logits_per_text, seg_logits, text_embeds, image_embeds, text_outputs, vision_outputs, ) else: output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) return ((loss,) + output) if loss is not None else output return TFGroupViTModelOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, segmentation_logits=seg_logits, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, ) class TFGroupViTPreTrainedModel(TFPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = GroupViTConfig base_model_prefix = "groupvit" GROUPVIT_START_DOCSTRING = r""" This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and behavior. TF 2.0 models accepts two formats as inputs: - having all inputs as keyword arguments (like PyTorch models), or - having all inputs as a list, tuple or dict in the first positional arguments. This second option is useful when using [`keras.Model.fit`] method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`. If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument : - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - a dictionary with one or several input Tensors associated to the input names given in the docstring: `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` Args: config ([`GroupViTConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ GROUPVIT_TEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and [`PreTrainedTokenizer.encode`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the config will be used instead. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the config will be used instead. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in eager mode, in graph mode the value will always be set to True. training (`bool`, *optional*, defaults to `False``): Whether or not to use the model in training mode (some modules like dropout modules have different behaviors between training and evaluation). """ GROUPVIT_VISION_INPUTS_DOCSTRING = r""" Args: pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]`, `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the config will be used instead. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the config will be used instead. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in eager mode, in graph mode the value will always be set to True. training (`bool`, *optional*, defaults to `False``): Whether or not to use the model in training mode (some modules like dropout modules have different behaviors between training and evaluation). """ GROUPVIT_INPUTS_DOCSTRING = r""" Args: input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and [`PreTrainedTokenizer.encode`] for details. [What are input IDs?](../glossary#input-ids) pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) return_loss (`bool`, *optional*): Whether or not to return the contrastive loss. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the config will be used instead. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the config will be used instead. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in eager mode, in graph mode the value will always be set to True. training (`bool`, *optional*, defaults to `False``): Whether or not to use the model in training mode (some modules like dropout modules have different behaviors between training and evaluation). """ class TFGroupViTTextModel(TFGroupViTPreTrainedModel): config_class = GroupViTTextConfig main_input_name = "input_ids" def __init__(self, config: GroupViTTextConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.groupvit = TFGroupViTTextMainLayer(config, name="groupvit") @unpack_inputs @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=GroupViTTextConfig) def call( self, input_ids: TFModelInputType | None = None, attention_mask: np.ndarray | tf.Tensor | None = None, position_ids: np.ndarray | tf.Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: r""" Returns: Examples: ```python >>> from transformers import CLIPTokenizer, TFGroupViTTextModel >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> model = TFGroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" outputs = self.groupvit( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) return outputs def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "groupvit", None) is not None: with tf.name_scope(self.groupvit.name): self.groupvit.build(None) class TFGroupViTVisionModel(TFGroupViTPreTrainedModel): config_class = GroupViTVisionConfig main_input_name = "pixel_values" def __init__(self, config: GroupViTVisionConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.groupvit = TFGroupViTVisionMainLayer(config, name="groupvit") @unpack_inputs @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=GroupViTVisionConfig) def call( self, pixel_values: TFModelInputType | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, TFGroupViTVisionModel >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> model = TFGroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="tf") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" outputs = self.groupvit( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) return outputs def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "groupvit", None) is not None: with tf.name_scope(self.groupvit.name): self.groupvit.build(None) @add_start_docstrings(GROUPVIT_START_DOCSTRING) class TFGroupViTModel(TFGroupViTPreTrainedModel): config_class = GroupViTConfig def __init__(self, config: GroupViTConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.groupvit = TFGroupViTMainLayer(config, name="groupvit") @unpack_inputs @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def get_text_features( self, input_ids: TFModelInputType | None = None, attention_mask: np.ndarray | tf.Tensor | None = None, position_ids: np.ndarray | tf.Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> tf.Tensor: r""" Returns: text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`TFGroupViTTextModel`]. Examples: ```python >>> from transformers import CLIPTokenizer, TFGroupViTModel >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf") >>> text_features = model.get_text_features(**inputs) ```""" text_features = self.groupvit.get_text_features( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) return text_features @unpack_inputs @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING) def get_image_features( self, pixel_values: TFModelInputType | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> tf.Tensor: r""" Returns: image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`TFGroupViTVisionModel`]. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, TFGroupViTModel >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="tf") >>> image_features = model.get_image_features(**inputs) ```""" image_features = self.groupvit.get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) return image_features @unpack_inputs @add_start_docstrings_to_model_forward(GROUPVIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFGroupViTModelOutput, config_class=GroupViTConfig) def call( self, input_ids: TFModelInputType | None = None, pixel_values: TFModelInputType | None = None, attention_mask: np.ndarray | tf.Tensor | None = None, position_ids: np.ndarray | tf.Tensor | None = None, return_loss: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_segmentation: bool | None = None, return_dict: bool | None = None, training: bool = False, ) -> TFGroupViTModelOutput | tuple[tf.Tensor]: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, TFGroupViTModel >>> import tensorflow as tf >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor( ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True ... ) >>> outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score >>> probs = tf.math.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities ```""" outputs = self.groupvit( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, return_loss=return_loss, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_segmentation=output_segmentation, return_dict=return_dict, training=training, ) return outputs def serving_output(self, output: TFGroupViTModelOutput) -> TFGroupViTModelOutput: # TODO: As is this currently fails with saved_model=True, because # TensorFlow cannot trace through nested dataclasses. Reference: # https://github.com/huggingface/transformers/pull/16886 return output def build(self, input_shape=None): if self.built: return self.built = True if getattr(self, "groupvit", None) is not None: with tf.name_scope(self.groupvit.name): self.groupvit.build(None) __all__ = ["TFGroupViTModel", "TFGroupViTPreTrainedModel", "TFGroupViTTextModel", "TFGroupViTVisionModel"]