111986Sandreas.sandberg@arm.com/* 211986Sandreas.sandberg@arm.com tests/test_numpy_vectorize.cpp -- auto-vectorize functions over NumPy array 311986Sandreas.sandberg@arm.com arguments 411986Sandreas.sandberg@arm.com 511986Sandreas.sandberg@arm.com Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> 611986Sandreas.sandberg@arm.com 711986Sandreas.sandberg@arm.com All rights reserved. Use of this source code is governed by a 811986Sandreas.sandberg@arm.com BSD-style license that can be found in the LICENSE file. 911986Sandreas.sandberg@arm.com*/ 1011986Sandreas.sandberg@arm.com 1111986Sandreas.sandberg@arm.com#include "pybind11_tests.h" 1211986Sandreas.sandberg@arm.com#include <pybind11/numpy.h> 1311986Sandreas.sandberg@arm.com 1411986Sandreas.sandberg@arm.comdouble my_func(int x, float y, double z) { 1511986Sandreas.sandberg@arm.com py::print("my_func(x:int={}, y:float={:.0f}, z:float={:.0f})"_s.format(x, y, z)); 1611986Sandreas.sandberg@arm.com return (float) x*y*z; 1711986Sandreas.sandberg@arm.com} 1811986Sandreas.sandberg@arm.com 1912391Sjason@lowepower.comTEST_SUBMODULE(numpy_vectorize, m) { 2012391Sjason@lowepower.com try { py::module::import("numpy"); } 2112391Sjason@lowepower.com catch (...) { return; } 2211986Sandreas.sandberg@arm.com 2312391Sjason@lowepower.com // test_vectorize, test_docs, test_array_collapse 2411986Sandreas.sandberg@arm.com // Vectorize all arguments of a function (though non-vector arguments are also allowed) 2511986Sandreas.sandberg@arm.com m.def("vectorized_func", py::vectorize(my_func)); 2611986Sandreas.sandberg@arm.com 2711986Sandreas.sandberg@arm.com // Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization) 2811986Sandreas.sandberg@arm.com m.def("vectorized_func2", 2911986Sandreas.sandberg@arm.com [](py::array_t<int> x, py::array_t<float> y, float z) { 3011986Sandreas.sandberg@arm.com return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(x, y); 3111986Sandreas.sandberg@arm.com } 3211986Sandreas.sandberg@arm.com ); 3311986Sandreas.sandberg@arm.com 3411986Sandreas.sandberg@arm.com // Vectorize a complex-valued function 3512391Sjason@lowepower.com m.def("vectorized_func3", py::vectorize( 3612391Sjason@lowepower.com [](std::complex<double> c) { return c * std::complex<double>(2.f); } 3712391Sjason@lowepower.com )); 3811986Sandreas.sandberg@arm.com 3912391Sjason@lowepower.com // test_type_selection 4012391Sjason@lowepower.com // Numpy function which only accepts specific data types 4111986Sandreas.sandberg@arm.com m.def("selective_func", [](py::array_t<int, py::array::c_style>) { return "Int branch taken."; }); 4211986Sandreas.sandberg@arm.com m.def("selective_func", [](py::array_t<float, py::array::c_style>) { return "Float branch taken."; }); 4311986Sandreas.sandberg@arm.com m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { return "Complex float branch taken."; }); 4412037Sandreas.sandberg@arm.com 4512037Sandreas.sandberg@arm.com 4612391Sjason@lowepower.com // test_passthrough_arguments 4712391Sjason@lowepower.com // Passthrough test: references and non-pod types should be automatically passed through (in the 4812391Sjason@lowepower.com // function definition below, only `b`, `d`, and `g` are vectorized): 4912391Sjason@lowepower.com struct NonPODClass { 5012391Sjason@lowepower.com NonPODClass(int v) : value{v} {} 5112391Sjason@lowepower.com int value; 5212391Sjason@lowepower.com }; 5312391Sjason@lowepower.com py::class_<NonPODClass>(m, "NonPODClass").def(py::init<int>()); 5412391Sjason@lowepower.com m.def("vec_passthrough", py::vectorize( 5512391Sjason@lowepower.com [](double *a, double b, py::array_t<double> c, const int &d, int &e, NonPODClass f, const double g) { 5612391Sjason@lowepower.com return *a + b + c.at(0) + d + e + f.value + g; 5712391Sjason@lowepower.com } 5812391Sjason@lowepower.com )); 5912391Sjason@lowepower.com 6012391Sjason@lowepower.com // test_method_vectorization 6112391Sjason@lowepower.com struct VectorizeTestClass { 6212391Sjason@lowepower.com VectorizeTestClass(int v) : value{v} {}; 6312391Sjason@lowepower.com float method(int x, float y) { return y + (float) (x + value); } 6412391Sjason@lowepower.com int value = 0; 6512391Sjason@lowepower.com }; 6612391Sjason@lowepower.com py::class_<VectorizeTestClass> vtc(m, "VectorizeTestClass"); 6712391Sjason@lowepower.com vtc .def(py::init<int>()) 6812391Sjason@lowepower.com .def_readwrite("value", &VectorizeTestClass::value); 6912391Sjason@lowepower.com 7012391Sjason@lowepower.com // Automatic vectorizing of methods 7112391Sjason@lowepower.com vtc.def("method", py::vectorize(&VectorizeTestClass::method)); 7212391Sjason@lowepower.com 7312391Sjason@lowepower.com // test_trivial_broadcasting 7412037Sandreas.sandberg@arm.com // Internal optimization test for whether the input is trivially broadcastable: 7512037Sandreas.sandberg@arm.com py::enum_<py::detail::broadcast_trivial>(m, "trivial") 7612037Sandreas.sandberg@arm.com .value("f_trivial", py::detail::broadcast_trivial::f_trivial) 7712037Sandreas.sandberg@arm.com .value("c_trivial", py::detail::broadcast_trivial::c_trivial) 7812037Sandreas.sandberg@arm.com .value("non_trivial", py::detail::broadcast_trivial::non_trivial); 7912037Sandreas.sandberg@arm.com m.def("vectorized_is_trivial", []( 8012037Sandreas.sandberg@arm.com py::array_t<int, py::array::forcecast> arg1, 8112037Sandreas.sandberg@arm.com py::array_t<float, py::array::forcecast> arg2, 8212037Sandreas.sandberg@arm.com py::array_t<double, py::array::forcecast> arg3 8312037Sandreas.sandberg@arm.com ) { 8412391Sjason@lowepower.com ssize_t ndim; 8512391Sjason@lowepower.com std::vector<ssize_t> shape; 8612037Sandreas.sandberg@arm.com std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }}; 8712037Sandreas.sandberg@arm.com return py::detail::broadcast(buffers, ndim, shape); 8812037Sandreas.sandberg@arm.com }); 8912391Sjason@lowepower.com} 90