# coding=utf-8 # Copyright 2020 The Trax Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch REFORMER model.""" import sys from collections import namedtuple from collections.abc import Iterable from dataclasses import dataclass from functools import reduce from operator import mul from typing import Any, Optional, Union import numpy as np import torch from torch import nn from torch.autograd.function import Function from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, ModelOutput, auto_docstring, logging, ) from .configuration_reformer import ReformerConfig logger = logging.get_logger(__name__) # Define named tuples for nn.Modules here LSHSelfAttentionOutput = namedtuple("LSHSelfAttentionOutput", ["hidden_states", "attention_probs", "buckets"]) LocalSelfAttentionOutput = namedtuple("LocalSelfAttentionOutput", ["hidden_states", "attention_probs"]) AttentionOutput = namedtuple("AttentionOutput", ["hidden_states", "attention_probs", "buckets"]) ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", "attention_probs", "buckets"]) ReformerBackwardOutput = namedtuple( "ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"] ) ReformerEncoderOutput = namedtuple( "ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"], ) class ReformerDynamicCache: """ A dynamic cache that stores past buckets instead of key/values. """ def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self.buckets_cache: list[torch.Tensor] = [] self.states_cache: list[torch.Tensor] = [] if _distributed_cache_data is not None: for buckets, states in _distributed_cache_data: self.buckets_cache.append(buckets) self.states_cache.append(states) def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the sequence length. """ if layer_idx < len(self): return (self.buckets_cache[layer_idx], self.states_cache[layer_idx]) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def __iter__(self): """ Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over keys and values """ for layer_idx in range(len(self)): yield (self.buckets_cache[layer_idx], self.states_cache[layer_idx]) def __len__(self): """ Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds to the number of layers in the model. """ return len(self.states_cache) def update( self, buckets: torch.Tensor, states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `ReformerDynamicCache`. Return: A tuple containing the updated key and value states. """ # Update the number of seen tokens if layer_idx == 0: self._seen_tokens += states.shape[-2] # Update the cache if states is not None: if len(self.states_cache) <= layer_idx: self.states_cache.append(states) else: self.states_cache[layer_idx] = torch.cat([self.states_cache[layer_idx], states], dim=1) if buckets is not None: if len(self.buckets_cache) <= layer_idx: self.buckets_cache.append(buckets) else: self.buckets_cache[layer_idx] = torch.cat([self.buckets_cache[layer_idx], buckets], dim=-1) else: # `ReformerLocalAttn` passes `None` to buckets as the module uses no buckets self.buckets_cache.append(torch.tensor([], device=self.states_cache[layer_idx].device)) return self.buckets_cache[layer_idx], self.states_cache[layer_idx] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return None def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: """Converts the `ReformerDynamicCache` instance into the its equivalent in the legacy cache format. Used for backward compatibility.""" legacy_cache = () for layer_idx in range(len(self)): buckets, states = self.buckets_cache[layer_idx], self.states_cache[layer_idx] buckets = buckets if buckets.numel() != 0 else None legacy_cache += ((buckets, states),) return legacy_cache @classmethod def from_legacy_cache( cls, past_buckets_states: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None ) -> "ReformerDynamicCache": """Converts a cache in the legacy cache format into an equivalent `ReformerDynamicCache`. Used for backward compatibility.""" cache = cls() if past_buckets_states is not None: for layer_idx in range(len(past_buckets_states)): buckets, states = past_buckets_states[layer_idx] cache.update(buckets, states, layer_idx) return cache def _stable_argsort(vector, dim): # this function scales the vector so that torch.argsort is stable. # torch.argsort is not stable on its own scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1) scale_offset = scale_offset.expand(vector.shape) scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim]) return torch.argsort(scaled_vector, dim=dim) def _get_least_common_mult_chunk_len(config): attn_types = config.attn_layers attn_types_set = set(attn_types) if len(attn_types_set) == 1 and attn_types[0] == "lsh": return config.lsh_attn_chunk_length elif len(attn_types_set) == 1 and attn_types[0] == "local": return config.local_attn_chunk_length elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}: return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length) else: raise NotImplementedError( f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select " "attn layer types from ['lsh', 'local'] only." ) def _get_min_chunk_len(config): attn_types = config.attn_layers attn_types_set = set(attn_types) if len(attn_types_set) == 1 and attn_types[0] == "lsh": return config.lsh_attn_chunk_length elif len(attn_types_set) == 1 and attn_types[0] == "local": return config.local_attn_chunk_length elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}: return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length) else: raise NotImplementedError( f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select " "attn layer types from ['lsh', 'local'] only." ) class AxialPositionEmbeddings(nn.Module): """ Constructs axial position embeddings. Useful for very long input sequences to save memory and time. """ def __init__(self, config): super().__init__() self.axial_pos_shape = config.axial_pos_shape self.axial_pos_embds_dim = config.axial_pos_embds_dim self.dropout = config.hidden_dropout_prob self.least_common_mult_chunk_length = _get_least_common_mult_chunk_len(config) self.weights = nn.ParameterList() if sum(self.axial_pos_embds_dim) != config.hidden_size: raise ValueError( f"Make sure that config.axial_pos_embds factors: {self.axial_pos_embds_dim} sum to " f"config.hidden_size: {config.hidden_size}" ) # create weights for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim): # create expanded shapes ax_shape = [1] * len(self.axial_pos_shape) ax_shape[axis] = self.axial_pos_shape[axis] ax_shape = tuple(ax_shape) + (axial_pos_embd_dim,) # create tensor and init self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32))) def forward(self, position_ids): # broadcast weights to correct shape batch_size = position_ids.shape[0] sequence_length = position_ids.shape[1] broadcasted_weights = [ weight.expand((batch_size,) + self.axial_pos_shape + weight.shape[-1:]) for weight in self.weights ] if self.training is True: if reduce(mul, self.axial_pos_shape) != sequence_length: raise ValueError( f"If training, make sure that config.axial_pos_shape factors: {self.axial_pos_shape} multiply to " f"sequence length. Got prod({self.axial_pos_shape}) != sequence_length: {sequence_length}. " f"You might want to consider padding your sequence length to {reduce(mul, self.axial_pos_shape)} " "or changing config.axial_pos_shape." ) if self.dropout > 0: weights = torch.cat(broadcasted_weights, dim=-1) # permute weights so that 2D correctly drops dims 1 and 2 transposed_weights = weights.transpose(2, 1) # drop entire matrix of last two dims (prev dims 1 and 2) dropped_transposed_weights = nn.functional.dropout2d( transposed_weights, p=self.dropout, training=self.training ) dropped_weights = dropped_transposed_weights.transpose(2, 1) position_encodings = torch.reshape(dropped_weights, (batch_size, sequence_length, -1)) else: position_encodings = torch.cat( [torch.reshape(weight, (batch_size, sequence_length, -1)) for weight in broadcasted_weights], dim=-1, ) else: if reduce(mul, self.axial_pos_shape) < sequence_length: raise ValueError( f"Make sure that config.axial_pos_shape factors: {self.axial_pos_shape} multiply at least to " f"max(sequence_length, least_common_mult_chunk_length): max({sequence_length}, " f"{self.least_common_mult_chunk_length})." ) # compute how many columns are needed max_position_id = position_ids.max().item() required_pos_encodings_columns = -(-(max_position_id + 1) // self.axial_pos_shape[1]) # cut to columns that are needed position_encodings = torch.cat( [weight[:, :required_pos_encodings_columns] for weight in broadcasted_weights], dim=-1 ) position_encodings = torch.reshape(position_encodings, (batch_size, -1, position_encodings.shape[-1])) # select correct position encodings position_encodings = torch.cat( [ torch.index_select(position_encodings[i], 0, position_ids[i]).unsqueeze(0) for i in range(batch_size) ], dim=0, ) return position_encodings class PositionEmbeddings(nn.Module): """Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`.""" def __init__(self, config): super().__init__() self.dropout = config.hidden_dropout_prob self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) def forward(self, position_ids): position_embeddings = self.embedding(position_ids) position_embeddings = nn.functional.dropout(position_embeddings, p=self.dropout, training=self.training) return position_embeddings class ReformerEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, config): super().__init__() self.max_position_embeddings = config.max_position_embeddings self.dropout = config.hidden_dropout_prob self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = ( AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) ) def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, start_idx_pos_encodings=0): if input_ids is not None: input_shape = input_ids.size() device = input_ids.device else: input_shape = inputs_embeds.size()[:-1] device = inputs_embeds.device seq_length = input_shape[1] if position_ids is None: position_ids = torch.arange( start_idx_pos_encodings, start_idx_pos_encodings + seq_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).expand(input_shape) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) if position_ids.shape[-1] > self.max_position_embeddings: raise ValueError( f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than " f"config.max_position_embeddings {self.max_position_embeddings}." ) # dropout embeddings = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training) # add positional embeddings position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings return embeddings class EfficientAttentionMixin: """ A few utilities for nn.Modules in Reformer, to be used as a mixin. """ def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): """ Used to implement attention between consecutive chunks. Args: vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] num_chunks_before: chunks before current chunk to include in attention num_chunks_after: chunks after current chunk to include in attention Returns: tensor of shape [num_chunks, N * chunk_length, ...], where N = (1 + num_chunks_before + num_chunks_after). """ if num_chunks_before == 0 and num_chunks_after == 0: return vectors slices = [] for i in range(-num_chunks_before, num_chunks_after + 1): if i == 0: slices.append(vectors) else: slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2)) return torch.cat(slices, dim=3) def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size): """ splits hidden_size dim into attn_head_size and num_attn_heads """ new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size) x = x.view(*new_x_shape) return x.transpose(2, 1) def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size): """ merges attn_head_size dim and num_attn_heads dim into hidden_size """ x = x.permute(0, 2, 1, 3) return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): """ splits sequence length dim of vectors into `dim_factor_1` and `dim_factor_2` dims """ batch_size = vectors.shape[0] split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) if len(vectors.shape) == 4: return torch.reshape(vectors, split_dim_shape + (attn_head_size,)) elif len(vectors.shape) == 3: return torch.reshape(vectors, split_dim_shape) else: raise ValueError(f"Input vector rank should be one of [3, 4], but is: {len(vectors.shape)}") class LSHSelfAttention(nn.Module, EfficientAttentionMixin): def __init__(self, config, layer_idx=None): super().__init__() self.config = config self.chunk_length = config.lsh_attn_chunk_length self.num_hashes = config.num_hashes self.num_buckets = config.num_buckets self.num_chunks_before = config.lsh_num_chunks_before self.num_chunks_after = config.lsh_num_chunks_after self.hash_seed = config.hash_seed self.is_decoder = config.is_decoder self.max_position_embeddings = config.max_position_embeddings self.layer_idx = layer_idx self.dropout = config.lsh_attention_probs_dropout_prob self.num_attention_heads = config.num_attention_heads self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size self.hidden_size = config.hidden_size # projection matrices self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) # save mask value here. Need fp32 and fp16 mask values self.register_buffer("self_mask_value_float16", torch.tensor(-1e3), persistent=False) self.register_buffer("self_mask_value_float32", torch.tensor(-1e5), persistent=False) self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False) self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False) def forward( self, hidden_states, attention_mask=None, head_mask=None, num_hashes=None, buckets=None, past_buckets_states=None, use_cache=False, output_attentions=False, cache_position=None, **kwargs, ): sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] # num hashes can optionally be overwritten by user num_hashes = num_hashes if num_hashes is not None else self.num_hashes # check if cache shall be used and that hidden states are already cached exists_cache = past_buckets_states is not None and len(past_buckets_states) > self.layer_idx if exists_cache: assert sequence_length == 1, ( "At the moment, auto-regressive language generation is only possible one word at a time. Make sure" f" that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed." ) # get query vector query_vectors = self.query_key(hidden_states) query_vectors = self._split_hidden_size_dim( query_vectors, self.num_attention_heads, self.attention_head_size ) past_buckets = past_buckets_states.buckets_cache[self.layer_idx] past_states = past_buckets_states.states_cache[self.layer_idx] if past_buckets.numel() != 0: key_value_hidden_states, sorted_bucket_idx, buckets = self._get_relevant_hid_states_and_buckets( query_vectors=query_vectors, attention_mask=attention_mask, num_hashes=num_hashes, hidden_states=hidden_states, past_states=past_states, past_buckets=past_buckets, ) query_key_vectors = self._query_per_attn_head(key_value_hidden_states) value_vectors = self._value_per_attn_head(key_value_hidden_states) # split key & value vectors by num hashes to apply # self attention on each separately query_key_vectors = self._split_seq_length_dim_to( query_key_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size, ) value_vectors = self._split_seq_length_dim_to( value_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size, ) # repeat query vectors across hash dimension query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1) else: key_value_hidden_states = torch.cat([past_states, hidden_states], dim=1) query_key_vectors = self.query_key(key_value_hidden_states) value_vectors = self.value(key_value_hidden_states) else: # project hidden_states to query_key and value query_vectors = None query_key_vectors = self.query_key(hidden_states) value_vectors = self.value(hidden_states) # if query key is not already split if not exists_cache or past_buckets.numel() == 0: query_key_vectors = self._split_hidden_size_dim( query_key_vectors, self.num_attention_heads, self.attention_head_size ) value_vectors = self._split_hidden_size_dim( value_vectors, self.num_attention_heads, self.attention_head_size ) # cache buckets for next incremental decoding if exists_cache and key_value_hidden_states.shape[1] >= self.chunk_length: buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) # free memory del hidden_states assert query_key_vectors.shape[-1] == self.attention_head_size, ( f"last dim of query_key_vectors is {query_key_vectors.shape[-1]} but should be {self.attention_head_size}." ) assert value_vectors.shape[-1] == self.attention_head_size, ( f"last dim of value_vectors is {value_vectors.shape[-1]} but should be {self.attention_head_size}." ) do_standard_self_attention = (sequence_length <= self.chunk_length) or ( exists_cache and past_states is not None ) # LSH attention only makes sense if chunked attention should be performed if not do_standard_self_attention: # set `num_buckets` on the fly, recommended way to do it if self.num_buckets is None: self._set_num_buckets(sequence_length) # use cached buckets for backprop only if buckets is None: # hash query key vectors into buckets buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) else: # make sure buckets has correct shape for LSH attention buckets = buckets.view(batch_size, self.num_attention_heads, num_hashes * sequence_length) assert int(buckets.shape[-1]) == num_hashes * sequence_length, ( f"last dim of buckets is {buckets.shape[-1]}, but should be {num_hashes * sequence_length}" ) sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( sequence_length, buckets, num_hashes ) # make sure bucket idx is not longer then sequence length sorted_bucket_idx_per_hash = sorted_bucket_idx % sequence_length # cluster query key value vectors according to hashed buckets query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes) query_key_vectors = self._split_seq_length_dim_to( query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) value_vectors = self._split_seq_length_dim_to( value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0, ( "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and" " `config.num_chunks_before` are set to 0." ) elif exists_cache and past_buckets.numel() != 0: # use max sequence length sorted_bucket_idx_per_hash = sorted_bucket_idx else: # get sequence length indices sorted_bucket_idx_per_hash = torch.arange(sequence_length, device=query_key_vectors.device).repeat( batch_size, self.num_attention_heads, 1 ) # scale key vectors sqrt_num = np.sqrt(self.attention_head_size) key_vectors = self._len_and_dim_norm(query_key_vectors, sqrt_num) # set query_vectors to query key vectors if LSH self attention query_vectors = query_vectors if query_vectors is not None else query_key_vectors # free memory del query_key_vectors # get attention probs out_vectors, logits, attention_probs = self._attend( query_vectors=query_vectors, key_vectors=key_vectors, value_vectors=value_vectors, sorted_bucket_idx_per_hash=sorted_bucket_idx_per_hash, attention_mask=attention_mask, head_mask=head_mask, do_standard_self_attention=do_standard_self_attention, use_cache=exists_cache, ) # free memory del key_vectors, value_vectors # re-order out_vectors and logits if not do_standard_self_attention: # sort clusters back to correct ordering out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) if not do_standard_self_attention or (exists_cache and past_buckets.numel() != 0): # sum up all hash rounds if num_hashes > 1: out_vectors = self._split_seq_length_dim_to( out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, ) logits = self._split_seq_length_dim_to( logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, ).unsqueeze(-1) probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) # free memory del probs_vectors # free memory del logits assert out_vectors.shape == ( batch_size, self.num_attention_heads, sequence_length, self.attention_head_size, ), ( "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length," " config.attention_head_size]`." ) out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) if output_attentions is False: attention_probs = () if buckets is not None: buckets = buckets.view(batch_size, self.num_attention_heads, num_hashes, -1) return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) def _query_per_attn_head(self, hidden_states): per_head_query_key = self.query_key.weight.reshape( self.num_attention_heads, self.attention_head_size, self.hidden_size ).transpose(-2, -1) # only relevant for inference and no bias => we can use einsum here query_key_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_query_key) return query_key_vectors def _value_per_attn_head(self, hidden_states): per_head_value = self.value.weight.reshape( self.num_attention_heads, self.attention_head_size, self.hidden_size ).transpose(-2, -1) # only relevant for inference and no bias => we can use einsum here value_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_value) return value_vectors def _hash_vectors(self, vectors, num_hashes, attention_mask, increase_num_buckets=False): batch_size = vectors.shape[0] # See https://huggingface.co/papers/1509.02897 # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. if isinstance(self.num_buckets, int): assert self.num_buckets % 2 == 0, ( f"There should be an even number of buckets, but `self.num_buckets`: {self.num_buckets}" ) rotation_size = self.num_buckets num_buckets = self.num_buckets else: # Factorize the hash if self.num_buckets is a list or tuple rotation_size, num_buckets = 0, 1 for bucket_factor in self.num_buckets: assert bucket_factor % 2 == 0, ( f"The number of buckets should be even, but `num_bucket`: {bucket_factor}" ) rotation_size = rotation_size + bucket_factor num_buckets = num_buckets * bucket_factor # remove gradient vectors = vectors.detach() if self.hash_seed is not None: # for determinism torch.manual_seed(self.hash_seed) rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) # create a random self.attention_head_size x num_hashes x num_buckets/2 random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype) # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) if isinstance(self.num_buckets, int) or len(self.num_buckets) == 1: rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1) buckets = torch.argmax(rotated_vectors, dim=-1) else: # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for bucket_factor in self.num_buckets: rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)] cur_sum = cur_sum + bucket_factor // 2 rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1) if buckets is None: buckets = torch.argmax(rotated_vectors_factor, dim=-1) else: buckets = buckets + (cur_product * torch.argmax(rotated_vectors_factor, dim=-1)) cur_product = cur_product * bucket_factor if attention_mask is not None and (attention_mask.sum().item() < batch_size * attention_mask.shape[-1]): # add an extra bucket for padding tokens only num_buckets = num_buckets + 1 # assign padding tokens extra bucket buckets_mask = attention_mask.to(torch.bool)[:, None, None, :].expand(buckets.shape) buckets = torch.where( buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device) ) elif increase_num_buckets: num_buckets = num_buckets + 1 # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. offsets = torch.arange(num_hashes, device=vectors.device) offsets = (offsets * num_buckets).view((1, 1, -1, 1)) # expand to batch size and num attention heads offsets = offsets.expand((batch_size, self.num_attention_heads) + offsets.shape[-2:]) offset_buckets = (buckets + offsets).flatten(start_dim=2, end_dim=3) return offset_buckets def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes): # no gradients are needed with torch.no_grad(): # hash-based sort sorted_bucket_idx = _stable_argsort(buckets, dim=-1) # create simple indices to scatter to, to have undo sort indices = ( torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device) .view(1, 1, -1) .expand(sorted_bucket_idx.shape) ) # get undo sort undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size()) undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices) return sorted_bucket_idx, undo_sorted_bucket_idx def _set_num_buckets(self, sequence_length): # `num_buckets` should be set to 2 * sequence_length // chunk_length as recommended in paper num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1 # make sure buckets are power of 2 num_buckets = 2**num_buckets_pow_2 # factorize `num_buckets` if `num_buckets` becomes too large num_buckets_limit = 2 * max( int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length, ) if num_buckets > num_buckets_limit: num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)] logger.warning(f"config.num_buckets is not set. Setting config.num_buckets to {num_buckets}...") # set num buckets in config to be properly saved self.config.num_buckets = num_buckets self.num_buckets = num_buckets def _attend( self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx_per_hash, attention_mask, head_mask, do_standard_self_attention, use_cache, ): # look at previous and following chunks if chunked attention if not do_standard_self_attention: key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) # get logits and dots # (BS, NumAttn, NumHash x NumChunk, Chunk_L x Hidden),(BS, NumAttn, NumHash x NumChunk, Chunk_L * (Num_bef + Num_aft + 1) x Hidden) -> (BS, NumAttn, NumHash x NumChunk, Chunk_L, Chunk_L * (1 + Num_bef + Num_aft)) query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) # free memory del query_vectors, key_vectors # if chunked attention split bucket idxs to query and key if not do_standard_self_attention: query_bucket_idx = self._split_seq_length_dim_to( sorted_bucket_idx_per_hash, -1, self.chunk_length, self.num_attention_heads ) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) elif use_cache and query_key_dots.ndim > 4: key_value_bucket_idx = sorted_bucket_idx_per_hash query_bucket_idx = ( key_value_bucket_idx.new_ones(key_value_bucket_idx.shape[:-1] + (1,)) * key_value_bucket_idx.max() ) elif use_cache and query_key_dots.ndim <= 4: query_bucket_idx = (query_key_dots.shape[-1] - 1) * torch.ones_like(query_key_dots)[:, :, :, -1] key_value_bucket_idx = torch.arange( query_key_dots.shape[-1], dtype=torch.long, device=query_key_dots.device )[None, None, :].expand(query_bucket_idx.shape[:2] + (-1,)) else: query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx_per_hash # get correct mask values depending on precision if query_key_dots.dtype == torch.float16: self_mask_value = self.self_mask_value_float16.half() mask_value = self.mask_value_float16.half() else: self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 if not use_cache: mask = self._compute_attn_mask( query_bucket_idx, key_value_bucket_idx, attention_mask, query_key_dots.shape, do_standard_self_attention, ) if mask is not None: query_key_dots = torch.where(mask, query_key_dots, mask_value) # free memory del mask # Self mask is ALWAYS applied. # From the reformer paper (https://huggingface.co/papers/2001.04451): # " While attention to the future is not allowed, typical implementations of the # Transformer do allow a position to attend to itself. # Such behavior is undesirable in a shared-QK formulation because the dot-product # of a query vector with itself will almost always be greater than the dot product of a # query vector with a vector at another position. We therefore modify the masking # to forbid a token from attending to itself, except in situations # where a token has no other valid attention targets (e.g. the first token in a sequence) " self_mask = torch.ne(query_bucket_idx.unsqueeze(-1), key_value_bucket_idx.unsqueeze(-2)).to( query_bucket_idx.device ) # apply self_mask query_key_dots = torch.where(self_mask, query_key_dots, self_mask_value) # free memory del self_mask logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) # dots shape is `[batch_size, num_attn_heads, num_hashes * seq_len // chunk_length, chunk_length, chunk_length * (1 + num_chunks_before + num_chunks_after)]` attention_probs = torch.exp(query_key_dots - logits) # free memory del query_key_dots # dropout attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask # attend values out_vectors = torch.matmul(attention_probs, value_vectors) # free memory del value_vectors # merge chunk length if out_vectors.ndim > 4: logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) return out_vectors, logits, attention_probs def _compute_attn_mask( self, query_indices, key_indices, attention_mask, query_key_dot_shape, do_standard_self_attention ): # attention mask for LSH if attention_mask is not None: # if chunked attention, the attention mask has to correspond to LSH order attention_mask = attention_mask.to(torch.bool)[:, None, :] if not do_standard_self_attention: # expand attn_mask to fit with key_value_bucket_idx shape attention_mask = attention_mask[:, None, :] attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) # extract attention mask from LSH sorted key_indices attention_mask = torch.gather(attention_mask, -1, key_indices) attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dot_shape) # Causal mask if self.is_decoder is True: causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) # add attention mask if not None if attention_mask is not None: attention_mask = causal_mask * attention_mask else: attention_mask = causal_mask return attention_mask def _get_relevant_hid_states_and_buckets( self, query_vectors, attention_mask, num_hashes, hidden_states, past_states, past_buckets ): # concat hidden states hidden_states = torch.cat([past_states, hidden_states], dim=1) # batch_size hidden batch_size = hidden_states.shape[0] sequence_length = hidden_states.shape[1] # check if cached buckets include pad bucket max_bucket = self.num_buckets if isinstance(self.num_buckets, int) else reduce(mul, self.num_buckets) # if pad bucket was cached => need to increase num buckets for caching increase_num_buckets = past_buckets.max() > num_hashes * max_bucket - 1 # retrieve query buckets query_buckets = self._hash_vectors( query_vectors, num_hashes, attention_mask, increase_num_buckets=increase_num_buckets ) # concat buckets concat_buckets = torch.cat([past_buckets, query_buckets.unsqueeze(-1)], dim=-1) # hash-based sort bucket_idx = _stable_argsort(concat_buckets, dim=-1) # bucket_idx has shape: BatchSize x NumAttnHeads x NumHashes x SequenceLength assert bucket_idx.shape == ( batch_size, self.num_attention_heads, num_hashes, sequence_length, ), ( f"bucket_idx should have shape {(batch_size, self.num_attention_heads, num_hashes, sequence_length)}, but" f" has shape {bucket_idx.shape}." ) # find indices of new bucket indices relevant_bucket_idx = (bucket_idx == (bucket_idx.shape[-1] - 1)).nonzero() # expand relevant bucket indices to its chunks relevant_bucket_idx_chunk = self._expand_to_indices_in_relevant_chunk(relevant_bucket_idx, sequence_length) relevant_bucket_idx_chunk = bucket_idx[tuple(relevant_bucket_idx_chunk.transpose(0, 1))] # adapt bucket_idx for batch and hidden states for index select offset = torch.arange(relevant_bucket_idx_chunk.shape[-1], device=hidden_states.device, dtype=torch.long) bucket_idx_batch_offset = sequence_length * ( batch_size * torch.div(offset, relevant_bucket_idx_chunk.shape[-1], rounding_mode="floor") ) # add batch offset relevant_bucket_idx_chunk_all_batch = relevant_bucket_idx_chunk + bucket_idx_batch_offset hidden_states = hidden_states.reshape((-1, self.hidden_size)) # select all relevant hidden states relevant_hidden_states = hidden_states.index_select(0, relevant_bucket_idx_chunk_all_batch) # reshape hidden states and bucket_idx to correct output relevant_hidden_states = relevant_hidden_states.reshape( batch_size, self.num_attention_heads, -1, self.hidden_size ) relevant_bucket_idx_chunk = relevant_bucket_idx_chunk.reshape( batch_size, self.num_attention_heads, num_hashes, -1 ) assert ( relevant_hidden_states.shape[2] == (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes ), ( "There should be" f" {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes} `hidden_states`," f" there are {relevant_hidden_states.shape[2]} `hidden_states`." ) assert ( relevant_bucket_idx_chunk.shape[-1] == (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length ), ( "There should be" f" {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length} `hidden_states`, there are" f" {relevant_bucket_idx_chunk.shape[-1]} `bucket_idx`." ) return relevant_hidden_states, relevant_bucket_idx_chunk, query_buckets def _expand_to_indices_in_relevant_chunk(self, indices, sequence_length): # get relevant indices of where chunk starts and its size start_indices_chunk = ((indices[:, -1] // self.chunk_length) - self.num_chunks_before) * self.chunk_length total_chunk_size = self.chunk_length * (1 + self.num_chunks_before + self.num_chunks_after) # expand start indices and add correct chunk offset via arange expanded_start_indices = start_indices_chunk.unsqueeze(-1).expand(indices.shape[0], total_chunk_size) chunk_sequence_indices = expanded_start_indices + torch.arange( total_chunk_size, device=indices.device, dtype=torch.long ).unsqueeze(0).expand(indices.shape[0], total_chunk_size) # make sure that circular logic holds via % seq len chunk_sequence_indices = chunk_sequence_indices.flatten() % sequence_length # expand indices and set indices correctly indices = indices.unsqueeze(1).expand((indices.shape[0], total_chunk_size, -1)).flatten(0, 1).clone() indices[:, -1] = chunk_sequence_indices return indices def _len_and_dim_norm(self, vectors, sqrt_num): """ length and attention head size dim normalization """ vectors = self._len_norm(vectors) vectors = vectors / sqrt_num return vectors def _len_norm(self, x, epsilon=1e-6): """ length normalization """ variance = torch.mean(x**2, -1, keepdim=True) norm_x = x * torch.rsqrt(variance + epsilon) return norm_x def _gather_by_expansion(self, vectors, idxs, num_hashes): """ expand dims of idxs and vectors for all hashes and gather """ expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) vectors = vectors.repeat(1, 1, num_hashes, 1) return torch.gather(vectors, 2, expanded_idxs) class ReverseSort(Function): """ After chunked attention is applied which sorted clusters, original ordering has to be restored. Since customized backward function is used for Reformer, the gradients of the output vectors have to be explicitly sorted here. """ @staticmethod def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx): # save sorted_bucket_idx for backprop with torch.no_grad(): ctx.sorted_bucket_idx = sorted_bucket_idx # undo sort to have correct order for next layer expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape) out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices) logits = torch.gather(logits, 2, undo_sorted_bucket_idx) return out_vectors, logits @staticmethod def backward(ctx, grad_out_vectors, grad_logits): # get parameters saved in ctx sorted_bucket_idx = ctx.sorted_bucket_idx expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape) # reverse sort of forward grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices) grad_logits = torch.gather(grad_logits, 2, sorted_bucket_idx) # return grad and `None` fillers for last 2 forward args return grad_out_vectors, grad_logits, None, None class LocalSelfAttention(nn.Module, EfficientAttentionMixin): def __init__(self, config, layer_idx=None): super().__init__() self.num_attention_heads = config.num_attention_heads self.chunk_length = config.local_attn_chunk_length self.num_chunks_before = config.local_num_chunks_before self.num_chunks_after = config.local_num_chunks_after self.is_decoder = config.is_decoder self.pad_token_id = config.pad_token_id self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size self.hidden_size = config.hidden_size self.layer_idx = layer_idx # projection matrices self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False) self.dropout = config.local_attention_probs_dropout_prob # save mask value here self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False) self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False) def forward( self, hidden_states, attention_mask=None, head_mask=None, past_buckets_states=None, use_cache=False, output_attentions=False, **kwargs, ): sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] # check if cache shall be used and that hidden states are already cached if past_buckets_states is not None and len(past_buckets_states) > self.layer_idx: past_buckets = past_buckets_states.buckets_cache[self.layer_idx] past_states = past_buckets_states.states_cache[self.layer_idx] assert past_buckets.numel() == 0, ( "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching" " hidden_states_and_buckets." ) key_value_hidden_states = self._retrieve_relevant_hidden_states( past_states, self.chunk_length, self.num_chunks_before ) key_value_hidden_states = torch.cat([key_value_hidden_states, hidden_states], dim=1) # only query vector for last token query_vectors = self.query(hidden_states) # compute key and value for relevant chunk key_vectors = self.key(key_value_hidden_states) value_vectors = self.value(key_value_hidden_states) # free memory del key_value_hidden_states else: # project hidden_states to query, key and value query_vectors = self.query(hidden_states) key_vectors = self.key(hidden_states) value_vectors = self.value(hidden_states) # split last dim into `config.num_attention_heads` and `config.attention_head_size` query_vectors = self._split_hidden_size_dim(query_vectors, self.num_attention_heads, self.attention_head_size) key_vectors = self._split_hidden_size_dim(key_vectors, self.num_attention_heads, self.attention_head_size) value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) assert query_vectors.shape[-1] == self.attention_head_size, ( f"last dim of query_key_vectors is {query_vectors.shape[-1]} but should be {self.attention_head_size}." ) assert key_vectors.shape[-1] == self.attention_head_size, ( f"last dim of query_key_vectors is {key_vectors.shape[-1]} but should be {self.attention_head_size}." ) assert value_vectors.shape[-1] == self.attention_head_size, ( f"last dim of query_key_vectors is {value_vectors.shape[-1]} but should be {self.attention_head_size}." ) if self.chunk_length is None: assert self.num_chunks_before == 0 and self.num_chunks_after == 0, ( "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and" " `config.num_chunks_before` are set to 0." ) # normalize key vectors key_vectors = key_vectors / np.sqrt(self.attention_head_size) # get sequence length indices indices = torch.arange(sequence_length, device=query_vectors.device).repeat( batch_size, self.num_attention_heads, 1 ) # if one should do normal n^2 self-attention do_standard_self_attention = sequence_length <= self.chunk_length # if input should be chunked if not do_standard_self_attention: # chunk vectors # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size query_vectors = self._split_seq_length_dim_to( query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) key_vectors = self._split_seq_length_dim_to( key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) value_vectors = self._split_seq_length_dim_to( value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) # chunk indices query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) # append chunks before and after key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after) else: query_indices = key_indices = indices # query-key matmul: QK^T query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) # free memory del query_vectors, key_vectors mask = self._compute_attn_mask( query_indices, key_indices, attention_mask, query_key_dots.shape, do_standard_self_attention ) if mask is not None: # get mask tensor depending on half precision or not if query_key_dots.dtype == torch.float16: mask_value = self.mask_value_float16.half() else: mask_value = self.mask_value_float32 query_key_dots = torch.where(mask, query_key_dots, mask_value) # free memory del mask # softmax logits = torch.logsumexp(query_key_dots, dim=-1, keepdim=True) attention_probs = torch.exp(query_key_dots - logits) # free memory del logits # dropout attention_probs = nn.functional.dropout(attention_probs, p=self.dropout, training=self.training) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask # attend values out_vectors = torch.matmul(attention_probs, value_vectors) # free memory del value_vectors # merge chunk length if not do_standard_self_attention: out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) assert out_vectors.shape == ( batch_size, self.num_attention_heads, sequence_length, self.attention_head_size, ) out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) if output_attentions is False: attention_probs = () return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) def _compute_attn_mask( self, query_indices, key_indices, attention_mask, query_key_dots_shape, do_standard_self_attention ): # chunk attention mask and look before and after if attention_mask is not None: attention_mask = attention_mask.to(torch.bool)[:, None, :] if not do_standard_self_attention: attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) # create attn_mask attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dots_shape) # Causal mask if self.is_decoder is True: causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device) # add attention mask if not None if attention_mask is not None: attention_mask = causal_mask * attention_mask else: attention_mask = causal_mask return attention_mask @staticmethod def _retrieve_relevant_hidden_states(previous_hidden_states, chunk_length, num_chunks_before): start_position = ((previous_hidden_states.shape[1] // chunk_length) - num_chunks_before) * chunk_length return previous_hidden_states[:, start_position:] class ReformerSelfOutput(nn.Module): def __init__(self, config): super().__init__() all_head_size = config.num_attention_heads * config.attention_head_size self.dropout = config.hidden_dropout_prob self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states class ReformerAttention(nn.Module): def __init__(self, config, layer_id=0): super().__init__() self.layer_id = layer_id self.attn_layers = config.attn_layers self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": self.self_attention = LSHSelfAttention(config, layer_idx=layer_id) elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": self.self_attention = LocalSelfAttention(config, layer_idx=layer_id) elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == {"lsh", "local"}: # get correct attn layers if self.attn_layers[self.layer_id] == "lsh": self.self_attention = LSHSelfAttention(config, layer_idx=layer_id) else: self.self_attention = LocalSelfAttention(config, layer_idx=layer_id) else: raise NotImplementedError( f"Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {self.attn_layers}. " "Select attn layer types from ['lsh', 'local'] only." ) self.output = ReformerSelfOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, num_hashes=None, past_buckets_states=None, use_cache=False, orig_sequence_length=None, output_attentions=False, buckets=None, cache_position=None, ): hidden_states = self.layer_norm(hidden_states) # use cached buckets for backprob if buckets not None for LSHSelfAttention self_attention_outputs = self.self_attention( hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, past_buckets_states=past_buckets_states, use_cache=use_cache, output_attentions=output_attentions, buckets=buckets, cache_position=cache_position, ) # add buckets if necessary if hasattr(self_attention_outputs, "buckets"): buckets = self_attention_outputs.buckets else: buckets = None # cache hidden states for future use if use_cache and past_buckets_states is not None: # padded input should not be cached during prefill states = ( hidden_states[:, :orig_sequence_length] if len(past_buckets_states.states_cache) <= self.layer_id else hidden_states ) buckets = ( buckets[:, :, :, :orig_sequence_length] if ( len(past_buckets_states.buckets_cache) <= self.layer_id and buckets is not None and orig_sequence_length > 1 ) else buckets ) buckets, hidden_states = past_buckets_states.update( buckets, states[:, :orig_sequence_length], self.layer_id ) # compute attention feed forward output attention_output = self.output(self_attention_outputs.hidden_states) return AttentionOutput( hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets, ) class ReformerFeedForwardDense(nn.Module): def __init__(self, config): super().__init__() self.dropout = config.hidden_dropout_prob if isinstance(config.hidden_act, str): self.act_fn = ACT2FN[config.hidden_act] else: self.act_fn = config.hidden_act self.dense = nn.Linear(config.hidden_size, config.feed_forward_size) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = self.act_fn(hidden_states) return hidden_states class ReformerFeedForwardOutput(nn.Module): def __init__(self, config): super().__init__() self.dropout = config.hidden_dropout_prob self.dense = nn.Linear(config.feed_forward_size, config.hidden_size) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states class ChunkReformerFeedForward(nn.Module): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dense = ReformerFeedForwardDense(config) self.output = ReformerFeedForwardOutput(config) def forward(self, attention_output): return apply_chunking_to_forward( self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) def forward_chunk(self, hidden_states): hidden_states = self.layer_norm(hidden_states) hidden_states = self.dense(hidden_states) return self.output(hidden_states) class ReformerLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() self.attention = ReformerAttention(config, layer_id) # dropout requires to have the same # seed for forward and backward pass self.attention_seed = None self.feed_forward_seed = None self.feed_forward = ChunkReformerFeedForward(config) def _init_attention_seed(self): """ This function sets a new seed for the attention layer to make dropout deterministic for both forward calls: 1 normal forward call and 1 forward call in backward to recalculate activations. """ # randomize seeds # use cuda generator if available if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0: # GPU device_idx = torch.cuda.current_device() self.attention_seed = torch.cuda.default_generators[device_idx].seed() else: # CPU self.attention_seed = int(torch.seed() % sys.maxsize) torch.manual_seed(self.attention_seed) def _init_feed_forward_seed(self): """ This function sets a new seed for the feed forward layer to make dropout deterministic for both forward calls: 1 normal forward call and 1 forward call in backward to recalculate activations. """ # randomize seeds # use cuda generator if available if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0: # GPU device_idx = torch.cuda.current_device() self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed() else: # CPU self.feed_forward_seed = int(torch.seed() % sys.maxsize) torch.manual_seed(self.feed_forward_seed) def forward( self, prev_attn_output, hidden_states, attention_mask=None, head_mask=None, num_hashes=None, past_buckets_states=None, use_cache=False, orig_sequence_length=None, output_attentions=False, ): with torch.no_grad(): # every forward pass we sample a different seed # for dropout and save for forward fn in backward pass # to have correct dropout if self.training: self._init_attention_seed() attn_outputs = self.attention( hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, past_buckets_states=past_buckets_states, use_cache=use_cache, orig_sequence_length=orig_sequence_length, output_attentions=output_attentions, ) attn_output = attn_outputs.hidden_states # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # Y_1 = X_1 + f(X_2) attn_output = prev_attn_output + attn_output # free memory del prev_attn_output # every forward pass we sample a different seed # for dropout and save seed for forward fn in backward # to have correct dropout if self.training: self._init_feed_forward_seed() # Y_2 = X_2 + g(Y_1) hidden_states = hidden_states + self.feed_forward(attn_output) return ReformerOutput( attn_output=attn_output, hidden_states=hidden_states, attention_probs=attn_outputs.attention_probs, buckets=attn_outputs.buckets, ) def backward_pass( self, next_attn_output, hidden_states, grad_attn_output, grad_hidden_states, attention_mask=None, head_mask=None, buckets=None, ): # Implements the backward pass for reversible ResNets. # A good blog post on how this works can be found here: # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py assert self.training, ( "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the" " model into training mode." ) with torch.enable_grad(): next_attn_output.requires_grad = True # set seed to have correct dropout torch.manual_seed(self.feed_forward_seed) # g(Y_1) res_hidden_states = self.feed_forward(next_attn_output) res_hidden_states.backward(grad_hidden_states, retain_graph=True) with torch.no_grad(): # X_2 = Y_2 - g(Y_1) hidden_states = hidden_states - res_hidden_states del res_hidden_states grad_attn_output = grad_attn_output + next_attn_output.grad next_attn_output.grad = None with torch.enable_grad(): hidden_states.requires_grad = True # set seed to have correct dropout torch.manual_seed(self.attention_seed) # f(X_2) # use cached buckets for backprob if buckets not None for LSHSelfAttention output = self.attention( hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets, ).hidden_states output.backward(grad_attn_output, retain_graph=True) with torch.no_grad(): # X_1 = Y_1 - f(X_2) attn_output = next_attn_output - output del output, next_attn_output grad_hidden_states = grad_hidden_states + hidden_states.grad hidden_states.grad = None hidden_states = hidden_states.detach() return ReformerBackwardOutput( attn_output=attn_output, hidden_states=hidden_states, grad_attn_output=grad_attn_output, grad_hidden_states=grad_hidden_states, ) class _ReversibleFunction(Function): """ To prevent PyTorch from performing the usual backpropagation, a customized backward function is implemented here. This way it is made sure that no memory expensive activations are saved during the forward pass. This function is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py """ @staticmethod def forward( ctx, hidden_states, layers, attention_mask, head_mask, num_hashes, all_hidden_states, all_attentions, past_buckets_states, use_cache, orig_sequence_length, output_hidden_states, output_attentions, ): all_buckets = () # split duplicated tensor hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)): if output_hidden_states is True: all_hidden_states.append(hidden_states) layer_outputs = layer( prev_attn_output=attn_output, hidden_states=hidden_states, attention_mask=attention_mask, head_mask=layer_head_mask, num_hashes=num_hashes, past_buckets_states=past_buckets_states, use_cache=use_cache, orig_sequence_length=orig_sequence_length, output_attentions=output_attentions, ) attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states all_buckets = all_buckets + (layer_outputs.buckets,) if output_attentions: all_attentions.append(layer_outputs.attention_probs) # Add last layer if output_hidden_states is True: all_hidden_states.append(hidden_states) # attach params to ctx for backward ctx.save_for_backward(attn_output.detach(), hidden_states.detach()) ctx.layers = layers ctx.all_buckets = all_buckets ctx.head_mask = head_mask ctx.attention_mask = attention_mask # Concatenate 2 RevNet outputs return torch.cat([attn_output, hidden_states], dim=-1) @staticmethod def backward(ctx, grad_hidden_states): grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1) # retrieve params from ctx for backward attn_output, hidden_states = ctx.saved_tensors # create tuple output = ReformerBackwardOutput( attn_output=attn_output, hidden_states=hidden_states, grad_attn_output=grad_attn_output, grad_hidden_states=grad_hidden_states, ) # free memory del grad_attn_output, grad_hidden_states, attn_output, hidden_states layers = ctx.layers all_buckets = ctx.all_buckets head_mask = ctx.head_mask attention_mask = ctx.attention_mask for idx, layer in enumerate(layers[::-1]): # pop last buckets from stack buckets = all_buckets[-1] all_buckets = all_buckets[:-1] # backprop output = layer.backward_pass( next_attn_output=output.attn_output, hidden_states=output.hidden_states, grad_attn_output=output.grad_attn_output, grad_hidden_states=output.grad_hidden_states, head_mask=head_mask[len(layers) - idx - 1], attention_mask=attention_mask, buckets=buckets, ) assert all_buckets == (), "buckets have to be empty after backpropagation" grad_hidden_states = torch.cat([output.grad_attn_output, output.grad_hidden_states], dim=-1) # num of return vars has to match num of forward() args # return gradient for hidden_states arg and None for other args return grad_hidden_states, None, None, None, None, None, None, None, None, None, None, None class ReformerEncoder(nn.Module): def __init__(self, config): super().__init__() self.dropout = config.hidden_dropout_prob self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)]) # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size self.layer_norm = nn.LayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states, attention_mask=None, head_mask=None, num_hashes=None, past_buckets_states=None, use_cache=False, orig_sequence_length=None, output_hidden_states=False, output_attentions=False, ): # hidden_states and attention lists to be filled if wished all_hidden_states = [] all_attentions = [] # init cached hidden states if necessary if use_cache and past_buckets_states is None: past_buckets_states = ReformerDynamicCache() elif use_cache and isinstance(past_buckets_states, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `ReformerDynamicCache` instead, e.g. " "`past_key_values=ReformerDynamicCache.from_legacy_cache(past_key_values)`." ) past_buckets_states = ReformerDynamicCache.from_legacy_cache(past_buckets_states) # concat same tensor for reversible ResNet hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) hidden_states = _ReversibleFunction.apply( hidden_states, self.layers, attention_mask, head_mask, num_hashes, all_hidden_states, all_attentions, past_buckets_states, use_cache, orig_sequence_length, output_hidden_states, output_attentions, ) # Apply layer norm to concatenated hidden states hidden_states = self.layer_norm(hidden_states) # Apply dropout hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) next_cache = past_buckets_states if use_cache else None return ReformerEncoderOutput( hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions, past_buckets_states=next_cache, ) class ReformerOnlyLMHead(nn.Module): def __init__(self, config): super().__init__() # Reformer is using Rev Nets, thus last layer outputs are concatenated and # Layer Norm is done over 2 * hidden_size self.seq_len_dim = 1 self.chunk_size_lm_head = config.chunk_size_lm_head self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.decoder.bias = self.bias def forward(self, hidden_states): return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) def forward_chunk(self, hidden_states): hidden_states = self.decoder(hidden_states) return hidden_states def _tie_weights(self) -> None: # For accelerate compatibility and to not break backward compatibility if self.decoder.bias.device.type == "meta": self.decoder.bias = self.bias else: # To tie those two weights if they get disconnected (on TPU or when the bias is resized) self.bias = self.decoder.bias @auto_docstring class ReformerPreTrainedModel(PreTrainedModel): config: ReformerConfig base_model_prefix = "reformer" @property def dummy_inputs(self): input_ids = torch.tensor(DUMMY_INPUTS) input_mask = torch.tensor(DUMMY_MASK) dummy_inputs = { "input_ids": input_ids, "attention_mask": input_mask, } return dummy_inputs def _init_weights(self, module): """Initialize the weights""" if isinstance(module, AxialPositionEmbeddings): for weight in module.weights: nn.init.normal_(weight, std=self.config.axial_norm_std) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @dataclass @auto_docstring( custom_intro=""" Output type of [`ReformerModel`]. """ ) class ReformerModelOutput(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_predict, hidden_size)`): Sequence of hidden-states at the last layer of the model. `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`. past_buckets_states (`list[tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): List of `tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`). Contains precomputed buckets and hidden-states that can be used (see `past_buckets_states` input) to speed up sequential decoding. """ last_hidden_state: torch.FloatTensor past_buckets_states: Optional[list[tuple[torch.LongTensor, torch.FloatTensor]]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None @dataclass @auto_docstring( custom_intro=""" Output type of [`ReformerModelWithLMHead`]. """ ) class ReformerModelWithLMHeadOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, num_predict, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`. past_buckets_states (`list[tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): List of `tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`). Contains precomputed buckets and hidden-states that can be used (see `past_buckets_states` input) to speed up sequential decoding. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_buckets_states: Optional[list[tuple[torch.LongTensor, torch.FloatTensor]]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None @auto_docstring class ReformerModel(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config assert self.config.num_hidden_layers > 0, ( "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']" ) self.embeddings = ReformerEmbeddings(config) self.encoder = ReformerEncoder(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, num_hashes: Optional[int] = None, past_buckets_states: Optional[list[tuple[torch.Tensor]]] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, ReformerModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically padded to be a multiple of the chunk length. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) num_hashes (`int`, *optional*): The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites the default defined in `config.num_hashes`. For more information, see `num_hashes` in [`ReformerConfig`]. past_buckets_states (`list[tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*): List of `tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`). Contains precomputed hidden-states and buckets (only relevant for LSH Self-Attention). Can be used to speed up sequential decoding. """ use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() device = input_ids.device elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] device = inputs_embeds.device else: raise ValueError("You have to specify either input_ids or inputs_embeds") assert len(input_shape) == 2, ( f"`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {input_shape}" ) if past_buckets_states is not None: assert not self.training, "`past_buckets_states` can only be used for inference, not for training`." # prepare head mask head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True) # original sequence length for padding orig_sequence_length = input_shape[-1] # if needs padding least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) min_chunk_length = _get_min_chunk_len(self.config) must_pad_to_match_chunk_length = ( input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > min_chunk_length and past_buckets_states is None ) if must_pad_to_match_chunk_length: padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length if self.training is True: raise ValueError( f"If training, sequence length {input_shape[-1]} has to be a multiple of least common multiple " f"chunk_length {least_common_mult_chunk_length}. Please consider padding the input to a length " f"of {input_shape[-1] + padding_length}." ) # pad input input_ids, inputs_embeds, attention_mask, position_ids, input_shape = self._pad_to_mult_of_chunk_length( input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, input_shape=input_shape, padding_length=padding_length, padded_seq_length=least_common_mult_chunk_length, device=device, ) # start index for position encoding depends on incremental decoding if past_buckets_states is not None: start_idx_pos_encodings = past_buckets_states[0][1].shape[1] else: start_idx_pos_encodings = 0 embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, start_idx_pos_encodings=start_idx_pos_encodings, ) encoder_outputs = self.encoder( hidden_states=embedding_output, head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, past_buckets_states=past_buckets_states, use_cache=use_cache, orig_sequence_length=orig_sequence_length, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) sequence_output = encoder_outputs.hidden_states # if padding was applied if must_pad_to_match_chunk_length: sequence_output = sequence_output[:, :orig_sequence_length] past_buckets_states = encoder_outputs.past_buckets_states if use_cache else None hidden_states = encoder_outputs.all_hidden_states if output_hidden_states else None attentions = encoder_outputs.all_attentions if output_attentions else None if not return_dict: return tuple(v for v in [sequence_output, past_buckets_states, hidden_states, attentions] if v is not None) return ReformerModelOutput( last_hidden_state=sequence_output, past_buckets_states=past_buckets_states, hidden_states=hidden_states, attentions=attentions, ) def _pad_to_mult_of_chunk_length( self, input_ids, inputs_embeds=None, attention_mask=None, position_ids=None, input_shape=None, padding_length=None, padded_seq_length=None, device=None, ): logger.warning_once( f"Input ids are automatically padded from {input_shape[-1]} to {input_shape[-1] + padding_length} to be a " f"multiple of `config.chunk_length`: {padded_seq_length}" ) padded_input_ids = torch.full( (input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long, ) # Extend `attention_mask` if attention_mask is not None: pad_attention_mask = torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype) attention_mask = torch.cat([attention_mask, pad_attention_mask], dim=-1) else: attention_mask = torch.cat( [ torch.ones(input_shape, device=device, dtype=torch.bool), torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool), ], dim=-1, ) # Extend `input_ids` with padding to match least common multiple chunk_length if input_ids is not None: input_ids = torch.cat([input_ids, padded_input_ids], dim=-1) input_shape = input_ids.size() # Pad position ids if given if position_ids is not None: padded_position_ids = torch.arange(input_shape[-1], padded_seq_length, dtype=torch.long, device=device) padded_position_ids = position_ids.unsqueeze(0).expand(input_shape[0], padding_length) position_ids = torch.cat([position_ids, padded_position_ids], dim=-1) # Extend `inputs_embeds` with padding to match least common multiple chunk_length if inputs_embeds is not None: padded_inputs_embeds = self.get_input_embeddings()(padded_input_ids) inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) input_shape = inputs_embeds.size() return input_ids, inputs_embeds, attention_mask, position_ids, input_shape @auto_docstring( custom_intro=""" Reformer Model with a `language modeling` head on top. """ ) class ReformerModelWithLMHead(ReformerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): super().__init__(config) assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`." assert "local" not in self.config.attn_layers or config.local_num_chunks_after == 0, ( "If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not" f" {config.local_num_chunks_after}." ) assert "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0, ( "If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not" f" {config.lsh_num_chunks_after}." ) self.reformer = ReformerModel(config) self.lm_head = ReformerOnlyLMHead(config) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.lm_head.decoder def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings self.lm_head.bias = new_embeddings.bias @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, num_hashes: Optional[int] = None, past_buckets_states: Optional[list[tuple[torch.Tensor]]] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple, CausalLMOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically padded to be a multiple of the chunk length. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) num_hashes (`int`, *optional*): The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites the default defined in `config.num_hashes`. For more information, see `num_hashes` in [`ReformerConfig`]. past_buckets_states (`list[tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*): List of `tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`). Contains precomputed hidden-states and buckets (only relevant for LSH Self-Attention). Can be used to speed up sequential decoding. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict reformer_outputs = self.reformer( input_ids, position_ids=position_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, num_hashes=num_hashes, past_buckets_states=past_buckets_states, use_cache=use_cache, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) sequence_output = reformer_outputs[0] logits = self.lm_head(sequence_output) loss = None if labels is not None: loss = self.loss_function( logits, labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: output = (logits,) + reformer_outputs[1:] return ((loss,) + output) if loss is not None else output return ReformerModelWithLMHeadOutput( loss=loss, logits=logits, past_buckets_states=reformer_outputs.past_buckets_states, hidden_states=reformer_outputs.hidden_states, attentions=reformer_outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, **kwargs ): # Overitten -- different expected inputs/outputs # only last token for inputs_ids if past is defined in kwargs if past_key_values is not None: input_ids = input_ids[:, -1:] model_inputs = { "input_ids": input_ids, "past_buckets_states": past_key_values, "use_cache": use_cache, "num_hashes": num_hashes, } # Attention mask is computed on ReformerModel.forward() kwargs.pop("attention_mask", None) # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): if key not in model_inputs: print(f"Warning: {key} is not a recognized input.") model_inputs[key] = value return model_inputs def _reorder_cache(self, past_key_values, beam_idx): reord_past_buckets_states = [] for buckets, hidden_states in past_key_values: # buckets if buckets is not None and buckets.numel() > 0: reord_buckets = buckets.index_select(0, beam_idx.to(buckets.device)) else: reord_buckets = None # hidden states reord_hidden_states = hidden_states.index_select(0, beam_idx.to(hidden_states.device)) reord_past_buckets_states.append((reord_buckets, reord_hidden_states)) if isinstance(past_key_values, ReformerDynamicCache): reord_past_buckets_states = ReformerDynamicCache.from_legacy_cache(reord_past_buckets_states) return reord_past_buckets_states @auto_docstring class ReformerForMaskedLM(ReformerPreTrainedModel): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): super().__init__(config) assert not config.is_decoder, ( "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional" " self-attention." ) self.reformer = ReformerModel(config) self.lm_head = ReformerOnlyLMHead(config) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.lm_head.decoder def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings self.lm_head.bias = new_embeddings.bias @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, num_hashes: Optional[int] = None, labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, MaskedLMOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically padded to be a multiple of the chunk length. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) num_hashes (`int`, *optional*): The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites the default defined in `config.num_hashes`. For more information, see `num_hashes` in [`ReformerConfig`]. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels This example uses a false checkpoint since we don't have any available pretrained model for the masked language modeling task with the Reformer architecture. Example: ```python >>> import torch >>> from transformers import AutoTokenizer, ReformerForMaskedLM >>> tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer") >>> model = ReformerForMaskedLM.from_pretrained("hf-internal-testing/tiny-random-reformer") >>> # add mask_token >>> tokenizer.add_special_tokens({"mask_token": "[MASK]"}) # doctest: +IGNORE_RESULT >>> inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt") >>> # resize model's embedding matrix >>> model.resize_token_embeddings(new_num_tokens=model.config.vocab_size + 1) # doctest: +IGNORE_RESULT >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> # retrieve index of [MASK] >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) >>> predicted_token = tokenizer.decode(predicted_token_id) ``` ```python >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] >>> # mask labels of non-[MASK] tokens >>> labels = torch.where( ... inputs.input_ids == tokenizer.mask_token_id, labels[:, : inputs["input_ids"].shape[-1]], -100 ... ) >>> outputs = model(**inputs, labels=labels) >>> loss = round(outputs.loss.item(), 2) ``` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict reformer_outputs = self.reformer( input_ids, position_ids=position_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, num_hashes=num_hashes, use_cache=False, # no causal mask output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) sequence_output = reformer_outputs[0] logits = self.lm_head(sequence_output) masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (logits,) + reformer_outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return MaskedLMOutput( loss=masked_lm_loss, logits=logits, hidden_states=reformer_outputs.hidden_states, attentions=reformer_outputs.attentions, ) @auto_docstring( custom_intro=""" Reformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """ ) class ReformerForSequenceClassification(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.reformer = ReformerModel(config) self.classifier = ReformerClassificationHead(config) if config.is_decoder is True: logger.warning("You might want to disable causal masking for sequence classification") # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, num_hashes: Optional[int] = None, labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, SequenceClassifierOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically padded to be a multiple of the chunk length. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) num_hashes (`int`, *optional*): The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites the default defined in `config.num_hashes`. For more information, see `num_hashes` in [`ReformerConfig`]. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Example of single-label classification: ```python >>> import torch >>> from transformers import AutoTokenizer, ReformerForSequenceClassification >>> tokenizer = AutoTokenizer.from_pretrained("google/reformer-crime-and-punishment") >>> model = ReformerForSequenceClassification.from_pretrained("google/reformer-crime-and-punishment") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> predicted_class_id = logits.argmax().item() >>> label = model.config.id2label[predicted_class_id] ``` ```python >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> num_labels = len(model.config.id2label) >>> model = ReformerForSequenceClassification.from_pretrained( ... "google/reformer-crime-and-punishment", num_labels=num_labels ... ) >>> labels = torch.tensor(1) >>> loss = model(**inputs, labels=labels).loss ``` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.reformer( input_ids, position_ids=position_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, num_hashes=num_hashes, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.classifier(sequence_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class ReformerClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, hidden_states, **kwargs): hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) hidden_states = self.dropout(hidden_states) hidden_states = self.dense(hidden_states) hidden_states = torch.tanh(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.out_proj(hidden_states) return hidden_states @auto_docstring class ReformerForQuestionAnswering(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.reformer = ReformerModel(config) # 2 * config.hidden_size because we use reversible residual layers self.qa_outputs = nn.Linear(2 * config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, num_hashes: Optional[int] = None, start_positions: Optional[torch.Tensor] = None, end_positions: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, QuestionAnsweringModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. During training the input_ids sequence_length has to be a multiple of the relevant model's chunk lengths (lsh's, local's or both). During evaluation, the indices are automatically padded to be a multiple of the chunk length. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) num_hashes (`int`, *optional*): The number of hashing rounds that should be performed during bucketing. Setting this argument overwrites the default defined in `config.num_hashes`. For more information, see `num_hashes` in [`ReformerConfig`]. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict reformer_outputs = self.reformer( input_ids, position_ids=position_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, num_hashes=num_hashes, use_cache=False, # no causal mask output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) sequence_output = reformer_outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + reformer_outputs[1:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=reformer_outputs.hidden_states, attentions=reformer_outputs.attentions, ) __all__ = [ "ReformerAttention", "ReformerForMaskedLM", "ReformerForQuestionAnswering", "ReformerForSequenceClassification", "ReformerLayer", "ReformerModel", "ReformerModelWithLMHead", "ReformerPreTrainedModel", ]