#pragma once #ifdef __OBJC__ #include typedef id MTLLibrary_t; typedef id MTLFunction_t; typedef id MTLComputePipelineState_t; typedef id MTLComputeCommandEncoder_t; #else typedef void MTLCompileOptions; typedef void* MTLLibrary_t; typedef void* MTLFunction_t; typedef void* MTLComputePipelineState_t; typedef void* MTLComputeCommandEncoder_t; #endif #include #include #include #include #include #include #include #include // Forward declaration of TensorBase and TensorIteratorBase namespace at { class TensorBase; struct TensorIteratorBase; } // namespace at namespace at::native::mps { namespace detail { template class has_size_type { template static constexpr std::true_type check(typename U::size_type*); template static constexpr std::false_type check(...); public: static constexpr bool value = decltype(check(nullptr))::value; }; template constexpr bool has_size_type_v = has_size_type::value; } // namespace detail // Returns `gpuAddress` of respective `id` plus storage offset void* get_tensor_gpu_address(const at::TensorBase&); class MetalKernelFunction { public: MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_); ~MetalKernelFunction(); MetalKernelFunction(MetalKernelFunction&) = delete; // Shader properties uint64_t getMaxThreadsPerThreadgroup() const; uint64_t getThreadExecutionWidth() const; uint64_t getStaticThreadGroupMemoryLength() const; void runCommandBlock(std::function f); // Methods below should be called from runCommandBlock function void startEncoding(); void setArg(unsigned idx, const at::TensorBase& t); void setArg(unsigned idx, const void* ptr, uint64_t size); template < typename T, typename = std::enable_if_t< std::is_integral_v || std::is_same_v || (std::is_class_v && std::is_trivially_copyable_v && !detail::has_size_type_v)>> inline void setArg(unsigned idx, const T val) { setArg(idx, &val, sizeof(T)); } template < typename Container, typename = std::enable_if_t>> inline void setArg(unsigned idx, const Container& values) { setArg( idx, values.data(), values.size() * sizeof(typename Container::value_type)); } void dispatch( uint64_t length, std::optional groupSize = std::nullopt); void dispatch( c10::ArrayRef length, c10::OptionalArrayRef groupSize = std::nullopt); private: MTLComputePipelineState_t cps; MTLFunction_t func; MTLComputeCommandEncoder_t encoder = nullptr; }; class MetalShaderLibrary { public: MetalShaderLibrary(std::string src) : shaderSource(std::move(src)), nparams(0), compile_options(nullptr) {} MetalShaderLibrary(std::string src, unsigned nparams_) : shaderSource(std::move(src)), nparams(nparams_), compile_options(nullptr) {} MetalShaderLibrary( std::string src, unsigned nparams_, MTLCompileOptions* compile_options_) : shaderSource(std::move(src)), nparams(nparams_), compile_options(compile_options_) {} MetalShaderLibrary(const MetalShaderLibrary&) = delete; virtual ~MetalShaderLibrary(); std::vector getFunctionNames(); std::shared_ptr getKernelFunction( const std::string& name); inline MTLComputePipelineState_t getPipelineStateForFunc( const std::string& fname) { return getLibraryPipelineState(getLibrary(), fname).first; } MTLComputePipelineState_t getPipelineStateForFunc( const std::string& fname, const std::initializer_list& params) { return getLibraryPipelineState(getLibrary(params), fname).first; } inline MTLFunction_t getMTLFunction(const std::string& fname) { return getLibraryPipelineState(getLibrary(), fname).second; } MTLFunction_t getMTLFunction( const std::string& fname, const std::initializer_list& params) { return getLibraryPipelineState(getLibrary(params), fname).second; } static MetalShaderLibrary& getBundledLibrary(); void exec_unary_kernel( TensorIteratorBase& iter, const std::string& name, const std::optional alpha = std::nullopt, const std::optional scalar_arg_type = std::nullopt); void exec_binary_kernel( TensorIteratorBase& iter, const std::string& name, const std::optional alpha = std::nullopt, const std::optional scalar_arg_type = std::nullopt); protected: virtual MTLLibrary_t getLibrary(); virtual MTLLibrary_t getLibrary( const std::initializer_list& params); MTLLibrary_t library = nullptr; private: std::pair getLibraryPipelineState( MTLLibrary_t lib, const std::string& fname); MTLLibrary_t compileLibrary(const std::string& src); std::string shaderSource; unsigned nparams; MTLCompileOptions* compile_options; std::unordered_map libMap; std::unordered_map< std::string, std::pair> cplMap; }; class DynamicMetalShaderLibrary : public MetalShaderLibrary { public: DynamicMetalShaderLibrary(const std::string& src) : MetalShaderLibrary(src) { // Compile right away getLibrary(); } ~DynamicMetalShaderLibrary() override; }; } // namespace at::native::mps