numpy.h revision 11986:c12e4625ab56
1/*
2    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
3
4    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
5
6    All rights reserved. Use of this source code is governed by a
7    BSD-style license that can be found in the LICENSE file.
8*/
9
10#pragma once
11
12#include "pybind11.h"
13#include "complex.h"
14#include <numeric>
15#include <algorithm>
16#include <array>
17#include <cstdlib>
18#include <cstring>
19#include <sstream>
20#include <string>
21#include <initializer_list>
22#include <functional>
23#include <utility>
24#include <typeindex>
25
26#if defined(_MSC_VER)
27#  pragma warning(push)
28#  pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
29#endif
30
31/* This will be true on all flat address space platforms and allows us to reduce the
32   whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
33   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
34   upon the library user. */
35static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
36
37NAMESPACE_BEGIN(pybind11)
38NAMESPACE_BEGIN(detail)
39template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
40template <typename type> struct is_pod_struct;
41
42struct PyArrayDescr_Proxy {
43    PyObject_HEAD
44    PyObject *typeobj;
45    char kind;
46    char type;
47    char byteorder;
48    char flags;
49    int type_num;
50    int elsize;
51    int alignment;
52    char *subarray;
53    PyObject *fields;
54    PyObject *names;
55};
56
57struct PyArray_Proxy {
58    PyObject_HEAD
59    char *data;
60    int nd;
61    ssize_t *dimensions;
62    ssize_t *strides;
63    PyObject *base;
64    PyObject *descr;
65    int flags;
66};
67
68struct PyVoidScalarObject_Proxy {
69    PyObject_VAR_HEAD
70    char *obval;
71    PyArrayDescr_Proxy *descr;
72    int flags;
73    PyObject *base;
74};
75
76struct numpy_type_info {
77    PyObject* dtype_ptr;
78    std::string format_str;
79};
80
81struct numpy_internals {
82    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
83
84    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
85        auto it = registered_dtypes.find(std::type_index(tinfo));
86        if (it != registered_dtypes.end())
87            return &(it->second);
88        if (throw_if_missing)
89            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
90        return nullptr;
91    }
92
93    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
94        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
95    }
96};
97
98inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
99    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
100}
101
102inline numpy_internals& get_numpy_internals() {
103    static numpy_internals* ptr = nullptr;
104    if (!ptr)
105        load_numpy_internals(ptr);
106    return *ptr;
107}
108
109struct npy_api {
110    enum constants {
111        NPY_C_CONTIGUOUS_ = 0x0001,
112        NPY_F_CONTIGUOUS_ = 0x0002,
113        NPY_ARRAY_OWNDATA_ = 0x0004,
114        NPY_ARRAY_FORCECAST_ = 0x0010,
115        NPY_ENSURE_ARRAY_ = 0x0040,
116        NPY_ARRAY_ALIGNED_ = 0x0100,
117        NPY_ARRAY_WRITEABLE_ = 0x0400,
118        NPY_BOOL_ = 0,
119        NPY_BYTE_, NPY_UBYTE_,
120        NPY_SHORT_, NPY_USHORT_,
121        NPY_INT_, NPY_UINT_,
122        NPY_LONG_, NPY_ULONG_,
123        NPY_LONGLONG_, NPY_ULONGLONG_,
124        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
125        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
126        NPY_OBJECT_ = 17,
127        NPY_STRING_, NPY_UNICODE_, NPY_VOID_
128    };
129
130    static npy_api& get() {
131        static npy_api api = lookup();
132        return api;
133    }
134
135    bool PyArray_Check_(PyObject *obj) const {
136        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
137    }
138    bool PyArrayDescr_Check_(PyObject *obj) const {
139        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
140    }
141
142    PyObject *(*PyArray_DescrFromType_)(int);
143    PyObject *(*PyArray_NewFromDescr_)
144        (PyTypeObject *, PyObject *, int, Py_intptr_t *,
145         Py_intptr_t *, void *, int, PyObject *);
146    PyObject *(*PyArray_DescrNewFromType_)(int);
147    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
148    PyTypeObject *PyArray_Type_;
149    PyTypeObject *PyVoidArrType_Type_;
150    PyTypeObject *PyArrayDescr_Type_;
151    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
152    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
153    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
154    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
155    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
156                                             Py_ssize_t *, PyObject **, PyObject *);
157    PyObject *(*PyArray_Squeeze_)(PyObject *);
158private:
159    enum functions {
160        API_PyArray_Type = 2,
161        API_PyArrayDescr_Type = 3,
162        API_PyVoidArrType_Type = 39,
163        API_PyArray_DescrFromType = 45,
164        API_PyArray_DescrFromScalar = 57,
165        API_PyArray_FromAny = 69,
166        API_PyArray_NewCopy = 85,
167        API_PyArray_NewFromDescr = 94,
168        API_PyArray_DescrNewFromType = 9,
169        API_PyArray_DescrConverter = 174,
170        API_PyArray_EquivTypes = 182,
171        API_PyArray_GetArrayParamsFromObject = 278,
172        API_PyArray_Squeeze = 136
173    };
174
175    static npy_api lookup() {
176        module m = module::import("numpy.core.multiarray");
177        auto c = m.attr("_ARRAY_API");
178#if PY_MAJOR_VERSION >= 3
179        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
180#else
181        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
182#endif
183        npy_api api;
184#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
185        DECL_NPY_API(PyArray_Type);
186        DECL_NPY_API(PyVoidArrType_Type);
187        DECL_NPY_API(PyArrayDescr_Type);
188        DECL_NPY_API(PyArray_DescrFromType);
189        DECL_NPY_API(PyArray_DescrFromScalar);
190        DECL_NPY_API(PyArray_FromAny);
191        DECL_NPY_API(PyArray_NewCopy);
192        DECL_NPY_API(PyArray_NewFromDescr);
193        DECL_NPY_API(PyArray_DescrNewFromType);
194        DECL_NPY_API(PyArray_DescrConverter);
195        DECL_NPY_API(PyArray_EquivTypes);
196        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
197        DECL_NPY_API(PyArray_Squeeze);
198#undef DECL_NPY_API
199        return api;
200    }
201};
202
203inline PyArray_Proxy* array_proxy(void* ptr) {
204    return reinterpret_cast<PyArray_Proxy*>(ptr);
205}
206
207inline const PyArray_Proxy* array_proxy(const void* ptr) {
208    return reinterpret_cast<const PyArray_Proxy*>(ptr);
209}
210
211inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
212   return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
213}
214
215inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
216   return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
217}
218
219inline bool check_flags(const void* ptr, int flag) {
220    return (flag == (array_proxy(ptr)->flags & flag));
221}
222
223NAMESPACE_END(detail)
224
225class dtype : public object {
226public:
227    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
228
229    explicit dtype(const buffer_info &info) {
230        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
231        // If info.itemsize == 0, use the value calculated from the format string
232        m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
233    }
234
235    explicit dtype(const std::string &format) {
236        m_ptr = from_args(pybind11::str(format)).release().ptr();
237    }
238
239    dtype(const char *format) : dtype(std::string(format)) { }
240
241    dtype(list names, list formats, list offsets, size_t itemsize) {
242        dict args;
243        args["names"] = names;
244        args["formats"] = formats;
245        args["offsets"] = offsets;
246        args["itemsize"] = pybind11::int_(itemsize);
247        m_ptr = from_args(args).release().ptr();
248    }
249
250    /// This is essentially the same as calling numpy.dtype(args) in Python.
251    static dtype from_args(object args) {
252        PyObject *ptr = nullptr;
253        if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
254            throw error_already_set();
255        return reinterpret_steal<dtype>(ptr);
256    }
257
258    /// Return dtype associated with a C++ type.
259    template <typename T> static dtype of() {
260        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
261    }
262
263    /// Size of the data type in bytes.
264    size_t itemsize() const {
265        return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
266    }
267
268    /// Returns true for structured data types.
269    bool has_fields() const {
270        return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
271    }
272
273    /// Single-character type code.
274    char kind() const {
275        return detail::array_descriptor_proxy(m_ptr)->kind;
276    }
277
278private:
279    static object _dtype_from_pep3118() {
280        static PyObject *obj = module::import("numpy.core._internal")
281            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
282        return reinterpret_borrow<object>(obj);
283    }
284
285    dtype strip_padding(size_t itemsize) {
286        // Recursively strip all void fields with empty names that are generated for
287        // padding fields (as of NumPy v1.11).
288        if (!has_fields())
289            return *this;
290
291        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
292        std::vector<field_descr> field_descriptors;
293
294        for (auto field : attr("fields").attr("items")()) {
295            auto spec = field.cast<tuple>();
296            auto name = spec[0].cast<pybind11::str>();
297            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
298            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
299            if (!len(name) && format.kind() == 'V')
300                continue;
301            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
302        }
303
304        std::sort(field_descriptors.begin(), field_descriptors.end(),
305                  [](const field_descr& a, const field_descr& b) {
306                      return a.offset.cast<int>() < b.offset.cast<int>();
307                  });
308
309        list names, formats, offsets;
310        for (auto& descr : field_descriptors) {
311            names.append(descr.name);
312            formats.append(descr.format);
313            offsets.append(descr.offset);
314        }
315        return dtype(names, formats, offsets, itemsize);
316    }
317};
318
319class array : public buffer {
320public:
321    PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322
323    enum {
324        c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
325        f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
326        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
327    };
328
329    array() : array(0, static_cast<const double *>(nullptr)) {}
330
331    array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
332          const std::vector<size_t> &strides, const void *ptr = nullptr,
333          handle base = handle()) {
334        auto& api = detail::npy_api::get();
335        auto ndim = shape.size();
336        if (shape.size() != strides.size())
337            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
338        auto descr = dt;
339
340        int flags = 0;
341        if (base && ptr) {
342            if (isinstance<array>(base))
343                /* Copy flags from base (except baseship bit) */
344                flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
345            else
346                /* Writable by default, easy to downgrade later on if needed */
347                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
348        }
349
350        auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
351            api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
352            (Py_intptr_t *) strides.data(), const_cast<void *>(ptr), flags, nullptr));
353        if (!tmp)
354            pybind11_fail("NumPy: unable to create array!");
355        if (ptr) {
356            if (base) {
357                detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
358            } else {
359                tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
360            }
361        }
362        m_ptr = tmp.release().ptr();
363    }
364
365    array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
366          const void *ptr = nullptr, handle base = handle())
367        : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368
369    array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
370          handle base = handle())
371        : array(dt, std::vector<size_t>{ count }, ptr, base) { }
372
373    template<typename T> array(const std::vector<size_t>& shape,
374                               const std::vector<size_t>& strides,
375                               const T* ptr, handle base = handle())
376    : array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
377
378    template <typename T>
379    array(const std::vector<size_t> &shape, const T *ptr,
380          handle base = handle())
381        : array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
382
383    template <typename T>
384    array(size_t count, const T *ptr, handle base = handle())
385        : array(std::vector<size_t>{ count }, ptr, base) { }
386
387    explicit array(const buffer_info &info)
388    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
389
390    /// Array descriptor (dtype)
391    pybind11::dtype dtype() const {
392        return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
393    }
394
395    /// Total number of elements
396    size_t size() const {
397        return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
398    }
399
400    /// Byte size of a single element
401    size_t itemsize() const {
402        return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
403    }
404
405    /// Total number of bytes
406    size_t nbytes() const {
407        return size() * itemsize();
408    }
409
410    /// Number of dimensions
411    size_t ndim() const {
412        return (size_t) detail::array_proxy(m_ptr)->nd;
413    }
414
415    /// Base object
416    object base() const {
417        return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
418    }
419
420    /// Dimensions of the array
421    const size_t* shape() const {
422        return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
423    }
424
425    /// Dimension along a given axis
426    size_t shape(size_t dim) const {
427        if (dim >= ndim())
428            fail_dim_check(dim, "invalid axis");
429        return shape()[dim];
430    }
431
432    /// Strides of the array
433    const size_t* strides() const {
434        return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
435    }
436
437    /// Stride along a given axis
438    size_t strides(size_t dim) const {
439        if (dim >= ndim())
440            fail_dim_check(dim, "invalid axis");
441        return strides()[dim];
442    }
443
444    /// Return the NumPy array flags
445    int flags() const {
446        return detail::array_proxy(m_ptr)->flags;
447    }
448
449    /// If set, the array is writeable (otherwise the buffer is read-only)
450    bool writeable() const {
451        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
452    }
453
454    /// If set, the array owns the data (will be freed when the array is deleted)
455    bool owndata() const {
456        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
457    }
458
459    /// Pointer to the contained data. If index is not provided, points to the
460    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
461    template<typename... Ix> const void* data(Ix... index) const {
462        return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
463    }
464
465    /// Mutable pointer to the contained data. If index is not provided, points to the
466    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
467    /// May throw if the array is not writeable.
468    template<typename... Ix> void* mutable_data(Ix... index) {
469        check_writeable();
470        return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
471    }
472
473    /// Byte offset from beginning of the array to a given index (full or partial).
474    /// May throw if the index would lead to out of bounds access.
475    template<typename... Ix> size_t offset_at(Ix... index) const {
476        if (sizeof...(index) > ndim())
477            fail_dim_check(sizeof...(index), "too many indices for an array");
478        return byte_offset(size_t(index)...);
479    }
480
481    size_t offset_at() const { return 0; }
482
483    /// Item count from beginning of the array to a given index (full or partial).
484    /// May throw if the index would lead to out of bounds access.
485    template<typename... Ix> size_t index_at(Ix... index) const {
486        return offset_at(index...) / itemsize();
487    }
488
489    /// Return a new view with all of the dimensions of length 1 removed
490    array squeeze() {
491        auto& api = detail::npy_api::get();
492        return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
493    }
494
495    /// Ensure that the argument is a NumPy array
496    /// In case of an error, nullptr is returned and the Python error is cleared.
497    static array ensure(handle h, int ExtraFlags = 0) {
498        auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
499        if (!result)
500            PyErr_Clear();
501        return result;
502    }
503
504protected:
505    template<typename, typename> friend struct detail::npy_format_descriptor;
506
507    void fail_dim_check(size_t dim, const std::string& msg) const {
508        throw index_error(msg + ": " + std::to_string(dim) +
509                          " (ndim = " + std::to_string(ndim()) + ")");
510    }
511
512    template<typename... Ix> size_t byte_offset(Ix... index) const {
513        check_dimensions(index...);
514        return byte_offset_unsafe(index...);
515    }
516
517    template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
518        return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
519    }
520
521    template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
522
523    void check_writeable() const {
524        if (!writeable())
525            throw std::runtime_error("array is not writeable");
526    }
527
528    static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
529        auto ndim = shape.size();
530        std::vector<size_t> strides(ndim);
531        if (ndim) {
532            std::fill(strides.begin(), strides.end(), itemsize);
533            for (size_t i = 0; i < ndim - 1; i++)
534                for (size_t j = 0; j < ndim - 1 - i; j++)
535                    strides[j] *= shape[ndim - 1 - i];
536        }
537        return strides;
538    }
539
540    template<typename... Ix> void check_dimensions(Ix... index) const {
541        check_dimensions_impl(size_t(0), shape(), size_t(index)...);
542    }
543
544    void check_dimensions_impl(size_t, const size_t*) const { }
545
546    template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
547        if (i >= *shape) {
548            throw index_error(std::string("index ") + std::to_string(i) +
549                              " is out of bounds for axis " + std::to_string(axis) +
550                              " with size " + std::to_string(*shape));
551        }
552        check_dimensions_impl(axis + 1, shape + 1, index...);
553    }
554
555    /// Create array from any object -- always returns a new reference
556    static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
557        if (ptr == nullptr)
558            return nullptr;
559        return detail::npy_api::get().PyArray_FromAny_(
560            ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
561    }
562};
563
564template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
565public:
566    array_t() : array(0, static_cast<const T *>(nullptr)) {}
567    array_t(handle h, borrowed_t) : array(h, borrowed) { }
568    array_t(handle h, stolen_t) : array(h, stolen) { }
569
570    PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
571    array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
572        if (!m_ptr) PyErr_Clear();
573        if (!is_borrowed) Py_XDECREF(h.ptr());
574    }
575
576    array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
577        if (!m_ptr) throw error_already_set();
578    }
579
580    explicit array_t(const buffer_info& info) : array(info) { }
581
582    array_t(const std::vector<size_t> &shape,
583            const std::vector<size_t> &strides, const T *ptr = nullptr,
584            handle base = handle())
585        : array(shape, strides, ptr, base) { }
586
587    explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
588            handle base = handle())
589        : array(shape, ptr, base) { }
590
591    explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
592        : array(count, ptr, base) { }
593
594    constexpr size_t itemsize() const {
595        return sizeof(T);
596    }
597
598    template<typename... Ix> size_t index_at(Ix... index) const {
599        return offset_at(index...) / itemsize();
600    }
601
602    template<typename... Ix> const T* data(Ix... index) const {
603        return static_cast<const T*>(array::data(index...));
604    }
605
606    template<typename... Ix> T* mutable_data(Ix... index) {
607        return static_cast<T*>(array::mutable_data(index...));
608    }
609
610    // Reference to element at a given index
611    template<typename... Ix> const T& at(Ix... index) const {
612        if (sizeof...(index) != ndim())
613            fail_dim_check(sizeof...(index), "index dimension mismatch");
614        return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
615    }
616
617    // Mutable reference to element at a given index
618    template<typename... Ix> T& mutable_at(Ix... index) {
619        if (sizeof...(index) != ndim())
620            fail_dim_check(sizeof...(index), "index dimension mismatch");
621        return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
622    }
623
624    /// Ensure that the argument is a NumPy array of the correct dtype.
625    /// In case of an error, nullptr is returned and the Python error is cleared.
626    static array_t ensure(handle h) {
627        auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
628        if (!result)
629            PyErr_Clear();
630        return result;
631    }
632
633    static bool _check(handle h) {
634        const auto &api = detail::npy_api::get();
635        return api.PyArray_Check_(h.ptr())
636               && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
637    }
638
639protected:
640    /// Create array from any object -- always returns a new reference
641    static PyObject *raw_array_t(PyObject *ptr) {
642        if (ptr == nullptr)
643            return nullptr;
644        return detail::npy_api::get().PyArray_FromAny_(
645            ptr, dtype::of<T>().release().ptr(), 0, 0,
646            detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
647    }
648};
649
650template <typename T>
651struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
652    static std::string format() {
653        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
654    }
655};
656
657template <size_t N> struct format_descriptor<char[N]> {
658    static std::string format() { return std::to_string(N) + "s"; }
659};
660template <size_t N> struct format_descriptor<std::array<char, N>> {
661    static std::string format() { return std::to_string(N) + "s"; }
662};
663
664template <typename T>
665struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
666    static std::string format() {
667        return format_descriptor<
668            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
669    }
670};
671
672NAMESPACE_BEGIN(detail)
673template <typename T, int ExtraFlags>
674struct pyobject_caster<array_t<T, ExtraFlags>> {
675    using type = array_t<T, ExtraFlags>;
676
677    bool load(handle src, bool /* convert */) {
678        value = type::ensure(src);
679        return static_cast<bool>(value);
680    }
681
682    static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
683        return src.inc_ref();
684    }
685    PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
686};
687
688template <typename T> struct is_std_array : std::false_type { };
689template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
690
691template <typename T>
692struct is_pod_struct {
693    enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
694           !std::is_reference<T>::value &&
695           !std::is_array<T>::value &&
696           !is_std_array<T>::value &&
697           !std::is_integral<T>::value &&
698           !std::is_enum<T>::value &&
699           !std::is_same<typename std::remove_cv<T>::type, float>::value &&
700           !std::is_same<typename std::remove_cv<T>::type, double>::value &&
701           !std::is_same<typename std::remove_cv<T>::type, bool>::value &&
702           !std::is_same<typename std::remove_cv<T>::type, std::complex<float>>::value &&
703           !std::is_same<typename std::remove_cv<T>::type, std::complex<double>>::value };
704};
705
706template <typename T> struct npy_format_descriptor<T, enable_if_t<std::is_integral<T>::value>> {
707private:
708    constexpr static const int values[8] = {
709        npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_,    npy_api::NPY_USHORT_,
710        npy_api::NPY_INT_,  npy_api::NPY_UINT_,  npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
711public:
712    enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
713    static pybind11::dtype dtype() {
714        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
715            return reinterpret_borrow<pybind11::dtype>(ptr);
716        pybind11_fail("Unsupported buffer format!");
717    }
718    template <typename T2 = T, enable_if_t<std::is_signed<T2>::value, int> = 0>
719    static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
720    template <typename T2 = T, enable_if_t<!std::is_signed<T2>::value, int> = 0>
721    static PYBIND11_DESCR name() { return _("uint") + _<sizeof(T)*8>(); }
722};
723template <typename T> constexpr const int npy_format_descriptor<
724    T, enable_if_t<std::is_integral<T>::value>>::values[8];
725
726#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
727    enum { value = npy_api::NumPyName }; \
728    static pybind11::dtype dtype() { \
729        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
730            return reinterpret_borrow<pybind11::dtype>(ptr); \
731        pybind11_fail("Unsupported buffer format!"); \
732    } \
733    static PYBIND11_DESCR name() { return _(Name); } }
734DECL_FMT(float, NPY_FLOAT_, "float32");
735DECL_FMT(double, NPY_DOUBLE_, "float64");
736DECL_FMT(bool, NPY_BOOL_, "bool");
737DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
738DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
739#undef DECL_FMT
740
741#define DECL_CHAR_FMT \
742    static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
743    static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
744template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
745template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
746#undef DECL_CHAR_FMT
747
748template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
749private:
750    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
751public:
752    static PYBIND11_DESCR name() { return base_descr::name(); }
753    static pybind11::dtype dtype() { return base_descr::dtype(); }
754};
755
756struct field_descriptor {
757    const char *name;
758    size_t offset;
759    size_t size;
760    size_t alignment;
761    std::string format;
762    dtype descr;
763};
764
765inline PYBIND11_NOINLINE void register_structured_dtype(
766    const std::initializer_list<field_descriptor>& fields,
767    const std::type_info& tinfo, size_t itemsize,
768    bool (*direct_converter)(PyObject *, void *&)) {
769
770    auto& numpy_internals = get_numpy_internals();
771    if (numpy_internals.get_type_info(tinfo, false))
772        pybind11_fail("NumPy: dtype is already registered");
773
774    list names, formats, offsets;
775    for (auto field : fields) {
776        if (!field.descr)
777            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
778                            field.name + "` @ " + tinfo.name());
779        names.append(PYBIND11_STR_TYPE(field.name));
780        formats.append(field.descr);
781        offsets.append(pybind11::int_(field.offset));
782    }
783    auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
784
785    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
786    // not encoded explicitly into the format string. This will supposedly
787    // get fixed in v1.12; for further details, see these:
788    // - https://github.com/numpy/numpy/issues/7797
789    // - https://github.com/numpy/numpy/pull/7798
790    // Because of this, we won't use numpy's logic to generate buffer format
791    // strings and will just do it ourselves.
792    std::vector<field_descriptor> ordered_fields(fields);
793    std::sort(ordered_fields.begin(), ordered_fields.end(),
794        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
795    size_t offset = 0;
796    std::ostringstream oss;
797    oss << "T{";
798    for (auto& field : ordered_fields) {
799        if (field.offset > offset)
800            oss << (field.offset - offset) << 'x';
801        // mark unaligned fields with '='
802        if (field.offset % field.alignment)
803            oss << '=';
804        oss << field.format << ':' << field.name << ':';
805        offset = field.offset + field.size;
806    }
807    if (itemsize > offset)
808        oss << (itemsize - offset) << 'x';
809    oss << '}';
810    auto format_str = oss.str();
811
812    // Sanity check: verify that NumPy properly parses our buffer format string
813    auto& api = npy_api::get();
814    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
815    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
816        pybind11_fail("NumPy: invalid buffer descriptor!");
817
818    auto tindex = std::type_index(tinfo);
819    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
820    get_internals().direct_conversions[tindex].push_back(direct_converter);
821}
822
823template <typename T>
824struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
825    static PYBIND11_DESCR name() { return _("struct"); }
826
827    static pybind11::dtype dtype() {
828        return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
829    }
830
831    static std::string format() {
832        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
833        return format_str;
834    }
835
836    static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
837        register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
838                                  sizeof(T), &direct_converter);
839    }
840
841private:
842    static PyObject* dtype_ptr() {
843        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
844        return ptr;
845    }
846
847    static bool direct_converter(PyObject *obj, void*& value) {
848        auto& api = npy_api::get();
849        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
850            return false;
851        if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
852            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
853                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
854                return true;
855            }
856        }
857        return false;
858    }
859};
860
861#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
862    ::pybind11::detail::field_descriptor {                                                    \
863        Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
864        alignof(decltype(std::declval<T>().Field)),                                           \
865        ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(),           \
866        ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
867    }
868
869// Extract name, offset and format descriptor for a struct field
870#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
871
872// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
873// (C) William Swanson, Paul Fultz
874#define PYBIND11_EVAL0(...) __VA_ARGS__
875#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
876#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
877#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
878#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
879#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
880#define PYBIND11_MAP_END(...)
881#define PYBIND11_MAP_OUT
882#define PYBIND11_MAP_COMMA ,
883#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
884#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
885#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
886#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
887#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
888#define PYBIND11_MAP_LIST_NEXT1(test, next) \
889    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
890#else
891#define PYBIND11_MAP_LIST_NEXT1(test, next) \
892    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
893#endif
894#define PYBIND11_MAP_LIST_NEXT(test, next) \
895    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
896#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
897    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
898#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
899    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
900// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
901#define PYBIND11_MAP_LIST(f, t, ...) \
902    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
903
904#define PYBIND11_NUMPY_DTYPE(Type, ...) \
905    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
906        ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
907
908#ifdef _MSC_VER
909#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
910    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
911#else
912#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
913    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
914#endif
915#define PYBIND11_MAP2_LIST_NEXT(test, next) \
916    PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
917#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
918    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
919#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
920    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
921// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
922#define PYBIND11_MAP2_LIST(f, t, ...) \
923    PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))
924
925#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
926    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
927        ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
928
929template  <class T>
930using array_iterator = typename std::add_pointer<T>::type;
931
932template <class T>
933array_iterator<T> array_begin(const buffer_info& buffer) {
934    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
935}
936
937template <class T>
938array_iterator<T> array_end(const buffer_info& buffer) {
939    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
940}
941
942class common_iterator {
943public:
944    using container_type = std::vector<size_t>;
945    using value_type = container_type::value_type;
946    using size_type = container_type::size_type;
947
948    common_iterator() : p_ptr(0), m_strides() {}
949
950    common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
951        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
952        m_strides.back() = static_cast<value_type>(strides.back());
953        for (size_type i = m_strides.size() - 1; i != 0; --i) {
954            size_type j = i - 1;
955            value_type s = static_cast<value_type>(shape[i]);
956            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
957        }
958    }
959
960    void increment(size_type dim) {
961        p_ptr += m_strides[dim];
962    }
963
964    void* data() const {
965        return p_ptr;
966    }
967
968private:
969    char* p_ptr;
970    container_type m_strides;
971};
972
973template <size_t N> class multi_array_iterator {
974public:
975    using container_type = std::vector<size_t>;
976
977    multi_array_iterator(const std::array<buffer_info, N> &buffers,
978                         const std::vector<size_t> &shape)
979        : m_shape(shape.size()), m_index(shape.size(), 0),
980          m_common_iterator() {
981
982        // Manual copy to avoid conversion warning if using std::copy
983        for (size_t i = 0; i < shape.size(); ++i)
984            m_shape[i] = static_cast<container_type::value_type>(shape[i]);
985
986        container_type strides(shape.size());
987        for (size_t i = 0; i < N; ++i)
988            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
989    }
990
991    multi_array_iterator& operator++() {
992        for (size_t j = m_index.size(); j != 0; --j) {
993            size_t i = j - 1;
994            if (++m_index[i] != m_shape[i]) {
995                increment_common_iterator(i);
996                break;
997            } else {
998                m_index[i] = 0;
999            }
1000        }
1001        return *this;
1002    }
1003
1004    template <size_t K, class T> const T& data() const {
1005        return *reinterpret_cast<T*>(m_common_iterator[K].data());
1006    }
1007
1008private:
1009
1010    using common_iter = common_iterator;
1011
1012    void init_common_iterator(const buffer_info &buffer,
1013                              const std::vector<size_t> &shape,
1014                              common_iter &iterator, container_type &strides) {
1015        auto buffer_shape_iter = buffer.shape.rbegin();
1016        auto buffer_strides_iter = buffer.strides.rbegin();
1017        auto shape_iter = shape.rbegin();
1018        auto strides_iter = strides.rbegin();
1019
1020        while (buffer_shape_iter != buffer.shape.rend()) {
1021            if (*shape_iter == *buffer_shape_iter)
1022                *strides_iter = static_cast<size_t>(*buffer_strides_iter);
1023            else
1024                *strides_iter = 0;
1025
1026            ++buffer_shape_iter;
1027            ++buffer_strides_iter;
1028            ++shape_iter;
1029            ++strides_iter;
1030        }
1031
1032        std::fill(strides_iter, strides.rend(), 0);
1033        iterator = common_iter(buffer.ptr, strides, shape);
1034    }
1035
1036    void increment_common_iterator(size_t dim) {
1037        for (auto &iter : m_common_iterator)
1038            iter.increment(dim);
1039    }
1040
1041    container_type m_shape;
1042    container_type m_index;
1043    std::array<common_iter, N> m_common_iterator;
1044};
1045
1046template <size_t N>
1047bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) {
1048    ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
1049        return std::max(res, buf.ndim);
1050    });
1051
1052    shape = std::vector<size_t>(ndim, 1);
1053    bool trivial_broadcast = true;
1054    for (size_t i = 0; i < N; ++i) {
1055        auto res_iter = shape.rbegin();
1056        bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
1057        for (auto shape_iter = buffers[i].shape.rbegin();
1058             shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {
1059
1060            if (*res_iter == 1)
1061                *res_iter = *shape_iter;
1062            else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
1063                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1064
1065            i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
1066        }
1067        trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
1068    }
1069    return trivial_broadcast;
1070}
1071
1072template <typename Func, typename Return, typename... Args>
1073struct vectorize_helper {
1074    typename std::remove_reference<Func>::type f;
1075
1076    template <typename T>
1077    explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
1078
1079    object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
1080        return run(args..., make_index_sequence<sizeof...(Args)>());
1081    }
1082
1083    template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
1084        /* Request buffers from all parameters */
1085        const size_t N = sizeof...(Args);
1086
1087        std::array<buffer_info, N> buffers {{ args.request()... }};
1088
1089        /* Determine dimensions parameters of output array */
1090        size_t ndim = 0;
1091        std::vector<size_t> shape(0);
1092        bool trivial_broadcast = broadcast(buffers, ndim, shape);
1093
1094        size_t size = 1;
1095        std::vector<size_t> strides(ndim);
1096        if (ndim > 0) {
1097            strides[ndim-1] = sizeof(Return);
1098            for (size_t i = ndim - 1; i > 0; --i) {
1099                strides[i - 1] = strides[i] * shape[i];
1100                size *= shape[i];
1101            }
1102            size *= shape[0];
1103        }
1104
1105        if (size == 1)
1106            return cast(f(*((Args *) buffers[Index].ptr)...));
1107
1108        array_t<Return> result(shape, strides);
1109        auto buf = result.request();
1110        auto output = (Return *) buf.ptr;
1111
1112        if (trivial_broadcast) {
1113            /* Call the function */
1114            for (size_t i = 0; i < size; ++i) {
1115                output[i] = f((buffers[Index].size == 1
1116                               ? *((Args *) buffers[Index].ptr)
1117                               : ((Args *) buffers[Index].ptr)[i])...);
1118            }
1119        } else {
1120            apply_broadcast<N, Index...>(buffers, buf, index);
1121        }
1122
1123        return result;
1124    }
1125
1126    template <size_t N, size_t... Index>
1127    void apply_broadcast(const std::array<buffer_info, N> &buffers,
1128                         buffer_info &output, index_sequence<Index...>) {
1129        using input_iterator = multi_array_iterator<N>;
1130        using output_iterator = array_iterator<Return>;
1131
1132        input_iterator input_iter(buffers, output.shape);
1133        output_iterator output_end = array_end<Return>(output);
1134
1135        for (output_iterator iter = array_begin<Return>(output);
1136             iter != output_end; ++iter, ++input_iter) {
1137            *iter = f((input_iter.template data<Index, Args>())...);
1138        }
1139    }
1140};
1141
1142template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
1143    static PYBIND11_DESCR name() { return _("numpy.ndarray[") + type_caster<T>::name() + _("]"); }
1144};
1145
1146NAMESPACE_END(detail)
1147
1148template <typename Func, typename Return, typename... Args>
1149detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) {
1150    return detail::vectorize_helper<Func, Return, Args...>(f);
1151}
1152
1153template <typename Return, typename... Args>
1154detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) {
1155    return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
1156}
1157
1158template <typename Func>
1159auto vectorize(Func &&f) -> decltype(
1160        vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type *) nullptr)) {
1161    return vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(
1162                   &std::remove_reference<Func>::type::operator())>::type *) nullptr);
1163}
1164
1165NAMESPACE_END(pybind11)
1166
1167#if defined(_MSC_VER)
1168#pragma warning(pop)
1169#endif
1170