test_numpy_array.py revision 11986
13560SN/Aimport pytest 23560SN/Aimport gc 33560SN/A 43560SN/Awith pytest.suppress(ImportError): 53560SN/A import numpy as np 63560SN/A 73560SN/A 83560SN/A@pytest.fixture(scope='function') 93560SN/Adef arr(): 103560SN/A return np.array([[1, 2, 3], [4, 5, 6]], '<u2') 113560SN/A 123560SN/A 133560SN/A@pytest.requires_numpy 143560SN/Adef test_array_attributes(): 153560SN/A from pybind11_tests.array import ( 163560SN/A ndim, shape, strides, writeable, size, itemsize, nbytes, owndata 173560SN/A ) 183560SN/A 193560SN/A a = np.array(0, 'f8') 203560SN/A assert ndim(a) == 0 213560SN/A assert all(shape(a) == []) 223560SN/A assert all(strides(a) == []) 233560SN/A with pytest.raises(IndexError) as excinfo: 243560SN/A shape(a, 0) 253560SN/A assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)' 263560SN/A with pytest.raises(IndexError) as excinfo: 273560SN/A strides(a, 0) 283560SN/A assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)' 293560SN/A assert writeable(a) 303560SN/A assert size(a) == 1 313560SN/A assert itemsize(a) == 8 3211793Sbrandon.potter@amd.com assert nbytes(a) == 8 3311793Sbrandon.potter@amd.com assert owndata(a) 343560SN/A 353560SN/A a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view() 363560SN/A a.flags.writeable = False 373560SN/A assert ndim(a) == 2 3811793Sbrandon.potter@amd.com assert all(shape(a) == [2, 3]) 399329Sdam.sunwoo@arm.com assert shape(a, 0) == 2 403560SN/A assert shape(a, 1) == 3 413560SN/A assert all(strides(a) == [6, 2]) 428232Snate@binkert.org assert strides(a, 0) == 6 433560SN/A assert strides(a, 1) == 2 443560SN/A with pytest.raises(IndexError) as excinfo: 453560SN/A shape(a, 2) 463560SN/A assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)' 473560SN/A with pytest.raises(IndexError) as excinfo: 483560SN/A strides(a, 2) 493560SN/A assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)' 503560SN/A assert not writeable(a) 513560SN/A assert size(a) == 6 523560SN/A assert itemsize(a) == 2 5312181Sgabeblack@google.com assert nbytes(a) == 12 5412181Sgabeblack@google.com assert not owndata(a) 553560SN/A 563560SN/A 573560SN/A@pytest.requires_numpy 583560SN/A@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]) 593560SN/Adef test_index_offset(arr, args, ret): 603560SN/A from pybind11_tests.array import index_at, index_at_t, offset_at, offset_at_t 613560SN/A assert index_at(arr, *args) == ret 623560SN/A assert index_at_t(arr, *args) == ret 633560SN/A assert offset_at(arr, *args) == ret * arr.dtype.itemsize 643560SN/A assert offset_at_t(arr, *args) == ret * arr.dtype.itemsize 653560SN/A 663560SN/A 673560SN/A@pytest.requires_numpy 683560SN/Adef test_dim_check_fail(arr): 693560SN/A from pybind11_tests.array import (index_at, index_at_t, offset_at, offset_at_t, data, data_t, 703560SN/A mutate_data, mutate_data_t) 713560SN/A for func in (index_at, index_at_t, offset_at, offset_at_t, data, data_t, 723560SN/A mutate_data, mutate_data_t): 733560SN/A with pytest.raises(IndexError) as excinfo: 743560SN/A func(arr, 1, 2, 3) 753560SN/A assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)' 763560SN/A 773560SN/A 783560SN/A@pytest.requires_numpy 793560SN/A@pytest.mark.parametrize('args, ret', 803560SN/A [([], [1, 2, 3, 4, 5, 6]), 813560SN/A ([1], [4, 5, 6]), 823560SN/A ([0, 1], [2, 3, 4, 5, 6]), 833560SN/A ([1, 2], [6])]) 843560SN/Adef test_data(arr, args, ret): 853560SN/A from pybind11_tests.array import data, data_t 863560SN/A assert all(data_t(arr, *args) == ret) 873560SN/A assert all(data(arr, *args)[::2] == ret) 883560SN/A assert all(data(arr, *args)[1::2] == 0) 893560SN/A 903560SN/A 913560SN/A@pytest.requires_numpy 923560SN/Adef test_mutate_readonly(arr): 933560SN/A from pybind11_tests.array import mutate_data, mutate_data_t, mutate_at_t 943560SN/A arr.flags.writeable = False 953560SN/A for func, args in (mutate_data, ()), (mutate_data_t, ()), (mutate_at_t, (0, 0)): 963560SN/A with pytest.raises(RuntimeError) as excinfo: 973560SN/A func(arr, *args) 983560SN/A assert str(excinfo.value) == 'array is not writeable' 993560SN/A 1003560SN/A 1013560SN/A@pytest.requires_numpy 1023560SN/A@pytest.mark.parametrize('dim', [0, 1, 3]) 1033560SN/Adef test_at_fail(arr, dim): 1043560SN/A from pybind11_tests.array import at_t, mutate_at_t 1053560SN/A for func in at_t, mutate_at_t: 1063560SN/A with pytest.raises(IndexError) as excinfo: 1073560SN/A func(arr, *([0] * dim)) 1083560SN/A assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim) 1093560SN/A 1103560SN/A 1113560SN/A@pytest.requires_numpy 1123560SN/Adef test_at(arr): 1133560SN/A from pybind11_tests.array import at_t, mutate_at_t 1143560SN/A 1153560SN/A assert at_t(arr, 0, 2) == 3 1163560SN/A assert at_t(arr, 1, 0) == 4 1173560SN/A 1183560SN/A assert all(mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) 1193560SN/A assert all(mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) 1203560SN/A 1213560SN/A 1223560SN/A@pytest.requires_numpy 1233560SN/Adef test_mutate_data(arr): 1243560SN/A from pybind11_tests.array import mutate_data, mutate_data_t 1253560SN/A 1263560SN/A assert all(mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12]) 1273560SN/A assert all(mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24]) 1283560SN/A assert all(mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48]) 1293560SN/A assert all(mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96]) 1303560SN/A assert all(mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192]) 1313560SN/A 1323560SN/A assert all(mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193]) 1333560SN/A assert all(mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194]) 1343560SN/A assert all(mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195]) 1353560SN/A assert all(mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196]) 1363560SN/A assert all(mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197]) 1373560SN/A 1383560SN/A 1393560SN/A@pytest.requires_numpy 1403560SN/Adef test_bounds_check(arr): 1413560SN/A from pybind11_tests.array import (index_at, index_at_t, data, data_t, 1425191Ssaidi@eecs.umich.edu mutate_data, mutate_data_t, at_t, mutate_at_t) 1435191Ssaidi@eecs.umich.edu funcs = (index_at, index_at_t, data, data_t, 1445191Ssaidi@eecs.umich.edu mutate_data, mutate_data_t, at_t, mutate_at_t) 1453560SN/A for func in funcs: 1463560SN/A with pytest.raises(IndexError) as excinfo: 1477823Ssteve.reinhardt@amd.com func(arr, 2, 0) 1483560SN/A assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2' 1497823Ssteve.reinhardt@amd.com with pytest.raises(IndexError) as excinfo: 1503560SN/A func(arr, 0, 4) 1513560SN/A assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3' 1523560SN/A 1533560SN/A 1543560SN/A@pytest.requires_numpy 1553560SN/Adef test_make_c_f_array(): 1565568Snate@binkert.org from pybind11_tests.array import ( 1573560SN/A make_c_array, make_f_array 1583560SN/A ) 1593560SN/A assert make_c_array().flags.c_contiguous 1603560SN/A assert not make_c_array().flags.f_contiguous 1613560SN/A assert make_f_array().flags.f_contiguous 1623560SN/A assert not make_f_array().flags.c_contiguous 1633560SN/A 1643560SN/A 1653560SN/A@pytest.requires_numpy 1663560SN/Adef test_wrap(): 1673560SN/A from pybind11_tests.array import wrap 1683560SN/A 1693560SN/A def assert_references(a, b): 1703560SN/A assert a is not b 1715191Ssaidi@eecs.umich.edu assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0] 1725191Ssaidi@eecs.umich.edu assert a.shape == b.shape 1735191Ssaidi@eecs.umich.edu assert a.strides == b.strides 1745191Ssaidi@eecs.umich.edu assert a.flags.c_contiguous == b.flags.c_contiguous 1753560SN/A assert a.flags.f_contiguous == b.flags.f_contiguous 1763560SN/A assert a.flags.writeable == b.flags.writeable 1773560SN/A assert a.flags.aligned == b.flags.aligned 1783560SN/A assert a.flags.updateifcopy == b.flags.updateifcopy 1793560SN/A assert np.all(a == b) 1803560SN/A assert not b.flags.owndata 1813560SN/A assert b.base is a 1823560SN/A if a.flags.writeable and a.ndim == 2: 1833560SN/A a[0, 0] = 1234 1843560SN/A assert b[0, 0] == 1234 1853560SN/A 1863560SN/A a1 = np.array([1, 2], dtype=np.int16) 18710905Sandreas.sandberg@arm.com assert a1.flags.owndata and a1.base is None 1883560SN/A a2 = wrap(a1) 18910905Sandreas.sandberg@arm.com assert_references(a1, a2) 1903560SN/A 1913560SN/A a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F') 1923560SN/A assert a1.flags.owndata and a1.base is None 1933560SN/A a2 = wrap(a1) 1943560SN/A assert_references(a1, a2) 1953560SN/A 1963560SN/A a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C') 19710905Sandreas.sandberg@arm.com a1.flags.writeable = False 1983560SN/A a2 = wrap(a1) 19910905Sandreas.sandberg@arm.com assert_references(a1, a2) 2003560SN/A 2013560SN/A a1 = np.random.random((4, 4, 4)) 2023560SN/A a2 = wrap(a1) 2033560SN/A assert_references(a1, a2) 2043560SN/A 2053560SN/A a1 = a1.transpose() 2063560SN/A a2 = wrap(a1) 2075568Snate@binkert.org assert_references(a1, a2) 2085568Snate@binkert.org 209 a1 = a1.diagonal() 210 a2 = wrap(a1) 211 assert_references(a1, a2) 212 213 214@pytest.requires_numpy 215def test_numpy_view(capture): 216 from pybind11_tests.array import ArrayClass 217 with capture: 218 ac = ArrayClass() 219 ac_view_1 = ac.numpy_view() 220 ac_view_2 = ac.numpy_view() 221 assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32)) 222 del ac 223 gc.collect() 224 assert capture == """ 225 ArrayClass() 226 ArrayClass::numpy_view() 227 ArrayClass::numpy_view() 228 """ 229 ac_view_1[0] = 4 230 ac_view_1[1] = 3 231 assert ac_view_2[0] == 4 232 assert ac_view_2[1] == 3 233 with capture: 234 del ac_view_1 235 del ac_view_2 236 gc.collect() 237 assert capture == """ 238 ~ArrayClass() 239 """ 240 241 242@pytest.requires_numpy 243def test_cast_numpy_int64_to_uint64(): 244 from pybind11_tests.array import function_taking_uint64 245 function_taking_uint64(123) 246 function_taking_uint64(np.uint64(123)) 247 248 249@pytest.requires_numpy 250def test_isinstance(): 251 from pybind11_tests.array import isinstance_untyped, isinstance_typed 252 253 assert isinstance_untyped(np.array([1, 2, 3]), "not an array") 254 assert isinstance_typed(np.array([1.0, 2.0, 3.0])) 255 256 257@pytest.requires_numpy 258def test_constructors(): 259 from pybind11_tests.array import default_constructors, converting_constructors 260 261 defaults = default_constructors() 262 for a in defaults.values(): 263 assert a.size == 0 264 assert defaults["array"].dtype == np.array([]).dtype 265 assert defaults["array_t<int32>"].dtype == np.int32 266 assert defaults["array_t<double>"].dtype == np.float64 267 268 results = converting_constructors([1, 2, 3]) 269 for a in results.values(): 270 np.testing.assert_array_equal(a, [1, 2, 3]) 271 assert results["array"].dtype == np.int_ 272 assert results["array_t<int32>"].dtype == np.int32 273 assert results["array_t<double>"].dtype == np.float64 274