test_numpy_vectorize.py revision 11986
111986Sandreas.sandberg@arm.comimport pytest 211986Sandreas.sandberg@arm.com 311986Sandreas.sandberg@arm.comwith pytest.suppress(ImportError): 411986Sandreas.sandberg@arm.com import numpy as np 511986Sandreas.sandberg@arm.com 611986Sandreas.sandberg@arm.com 711986Sandreas.sandberg@arm.com@pytest.requires_numpy 811986Sandreas.sandberg@arm.comdef test_vectorize(capture): 911986Sandreas.sandberg@arm.com from pybind11_tests import vectorized_func, vectorized_func2, vectorized_func3 1011986Sandreas.sandberg@arm.com 1111986Sandreas.sandberg@arm.com assert np.isclose(vectorized_func3(np.array(3 + 7j)), [6 + 14j]) 1211986Sandreas.sandberg@arm.com 1311986Sandreas.sandberg@arm.com for f in [vectorized_func, vectorized_func2]: 1411986Sandreas.sandberg@arm.com with capture: 1511986Sandreas.sandberg@arm.com assert np.isclose(f(1, 2, 3), 6) 1611986Sandreas.sandberg@arm.com assert capture == "my_func(x:int=1, y:float=2, z:float=3)" 1711986Sandreas.sandberg@arm.com with capture: 1811986Sandreas.sandberg@arm.com assert np.isclose(f(np.array(1), np.array(2), 3), 6) 1911986Sandreas.sandberg@arm.com assert capture == "my_func(x:int=1, y:float=2, z:float=3)" 2011986Sandreas.sandberg@arm.com with capture: 2111986Sandreas.sandberg@arm.com assert np.allclose(f(np.array([1, 3]), np.array([2, 4]), 3), [6, 36]) 2211986Sandreas.sandberg@arm.com assert capture == """ 2311986Sandreas.sandberg@arm.com my_func(x:int=1, y:float=2, z:float=3) 2411986Sandreas.sandberg@arm.com my_func(x:int=3, y:float=4, z:float=3) 2511986Sandreas.sandberg@arm.com """ 2611986Sandreas.sandberg@arm.com with capture: 2711986Sandreas.sandberg@arm.com a, b, c = np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3 2811986Sandreas.sandberg@arm.com assert np.allclose(f(a, b, c), a * b * c) 2911986Sandreas.sandberg@arm.com assert capture == """ 3011986Sandreas.sandberg@arm.com my_func(x:int=1, y:float=2, z:float=3) 3111986Sandreas.sandberg@arm.com my_func(x:int=3, y:float=4, z:float=3) 3211986Sandreas.sandberg@arm.com my_func(x:int=5, y:float=6, z:float=3) 3311986Sandreas.sandberg@arm.com my_func(x:int=7, y:float=8, z:float=3) 3411986Sandreas.sandberg@arm.com my_func(x:int=9, y:float=10, z:float=3) 3511986Sandreas.sandberg@arm.com my_func(x:int=11, y:float=12, z:float=3) 3611986Sandreas.sandberg@arm.com """ 3711986Sandreas.sandberg@arm.com with capture: 3811986Sandreas.sandberg@arm.com a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2 3911986Sandreas.sandberg@arm.com assert np.allclose(f(a, b, c), a * b * c) 4011986Sandreas.sandberg@arm.com assert capture == """ 4111986Sandreas.sandberg@arm.com my_func(x:int=1, y:float=2, z:float=2) 4211986Sandreas.sandberg@arm.com my_func(x:int=2, y:float=3, z:float=2) 4311986Sandreas.sandberg@arm.com my_func(x:int=3, y:float=4, z:float=2) 4411986Sandreas.sandberg@arm.com my_func(x:int=4, y:float=2, z:float=2) 4511986Sandreas.sandberg@arm.com my_func(x:int=5, y:float=3, z:float=2) 4611986Sandreas.sandberg@arm.com my_func(x:int=6, y:float=4, z:float=2) 4711986Sandreas.sandberg@arm.com """ 4811986Sandreas.sandberg@arm.com with capture: 4911986Sandreas.sandberg@arm.com a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2 5011986Sandreas.sandberg@arm.com assert np.allclose(f(a, b, c), a * b * c) 5111986Sandreas.sandberg@arm.com assert capture == """ 5211986Sandreas.sandberg@arm.com my_func(x:int=1, y:float=2, z:float=2) 5311986Sandreas.sandberg@arm.com my_func(x:int=2, y:float=2, z:float=2) 5411986Sandreas.sandberg@arm.com my_func(x:int=3, y:float=2, z:float=2) 5511986Sandreas.sandberg@arm.com my_func(x:int=4, y:float=3, z:float=2) 5611986Sandreas.sandberg@arm.com my_func(x:int=5, y:float=3, z:float=2) 5711986Sandreas.sandberg@arm.com my_func(x:int=6, y:float=3, z:float=2) 5811986Sandreas.sandberg@arm.com """ 5911986Sandreas.sandberg@arm.com 6011986Sandreas.sandberg@arm.com 6111986Sandreas.sandberg@arm.com@pytest.requires_numpy 6211986Sandreas.sandberg@arm.comdef test_type_selection(): 6311986Sandreas.sandberg@arm.com from pybind11_tests import selective_func 6411986Sandreas.sandberg@arm.com 6511986Sandreas.sandberg@arm.com assert selective_func(np.array([1], dtype=np.int32)) == "Int branch taken." 6611986Sandreas.sandberg@arm.com assert selective_func(np.array([1.0], dtype=np.float32)) == "Float branch taken." 6711986Sandreas.sandberg@arm.com assert selective_func(np.array([1.0j], dtype=np.complex64)) == "Complex float branch taken." 6811986Sandreas.sandberg@arm.com 6911986Sandreas.sandberg@arm.com 7011986Sandreas.sandberg@arm.com@pytest.requires_numpy 7111986Sandreas.sandberg@arm.comdef test_docs(doc): 7211986Sandreas.sandberg@arm.com from pybind11_tests import vectorized_func 7311986Sandreas.sandberg@arm.com 7411986Sandreas.sandberg@arm.com assert doc(vectorized_func) == """ 7511986Sandreas.sandberg@arm.com vectorized_func(arg0: numpy.ndarray[int], arg1: numpy.ndarray[float], arg2: numpy.ndarray[float]) -> object 7611986Sandreas.sandberg@arm.com """ # noqa: E501 line too long 77