test_virtual_functions.py revision 11986:c12e4625ab56
16899SN/Aimport pytest 26899SN/Aimport pybind11_tests 36899SN/Afrom pybind11_tests import ConstructorStats 46899SN/A 56899SN/A 66899SN/Adef test_override(capture, msg): 76899SN/A from pybind11_tests import (ExampleVirt, runExampleVirt, runExampleVirtVirtual, 86899SN/A runExampleVirtBool) 96899SN/A 106899SN/A class ExtendedExampleVirt(ExampleVirt): 116899SN/A def __init__(self, state): 126899SN/A super(ExtendedExampleVirt, self).__init__(state + 1) 136899SN/A self.data = "Hello world" 146899SN/A 156899SN/A def run(self, value): 166899SN/A print('ExtendedExampleVirt::run(%i), calling parent..' % value) 176899SN/A return super(ExtendedExampleVirt, self).run(value + 1) 186899SN/A 196899SN/A def run_bool(self): 206899SN/A print('ExtendedExampleVirt::run_bool()') 216899SN/A return False 226899SN/A 236899SN/A def get_string1(self): 246899SN/A return "override1" 256899SN/A 266899SN/A def pure_virtual(self): 276899SN/A print('ExtendedExampleVirt::pure_virtual(): %s' % self.data) 286899SN/A 296899SN/A class ExtendedExampleVirt2(ExtendedExampleVirt): 307632SBrad.Beckmann@amd.com def __init__(self, state): 317632SBrad.Beckmann@amd.com super(ExtendedExampleVirt2, self).__init__(state + 1) 326899SN/A 337053SN/A def get_string2(self): 347053SN/A return "override2" 356899SN/A 366899SN/A ex12 = ExampleVirt(10) 376899SN/A with capture: 386899SN/A assert runExampleVirt(ex12, 20) == 30 397053SN/A assert capture == """ 406899SN/A Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2) 416899SN/A """ # noqa: E501 line too long 426899SN/A 436899SN/A with pytest.raises(RuntimeError) as excinfo: 447053SN/A runExampleVirtVirtual(ex12) 456899SN/A assert msg(excinfo.value) == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"' 467053SN/A 477053SN/A ex12p = ExtendedExampleVirt(10) 486899SN/A with capture: 496899SN/A assert runExampleVirt(ex12p, 20) == 32 506899SN/A assert capture == """ 516899SN/A ExtendedExampleVirt::run(20), calling parent.. 527053SN/A Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2) 537053SN/A """ # noqa: E501 line too long 547053SN/A with capture: 556899SN/A assert runExampleVirtBool(ex12p) is False 566899SN/A assert capture == "ExtendedExampleVirt::run_bool()" 577053SN/A with capture: 587053SN/A runExampleVirtVirtual(ex12p) 596899SN/A assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world" 607053SN/A 616899SN/A ex12p2 = ExtendedExampleVirt2(15) 627454SN/A with capture: 637053SN/A assert runExampleVirt(ex12p2, 50) == 68 647053SN/A assert capture == """ 657053SN/A ExtendedExampleVirt::run(50), calling parent.. 666899SN/A Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2) 677053SN/A """ # noqa: E501 line too long 686899SN/A 697053SN/A cstats = ConstructorStats.get(ExampleVirt) 706899SN/A assert cstats.alive() == 3 716899SN/A del ex12, ex12p, ex12p2 726899SN/A assert cstats.alive() == 0 736899SN/A assert cstats.values() == ['10', '11', '17'] 746899SN/A assert cstats.copy_constructions == 0 756899SN/A assert cstats.move_constructions >= 0 766899SN/A 776899SN/A 786899SN/Adef test_inheriting_repeat(): 796899SN/A from pybind11_tests import A_Repeat, B_Repeat, C_Repeat, D_Repeat, A_Tpl, B_Tpl, C_Tpl, D_Tpl 806899SN/A 816899SN/A class AR(A_Repeat): 826899SN/A def unlucky_number(self): 836899SN/A return 99 846899SN/A 856899SN/A class AT(A_Tpl): 866899SN/A def unlucky_number(self): 876899SN/A return 999 886899SN/A 896899SN/A obj = AR() 906899SN/A assert obj.say_something(3) == "hihihi" 916899SN/A assert obj.unlucky_number() == 99 926899SN/A assert obj.say_everything() == "hi 99" 936899SN/A 946899SN/A obj = AT() 956899SN/A assert obj.say_something(3) == "hihihi" 967053SN/A assert obj.unlucky_number() == 999 977053SN/A assert obj.say_everything() == "hi 999" 986899SN/A 996899SN/A for obj in [B_Repeat(), B_Tpl()]: 1006899SN/A assert obj.say_something(3) == "B says hi 3 times" 1016899SN/A assert obj.unlucky_number() == 13 1026899SN/A assert obj.lucky_number() == 7.0 1037053SN/A assert obj.say_everything() == "B says hi 1 times 13" 1047053SN/A 1057053SN/A for obj in [C_Repeat(), C_Tpl()]: 1067053SN/A assert obj.say_something(3) == "B says hi 3 times" 1077053SN/A assert obj.unlucky_number() == 4444 1086899SN/A assert obj.lucky_number() == 888.0 1097053SN/A assert obj.say_everything() == "B says hi 1 times 4444" 1107053SN/A 1116899SN/A class CR(C_Repeat): 1127053SN/A def lucky_number(self): 1137053SN/A return C_Repeat.lucky_number(self) + 1.25 1147053SN/A 1157053SN/A obj = CR() 1167053SN/A assert obj.say_something(3) == "B says hi 3 times" 1177053SN/A assert obj.unlucky_number() == 4444 1187053SN/A assert obj.lucky_number() == 889.25 1197053SN/A assert obj.say_everything() == "B says hi 1 times 4444" 1206899SN/A 1216899SN/A class CT(C_Tpl): 1227053SN/A pass 1236899SN/A 1246899SN/A obj = CT() 1257053SN/A assert obj.say_something(3) == "B says hi 3 times" 1266899SN/A assert obj.unlucky_number() == 4444 1277053SN/A assert obj.lucky_number() == 888.0 1286899SN/A assert obj.say_everything() == "B says hi 1 times 4444" 1296899SN/A 1307053SN/A class CCR(CR): 1317053SN/A def lucky_number(self): 1326899SN/A return CR.lucky_number(self) * 10 1337053SN/A 1347053SN/A obj = CCR() 1356899SN/A assert obj.say_something(3) == "B says hi 3 times" 1367053SN/A assert obj.unlucky_number() == 4444 1377053SN/A assert obj.lucky_number() == 8892.5 1387053SN/A assert obj.say_everything() == "B says hi 1 times 4444" 1397053SN/A 1407053SN/A class CCT(CT): 1417053SN/A def lucky_number(self): 1427053SN/A return CT.lucky_number(self) * 1000 1436899SN/A 1447053SN/A obj = CCT() 1457053SN/A assert obj.say_something(3) == "B says hi 3 times" 1467053SN/A assert obj.unlucky_number() == 4444 1477053SN/A assert obj.lucky_number() == 888000.0 1487053SN/A assert obj.say_everything() == "B says hi 1 times 4444" 1496899SN/A 1506899SN/A class DR(D_Repeat): 1517053SN/A def unlucky_number(self): 1527053SN/A return 123 1537053SN/A 1547053SN/A def lucky_number(self): 1557053SN/A return 42.0 1567053SN/A 1577053SN/A for obj in [D_Repeat(), D_Tpl()]: 1587053SN/A assert obj.say_something(3) == "B says hi 3 times" 1597053SN/A assert obj.unlucky_number() == 4444 1607053SN/A assert obj.lucky_number() == 888.0 1617053SN/A assert obj.say_everything() == "B says hi 1 times 4444" 1627053SN/A 1637053SN/A obj = DR() 1647053SN/A assert obj.say_something(3) == "B says hi 3 times" 1657053SN/A assert obj.unlucky_number() == 123 1666899SN/A assert obj.lucky_number() == 42.0 1676899SN/A assert obj.say_everything() == "B says hi 1 times 123" 1687053SN/A 1697053SN/A class DT(D_Tpl): 1706899SN/A def say_something(self, times): 1717053SN/A return "DT says:" + (' quack' * times) 1727053SN/A 1737053SN/A def unlucky_number(self): 1747053SN/A return 1234 1757053SN/A 1767053SN/A def lucky_number(self): 1777053SN/A return -4.25 1787053SN/A 1797053SN/A obj = DT() 1807053SN/A assert obj.say_something(3) == "DT says: quack quack quack" 1817053SN/A assert obj.unlucky_number() == 1234 1826899SN/A assert obj.lucky_number() == -4.25 1836899SN/A assert obj.say_everything() == "DT says: quack 1234" 1846899SN/A 1857053SN/A class DT2(DT): 1867055SN/A def say_something(self, times): 1876899SN/A return "DT2: " + ('QUACK' * times) 1887055SN/A 1896899SN/A def unlucky_number(self): 1906899SN/A return -3 1916899SN/A 1926899SN/A class BT(B_Tpl): 1936899SN/A def say_something(self, times): 1946899SN/A return "BT" * times 1956899SN/A 196 def unlucky_number(self): 197 return -7 198 199 def lucky_number(self): 200 return -1.375 201 202 obj = BT() 203 assert obj.say_something(3) == "BTBTBT" 204 assert obj.unlucky_number() == -7 205 assert obj.lucky_number() == -1.375 206 assert obj.say_everything() == "BT -7" 207 208 209@pytest.mark.skipif(not hasattr(pybind11_tests, 'NCVirt'), 210 reason="NCVirt test broken on ICPC") 211def test_move_support(): 212 from pybind11_tests import NCVirt, NonCopyable, Movable 213 214 class NCVirtExt(NCVirt): 215 def get_noncopyable(self, a, b): 216 # Constructs and returns a new instance: 217 nc = NonCopyable(a * a, b * b) 218 return nc 219 220 def get_movable(self, a, b): 221 # Return a referenced copy 222 self.movable = Movable(a, b) 223 return self.movable 224 225 class NCVirtExt2(NCVirt): 226 def get_noncopyable(self, a, b): 227 # Keep a reference: this is going to throw an exception 228 self.nc = NonCopyable(a, b) 229 return self.nc 230 231 def get_movable(self, a, b): 232 # Return a new instance without storing it 233 return Movable(a, b) 234 235 ncv1 = NCVirtExt() 236 assert ncv1.print_nc(2, 3) == "36" 237 assert ncv1.print_movable(4, 5) == "9" 238 ncv2 = NCVirtExt2() 239 assert ncv2.print_movable(7, 7) == "14" 240 # Don't check the exception message here because it differs under debug/non-debug mode 241 with pytest.raises(RuntimeError): 242 ncv2.print_nc(9, 9) 243 244 nc_stats = ConstructorStats.get(NonCopyable) 245 mv_stats = ConstructorStats.get(Movable) 246 assert nc_stats.alive() == 1 247 assert mv_stats.alive() == 1 248 del ncv1, ncv2 249 assert nc_stats.alive() == 0 250 assert mv_stats.alive() == 0 251 assert nc_stats.values() == ['4', '9', '9', '9'] 252 assert mv_stats.values() == ['4', '5', '7', '7'] 253 assert nc_stats.copy_constructions == 0 254 assert mv_stats.copy_constructions == 1 255 assert nc_stats.move_constructions >= 0 256 assert mv_stats.move_constructions >= 0 257