test_numpy_array.cpp revision 12037
111986Sandreas.sandberg@arm.com/*
211986Sandreas.sandberg@arm.com    tests/test_numpy_array.cpp -- test core array functionality
311986Sandreas.sandberg@arm.com
411986Sandreas.sandberg@arm.com    Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
511986Sandreas.sandberg@arm.com
611986Sandreas.sandberg@arm.com    All rights reserved. Use of this source code is governed by a
711986Sandreas.sandberg@arm.com    BSD-style license that can be found in the LICENSE file.
811986Sandreas.sandberg@arm.com*/
911986Sandreas.sandberg@arm.com
1011986Sandreas.sandberg@arm.com#include "pybind11_tests.h"
1111986Sandreas.sandberg@arm.com
1211986Sandreas.sandberg@arm.com#include <pybind11/numpy.h>
1311986Sandreas.sandberg@arm.com#include <pybind11/stl.h>
1411986Sandreas.sandberg@arm.com
1511986Sandreas.sandberg@arm.com#include <cstdint>
1611986Sandreas.sandberg@arm.com#include <vector>
1711986Sandreas.sandberg@arm.com
1811986Sandreas.sandberg@arm.comusing arr = py::array;
1911986Sandreas.sandberg@arm.comusing arr_t = py::array_t<uint16_t, 0>;
2012037Sandreas.sandberg@arm.comstatic_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
2111986Sandreas.sandberg@arm.com
2211986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr data(const arr& a, Ix... index) {
2311986Sandreas.sandberg@arm.com    return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
2411986Sandreas.sandberg@arm.com}
2511986Sandreas.sandberg@arm.com
2611986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr data_t(const arr_t& a, Ix... index) {
2711986Sandreas.sandberg@arm.com    return arr(a.size() - a.index_at(index...), a.data(index...));
2811986Sandreas.sandberg@arm.com}
2911986Sandreas.sandberg@arm.com
3011986Sandreas.sandberg@arm.comarr& mutate_data(arr& a) {
3111986Sandreas.sandberg@arm.com    auto ptr = (uint8_t *) a.mutable_data();
3211986Sandreas.sandberg@arm.com    for (size_t i = 0; i < a.nbytes(); i++)
3311986Sandreas.sandberg@arm.com        ptr[i] = (uint8_t) (ptr[i] * 2);
3411986Sandreas.sandberg@arm.com    return a;
3511986Sandreas.sandberg@arm.com}
3611986Sandreas.sandberg@arm.com
3711986Sandreas.sandberg@arm.comarr_t& mutate_data_t(arr_t& a) {
3811986Sandreas.sandberg@arm.com    auto ptr = a.mutable_data();
3911986Sandreas.sandberg@arm.com    for (size_t i = 0; i < a.size(); i++)
4011986Sandreas.sandberg@arm.com        ptr[i]++;
4111986Sandreas.sandberg@arm.com    return a;
4211986Sandreas.sandberg@arm.com}
4311986Sandreas.sandberg@arm.com
4411986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr& mutate_data(arr& a, Ix... index) {
4511986Sandreas.sandberg@arm.com    auto ptr = (uint8_t *) a.mutable_data(index...);
4611986Sandreas.sandberg@arm.com    for (size_t i = 0; i < a.nbytes() - a.offset_at(index...); i++)
4711986Sandreas.sandberg@arm.com        ptr[i] = (uint8_t) (ptr[i] * 2);
4811986Sandreas.sandberg@arm.com    return a;
4911986Sandreas.sandberg@arm.com}
5011986Sandreas.sandberg@arm.com
5111986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr_t& mutate_data_t(arr_t& a, Ix... index) {
5211986Sandreas.sandberg@arm.com    auto ptr = a.mutable_data(index...);
5311986Sandreas.sandberg@arm.com    for (size_t i = 0; i < a.size() - a.index_at(index...); i++)
5411986Sandreas.sandberg@arm.com        ptr[i]++;
5511986Sandreas.sandberg@arm.com    return a;
5611986Sandreas.sandberg@arm.com}
5711986Sandreas.sandberg@arm.com
5811986Sandreas.sandberg@arm.comtemplate<typename... Ix> size_t index_at(const arr& a, Ix... idx) { return a.index_at(idx...); }
5911986Sandreas.sandberg@arm.comtemplate<typename... Ix> size_t index_at_t(const arr_t& a, Ix... idx) { return a.index_at(idx...); }
6011986Sandreas.sandberg@arm.comtemplate<typename... Ix> size_t offset_at(const arr& a, Ix... idx) { return a.offset_at(idx...); }
6111986Sandreas.sandberg@arm.comtemplate<typename... Ix> size_t offset_at_t(const arr_t& a, Ix... idx) { return a.offset_at(idx...); }
6211986Sandreas.sandberg@arm.comtemplate<typename... Ix> size_t at_t(const arr_t& a, Ix... idx) { return a.at(idx...); }
6311986Sandreas.sandberg@arm.comtemplate<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(idx...)++; return a; }
6411986Sandreas.sandberg@arm.com
6511986Sandreas.sandberg@arm.com#define def_index_fn(name, type) \
6611986Sandreas.sandberg@arm.com    sm.def(#name, [](type a) { return name(a); }); \
6711986Sandreas.sandberg@arm.com    sm.def(#name, [](type a, int i) { return name(a, i); }); \
6811986Sandreas.sandberg@arm.com    sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
6911986Sandreas.sandberg@arm.com    sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
7011986Sandreas.sandberg@arm.com
7112037Sandreas.sandberg@arm.comtemplate <typename T, typename T2> py::handle auxiliaries(T &&r, T2 &&r2) {
7212037Sandreas.sandberg@arm.com    if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
7312037Sandreas.sandberg@arm.com    py::list l;
7412037Sandreas.sandberg@arm.com    l.append(*r.data(0, 0));
7512037Sandreas.sandberg@arm.com    l.append(*r2.mutable_data(0, 0));
7612037Sandreas.sandberg@arm.com    l.append(r.data(0, 1) == r2.mutable_data(0, 1));
7712037Sandreas.sandberg@arm.com    l.append(r.ndim());
7812037Sandreas.sandberg@arm.com    l.append(r.itemsize());
7912037Sandreas.sandberg@arm.com    l.append(r.shape(0));
8012037Sandreas.sandberg@arm.com    l.append(r.shape(1));
8112037Sandreas.sandberg@arm.com    l.append(r.size());
8212037Sandreas.sandberg@arm.com    l.append(r.nbytes());
8312037Sandreas.sandberg@arm.com    return l.release();
8412037Sandreas.sandberg@arm.com}
8512037Sandreas.sandberg@arm.com
8611986Sandreas.sandberg@arm.comtest_initializer numpy_array([](py::module &m) {
8711986Sandreas.sandberg@arm.com    auto sm = m.def_submodule("array");
8811986Sandreas.sandberg@arm.com
8911986Sandreas.sandberg@arm.com    sm.def("ndim", [](const arr& a) { return a.ndim(); });
9011986Sandreas.sandberg@arm.com    sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });
9111986Sandreas.sandberg@arm.com    sm.def("shape", [](const arr& a, size_t dim) { return a.shape(dim); });
9211986Sandreas.sandberg@arm.com    sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); });
9311986Sandreas.sandberg@arm.com    sm.def("strides", [](const arr& a, size_t dim) { return a.strides(dim); });
9411986Sandreas.sandberg@arm.com    sm.def("writeable", [](const arr& a) { return a.writeable(); });
9511986Sandreas.sandberg@arm.com    sm.def("size", [](const arr& a) { return a.size(); });
9611986Sandreas.sandberg@arm.com    sm.def("itemsize", [](const arr& a) { return a.itemsize(); });
9711986Sandreas.sandberg@arm.com    sm.def("nbytes", [](const arr& a) { return a.nbytes(); });
9811986Sandreas.sandberg@arm.com    sm.def("owndata", [](const arr& a) { return a.owndata(); });
9911986Sandreas.sandberg@arm.com
10011986Sandreas.sandberg@arm.com    def_index_fn(data, const arr&);
10111986Sandreas.sandberg@arm.com    def_index_fn(data_t, const arr_t&);
10211986Sandreas.sandberg@arm.com    def_index_fn(index_at, const arr&);
10311986Sandreas.sandberg@arm.com    def_index_fn(index_at_t, const arr_t&);
10411986Sandreas.sandberg@arm.com    def_index_fn(offset_at, const arr&);
10511986Sandreas.sandberg@arm.com    def_index_fn(offset_at_t, const arr_t&);
10611986Sandreas.sandberg@arm.com    def_index_fn(mutate_data, arr&);
10711986Sandreas.sandberg@arm.com    def_index_fn(mutate_data_t, arr_t&);
10811986Sandreas.sandberg@arm.com    def_index_fn(at_t, const arr_t&);
10911986Sandreas.sandberg@arm.com    def_index_fn(mutate_at_t, arr_t&);
11011986Sandreas.sandberg@arm.com
11111986Sandreas.sandberg@arm.com    sm.def("make_f_array", [] {
11211986Sandreas.sandberg@arm.com        return py::array_t<float>({ 2, 2 }, { 4, 8 });
11311986Sandreas.sandberg@arm.com    });
11411986Sandreas.sandberg@arm.com
11511986Sandreas.sandberg@arm.com    sm.def("make_c_array", [] {
11611986Sandreas.sandberg@arm.com        return py::array_t<float>({ 2, 2 }, { 8, 4 });
11711986Sandreas.sandberg@arm.com    });
11811986Sandreas.sandberg@arm.com
11911986Sandreas.sandberg@arm.com    sm.def("wrap", [](py::array a) {
12011986Sandreas.sandberg@arm.com        return py::array(
12111986Sandreas.sandberg@arm.com            a.dtype(),
12211986Sandreas.sandberg@arm.com            std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
12311986Sandreas.sandberg@arm.com            std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
12411986Sandreas.sandberg@arm.com            a.data(),
12511986Sandreas.sandberg@arm.com            a
12611986Sandreas.sandberg@arm.com        );
12711986Sandreas.sandberg@arm.com    });
12811986Sandreas.sandberg@arm.com
12911986Sandreas.sandberg@arm.com    struct ArrayClass {
13011986Sandreas.sandberg@arm.com        int data[2] = { 1, 2 };
13111986Sandreas.sandberg@arm.com        ArrayClass() { py::print("ArrayClass()"); }
13211986Sandreas.sandberg@arm.com        ~ArrayClass() { py::print("~ArrayClass()"); }
13311986Sandreas.sandberg@arm.com    };
13411986Sandreas.sandberg@arm.com
13511986Sandreas.sandberg@arm.com    py::class_<ArrayClass>(sm, "ArrayClass")
13611986Sandreas.sandberg@arm.com        .def(py::init<>())
13711986Sandreas.sandberg@arm.com        .def("numpy_view", [](py::object &obj) {
13811986Sandreas.sandberg@arm.com            py::print("ArrayClass::numpy_view()");
13911986Sandreas.sandberg@arm.com            ArrayClass &a = obj.cast<ArrayClass&>();
14011986Sandreas.sandberg@arm.com            return py::array_t<int>({2}, {4}, a.data, obj);
14111986Sandreas.sandberg@arm.com        }
14211986Sandreas.sandberg@arm.com    );
14311986Sandreas.sandberg@arm.com
14411986Sandreas.sandberg@arm.com    sm.def("function_taking_uint64", [](uint64_t) { });
14511986Sandreas.sandberg@arm.com
14611986Sandreas.sandberg@arm.com    sm.def("isinstance_untyped", [](py::object yes, py::object no) {
14711986Sandreas.sandberg@arm.com        return py::isinstance<py::array>(yes) && !py::isinstance<py::array>(no);
14811986Sandreas.sandberg@arm.com    });
14911986Sandreas.sandberg@arm.com
15011986Sandreas.sandberg@arm.com    sm.def("isinstance_typed", [](py::object o) {
15111986Sandreas.sandberg@arm.com        return py::isinstance<py::array_t<double>>(o) && !py::isinstance<py::array_t<int>>(o);
15211986Sandreas.sandberg@arm.com    });
15311986Sandreas.sandberg@arm.com
15411986Sandreas.sandberg@arm.com    sm.def("default_constructors", []() {
15511986Sandreas.sandberg@arm.com        return py::dict(
15611986Sandreas.sandberg@arm.com            "array"_a=py::array(),
15711986Sandreas.sandberg@arm.com            "array_t<int32>"_a=py::array_t<std::int32_t>(),
15811986Sandreas.sandberg@arm.com            "array_t<double>"_a=py::array_t<double>()
15911986Sandreas.sandberg@arm.com        );
16011986Sandreas.sandberg@arm.com    });
16111986Sandreas.sandberg@arm.com
16211986Sandreas.sandberg@arm.com    sm.def("converting_constructors", [](py::object o) {
16311986Sandreas.sandberg@arm.com        return py::dict(
16411986Sandreas.sandberg@arm.com            "array"_a=py::array(o),
16511986Sandreas.sandberg@arm.com            "array_t<int32>"_a=py::array_t<std::int32_t>(o),
16611986Sandreas.sandberg@arm.com            "array_t<double>"_a=py::array_t<double>(o)
16711986Sandreas.sandberg@arm.com        );
16811986Sandreas.sandberg@arm.com    });
16912037Sandreas.sandberg@arm.com
17012037Sandreas.sandberg@arm.com    // Overload resolution tests:
17112037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<double>) { return "double"; });
17212037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<float>) { return "float"; });
17312037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<int>) { return "int"; });
17412037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<unsigned short>) { return "unsigned short"; });
17512037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<long long>) { return "long long"; });
17612037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<std::complex<double>>) { return "double complex"; });
17712037Sandreas.sandberg@arm.com    sm.def("overloaded", [](py::array_t<std::complex<float>>) { return "float complex"; });
17812037Sandreas.sandberg@arm.com
17912037Sandreas.sandberg@arm.com    sm.def("overloaded2", [](py::array_t<std::complex<double>>) { return "double complex"; });
18012037Sandreas.sandberg@arm.com    sm.def("overloaded2", [](py::array_t<double>) { return "double"; });
18112037Sandreas.sandberg@arm.com    sm.def("overloaded2", [](py::array_t<std::complex<float>>) { return "float complex"; });
18212037Sandreas.sandberg@arm.com    sm.def("overloaded2", [](py::array_t<float>) { return "float"; });
18312037Sandreas.sandberg@arm.com
18412037Sandreas.sandberg@arm.com    // Only accept the exact types:
18512037Sandreas.sandberg@arm.com    sm.def("overloaded3", [](py::array_t<int>) { return "int"; }, py::arg().noconvert());
18612037Sandreas.sandberg@arm.com    sm.def("overloaded3", [](py::array_t<double>) { return "double"; }, py::arg().noconvert());
18712037Sandreas.sandberg@arm.com
18812037Sandreas.sandberg@arm.com    // Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but
18912037Sandreas.sandberg@arm.com    // rather that float gets converted via the safe (conversion to double) overload:
19012037Sandreas.sandberg@arm.com    sm.def("overloaded4", [](py::array_t<long long, 0>) { return "long long"; });
19112037Sandreas.sandberg@arm.com    sm.def("overloaded4", [](py::array_t<double, 0>) { return "double"; });
19212037Sandreas.sandberg@arm.com
19312037Sandreas.sandberg@arm.com    // But we do allow conversion to int if forcecast is enabled (but only if no overload matches
19412037Sandreas.sandberg@arm.com    // without conversion)
19512037Sandreas.sandberg@arm.com    sm.def("overloaded5", [](py::array_t<unsigned int>) { return "unsigned int"; });
19612037Sandreas.sandberg@arm.com    sm.def("overloaded5", [](py::array_t<double>) { return "double"; });
19712037Sandreas.sandberg@arm.com
19812037Sandreas.sandberg@arm.com    // Issue 685: ndarray shouldn't go to std::string overload
19912037Sandreas.sandberg@arm.com    sm.def("issue685", [](std::string) { return "string"; });
20012037Sandreas.sandberg@arm.com    sm.def("issue685", [](py::array) { return "array"; });
20112037Sandreas.sandberg@arm.com    sm.def("issue685", [](py::object) { return "other"; });
20212037Sandreas.sandberg@arm.com
20312037Sandreas.sandberg@arm.com    sm.def("proxy_add2", [](py::array_t<double> a, double v) {
20412037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked<2>();
20512037Sandreas.sandberg@arm.com        for (size_t i = 0; i < r.shape(0); i++)
20612037Sandreas.sandberg@arm.com            for (size_t j = 0; j < r.shape(1); j++)
20712037Sandreas.sandberg@arm.com                r(i, j) += v;
20812037Sandreas.sandberg@arm.com    }, py::arg().noconvert(), py::arg());
20912037Sandreas.sandberg@arm.com
21012037Sandreas.sandberg@arm.com    sm.def("proxy_init3", [](double start) {
21112037Sandreas.sandberg@arm.com        py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
21212037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked<3>();
21312037Sandreas.sandberg@arm.com        for (size_t i = 0; i < r.shape(0); i++)
21412037Sandreas.sandberg@arm.com        for (size_t j = 0; j < r.shape(1); j++)
21512037Sandreas.sandberg@arm.com        for (size_t k = 0; k < r.shape(2); k++)
21612037Sandreas.sandberg@arm.com            r(i, j, k) = start++;
21712037Sandreas.sandberg@arm.com        return a;
21812037Sandreas.sandberg@arm.com    });
21912037Sandreas.sandberg@arm.com    sm.def("proxy_init3F", [](double start) {
22012037Sandreas.sandberg@arm.com        py::array_t<double, py::array::f_style> a({ 3, 3, 3 });
22112037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked<3>();
22212037Sandreas.sandberg@arm.com        for (size_t k = 0; k < r.shape(2); k++)
22312037Sandreas.sandberg@arm.com        for (size_t j = 0; j < r.shape(1); j++)
22412037Sandreas.sandberg@arm.com        for (size_t i = 0; i < r.shape(0); i++)
22512037Sandreas.sandberg@arm.com            r(i, j, k) = start++;
22612037Sandreas.sandberg@arm.com        return a;
22712037Sandreas.sandberg@arm.com    });
22812037Sandreas.sandberg@arm.com    sm.def("proxy_squared_L2_norm", [](py::array_t<double> a) {
22912037Sandreas.sandberg@arm.com        auto r = a.unchecked<1>();
23012037Sandreas.sandberg@arm.com        double sumsq = 0;
23112037Sandreas.sandberg@arm.com        for (size_t i = 0; i < r.shape(0); i++)
23212037Sandreas.sandberg@arm.com            sumsq += r[i] * r(i); // Either notation works for a 1D array
23312037Sandreas.sandberg@arm.com        return sumsq;
23412037Sandreas.sandberg@arm.com    });
23512037Sandreas.sandberg@arm.com
23612037Sandreas.sandberg@arm.com    sm.def("proxy_auxiliaries2", [](py::array_t<double> a) {
23712037Sandreas.sandberg@arm.com        auto r = a.unchecked<2>();
23812037Sandreas.sandberg@arm.com        auto r2 = a.mutable_unchecked<2>();
23912037Sandreas.sandberg@arm.com        return auxiliaries(r, r2);
24012037Sandreas.sandberg@arm.com    });
24112037Sandreas.sandberg@arm.com
24212037Sandreas.sandberg@arm.com    // Same as the above, but without a compile-time dimensions specification:
24312037Sandreas.sandberg@arm.com    sm.def("proxy_add2_dyn", [](py::array_t<double> a, double v) {
24412037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked();
24512037Sandreas.sandberg@arm.com        if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
24612037Sandreas.sandberg@arm.com        for (size_t i = 0; i < r.shape(0); i++)
24712037Sandreas.sandberg@arm.com            for (size_t j = 0; j < r.shape(1); j++)
24812037Sandreas.sandberg@arm.com                r(i, j) += v;
24912037Sandreas.sandberg@arm.com    }, py::arg().noconvert(), py::arg());
25012037Sandreas.sandberg@arm.com    sm.def("proxy_init3_dyn", [](double start) {
25112037Sandreas.sandberg@arm.com        py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
25212037Sandreas.sandberg@arm.com        auto r = a.mutable_unchecked();
25312037Sandreas.sandberg@arm.com        if (r.ndim() != 3) throw std::domain_error("error: ndim != 3");
25412037Sandreas.sandberg@arm.com        for (size_t i = 0; i < r.shape(0); i++)
25512037Sandreas.sandberg@arm.com        for (size_t j = 0; j < r.shape(1); j++)
25612037Sandreas.sandberg@arm.com        for (size_t k = 0; k < r.shape(2); k++)
25712037Sandreas.sandberg@arm.com            r(i, j, k) = start++;
25812037Sandreas.sandberg@arm.com        return a;
25912037Sandreas.sandberg@arm.com    });
26012037Sandreas.sandberg@arm.com    sm.def("proxy_auxiliaries2_dyn", [](py::array_t<double> a) {
26112037Sandreas.sandberg@arm.com        return auxiliaries(a.unchecked(), a.mutable_unchecked());
26212037Sandreas.sandberg@arm.com    });
26312037Sandreas.sandberg@arm.com
26412037Sandreas.sandberg@arm.com    sm.def("array_auxiliaries2", [](py::array_t<double> a) {
26512037Sandreas.sandberg@arm.com        return auxiliaries(a, a);
26612037Sandreas.sandberg@arm.com    });
26711986Sandreas.sandberg@arm.com});
268