# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_wav2vec2_bert.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings from typing import Optional, Union import numpy as np import torch from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput, Wav2Vec2BaseModelOutput, XVectorOutput, ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_peft_available from .configuration_wav2vec2_bert import Wav2Vec2BertConfig class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module): """Rotary positional embedding Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://huggingface.co/papers/2104.09864 """ def __init__(self, config): super().__init__() dim = config.hidden_size // config.num_attention_heads base = config.rotary_embedding_base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) # Ignore copy self.register_buffer("inv_freq", inv_freq, persistent=False) self.cached_sequence_length = None self.cached_rotary_positional_embedding = None def forward(self, hidden_states): sequence_length = hidden_states.shape[1] if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: return self.cached_rotary_positional_embedding self.cached_sequence_length = sequence_length # Embeddings are computed in the dtype of the inv_freq constant time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) embeddings = torch.cat((freqs, freqs), dim=-1) cos_embeddings = embeddings.cos()[:, None, None, :] sin_embeddings = embeddings.sin()[:, None, None, :] # Computed embeddings are cast to the dtype of the hidden state inputs self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) return self.cached_rotary_positional_embedding class Wav2Vec2BertRelPositionalEmbedding(nn.Module): """Relative positional encoding module.""" def __init__(self, config): super().__init__() self.max_len = config.max_source_positions self.d_model = config.hidden_size self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) def extend_pe(self, x): # Reset the positional encodings if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` is the position of query vector and `j` is the # position of key vector. We use positive relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i