111986Sandreas.sandberg@arm.com/*
211986Sandreas.sandberg@arm.com    tests/test_pickling.cpp -- pickle support
311986Sandreas.sandberg@arm.com
411986Sandreas.sandberg@arm.com    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
511986Sandreas.sandberg@arm.com
611986Sandreas.sandberg@arm.com    All rights reserved. Use of this source code is governed by a
711986Sandreas.sandberg@arm.com    BSD-style license that can be found in the LICENSE file.
811986Sandreas.sandberg@arm.com*/
911986Sandreas.sandberg@arm.com
1011986Sandreas.sandberg@arm.com#include "pybind11_tests.h"
1111986Sandreas.sandberg@arm.com
1212391Sjason@lowepower.comTEST_SUBMODULE(pickling, m) {
1312391Sjason@lowepower.com    // test_roundtrip
1412391Sjason@lowepower.com    class Pickleable {
1512391Sjason@lowepower.com    public:
1612391Sjason@lowepower.com        Pickleable(const std::string &value) : m_value(value) { }
1712391Sjason@lowepower.com        const std::string &value() const { return m_value; }
1811986Sandreas.sandberg@arm.com
1912391Sjason@lowepower.com        void setExtra1(int extra1) { m_extra1 = extra1; }
2012391Sjason@lowepower.com        void setExtra2(int extra2) { m_extra2 = extra2; }
2112391Sjason@lowepower.com        int extra1() const { return m_extra1; }
2212391Sjason@lowepower.com        int extra2() const { return m_extra2; }
2312391Sjason@lowepower.com    private:
2412391Sjason@lowepower.com        std::string m_value;
2512391Sjason@lowepower.com        int m_extra1 = 0;
2612391Sjason@lowepower.com        int m_extra2 = 0;
2712391Sjason@lowepower.com    };
2811986Sandreas.sandberg@arm.com
2912391Sjason@lowepower.com    class PickleableNew : public Pickleable {
3012391Sjason@lowepower.com    public:
3112391Sjason@lowepower.com        using Pickleable::Pickleable;
3212391Sjason@lowepower.com    };
3311986Sandreas.sandberg@arm.com
3411986Sandreas.sandberg@arm.com    py::class_<Pickleable>(m, "Pickleable")
3511986Sandreas.sandberg@arm.com        .def(py::init<std::string>())
3611986Sandreas.sandberg@arm.com        .def("value", &Pickleable::value)
3711986Sandreas.sandberg@arm.com        .def("extra1", &Pickleable::extra1)
3811986Sandreas.sandberg@arm.com        .def("extra2", &Pickleable::extra2)
3911986Sandreas.sandberg@arm.com        .def("setExtra1", &Pickleable::setExtra1)
4011986Sandreas.sandberg@arm.com        .def("setExtra2", &Pickleable::setExtra2)
4111986Sandreas.sandberg@arm.com        // For details on the methods below, refer to
4211986Sandreas.sandberg@arm.com        // http://docs.python.org/3/library/pickle.html#pickling-class-instances
4311986Sandreas.sandberg@arm.com        .def("__getstate__", [](const Pickleable &p) {
4411986Sandreas.sandberg@arm.com            /* Return a tuple that fully encodes the state of the object */
4511986Sandreas.sandberg@arm.com            return py::make_tuple(p.value(), p.extra1(), p.extra2());
4611986Sandreas.sandberg@arm.com        })
4711986Sandreas.sandberg@arm.com        .def("__setstate__", [](Pickleable &p, py::tuple t) {
4811986Sandreas.sandberg@arm.com            if (t.size() != 3)
4911986Sandreas.sandberg@arm.com                throw std::runtime_error("Invalid state!");
5011986Sandreas.sandberg@arm.com            /* Invoke the constructor (need to use in-place version) */
5111986Sandreas.sandberg@arm.com            new (&p) Pickleable(t[0].cast<std::string>());
5211986Sandreas.sandberg@arm.com
5311986Sandreas.sandberg@arm.com            /* Assign any additional state */
5411986Sandreas.sandberg@arm.com            p.setExtra1(t[1].cast<int>());
5511986Sandreas.sandberg@arm.com            p.setExtra2(t[2].cast<int>());
5611986Sandreas.sandberg@arm.com        });
5711986Sandreas.sandberg@arm.com
5812391Sjason@lowepower.com    py::class_<PickleableNew, Pickleable>(m, "PickleableNew")
5912391Sjason@lowepower.com        .def(py::init<std::string>())
6012391Sjason@lowepower.com        .def(py::pickle(
6112391Sjason@lowepower.com            [](const PickleableNew &p) {
6212391Sjason@lowepower.com                return py::make_tuple(p.value(), p.extra1(), p.extra2());
6312391Sjason@lowepower.com            },
6412391Sjason@lowepower.com            [](py::tuple t) {
6512391Sjason@lowepower.com                if (t.size() != 3)
6612391Sjason@lowepower.com                    throw std::runtime_error("Invalid state!");
6712391Sjason@lowepower.com                auto p = PickleableNew(t[0].cast<std::string>());
6812391Sjason@lowepower.com
6912391Sjason@lowepower.com                p.setExtra1(t[1].cast<int>());
7012391Sjason@lowepower.com                p.setExtra2(t[2].cast<int>());
7112391Sjason@lowepower.com                return p;
7212391Sjason@lowepower.com            }
7312391Sjason@lowepower.com        ));
7412391Sjason@lowepower.com
7512037Sandreas.sandberg@arm.com#if !defined(PYPY_VERSION)
7612391Sjason@lowepower.com    // test_roundtrip_with_dict
7712391Sjason@lowepower.com    class PickleableWithDict {
7812391Sjason@lowepower.com    public:
7912391Sjason@lowepower.com        PickleableWithDict(const std::string &value) : value(value) { }
8012391Sjason@lowepower.com
8112391Sjason@lowepower.com        std::string value;
8212391Sjason@lowepower.com        int extra;
8312391Sjason@lowepower.com    };
8412391Sjason@lowepower.com
8512391Sjason@lowepower.com    class PickleableWithDictNew : public PickleableWithDict {
8612391Sjason@lowepower.com    public:
8712391Sjason@lowepower.com        using PickleableWithDict::PickleableWithDict;
8812391Sjason@lowepower.com    };
8912391Sjason@lowepower.com
9011986Sandreas.sandberg@arm.com    py::class_<PickleableWithDict>(m, "PickleableWithDict", py::dynamic_attr())
9111986Sandreas.sandberg@arm.com        .def(py::init<std::string>())
9211986Sandreas.sandberg@arm.com        .def_readwrite("value", &PickleableWithDict::value)
9311986Sandreas.sandberg@arm.com        .def_readwrite("extra", &PickleableWithDict::extra)
9411986Sandreas.sandberg@arm.com        .def("__getstate__", [](py::object self) {
9511986Sandreas.sandberg@arm.com            /* Also include __dict__ in state */
9611986Sandreas.sandberg@arm.com            return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__"));
9711986Sandreas.sandberg@arm.com        })
9811986Sandreas.sandberg@arm.com        .def("__setstate__", [](py::object self, py::tuple t) {
9911986Sandreas.sandberg@arm.com            if (t.size() != 3)
10011986Sandreas.sandberg@arm.com                throw std::runtime_error("Invalid state!");
10111986Sandreas.sandberg@arm.com            /* Cast and construct */
10211986Sandreas.sandberg@arm.com            auto& p = self.cast<PickleableWithDict&>();
10312037Sandreas.sandberg@arm.com            new (&p) PickleableWithDict(t[0].cast<std::string>());
10411986Sandreas.sandberg@arm.com
10511986Sandreas.sandberg@arm.com            /* Assign C++ state */
10611986Sandreas.sandberg@arm.com            p.extra = t[1].cast<int>();
10711986Sandreas.sandberg@arm.com
10811986Sandreas.sandberg@arm.com            /* Assign Python state */
10911986Sandreas.sandberg@arm.com            self.attr("__dict__") = t[2];
11011986Sandreas.sandberg@arm.com        });
11112391Sjason@lowepower.com
11212391Sjason@lowepower.com    py::class_<PickleableWithDictNew, PickleableWithDict>(m, "PickleableWithDictNew")
11312391Sjason@lowepower.com        .def(py::init<std::string>())
11412391Sjason@lowepower.com        .def(py::pickle(
11512391Sjason@lowepower.com            [](py::object self) {
11612391Sjason@lowepower.com                return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__"));
11712391Sjason@lowepower.com            },
11812391Sjason@lowepower.com            [](const py::tuple &t) {
11912391Sjason@lowepower.com                if (t.size() != 3)
12012391Sjason@lowepower.com                    throw std::runtime_error("Invalid state!");
12112391Sjason@lowepower.com
12212391Sjason@lowepower.com                auto cpp_state = PickleableWithDictNew(t[0].cast<std::string>());
12312391Sjason@lowepower.com                cpp_state.extra = t[1].cast<int>();
12412391Sjason@lowepower.com
12512391Sjason@lowepower.com                auto py_state = t[2].cast<py::dict>();
12612391Sjason@lowepower.com                return std::make_pair(cpp_state, py_state);
12712391Sjason@lowepower.com            }
12812391Sjason@lowepower.com        ));
12912037Sandreas.sandberg@arm.com#endif
13012391Sjason@lowepower.com}
131