test_virtual_functions.py revision 12037:d28054ac6ec9
1import pytest
2import pybind11_tests
3from pybind11_tests import ConstructorStats
4
5
6def test_override(capture, msg):
7    from pybind11_tests import (ExampleVirt, runExampleVirt, runExampleVirtVirtual,
8                                runExampleVirtBool)
9
10    class ExtendedExampleVirt(ExampleVirt):
11        def __init__(self, state):
12            super(ExtendedExampleVirt, self).__init__(state + 1)
13            self.data = "Hello world"
14
15        def run(self, value):
16            print('ExtendedExampleVirt::run(%i), calling parent..' % value)
17            return super(ExtendedExampleVirt, self).run(value + 1)
18
19        def run_bool(self):
20            print('ExtendedExampleVirt::run_bool()')
21            return False
22
23        def get_string1(self):
24            return "override1"
25
26        def pure_virtual(self):
27            print('ExtendedExampleVirt::pure_virtual(): %s' % self.data)
28
29    class ExtendedExampleVirt2(ExtendedExampleVirt):
30        def __init__(self, state):
31            super(ExtendedExampleVirt2, self).__init__(state + 1)
32
33        def get_string2(self):
34            return "override2"
35
36    ex12 = ExampleVirt(10)
37    with capture:
38        assert runExampleVirt(ex12, 20) == 30
39    assert capture == """
40        Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
41    """  # noqa: E501 line too long
42
43    with pytest.raises(RuntimeError) as excinfo:
44        runExampleVirtVirtual(ex12)
45    assert msg(excinfo.value) == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
46
47    ex12p = ExtendedExampleVirt(10)
48    with capture:
49        assert runExampleVirt(ex12p, 20) == 32
50    assert capture == """
51        ExtendedExampleVirt::run(20), calling parent..
52        Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
53    """  # noqa: E501 line too long
54    with capture:
55        assert runExampleVirtBool(ex12p) is False
56    assert capture == "ExtendedExampleVirt::run_bool()"
57    with capture:
58        runExampleVirtVirtual(ex12p)
59    assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
60
61    ex12p2 = ExtendedExampleVirt2(15)
62    with capture:
63        assert runExampleVirt(ex12p2, 50) == 68
64    assert capture == """
65        ExtendedExampleVirt::run(50), calling parent..
66        Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
67    """  # noqa: E501 line too long
68
69    cstats = ConstructorStats.get(ExampleVirt)
70    assert cstats.alive() == 3
71    del ex12, ex12p, ex12p2
72    assert cstats.alive() == 0
73    assert cstats.values() == ['10', '11', '17']
74    assert cstats.copy_constructions == 0
75    assert cstats.move_constructions >= 0
76
77
78def test_inheriting_repeat():
79    from pybind11_tests import A_Repeat, B_Repeat, C_Repeat, D_Repeat, A_Tpl, B_Tpl, C_Tpl, D_Tpl
80
81    class AR(A_Repeat):
82        def unlucky_number(self):
83            return 99
84
85    class AT(A_Tpl):
86        def unlucky_number(self):
87            return 999
88
89    obj = AR()
90    assert obj.say_something(3) == "hihihi"
91    assert obj.unlucky_number() == 99
92    assert obj.say_everything() == "hi 99"
93
94    obj = AT()
95    assert obj.say_something(3) == "hihihi"
96    assert obj.unlucky_number() == 999
97    assert obj.say_everything() == "hi 999"
98
99    for obj in [B_Repeat(), B_Tpl()]:
100        assert obj.say_something(3) == "B says hi 3 times"
101        assert obj.unlucky_number() == 13
102        assert obj.lucky_number() == 7.0
103        assert obj.say_everything() == "B says hi 1 times 13"
104
105    for obj in [C_Repeat(), C_Tpl()]:
106        assert obj.say_something(3) == "B says hi 3 times"
107        assert obj.unlucky_number() == 4444
108        assert obj.lucky_number() == 888.0
109        assert obj.say_everything() == "B says hi 1 times 4444"
110
111    class CR(C_Repeat):
112        def lucky_number(self):
113            return C_Repeat.lucky_number(self) + 1.25
114
115    obj = CR()
116    assert obj.say_something(3) == "B says hi 3 times"
117    assert obj.unlucky_number() == 4444
118    assert obj.lucky_number() == 889.25
119    assert obj.say_everything() == "B says hi 1 times 4444"
120
121    class CT(C_Tpl):
122        pass
123
124    obj = CT()
125    assert obj.say_something(3) == "B says hi 3 times"
126    assert obj.unlucky_number() == 4444
127    assert obj.lucky_number() == 888.0
128    assert obj.say_everything() == "B says hi 1 times 4444"
129
130    class CCR(CR):
131        def lucky_number(self):
132            return CR.lucky_number(self) * 10
133
134    obj = CCR()
135    assert obj.say_something(3) == "B says hi 3 times"
136    assert obj.unlucky_number() == 4444
137    assert obj.lucky_number() == 8892.5
138    assert obj.say_everything() == "B says hi 1 times 4444"
139
140    class CCT(CT):
141        def lucky_number(self):
142            return CT.lucky_number(self) * 1000
143
144    obj = CCT()
145    assert obj.say_something(3) == "B says hi 3 times"
146    assert obj.unlucky_number() == 4444
147    assert obj.lucky_number() == 888000.0
148    assert obj.say_everything() == "B says hi 1 times 4444"
149
150    class DR(D_Repeat):
151        def unlucky_number(self):
152            return 123
153
154        def lucky_number(self):
155            return 42.0
156
157    for obj in [D_Repeat(), D_Tpl()]:
158        assert obj.say_something(3) == "B says hi 3 times"
159        assert obj.unlucky_number() == 4444
160        assert obj.lucky_number() == 888.0
161        assert obj.say_everything() == "B says hi 1 times 4444"
162
163    obj = DR()
164    assert obj.say_something(3) == "B says hi 3 times"
165    assert obj.unlucky_number() == 123
166    assert obj.lucky_number() == 42.0
167    assert obj.say_everything() == "B says hi 1 times 123"
168
169    class DT(D_Tpl):
170        def say_something(self, times):
171            return "DT says:" + (' quack' * times)
172
173        def unlucky_number(self):
174            return 1234
175
176        def lucky_number(self):
177            return -4.25
178
179    obj = DT()
180    assert obj.say_something(3) == "DT says: quack quack quack"
181    assert obj.unlucky_number() == 1234
182    assert obj.lucky_number() == -4.25
183    assert obj.say_everything() == "DT says: quack 1234"
184
185    class DT2(DT):
186        def say_something(self, times):
187            return "DT2: " + ('QUACK' * times)
188
189        def unlucky_number(self):
190            return -3
191
192    class BT(B_Tpl):
193        def say_something(self, times):
194            return "BT" * times
195
196        def unlucky_number(self):
197            return -7
198
199        def lucky_number(self):
200            return -1.375
201
202    obj = BT()
203    assert obj.say_something(3) == "BTBTBT"
204    assert obj.unlucky_number() == -7
205    assert obj.lucky_number() == -1.375
206    assert obj.say_everything() == "BT -7"
207
208
209# PyPy: Reference count > 1 causes call with noncopyable instance
210# to fail in ncv1.print_nc()
211@pytest.unsupported_on_pypy
212@pytest.mark.skipif(not hasattr(pybind11_tests, 'NCVirt'),
213                    reason="NCVirt test broken on ICPC")
214def test_move_support():
215    from pybind11_tests import NCVirt, NonCopyable, Movable
216
217    class NCVirtExt(NCVirt):
218        def get_noncopyable(self, a, b):
219            # Constructs and returns a new instance:
220            nc = NonCopyable(a * a, b * b)
221            return nc
222
223        def get_movable(self, a, b):
224            # Return a referenced copy
225            self.movable = Movable(a, b)
226            return self.movable
227
228    class NCVirtExt2(NCVirt):
229        def get_noncopyable(self, a, b):
230            # Keep a reference: this is going to throw an exception
231            self.nc = NonCopyable(a, b)
232            return self.nc
233
234        def get_movable(self, a, b):
235            # Return a new instance without storing it
236            return Movable(a, b)
237
238    ncv1 = NCVirtExt()
239    assert ncv1.print_nc(2, 3) == "36"
240    assert ncv1.print_movable(4, 5) == "9"
241    ncv2 = NCVirtExt2()
242    assert ncv2.print_movable(7, 7) == "14"
243    # Don't check the exception message here because it differs under debug/non-debug mode
244    with pytest.raises(RuntimeError):
245        ncv2.print_nc(9, 9)
246
247    nc_stats = ConstructorStats.get(NonCopyable)
248    mv_stats = ConstructorStats.get(Movable)
249    assert nc_stats.alive() == 1
250    assert mv_stats.alive() == 1
251    del ncv1, ncv2
252    assert nc_stats.alive() == 0
253    assert mv_stats.alive() == 0
254    assert nc_stats.values() == ['4', '9', '9', '9']
255    assert mv_stats.values() == ['4', '5', '7', '7']
256    assert nc_stats.copy_constructions == 0
257    assert mv_stats.copy_constructions == 1
258    assert nc_stats.move_constructions >= 0
259    assert mv_stats.move_constructions >= 0
260