111986Sandreas.sandberg@arm.com/* 211986Sandreas.sandberg@arm.com tests/test_buffers.cpp -- supporting Pythons' buffer protocol 311986Sandreas.sandberg@arm.com 411986Sandreas.sandberg@arm.com Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> 511986Sandreas.sandberg@arm.com 611986Sandreas.sandberg@arm.com All rights reserved. Use of this source code is governed by a 711986Sandreas.sandberg@arm.com BSD-style license that can be found in the LICENSE file. 811986Sandreas.sandberg@arm.com*/ 911986Sandreas.sandberg@arm.com 1011986Sandreas.sandberg@arm.com#include "pybind11_tests.h" 1111986Sandreas.sandberg@arm.com#include "constructor_stats.h" 1211986Sandreas.sandberg@arm.com 1312391Sjason@lowepower.comTEST_SUBMODULE(buffers, m) { 1412391Sjason@lowepower.com // test_from_python / test_to_python: 1512391Sjason@lowepower.com class Matrix { 1612391Sjason@lowepower.com public: 1712391Sjason@lowepower.com Matrix(ssize_t rows, ssize_t cols) : m_rows(rows), m_cols(cols) { 1812391Sjason@lowepower.com print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 1912391Sjason@lowepower.com m_data = new float[(size_t) (rows*cols)]; 2012391Sjason@lowepower.com memset(m_data, 0, sizeof(float) * (size_t) (rows * cols)); 2112391Sjason@lowepower.com } 2211986Sandreas.sandberg@arm.com 2312391Sjason@lowepower.com Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) { 2412391Sjason@lowepower.com print_copy_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 2512391Sjason@lowepower.com m_data = new float[(size_t) (m_rows * m_cols)]; 2612391Sjason@lowepower.com memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols)); 2712391Sjason@lowepower.com } 2811986Sandreas.sandberg@arm.com 2912391Sjason@lowepower.com Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) { 3012391Sjason@lowepower.com print_move_created(this); 3112391Sjason@lowepower.com s.m_rows = 0; 3212391Sjason@lowepower.com s.m_cols = 0; 3312391Sjason@lowepower.com s.m_data = nullptr; 3412391Sjason@lowepower.com } 3511986Sandreas.sandberg@arm.com 3612391Sjason@lowepower.com ~Matrix() { 3712391Sjason@lowepower.com print_destroyed(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 3812391Sjason@lowepower.com delete[] m_data; 3912391Sjason@lowepower.com } 4011986Sandreas.sandberg@arm.com 4112391Sjason@lowepower.com Matrix &operator=(const Matrix &s) { 4212391Sjason@lowepower.com print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 4312391Sjason@lowepower.com delete[] m_data; 4412391Sjason@lowepower.com m_rows = s.m_rows; 4512391Sjason@lowepower.com m_cols = s.m_cols; 4612391Sjason@lowepower.com m_data = new float[(size_t) (m_rows * m_cols)]; 4712391Sjason@lowepower.com memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols)); 4812391Sjason@lowepower.com return *this; 4912391Sjason@lowepower.com } 5011986Sandreas.sandberg@arm.com 5112391Sjason@lowepower.com Matrix &operator=(Matrix &&s) { 5212391Sjason@lowepower.com print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 5312391Sjason@lowepower.com if (&s != this) { 5412391Sjason@lowepower.com delete[] m_data; 5512391Sjason@lowepower.com m_rows = s.m_rows; m_cols = s.m_cols; m_data = s.m_data; 5612391Sjason@lowepower.com s.m_rows = 0; s.m_cols = 0; s.m_data = nullptr; 5712391Sjason@lowepower.com } 5812391Sjason@lowepower.com return *this; 5911986Sandreas.sandberg@arm.com } 6011986Sandreas.sandberg@arm.com 6112391Sjason@lowepower.com float operator()(ssize_t i, ssize_t j) const { 6212391Sjason@lowepower.com return m_data[(size_t) (i*m_cols + j)]; 6312391Sjason@lowepower.com } 6411986Sandreas.sandberg@arm.com 6512391Sjason@lowepower.com float &operator()(ssize_t i, ssize_t j) { 6612391Sjason@lowepower.com return m_data[(size_t) (i*m_cols + j)]; 6712391Sjason@lowepower.com } 6811986Sandreas.sandberg@arm.com 6912391Sjason@lowepower.com float *data() { return m_data; } 7011986Sandreas.sandberg@arm.com 7112391Sjason@lowepower.com ssize_t rows() const { return m_rows; } 7212391Sjason@lowepower.com ssize_t cols() const { return m_cols; } 7312391Sjason@lowepower.com private: 7412391Sjason@lowepower.com ssize_t m_rows; 7512391Sjason@lowepower.com ssize_t m_cols; 7612391Sjason@lowepower.com float *m_data; 7712391Sjason@lowepower.com }; 7812391Sjason@lowepower.com py::class_<Matrix>(m, "Matrix", py::buffer_protocol()) 7912391Sjason@lowepower.com .def(py::init<ssize_t, ssize_t>()) 8011986Sandreas.sandberg@arm.com /// Construct from a buffer 8114299Sbbruce@ucdavis.edu .def(py::init([](py::buffer const b) { 8211986Sandreas.sandberg@arm.com py::buffer_info info = b.request(); 8311986Sandreas.sandberg@arm.com if (info.format != py::format_descriptor<float>::format() || info.ndim != 2) 8411986Sandreas.sandberg@arm.com throw std::runtime_error("Incompatible buffer format!"); 8512391Sjason@lowepower.com 8612391Sjason@lowepower.com auto v = new Matrix(info.shape[0], info.shape[1]); 8712391Sjason@lowepower.com memcpy(v->data(), info.ptr, sizeof(float) * (size_t) (v->rows() * v->cols())); 8812391Sjason@lowepower.com return v; 8912391Sjason@lowepower.com })) 9011986Sandreas.sandberg@arm.com 9111986Sandreas.sandberg@arm.com .def("rows", &Matrix::rows) 9211986Sandreas.sandberg@arm.com .def("cols", &Matrix::cols) 9311986Sandreas.sandberg@arm.com 9411986Sandreas.sandberg@arm.com /// Bare bones interface 9512391Sjason@lowepower.com .def("__getitem__", [](const Matrix &m, std::pair<ssize_t, ssize_t> i) { 9611986Sandreas.sandberg@arm.com if (i.first >= m.rows() || i.second >= m.cols()) 9711986Sandreas.sandberg@arm.com throw py::index_error(); 9811986Sandreas.sandberg@arm.com return m(i.first, i.second); 9911986Sandreas.sandberg@arm.com }) 10012391Sjason@lowepower.com .def("__setitem__", [](Matrix &m, std::pair<ssize_t, ssize_t> i, float v) { 10111986Sandreas.sandberg@arm.com if (i.first >= m.rows() || i.second >= m.cols()) 10211986Sandreas.sandberg@arm.com throw py::index_error(); 10311986Sandreas.sandberg@arm.com m(i.first, i.second) = v; 10411986Sandreas.sandberg@arm.com }) 10511986Sandreas.sandberg@arm.com /// Provide buffer access 10611986Sandreas.sandberg@arm.com .def_buffer([](Matrix &m) -> py::buffer_info { 10711986Sandreas.sandberg@arm.com return py::buffer_info( 10811986Sandreas.sandberg@arm.com m.data(), /* Pointer to buffer */ 10911986Sandreas.sandberg@arm.com { m.rows(), m.cols() }, /* Buffer dimensions */ 11014299Sbbruce@ucdavis.edu { sizeof(float) * size_t(m.cols()), /* Strides (in bytes) for each index */ 11111986Sandreas.sandberg@arm.com sizeof(float) } 11211986Sandreas.sandberg@arm.com ); 11311986Sandreas.sandberg@arm.com }) 11411986Sandreas.sandberg@arm.com ; 11512391Sjason@lowepower.com 11612391Sjason@lowepower.com 11712391Sjason@lowepower.com // test_inherited_protocol 11812391Sjason@lowepower.com class SquareMatrix : public Matrix { 11912391Sjason@lowepower.com public: 12012391Sjason@lowepower.com SquareMatrix(ssize_t n) : Matrix(n, n) { } 12112391Sjason@lowepower.com }; 12212391Sjason@lowepower.com // Derived classes inherit the buffer protocol and the buffer access function 12312391Sjason@lowepower.com py::class_<SquareMatrix, Matrix>(m, "SquareMatrix") 12412391Sjason@lowepower.com .def(py::init<ssize_t>()); 12512391Sjason@lowepower.com 12612391Sjason@lowepower.com 12712391Sjason@lowepower.com // test_pointer_to_member_fn 12812391Sjason@lowepower.com // Tests that passing a pointer to member to the base class works in 12912391Sjason@lowepower.com // the derived class. 13012391Sjason@lowepower.com struct Buffer { 13112391Sjason@lowepower.com int32_t value = 0; 13212391Sjason@lowepower.com 13312391Sjason@lowepower.com py::buffer_info get_buffer_info() { 13412391Sjason@lowepower.com return py::buffer_info(&value, sizeof(value), 13512391Sjason@lowepower.com py::format_descriptor<int32_t>::format(), 1); 13612391Sjason@lowepower.com } 13712391Sjason@lowepower.com }; 13812391Sjason@lowepower.com py::class_<Buffer>(m, "Buffer", py::buffer_protocol()) 13912391Sjason@lowepower.com .def(py::init<>()) 14012391Sjason@lowepower.com .def_readwrite("value", &Buffer::value) 14112391Sjason@lowepower.com .def_buffer(&Buffer::get_buffer_info); 14212391Sjason@lowepower.com 14312391Sjason@lowepower.com 14412391Sjason@lowepower.com class ConstBuffer { 14512391Sjason@lowepower.com std::unique_ptr<int32_t> value; 14612391Sjason@lowepower.com 14712391Sjason@lowepower.com public: 14812391Sjason@lowepower.com int32_t get_value() const { return *value; } 14912391Sjason@lowepower.com void set_value(int32_t v) { *value = v; } 15012391Sjason@lowepower.com 15112391Sjason@lowepower.com py::buffer_info get_buffer_info() const { 15212391Sjason@lowepower.com return py::buffer_info(value.get(), sizeof(*value), 15312391Sjason@lowepower.com py::format_descriptor<int32_t>::format(), 1); 15412391Sjason@lowepower.com } 15512391Sjason@lowepower.com 15612391Sjason@lowepower.com ConstBuffer() : value(new int32_t{0}) { }; 15712391Sjason@lowepower.com }; 15812391Sjason@lowepower.com py::class_<ConstBuffer>(m, "ConstBuffer", py::buffer_protocol()) 15912391Sjason@lowepower.com .def(py::init<>()) 16012391Sjason@lowepower.com .def_property("value", &ConstBuffer::get_value, &ConstBuffer::set_value) 16112391Sjason@lowepower.com .def_buffer(&ConstBuffer::get_buffer_info); 16212391Sjason@lowepower.com 16312391Sjason@lowepower.com struct DerivedBuffer : public Buffer { }; 16412391Sjason@lowepower.com py::class_<DerivedBuffer>(m, "DerivedBuffer", py::buffer_protocol()) 16512391Sjason@lowepower.com .def(py::init<>()) 16612391Sjason@lowepower.com .def_readwrite("value", (int32_t DerivedBuffer::*) &DerivedBuffer::value) 16712391Sjason@lowepower.com .def_buffer(&DerivedBuffer::get_buffer_info); 16812391Sjason@lowepower.com 16912391Sjason@lowepower.com} 170