test_interpreter.cpp revision 12391:ceeca8b41e4b
1#include <pybind11/embed.h> 2#include <catch.hpp> 3 4#include <thread> 5#include <fstream> 6#include <functional> 7 8namespace py = pybind11; 9using namespace py::literals; 10 11class Widget { 12public: 13 Widget(std::string message) : message(message) { } 14 virtual ~Widget() = default; 15 16 std::string the_message() const { return message; } 17 virtual int the_answer() const = 0; 18 19private: 20 std::string message; 21}; 22 23class PyWidget final : public Widget { 24 using Widget::Widget; 25 26 int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); } 27}; 28 29PYBIND11_EMBEDDED_MODULE(widget_module, m) { 30 py::class_<Widget, PyWidget>(m, "Widget") 31 .def(py::init<std::string>()) 32 .def_property_readonly("the_message", &Widget::the_message); 33 34 m.def("add", [](int i, int j) { return i + j; }); 35} 36 37PYBIND11_EMBEDDED_MODULE(throw_exception, ) { 38 throw std::runtime_error("C++ Error"); 39} 40 41PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) { 42 auto d = py::dict(); 43 d["missing"].cast<py::object>(); 44} 45 46TEST_CASE("Pass classes and data between modules defined in C++ and Python") { 47 auto module = py::module::import("test_interpreter"); 48 REQUIRE(py::hasattr(module, "DerivedWidget")); 49 50 auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__")); 51 py::exec(R"( 52 widget = DerivedWidget("{} - {}".format(hello, x)) 53 message = widget.the_message 54 )", py::globals(), locals); 55 REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5"); 56 57 auto py_widget = module.attr("DerivedWidget")("The question"); 58 auto message = py_widget.attr("the_message"); 59 REQUIRE(message.cast<std::string>() == "The question"); 60 61 const auto &cpp_widget = py_widget.cast<const Widget &>(); 62 REQUIRE(cpp_widget.the_answer() == 42); 63} 64 65TEST_CASE("Import error handling") { 66 REQUIRE_NOTHROW(py::module::import("widget_module")); 67 REQUIRE_THROWS_WITH(py::module::import("throw_exception"), 68 "ImportError: C++ Error"); 69 REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"), 70 Catch::Contains("ImportError: KeyError")); 71} 72 73TEST_CASE("There can be only one interpreter") { 74 static_assert(std::is_move_constructible<py::scoped_interpreter>::value, ""); 75 static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, ""); 76 static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, ""); 77 static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, ""); 78 79 REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running"); 80 REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running"); 81 82 py::finalize_interpreter(); 83 REQUIRE_NOTHROW(py::scoped_interpreter()); 84 { 85 auto pyi1 = py::scoped_interpreter(); 86 auto pyi2 = std::move(pyi1); 87 } 88 py::initialize_interpreter(); 89} 90 91bool has_pybind11_internals_builtin() { 92 auto builtins = py::handle(PyEval_GetBuiltins()); 93 return builtins.contains(PYBIND11_INTERNALS_ID); 94}; 95 96bool has_pybind11_internals_static() { 97 return py::detail::get_internals_ptr() != nullptr; 98} 99 100TEST_CASE("Restart the interpreter") { 101 // Verify pre-restart state. 102 REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3); 103 REQUIRE(has_pybind11_internals_builtin()); 104 REQUIRE(has_pybind11_internals_static()); 105 106 // Restart the interpreter. 107 py::finalize_interpreter(); 108 REQUIRE(Py_IsInitialized() == 0); 109 110 py::initialize_interpreter(); 111 REQUIRE(Py_IsInitialized() == 1); 112 113 // Internals are deleted after a restart. 114 REQUIRE_FALSE(has_pybind11_internals_builtin()); 115 REQUIRE_FALSE(has_pybind11_internals_static()); 116 pybind11::detail::get_internals(); 117 REQUIRE(has_pybind11_internals_builtin()); 118 REQUIRE(has_pybind11_internals_static()); 119 120 // Make sure that an interpreter with no get_internals() created until finalize still gets the 121 // internals destroyed 122 py::finalize_interpreter(); 123 py::initialize_interpreter(); 124 bool ran = false; 125 py::module::import("__main__").attr("internals_destroy_test") = 126 py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; }); 127 REQUIRE_FALSE(has_pybind11_internals_builtin()); 128 REQUIRE_FALSE(has_pybind11_internals_static()); 129 REQUIRE_FALSE(ran); 130 py::finalize_interpreter(); 131 REQUIRE(ran); 132 py::initialize_interpreter(); 133 REQUIRE_FALSE(has_pybind11_internals_builtin()); 134 REQUIRE_FALSE(has_pybind11_internals_static()); 135 136 // C++ modules can be reloaded. 137 auto cpp_module = py::module::import("widget_module"); 138 REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3); 139 140 // C++ type information is reloaded and can be used in python modules. 141 auto py_module = py::module::import("test_interpreter"); 142 auto py_widget = py_module.attr("DerivedWidget")("Hello after restart"); 143 REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart"); 144} 145 146TEST_CASE("Subinterpreter") { 147 // Add tags to the modules in the main interpreter and test the basics. 148 py::module::import("__main__").attr("main_tag") = "main interpreter"; 149 { 150 auto m = py::module::import("widget_module"); 151 m.attr("extension_module_tag") = "added to module in main interpreter"; 152 153 REQUIRE(m.attr("add")(1, 2).cast<int>() == 3); 154 } 155 REQUIRE(has_pybind11_internals_builtin()); 156 REQUIRE(has_pybind11_internals_static()); 157 158 /// Create and switch to a subinterpreter. 159 auto main_tstate = PyThreadState_Get(); 160 auto sub_tstate = Py_NewInterpreter(); 161 162 // Subinterpreters get their own copy of builtins. detail::get_internals() still 163 // works by returning from the static variable, i.e. all interpreters share a single 164 // global pybind11::internals; 165 REQUIRE_FALSE(has_pybind11_internals_builtin()); 166 REQUIRE(has_pybind11_internals_static()); 167 168 // Modules tags should be gone. 169 REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag")); 170 { 171 auto m = py::module::import("widget_module"); 172 REQUIRE_FALSE(py::hasattr(m, "extension_module_tag")); 173 174 // Function bindings should still work. 175 REQUIRE(m.attr("add")(1, 2).cast<int>() == 3); 176 } 177 178 // Restore main interpreter. 179 Py_EndInterpreter(sub_tstate); 180 PyThreadState_Swap(main_tstate); 181 182 REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag")); 183 REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag")); 184} 185 186TEST_CASE("Execution frame") { 187 // When the interpreter is embedded, there is no execution frame, but `py::exec` 188 // should still function by using reasonable globals: `__main__.__dict__`. 189 py::exec("var = dict(number=42)"); 190 REQUIRE(py::globals()["var"]["number"].cast<int>() == 42); 191} 192 193TEST_CASE("Threads") { 194 // Restart interpreter to ensure threads are not initialized 195 py::finalize_interpreter(); 196 py::initialize_interpreter(); 197 REQUIRE_FALSE(has_pybind11_internals_static()); 198 199 constexpr auto num_threads = 10; 200 auto locals = py::dict("count"_a=0); 201 202 { 203 py::gil_scoped_release gil_release{}; 204 REQUIRE(has_pybind11_internals_static()); 205 206 auto threads = std::vector<std::thread>(); 207 for (auto i = 0; i < num_threads; ++i) { 208 threads.emplace_back([&]() { 209 py::gil_scoped_acquire gil{}; 210 locals["count"] = locals["count"].cast<int>() + 1; 211 }); 212 } 213 214 for (auto &thread : threads) { 215 thread.join(); 216 } 217 } 218 219 REQUIRE(locals["count"].cast<int>() == num_threads); 220} 221 222// Scope exit utility https://stackoverflow.com/a/36644501/7255855 223struct scope_exit { 224 std::function<void()> f_; 225 explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {} 226 ~scope_exit() { if (f_) f_(); } 227}; 228 229TEST_CASE("Reload module from file") { 230 // Disable generation of cached bytecode (.pyc files) for this test, otherwise 231 // Python might pick up an old version from the cache instead of the new versions 232 // of the .py files generated below 233 auto sys = py::module::import("sys"); 234 bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>(); 235 sys.attr("dont_write_bytecode") = true; 236 // Reset the value at scope exit 237 scope_exit reset_dont_write_bytecode([&]() { 238 sys.attr("dont_write_bytecode") = dont_write_bytecode; 239 }); 240 241 std::string module_name = "test_module_reload"; 242 std::string module_file = module_name + ".py"; 243 244 // Create the module .py file 245 std::ofstream test_module(module_file); 246 test_module << "def test():\n"; 247 test_module << " return 1\n"; 248 test_module.close(); 249 // Delete the file at scope exit 250 scope_exit delete_module_file([&]() { 251 std::remove(module_file.c_str()); 252 }); 253 254 // Import the module from file 255 auto module = py::module::import(module_name.c_str()); 256 int result = module.attr("test")().cast<int>(); 257 REQUIRE(result == 1); 258 259 // Update the module .py file with a small change 260 test_module.open(module_file); 261 test_module << "def test():\n"; 262 test_module << " return 2\n"; 263 test_module.close(); 264 265 // Reload the module 266 module.reload(); 267 result = module.attr("test")().cast<int>(); 268 REQUIRE(result == 2); 269} 270