import math from typing import Optional, Union import torch from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput, Wav2Vec2BaseModelOutput, XVectorOutput, ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2FeedForward, Wav2Vec2ForSequenceClassification, Wav2Vec2Model from ..wav2vec2_conformer.modeling_wav2vec2_conformer import ( Wav2Vec2ConformerForAudioFrameClassification, Wav2Vec2ConformerForCTC, Wav2Vec2ConformerForXVector, Wav2Vec2ConformerRelPositionalEmbedding, Wav2Vec2ConformerRotaryPositionalEmbedding, Wav2Vec2ConformerSelfAttention, ) from .configuration_wav2vec2_bert import Wav2Vec2BertConfig logger = logging.get_logger(__name__) _HIDDEN_STATES_START_POSITION = 2 # Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): """ Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that stops at the corresponding element in `seq_lens`. Args: hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`): The sequences to mask, where `*` is any number of sequence-specific dimensions including none. seq_lens (`torch.Tensor` of shape `(batch)`: Each element represents the length of the sequence at the same index in `hidden_states` Returns: `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)` """ batch_size, mask_seq_len = hidden_states.shape[:2] indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1) bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len) mask = hidden_states.new_ones((batch_size, mask_seq_len)) mask = mask.masked_fill(bool_mask, 0) return mask class Wav2Vec2BertRotaryPositionalEmbedding(Wav2Vec2ConformerRotaryPositionalEmbedding): def __init__(self, config): nn.Module.__init__(self) dim = config.hidden_size // config.num_attention_heads base = config.rotary_embedding_base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) # Ignore copy self.register_buffer("inv_freq", inv_freq, persistent=False) self.cached_sequence_length = None self.cached_rotary_positional_embedding = None class Wav2Vec2BertRelPositionalEmbedding(Wav2Vec2ConformerRelPositionalEmbedding): pass class Wav2Vec2BertFeatureProjection(nn.Module): def __init__(self, config): super().__init__() self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps) self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size) self.dropout = nn.Dropout(config.feat_proj_dropout) def forward(self, hidden_states): # non-projected hidden states are needed for quantization norm_hidden_states = self.layer_norm(hidden_states) hidden_states = self.projection(norm_hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states, norm_hidden_states class Wav2Vec2BertFeedForward(Wav2Vec2FeedForward): def __init__(self, config, act_fn=None, hidden_size=None): nn.Module.__init__(self) act_fn = act_fn if act_fn is not None else config.hidden_act hidden_size = hidden_size if hidden_size is not None else config.hidden_size self.intermediate_dropout = nn.Dropout(config.activation_dropout) self.intermediate_dense = nn.Linear(hidden_size, config.intermediate_size) self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn self.output_dense = nn.Linear(config.intermediate_size, hidden_size) self.output_dropout = nn.Dropout(config.hidden_dropout) class Wav2Vec2BertConvolutionModule(nn.Module): """Convolution block used in the conformer block""" def __init__(self, config): super().__init__() if (config.conv_depthwise_kernel_size - 1) % 2 == 1: raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pointwise_conv1 = nn.Conv1d( config.hidden_size, 2 * config.hidden_size, kernel_size=1, stride=1, padding=0, bias=False, ) self.glu = nn.GLU(dim=1) self.depthwise_conv = nn.Conv1d( config.hidden_size, config.hidden_size, config.conv_depthwise_kernel_size, stride=1, padding=0, groups=config.hidden_size, bias=False, ) self.depthwise_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.activation = ACT2FN[config.hidden_act] self.pointwise_conv2 = nn.Conv1d( config.hidden_size, config.hidden_size, kernel_size=1, stride=1, padding=0, bias=False, ) self.dropout = nn.Dropout(config.conformer_conv_dropout) def forward(self, hidden_states, attention_mask=None): hidden_states = self.layer_norm(hidden_states) # Ensure that we do not leak padded positions in depthwise convolution if attention mask is passed. # Put 0 where necessary if attention_mask is not None: hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) # exchange the temporal dimension and the feature dimension hidden_states = hidden_states.transpose(1, 2) # GLU mechanism # => (batch, 2*channel, dim) hidden_states = self.pointwise_conv1(hidden_states) # => (batch, channel, dim) hidden_states = self.glu(hidden_states) # Pad the sequence entirely on the left because of causal convolution. hidden_states = torch.nn.functional.pad(hidden_states, (self.depthwise_conv.kernel_size[0] - 1, 0)) # 1D Depthwise Conv hidden_states = self.depthwise_conv(hidden_states) hidden_states = self.depthwise_layer_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = self.activation(hidden_states) hidden_states = self.pointwise_conv2(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = hidden_states.transpose(1, 2) return hidden_states class Wav2Vec2BertSelfAttention(Wav2Vec2ConformerSelfAttention, nn.Module): """Construct an Wav2Vec2BertSelfAttention object. Can be enhanced with rotary or relative position embeddings. """ def __init__(self, config, is_adapter_attention=False): nn.Module.__init__(self) hidden_size = config.hidden_size if not is_adapter_attention else config.output_hidden_size self.head_size = hidden_size // config.num_attention_heads self.num_heads = config.num_attention_heads self.position_embeddings_type = config.position_embeddings_type if not is_adapter_attention else None self.linear_q = nn.Linear(hidden_size, hidden_size) self.linear_k = nn.Linear(hidden_size, hidden_size) self.linear_v = nn.Linear(hidden_size, hidden_size) self.linear_out = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(p=config.attention_dropout) if self.position_embeddings_type == "relative": # linear transformation for positional encoding self.linear_pos = nn.Linear(hidden_size, hidden_size, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in https://huggingface.co/papers/1901.02860 Section 3.3 self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) if self.position_embeddings_type == "relative_key": self.left_max_position_embeddings = config.left_max_position_embeddings self.right_max_position_embeddings = config.right_max_position_embeddings num_positions = self.left_max_position_embeddings + self.right_max_position_embeddings + 1 self.distance_embedding = nn.Embedding(num_positions, self.head_size) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, relative_position_embeddings: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: # self-attention mechanism batch_size, sequence_length, hidden_size = hidden_states.size() # make sure query/key states can be != value states query_key_states = hidden_states value_states = hidden_states if self.position_embeddings_type == "rotary": if relative_position_embeddings is None: raise ValueError( "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" ) query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) # project query_key_states and value_states query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) # => (batch, head, time1, d_k) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) if self.position_embeddings_type == "relative": if relative_position_embeddings is None: raise ValueError( "`relative_position_embeddings` has to be defined when `self.position_embeddings_type ==" " 'relative'" ) # apply relative_position_embeddings to qk scores # as proposed in Transformer_XL: https://huggingface.co/papers/1901.02860 scores = self._apply_relative_embeddings( query=query, key=key, relative_position_embeddings=relative_position_embeddings ) else: scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) if self.position_embeddings_type == "relative_key": query_length, key_length = query.shape[2], key.shape[2] position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_r - position_ids_l distance = torch.clamp(distance, -self.left_max_position_embeddings, self.right_max_position_embeddings) positional_embedding = self.distance_embedding(distance + self.left_max_position_embeddings) positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility relative_position_attn_weights = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) scores = scores + (relative_position_attn_weights / math.sqrt(self.head_size)) # apply attention_mask if necessary if attention_mask is not None: scores = scores + attention_mask # => (batch, head, time1, time2) probs = torch.softmax(scores, dim=-1) probs = self.dropout(probs) # => (batch, head, time1, d_k) hidden_states = torch.matmul(probs, value) # => (batch, time1, hidden_size) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) hidden_states = self.linear_out(hidden_states) return hidden_states, probs class Wav2Vec2BertEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): super().__init__() embed_dim = config.hidden_size dropout = config.attention_dropout # Feed-forward 1 self.ffn1_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.ffn1 = Wav2Vec2BertFeedForward(config) # Self-Attention self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.self_attn_dropout = nn.Dropout(dropout) self.self_attn = Wav2Vec2BertSelfAttention(config) # Conformer Convolution self.conv_module = Wav2Vec2BertConvolutionModule(config) # Feed-forward 2 self.ffn2_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.ffn2 = Wav2Vec2BertFeedForward(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states, attention_mask: Optional[torch.Tensor] = None, relative_position_embeddings: Optional[torch.Tensor] = None, output_attentions: bool = False, conv_attention_mask: Optional[torch.Tensor] = None, ): # 1. Feed-Forward 1 layer residual = hidden_states hidden_states = self.ffn1_layer_norm(hidden_states) hidden_states = self.ffn1(hidden_states) hidden_states = hidden_states * 0.5 + residual residual = hidden_states # 2. Self-Attention layer hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weigts = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, relative_position_embeddings=relative_position_embeddings, output_attentions=output_attentions, ) hidden_states = self.self_attn_dropout(hidden_states) hidden_states = hidden_states + residual # 3. Convolutional Layer residual = hidden_states hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask) hidden_states = residual + hidden_states # 4. Feed-Forward 2 Layer residual = hidden_states hidden_states = self.ffn2_layer_norm(hidden_states) hidden_states = self.ffn2(hidden_states) hidden_states = hidden_states * 0.5 + residual hidden_states = self.final_layer_norm(hidden_states) return hidden_states, attn_weigts class Wav2Vec2BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config if config.position_embeddings_type == "relative": self.embed_positions = Wav2Vec2BertRelPositionalEmbedding(config) elif config.position_embeddings_type == "rotary": self.embed_positions = Wav2Vec2BertRotaryPositionalEmbedding(config) else: self.embed_positions = None self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Wav2Vec2BertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None conv_attention_mask = attention_mask if attention_mask is not None: # make sure padded tokens output 0 hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) # extend attention_mask attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) hidden_states = self.dropout(hidden_states) if self.embed_positions is not None: relative_position_embeddings = self.embed_positions(hidden_states) else: relative_position_embeddings = None synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) dropout_probability = torch.rand([]) skip_the_layer = self.training and dropout_probability < self.config.layerdrop if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync layer_outputs = layer( hidden_states, attention_mask=attention_mask, relative_position_embeddings=relative_position_embeddings, output_attentions=output_attentions, conv_attention_mask=conv_attention_mask, ) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class Wav2Vec2BertAdapter(nn.Module): def __init__(self, config): super().__init__() # feature dim might need to be down-projected if config.output_hidden_size != config.hidden_size: self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size, eps=config.layer_norm_eps) else: self.proj = self.proj_layer_norm = None self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers)) self.layerdrop = config.layerdrop self.kernel_size = config.adapter_kernel_size self.stride = config.adapter_stride def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens): if seq_lens is None: return seq_lens pad = self.kernel_size // 2 seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 return seq_lens.floor() def forward(self, hidden_states, attention_mask=None): # down project hidden_states if necessary if self.proj is not None and self.proj_layer_norm is not None: hidden_states = self.proj(hidden_states) hidden_states = self.proj_layer_norm(hidden_states) sub_sampled_lengths = None if attention_mask is not None: sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device) for layer in self.layers: layerdrop_prob = torch.rand([]) sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths) if not self.training or (layerdrop_prob > self.layerdrop): hidden_states = layer( hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths ) return hidden_states class Wav2Vec2BertAdapterLayer(nn.Module): def __init__(self, config): super().__init__() embed_dim = config.output_hidden_size dropout = config.conformer_conv_dropout self.kernel_size = config.adapter_kernel_size self.stride = config.adapter_stride # 1. residual convolution self.residual_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.residual_conv = nn.Conv1d( embed_dim, 2 * embed_dim, self.kernel_size, stride=self.stride, padding=self.stride // 2, ) self.activation = nn.GLU(dim=1) # Self-Attention self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.self_attn_conv = nn.Conv1d( embed_dim, 2 * embed_dim, self.kernel_size, stride=self.stride, padding=self.stride // 2, ) self.self_attn = Wav2Vec2BertSelfAttention(config, is_adapter_attention=True) self.self_attn_dropout = nn.Dropout(dropout) # Feed-forward self.ffn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.ffn = Wav2Vec2BertFeedForward(config, act_fn=config.adapter_act, hidden_size=embed_dim) def forward( self, hidden_states, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, sub_sampled_lengths: Optional[torch.Tensor] = None, ): residual = self.residual_layer_norm(hidden_states) # Apply pooling to the residual to match the sequence length of the # multi-head attention output. # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) residual = residual.transpose(1, 2) residual = self.residual_conv(residual) residual = self.activation(residual) # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) residual = residual.transpose(1, 2) hidden_states = self.self_attn_layer_norm(hidden_states) # Apply pooling before feeding to the multihead-attention layer. # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) hidden_states = hidden_states.transpose(1, 2) hidden_states = self.self_attn_conv(hidden_states) hidden_states = self.activation(hidden_states) # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) hidden_states = hidden_states.transpose(1, 2) if attention_mask is not None: attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) attention_mask = _prepare_4d_attention_mask( attention_mask, hidden_states.dtype, ) # The rest of the computation is identical to a vanilla Transformer # encoder layer. hidden_states, attn_weights = self.self_attn( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = self.self_attn_dropout(hidden_states) hidden_states = hidden_states + residual residual = hidden_states hidden_states = self.ffn_layer_norm(hidden_states) hidden_states = self.ffn(hidden_states) + residual return hidden_states @auto_docstring class Wav2Vec2BertPreTrainedModel(PreTrainedModel): config: Wav2Vec2BertConfig base_model_prefix = "wav2vec2_bert" main_input_name = "input_features" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Wav2Vec2BertSelfAttention): if hasattr(module, "pos_bias_u"): nn.init.xavier_uniform_(module.pos_bias_u) if hasattr(module, "pos_bias_v"): nn.init.xavier_uniform_(module.pos_bias_v) elif isinstance(module, Wav2Vec2BertFeatureProjection): k = math.sqrt(1 / module.projection.in_features) nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, Wav2Vec2BertModel): if hasattr(module, "masked_spec_embed"): module.masked_spec_embed.data.uniform_() elif isinstance( module, (Wav2Vec2BertForSequenceClassification, Wav2Vec2BertForAudioFrameClassification, Wav2Vec2BertForXVector), ): if hasattr(module, "layer_weights"): module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1)) elif isinstance(module, AMSoftmaxLoss): # noqa: F821 module.weight.data.normal_() # Ignore copy def _get_feat_extract_output_lengths( self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None ): """ Computes the output length of the convolutional layers """ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter def _conv_out_length(input_length, kernel_size, stride, padding): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html return torch.div(input_length + 2 * padding - kernel_size, stride, rounding_mode="floor") + 1 if add_adapter: padding = self.config.adapter_kernel_size // 2 for _ in range(self.config.num_adapter_layers): input_lengths = _conv_out_length( input_lengths, self.config.adapter_kernel_size, self.config.adapter_stride, padding ) return input_lengths def _get_feature_vector_attention_mask( self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None ): # Effectively attention_mask.sum(-1), but not inplace to be able to run # on inference mode. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) output_lengths = output_lengths.to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device ) # these two operations makes sure that all values before the output lengths idxs are attended to attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask Wav2Vec2BertBaseModelOutput = Wav2Vec2BaseModelOutput class Wav2Vec2BertModel(Wav2Vec2Model, Wav2Vec2BertPreTrainedModel): def __init__(self, config: Wav2Vec2BertConfig): Wav2Vec2BertPreTrainedModel.__init__(self, config) self.config = config self.feature_projection = Wav2Vec2BertFeatureProjection(config) # model only needs masking vector if mask prob is > 0.0 if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) self.encoder = Wav2Vec2BertEncoder(config) self.adapter = Wav2Vec2BertAdapter(config) if config.add_adapter else None self.intermediate_ffn = None if config.use_intermediate_ffn_before_adapter: self.intermediate_ffn = Wav2Vec2BertFeedForward(config, act_fn="relu") # Initialize weights and apply final processing self.post_init() def freeze_feature_extractor(self): raise AttributeError("Not needed for Wav2Vec2Bert") def freeze_feature_encoder(self): raise AttributeError("Not needed for Wav2Vec2Bert") def forward( self, input_features: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, Wav2Vec2BertBaseModelOutput]: r""" mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict masked extracted features in *config.proj_codevector_dim* space. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states, extract_features = self.feature_projection(input_features) hidden_states = self._mask_hidden_states( hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask ) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if self.intermediate_ffn: expanded_hidden_states = self.intermediate_ffn(hidden_states) hidden_states = hidden_states + 0.5 * expanded_hidden_states if self.adapter is not None: hidden_states = self.adapter(hidden_states, attention_mask=attention_mask) if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] return Wav2Vec2BertBaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class Wav2Vec2BertForCTC(Wav2Vec2ConformerForCTC): def __init__(self, config, target_lang: Optional[str] = None): r""" target_lang (`str`, *optional*): Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or adapter..bin. Only relevant when using an instance of [`UniSpeechSatForCTC`] with adapters. Uses 'eng' by default. """ super().__init__(config) def freeze_feature_encoder(self): raise AttributeError("Not needed for Wav2Vec2Bert") def forward( self, input_features: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[tuple, CausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`. """ if labels is not None and labels.max() >= self.config.vocab_size: raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.wav2vec2_bert( input_features, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones(input_features.shape[:2], device=input_features.device, dtype=torch.long) ) input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum([-1])).to(torch.long) # assuming that padded tokens are filled with -100 # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) # ctc_loss doesn't support fp16 log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): loss = nn.functional.ctc_loss( log_probs, flattened_targets, input_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return CausalLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) class Wav2Vec2BertForSequenceClassification(Wav2Vec2ForSequenceClassification): def __init__(self, config): super().__init__(config) def freeze_feature_extractor(self): raise AttributeError("Not needed for Wav2Vec2Bert") def freeze_feature_encoder(self): raise AttributeError("Not needed for Wav2Vec2Bert") def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not be updated during training. Only the classification head will be updated. """ for param in self.wav2vec2_bert.parameters(): param.requires_grad = False def forward( self, input_features: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states outputs = self.wav2vec2_bert( input_features, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.use_weighted_layer_sum: hidden_states = outputs[_HIDDEN_STATES_START_POSITION] hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: hidden_states = outputs[0] hidden_states = self.projector(hidden_states) if attention_mask is None: pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2ConformerForAudioFrameClassification): def __init__(self, config): super().__init__(config) def freeze_feature_encoder(self): raise AttributeError("Not needed for Wav2Vec2Bert") def forward( self, input_features: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states outputs = self.wav2vec2_bert( input_features, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.use_weighted_layer_sum: hidden_states = outputs[_HIDDEN_STATES_START_POSITION] hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: hidden_states = outputs[0] logits = self.classifier(hidden_states) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class Wav2Vec2BertForXVector(Wav2Vec2ConformerForXVector): def __init__(self, config): super().__init__(config) def freeze_feature_encoder(self): raise AttributeError("Not needed for Wav2Vec2Bert") def forward( self, input_features: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[tuple, XVectorOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states outputs = self.wav2vec2_bert( input_features, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.use_weighted_layer_sum: hidden_states = outputs[_HIDDEN_STATES_START_POSITION] hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: hidden_states = outputs[0] hidden_states = self.projector(hidden_states) for tdnn_layer in self.tdnn: hidden_states = tdnn_layer(hidden_states) # Statistic Pooling if attention_mask is None: mean_features = hidden_states.mean(dim=1) std_features = hidden_states.std(dim=1) else: feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) mean_features = [] std_features = [] for i, length in enumerate(tdnn_output_lengths): mean_features.append(hidden_states[i, :length].mean(dim=0)) std_features.append(hidden_states[i, :length].std(dim=0)) mean_features = torch.stack(mean_features) std_features = torch.stack(std_features) statistic_pooling = torch.cat([mean_features, std_features], dim=-1) output_embeddings = self.feature_extractor(statistic_pooling) logits = self.classifier(output_embeddings) loss = None if labels is not None: loss = self.objective(logits, labels) if not return_dict: output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return XVectorOutput( loss=loss, logits=logits, embeddings=output_embeddings, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "Wav2Vec2BertForAudioFrameClassification", "Wav2Vec2BertForCTC", "Wav2Vec2BertForSequenceClassification", "Wav2Vec2BertForXVector", "Wav2Vec2BertModel", "Wav2Vec2BertPreTrainedModel", ]