test_interpreter.cpp revision 12391:ceeca8b41e4b
1#include <pybind11/embed.h>
2#include <catch.hpp>
3
4#include <thread>
5#include <fstream>
6#include <functional>
7
8namespace py = pybind11;
9using namespace py::literals;
10
11class Widget {
12public:
13    Widget(std::string message) : message(message) { }
14    virtual ~Widget() = default;
15
16    std::string the_message() const { return message; }
17    virtual int the_answer() const = 0;
18
19private:
20    std::string message;
21};
22
23class PyWidget final : public Widget {
24    using Widget::Widget;
25
26    int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); }
27};
28
29PYBIND11_EMBEDDED_MODULE(widget_module, m) {
30    py::class_<Widget, PyWidget>(m, "Widget")
31        .def(py::init<std::string>())
32        .def_property_readonly("the_message", &Widget::the_message);
33
34    m.def("add", [](int i, int j) { return i + j; });
35}
36
37PYBIND11_EMBEDDED_MODULE(throw_exception, ) {
38    throw std::runtime_error("C++ Error");
39}
40
41PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
42    auto d = py::dict();
43    d["missing"].cast<py::object>();
44}
45
46TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
47    auto module = py::module::import("test_interpreter");
48    REQUIRE(py::hasattr(module, "DerivedWidget"));
49
50    auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__"));
51    py::exec(R"(
52        widget = DerivedWidget("{} - {}".format(hello, x))
53        message = widget.the_message
54    )", py::globals(), locals);
55    REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
56
57    auto py_widget = module.attr("DerivedWidget")("The question");
58    auto message = py_widget.attr("the_message");
59    REQUIRE(message.cast<std::string>() == "The question");
60
61    const auto &cpp_widget = py_widget.cast<const Widget &>();
62    REQUIRE(cpp_widget.the_answer() == 42);
63}
64
65TEST_CASE("Import error handling") {
66    REQUIRE_NOTHROW(py::module::import("widget_module"));
67    REQUIRE_THROWS_WITH(py::module::import("throw_exception"),
68                        "ImportError: C++ Error");
69    REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"),
70                        Catch::Contains("ImportError: KeyError"));
71}
72
73TEST_CASE("There can be only one interpreter") {
74    static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
75    static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
76    static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
77    static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");
78
79    REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
80    REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
81
82    py::finalize_interpreter();
83    REQUIRE_NOTHROW(py::scoped_interpreter());
84    {
85        auto pyi1 = py::scoped_interpreter();
86        auto pyi2 = std::move(pyi1);
87    }
88    py::initialize_interpreter();
89}
90
91bool has_pybind11_internals_builtin() {
92    auto builtins = py::handle(PyEval_GetBuiltins());
93    return builtins.contains(PYBIND11_INTERNALS_ID);
94};
95
96bool has_pybind11_internals_static() {
97    return py::detail::get_internals_ptr() != nullptr;
98}
99
100TEST_CASE("Restart the interpreter") {
101    // Verify pre-restart state.
102    REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
103    REQUIRE(has_pybind11_internals_builtin());
104    REQUIRE(has_pybind11_internals_static());
105
106    // Restart the interpreter.
107    py::finalize_interpreter();
108    REQUIRE(Py_IsInitialized() == 0);
109
110    py::initialize_interpreter();
111    REQUIRE(Py_IsInitialized() == 1);
112
113    // Internals are deleted after a restart.
114    REQUIRE_FALSE(has_pybind11_internals_builtin());
115    REQUIRE_FALSE(has_pybind11_internals_static());
116    pybind11::detail::get_internals();
117    REQUIRE(has_pybind11_internals_builtin());
118    REQUIRE(has_pybind11_internals_static());
119
120    // Make sure that an interpreter with no get_internals() created until finalize still gets the
121    // internals destroyed
122    py::finalize_interpreter();
123    py::initialize_interpreter();
124    bool ran = false;
125    py::module::import("__main__").attr("internals_destroy_test") =
126        py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; });
127    REQUIRE_FALSE(has_pybind11_internals_builtin());
128    REQUIRE_FALSE(has_pybind11_internals_static());
129    REQUIRE_FALSE(ran);
130    py::finalize_interpreter();
131    REQUIRE(ran);
132    py::initialize_interpreter();
133    REQUIRE_FALSE(has_pybind11_internals_builtin());
134    REQUIRE_FALSE(has_pybind11_internals_static());
135
136    // C++ modules can be reloaded.
137    auto cpp_module = py::module::import("widget_module");
138    REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
139
140    // C++ type information is reloaded and can be used in python modules.
141    auto py_module = py::module::import("test_interpreter");
142    auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
143    REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
144}
145
146TEST_CASE("Subinterpreter") {
147    // Add tags to the modules in the main interpreter and test the basics.
148    py::module::import("__main__").attr("main_tag") = "main interpreter";
149    {
150        auto m = py::module::import("widget_module");
151        m.attr("extension_module_tag") = "added to module in main interpreter";
152
153        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
154    }
155    REQUIRE(has_pybind11_internals_builtin());
156    REQUIRE(has_pybind11_internals_static());
157
158    /// Create and switch to a subinterpreter.
159    auto main_tstate = PyThreadState_Get();
160    auto sub_tstate = Py_NewInterpreter();
161
162    // Subinterpreters get their own copy of builtins. detail::get_internals() still
163    // works by returning from the static variable, i.e. all interpreters share a single
164    // global pybind11::internals;
165    REQUIRE_FALSE(has_pybind11_internals_builtin());
166    REQUIRE(has_pybind11_internals_static());
167
168    // Modules tags should be gone.
169    REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag"));
170    {
171        auto m = py::module::import("widget_module");
172        REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
173
174        // Function bindings should still work.
175        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
176    }
177
178    // Restore main interpreter.
179    Py_EndInterpreter(sub_tstate);
180    PyThreadState_Swap(main_tstate);
181
182    REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag"));
183    REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag"));
184}
185
186TEST_CASE("Execution frame") {
187    // When the interpreter is embedded, there is no execution frame, but `py::exec`
188    // should still function by using reasonable globals: `__main__.__dict__`.
189    py::exec("var = dict(number=42)");
190    REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
191}
192
193TEST_CASE("Threads") {
194    // Restart interpreter to ensure threads are not initialized
195    py::finalize_interpreter();
196    py::initialize_interpreter();
197    REQUIRE_FALSE(has_pybind11_internals_static());
198
199    constexpr auto num_threads = 10;
200    auto locals = py::dict("count"_a=0);
201
202    {
203        py::gil_scoped_release gil_release{};
204        REQUIRE(has_pybind11_internals_static());
205
206        auto threads = std::vector<std::thread>();
207        for (auto i = 0; i < num_threads; ++i) {
208            threads.emplace_back([&]() {
209                py::gil_scoped_acquire gil{};
210                locals["count"] = locals["count"].cast<int>() + 1;
211            });
212        }
213
214        for (auto &thread : threads) {
215            thread.join();
216        }
217    }
218
219    REQUIRE(locals["count"].cast<int>() == num_threads);
220}
221
222// Scope exit utility https://stackoverflow.com/a/36644501/7255855
223struct scope_exit {
224    std::function<void()> f_;
225    explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
226    ~scope_exit() { if (f_) f_(); }
227};
228
229TEST_CASE("Reload module from file") {
230    // Disable generation of cached bytecode (.pyc files) for this test, otherwise
231    // Python might pick up an old version from the cache instead of the new versions
232    // of the .py files generated below
233    auto sys = py::module::import("sys");
234    bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
235    sys.attr("dont_write_bytecode") = true;
236    // Reset the value at scope exit
237    scope_exit reset_dont_write_bytecode([&]() {
238        sys.attr("dont_write_bytecode") = dont_write_bytecode;
239    });
240
241    std::string module_name = "test_module_reload";
242    std::string module_file = module_name + ".py";
243
244    // Create the module .py file
245    std::ofstream test_module(module_file);
246    test_module << "def test():\n";
247    test_module << "    return 1\n";
248    test_module.close();
249    // Delete the file at scope exit
250    scope_exit delete_module_file([&]() {
251        std::remove(module_file.c_str());
252    });
253
254    // Import the module from file
255    auto module = py::module::import(module_name.c_str());
256    int result = module.attr("test")().cast<int>();
257    REQUIRE(result == 1);
258
259    // Update the module .py file with a small change
260    test_module.open(module_file);
261    test_module << "def test():\n";
262    test_module << "    return 2\n";
263    test_module.close();
264
265    // Reload the module
266    module.reload();
267    result = module.attr("test")().cast<int>();
268    REQUIRE(result == 2);
269}
270