#pragma once #include namespace at::native::mps { void _fused_adam_mps_impl_( TensorList params, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList state_steps, const double lr, const double beta1, const double beta2, const double weight_decay, const double eps, const bool maximize, const std::optional& grad_scale, const std::optional& found_inf); void _fused_adam_mps_impl_( TensorList params, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList state_steps, const Tensor& lr, const double beta1, const double beta2, const double weight_decay, const double eps, const bool maximize, const std::optional& grad_scale, const std::optional& found_inf); } // namespace at::native::mps