from dataclasses import dataclass from typing import List, Any from triton._utils import validate_block_shape @dataclass class TensorDescriptor: base: Any shape: List[int] strides: List[int] block_shape: List[int] padding: str = "zero" def __post_init__(self): rank = len(self.shape) assert len(self.strides) == rank, f"rank mismatch: {self}" assert len(self.block_shape) == rank, f"rank mismatch: {self}" assert rank > 0, "rank must not be zero" assert rank <= 5, "rank cannot be more than 5" ty = type(self.base) if ty.__name__ not in ("FakeTensor", "FunctionalTensor"): assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned" validate_block_shape(self.block_shape) elem_bytes = self.base.dtype.itemsize for stride in self.strides[:-1]: assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned" assert self.strides[-1] == 1, "Last dimension must be contiguous" assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding" if self.padding == "nan": assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors" @staticmethod def from_tensor(tensor: Any, block_shape: List[int], padding="zero"): return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding)