test_virtual_functions.py revision 11986:c12e4625ab56
16899SN/Aimport pytest
26899SN/Aimport pybind11_tests
36899SN/Afrom pybind11_tests import ConstructorStats
46899SN/A
56899SN/A
66899SN/Adef test_override(capture, msg):
76899SN/A    from pybind11_tests import (ExampleVirt, runExampleVirt, runExampleVirtVirtual,
86899SN/A                                runExampleVirtBool)
96899SN/A
106899SN/A    class ExtendedExampleVirt(ExampleVirt):
116899SN/A        def __init__(self, state):
126899SN/A            super(ExtendedExampleVirt, self).__init__(state + 1)
136899SN/A            self.data = "Hello world"
146899SN/A
156899SN/A        def run(self, value):
166899SN/A            print('ExtendedExampleVirt::run(%i), calling parent..' % value)
176899SN/A            return super(ExtendedExampleVirt, self).run(value + 1)
186899SN/A
196899SN/A        def run_bool(self):
206899SN/A            print('ExtendedExampleVirt::run_bool()')
216899SN/A            return False
226899SN/A
236899SN/A        def get_string1(self):
246899SN/A            return "override1"
256899SN/A
266899SN/A        def pure_virtual(self):
276899SN/A            print('ExtendedExampleVirt::pure_virtual(): %s' % self.data)
286899SN/A
296899SN/A    class ExtendedExampleVirt2(ExtendedExampleVirt):
307632SBrad.Beckmann@amd.com        def __init__(self, state):
317632SBrad.Beckmann@amd.com            super(ExtendedExampleVirt2, self).__init__(state + 1)
326899SN/A
337053SN/A        def get_string2(self):
347053SN/A            return "override2"
356899SN/A
366899SN/A    ex12 = ExampleVirt(10)
376899SN/A    with capture:
386899SN/A        assert runExampleVirt(ex12, 20) == 30
397053SN/A    assert capture == """
406899SN/A        Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
416899SN/A    """  # noqa: E501 line too long
426899SN/A
436899SN/A    with pytest.raises(RuntimeError) as excinfo:
447053SN/A        runExampleVirtVirtual(ex12)
456899SN/A    assert msg(excinfo.value) == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
467053SN/A
477053SN/A    ex12p = ExtendedExampleVirt(10)
486899SN/A    with capture:
496899SN/A        assert runExampleVirt(ex12p, 20) == 32
506899SN/A    assert capture == """
516899SN/A        ExtendedExampleVirt::run(20), calling parent..
527053SN/A        Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
537053SN/A    """  # noqa: E501 line too long
547053SN/A    with capture:
556899SN/A        assert runExampleVirtBool(ex12p) is False
566899SN/A    assert capture == "ExtendedExampleVirt::run_bool()"
577053SN/A    with capture:
587053SN/A        runExampleVirtVirtual(ex12p)
596899SN/A    assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
607053SN/A
616899SN/A    ex12p2 = ExtendedExampleVirt2(15)
627454SN/A    with capture:
637053SN/A        assert runExampleVirt(ex12p2, 50) == 68
647053SN/A    assert capture == """
657053SN/A        ExtendedExampleVirt::run(50), calling parent..
666899SN/A        Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
677053SN/A    """  # noqa: E501 line too long
686899SN/A
697053SN/A    cstats = ConstructorStats.get(ExampleVirt)
706899SN/A    assert cstats.alive() == 3
716899SN/A    del ex12, ex12p, ex12p2
726899SN/A    assert cstats.alive() == 0
736899SN/A    assert cstats.values() == ['10', '11', '17']
746899SN/A    assert cstats.copy_constructions == 0
756899SN/A    assert cstats.move_constructions >= 0
766899SN/A
776899SN/A
786899SN/Adef test_inheriting_repeat():
796899SN/A    from pybind11_tests import A_Repeat, B_Repeat, C_Repeat, D_Repeat, A_Tpl, B_Tpl, C_Tpl, D_Tpl
806899SN/A
816899SN/A    class AR(A_Repeat):
826899SN/A        def unlucky_number(self):
836899SN/A            return 99
846899SN/A
856899SN/A    class AT(A_Tpl):
866899SN/A        def unlucky_number(self):
876899SN/A            return 999
886899SN/A
896899SN/A    obj = AR()
906899SN/A    assert obj.say_something(3) == "hihihi"
916899SN/A    assert obj.unlucky_number() == 99
926899SN/A    assert obj.say_everything() == "hi 99"
936899SN/A
946899SN/A    obj = AT()
956899SN/A    assert obj.say_something(3) == "hihihi"
967053SN/A    assert obj.unlucky_number() == 999
977053SN/A    assert obj.say_everything() == "hi 999"
986899SN/A
996899SN/A    for obj in [B_Repeat(), B_Tpl()]:
1006899SN/A        assert obj.say_something(3) == "B says hi 3 times"
1016899SN/A        assert obj.unlucky_number() == 13
1026899SN/A        assert obj.lucky_number() == 7.0
1037053SN/A        assert obj.say_everything() == "B says hi 1 times 13"
1047053SN/A
1057053SN/A    for obj in [C_Repeat(), C_Tpl()]:
1067053SN/A        assert obj.say_something(3) == "B says hi 3 times"
1077053SN/A        assert obj.unlucky_number() == 4444
1086899SN/A        assert obj.lucky_number() == 888.0
1097053SN/A        assert obj.say_everything() == "B says hi 1 times 4444"
1107053SN/A
1116899SN/A    class CR(C_Repeat):
1127053SN/A        def lucky_number(self):
1137053SN/A            return C_Repeat.lucky_number(self) + 1.25
1147053SN/A
1157053SN/A    obj = CR()
1167053SN/A    assert obj.say_something(3) == "B says hi 3 times"
1177053SN/A    assert obj.unlucky_number() == 4444
1187053SN/A    assert obj.lucky_number() == 889.25
1197053SN/A    assert obj.say_everything() == "B says hi 1 times 4444"
1206899SN/A
1216899SN/A    class CT(C_Tpl):
1227053SN/A        pass
1236899SN/A
1246899SN/A    obj = CT()
1257053SN/A    assert obj.say_something(3) == "B says hi 3 times"
1266899SN/A    assert obj.unlucky_number() == 4444
1277053SN/A    assert obj.lucky_number() == 888.0
1286899SN/A    assert obj.say_everything() == "B says hi 1 times 4444"
1296899SN/A
1307053SN/A    class CCR(CR):
1317053SN/A        def lucky_number(self):
1326899SN/A            return CR.lucky_number(self) * 10
1337053SN/A
1347053SN/A    obj = CCR()
1356899SN/A    assert obj.say_something(3) == "B says hi 3 times"
1367053SN/A    assert obj.unlucky_number() == 4444
1377053SN/A    assert obj.lucky_number() == 8892.5
1387053SN/A    assert obj.say_everything() == "B says hi 1 times 4444"
1397053SN/A
1407053SN/A    class CCT(CT):
1417053SN/A        def lucky_number(self):
1427053SN/A            return CT.lucky_number(self) * 1000
1436899SN/A
1447053SN/A    obj = CCT()
1457053SN/A    assert obj.say_something(3) == "B says hi 3 times"
1467053SN/A    assert obj.unlucky_number() == 4444
1477053SN/A    assert obj.lucky_number() == 888000.0
1487053SN/A    assert obj.say_everything() == "B says hi 1 times 4444"
1496899SN/A
1506899SN/A    class DR(D_Repeat):
1517053SN/A        def unlucky_number(self):
1527053SN/A            return 123
1537053SN/A
1547053SN/A        def lucky_number(self):
1557053SN/A            return 42.0
1567053SN/A
1577053SN/A    for obj in [D_Repeat(), D_Tpl()]:
1587053SN/A        assert obj.say_something(3) == "B says hi 3 times"
1597053SN/A        assert obj.unlucky_number() == 4444
1607053SN/A        assert obj.lucky_number() == 888.0
1617053SN/A        assert obj.say_everything() == "B says hi 1 times 4444"
1627053SN/A
1637053SN/A    obj = DR()
1647053SN/A    assert obj.say_something(3) == "B says hi 3 times"
1657053SN/A    assert obj.unlucky_number() == 123
1666899SN/A    assert obj.lucky_number() == 42.0
1676899SN/A    assert obj.say_everything() == "B says hi 1 times 123"
1687053SN/A
1697053SN/A    class DT(D_Tpl):
1706899SN/A        def say_something(self, times):
1717053SN/A            return "DT says:" + (' quack' * times)
1727053SN/A
1737053SN/A        def unlucky_number(self):
1747053SN/A            return 1234
1757053SN/A
1767053SN/A        def lucky_number(self):
1777053SN/A            return -4.25
1787053SN/A
1797053SN/A    obj = DT()
1807053SN/A    assert obj.say_something(3) == "DT says: quack quack quack"
1817053SN/A    assert obj.unlucky_number() == 1234
1826899SN/A    assert obj.lucky_number() == -4.25
1836899SN/A    assert obj.say_everything() == "DT says: quack 1234"
1846899SN/A
1857053SN/A    class DT2(DT):
1867055SN/A        def say_something(self, times):
1876899SN/A            return "DT2: " + ('QUACK' * times)
1887055SN/A
1896899SN/A        def unlucky_number(self):
1906899SN/A            return -3
1916899SN/A
1926899SN/A    class BT(B_Tpl):
1936899SN/A        def say_something(self, times):
1946899SN/A            return "BT" * times
1956899SN/A
196        def unlucky_number(self):
197            return -7
198
199        def lucky_number(self):
200            return -1.375
201
202    obj = BT()
203    assert obj.say_something(3) == "BTBTBT"
204    assert obj.unlucky_number() == -7
205    assert obj.lucky_number() == -1.375
206    assert obj.say_everything() == "BT -7"
207
208
209@pytest.mark.skipif(not hasattr(pybind11_tests, 'NCVirt'),
210                    reason="NCVirt test broken on ICPC")
211def test_move_support():
212    from pybind11_tests import NCVirt, NonCopyable, Movable
213
214    class NCVirtExt(NCVirt):
215        def get_noncopyable(self, a, b):
216            # Constructs and returns a new instance:
217            nc = NonCopyable(a * a, b * b)
218            return nc
219
220        def get_movable(self, a, b):
221            # Return a referenced copy
222            self.movable = Movable(a, b)
223            return self.movable
224
225    class NCVirtExt2(NCVirt):
226        def get_noncopyable(self, a, b):
227            # Keep a reference: this is going to throw an exception
228            self.nc = NonCopyable(a, b)
229            return self.nc
230
231        def get_movable(self, a, b):
232            # Return a new instance without storing it
233            return Movable(a, b)
234
235    ncv1 = NCVirtExt()
236    assert ncv1.print_nc(2, 3) == "36"
237    assert ncv1.print_movable(4, 5) == "9"
238    ncv2 = NCVirtExt2()
239    assert ncv2.print_movable(7, 7) == "14"
240    # Don't check the exception message here because it differs under debug/non-debug mode
241    with pytest.raises(RuntimeError):
242        ncv2.print_nc(9, 9)
243
244    nc_stats = ConstructorStats.get(NonCopyable)
245    mv_stats = ConstructorStats.get(Movable)
246    assert nc_stats.alive() == 1
247    assert mv_stats.alive() == 1
248    del ncv1, ncv2
249    assert nc_stats.alive() == 0
250    assert mv_stats.alive() == 0
251    assert nc_stats.values() == ['4', '9', '9', '9']
252    assert mv_stats.values() == ['4', '5', '7', '7']
253    assert nc_stats.copy_constructions == 0
254    assert mv_stats.copy_constructions == 1
255    assert nc_stats.move_constructions >= 0
256    assert mv_stats.move_constructions >= 0
257