# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_got_ocr2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections from dataclasses import dataclass from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers.utils.generic import check_model_inputs from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ..auto import AutoModel from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig class GotOcr2MLPBlock(nn.Module): def __init__(self, config): super().__init__() self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) self.act = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.lin1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.lin2(hidden_states) return hidden_states class GotOcr2VisionAttention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__(self, config, window_size): super().__init__() input_size = ( (config.image_size // config.patch_size, config.image_size // config.patch_size) if window_size == 0 else (window_size, window_size) ) self.num_attention_heads = config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads self.scale = head_dim**-0.5 self.dropout = config.attention_dropout self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) self.proj = nn.Linear(config.hidden_size, config.hidden_size) self.use_rel_pos = config.use_rel_pos if self.use_rel_pos: if input_size is None: raise ValueError("Input size must be provided if using relative positional encoding.") # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of the query. k_size (int): size of key k. rel_pos (`torch.Tensor`): relative position embeddings (L, channel). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def get_decomposed_rel_pos( self, query: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: tuple[int, int], k_size: tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py Args: query (`torch.Tensor`): query q in the attention layer with shape (batch_size, query_height * query_width, channel). rel_pos_h (`torch.Tensor`): relative position embeddings (Lh, channel) for height axis. rel_pos_w (`torch.Tensor`): relative position embeddings (Lw, channel) for width axis. q_size (tuple): spatial sequence size of query q with (query_height, query_width). k_size (tuple): spatial sequence size of key k with (key_height, key_width). Returns: decomposed_rel_pos (`torch.Tensor`): decomposed relative position embeddings. """ query_height, query_width = q_size key_height, key_width = k_size relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) batch_size, _, dim = query.shape reshaped_query = query.reshape(batch_size, query_height, query_width, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] return decomposed_rel_pos def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) qkv = ( self.qkv(hidden_states) .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) .permute(2, 0, 3, 1, 4) ) # q, k, v with shape (batch_size * nHead, height * width, channel) query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) attn_weights = (query * self.scale) @ key.transpose(-2, -1) if self.use_rel_pos: decomposed_rel_pos = self.get_decomposed_rel_pos( query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) ) decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) attn_weights = attn_weights + decomposed_rel_pos attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) attn_output = self.proj(attn_output) return attn_output, attn_weights class GotOcr2VisionLayer(GradientCheckpointingLayer): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn = GotOcr2VisionAttention(config, window_size) self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = GotOcr2MLPBlock(config) self.window_size = window_size def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]: """ Args: Partition into non-overlapping windows with padding if needed. hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window size. Returns: windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. (pad_height, pad_width): padded height and width before partition """ batch_size, height, width, channel = hidden_states.shape pad_h = (window_size - height % window_size) % window_size pad_w = (window_size - width % window_size) % window_size hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) pad_height, pad_width = height + pad_h, width + pad_w hidden_states = hidden_states.reshape( batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel ) windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) return windows, (pad_height, pad_width) def window_unpartition( self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int] ) -> torch.Tensor: """ Args: Window unpartition into original sequences and removing padding. hidden_states (tensor): input tokens with [batch_size * num_windows, window_size, window_size, channel]. window_size (int): window size. padding_shape (Tuple): padded height and width (pad_height, pad_width). original_shape (Tuple): original height and width (height, width) before padding. Returns: hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. """ pad_height, pad_width = padding_shape height, width = original_shape batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) hidden_states = windows.reshape( batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 ) hidden_states = ( hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) ) hidden_states = hidden_states[:, :height, :width, :].contiguous() return hidden_states def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) # Window partition if self.window_size > 0: height, width = hidden_states.shape[1], hidden_states.shape[2] hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) hidden_states, attn_weights = self.attn( hidden_states=hidden_states, ) # Reverse window partition if self.window_size > 0: hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) hidden_states = residual + hidden_states layernorm_output = self.layer_norm2(hidden_states) hidden_states = hidden_states + self.mlp(layernorm_output) return hidden_states @auto_docstring class GotOcr2PreTrainedModel(PreTrainedModel): config: GotOcr2Config base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_flash_attn = False _supports_sdpa = False _can_compile_fullgraph = True _supports_flex_attn = False _supports_attention_backend = True def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GotOcr2VisionAttention): if module.use_rel_pos: module.rel_pos_h.data.zero_() module.rel_pos_w.data.zero_() elif isinstance(module, GotOcr2VisionEncoder): if module.pos_embed is not None: module.pos_embed.data.zero_() @dataclass @auto_docstring( custom_intro=""" Base class for got_ocr2 vision model's outputs that also contains image embeddings obtained by applying the projection layer to the pooler_output. """ ) class GotOcr2VisionEncoderOutput(ModelOutput): r""" image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. """ image_embeds: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None class GotOcr2PatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) return embeddings class GotOcr2LayerNorm(nn.LayerNorm): r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): super().__init__(normalized_shape, eps=eps, **kwargs) if data_format not in ["channels_last", "channels_first"]: raise NotImplementedError(f"Unsupported data format: {data_format}") self.data_format = data_format def forward(self, features: torch.Tensor) -> torch.Tensor: """ Args: features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) """ if self.data_format == "channels_first": features = features.permute(0, 2, 3, 1) features = super().forward(features) features = features.permute(0, 3, 1, 2) else: features = super().forward(features) return features class GotOcr2VisionNeck(nn.Module): def __init__(self, config: GotOcr2VisionConfig): super().__init__() self.config = config self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) self.layer_norm1 = GotOcr2LayerNorm(config.output_channels, data_format="channels_first") self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) self.layer_norm2 = GotOcr2LayerNorm(config.output_channels, data_format="channels_first") def forward(self, hidden_states): hidden_states = hidden_states.permute(0, 3, 1, 2) hidden_states = self.conv1(hidden_states) hidden_states = self.layer_norm1(hidden_states) hidden_states = self.conv2(hidden_states) hidden_states = self.layer_norm2(hidden_states) return hidden_states class GotOcr2VisionEncoder(GotOcr2PreTrainedModel): _can_record_outputs = {"hidden_states": GotOcr2VisionLayer, "attentions": GotOcr2VisionAttention} def __init__(self, config: GotOcr2VisionConfig): super().__init__(config) self.config = config self.image_size = config.image_size self.patch_embed = GotOcr2PatchEmbeddings(config) self.pos_embed = None if config.use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros( 1, config.image_size // config.patch_size, config.image_size // config.patch_size, config.hidden_size, ) ) self.layers = nn.ModuleList() for i in range(config.num_hidden_layers): layer = GotOcr2VisionLayer( config, window_size=config.window_size if i not in config.global_attn_indexes else 0, ) self.layers.append(layer) self.neck = GotOcr2VisionNeck(config) self.gradient_checkpointing = False def get_input_embeddings(self): return self.patch_embed @check_model_inputs def forward( self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs] ) -> GotOcr2VisionEncoderOutput: if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.patch_embed(pixel_values) if self.pos_embed is not None: hidden_states = hidden_states + self.pos_embed for layer_module in self.layers: hidden_states = layer_module(hidden_states) hidden_states = self.neck(hidden_states) return GotOcr2VisionEncoderOutput( last_hidden_state=hidden_states, ) class GotOcr2MultiModalProjector(nn.Module): def __init__(self, config: GotOcr2Config): super().__init__() vision_output_channels = config.vision_config.output_channels language_hidden_size = config.text_config.hidden_size self.conv_upsampler1 = nn.Conv2d( vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False ) self.conv_upsampler2 = nn.Conv2d( vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False ) self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size) def forward(self, vision_embeddings: torch.Tensor) -> torch.Tensor: hidden_state = self.conv_upsampler1(vision_embeddings) hidden_state = self.conv_upsampler2(hidden_state) hidden_state = hidden_state.flatten(2).permute(0, 2, 1) hidden_state = self.multimodal_projector(hidden_state) return hidden_state @dataclass @auto_docstring( custom_intro=""" Base class for GotOcr2 causal language model (or autoregressive) outputs. """ ) class GotOcr2CausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None @dataclass @auto_docstring( custom_intro=""" Base class for GotOcr2 outputs, with hidden states and attentions. """ ) class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast): r""" past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: Optional[torch.FloatTensor] = None @auto_docstring( custom_intro=""" The GotOcr2 model which consists of a vision backbone and a language model, without a language modeling head. """ ) class GotOcr2Model(GotOcr2PreTrainedModel): _checkpoint_conversion_mapping = {"language_model.model": "language_model"} def __init__(self, config: GotOcr2Config): super().__init__(config) self.vision_tower = GotOcr2VisionEncoder(config.vision_config) self.multi_modal_projector = GotOcr2MultiModalProjector(config) self.language_model = AutoModel.from_config(config.text_config) self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_decoder(self, decoder): self.language_model = decoder def get_decoder(self): return self.language_model def get_image_features( self, pixel_values: torch.FloatTensor, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is equal to the length of multimodal features. If the lengths are different, an error is raised. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_features = image_features.shape[0] * image_features.shape[1] if inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) return special_image_mask @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, GotOcr2ModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **kwargs, ) return GotOcr2ModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) @auto_docstring( custom_intro=""" The GOT_OCR2 model which consists of a vision backbone and a language model. """ ) class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = { "^language_model.model": "model.language_model", "^vision_tower": "model.vision_tower", "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: GotOcr2Config): super().__init__(config) self.model = GotOcr2Model(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_decoder(self, decoder): self.model.set_decoder(decoder) def get_decoder(self): return self.model.get_decoder() def get_image_features( self, pixel_values: torch.FloatTensor, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, **kwargs, ): return self.model.get_image_features( pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, **kwargs, ) # Make modules available through conditional class for BC @property def language_model(self): return self.model.language_model @property def vision_tower(self): return self.model.vision_tower @property def multi_modal_projector(self): return self.model.multi_modal_projector @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, GotOcr2CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, GotOcr2ForConditionalGeneration, TextStreamer >>> model = GotOcr2ForConditionalGeneration.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf").to("cuda") >>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(image, return_tensors="pt", color="green").to("cuda") >>> # Generate >>> streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) >>> generate_ids = model.generate( ... **inputs, ... do_sample=False, ... tokenizer = processor.tokenizer, ... stop_strings='<|im_end|>', ... streamer=streamer, ... max_new_tokens=4096, ... ) "You should keep in mind what features from the module should be used, especially when you're planning to sell a template." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return GotOcr2CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, cache_position=None, logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values return model_inputs __all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"]