111986Sandreas.sandberg@arm.com/*
211986Sandreas.sandberg@arm.com    tests/test_buffers.cpp -- supporting Pythons' buffer protocol
311986Sandreas.sandberg@arm.com
411986Sandreas.sandberg@arm.com    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
511986Sandreas.sandberg@arm.com
611986Sandreas.sandberg@arm.com    All rights reserved. Use of this source code is governed by a
711986Sandreas.sandberg@arm.com    BSD-style license that can be found in the LICENSE file.
811986Sandreas.sandberg@arm.com*/
911986Sandreas.sandberg@arm.com
1011986Sandreas.sandberg@arm.com#include "pybind11_tests.h"
1111986Sandreas.sandberg@arm.com#include "constructor_stats.h"
1211986Sandreas.sandberg@arm.com
1312391Sjason@lowepower.comTEST_SUBMODULE(buffers, m) {
1412391Sjason@lowepower.com    // test_from_python / test_to_python:
1512391Sjason@lowepower.com    class Matrix {
1612391Sjason@lowepower.com    public:
1712391Sjason@lowepower.com        Matrix(ssize_t rows, ssize_t cols) : m_rows(rows), m_cols(cols) {
1812391Sjason@lowepower.com            print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
1912391Sjason@lowepower.com            m_data = new float[(size_t) (rows*cols)];
2012391Sjason@lowepower.com            memset(m_data, 0, sizeof(float) * (size_t) (rows * cols));
2112391Sjason@lowepower.com        }
2211986Sandreas.sandberg@arm.com
2312391Sjason@lowepower.com        Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) {
2412391Sjason@lowepower.com            print_copy_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
2512391Sjason@lowepower.com            m_data = new float[(size_t) (m_rows * m_cols)];
2612391Sjason@lowepower.com            memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols));
2712391Sjason@lowepower.com        }
2811986Sandreas.sandberg@arm.com
2912391Sjason@lowepower.com        Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) {
3012391Sjason@lowepower.com            print_move_created(this);
3112391Sjason@lowepower.com            s.m_rows = 0;
3212391Sjason@lowepower.com            s.m_cols = 0;
3312391Sjason@lowepower.com            s.m_data = nullptr;
3412391Sjason@lowepower.com        }
3511986Sandreas.sandberg@arm.com
3612391Sjason@lowepower.com        ~Matrix() {
3712391Sjason@lowepower.com            print_destroyed(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
3812391Sjason@lowepower.com            delete[] m_data;
3912391Sjason@lowepower.com        }
4011986Sandreas.sandberg@arm.com
4112391Sjason@lowepower.com        Matrix &operator=(const Matrix &s) {
4212391Sjason@lowepower.com            print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
4312391Sjason@lowepower.com            delete[] m_data;
4412391Sjason@lowepower.com            m_rows = s.m_rows;
4512391Sjason@lowepower.com            m_cols = s.m_cols;
4612391Sjason@lowepower.com            m_data = new float[(size_t) (m_rows * m_cols)];
4712391Sjason@lowepower.com            memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols));
4812391Sjason@lowepower.com            return *this;
4912391Sjason@lowepower.com        }
5011986Sandreas.sandberg@arm.com
5112391Sjason@lowepower.com        Matrix &operator=(Matrix &&s) {
5212391Sjason@lowepower.com            print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
5312391Sjason@lowepower.com            if (&s != this) {
5412391Sjason@lowepower.com                delete[] m_data;
5512391Sjason@lowepower.com                m_rows = s.m_rows; m_cols = s.m_cols; m_data = s.m_data;
5612391Sjason@lowepower.com                s.m_rows = 0; s.m_cols = 0; s.m_data = nullptr;
5712391Sjason@lowepower.com            }
5812391Sjason@lowepower.com            return *this;
5911986Sandreas.sandberg@arm.com        }
6011986Sandreas.sandberg@arm.com
6112391Sjason@lowepower.com        float operator()(ssize_t i, ssize_t j) const {
6212391Sjason@lowepower.com            return m_data[(size_t) (i*m_cols + j)];
6312391Sjason@lowepower.com        }
6411986Sandreas.sandberg@arm.com
6512391Sjason@lowepower.com        float &operator()(ssize_t i, ssize_t j) {
6612391Sjason@lowepower.com            return m_data[(size_t) (i*m_cols + j)];
6712391Sjason@lowepower.com        }
6811986Sandreas.sandberg@arm.com
6912391Sjason@lowepower.com        float *data() { return m_data; }
7011986Sandreas.sandberg@arm.com
7112391Sjason@lowepower.com        ssize_t rows() const { return m_rows; }
7212391Sjason@lowepower.com        ssize_t cols() const { return m_cols; }
7312391Sjason@lowepower.com    private:
7412391Sjason@lowepower.com        ssize_t m_rows;
7512391Sjason@lowepower.com        ssize_t m_cols;
7612391Sjason@lowepower.com        float *m_data;
7712391Sjason@lowepower.com    };
7812391Sjason@lowepower.com    py::class_<Matrix>(m, "Matrix", py::buffer_protocol())
7912391Sjason@lowepower.com        .def(py::init<ssize_t, ssize_t>())
8011986Sandreas.sandberg@arm.com        /// Construct from a buffer
8114299Sbbruce@ucdavis.edu        .def(py::init([](py::buffer const b) {
8211986Sandreas.sandberg@arm.com            py::buffer_info info = b.request();
8311986Sandreas.sandberg@arm.com            if (info.format != py::format_descriptor<float>::format() || info.ndim != 2)
8411986Sandreas.sandberg@arm.com                throw std::runtime_error("Incompatible buffer format!");
8512391Sjason@lowepower.com
8612391Sjason@lowepower.com            auto v = new Matrix(info.shape[0], info.shape[1]);
8712391Sjason@lowepower.com            memcpy(v->data(), info.ptr, sizeof(float) * (size_t) (v->rows() * v->cols()));
8812391Sjason@lowepower.com            return v;
8912391Sjason@lowepower.com        }))
9011986Sandreas.sandberg@arm.com
9111986Sandreas.sandberg@arm.com       .def("rows", &Matrix::rows)
9211986Sandreas.sandberg@arm.com       .def("cols", &Matrix::cols)
9311986Sandreas.sandberg@arm.com
9411986Sandreas.sandberg@arm.com        /// Bare bones interface
9512391Sjason@lowepower.com       .def("__getitem__", [](const Matrix &m, std::pair<ssize_t, ssize_t> i) {
9611986Sandreas.sandberg@arm.com            if (i.first >= m.rows() || i.second >= m.cols())
9711986Sandreas.sandberg@arm.com                throw py::index_error();
9811986Sandreas.sandberg@arm.com            return m(i.first, i.second);
9911986Sandreas.sandberg@arm.com        })
10012391Sjason@lowepower.com       .def("__setitem__", [](Matrix &m, std::pair<ssize_t, ssize_t> i, float v) {
10111986Sandreas.sandberg@arm.com            if (i.first >= m.rows() || i.second >= m.cols())
10211986Sandreas.sandberg@arm.com                throw py::index_error();
10311986Sandreas.sandberg@arm.com            m(i.first, i.second) = v;
10411986Sandreas.sandberg@arm.com        })
10511986Sandreas.sandberg@arm.com       /// Provide buffer access
10611986Sandreas.sandberg@arm.com       .def_buffer([](Matrix &m) -> py::buffer_info {
10711986Sandreas.sandberg@arm.com            return py::buffer_info(
10811986Sandreas.sandberg@arm.com                m.data(),                               /* Pointer to buffer */
10911986Sandreas.sandberg@arm.com                { m.rows(), m.cols() },                 /* Buffer dimensions */
11014299Sbbruce@ucdavis.edu                { sizeof(float) * size_t(m.cols()),     /* Strides (in bytes) for each index */
11111986Sandreas.sandberg@arm.com                  sizeof(float) }
11211986Sandreas.sandberg@arm.com            );
11311986Sandreas.sandberg@arm.com        })
11411986Sandreas.sandberg@arm.com        ;
11512391Sjason@lowepower.com
11612391Sjason@lowepower.com
11712391Sjason@lowepower.com    // test_inherited_protocol
11812391Sjason@lowepower.com    class SquareMatrix : public Matrix {
11912391Sjason@lowepower.com    public:
12012391Sjason@lowepower.com        SquareMatrix(ssize_t n) : Matrix(n, n) { }
12112391Sjason@lowepower.com    };
12212391Sjason@lowepower.com    // Derived classes inherit the buffer protocol and the buffer access function
12312391Sjason@lowepower.com    py::class_<SquareMatrix, Matrix>(m, "SquareMatrix")
12412391Sjason@lowepower.com        .def(py::init<ssize_t>());
12512391Sjason@lowepower.com
12612391Sjason@lowepower.com
12712391Sjason@lowepower.com    // test_pointer_to_member_fn
12812391Sjason@lowepower.com    // Tests that passing a pointer to member to the base class works in
12912391Sjason@lowepower.com    // the derived class.
13012391Sjason@lowepower.com    struct Buffer {
13112391Sjason@lowepower.com        int32_t value = 0;
13212391Sjason@lowepower.com
13312391Sjason@lowepower.com        py::buffer_info get_buffer_info() {
13412391Sjason@lowepower.com            return py::buffer_info(&value, sizeof(value),
13512391Sjason@lowepower.com                                   py::format_descriptor<int32_t>::format(), 1);
13612391Sjason@lowepower.com        }
13712391Sjason@lowepower.com    };
13812391Sjason@lowepower.com    py::class_<Buffer>(m, "Buffer", py::buffer_protocol())
13912391Sjason@lowepower.com        .def(py::init<>())
14012391Sjason@lowepower.com        .def_readwrite("value", &Buffer::value)
14112391Sjason@lowepower.com        .def_buffer(&Buffer::get_buffer_info);
14212391Sjason@lowepower.com
14312391Sjason@lowepower.com
14412391Sjason@lowepower.com    class ConstBuffer {
14512391Sjason@lowepower.com        std::unique_ptr<int32_t> value;
14612391Sjason@lowepower.com
14712391Sjason@lowepower.com    public:
14812391Sjason@lowepower.com        int32_t get_value() const { return *value; }
14912391Sjason@lowepower.com        void set_value(int32_t v) { *value = v; }
15012391Sjason@lowepower.com
15112391Sjason@lowepower.com        py::buffer_info get_buffer_info() const {
15212391Sjason@lowepower.com            return py::buffer_info(value.get(), sizeof(*value),
15312391Sjason@lowepower.com                                   py::format_descriptor<int32_t>::format(), 1);
15412391Sjason@lowepower.com        }
15512391Sjason@lowepower.com
15612391Sjason@lowepower.com        ConstBuffer() : value(new int32_t{0}) { };
15712391Sjason@lowepower.com    };
15812391Sjason@lowepower.com    py::class_<ConstBuffer>(m, "ConstBuffer", py::buffer_protocol())
15912391Sjason@lowepower.com        .def(py::init<>())
16012391Sjason@lowepower.com        .def_property("value", &ConstBuffer::get_value, &ConstBuffer::set_value)
16112391Sjason@lowepower.com        .def_buffer(&ConstBuffer::get_buffer_info);
16212391Sjason@lowepower.com
16312391Sjason@lowepower.com    struct DerivedBuffer : public Buffer { };
16412391Sjason@lowepower.com    py::class_<DerivedBuffer>(m, "DerivedBuffer", py::buffer_protocol())
16512391Sjason@lowepower.com        .def(py::init<>())
16612391Sjason@lowepower.com        .def_readwrite("value", (int32_t DerivedBuffer::*) &DerivedBuffer::value)
16712391Sjason@lowepower.com        .def_buffer(&DerivedBuffer::get_buffer_info);
16812391Sjason@lowepower.com
16912391Sjason@lowepower.com}
170