// Copyright © 2022 Apple Inc. #pragma once #include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #endif #include @interface MPSGraph (PyTorchFixups) - (MPSGraphTensor*)minimumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor secondaryTensor:(MPSGraphTensor*)secondaryTensor name:(NSString*)name; - (MPSGraphTensor*)maximumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor secondaryTensor:(MPSGraphTensor*)secondaryTensor name:(NSString*)name; @end using namespace at::mps; namespace at::native::mps { void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); struct MPSScalar { id getMTLBuffer() const { return __builtin_bit_cast(id, buffer.get()); } size_t size = 0; ScalarType type = ScalarType::Undefined; c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope) union { float f; // MPS doesn't support 'double' at::Half h; int64_t i; bool b; c10::complex cf; c10::complex ch; at::BFloat16 bf16; } value{}; }; void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results); MPSDataType getMPSDataType(ScalarType scalar_type); static inline MPSDataType getMPSDataType(const TensorBase& t) { return getMPSDataType(t.scalar_type()); } MPSDataType getMPSScalarType(ScalarType scalar_type); static inline MPSDataType getMPSScalarType(const TensorBase& t) { return getMPSScalarType(t.scalar_type()); } MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); static inline std::string getMPSTypeString(const TensorBase& t, bool short_name = false) { return getMPSTypeString(t.scalar_type(), short_name); } std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); static inline std::string scalarToMetalTypeString(const TensorBase& t) { return scalarToMetalTypeString(t.scalar_type()); } NSArray* getTensorAxes(const TensorBase& t); NSArray* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim); std::string getMPSShapeString(MPSShape* shape); std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false); std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view Tensor gatherViewTensor(const Tensor& src, Tensor& dst); Tensor& scatterViewTensor(const Tensor& src, Tensor& output); MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray); MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil); // The MPSShape could vary based on memory format Tensor getTensorView(const Tensor& t, MPSShape* shape); MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); static inline id getMTLBufferStorage(const TensorBase& tensor) { return __builtin_bit_cast(id, tensor.storage().data()); } class Placeholder { public: Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {} Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {} Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray); Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape* mpsShape = nullptr, bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true); MPSGraphTensor* getMPSGraphTensor() { return _placeholder; } MPSGraphTensorData* getMPSGraphTensorData() { return _value; } bool isIntermediate() { return _value == nullptr; } private: MPSGraphTensor* _placeholder; MPSGraphTensorData* _value; Tensor _tensor; }; void resize_tensor(Tensor* output); Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device); MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor); MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType); MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType); MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor); MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); MPSGraph* make_mps_graph(); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase& tensor); MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar); std::string get_mem_format_string(c10::MemoryFormat memory_format); using MPSCacheKey = uint64_t; struct MPSCachedKernel { MPSCachedKernel(NSObject* object) : _object([object retain]) {} virtual ~MPSCachedKernel() { [_object release]; _object = nullptr; } // Delete copy constructor and assignment MPSCachedKernel(const MPSCachedKernel&) = delete; void operator=(const MPSCachedKernel&) = delete; template inline T* kernel() const { return (T*)_object; } private: NSObject* _object = nullptr; }; // derive this class to cache a graph and its inputs/outputs // can be used to store any NSObject struct MPSCachedGraph { MPSCachedGraph(NSObject* object) : _object([object retain]) {} virtual ~MPSCachedGraph() { [_object release]; _object = nullptr; } template inline T* as() { return static_cast(this); } MPSGraph* graph() const { return (MPSGraph*)_object; } NSObject* object() const { return _object; } private: NSObject* _object = nullptr; }; struct MPSUnaryCachedGraph : public MPSCachedGraph { MPSUnaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; }; struct MPSUnaryGradCachedGraph : public MPSCachedGraph { MPSUnaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; // some backward input is actually the forward's output MPSGraphTensor* gradInputTensor_ = nil; }; struct MPSBinaryCachedGraph : public MPSCachedGraph { MPSBinaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* otherTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; }; struct MPSBinaryGradCachedGraph : public MPSCachedGraph { MPSBinaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* otherTensor_ = nil; MPSGraphTensor* gradInputTensor_ = nil; }; struct MPSKernelCache { typedef MPSCachedKernel* (^CreateCachedKernelBlock)(); struct CacheEntry { CacheEntry(const std::string& key, MPSCachedKernel* cachedKernel) : cachedKernel_(cachedKernel), key_(key) {} MPSCachedKernel* cachedKernel_ = nullptr; std::string key_; }; public: static MPSKernelCache* getInstance() { if (_instance_cache == nullptr) { _instance_cache = new MPSKernelCache(); } return _instance_cache; } ~MPSKernelCache() { dispatch_release(serialQueue_); for (const auto& i : cache_) { delete i.second.cachedKernel_; } } // Disallow the copy constructor and operator= functions MPSKernelCache(const MPSKernelCache&) = delete; void operator=(const MPSKernelCache&) = delete; MPSCachedKernel* CreateCachedKernel(const std::string& key, CreateCachedKernelBlock createCacheBlock) { __block MPSCachedKernel* cachedKernel = nil; MPSCacheKey hash = std::hash{}(key); dispatch_sync_with_rethrow(serialQueue_, ^() { if (cache_.count(hash) != 0) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); cachedKernel = entry.cachedKernel_; } else { cachedKernel = createCacheBlock(); CacheEntry entry(key, cachedKernel); cache_.emplace(hash, entry); } }); return cachedKernel; } template inline T* CreateCachedKernelAs(const std::string& key, CreateCachedKernelBlock createCacheBlock) { return static_cast(CreateCachedKernel(key, createCacheBlock)); } MPSCachedKernel* LookUp(const std::string& key) const { __block MPSCachedKernel* cachedKernel = nil; MPSCacheKey hash = std::hash{}(key); dispatch_sync_with_rethrow(serialQueue_, ^() { if (cache_.count(hash) != 0) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); cachedKernel = entry.cachedKernel_; } }); return cachedKernel; } template inline T* LookUpAs(const std::string& key) const { return static_cast(LookUp(key)); } private: MPSKernelCache() { serialQueue_ = dispatch_queue_create("kernel cache queue", DISPATCH_QUEUE_SERIAL); } static MPSKernelCache* _instance_cache; std::unordered_map cache_; dispatch_queue_t serialQueue_ = nullptr; }; // Common template for creating cached kernel if missing template inline T* LookUpOrCreateCachedKernel(const std::string& key, std::function instantiate) { auto cache_ = MPSKernelCache::getInstance(); if (auto rc = cache_->LookUpAs(key)) { return rc; } return cache_->CreateCachedKernelAs(key, ^mps::MPSCachedKernel*() { auto k_ = new mps::MPSCachedKernel(instantiate()); return k_; }); } // TODO: Improve the overall design of MPSGraphCache. // https://github.com/pytorch/pytorch/issues/77176 // Cache holding various keys mapped to graphs struct MPSGraphCache { typedef MPSCachedGraph* (^CreateCachedGraphBlock)(); struct CacheEntry { CacheEntry(const std::string& key, MPSCachedGraph* cachedGraph) : cachedGraph_(cachedGraph), key_(key) {} MPSCachedGraph* cachedGraph_ = nullptr; std::string key_; }; public: static MPSGraphCache* getInstance() { if (_instance_cache == nullptr) { _instance_cache = new MPSGraphCache(); } return _instance_cache; } ~MPSGraphCache() { dispatch_release(serialQueue_); for (const auto& i : cache_) { delete i.second.cachedGraph_; } } // Disallow the copy constructor and operator= functions MPSGraphCache(const MPSGraphCache&) = delete; void operator=(const MPSGraphCache&) = delete; MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) { __block MPSCachedGraph* cachedGraph = nil; MPSCacheKey hash = std::hash{}(key); dispatch_sync_with_rethrow(serialQueue_, ^() { // verify the cached entry doesn't already exist if (cache_.count(hash) != 0) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); cachedGraph = entry.cachedGraph_; } else { cachedGraph = createCacheBlock(); CacheEntry entry(key, cachedGraph); cache_.emplace(hash, entry); profileCachedGraph(entry); } }); return cachedGraph; } template inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) { return static_cast(CreateCachedGraph(key, createCacheBlock)); } MPSCachedGraph* LookUp(const std::string& key) const { __block MPSCachedGraph* cachedGraph = nullptr; MPSCacheKey hash = std::hash{}(key); dispatch_sync(serialQueue_, ^() { if (cache_.count(hash) != 0) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); cachedGraph = entry.cachedGraph_; profileCachedGraph(entry); } }); return cachedGraph; } template inline T* LookUpAs(const std::string& key) const { return static_cast(LookUp(key)); } private: MPSGraphCache() { serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL); } // this is defined in OperationUtils.mm to not include // MPSProfiler.h in header OperationUtils.h void profileCachedGraph(const CacheEntry& cacheEntry) const; static MPSGraphCache* _instance_cache; std::unordered_map cache_; dispatch_queue_t serialQueue_ = nullptr; }; // Common template for creating graph with a specified cache if missing template inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function instantiate) { auto cache_ = MPSGraphCache::getInstance(); if (auto rc = cache_->LookUpAs(key)) { return rc; } return cache_->CreateCachedGraphAs(key, ^mps::MPSCachedGraph*() { T* newCachedGraph = nil; @autoreleasepool { // Initialize graph auto mpsGraph = mps::make_mps_graph(); newCachedGraph = new T(mpsGraph); instantiate(mpsGraph, newCachedGraph); } return newCachedGraph; }); } // Common math operations MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); /** * Returns distance from lowest to highest element offset in given tensor. */ size_t compute_storage_numel_distance(const TensorBase& t); /** * Checks whether tensor is mapped to a contiguous area in the storage. */ inline bool is_dense_in_storage(const TensorBase& t) { return compute_storage_numel_distance(t) == static_cast(t.numel()); } template , encoder_t> || std::is_same_v, encoder_t>>> static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) { if (C10_UNLIKELY(t.device().type() == kCPU)) { if constexpr (std::is_same_v, encoder_t>) { TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op"); // MPS does not support doubles, silently downcast CPU scalar to float if (C10_UNLIKELY(t.scalar_type() == kDouble)) { auto val = static_cast(*reinterpret_cast(t.const_data_ptr())); [encoder setBytes:&val length:sizeof(val) atIndex:idx]; return; } if (C10_UNLIKELY(t.scalar_type() == kComplexDouble)) { auto val = static_cast>(*reinterpret_cast*>(t.const_data_ptr())); [encoder setBytes:&val length:sizeof(val) atIndex:idx]; return; } [encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx]; } else { TORCH_CHECK(false, "Passed CPU tensor to MPS op"); } return; } [encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx]; } // Implementation of setBytes for containers vs trivially copiable types must be separate // Containers like `std::array` could have been uploaded directly, but `c10::ArrayRef`, // while trivially copiable, includes padding which if copied as Metal shader parameters // might overwrite other values template < typename T, typename = std::enable_if_t || std::is_same_v || (std::is_class_v && std::is_trivially_copyable_v && !detail::has_size_type_v)>> static inline void mtl_setBytes(id encoder, const T val, unsigned idx) { [encoder setBytes:&val length:sizeof(T) atIndex:idx]; } template >> static inline void mtl_setBytes(id encoder, const Container& values, unsigned idx) { [encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx]; } static inline void mtl_setBytes(id encoder, const MPSScalar& s, unsigned idx) { [encoder setBytes:&s.value length:s.size atIndex:idx]; } static size_t iter_tensor_offset(TensorIteratorBase& iter, unsigned idx) { // At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id object // But TensorIterator constructs data_ptr as if base was just a raw pointer // Workaround this problem by computing an offset from the start of the tensor, which works for both // tensor views and sliced 64-bit iterators return reinterpret_cast(iter.data_ptr(idx)) - reinterpret_cast(iter.tensor_base(idx).storage().data()); } static inline void bind_iter_tensors(id encoder, TensorIteratorBase& iter, std::optional ntensors = std::nullopt) { for (auto idx : c10::irange(ntensors.value_or(iter.ntensors()))) { auto& t = iter.tensor_base(idx); // Handle CPU scalars if (C10_UNLIKELY(t.device().type() == kCPU)) { mtl_setBuffer(encoder, t, idx); continue; } auto offs = iter_tensor_offset(iter, idx); [encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx]; } } namespace detail { template inline void mtl_setArg(id encoder, const T& val, unsigned idx) { mtl_setBytes(encoder, val, idx); } inline void mtl_setArg(id encoder, id val, unsigned idx) { [encoder setBuffer:val offset:0 atIndex:idx]; } template <> inline void mtl_setArg(id encoder, const Tensor& val, unsigned idx) { mtl_setBuffer(encoder, val, idx); } template <> inline void mtl_setArg(id encoder, const std::optional& val, unsigned idx) { if (val.has_value()) { mtl_setBuffer(encoder, val.value(), idx); } } template <> inline void mtl_setArg(id encoder, const TensorBase& val, unsigned idx) { mtl_setBuffer(encoder, val, idx); } // MPS does not support doubles, so cast it down to float before passing as an argument template <> inline void mtl_setArg(id encoder, const double& val, unsigned idx) { float val_f = static_cast(val); mtl_setBytes(encoder, val_f, idx); } } // namespace detail template static inline void mtl_setArgs(id encoder, const T& val) { detail::mtl_setArg(encoder, val, idx); } template static inline void mtl_setArgs(id encoder, const T& val, Args&&... args) { detail::mtl_setArg(encoder, val, idx); mtl_setArgs(encoder, std::forward(args)...); } static inline void mtl_dispatch1DJob(id encoder, id cplState, NSUInteger length) { static_assert(sizeof(NSUInteger) == sizeof(uint64_t)); const auto maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup]; auto size = MTLSizeMake(length, 1, 1); auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1); [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; } id generateKernelDataOffsets(id commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false); inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) { return @{p1.getMPSGraphTensor() : p1.getMPSGraphTensorData()}; } inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) { return @{ p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), }; } inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) { return @{ p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(), }; } inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) { return @{ p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(), p4.getMPSGraphTensor() : p4.getMPSGraphTensorData(), }; } inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) { runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result)); } // MPS yet to support double types, but starting from MacOS 14, supports bfloat16 inline bool supportedFloatingType(ScalarType dtype) { return dtype == kFloat || dtype == kHalf || dtype == kBFloat16; } inline bool supportedFloatingType(const TensorBase& t) { return supportedFloatingType(t.scalar_type()); } inline bool supportedFloatingOrComplexType(ScalarType dtype) { if (dtype == kComplexFloat || dtype == kComplexHalf) { return true; } return supportedFloatingType(dtype); } inline bool supportedFloatingOrComplexType(const TensorBase& t) { return supportedFloatingOrComplexType(t.scalar_type()); } inline bool needsGather(const TensorBase& t) { static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()); } } // namespace at::native::mps