1import re
2import pytest
3from pybind11_tests import numpy_dtypes as m
4
5pytestmark = pytest.requires_numpy
6
7with pytest.suppress(ImportError):
8    import numpy as np
9
10
11@pytest.fixture(scope='module')
12def simple_dtype():
13    ld = np.dtype('longdouble')
14    return np.dtype({'names': ['bool_', 'uint_', 'float_', 'ldbl_'],
15                     'formats': ['?', 'u4', 'f4', 'f{}'.format(ld.itemsize)],
16                     'offsets': [0, 4, 8, (16 if ld.alignment > 4 else 12)]})
17
18
19@pytest.fixture(scope='module')
20def packed_dtype():
21    return np.dtype([('bool_', '?'), ('uint_', 'u4'), ('float_', 'f4'), ('ldbl_', 'g')])
22
23
24def dt_fmt():
25    from sys import byteorder
26    e = '<' if byteorder == 'little' else '>'
27    return ("{{'names':['bool_','uint_','float_','ldbl_'],"
28            " 'formats':['?','" + e + "u4','" + e + "f4','" + e + "f{}'],"
29            " 'offsets':[0,4,8,{}], 'itemsize':{}}}")
30
31
32def simple_dtype_fmt():
33    ld = np.dtype('longdouble')
34    simple_ld_off = 12 + 4 * (ld.alignment > 4)
35    return dt_fmt().format(ld.itemsize, simple_ld_off, simple_ld_off + ld.itemsize)
36
37
38def packed_dtype_fmt():
39    from sys import byteorder
40    return "[('bool_', '?'), ('uint_', '{e}u4'), ('float_', '{e}f4'), ('ldbl_', '{e}f{}')]".format(
41        np.dtype('longdouble').itemsize, e='<' if byteorder == 'little' else '>')
42
43
44def partial_ld_offset():
45    return 12 + 4 * (np.dtype('uint64').alignment > 4) + 8 + 8 * (
46        np.dtype('longdouble').alignment > 8)
47
48
49def partial_dtype_fmt():
50    ld = np.dtype('longdouble')
51    partial_ld_off = partial_ld_offset()
52    return dt_fmt().format(ld.itemsize, partial_ld_off, partial_ld_off + ld.itemsize)
53
54
55def partial_nested_fmt():
56    ld = np.dtype('longdouble')
57    partial_nested_off = 8 + 8 * (ld.alignment > 8)
58    partial_ld_off = partial_ld_offset()
59    partial_nested_size = partial_nested_off * 2 + partial_ld_off + ld.itemsize
60    return "{{'names':['a'], 'formats':[{}], 'offsets':[{}], 'itemsize':{}}}".format(
61        partial_dtype_fmt(), partial_nested_off, partial_nested_size)
62
63
64def assert_equal(actual, expected_data, expected_dtype):
65    np.testing.assert_equal(actual, np.array(expected_data, dtype=expected_dtype))
66
67
68def test_format_descriptors():
69    with pytest.raises(RuntimeError) as excinfo:
70        m.get_format_unbound()
71    assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))
72
73    ld = np.dtype('longdouble')
74    ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char
75    ss_fmt = "^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
76    dbl = np.dtype('double')
77    partial_fmt = ("^T{?:bool_:3xI:uint_:f:float_:" +
78                   str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) +
79                   "xg:ldbl_:}")
80    nested_extra = str(max(8, ld.alignment))
81    assert m.print_format_descriptors() == [
82        ss_fmt,
83        "^T{?:bool_:I:uint_:f:float_:g:ldbl_:}",
84        "^T{" + ss_fmt + ":a:^T{?:bool_:I:uint_:f:float_:g:ldbl_:}:b:}",
85        partial_fmt,
86        "^T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
87        "^T{3s:a:3s:b:}",
88        "^T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}",
89        '^T{q:e1:B:e2:}',
90        '^T{Zf:cflt:Zd:cdbl:}'
91    ]
92
93
94def test_dtype(simple_dtype):
95    from sys import byteorder
96    e = '<' if byteorder == 'little' else '>'
97
98    assert m.print_dtypes() == [
99        simple_dtype_fmt(),
100        packed_dtype_fmt(),
101        "[('a', {}), ('b', {})]".format(simple_dtype_fmt(), packed_dtype_fmt()),
102        partial_dtype_fmt(),
103        partial_nested_fmt(),
104        "[('a', 'S3'), ('b', 'S3')]",
105        ("{{'names':['a','b','c','d'], " +
106         "'formats':[('S4', (3,)),('" + e + "i4', (2,)),('u1', (3,)),('" + e + "f4', (4, 2))], " +
107         "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e),
108        "[('e1', '" + e + "i8'), ('e2', 'u1')]",
109        "[('x', 'i1'), ('y', '" + e + "u8')]",
110        "[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]"
111    ]
112
113    d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
114                   'offsets': [1, 10], 'itemsize': 20})
115    d2 = np.dtype([('a', 'i4'), ('b', 'f4')])
116    assert m.test_dtype_ctors() == [np.dtype('int32'), np.dtype('float64'),
117                                    np.dtype('bool'), d1, d1, np.dtype('uint32'), d2]
118
119    assert m.test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True,
120                                      np.dtype('int32').itemsize, simple_dtype.itemsize]
121
122    assert m.trailing_padding_dtype() == m.buffer_to_dtype(np.zeros(1, m.trailing_padding_dtype()))
123
124
125def test_recarray(simple_dtype, packed_dtype):
126    elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)]
127
128    for func, dtype in [(m.create_rec_simple, simple_dtype), (m.create_rec_packed, packed_dtype)]:
129        arr = func(0)
130        assert arr.dtype == dtype
131        assert_equal(arr, [], simple_dtype)
132        assert_equal(arr, [], packed_dtype)
133
134        arr = func(3)
135        assert arr.dtype == dtype
136        assert_equal(arr, elements, simple_dtype)
137        assert_equal(arr, elements, packed_dtype)
138
139        if dtype == simple_dtype:
140            assert m.print_rec_simple(arr) == [
141                "s:0,0,0,-0",
142                "s:1,1,1.5,-2.5",
143                "s:0,2,3,-5"
144            ]
145        else:
146            assert m.print_rec_packed(arr) == [
147                "p:0,0,0,-0",
148                "p:1,1,1.5,-2.5",
149                "p:0,2,3,-5"
150            ]
151
152    nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)])
153
154    arr = m.create_rec_nested(0)
155    assert arr.dtype == nested_dtype
156    assert_equal(arr, [], nested_dtype)
157
158    arr = m.create_rec_nested(3)
159    assert arr.dtype == nested_dtype
160    assert_equal(arr, [((False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5)),
161                       ((True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)),
162                       ((False, 2, 3.0, -5.0), (True, 3, 4.5, -7.5))], nested_dtype)
163    assert m.print_rec_nested(arr) == [
164        "n:a=s:0,0,0,-0;b=p:1,1,1.5,-2.5",
165        "n:a=s:1,1,1.5,-2.5;b=p:0,2,3,-5",
166        "n:a=s:0,2,3,-5;b=p:1,3,4.5,-7.5"
167    ]
168
169    arr = m.create_rec_partial(3)
170    assert str(arr.dtype) == partial_dtype_fmt()
171    partial_dtype = arr.dtype
172    assert '' not in arr.dtype.fields
173    assert partial_dtype.itemsize > simple_dtype.itemsize
174    assert_equal(arr, elements, simple_dtype)
175    assert_equal(arr, elements, packed_dtype)
176
177    arr = m.create_rec_partial_nested(3)
178    assert str(arr.dtype) == partial_nested_fmt()
179    assert '' not in arr.dtype.fields
180    assert '' not in arr.dtype.fields['a'][0].fields
181    assert arr.dtype.itemsize > partial_dtype.itemsize
182    np.testing.assert_equal(arr['a'], m.create_rec_partial(3))
183
184
185def test_array_constructors():
186    data = np.arange(1, 7, dtype='int32')
187    for i in range(8):
188        np.testing.assert_array_equal(m.test_array_ctors(10 + i), data.reshape((3, 2)))
189        np.testing.assert_array_equal(m.test_array_ctors(20 + i), data.reshape((3, 2)))
190    for i in range(5):
191        np.testing.assert_array_equal(m.test_array_ctors(30 + i), data)
192        np.testing.assert_array_equal(m.test_array_ctors(40 + i), data)
193
194
195def test_string_array():
196    arr = m.create_string_array(True)
197    assert str(arr.dtype) == "[('a', 'S3'), ('b', 'S3')]"
198    assert m.print_string_array(arr) == [
199        "a='',b=''",
200        "a='a',b='a'",
201        "a='ab',b='ab'",
202        "a='abc',b='abc'"
203    ]
204    dtype = arr.dtype
205    assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc']
206    assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc']
207    arr = m.create_string_array(False)
208    assert dtype == arr.dtype
209
210
211def test_array_array():
212    from sys import byteorder
213    e = '<' if byteorder == 'little' else '>'
214
215    arr = m.create_array_array(3)
216    assert str(arr.dtype) == (
217        "{{'names':['a','b','c','d'], " +
218        "'formats':[('S4', (3,)),('" + e + "i4', (2,)),('u1', (3,)),('{e}f4', (4, 2))], " +
219        "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e)
220    assert m.print_array_array(arr) == [
221        "a={{A,B,C,D},{K,L,M,N},{U,V,W,X}},b={0,1}," +
222        "c={0,1,2},d={{0,1},{10,11},{20,21},{30,31}}",
223        "a={{W,X,Y,Z},{G,H,I,J},{Q,R,S,T}},b={1000,1001}," +
224        "c={10,11,12},d={{100,101},{110,111},{120,121},{130,131}}",
225        "a={{S,T,U,V},{C,D,E,F},{M,N,O,P}},b={2000,2001}," +
226        "c={20,21,22},d={{200,201},{210,211},{220,221},{230,231}}",
227    ]
228    assert arr['a'].tolist() == [[b'ABCD', b'KLMN', b'UVWX'],
229                                 [b'WXYZ', b'GHIJ', b'QRST'],
230                                 [b'STUV', b'CDEF', b'MNOP']]
231    assert arr['b'].tolist() == [[0, 1], [1000, 1001], [2000, 2001]]
232    assert m.create_array_array(0).dtype == arr.dtype
233
234
235def test_enum_array():
236    from sys import byteorder
237    e = '<' if byteorder == 'little' else '>'
238
239    arr = m.create_enum_array(3)
240    dtype = arr.dtype
241    assert dtype == np.dtype([('e1', e + 'i8'), ('e2', 'u1')])
242    assert m.print_enum_array(arr) == [
243        "e1=A,e2=X",
244        "e1=B,e2=Y",
245        "e1=A,e2=X"
246    ]
247    assert arr['e1'].tolist() == [-1, 1, -1]
248    assert arr['e2'].tolist() == [1, 2, 1]
249    assert m.create_enum_array(0).dtype == dtype
250
251
252def test_complex_array():
253    from sys import byteorder
254    e = '<' if byteorder == 'little' else '>'
255
256    arr = m.create_complex_array(3)
257    dtype = arr.dtype
258    assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')])
259    assert m.print_complex_array(arr) == [
260        "c:(0,0.25),(0.5,0.75)",
261        "c:(1,1.25),(1.5,1.75)",
262        "c:(2,2.25),(2.5,2.75)"
263    ]
264    assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j]
265    assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j]
266    assert m.create_complex_array(0).dtype == dtype
267
268
269def test_signature(doc):
270    assert doc(m.create_rec_nested) == \
271        "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
272
273
274def test_scalar_conversion():
275    n = 3
276    arrays = [m.create_rec_simple(n), m.create_rec_packed(n),
277              m.create_rec_nested(n), m.create_enum_array(n)]
278    funcs = [m.f_simple, m.f_packed, m.f_nested]
279
280    for i, func in enumerate(funcs):
281        for j, arr in enumerate(arrays):
282            if i == j and i < 2:
283                assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)]
284            else:
285                with pytest.raises(TypeError) as excinfo:
286                    func(arr[0])
287                assert 'incompatible function arguments' in str(excinfo.value)
288
289
290def test_register_dtype():
291    with pytest.raises(RuntimeError) as excinfo:
292        m.register_dtype()
293    assert 'dtype is already registered' in str(excinfo.value)
294
295
296@pytest.unsupported_on_pypy
297def test_str_leak():
298    from sys import getrefcount
299    fmt = "f4"
300    pytest.gc_collect()
301    start = getrefcount(fmt)
302    d = m.dtype_wrapper(fmt)
303    assert d is np.dtype("f4")
304    del d
305    pytest.gc_collect()
306    assert getrefcount(fmt) == start
307
308
309def test_compare_buffer_info():
310    assert all(m.compare_buffer_info())
311