112391Sjason@lowepower.com#include <pybind11/embed.h> 214299Sbbruce@ucdavis.edu 314299Sbbruce@ucdavis.edu#ifdef _MSC_VER 414299Sbbruce@ucdavis.edu// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch 514299Sbbruce@ucdavis.edu// 2.0.1; this should be fixed in the next catch release after 2.0.1). 614299Sbbruce@ucdavis.edu# pragma warning(disable: 4996) 714299Sbbruce@ucdavis.edu#endif 814299Sbbruce@ucdavis.edu 912391Sjason@lowepower.com#include <catch.hpp> 1012391Sjason@lowepower.com 1112391Sjason@lowepower.com#include <thread> 1212391Sjason@lowepower.com#include <fstream> 1312391Sjason@lowepower.com#include <functional> 1412391Sjason@lowepower.com 1512391Sjason@lowepower.comnamespace py = pybind11; 1612391Sjason@lowepower.comusing namespace py::literals; 1712391Sjason@lowepower.com 1812391Sjason@lowepower.comclass Widget { 1912391Sjason@lowepower.compublic: 2012391Sjason@lowepower.com Widget(std::string message) : message(message) { } 2112391Sjason@lowepower.com virtual ~Widget() = default; 2212391Sjason@lowepower.com 2312391Sjason@lowepower.com std::string the_message() const { return message; } 2412391Sjason@lowepower.com virtual int the_answer() const = 0; 2512391Sjason@lowepower.com 2612391Sjason@lowepower.comprivate: 2712391Sjason@lowepower.com std::string message; 2812391Sjason@lowepower.com}; 2912391Sjason@lowepower.com 3012391Sjason@lowepower.comclass PyWidget final : public Widget { 3112391Sjason@lowepower.com using Widget::Widget; 3212391Sjason@lowepower.com 3312391Sjason@lowepower.com int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); } 3412391Sjason@lowepower.com}; 3512391Sjason@lowepower.com 3612391Sjason@lowepower.comPYBIND11_EMBEDDED_MODULE(widget_module, m) { 3712391Sjason@lowepower.com py::class_<Widget, PyWidget>(m, "Widget") 3812391Sjason@lowepower.com .def(py::init<std::string>()) 3912391Sjason@lowepower.com .def_property_readonly("the_message", &Widget::the_message); 4012391Sjason@lowepower.com 4112391Sjason@lowepower.com m.def("add", [](int i, int j) { return i + j; }); 4212391Sjason@lowepower.com} 4312391Sjason@lowepower.com 4412391Sjason@lowepower.comPYBIND11_EMBEDDED_MODULE(throw_exception, ) { 4512391Sjason@lowepower.com throw std::runtime_error("C++ Error"); 4612391Sjason@lowepower.com} 4712391Sjason@lowepower.com 4812391Sjason@lowepower.comPYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) { 4912391Sjason@lowepower.com auto d = py::dict(); 5012391Sjason@lowepower.com d["missing"].cast<py::object>(); 5112391Sjason@lowepower.com} 5212391Sjason@lowepower.com 5312391Sjason@lowepower.comTEST_CASE("Pass classes and data between modules defined in C++ and Python") { 5412391Sjason@lowepower.com auto module = py::module::import("test_interpreter"); 5512391Sjason@lowepower.com REQUIRE(py::hasattr(module, "DerivedWidget")); 5612391Sjason@lowepower.com 5712391Sjason@lowepower.com auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__")); 5812391Sjason@lowepower.com py::exec(R"( 5912391Sjason@lowepower.com widget = DerivedWidget("{} - {}".format(hello, x)) 6012391Sjason@lowepower.com message = widget.the_message 6112391Sjason@lowepower.com )", py::globals(), locals); 6212391Sjason@lowepower.com REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5"); 6312391Sjason@lowepower.com 6412391Sjason@lowepower.com auto py_widget = module.attr("DerivedWidget")("The question"); 6512391Sjason@lowepower.com auto message = py_widget.attr("the_message"); 6612391Sjason@lowepower.com REQUIRE(message.cast<std::string>() == "The question"); 6712391Sjason@lowepower.com 6812391Sjason@lowepower.com const auto &cpp_widget = py_widget.cast<const Widget &>(); 6912391Sjason@lowepower.com REQUIRE(cpp_widget.the_answer() == 42); 7012391Sjason@lowepower.com} 7112391Sjason@lowepower.com 7212391Sjason@lowepower.comTEST_CASE("Import error handling") { 7312391Sjason@lowepower.com REQUIRE_NOTHROW(py::module::import("widget_module")); 7412391Sjason@lowepower.com REQUIRE_THROWS_WITH(py::module::import("throw_exception"), 7512391Sjason@lowepower.com "ImportError: C++ Error"); 7612391Sjason@lowepower.com REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"), 7712391Sjason@lowepower.com Catch::Contains("ImportError: KeyError")); 7812391Sjason@lowepower.com} 7912391Sjason@lowepower.com 8012391Sjason@lowepower.comTEST_CASE("There can be only one interpreter") { 8112391Sjason@lowepower.com static_assert(std::is_move_constructible<py::scoped_interpreter>::value, ""); 8212391Sjason@lowepower.com static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, ""); 8312391Sjason@lowepower.com static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, ""); 8412391Sjason@lowepower.com static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, ""); 8512391Sjason@lowepower.com 8612391Sjason@lowepower.com REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running"); 8712391Sjason@lowepower.com REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running"); 8812391Sjason@lowepower.com 8912391Sjason@lowepower.com py::finalize_interpreter(); 9012391Sjason@lowepower.com REQUIRE_NOTHROW(py::scoped_interpreter()); 9112391Sjason@lowepower.com { 9212391Sjason@lowepower.com auto pyi1 = py::scoped_interpreter(); 9312391Sjason@lowepower.com auto pyi2 = std::move(pyi1); 9412391Sjason@lowepower.com } 9512391Sjason@lowepower.com py::initialize_interpreter(); 9612391Sjason@lowepower.com} 9712391Sjason@lowepower.com 9812391Sjason@lowepower.combool has_pybind11_internals_builtin() { 9912391Sjason@lowepower.com auto builtins = py::handle(PyEval_GetBuiltins()); 10012391Sjason@lowepower.com return builtins.contains(PYBIND11_INTERNALS_ID); 10112391Sjason@lowepower.com}; 10212391Sjason@lowepower.com 10312391Sjason@lowepower.combool has_pybind11_internals_static() { 10414299Sbbruce@ucdavis.edu auto **&ipp = py::detail::get_internals_pp(); 10514299Sbbruce@ucdavis.edu return ipp && *ipp; 10612391Sjason@lowepower.com} 10712391Sjason@lowepower.com 10812391Sjason@lowepower.comTEST_CASE("Restart the interpreter") { 10912391Sjason@lowepower.com // Verify pre-restart state. 11012391Sjason@lowepower.com REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3); 11112391Sjason@lowepower.com REQUIRE(has_pybind11_internals_builtin()); 11212391Sjason@lowepower.com REQUIRE(has_pybind11_internals_static()); 11314299Sbbruce@ucdavis.edu REQUIRE(py::module::import("external_module").attr("A")(123).attr("value").cast<int>() == 123); 11414299Sbbruce@ucdavis.edu 11514299Sbbruce@ucdavis.edu // local and foreign module internals should point to the same internals: 11614299Sbbruce@ucdavis.edu REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) == 11714299Sbbruce@ucdavis.edu py::module::import("external_module").attr("internals_at")().cast<uintptr_t>()); 11812391Sjason@lowepower.com 11912391Sjason@lowepower.com // Restart the interpreter. 12012391Sjason@lowepower.com py::finalize_interpreter(); 12112391Sjason@lowepower.com REQUIRE(Py_IsInitialized() == 0); 12212391Sjason@lowepower.com 12312391Sjason@lowepower.com py::initialize_interpreter(); 12412391Sjason@lowepower.com REQUIRE(Py_IsInitialized() == 1); 12512391Sjason@lowepower.com 12612391Sjason@lowepower.com // Internals are deleted after a restart. 12712391Sjason@lowepower.com REQUIRE_FALSE(has_pybind11_internals_builtin()); 12812391Sjason@lowepower.com REQUIRE_FALSE(has_pybind11_internals_static()); 12912391Sjason@lowepower.com pybind11::detail::get_internals(); 13012391Sjason@lowepower.com REQUIRE(has_pybind11_internals_builtin()); 13112391Sjason@lowepower.com REQUIRE(has_pybind11_internals_static()); 13214299Sbbruce@ucdavis.edu REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) == 13314299Sbbruce@ucdavis.edu py::module::import("external_module").attr("internals_at")().cast<uintptr_t>()); 13412391Sjason@lowepower.com 13512391Sjason@lowepower.com // Make sure that an interpreter with no get_internals() created until finalize still gets the 13612391Sjason@lowepower.com // internals destroyed 13712391Sjason@lowepower.com py::finalize_interpreter(); 13812391Sjason@lowepower.com py::initialize_interpreter(); 13912391Sjason@lowepower.com bool ran = false; 14012391Sjason@lowepower.com py::module::import("__main__").attr("internals_destroy_test") = 14112391Sjason@lowepower.com py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; }); 14212391Sjason@lowepower.com REQUIRE_FALSE(has_pybind11_internals_builtin()); 14312391Sjason@lowepower.com REQUIRE_FALSE(has_pybind11_internals_static()); 14412391Sjason@lowepower.com REQUIRE_FALSE(ran); 14512391Sjason@lowepower.com py::finalize_interpreter(); 14612391Sjason@lowepower.com REQUIRE(ran); 14712391Sjason@lowepower.com py::initialize_interpreter(); 14812391Sjason@lowepower.com REQUIRE_FALSE(has_pybind11_internals_builtin()); 14912391Sjason@lowepower.com REQUIRE_FALSE(has_pybind11_internals_static()); 15012391Sjason@lowepower.com 15112391Sjason@lowepower.com // C++ modules can be reloaded. 15212391Sjason@lowepower.com auto cpp_module = py::module::import("widget_module"); 15312391Sjason@lowepower.com REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3); 15412391Sjason@lowepower.com 15512391Sjason@lowepower.com // C++ type information is reloaded and can be used in python modules. 15612391Sjason@lowepower.com auto py_module = py::module::import("test_interpreter"); 15712391Sjason@lowepower.com auto py_widget = py_module.attr("DerivedWidget")("Hello after restart"); 15812391Sjason@lowepower.com REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart"); 15912391Sjason@lowepower.com} 16012391Sjason@lowepower.com 16112391Sjason@lowepower.comTEST_CASE("Subinterpreter") { 16212391Sjason@lowepower.com // Add tags to the modules in the main interpreter and test the basics. 16312391Sjason@lowepower.com py::module::import("__main__").attr("main_tag") = "main interpreter"; 16412391Sjason@lowepower.com { 16512391Sjason@lowepower.com auto m = py::module::import("widget_module"); 16612391Sjason@lowepower.com m.attr("extension_module_tag") = "added to module in main interpreter"; 16712391Sjason@lowepower.com 16812391Sjason@lowepower.com REQUIRE(m.attr("add")(1, 2).cast<int>() == 3); 16912391Sjason@lowepower.com } 17012391Sjason@lowepower.com REQUIRE(has_pybind11_internals_builtin()); 17112391Sjason@lowepower.com REQUIRE(has_pybind11_internals_static()); 17212391Sjason@lowepower.com 17312391Sjason@lowepower.com /// Create and switch to a subinterpreter. 17412391Sjason@lowepower.com auto main_tstate = PyThreadState_Get(); 17512391Sjason@lowepower.com auto sub_tstate = Py_NewInterpreter(); 17612391Sjason@lowepower.com 17712391Sjason@lowepower.com // Subinterpreters get their own copy of builtins. detail::get_internals() still 17812391Sjason@lowepower.com // works by returning from the static variable, i.e. all interpreters share a single 17912391Sjason@lowepower.com // global pybind11::internals; 18012391Sjason@lowepower.com REQUIRE_FALSE(has_pybind11_internals_builtin()); 18112391Sjason@lowepower.com REQUIRE(has_pybind11_internals_static()); 18212391Sjason@lowepower.com 18312391Sjason@lowepower.com // Modules tags should be gone. 18412391Sjason@lowepower.com REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag")); 18512391Sjason@lowepower.com { 18612391Sjason@lowepower.com auto m = py::module::import("widget_module"); 18712391Sjason@lowepower.com REQUIRE_FALSE(py::hasattr(m, "extension_module_tag")); 18812391Sjason@lowepower.com 18912391Sjason@lowepower.com // Function bindings should still work. 19012391Sjason@lowepower.com REQUIRE(m.attr("add")(1, 2).cast<int>() == 3); 19112391Sjason@lowepower.com } 19212391Sjason@lowepower.com 19312391Sjason@lowepower.com // Restore main interpreter. 19412391Sjason@lowepower.com Py_EndInterpreter(sub_tstate); 19512391Sjason@lowepower.com PyThreadState_Swap(main_tstate); 19612391Sjason@lowepower.com 19712391Sjason@lowepower.com REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag")); 19812391Sjason@lowepower.com REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag")); 19912391Sjason@lowepower.com} 20012391Sjason@lowepower.com 20112391Sjason@lowepower.comTEST_CASE("Execution frame") { 20212391Sjason@lowepower.com // When the interpreter is embedded, there is no execution frame, but `py::exec` 20312391Sjason@lowepower.com // should still function by using reasonable globals: `__main__.__dict__`. 20412391Sjason@lowepower.com py::exec("var = dict(number=42)"); 20512391Sjason@lowepower.com REQUIRE(py::globals()["var"]["number"].cast<int>() == 42); 20612391Sjason@lowepower.com} 20712391Sjason@lowepower.com 20812391Sjason@lowepower.comTEST_CASE("Threads") { 20912391Sjason@lowepower.com // Restart interpreter to ensure threads are not initialized 21012391Sjason@lowepower.com py::finalize_interpreter(); 21112391Sjason@lowepower.com py::initialize_interpreter(); 21212391Sjason@lowepower.com REQUIRE_FALSE(has_pybind11_internals_static()); 21312391Sjason@lowepower.com 21412391Sjason@lowepower.com constexpr auto num_threads = 10; 21512391Sjason@lowepower.com auto locals = py::dict("count"_a=0); 21612391Sjason@lowepower.com 21712391Sjason@lowepower.com { 21812391Sjason@lowepower.com py::gil_scoped_release gil_release{}; 21912391Sjason@lowepower.com REQUIRE(has_pybind11_internals_static()); 22012391Sjason@lowepower.com 22112391Sjason@lowepower.com auto threads = std::vector<std::thread>(); 22212391Sjason@lowepower.com for (auto i = 0; i < num_threads; ++i) { 22312391Sjason@lowepower.com threads.emplace_back([&]() { 22412391Sjason@lowepower.com py::gil_scoped_acquire gil{}; 22512391Sjason@lowepower.com locals["count"] = locals["count"].cast<int>() + 1; 22612391Sjason@lowepower.com }); 22712391Sjason@lowepower.com } 22812391Sjason@lowepower.com 22912391Sjason@lowepower.com for (auto &thread : threads) { 23012391Sjason@lowepower.com thread.join(); 23112391Sjason@lowepower.com } 23212391Sjason@lowepower.com } 23312391Sjason@lowepower.com 23412391Sjason@lowepower.com REQUIRE(locals["count"].cast<int>() == num_threads); 23512391Sjason@lowepower.com} 23612391Sjason@lowepower.com 23712391Sjason@lowepower.com// Scope exit utility https://stackoverflow.com/a/36644501/7255855 23812391Sjason@lowepower.comstruct scope_exit { 23912391Sjason@lowepower.com std::function<void()> f_; 24012391Sjason@lowepower.com explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {} 24112391Sjason@lowepower.com ~scope_exit() { if (f_) f_(); } 24212391Sjason@lowepower.com}; 24312391Sjason@lowepower.com 24412391Sjason@lowepower.comTEST_CASE("Reload module from file") { 24512391Sjason@lowepower.com // Disable generation of cached bytecode (.pyc files) for this test, otherwise 24612391Sjason@lowepower.com // Python might pick up an old version from the cache instead of the new versions 24712391Sjason@lowepower.com // of the .py files generated below 24812391Sjason@lowepower.com auto sys = py::module::import("sys"); 24912391Sjason@lowepower.com bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>(); 25012391Sjason@lowepower.com sys.attr("dont_write_bytecode") = true; 25112391Sjason@lowepower.com // Reset the value at scope exit 25212391Sjason@lowepower.com scope_exit reset_dont_write_bytecode([&]() { 25312391Sjason@lowepower.com sys.attr("dont_write_bytecode") = dont_write_bytecode; 25412391Sjason@lowepower.com }); 25512391Sjason@lowepower.com 25612391Sjason@lowepower.com std::string module_name = "test_module_reload"; 25712391Sjason@lowepower.com std::string module_file = module_name + ".py"; 25812391Sjason@lowepower.com 25912391Sjason@lowepower.com // Create the module .py file 26012391Sjason@lowepower.com std::ofstream test_module(module_file); 26112391Sjason@lowepower.com test_module << "def test():\n"; 26212391Sjason@lowepower.com test_module << " return 1\n"; 26312391Sjason@lowepower.com test_module.close(); 26412391Sjason@lowepower.com // Delete the file at scope exit 26512391Sjason@lowepower.com scope_exit delete_module_file([&]() { 26612391Sjason@lowepower.com std::remove(module_file.c_str()); 26712391Sjason@lowepower.com }); 26812391Sjason@lowepower.com 26912391Sjason@lowepower.com // Import the module from file 27012391Sjason@lowepower.com auto module = py::module::import(module_name.c_str()); 27112391Sjason@lowepower.com int result = module.attr("test")().cast<int>(); 27212391Sjason@lowepower.com REQUIRE(result == 1); 27312391Sjason@lowepower.com 27412391Sjason@lowepower.com // Update the module .py file with a small change 27512391Sjason@lowepower.com test_module.open(module_file); 27612391Sjason@lowepower.com test_module << "def test():\n"; 27712391Sjason@lowepower.com test_module << " return 2\n"; 27812391Sjason@lowepower.com test_module.close(); 27912391Sjason@lowepower.com 28012391Sjason@lowepower.com // Reload the module 28112391Sjason@lowepower.com module.reload(); 28212391Sjason@lowepower.com result = module.attr("test")().cast<int>(); 28312391Sjason@lowepower.com REQUIRE(result == 2); 28412391Sjason@lowepower.com} 285