from __future__ import annotations # remove after python 3.11 import warnings from typing import List, Optional, Sequence, Tuple, TypeVar, Generic, Type import numbers from triton.runtime import driver from .._C.libtriton import ir from . import core as tl T = TypeVar('T') TensorTy = TypeVar('TensorTy') class IncompatibleTypeErrorImpl(Exception): def __init__(self, type_a, type_b): self.type_a = type_a self.type_b = type_b self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() super(IncompatibleTypeErrorImpl, self).__init__(self.message) class TritonSemantic(Generic[TensorTy]): tensor: Type[TensorTy] = tl.tensor lang = tl builder: ir.builder def __init__(self, builder): self.builder = builder # ===----------------------------------------------------------------------===## # Programming Model # ===----------------------------------------------------------------------===## def program_id(self, axis: int) -> TensorTy: if axis not in (0, 1, 2): raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") return self.tensor(self.builder.create_get_program_id(axis), tl.int32) def num_programs(self, axis: int) -> TensorTy: if axis not in (0, 1, 2): raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") return self.tensor(self.builder.create_get_num_programs(axis), tl.int32) # ===----------------------------------------------------------------------===// # Implicit Casting Utilities # ===----------------------------------------------------------------------===// def integer_promote_impl(self, a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: a_rank = a_ty.int_bitwidth b_rank = b_ty.int_bitwidth a_sn = a_ty.int_signedness b_sn = b_ty.int_signedness # Rules for signedness taken from "Usual arithmetic conversions" on # https://en.cppreference.com/w/c/language/conversion. if a_sn == b_sn: return a_ty if a_rank > b_rank else b_ty elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: return a_ty if a_rank >= b_rank else b_ty elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: return b_ty if b_rank >= a_rank else a_ty raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") def computation_type_impl(self, a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool, div_or_mod: bool) -> tl.dtype: # 0) For scalars we follow semantics similar to PyTorch, namely: # - If the scalar is of a lower or equal kind (bool < uint < int < fp), # it doesn't participate in the promotion if a_is_scalar != b_is_scalar: scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty) if scalar_ty.kind().value <= tensor_ty.kind().value: # Upcast because of 3) and 4) below! if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)): return tl.float32 return tensor_ty # 1) if one operand is double, the other is implicitly # converted to double if a_ty.is_fp64() or b_ty.is_fp64(): return tl.float64 # 2) if one operand is float, the other is implicitly # converted to float if a_ty.is_fp32() or b_ty.is_fp32(): return tl.float32 # 3 ) if one operand is half, the other is implicitly converted to half # unless we're doing / or %, which do not exist natively in PTX for fp16. # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp if a_ty.is_fp16() or b_ty.is_fp16(): if div_or_mod: return tl.float32 else: return tl.float16 # 4) return bf16 only if both operands are of bf16 if a_ty.is_bf16() and b_ty.is_bf16(): if div_or_mod: return tl.float32 else: return tl.bfloat16 if a_ty.is_bf16() or b_ty.is_bf16(): return tl.float32 # 5) return fp16 if operands are different fp8 if a_ty.is_fp8() and b_ty.is_fp8(): return a_ty if a_ty == b_ty else tl.float16 if not a_ty.is_int() or not b_ty.is_int(): raise TypeError(f"unexpected type {a_ty} and {b_ty}") # 6 ) both operands are integer and undergo # integer promotion if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" "this is unlikely to result in a useful answer. Cast them to the same signedness.") return self.integer_promote_impl(a_ty, b_ty) def to_tensor(self, x, check_type: bool = True): if isinstance(x, bool): return self.tensor(self.builder.get_int1(x), tl.int1) # Note: compile-time const integers are represented by unsigned values elif isinstance(x, int): if -2**31 <= x < 2**31: dtype = tl.int32 elif 2**31 <= x < 2**32: dtype = tl.uint32 elif -2**63 <= x < 2**63: dtype = tl.int64 elif 2**63 <= x < 2**64: dtype = tl.uint64 else: raise ValueError(f'Nonrepresentable integer {x}.') return self.scalar_constant(x, dtype=dtype) elif isinstance(x, float): min_float32 = 2**-126 max_float32 = (2 - 2**-23) * 2**127 abs_x = __builtins__['abs'](x) if abs_x == float("inf") or\ abs_x == 0.0 or \ x != x or \ min_float32 <= abs_x <= max_float32: dtype = tl.float32 else: dtype = tl.float64 return self.scalar_constant(x, dtype=dtype) elif isinstance(x, tl.constexpr): return self.to_tensor(x.value) elif isinstance(x, self.tensor): return x if check_type: raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") return x # ===----------------------------------------------------------------------===// # Binary Operators # ===----------------------------------------------------------------------===// def check_ptr_type_impl(self, type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: if type_a.is_ptr(): if not allow_ptr_a: raise IncompatibleTypeErrorImpl(type_a, type_b) # T* + U* with T != U if type_b.is_ptr() and (type_a != type_b): raise IncompatibleTypeErrorImpl(type_a, type_b) # T* + float if type_b.is_floating(): raise IncompatibleTypeErrorImpl(type_a, type_b) def binary_op_type_checking_impl(self, lhs: TensorTy | numbers.Number, rhs: TensorTy | numbers.Number, allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True, div_or_mod=False) -> Tuple[TensorTy, TensorTy]: lhs_is_scalar = isinstance(lhs, numbers.Number) rhs_is_scalar = isinstance(rhs, numbers.Number) if lhs_is_scalar: lhs_scalar = lhs lhs = self.to_tensor(lhs) if rhs_is_scalar: rhs_scalar = rhs rhs = self.to_tensor(rhs) # implicit typecasting lhs_sca_ty = lhs.type.scalar rhs_sca_ty = rhs.type.scalar self.check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) self.check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): ret_sca_ty = self.computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod) if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned() or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()): raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. " "Perform a explicit cast on one of them.") if ret_sca_ty.is_int(): if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= ret_sca_ty.get_int_max_value()): raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}") if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= ret_sca_ty.get_int_max_value()): raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}") lhs = self.scalar_constant(lhs_scalar, dtype=ret_sca_ty) if lhs_is_scalar else self.cast(lhs, ret_sca_ty) rhs = self.scalar_constant(rhs_scalar, dtype=ret_sca_ty) if rhs_is_scalar else self.cast(rhs, ret_sca_ty) # implicit broadcasting lhs, rhs = self.broadcast_impl_value(lhs, rhs) return lhs, rhs def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_op: callable): if lhs.type.scalar.int_bitwidth >= 64 or not self.builder.options.sanitize_overflow: return lhs_sca_ty = lhs.type.scalar rhs_sca_ty = rhs.type.scalar assert lhs_sca_ty == rhs_sca_ty assert lhs_sca_ty.is_int() lhs = self.cast(lhs, tl.int64) rhs = self.cast(rhs, tl.int64) ret = binary_op(lhs, rhs, False) max_value = lhs_sca_ty.get_int_max_value() max_value = self.scalar_constant(max_value, tl.int64) min_value = lhs_sca_ty.get_int_min_value() min_value = self.scalar_constant(min_value, tl.int64) cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value)) msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}" self.device_assert(cond, msg, None) def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, sanitize_overflow: bool) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): raise TypeError("cannot add pointers together") # offset + ptr # ptr + offset if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): input, other = other, input input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar if input_scalar_ty.is_ptr(): other_handle = other.handle if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64: # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive i64_ty = other.type.with_element_ty(tl.int64).to_ir(self.builder) other_handle = self.builder.create_int_cast(other.handle, i64_ty, False) return self.tensor(self.builder.create_addptr(input.handle, other_handle), input.type) # float + float elif input_scalar_ty.is_floating(): return self.tensor(self.builder.create_fadd(input.handle, other.handle), input.type) # int + int elif input_scalar_ty.is_int(): if sanitize_overflow: self.binary_op_sanitize_overflow_impl(input, other, self.add) return self.tensor(self.builder.create_add(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {input_scalar_ty}") def sub(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, sanitize_overflow: bool) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other, True, False) scalar_ty = input.type.scalar # ptr - offset if scalar_ty.is_ptr(): return self.add(input, self.minus(other), sanitize_overflow=False) # float - float if scalar_ty.is_floating(): return self.tensor(self.builder.create_fsub(input.handle, other.handle), input.type) # int - int elif scalar_ty.is_int(): if sanitize_overflow: self.binary_op_sanitize_overflow_impl(input, other, self.sub) return self.tensor(self.builder.create_sub(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") def mul(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, sanitize_overflow: bool) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other) scalar_ty = input.type.scalar # float * float if scalar_ty.is_floating(): return self.tensor(self.builder.create_fmul(input.handle, other.handle), input.type) # int * int elif scalar_ty.is_int(): if sanitize_overflow: self.binary_op_sanitize_overflow_impl(input, other, self.mul) return self.tensor(self.builder.create_mul(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") def truediv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar # float / int if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): other = self.cast(other, input_scalar_ty) # int / float elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): input = self.cast(input, other_scalar_ty) # int / int (cast to tl.float32) elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): input = self.cast(input, tl.float32) other = self.cast(other, tl.float32) # float / float (cast to the highest exponent type) elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: other = self.cast(other, input_scalar_ty) else: input = self.cast(input, other_scalar_ty) # unreachable else: raise TypeError(f"unexpected type {input_scalar_ty}") return self.tensor(self.builder.create_fdiv(input.handle, other.handle), input.type) def floordiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar if input_scalar_ty.is_int() and other_scalar_ty.is_int(): ret_ty = self.integer_promote_impl(input_scalar_ty, other_scalar_ty) input = self.cast(input, ret_ty) other = self.cast(other, ret_ty) if ret_ty.is_int_signed(): return self.tensor(self.builder.create_sdiv(input.handle, other.handle), input.type) else: return self.tensor(self.builder.create_udiv(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {input_scalar_ty}") def fdiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, ieee_rounding: bool) -> TensorTy: input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): raise TypeError("both operands of fdiv must have floating scalar type") input, other = self.binary_op_type_checking_impl(input, other, False, False, False, True) ret = self.builder.create_fdiv(input.handle, other.handle) return self.tensor(ret, input.type) def mod(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar # float % float if scalar_ty.is_floating(): return self.tensor(self.builder.create_frem(input.handle, other.handle), input.type) # % int elif scalar_ty.is_int(): if scalar_ty.int_signedness != other_scalar_ty.int_signedness: raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " "because they have different signedness;" "this is unlikely to result in a useful answer. Cast them to the same signedness.") if scalar_ty.is_int_signed(): return self.tensor(self.builder.create_srem(input.handle, other.handle), input.type) else: return self.tensor(self.builder.create_urem(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") ############## # other arithmetic ops ############## def minimum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan): x, y = self.binary_op_type_checking_impl(x, y) dtype = x.dtype if dtype.is_floating(): if propagate_nan == tl.PropagateNan.ALL: return self.tensor(self.builder.create_minimumf(x.handle, y.handle), x.type) elif propagate_nan == tl.PropagateNan.NONE: return self.tensor(self.builder.create_minnumf(x.handle, y.handle), x.type) else: raise ValueError(f"Unexpected propagate_nan {propagate_nan}") elif dtype.is_int_signed(): return self.tensor(self.builder.create_minsi(x.handle, y.handle), x.type) elif dtype.is_int_unsigned(): return self.tensor(self.builder.create_minui(x.handle, y.handle), x.type) else: raise TypeError(f"Unexpected dtype {dtype}") def maximum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan): x, y = self.binary_op_type_checking_impl(x, y) dtype = x.dtype if dtype.is_floating(): if propagate_nan == tl.PropagateNan.ALL: return self.tensor(self.builder.create_maximumf(x.handle, y.handle), x.type) elif propagate_nan == tl.PropagateNan.NONE: return self.tensor(self.builder.create_maxnumf(x.handle, y.handle), x.type) else: raise ValueError(f"Unexpected propagate_nan {propagate_nan}") elif dtype.is_int_signed(): return self.tensor(self.builder.create_maxsi(x.handle, y.handle), x.type) elif dtype.is_int_unsigned(): return self.tensor(self.builder.create_maxui(x.handle, y.handle), x.type) else: raise TypeError(f"Unexpected dtype {dtype}") def clamp(self, x: TensorTy, min: TensorTy, max: TensorTy, propagate_nan: tl.PropagateNan): min, max = self.binary_op_type_checking_impl(min, max) x, min = self.binary_op_type_checking_impl(x, min) x, max = self.binary_op_type_checking_impl(x, max) dtype = x.dtype if dtype.is_floating(): return self.tensor(self.builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) else: raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") ############## # bitwise ops ############## def bitwise_op_type_checking_impl(self, input: TensorTy, other: TensorTy) -> Tuple[TensorTy, TensorTy]: input, other = self.binary_op_type_checking_impl(input, other) input_sca_ty = input.type.scalar other_sca_ty = other.type.scalar if not input_sca_ty.is_int() or not other_sca_ty.is_int(): raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) ret_sca_ty = self.integer_promote_impl(input_sca_ty, other_sca_ty) if ret_sca_ty != input_sca_ty: input = self.cast(input, ret_sca_ty) if ret_sca_ty != other_sca_ty: other = self.cast(other, ret_sca_ty) return input, other def and_(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.bitwise_op_type_checking_impl(input, other) return self.tensor(self.builder.create_and(input.handle, other.handle), input.type) def or_(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.bitwise_op_type_checking_impl(input, other) return self.tensor(self.builder.create_or(input.handle, other.handle), input.type) def xor_(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.bitwise_op_type_checking_impl(input, other) return self.tensor(self.builder.create_xor(input.handle, other.handle), input.type) def logical_and(self, input: TensorTy, other: TensorTy) -> TensorTy: if not input.type.is_int1(): input = self.bitcast(input, tl.int1) if not other.type.is_int1(): other = self.bitcast(other, tl.int1) return self.and_(input, other) def logical_or(self, input: TensorTy, other: TensorTy) -> TensorTy: if not input.type.is_int1(): input = self.bitcast(input, tl.int1) if not other.type.is_int1(): other = self.bitcast(other, tl.int1) return self.or_(input, other) def not_(self, input: TensorTy): if not input.type.is_int1(): input = self.bitcast(input, tl.int1) return self.invert(input) def lshr(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.bitwise_op_type_checking_impl(input, other) return self.tensor(self.builder.create_lshr(input.handle, other.handle), input.type) def ashr(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.bitwise_op_type_checking_impl(input, other) return self.tensor(self.builder.create_ashr(input.handle, other.handle), input.type) def shl(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.bitwise_op_type_checking_impl(input, other) return self.tensor(self.builder.create_shl(input.handle, other.handle), input.type) # ===----------------------------------------------------------------------===// # Unary Operators # ===----------------------------------------------------------------------===// def plus(self, input: TensorTy) -> TensorTy: return input def minus(self, input: TensorTy) -> TensorTy: input_sca_ty = input.type.scalar if input_sca_ty.is_ptr(): raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") _0 = self.tensor(self.builder.get_null_value(input_sca_ty.to_ir(self.builder)), input_sca_ty) return self.sub(_0, input, True) def invert(self, input: TensorTy) -> TensorTy: input_sca_ty = input.type.scalar if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") _1 = self.tensor(self.builder.get_all_ones_value(input_sca_ty.to_ir(self.builder)), input_sca_ty) return self.xor_(input, _1) # ===----------------------------------------------------------------------===// # Comparison Operators # ===----------------------------------------------------------------------===// def _bool_like(self, v: TensorTy) -> tl.block_type: return v.type.with_element_ty(tl.int1) def greater_than(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other) scalar_ty = input.type.scalar # float > float if scalar_ty.is_floating(): return self.tensor(self.builder.create_fcmpOGT(input.handle, other.handle), self._bool_like(input)) # > int elif scalar_ty.is_int(): if scalar_ty.is_int_signed(): return self.tensor(self.builder.create_icmpSGT(input.handle, other.handle), self._bool_like(input)) else: return self.tensor(self.builder.create_icmpUGT(input.handle, other.handle), self._bool_like(input)) raise TypeError(f"unexpected type {scalar_ty}") def greater_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other) scalar_ty = input.type.scalar # float >= float if scalar_ty.is_floating(): return self.tensor(self.builder.create_fcmpOGE(input.handle, other.handle), self._bool_like(input)) # >= int elif scalar_ty.is_int(): if scalar_ty.is_int_signed(): return self.tensor(self.builder.create_icmpSGE(input.handle, other.handle), self._bool_like(input)) else: return self.tensor(self.builder.create_icmpUGE(input.handle, other.handle), self._bool_like(input)) raise TypeError(f"unexpected type {scalar_ty}") def less_than(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other) scalar_ty = input.type.scalar # float < float if scalar_ty.is_floating(): return self.tensor(self.builder.create_fcmpOLT(input.handle, other.handle), self._bool_like(input)) # < int elif scalar_ty.is_int(): if scalar_ty.is_int_signed(): return self.tensor(self.builder.create_icmpSLT(input.handle, other.handle), self._bool_like(input)) else: return self.tensor(self.builder.create_icmpULT(input.handle, other.handle), self._bool_like(input)) raise TypeError(f"unexpected type {scalar_ty}") def less_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other) scalar_ty = input.type.scalar # float < float if scalar_ty.is_floating(): return self.tensor(self.builder.create_fcmpOLE(input.handle, other.handle), self._bool_like(input)) # < int elif scalar_ty.is_int(): if scalar_ty.is_int_signed(): return self.tensor(self.builder.create_icmpSLE(input.handle, other.handle), self._bool_like(input)) else: return self.tensor(self.builder.create_icmpULE(input.handle, other.handle), self._bool_like(input)) raise TypeError(f"unexpected type {scalar_ty}") def equal(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other) scalar_ty = input.type.scalar # float == float if scalar_ty.is_floating(): return self.tensor(self.builder.create_fcmpOEQ(input.handle, other.handle), self._bool_like(input)) # == int elif scalar_ty.is_int(): return self.tensor(self.builder.create_icmpEQ(input.handle, other.handle), self._bool_like(input)) raise TypeError(f"unexpected type {scalar_ty}") def not_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: input, other = self.binary_op_type_checking_impl(input, other) scalar_ty = input.type.scalar # float == float if scalar_ty.is_floating(): return self.tensor(self.builder.create_fcmpUNE(input.handle, other.handle), self._bool_like(input)) # == int elif scalar_ty.is_int(): return self.tensor(self.builder.create_icmpNE(input.handle, other.handle), self._bool_like(input)) raise TypeError(f"unexpected type {scalar_ty}") # ===----------------------------------------------------------------------===// # Block Creation # ===----------------------------------------------------------------------===// def arange(self, start: int, end: int, *, ret_ty: tl.block_type = None) -> TensorTy: if not isinstance(start, int) or not isinstance(end, int): raise ValueError("arange's arguments must be of type tl.constexpr") is_start_int64 = bool(start >> 32) is_end_int64 = bool(end >> 32) if is_start_int64 or is_end_int64: raise ValueError("arange must fit in int32") if end <= start: raise ValueError("arange's end argument must be greater than the start argument") range = end - start if (range & (range - 1)) != 0: raise ValueError("arange's range must be a power of 2") shape = [range] if ret_ty is None: ret_ty = tl.block_type(tl.int32, shape) ret_ty_ir = ret_ty.to_ir(self.builder) return self.tensor(self.builder.create_make_range(ret_ty_ir, start, end), ret_ty) def scalar_constant(self, value, dtype: tl.dtype) -> TensorTy: # scalar if dtype is None: raise ValueError("dtype must be specified when value is not a tensor") if value == 0: value = self.builder.get_null_value(dtype.to_ir(self.builder)) else: get_value_fn = getattr(self.builder, f"get_{dtype.name}") value = get_value_fn(value) return self.tensor(value, dtype) def make_scalar(self, value, dtype: tl.dtype) -> TensorTy: if isinstance(value, tl.tensor): assert value.numel.value == 1, "only accepts size-1 tensor" return self.cast(value, dtype) # scalar return self.scalar_constant(value, dtype) def full(self, shape: List[int], value, dtype: tl.dtype) -> TensorTy: return self.splat(self.make_scalar(value, dtype), shape) # ===----------------------------------------------------------------------===// # Shape Manipulation # ===----------------------------------------------------------------------===// def splat(self, value: TensorTy, shape: List[int]) -> TensorTy: assert not value.type.is_block(), "Cannot splat a block tensor" if len(shape) == 0: return value ret_ty = tl.block_type(value.dtype, shape) return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty) def unsplat(self, value: TensorTy) -> TensorTy: return self.tensor(self.builder.create_unsplat(value.handle), value.dtype) def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy: numel = 1 for s in dst_shape: numel *= s if input.type.numel != numel: raise ValueError("reshape() cannot change total number of elements in tensor") ret_ty = tl.block_type(input.type.scalar, dst_shape) return self.tensor(self.builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) def expand_dims(self, input: TensorTy, axis: int) -> TensorTy: dst_shape = [tl._unwrap_if_constexpr(x) for x in input.shape] dst_shape.insert(axis, 1) if not input.type.is_block(): return self.splat(input, shape=dst_shape) ret_ty = tl.block_type(input.type.scalar, dst_shape) return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty) def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy: assert can_reorder, "current implementation of `cat` always may reorder elements" assert len(lhs.shape) == 1 ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type) def join(self, a: TensorTy, b: TensorTy) -> TensorTy: a, b = self.broadcast_impl_value(a, b) # The IR can't handle joining two scalars, so upcast them to 1D tensors, # then downcast the result. was_rank_1 = a.shape == [] if was_rank_1: a = self.expand_dims(a, 0) b = self.expand_dims(b, 0) if isinstance(a.shape[-1], tl.constexpr): two = tl.constexpr(2) else: two = 2 new_shape = a.shape + [two] ret_type = tl.block_type(a.type.scalar, new_shape) ret = self.tensor(self.builder.create_join(a.handle, b.handle), ret_type) if was_rank_1: ret = self.reshape(ret, [2], can_reorder=False) return ret def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]: assert (len(a.shape) > 0) assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2) new_shape = a.shape[:-1] ret_type = tl.block_type(a.type.scalar, new_shape) outLHS, outRHS = self.builder.create_split(a.handle) return ( self.tensor(outLHS, ret_type), self.tensor(outRHS, ret_type), ) def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy: if len(input.shape) != len(dims): raise ValueError("permute dims must have the same length as input shape") if sorted(tl._unwrap_if_constexpr(d) for d in dims) != list(range(len(dims))): raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) return self.tensor(self.builder.create_trans(input.handle, dims), ret_type) def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy: if not input.type.is_block(): return self.splat(input, shape) src_shape = input.type.get_block_shapes() if len(src_shape) != len(shape): raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") if shape == src_shape: return input for i, item in enumerate(src_shape): if shape[i] != item and item != 1: raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" f" must match the existing size ({item}) at non-singleton dimension" f" {i}: {src_shape}, {shape}") ret_ty = tl.block_type(input.type.scalar, shape) return self.tensor(self.builder.create_broadcast(input.handle, shape), ret_ty) def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy: lhs_ty = lhs.type rhs_ty = rhs.type # make_shape_compatible(block, scalar) if lhs_ty.is_block() and not rhs_ty.is_block(): rhs_ty = lhs_ty.with_element_ty(rhs_ty.scalar) rhs = self.tensor(self.builder.create_splat(rhs_ty.to_ir(self.builder), rhs.handle), rhs_ty) # make_shape_compatible(scalar, block) elif not lhs_ty.is_block() and rhs_ty.is_block(): lhs_ty = rhs_ty.with_element_ty(lhs_ty.scalar) lhs = self.tensor(self.builder.create_splat(lhs_ty.to_ir(self.builder), lhs.handle), lhs_ty) # make_shape_compatible(block, block) elif lhs_ty.is_block() and rhs_ty.is_block(): lhs_shape = lhs_ty.get_block_shapes() rhs_shape = rhs_ty.get_block_shapes() if len(lhs_shape) < len(rhs_shape): # Add new axes to lhs for _ in range(len(lhs_shape), len(rhs_shape)): lhs = self.tensor(self.builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values)) lhs_ty = lhs.type lhs_shape = lhs_ty.get_block_shapes() elif len(rhs_shape) < len(lhs_shape): # Add new axes to rhs for _ in range(len(rhs_shape), len(lhs_shape)): rhs = self.tensor(self.builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() assert len(rhs_shape) == len(lhs_shape) ret_shape = [] for i, left in enumerate(lhs_shape): right = rhs_shape[i] if left == 1: ret_shape.append(right) elif (right == 1) or (right == left): ret_shape.append(left) else: raise ValueError("Cannot make_shape_compatible: incompatible dimensions " "at index " + str(i) + ": " + str(left) + " and " + str(right)) if lhs_shape != ret_shape: ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) lhs = self.tensor(self.builder.create_broadcast(lhs.handle, ret_shape), ret_ty) if rhs_shape != ret_shape: ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) rhs = self.tensor(self.builder.create_broadcast(rhs.handle, ret_shape), ret_ty) # (scalar, scalar) => returns original blocks return lhs, rhs ####### # cast ####### def _str_to_rounding_mode(self, rounding_mode: Optional[str]): if rounding_mode is None: return None if rounding_mode == 'rtne': return ir.ROUNDING_MODE.RTNE if rounding_mode == 'rtz': return ir.ROUNDING_MODE.RTZ raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") def bitcast(self, input: TensorTy, dst_ty: tl.dtype) -> TensorTy: src_ty = input.type if src_ty.is_block(): dst_ty = src_ty.with_element_ty(dst_ty.scalar) if src_ty == dst_ty: return input src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): return self.cast(input, dst_ty) # Bitcast src_bits = src_sca_ty.primitive_bitwidth dst_bits = dst_sca_ty.primitive_bitwidth if src_bits != dst_bits: raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " "data-type of size " + str(dst_bits)) return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty) def cast(self, input: TensorTy, dst_ty: tl.dtype, fp_downcast_rounding: Optional[str] = None) -> TensorTy: src_ty = input.type src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar if src_sca_ty == dst_sca_ty: return input if src_ty.is_block(): dst_ty = src_ty.with_element_ty(dst_sca_ty) # For fp downcasting default rounding mode should be RTNE, for all other conversions it should # not be set fp_downcast_rounding = self._str_to_rounding_mode(fp_downcast_rounding) use_custom_rounding = False if dst_sca_ty.is_floating() and src_sca_ty.is_floating( ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True else: if fp_downcast_rounding is not None: raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): assert self.builder.codegen_fns.get( "convert_custom_types") is not None, "target doesn't provide conversion for this type." return self.builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _semantic=self) # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 # and non-default rounding modes for downcasting if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ use_custom_rounding: return self.tensor( self.builder.create_fp_to_fp(input.handle, dst_ty.to_ir(self.builder), fp_downcast_rounding), dst_ty) # bf16 <=> (not fp32) if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): return self.cast(self.cast(input, tl.float32), dst_sca_ty) # Standard floating types' casting: truncation # fp64 => fp32, fp16, bf16 # fp32 => fp16, bf16 truncate_fp = src_sca_ty.is_floating() and \ dst_sca_ty.is_floating() and \ src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth if truncate_fp: return self.tensor(self.builder.create_fp_trunc(input.handle, dst_ty.to_ir(self.builder)), dst_ty) # Standard floating types' casting: extension # fp32 => fp64 # fp16 => fp32, fp64 # bf16 => fp32, fp64 ext_fp = src_sca_ty.is_floating() and \ dst_sca_ty.is_floating() and \ src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth if ext_fp: return self.tensor(self.builder.create_fp_ext(input.handle, dst_ty.to_ir(self.builder)), dst_ty) # Casting between integer types if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() if dst_sca_ty.is_bool(): ty = input.dtype.to_ir(self.builder) _0 = self.tensor(self.builder.get_null_value(ty), input.dtype) return self.not_equal(input, _0) else: return self.tensor(self.builder.create_int_cast(input.handle, dst_ty.to_ir(self.builder), sign_extend), dst_ty) # Casting standard floating types to integer types if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): if dst_sca_ty.is_bool(): ty = input.dtype.to_ir(self.builder) _0 = self.tensor(self.builder.get_null_value(ty), input.dtype) return self.not_equal(input, _0) elif dst_sca_ty.is_int_signed(): return self.tensor(self.builder.create_fp_to_si(input.handle, dst_ty.to_ir(self.builder)), dst_ty) else: return self.tensor(self.builder.create_fp_to_ui(input.handle, dst_ty.to_ir(self.builder)), dst_ty) # Casting integer types to standard floating types if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): return self.tensor(self.builder.create_ui_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty) else: return self.tensor(self.builder.create_si_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty) # Casting pointer types to integer types if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): bitwidth = dst_sca_ty.int_bitwidth if bitwidth == 64: return self.tensor(self.builder.create_ptr_to_int(input.handle, dst_ty.to_ir(self.builder)), dst_ty) if bitwidth == 1: return self.not_equal(self.cast(input, tl.int64), self.tensor(self.builder.get_int64(0), tl.int64)) # Casting integer types to pointer types if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): return self.tensor(self.builder.create_int_to_ptr(input.handle, dst_ty.to_ir(self.builder)), dst_ty) # Casting pointer types to pointer types if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty) assert False, f'cannot cast {input} to {dst_ty}' # ===----------------------------------------------------------------------===// # Memory Operators # ===----------------------------------------------------------------------===// def _str_to_load_cache_modifier(self, cache_modifier): cache = ir.CACHE_MODIFIER.NONE # default if cache_modifier: if cache_modifier == ".ca": cache = ir.CACHE_MODIFIER.CA elif cache_modifier == ".cg": cache = ir.CACHE_MODIFIER.CG elif cache_modifier == ".cv": cache = ir.CACHE_MODIFIER.CV else: raise ValueError(f"Cache modifier {cache_modifier} not supported") return cache def _str_to_store_cache_modifier(self, cache_modifier): cache = ir.CACHE_MODIFIER.NONE # default if cache_modifier: if cache_modifier == ".wb": cache = ir.CACHE_MODIFIER.WB elif cache_modifier == ".cg": cache = ir.CACHE_MODIFIER.CG elif cache_modifier == ".cs": cache = ir.CACHE_MODIFIER.CS elif cache_modifier == ".wt": cache = ir.CACHE_MODIFIER.WT else: raise ValueError(f"Cache modifier {cache_modifier} not supported") return cache def _str_to_eviction_policy(self, eviction_policy): eviction = ir.EVICTION_POLICY.NORMAL # default if eviction_policy: if eviction_policy == "evict_last": eviction = ir.EVICTION_POLICY.EVICT_LAST elif eviction_policy == "evict_first": eviction = ir.EVICTION_POLICY.EVICT_FIRST else: raise ValueError(f"Eviction policy {eviction_policy} not supported") return eviction def _str_to_padding_option(self, padding_option): padding = None # default if padding_option: if padding_option == "zero": padding = ir.PADDING_OPTION.PAD_ZERO elif padding_option == "nan": padding = ir.PADDING_OPTION.PAD_NAN else: raise ValueError(f"Padding option {padding_option} not supported") return padding def _str_to_sem(self, sem_option): sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE if sem_option: if sem_option == "acquire": sem = ir.MEM_SEMANTIC.ACQUIRE elif sem_option == "release": sem = ir.MEM_SEMANTIC.RELEASE elif sem_option == "acq_rel": sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE elif sem_option == "relaxed": sem = ir.MEM_SEMANTIC.RELAXED else: raise ValueError(f"Memory semantic {sem_option} not supported") return sem def _str_to_scope(self, scope_option): scope = ir.MEM_SYNC_SCOPE.GPU if scope_option: if scope_option == "gpu": scope = ir.MEM_SYNC_SCOPE.GPU elif scope_option == "cta": scope = ir.MEM_SYNC_SCOPE.CTA elif scope_option == "sys": scope = ir.MEM_SYNC_SCOPE.SYSTEM else: raise ValueError(f"Memory semantic {scope_option} not supported") return scope def _canonicalize_boundary_check(self, boundary_check, block_shape): if boundary_check: if not hasattr(boundary_check, "__iter__"): boundary_check = [boundary_check] boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] for dim in boundary_check: assert isinstance(dim, int) and 0 <= dim < len(block_shape) assert len(boundary_check) > 0 assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" return sorted(boundary_check) return () def _load_block_pointer(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile): # Load by a block pointer: `pointer_type>` # Block pointer can not have `mask` and `other` arguments if mask is not None or other is not None: raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") elt_ty = ptr.type.element_ty.element_ty assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`" if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: raise ValueError("Padding option `nan` is not supported for integer block pointers") # `dst_ty` is de-referenced type of the pointer type dst_ty = ptr.type.element_ty # Check `boundary_check` argument boundary_check = self._canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) # Build IR return self.tensor( self.builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) def _load_legacy(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile): # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` if not ptr.type.scalar.is_ptr(): raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") # Check `mask`, `other`, `boundary_check`, and `padding` arguments if mask is None and other is not None: raise ValueError("`other` cannot be provided without `mask`") if padding or boundary_check: raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" "pointers or loading a scalar. Because the compiler does not know the boundary; please " "use block pointers (defined by `make_block_ptr`) instead") # For a pointer of scalar, check the type of `mask` and `other` if not ptr.type.is_block(): if mask and mask.type.is_block(): raise ValueError("Mask argument cannot be block type if pointer argument is not a block") if other and other.type.is_block(): raise ValueError("Other argument cannot be block type if pointer argument is not a block") # Make `mask` and `other` into the same shape as `ptr` if ptr.type.is_block(): if mask is not None: ptr, mask = self.broadcast_impl_value(ptr, mask) if other is not None: ptr, other = self.broadcast_impl_value(ptr, other) # Get `pointer_type` and `elt_ty` ptr_ty = ptr.type.scalar elt_ty = ptr_ty.element_ty # Treat `pointer_type` as `pointer_type` is_bool = elt_ty == tl.int1 if is_bool: elt_ty = tl.int8 ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) ptr = self.cast(ptr, ptr_ty) # Cast `other` into `elt_ty` type if other is not None: other = self.cast(other, elt_ty) # Create loaded result type `dst_ty` if ptr.type.is_block(): dst_ty = ptr.type.with_element_ty(elt_ty) else: # Load by de-referencing the pointer of scalar dst_ty = elt_ty # Build IR if mask is None: ret = self.tensor(self.builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) else: ret = self.tensor( self.builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, is_volatile), dst_ty) if is_bool: ret = self.cast(ret, tl.int1) return ret def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy], boundary_check: Tuple, padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool) -> TensorTy: # Cache, eviction and padding options cache = self._str_to_load_cache_modifier(cache_modifier) eviction = self._str_to_eviction_policy(eviction_policy) padding = self._str_to_padding_option(padding_option) if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): # Load by a block pointer: `pointer_type>` return self._load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile) else: # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile) def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str, eviction_policy: str) -> TensorTy: assert isinstance(desc, tl.tensor_descriptor_base) ndim = len(desc.block_shape) assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" offsets = self._convert_to_ir_values(offsets, require_i64=False) x = self.builder.create_descriptor_load(desc.handle, offsets, self._str_to_load_cache_modifier(cache_modifier), self._str_to_eviction_policy(eviction_policy)) return self.tensor(x, desc.block_type) def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None: assert isinstance(desc, tl.tensor_descriptor_base) ndim = len(desc.block_shape) assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" assert value.shape == desc.block_shape def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: self.validate_store_like(desc, value, offsets) # implicitly cast to the descriptor's type value = self.cast(value, desc.dtype) offsets = self._convert_to_ir_values(offsets, require_i64=False) return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void) def descriptor_atomic_add(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: self.validate_store_like(desc, value, offsets) assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, "Unsupported dtype" offsets = self._convert_to_ir_values(offsets, require_i64=False) kind = ir.DESCRIPTOR_REDUCE_KIND.ADD return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) def _has_native_tma(self, ): target = driver.active.get_current_target() return (target.backend == "cuda" and target.arch >= 90) def _descriptor_atomic_min_max_supported(self, dtype): assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype" if dtype in {tl.float16, tl.bfloat16}: assert self._has_native_tma(), "16-bit float types require native tma support" def descriptor_atomic_min(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: self.validate_store_like(desc, value, offsets) self._descriptor_atomic_min_max_supported(desc.dtype) offsets = self._convert_to_ir_values(offsets, require_i64=False) kind = ir.DESCRIPTOR_REDUCE_KIND.MIN return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) def descriptor_atomic_max(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: self.validate_store_like(desc, value, offsets) self._descriptor_atomic_min_max_supported(desc.dtype) offsets = self._convert_to_ir_values(offsets, require_i64=False) kind = ir.DESCRIPTOR_REDUCE_KIND.MAX return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) def descriptor_atomic_and(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: self.validate_store_like(desc, value, offsets) assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" offsets = self._convert_to_ir_values(offsets, require_i64=False) kind = ir.DESCRIPTOR_REDUCE_KIND.AND return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) def descriptor_atomic_or(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: self.validate_store_like(desc, value, offsets) assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" offsets = self._convert_to_ir_values(offsets, require_i64=False) kind = ir.DESCRIPTOR_REDUCE_KIND.OR return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: self.validate_store_like(desc, value, offsets) assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" offsets = self._convert_to_ir_values(offsets, require_i64=False) kind = ir.DESCRIPTOR_REDUCE_KIND.XOR return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy: assert isinstance(desc, tl.tensor_descriptor_base) assert cache_modifier == "", "cache modifier is not supported yet" assert eviction_policy == "", "eviction policy is not supported yet" # Validate descriptor. assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}" # Validate offsets. assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}" # Validate minimum block size. assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}" dtype = desc.dtype min_cols = 32 // dtype.primitive_bitwidth * 8 assert desc.block_shape[ 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}" type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]]) y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0] x = self.builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(self.builder)) return self.tensor(x, type) def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy: assert isinstance(desc, tl.tensor_descriptor_base) # Validate descriptor. assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}" # Validate offsets. assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}" # Validate minimum block size. assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}" dtype = desc.dtype min_cols = 32 // dtype.primitive_bitwidth * 8 assert desc.block_shape[ 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}" y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0] self.builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset) return self.tensor(None, tl.void) def _store_block_pointer(self, ptr, val, mask, boundary_check, cache, eviction): # Store by a block pointer: `pointer_type>` # Block pointers can not have the `mask` argument if mask is not None: raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") # Check same shape and element type block_shape = ptr.type.element_ty.get_block_shapes() if not val.type.is_block(): val = self.broadcast_impl_shape(val, block_shape) assert val.type.is_block(), "Value argument must be block type or a scalar" assert block_shape == val.type.get_block_shapes( ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" elt_ty = ptr.type.element_ty.element_ty assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`" # Check `boundary_check` argument boundary_check = self._canonicalize_boundary_check(boundary_check, block_shape) # Cast to target data type val = self.cast(val, elt_ty) # Build IR return self.tensor( self.builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), tl.void) def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction): # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` if not ptr.type.scalar.is_ptr(): raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") # Check `boundary_check` argument if boundary_check: raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " "scalar. Because the compiler does not know the boundary; please use block pointers " "(defined by `make_block_ptr`) instead") # For a pointer of scalar, check the type of `val` and `mask` if not ptr.type.is_block(): if val.type.is_block(): raise ValueError("Value argument cannot be block type if pointer argument is not a block") if mask and mask.type.is_block(): raise ValueError("Mask argument cannot be block type if pointer argument is not a block") # Make `mask` and `val` into the same shape as `ptr` if ptr.type.is_block(): val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes()) if mask is not None: mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes()) ptr_ty = ptr.type.scalar elt_ty = ptr_ty.element_ty # Treat `pointer_type` as `pointer_type` if elt_ty == tl.int1: elt_ty = tl.int8 ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) ptr = self.cast(ptr, ptr_ty) # Cast to target data type val = self.cast(val, elt_ty) # Build IR if mask is None: return self.tensor(self.builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) if not mask.type.scalar.is_bool(): raise ValueError("Mask must have boolean scalar type") return self.tensor(self.builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) def store(self, ptr: TensorTy, val: TensorTy, mask: Optional[TensorTy], boundary_check, cache_modifier: str, eviction_policy: str) -> TensorTy: # Cache and eviction options cache = self._str_to_store_cache_modifier(cache_modifier) eviction = self._str_to_eviction_policy(eviction_policy) if ptr.type.is_const() or ptr.type.scalar.is_const(): raise ValueError("Cannot store to a constant pointer") if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): # Store by a block pointer: `pointer_type>` return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction) else: # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` return self._store_legacy(ptr, val, mask, boundary_check, cache, eviction) ######### # atomic ######### def atomic_cas(self, ptr: TensorTy, cmp: TensorTy, val: TensorTy, sem: str, scope: str) -> TensorTy: sem = self._str_to_sem(sem) scope = self._str_to_scope(scope) element_ty = ptr.type.scalar.element_ty if element_ty.primitive_bitwidth not in [16, 32, 64]: raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") return self.tensor(self.builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, op: str) -> Tuple[TensorTy, TensorTy, TensorTy]: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) if ptr.type.is_const() or ptr.type.element_ty.is_const(): raise ValueError("Cannot store to a constant pointer") element_ty = ptr.type.scalar.element_ty if element_ty is tl.float16 and op != 'add': raise ValueError("atomic_" + op + " does not support fp16") if element_ty is tl.bfloat16 and op != 'add': raise ValueError("atomic_" + op + " does not support bf16") if element_ty in [tl.int16, tl.uint16] or element_ty.primitive_bitwidth < 16: raise ValueError("atomic_" + op + " does not support " + str(element_ty)) if ptr.type.is_block(): if mask is not None: mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes()) if val is not None: val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes()) val = self.cast(val, ptr.type.scalar.element_ty) if mask is None: mask_ir = self.builder.get_int1(True) mask_ty = tl.int1 if ptr.type.is_block(): mask_ty = ptr.type.with_element_ty(tl.int1) mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir) mask = self.tensor(mask_ir, mask_ty) return ptr, val, mask def _signbit(self, x: TensorTy) -> TensorTy: bitwidth = x.dtype.primitive_bitwidth idtype = tl.get_int_dtype(bitwidth=bitwidth, signed=False) ix = self.bitcast(x, idtype) signbit = self.lshr(ix, bitwidth - 1) return self.cast(signbit, tl.int1) def atomic_max(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'max') sem = self._str_to_sem(sem) scope = self._str_to_scope(scope) sca_ty = val.type.scalar # direct call to atomic_max for integers if sca_ty.is_int(): if sca_ty.is_int_signed(): return self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) else: return self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) # for float # return atomic_smax(i_ptr, i_val) if val >= 0 # return atomic_umin(i_ptr, i_val) if val < 0 if sca_ty not in {tl.float32, tl.float64}: raise TypeError(f"atomic_max not supported for dtype {sca_ty}") i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 i_val = self.bitcast(val, i_type) i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1)) ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 ui_val = self.bitcast(val, ui_type) ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1)) neg = self._signbit(val) pos = self.not_(neg) pos_ret = self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, self.and_(mask, pos).handle, sem, scope), i_val.type) neg_ret = self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, self.and_(mask, neg).handle, sem, scope), ui_val.type) ret = self.where(pos, pos_ret, neg_ret) return self.bitcast(ret, sca_ty) def atomic_min(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'min') sem = self._str_to_sem(sem) scope = self._str_to_scope(scope) sca_ty = val.type.scalar # direct call to atomic_min for integers if sca_ty.is_int(): if sca_ty.is_int_signed(): return self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) else: return self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) # for float # return atomic_smin(i_ptr, i_val) if val >= 0 # return atomic_umax(i_ptr, i_val) if val < 0 if sca_ty not in {tl.float32, tl.float64}: raise TypeError(f"atomic_min not supported for dtype {sca_ty}") i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 i_val = self.bitcast(val, i_type) i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1)) ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 ui_val = self.bitcast(val, ui_type) ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1)) neg = self._signbit(val) pos = self.not_(neg) pos_ret = self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, self.and_(mask, pos).handle, sem, scope), i_val.type) neg_ret = self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, self.and_(mask, neg).handle, sem, scope), ui_ptr.type) ret = self.where(pos, pos_ret, neg_ret) return self.bitcast(ret, sca_ty) def atomic_add(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'add') sem = self._str_to_sem(sem) scope = self._str_to_scope(scope) sca_ty = val.type.scalar op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) def atomic_and(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'and') sem = self._str_to_sem(sem) scope = self._str_to_scope(scope) return self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type) def atomic_or(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'or') sem = self._str_to_sem(sem) scope = self._str_to_scope(scope) return self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type) def atomic_xor(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xor') sem = self._str_to_sem(sem) scope = self._str_to_scope(scope) return self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type) def atomic_xchg(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xchg') sem = self._str_to_sem(sem) scope = self._str_to_scope(scope) return self.tensor( self.builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), val.type) # ===----------------------------------------------------------------------===// # Linear Algebra # ===----------------------------------------------------------------------===// def _str_to_dot_input_precision(self, input_precision): assert input_precision.lower() in self.builder.options.allowed_dot_input_precisions, \ f"input_precision must be one of {self.builder.options.allowed_dot_input_precisions}. Got {input_precision}" input_precision = input_precision.upper() if input_precision == "TF32X3": input_precision = "TF32x3" return getattr(ir.INPUT_PRECISION, input_precision) def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str], max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy: assert lhs.type.is_block() and rhs.type.is_block() if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): # All combinations of supported fp8 x fp8 are permitted pass else: assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32, tl.float64), f"Unsupported lhs dtype {lhs.dtype}" assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32, tl.float64), f"Unsupported rhs dtype {rhs.dtype}" assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): if "fp8e4b15" in self.builder.options.deprecated_fp8_dot_operand_dtypes: warnings.warn( "the use of fp8e4b15 is deprecated on Hopper and later architectures and can cause significant slow down. It will be removed in a future triton release" ) # We upcast because there's no fp8e4b15 type in MLIR lhs = self.cast(lhs, tl.float16) rhs = self.cast(rhs, tl.float16) uses_fp8e4b8 = lhs.dtype.is_fp8e4b8() or rhs.dtype.is_fp8e4b8() uses_fp8e5b16 = lhs.dtype.is_fp8e5b16() or rhs.dtype.is_fp8e5b16() if uses_fp8e4b8 or uses_fp8e5b16: type_name = "fp8e4b8" if uses_fp8e4b8 else "fp8e5b16" if type_name in self.builder.options.deprecated_fp8_dot_operand_dtypes: arch = self.builder.options.arch warnings.warn( f"{type_name} is AMD gfx942 specific and not supported on {arch} so it's upcasted to fp16 and can cause significant slow down. " f"Please use OCP fp8 variants on {arch} for performance") lhs = self.cast(lhs, tl.float16) rhs = self.cast(rhs, tl.float16) if input_precision is None: input_precision = self.builder.options.default_dot_input_precision input_precision = self._str_to_dot_input_precision(input_precision) lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" assert lhs.shape[-1].value == rhs.shape[ -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" assert self.builder.codegen_fns.get( "min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." min_dot_size = self.builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ and rhs.shape[-1].value >= min_dot_size[1], \ f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" if lhs.type.scalar.is_int(): assert lhs.type.scalar == tl.int8, "only int8 supported!" _0 = self.builder.get_int32(0) ret_scalar_ty = tl.int32 elif out_dtype.is_bf16(): raise ValueError( "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`" ) elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): _0 = self.builder.get_fp32(0) ret_scalar_ty = tl.float32 elif lhs.type.scalar.is_fp64(): _0 = self.builder.get_fp64(0) ret_scalar_ty = tl.float64 else: _0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0) ret_scalar_ty = out_dtype M = lhs.type.shape[-2] N = rhs.type.shape[-1] K = lhs.type.shape[-1] B = lhs.type.shape[0] if lhs_rank == 3 else None ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) if acc is None: acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0) else: acc_handle = acc.handle assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 if max_num_imprecise_acc is None: if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): max_num_imprecise_acc = self.builder.options.max_num_imprecise_acc_default else: max_num_imprecise_acc = 0 else: if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})") return self.tensor( self.builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty) def _str_to_fp_type(self, float_format: str): ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None) if ty_enum is None: raise ValueError(f"Invalid float format: {float_format}.") return ty_enum def _bitcast_to_fp_type(self, val: TensorTy, float_format: str): """ If float_format is subbyte, make sure it's packed as uint8 and return it. Otherwise, return a tensor (perhaps bitcasting) of the specified float format. """ triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) if triton_ty is None: assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" return val if val.dtype == triton_ty: return val else: unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format] assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" return self.bitcast(val, triton_ty) def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy, rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool, lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy: assert lhs.type.is_block() and rhs.type.is_block() #TODO: validate types. lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" lhs_format: str = lhs_format.value rhs_format: str = rhs_format.value lhs_format_enum = self._str_to_fp_type(lhs_format) rhs_format_enum = self._str_to_fp_type(rhs_format) allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"} assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) lhs = self._bitcast_to_fp_type(lhs, lhs_format) rhs = self._bitcast_to_fp_type(rhs, rhs_format) assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" M, K_LHS = lhs.type.shape[-2:] K_RHS, N = rhs.type.shape[-2:] PACKED_A = 2 if lhs_format == "e2m1" else 1 PACKED_B = 2 if rhs_format == "e2m1" else 1 PACKED_A_DIM = PACKED_A * K_LHS if lhs_k_pack else K_LHS PACKED_B_DIM = PACKED_B * K_RHS if rhs_k_pack else K_RHS assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}" B = lhs.type.shape[0] if lhs_rank == 3 else None if not lhs_k_pack: M = M * PACKED_A if not rhs_k_pack: N = N * PACKED_B ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) _0 = self.builder.get_fp32(0) if acc is None: acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0) else: acc_handle = acc.handle assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle return self.tensor( self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle, rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty) # ===----------------------------------------------------------------------===// # Indexing # ===----------------------------------------------------------------------===// def where(self, condition: TensorTy, x: TensorTy, y: TensorTy) -> TensorTy: if condition.dtype != tl.int1: warnings.warn( f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}" ) condition = self.cast(condition, tl.int1) x, y = self.binary_op_type_checking_impl(x, y, True, True) # x, y are broadcasted if condition.type.is_block(): condition, x = self.broadcast_impl_value(condition, x) x, y = self.broadcast_impl_value(x, y) else: condition, _ = self.broadcast_impl_value(condition, x) ret_ty = x.type return self.tensor(self.builder.create_select(condition.handle, x.handle, y.handle), ret_ty) # ===----------------------------------------------------------------------===// # Reduction # ===----------------------------------------------------------------------=== def wrap_tensor(self, x, scalar_ty, ret_shape): if ret_shape: res_ty = tl.block_type(scalar_ty, ret_shape) else: # 0d-tensor -> scalar res_ty = scalar_ty return self.tensor(x, res_ty) def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]: if axis is None: inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=True) for t in inputs) axis = 0 # get result shape shape = inputs[0].type.shape rank = len(shape) assert axis < rank, f"reduction axis must be < inputs rank ({rank})" ret_shape = [s for i, s in enumerate(shape) if i != axis] assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis) region_builder_fn(reduce_op) assert reduce_op.verify() return tuple( self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) # ===----------------------------------------------------------------------=== # Associative Scan # ===----------------------------------------------------------------------=== def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn, reverse: bool) -> Tuple[TensorTy, ...]: shape = inputs[0].type.shape rank = len(shape) assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" if axis < 0: axis += rank for t in inputs: assert t.type.shape == shape, "all scan inputs must have the same shape" scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse) region_builder_fn(scan_op) assert scan_op.verify() return tuple(self.wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) # ===----------------------------------------------------------------------=== # Gather # ===----------------------------------------------------------------------=== def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy: assert index.dtype.is_int(), "index must be an integer tensor" rank = len(src.type.shape) assert len(index.type.shape) == rank, "source and index tensors must have the same rank" assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" if axis < 0: axis += rank for d in range(rank): if d == axis: continue assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" gather = self.builder.create_gather(src.handle, index.handle, axis) return self.wrap_tensor(gather, src.type.scalar, index.type.shape) # ===----------------------------------------------------------------------=== # Map Elementwise # ===----------------------------------------------------------------------=== def broadcast_tensors(self, *inputs): if not inputs: return () head, *tail = inputs for i in range(len(tail)): head, tail[i] = self.broadcast_impl_value(head, tail[i]) for i in range(len(tail)): head, tail[i] = self.broadcast_impl_value(head, tail[i]) return (head, *tail) def map_elementwise(self, inputs: Sequence[tl.tensor], result_types: Sequence[tl.dtype], pack: int, region_builder_fn) -> Tuple[tl.tensor, ...]: inputs = self.broadcast_tensors(*inputs) assert len(inputs) > 0, "map_elementwise must have at least 1 input tensor" result_types = [inputs[0].type.with_element_ty(ty.scalar) for ty in result_types] elementwise_op = self.builder.create_map_elementwise( [t.handle for t in inputs], [ty.to_ir(self.builder) for ty in result_types], pack, ) region_builder_fn(elementwise_op) # assert elementwise_op.verify() return tuple(self.tensor(elementwise_op.get_result(i), ty) for i, ty in enumerate(result_types)) # ===----------------------------------------------------------------------=== # Histogram # ===----------------------------------------------------------------------=== def histogram(self, input: TensorTy, num_bins: int, mask: Optional[TensorTy]) -> TensorTy: assert len(input.shape) == 1, "histogram only supports 1D input" assert input.dtype.is_int(), "histogram only supports integer input" if mask is not None: mask = self.broadcast_impl_shape(mask, input.shape) if not mask.type.scalar.is_bool(): raise ValueError("Mask must have boolean scalar type") mask = mask.handle return self.tensor(self.builder.create_histogram(input.handle, num_bins, mask), tl.block_type(tl.int32, [num_bins])) def multiple_of(self, x: TensorTy, values: List[int]) -> TensorTy: if max(1, len(x.shape)) != len(values): raise ValueError("Shape of input to multiple_of does not match the length of values") x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) return x def max_contiguous(self, x: TensorTy, values: List[int]) -> TensorTy: if len(x.shape) != len(values): raise ValueError("Shape of input to max_contiguous does not match the length of values") x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) return x def max_constancy(self, x: TensorTy, values: List[int]) -> TensorTy: if len(x.shape) != len(values): raise ValueError("Shape of input to max_constancy does not match the length of values") x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) return x def debug_barrier(self) -> TensorTy: return self.tensor(self.builder.create_barrier(), tl.void) def device_print(self, prefix: str, args: List[TensorTy], hex: bool) -> TensorTy: # It makes sense visually for prefix to end in ": "; make it so. Also, # non-empty prefixes should start with " ". if not prefix.endswith(" ") and args: prefix += " " if not prefix.endswith(": ") and args: prefix = prefix[:-1] + ": " if len(prefix) > 2 and not prefix.startswith(" "): prefix = " " + prefix new_args = [arg.handle for arg in args] is_signed = [arg.dtype.is_int_signed() for arg in args] return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void) def device_assert(self, cond: TensorTy, msg: str, mask: Optional[TensorTy]) -> TensorTy: if not self.builder.options.debug: return if mask is not None: cond = self.or_(cond, self.not_(mask)) return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void) def assume(self, cond) -> TensorTy: return self.tensor(self.builder.create_assume(cond.handle), tl.void) def _convert_elem_to_ir_value(self, elem, require_i64): if isinstance(elem, int): elem = tl.constexpr(elem) if isinstance(elem, tl.constexpr): if isinstance(elem.value, bool): return self.builder.get_int1(elem.value) if require_i64: assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ f"got a value {elem.value} which is out of the range" return self.builder.get_int64(elem.value) else: assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ f"got a value {elem.value} which is out of the range" return self.builder.get_int32(elem.value) elif isinstance(elem, tl.tensor): assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" if elem.dtype != tl.int64 and require_i64: return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(), elem.dtype.is_int_signed()) elif elem.dtype == tl.int64 and not require_i64: assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ "add a `.to(tl.int32)` or use regular indexing for 64 bit support" return elem.handle assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" def _convert_to_ir_values(self, list_like, require_i64=True): if hasattr(list_like, "__iter__"): return [self._convert_elem_to_ir_value(elem, require_i64) for elem in list_like] return [self._convert_elem_to_ir_value(list_like, require_i64)] def make_block_ptr(self, base: TensorTy, shape, strides, offsets, block_shape, order) -> TensorTy: # Convert dynamic arguments to IR values # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` shape = self._convert_to_ir_values(shape) strides = self._convert_to_ir_values(strides) offsets = self._convert_to_ir_values(offsets, require_i64=False) # Check `base` type if not base.type.is_ptr() or base.type.element_ty.is_block(): raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") # Treat `pointer_type` as `pointer_type` if base.type.element_ty == tl.int1: base = self.cast(base, tl.pointer_type(tl.int8, base.type.address_space)) # Check whether `block_shape` is static if not hasattr(block_shape, "__iter__"): block_shape = [block_shape] block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ "Expected a list of constant integers (`int32_t` range) in `block_shape`" # Check `order` if not hasattr(order, "__iter__"): order = [order] order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" # Must have same length assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ "Expected shape/strides/offsets/block_shape to have the same length" # Build value, the type is: # `pointer_type>` in Python # `tt.ptr>` in MLIR handle = self.builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) return self.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) def advance(self, base: TensorTy, offsets) -> TensorTy: # Convert dynamic offsets to IR values offsets = self._convert_to_ir_values(offsets, require_i64=False) # Advanced block pointer type is the same as before return self.tensor(self.builder.create_advance(base.handle, offsets), base.type) def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: List[TensorTy], block_shape: List[tl.constexpr], padding_option: str = "zero") -> tl.tensor_descriptor: ndim = len(shape) if not (1 <= ndim <= 5): raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") if len(strides) != ndim: raise ValueError(f"Expected {ndim} strides but got {len(strides)}") if len(block_shape) != ndim: raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") assert isinstance(base.dtype, tl.pointer_type) elem_size = base.dtype.element_ty.primitive_bitwidth // 8 contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1]) if contig_dim_size * elem_size < 16: raise ValueError( f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" ) last_stride = tl._unwrap_if_constexpr(strides[-1]) if last_stride != 1: raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}") shape = [self.make_scalar(x, tl.int32) for x in shape] strides = [self.make_scalar(tl._unwrap_if_constexpr(x), tl.int64) for x in strides] # Check whether `block_shape` is static block_shape = tl._unwrap_shape(block_shape) assert isinstance(base.type, tl.pointer_type) type = tl.block_type(base.type.element_ty, block_shape) base_handle = base.handle is_signed_int = base.type.element_ty.is_int_signed() padding = self._str_to_padding_option(padding_option) if base.type.element_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: raise ValueError("Padding option `nan` is not supported for integer blocks") handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], [s.handle for s in strides], block_shape, is_signed_int, padding) return tl.tensor_descriptor(handle, shape, strides, type)