test_buffers.cpp revision 11986
111986Sandreas.sandberg@arm.com/*
211986Sandreas.sandberg@arm.com    tests/test_buffers.cpp -- supporting Pythons' buffer protocol
311986Sandreas.sandberg@arm.com
411986Sandreas.sandberg@arm.com    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
511986Sandreas.sandberg@arm.com
611986Sandreas.sandberg@arm.com    All rights reserved. Use of this source code is governed by a
711986Sandreas.sandberg@arm.com    BSD-style license that can be found in the LICENSE file.
811986Sandreas.sandberg@arm.com*/
911986Sandreas.sandberg@arm.com
1011986Sandreas.sandberg@arm.com#include "pybind11_tests.h"
1111986Sandreas.sandberg@arm.com#include "constructor_stats.h"
1211986Sandreas.sandberg@arm.com
1311986Sandreas.sandberg@arm.comclass Matrix {
1411986Sandreas.sandberg@arm.compublic:
1511986Sandreas.sandberg@arm.com    Matrix(size_t rows, size_t cols) : m_rows(rows), m_cols(cols) {
1611986Sandreas.sandberg@arm.com        print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
1711986Sandreas.sandberg@arm.com        m_data = new float[rows*cols];
1811986Sandreas.sandberg@arm.com        memset(m_data, 0, sizeof(float) * rows * cols);
1911986Sandreas.sandberg@arm.com    }
2011986Sandreas.sandberg@arm.com
2111986Sandreas.sandberg@arm.com    Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) {
2211986Sandreas.sandberg@arm.com        print_copy_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
2311986Sandreas.sandberg@arm.com        m_data = new float[m_rows * m_cols];
2411986Sandreas.sandberg@arm.com        memcpy(m_data, s.m_data, sizeof(float) * m_rows * m_cols);
2511986Sandreas.sandberg@arm.com    }
2611986Sandreas.sandberg@arm.com
2711986Sandreas.sandberg@arm.com    Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) {
2811986Sandreas.sandberg@arm.com        print_move_created(this);
2911986Sandreas.sandberg@arm.com        s.m_rows = 0;
3011986Sandreas.sandberg@arm.com        s.m_cols = 0;
3111986Sandreas.sandberg@arm.com        s.m_data = nullptr;
3211986Sandreas.sandberg@arm.com    }
3311986Sandreas.sandberg@arm.com
3411986Sandreas.sandberg@arm.com    ~Matrix() {
3511986Sandreas.sandberg@arm.com        print_destroyed(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
3611986Sandreas.sandberg@arm.com        delete[] m_data;
3711986Sandreas.sandberg@arm.com    }
3811986Sandreas.sandberg@arm.com
3911986Sandreas.sandberg@arm.com    Matrix &operator=(const Matrix &s) {
4011986Sandreas.sandberg@arm.com        print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
4111986Sandreas.sandberg@arm.com        delete[] m_data;
4211986Sandreas.sandberg@arm.com        m_rows = s.m_rows;
4311986Sandreas.sandberg@arm.com        m_cols = s.m_cols;
4411986Sandreas.sandberg@arm.com        m_data = new float[m_rows * m_cols];
4511986Sandreas.sandberg@arm.com        memcpy(m_data, s.m_data, sizeof(float) * m_rows * m_cols);
4611986Sandreas.sandberg@arm.com        return *this;
4711986Sandreas.sandberg@arm.com    }
4811986Sandreas.sandberg@arm.com
4911986Sandreas.sandberg@arm.com    Matrix &operator=(Matrix &&s) {
5011986Sandreas.sandberg@arm.com        print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
5111986Sandreas.sandberg@arm.com        if (&s != this) {
5211986Sandreas.sandberg@arm.com            delete[] m_data;
5311986Sandreas.sandberg@arm.com            m_rows = s.m_rows; m_cols = s.m_cols; m_data = s.m_data;
5411986Sandreas.sandberg@arm.com            s.m_rows = 0; s.m_cols = 0; s.m_data = nullptr;
5511986Sandreas.sandberg@arm.com        }
5611986Sandreas.sandberg@arm.com        return *this;
5711986Sandreas.sandberg@arm.com    }
5811986Sandreas.sandberg@arm.com
5911986Sandreas.sandberg@arm.com    float operator()(size_t i, size_t j) const {
6011986Sandreas.sandberg@arm.com        return m_data[i*m_cols + j];
6111986Sandreas.sandberg@arm.com    }
6211986Sandreas.sandberg@arm.com
6311986Sandreas.sandberg@arm.com    float &operator()(size_t i, size_t j) {
6411986Sandreas.sandberg@arm.com        return m_data[i*m_cols + j];
6511986Sandreas.sandberg@arm.com    }
6611986Sandreas.sandberg@arm.com
6711986Sandreas.sandberg@arm.com    float *data() { return m_data; }
6811986Sandreas.sandberg@arm.com
6911986Sandreas.sandberg@arm.com    size_t rows() const { return m_rows; }
7011986Sandreas.sandberg@arm.com    size_t cols() const { return m_cols; }
7111986Sandreas.sandberg@arm.comprivate:
7211986Sandreas.sandberg@arm.com    size_t m_rows;
7311986Sandreas.sandberg@arm.com    size_t m_cols;
7411986Sandreas.sandberg@arm.com    float *m_data;
7511986Sandreas.sandberg@arm.com};
7611986Sandreas.sandberg@arm.com
7711986Sandreas.sandberg@arm.comtest_initializer buffers([](py::module &m) {
7811986Sandreas.sandberg@arm.com    py::class_<Matrix> mtx(m, "Matrix");
7911986Sandreas.sandberg@arm.com
8011986Sandreas.sandberg@arm.com    mtx.def(py::init<size_t, size_t>())
8111986Sandreas.sandberg@arm.com        /// Construct from a buffer
8211986Sandreas.sandberg@arm.com        .def("__init__", [](Matrix &v, py::buffer b) {
8311986Sandreas.sandberg@arm.com            py::buffer_info info = b.request();
8411986Sandreas.sandberg@arm.com            if (info.format != py::format_descriptor<float>::format() || info.ndim != 2)
8511986Sandreas.sandberg@arm.com                throw std::runtime_error("Incompatible buffer format!");
8611986Sandreas.sandberg@arm.com            new (&v) Matrix(info.shape[0], info.shape[1]);
8711986Sandreas.sandberg@arm.com            memcpy(v.data(), info.ptr, sizeof(float) * v.rows() * v.cols());
8811986Sandreas.sandberg@arm.com        })
8911986Sandreas.sandberg@arm.com
9011986Sandreas.sandberg@arm.com       .def("rows", &Matrix::rows)
9111986Sandreas.sandberg@arm.com       .def("cols", &Matrix::cols)
9211986Sandreas.sandberg@arm.com
9311986Sandreas.sandberg@arm.com        /// Bare bones interface
9411986Sandreas.sandberg@arm.com       .def("__getitem__", [](const Matrix &m, std::pair<size_t, size_t> i) {
9511986Sandreas.sandberg@arm.com            if (i.first >= m.rows() || i.second >= m.cols())
9611986Sandreas.sandberg@arm.com                throw py::index_error();
9711986Sandreas.sandberg@arm.com            return m(i.first, i.second);
9811986Sandreas.sandberg@arm.com        })
9911986Sandreas.sandberg@arm.com       .def("__setitem__", [](Matrix &m, std::pair<size_t, size_t> i, float v) {
10011986Sandreas.sandberg@arm.com            if (i.first >= m.rows() || i.second >= m.cols())
10111986Sandreas.sandberg@arm.com                throw py::index_error();
10211986Sandreas.sandberg@arm.com            m(i.first, i.second) = v;
10311986Sandreas.sandberg@arm.com        })
10411986Sandreas.sandberg@arm.com       /// Provide buffer access
10511986Sandreas.sandberg@arm.com       .def_buffer([](Matrix &m) -> py::buffer_info {
10611986Sandreas.sandberg@arm.com            return py::buffer_info(
10711986Sandreas.sandberg@arm.com                m.data(),                               /* Pointer to buffer */
10811986Sandreas.sandberg@arm.com                sizeof(float),                          /* Size of one scalar */
10911986Sandreas.sandberg@arm.com                py::format_descriptor<float>::format(), /* Python struct-style format descriptor */
11011986Sandreas.sandberg@arm.com                2,                                      /* Number of dimensions */
11111986Sandreas.sandberg@arm.com                { m.rows(), m.cols() },                 /* Buffer dimensions */
11211986Sandreas.sandberg@arm.com                { sizeof(float) * m.rows(),             /* Strides (in bytes) for each index */
11311986Sandreas.sandberg@arm.com                  sizeof(float) }
11411986Sandreas.sandberg@arm.com            );
11511986Sandreas.sandberg@arm.com        })
11611986Sandreas.sandberg@arm.com        ;
11711986Sandreas.sandberg@arm.com});
118