test_numpy_dtypes.py revision 11986
111986Sandreas.sandberg@arm.comimport re
211986Sandreas.sandberg@arm.comimport pytest
311986Sandreas.sandberg@arm.com
411986Sandreas.sandberg@arm.comwith pytest.suppress(ImportError):
511986Sandreas.sandberg@arm.com    import numpy as np
611986Sandreas.sandberg@arm.com
711986Sandreas.sandberg@arm.com
811986Sandreas.sandberg@arm.com@pytest.fixture(scope='module')
911986Sandreas.sandberg@arm.comdef simple_dtype():
1011986Sandreas.sandberg@arm.com    return np.dtype({'names': ['x', 'y', 'z'],
1111986Sandreas.sandberg@arm.com                     'formats': ['?', 'u4', 'f4'],
1211986Sandreas.sandberg@arm.com                     'offsets': [0, 4, 8]})
1311986Sandreas.sandberg@arm.com
1411986Sandreas.sandberg@arm.com
1511986Sandreas.sandberg@arm.com@pytest.fixture(scope='module')
1611986Sandreas.sandberg@arm.comdef packed_dtype():
1711986Sandreas.sandberg@arm.com    return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
1811986Sandreas.sandberg@arm.com
1911986Sandreas.sandberg@arm.com
2011986Sandreas.sandberg@arm.comdef assert_equal(actual, expected_data, expected_dtype):
2111986Sandreas.sandberg@arm.com    np.testing.assert_equal(actual, np.array(expected_data, dtype=expected_dtype))
2211986Sandreas.sandberg@arm.com
2311986Sandreas.sandberg@arm.com
2411986Sandreas.sandberg@arm.com@pytest.requires_numpy
2511986Sandreas.sandberg@arm.comdef test_format_descriptors():
2611986Sandreas.sandberg@arm.com    from pybind11_tests import get_format_unbound, print_format_descriptors
2711986Sandreas.sandberg@arm.com
2811986Sandreas.sandberg@arm.com    with pytest.raises(RuntimeError) as excinfo:
2911986Sandreas.sandberg@arm.com        get_format_unbound()
3011986Sandreas.sandberg@arm.com    assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))
3111986Sandreas.sandberg@arm.com
3211986Sandreas.sandberg@arm.com    assert print_format_descriptors() == [
3311986Sandreas.sandberg@arm.com        "T{?:x:3xI:y:f:z:}",
3411986Sandreas.sandberg@arm.com        "T{?:x:=I:y:=f:z:}",
3511986Sandreas.sandberg@arm.com        "T{T{?:x:3xI:y:f:z:}:a:T{?:x:=I:y:=f:z:}:b:}",
3611986Sandreas.sandberg@arm.com        "T{?:x:3xI:y:f:z:12x}",
3711986Sandreas.sandberg@arm.com        "T{8xT{?:x:3xI:y:f:z:12x}:a:8x}",
3811986Sandreas.sandberg@arm.com        "T{3s:a:3s:b:}",
3911986Sandreas.sandberg@arm.com        'T{q:e1:B:e2:}'
4011986Sandreas.sandberg@arm.com    ]
4111986Sandreas.sandberg@arm.com
4211986Sandreas.sandberg@arm.com
4311986Sandreas.sandberg@arm.com@pytest.requires_numpy
4411986Sandreas.sandberg@arm.comdef test_dtype(simple_dtype):
4511986Sandreas.sandberg@arm.com    from pybind11_tests import (print_dtypes, test_dtype_ctors, test_dtype_methods,
4611986Sandreas.sandberg@arm.com                                trailing_padding_dtype, buffer_to_dtype)
4711986Sandreas.sandberg@arm.com
4811986Sandreas.sandberg@arm.com    assert print_dtypes() == [
4911986Sandreas.sandberg@arm.com        "{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}",
5011986Sandreas.sandberg@arm.com        "[('x', '?'), ('y', '<u4'), ('z', '<f4')]",
5111986Sandreas.sandberg@arm.com        "[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8],"
5211986Sandreas.sandberg@arm.com        " 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]",
5311986Sandreas.sandberg@arm.com        "{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}",
5411986Sandreas.sandberg@arm.com        "{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'],"
5511986Sandreas.sandberg@arm.com        " 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}",
5611986Sandreas.sandberg@arm.com        "[('a', 'S3'), ('b', 'S3')]",
5711986Sandreas.sandberg@arm.com        "[('e1', '<i8'), ('e2', 'u1')]",
5811986Sandreas.sandberg@arm.com        "[('x', 'i1'), ('y', '<u8')]"
5911986Sandreas.sandberg@arm.com    ]
6011986Sandreas.sandberg@arm.com
6111986Sandreas.sandberg@arm.com    d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
6211986Sandreas.sandberg@arm.com                   'offsets': [1, 10], 'itemsize': 20})
6311986Sandreas.sandberg@arm.com    d2 = np.dtype([('a', 'i4'), ('b', 'f4')])
6411986Sandreas.sandberg@arm.com    assert test_dtype_ctors() == [np.dtype('int32'), np.dtype('float64'),
6511986Sandreas.sandberg@arm.com                                  np.dtype('bool'), d1, d1, np.dtype('uint32'), d2]
6611986Sandreas.sandberg@arm.com
6711986Sandreas.sandberg@arm.com    assert test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True,
6811986Sandreas.sandberg@arm.com                                    np.dtype('int32').itemsize, simple_dtype.itemsize]
6911986Sandreas.sandberg@arm.com
7011986Sandreas.sandberg@arm.com    assert trailing_padding_dtype() == buffer_to_dtype(np.zeros(1, trailing_padding_dtype()))
7111986Sandreas.sandberg@arm.com
7211986Sandreas.sandberg@arm.com
7311986Sandreas.sandberg@arm.com@pytest.requires_numpy
7411986Sandreas.sandberg@arm.comdef test_recarray(simple_dtype, packed_dtype):
7511986Sandreas.sandberg@arm.com    from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested,
7611986Sandreas.sandberg@arm.com                                print_rec_simple, print_rec_packed, print_rec_nested,
7711986Sandreas.sandberg@arm.com                                create_rec_partial, create_rec_partial_nested)
7811986Sandreas.sandberg@arm.com
7911986Sandreas.sandberg@arm.com    elements = [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)]
8011986Sandreas.sandberg@arm.com
8111986Sandreas.sandberg@arm.com    for func, dtype in [(create_rec_simple, simple_dtype), (create_rec_packed, packed_dtype)]:
8211986Sandreas.sandberg@arm.com        arr = func(0)
8311986Sandreas.sandberg@arm.com        assert arr.dtype == dtype
8411986Sandreas.sandberg@arm.com        assert_equal(arr, [], simple_dtype)
8511986Sandreas.sandberg@arm.com        assert_equal(arr, [], packed_dtype)
8611986Sandreas.sandberg@arm.com
8711986Sandreas.sandberg@arm.com        arr = func(3)
8811986Sandreas.sandberg@arm.com        assert arr.dtype == dtype
8911986Sandreas.sandberg@arm.com        assert_equal(arr, elements, simple_dtype)
9011986Sandreas.sandberg@arm.com        assert_equal(arr, elements, packed_dtype)
9111986Sandreas.sandberg@arm.com
9211986Sandreas.sandberg@arm.com        if dtype == simple_dtype:
9311986Sandreas.sandberg@arm.com            assert print_rec_simple(arr) == [
9411986Sandreas.sandberg@arm.com                "s:0,0,0",
9511986Sandreas.sandberg@arm.com                "s:1,1,1.5",
9611986Sandreas.sandberg@arm.com                "s:0,2,3"
9711986Sandreas.sandberg@arm.com            ]
9811986Sandreas.sandberg@arm.com        else:
9911986Sandreas.sandberg@arm.com            assert print_rec_packed(arr) == [
10011986Sandreas.sandberg@arm.com                "p:0,0,0",
10111986Sandreas.sandberg@arm.com                "p:1,1,1.5",
10211986Sandreas.sandberg@arm.com                "p:0,2,3"
10311986Sandreas.sandberg@arm.com            ]
10411986Sandreas.sandberg@arm.com
10511986Sandreas.sandberg@arm.com    nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)])
10611986Sandreas.sandberg@arm.com
10711986Sandreas.sandberg@arm.com    arr = create_rec_nested(0)
10811986Sandreas.sandberg@arm.com    assert arr.dtype == nested_dtype
10911986Sandreas.sandberg@arm.com    assert_equal(arr, [], nested_dtype)
11011986Sandreas.sandberg@arm.com
11111986Sandreas.sandberg@arm.com    arr = create_rec_nested(3)
11211986Sandreas.sandberg@arm.com    assert arr.dtype == nested_dtype
11311986Sandreas.sandberg@arm.com    assert_equal(arr, [((False, 0, 0.0), (True, 1, 1.5)),
11411986Sandreas.sandberg@arm.com                       ((True, 1, 1.5), (False, 2, 3.0)),
11511986Sandreas.sandberg@arm.com                       ((False, 2, 3.0), (True, 3, 4.5))], nested_dtype)
11611986Sandreas.sandberg@arm.com    assert print_rec_nested(arr) == [
11711986Sandreas.sandberg@arm.com        "n:a=s:0,0,0;b=p:1,1,1.5",
11811986Sandreas.sandberg@arm.com        "n:a=s:1,1,1.5;b=p:0,2,3",
11911986Sandreas.sandberg@arm.com        "n:a=s:0,2,3;b=p:1,3,4.5"
12011986Sandreas.sandberg@arm.com    ]
12111986Sandreas.sandberg@arm.com
12211986Sandreas.sandberg@arm.com    arr = create_rec_partial(3)
12311986Sandreas.sandberg@arm.com    assert str(arr.dtype) == \
12411986Sandreas.sandberg@arm.com        "{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}"
12511986Sandreas.sandberg@arm.com    partial_dtype = arr.dtype
12611986Sandreas.sandberg@arm.com    assert '' not in arr.dtype.fields
12711986Sandreas.sandberg@arm.com    assert partial_dtype.itemsize > simple_dtype.itemsize
12811986Sandreas.sandberg@arm.com    assert_equal(arr, elements, simple_dtype)
12911986Sandreas.sandberg@arm.com    assert_equal(arr, elements, packed_dtype)
13011986Sandreas.sandberg@arm.com
13111986Sandreas.sandberg@arm.com    arr = create_rec_partial_nested(3)
13211986Sandreas.sandberg@arm.com    assert str(arr.dtype) == \
13311986Sandreas.sandberg@arm.com        "{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4']," \
13411986Sandreas.sandberg@arm.com        " 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}"
13511986Sandreas.sandberg@arm.com    assert '' not in arr.dtype.fields
13611986Sandreas.sandberg@arm.com    assert '' not in arr.dtype.fields['a'][0].fields
13711986Sandreas.sandberg@arm.com    assert arr.dtype.itemsize > partial_dtype.itemsize
13811986Sandreas.sandberg@arm.com    np.testing.assert_equal(arr['a'], create_rec_partial(3))
13911986Sandreas.sandberg@arm.com
14011986Sandreas.sandberg@arm.com
14111986Sandreas.sandberg@arm.com@pytest.requires_numpy
14211986Sandreas.sandberg@arm.comdef test_array_constructors():
14311986Sandreas.sandberg@arm.com    from pybind11_tests import test_array_ctors
14411986Sandreas.sandberg@arm.com
14511986Sandreas.sandberg@arm.com    data = np.arange(1, 7, dtype='int32')
14611986Sandreas.sandberg@arm.com    for i in range(8):
14711986Sandreas.sandberg@arm.com        np.testing.assert_array_equal(test_array_ctors(10 + i), data.reshape((3, 2)))
14811986Sandreas.sandberg@arm.com        np.testing.assert_array_equal(test_array_ctors(20 + i), data.reshape((3, 2)))
14911986Sandreas.sandberg@arm.com    for i in range(5):
15011986Sandreas.sandberg@arm.com        np.testing.assert_array_equal(test_array_ctors(30 + i), data)
15111986Sandreas.sandberg@arm.com        np.testing.assert_array_equal(test_array_ctors(40 + i), data)
15211986Sandreas.sandberg@arm.com
15311986Sandreas.sandberg@arm.com
15411986Sandreas.sandberg@arm.com@pytest.requires_numpy
15511986Sandreas.sandberg@arm.comdef test_string_array():
15611986Sandreas.sandberg@arm.com    from pybind11_tests import create_string_array, print_string_array
15711986Sandreas.sandberg@arm.com
15811986Sandreas.sandberg@arm.com    arr = create_string_array(True)
15911986Sandreas.sandberg@arm.com    assert str(arr.dtype) == "[('a', 'S3'), ('b', 'S3')]"
16011986Sandreas.sandberg@arm.com    assert print_string_array(arr) == [
16111986Sandreas.sandberg@arm.com        "a='',b=''",
16211986Sandreas.sandberg@arm.com        "a='a',b='a'",
16311986Sandreas.sandberg@arm.com        "a='ab',b='ab'",
16411986Sandreas.sandberg@arm.com        "a='abc',b='abc'"
16511986Sandreas.sandberg@arm.com    ]
16611986Sandreas.sandberg@arm.com    dtype = arr.dtype
16711986Sandreas.sandberg@arm.com    assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc']
16811986Sandreas.sandberg@arm.com    assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc']
16911986Sandreas.sandberg@arm.com    arr = create_string_array(False)
17011986Sandreas.sandberg@arm.com    assert dtype == arr.dtype
17111986Sandreas.sandberg@arm.com
17211986Sandreas.sandberg@arm.com
17311986Sandreas.sandberg@arm.com@pytest.requires_numpy
17411986Sandreas.sandberg@arm.comdef test_enum_array():
17511986Sandreas.sandberg@arm.com    from pybind11_tests import create_enum_array, print_enum_array
17611986Sandreas.sandberg@arm.com
17711986Sandreas.sandberg@arm.com    arr = create_enum_array(3)
17811986Sandreas.sandberg@arm.com    dtype = arr.dtype
17911986Sandreas.sandberg@arm.com    assert dtype == np.dtype([('e1', '<i8'), ('e2', 'u1')])
18011986Sandreas.sandberg@arm.com    assert print_enum_array(arr) == [
18111986Sandreas.sandberg@arm.com        "e1=A,e2=X",
18211986Sandreas.sandberg@arm.com        "e1=B,e2=Y",
18311986Sandreas.sandberg@arm.com        "e1=A,e2=X"
18411986Sandreas.sandberg@arm.com    ]
18511986Sandreas.sandberg@arm.com    assert arr['e1'].tolist() == [-1, 1, -1]
18611986Sandreas.sandberg@arm.com    assert arr['e2'].tolist() == [1, 2, 1]
18711986Sandreas.sandberg@arm.com    assert create_enum_array(0).dtype == dtype
18811986Sandreas.sandberg@arm.com
18911986Sandreas.sandberg@arm.com
19011986Sandreas.sandberg@arm.com@pytest.requires_numpy
19111986Sandreas.sandberg@arm.comdef test_signature(doc):
19211986Sandreas.sandberg@arm.com    from pybind11_tests import create_rec_nested
19311986Sandreas.sandberg@arm.com
19411986Sandreas.sandberg@arm.com    assert doc(create_rec_nested) == "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
19511986Sandreas.sandberg@arm.com
19611986Sandreas.sandberg@arm.com
19711986Sandreas.sandberg@arm.com@pytest.requires_numpy
19811986Sandreas.sandberg@arm.comdef test_scalar_conversion():
19911986Sandreas.sandberg@arm.com    from pybind11_tests import (create_rec_simple, f_simple,
20011986Sandreas.sandberg@arm.com                                create_rec_packed, f_packed,
20111986Sandreas.sandberg@arm.com                                create_rec_nested, f_nested,
20211986Sandreas.sandberg@arm.com                                create_enum_array)
20311986Sandreas.sandberg@arm.com
20411986Sandreas.sandberg@arm.com    n = 3
20511986Sandreas.sandberg@arm.com    arrays = [create_rec_simple(n), create_rec_packed(n),
20611986Sandreas.sandberg@arm.com              create_rec_nested(n), create_enum_array(n)]
20711986Sandreas.sandberg@arm.com    funcs = [f_simple, f_packed, f_nested]
20811986Sandreas.sandberg@arm.com
20911986Sandreas.sandberg@arm.com    for i, func in enumerate(funcs):
21011986Sandreas.sandberg@arm.com        for j, arr in enumerate(arrays):
21111986Sandreas.sandberg@arm.com            if i == j and i < 2:
21211986Sandreas.sandberg@arm.com                assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)]
21311986Sandreas.sandberg@arm.com            else:
21411986Sandreas.sandberg@arm.com                with pytest.raises(TypeError) as excinfo:
21511986Sandreas.sandberg@arm.com                    func(arr[0])
21611986Sandreas.sandberg@arm.com                assert 'incompatible function arguments' in str(excinfo.value)
21711986Sandreas.sandberg@arm.com
21811986Sandreas.sandberg@arm.com
21911986Sandreas.sandberg@arm.com@pytest.requires_numpy
22011986Sandreas.sandberg@arm.comdef test_register_dtype():
22111986Sandreas.sandberg@arm.com    from pybind11_tests import register_dtype
22211986Sandreas.sandberg@arm.com
22311986Sandreas.sandberg@arm.com    with pytest.raises(RuntimeError) as excinfo:
22411986Sandreas.sandberg@arm.com        register_dtype()
22511986Sandreas.sandberg@arm.com    assert 'dtype is already registered' in str(excinfo.value)
226