#pragma once #include // N is the maximum allowed number of dimensions in the input and outputs. The // maximum allowed pooling dimensions is N-2, because the input may have up to 2 // leading dimensions that are not pooled. To support up to 3-D pooling, N=5 is // the default. template struct PoolingParams { int32_t dims; int32_t pooling_dims; ::c10::metal::array input_sizes; ::c10::metal::array input_strides; ::c10::metal::array output_sizes; ::c10::metal::array output_strides; ::c10::metal::array indices_sizes; ::c10::metal::array indices_strides; ::c10::metal::array kernel_size; ::c10::metal::array stride; ::c10::metal::array padding; ::c10::metal::array dilation; bool return_indices; }; template struct AvgPoolingParams { int32_t dims; int32_t pooling_dims; ::c10::metal::array input_sizes; ::c10::metal::array input_strides; ::c10::metal::array output_sizes; ::c10::metal::array output_strides; ::c10::metal::array kernel_size; ::c10::metal::array stride; ::c10::metal::array padding; bool count_include_pad; bool has_divisor_override; int32_t divisor_override; }; template struct PoolingBackwardParams { int32_t dims; int32_t pooling_dims; ::c10::metal::array grad_input_sizes; ::c10::metal::array grad_input_strides; ::c10::metal::array grad_output_sizes; ::c10::metal::array grad_output_strides; ::c10::metal::array indices_strides; }; template struct MaxUnpoolingParams { int32_t dims; int32_t pooling_dims; ::c10::metal::array input_sizes; ::c10::metal::array input_strides; ::c10::metal::array output_sizes; ::c10::metal::array output_strides; ::c10::metal::array indices_strides; };