test_buffers.cpp revision 14299
1/* 2 tests/test_buffers.cpp -- supporting Pythons' buffer protocol 3 4 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> 5 6 All rights reserved. Use of this source code is governed by a 7 BSD-style license that can be found in the LICENSE file. 8*/ 9 10#include "pybind11_tests.h" 11#include "constructor_stats.h" 12 13TEST_SUBMODULE(buffers, m) { 14 // test_from_python / test_to_python: 15 class Matrix { 16 public: 17 Matrix(ssize_t rows, ssize_t cols) : m_rows(rows), m_cols(cols) { 18 print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 19 m_data = new float[(size_t) (rows*cols)]; 20 memset(m_data, 0, sizeof(float) * (size_t) (rows * cols)); 21 } 22 23 Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) { 24 print_copy_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 25 m_data = new float[(size_t) (m_rows * m_cols)]; 26 memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols)); 27 } 28 29 Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) { 30 print_move_created(this); 31 s.m_rows = 0; 32 s.m_cols = 0; 33 s.m_data = nullptr; 34 } 35 36 ~Matrix() { 37 print_destroyed(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 38 delete[] m_data; 39 } 40 41 Matrix &operator=(const Matrix &s) { 42 print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 43 delete[] m_data; 44 m_rows = s.m_rows; 45 m_cols = s.m_cols; 46 m_data = new float[(size_t) (m_rows * m_cols)]; 47 memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols)); 48 return *this; 49 } 50 51 Matrix &operator=(Matrix &&s) { 52 print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); 53 if (&s != this) { 54 delete[] m_data; 55 m_rows = s.m_rows; m_cols = s.m_cols; m_data = s.m_data; 56 s.m_rows = 0; s.m_cols = 0; s.m_data = nullptr; 57 } 58 return *this; 59 } 60 61 float operator()(ssize_t i, ssize_t j) const { 62 return m_data[(size_t) (i*m_cols + j)]; 63 } 64 65 float &operator()(ssize_t i, ssize_t j) { 66 return m_data[(size_t) (i*m_cols + j)]; 67 } 68 69 float *data() { return m_data; } 70 71 ssize_t rows() const { return m_rows; } 72 ssize_t cols() const { return m_cols; } 73 private: 74 ssize_t m_rows; 75 ssize_t m_cols; 76 float *m_data; 77 }; 78 py::class_<Matrix>(m, "Matrix", py::buffer_protocol()) 79 .def(py::init<ssize_t, ssize_t>()) 80 /// Construct from a buffer 81 .def(py::init([](py::buffer const b) { 82 py::buffer_info info = b.request(); 83 if (info.format != py::format_descriptor<float>::format() || info.ndim != 2) 84 throw std::runtime_error("Incompatible buffer format!"); 85 86 auto v = new Matrix(info.shape[0], info.shape[1]); 87 memcpy(v->data(), info.ptr, sizeof(float) * (size_t) (v->rows() * v->cols())); 88 return v; 89 })) 90 91 .def("rows", &Matrix::rows) 92 .def("cols", &Matrix::cols) 93 94 /// Bare bones interface 95 .def("__getitem__", [](const Matrix &m, std::pair<ssize_t, ssize_t> i) { 96 if (i.first >= m.rows() || i.second >= m.cols()) 97 throw py::index_error(); 98 return m(i.first, i.second); 99 }) 100 .def("__setitem__", [](Matrix &m, std::pair<ssize_t, ssize_t> i, float v) { 101 if (i.first >= m.rows() || i.second >= m.cols()) 102 throw py::index_error(); 103 m(i.first, i.second) = v; 104 }) 105 /// Provide buffer access 106 .def_buffer([](Matrix &m) -> py::buffer_info { 107 return py::buffer_info( 108 m.data(), /* Pointer to buffer */ 109 { m.rows(), m.cols() }, /* Buffer dimensions */ 110 { sizeof(float) * size_t(m.cols()), /* Strides (in bytes) for each index */ 111 sizeof(float) } 112 ); 113 }) 114 ; 115 116 117 // test_inherited_protocol 118 class SquareMatrix : public Matrix { 119 public: 120 SquareMatrix(ssize_t n) : Matrix(n, n) { } 121 }; 122 // Derived classes inherit the buffer protocol and the buffer access function 123 py::class_<SquareMatrix, Matrix>(m, "SquareMatrix") 124 .def(py::init<ssize_t>()); 125 126 127 // test_pointer_to_member_fn 128 // Tests that passing a pointer to member to the base class works in 129 // the derived class. 130 struct Buffer { 131 int32_t value = 0; 132 133 py::buffer_info get_buffer_info() { 134 return py::buffer_info(&value, sizeof(value), 135 py::format_descriptor<int32_t>::format(), 1); 136 } 137 }; 138 py::class_<Buffer>(m, "Buffer", py::buffer_protocol()) 139 .def(py::init<>()) 140 .def_readwrite("value", &Buffer::value) 141 .def_buffer(&Buffer::get_buffer_info); 142 143 144 class ConstBuffer { 145 std::unique_ptr<int32_t> value; 146 147 public: 148 int32_t get_value() const { return *value; } 149 void set_value(int32_t v) { *value = v; } 150 151 py::buffer_info get_buffer_info() const { 152 return py::buffer_info(value.get(), sizeof(*value), 153 py::format_descriptor<int32_t>::format(), 1); 154 } 155 156 ConstBuffer() : value(new int32_t{0}) { }; 157 }; 158 py::class_<ConstBuffer>(m, "ConstBuffer", py::buffer_protocol()) 159 .def(py::init<>()) 160 .def_property("value", &ConstBuffer::get_value, &ConstBuffer::set_value) 161 .def_buffer(&ConstBuffer::get_buffer_info); 162 163 struct DerivedBuffer : public Buffer { }; 164 py::class_<DerivedBuffer>(m, "DerivedBuffer", py::buffer_protocol()) 165 .def(py::init<>()) 166 .def_readwrite("value", (int32_t DerivedBuffer::*) &DerivedBuffer::value) 167 .def_buffer(&DerivedBuffer::get_buffer_info); 168 169} 170