test_interpreter.cpp revision 14299:2fbea9df56d2
1#include <pybind11/embed.h> 2 3#ifdef _MSC_VER 4// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch 5// 2.0.1; this should be fixed in the next catch release after 2.0.1). 6# pragma warning(disable: 4996) 7#endif 8 9#include <catch.hpp> 10 11#include <thread> 12#include <fstream> 13#include <functional> 14 15namespace py = pybind11; 16using namespace py::literals; 17 18class Widget { 19public: 20 Widget(std::string message) : message(message) { } 21 virtual ~Widget() = default; 22 23 std::string the_message() const { return message; } 24 virtual int the_answer() const = 0; 25 26private: 27 std::string message; 28}; 29 30class PyWidget final : public Widget { 31 using Widget::Widget; 32 33 int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); } 34}; 35 36PYBIND11_EMBEDDED_MODULE(widget_module, m) { 37 py::class_<Widget, PyWidget>(m, "Widget") 38 .def(py::init<std::string>()) 39 .def_property_readonly("the_message", &Widget::the_message); 40 41 m.def("add", [](int i, int j) { return i + j; }); 42} 43 44PYBIND11_EMBEDDED_MODULE(throw_exception, ) { 45 throw std::runtime_error("C++ Error"); 46} 47 48PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) { 49 auto d = py::dict(); 50 d["missing"].cast<py::object>(); 51} 52 53TEST_CASE("Pass classes and data between modules defined in C++ and Python") { 54 auto module = py::module::import("test_interpreter"); 55 REQUIRE(py::hasattr(module, "DerivedWidget")); 56 57 auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__")); 58 py::exec(R"( 59 widget = DerivedWidget("{} - {}".format(hello, x)) 60 message = widget.the_message 61 )", py::globals(), locals); 62 REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5"); 63 64 auto py_widget = module.attr("DerivedWidget")("The question"); 65 auto message = py_widget.attr("the_message"); 66 REQUIRE(message.cast<std::string>() == "The question"); 67 68 const auto &cpp_widget = py_widget.cast<const Widget &>(); 69 REQUIRE(cpp_widget.the_answer() == 42); 70} 71 72TEST_CASE("Import error handling") { 73 REQUIRE_NOTHROW(py::module::import("widget_module")); 74 REQUIRE_THROWS_WITH(py::module::import("throw_exception"), 75 "ImportError: C++ Error"); 76 REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"), 77 Catch::Contains("ImportError: KeyError")); 78} 79 80TEST_CASE("There can be only one interpreter") { 81 static_assert(std::is_move_constructible<py::scoped_interpreter>::value, ""); 82 static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, ""); 83 static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, ""); 84 static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, ""); 85 86 REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running"); 87 REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running"); 88 89 py::finalize_interpreter(); 90 REQUIRE_NOTHROW(py::scoped_interpreter()); 91 { 92 auto pyi1 = py::scoped_interpreter(); 93 auto pyi2 = std::move(pyi1); 94 } 95 py::initialize_interpreter(); 96} 97 98bool has_pybind11_internals_builtin() { 99 auto builtins = py::handle(PyEval_GetBuiltins()); 100 return builtins.contains(PYBIND11_INTERNALS_ID); 101}; 102 103bool has_pybind11_internals_static() { 104 auto **&ipp = py::detail::get_internals_pp(); 105 return ipp && *ipp; 106} 107 108TEST_CASE("Restart the interpreter") { 109 // Verify pre-restart state. 110 REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3); 111 REQUIRE(has_pybind11_internals_builtin()); 112 REQUIRE(has_pybind11_internals_static()); 113 REQUIRE(py::module::import("external_module").attr("A")(123).attr("value").cast<int>() == 123); 114 115 // local and foreign module internals should point to the same internals: 116 REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) == 117 py::module::import("external_module").attr("internals_at")().cast<uintptr_t>()); 118 119 // Restart the interpreter. 120 py::finalize_interpreter(); 121 REQUIRE(Py_IsInitialized() == 0); 122 123 py::initialize_interpreter(); 124 REQUIRE(Py_IsInitialized() == 1); 125 126 // Internals are deleted after a restart. 127 REQUIRE_FALSE(has_pybind11_internals_builtin()); 128 REQUIRE_FALSE(has_pybind11_internals_static()); 129 pybind11::detail::get_internals(); 130 REQUIRE(has_pybind11_internals_builtin()); 131 REQUIRE(has_pybind11_internals_static()); 132 REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) == 133 py::module::import("external_module").attr("internals_at")().cast<uintptr_t>()); 134 135 // Make sure that an interpreter with no get_internals() created until finalize still gets the 136 // internals destroyed 137 py::finalize_interpreter(); 138 py::initialize_interpreter(); 139 bool ran = false; 140 py::module::import("__main__").attr("internals_destroy_test") = 141 py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; }); 142 REQUIRE_FALSE(has_pybind11_internals_builtin()); 143 REQUIRE_FALSE(has_pybind11_internals_static()); 144 REQUIRE_FALSE(ran); 145 py::finalize_interpreter(); 146 REQUIRE(ran); 147 py::initialize_interpreter(); 148 REQUIRE_FALSE(has_pybind11_internals_builtin()); 149 REQUIRE_FALSE(has_pybind11_internals_static()); 150 151 // C++ modules can be reloaded. 152 auto cpp_module = py::module::import("widget_module"); 153 REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3); 154 155 // C++ type information is reloaded and can be used in python modules. 156 auto py_module = py::module::import("test_interpreter"); 157 auto py_widget = py_module.attr("DerivedWidget")("Hello after restart"); 158 REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart"); 159} 160 161TEST_CASE("Subinterpreter") { 162 // Add tags to the modules in the main interpreter and test the basics. 163 py::module::import("__main__").attr("main_tag") = "main interpreter"; 164 { 165 auto m = py::module::import("widget_module"); 166 m.attr("extension_module_tag") = "added to module in main interpreter"; 167 168 REQUIRE(m.attr("add")(1, 2).cast<int>() == 3); 169 } 170 REQUIRE(has_pybind11_internals_builtin()); 171 REQUIRE(has_pybind11_internals_static()); 172 173 /// Create and switch to a subinterpreter. 174 auto main_tstate = PyThreadState_Get(); 175 auto sub_tstate = Py_NewInterpreter(); 176 177 // Subinterpreters get their own copy of builtins. detail::get_internals() still 178 // works by returning from the static variable, i.e. all interpreters share a single 179 // global pybind11::internals; 180 REQUIRE_FALSE(has_pybind11_internals_builtin()); 181 REQUIRE(has_pybind11_internals_static()); 182 183 // Modules tags should be gone. 184 REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag")); 185 { 186 auto m = py::module::import("widget_module"); 187 REQUIRE_FALSE(py::hasattr(m, "extension_module_tag")); 188 189 // Function bindings should still work. 190 REQUIRE(m.attr("add")(1, 2).cast<int>() == 3); 191 } 192 193 // Restore main interpreter. 194 Py_EndInterpreter(sub_tstate); 195 PyThreadState_Swap(main_tstate); 196 197 REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag")); 198 REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag")); 199} 200 201TEST_CASE("Execution frame") { 202 // When the interpreter is embedded, there is no execution frame, but `py::exec` 203 // should still function by using reasonable globals: `__main__.__dict__`. 204 py::exec("var = dict(number=42)"); 205 REQUIRE(py::globals()["var"]["number"].cast<int>() == 42); 206} 207 208TEST_CASE("Threads") { 209 // Restart interpreter to ensure threads are not initialized 210 py::finalize_interpreter(); 211 py::initialize_interpreter(); 212 REQUIRE_FALSE(has_pybind11_internals_static()); 213 214 constexpr auto num_threads = 10; 215 auto locals = py::dict("count"_a=0); 216 217 { 218 py::gil_scoped_release gil_release{}; 219 REQUIRE(has_pybind11_internals_static()); 220 221 auto threads = std::vector<std::thread>(); 222 for (auto i = 0; i < num_threads; ++i) { 223 threads.emplace_back([&]() { 224 py::gil_scoped_acquire gil{}; 225 locals["count"] = locals["count"].cast<int>() + 1; 226 }); 227 } 228 229 for (auto &thread : threads) { 230 thread.join(); 231 } 232 } 233 234 REQUIRE(locals["count"].cast<int>() == num_threads); 235} 236 237// Scope exit utility https://stackoverflow.com/a/36644501/7255855 238struct scope_exit { 239 std::function<void()> f_; 240 explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {} 241 ~scope_exit() { if (f_) f_(); } 242}; 243 244TEST_CASE("Reload module from file") { 245 // Disable generation of cached bytecode (.pyc files) for this test, otherwise 246 // Python might pick up an old version from the cache instead of the new versions 247 // of the .py files generated below 248 auto sys = py::module::import("sys"); 249 bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>(); 250 sys.attr("dont_write_bytecode") = true; 251 // Reset the value at scope exit 252 scope_exit reset_dont_write_bytecode([&]() { 253 sys.attr("dont_write_bytecode") = dont_write_bytecode; 254 }); 255 256 std::string module_name = "test_module_reload"; 257 std::string module_file = module_name + ".py"; 258 259 // Create the module .py file 260 std::ofstream test_module(module_file); 261 test_module << "def test():\n"; 262 test_module << " return 1\n"; 263 test_module.close(); 264 // Delete the file at scope exit 265 scope_exit delete_module_file([&]() { 266 std::remove(module_file.c_str()); 267 }); 268 269 // Import the module from file 270 auto module = py::module::import(module_name.c_str()); 271 int result = module.attr("test")().cast<int>(); 272 REQUIRE(result == 1); 273 274 // Update the module .py file with a small change 275 test_module.open(module_file); 276 test_module << "def test():\n"; 277 test_module << " return 2\n"; 278 test_module.close(); 279 280 // Reload the module 281 module.reload(); 282 result = module.attr("test")().cast<int>(); 283 REQUIRE(result == 2); 284} 285