# mypy: ignore-errors import math from copy import copy from dataclasses import dataclass from functools import partial from typing import Optional import torch from torch.fx.experimental.symbolic_shapes import is_nested_int from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.opinfo.core import ( BinaryUfuncInfo, ReductionOpInfo, SampleInput, UnaryUfuncInfo, ) from torch.utils._pytree import tree_flatten, tree_map @dataclass class ExtraOpData: """ Contains info on top of the typical OpInfo data that is useful for NJT test generation. The process that converts the standard op_db -> an NJT-compatible op_db will attach this data onto each associated OpInfo entry. """ # Indicates whether the associated op is a view op is_view: bool = False # Specifies the names of any dim-related args that the op takes in. This is useful # for NJT tests because there is often asymmetry across the supported set of dims for # an op; it may make sense to operate over the batch dim but not the ragged dim, for # example. The length of this list should match the number of relevant overloads. # Each list item of the outer list should specify dim argnames. Ellipses should be used # to indicate multi-dim support for a given overload. # # For example, squeeze() has both a dim and multi-dim overload, where the argname for # each is simply "dim". Its entry should be: [["dim"], ["dim..."]]. # # If no overload of the op accepts dim-related args, this should be None. dim_args: list[list[str]] = None # Helper function to extract names of dim-related args. # Returns: tuple of (single dim argname if available, dim list argname if available) # If the op doesn't support dim-related args at all OR this op only has overloads # with multiple dim args (e.g. transpose()), then this returns (None, None). def get_dim_argnames(self) -> tuple[Optional[str], Optional[str]]: if self.dim_args is None: return (None, None) # name for the dim arg that supports a single dim single_dim_argname = None # name for the dim arg that supports a list of dims dimlist_argname = None for overload in self.dim_args: # only consider overloads with a single dim-related arg if len(overload) != 1: continue if overload[0].endswith("..."): dimlist_argname = overload[0].replace("...", "") if single_dim_argname is None: single_dim_argname = dimlist_argname else: single_dim_argname = overload[0] return (single_dim_argname, dimlist_argname) # Mapping of OpInfo full names -> extra data to tack onto the OpInfo entry for use # in test generation. extra_op_data = { "_segment_reduce.lengths": ExtraOpData(dim_args=[["axis0"]]), "_segment_reduce.offsets": ExtraOpData(dim_args=[["axis0"]]), "all": ExtraOpData(dim_args=[["dim"], ["dim..."]]), "argmax": ExtraOpData(dim_args=[["dim"]]), "argmin": ExtraOpData(dim_args=[["dim"]]), "amax": ExtraOpData(dim_args=[["dim..."]]), "amin": ExtraOpData(dim_args=[["dim..."]]), "any": ExtraOpData(dim_args=[["dim"], ["dim..."]]), "argsort": ExtraOpData(dim_args=[["dim"]]), "broadcast_to": ExtraOpData(is_view=True), "cat": ExtraOpData(dim_args=[["dim"]]), "chunk": ExtraOpData(is_view=True, dim_args=[["dim"]]), "conj": ExtraOpData(is_view=True), "contiguous": ExtraOpData(is_view=True), "count_nonzero": ExtraOpData(dim_args=[["dim"], ["dim..."]]), "cummax": ExtraOpData(dim_args=[["dim"]]), "cummin": ExtraOpData(dim_args=[["dim"]]), "cumprod": ExtraOpData(dim_args=[["dim"]]), "cumsum": ExtraOpData(dim_args=[["dim"]]), "cumulative_trapezoid": ExtraOpData(dim_args=[["dim"]]), "diag_embed": ExtraOpData(dim_args=[["dim1", "dim2"]]), "diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]), "diagonal_copy": ExtraOpData(dim_args=[["dim1", "dim2"]]), "diagonal_scatter": ExtraOpData(dim_args=[["dim1", "dim2"]]), "diff": ExtraOpData(dim_args=[["dim"]]), "expand": ExtraOpData(is_view=True), "expand_as": ExtraOpData(is_view=True), "fft.fft": ExtraOpData(dim_args=[["dim"]]), "fft.hfft": ExtraOpData(dim_args=[["dim"]]), "fft.ifft": ExtraOpData(dim_args=[["dim"]]), "fft.ihfft": ExtraOpData(dim_args=[["dim"]]), "fft.irfft": ExtraOpData(dim_args=[["dim"]]), "fft.rfft": ExtraOpData(dim_args=[["dim"]]), "flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]), "flip": ExtraOpData(dim_args=[["dims..."]]), "gather": ExtraOpData(dim_args=[["dim"]]), "hash_tensor": ExtraOpData(dim_args=[["dim..."]]), "imag": ExtraOpData(is_view=True), "index_add": ExtraOpData(dim_args=[["dim"]]), "index_copy": ExtraOpData(dim_args=[["dim"]]), "index_fill": ExtraOpData(dim_args=[["dim"]]), "index_reduce.amax": ExtraOpData(dim_args=[["dim"]]), "index_reduce.amin": ExtraOpData(dim_args=[["dim"]]), "index_reduce.mean": ExtraOpData(dim_args=[["dim"]]), "index_reduce.prod": ExtraOpData(dim_args=[["dim"]]), "index_select": ExtraOpData(dim_args=[["dim"]]), "kthvalue": ExtraOpData(dim_args=[["dim"]]), "linalg.cross": ExtraOpData(dim_args=[["dim"]]), "linalg.diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]), "linalg.tensorsolve": ExtraOpData(dim_args=[["dims..."]]), "linalg.vecdot": ExtraOpData(dim_args=[["dim"]]), "linalg.vector_norm": ExtraOpData(dim_args=[["dim..."]]), "log_softmax": ExtraOpData(dim_args=[["dim"]]), "logcumsumexp": ExtraOpData(dim_args=[["dim"]]), "masked.amax": ExtraOpData(dim_args=[["dim"]]), "masked.amin": ExtraOpData(dim_args=[["dim"]]), "masked.argmax": ExtraOpData(dim_args=[["dim"]]), "masked.argmin": ExtraOpData(dim_args=[["dim"]]), "masked.logsumexp": ExtraOpData(dim_args=[["dim"]]), "masked.mean": ExtraOpData(dim_args=[["dim"]]), "masked.norm": ExtraOpData(dim_args=[["dim"]]), "masked.prod": ExtraOpData(dim_args=[["dim"]]), "masked.std": ExtraOpData(dim_args=[["dim"]]), "masked.sum": ExtraOpData(dim_args=[["dim"]]), "masked.var": ExtraOpData(dim_args=[["dim"]]), "max.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]), "median": ExtraOpData(dim_args=[["dim"]]), "mean": ExtraOpData(dim_args=[["dim..."]]), "min.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]), "mode": ExtraOpData(dim_args=[["dim"]]), "movedim": ExtraOpData( dim_args=[["source", "destination"], ["source...", "destination..."]] ), "nanmean": ExtraOpData(dim_args=[["dim..."]]), "nanmedian": ExtraOpData(dim_args=[["dim"]]), "nansum": ExtraOpData(dim_args=[["dim..."]]), "narrow": ExtraOpData(is_view=True, dim_args=[["dim"]]), "narrow_copy": ExtraOpData(dim_args=[["dim"]]), "nn.functional.cosine_similarity": ExtraOpData(dim_args=[["dim"]]), "nn.functional.glu": ExtraOpData(dim_args=[["dim"]]), "permute": ExtraOpData(is_view=True, dim_args=[["dims..."]]), "positive": ExtraOpData(is_view=True), "prod": ExtraOpData(dim_args=[["dim"]]), "ravel": ExtraOpData(is_view=True), "real": ExtraOpData(is_view=True), "renorm": ExtraOpData(dim_args=[["dim"]]), "reshape": ExtraOpData(is_view=True), "reshape_as": ExtraOpData(is_view=True), "roll": ExtraOpData(dim_args=[["dims..."]]), "rot90": ExtraOpData(dim_args=[["dims..."]]), "scatter": ExtraOpData(dim_args=[["dim"]]), "scatter_add": ExtraOpData(dim_args=[["dim"]]), "scatter_reduce.amax": ExtraOpData(dim_args=[["dim"]]), "scatter_reduce.amin": ExtraOpData(dim_args=[["dim"]]), "scatter_reduce.mean": ExtraOpData(dim_args=[["dim"]]), "scatter_reduce.prod": ExtraOpData(dim_args=[["dim"]]), "scatter_reduce.sum": ExtraOpData(dim_args=[["dim"]]), "select": ExtraOpData(is_view=True, dim_args=[["dim"]]), "select_scatter": ExtraOpData(dim_args=[["dim"]]), "slice": ExtraOpData(is_view=True, dim_args=[["dim"]]), "slice_scatter": ExtraOpData(dim_args=[["dim"]]), "softmax": ExtraOpData(dim_args=[["dim"]]), "sort": ExtraOpData(dim_args=[["dim"]]), "split": ExtraOpData(is_view=True, dim_args=[["dim"]]), "split_with_sizes": ExtraOpData(is_view=True, dim_args=[["dim"]]), "split_with_sizes_copy": ExtraOpData(dim_args=[["dim"]]), "squeeze": ExtraOpData(is_view=True, dim_args=[["dim"], ["dim..."]]), "squeeze_copy": ExtraOpData(dim_args=[["dim"], ["dim..."]]), "stack": ExtraOpData(dim_args=[["dim"]]), "std": ExtraOpData(dim_args=[["dim..."]]), "std.unbiased": ExtraOpData(dim_args=[["dim..."]]), "sum": ExtraOpData(dim_args=[["dim..."]]), "t": ExtraOpData(is_view=True), "tensor_split": ExtraOpData(is_view=True, dim_args=[["dim"]]), "tensordot": ExtraOpData(dim_args=[["dims..."]]), "tile": ExtraOpData(dim_args=[["dims..."]]), "topk": ExtraOpData(dim_args=[["dim"]]), "transpose": ExtraOpData(is_view=True, dim_args=[["dim0", "dim1"]]), "transpose_copy": ExtraOpData(dim_args=[["dim0", "dim1"]]), "trapezoid": ExtraOpData(dim_args=[["dim"]]), "trapz": ExtraOpData(dim_args=[["dim"]]), "unbind": ExtraOpData(is_view=True, dim_args=[["dim"]]), "unflatten": ExtraOpData(is_view=True, dim_args=[["dim"]]), "unfold": ExtraOpData(is_view=True, dim_args=[["dimension"]]), "unfold_copy": ExtraOpData(dim_args=[["dimension"]]), "unsafe_chunk": ExtraOpData(dim_args=[["dim"]]), "unsafe_split": ExtraOpData(dim_args=[["dim"]]), "unsqueeze": ExtraOpData(is_view=True, dim_args=[["dim"]]), "unsqueeze_copy": ExtraOpData(dim_args=[["dim"]]), "var": ExtraOpData(dim_args=[["dim..."]]), "var.unbiased": ExtraOpData(dim_args=[["dim..."]]), "view": ExtraOpData(is_view=True), "view_as": ExtraOpData(is_view=True), "view_as_complex": ExtraOpData(is_view=True), "view_as_real": ExtraOpData(is_view=True), } # random integer used for sizes def _rnd(): return torch.randint(3, 8, ()).item() def _raggedness_matches(nt1, nt2): return ( nt1.is_nested and nt2.is_nested and nt1._ragged_idx == nt2._ragged_idx and nt1.shape[nt1._ragged_idx] == nt2.shape[nt2._ragged_idx] ) # Helper function to avoid reusing the exact same tensor / NJT across SampleInputs, # as this causes autograd problems. def _clone(t): requires_grad = t.requires_grad return t.detach().clone().requires_grad_(requires_grad) # Helper function to update a sample with new kwargs / name def _update_sample(sample, new_kwargs): all_kwargs = dict(sample.kwargs) all_kwargs.update(new_kwargs) full_name = ", ".join([sample.name, *(f"{k}={v}" for (k, v) in new_kwargs.items())]) return SampleInput( _clone(sample.input), args=sample.args, kwargs=all_kwargs, name=full_name, ) # Generates a random NT. # dims should be something like [5, None, 10], with None indicating that a # random ragged structure should be used def random_nt_from_dims( dims, device=None, dtype=None, layout=torch.strided, requires_grad=False ): sizes = [[d if d is not None else _rnd() for d in dims[1:]] for d in range(dims[0])] return torch.nested.nested_tensor( [torch.randn(*size) for size in sizes], device=device, dtype=dtype, layout=layout, requires_grad=requires_grad, ) # Helper function to get a reasonable string representation of an NJT for use in # SampleInput names. def _describe_njt(njt) -> str: contig_type = "_contig" if njt.is_contiguous() else "_noncontig" if njt._lengths is not None and njt._offsets is not None: contig_type += "_holes" elif njt._ragged_idx != 1: contig_type += "_transposed" cached_data = "_without_seqlen_cache" if njt._max_seqlen_tensor is not None: cached_data = "_with_seqlen_cache" return f"{njt.dim()}D{contig_type}{cached_data}" # Helper function to get a reasonable string representation of a given dim wrt an NJT. def _describe_dim(njt, dim): if dim == 0: return "batch_dim" elif dim == njt._ragged_idx: return "ragged_dim" return "normal_dim" # Helper function for generating a comprehensive set of NJT sample inputs. def _sample_njts(device, dtype, requires_grad=False, dims=None): if dims is None: dims = [2, 3, 4] if not isinstance(dims, (list, tuple)): dims = [dims] # contiguous NJTs for dim in dims: # with min / max seqlen cached shape = (_rnd(), None, *[_rnd() for _ in range(dim - 2)]) nt = random_nt_from_dims( shape, device=device, dtype=dtype, requires_grad=requires_grad, layout=torch.jagged, ) yield nt # without min / max seqlen cached values = _clone(nt.values()) offsets = _clone(nt.offsets()) yield torch.nested.nested_tensor_from_jagged(values, offsets).requires_grad_( requires_grad ) # non-contiguous transposed NJT (not possible for 2D) if dim > 2: yield nt.transpose(-1, nt._ragged_idx) # non-contiguous with holes NJT values = _clone(nt.values()) offsets = _clone(nt.offsets()) # subtract 1 to cause holes lengths = _clone(offsets.diff() - 1) yield torch.nested.nested_tensor_from_jagged( values=values, offsets=offsets, lengths=lengths, ).requires_grad_(requires_grad) # Computes an unbind-based reference for a given OpInfo on a given SampleInput. # This reference unbinds the input NJT and invokes the op on each of the components, # optionally wrapping the result in an NJT. def unbind_reference(op, sample, wrap_output_as_njt=True): # first NJT in the arglist determines expected ragged structure nt_inp = ( sample.input if sample.input.is_nested # TODO: look in kwargs too? else next(a for a in sample.args if a.is_nested) ) out_ref_components = [] for i in range(nt_inp.shape[0]): def _slice_input(t, i=i, inp=nt_inp): # any NJT with the same ragged structure as the input should # be sliced to pass to the reference if isinstance(t, torch.Tensor) and _raggedness_matches(t, inp): return t[i] # allow the SampleInput to tell us how to slice it for ref calculation elif isinstance(t, torch.Tensor) and hasattr(t, "_batch_dim"): bdim = t._batch_dim # type: ignore[attr] if t.shape[bdim] == 1: return t[0] else: return t.select(bdim, i) else: return t inp = _slice_input(sample.input) args = tree_map(_slice_input, sample.args) kwargs = tree_map(_slice_input, sample.kwargs) # Handle indices in index_put if "index_put" in op.full_name and "indices" in kwargs: if len(kwargs["indices"]) > 1: # If after unrolling we still have indices left, use them kwargs["indices"] = [t[i] for t in kwargs["indices"][1:]] else: # If no indices are left, create them so they match the NJT implementation sequence_put = kwargs["indices"][0].tolist() if i in sequence_put: kwargs["indices"] = [ torch.tensor( list(range(inp.shape[0])), dtype=torch.int32, device=kwargs["indices"][0].device, ) ] else: kwargs["indices"] = [ torch.tensor( [], dtype=torch.int32, device=kwargs["indices"][0].device ) ] from torch.nested._internal.ops import _outer_to_inner_dim # Need to adjust dims to apply on NJT component if op._extra_op_data.dim_args is not None: # get all possible dim-related argnames that could be encountered for this op argnames = tree_map( lambda a: a.replace("...", ""), tree_flatten(op._extra_op_data.dim_args)[0], ) # for all dim-related args present, convert from outer -> inner dim space for argname in {a for a in argnames if a in kwargs}: # allow the SampleInput to tell us how to canonicalize the dim kwargs ndim = nt_inp._ndim if hasattr(nt_inp, "_ndim") else nt_inp.dim() kwargs[argname] = _outer_to_inner_dim( ndim, kwargs[argname], nt_inp._ragged_idx, canonicalize=True ) out_ref_component = op.op(inp, *args, **kwargs) out_ref_components.append(out_ref_component) if wrap_output_as_njt: # handle list / tuple of outputs if len(out_ref_components) > 0 and isinstance( out_ref_components[0], (list, tuple) ): num_returns = len(out_ref_components[0]) # ensure we get the same number of returns for each invocation assert all(len(o) == num_returns for o in out_ref_components) # construct NJTs from same index returns from each invocation njt_returns = [ torch.nested.as_nested_tensor( [o[r] for o in out_ref_components], layout=torch.jagged ) for r in range(num_returns) ] return type(out_ref_components[0])(njt_returns) return torch.nested.as_nested_tensor(out_ref_components, layout=torch.jagged) return out_ref_components # Computes the reference value for a non-reduction unary op with dim-wise application. def unary_dimwise_reference(op, sample, batchwise_reference=None): # extract info about the dim args this op supports assert op._extra_op_data.dim_args is not None single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames() # only support a single non-list dim arg for now assert dimlist_argname is None assert single_dim_argname is not None if sample.kwargs[single_dim_argname] == 0: # unbind reference won't work for batch-wise operation; handle this case here assert batchwise_reference is not None return batchwise_reference(op, sample) return unbind_reference(op, sample) # Computes the reference value for a reduction op. def reduction_reference(op, sample): assert sample.input.is_nested # extract info about the dim args this op supports assert op._extra_op_data.dim_args is not None single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames() assert single_dim_argname is not None dim = sample.kwargs.get( dimlist_argname, sample.kwargs.get(single_dim_argname, None) ) keepdim = sample.kwargs.get("keepdim", False) assert dim != 0, "reductions over just the batch dim are not supported" if isinstance(dim, (tuple, list)): reduce_on_ragged = sample.input._ragged_idx in dim reduce_on_batch = 0 in dim else: reduce_on_ragged = sample.input._ragged_idx == dim reduce_on_batch = dim == 0 if dim is None: # calculate reference value by running reduction on values buffer return op.op(sample.input.values(), *sample.args, **sample.kwargs) if reduce_on_ragged and reduce_on_batch: # run reference directly on buffer with dims converted to inner space from torch.nested._internal.ops import _outer_to_inner_dim ref_kwargs = dict(sample.kwargs) assert dimlist_argname is not None ref_kwargs[dimlist_argname] = _outer_to_inner_dim( sample.input.dim(), dim, sample.input._ragged_idx, canonicalize=True ) out = op.op(sample.input.values(), *sample.args, **ref_kwargs) if keepdim: if isinstance(out, (tuple, list)): # some ops return multiple things; unsqueeze all of them out = type(out)(o.unsqueeze(0) for o in out) else: out = out.unsqueeze(0) return out if reduce_on_ragged and not reduce_on_batch: # calculate reference value by running an unbind reference and stacking out_ref_components = unbind_reference(op, sample, wrap_output_as_njt=False) if len(out_ref_components) > 0 and isinstance( out_ref_components[0], (tuple, list) ): # some ops return multiple things; stack all of them num_returns = len(out_ref_components[0]) # ensure we get the same number of returns for each invocation assert all(len(o) == num_returns for o in out_ref_components) # stack same index returns from each invocation stacked_returns = [ torch.stack([o[r] for o in out_ref_components], dim=0) for r in range(num_returns) ] return type(out_ref_components[0])(stacked_returns) return torch.stack(out_ref_components, dim=0) # unbind reference works for other reductions return unbind_reference(op, sample) def sample_inputs_elementwise_njt_unary( op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs ): if not op_kwargs: op_kwargs = {} for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] ): yield SampleInput(njt, kwargs=dict(op_kwargs), name=_describe_njt(njt)) def sample_inputs_elementwise_njt_binary( op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs ): if not op_kwargs: op_kwargs = {} for njt1 in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] ): njt_desc = _describe_njt(njt1) njt2 = torch.randn_like(njt1) yield SampleInput( _clone(njt1), args=(njt2,), kwargs=dict(op_kwargs), name=f"{njt_desc}: (NT, NT)", ) # broadcasting case: (B, j0, ...) with (B, 1, ...) dense_shape = list(njt1.shape) dense_shape[njt1._ragged_idx] = 1 t = torch.randn( dense_shape, device=device, dtype=dtype, requires_grad=requires_grad, ) t2 = _clone(t) # used for slicing in unbind_reference() t._batch_dim = 0 t2._batch_dim = 0 # (NT, T) yield SampleInput( _clone(njt1), args=(t,), kwargs=dict(op_kwargs), name=f"{njt_desc}: (NT, T) broadcasting 1 over ragged", ) # (T, NT) yield SampleInput( t2, args=(_clone(njt1),), kwargs=dict(op_kwargs), name=f"{njt_desc}: (T, NT) broadcasting 1 over ragged", ) # broadcasting case: (B, j0, ...) with (1, 1...) t = torch.randn( [1 for _ in range(njt1.dim())], device=device, dtype=dtype, requires_grad=requires_grad, ) t2 = _clone(t) # used for slicing in unbind_reference() t._batch_dim = 0 t2._batch_dim = 0 # (NT, T) yield SampleInput( _clone(njt1), args=(t,), kwargs=dict(op_kwargs), name=f"{njt_desc}: (NT, T) broadcasting all 1s", ) # (T, NT) yield SampleInput( t2, args=(_clone(njt1),), kwargs=dict(op_kwargs), name=f"{njt_desc}: (T, NT) broadcasting all 1s", ) # broadcasting case: (B, j0, ...) with (...) if njt1.dim() > njt1._ragged_idx + 1: t = torch.randn( njt1.shape[njt1._ragged_idx + 1 :], device=device, dtype=dtype, requires_grad=requires_grad, ) # (NT, T) yield SampleInput( _clone(njt1), args=(_clone(t),), kwargs=dict(op_kwargs), name=f"{njt_desc}: (NT, T) broadcasting normal dims", ) # (T, NT) yield SampleInput( _clone(t), args=(_clone(njt1),), kwargs=dict(op_kwargs), name=f"{njt_desc}: (T, NT) broadcasting normal dims", ) # broadcasting case: (B, j0, ...) with scalar t = torch.randn((), device=device, dtype=dtype, requires_grad=requires_grad) # (NT, T) yield SampleInput( _clone(njt1), args=(_clone(t),), kwargs=dict(op_kwargs), name=f"{njt_desc}: (NT, T) broadcasting with scalar", ) # (T, NT) yield SampleInput( _clone(t), args=(_clone(njt1),), kwargs=dict(op_kwargs), name=f"{njt_desc}: (T, NT) broadcasting with scalar", ) # mixed broadcasting case: (B, j0, 1) with (B, 1, D) B = 4 D = 16 njt = random_nt_from_dims( (B, None, 1), device=device, dtype=dtype, requires_grad=requires_grad, layout=torch.jagged, ) njt_desc = _describe_njt(njt) t = torch.randn(B, 1, D, device=device, dtype=dtype, requires_grad=requires_grad) t2 = _clone(t) # used for slicing in unbind_reference() t._batch_dim = 0 t2._batch_dim = 0 # (NT, T) yield SampleInput( _clone(njt), args=(t,), kwargs=dict(op_kwargs), name=f"{njt_desc}: (NT, T) mixed broadcasting", ) # (T, NT) yield SampleInput( t2, args=(_clone(njt),), kwargs=dict(op_kwargs), name=f"{njt_desc}: (T, NT) mixed broadcasting", ) def sample_inputs_njt_reduction( op_info, device, dtype, requires_grad, supports_keepdim=True, op_kwargs=None, **kwargs, ): if not op_kwargs: op_kwargs = {} # extract info about the dim args this op supports assert op_info._extra_op_data.dim_args is not None ( single_dim_argname, dimlist_argname, ) = op_info._extra_op_data.get_dim_argnames() assert single_dim_argname is not None supports_dimlist = dimlist_argname is not None for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] ): njt_desc = _describe_njt(njt) keepdim_values = [False, True] if supports_keepdim else [None] for keepdim in keepdim_values: keepdim_suffix = f" with keepdim={keepdim}" if supports_keepdim else "" # single dim-wise reduction; includes reduction over the ragged dim # NB: reduction over the batch dim is not supported! # TODO: Cover this in the set of error inputs for dim in range(1, njt.dim()): dim_desc = "normal" if dim != njt._ragged_idx else "ragged" yield SampleInput( _clone(njt), kwargs={ **op_kwargs, single_dim_argname: dim, **({"keepdim": keepdim} if supports_keepdim else {}), }, name=f"{njt_desc}: {dim_desc} dim reduction{keepdim_suffix}", ) if supports_dimlist: # reduce on both batch and ragged dims yield SampleInput( _clone(njt), kwargs={ **op_kwargs, dimlist_argname: [0, njt._ragged_idx], **({"keepdim": keepdim} if supports_keepdim else {}), }, name=f"{njt_desc}: batch+ragged reduction{keepdim_suffix}", ) # reduce on batch, ragged, and other dims for other_dim in range(njt._ragged_idx + 1, njt.dim()): yield SampleInput( _clone(njt), kwargs={ **op_kwargs, dimlist_argname: [0, njt._ragged_idx, other_dim], **({"keepdim": keepdim} if supports_keepdim else {}), }, name=( f"{njt_desc}: batch+ragged+dim={other_dim} " f"reduction{keepdim_suffix}" ), ) # reduce on two non-ragged, non-batch dims if njt.dim() > 3 and njt._ragged_idx == 1: yield SampleInput( _clone(njt), kwargs={ **op_kwargs, dimlist_argname: [njt.dim() - 2, njt.dim() - 1], **({"keepdim": keepdim} if supports_keepdim else {}), }, name=f"{njt_desc}: two normal dim reduction{keepdim_suffix}", ) # full reduction by specifying all dims yield SampleInput( _clone(njt), kwargs={ **op_kwargs, dimlist_argname: list(range(njt.dim())), **({"keepdim": keepdim} if supports_keepdim else {}), }, name=f"{njt_desc}: all dim reduction{keepdim_suffix}", ) # TODO: Reducing on ragged dim and non-batch dim is not supported; # cover this in the set of error inputs. # full reduction yield SampleInput( _clone(njt), kwargs=dict(op_kwargs), name=f"{njt_desc}: full reduction with keepdim={keepdim}", ) def unsupported_sample_inputs_func(op_name): def _f(op_info, device, dtype, requires_grad, op_name=op_name, **kwargs): raise RuntimeError( f"OpInfo for {op_name} does not support NJT. Support can be added by modifying " "torch/testing/_internal/opinfo/definitions/nested.py." ) return _f def unsupported_reference(op_name): def _f(op, sample): raise RuntimeError( f"OpInfo for {op_name} does not define a ref() function. Support can be added by " "modifying torch/testing/_internal/opinfo/definitions/nested.py." ) return _f # === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES === def sample_inputs_unary_dimwise( op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs ): if op_kwargs is None: op_kwargs = {} # only support a single non-list dim arg for now assert op_info._extra_op_data is not None single_dim_argname, dimlist_argname = op_info._extra_op_data.get_dim_argnames() assert single_dim_argname is not None assert dimlist_argname is None for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] ): for dim in range(njt.dim()): kwargs = {single_dim_argname: dim} kwargs.update(op_kwargs) yield SampleInput( _clone(njt), kwargs=kwargs, name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}", ) def batchwise_reference_chunk(op, sample): # reference for chunk() over dim=0 B = sample.input.size(0) num_chunks = sample.kwargs["chunks"] chunk_size = math.ceil(B / num_chunks) num_full_chunks = B // chunk_size chunk_sizes = [chunk_size for _ in range(num_full_chunks)] if B % chunk_size != 0: # final chunk contains the leftovers chunk_sizes.append(B % chunk_size) # split unbound components into chunks according to calculated sizes components = list(sample.input.unbind()) start = 0 chunks = [] for chunk_size in chunk_sizes: chunks.append(components[start : start + chunk_size]) start += chunk_size # rejoin into NJT outputs return [torch.nested.as_nested_tensor(lst, layout=torch.jagged) for lst in chunks] def batchwise_reference_narrow(op, sample): # TODO: write this! raise NotImplementedError def batchwise_reference_select(op, sample): # reference for select() over dim=0 return sample.input.unbind()[sample.kwargs["index"]] def batchwise_reference_split(op, sample): # TODO: write this! raise NotImplementedError def batchwise_reference_split_with_sizes(op, sample): # TODO: write this! raise NotImplementedError def batchwise_reference_unflatten(op, sample): # TODO: write this! raise NotImplementedError def batchwise_reference_unsqueeze(op, sample): raise ValueError("unsqueeze() is not intended to operate on the batch dim") def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs): # non-contiguous NJTs for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] ): yield SampleInput(njt, name=_describe_njt(njt)) for memory_format in (torch.contiguous_format, torch.preserve_format): # construct a "non-contiguous with holes" NJT values = torch.randn( 10, 5, device=device, dtype=dtype, requires_grad=requires_grad ) offsets = torch.tensor([0, 2, 4, 10], device=device, dtype=torch.int64) lengths = torch.tensor([2, 1, 3], device=device, dtype=torch.int64) njt = torch.nested.nested_tensor_from_jagged( values, offsets=offsets, lengths=lengths ) njt_desc = _describe_njt(njt) yield SampleInput( njt, kwargs={"memory_format": memory_format}, name=f"{njt_desc}: {memory_format})", ) def sample_inputs_fill(op_info, device, dtype, requires_grad, **kwargs): # scalar case unary_func = partial(sample_inputs_elementwise_njt_unary, op_kwargs={"value": 42.0}) yield from unary_func(op_info, device, dtype, requires_grad) # TODO: add Tensor case def sample_inputs_mvl_gamma(p): return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"p": p}) def sample_inputs_polygamma_n(n): return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n}) def sample_inputs_special_polygamma_n(n): return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n}) def sample_inputs_to(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs): for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4], ): other_dtypes = ( d for d in (torch.float32, torch.half, torch.double) if d is not dtype ) for other_dtype in other_dtypes: sample_name = f"{njt.dim()}D: {dtype} -> {other_dtype}" yield SampleInput(_clone(njt), kwargs={"dtype": dtype}, name=sample_name) # only include device transfer for CUDA inputs if "cuda" in device: other_device = "cpu" sample_name = f"{_describe_njt(njt)}: {device} -> {other_device}" yield SampleInput( _clone(njt), kwargs={"device": other_device}, name=sample_name ) def sample_inputs_bmm(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs): for njt_3d in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[3] ): # (B, j1, D) x (B, D, E) => (B, j1, E) if njt_3d._ragged_idx == 1: B, D = njt_3d.shape[0], njt_3d.shape[-1] E = D + 2 other = torch.randn(B, D, E, device=device, dtype=dtype) # used for slicing in unbind_reference() other._batch_dim = 0 njt_desc = _describe_njt(njt_3d) yield SampleInput( _clone(njt_3d), kwargs={"mat2": other}, name=f"{njt_desc}: (B, j, D) x (B, D, E)", ) # TODO (need factory functions): # (B, D, j1) x (B, j1, E) => (B, D, E) def reference_bmm(op, sample): # unbind reduces a dim and bmm requires 3D, so use matmul as the reference matmul_op = copy(op) matmul_op.op = torch.matmul # change arg name from mat2 -> other modified_sample = copy(sample) other = modified_sample.kwargs["mat2"] del modified_sample.kwargs["mat2"] modified_sample.kwargs["other"] = other return unbind_reference(matmul_op, modified_sample) def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): for sample_input in sample_inputs_unary_dimwise( op_info, device, dtype, requires_grad, **kwargs ): # ragged dim chunking: test a single chunks value if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: yield _update_sample(sample_input, {"chunks": 3}) # other dim chunking: test different chunks values else: D = sample_input.input.size(sample_input.kwargs["dim"]) for chunks in [1, D // 2, D - 1, D]: yield _update_sample(sample_input, {"chunks": chunks}) def sample_inputs_matmul( op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs ): # also run bmm samples through for sample_input in sample_inputs_bmm(op_info, device, dtype, requires_grad): # change arg name from mat2 -> other other = sample_input.kwargs["mat2"] del sample_input.kwargs["mat2"] sample_input.kwargs["other"] = other yield sample_input # 3D cases not covered by bmm for njt_3d in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[3] ): # (B, j1, D) x (D, E) => (B, j1, E) if njt_3d._ragged_idx == 1: D = njt_3d.shape[-1] E = D + 2 njt_desc = _describe_njt(njt_3d) yield SampleInput( _clone(njt_3d), kwargs={"other": torch.randn(D, E, device=device, dtype=dtype)}, name=f"{njt_desc}: (B, j, D) x (D, E)", ) # 4D cases for njt_4d in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[4] ): # (B, j1, D, E) x (E, F) => (B, j1, D, F) if njt_4d._ragged_idx == 1: E = njt_4d.shape[-1] F = E + 2 njt_desc = _describe_njt(njt_4d) yield SampleInput( _clone(njt_4d), kwargs={"other": torch.randn(E, F, device=device, dtype=dtype)}, name=f"{njt_desc}: (B, j, D, E) x (E, F)", ) # Dense x NJT cases for njt_3d in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[3], ): # (B, F, E) x (B, E, j1) => (B, F, j1) if njt_3d._ragged_idx == 2: B = njt_3d.shape[0] E = njt_3d.shape[1] F = E + 2 njt_desc = _describe_njt(njt_3d) dense_t = torch.randn( B, F, E, device=device, dtype=dtype, requires_grad=requires_grad ) dense_t._batch_dim = 0 # for unbind_reference() yield SampleInput( dense_t, args=(_clone(njt_3d),), name=f"{njt_desc}: (B, F, E) x (B, E, j1)", ) # NJT x NJT => Dense case for njt_3d in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[3], ): # (B, E, j1) x (B, j1, F) => (B, E, F) if njt_3d._ragged_idx == 2 and njt_3d.is_contiguous(): B, E, _ = njt_3d.shape sum_j1 = len(njt_3d.values()) other_cont = torch.randn( sum_j1, E + 2, device=device, dtype=dtype, requires_grad=requires_grad ) other_njt = torch.nested.nested_tensor_from_jagged( other_cont, njt_3d.offsets(), lengths=njt_3d._lengths ) njt_desc = _describe_njt(njt_3d) yield SampleInput( _clone(njt_3d), kwargs={"other": _clone(other_njt)}, name=f"{njt_desc}: (B, E, j1) x (B, j1, F)", ) # TODO (need factory functions): # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F) def sample_inputs_masked_select( op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs ): for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[2] ): yield SampleInput( njt, kwargs={"mask": (torch.randn_like(njt, requires_grad=False) < 0.0)}, name=_describe_njt(njt), ) def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs): for sample_input in sample_inputs_unary_dimwise( op_info, device, dtype, requires_grad, **kwargs ): # ragged dim narrowing: test a single start, length value if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: yield _update_sample(sample_input, {"start": 1, "length": 2}) # other dim narrowing: test different start, length values else: D = sample_input.input.size(sample_input.kwargs["dim"]) for start, length in [(0, D), (0, D - 1), (1, D - 1), (D - 1, 1)]: yield _update_sample(sample_input, {"start": start, "length": length}) def sample_inputs_nn_functional_embedding( op_info, device, dtype, requires_grad, **kwargs ): indices = torch.nested.nested_tensor( [ torch.tensor([0, 2, 1, 3]), torch.tensor([4, 2, 1]), torch.tensor([6, 7, 5, 2, 4]), ], layout=torch.jagged, dtype=torch.int64, device=device, ) NUM_EMBEDDINGS = 20 EMBEDDING_DIM = 32 weight = torch.randn(NUM_EMBEDDINGS, EMBEDDING_DIM, device=device, dtype=dtype) # NB: the OpInfo entry for embedding_bag expects weight first so the gradients # can be checked yield SampleInput( _clone(weight).requires_grad_(), args=(indices,), ) yield SampleInput( _clone(weight).requires_grad_(), args=(indices,), kwargs={"padding_idx": 1}, ) def sample_inputs_index_put( op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs ): for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] ): for dim in range(njt.dim()): indices = [ torch.tensor(list(range(njt.size(0))), device=njt.device), *[ torch.tensor([0] * njt.size(0), device=njt.device) for _ in range(dim - 1) ], ] njt_desc = _describe_njt(njt) yield SampleInput( _clone(njt), kwargs={ "indices": indices, "values": torch.tensor(1.0, device=njt.device), }, name=f"{njt_desc}: up to dim {dim - 1}", ) # Non-cont NJT for completeness offsets = torch.tensor([0, 2, 5, 7], device=device) lengths = torch.tensor([2, 2, 2], device=device) indices = [ torch.tensor([0, 1, 2], device=device), torch.tensor([0, 1, 1], device=device), torch.tensor([0, 0, 0], device=device), ] a = torch.nested.nested_tensor_from_jagged( torch.zeros(7, 3, device=device), offsets, lengths ).requires_grad_(requires_grad) njt_desc = _describe_njt(a) yield SampleInput( _clone(a), kwargs={"indices": indices, "values": torch.tensor(1.0, device=a.device)}, name=f"{njt_desc}: all dims", ) def sample_inputs_nn_functional_embedding_bag( op_info, device, dtype, requires_grad, **kwargs ): for generate_per_sample_weight in (True, False): for mode in ("sum", "mean", "max"): # per_sample_weights is only supported for mode='sum' if mode != "sum" and generate_per_sample_weight: continue NUM_EMBEDDINGS = 10 EMBEDDING_DIM = 32 weight = torch.randn( NUM_EMBEDDINGS, EMBEDDING_DIM, dtype=dtype, device=device ) njt = torch.nested.nested_tensor( [ torch.randint(0, NUM_EMBEDDINGS, size=(2,)), torch.randint(0, NUM_EMBEDDINGS, size=(3,)), torch.randint(0, NUM_EMBEDDINGS, size=(4,)), ], layout=torch.jagged, dtype=torch.int64, device=device, ) per_sample_weights = None if generate_per_sample_weight: per_sample_weights = torch.randn_like(njt, dtype=dtype) # NB: the OpInfo entry for embedding_bag expects weight first so the gradients # can be checked yield SampleInput( weight, args=(njt,), kwargs={ "mode": mode, "per_sample_weights": per_sample_weights, }, ) def reference_nn_functional_embedding_bag(op, sample): # run reference on a single bag at a time new_kwargs = dict(sample.kwargs) new_kwargs.update( {"offsets": torch.tensor([0], dtype=torch.int64, device=sample.input.device)} ) # flip input / weight back to what unbind_reference() expects sample = SampleInput(sample.args[0], args=(sample.input,), kwargs=new_kwargs) old_op = op.op op.op = torch.nn.functional.embedding_bag output = unbind_reference(op, sample, wrap_output_as_njt=False) op.op = old_op # concat bag outputs to get final output return torch.cat(output, dim=0) def sample_inputs_nn_functional_linear(op_info, device, dtype, requires_grad, **kwargs): for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4, 5] ): # projection over a ragged dim is not currently supported if is_nested_int(njt.size(-1)): continue # with bias NUM_OUTPUT = 10 weight = torch.randn( NUM_OUTPUT, njt.size(-1), device=device, dtype=dtype, requires_grad=requires_grad, ) bias = torch.randn( NUM_OUTPUT, device=device, dtype=dtype, requires_grad=requires_grad ) yield SampleInput( _clone(njt), kwargs={ "weight": _clone(weight), "bias": _clone(bias), }, name=f"{_describe_njt(njt)}: with bias", ) # without bias yield SampleInput( _clone(njt), kwargs={ "weight": _clone(weight), }, name=f"{_describe_njt(njt)}: without bias", ) def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **kwargs): for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4] ): # Second dim is interpreted as number of channels; this should be non-ragged for now num_channels = njt.size(1) if is_nested_int(num_channels): continue # 1D weight weight = torch.randn( num_channels, device=device, dtype=dtype, requires_grad=requires_grad, ) yield SampleInput( _clone(njt), kwargs={ "weight": _clone(weight), }, name=f"{_describe_njt(njt)}: 1D weight", ) # scalar tensor weight yield SampleInput( _clone(njt), kwargs={ "weight": torch.tensor(4.2, device=device, dtype=dtype), }, name=f"{_describe_njt(njt)}: scalar tensor weight", ) def sample_inputs_nn_functional_rms_norm( op_info, device, dtype, requires_grad, **kwargs ): for njt in _sample_njts( device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4] ): # normalize over non-ragged dims for start_dim in range(njt.dim()): if start_dim <= njt._ragged_idx: continue normalized_shape = njt.shape[start_dim:] weight = torch.randn( normalized_shape, device=device, dtype=dtype, requires_grad=requires_grad, ) yield SampleInput( _clone(njt), kwargs={ "normalized_shape": normalized_shape, "weight": weight, }, name=f"{_describe_njt(njt)}", ) sample_inputs_nn_functional_threshold = partial( sample_inputs_elementwise_njt_unary, op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9}, ) def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs): for sample_input in sample_inputs_unary_dimwise( op_info, device, dtype, requires_grad, **kwargs ): # ragged dim chunking: test a single index if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: yield _update_sample(sample_input, {"index": 0}) # other dim chunking: test different indices else: D = sample_input.input.size(sample_input.kwargs["dim"]) for index in [0, D // 2, D - 1]: yield _update_sample(sample_input, {"index": index}) def sample_inputs_split(op_info, device, dtype, requires_grad, **kwargs): for sample_input in sample_inputs_unary_dimwise( op_info, device, dtype, requires_grad, **kwargs ): # ragged dim chunking: test a single split size if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: yield _update_sample(sample_input, {"split_size_or_sections": 3}) # other dim chunking: test different split sizes else: D = sample_input.input.size(sample_input.kwargs["dim"]) for split_size in [1, D // 2, D - 1, D]: yield _update_sample( sample_input, {"split_size_or_sections": split_size} ) def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): for sample_input in sample_inputs_unary_dimwise( op_info, device, dtype, requires_grad, **kwargs ): # It will never make sense to operate on the ragged dim. # TODO: Handle this with error_inputs if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: continue D = sample_input.input.size(sample_input.kwargs["dim"]) # splits should add up to D split1 = torch.randint(0, D - 1, size=()).item() split2 = D - split1 yield _update_sample(sample_input, {"split_sizes": [split1, split2]}) def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs): # squeeze-specific NJT generator (need to ensure there are some 1s in the shape) def _get_njts(): njt = random_nt_from_dims( (4, None, 1, 3, 1), device=device, dtype=dtype, requires_grad=requires_grad, layout=torch.jagged, ) yield njt # without min / max seqlen cached values = njt.values().detach().clone() offsets = njt.offsets().detach().clone() yield torch.nested.nested_tensor_from_jagged(values, offsets) # non-contiguous transposed yield njt.transpose(1, 3) # non-contiguous with holes values = njt.values().detach().clone() offsets = njt.offsets().detach().clone() # subtract 1 to cause holes lengths = (offsets.diff() - 1).detach().clone() yield torch.nested.nested_tensor_from_jagged( values=values, offsets=offsets, lengths=lengths, ) for njt in _get_njts(): # single dim operation for dim in range(njt.dim()): # Operation on batch / ragged dim is never expected to work. # TODO: Handle these via error_inputs. if dim == 0 or dim == njt._ragged_idx: continue yield SampleInput( _clone(njt), kwargs={"dim": dim}, name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}", ) # multiple dim operation (pass no args) yield SampleInput( _clone(njt), kwargs={"dim": dim}, name=f"{_describe_njt(njt)}: multiple dims", ) def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs): for sample_input in sample_inputs_unary_dimwise( op_info, device, dtype, requires_grad, **kwargs ): # It will never make sense to operate on the ragged dim. # TODO: Handle this with error_inputs if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: continue D = sample_input.input.size(sample_input.kwargs["dim"]) # sizes should multiply to be D yield _update_sample(sample_input, {"sizes": [D, 1]}) yield _update_sample(sample_input, {"sizes": [1, D]}) if D % 2 == 0: yield _update_sample(sample_input, {"sizes": [D // 2, 2]}) yield _update_sample(sample_input, {"sizes": [2, D // 2]}) def sample_inputs_unsqueeze(op_info, device, dtype, requires_grad, **kwargs): for sample_input in sample_inputs_unary_dimwise( op_info, device, dtype, requires_grad, **kwargs ): yield sample_input last_dim_sample = _update_sample(sample_input, {"dim": -1}) last_dim_sample.name = ( f"{_describe_njt(last_dim_sample.input)}: add dim to the end" ) # Tell the unbind reference how to canonicalize the dim kwargs # This is necessary because unsqueeze() allows for a dim after # the last dim to indicate an unsqueeze at the end. last_dim_sample.input._ndim = last_dim_sample.input.dim() + 1 yield last_dim_sample def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): for sample in sample_inputs_elementwise_njt_binary( op_info, device, dtype, requires_grad, **kwargs ): other = sample.args[0] sample.args = () sample.kwargs["other"] = other sample.kwargs["condition"] = sample.input > 0.0 sample.name = sample.name.replace("(", "(NT, ") yield sample # === END OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES === # Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs # (involving NJTs) to pass to the op. Full name consists of the OpInfo's name and variant name # separated by a period (e.g. special.polygamma.special_polygamma_n_0). These are necessary # to specify if they cannot be auto-generated for some reason. Try to keep these sorted # in alphabetical order! njt_sample_inputs = { "bmm": sample_inputs_bmm, "chunk": sample_inputs_chunk, "clone": sample_inputs_clone, "count_nonzero": partial(sample_inputs_njt_reduction, supports_keepdim=False), "fill": sample_inputs_fill, **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)}, "nn.functional.embedding": sample_inputs_nn_functional_embedding, "nn.functional.embedding_bag": sample_inputs_nn_functional_embedding_bag, "nn.functional.linear": sample_inputs_nn_functional_linear, "nn.functional.prelu": sample_inputs_nn_functional_prelu, "nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm, "nn.functional.threshold": sample_inputs_nn_functional_threshold, **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)}, "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0), "to": sample_inputs_to, "matmul": sample_inputs_matmul, "masked_select": sample_inputs_masked_select, "narrow": sample_inputs_narrow, "index_put": sample_inputs_index_put, # these two don't have ReductionOpInfo entries "max.reduction_with_dim": sample_inputs_njt_reduction, "min.reduction_with_dim": sample_inputs_njt_reduction, "select": sample_inputs_select, "split": sample_inputs_split, "split_with_sizes": sample_inputs_split_with_sizes, "squeeze": sample_inputs_squeeze, "unflatten": sample_inputs_unflatten, "unsqueeze": sample_inputs_unsqueeze, "where": sample_inputs_where, } njt_references = { "bmm": reference_bmm, "chunk": partial( unary_dimwise_reference, batchwise_reference=batchwise_reference_chunk ), "count_nonzero": reduction_reference, # these two don't have ReductionOpInfo entries "max.reduction_with_dim": reduction_reference, "min.reduction_with_dim": reduction_reference, "narrow": partial( unary_dimwise_reference, batchwise_reference=batchwise_reference_narrow ), "select": partial( unary_dimwise_reference, batchwise_reference=batchwise_reference_select ), "split": partial( unary_dimwise_reference, batchwise_reference=batchwise_reference_split ), "split_with_sizes": partial( unary_dimwise_reference, batchwise_reference=batchwise_reference_split_with_sizes, ), "squeeze": unbind_reference, "nn.functional.embedding_bag": reference_nn_functional_embedding_bag, "unflatten": partial( unary_dimwise_reference, batchwise_reference=batchwise_reference_unflatten ), "unsqueeze": partial( unary_dimwise_reference, batchwise_reference=batchwise_reference_unsqueeze ), } # Translates an OpInfo entry to one that operates on NJTs. def translate_opinfo(op): new_op = copy(op) new_op.supports_njt = True # add some extra info for use in generating tests on the right subset of ops new_op._extra_op_data = extra_op_data.get(op.full_name, ExtraOpData()) if op.full_name in njt_sample_inputs: new_op.sample_inputs_func = njt_sample_inputs[op.full_name] new_op.ref = njt_references.get(op.full_name, unbind_reference) elif isinstance(op, UnaryUfuncInfo): new_op.sample_inputs_func = partial( sample_inputs_elementwise_njt_unary, op_kwargs=None ) new_op.ref = unbind_reference elif isinstance(op, BinaryUfuncInfo): new_op.sample_inputs_func = partial( sample_inputs_elementwise_njt_binary, op_kwargs=None ) new_op.ref = unbind_reference elif isinstance(op, ReductionOpInfo): new_op.sample_inputs_func = partial(sample_inputs_njt_reduction, op_kwargs=None) new_op.ref = reduction_reference # TODO: Translate the rest of the OpInfos else: new_op.sample_inputs_func = unsupported_sample_inputs_func(op.full_name) new_op.ref = unsupported_reference(op.full_name) new_op.supports_njt = False return new_op njt_op_db = [translate_opinfo(op) for op in op_db]