numpy.h revision 12037:d28054ac6ec9
1/* 2 pybind11/numpy.h: Basic NumPy support, vectorize() wrapper 3 4 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> 5 6 All rights reserved. Use of this source code is governed by a 7 BSD-style license that can be found in the LICENSE file. 8*/ 9 10#pragma once 11 12#include "pybind11.h" 13#include "complex.h" 14#include <numeric> 15#include <algorithm> 16#include <array> 17#include <cstdlib> 18#include <cstring> 19#include <sstream> 20#include <string> 21#include <initializer_list> 22#include <functional> 23#include <utility> 24#include <typeindex> 25 26#if defined(_MSC_VER) 27# pragma warning(push) 28# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 29#endif 30 31/* This will be true on all flat address space platforms and allows us to reduce the 32 whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size 33 and dimension types (e.g. shape, strides, indexing), instead of inflicting this 34 upon the library user. */ 35static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); 36 37NAMESPACE_BEGIN(pybind11) 38 39class array; // Forward declaration 40 41NAMESPACE_BEGIN(detail) 42template <typename type, typename SFINAE = void> struct npy_format_descriptor; 43 44struct PyArrayDescr_Proxy { 45 PyObject_HEAD 46 PyObject *typeobj; 47 char kind; 48 char type; 49 char byteorder; 50 char flags; 51 int type_num; 52 int elsize; 53 int alignment; 54 char *subarray; 55 PyObject *fields; 56 PyObject *names; 57}; 58 59struct PyArray_Proxy { 60 PyObject_HEAD 61 char *data; 62 int nd; 63 ssize_t *dimensions; 64 ssize_t *strides; 65 PyObject *base; 66 PyObject *descr; 67 int flags; 68}; 69 70struct PyVoidScalarObject_Proxy { 71 PyObject_VAR_HEAD 72 char *obval; 73 PyArrayDescr_Proxy *descr; 74 int flags; 75 PyObject *base; 76}; 77 78struct numpy_type_info { 79 PyObject* dtype_ptr; 80 std::string format_str; 81}; 82 83struct numpy_internals { 84 std::unordered_map<std::type_index, numpy_type_info> registered_dtypes; 85 86 numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) { 87 auto it = registered_dtypes.find(std::type_index(tinfo)); 88 if (it != registered_dtypes.end()) 89 return &(it->second); 90 if (throw_if_missing) 91 pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name()); 92 return nullptr; 93 } 94 95 template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) { 96 return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing); 97 } 98}; 99 100inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { 101 ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals"); 102} 103 104inline numpy_internals& get_numpy_internals() { 105 static numpy_internals* ptr = nullptr; 106 if (!ptr) 107 load_numpy_internals(ptr); 108 return *ptr; 109} 110 111struct npy_api { 112 enum constants { 113 NPY_ARRAY_C_CONTIGUOUS_ = 0x0001, 114 NPY_ARRAY_F_CONTIGUOUS_ = 0x0002, 115 NPY_ARRAY_OWNDATA_ = 0x0004, 116 NPY_ARRAY_FORCECAST_ = 0x0010, 117 NPY_ARRAY_ENSUREARRAY_ = 0x0040, 118 NPY_ARRAY_ALIGNED_ = 0x0100, 119 NPY_ARRAY_WRITEABLE_ = 0x0400, 120 NPY_BOOL_ = 0, 121 NPY_BYTE_, NPY_UBYTE_, 122 NPY_SHORT_, NPY_USHORT_, 123 NPY_INT_, NPY_UINT_, 124 NPY_LONG_, NPY_ULONG_, 125 NPY_LONGLONG_, NPY_ULONGLONG_, 126 NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_, 127 NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_, 128 NPY_OBJECT_ = 17, 129 NPY_STRING_, NPY_UNICODE_, NPY_VOID_ 130 }; 131 132 static npy_api& get() { 133 static npy_api api = lookup(); 134 return api; 135 } 136 137 bool PyArray_Check_(PyObject *obj) const { 138 return (bool) PyObject_TypeCheck(obj, PyArray_Type_); 139 } 140 bool PyArrayDescr_Check_(PyObject *obj) const { 141 return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_); 142 } 143 144 PyObject *(*PyArray_DescrFromType_)(int); 145 PyObject *(*PyArray_NewFromDescr_) 146 (PyTypeObject *, PyObject *, int, Py_intptr_t *, 147 Py_intptr_t *, void *, int, PyObject *); 148 PyObject *(*PyArray_DescrNewFromType_)(int); 149 PyObject *(*PyArray_NewCopy_)(PyObject *, int); 150 PyTypeObject *PyArray_Type_; 151 PyTypeObject *PyVoidArrType_Type_; 152 PyTypeObject *PyArrayDescr_Type_; 153 PyObject *(*PyArray_DescrFromScalar_)(PyObject *); 154 PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); 155 int (*PyArray_DescrConverter_) (PyObject *, PyObject **); 156 bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); 157 int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, 158 Py_ssize_t *, PyObject **, PyObject *); 159 PyObject *(*PyArray_Squeeze_)(PyObject *); 160 int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); 161private: 162 enum functions { 163 API_PyArray_Type = 2, 164 API_PyArrayDescr_Type = 3, 165 API_PyVoidArrType_Type = 39, 166 API_PyArray_DescrFromType = 45, 167 API_PyArray_DescrFromScalar = 57, 168 API_PyArray_FromAny = 69, 169 API_PyArray_NewCopy = 85, 170 API_PyArray_NewFromDescr = 94, 171 API_PyArray_DescrNewFromType = 9, 172 API_PyArray_DescrConverter = 174, 173 API_PyArray_EquivTypes = 182, 174 API_PyArray_GetArrayParamsFromObject = 278, 175 API_PyArray_Squeeze = 136, 176 API_PyArray_SetBaseObject = 282 177 }; 178 179 static npy_api lookup() { 180 module m = module::import("numpy.core.multiarray"); 181 auto c = m.attr("_ARRAY_API"); 182#if PY_MAJOR_VERSION >= 3 183 void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL); 184#else 185 void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr()); 186#endif 187 npy_api api; 188#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func]; 189 DECL_NPY_API(PyArray_Type); 190 DECL_NPY_API(PyVoidArrType_Type); 191 DECL_NPY_API(PyArrayDescr_Type); 192 DECL_NPY_API(PyArray_DescrFromType); 193 DECL_NPY_API(PyArray_DescrFromScalar); 194 DECL_NPY_API(PyArray_FromAny); 195 DECL_NPY_API(PyArray_NewCopy); 196 DECL_NPY_API(PyArray_NewFromDescr); 197 DECL_NPY_API(PyArray_DescrNewFromType); 198 DECL_NPY_API(PyArray_DescrConverter); 199 DECL_NPY_API(PyArray_EquivTypes); 200 DECL_NPY_API(PyArray_GetArrayParamsFromObject); 201 DECL_NPY_API(PyArray_Squeeze); 202 DECL_NPY_API(PyArray_SetBaseObject); 203#undef DECL_NPY_API 204 return api; 205 } 206}; 207 208inline PyArray_Proxy* array_proxy(void* ptr) { 209 return reinterpret_cast<PyArray_Proxy*>(ptr); 210} 211 212inline const PyArray_Proxy* array_proxy(const void* ptr) { 213 return reinterpret_cast<const PyArray_Proxy*>(ptr); 214} 215 216inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) { 217 return reinterpret_cast<PyArrayDescr_Proxy*>(ptr); 218} 219 220inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) { 221 return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr); 222} 223 224inline bool check_flags(const void* ptr, int flag) { 225 return (flag == (array_proxy(ptr)->flags & flag)); 226} 227 228template <typename T> struct is_std_array : std::false_type { }; 229template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { }; 230template <typename T> struct is_complex : std::false_type { }; 231template <typename T> struct is_complex<std::complex<T>> : std::true_type { }; 232 233template <typename T> using is_pod_struct = all_of< 234 std::is_pod<T>, // since we're accessing directly in memory we need a POD type 235 satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum> 236>; 237 238template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; } 239template <size_t Dim = 0, typename Strides, typename... Ix> 240size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) { 241 return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...); 242} 243 244/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through 245 * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims` 246 * will be -1 for dimensions determined at runtime. 247 */ 248template <typename T, ssize_t Dims> 249class unchecked_reference { 250protected: 251 static constexpr bool Dynamic = Dims < 0; 252 const unsigned char *data_; 253 // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to 254 // make large performance gains on big, nested loops, but requires compile-time dimensions 255 conditional_t<Dynamic, const size_t *, std::array<size_t, (size_t) Dims>> 256 shape_, strides_; 257 const size_t dims_; 258 259 friend class pybind11::array; 260 // Constructor for compile-time dimensions: 261 template <bool Dyn = Dynamic> 262 unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<!Dyn, size_t>) 263 : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} { 264 for (size_t i = 0; i < dims_; i++) { 265 shape_[i] = shape[i]; 266 strides_[i] = strides[i]; 267 } 268 } 269 // Constructor for runtime dimensions: 270 template <bool Dyn = Dynamic> 271 unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<Dyn, size_t> dims) 272 : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {} 273 274public: 275 /** Unchecked const reference access to data at the given indices. For a compile-time known 276 * number of dimensions, this requires the correct number of arguments; for run-time 277 * dimensionality, this is not checked (and so is up to the caller to use safely). 278 */ 279 template <typename... Ix> const T &operator()(Ix... index) const { 280 static_assert(sizeof...(Ix) == Dims || Dynamic, 281 "Invalid number of indices for unchecked array reference"); 282 return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t(index)...)); 283 } 284 /** Unchecked const reference access to data; this operator only participates if the reference 285 * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. 286 */ 287 template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>> 288 const T &operator[](size_t index) const { return operator()(index); } 289 290 /// Pointer access to the data at the given indices. 291 template <typename... Ix> const T *data(Ix... ix) const { return &operator()(size_t(ix)...); } 292 293 /// Returns the item size, i.e. sizeof(T) 294 constexpr static size_t itemsize() { return sizeof(T); } 295 296 /// Returns the shape (i.e. size) of dimension `dim` 297 size_t shape(size_t dim) const { return shape_[dim]; } 298 299 /// Returns the number of dimensions of the array 300 size_t ndim() const { return dims_; } 301 302 /// Returns the total number of elements in the referenced array, i.e. the product of the shapes 303 template <bool Dyn = Dynamic> 304 enable_if_t<!Dyn, size_t> size() const { 305 return std::accumulate(shape_.begin(), shape_.end(), (size_t) 1, std::multiplies<size_t>()); 306 } 307 template <bool Dyn = Dynamic> 308 enable_if_t<Dyn, size_t> size() const { 309 return std::accumulate(shape_, shape_ + ndim(), (size_t) 1, std::multiplies<size_t>()); 310 } 311 312 /// Returns the total number of bytes used by the referenced data. Note that the actual span in 313 /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice). 314 size_t nbytes() const { 315 return size() * itemsize(); 316 } 317}; 318 319template <typename T, ssize_t Dims> 320class unchecked_mutable_reference : public unchecked_reference<T, Dims> { 321 friend class pybind11::array; 322 using ConstBase = unchecked_reference<T, Dims>; 323 using ConstBase::ConstBase; 324 using ConstBase::Dynamic; 325public: 326 /// Mutable, unchecked access to data at the given indices. 327 template <typename... Ix> T& operator()(Ix... index) { 328 static_assert(sizeof...(Ix) == Dims || Dynamic, 329 "Invalid number of indices for unchecked array reference"); 330 return const_cast<T &>(ConstBase::operator()(index...)); 331 } 332 /** Mutable, unchecked access data at the given index; this operator only participates if the 333 * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is 334 * exactly equivalent to `obj(index)`. 335 */ 336 template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>> 337 T &operator[](size_t index) { return operator()(index); } 338 339 /// Mutable pointer access to the data at the given indices. 340 template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(size_t(ix)...); } 341}; 342 343template <typename T, size_t Dim> 344struct type_caster<unchecked_reference<T, Dim>> { 345 static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable"); 346}; 347template <typename T, size_t Dim> 348struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {}; 349 350NAMESPACE_END(detail) 351 352class dtype : public object { 353public: 354 PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); 355 356 explicit dtype(const buffer_info &info) { 357 dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format))); 358 // If info.itemsize == 0, use the value calculated from the format string 359 m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr(); 360 } 361 362 explicit dtype(const std::string &format) { 363 m_ptr = from_args(pybind11::str(format)).release().ptr(); 364 } 365 366 dtype(const char *format) : dtype(std::string(format)) { } 367 368 dtype(list names, list formats, list offsets, size_t itemsize) { 369 dict args; 370 args["names"] = names; 371 args["formats"] = formats; 372 args["offsets"] = offsets; 373 args["itemsize"] = pybind11::int_(itemsize); 374 m_ptr = from_args(args).release().ptr(); 375 } 376 377 /// This is essentially the same as calling numpy.dtype(args) in Python. 378 static dtype from_args(object args) { 379 PyObject *ptr = nullptr; 380 if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr) 381 throw error_already_set(); 382 return reinterpret_steal<dtype>(ptr); 383 } 384 385 /// Return dtype associated with a C++ type. 386 template <typename T> static dtype of() { 387 return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype(); 388 } 389 390 /// Size of the data type in bytes. 391 size_t itemsize() const { 392 return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize; 393 } 394 395 /// Returns true for structured data types. 396 bool has_fields() const { 397 return detail::array_descriptor_proxy(m_ptr)->names != nullptr; 398 } 399 400 /// Single-character type code. 401 char kind() const { 402 return detail::array_descriptor_proxy(m_ptr)->kind; 403 } 404 405private: 406 static object _dtype_from_pep3118() { 407 static PyObject *obj = module::import("numpy.core._internal") 408 .attr("_dtype_from_pep3118").cast<object>().release().ptr(); 409 return reinterpret_borrow<object>(obj); 410 } 411 412 dtype strip_padding(size_t itemsize) { 413 // Recursively strip all void fields with empty names that are generated for 414 // padding fields (as of NumPy v1.11). 415 if (!has_fields()) 416 return *this; 417 418 struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; 419 std::vector<field_descr> field_descriptors; 420 421 for (auto field : attr("fields").attr("items")()) { 422 auto spec = field.cast<tuple>(); 423 auto name = spec[0].cast<pybind11::str>(); 424 auto format = spec[1].cast<tuple>()[0].cast<dtype>(); 425 auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>(); 426 if (!len(name) && format.kind() == 'V') 427 continue; 428 field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset}); 429 } 430 431 std::sort(field_descriptors.begin(), field_descriptors.end(), 432 [](const field_descr& a, const field_descr& b) { 433 return a.offset.cast<int>() < b.offset.cast<int>(); 434 }); 435 436 list names, formats, offsets; 437 for (auto& descr : field_descriptors) { 438 names.append(descr.name); 439 formats.append(descr.format); 440 offsets.append(descr.offset); 441 } 442 return dtype(names, formats, offsets, itemsize); 443 } 444}; 445 446class array : public buffer { 447public: 448 PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array) 449 450 enum { 451 c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_, 452 f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_, 453 forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ 454 }; 455 456 array() : array(0, static_cast<const double *>(nullptr)) {} 457 458 array(const pybind11::dtype &dt, const std::vector<size_t> &shape, 459 const std::vector<size_t> &strides, const void *ptr = nullptr, 460 handle base = handle()) { 461 auto& api = detail::npy_api::get(); 462 auto ndim = shape.size(); 463 if (shape.size() != strides.size()) 464 pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); 465 auto descr = dt; 466 467 int flags = 0; 468 if (base && ptr) { 469 if (isinstance<array>(base)) 470 /* Copy flags from base (except ownership bit) */ 471 flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; 472 else 473 /* Writable by default, easy to downgrade later on if needed */ 474 flags = detail::npy_api::NPY_ARRAY_WRITEABLE_; 475 } 476 477 auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_( 478 api.PyArray_Type_, descr.release().ptr(), (int) ndim, 479 reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())), 480 reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())), 481 const_cast<void *>(ptr), flags, nullptr)); 482 if (!tmp) 483 pybind11_fail("NumPy: unable to create array!"); 484 if (ptr) { 485 if (base) { 486 api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr()); 487 } else { 488 tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); 489 } 490 } 491 m_ptr = tmp.release().ptr(); 492 } 493 494 array(const pybind11::dtype &dt, const std::vector<size_t> &shape, 495 const void *ptr = nullptr, handle base = handle()) 496 : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { } 497 498 array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr, 499 handle base = handle()) 500 : array(dt, std::vector<size_t>{ count }, ptr, base) { } 501 502 template<typename T> array(const std::vector<size_t>& shape, 503 const std::vector<size_t>& strides, 504 const T* ptr, handle base = handle()) 505 : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { } 506 507 template <typename T> 508 array(const std::vector<size_t> &shape, const T *ptr, 509 handle base = handle()) 510 : array(shape, default_strides(shape, sizeof(T)), ptr, base) { } 511 512 template <typename T> 513 array(size_t count, const T *ptr, handle base = handle()) 514 : array(std::vector<size_t>{ count }, ptr, base) { } 515 516 explicit array(const buffer_info &info) 517 : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } 518 519 /// Array descriptor (dtype) 520 pybind11::dtype dtype() const { 521 return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr); 522 } 523 524 /// Total number of elements 525 size_t size() const { 526 return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>()); 527 } 528 529 /// Byte size of a single element 530 size_t itemsize() const { 531 return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize; 532 } 533 534 /// Total number of bytes 535 size_t nbytes() const { 536 return size() * itemsize(); 537 } 538 539 /// Number of dimensions 540 size_t ndim() const { 541 return (size_t) detail::array_proxy(m_ptr)->nd; 542 } 543 544 /// Base object 545 object base() const { 546 return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base); 547 } 548 549 /// Dimensions of the array 550 const size_t* shape() const { 551 return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions); 552 } 553 554 /// Dimension along a given axis 555 size_t shape(size_t dim) const { 556 if (dim >= ndim()) 557 fail_dim_check(dim, "invalid axis"); 558 return shape()[dim]; 559 } 560 561 /// Strides of the array 562 const size_t* strides() const { 563 return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides); 564 } 565 566 /// Stride along a given axis 567 size_t strides(size_t dim) const { 568 if (dim >= ndim()) 569 fail_dim_check(dim, "invalid axis"); 570 return strides()[dim]; 571 } 572 573 /// Return the NumPy array flags 574 int flags() const { 575 return detail::array_proxy(m_ptr)->flags; 576 } 577 578 /// If set, the array is writeable (otherwise the buffer is read-only) 579 bool writeable() const { 580 return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); 581 } 582 583 /// If set, the array owns the data (will be freed when the array is deleted) 584 bool owndata() const { 585 return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); 586 } 587 588 /// Pointer to the contained data. If index is not provided, points to the 589 /// beginning of the buffer. May throw if the index would lead to out of bounds access. 590 template<typename... Ix> const void* data(Ix... index) const { 591 return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...)); 592 } 593 594 /// Mutable pointer to the contained data. If index is not provided, points to the 595 /// beginning of the buffer. May throw if the index would lead to out of bounds access. 596 /// May throw if the array is not writeable. 597 template<typename... Ix> void* mutable_data(Ix... index) { 598 check_writeable(); 599 return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...)); 600 } 601 602 /// Byte offset from beginning of the array to a given index (full or partial). 603 /// May throw if the index would lead to out of bounds access. 604 template<typename... Ix> size_t offset_at(Ix... index) const { 605 if (sizeof...(index) > ndim()) 606 fail_dim_check(sizeof...(index), "too many indices for an array"); 607 return byte_offset(size_t(index)...); 608 } 609 610 size_t offset_at() const { return 0; } 611 612 /// Item count from beginning of the array to a given index (full or partial). 613 /// May throw if the index would lead to out of bounds access. 614 template<typename... Ix> size_t index_at(Ix... index) const { 615 return offset_at(index...) / itemsize(); 616 } 617 618 /** Returns a proxy object that provides access to the array's data without bounds or 619 * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with 620 * care: the array must not be destroyed or reshaped for the duration of the returned object, 621 * and the caller must take care not to access invalid dimensions or dimension indices. 622 */ 623 template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() { 624 if (Dims >= 0 && ndim() != (size_t) Dims) 625 throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + 626 "; expected " + std::to_string(Dims)); 627 return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim()); 628 } 629 630 /** Returns a proxy object that provides const access to the array's data without bounds or 631 * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the 632 * underlying array have the `writable` flag. Use with care: the array must not be destroyed or 633 * reshaped for the duration of the returned object, and the caller must take care not to access 634 * invalid dimensions or dimension indices. 635 */ 636 template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const { 637 if (Dims >= 0 && ndim() != (size_t) Dims) 638 throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + 639 "; expected " + std::to_string(Dims)); 640 return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim()); 641 } 642 643 /// Return a new view with all of the dimensions of length 1 removed 644 array squeeze() { 645 auto& api = detail::npy_api::get(); 646 return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr)); 647 } 648 649 /// Ensure that the argument is a NumPy array 650 /// In case of an error, nullptr is returned and the Python error is cleared. 651 static array ensure(handle h, int ExtraFlags = 0) { 652 auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags)); 653 if (!result) 654 PyErr_Clear(); 655 return result; 656 } 657 658protected: 659 template<typename, typename> friend struct detail::npy_format_descriptor; 660 661 void fail_dim_check(size_t dim, const std::string& msg) const { 662 throw index_error(msg + ": " + std::to_string(dim) + 663 " (ndim = " + std::to_string(ndim()) + ")"); 664 } 665 666 template<typename... Ix> size_t byte_offset(Ix... index) const { 667 check_dimensions(index...); 668 return detail::byte_offset_unsafe(strides(), size_t(index)...); 669 } 670 671 void check_writeable() const { 672 if (!writeable()) 673 throw std::domain_error("array is not writeable"); 674 } 675 676 static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) { 677 auto ndim = shape.size(); 678 std::vector<size_t> strides(ndim); 679 if (ndim) { 680 std::fill(strides.begin(), strides.end(), itemsize); 681 for (size_t i = 0; i < ndim - 1; i++) 682 for (size_t j = 0; j < ndim - 1 - i; j++) 683 strides[j] *= shape[ndim - 1 - i]; 684 } 685 return strides; 686 } 687 688 template<typename... Ix> void check_dimensions(Ix... index) const { 689 check_dimensions_impl(size_t(0), shape(), size_t(index)...); 690 } 691 692 void check_dimensions_impl(size_t, const size_t*) const { } 693 694 template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const { 695 if (i >= *shape) { 696 throw index_error(std::string("index ") + std::to_string(i) + 697 " is out of bounds for axis " + std::to_string(axis) + 698 " with size " + std::to_string(*shape)); 699 } 700 check_dimensions_impl(axis + 1, shape + 1, index...); 701 } 702 703 /// Create array from any object -- always returns a new reference 704 static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) { 705 if (ptr == nullptr) 706 return nullptr; 707 return detail::npy_api::get().PyArray_FromAny_( 708 ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); 709 } 710}; 711 712template <typename T, int ExtraFlags = array::forcecast> class array_t : public array { 713public: 714 using value_type = T; 715 716 array_t() : array(0, static_cast<const T *>(nullptr)) {} 717 array_t(handle h, borrowed_t) : array(h, borrowed) { } 718 array_t(handle h, stolen_t) : array(h, stolen) { } 719 720 PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead") 721 array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) { 722 if (!m_ptr) PyErr_Clear(); 723 if (!is_borrowed) Py_XDECREF(h.ptr()); 724 } 725 726 array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) { 727 if (!m_ptr) throw error_already_set(); 728 } 729 730 explicit array_t(const buffer_info& info) : array(info) { } 731 732 array_t(const std::vector<size_t> &shape, 733 const std::vector<size_t> &strides, const T *ptr = nullptr, 734 handle base = handle()) 735 : array(shape, strides, ptr, base) { } 736 737 explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr, 738 handle base = handle()) 739 : array(shape, ptr, base) { } 740 741 explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle()) 742 : array(count, ptr, base) { } 743 744 constexpr size_t itemsize() const { 745 return sizeof(T); 746 } 747 748 template<typename... Ix> size_t index_at(Ix... index) const { 749 return offset_at(index...) / itemsize(); 750 } 751 752 template<typename... Ix> const T* data(Ix... index) const { 753 return static_cast<const T*>(array::data(index...)); 754 } 755 756 template<typename... Ix> T* mutable_data(Ix... index) { 757 return static_cast<T*>(array::mutable_data(index...)); 758 } 759 760 // Reference to element at a given index 761 template<typename... Ix> const T& at(Ix... index) const { 762 if (sizeof...(index) != ndim()) 763 fail_dim_check(sizeof...(index), "index dimension mismatch"); 764 return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize()); 765 } 766 767 // Mutable reference to element at a given index 768 template<typename... Ix> T& mutable_at(Ix... index) { 769 if (sizeof...(index) != ndim()) 770 fail_dim_check(sizeof...(index), "index dimension mismatch"); 771 return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize()); 772 } 773 774 /** Returns a proxy object that provides access to the array's data without bounds or 775 * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with 776 * care: the array must not be destroyed or reshaped for the duration of the returned object, 777 * and the caller must take care not to access invalid dimensions or dimension indices. 778 */ 779 template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() { 780 return array::mutable_unchecked<T, Dims>(); 781 } 782 783 /** Returns a proxy object that provides const access to the array's data without bounds or 784 * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying 785 * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped 786 * for the duration of the returned object, and the caller must take care not to access invalid 787 * dimensions or dimension indices. 788 */ 789 template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const { 790 return array::unchecked<T, Dims>(); 791 } 792 793 /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert 794 /// it). In case of an error, nullptr is returned and the Python error is cleared. 795 static array_t ensure(handle h) { 796 auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr())); 797 if (!result) 798 PyErr_Clear(); 799 return result; 800 } 801 802 static bool check_(handle h) { 803 const auto &api = detail::npy_api::get(); 804 return api.PyArray_Check_(h.ptr()) 805 && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr()); 806 } 807 808protected: 809 /// Create array from any object -- always returns a new reference 810 static PyObject *raw_array_t(PyObject *ptr) { 811 if (ptr == nullptr) 812 return nullptr; 813 return detail::npy_api::get().PyArray_FromAny_( 814 ptr, dtype::of<T>().release().ptr(), 0, 0, 815 detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); 816 } 817}; 818 819template <typename T> 820struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> { 821 static std::string format() { 822 return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format(); 823 } 824}; 825 826template <size_t N> struct format_descriptor<char[N]> { 827 static std::string format() { return std::to_string(N) + "s"; } 828}; 829template <size_t N> struct format_descriptor<std::array<char, N>> { 830 static std::string format() { return std::to_string(N) + "s"; } 831}; 832 833template <typename T> 834struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> { 835 static std::string format() { 836 return format_descriptor< 837 typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format(); 838 } 839}; 840 841NAMESPACE_BEGIN(detail) 842template <typename T, int ExtraFlags> 843struct pyobject_caster<array_t<T, ExtraFlags>> { 844 using type = array_t<T, ExtraFlags>; 845 846 bool load(handle src, bool convert) { 847 if (!convert && !type::check_(src)) 848 return false; 849 value = type::ensure(src); 850 return static_cast<bool>(value); 851 } 852 853 static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) { 854 return src.inc_ref(); 855 } 856 PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name()); 857}; 858 859template <typename T> 860struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> { 861 static bool compare(const buffer_info& b) { 862 return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr()); 863 } 864}; 865 866template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> { 867private: 868 // NB: the order here must match the one in common.h 869 constexpr static const int values[15] = { 870 npy_api::NPY_BOOL_, 871 npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_, 872 npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_, 873 npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_, 874 npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_ 875 }; 876 877public: 878 static constexpr int value = values[detail::is_fmt_numeric<T>::index]; 879 880 static pybind11::dtype dtype() { 881 if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) 882 return reinterpret_borrow<pybind11::dtype>(ptr); 883 pybind11_fail("Unsupported buffer format!"); 884 } 885 template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0> 886 static PYBIND11_DESCR name() { 887 return _<std::is_same<T, bool>::value>(_("bool"), 888 _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>()); 889 } 890 template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0> 891 static PYBIND11_DESCR name() { 892 return _<std::is_same<T, float>::value || std::is_same<T, double>::value>( 893 _("float") + _<sizeof(T)*8>(), _("longdouble")); 894 } 895 template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0> 896 static PYBIND11_DESCR name() { 897 return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>( 898 _("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex")); 899 } 900}; 901 902#define PYBIND11_DECL_CHAR_FMT \ 903 static PYBIND11_DESCR name() { return _("S") + _<N>(); } \ 904 static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); } 905template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT }; 906template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT }; 907#undef PYBIND11_DECL_CHAR_FMT 908 909template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> { 910private: 911 using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>; 912public: 913 static PYBIND11_DESCR name() { return base_descr::name(); } 914 static pybind11::dtype dtype() { return base_descr::dtype(); } 915}; 916 917struct field_descriptor { 918 const char *name; 919 size_t offset; 920 size_t size; 921 size_t alignment; 922 std::string format; 923 dtype descr; 924}; 925 926inline PYBIND11_NOINLINE void register_structured_dtype( 927 const std::initializer_list<field_descriptor>& fields, 928 const std::type_info& tinfo, size_t itemsize, 929 bool (*direct_converter)(PyObject *, void *&)) { 930 931 auto& numpy_internals = get_numpy_internals(); 932 if (numpy_internals.get_type_info(tinfo, false)) 933 pybind11_fail("NumPy: dtype is already registered"); 934 935 list names, formats, offsets; 936 for (auto field : fields) { 937 if (!field.descr) 938 pybind11_fail(std::string("NumPy: unsupported field dtype: `") + 939 field.name + "` @ " + tinfo.name()); 940 names.append(PYBIND11_STR_TYPE(field.name)); 941 formats.append(field.descr); 942 offsets.append(pybind11::int_(field.offset)); 943 } 944 auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); 945 946 // There is an existing bug in NumPy (as of v1.11): trailing bytes are 947 // not encoded explicitly into the format string. This will supposedly 948 // get fixed in v1.12; for further details, see these: 949 // - https://github.com/numpy/numpy/issues/7797 950 // - https://github.com/numpy/numpy/pull/7798 951 // Because of this, we won't use numpy's logic to generate buffer format 952 // strings and will just do it ourselves. 953 std::vector<field_descriptor> ordered_fields(fields); 954 std::sort(ordered_fields.begin(), ordered_fields.end(), 955 [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); 956 size_t offset = 0; 957 std::ostringstream oss; 958 oss << "T{"; 959 for (auto& field : ordered_fields) { 960 if (field.offset > offset) 961 oss << (field.offset - offset) << 'x'; 962 // mark unaligned fields with '^' (unaligned native type) 963 if (field.offset % field.alignment) 964 oss << '^'; 965 oss << field.format << ':' << field.name << ':'; 966 offset = field.offset + field.size; 967 } 968 if (itemsize > offset) 969 oss << (itemsize - offset) << 'x'; 970 oss << '}'; 971 auto format_str = oss.str(); 972 973 // Sanity check: verify that NumPy properly parses our buffer format string 974 auto& api = npy_api::get(); 975 auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); 976 if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) 977 pybind11_fail("NumPy: invalid buffer descriptor!"); 978 979 auto tindex = std::type_index(tinfo); 980 numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; 981 get_internals().direct_conversions[tindex].push_back(direct_converter); 982} 983 984template <typename T, typename SFINAE> struct npy_format_descriptor { 985 static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype"); 986 987 static PYBIND11_DESCR name() { return make_caster<T>::name(); } 988 989 static pybind11::dtype dtype() { 990 return reinterpret_borrow<pybind11::dtype>(dtype_ptr()); 991 } 992 993 static std::string format() { 994 static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str; 995 return format_str; 996 } 997 998 static void register_dtype(const std::initializer_list<field_descriptor>& fields) { 999 register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type), 1000 sizeof(T), &direct_converter); 1001 } 1002 1003private: 1004 static PyObject* dtype_ptr() { 1005 static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr; 1006 return ptr; 1007 } 1008 1009 static bool direct_converter(PyObject *obj, void*& value) { 1010 auto& api = npy_api::get(); 1011 if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) 1012 return false; 1013 if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) { 1014 if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) { 1015 value = ((PyVoidScalarObject_Proxy *) obj)->obval; 1016 return true; 1017 } 1018 } 1019 return false; 1020 } 1021}; 1022 1023#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ 1024 ::pybind11::detail::field_descriptor { \ 1025 Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \ 1026 alignof(decltype(std::declval<T>().Field)), \ 1027 ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \ 1028 ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \ 1029 } 1030 1031// Extract name, offset and format descriptor for a struct field 1032#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field) 1033 1034// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro 1035// (C) William Swanson, Paul Fultz 1036#define PYBIND11_EVAL0(...) __VA_ARGS__ 1037#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__))) 1038#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__))) 1039#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__))) 1040#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__))) 1041#define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__))) 1042#define PYBIND11_MAP_END(...) 1043#define PYBIND11_MAP_OUT 1044#define PYBIND11_MAP_COMMA , 1045#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END 1046#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT 1047#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0) 1048#define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next) 1049#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround 1050#define PYBIND11_MAP_LIST_NEXT1(test, next) \ 1051 PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)) 1052#else 1053#define PYBIND11_MAP_LIST_NEXT1(test, next) \ 1054 PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0) 1055#endif 1056#define PYBIND11_MAP_LIST_NEXT(test, next) \ 1057 PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next) 1058#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \ 1059 f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__) 1060#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \ 1061 f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__) 1062// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ... 1063#define PYBIND11_MAP_LIST(f, t, ...) \ 1064 PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0)) 1065 1066#define PYBIND11_NUMPY_DTYPE(Type, ...) \ 1067 ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \ 1068 ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)}) 1069 1070#ifdef _MSC_VER 1071#define PYBIND11_MAP2_LIST_NEXT1(test, next) \ 1072 PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)) 1073#else 1074#define PYBIND11_MAP2_LIST_NEXT1(test, next) \ 1075 PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0) 1076#endif 1077#define PYBIND11_MAP2_LIST_NEXT(test, next) \ 1078 PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next) 1079#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \ 1080 f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__) 1081#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \ 1082 f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__) 1083// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ... 1084#define PYBIND11_MAP2_LIST(f, t, ...) \ 1085 PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0)) 1086 1087#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \ 1088 ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \ 1089 ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)}) 1090 1091template <class T> 1092using array_iterator = typename std::add_pointer<T>::type; 1093 1094template <class T> 1095array_iterator<T> array_begin(const buffer_info& buffer) { 1096 return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr)); 1097} 1098 1099template <class T> 1100array_iterator<T> array_end(const buffer_info& buffer) { 1101 return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size); 1102} 1103 1104class common_iterator { 1105public: 1106 using container_type = std::vector<size_t>; 1107 using value_type = container_type::value_type; 1108 using size_type = container_type::size_type; 1109 1110 common_iterator() : p_ptr(0), m_strides() {} 1111 1112 common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape) 1113 : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) { 1114 m_strides.back() = static_cast<value_type>(strides.back()); 1115 for (size_type i = m_strides.size() - 1; i != 0; --i) { 1116 size_type j = i - 1; 1117 value_type s = static_cast<value_type>(shape[i]); 1118 m_strides[j] = strides[j] + m_strides[i] - strides[i] * s; 1119 } 1120 } 1121 1122 void increment(size_type dim) { 1123 p_ptr += m_strides[dim]; 1124 } 1125 1126 void* data() const { 1127 return p_ptr; 1128 } 1129 1130private: 1131 char* p_ptr; 1132 container_type m_strides; 1133}; 1134 1135template <size_t N> class multi_array_iterator { 1136public: 1137 using container_type = std::vector<size_t>; 1138 1139 multi_array_iterator(const std::array<buffer_info, N> &buffers, 1140 const std::vector<size_t> &shape) 1141 : m_shape(shape.size()), m_index(shape.size(), 0), 1142 m_common_iterator() { 1143 1144 // Manual copy to avoid conversion warning if using std::copy 1145 for (size_t i = 0; i < shape.size(); ++i) 1146 m_shape[i] = static_cast<container_type::value_type>(shape[i]); 1147 1148 container_type strides(shape.size()); 1149 for (size_t i = 0; i < N; ++i) 1150 init_common_iterator(buffers[i], shape, m_common_iterator[i], strides); 1151 } 1152 1153 multi_array_iterator& operator++() { 1154 for (size_t j = m_index.size(); j != 0; --j) { 1155 size_t i = j - 1; 1156 if (++m_index[i] != m_shape[i]) { 1157 increment_common_iterator(i); 1158 break; 1159 } else { 1160 m_index[i] = 0; 1161 } 1162 } 1163 return *this; 1164 } 1165 1166 template <size_t K, class T> const T& data() const { 1167 return *reinterpret_cast<T*>(m_common_iterator[K].data()); 1168 } 1169 1170private: 1171 1172 using common_iter = common_iterator; 1173 1174 void init_common_iterator(const buffer_info &buffer, 1175 const std::vector<size_t> &shape, 1176 common_iter &iterator, container_type &strides) { 1177 auto buffer_shape_iter = buffer.shape.rbegin(); 1178 auto buffer_strides_iter = buffer.strides.rbegin(); 1179 auto shape_iter = shape.rbegin(); 1180 auto strides_iter = strides.rbegin(); 1181 1182 while (buffer_shape_iter != buffer.shape.rend()) { 1183 if (*shape_iter == *buffer_shape_iter) 1184 *strides_iter = static_cast<size_t>(*buffer_strides_iter); 1185 else 1186 *strides_iter = 0; 1187 1188 ++buffer_shape_iter; 1189 ++buffer_strides_iter; 1190 ++shape_iter; 1191 ++strides_iter; 1192 } 1193 1194 std::fill(strides_iter, strides.rend(), 0); 1195 iterator = common_iter(buffer.ptr, strides, shape); 1196 } 1197 1198 void increment_common_iterator(size_t dim) { 1199 for (auto &iter : m_common_iterator) 1200 iter.increment(dim); 1201 } 1202 1203 container_type m_shape; 1204 container_type m_index; 1205 std::array<common_iter, N> m_common_iterator; 1206}; 1207 1208enum class broadcast_trivial { non_trivial, c_trivial, f_trivial }; 1209 1210// Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial 1211// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a 1212// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage 1213// buffer; returns `non_trivial` otherwise. 1214template <size_t N> 1215broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, size_t &ndim, std::vector<size_t> &shape) { 1216 ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) { 1217 return std::max(res, buf.ndim); 1218 }); 1219 1220 shape.clear(); 1221 shape.resize(ndim, 1); 1222 1223 // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or 1224 // the full size). 1225 for (size_t i = 0; i < N; ++i) { 1226 auto res_iter = shape.rbegin(); 1227 auto end = buffers[i].shape.rend(); 1228 for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) { 1229 const auto &dim_size_in = *shape_iter; 1230 auto &dim_size_out = *res_iter; 1231 1232 // Each input dimension can either be 1 or `n`, but `n` values must match across buffers 1233 if (dim_size_out == 1) 1234 dim_size_out = dim_size_in; 1235 else if (dim_size_in != 1 && dim_size_in != dim_size_out) 1236 pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!"); 1237 } 1238 } 1239 1240 bool trivial_broadcast_c = true; 1241 bool trivial_broadcast_f = true; 1242 for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) { 1243 if (buffers[i].size == 1) 1244 continue; 1245 1246 // Require the same number of dimensions: 1247 if (buffers[i].ndim != ndim) 1248 return broadcast_trivial::non_trivial; 1249 1250 // Require all dimensions be full-size: 1251 if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) 1252 return broadcast_trivial::non_trivial; 1253 1254 // Check for C contiguity (but only if previous inputs were also C contiguous) 1255 if (trivial_broadcast_c) { 1256 size_t expect_stride = buffers[i].itemsize; 1257 auto end = buffers[i].shape.crend(); 1258 for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin(); 1259 trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) { 1260 if (expect_stride == *stride_iter) 1261 expect_stride *= *shape_iter; 1262 else 1263 trivial_broadcast_c = false; 1264 } 1265 } 1266 1267 // Check for Fortran contiguity (if previous inputs were also F contiguous) 1268 if (trivial_broadcast_f) { 1269 size_t expect_stride = buffers[i].itemsize; 1270 auto end = buffers[i].shape.cend(); 1271 for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin(); 1272 trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) { 1273 if (expect_stride == *stride_iter) 1274 expect_stride *= *shape_iter; 1275 else 1276 trivial_broadcast_f = false; 1277 } 1278 } 1279 } 1280 1281 return 1282 trivial_broadcast_c ? broadcast_trivial::c_trivial : 1283 trivial_broadcast_f ? broadcast_trivial::f_trivial : 1284 broadcast_trivial::non_trivial; 1285} 1286 1287template <typename Func, typename Return, typename... Args> 1288struct vectorize_helper { 1289 typename std::remove_reference<Func>::type f; 1290 static constexpr size_t N = sizeof...(Args); 1291 1292 template <typename T> 1293 explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { } 1294 1295 object operator()(array_t<Args, array::forcecast>... args) { 1296 return run(args..., make_index_sequence<N>()); 1297 } 1298 1299 template <size_t ... Index> object run(array_t<Args, array::forcecast>&... args, index_sequence<Index...> index) { 1300 /* Request buffers from all parameters */ 1301 std::array<buffer_info, N> buffers {{ args.request()... }}; 1302 1303 /* Determine dimensions parameters of output array */ 1304 size_t ndim = 0; 1305 std::vector<size_t> shape(0); 1306 auto trivial = broadcast(buffers, ndim, shape); 1307 1308 size_t size = 1; 1309 std::vector<size_t> strides(ndim); 1310 if (ndim > 0) { 1311 if (trivial == broadcast_trivial::f_trivial) { 1312 strides[0] = sizeof(Return); 1313 for (size_t i = 1; i < ndim; ++i) { 1314 strides[i] = strides[i - 1] * shape[i - 1]; 1315 size *= shape[i - 1]; 1316 } 1317 size *= shape[ndim - 1]; 1318 } 1319 else { 1320 strides[ndim-1] = sizeof(Return); 1321 for (size_t i = ndim - 1; i > 0; --i) { 1322 strides[i - 1] = strides[i] * shape[i]; 1323 size *= shape[i]; 1324 } 1325 size *= shape[0]; 1326 } 1327 } 1328 1329 if (size == 1) 1330 return cast(f(*reinterpret_cast<Args *>(buffers[Index].ptr)...)); 1331 1332 array_t<Return> result(shape, strides); 1333 auto buf = result.request(); 1334 auto output = (Return *) buf.ptr; 1335 1336 /* Call the function */ 1337 if (trivial == broadcast_trivial::non_trivial) { 1338 apply_broadcast<Index...>(buffers, buf, index); 1339 } else { 1340 for (size_t i = 0; i < size; ++i) 1341 output[i] = f((reinterpret_cast<Args *>(buffers[Index].ptr)[buffers[Index].size == 1 ? 0 : i])...); 1342 } 1343 1344 return result; 1345 } 1346 1347 template <size_t... Index> 1348 void apply_broadcast(const std::array<buffer_info, N> &buffers, 1349 buffer_info &output, index_sequence<Index...>) { 1350 using input_iterator = multi_array_iterator<N>; 1351 using output_iterator = array_iterator<Return>; 1352 1353 input_iterator input_iter(buffers, output.shape); 1354 output_iterator output_end = array_end<Return>(output); 1355 1356 for (output_iterator iter = array_begin<Return>(output); 1357 iter != output_end; ++iter, ++input_iter) { 1358 *iter = f((input_iter.template data<Index, Args>())...); 1359 } 1360 } 1361}; 1362 1363template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> { 1364 static PYBIND11_DESCR name() { 1365 return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]"); 1366 } 1367}; 1368 1369NAMESPACE_END(detail) 1370 1371template <typename Func, typename Return, typename... Args> 1372detail::vectorize_helper<Func, Return, Args...> 1373vectorize(const Func &f, Return (*) (Args ...)) { 1374 return detail::vectorize_helper<Func, Return, Args...>(f); 1375} 1376 1377template <typename Return, typename... Args> 1378detail::vectorize_helper<Return (*) (Args ...), Return, Args...> 1379vectorize(Return (*f) (Args ...)) { 1380 return vectorize<Return (*) (Args ...), Return, Args...>(f, f); 1381} 1382 1383template <typename Func, typename FuncType = typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type> 1384auto vectorize(Func &&f) -> decltype( 1385 vectorize(std::forward<Func>(f), (FuncType *) nullptr)) { 1386 return vectorize(std::forward<Func>(f), (FuncType *) nullptr); 1387} 1388 1389NAMESPACE_END(pybind11) 1390 1391#if defined(_MSC_VER) 1392#pragma warning(pop) 1393#endif 1394