test_virtual_functions.py revision 14299
1import pytest
2
3from pybind11_tests import virtual_functions as m
4from pybind11_tests import ConstructorStats
5
6
7def test_override(capture, msg):
8    class ExtendedExampleVirt(m.ExampleVirt):
9        def __init__(self, state):
10            super(ExtendedExampleVirt, self).__init__(state + 1)
11            self.data = "Hello world"
12
13        def run(self, value):
14            print('ExtendedExampleVirt::run(%i), calling parent..' % value)
15            return super(ExtendedExampleVirt, self).run(value + 1)
16
17        def run_bool(self):
18            print('ExtendedExampleVirt::run_bool()')
19            return False
20
21        def get_string1(self):
22            return "override1"
23
24        def pure_virtual(self):
25            print('ExtendedExampleVirt::pure_virtual(): %s' % self.data)
26
27    class ExtendedExampleVirt2(ExtendedExampleVirt):
28        def __init__(self, state):
29            super(ExtendedExampleVirt2, self).__init__(state + 1)
30
31        def get_string2(self):
32            return "override2"
33
34    ex12 = m.ExampleVirt(10)
35    with capture:
36        assert m.runExampleVirt(ex12, 20) == 30
37    assert capture == """
38        Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
39    """  # noqa: E501 line too long
40
41    with pytest.raises(RuntimeError) as excinfo:
42        m.runExampleVirtVirtual(ex12)
43    assert msg(excinfo.value) == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
44
45    ex12p = ExtendedExampleVirt(10)
46    with capture:
47        assert m.runExampleVirt(ex12p, 20) == 32
48    assert capture == """
49        ExtendedExampleVirt::run(20), calling parent..
50        Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
51    """  # noqa: E501 line too long
52    with capture:
53        assert m.runExampleVirtBool(ex12p) is False
54    assert capture == "ExtendedExampleVirt::run_bool()"
55    with capture:
56        m.runExampleVirtVirtual(ex12p)
57    assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
58
59    ex12p2 = ExtendedExampleVirt2(15)
60    with capture:
61        assert m.runExampleVirt(ex12p2, 50) == 68
62    assert capture == """
63        ExtendedExampleVirt::run(50), calling parent..
64        Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
65    """  # noqa: E501 line too long
66
67    cstats = ConstructorStats.get(m.ExampleVirt)
68    assert cstats.alive() == 3
69    del ex12, ex12p, ex12p2
70    assert cstats.alive() == 0
71    assert cstats.values() == ['10', '11', '17']
72    assert cstats.copy_constructions == 0
73    assert cstats.move_constructions >= 0
74
75
76def test_alias_delay_initialization1(capture):
77    """`A` only initializes its trampoline class when we inherit from it
78
79    If we just create and use an A instance directly, the trampoline initialization is
80    bypassed and we only initialize an A() instead (for performance reasons).
81    """
82    class B(m.A):
83        def __init__(self):
84            super(B, self).__init__()
85
86        def f(self):
87            print("In python f()")
88
89    # C++ version
90    with capture:
91        a = m.A()
92        m.call_f(a)
93        del a
94        pytest.gc_collect()
95    assert capture == "A.f()"
96
97    # Python version
98    with capture:
99        b = B()
100        m.call_f(b)
101        del b
102        pytest.gc_collect()
103    assert capture == """
104        PyA.PyA()
105        PyA.f()
106        In python f()
107        PyA.~PyA()
108    """
109
110
111def test_alias_delay_initialization2(capture):
112    """`A2`, unlike the above, is configured to always initialize the alias
113
114    While the extra initialization and extra class layer has small virtual dispatch
115    performance penalty, it also allows us to do more things with the trampoline
116    class such as defining local variables and performing construction/destruction.
117    """
118    class B2(m.A2):
119        def __init__(self):
120            super(B2, self).__init__()
121
122        def f(self):
123            print("In python B2.f()")
124
125    # No python subclass version
126    with capture:
127        a2 = m.A2()
128        m.call_f(a2)
129        del a2
130        pytest.gc_collect()
131        a3 = m.A2(1)
132        m.call_f(a3)
133        del a3
134        pytest.gc_collect()
135    assert capture == """
136        PyA2.PyA2()
137        PyA2.f()
138        A2.f()
139        PyA2.~PyA2()
140        PyA2.PyA2()
141        PyA2.f()
142        A2.f()
143        PyA2.~PyA2()
144    """
145
146    # Python subclass version
147    with capture:
148        b2 = B2()
149        m.call_f(b2)
150        del b2
151        pytest.gc_collect()
152    assert capture == """
153        PyA2.PyA2()
154        PyA2.f()
155        In python B2.f()
156        PyA2.~PyA2()
157    """
158
159
160# PyPy: Reference count > 1 causes call with noncopyable instance
161# to fail in ncv1.print_nc()
162@pytest.unsupported_on_pypy
163@pytest.mark.skipif(not hasattr(m, "NCVirt"), reason="NCVirt test broken on ICPC")
164def test_move_support():
165    class NCVirtExt(m.NCVirt):
166        def get_noncopyable(self, a, b):
167            # Constructs and returns a new instance:
168            nc = m.NonCopyable(a * a, b * b)
169            return nc
170
171        def get_movable(self, a, b):
172            # Return a referenced copy
173            self.movable = m.Movable(a, b)
174            return self.movable
175
176    class NCVirtExt2(m.NCVirt):
177        def get_noncopyable(self, a, b):
178            # Keep a reference: this is going to throw an exception
179            self.nc = m.NonCopyable(a, b)
180            return self.nc
181
182        def get_movable(self, a, b):
183            # Return a new instance without storing it
184            return m.Movable(a, b)
185
186    ncv1 = NCVirtExt()
187    assert ncv1.print_nc(2, 3) == "36"
188    assert ncv1.print_movable(4, 5) == "9"
189    ncv2 = NCVirtExt2()
190    assert ncv2.print_movable(7, 7) == "14"
191    # Don't check the exception message here because it differs under debug/non-debug mode
192    with pytest.raises(RuntimeError):
193        ncv2.print_nc(9, 9)
194
195    nc_stats = ConstructorStats.get(m.NonCopyable)
196    mv_stats = ConstructorStats.get(m.Movable)
197    assert nc_stats.alive() == 1
198    assert mv_stats.alive() == 1
199    del ncv1, ncv2
200    assert nc_stats.alive() == 0
201    assert mv_stats.alive() == 0
202    assert nc_stats.values() == ['4', '9', '9', '9']
203    assert mv_stats.values() == ['4', '5', '7', '7']
204    assert nc_stats.copy_constructions == 0
205    assert mv_stats.copy_constructions == 1
206    assert nc_stats.move_constructions >= 0
207    assert mv_stats.move_constructions >= 0
208
209
210def test_dispatch_issue(msg):
211    """#159: virtual function dispatch has problems with similar-named functions"""
212    class PyClass1(m.DispatchIssue):
213        def dispatch(self):
214            return "Yay.."
215
216    class PyClass2(m.DispatchIssue):
217        def dispatch(self):
218            with pytest.raises(RuntimeError) as excinfo:
219                super(PyClass2, self).dispatch()
220            assert msg(excinfo.value) == 'Tried to call pure virtual function "Base::dispatch"'
221
222            p = PyClass1()
223            return m.dispatch_issue_go(p)
224
225    b = PyClass2()
226    assert m.dispatch_issue_go(b) == "Yay.."
227
228
229def test_override_ref():
230    """#392/397: overriding reference-returning functions"""
231    o = m.OverrideTest("asdf")
232
233    # Not allowed (see associated .cpp comment)
234    # i = o.str_ref()
235    # assert o.str_ref() == "asdf"
236    assert o.str_value() == "asdf"
237
238    assert o.A_value().value == "hi"
239    a = o.A_ref()
240    assert a.value == "hi"
241    a.value = "bye"
242    assert a.value == "bye"
243
244
245def test_inherited_virtuals():
246    class AR(m.A_Repeat):
247        def unlucky_number(self):
248            return 99
249
250    class AT(m.A_Tpl):
251        def unlucky_number(self):
252            return 999
253
254    obj = AR()
255    assert obj.say_something(3) == "hihihi"
256    assert obj.unlucky_number() == 99
257    assert obj.say_everything() == "hi 99"
258
259    obj = AT()
260    assert obj.say_something(3) == "hihihi"
261    assert obj.unlucky_number() == 999
262    assert obj.say_everything() == "hi 999"
263
264    for obj in [m.B_Repeat(), m.B_Tpl()]:
265        assert obj.say_something(3) == "B says hi 3 times"
266        assert obj.unlucky_number() == 13
267        assert obj.lucky_number() == 7.0
268        assert obj.say_everything() == "B says hi 1 times 13"
269
270    for obj in [m.C_Repeat(), m.C_Tpl()]:
271        assert obj.say_something(3) == "B says hi 3 times"
272        assert obj.unlucky_number() == 4444
273        assert obj.lucky_number() == 888.0
274        assert obj.say_everything() == "B says hi 1 times 4444"
275
276    class CR(m.C_Repeat):
277        def lucky_number(self):
278            return m.C_Repeat.lucky_number(self) + 1.25
279
280    obj = CR()
281    assert obj.say_something(3) == "B says hi 3 times"
282    assert obj.unlucky_number() == 4444
283    assert obj.lucky_number() == 889.25
284    assert obj.say_everything() == "B says hi 1 times 4444"
285
286    class CT(m.C_Tpl):
287        pass
288
289    obj = CT()
290    assert obj.say_something(3) == "B says hi 3 times"
291    assert obj.unlucky_number() == 4444
292    assert obj.lucky_number() == 888.0
293    assert obj.say_everything() == "B says hi 1 times 4444"
294
295    class CCR(CR):
296        def lucky_number(self):
297            return CR.lucky_number(self) * 10
298
299    obj = CCR()
300    assert obj.say_something(3) == "B says hi 3 times"
301    assert obj.unlucky_number() == 4444
302    assert obj.lucky_number() == 8892.5
303    assert obj.say_everything() == "B says hi 1 times 4444"
304
305    class CCT(CT):
306        def lucky_number(self):
307            return CT.lucky_number(self) * 1000
308
309    obj = CCT()
310    assert obj.say_something(3) == "B says hi 3 times"
311    assert obj.unlucky_number() == 4444
312    assert obj.lucky_number() == 888000.0
313    assert obj.say_everything() == "B says hi 1 times 4444"
314
315    class DR(m.D_Repeat):
316        def unlucky_number(self):
317            return 123
318
319        def lucky_number(self):
320            return 42.0
321
322    for obj in [m.D_Repeat(), m.D_Tpl()]:
323        assert obj.say_something(3) == "B says hi 3 times"
324        assert obj.unlucky_number() == 4444
325        assert obj.lucky_number() == 888.0
326        assert obj.say_everything() == "B says hi 1 times 4444"
327
328    obj = DR()
329    assert obj.say_something(3) == "B says hi 3 times"
330    assert obj.unlucky_number() == 123
331    assert obj.lucky_number() == 42.0
332    assert obj.say_everything() == "B says hi 1 times 123"
333
334    class DT(m.D_Tpl):
335        def say_something(self, times):
336            return "DT says:" + (' quack' * times)
337
338        def unlucky_number(self):
339            return 1234
340
341        def lucky_number(self):
342            return -4.25
343
344    obj = DT()
345    assert obj.say_something(3) == "DT says: quack quack quack"
346    assert obj.unlucky_number() == 1234
347    assert obj.lucky_number() == -4.25
348    assert obj.say_everything() == "DT says: quack 1234"
349
350    class DT2(DT):
351        def say_something(self, times):
352            return "DT2: " + ('QUACK' * times)
353
354        def unlucky_number(self):
355            return -3
356
357    class BT(m.B_Tpl):
358        def say_something(self, times):
359            return "BT" * times
360
361        def unlucky_number(self):
362            return -7
363
364        def lucky_number(self):
365            return -1.375
366
367    obj = BT()
368    assert obj.say_something(3) == "BTBTBT"
369    assert obj.unlucky_number() == -7
370    assert obj.lucky_number() == -1.375
371    assert obj.say_everything() == "BT -7"
372
373
374def test_issue_1454():
375    # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
376    m.test_gil()
377    m.test_gil_from_thread()
378