112391Sjason@lowepower.com#include <pybind11/embed.h>
214299Sbbruce@ucdavis.edu
314299Sbbruce@ucdavis.edu#ifdef _MSC_VER
414299Sbbruce@ucdavis.edu// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch
514299Sbbruce@ucdavis.edu// 2.0.1; this should be fixed in the next catch release after 2.0.1).
614299Sbbruce@ucdavis.edu#  pragma warning(disable: 4996)
714299Sbbruce@ucdavis.edu#endif
814299Sbbruce@ucdavis.edu
912391Sjason@lowepower.com#include <catch.hpp>
1012391Sjason@lowepower.com
1112391Sjason@lowepower.com#include <thread>
1212391Sjason@lowepower.com#include <fstream>
1312391Sjason@lowepower.com#include <functional>
1412391Sjason@lowepower.com
1512391Sjason@lowepower.comnamespace py = pybind11;
1612391Sjason@lowepower.comusing namespace py::literals;
1712391Sjason@lowepower.com
1812391Sjason@lowepower.comclass Widget {
1912391Sjason@lowepower.compublic:
2012391Sjason@lowepower.com    Widget(std::string message) : message(message) { }
2112391Sjason@lowepower.com    virtual ~Widget() = default;
2212391Sjason@lowepower.com
2312391Sjason@lowepower.com    std::string the_message() const { return message; }
2412391Sjason@lowepower.com    virtual int the_answer() const = 0;
2512391Sjason@lowepower.com
2612391Sjason@lowepower.comprivate:
2712391Sjason@lowepower.com    std::string message;
2812391Sjason@lowepower.com};
2912391Sjason@lowepower.com
3012391Sjason@lowepower.comclass PyWidget final : public Widget {
3112391Sjason@lowepower.com    using Widget::Widget;
3212391Sjason@lowepower.com
3312391Sjason@lowepower.com    int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); }
3412391Sjason@lowepower.com};
3512391Sjason@lowepower.com
3612391Sjason@lowepower.comPYBIND11_EMBEDDED_MODULE(widget_module, m) {
3712391Sjason@lowepower.com    py::class_<Widget, PyWidget>(m, "Widget")
3812391Sjason@lowepower.com        .def(py::init<std::string>())
3912391Sjason@lowepower.com        .def_property_readonly("the_message", &Widget::the_message);
4012391Sjason@lowepower.com
4112391Sjason@lowepower.com    m.def("add", [](int i, int j) { return i + j; });
4212391Sjason@lowepower.com}
4312391Sjason@lowepower.com
4412391Sjason@lowepower.comPYBIND11_EMBEDDED_MODULE(throw_exception, ) {
4512391Sjason@lowepower.com    throw std::runtime_error("C++ Error");
4612391Sjason@lowepower.com}
4712391Sjason@lowepower.com
4812391Sjason@lowepower.comPYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
4912391Sjason@lowepower.com    auto d = py::dict();
5012391Sjason@lowepower.com    d["missing"].cast<py::object>();
5112391Sjason@lowepower.com}
5212391Sjason@lowepower.com
5312391Sjason@lowepower.comTEST_CASE("Pass classes and data between modules defined in C++ and Python") {
5412391Sjason@lowepower.com    auto module = py::module::import("test_interpreter");
5512391Sjason@lowepower.com    REQUIRE(py::hasattr(module, "DerivedWidget"));
5612391Sjason@lowepower.com
5712391Sjason@lowepower.com    auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__"));
5812391Sjason@lowepower.com    py::exec(R"(
5912391Sjason@lowepower.com        widget = DerivedWidget("{} - {}".format(hello, x))
6012391Sjason@lowepower.com        message = widget.the_message
6112391Sjason@lowepower.com    )", py::globals(), locals);
6212391Sjason@lowepower.com    REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
6312391Sjason@lowepower.com
6412391Sjason@lowepower.com    auto py_widget = module.attr("DerivedWidget")("The question");
6512391Sjason@lowepower.com    auto message = py_widget.attr("the_message");
6612391Sjason@lowepower.com    REQUIRE(message.cast<std::string>() == "The question");
6712391Sjason@lowepower.com
6812391Sjason@lowepower.com    const auto &cpp_widget = py_widget.cast<const Widget &>();
6912391Sjason@lowepower.com    REQUIRE(cpp_widget.the_answer() == 42);
7012391Sjason@lowepower.com}
7112391Sjason@lowepower.com
7212391Sjason@lowepower.comTEST_CASE("Import error handling") {
7312391Sjason@lowepower.com    REQUIRE_NOTHROW(py::module::import("widget_module"));
7412391Sjason@lowepower.com    REQUIRE_THROWS_WITH(py::module::import("throw_exception"),
7512391Sjason@lowepower.com                        "ImportError: C++ Error");
7612391Sjason@lowepower.com    REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"),
7712391Sjason@lowepower.com                        Catch::Contains("ImportError: KeyError"));
7812391Sjason@lowepower.com}
7912391Sjason@lowepower.com
8012391Sjason@lowepower.comTEST_CASE("There can be only one interpreter") {
8112391Sjason@lowepower.com    static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
8212391Sjason@lowepower.com    static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
8312391Sjason@lowepower.com    static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
8412391Sjason@lowepower.com    static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");
8512391Sjason@lowepower.com
8612391Sjason@lowepower.com    REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
8712391Sjason@lowepower.com    REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
8812391Sjason@lowepower.com
8912391Sjason@lowepower.com    py::finalize_interpreter();
9012391Sjason@lowepower.com    REQUIRE_NOTHROW(py::scoped_interpreter());
9112391Sjason@lowepower.com    {
9212391Sjason@lowepower.com        auto pyi1 = py::scoped_interpreter();
9312391Sjason@lowepower.com        auto pyi2 = std::move(pyi1);
9412391Sjason@lowepower.com    }
9512391Sjason@lowepower.com    py::initialize_interpreter();
9612391Sjason@lowepower.com}
9712391Sjason@lowepower.com
9812391Sjason@lowepower.combool has_pybind11_internals_builtin() {
9912391Sjason@lowepower.com    auto builtins = py::handle(PyEval_GetBuiltins());
10012391Sjason@lowepower.com    return builtins.contains(PYBIND11_INTERNALS_ID);
10112391Sjason@lowepower.com};
10212391Sjason@lowepower.com
10312391Sjason@lowepower.combool has_pybind11_internals_static() {
10414299Sbbruce@ucdavis.edu    auto **&ipp = py::detail::get_internals_pp();
10514299Sbbruce@ucdavis.edu    return ipp && *ipp;
10612391Sjason@lowepower.com}
10712391Sjason@lowepower.com
10812391Sjason@lowepower.comTEST_CASE("Restart the interpreter") {
10912391Sjason@lowepower.com    // Verify pre-restart state.
11012391Sjason@lowepower.com    REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
11112391Sjason@lowepower.com    REQUIRE(has_pybind11_internals_builtin());
11212391Sjason@lowepower.com    REQUIRE(has_pybind11_internals_static());
11314299Sbbruce@ucdavis.edu    REQUIRE(py::module::import("external_module").attr("A")(123).attr("value").cast<int>() == 123);
11414299Sbbruce@ucdavis.edu
11514299Sbbruce@ucdavis.edu    // local and foreign module internals should point to the same internals:
11614299Sbbruce@ucdavis.edu    REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) ==
11714299Sbbruce@ucdavis.edu            py::module::import("external_module").attr("internals_at")().cast<uintptr_t>());
11812391Sjason@lowepower.com
11912391Sjason@lowepower.com    // Restart the interpreter.
12012391Sjason@lowepower.com    py::finalize_interpreter();
12112391Sjason@lowepower.com    REQUIRE(Py_IsInitialized() == 0);
12212391Sjason@lowepower.com
12312391Sjason@lowepower.com    py::initialize_interpreter();
12412391Sjason@lowepower.com    REQUIRE(Py_IsInitialized() == 1);
12512391Sjason@lowepower.com
12612391Sjason@lowepower.com    // Internals are deleted after a restart.
12712391Sjason@lowepower.com    REQUIRE_FALSE(has_pybind11_internals_builtin());
12812391Sjason@lowepower.com    REQUIRE_FALSE(has_pybind11_internals_static());
12912391Sjason@lowepower.com    pybind11::detail::get_internals();
13012391Sjason@lowepower.com    REQUIRE(has_pybind11_internals_builtin());
13112391Sjason@lowepower.com    REQUIRE(has_pybind11_internals_static());
13214299Sbbruce@ucdavis.edu    REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) ==
13314299Sbbruce@ucdavis.edu            py::module::import("external_module").attr("internals_at")().cast<uintptr_t>());
13412391Sjason@lowepower.com
13512391Sjason@lowepower.com    // Make sure that an interpreter with no get_internals() created until finalize still gets the
13612391Sjason@lowepower.com    // internals destroyed
13712391Sjason@lowepower.com    py::finalize_interpreter();
13812391Sjason@lowepower.com    py::initialize_interpreter();
13912391Sjason@lowepower.com    bool ran = false;
14012391Sjason@lowepower.com    py::module::import("__main__").attr("internals_destroy_test") =
14112391Sjason@lowepower.com        py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; });
14212391Sjason@lowepower.com    REQUIRE_FALSE(has_pybind11_internals_builtin());
14312391Sjason@lowepower.com    REQUIRE_FALSE(has_pybind11_internals_static());
14412391Sjason@lowepower.com    REQUIRE_FALSE(ran);
14512391Sjason@lowepower.com    py::finalize_interpreter();
14612391Sjason@lowepower.com    REQUIRE(ran);
14712391Sjason@lowepower.com    py::initialize_interpreter();
14812391Sjason@lowepower.com    REQUIRE_FALSE(has_pybind11_internals_builtin());
14912391Sjason@lowepower.com    REQUIRE_FALSE(has_pybind11_internals_static());
15012391Sjason@lowepower.com
15112391Sjason@lowepower.com    // C++ modules can be reloaded.
15212391Sjason@lowepower.com    auto cpp_module = py::module::import("widget_module");
15312391Sjason@lowepower.com    REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
15412391Sjason@lowepower.com
15512391Sjason@lowepower.com    // C++ type information is reloaded and can be used in python modules.
15612391Sjason@lowepower.com    auto py_module = py::module::import("test_interpreter");
15712391Sjason@lowepower.com    auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
15812391Sjason@lowepower.com    REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
15912391Sjason@lowepower.com}
16012391Sjason@lowepower.com
16112391Sjason@lowepower.comTEST_CASE("Subinterpreter") {
16212391Sjason@lowepower.com    // Add tags to the modules in the main interpreter and test the basics.
16312391Sjason@lowepower.com    py::module::import("__main__").attr("main_tag") = "main interpreter";
16412391Sjason@lowepower.com    {
16512391Sjason@lowepower.com        auto m = py::module::import("widget_module");
16612391Sjason@lowepower.com        m.attr("extension_module_tag") = "added to module in main interpreter";
16712391Sjason@lowepower.com
16812391Sjason@lowepower.com        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
16912391Sjason@lowepower.com    }
17012391Sjason@lowepower.com    REQUIRE(has_pybind11_internals_builtin());
17112391Sjason@lowepower.com    REQUIRE(has_pybind11_internals_static());
17212391Sjason@lowepower.com
17312391Sjason@lowepower.com    /// Create and switch to a subinterpreter.
17412391Sjason@lowepower.com    auto main_tstate = PyThreadState_Get();
17512391Sjason@lowepower.com    auto sub_tstate = Py_NewInterpreter();
17612391Sjason@lowepower.com
17712391Sjason@lowepower.com    // Subinterpreters get their own copy of builtins. detail::get_internals() still
17812391Sjason@lowepower.com    // works by returning from the static variable, i.e. all interpreters share a single
17912391Sjason@lowepower.com    // global pybind11::internals;
18012391Sjason@lowepower.com    REQUIRE_FALSE(has_pybind11_internals_builtin());
18112391Sjason@lowepower.com    REQUIRE(has_pybind11_internals_static());
18212391Sjason@lowepower.com
18312391Sjason@lowepower.com    // Modules tags should be gone.
18412391Sjason@lowepower.com    REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag"));
18512391Sjason@lowepower.com    {
18612391Sjason@lowepower.com        auto m = py::module::import("widget_module");
18712391Sjason@lowepower.com        REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
18812391Sjason@lowepower.com
18912391Sjason@lowepower.com        // Function bindings should still work.
19012391Sjason@lowepower.com        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
19112391Sjason@lowepower.com    }
19212391Sjason@lowepower.com
19312391Sjason@lowepower.com    // Restore main interpreter.
19412391Sjason@lowepower.com    Py_EndInterpreter(sub_tstate);
19512391Sjason@lowepower.com    PyThreadState_Swap(main_tstate);
19612391Sjason@lowepower.com
19712391Sjason@lowepower.com    REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag"));
19812391Sjason@lowepower.com    REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag"));
19912391Sjason@lowepower.com}
20012391Sjason@lowepower.com
20112391Sjason@lowepower.comTEST_CASE("Execution frame") {
20212391Sjason@lowepower.com    // When the interpreter is embedded, there is no execution frame, but `py::exec`
20312391Sjason@lowepower.com    // should still function by using reasonable globals: `__main__.__dict__`.
20412391Sjason@lowepower.com    py::exec("var = dict(number=42)");
20512391Sjason@lowepower.com    REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
20612391Sjason@lowepower.com}
20712391Sjason@lowepower.com
20812391Sjason@lowepower.comTEST_CASE("Threads") {
20912391Sjason@lowepower.com    // Restart interpreter to ensure threads are not initialized
21012391Sjason@lowepower.com    py::finalize_interpreter();
21112391Sjason@lowepower.com    py::initialize_interpreter();
21212391Sjason@lowepower.com    REQUIRE_FALSE(has_pybind11_internals_static());
21312391Sjason@lowepower.com
21412391Sjason@lowepower.com    constexpr auto num_threads = 10;
21512391Sjason@lowepower.com    auto locals = py::dict("count"_a=0);
21612391Sjason@lowepower.com
21712391Sjason@lowepower.com    {
21812391Sjason@lowepower.com        py::gil_scoped_release gil_release{};
21912391Sjason@lowepower.com        REQUIRE(has_pybind11_internals_static());
22012391Sjason@lowepower.com
22112391Sjason@lowepower.com        auto threads = std::vector<std::thread>();
22212391Sjason@lowepower.com        for (auto i = 0; i < num_threads; ++i) {
22312391Sjason@lowepower.com            threads.emplace_back([&]() {
22412391Sjason@lowepower.com                py::gil_scoped_acquire gil{};
22512391Sjason@lowepower.com                locals["count"] = locals["count"].cast<int>() + 1;
22612391Sjason@lowepower.com            });
22712391Sjason@lowepower.com        }
22812391Sjason@lowepower.com
22912391Sjason@lowepower.com        for (auto &thread : threads) {
23012391Sjason@lowepower.com            thread.join();
23112391Sjason@lowepower.com        }
23212391Sjason@lowepower.com    }
23312391Sjason@lowepower.com
23412391Sjason@lowepower.com    REQUIRE(locals["count"].cast<int>() == num_threads);
23512391Sjason@lowepower.com}
23612391Sjason@lowepower.com
23712391Sjason@lowepower.com// Scope exit utility https://stackoverflow.com/a/36644501/7255855
23812391Sjason@lowepower.comstruct scope_exit {
23912391Sjason@lowepower.com    std::function<void()> f_;
24012391Sjason@lowepower.com    explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
24112391Sjason@lowepower.com    ~scope_exit() { if (f_) f_(); }
24212391Sjason@lowepower.com};
24312391Sjason@lowepower.com
24412391Sjason@lowepower.comTEST_CASE("Reload module from file") {
24512391Sjason@lowepower.com    // Disable generation of cached bytecode (.pyc files) for this test, otherwise
24612391Sjason@lowepower.com    // Python might pick up an old version from the cache instead of the new versions
24712391Sjason@lowepower.com    // of the .py files generated below
24812391Sjason@lowepower.com    auto sys = py::module::import("sys");
24912391Sjason@lowepower.com    bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
25012391Sjason@lowepower.com    sys.attr("dont_write_bytecode") = true;
25112391Sjason@lowepower.com    // Reset the value at scope exit
25212391Sjason@lowepower.com    scope_exit reset_dont_write_bytecode([&]() {
25312391Sjason@lowepower.com        sys.attr("dont_write_bytecode") = dont_write_bytecode;
25412391Sjason@lowepower.com    });
25512391Sjason@lowepower.com
25612391Sjason@lowepower.com    std::string module_name = "test_module_reload";
25712391Sjason@lowepower.com    std::string module_file = module_name + ".py";
25812391Sjason@lowepower.com
25912391Sjason@lowepower.com    // Create the module .py file
26012391Sjason@lowepower.com    std::ofstream test_module(module_file);
26112391Sjason@lowepower.com    test_module << "def test():\n";
26212391Sjason@lowepower.com    test_module << "    return 1\n";
26312391Sjason@lowepower.com    test_module.close();
26412391Sjason@lowepower.com    // Delete the file at scope exit
26512391Sjason@lowepower.com    scope_exit delete_module_file([&]() {
26612391Sjason@lowepower.com        std::remove(module_file.c_str());
26712391Sjason@lowepower.com    });
26812391Sjason@lowepower.com
26912391Sjason@lowepower.com    // Import the module from file
27012391Sjason@lowepower.com    auto module = py::module::import(module_name.c_str());
27112391Sjason@lowepower.com    int result = module.attr("test")().cast<int>();
27212391Sjason@lowepower.com    REQUIRE(result == 1);
27312391Sjason@lowepower.com
27412391Sjason@lowepower.com    // Update the module .py file with a small change
27512391Sjason@lowepower.com    test_module.open(module_file);
27612391Sjason@lowepower.com    test_module << "def test():\n";
27712391Sjason@lowepower.com    test_module << "    return 2\n";
27812391Sjason@lowepower.com    test_module.close();
27912391Sjason@lowepower.com
28012391Sjason@lowepower.com    // Reload the module
28112391Sjason@lowepower.com    module.reload();
28212391Sjason@lowepower.com    result = module.attr("test")().cast<int>();
28312391Sjason@lowepower.com    REQUIRE(result == 2);
28412391Sjason@lowepower.com}
285