test_virtual_functions.py revision 12391:ceeca8b41e4b
112855Sgabeblack@google.comimport pytest 212855Sgabeblack@google.com 312855Sgabeblack@google.comfrom pybind11_tests import virtual_functions as m 412855Sgabeblack@google.comfrom pybind11_tests import ConstructorStats 512855Sgabeblack@google.com 612855Sgabeblack@google.com 712855Sgabeblack@google.comdef test_override(capture, msg): 812855Sgabeblack@google.com class ExtendedExampleVirt(m.ExampleVirt): 912855Sgabeblack@google.com def __init__(self, state): 1012855Sgabeblack@google.com super(ExtendedExampleVirt, self).__init__(state + 1) 1112855Sgabeblack@google.com self.data = "Hello world" 1212855Sgabeblack@google.com 1312855Sgabeblack@google.com def run(self, value): 1412855Sgabeblack@google.com print('ExtendedExampleVirt::run(%i), calling parent..' % value) 1512855Sgabeblack@google.com return super(ExtendedExampleVirt, self).run(value + 1) 1612855Sgabeblack@google.com 1712855Sgabeblack@google.com def run_bool(self): 1812855Sgabeblack@google.com print('ExtendedExampleVirt::run_bool()') 1912855Sgabeblack@google.com return False 2012855Sgabeblack@google.com 2112855Sgabeblack@google.com def get_string1(self): 2212855Sgabeblack@google.com return "override1" 2312855Sgabeblack@google.com 2412855Sgabeblack@google.com def pure_virtual(self): 2512855Sgabeblack@google.com print('ExtendedExampleVirt::pure_virtual(): %s' % self.data) 2612855Sgabeblack@google.com 2712855Sgabeblack@google.com class ExtendedExampleVirt2(ExtendedExampleVirt): 2812855Sgabeblack@google.com def __init__(self, state): 2912855Sgabeblack@google.com super(ExtendedExampleVirt2, self).__init__(state + 1) 3012855Sgabeblack@google.com 3112855Sgabeblack@google.com def get_string2(self): 3212855Sgabeblack@google.com return "override2" 3312855Sgabeblack@google.com 3412855Sgabeblack@google.com ex12 = m.ExampleVirt(10) 3512855Sgabeblack@google.com with capture: 3612855Sgabeblack@google.com assert m.runExampleVirt(ex12, 20) == 30 3712855Sgabeblack@google.com assert capture == """ 3812855Sgabeblack@google.com Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2) 3912855Sgabeblack@google.com """ # noqa: E501 line too long 4012855Sgabeblack@google.com 4112855Sgabeblack@google.com with pytest.raises(RuntimeError) as excinfo: 4212855Sgabeblack@google.com m.runExampleVirtVirtual(ex12) 4312855Sgabeblack@google.com assert msg(excinfo.value) == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"' 4412855Sgabeblack@google.com 4512855Sgabeblack@google.com ex12p = ExtendedExampleVirt(10) 4612855Sgabeblack@google.com with capture: 4712855Sgabeblack@google.com assert m.runExampleVirt(ex12p, 20) == 32 4812855Sgabeblack@google.com assert capture == """ 4912855Sgabeblack@google.com ExtendedExampleVirt::run(20), calling parent.. 5012855Sgabeblack@google.com Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2) 5112855Sgabeblack@google.com """ # noqa: E501 line too long 5212855Sgabeblack@google.com with capture: 5312855Sgabeblack@google.com assert m.runExampleVirtBool(ex12p) is False 5412855Sgabeblack@google.com assert capture == "ExtendedExampleVirt::run_bool()" 5512855Sgabeblack@google.com with capture: 5612855Sgabeblack@google.com m.runExampleVirtVirtual(ex12p) 5712855Sgabeblack@google.com assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world" 5812855Sgabeblack@google.com 5912855Sgabeblack@google.com ex12p2 = ExtendedExampleVirt2(15) 6012855Sgabeblack@google.com with capture: 6112855Sgabeblack@google.com assert m.runExampleVirt(ex12p2, 50) == 68 6212855Sgabeblack@google.com assert capture == """ 6312855Sgabeblack@google.com ExtendedExampleVirt::run(50), calling parent.. 6412855Sgabeblack@google.com Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2) 6512855Sgabeblack@google.com """ # noqa: E501 line too long 6612855Sgabeblack@google.com 6712855Sgabeblack@google.com cstats = ConstructorStats.get(m.ExampleVirt) 6812855Sgabeblack@google.com assert cstats.alive() == 3 6912855Sgabeblack@google.com del ex12, ex12p, ex12p2 7012855Sgabeblack@google.com assert cstats.alive() == 0 7112855Sgabeblack@google.com assert cstats.values() == ['10', '11', '17'] 7212855Sgabeblack@google.com assert cstats.copy_constructions == 0 7312855Sgabeblack@google.com assert cstats.move_constructions >= 0 7412855Sgabeblack@google.com 7512855Sgabeblack@google.com 7612855Sgabeblack@google.comdef test_alias_delay_initialization1(capture): 7712855Sgabeblack@google.com """`A` only initializes its trampoline class when we inherit from it 7812855Sgabeblack@google.com 7912855Sgabeblack@google.com If we just create and use an A instance directly, the trampoline initialization is 8012855Sgabeblack@google.com bypassed and we only initialize an A() instead (for performance reasons). 8112855Sgabeblack@google.com """ 8212855Sgabeblack@google.com class B(m.A): 8312855Sgabeblack@google.com def __init__(self): 8412855Sgabeblack@google.com super(B, self).__init__() 8512855Sgabeblack@google.com 8612855Sgabeblack@google.com def f(self): 8712855Sgabeblack@google.com print("In python f()") 88 89 # C++ version 90 with capture: 91 a = m.A() 92 m.call_f(a) 93 del a 94 pytest.gc_collect() 95 assert capture == "A.f()" 96 97 # Python version 98 with capture: 99 b = B() 100 m.call_f(b) 101 del b 102 pytest.gc_collect() 103 assert capture == """ 104 PyA.PyA() 105 PyA.f() 106 In python f() 107 PyA.~PyA() 108 """ 109 110 111def test_alias_delay_initialization2(capture): 112 """`A2`, unlike the above, is configured to always initialize the alias 113 114 While the extra initialization and extra class layer has small virtual dispatch 115 performance penalty, it also allows us to do more things with the trampoline 116 class such as defining local variables and performing construction/destruction. 117 """ 118 class B2(m.A2): 119 def __init__(self): 120 super(B2, self).__init__() 121 122 def f(self): 123 print("In python B2.f()") 124 125 # No python subclass version 126 with capture: 127 a2 = m.A2() 128 m.call_f(a2) 129 del a2 130 pytest.gc_collect() 131 a3 = m.A2(1) 132 m.call_f(a3) 133 del a3 134 pytest.gc_collect() 135 assert capture == """ 136 PyA2.PyA2() 137 PyA2.f() 138 A2.f() 139 PyA2.~PyA2() 140 PyA2.PyA2() 141 PyA2.f() 142 A2.f() 143 PyA2.~PyA2() 144 """ 145 146 # Python subclass version 147 with capture: 148 b2 = B2() 149 m.call_f(b2) 150 del b2 151 pytest.gc_collect() 152 assert capture == """ 153 PyA2.PyA2() 154 PyA2.f() 155 In python B2.f() 156 PyA2.~PyA2() 157 """ 158 159 160# PyPy: Reference count > 1 causes call with noncopyable instance 161# to fail in ncv1.print_nc() 162@pytest.unsupported_on_pypy 163@pytest.mark.skipif(not hasattr(m, "NCVirt"), reason="NCVirt test broken on ICPC") 164def test_move_support(): 165 class NCVirtExt(m.NCVirt): 166 def get_noncopyable(self, a, b): 167 # Constructs and returns a new instance: 168 nc = m.NonCopyable(a * a, b * b) 169 return nc 170 171 def get_movable(self, a, b): 172 # Return a referenced copy 173 self.movable = m.Movable(a, b) 174 return self.movable 175 176 class NCVirtExt2(m.NCVirt): 177 def get_noncopyable(self, a, b): 178 # Keep a reference: this is going to throw an exception 179 self.nc = m.NonCopyable(a, b) 180 return self.nc 181 182 def get_movable(self, a, b): 183 # Return a new instance without storing it 184 return m.Movable(a, b) 185 186 ncv1 = NCVirtExt() 187 assert ncv1.print_nc(2, 3) == "36" 188 assert ncv1.print_movable(4, 5) == "9" 189 ncv2 = NCVirtExt2() 190 assert ncv2.print_movable(7, 7) == "14" 191 # Don't check the exception message here because it differs under debug/non-debug mode 192 with pytest.raises(RuntimeError): 193 ncv2.print_nc(9, 9) 194 195 nc_stats = ConstructorStats.get(m.NonCopyable) 196 mv_stats = ConstructorStats.get(m.Movable) 197 assert nc_stats.alive() == 1 198 assert mv_stats.alive() == 1 199 del ncv1, ncv2 200 assert nc_stats.alive() == 0 201 assert mv_stats.alive() == 0 202 assert nc_stats.values() == ['4', '9', '9', '9'] 203 assert mv_stats.values() == ['4', '5', '7', '7'] 204 assert nc_stats.copy_constructions == 0 205 assert mv_stats.copy_constructions == 1 206 assert nc_stats.move_constructions >= 0 207 assert mv_stats.move_constructions >= 0 208 209 210def test_dispatch_issue(msg): 211 """#159: virtual function dispatch has problems with similar-named functions""" 212 class PyClass1(m.DispatchIssue): 213 def dispatch(self): 214 return "Yay.." 215 216 class PyClass2(m.DispatchIssue): 217 def dispatch(self): 218 with pytest.raises(RuntimeError) as excinfo: 219 super(PyClass2, self).dispatch() 220 assert msg(excinfo.value) == 'Tried to call pure virtual function "Base::dispatch"' 221 222 p = PyClass1() 223 return m.dispatch_issue_go(p) 224 225 b = PyClass2() 226 assert m.dispatch_issue_go(b) == "Yay.." 227 228 229def test_override_ref(): 230 """#392/397: overridding reference-returning functions""" 231 o = m.OverrideTest("asdf") 232 233 # Not allowed (see associated .cpp comment) 234 # i = o.str_ref() 235 # assert o.str_ref() == "asdf" 236 assert o.str_value() == "asdf" 237 238 assert o.A_value().value == "hi" 239 a = o.A_ref() 240 assert a.value == "hi" 241 a.value = "bye" 242 assert a.value == "bye" 243 244 245def test_inherited_virtuals(): 246 class AR(m.A_Repeat): 247 def unlucky_number(self): 248 return 99 249 250 class AT(m.A_Tpl): 251 def unlucky_number(self): 252 return 999 253 254 obj = AR() 255 assert obj.say_something(3) == "hihihi" 256 assert obj.unlucky_number() == 99 257 assert obj.say_everything() == "hi 99" 258 259 obj = AT() 260 assert obj.say_something(3) == "hihihi" 261 assert obj.unlucky_number() == 999 262 assert obj.say_everything() == "hi 999" 263 264 for obj in [m.B_Repeat(), m.B_Tpl()]: 265 assert obj.say_something(3) == "B says hi 3 times" 266 assert obj.unlucky_number() == 13 267 assert obj.lucky_number() == 7.0 268 assert obj.say_everything() == "B says hi 1 times 13" 269 270 for obj in [m.C_Repeat(), m.C_Tpl()]: 271 assert obj.say_something(3) == "B says hi 3 times" 272 assert obj.unlucky_number() == 4444 273 assert obj.lucky_number() == 888.0 274 assert obj.say_everything() == "B says hi 1 times 4444" 275 276 class CR(m.C_Repeat): 277 def lucky_number(self): 278 return m.C_Repeat.lucky_number(self) + 1.25 279 280 obj = CR() 281 assert obj.say_something(3) == "B says hi 3 times" 282 assert obj.unlucky_number() == 4444 283 assert obj.lucky_number() == 889.25 284 assert obj.say_everything() == "B says hi 1 times 4444" 285 286 class CT(m.C_Tpl): 287 pass 288 289 obj = CT() 290 assert obj.say_something(3) == "B says hi 3 times" 291 assert obj.unlucky_number() == 4444 292 assert obj.lucky_number() == 888.0 293 assert obj.say_everything() == "B says hi 1 times 4444" 294 295 class CCR(CR): 296 def lucky_number(self): 297 return CR.lucky_number(self) * 10 298 299 obj = CCR() 300 assert obj.say_something(3) == "B says hi 3 times" 301 assert obj.unlucky_number() == 4444 302 assert obj.lucky_number() == 8892.5 303 assert obj.say_everything() == "B says hi 1 times 4444" 304 305 class CCT(CT): 306 def lucky_number(self): 307 return CT.lucky_number(self) * 1000 308 309 obj = CCT() 310 assert obj.say_something(3) == "B says hi 3 times" 311 assert obj.unlucky_number() == 4444 312 assert obj.lucky_number() == 888000.0 313 assert obj.say_everything() == "B says hi 1 times 4444" 314 315 class DR(m.D_Repeat): 316 def unlucky_number(self): 317 return 123 318 319 def lucky_number(self): 320 return 42.0 321 322 for obj in [m.D_Repeat(), m.D_Tpl()]: 323 assert obj.say_something(3) == "B says hi 3 times" 324 assert obj.unlucky_number() == 4444 325 assert obj.lucky_number() == 888.0 326 assert obj.say_everything() == "B says hi 1 times 4444" 327 328 obj = DR() 329 assert obj.say_something(3) == "B says hi 3 times" 330 assert obj.unlucky_number() == 123 331 assert obj.lucky_number() == 42.0 332 assert obj.say_everything() == "B says hi 1 times 123" 333 334 class DT(m.D_Tpl): 335 def say_something(self, times): 336 return "DT says:" + (' quack' * times) 337 338 def unlucky_number(self): 339 return 1234 340 341 def lucky_number(self): 342 return -4.25 343 344 obj = DT() 345 assert obj.say_something(3) == "DT says: quack quack quack" 346 assert obj.unlucky_number() == 1234 347 assert obj.lucky_number() == -4.25 348 assert obj.say_everything() == "DT says: quack 1234" 349 350 class DT2(DT): 351 def say_something(self, times): 352 return "DT2: " + ('QUACK' * times) 353 354 def unlucky_number(self): 355 return -3 356 357 class BT(m.B_Tpl): 358 def say_something(self, times): 359 return "BT" * times 360 361 def unlucky_number(self): 362 return -7 363 364 def lucky_number(self): 365 return -1.375 366 367 obj = BT() 368 assert obj.say_something(3) == "BTBTBT" 369 assert obj.unlucky_number() == -7 370 assert obj.lucky_number() == -1.375 371 assert obj.say_everything() == "BT -7" 372