# coding=utf-8 # Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Musicgen Melody model.""" import copy import inspect import math import random from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import ( ClassifierFreeGuidanceLogitsProcessor, GenerationConfig, GenerationMixin, GenerationMode, LogitsProcessorList, StoppingCriteriaList, ) from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ...utils.deprecation import deprecate_kwarg from ..auto.configuration_auto import AutoConfig from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig if is_torch_flex_attn_available(): from ...integrations.flex_attention import make_flex_block_causal_mask if TYPE_CHECKING: from ...generation.streamers import BaseStreamer logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Base class for Musicgen Melody autoregressive outputs. """ ) class MusicgenMelodyOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of conditional hidden-states representing the concatenation of the projected text encoder output and the projected audio encoder output. Used as a conditional signal. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[torch.FloatTensor] = None # Copied from transformers.models.musicgen.modeling_musicgen.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ # transpose to get (bsz, num_codebooks, seq_len) input_ids = input_ids.transpose(1, 2) shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() if decoder_start_token_id is None: raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") shifted_input_ids[..., 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->MusicgenMelody class MusicgenMelodySinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int): super().__init__() self.embedding_dim = embedding_dim self.make_weights(num_positions, embedding_dim) def make_weights(self, num_embeddings: int, embedding_dim: int): emb_weights = self.get_embedding(num_embeddings, embedding_dim) if hasattr(self, "weights"): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) self.register_buffer("weights", emb_weights, persistent=False) @staticmethod def get_embedding(num_embeddings: int, embedding_dim: int): """ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) return emb.to(torch.get_default_dtype()) @torch.no_grad() # Ignore copy def forward(self, inputs_embeds: torch.Tensor, past_key_values_length: int = 0): bsz, seq_len, _ = inputs_embeds.size() # Create the position ids from the input token ids. position_ids = (torch.arange(seq_len) + past_key_values_length).to(inputs_embeds.device) # expand embeddings if needed if seq_len > self.weights.size(0): self.make_weights(seq_len, self.embedding_dim) return self.weights.index_select(0, position_ids.view(-1)).detach() # Copied from transformers.models.bart.modeling_bart.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, head_mask: Optional[torch.Tensor] = None, **kwargs, ): if scaling is None: scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) if head_mask is not None: attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody class MusicgenMelodyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, dropout: Optional[float] = 0.0, is_decoder: Optional[bool] = False, bias: Optional[bool] = True, is_causal: Optional[bool] = False, config: Optional[MusicgenMelodyConfig] = None, layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) is_updated = False if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache curr_past_key_value = past_key_values.cross_attention_cache else: curr_past_key_value = past_key_values.self_attention_cache else: curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output, attn_weights class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MusicgenMelodyDecoderConfig, layer_idx=None): super().__init__() self.embed_dim = config.hidden_size self.self_attn = MusicgenMelodyAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=True, bias=False, is_causal=True, config=config, layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False) self.final_layer_norm = nn.LayerNorm(self.embed_dim) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(attention_heads,)`. past_key_values (`Cache`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states return hidden_states, self_attn_weights @auto_docstring # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->MusicgenMelody class MusicgenMelodyPreTrainedModel(PreTrainedModel): config: MusicgenMelodyDecoderConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True def _init_weights(self, module): std = self.config.initializer_factor if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenMelodyDecoderLayer`] """ def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.layerdrop self.max_target_positions = config.max_position_embeddings self.d_model = config.hidden_size self.num_codebooks = config.num_codebooks self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 embed_dim = config.vocab_size + 1 self.embed_tokens = nn.ModuleList( [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] ) self.embed_positions = MusicgenMelodySinusoidalPositionalEmbedding( config.max_position_embeddings, config.hidden_size, ) self.layers = nn.ModuleList( [MusicgenMelodyDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] ) self.layer_norm = nn.LayerNorm(config.hidden_size) self.attn_implementation = config._attn_implementation self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @auto_docstring # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. [What are input IDs?](../glossary#input-ids) The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as `input_ids`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states representing the concatenation of the text encoder output and the processed audio encoder output. Used as a conditional signal and will thus be concatenated to the projected `decoder_input_ids`. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): Mask to avoid performing attention on conditional hidden states. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len) input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) bsz, num_codebooks, seq_len = input.shape input_shape = (bsz, seq_len) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1:] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False if use_cache and past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = sum(self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)) if encoder_hidden_states is not None: # take care of attention masks if encoder_attention_mask is not None and attention_mask is None: attention_mask = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device) if attention_mask is not None: if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=attention_mask.device) attention_mask = torch.cat([encoder_attention_mask, attention_mask], dim=1) # fuse encoder_hidden_states and inputs_embeds inputs_embeds = torch.cat([encoder_hidden_states, inputs_embeds], dim=1) input_shape = inputs_embeds.size()[:-1] attention_mask = self._update_causal_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length, ) # embed positions positions = self.embed_positions(inputs_embeds, past_key_values_length) hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: if head_mask.size()[0] != len(self.layers): raise ValueError( f"The `head_mask` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): continue layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) def _update_causal_mask( self, attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, ): if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, inputs_embeds, past_key_values_length, ) elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) # while we need to create our specific block mask regardless elif attention_mask is None: attention_mask = make_flex_block_causal_mask( torch.ones( size=(input_shape), device=inputs_embeds.device, ) ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) return attention_mask # Ignore copy def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, ): # MusicgenMelody doesn't apply cross attention, hence it's ignored here # and only exists to not confuse any copy checks pass @auto_docstring # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody class MusicgenMelodyModel(MusicgenMelodyPreTrainedModel): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__(config) self.decoder = MusicgenMelodyDecoder(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.decoder.embed_tokens def set_input_embeddings(self, value): self.decoder.embed_tokens = value @auto_docstring # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. [What are input IDs?](../glossary#input-ids) The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as `input_ids`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states representing the concatenation of the text encoder output and the processed audio encoder output. Used as a conditional signal and will thus be concatenated to the projected `decoder_input_ids`. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): Mask to avoid performing attention on conditional hidden states. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) if not return_dict: return decoder_outputs return BaseModelOutputWithPast( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, ) @auto_docstring( custom_intro=""" The Musicgen Melody decoder model with a language modelling head on top. """ ) # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody,MusicGen->Musicgen Melody class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__(config) self.model = MusicgenMelodyModel(config) self.num_codebooks = config.num_codebooks self.lm_heads = nn.ModuleList( [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] ) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.decoder.embed_tokens def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value def get_output_embeddings(self): return self.lm_heads def set_output_embeddings(self, new_embeddings): self.lm_heads = new_embeddings def set_decoder(self, decoder): self.model.decoder = decoder def get_decoder(self): return self.model.decoder @auto_docstring # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, MusicgenMelodyOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. [What are input IDs?](../glossary#input-ids) The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as `input_ids`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states representing the concatenation of the text encoder output and the processed audio encoder output. Used as a conditional signal and will thus be concatenated to the projected `decoder_input_ids`. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): Mask to avoid performing attention on conditional hidden states. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (labels is not None) and (input_ids is None and inputs_embeds is None): input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id) outputs = self.model( input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1) loss = None if labels is not None: # since encoder hidden states have been concatenated to the decoder hidden states, # we take the last timestamps corresponding to labels logits = lm_logits[:, :, -labels.shape[1] :] loss_fct = CrossEntropyLoss() loss = torch.zeros([], device=self.device) # per codebook cross-entropy # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243 # -100 labels are ignored labels = labels.masked_fill(labels == self.config.pad_token_id, -100) # per codebook cross-entropy for codebook in range(self.config.num_codebooks): codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) codebook_labels = labels[..., codebook].contiguous().view(-1) loss += loss_fct(codebook_logits, codebook_labels) loss = loss / self.config.num_codebooks # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return MusicgenMelodyOutputWithPast( loss=loss, logits=lm_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # Ignore copy def prepare_inputs_for_generation( self, input_ids, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, past_key_values=None, use_cache=True, delay_pattern_mask=None, guidance_scale=None, **kwargs, ): # Overwritten -- MusicGen has custom processing if delay_pattern_mask is None: input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, pad_token_id=self.generation_config.pad_token_id, max_length=self.generation_config.max_length, ) # apply the delay pattern mask input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask) if guidance_scale is not None and guidance_scale > 1: # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these # before sampling) input_ids = input_ids.repeat((2, 1)) if attention_mask is not None: attention_mask = attention_mask.repeat((2, 1)) if encoder_hidden_states is not None: encoder_hidden_states = torch.concatenate( [encoder_hidden_states, torch.zeros_like(encoder_hidden_states)], dim=0 ) if encoder_attention_mask is not None: encoder_attention_mask = torch.concatenate( encoder_attention_mask, torch.zeros_like(encoder_attention_mask), dim=0 ) if past_key_values is not None: input_ids = input_ids[:, -1:] # we only want to use conditional signal in the 1st generation step but keeping the attention mask encoder_hidden_states = None return { "input_ids": input_ids, "attention_mask": attention_mask, "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask, "head_mask": head_mask, "past_key_values": past_key_values, "use_cache": use_cache, } def build_delay_pattern_mask( self, input_ids: torch.LongTensor, pad_token_id: int, max_length: Optional[int] = None ): """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, seq_len)`: - [P, -1, -1, -1, -1, P, P, P] - [P, P, -1, -1, -1, -1, P, P] - [P, P, P, -1, -1, -1, -1, P] - [P, P, P, P, -1, -1, -1, -1] where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the mask is set to the value in the prompt: - [P, a, b, -1, -1, P, P, P] - [P, P, c, d, -1, -1, P, P] - [P, P, P, e, f, -1, -1, P] - [P, P, P, P, g, h, -1, -1] where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 tokens in our prediction. """ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) bsz, num_codebooks, seq_len = input_ids.shape max_length = max_length if max_length is not None else self.generation_config.max_length input_ids_shifted = ( torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 ) channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks # we only apply the mask if we have a large enough seq len - otherwise we return as is if max_length < 2 * channel_codebooks - 1: return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1) # fill the shifted ids with the prompt entries, offset by the codebook idx for codebook in range(channel_codebooks): if self.config.audio_channels == 1: # mono channel - loop over the codebooks one-by-one input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] else: # left/right channels are interleaved in the generated codebooks, so handle one then the other input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook] input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1] # construct a pattern mask that indicates the positions of padding tokens for each codebook # first fill the upper triangular part (the EOS padding) delay_pattern = torch.triu( torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1 ) # then fill the lower triangular part (the BOS padding) delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool)) if self.config.audio_channels == 2: # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion delay_pattern = delay_pattern.repeat_interleave(2, dim=0) mask = ~delay_pattern.to(input_ids.device) input_ids = mask * input_ids_shifted + ~mask * pad_token_id # find the first position to start generating - this is the first place we have the -1 token # and will always be in the first codebook (since it has no codebook offset) first_codebook_ids = input_ids[:, 0, :] start_ids = (first_codebook_ids == -1).nonzero()[:, 1] if len(start_ids) > 0: first_start_id = min(start_ids) else: # we have no tokens that need to be filled - return entire matrix of input ids first_start_id = seq_len # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) pattern_mask = input_ids.reshape(bsz * num_codebooks, -1) input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) return input_ids, pattern_mask @staticmethod def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): """Apply a delay pattern mask to the decoder input ids, only preserving predictions where the mask is set to -1, and otherwise setting to the value detailed in the mask.""" seq_len = input_ids.shape[-1] decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) return input_ids @torch.no_grad() # Ignore copy def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, synced_gpus: Optional[bool] = None, streamer: Optional["BaseStreamer"] = None, **kwargs, ): """ Generates sequences of token ids for models with a language modeling head. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation strategies and code examples, check out the [following guide](./generation_strategies). Parameters: inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed to avoid deadlocking with `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. kwargs (`dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GenerateDecoderOnlyOutput`], - [`~generation.GenerateBeamDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GenerateEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`] """ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs` input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = input_ids.shape[0] // self.num_codebooks self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["guidance_scale"] = generation_config.guidance_scale if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( input_ids, generation_config, model_kwargs ) # 5. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=input_ids, input_ids_length=input_ids_length, ) # 6. Prepare `input_ids` which will be used for auto-regressive generation # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen) input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, pad_token_id=generation_config._decoder_start_token_tensor, max_length=generation_config.max_length, ) if streamer is not None: streamer.put(input_ids.cpu()) # stash the delay mask so that we don't have to recompute it in each forward pass model_kwargs["delay_pattern_mask"] = delay_pattern_mask # 7. determine generation mode generation_mode = generation_config.get_generation_mode() # 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG) if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) generation_config.guidance_scale = None # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, device=input_ids.device, ) # 10. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, **model_kwargs, ) # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) else: raise ValueError( "Got incompatible mode for generation, should be one of greedy or sampling. " "Ensure that beam search is de-activated by setting `num_beams=1`." ) if generation_config.return_dict_in_generate: output_ids = outputs.sequences else: output_ids = outputs # apply the pattern mask to the final ids output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) # revert the pattern delay mask by filtering the pad token id output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( batch_size, self.num_codebooks, -1 ) if generation_config.return_dict_in_generate: outputs.sequences = output_ids return outputs else: return output_ids @auto_docstring class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): config: MusicgenMelodyConfig main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True def __init__( self, config: MusicgenMelodyConfig = None, text_encoder: Optional[PreTrainedModel] = None, audio_encoder: Optional[PreTrainedModel] = None, decoder: Optional[MusicgenMelodyForCausalLM] = None, ): r""" text_encoder (`PreTrainedModel`, *optional*): The text encoder model that encodes text into hidden states for conditioning. audio_encoder (`PreTrainedModel`, *optional*): The audio encoder model that encodes audio into hidden states for conditioning. decoder (`MusicgenMelodyForCausalLM`, *optional*): The decoder model that generates audio tokens based on conditioning signals. """ if config is None and None in (text_encoder, audio_encoder, decoder): raise ValueError( "Either a configuration has to be provided, or all three of text encoder, audio encoder and Musicgen Melody decoder." ) if config is None: config = MusicgenMelodyConfig.from_sub_models_config( text_encoder.config, audio_encoder.config, decoder.config ) else: if not isinstance(config, self.config_class): raise ValueError(f"Config: {config} has to be of type {self.config_class}") # initialize with config super().__init__(config) if text_encoder is None: text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder) if audio_encoder is None: audio_encoder = AutoModel.from_config(config.audio_encoder) if decoder is None: decoder = MusicgenMelodyForCausalLM._from_config(config.decoder) self.text_encoder = text_encoder self.audio_encoder = audio_encoder self.decoder = decoder # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation self.config.decoder._attn_implementation = self.decoder.config._attn_implementation self.text_encoder.config = self.config.text_encoder self.audio_encoder.config = self.config.audio_encoder self.decoder.config = self.config.decoder # text encoder outputs might need to be projected to different dimension for decoder if self.text_encoder.config.hidden_size != self.decoder.config.hidden_size: self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) # audio encoder outputs after chroma extraction might need to be projected to different dimension for decoder if self.config.num_chroma != self.decoder.config.hidden_size: self.audio_enc_to_dec_proj = nn.Linear(self.config.num_chroma, self.decoder.config.hidden_size) if self.text_encoder.get_output_embeddings() is not None: raise ValueError( f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" ) # Initialize projection layers weights and tie text encoder and decoder weights if set accordingly self.post_init() def _init_weights(self, module): # MusicgenMelodyForConditionalGeneration is made of PreTrainedModels that have already been initialized # Projection layers still need to be initialized. std = self.decoder.config.initializer_factor if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() def tie_weights(self): # tie text encoder & decoder if needed if self.config.tie_encoder_decoder: # tie text encoder and decoder base model decoder_base_model_prefix = self.decoder.base_model_prefix tied_weights = self._tie_encoder_decoder_weights( self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix, "text_encoder", ) # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class # attributed not an instance member, therefore modifying it will modify the entire class # Leading to issues on subsequent calls by different tests or subsequent calls. self._dynamic_tied_weights_keys = tied_weights def get_text_encoder(self): return self.text_encoder def get_encoder(self): # get the text encoder to compute the conditioning hidden-states for generation return self.get_text_encoder() def get_input_embeddings(self): return self.text_encoder.get_input_embeddings() def get_output_embeddings(self): return self.decoder.get_output_embeddings() def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) @classmethod # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration.from_sub_models_pretrained with Musicgen->MusicgenMelody, musicgen-small->musicgen-melody def from_sub_models_pretrained( cls, text_encoder_pretrained_model_name_or_path: Optional[str] = None, audio_encoder_pretrained_model_name_or_path: Optional[str] = None, decoder_pretrained_model_name_or_path: Optional[str] = None, *model_args, **kwargs, ) -> PreTrainedModel: r""" Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the library from pretrained model checkpoints. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train the model, you need to first set it back in training mode with `model.train()`. Params: text_encoder_pretrained_model_name_or_path (`str`, *optional*): Information necessary to initiate the text encoder. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. audio_encoder_pretrained_model_name_or_path (`str`, *optional*): Information necessary to initiate the audio encoder. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): Information necessary to initiate the decoder. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. model_args (remaining positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration parameter. - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration parameter. - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - To update the parent model configuration, do not use a prefix for each configuration parameter. Behaves differently depending on whether a `config` is provided or automatically loaded. Example: ```python >>> from transformers import MusicgenMelodyForConditionalGeneration >>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder >>> model = MusicgenMelodyForConditionalGeneration.from_sub_models_pretrained( ... text_encoder_pretrained_model_name_or_path="google-t5/t5-base", ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz", ... decoder_pretrained_model_name_or_path="facebook/musicgen-melody", ... ) >>> # saving model after fine-tuning >>> model.save_pretrained("./musicgen-ft") >>> # load fine-tuned model >>> model = MusicgenMelodyForConditionalGeneration.from_pretrained("./musicgen-ft") ```""" kwargs_text_encoder = { argument[len("text_encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("text_encoder_") } kwargs_audio_encoder = { argument[len("audio_encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("audio_encoder_") } kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } # remove text encoder, audio encoder and decoder kwargs from kwargs for key in kwargs_text_encoder: del kwargs["text_encoder_" + key] for key in kwargs_audio_encoder: del kwargs["audio_encoder_" + key] for key in kwargs_decoder: del kwargs["decoder_" + key] # Load and initialize the encoder and decoder # The distinction between encoder and decoder at the model level is made # by the value of the flag `is_decoder` that we need to set correctly. text_encoder = kwargs_text_encoder.pop("model", None) if text_encoder is None: if text_encoder_pretrained_model_name_or_path is None: raise ValueError( "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_text_encoder: encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained( text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model " "from a decoder model. Cross-attention and causal mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False kwargs_text_encoder["config"] = encoder_config text_encoder = AutoModel.from_pretrained( text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder ) audio_encoder = kwargs_audio_encoder.pop("model", None) if audio_encoder is None: if audio_encoder_pretrained_model_name_or_path is None: raise ValueError( "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_audio_encoder: encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained( audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model " "from a decoder model. Cross-attention and causal mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False kwargs_audio_encoder["config"] = encoder_config audio_encoder = AutoModel.from_pretrained( audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder ) decoder = kwargs_decoder.pop("model", None) if decoder is None: if decoder_pretrained_model_name_or_path is None: raise ValueError( "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_decoder: decoder_config, kwargs_decoder = AutoConfig.from_pretrained( decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True ) if isinstance(decoder_config, MusicgenMelodyConfig): decoder_config = decoder_config.decoder if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True kwargs_decoder["config"] = decoder_config if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a " "`decoder_config` to `.from_sub_models_pretrained(...)`" ) decoder = MusicgenMelodyForCausalLM.from_pretrained( decoder_pretrained_model_name_or_path, **kwargs_decoder ) # instantiate config with corresponding kwargs config = MusicgenMelodyConfig.from_sub_models_config( text_encoder.config, audio_encoder.config, decoder.config, **kwargs ) return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config) @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, input_features: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, past_key_values: Optional[Cache] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[tuple, MusicgenMelodyOutputWithPast]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks, target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as `decoder_input_ids`. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of conditional hidden-states representing the concatenation of the projected text encoder output and the projected audio encoder output. Used as a conditional signal and will thus be concatenated to the projected `decoder_input_ids`. labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` Examples: ```python >>> from transformers import AutoProcessor, MusicgenMelodyForConditionalGeneration >>> import torch >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-melody") >>> model = MusicgenMelodyForConditionalGeneration.from_pretrained("facebook/musicgen-melody") >>> inputs = processor( ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], ... padding=True, ... return_tensors="pt", ... ) >>> pad_token_id = model.generation_config.pad_token_id >>> decoder_input_ids = ( ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) ... * pad_token_id ... ) >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits >>> logits.shape # (bsz * num_codebooks, encoder_len + tgt_len, vocab_size) torch.Size([8, 249, 2048]) ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_text_encoder = { argument[len("text_encoder_")]: value for argument, value in kwargs.items() if argument.startswith("text_encoder_") } kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } if encoder_hidden_states is None: if inputs_embeds is not None or input_ids is not None: encoder_outputs = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_text_encoder, ) encoder_hidden_states = encoder_outputs[0] # optionally project encoder_hidden_states if self.text_encoder.config.hidden_size != self.decoder.config.hidden_size: encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) if attention_mask is not None and encoder_hidden_states is not None: encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] # set a default audio conditional hidden states if text is not None if encoder_hidden_states is not None and input_features is None: input_features = torch.zeros( (encoder_hidden_states.shape[0], 1, self.config.num_chroma), device=self.device, dtype=self.dtype, ) input_features[:, :, 0] = 1 if input_features is not None: audio_hidden_states = input_features # optionally project audio_hidden_states -> # (batch_size, seq_len, num_chroma) -> (batch_size, seq_len, hidden_size) if self.config.num_chroma != self.decoder.config.hidden_size: audio_hidden_states = self.audio_enc_to_dec_proj(audio_hidden_states) # pad or truncate to config.chroma_length if audio_hidden_states.shape[1] < self.config.chroma_length: n_repeat = int(math.ceil(self.config.chroma_length / audio_hidden_states.shape[1])) audio_hidden_states = audio_hidden_states.repeat(1, n_repeat, 1) else: logger.warning( f"The conditional audio signal is of length {audio_hidden_states.shape[1]}, which exceeds" f"the maximum chroma duration of {self.config.chroma_length}." f"The audio will be truncated to {self.config.chroma_length} frames." ) audio_hidden_states = audio_hidden_states[:, : self.config.chroma_length] if encoder_hidden_states is not None: encoder_hidden_states = torch.cat([audio_hidden_states, encoder_hidden_states], dim=1) else: encoder_hidden_states = audio_hidden_states if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id ) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, labels=labels, **kwargs_decoder, ) if not return_dict: return decoder_outputs + (encoder_hidden_states,) return MusicgenMelodyOutputWithPast( loss=decoder_outputs.loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, encoder_hidden_states=encoder_hidden_states, ) def prepare_inputs_for_generation( self, decoder_input_ids, encoder_hidden_states=None, past_key_values=None, attention_mask=None, decoder_attention_mask=None, decoder_head_mask=None, use_cache=None, decoder_delay_pattern_mask=None, guidance_scale=None, **kwargs, ): # Overwritten -- MusicGen has custom processing if decoder_delay_pattern_mask is None: decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( decoder_input_ids, self.generation_config.pad_token_id, max_length=self.generation_config.max_length, ) # apply the delay pattern mask decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) if guidance_scale is not None and guidance_scale > 1: # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these # before sampling) decoder_input_ids = decoder_input_ids.repeat((2, 1)) if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) if past_key_values is not None: past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = decoder_input_ids.shape[1] - 1 decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] # we only want to use conditional signal in the 1st generation step but keeping the attention mask encoder_hidden_states = None # we also have to update the attention mask return { "input_ids": None, # encoder_hidden_states is defined. input_ids not needed "encoder_hidden_states": encoder_hidden_states, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "decoder_head_mask": decoder_head_mask, "use_cache": use_cache, } # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration._prepare_decoder_input_ids_for_generation def _prepare_decoder_input_ids_for_generation( self, batch_size: int, model_input_name: str, model_kwargs: dict[str, torch.Tensor], decoder_start_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, device: Optional[torch.device] = None, ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. if model_kwargs is not None and "decoder_input_ids" in model_kwargs: decoder_input_ids = model_kwargs.pop("decoder_input_ids") elif "input_ids" in model_kwargs and model_input_name != "input_ids": decoder_input_ids = model_kwargs.pop("input_ids") else: decoder_input_ids = None # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device decoder_input_ids_start = ( torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device) * decoder_start_token_id ) # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: decoder_input_ids = decoder_input_ids_start # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item(): decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), dim=-1, ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask return decoder_input_ids, model_kwargs def _prepare_encoder_hidden_states_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str], generation_config: GenerationConfig, ) -> dict[str, Any]: encoder_hidden_states = None # attention mask is consumed once to produce text conditional hidden states through the text encoder encoder_attention_mask = model_kwargs.pop("attention_mask") guidance_scale = generation_config.guidance_scale # 1. condition on text if inputs_tensor is not None: encoder = self.get_text_encoder() # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # as the inputs. if hasattr(encoder, "_hf_hook"): encoder._hf_hook.io_same_device = True # Prepare args and kwargs from model kwargs. irrelevant_prefix = ["decoder_", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } encoder_signature = set(inspect.signature(encoder.forward).parameters) encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature if not encoder_accepts_wildcard: encoder_kwargs = { argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature } encoder_kwargs["output_attentions"] = generation_config.output_attentions encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states # make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor if encoder_attention_mask is not None: encoder_kwargs["attention_mask"] = encoder_attention_mask encoder_hidden_states = encoder(**encoder_kwargs).last_hidden_state # optionally project encoder_hidden_states if self.text_encoder.config.hidden_size != self.decoder.config.hidden_size: encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) # for classifier free guidance we need to add a 'null' input to our encoder hidden states if guidance_scale is not None and guidance_scale > 1: encoder_hidden_states = torch.concatenate( [encoder_hidden_states, torch.zeros_like(encoder_hidden_states)], dim=0 ) if encoder_attention_mask is not None: encoder_attention_mask = torch.concatenate( [encoder_attention_mask, torch.zeros_like(encoder_attention_mask)], dim=0 ) if encoder_attention_mask is not None: encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[..., None] # 2. condition on audio audio_hidden_states = model_kwargs.get("input_features", None) if inputs_tensor is not None: if audio_hidden_states is not None: null_audio_hidden_states = torch.zeros_like(audio_hidden_states) else: null_audio_hidden_states = torch.zeros( (inputs_tensor.shape[0], 1, self.config.num_chroma), device=self.device, dtype=self.dtype ) null_audio_hidden_states[:, :, 0] = 1 if audio_hidden_states is None: audio_hidden_states = null_audio_hidden_states if audio_hidden_states is not None: # for classifier free guidance we need to add a 'null' input to our audio hidden states if guidance_scale is not None and guidance_scale > 1: audio_hidden_states = torch.concatenate([audio_hidden_states, null_audio_hidden_states], dim=0) # optionally project audio_hidden_states -> # (batch_size, seq_len, num_chroma) -> (batch_size, seq_len, hidden_size) if self.config.num_chroma != self.decoder.config.hidden_size: audio_hidden_states = self.audio_enc_to_dec_proj(audio_hidden_states) # pad or truncate to config.chroma_length if audio_hidden_states.shape[1] < self.config.chroma_length: n_repeat = int(math.ceil(self.config.chroma_length / audio_hidden_states.shape[1])) audio_hidden_states = audio_hidden_states.repeat(1, n_repeat, 1) audio_hidden_states = audio_hidden_states[:, : self.config.chroma_length] if encoder_hidden_states is not None: encoder_hidden_states = torch.cat([audio_hidden_states, encoder_hidden_states], dim=1) else: encoder_hidden_states = audio_hidden_states model_kwargs["encoder_hidden_states"] = encoder_hidden_states return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id) def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" " model.decoder.resize_token_embeddings(...))" ) def _maybe_initialize_input_ids_for_generation( self, inputs: Optional[torch.Tensor] = None, bos_token_id: Optional[int] = None, model_kwargs: Optional[dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: """Initializes input ids for generation, if necessary.""" if inputs is not None: return inputs if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with # soft-prompting or in multimodal implementations built on top of decoder-only language models. batch_size = 1 for value in model_kwargs.values(): if isinstance(value, torch.Tensor): batch_size = value.shape[0] break return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id def freeze_audio_encoder(self): """ Freeze the audio encoder weights. """ for param in self.audio_encoder.parameters(): param.requires_grad = False self.audio_encoder._requires_grad = False def freeze_text_encoder(self): """ Freeze the text encoder weights. """ for param in self.text_encoder.parameters(): param.requires_grad = False self.text_encoder._requires_grad = False # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration._get_decoder_start_token_id def _get_decoder_start_token_id( self, decoder_start_token_id: Optional[Union[int, list[int]]] = None, bos_token_id: Optional[int] = None ) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.generation_config.decoder_start_token_id ) bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id elif bos_token_id is not None: return bos_token_id raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, synced_gpus: Optional[bool] = None, streamer: Optional["BaseStreamer"] = None, **kwargs, ): """ Generates sequences of token ids for models with a language modeling head. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation strategies and code examples, check out the [following guide](./generation_strategies). Parameters: inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed to avoid deadlocking with `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. kwargs (`dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GenerateDecoderOnlyOutput`], - [`~generation.GenerateBeamDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GenerateEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`] """ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["guidance_scale"] = generation_config.guidance_scale if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config, model_kwargs ) if "encoder_hidden_states" not in model_kwargs: # encoder_hidden_states are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name, generation_config ) # 5. Prepare `input_ids` which will be used for auto-regressive generation input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, decoder_start_token_id=generation_config._decoder_start_token_tensor, bos_token_id=generation_config._bos_token_tensor, device=inputs_tensor.device, ) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. Prepare the cache. # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. # - different models have a different cache name expected by the model (default = "past_key_values") # - `max_length`, prepared above, is used to determine the maximum cache length max_cache_length = generation_config.max_length - 1 if ( inputs_tensor.shape[1] != input_ids_length and model_input_name == "inputs_embeds" and not self.config.is_encoder_decoder ): max_cache_length += inputs_tensor.shape[1] self._prepare_cache_for_generation( generation_config, model_kwargs, generation_mode=None, batch_size=batch_size, max_cache_length=max_cache_length, ) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, pad_token_id=generation_config._decoder_start_token_tensor, max_length=generation_config.max_length, ) # stash the delay mask so that we don't have to recompute in each forward pass model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask # input_ids are ready to be placed on the streamer (if used) if streamer is not None: streamer.put(input_ids.cpu()) # 8. determine generation mode generation_mode = generation_config.get_generation_mode() # 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG) if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) generation_config.guidance_scale = None # 10. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, device=input_ids.device, ) # 10. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) else: raise ValueError( "Got incompatible mode for generation, should be one of greedy or sampling. " "Ensure that beam search is de-activated by setting `num_beams=1`." ) if generation_config.return_dict_in_generate: output_ids = outputs.sequences else: output_ids = outputs # apply the pattern mask to the final ids output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) # revert the pattern delay mask by filtering the pad token id output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( batch_size, self.decoder.num_codebooks, -1 ) # append the frame dimension back to the audio codes output_ids = output_ids[None, ...] audio_scales = model_kwargs.get("audio_scales") if audio_scales is None: audio_scales = [None] * batch_size if self.decoder.config.audio_channels == 1: output_values = self.audio_encoder.decode( output_ids, audio_scales=audio_scales, ).audio_values else: codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales) output_values_left = codec_outputs_left.audio_values codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales) output_values_right = codec_outputs_right.audio_values output_values = torch.cat([output_values_left, output_values_right], dim=1) if generation_config.return_dict_in_generate: outputs.sequences = output_values return outputs else: return output_values __all__ = [ "MusicgenMelodyForConditionalGeneration", "MusicgenMelodyForCausalLM", "MusicgenMelodyModel", "MusicgenMelodyPreTrainedModel", ]