#pragma once #include #ifdef __METAL__ enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic }; enum class GridSamplerPadding { Zeros, Border, Reflection }; #else #include using at::native::GridSamplerInterpolation; using at::native::GridSamplerPadding; #endif template struct GridSamplerParams { int32_t sampler_dims; ::c10::metal::array output_sizes; ::c10::metal::array output_strides; ::c10::metal::array input_sizes; ::c10::metal::array input_strides; ::c10::metal::array grid_sizes; ::c10::metal::array grid_strides; GridSamplerInterpolation interpolation_mode; GridSamplerPadding padding_mode; bool align_corners; };