test_numpy_dtypes.py revision 11986
111986Sandreas.sandberg@arm.comimport re 211986Sandreas.sandberg@arm.comimport pytest 311986Sandreas.sandberg@arm.com 411986Sandreas.sandberg@arm.comwith pytest.suppress(ImportError): 511986Sandreas.sandberg@arm.com import numpy as np 611986Sandreas.sandberg@arm.com 711986Sandreas.sandberg@arm.com 811986Sandreas.sandberg@arm.com@pytest.fixture(scope='module') 911986Sandreas.sandberg@arm.comdef simple_dtype(): 1011986Sandreas.sandberg@arm.com return np.dtype({'names': ['x', 'y', 'z'], 1111986Sandreas.sandberg@arm.com 'formats': ['?', 'u4', 'f4'], 1211986Sandreas.sandberg@arm.com 'offsets': [0, 4, 8]}) 1311986Sandreas.sandberg@arm.com 1411986Sandreas.sandberg@arm.com 1511986Sandreas.sandberg@arm.com@pytest.fixture(scope='module') 1611986Sandreas.sandberg@arm.comdef packed_dtype(): 1711986Sandreas.sandberg@arm.com return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')]) 1811986Sandreas.sandberg@arm.com 1911986Sandreas.sandberg@arm.com 2011986Sandreas.sandberg@arm.comdef assert_equal(actual, expected_data, expected_dtype): 2111986Sandreas.sandberg@arm.com np.testing.assert_equal(actual, np.array(expected_data, dtype=expected_dtype)) 2211986Sandreas.sandberg@arm.com 2311986Sandreas.sandberg@arm.com 2411986Sandreas.sandberg@arm.com@pytest.requires_numpy 2511986Sandreas.sandberg@arm.comdef test_format_descriptors(): 2611986Sandreas.sandberg@arm.com from pybind11_tests import get_format_unbound, print_format_descriptors 2711986Sandreas.sandberg@arm.com 2811986Sandreas.sandberg@arm.com with pytest.raises(RuntimeError) as excinfo: 2911986Sandreas.sandberg@arm.com get_format_unbound() 3011986Sandreas.sandberg@arm.com assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value)) 3111986Sandreas.sandberg@arm.com 3211986Sandreas.sandberg@arm.com assert print_format_descriptors() == [ 3311986Sandreas.sandberg@arm.com "T{?:x:3xI:y:f:z:}", 3411986Sandreas.sandberg@arm.com "T{?:x:=I:y:=f:z:}", 3511986Sandreas.sandberg@arm.com "T{T{?:x:3xI:y:f:z:}:a:T{?:x:=I:y:=f:z:}:b:}", 3611986Sandreas.sandberg@arm.com "T{?:x:3xI:y:f:z:12x}", 3711986Sandreas.sandberg@arm.com "T{8xT{?:x:3xI:y:f:z:12x}:a:8x}", 3811986Sandreas.sandberg@arm.com "T{3s:a:3s:b:}", 3911986Sandreas.sandberg@arm.com 'T{q:e1:B:e2:}' 4011986Sandreas.sandberg@arm.com ] 4111986Sandreas.sandberg@arm.com 4211986Sandreas.sandberg@arm.com 4311986Sandreas.sandberg@arm.com@pytest.requires_numpy 4411986Sandreas.sandberg@arm.comdef test_dtype(simple_dtype): 4511986Sandreas.sandberg@arm.com from pybind11_tests import (print_dtypes, test_dtype_ctors, test_dtype_methods, 4611986Sandreas.sandberg@arm.com trailing_padding_dtype, buffer_to_dtype) 4711986Sandreas.sandberg@arm.com 4811986Sandreas.sandberg@arm.com assert print_dtypes() == [ 4911986Sandreas.sandberg@arm.com "{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}", 5011986Sandreas.sandberg@arm.com "[('x', '?'), ('y', '<u4'), ('z', '<f4')]", 5111986Sandreas.sandberg@arm.com "[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8]," 5211986Sandreas.sandberg@arm.com " 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]", 5311986Sandreas.sandberg@arm.com "{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}", 5411986Sandreas.sandberg@arm.com "{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4']," 5511986Sandreas.sandberg@arm.com " 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}", 5611986Sandreas.sandberg@arm.com "[('a', 'S3'), ('b', 'S3')]", 5711986Sandreas.sandberg@arm.com "[('e1', '<i8'), ('e2', 'u1')]", 5811986Sandreas.sandberg@arm.com "[('x', 'i1'), ('y', '<u8')]" 5911986Sandreas.sandberg@arm.com ] 6011986Sandreas.sandberg@arm.com 6111986Sandreas.sandberg@arm.com d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'], 6211986Sandreas.sandberg@arm.com 'offsets': [1, 10], 'itemsize': 20}) 6311986Sandreas.sandberg@arm.com d2 = np.dtype([('a', 'i4'), ('b', 'f4')]) 6411986Sandreas.sandberg@arm.com assert test_dtype_ctors() == [np.dtype('int32'), np.dtype('float64'), 6511986Sandreas.sandberg@arm.com np.dtype('bool'), d1, d1, np.dtype('uint32'), d2] 6611986Sandreas.sandberg@arm.com 6711986Sandreas.sandberg@arm.com assert test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True, 6811986Sandreas.sandberg@arm.com np.dtype('int32').itemsize, simple_dtype.itemsize] 6911986Sandreas.sandberg@arm.com 7011986Sandreas.sandberg@arm.com assert trailing_padding_dtype() == buffer_to_dtype(np.zeros(1, trailing_padding_dtype())) 7111986Sandreas.sandberg@arm.com 7211986Sandreas.sandberg@arm.com 7311986Sandreas.sandberg@arm.com@pytest.requires_numpy 7411986Sandreas.sandberg@arm.comdef test_recarray(simple_dtype, packed_dtype): 7511986Sandreas.sandberg@arm.com from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested, 7611986Sandreas.sandberg@arm.com print_rec_simple, print_rec_packed, print_rec_nested, 7711986Sandreas.sandberg@arm.com create_rec_partial, create_rec_partial_nested) 7811986Sandreas.sandberg@arm.com 7911986Sandreas.sandberg@arm.com elements = [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)] 8011986Sandreas.sandberg@arm.com 8111986Sandreas.sandberg@arm.com for func, dtype in [(create_rec_simple, simple_dtype), (create_rec_packed, packed_dtype)]: 8211986Sandreas.sandberg@arm.com arr = func(0) 8311986Sandreas.sandberg@arm.com assert arr.dtype == dtype 8411986Sandreas.sandberg@arm.com assert_equal(arr, [], simple_dtype) 8511986Sandreas.sandberg@arm.com assert_equal(arr, [], packed_dtype) 8611986Sandreas.sandberg@arm.com 8711986Sandreas.sandberg@arm.com arr = func(3) 8811986Sandreas.sandberg@arm.com assert arr.dtype == dtype 8911986Sandreas.sandberg@arm.com assert_equal(arr, elements, simple_dtype) 9011986Sandreas.sandberg@arm.com assert_equal(arr, elements, packed_dtype) 9111986Sandreas.sandberg@arm.com 9211986Sandreas.sandberg@arm.com if dtype == simple_dtype: 9311986Sandreas.sandberg@arm.com assert print_rec_simple(arr) == [ 9411986Sandreas.sandberg@arm.com "s:0,0,0", 9511986Sandreas.sandberg@arm.com "s:1,1,1.5", 9611986Sandreas.sandberg@arm.com "s:0,2,3" 9711986Sandreas.sandberg@arm.com ] 9811986Sandreas.sandberg@arm.com else: 9911986Sandreas.sandberg@arm.com assert print_rec_packed(arr) == [ 10011986Sandreas.sandberg@arm.com "p:0,0,0", 10111986Sandreas.sandberg@arm.com "p:1,1,1.5", 10211986Sandreas.sandberg@arm.com "p:0,2,3" 10311986Sandreas.sandberg@arm.com ] 10411986Sandreas.sandberg@arm.com 10511986Sandreas.sandberg@arm.com nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)]) 10611986Sandreas.sandberg@arm.com 10711986Sandreas.sandberg@arm.com arr = create_rec_nested(0) 10811986Sandreas.sandberg@arm.com assert arr.dtype == nested_dtype 10911986Sandreas.sandberg@arm.com assert_equal(arr, [], nested_dtype) 11011986Sandreas.sandberg@arm.com 11111986Sandreas.sandberg@arm.com arr = create_rec_nested(3) 11211986Sandreas.sandberg@arm.com assert arr.dtype == nested_dtype 11311986Sandreas.sandberg@arm.com assert_equal(arr, [((False, 0, 0.0), (True, 1, 1.5)), 11411986Sandreas.sandberg@arm.com ((True, 1, 1.5), (False, 2, 3.0)), 11511986Sandreas.sandberg@arm.com ((False, 2, 3.0), (True, 3, 4.5))], nested_dtype) 11611986Sandreas.sandberg@arm.com assert print_rec_nested(arr) == [ 11711986Sandreas.sandberg@arm.com "n:a=s:0,0,0;b=p:1,1,1.5", 11811986Sandreas.sandberg@arm.com "n:a=s:1,1,1.5;b=p:0,2,3", 11911986Sandreas.sandberg@arm.com "n:a=s:0,2,3;b=p:1,3,4.5" 12011986Sandreas.sandberg@arm.com ] 12111986Sandreas.sandberg@arm.com 12211986Sandreas.sandberg@arm.com arr = create_rec_partial(3) 12311986Sandreas.sandberg@arm.com assert str(arr.dtype) == \ 12411986Sandreas.sandberg@arm.com "{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}" 12511986Sandreas.sandberg@arm.com partial_dtype = arr.dtype 12611986Sandreas.sandberg@arm.com assert '' not in arr.dtype.fields 12711986Sandreas.sandberg@arm.com assert partial_dtype.itemsize > simple_dtype.itemsize 12811986Sandreas.sandberg@arm.com assert_equal(arr, elements, simple_dtype) 12911986Sandreas.sandberg@arm.com assert_equal(arr, elements, packed_dtype) 13011986Sandreas.sandberg@arm.com 13111986Sandreas.sandberg@arm.com arr = create_rec_partial_nested(3) 13211986Sandreas.sandberg@arm.com assert str(arr.dtype) == \ 13311986Sandreas.sandberg@arm.com "{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4']," \ 13411986Sandreas.sandberg@arm.com " 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}" 13511986Sandreas.sandberg@arm.com assert '' not in arr.dtype.fields 13611986Sandreas.sandberg@arm.com assert '' not in arr.dtype.fields['a'][0].fields 13711986Sandreas.sandberg@arm.com assert arr.dtype.itemsize > partial_dtype.itemsize 13811986Sandreas.sandberg@arm.com np.testing.assert_equal(arr['a'], create_rec_partial(3)) 13911986Sandreas.sandberg@arm.com 14011986Sandreas.sandberg@arm.com 14111986Sandreas.sandberg@arm.com@pytest.requires_numpy 14211986Sandreas.sandberg@arm.comdef test_array_constructors(): 14311986Sandreas.sandberg@arm.com from pybind11_tests import test_array_ctors 14411986Sandreas.sandberg@arm.com 14511986Sandreas.sandberg@arm.com data = np.arange(1, 7, dtype='int32') 14611986Sandreas.sandberg@arm.com for i in range(8): 14711986Sandreas.sandberg@arm.com np.testing.assert_array_equal(test_array_ctors(10 + i), data.reshape((3, 2))) 14811986Sandreas.sandberg@arm.com np.testing.assert_array_equal(test_array_ctors(20 + i), data.reshape((3, 2))) 14911986Sandreas.sandberg@arm.com for i in range(5): 15011986Sandreas.sandberg@arm.com np.testing.assert_array_equal(test_array_ctors(30 + i), data) 15111986Sandreas.sandberg@arm.com np.testing.assert_array_equal(test_array_ctors(40 + i), data) 15211986Sandreas.sandberg@arm.com 15311986Sandreas.sandberg@arm.com 15411986Sandreas.sandberg@arm.com@pytest.requires_numpy 15511986Sandreas.sandberg@arm.comdef test_string_array(): 15611986Sandreas.sandberg@arm.com from pybind11_tests import create_string_array, print_string_array 15711986Sandreas.sandberg@arm.com 15811986Sandreas.sandberg@arm.com arr = create_string_array(True) 15911986Sandreas.sandberg@arm.com assert str(arr.dtype) == "[('a', 'S3'), ('b', 'S3')]" 16011986Sandreas.sandberg@arm.com assert print_string_array(arr) == [ 16111986Sandreas.sandberg@arm.com "a='',b=''", 16211986Sandreas.sandberg@arm.com "a='a',b='a'", 16311986Sandreas.sandberg@arm.com "a='ab',b='ab'", 16411986Sandreas.sandberg@arm.com "a='abc',b='abc'" 16511986Sandreas.sandberg@arm.com ] 16611986Sandreas.sandberg@arm.com dtype = arr.dtype 16711986Sandreas.sandberg@arm.com assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc'] 16811986Sandreas.sandberg@arm.com assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc'] 16911986Sandreas.sandberg@arm.com arr = create_string_array(False) 17011986Sandreas.sandberg@arm.com assert dtype == arr.dtype 17111986Sandreas.sandberg@arm.com 17211986Sandreas.sandberg@arm.com 17311986Sandreas.sandberg@arm.com@pytest.requires_numpy 17411986Sandreas.sandberg@arm.comdef test_enum_array(): 17511986Sandreas.sandberg@arm.com from pybind11_tests import create_enum_array, print_enum_array 17611986Sandreas.sandberg@arm.com 17711986Sandreas.sandberg@arm.com arr = create_enum_array(3) 17811986Sandreas.sandberg@arm.com dtype = arr.dtype 17911986Sandreas.sandberg@arm.com assert dtype == np.dtype([('e1', '<i8'), ('e2', 'u1')]) 18011986Sandreas.sandberg@arm.com assert print_enum_array(arr) == [ 18111986Sandreas.sandberg@arm.com "e1=A,e2=X", 18211986Sandreas.sandberg@arm.com "e1=B,e2=Y", 18311986Sandreas.sandberg@arm.com "e1=A,e2=X" 18411986Sandreas.sandberg@arm.com ] 18511986Sandreas.sandberg@arm.com assert arr['e1'].tolist() == [-1, 1, -1] 18611986Sandreas.sandberg@arm.com assert arr['e2'].tolist() == [1, 2, 1] 18711986Sandreas.sandberg@arm.com assert create_enum_array(0).dtype == dtype 18811986Sandreas.sandberg@arm.com 18911986Sandreas.sandberg@arm.com 19011986Sandreas.sandberg@arm.com@pytest.requires_numpy 19111986Sandreas.sandberg@arm.comdef test_signature(doc): 19211986Sandreas.sandberg@arm.com from pybind11_tests import create_rec_nested 19311986Sandreas.sandberg@arm.com 19411986Sandreas.sandberg@arm.com assert doc(create_rec_nested) == "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]" 19511986Sandreas.sandberg@arm.com 19611986Sandreas.sandberg@arm.com 19711986Sandreas.sandberg@arm.com@pytest.requires_numpy 19811986Sandreas.sandberg@arm.comdef test_scalar_conversion(): 19911986Sandreas.sandberg@arm.com from pybind11_tests import (create_rec_simple, f_simple, 20011986Sandreas.sandberg@arm.com create_rec_packed, f_packed, 20111986Sandreas.sandberg@arm.com create_rec_nested, f_nested, 20211986Sandreas.sandberg@arm.com create_enum_array) 20311986Sandreas.sandberg@arm.com 20411986Sandreas.sandberg@arm.com n = 3 20511986Sandreas.sandberg@arm.com arrays = [create_rec_simple(n), create_rec_packed(n), 20611986Sandreas.sandberg@arm.com create_rec_nested(n), create_enum_array(n)] 20711986Sandreas.sandberg@arm.com funcs = [f_simple, f_packed, f_nested] 20811986Sandreas.sandberg@arm.com 20911986Sandreas.sandberg@arm.com for i, func in enumerate(funcs): 21011986Sandreas.sandberg@arm.com for j, arr in enumerate(arrays): 21111986Sandreas.sandberg@arm.com if i == j and i < 2: 21211986Sandreas.sandberg@arm.com assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)] 21311986Sandreas.sandberg@arm.com else: 21411986Sandreas.sandberg@arm.com with pytest.raises(TypeError) as excinfo: 21511986Sandreas.sandberg@arm.com func(arr[0]) 21611986Sandreas.sandberg@arm.com assert 'incompatible function arguments' in str(excinfo.value) 21711986Sandreas.sandberg@arm.com 21811986Sandreas.sandberg@arm.com 21911986Sandreas.sandberg@arm.com@pytest.requires_numpy 22011986Sandreas.sandberg@arm.comdef test_register_dtype(): 22111986Sandreas.sandberg@arm.com from pybind11_tests import register_dtype 22211986Sandreas.sandberg@arm.com 22311986Sandreas.sandberg@arm.com with pytest.raises(RuntimeError) as excinfo: 22411986Sandreas.sandberg@arm.com register_dtype() 22511986Sandreas.sandberg@arm.com assert 'dtype is already registered' in str(excinfo.value) 226