test_numpy_array.cpp revision 14299
111986Sandreas.sandberg@arm.com/*
211986Sandreas.sandberg@arm.com    tests/test_numpy_array.cpp -- test core array functionality
311986Sandreas.sandberg@arm.com
411986Sandreas.sandberg@arm.com    Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
511986Sandreas.sandberg@arm.com
611986Sandreas.sandberg@arm.com    All rights reserved. Use of this source code is governed by a
711986Sandreas.sandberg@arm.com    BSD-style license that can be found in the LICENSE file.
811986Sandreas.sandberg@arm.com*/
911986Sandreas.sandberg@arm.com
1011986Sandreas.sandberg@arm.com#include "pybind11_tests.h"
1111986Sandreas.sandberg@arm.com
1211986Sandreas.sandberg@arm.com#include <pybind11/numpy.h>
1311986Sandreas.sandberg@arm.com#include <pybind11/stl.h>
1411986Sandreas.sandberg@arm.com
1511986Sandreas.sandberg@arm.com#include <cstdint>
1611986Sandreas.sandberg@arm.com
1714299Sbbruce@ucdavis.edu// Size / dtype checks.
1814299Sbbruce@ucdavis.edustruct DtypeCheck {
1914299Sbbruce@ucdavis.edu    py::dtype numpy{};
2014299Sbbruce@ucdavis.edu    py::dtype pybind11{};
2114299Sbbruce@ucdavis.edu};
2214299Sbbruce@ucdavis.edu
2314299Sbbruce@ucdavis.edutemplate <typename T>
2414299Sbbruce@ucdavis.eduDtypeCheck get_dtype_check(const char* name) {
2514299Sbbruce@ucdavis.edu    py::module np = py::module::import("numpy");
2614299Sbbruce@ucdavis.edu    DtypeCheck check{};
2714299Sbbruce@ucdavis.edu    check.numpy = np.attr("dtype")(np.attr(name));
2814299Sbbruce@ucdavis.edu    check.pybind11 = py::dtype::of<T>();
2914299Sbbruce@ucdavis.edu    return check;
3014299Sbbruce@ucdavis.edu}
3114299Sbbruce@ucdavis.edu
3214299Sbbruce@ucdavis.edustd::vector<DtypeCheck> get_concrete_dtype_checks() {
3314299Sbbruce@ucdavis.edu    return {
3414299Sbbruce@ucdavis.edu        // Normalization
3514299Sbbruce@ucdavis.edu        get_dtype_check<std::int8_t>("int8"),
3614299Sbbruce@ucdavis.edu        get_dtype_check<std::uint8_t>("uint8"),
3714299Sbbruce@ucdavis.edu        get_dtype_check<std::int16_t>("int16"),
3814299Sbbruce@ucdavis.edu        get_dtype_check<std::uint16_t>("uint16"),
3914299Sbbruce@ucdavis.edu        get_dtype_check<std::int32_t>("int32"),
4014299Sbbruce@ucdavis.edu        get_dtype_check<std::uint32_t>("uint32"),
4114299Sbbruce@ucdavis.edu        get_dtype_check<std::int64_t>("int64"),
4214299Sbbruce@ucdavis.edu        get_dtype_check<std::uint64_t>("uint64")
4314299Sbbruce@ucdavis.edu    };
4414299Sbbruce@ucdavis.edu}
4514299Sbbruce@ucdavis.edu
4614299Sbbruce@ucdavis.edustruct DtypeSizeCheck {
4714299Sbbruce@ucdavis.edu    std::string name{};
4814299Sbbruce@ucdavis.edu    int size_cpp{};
4914299Sbbruce@ucdavis.edu    int size_numpy{};
5014299Sbbruce@ucdavis.edu    // For debugging.
5114299Sbbruce@ucdavis.edu    py::dtype dtype{};
5214299Sbbruce@ucdavis.edu};
5314299Sbbruce@ucdavis.edu
5414299Sbbruce@ucdavis.edutemplate <typename T>
5514299Sbbruce@ucdavis.eduDtypeSizeCheck get_dtype_size_check() {
5614299Sbbruce@ucdavis.edu    DtypeSizeCheck check{};
5714299Sbbruce@ucdavis.edu    check.name = py::type_id<T>();
5814299Sbbruce@ucdavis.edu    check.size_cpp = sizeof(T);
5914299Sbbruce@ucdavis.edu    check.dtype = py::dtype::of<T>();
6014299Sbbruce@ucdavis.edu    check.size_numpy = check.dtype.attr("itemsize").template cast<int>();
6114299Sbbruce@ucdavis.edu    return check;
6214299Sbbruce@ucdavis.edu}
6314299Sbbruce@ucdavis.edu
6414299Sbbruce@ucdavis.edustd::vector<DtypeSizeCheck> get_platform_dtype_size_checks() {
6514299Sbbruce@ucdavis.edu    return {
6614299Sbbruce@ucdavis.edu        get_dtype_size_check<short>(),
6714299Sbbruce@ucdavis.edu        get_dtype_size_check<unsigned short>(),
6814299Sbbruce@ucdavis.edu        get_dtype_size_check<int>(),
6914299Sbbruce@ucdavis.edu        get_dtype_size_check<unsigned int>(),
7014299Sbbruce@ucdavis.edu        get_dtype_size_check<long>(),
7114299Sbbruce@ucdavis.edu        get_dtype_size_check<unsigned long>(),
7214299Sbbruce@ucdavis.edu        get_dtype_size_check<long long>(),
7314299Sbbruce@ucdavis.edu        get_dtype_size_check<unsigned long long>(),
7414299Sbbruce@ucdavis.edu    };
7514299Sbbruce@ucdavis.edu}
7614299Sbbruce@ucdavis.edu
7714299Sbbruce@ucdavis.edu// Arrays.
7811986Sandreas.sandberg@arm.comusing arr = py::array;
7911986Sandreas.sandberg@arm.comusing arr_t = py::array_t<uint16_t, 0>;
8012037Sandreas.sandberg@arm.comstatic_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
8111986Sandreas.sandberg@arm.com
8211986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr data(const arr& a, Ix... index) {
8311986Sandreas.sandberg@arm.com    return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
8411986Sandreas.sandberg@arm.com}
8511986Sandreas.sandberg@arm.com
8611986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr data_t(const arr_t& a, Ix... index) {
8711986Sandreas.sandberg@arm.com    return arr(a.size() - a.index_at(index...), a.data(index...));
8811986Sandreas.sandberg@arm.com}
8911986Sandreas.sandberg@arm.com
9011986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr& mutate_data(arr& a, Ix... index) {
9111986Sandreas.sandberg@arm.com    auto ptr = (uint8_t *) a.mutable_data(index...);
9212391Sjason@lowepower.com    for (ssize_t i = 0; i < a.nbytes() - a.offset_at(index...); i++)
9311986Sandreas.sandberg@arm.com        ptr[i] = (uint8_t) (ptr[i] * 2);
9411986Sandreas.sandberg@arm.com    return a;
9511986Sandreas.sandberg@arm.com}
9611986Sandreas.sandberg@arm.com
9711986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr_t& mutate_data_t(arr_t& a, Ix... index) {
9811986Sandreas.sandberg@arm.com    auto ptr = a.mutable_data(index...);
9912391Sjason@lowepower.com    for (ssize_t i = 0; i < a.size() - a.index_at(index...); i++)
10011986Sandreas.sandberg@arm.com        ptr[i]++;
10111986Sandreas.sandberg@arm.com    return a;
10211986Sandreas.sandberg@arm.com}
10311986Sandreas.sandberg@arm.com
10412391Sjason@lowepower.comtemplate<typename... Ix> ssize_t index_at(const arr& a, Ix... idx) { return a.index_at(idx...); }
10512391Sjason@lowepower.comtemplate<typename... Ix> ssize_t index_at_t(const arr_t& a, Ix... idx) { return a.index_at(idx...); }
10612391Sjason@lowepower.comtemplate<typename... Ix> ssize_t offset_at(const arr& a, Ix... idx) { return a.offset_at(idx...); }
10712391Sjason@lowepower.comtemplate<typename... Ix> ssize_t offset_at_t(const arr_t& a, Ix... idx) { return a.offset_at(idx...); }
10812391Sjason@lowepower.comtemplate<typename... Ix> ssize_t at_t(const arr_t& a, Ix... idx) { return a.at(idx...); }
10911986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(idx...)++; return a; }
11011986Sandreas.sandberg@arm.com
11111986Sandreas.sandberg@arm.com#define def_index_fn(name, type) \
11211986Sandreas.sandberg@arm.com    sm.def(#name, [](type a) { return name(a); }); \
11311986Sandreas.sandberg@arm.com    sm.def(#name, [](type a, int i) { return name(a, i); }); \
11411986Sandreas.sandberg@arm.com    sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
11511986Sandreas.sandberg@arm.com    sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
11611986Sandreas.sandberg@arm.com
11712037Sandreas.sandberg@arm.comtemplate <typename T, typename T2> py::handle auxiliaries(T &&r, T2 &&r2) {
11812037Sandreas.sandberg@arm.com    if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
11912037Sandreas.sandberg@arm.com    py::list l;
12012037Sandreas.sandberg@arm.com    l.append(*r.data(0, 0));
12112037Sandreas.sandberg@arm.com    l.append(*r2.mutable_data(0, 0));
12212037Sandreas.sandberg@arm.com    l.append(r.data(0, 1) == r2.mutable_data(0, 1));
12312037Sandreas.sandberg@arm.com    l.append(r.ndim());
12412037Sandreas.sandberg@arm.com    l.append(r.itemsize());
12512037Sandreas.sandberg@arm.com    l.append(r.shape(0));
12612037Sandreas.sandberg@arm.com    l.append(r.shape(1));
12712037Sandreas.sandberg@arm.com    l.append(r.size());
12812037Sandreas.sandberg@arm.com    l.append(r.nbytes());
12912037Sandreas.sandberg@arm.com    return l.release();
13012037Sandreas.sandberg@arm.com}
13112037Sandreas.sandberg@arm.com
13214299Sbbruce@ucdavis.edu// note: declaration at local scope would create a dangling reference!
13314299Sbbruce@ucdavis.edustatic int data_i = 42;
13414299Sbbruce@ucdavis.edu
13512391Sjason@lowepower.comTEST_SUBMODULE(numpy_array, sm) {
13612391Sjason@lowepower.com    try { py::module::import("numpy"); }
13712391Sjason@lowepower.com    catch (...) { return; }
13811986Sandreas.sandberg@arm.com
13914299Sbbruce@ucdavis.edu    // test_dtypes
14014299Sbbruce@ucdavis.edu    py::class_<DtypeCheck>(sm, "DtypeCheck")
14114299Sbbruce@ucdavis.edu        .def_readonly("numpy", &DtypeCheck::numpy)
14214299Sbbruce@ucdavis.edu        .def_readonly("pybind11", &DtypeCheck::pybind11)
14314299Sbbruce@ucdavis.edu        .def("__repr__", [](const DtypeCheck& self) {
14414299Sbbruce@ucdavis.edu            return py::str("<DtypeCheck numpy={} pybind11={}>").format(
14514299Sbbruce@ucdavis.edu                self.numpy, self.pybind11);
14614299Sbbruce@ucdavis.edu        });
14714299Sbbruce@ucdavis.edu    sm.def("get_concrete_dtype_checks", &get_concrete_dtype_checks);
14814299Sbbruce@ucdavis.edu
14914299Sbbruce@ucdavis.edu    py::class_<DtypeSizeCheck>(sm, "DtypeSizeCheck")
15014299Sbbruce@ucdavis.edu        .def_readonly("name", &DtypeSizeCheck::name)
15114299Sbbruce@ucdavis.edu        .def_readonly("size_cpp", &DtypeSizeCheck::size_cpp)
15214299Sbbruce@ucdavis.edu        .def_readonly("size_numpy", &DtypeSizeCheck::size_numpy)
15314299Sbbruce@ucdavis.edu        .def("__repr__", [](const DtypeSizeCheck& self) {
15414299Sbbruce@ucdavis.edu            return py::str("<DtypeSizeCheck name='{}' size_cpp={} size_numpy={} dtype={}>").format(
15514299Sbbruce@ucdavis.edu                self.name, self.size_cpp, self.size_numpy, self.dtype);
15614299Sbbruce@ucdavis.edu        });
15714299Sbbruce@ucdavis.edu    sm.def("get_platform_dtype_size_checks", &get_platform_dtype_size_checks);
15814299Sbbruce@ucdavis.edu
15912391Sjason@lowepower.com    // test_array_attributes
16011986Sandreas.sandberg@arm.com    sm.def("ndim", [](const arr& a) { return a.ndim(); });
16111986Sandreas.sandberg@arm.com    sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });
16212391Sjason@lowepower.com    sm.def("shape", [](const arr& a, ssize_t dim) { return a.shape(dim); });
16311986Sandreas.sandberg@arm.com    sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); });
16412391Sjason@lowepower.com    sm.def("strides", [](const arr& a, ssize_t dim) { return a.strides(dim); });
16511986Sandreas.sandberg@arm.com    sm.def("writeable", [](const arr& a) { return a.writeable(); });
16611986Sandreas.sandberg@arm.com    sm.def("size", [](const arr& a) { return a.size(); });
16711986Sandreas.sandberg@arm.com    sm.def("itemsize", [](const arr& a) { return a.itemsize(); });
16811986Sandreas.sandberg@arm.com    sm.def("nbytes", [](const arr& a) { return a.nbytes(); });
16911986Sandreas.sandberg@arm.com    sm.def("owndata", [](const arr& a) { return a.owndata(); });
17011986Sandreas.sandberg@arm.com
17112391Sjason@lowepower.com    // test_index_offset
17211986Sandreas.sandberg@arm.com    def_index_fn(index_at, const arr&);
17311986Sandreas.sandberg@arm.com    def_index_fn(index_at_t, const arr_t&);
17411986Sandreas.sandberg@arm.com    def_index_fn(offset_at, const arr&);
17511986Sandreas.sandberg@arm.com    def_index_fn(offset_at_t, const arr_t&);
17612391Sjason@lowepower.com    // test_data
17712391Sjason@lowepower.com    def_index_fn(data, const arr&);
17812391Sjason@lowepower.com    def_index_fn(data_t, const arr_t&);
17912391Sjason@lowepower.com    // test_mutate_data, test_mutate_readonly
18011986Sandreas.sandberg@arm.com    def_index_fn(mutate_data, arr&);
18111986Sandreas.sandberg@arm.com    def_index_fn(mutate_data_t, arr_t&);
18211986Sandreas.sandberg@arm.com    def_index_fn(at_t, const arr_t&);
18311986Sandreas.sandberg@arm.com    def_index_fn(mutate_at_t, arr_t&);
18411986Sandreas.sandberg@arm.com
18512391Sjason@lowepower.com    // test_make_c_f_array
18612391Sjason@lowepower.com    sm.def("make_f_array", [] { return py::array_t<float>({ 2, 2 }, { 4, 8 }); });
18712391Sjason@lowepower.com    sm.def("make_c_array", [] { return py::array_t<float>({ 2, 2 }, { 8, 4 }); });
18811986Sandreas.sandberg@arm.com
18914299Sbbruce@ucdavis.edu    // test_empty_shaped_array
19014299Sbbruce@ucdavis.edu    sm.def("make_empty_shaped_array", [] { return py::array(py::dtype("f"), {}, {}); });
19114299Sbbruce@ucdavis.edu    // test numpy scalars (empty shape, ndim==0)
19214299Sbbruce@ucdavis.edu    sm.def("scalar_int", []() { return py::array(py::dtype("i"), {}, {}, &data_i); });
19314299Sbbruce@ucdavis.edu
19412391Sjason@lowepower.com    // test_wrap
19511986Sandreas.sandberg@arm.com    sm.def("wrap", [](py::array a) {
19611986Sandreas.sandberg@arm.com        return py::array(
19711986Sandreas.sandberg@arm.com            a.dtype(),
19812391Sjason@lowepower.com            {a.shape(), a.shape() + a.ndim()},
19912391Sjason@lowepower.com            {a.strides(), a.strides() + a.ndim()},
20011986Sandreas.sandberg@arm.com            a.data(),
20111986Sandreas.sandberg@arm.com            a
20211986Sandreas.sandberg@arm.com        );
20311986Sandreas.sandberg@arm.com    });
20411986Sandreas.sandberg@arm.com
20512391Sjason@lowepower.com    // test_numpy_view
20611986Sandreas.sandberg@arm.com    struct ArrayClass {
20711986Sandreas.sandberg@arm.com        int data[2] = { 1, 2 };
20811986Sandreas.sandberg@arm.com        ArrayClass() { py::print("ArrayClass()"); }
20911986Sandreas.sandberg@arm.com        ~ArrayClass() { py::print("~ArrayClass()"); }
21011986Sandreas.sandberg@arm.com    };
21111986Sandreas.sandberg@arm.com    py::class_<ArrayClass>(sm, "ArrayClass")
21211986Sandreas.sandberg@arm.com        .def(py::init<>())
21311986Sandreas.sandberg@arm.com        .def("numpy_view", [](py::object &obj) {
21411986Sandreas.sandberg@arm.com            py::print("ArrayClass::numpy_view()");
21511986Sandreas.sandberg@arm.com            ArrayClass &a = obj.cast<ArrayClass&>();
21611986Sandreas.sandberg@arm.com            return py::array_t<int>({2}, {4}, a.data, obj);
21711986Sandreas.sandberg@arm.com        }
21811986Sandreas.sandberg@arm.com    );
21911986Sandreas.sandberg@arm.com
22012391Sjason@lowepower.com    // test_cast_numpy_int64_to_uint64
22111986Sandreas.sandberg@arm.com    sm.def("function_taking_uint64", [](uint64_t) { });
22211986Sandreas.sandberg@arm.com
22312391Sjason@lowepower.com    // test_isinstance
22411986Sandreas.sandberg@arm.com    sm.def("isinstance_untyped", [](py::object yes, py::object no) {
22511986Sandreas.sandberg@arm.com        return py::isinstance<py::array>(yes) && !py::isinstance<py::array>(no);
22611986Sandreas.sandberg@arm.com    });
22711986Sandreas.sandberg@arm.com    sm.def("isinstance_typed", [](py::object o) {
22811986Sandreas.sandberg@arm.com        return py::isinstance<py::array_t<double>>(o) && !py::isinstance<py::array_t<int>>(o);
22911986Sandreas.sandberg@arm.com    });
23011986Sandreas.sandberg@arm.com
23112391Sjason@lowepower.com    // test_constructors
23211986Sandreas.sandberg@arm.com    sm.def("default_constructors", []() {
23311986Sandreas.sandberg@arm.com        return py::dict(
23411986Sandreas.sandberg@arm.com            "array"_a=py::array(),
23511986Sandreas.sandberg@arm.com            "array_t<int32>"_a=py::array_t<std::int32_t>(),
23611986Sandreas.sandberg@arm.com            "array_t<double>"_a=py::array_t<double>()
23711986Sandreas.sandberg@arm.com        );
23811986Sandreas.sandberg@arm.com    });
23911986Sandreas.sandberg@arm.com    sm.def("converting_constructors", [](py::object o) {
24011986Sandreas.sandberg@arm.com        return py::dict(
24111986Sandreas.sandberg@arm.com            "array"_a=py::array(o),
24211986Sandreas.sandberg@arm.com            "array_t<int32>"_a=py::array_t<std::int32_t>(o),
24311986Sandreas.sandberg@arm.com            "array_t<double>"_a=py::array_t<double>(o)
24411986Sandreas.sandberg@arm.com        );
24511986Sandreas.sandberg@arm.com    });
24612037Sandreas.sandberg@arm.com
24712391Sjason@lowepower.com    // test_overload_resolution
24812037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<double>) { return "double"; });
24912037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<float>) { return "float"; });
25012037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<int>) { return "int"; });
25112037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<unsigned short>) { return "unsigned short"; });
25212037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<long long>) { return "long long"; });
25312037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<std::complex<double>>) { return "double complex"; });
25412037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<std::complex<float>>) { return "float complex"; });
25512037Sandreas.sandberg@arm.com
25612037Sandreas.sandberg@arm.com    sm.def("overloaded2", [](py::array_t<std::complex<double>>) { return "double complex"; });
25712037Sandreas.sandberg@arm.com    sm.def("overloaded2", [](py::array_t<double>) { return "double"; });
25812037Sandreas.sandberg@arm.com    sm.def("overloaded2", [](py::array_t<std::complex<float>>) { return "float complex"; });
25912037Sandreas.sandberg@arm.com    sm.def("overloaded2", [](py::array_t<float>) { return "float"; });
26012037Sandreas.sandberg@arm.com
26112037Sandreas.sandberg@arm.com    // Only accept the exact types:
26212037Sandreas.sandberg@arm.com    sm.def("overloaded3", [](py::array_t<int>) { return "int"; }, py::arg().noconvert());
26312037Sandreas.sandberg@arm.com    sm.def("overloaded3", [](py::array_t<double>) { return "double"; }, py::arg().noconvert());
26412037Sandreas.sandberg@arm.com
26512037Sandreas.sandberg@arm.com    // Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but
26612037Sandreas.sandberg@arm.com    // rather that float gets converted via the safe (conversion to double) overload:
26712037Sandreas.sandberg@arm.com    sm.def("overloaded4", [](py::array_t<long long, 0>) { return "long long"; });
26812037Sandreas.sandberg@arm.com    sm.def("overloaded4", [](py::array_t<double, 0>) { return "double"; });
26912037Sandreas.sandberg@arm.com
27012037Sandreas.sandberg@arm.com    // But we do allow conversion to int if forcecast is enabled (but only if no overload matches
27112037Sandreas.sandberg@arm.com    // without conversion)
27212037Sandreas.sandberg@arm.com    sm.def("overloaded5", [](py::array_t<unsigned int>) { return "unsigned int"; });
27312037Sandreas.sandberg@arm.com    sm.def("overloaded5", [](py::array_t<double>) { return "double"; });
27412037Sandreas.sandberg@arm.com
27512391Sjason@lowepower.com    // test_greedy_string_overload
27612037Sandreas.sandberg@arm.com    // Issue 685: ndarray shouldn't go to std::string overload
27712037Sandreas.sandberg@arm.com    sm.def("issue685", [](std::string) { return "string"; });
27812037Sandreas.sandberg@arm.com    sm.def("issue685", [](py::array) { return "array"; });
27912037Sandreas.sandberg@arm.com    sm.def("issue685", [](py::object) { return "other"; });
28012037Sandreas.sandberg@arm.com
28112391Sjason@lowepower.com    // test_array_unchecked_fixed_dims
28212037Sandreas.sandberg@arm.com    sm.def("proxy_add2", [](py::array_t<double> a, double v) {
28312037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked<2>();
28412391Sjason@lowepower.com        for (ssize_t i = 0; i < r.shape(0); i++)
28512391Sjason@lowepower.com            for (ssize_t j = 0; j < r.shape(1); j++)
28612037Sandreas.sandberg@arm.com                r(i, j) += v;
28712037Sandreas.sandberg@arm.com    }, py::arg().noconvert(), py::arg());
28812037Sandreas.sandberg@arm.com
28912037Sandreas.sandberg@arm.com    sm.def("proxy_init3", [](double start) {
29012037Sandreas.sandberg@arm.com        py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
29112037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked<3>();
29212391Sjason@lowepower.com        for (ssize_t i = 0; i < r.shape(0); i++)
29312391Sjason@lowepower.com        for (ssize_t j = 0; j < r.shape(1); j++)
29412391Sjason@lowepower.com        for (ssize_t k = 0; k < r.shape(2); k++)
29512037Sandreas.sandberg@arm.com            r(i, j, k) = start++;
29612037Sandreas.sandberg@arm.com        return a;
29712037Sandreas.sandberg@arm.com    });
29812037Sandreas.sandberg@arm.com    sm.def("proxy_init3F", [](double start) {
29912037Sandreas.sandberg@arm.com        py::array_t<double, py::array::f_style> a({ 3, 3, 3 });
30012037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked<3>();
30112391Sjason@lowepower.com        for (ssize_t k = 0; k < r.shape(2); k++)
30212391Sjason@lowepower.com        for (ssize_t j = 0; j < r.shape(1); j++)
30312391Sjason@lowepower.com        for (ssize_t i = 0; i < r.shape(0); i++)
30412037Sandreas.sandberg@arm.com            r(i, j, k) = start++;
30512037Sandreas.sandberg@arm.com        return a;
30612037Sandreas.sandberg@arm.com    });
30712037Sandreas.sandberg@arm.com    sm.def("proxy_squared_L2_norm", [](py::array_t<double> a) {
30812037Sandreas.sandberg@arm.com        auto r = a.unchecked<1>();
30912037Sandreas.sandberg@arm.com        double sumsq = 0;
31012391Sjason@lowepower.com        for (ssize_t i = 0; i < r.shape(0); i++)
31112037Sandreas.sandberg@arm.com            sumsq += r[i] * r(i); // Either notation works for a 1D array
31212037Sandreas.sandberg@arm.com        return sumsq;
31312037Sandreas.sandberg@arm.com    });
31412037Sandreas.sandberg@arm.com
31512037Sandreas.sandberg@arm.com    sm.def("proxy_auxiliaries2", [](py::array_t<double> a) {
31612037Sandreas.sandberg@arm.com        auto r = a.unchecked<2>();
31712037Sandreas.sandberg@arm.com        auto r2 = a.mutable_unchecked<2>();
31812037Sandreas.sandberg@arm.com        return auxiliaries(r, r2);
31912037Sandreas.sandberg@arm.com    });
32012037Sandreas.sandberg@arm.com
32112391Sjason@lowepower.com    // test_array_unchecked_dyn_dims
32212037Sandreas.sandberg@arm.com    // Same as the above, but without a compile-time dimensions specification:
32312037Sandreas.sandberg@arm.com    sm.def("proxy_add2_dyn", [](py::array_t<double> a, double v) {
32412037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked();
32512037Sandreas.sandberg@arm.com        if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
32612391Sjason@lowepower.com        for (ssize_t i = 0; i < r.shape(0); i++)
32712391Sjason@lowepower.com            for (ssize_t j = 0; j < r.shape(1); j++)
32812037Sandreas.sandberg@arm.com                r(i, j) += v;
32912037Sandreas.sandberg@arm.com    }, py::arg().noconvert(), py::arg());
33012037Sandreas.sandberg@arm.com    sm.def("proxy_init3_dyn", [](double start) {
33112037Sandreas.sandberg@arm.com        py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
33212037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked();
33312037Sandreas.sandberg@arm.com        if (r.ndim() != 3) throw std::domain_error("error: ndim != 3");
33412391Sjason@lowepower.com        for (ssize_t i = 0; i < r.shape(0); i++)
33512391Sjason@lowepower.com        for (ssize_t j = 0; j < r.shape(1); j++)
33612391Sjason@lowepower.com        for (ssize_t k = 0; k < r.shape(2); k++)
33712037Sandreas.sandberg@arm.com            r(i, j, k) = start++;
33812037Sandreas.sandberg@arm.com        return a;
33912037Sandreas.sandberg@arm.com    });
34012037Sandreas.sandberg@arm.com    sm.def("proxy_auxiliaries2_dyn", [](py::array_t<double> a) {
34112037Sandreas.sandberg@arm.com        return auxiliaries(a.unchecked(), a.mutable_unchecked());
34212037Sandreas.sandberg@arm.com    });
34312037Sandreas.sandberg@arm.com
34412037Sandreas.sandberg@arm.com    sm.def("array_auxiliaries2", [](py::array_t<double> a) {
34512037Sandreas.sandberg@arm.com        return auxiliaries(a, a);
34612037Sandreas.sandberg@arm.com    });
34712391Sjason@lowepower.com
34812391Sjason@lowepower.com    // test_array_failures
34912391Sjason@lowepower.com    // Issue #785: Uninformative "Unknown internal error" exception when constructing array from empty object:
35012391Sjason@lowepower.com    sm.def("array_fail_test", []() { return py::array(py::object()); });
35112391Sjason@lowepower.com    sm.def("array_t_fail_test", []() { return py::array_t<double>(py::object()); });
35212391Sjason@lowepower.com    // Make sure the error from numpy is being passed through:
35312391Sjason@lowepower.com    sm.def("array_fail_test_negative_size", []() { int c = 0; return py::array(-1, &c); });
35412391Sjason@lowepower.com
35512391Sjason@lowepower.com    // test_initializer_list
35612391Sjason@lowepower.com    // Issue (unnumbered; reported in #788): regression: initializer lists can be ambiguous
35712391Sjason@lowepower.com    sm.def("array_initializer_list1", []() { return py::array_t<float>(1); }); // { 1 } also works, but clang warns about it
35812391Sjason@lowepower.com    sm.def("array_initializer_list2", []() { return py::array_t<float>({ 1, 2 }); });
35912391Sjason@lowepower.com    sm.def("array_initializer_list3", []() { return py::array_t<float>({ 1, 2, 3 }); });
36012391Sjason@lowepower.com    sm.def("array_initializer_list4", []() { return py::array_t<float>({ 1, 2, 3, 4 }); });
36112391Sjason@lowepower.com
36212391Sjason@lowepower.com    // test_array_resize
36312391Sjason@lowepower.com    // reshape array to 2D without changing size
36412391Sjason@lowepower.com    sm.def("array_reshape2", [](py::array_t<double> a) {
36512391Sjason@lowepower.com        const ssize_t dim_sz = (ssize_t)std::sqrt(a.size());
36612391Sjason@lowepower.com        if (dim_sz * dim_sz != a.size())
36712391Sjason@lowepower.com            throw std::domain_error("array_reshape2: input array total size is not a squared integer");
36812391Sjason@lowepower.com        a.resize({dim_sz, dim_sz});
36912391Sjason@lowepower.com    });
37012391Sjason@lowepower.com
37112391Sjason@lowepower.com    // resize to 3D array with each dimension = N
37212391Sjason@lowepower.com    sm.def("array_resize3", [](py::array_t<double> a, size_t N, bool refcheck) {
37312391Sjason@lowepower.com        a.resize({N, N, N}, refcheck);
37412391Sjason@lowepower.com    });
37512391Sjason@lowepower.com
37612391Sjason@lowepower.com    // test_array_create_and_resize
37712391Sjason@lowepower.com    // return 2D array with Nrows = Ncols = N
37812391Sjason@lowepower.com    sm.def("create_and_resize", [](size_t N) {
37912391Sjason@lowepower.com        py::array_t<double> a;
38012391Sjason@lowepower.com        a.resize({N, N});
38112391Sjason@lowepower.com        std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.);
38212391Sjason@lowepower.com        return a;
38312391Sjason@lowepower.com    });
38414299Sbbruce@ucdavis.edu
38514299Sbbruce@ucdavis.edu#if PY_MAJOR_VERSION >= 3
38614299Sbbruce@ucdavis.edu        sm.def("index_using_ellipsis", [](py::array a) {
38714299Sbbruce@ucdavis.edu            return a[py::make_tuple(0, py::ellipsis(), 0)];
38814299Sbbruce@ucdavis.edu        });
38914299Sbbruce@ucdavis.edu#endif
39012391Sjason@lowepower.com}
391