test_buffers.cpp revision 11986
12600SN/A/*
22600SN/A    tests/test_buffers.cpp -- supporting Pythons' buffer protocol
32600SN/A
42600SN/A    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
52600SN/A
62600SN/A    All rights reserved. Use of this source code is governed by a
72600SN/A    BSD-style license that can be found in the LICENSE file.
82600SN/A*/
92600SN/A
102600SN/A#include "pybind11_tests.h"
112600SN/A#include "constructor_stats.h"
122600SN/A
132600SN/Aclass Matrix {
142600SN/Apublic:
152600SN/A    Matrix(size_t rows, size_t cols) : m_rows(rows), m_cols(cols) {
162600SN/A        print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
172600SN/A        m_data = new float[rows*cols];
182600SN/A        memset(m_data, 0, sizeof(float) * rows * cols);
192600SN/A    }
202600SN/A
212600SN/A    Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) {
222600SN/A        print_copy_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
232600SN/A        m_data = new float[m_rows * m_cols];
242600SN/A        memcpy(m_data, s.m_data, sizeof(float) * m_rows * m_cols);
252600SN/A    }
262600SN/A
272665Ssaidi@eecs.umich.edu    Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) {
282665Ssaidi@eecs.umich.edu        print_move_created(this);
292600SN/A        s.m_rows = 0;
302600SN/A        s.m_cols = 0;
318229Snate@binkert.org        s.m_data = nullptr;
3211793Sbrandon.potter@amd.com    }
332600SN/A
346329Sgblack@eecs.umich.edu    ~Matrix() {
3513988Sgabeblack@google.com        print_destroyed(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
362600SN/A        delete[] m_data;
372680Sktlim@umich.edu    }
382600SN/A
392600SN/A    Matrix &operator=(const Matrix &s) {
4011794Sbrandon.potter@amd.com        print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
412600SN/A        delete[] m_data;
422600SN/A        m_rows = s.m_rows;
432600SN/A        m_cols = s.m_cols;
442600SN/A        m_data = new float[m_rows * m_cols];
452600SN/A        memcpy(m_data, s.m_data, sizeof(float) * m_rows * m_cols);
4613988Sgabeblack@google.com        return *this;
4713988Sgabeblack@google.com    }
4813988Sgabeblack@google.com
4913988Sgabeblack@google.com    Matrix &operator=(Matrix &&s) {
5013988Sgabeblack@google.com        print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
5113988Sgabeblack@google.com        if (&s != this) {
5213988Sgabeblack@google.com            delete[] m_data;
5313988Sgabeblack@google.com            m_rows = s.m_rows; m_cols = s.m_cols; m_data = s.m_data;
5413988Sgabeblack@google.com            s.m_rows = 0; s.m_cols = 0; s.m_data = nullptr;
5513988Sgabeblack@google.com        }
5613988Sgabeblack@google.com        return *this;
5713988Sgabeblack@google.com    }
5813988Sgabeblack@google.com
5913988Sgabeblack@google.com    float operator()(size_t i, size_t j) const {
6013988Sgabeblack@google.com        return m_data[i*m_cols + j];
6113988Sgabeblack@google.com    }
6213988Sgabeblack@google.com
6313988Sgabeblack@google.com    float &operator()(size_t i, size_t j) {
6413988Sgabeblack@google.com        return m_data[i*m_cols + j];
6513988Sgabeblack@google.com    }
6613988Sgabeblack@google.com
6713988Sgabeblack@google.com    float *data() { return m_data; }
6813988Sgabeblack@google.com
6913988Sgabeblack@google.com    size_t rows() const { return m_rows; }
7013988Sgabeblack@google.com    size_t cols() const { return m_cols; }
7113988Sgabeblack@google.comprivate:
722600SN/A    size_t m_rows;
732600SN/A    size_t m_cols;
742600SN/A    float *m_data;
7513995Sbrandon.potter@amd.com};
762600SN/A
776701Sgblack@eecs.umich.edutest_initializer buffers([](py::module &m) {
7813995Sbrandon.potter@amd.com    py::class_<Matrix> mtx(m, "Matrix");
796701Sgblack@eecs.umich.edu
802600SN/A    mtx.def(py::init<size_t, size_t>())
812600SN/A        /// Construct from a buffer
822600SN/A        .def("__init__", [](Matrix &v, py::buffer b) {
8314014Sciro.santilli@arm.com            py::buffer_info info = b.request();
842600SN/A            if (info.format != py::format_descriptor<float>::format() || info.ndim != 2)
852600SN/A                throw std::runtime_error("Incompatible buffer format!");
862600SN/A            new (&v) Matrix(info.shape[0], info.shape[1]);
8714024Sgabeblack@google.com            memcpy(v.data(), info.ptr, sizeof(float) * v.rows() * v.cols());
882600SN/A        })
892600SN/A
902600SN/A       .def("rows", &Matrix::rows)
912600SN/A       .def("cols", &Matrix::cols)
922600SN/A
932600SN/A        /// Bare bones interface
942600SN/A       .def("__getitem__", [](const Matrix &m, std::pair<size_t, size_t> i) {
952600SN/A            if (i.first >= m.rows() || i.second >= m.cols())
962600SN/A                throw py::index_error();
9713570Sbrandon.potter@amd.com            return m(i.first, i.second);
9813570Sbrandon.potter@amd.com        })
992600SN/A       .def("__setitem__", [](Matrix &m, std::pair<size_t, size_t> i, float v) {
1002600SN/A            if (i.first >= m.rows() || i.second >= m.cols())
1012600SN/A                throw py::index_error();
1022600SN/A            m(i.first, i.second) = v;
1032600SN/A        })
1042600SN/A       /// Provide buffer access
1052600SN/A       .def_buffer([](Matrix &m) -> py::buffer_info {
1062600SN/A            return py::buffer_info(
1072600SN/A                m.data(),                               /* Pointer to buffer */
1082600SN/A                sizeof(float),                          /* Size of one scalar */
1092600SN/A                py::format_descriptor<float>::format(), /* Python struct-style format descriptor */
1102600SN/A                2,                                      /* Number of dimensions */
1115748SSteve.Reinhardt@amd.com                { m.rows(), m.cols() },                 /* Buffer dimensions */
1122600SN/A                { sizeof(float) * m.rows(),             /* Strides (in bytes) for each index */
1132600SN/A                  sizeof(float) }
1142600SN/A            );
1152600SN/A        })
1162600SN/A        ;
1172600SN/A});
1182600SN/A