"""isort:skip_file""" # Import order is significant here. from . import math from . import extra from .standard import ( argmax, argmin, bitonic_merge, cdiv, cumprod, cumsum, flip, interleave, max, min, ravel, reduce_or, sigmoid, softmax, sort, sum, swizzle2d, topk, xor_sum, zeros, zeros_like, ) from .core import ( PropagateNan, TRITON_MAX_TENSOR_NUMEL, load_tensor_descriptor, store_tensor_descriptor, make_tensor_descriptor, tensor_descriptor, tensor_descriptor_type, add, advance, arange, associative_scan, assume, async_task, atomic_add, atomic_and, atomic_cas, atomic_max, atomic_min, atomic_or, atomic_xchg, atomic_xor, bfloat16, block_type, broadcast, broadcast_to, cat, cast, clamp, condition, const, constexpr, constexpr_type, debug_barrier, device_assert, device_print, dot, dot_scaled, dtype, expand_dims, float16, float32, float64, float8e4b15, float8e4nv, float8e4b8, float8e5, float8e5b16, full, gather, histogram, inline_asm_elementwise, int1, int16, int32, int64, int8, join, load, make_block_ptr, map_elementwise, max_constancy, max_contiguous, maximum, minimum, multiple_of, num_programs, permute, pi32_t, pointer_type, program_id, range, reduce, reshape, slice, split, static_assert, static_print, static_range, store, tensor, trans, tuple, tuple_type, uint16, uint32, uint64, uint8, view, void, where, ) from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, ceil) from .random import ( pair_uniform_to_normal, philox, philox_impl, rand, rand4x, randint, randint4x, randn, randn4x, uint_to_uniform_float, ) from . import target_info __all__ = [ "PropagateNan", "TRITON_MAX_TENSOR_NUMEL", "load_tensor_descriptor", "store_tensor_descriptor", "make_tensor_descriptor", "tensor_descriptor", "abs", "add", "advance", "arange", "argmax", "argmin", "associative_scan", "assume", "async_task", "atomic_add", "atomic_and", "atomic_cas", "atomic_max", "atomic_min", "atomic_or", "atomic_xchg", "atomic_xor", "bfloat16", "bitonic_merge", "block_type", "broadcast", "broadcast_to", "cat", "cast", "cdiv", "ceil", "clamp", "condition", "const", "constexpr", "constexpr_type", "cos", "cumprod", "cumsum", "debug_barrier", "device_assert", "device_print", "div_rn", "dot", "dot_scaled", "dtype", "erf", "exp", "exp2", "expand_dims", "extra", "fdiv", "flip", "float16", "float32", "float64", "float8e4b15", "float8e4nv", "float8e4b8", "float8e5", "float8e5b16", "floor", "fma", "full", "gather", "histogram", "inline_asm_elementwise", "interleave", "int1", "int16", "int32", "int64", "int8", "join", "load", "log", "log2", "make_block_ptr", "map_elementwise", "math", "max", "max_constancy", "max_contiguous", "maximum", "min", "minimum", "multiple_of", "num_programs", "pair_uniform_to_normal", "permute", "philox", "philox_impl", "pi32_t", "pointer_type", "program_id", "rand", "rand4x", "randint", "randint4x", "randn", "randn4x", "range", "ravel", "reduce", "reduce_or", "reshape", "rsqrt", "slice", "sigmoid", "sin", "softmax", "sort", "split", "sqrt", "sqrt_rn", "static_assert", "static_print", "static_range", "store", "sum", "swizzle2d", "target_info", "tensor", "topk", "trans", "tuple", "uint16", "uint32", "uint64", "uint8", "uint_to_uniform_float", "umulhi", "view", "void", "where", "xor_sum", "zeros", "zeros_like", ] def str_to_ty(name, c): from builtins import tuple if isinstance(name, tuple): fields = type(name).__dict__.get("_fields", None) return tuple_type([str_to_ty(x, c) for x in name], fields) if name[0] == "*": name = name[1:] const = False if name[0] == "k": name = name[1:] const = True ty = str_to_ty(name, c) return pointer_type(element_ty=ty, const=const) if name.startswith("tensordesc"): inner = name.split("<")[1].rstrip(">") dtype, rest = inner.split("[", maxsplit=1) block_shape, rest = rest.split("]", maxsplit=1) block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")] layout = rest.lstrip(",") is_gluon = len(layout) dtype = str_to_ty(dtype, None) ndim = len(block_shape) shape_type = tuple_type([int32] * ndim) # FIXME: Last dim stride should be constexpr(1) stride_type = tuple_type(([int64] * ndim)) block = block_type(dtype, block_shape) if is_gluon: from triton.experimental.gluon.language._layouts import NVMMASharedLayout from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as gluon_tensor_descriptor_type layout = eval(layout, dict(NVMMASharedLayout=NVMMASharedLayout)) assert isinstance(layout, NVMMASharedLayout) return gluon_tensor_descriptor_type(block, shape_type, stride_type, layout) return tensor_descriptor_type(block, shape_type, stride_type) if name.startswith("constexpr"): return constexpr_type(c) tys = { "fp8e4nv": float8e4nv, "fp8e4b8": float8e4b8, "fp8e5": float8e5, "fp8e5b16": float8e5b16, "fp8e4b15": float8e4b15, "fp16": float16, "bf16": bfloat16, "fp32": float32, "fp64": float64, "i1": int1, "i8": int8, "i16": int16, "i32": int32, "i64": int64, "u1": int1, "u8": uint8, "u16": uint16, "u32": uint32, "u64": uint64, "B": int1, } return tys[name]