# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import math as pymath import warnings from typing import Any, Callable, TypeVar from .triton_compat import ( # noqa: F401 _log2, builtins_use_semantic_kwarg, libdevice, math, tl, triton, ) _T = TypeVar("_T") _LOG_2_E: tl.constexpr = tl.constexpr(pymath.log2(pymath.e)) def set_driver_to_cpu(): driver = triton.runtime.driver if backend := triton.backends.backends.get("cpu", None): if isinstance(driver.active, backend.driver): # Don't re-initialize backend if it is already active return driver.set_active(backend.driver()) return # This can be a hard error once triton-cpu is merged into fbcode warnings.warn( "Could not find an active CPU backend. Generated kernels will not be executable!" ) def set_driver_to_gpu(): driver = triton.runtime.driver for name, backend in triton.backends.backends.items(): if backend.driver.is_active() and name != "cpu": # After https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4, # `driver.active` can be of `LazyProxy` type and the sign of this - `_obj` attribute. if ( isinstance(driver.active, backend.driver) or hasattr(driver.active, "_obj") and isinstance(driver.active._obj, backend.driver) ): # Don't re-initialize backend if it is already active return driver.set_active(backend.driver()) return raise RuntimeError("Could not find an active GPU backend") def get_backend_options(): from triton.runtime import driver target = driver.active.get_current_target() backend = triton.compiler.compiler.make_backend(target) options = backend.parse_options(dict()) return options.__dict__ @triton.jit def promote_to_tensor(x): # Addition promotes to tensor for us return x + tl.zeros((1,), tl.int1) @triton.jit def div_floor_integer(a, b): # NOTE: a // b is C division, but we want floor division # Based on c10::div_floor_integer quot = a // b remainder = a % b fixed = tl.where(remainder != 0, quot - 1, quot) return tl.where((a < 0) != (b < 0), fixed, quot) @triton.jit def remainder_integer(a, b): # NOTE: a % b matches C division, not floor division remainder = a % b return tl.where((remainder != 0) & ((a < 0) != (b < 0)), remainder + b, remainder) @triton.jit def is_floating(x): return promote_to_tensor(x).dtype.is_floating() @triton.jit def _prod_accumulate(a, b): return a * b @triton.jit def prod(input, axis): return tl.reduce(input, axis, _prod_accumulate) @triton.jit def minimum(a, b): mask = a < b if is_floating(a): mask |= a != a return tl.where(mask, a, b) @triton.jit def maximum(a, b): mask = a > b if is_floating(a): mask |= a != a return tl.where(mask, a, b) @triton.jit def min2(a, dim): return tl.reduce(a, dim, minimum) @triton.jit def max2(a, dim): return tl.reduce(a, dim, maximum) @triton.jit def minimum_with_index(a_value, a_index, b_value, b_index): mask = a_value < b_value equal = a_value == b_value if is_floating(a_value): a_isnan = a_value != a_value b_isnan = b_value != b_value mask |= a_isnan & (not b_isnan) # Consider NaNs as equal equal |= a_isnan & b_isnan # Prefer lowest index if values are equal mask |= equal & (a_index < b_index) return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) @triton.jit def maximum_with_index(a_value, a_index, b_value, b_index): mask = a_value > b_value equal = a_value == b_value if is_floating(a_value): a_isnan = a_value != a_value b_isnan = b_value != b_value mask |= a_isnan & (not b_isnan) # Consider NaNs as equal equal |= a_isnan & b_isnan # Prefer lowest index if values are equal mask |= equal & (a_index < b_index) return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) @triton.jit def min_with_index(value, index, dim): return tl.reduce((value, index), dim, minimum_with_index) @triton.jit def max_with_index(value, index, dim): return tl.reduce((value, index), dim, maximum_with_index) @triton.jit def exp(x, use_fast_math: tl.constexpr): if use_fast_math: return math.exp(x) else: return libdevice.exp(x) @triton.jit def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr): out_max = max2(lhs_max, dim) out_max_keepdim = tl.expand_dims(out_max, dim) delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim) out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim) return out_max, out_sum @triton.jit def online_softmax_combine(lhs_max, lhs_sum, rhs_max, use_fast_math: tl.constexpr): """ When we do combine, we assume lhs is the accumulator and rhs is the next block of data. Then rhs_sum is always 1. With that assumption, we can save some registers and computation. """ out_max = maximum(lhs_max, rhs_max) lhs_scale = tl.where( out_max == float("-inf"), 1.0, exp(lhs_max - out_max, use_fast_math) ) rhs_scale = tl.where( out_max == float("-inf"), 1.0, exp(rhs_max - out_max, use_fast_math) ) # Should be # out_sum = lhs_sum * lhs_scale + rhs_sum * rhs_scale # but since rhs_sum is all 1, we can simplify it. out_sum = lhs_sum * lhs_scale + rhs_scale return out_max, out_sum @triton.jit def welford_reduce(value, mean, m2, weight, first_iteration): if first_iteration: new_weight = tl.full(weight.shape, 1, weight.dtype) new_mean = value new_m2 = tl.zeros_like(m2) else: delta = value - mean new_weight = weight + 1 new_mean = mean + delta / new_weight new_m2 = m2 + delta * (value - new_mean) return new_mean, new_m2, new_weight @triton.jit def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): delta = mean_2 - mean_1 new_weight = weight_1 + weight_2 w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight) return ( mean_1 + delta * w2_over_w, m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, new_weight, ) @triton.jit def welford(mean, m2, weight, dim): return tl.reduce((mean, m2, weight), dim, welford_combine) @triton.jit def device_assert_then(cond, msg, r): tl.device_assert(cond, msg) return r @triton.jit def randint64(seed, offset, low, high): r0, r1, _r2, _r3 = tl.randint4x(seed, offset) r0 = r0.to(tl.uint64) r1 = r1.to(tl.uint64) result = r0 | (r1 << 32) size = high - low result = result % size.to(tl.uint64) result = result.to(tl.int64) + low return result @triton.jit def _any_combine(a, b): return a | b @triton.jit def any(a, dim): return tl.reduce(a, dim, _any_combine) @triton.jit def bucketize_binary_search( values: tl.tensor, boundaries_ptr: tl.tensor, BOUNDARIES_SIZE: int, BOUNDARIES_UNDERLYING_NUMEL: int, BOUNDARIES_STRIDE: int, boundary_indices: tl.tensor, indexing_dtype: tl.dtype, right: "bool", # triton can't handle the unquoted bool annotation sorter_ptr: tl.tensor, SORTER_STRIDE: int, sorter_indices: tl.tensor, ): """ See [Note: Inductor bucketize op] Inputs: ------- values: the values to bucketize. boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D. BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one individual set of boundaries). BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring any striding. BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor boundary_indices: a tensor of the same size as "values"; each element is an index into a 1-D, un-strided boundaries tensor, pointing to the first element in the set of boundaries used for that value. indexing_dtype: the dtype used for indexing into the boundaries tensor, and the return dtype. right: if true, use boundary intervals closed on the left; otherwise use intervals closed on the right. sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries, but potentially different striding. If present, this allows us to treat boundaries as sorted even if the elements of boundaries are unsorted. SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last dimension of the sorter tensor. sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices". BLOCK_SHAPE: the shape of the data block being processed. """ low = tl.zeros(values.shape, dtype=indexing_dtype) high = tl.full(values.shape, BOUNDARIES_SIZE, dtype=indexing_dtype) full_range = BOUNDARIES_SIZE + 1 while full_range > 1: mid = (high + low) // 2 mask = ( (mid * BOUNDARIES_STRIDE + boundary_indices) < BOUNDARIES_UNDERLYING_NUMEL ).logical_and(mid < BOUNDARIES_SIZE) mid_indices = ( mid if sorter_ptr is None or SORTER_STRIDE is None else tl.load( sorter_ptr + sorter_indices + SORTER_STRIDE * mid, mask=mask, other=0, ) ) bucket_upper_bound = tl.load( boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices, mask=mask, other=0, ) if right: is_above = values >= bucket_upper_bound else: is_above = values > bucket_upper_bound low = tl.where(is_above & mask, mid + 1, low) high = tl.where(is_above, high, mid) full_range = (full_range + 1) // 2 return low @triton.jit def pack_value_flag( value, flag, DTYPE_VALUE_AS_UINT: tl.constexpr, DTYPE_PACK: tl.constexpr, ): # Workaround for triton bug, tensor.to doesn't unwrap constexpr values DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK) return flag.to(DTYPE_PACK) | (uv << bitwidth) @triton.jit def unpack_value( pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT, ): # Workaround for triton bug, tensor.to doesn't unwrap constexpr values DTYPE_VALUE = tl.core._unwrap_if_constexpr(DTYPE_VALUE) DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT) return value_uint.to(DTYPE_VALUE, bitcast=True) @triton.jit def unpack_flag(pack, DTYPE_FLAG): return pack.to(DTYPE_FLAG) @triton.jit def exclusive_scan_decoupled_lookback( scratch_base, block_value, index, combine_fn, DTYPE_VALUE_AS_UINT: tl.constexpr, DTYPE_PACK: tl.constexpr, ): """Compute exclusive scan of a scalar value between blocks Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back scratch_base: Pointer to scratch space in global memory block_value: Scalar value for this block index: Scalar index of this block relative to the current scan combine_fn: Function ``(value, value) -> value`` which is scanned over DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value`` DTYPE_PACK: Unsigned type twice the width of block_value NOTE: This function is limited to values which are 32-bits or less because we need to pack (value, flag) into a single unsigned int. """ # Publish block sum so subsequent blocks don't get stuck waiting for us DTYPE_VALUE = block_value.dtype pack = pack_value_flag( block_value, tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT), DTYPE_VALUE_AS_UINT, DTYPE_PACK, ) if index > 0: tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") # Calculate exclusive prefix scan exclusive_prefix = tl.zeros([], DTYPE_VALUE) prefix_valid = False test_target = index - 1 while test_target >= 0: # tl.atomic_load flag = tl.full([], 0, DTYPE_VALUE_AS_UINT) while flag == 0: pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed") flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT) value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT) if prefix_valid: exclusive_prefix = combine_fn(value, exclusive_prefix) else: exclusive_prefix = value prefix_valid = True if flag == 2: test_target = -1 else: test_target = test_target - 1 # Make inclusive block sum visible to other blocks if prefix_valid: inclusive_prefix = combine_fn(exclusive_prefix, block_value) else: inclusive_prefix = block_value pack = pack_value_flag( inclusive_prefix, tl.full([], 2, DTYPE_VALUE_AS_UINT), DTYPE_VALUE_AS_UINT, DTYPE_PACK, ) tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") return exclusive_prefix @triton.jit def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn): """Compute exclusive scan of a scalar value between blocks Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back scratch_base: Pointer to scratch space in global memory block_value: Scalar value for this block, must be 64-bits wide index: Scalar index of this block relative to the current scan combine_fn: Function ``(value, value) -> value`` which is scanned over init: Scalar value equal to the identity of combine_fn """ # Publish block sum so subsequent blocks don't get stuck waiting for us if index > 0: block_value_u64 = block_value.to(tl.uint64, bitcast=True) tl.store(scratch_base + 3 * index + 1, block_value_u64) tl.debug_barrier() flag_one = tl.full([], 1, tl.uint64) tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release") # Calculate exclusive prefix scan exclusive_prefix = tl.zeros([], block_value.dtype) prefix_valid = False test_target = index - 1 while test_target >= 0: flag = tl.full([], 0, tl.uint64) while flag == 0: flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire") value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32)) value = value_u64.to(block_value.dtype, bitcast=True) if prefix_valid: exclusive_prefix = combine_fn(value, exclusive_prefix) else: exclusive_prefix = value prefix_valid = True if flag == 2: test_target = -1 else: test_target = test_target - 1 # Make inclusive block sum visible to other blocks if prefix_valid: inclusive_prefix = combine_fn(exclusive_prefix, block_value) else: inclusive_prefix = block_value inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True) tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64) tl.debug_barrier() flag_two = tl.full([], 2, tl.uint64) tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release") return exclusive_prefix @triton.jit def frexp(x): # TODO(isuruf): use inline_asm_elementwise here y = libdevice.ilogb(x) + 1 exponent = tl.where(x == 0, 0, y) mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y)) return mantissa, exponent @triton.jit def _compare_and_swap_with_index( x, idxs, rnumel, flip, i: tl.constexpr, n_dims: tl.constexpr, stable: tl.constexpr, descending: tl.constexpr, ): n_outer: tl.constexpr = x.numel >> n_dims shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) y = tl.reshape(x, shape) iy = y.to(idtype, bitcast=True) # slice left/right with 'stride' 2**(n_dims - i - 1) right_mask = tl.arange(0, 2)[None, :, None].to(idtype) left_mask = (1 - right_mask).to(idtype) ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape) iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape) ileft = tl.reshape(ileft, x.shape) iright = tl.reshape(iright, x.shape) left = ileft.to(x.dtype, bitcast=True) right = iright.to(x.dtype, bitcast=True) # idx y_idx = tl.reshape(idxs, shape) left_idx = tl.broadcast_to( tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape ) right_idx = tl.broadcast_to( tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape ) left_idx = tl.reshape(left_idx, x.shape) right_idx = tl.reshape(right_idx, x.shape) # valid if rnumel is None: left_valid_mask = tl.full(x.shape, True, tl.int1) right_valid_mask = tl.full(x.shape, True, tl.int1) else: left_valid_mask = left_idx < rnumel right_valid_mask = right_idx < rnumel # actual compare-and-swap ix = x.to(idtype, bitcast=True) # sort treats nan as having the higher value. comparisons with nan always return False. # to align with sort semantics, we need to update descending to check if right_isnan, # and ascending to check if left_isnan. left_isnan = left != left right_isnan = right != right if descending: cond = left < right if is_floating(left): if not stable: cond = cond | right_isnan else: cond = cond | (right_isnan & (~left_isnan)) else: cond = left > right if is_floating(left): if not stable: cond = cond | left_isnan else: cond = cond | (left_isnan & (~right_isnan)) if stable: # When stable sorting, tie break by index eq = left == right if is_floating(left): eq = eq | (left_isnan & right_isnan) cond = cond | (eq & (left_idx > right_idx)) cond = (right_valid_mask > left_valid_mask) | ( (right_valid_mask == left_valid_mask) & cond ) cond = (cond ^ flip).to(tl.int1) ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs)) return ret.to(x.dtype, bitcast=True), new_idxs @triton.jit def _bitonic_merge_with_index( x, idxs, rnumel, stage: tl.constexpr, alternating: tl.constexpr, n_dims: tl.constexpr, stable: tl.constexpr, descending: tl.constexpr, ): n_outer: tl.constexpr = x.numel >> n_dims tl.static_assert(stage <= n_dims) # flip denotes whether to re-arrange sub-sequences of elements in ascending or # descending order. # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage # if flip = 00110011... then all the elements will be re-arranged alternatingly (with # a stride of 2) at this stage if alternating: shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] flip = tl.reshape( tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape ) else: flip = False # perform `stage` rounds of `compare-and-swap` for i in tl.static_range(stage): x, idxs = _compare_and_swap_with_index( x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending ) return x, idxs @triton.jit def sort_with_index( x, # value idxs, # index rnumel, # number of elements dim: tl.constexpr = None, stable: tl.constexpr = tl.constexpr(False), descending: tl.constexpr = tl.constexpr(False), ): x, idxs = tl.broadcast(x, idxs) # handle default dimension or check that it is the most minor dim _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim tl.static_assert( _dim == len(x.shape) - 1, "only minor dimension is currently supported" ) # iteratively run bitonic merge-sort steps n_dims: tl.constexpr = _log2(x.shape[_dim]) for i in tl.static_range(1, n_dims + 1): x, idxs = _bitonic_merge_with_index( x, idxs, rnumel, i, alternating=i < n_dims, n_dims=n_dims, stable=stable, descending=descending, ) return x, idxs @triton.jit def select_one(x, mask, dim, keep_dims=False): idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False) ix = x.to(idtype, bitcast=True) iy = tl.sum(ix * mask, dim, keep_dims=keep_dims) return iy.to(x.dtype, bitcast=True) @triton.jit def x_grid_barrier(sem): """ Wait for all other thread blocks in grid sharing same y/z program_id to reach this barrier before returning. Args: sem: an uint32 semaphores, zero or 0x80000000 initialized. Must be unique to each y/z program ID. """ # ensure stores before this are visible tl.debug_barrier() one_i32 = 1 one_u32 = one_i32.to(tl.uint32) # type: ignore[attr-defined] expected = tl.num_programs(0).to(tl.uint32) if tl.program_id(0) == 0: nb = 0x80000000 - (expected - one_u32) else: nb = one_u32 old_arrive = tl.atomic_add(sem, nb, sem="release") bar_flipped = False while not bar_flipped: # want a `ld.acquire.gpu.u32 $0,[$1];` but Triton doesn't have it current_arrive = tl.atomic_add(sem, 0, sem="acquire") # current_arrive = tl.load(sem, volatile=True) bar_flipped = ((old_arrive ^ current_arrive) & 0x80000000) != 0 # TODO(jansel): is this needed? tl.debug_barrier() def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]: """ Decorator to mark a function as a Triton built-in function. These functions are evaluated at compile time. Args: f (function): The function to be marked as a Triton built-in. Returns: function: The same function, marked as a Triton built-in. """ if builtins_use_semantic_kwarg: # support Triton before and after https://github.com/triton-lang/triton/pull/7054 # and after https://github.com/triton-lang/triton/pull/7239 def wrapper(*args, _semantic, **kwargs): kwargs["_builder"] = _semantic return f(*args, **kwargs) else: wrapper = f # type: ignore[assignment] wrapper.__triton_builtin__ = True # type: ignore[attr-defined] return wrapper @triton_builtin def constexpr_next_power_of_2( n: tl.constexpr, *, _builder: object = None ) -> tl.constexpr: """ A version triton.next_power_of_two that can be used within a kernel on constants. """ assert isinstance(n, tl.constexpr) return tl.constexpr(triton.next_power_of_2(n.value)) @triton_builtin def if_mask(mask: Any, val, *, _builder: object = None) -> tl.constexpr: """ Work around triton compile error: `ValueError: `other` cannot be provided without `mask`` A compile-time to check to return either `val` or `None` depending on the value of mask. """ if isinstance(mask, tl.constexpr) and mask.value is None: return tl.constexpr(None) return val