// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // Internal utilities for dealing with NumPy #pragma once #include "arrow/python/numpy_init.h" #include "arrow/python/numpy_interop.h" #include "arrow/status.h" #include "arrow/python/platform.h" #include #include #include namespace arrow { namespace py { /// Indexing convenience for interacting with strided 1-dim ndarray objects template class Ndarray1DIndexer { public: typedef int64_t size_type; Ndarray1DIndexer() : arr_(NULLPTR), data_(NULLPTR) {} explicit Ndarray1DIndexer(PyArrayObject* arr) : Ndarray1DIndexer() { arr_ = arr; ARROW_DCHECK_EQ(1, PyArray_NDIM(arr)) << "Only works with 1-dimensional arrays"; data_ = reinterpret_cast(PyArray_DATA(arr)); stride_ = PyArray_STRIDES(arr)[0]; } ~Ndarray1DIndexer() = default; int64_t size() const { return PyArray_SIZE(arr_); } const T* data() const { return reinterpret_cast(data_); } bool is_strided() const { return stride_ != sizeof(T); } T& operator[](size_type index) { return *reinterpret_cast(data_ + index * stride_); } const T& operator[](size_type index) const { return *reinterpret_cast(data_ + index * stride_); } private: PyArrayObject* arr_; uint8_t* data_; int64_t stride_; }; // Handling of Numpy Types by their static numbers // (the NPY_TYPES enum and related defines) static inline std::string GetNumPyTypeName(int npy_type) { #define TYPE_CASE(TYPE, NAME) \ case NPY_##TYPE: \ return NAME; switch (npy_type) { TYPE_CASE(BOOL, "bool") TYPE_CASE(INT8, "int8") TYPE_CASE(INT16, "int16") TYPE_CASE(INT32, "int32") TYPE_CASE(INT64, "int64") #if !NPY_INT32_IS_INT TYPE_CASE(INT, "intc") #endif #if !NPY_INT64_IS_LONG_LONG TYPE_CASE(LONGLONG, "longlong") #endif TYPE_CASE(UINT8, "uint8") TYPE_CASE(UINT16, "uint16") TYPE_CASE(UINT32, "uint32") TYPE_CASE(UINT64, "uint64") #if !NPY_INT32_IS_INT TYPE_CASE(UINT, "uintc") #endif #if !NPY_INT64_IS_LONG_LONG TYPE_CASE(ULONGLONG, "ulonglong") #endif TYPE_CASE(FLOAT16, "float16") TYPE_CASE(FLOAT32, "float32") TYPE_CASE(FLOAT64, "float64") TYPE_CASE(DATETIME, "datetime64") TYPE_CASE(TIMEDELTA, "timedelta64") TYPE_CASE(OBJECT, "object") TYPE_CASE(VOID, "void") default: break; } #undef TYPE_CASE std::stringstream ss; ss << "unrecognized type (" << npy_type << ") in GetNumPyTypeName"; return ss.str(); } #define TYPE_VISIT_INLINE(TYPE) \ case NPY_##TYPE: \ return visitor->template Visit(arr); template inline Status VisitNumpyArrayInline(PyArrayObject* arr, VISITOR* visitor) { switch (PyArray_TYPE(arr)) { TYPE_VISIT_INLINE(BOOL); TYPE_VISIT_INLINE(INT8); TYPE_VISIT_INLINE(UINT8); TYPE_VISIT_INLINE(INT16); TYPE_VISIT_INLINE(UINT16); TYPE_VISIT_INLINE(INT32); TYPE_VISIT_INLINE(UINT32); TYPE_VISIT_INLINE(INT64); TYPE_VISIT_INLINE(UINT64); #if !NPY_INT32_IS_INT TYPE_VISIT_INLINE(INT); TYPE_VISIT_INLINE(UINT); #endif #if !NPY_INT64_IS_LONG_LONG TYPE_VISIT_INLINE(LONGLONG); TYPE_VISIT_INLINE(ULONGLONG); #endif TYPE_VISIT_INLINE(FLOAT16); TYPE_VISIT_INLINE(FLOAT32); TYPE_VISIT_INLINE(FLOAT64); TYPE_VISIT_INLINE(DATETIME); TYPE_VISIT_INLINE(TIMEDELTA); TYPE_VISIT_INLINE(OBJECT); } return Status::NotImplemented("NumPy type not implemented: ", GetNumPyTypeName(PyArray_TYPE(arr))); } #undef TYPE_VISIT_INLINE namespace internal { inline bool PyFloatScalar_Check(PyObject* obj) { if (has_numpy()) { return PyFloat_Check(obj) || PyArray_IsScalar(obj, Floating); } else { return PyFloat_Check(obj); } } inline bool PyIntScalar_Check(PyObject* obj) { if (has_numpy()) { return PyLong_Check(obj) || PyArray_IsScalar(obj, Integer); } else { return PyLong_Check(obj); } } inline bool PyBoolScalar_Check(PyObject* obj) { if (has_numpy()) { return PyBool_Check(obj) || PyArray_IsScalar(obj, Bool); } else { return PyBool_Check(obj); } } static inline PyArray_Descr* GetSafeNumPyDtype(int type) { if (type == NPY_DATETIME || type == NPY_TIMEDELTA) { // It is not safe to mutate the result of DescrFromType for datetime and // timedelta descriptors return PyArray_DescrNewFromType(type); } else { return PyArray_DescrFromType(type); } } } // namespace internal } // namespace py } // namespace arrow