test_numpy_array.cpp revision 12037
1/* 2 tests/test_numpy_array.cpp -- test core array functionality 3 4 Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com> 5 6 All rights reserved. Use of this source code is governed by a 7 BSD-style license that can be found in the LICENSE file. 8*/ 9 10#include "pybind11_tests.h" 11 12#include <pybind11/numpy.h> 13#include <pybind11/stl.h> 14 15#include <cstdint> 16#include <vector> 17 18using arr = py::array; 19using arr_t = py::array_t<uint16_t, 0>; 20static_assert(std::is_same<arr_t::value_type, uint16_t>::value, ""); 21 22template<typename... Ix> arr data(const arr& a, Ix... index) { 23 return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...)); 24} 25 26template<typename... Ix> arr data_t(const arr_t& a, Ix... index) { 27 return arr(a.size() - a.index_at(index...), a.data(index...)); 28} 29 30arr& mutate_data(arr& a) { 31 auto ptr = (uint8_t *) a.mutable_data(); 32 for (size_t i = 0; i < a.nbytes(); i++) 33 ptr[i] = (uint8_t) (ptr[i] * 2); 34 return a; 35} 36 37arr_t& mutate_data_t(arr_t& a) { 38 auto ptr = a.mutable_data(); 39 for (size_t i = 0; i < a.size(); i++) 40 ptr[i]++; 41 return a; 42} 43 44template<typename... Ix> arr& mutate_data(arr& a, Ix... index) { 45 auto ptr = (uint8_t *) a.mutable_data(index...); 46 for (size_t i = 0; i < a.nbytes() - a.offset_at(index...); i++) 47 ptr[i] = (uint8_t) (ptr[i] * 2); 48 return a; 49} 50 51template<typename... Ix> arr_t& mutate_data_t(arr_t& a, Ix... index) { 52 auto ptr = a.mutable_data(index...); 53 for (size_t i = 0; i < a.size() - a.index_at(index...); i++) 54 ptr[i]++; 55 return a; 56} 57 58template<typename... Ix> size_t index_at(const arr& a, Ix... idx) { return a.index_at(idx...); } 59template<typename... Ix> size_t index_at_t(const arr_t& a, Ix... idx) { return a.index_at(idx...); } 60template<typename... Ix> size_t offset_at(const arr& a, Ix... idx) { return a.offset_at(idx...); } 61template<typename... Ix> size_t offset_at_t(const arr_t& a, Ix... idx) { return a.offset_at(idx...); } 62template<typename... Ix> size_t at_t(const arr_t& a, Ix... idx) { return a.at(idx...); } 63template<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(idx...)++; return a; } 64 65#define def_index_fn(name, type) \ 66 sm.def(#name, [](type a) { return name(a); }); \ 67 sm.def(#name, [](type a, int i) { return name(a, i); }); \ 68 sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \ 69 sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); }); 70 71template <typename T, typename T2> py::handle auxiliaries(T &&r, T2 &&r2) { 72 if (r.ndim() != 2) throw std::domain_error("error: ndim != 2"); 73 py::list l; 74 l.append(*r.data(0, 0)); 75 l.append(*r2.mutable_data(0, 0)); 76 l.append(r.data(0, 1) == r2.mutable_data(0, 1)); 77 l.append(r.ndim()); 78 l.append(r.itemsize()); 79 l.append(r.shape(0)); 80 l.append(r.shape(1)); 81 l.append(r.size()); 82 l.append(r.nbytes()); 83 return l.release(); 84} 85 86test_initializer numpy_array([](py::module &m) { 87 auto sm = m.def_submodule("array"); 88 89 sm.def("ndim", [](const arr& a) { return a.ndim(); }); 90 sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); }); 91 sm.def("shape", [](const arr& a, size_t dim) { return a.shape(dim); }); 92 sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); }); 93 sm.def("strides", [](const arr& a, size_t dim) { return a.strides(dim); }); 94 sm.def("writeable", [](const arr& a) { return a.writeable(); }); 95 sm.def("size", [](const arr& a) { return a.size(); }); 96 sm.def("itemsize", [](const arr& a) { return a.itemsize(); }); 97 sm.def("nbytes", [](const arr& a) { return a.nbytes(); }); 98 sm.def("owndata", [](const arr& a) { return a.owndata(); }); 99 100 def_index_fn(data, const arr&); 101 def_index_fn(data_t, const arr_t&); 102 def_index_fn(index_at, const arr&); 103 def_index_fn(index_at_t, const arr_t&); 104 def_index_fn(offset_at, const arr&); 105 def_index_fn(offset_at_t, const arr_t&); 106 def_index_fn(mutate_data, arr&); 107 def_index_fn(mutate_data_t, arr_t&); 108 def_index_fn(at_t, const arr_t&); 109 def_index_fn(mutate_at_t, arr_t&); 110 111 sm.def("make_f_array", [] { 112 return py::array_t<float>({ 2, 2 }, { 4, 8 }); 113 }); 114 115 sm.def("make_c_array", [] { 116 return py::array_t<float>({ 2, 2 }, { 8, 4 }); 117 }); 118 119 sm.def("wrap", [](py::array a) { 120 return py::array( 121 a.dtype(), 122 std::vector<size_t>(a.shape(), a.shape() + a.ndim()), 123 std::vector<size_t>(a.strides(), a.strides() + a.ndim()), 124 a.data(), 125 a 126 ); 127 }); 128 129 struct ArrayClass { 130 int data[2] = { 1, 2 }; 131 ArrayClass() { py::print("ArrayClass()"); } 132 ~ArrayClass() { py::print("~ArrayClass()"); } 133 }; 134 135 py::class_<ArrayClass>(sm, "ArrayClass") 136 .def(py::init<>()) 137 .def("numpy_view", [](py::object &obj) { 138 py::print("ArrayClass::numpy_view()"); 139 ArrayClass &a = obj.cast<ArrayClass&>(); 140 return py::array_t<int>({2}, {4}, a.data, obj); 141 } 142 ); 143 144 sm.def("function_taking_uint64", [](uint64_t) { }); 145 146 sm.def("isinstance_untyped", [](py::object yes, py::object no) { 147 return py::isinstance<py::array>(yes) && !py::isinstance<py::array>(no); 148 }); 149 150 sm.def("isinstance_typed", [](py::object o) { 151 return py::isinstance<py::array_t<double>>(o) && !py::isinstance<py::array_t<int>>(o); 152 }); 153 154 sm.def("default_constructors", []() { 155 return py::dict( 156 "array"_a=py::array(), 157 "array_t<int32>"_a=py::array_t<std::int32_t>(), 158 "array_t<double>"_a=py::array_t<double>() 159 ); 160 }); 161 162 sm.def("converting_constructors", [](py::object o) { 163 return py::dict( 164 "array"_a=py::array(o), 165 "array_t<int32>"_a=py::array_t<std::int32_t>(o), 166 "array_t<double>"_a=py::array_t<double>(o) 167 ); 168 }); 169 170 // Overload resolution tests: 171 sm.def("overloaded", [](py::array_t<double>) { return "double"; }); 172 sm.def("overloaded", [](py::array_t<float>) { return "float"; }); 173 sm.def("overloaded", [](py::array_t<int>) { return "int"; }); 174 sm.def("overloaded", [](py::array_t<unsigned short>) { return "unsigned short"; }); 175 sm.def("overloaded", [](py::array_t<long long>) { return "long long"; }); 176 sm.def("overloaded", [](py::array_t<std::complex<double>>) { return "double complex"; }); 177 sm.def("overloaded", [](py::array_t<std::complex<float>>) { return "float complex"; }); 178 179 sm.def("overloaded2", [](py::array_t<std::complex<double>>) { return "double complex"; }); 180 sm.def("overloaded2", [](py::array_t<double>) { return "double"; }); 181 sm.def("overloaded2", [](py::array_t<std::complex<float>>) { return "float complex"; }); 182 sm.def("overloaded2", [](py::array_t<float>) { return "float"; }); 183 184 // Only accept the exact types: 185 sm.def("overloaded3", [](py::array_t<int>) { return "int"; }, py::arg().noconvert()); 186 sm.def("overloaded3", [](py::array_t<double>) { return "double"; }, py::arg().noconvert()); 187 188 // Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but 189 // rather that float gets converted via the safe (conversion to double) overload: 190 sm.def("overloaded4", [](py::array_t<long long, 0>) { return "long long"; }); 191 sm.def("overloaded4", [](py::array_t<double, 0>) { return "double"; }); 192 193 // But we do allow conversion to int if forcecast is enabled (but only if no overload matches 194 // without conversion) 195 sm.def("overloaded5", [](py::array_t<unsigned int>) { return "unsigned int"; }); 196 sm.def("overloaded5", [](py::array_t<double>) { return "double"; }); 197 198 // Issue 685: ndarray shouldn't go to std::string overload 199 sm.def("issue685", [](std::string) { return "string"; }); 200 sm.def("issue685", [](py::array) { return "array"; }); 201 sm.def("issue685", [](py::object) { return "other"; }); 202 203 sm.def("proxy_add2", [](py::array_t<double> a, double v) { 204 auto r = a.mutable_unchecked<2>(); 205 for (size_t i = 0; i < r.shape(0); i++) 206 for (size_t j = 0; j < r.shape(1); j++) 207 r(i, j) += v; 208 }, py::arg().noconvert(), py::arg()); 209 210 sm.def("proxy_init3", [](double start) { 211 py::array_t<double, py::array::c_style> a({ 3, 3, 3 }); 212 auto r = a.mutable_unchecked<3>(); 213 for (size_t i = 0; i < r.shape(0); i++) 214 for (size_t j = 0; j < r.shape(1); j++) 215 for (size_t k = 0; k < r.shape(2); k++) 216 r(i, j, k) = start++; 217 return a; 218 }); 219 sm.def("proxy_init3F", [](double start) { 220 py::array_t<double, py::array::f_style> a({ 3, 3, 3 }); 221 auto r = a.mutable_unchecked<3>(); 222 for (size_t k = 0; k < r.shape(2); k++) 223 for (size_t j = 0; j < r.shape(1); j++) 224 for (size_t i = 0; i < r.shape(0); i++) 225 r(i, j, k) = start++; 226 return a; 227 }); 228 sm.def("proxy_squared_L2_norm", [](py::array_t<double> a) { 229 auto r = a.unchecked<1>(); 230 double sumsq = 0; 231 for (size_t i = 0; i < r.shape(0); i++) 232 sumsq += r[i] * r(i); // Either notation works for a 1D array 233 return sumsq; 234 }); 235 236 sm.def("proxy_auxiliaries2", [](py::array_t<double> a) { 237 auto r = a.unchecked<2>(); 238 auto r2 = a.mutable_unchecked<2>(); 239 return auxiliaries(r, r2); 240 }); 241 242 // Same as the above, but without a compile-time dimensions specification: 243 sm.def("proxy_add2_dyn", [](py::array_t<double> a, double v) { 244 auto r = a.mutable_unchecked(); 245 if (r.ndim() != 2) throw std::domain_error("error: ndim != 2"); 246 for (size_t i = 0; i < r.shape(0); i++) 247 for (size_t j = 0; j < r.shape(1); j++) 248 r(i, j) += v; 249 }, py::arg().noconvert(), py::arg()); 250 sm.def("proxy_init3_dyn", [](double start) { 251 py::array_t<double, py::array::c_style> a({ 3, 3, 3 }); 252 auto r = a.mutable_unchecked(); 253 if (r.ndim() != 3) throw std::domain_error("error: ndim != 3"); 254 for (size_t i = 0; i < r.shape(0); i++) 255 for (size_t j = 0; j < r.shape(1); j++) 256 for (size_t k = 0; k < r.shape(2); k++) 257 r(i, j, k) = start++; 258 return a; 259 }); 260 sm.def("proxy_auxiliaries2_dyn", [](py::array_t<double> a) { 261 return auxiliaries(a.unchecked(), a.mutable_unchecked()); 262 }); 263 264 sm.def("array_auxiliaries2", [](py::array_t<double> a) { 265 return auxiliaries(a, a); 266 }); 267}); 268