#pragma once #include #include namespace at::native { using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double); using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&); using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double); DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub) DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub) DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub) // TensorAccessor when it is defined to work around undefined... template static TensorAccessor conditional_accessor_1d(const Tensor& t) { if (! t.defined()) { return TensorAccessor(nullptr, nullptr, nullptr); } return t.accessor(); } template static scalar_t* conditional_data_ptr(const Tensor& t) { if constexpr (std::is_const_v) { return t.defined() ? t.contiguous().const_data_ptr() : nullptr; } else { return t.defined() ? t.contiguous().data_ptr() : nullptr; } } } // namespace at::native