1import re 2import pytest 3from pybind11_tests import numpy_dtypes as m 4 5pytestmark = pytest.requires_numpy 6 7with pytest.suppress(ImportError): 8 import numpy as np 9 10 11@pytest.fixture(scope='module') 12def simple_dtype(): 13 ld = np.dtype('longdouble') 14 return np.dtype({'names': ['bool_', 'uint_', 'float_', 'ldbl_'], 15 'formats': ['?', 'u4', 'f4', 'f{}'.format(ld.itemsize)], 16 'offsets': [0, 4, 8, (16 if ld.alignment > 4 else 12)]}) 17 18 19@pytest.fixture(scope='module') 20def packed_dtype(): 21 return np.dtype([('bool_', '?'), ('uint_', 'u4'), ('float_', 'f4'), ('ldbl_', 'g')]) 22 23 24def dt_fmt(): 25 from sys import byteorder 26 e = '<' if byteorder == 'little' else '>' 27 return ("{{'names':['bool_','uint_','float_','ldbl_']," 28 " 'formats':['?','" + e + "u4','" + e + "f4','" + e + "f{}']," 29 " 'offsets':[0,4,8,{}], 'itemsize':{}}}") 30 31 32def simple_dtype_fmt(): 33 ld = np.dtype('longdouble') 34 simple_ld_off = 12 + 4 * (ld.alignment > 4) 35 return dt_fmt().format(ld.itemsize, simple_ld_off, simple_ld_off + ld.itemsize) 36 37 38def packed_dtype_fmt(): 39 from sys import byteorder 40 return "[('bool_', '?'), ('uint_', '{e}u4'), ('float_', '{e}f4'), ('ldbl_', '{e}f{}')]".format( 41 np.dtype('longdouble').itemsize, e='<' if byteorder == 'little' else '>') 42 43 44def partial_ld_offset(): 45 return 12 + 4 * (np.dtype('uint64').alignment > 4) + 8 + 8 * ( 46 np.dtype('longdouble').alignment > 8) 47 48 49def partial_dtype_fmt(): 50 ld = np.dtype('longdouble') 51 partial_ld_off = partial_ld_offset() 52 return dt_fmt().format(ld.itemsize, partial_ld_off, partial_ld_off + ld.itemsize) 53 54 55def partial_nested_fmt(): 56 ld = np.dtype('longdouble') 57 partial_nested_off = 8 + 8 * (ld.alignment > 8) 58 partial_ld_off = partial_ld_offset() 59 partial_nested_size = partial_nested_off * 2 + partial_ld_off + ld.itemsize 60 return "{{'names':['a'], 'formats':[{}], 'offsets':[{}], 'itemsize':{}}}".format( 61 partial_dtype_fmt(), partial_nested_off, partial_nested_size) 62 63 64def assert_equal(actual, expected_data, expected_dtype): 65 np.testing.assert_equal(actual, np.array(expected_data, dtype=expected_dtype)) 66 67 68def test_format_descriptors(): 69 with pytest.raises(RuntimeError) as excinfo: 70 m.get_format_unbound() 71 assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value)) 72 73 ld = np.dtype('longdouble') 74 ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char 75 ss_fmt = "^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}" 76 dbl = np.dtype('double') 77 partial_fmt = ("^T{?:bool_:3xI:uint_:f:float_:" + 78 str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) + 79 "xg:ldbl_:}") 80 nested_extra = str(max(8, ld.alignment)) 81 assert m.print_format_descriptors() == [ 82 ss_fmt, 83 "^T{?:bool_:I:uint_:f:float_:g:ldbl_:}", 84 "^T{" + ss_fmt + ":a:^T{?:bool_:I:uint_:f:float_:g:ldbl_:}:b:}", 85 partial_fmt, 86 "^T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}", 87 "^T{3s:a:3s:b:}", 88 "^T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}", 89 '^T{q:e1:B:e2:}', 90 '^T{Zf:cflt:Zd:cdbl:}' 91 ] 92 93 94def test_dtype(simple_dtype): 95 from sys import byteorder 96 e = '<' if byteorder == 'little' else '>' 97 98 assert m.print_dtypes() == [ 99 simple_dtype_fmt(), 100 packed_dtype_fmt(), 101 "[('a', {}), ('b', {})]".format(simple_dtype_fmt(), packed_dtype_fmt()), 102 partial_dtype_fmt(), 103 partial_nested_fmt(), 104 "[('a', 'S3'), ('b', 'S3')]", 105 ("{{'names':['a','b','c','d'], " + 106 "'formats':[('S4', (3,)),('" + e + "i4', (2,)),('u1', (3,)),('" + e + "f4', (4, 2))], " + 107 "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e), 108 "[('e1', '" + e + "i8'), ('e2', 'u1')]", 109 "[('x', 'i1'), ('y', '" + e + "u8')]", 110 "[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]" 111 ] 112 113 d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'], 114 'offsets': [1, 10], 'itemsize': 20}) 115 d2 = np.dtype([('a', 'i4'), ('b', 'f4')]) 116 assert m.test_dtype_ctors() == [np.dtype('int32'), np.dtype('float64'), 117 np.dtype('bool'), d1, d1, np.dtype('uint32'), d2] 118 119 assert m.test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True, 120 np.dtype('int32').itemsize, simple_dtype.itemsize] 121 122 assert m.trailing_padding_dtype() == m.buffer_to_dtype(np.zeros(1, m.trailing_padding_dtype())) 123 124 125def test_recarray(simple_dtype, packed_dtype): 126 elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)] 127 128 for func, dtype in [(m.create_rec_simple, simple_dtype), (m.create_rec_packed, packed_dtype)]: 129 arr = func(0) 130 assert arr.dtype == dtype 131 assert_equal(arr, [], simple_dtype) 132 assert_equal(arr, [], packed_dtype) 133 134 arr = func(3) 135 assert arr.dtype == dtype 136 assert_equal(arr, elements, simple_dtype) 137 assert_equal(arr, elements, packed_dtype) 138 139 if dtype == simple_dtype: 140 assert m.print_rec_simple(arr) == [ 141 "s:0,0,0,-0", 142 "s:1,1,1.5,-2.5", 143 "s:0,2,3,-5" 144 ] 145 else: 146 assert m.print_rec_packed(arr) == [ 147 "p:0,0,0,-0", 148 "p:1,1,1.5,-2.5", 149 "p:0,2,3,-5" 150 ] 151 152 nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)]) 153 154 arr = m.create_rec_nested(0) 155 assert arr.dtype == nested_dtype 156 assert_equal(arr, [], nested_dtype) 157 158 arr = m.create_rec_nested(3) 159 assert arr.dtype == nested_dtype 160 assert_equal(arr, [((False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5)), 161 ((True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)), 162 ((False, 2, 3.0, -5.0), (True, 3, 4.5, -7.5))], nested_dtype) 163 assert m.print_rec_nested(arr) == [ 164 "n:a=s:0,0,0,-0;b=p:1,1,1.5,-2.5", 165 "n:a=s:1,1,1.5,-2.5;b=p:0,2,3,-5", 166 "n:a=s:0,2,3,-5;b=p:1,3,4.5,-7.5" 167 ] 168 169 arr = m.create_rec_partial(3) 170 assert str(arr.dtype) == partial_dtype_fmt() 171 partial_dtype = arr.dtype 172 assert '' not in arr.dtype.fields 173 assert partial_dtype.itemsize > simple_dtype.itemsize 174 assert_equal(arr, elements, simple_dtype) 175 assert_equal(arr, elements, packed_dtype) 176 177 arr = m.create_rec_partial_nested(3) 178 assert str(arr.dtype) == partial_nested_fmt() 179 assert '' not in arr.dtype.fields 180 assert '' not in arr.dtype.fields['a'][0].fields 181 assert arr.dtype.itemsize > partial_dtype.itemsize 182 np.testing.assert_equal(arr['a'], m.create_rec_partial(3)) 183 184 185def test_array_constructors(): 186 data = np.arange(1, 7, dtype='int32') 187 for i in range(8): 188 np.testing.assert_array_equal(m.test_array_ctors(10 + i), data.reshape((3, 2))) 189 np.testing.assert_array_equal(m.test_array_ctors(20 + i), data.reshape((3, 2))) 190 for i in range(5): 191 np.testing.assert_array_equal(m.test_array_ctors(30 + i), data) 192 np.testing.assert_array_equal(m.test_array_ctors(40 + i), data) 193 194 195def test_string_array(): 196 arr = m.create_string_array(True) 197 assert str(arr.dtype) == "[('a', 'S3'), ('b', 'S3')]" 198 assert m.print_string_array(arr) == [ 199 "a='',b=''", 200 "a='a',b='a'", 201 "a='ab',b='ab'", 202 "a='abc',b='abc'" 203 ] 204 dtype = arr.dtype 205 assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc'] 206 assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc'] 207 arr = m.create_string_array(False) 208 assert dtype == arr.dtype 209 210 211def test_array_array(): 212 from sys import byteorder 213 e = '<' if byteorder == 'little' else '>' 214 215 arr = m.create_array_array(3) 216 assert str(arr.dtype) == ( 217 "{{'names':['a','b','c','d'], " + 218 "'formats':[('S4', (3,)),('" + e + "i4', (2,)),('u1', (3,)),('{e}f4', (4, 2))], " + 219 "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e) 220 assert m.print_array_array(arr) == [ 221 "a={{A,B,C,D},{K,L,M,N},{U,V,W,X}},b={0,1}," + 222 "c={0,1,2},d={{0,1},{10,11},{20,21},{30,31}}", 223 "a={{W,X,Y,Z},{G,H,I,J},{Q,R,S,T}},b={1000,1001}," + 224 "c={10,11,12},d={{100,101},{110,111},{120,121},{130,131}}", 225 "a={{S,T,U,V},{C,D,E,F},{M,N,O,P}},b={2000,2001}," + 226 "c={20,21,22},d={{200,201},{210,211},{220,221},{230,231}}", 227 ] 228 assert arr['a'].tolist() == [[b'ABCD', b'KLMN', b'UVWX'], 229 [b'WXYZ', b'GHIJ', b'QRST'], 230 [b'STUV', b'CDEF', b'MNOP']] 231 assert arr['b'].tolist() == [[0, 1], [1000, 1001], [2000, 2001]] 232 assert m.create_array_array(0).dtype == arr.dtype 233 234 235def test_enum_array(): 236 from sys import byteorder 237 e = '<' if byteorder == 'little' else '>' 238 239 arr = m.create_enum_array(3) 240 dtype = arr.dtype 241 assert dtype == np.dtype([('e1', e + 'i8'), ('e2', 'u1')]) 242 assert m.print_enum_array(arr) == [ 243 "e1=A,e2=X", 244 "e1=B,e2=Y", 245 "e1=A,e2=X" 246 ] 247 assert arr['e1'].tolist() == [-1, 1, -1] 248 assert arr['e2'].tolist() == [1, 2, 1] 249 assert m.create_enum_array(0).dtype == dtype 250 251 252def test_complex_array(): 253 from sys import byteorder 254 e = '<' if byteorder == 'little' else '>' 255 256 arr = m.create_complex_array(3) 257 dtype = arr.dtype 258 assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')]) 259 assert m.print_complex_array(arr) == [ 260 "c:(0,0.25),(0.5,0.75)", 261 "c:(1,1.25),(1.5,1.75)", 262 "c:(2,2.25),(2.5,2.75)" 263 ] 264 assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j] 265 assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j] 266 assert m.create_complex_array(0).dtype == dtype 267 268 269def test_signature(doc): 270 assert doc(m.create_rec_nested) == \ 271 "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]" 272 273 274def test_scalar_conversion(): 275 n = 3 276 arrays = [m.create_rec_simple(n), m.create_rec_packed(n), 277 m.create_rec_nested(n), m.create_enum_array(n)] 278 funcs = [m.f_simple, m.f_packed, m.f_nested] 279 280 for i, func in enumerate(funcs): 281 for j, arr in enumerate(arrays): 282 if i == j and i < 2: 283 assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)] 284 else: 285 with pytest.raises(TypeError) as excinfo: 286 func(arr[0]) 287 assert 'incompatible function arguments' in str(excinfo.value) 288 289 290def test_register_dtype(): 291 with pytest.raises(RuntimeError) as excinfo: 292 m.register_dtype() 293 assert 'dtype is already registered' in str(excinfo.value) 294 295 296@pytest.unsupported_on_pypy 297def test_str_leak(): 298 from sys import getrefcount 299 fmt = "f4" 300 pytest.gc_collect() 301 start = getrefcount(fmt) 302 d = m.dtype_wrapper(fmt) 303 assert d is np.dtype("f4") 304 del d 305 pytest.gc_collect() 306 assert getrefcount(fmt) == start 307 308 309def test_compare_buffer_info(): 310 assert all(m.compare_buffer_info()) 311