test_interpreter.cpp revision 14299:2fbea9df56d2
1#include <pybind11/embed.h>
2
3#ifdef _MSC_VER
4// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch
5// 2.0.1; this should be fixed in the next catch release after 2.0.1).
6#  pragma warning(disable: 4996)
7#endif
8
9#include <catch.hpp>
10
11#include <thread>
12#include <fstream>
13#include <functional>
14
15namespace py = pybind11;
16using namespace py::literals;
17
18class Widget {
19public:
20    Widget(std::string message) : message(message) { }
21    virtual ~Widget() = default;
22
23    std::string the_message() const { return message; }
24    virtual int the_answer() const = 0;
25
26private:
27    std::string message;
28};
29
30class PyWidget final : public Widget {
31    using Widget::Widget;
32
33    int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); }
34};
35
36PYBIND11_EMBEDDED_MODULE(widget_module, m) {
37    py::class_<Widget, PyWidget>(m, "Widget")
38        .def(py::init<std::string>())
39        .def_property_readonly("the_message", &Widget::the_message);
40
41    m.def("add", [](int i, int j) { return i + j; });
42}
43
44PYBIND11_EMBEDDED_MODULE(throw_exception, ) {
45    throw std::runtime_error("C++ Error");
46}
47
48PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
49    auto d = py::dict();
50    d["missing"].cast<py::object>();
51}
52
53TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
54    auto module = py::module::import("test_interpreter");
55    REQUIRE(py::hasattr(module, "DerivedWidget"));
56
57    auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__"));
58    py::exec(R"(
59        widget = DerivedWidget("{} - {}".format(hello, x))
60        message = widget.the_message
61    )", py::globals(), locals);
62    REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
63
64    auto py_widget = module.attr("DerivedWidget")("The question");
65    auto message = py_widget.attr("the_message");
66    REQUIRE(message.cast<std::string>() == "The question");
67
68    const auto &cpp_widget = py_widget.cast<const Widget &>();
69    REQUIRE(cpp_widget.the_answer() == 42);
70}
71
72TEST_CASE("Import error handling") {
73    REQUIRE_NOTHROW(py::module::import("widget_module"));
74    REQUIRE_THROWS_WITH(py::module::import("throw_exception"),
75                        "ImportError: C++ Error");
76    REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"),
77                        Catch::Contains("ImportError: KeyError"));
78}
79
80TEST_CASE("There can be only one interpreter") {
81    static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
82    static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
83    static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
84    static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");
85
86    REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
87    REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
88
89    py::finalize_interpreter();
90    REQUIRE_NOTHROW(py::scoped_interpreter());
91    {
92        auto pyi1 = py::scoped_interpreter();
93        auto pyi2 = std::move(pyi1);
94    }
95    py::initialize_interpreter();
96}
97
98bool has_pybind11_internals_builtin() {
99    auto builtins = py::handle(PyEval_GetBuiltins());
100    return builtins.contains(PYBIND11_INTERNALS_ID);
101};
102
103bool has_pybind11_internals_static() {
104    auto **&ipp = py::detail::get_internals_pp();
105    return ipp && *ipp;
106}
107
108TEST_CASE("Restart the interpreter") {
109    // Verify pre-restart state.
110    REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
111    REQUIRE(has_pybind11_internals_builtin());
112    REQUIRE(has_pybind11_internals_static());
113    REQUIRE(py::module::import("external_module").attr("A")(123).attr("value").cast<int>() == 123);
114
115    // local and foreign module internals should point to the same internals:
116    REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) ==
117            py::module::import("external_module").attr("internals_at")().cast<uintptr_t>());
118
119    // Restart the interpreter.
120    py::finalize_interpreter();
121    REQUIRE(Py_IsInitialized() == 0);
122
123    py::initialize_interpreter();
124    REQUIRE(Py_IsInitialized() == 1);
125
126    // Internals are deleted after a restart.
127    REQUIRE_FALSE(has_pybind11_internals_builtin());
128    REQUIRE_FALSE(has_pybind11_internals_static());
129    pybind11::detail::get_internals();
130    REQUIRE(has_pybind11_internals_builtin());
131    REQUIRE(has_pybind11_internals_static());
132    REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) ==
133            py::module::import("external_module").attr("internals_at")().cast<uintptr_t>());
134
135    // Make sure that an interpreter with no get_internals() created until finalize still gets the
136    // internals destroyed
137    py::finalize_interpreter();
138    py::initialize_interpreter();
139    bool ran = false;
140    py::module::import("__main__").attr("internals_destroy_test") =
141        py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; });
142    REQUIRE_FALSE(has_pybind11_internals_builtin());
143    REQUIRE_FALSE(has_pybind11_internals_static());
144    REQUIRE_FALSE(ran);
145    py::finalize_interpreter();
146    REQUIRE(ran);
147    py::initialize_interpreter();
148    REQUIRE_FALSE(has_pybind11_internals_builtin());
149    REQUIRE_FALSE(has_pybind11_internals_static());
150
151    // C++ modules can be reloaded.
152    auto cpp_module = py::module::import("widget_module");
153    REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
154
155    // C++ type information is reloaded and can be used in python modules.
156    auto py_module = py::module::import("test_interpreter");
157    auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
158    REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
159}
160
161TEST_CASE("Subinterpreter") {
162    // Add tags to the modules in the main interpreter and test the basics.
163    py::module::import("__main__").attr("main_tag") = "main interpreter";
164    {
165        auto m = py::module::import("widget_module");
166        m.attr("extension_module_tag") = "added to module in main interpreter";
167
168        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
169    }
170    REQUIRE(has_pybind11_internals_builtin());
171    REQUIRE(has_pybind11_internals_static());
172
173    /// Create and switch to a subinterpreter.
174    auto main_tstate = PyThreadState_Get();
175    auto sub_tstate = Py_NewInterpreter();
176
177    // Subinterpreters get their own copy of builtins. detail::get_internals() still
178    // works by returning from the static variable, i.e. all interpreters share a single
179    // global pybind11::internals;
180    REQUIRE_FALSE(has_pybind11_internals_builtin());
181    REQUIRE(has_pybind11_internals_static());
182
183    // Modules tags should be gone.
184    REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag"));
185    {
186        auto m = py::module::import("widget_module");
187        REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
188
189        // Function bindings should still work.
190        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
191    }
192
193    // Restore main interpreter.
194    Py_EndInterpreter(sub_tstate);
195    PyThreadState_Swap(main_tstate);
196
197    REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag"));
198    REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag"));
199}
200
201TEST_CASE("Execution frame") {
202    // When the interpreter is embedded, there is no execution frame, but `py::exec`
203    // should still function by using reasonable globals: `__main__.__dict__`.
204    py::exec("var = dict(number=42)");
205    REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
206}
207
208TEST_CASE("Threads") {
209    // Restart interpreter to ensure threads are not initialized
210    py::finalize_interpreter();
211    py::initialize_interpreter();
212    REQUIRE_FALSE(has_pybind11_internals_static());
213
214    constexpr auto num_threads = 10;
215    auto locals = py::dict("count"_a=0);
216
217    {
218        py::gil_scoped_release gil_release{};
219        REQUIRE(has_pybind11_internals_static());
220
221        auto threads = std::vector<std::thread>();
222        for (auto i = 0; i < num_threads; ++i) {
223            threads.emplace_back([&]() {
224                py::gil_scoped_acquire gil{};
225                locals["count"] = locals["count"].cast<int>() + 1;
226            });
227        }
228
229        for (auto &thread : threads) {
230            thread.join();
231        }
232    }
233
234    REQUIRE(locals["count"].cast<int>() == num_threads);
235}
236
237// Scope exit utility https://stackoverflow.com/a/36644501/7255855
238struct scope_exit {
239    std::function<void()> f_;
240    explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
241    ~scope_exit() { if (f_) f_(); }
242};
243
244TEST_CASE("Reload module from file") {
245    // Disable generation of cached bytecode (.pyc files) for this test, otherwise
246    // Python might pick up an old version from the cache instead of the new versions
247    // of the .py files generated below
248    auto sys = py::module::import("sys");
249    bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
250    sys.attr("dont_write_bytecode") = true;
251    // Reset the value at scope exit
252    scope_exit reset_dont_write_bytecode([&]() {
253        sys.attr("dont_write_bytecode") = dont_write_bytecode;
254    });
255
256    std::string module_name = "test_module_reload";
257    std::string module_file = module_name + ".py";
258
259    // Create the module .py file
260    std::ofstream test_module(module_file);
261    test_module << "def test():\n";
262    test_module << "    return 1\n";
263    test_module.close();
264    // Delete the file at scope exit
265    scope_exit delete_module_file([&]() {
266        std::remove(module_file.c_str());
267    });
268
269    // Import the module from file
270    auto module = py::module::import(module_name.c_str());
271    int result = module.attr("test")().cast<int>();
272    REQUIRE(result == 1);
273
274    // Update the module .py file with a small change
275    test_module.open(module_file);
276    test_module << "def test():\n";
277    test_module << "    return 2\n";
278    test_module.close();
279
280    // Reload the module
281    module.reload();
282    result = module.attr("test")().cast<int>();
283    REQUIRE(result == 2);
284}
285