test_numpy_dtypes.py revision 12037:d28054ac6ec9
1import re
2import pytest
3
4pytestmark = pytest.requires_numpy
5
6with pytest.suppress(ImportError):
7    import numpy as np
8
9
10@pytest.fixture(scope='module')
11def simple_dtype():
12    ld = np.dtype('longdouble')
13    return np.dtype({'names': ['bool_', 'uint_', 'float_', 'ldbl_'],
14                     'formats': ['?', 'u4', 'f4', 'f{}'.format(ld.itemsize)],
15                     'offsets': [0, 4, 8, (16 if ld.alignment > 4 else 12)]})
16
17
18@pytest.fixture(scope='module')
19def packed_dtype():
20    return np.dtype([('bool_', '?'), ('uint_', 'u4'), ('float_', 'f4'), ('ldbl_', 'g')])
21
22
23def dt_fmt():
24    from sys import byteorder
25    e = '<' if byteorder == 'little' else '>'
26    return ("{{'names':['bool_','uint_','float_','ldbl_'],"
27            " 'formats':['?','" + e + "u4','" + e + "f4','" + e + "f{}'],"
28            " 'offsets':[0,4,8,{}], 'itemsize':{}}}")
29
30
31def simple_dtype_fmt():
32    ld = np.dtype('longdouble')
33    simple_ld_off = 12 + 4 * (ld.alignment > 4)
34    return dt_fmt().format(ld.itemsize, simple_ld_off, simple_ld_off + ld.itemsize)
35
36
37def packed_dtype_fmt():
38    from sys import byteorder
39    return "[('bool_', '?'), ('uint_', '{e}u4'), ('float_', '{e}f4'), ('ldbl_', '{e}f{}')]".format(
40        np.dtype('longdouble').itemsize, e='<' if byteorder == 'little' else '>')
41
42
43def partial_ld_offset():
44    return 12 + 4 * (np.dtype('uint64').alignment > 4) + 8 + 8 * (
45        np.dtype('longdouble').alignment > 8)
46
47
48def partial_dtype_fmt():
49    ld = np.dtype('longdouble')
50    partial_ld_off = partial_ld_offset()
51    return dt_fmt().format(ld.itemsize, partial_ld_off, partial_ld_off + ld.itemsize)
52
53
54def partial_nested_fmt():
55    ld = np.dtype('longdouble')
56    partial_nested_off = 8 + 8 * (ld.alignment > 8)
57    partial_ld_off = partial_ld_offset()
58    partial_nested_size = partial_nested_off * 2 + partial_ld_off + ld.itemsize
59    return "{{'names':['a'], 'formats':[{}], 'offsets':[{}], 'itemsize':{}}}".format(
60        partial_dtype_fmt(), partial_nested_off, partial_nested_size)
61
62
63def assert_equal(actual, expected_data, expected_dtype):
64    np.testing.assert_equal(actual, np.array(expected_data, dtype=expected_dtype))
65
66
67def test_format_descriptors():
68    from pybind11_tests import get_format_unbound, print_format_descriptors
69
70    with pytest.raises(RuntimeError) as excinfo:
71        get_format_unbound()
72    assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))
73
74    ld = np.dtype('longdouble')
75    ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char
76    ss_fmt = "T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
77    dbl = np.dtype('double')
78    partial_fmt = ("T{?:bool_:3xI:uint_:f:float_:" +
79                   str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) +
80                   "xg:ldbl_:}")
81    nested_extra = str(max(8, ld.alignment))
82    assert print_format_descriptors() == [
83        ss_fmt,
84        "T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}",
85        "T{" + ss_fmt + ":a:T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}:b:}",
86        partial_fmt,
87        "T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
88        "T{3s:a:3s:b:}",
89        'T{q:e1:B:e2:}'
90    ]
91
92
93def test_dtype(simple_dtype):
94    from pybind11_tests import (print_dtypes, test_dtype_ctors, test_dtype_methods,
95                                trailing_padding_dtype, buffer_to_dtype)
96    from sys import byteorder
97    e = '<' if byteorder == 'little' else '>'
98
99    assert print_dtypes() == [
100        simple_dtype_fmt(),
101        packed_dtype_fmt(),
102        "[('a', {}), ('b', {})]".format(simple_dtype_fmt(), packed_dtype_fmt()),
103        partial_dtype_fmt(),
104        partial_nested_fmt(),
105        "[('a', 'S3'), ('b', 'S3')]",
106        "[('e1', '" + e + "i8'), ('e2', 'u1')]",
107        "[('x', 'i1'), ('y', '" + e + "u8')]"
108    ]
109
110    d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
111                   'offsets': [1, 10], 'itemsize': 20})
112    d2 = np.dtype([('a', 'i4'), ('b', 'f4')])
113    assert test_dtype_ctors() == [np.dtype('int32'), np.dtype('float64'),
114                                  np.dtype('bool'), d1, d1, np.dtype('uint32'), d2]
115
116    assert test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True,
117                                    np.dtype('int32').itemsize, simple_dtype.itemsize]
118
119    assert trailing_padding_dtype() == buffer_to_dtype(np.zeros(1, trailing_padding_dtype()))
120
121
122def test_recarray(simple_dtype, packed_dtype):
123    from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested,
124                                print_rec_simple, print_rec_packed, print_rec_nested,
125                                create_rec_partial, create_rec_partial_nested)
126
127    elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)]
128
129    for func, dtype in [(create_rec_simple, simple_dtype), (create_rec_packed, packed_dtype)]:
130        arr = func(0)
131        assert arr.dtype == dtype
132        assert_equal(arr, [], simple_dtype)
133        assert_equal(arr, [], packed_dtype)
134
135        arr = func(3)
136        assert arr.dtype == dtype
137        assert_equal(arr, elements, simple_dtype)
138        assert_equal(arr, elements, packed_dtype)
139
140        if dtype == simple_dtype:
141            assert print_rec_simple(arr) == [
142                "s:0,0,0,-0",
143                "s:1,1,1.5,-2.5",
144                "s:0,2,3,-5"
145            ]
146        else:
147            assert print_rec_packed(arr) == [
148                "p:0,0,0,-0",
149                "p:1,1,1.5,-2.5",
150                "p:0,2,3,-5"
151            ]
152
153    nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)])
154
155    arr = create_rec_nested(0)
156    assert arr.dtype == nested_dtype
157    assert_equal(arr, [], nested_dtype)
158
159    arr = create_rec_nested(3)
160    assert arr.dtype == nested_dtype
161    assert_equal(arr, [((False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5)),
162                       ((True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)),
163                       ((False, 2, 3.0, -5.0), (True, 3, 4.5, -7.5))], nested_dtype)
164    assert print_rec_nested(arr) == [
165        "n:a=s:0,0,0,-0;b=p:1,1,1.5,-2.5",
166        "n:a=s:1,1,1.5,-2.5;b=p:0,2,3,-5",
167        "n:a=s:0,2,3,-5;b=p:1,3,4.5,-7.5"
168    ]
169
170    arr = create_rec_partial(3)
171    assert str(arr.dtype) == partial_dtype_fmt()
172    partial_dtype = arr.dtype
173    assert '' not in arr.dtype.fields
174    assert partial_dtype.itemsize > simple_dtype.itemsize
175    assert_equal(arr, elements, simple_dtype)
176    assert_equal(arr, elements, packed_dtype)
177
178    arr = create_rec_partial_nested(3)
179    assert str(arr.dtype) == partial_nested_fmt()
180    assert '' not in arr.dtype.fields
181    assert '' not in arr.dtype.fields['a'][0].fields
182    assert arr.dtype.itemsize > partial_dtype.itemsize
183    np.testing.assert_equal(arr['a'], create_rec_partial(3))
184
185
186def test_array_constructors():
187    from pybind11_tests import test_array_ctors
188
189    data = np.arange(1, 7, dtype='int32')
190    for i in range(8):
191        np.testing.assert_array_equal(test_array_ctors(10 + i), data.reshape((3, 2)))
192        np.testing.assert_array_equal(test_array_ctors(20 + i), data.reshape((3, 2)))
193    for i in range(5):
194        np.testing.assert_array_equal(test_array_ctors(30 + i), data)
195        np.testing.assert_array_equal(test_array_ctors(40 + i), data)
196
197
198def test_string_array():
199    from pybind11_tests import create_string_array, print_string_array
200
201    arr = create_string_array(True)
202    assert str(arr.dtype) == "[('a', 'S3'), ('b', 'S3')]"
203    assert print_string_array(arr) == [
204        "a='',b=''",
205        "a='a',b='a'",
206        "a='ab',b='ab'",
207        "a='abc',b='abc'"
208    ]
209    dtype = arr.dtype
210    assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc']
211    assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc']
212    arr = create_string_array(False)
213    assert dtype == arr.dtype
214
215
216def test_enum_array():
217    from pybind11_tests import create_enum_array, print_enum_array
218    from sys import byteorder
219    e = '<' if byteorder == 'little' else '>'
220
221    arr = create_enum_array(3)
222    dtype = arr.dtype
223    assert dtype == np.dtype([('e1', e + 'i8'), ('e2', 'u1')])
224    assert print_enum_array(arr) == [
225        "e1=A,e2=X",
226        "e1=B,e2=Y",
227        "e1=A,e2=X"
228    ]
229    assert arr['e1'].tolist() == [-1, 1, -1]
230    assert arr['e2'].tolist() == [1, 2, 1]
231    assert create_enum_array(0).dtype == dtype
232
233
234def test_signature(doc):
235    from pybind11_tests import create_rec_nested
236
237    assert doc(create_rec_nested) == "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
238
239
240def test_scalar_conversion():
241    from pybind11_tests import (create_rec_simple, f_simple,
242                                create_rec_packed, f_packed,
243                                create_rec_nested, f_nested,
244                                create_enum_array)
245
246    n = 3
247    arrays = [create_rec_simple(n), create_rec_packed(n),
248              create_rec_nested(n), create_enum_array(n)]
249    funcs = [f_simple, f_packed, f_nested]
250
251    for i, func in enumerate(funcs):
252        for j, arr in enumerate(arrays):
253            if i == j and i < 2:
254                assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)]
255            else:
256                with pytest.raises(TypeError) as excinfo:
257                    func(arr[0])
258                assert 'incompatible function arguments' in str(excinfo.value)
259
260
261def test_register_dtype():
262    from pybind11_tests import register_dtype
263
264    with pytest.raises(RuntimeError) as excinfo:
265        register_dtype()
266    assert 'dtype is already registered' in str(excinfo.value)
267
268
269@pytest.requires_numpy
270def test_compare_buffer_info():
271    from pybind11_tests import compare_buffer_info
272    assert all(compare_buffer_info())
273