#pragma once #include #include namespace at::cuda::detail { TORCH_API void f8f8bf16_grouped_mm( at::Tensor mat_a, // FP8 at::Tensor mat_b, // FP8 at::Tensor scale_a, // FP32 at::Tensor scale_b, // FP32 std::optional offs, std::optional bias, // BF16 bool use_fast_accum, at::Tensor& out); } // namespace at::cuda::detail