#pragma once #include #include #include #include #if AT_KLEIDIAI_ENABLED() namespace at::native::kleidiai { template void kai_pack_rhs_groupwise_int4( T& kernel, const Tensor& weight_packed, const Tensor& weight, const Tensor& scales, const std::optional& bias, const int64_t n, const int64_t k, const int64_t bl, const int64_t rhs_stride, const int64_t scale_stride) { const auto& ukernel = kernel.ukernel; const size_t nr = ukernel.get_nr(); const size_t kr = ukernel.get_kr(); const size_t sr = ukernel.get_sr(); auto weight_packed_data = reinterpret_cast(weight_packed.data_ptr()); const auto weight_data = weight.data_ptr(); auto scales_data = scales.const_data_ptr(); if (weight_data == nullptr) { AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); } if (scales_data == nullptr) { AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); } float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; auto& params = kernel.rhs_pack_params; kernel.kai_run_rhs_pack( /*num_groups=*/1, n, k, nr, kr, sr, bl, (const uint8_t*)(weight_data), rhs_stride, bias_ptr, scales_data, scale_stride, weight_packed_data, 0, ¶ms); } template void kai_pack_rhs_channelwise_int4( T& kernel, const Tensor& weight_packed, const Tensor& weight, const Tensor& scales, const std::optional& bias, const int64_t n, const int64_t k) { const auto& ukernel = kernel.ukernel; const size_t nr = ukernel.get_nr(); const size_t kr = ukernel.get_kr(); const size_t sr = ukernel.get_sr(); auto weight_packed_data = reinterpret_cast(weight_packed.data_ptr()); const auto weight_data = weight.data_ptr(); const auto scales_data = scales.data_ptr(); if (weight_data == nullptr) { AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); } if (scales_data == nullptr) { AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); } float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; auto& params = kernel.rhs_pack_params; kernel.kai_run_rhs_pack( /*num_groups=*/1, n, k, nr, kr, sr, (const uint8_t*)(weight_data), (const float*)(bias_ptr), (const float*)(scales_data), weight_packed_data, 0, ¶ms); } } // namespace at::native::kleidiai #endif