test_numpy_array.py revision 11986
111986Sandreas.sandberg@arm.comimport pytest 211986Sandreas.sandberg@arm.comimport gc 311986Sandreas.sandberg@arm.com 411986Sandreas.sandberg@arm.comwith pytest.suppress(ImportError): 511986Sandreas.sandberg@arm.com import numpy as np 611986Sandreas.sandberg@arm.com 711986Sandreas.sandberg@arm.com 811986Sandreas.sandberg@arm.com@pytest.fixture(scope='function') 911986Sandreas.sandberg@arm.comdef arr(): 1011986Sandreas.sandberg@arm.com return np.array([[1, 2, 3], [4, 5, 6]], '<u2') 1111986Sandreas.sandberg@arm.com 1211986Sandreas.sandberg@arm.com 1311986Sandreas.sandberg@arm.com@pytest.requires_numpy 1411986Sandreas.sandberg@arm.comdef test_array_attributes(): 1511986Sandreas.sandberg@arm.com from pybind11_tests.array import ( 1611986Sandreas.sandberg@arm.com ndim, shape, strides, writeable, size, itemsize, nbytes, owndata 1711986Sandreas.sandberg@arm.com ) 1811986Sandreas.sandberg@arm.com 1911986Sandreas.sandberg@arm.com a = np.array(0, 'f8') 2011986Sandreas.sandberg@arm.com assert ndim(a) == 0 2111986Sandreas.sandberg@arm.com assert all(shape(a) == []) 2211986Sandreas.sandberg@arm.com assert all(strides(a) == []) 2311986Sandreas.sandberg@arm.com with pytest.raises(IndexError) as excinfo: 2411986Sandreas.sandberg@arm.com shape(a, 0) 2511986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)' 2611986Sandreas.sandberg@arm.com with pytest.raises(IndexError) as excinfo: 2711986Sandreas.sandberg@arm.com strides(a, 0) 2811986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)' 2911986Sandreas.sandberg@arm.com assert writeable(a) 3011986Sandreas.sandberg@arm.com assert size(a) == 1 3111986Sandreas.sandberg@arm.com assert itemsize(a) == 8 3211986Sandreas.sandberg@arm.com assert nbytes(a) == 8 3311986Sandreas.sandberg@arm.com assert owndata(a) 3411986Sandreas.sandberg@arm.com 3511986Sandreas.sandberg@arm.com a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view() 3611986Sandreas.sandberg@arm.com a.flags.writeable = False 3711986Sandreas.sandberg@arm.com assert ndim(a) == 2 3811986Sandreas.sandberg@arm.com assert all(shape(a) == [2, 3]) 3911986Sandreas.sandberg@arm.com assert shape(a, 0) == 2 4011986Sandreas.sandberg@arm.com assert shape(a, 1) == 3 4111986Sandreas.sandberg@arm.com assert all(strides(a) == [6, 2]) 4211986Sandreas.sandberg@arm.com assert strides(a, 0) == 6 4311986Sandreas.sandberg@arm.com assert strides(a, 1) == 2 4411986Sandreas.sandberg@arm.com with pytest.raises(IndexError) as excinfo: 4511986Sandreas.sandberg@arm.com shape(a, 2) 4611986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)' 4711986Sandreas.sandberg@arm.com with pytest.raises(IndexError) as excinfo: 4811986Sandreas.sandberg@arm.com strides(a, 2) 4911986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)' 5011986Sandreas.sandberg@arm.com assert not writeable(a) 5111986Sandreas.sandberg@arm.com assert size(a) == 6 5211986Sandreas.sandberg@arm.com assert itemsize(a) == 2 5311986Sandreas.sandberg@arm.com assert nbytes(a) == 12 5411986Sandreas.sandberg@arm.com assert not owndata(a) 5511986Sandreas.sandberg@arm.com 5611986Sandreas.sandberg@arm.com 5711986Sandreas.sandberg@arm.com@pytest.requires_numpy 5811986Sandreas.sandberg@arm.com@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]) 5911986Sandreas.sandberg@arm.comdef test_index_offset(arr, args, ret): 6011986Sandreas.sandberg@arm.com from pybind11_tests.array import index_at, index_at_t, offset_at, offset_at_t 6111986Sandreas.sandberg@arm.com assert index_at(arr, *args) == ret 6211986Sandreas.sandberg@arm.com assert index_at_t(arr, *args) == ret 6311986Sandreas.sandberg@arm.com assert offset_at(arr, *args) == ret * arr.dtype.itemsize 6411986Sandreas.sandberg@arm.com assert offset_at_t(arr, *args) == ret * arr.dtype.itemsize 6511986Sandreas.sandberg@arm.com 6611986Sandreas.sandberg@arm.com 6711986Sandreas.sandberg@arm.com@pytest.requires_numpy 6811986Sandreas.sandberg@arm.comdef test_dim_check_fail(arr): 6911986Sandreas.sandberg@arm.com from pybind11_tests.array import (index_at, index_at_t, offset_at, offset_at_t, data, data_t, 7011986Sandreas.sandberg@arm.com mutate_data, mutate_data_t) 7111986Sandreas.sandberg@arm.com for func in (index_at, index_at_t, offset_at, offset_at_t, data, data_t, 7211986Sandreas.sandberg@arm.com mutate_data, mutate_data_t): 7311986Sandreas.sandberg@arm.com with pytest.raises(IndexError) as excinfo: 7411986Sandreas.sandberg@arm.com func(arr, 1, 2, 3) 7511986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)' 7611986Sandreas.sandberg@arm.com 7711986Sandreas.sandberg@arm.com 7811986Sandreas.sandberg@arm.com@pytest.requires_numpy 7911986Sandreas.sandberg@arm.com@pytest.mark.parametrize('args, ret', 8011986Sandreas.sandberg@arm.com [([], [1, 2, 3, 4, 5, 6]), 8111986Sandreas.sandberg@arm.com ([1], [4, 5, 6]), 8211986Sandreas.sandberg@arm.com ([0, 1], [2, 3, 4, 5, 6]), 8311986Sandreas.sandberg@arm.com ([1, 2], [6])]) 8411986Sandreas.sandberg@arm.comdef test_data(arr, args, ret): 8511986Sandreas.sandberg@arm.com from pybind11_tests.array import data, data_t 8611986Sandreas.sandberg@arm.com assert all(data_t(arr, *args) == ret) 8711986Sandreas.sandberg@arm.com assert all(data(arr, *args)[::2] == ret) 8811986Sandreas.sandberg@arm.com assert all(data(arr, *args)[1::2] == 0) 8911986Sandreas.sandberg@arm.com 9011986Sandreas.sandberg@arm.com 9111986Sandreas.sandberg@arm.com@pytest.requires_numpy 9211986Sandreas.sandberg@arm.comdef test_mutate_readonly(arr): 9311986Sandreas.sandberg@arm.com from pybind11_tests.array import mutate_data, mutate_data_t, mutate_at_t 9411986Sandreas.sandberg@arm.com arr.flags.writeable = False 9511986Sandreas.sandberg@arm.com for func, args in (mutate_data, ()), (mutate_data_t, ()), (mutate_at_t, (0, 0)): 9611986Sandreas.sandberg@arm.com with pytest.raises(RuntimeError) as excinfo: 9711986Sandreas.sandberg@arm.com func(arr, *args) 9811986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'array is not writeable' 9911986Sandreas.sandberg@arm.com 10011986Sandreas.sandberg@arm.com 10111986Sandreas.sandberg@arm.com@pytest.requires_numpy 10211986Sandreas.sandberg@arm.com@pytest.mark.parametrize('dim', [0, 1, 3]) 10311986Sandreas.sandberg@arm.comdef test_at_fail(arr, dim): 10411986Sandreas.sandberg@arm.com from pybind11_tests.array import at_t, mutate_at_t 10511986Sandreas.sandberg@arm.com for func in at_t, mutate_at_t: 10611986Sandreas.sandberg@arm.com with pytest.raises(IndexError) as excinfo: 10711986Sandreas.sandberg@arm.com func(arr, *([0] * dim)) 10811986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim) 10911986Sandreas.sandberg@arm.com 11011986Sandreas.sandberg@arm.com 11111986Sandreas.sandberg@arm.com@pytest.requires_numpy 11211986Sandreas.sandberg@arm.comdef test_at(arr): 11311986Sandreas.sandberg@arm.com from pybind11_tests.array import at_t, mutate_at_t 11411986Sandreas.sandberg@arm.com 11511986Sandreas.sandberg@arm.com assert at_t(arr, 0, 2) == 3 11611986Sandreas.sandberg@arm.com assert at_t(arr, 1, 0) == 4 11711986Sandreas.sandberg@arm.com 11811986Sandreas.sandberg@arm.com assert all(mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) 11911986Sandreas.sandberg@arm.com assert all(mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) 12011986Sandreas.sandberg@arm.com 12111986Sandreas.sandberg@arm.com 12211986Sandreas.sandberg@arm.com@pytest.requires_numpy 12311986Sandreas.sandberg@arm.comdef test_mutate_data(arr): 12411986Sandreas.sandberg@arm.com from pybind11_tests.array import mutate_data, mutate_data_t 12511986Sandreas.sandberg@arm.com 12611986Sandreas.sandberg@arm.com assert all(mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12]) 12711986Sandreas.sandberg@arm.com assert all(mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24]) 12811986Sandreas.sandberg@arm.com assert all(mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48]) 12911986Sandreas.sandberg@arm.com assert all(mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96]) 13011986Sandreas.sandberg@arm.com assert all(mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192]) 13111986Sandreas.sandberg@arm.com 13211986Sandreas.sandberg@arm.com assert all(mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193]) 13311986Sandreas.sandberg@arm.com assert all(mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194]) 13411986Sandreas.sandberg@arm.com assert all(mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195]) 13511986Sandreas.sandberg@arm.com assert all(mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196]) 13611986Sandreas.sandberg@arm.com assert all(mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197]) 13711986Sandreas.sandberg@arm.com 13811986Sandreas.sandberg@arm.com 13911986Sandreas.sandberg@arm.com@pytest.requires_numpy 14011986Sandreas.sandberg@arm.comdef test_bounds_check(arr): 14111986Sandreas.sandberg@arm.com from pybind11_tests.array import (index_at, index_at_t, data, data_t, 14211986Sandreas.sandberg@arm.com mutate_data, mutate_data_t, at_t, mutate_at_t) 14311986Sandreas.sandberg@arm.com funcs = (index_at, index_at_t, data, data_t, 14411986Sandreas.sandberg@arm.com mutate_data, mutate_data_t, at_t, mutate_at_t) 14511986Sandreas.sandberg@arm.com for func in funcs: 14611986Sandreas.sandberg@arm.com with pytest.raises(IndexError) as excinfo: 14711986Sandreas.sandberg@arm.com func(arr, 2, 0) 14811986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2' 14911986Sandreas.sandberg@arm.com with pytest.raises(IndexError) as excinfo: 15011986Sandreas.sandberg@arm.com func(arr, 0, 4) 15111986Sandreas.sandberg@arm.com assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3' 15211986Sandreas.sandberg@arm.com 15311986Sandreas.sandberg@arm.com 15411986Sandreas.sandberg@arm.com@pytest.requires_numpy 15511986Sandreas.sandberg@arm.comdef test_make_c_f_array(): 15611986Sandreas.sandberg@arm.com from pybind11_tests.array import ( 15711986Sandreas.sandberg@arm.com make_c_array, make_f_array 15811986Sandreas.sandberg@arm.com ) 15911986Sandreas.sandberg@arm.com assert make_c_array().flags.c_contiguous 16011986Sandreas.sandberg@arm.com assert not make_c_array().flags.f_contiguous 16111986Sandreas.sandberg@arm.com assert make_f_array().flags.f_contiguous 16211986Sandreas.sandberg@arm.com assert not make_f_array().flags.c_contiguous 16311986Sandreas.sandberg@arm.com 16411986Sandreas.sandberg@arm.com 16511986Sandreas.sandberg@arm.com@pytest.requires_numpy 16611986Sandreas.sandberg@arm.comdef test_wrap(): 16711986Sandreas.sandberg@arm.com from pybind11_tests.array import wrap 16811986Sandreas.sandberg@arm.com 16911986Sandreas.sandberg@arm.com def assert_references(a, b): 17011986Sandreas.sandberg@arm.com assert a is not b 17111986Sandreas.sandberg@arm.com assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0] 17211986Sandreas.sandberg@arm.com assert a.shape == b.shape 17311986Sandreas.sandberg@arm.com assert a.strides == b.strides 17411986Sandreas.sandberg@arm.com assert a.flags.c_contiguous == b.flags.c_contiguous 17511986Sandreas.sandberg@arm.com assert a.flags.f_contiguous == b.flags.f_contiguous 17611986Sandreas.sandberg@arm.com assert a.flags.writeable == b.flags.writeable 17711986Sandreas.sandberg@arm.com assert a.flags.aligned == b.flags.aligned 17811986Sandreas.sandberg@arm.com assert a.flags.updateifcopy == b.flags.updateifcopy 17911986Sandreas.sandberg@arm.com assert np.all(a == b) 18011986Sandreas.sandberg@arm.com assert not b.flags.owndata 18111986Sandreas.sandberg@arm.com assert b.base is a 18211986Sandreas.sandberg@arm.com if a.flags.writeable and a.ndim == 2: 18311986Sandreas.sandberg@arm.com a[0, 0] = 1234 18411986Sandreas.sandberg@arm.com assert b[0, 0] == 1234 18511986Sandreas.sandberg@arm.com 18611986Sandreas.sandberg@arm.com a1 = np.array([1, 2], dtype=np.int16) 18711986Sandreas.sandberg@arm.com assert a1.flags.owndata and a1.base is None 18811986Sandreas.sandberg@arm.com a2 = wrap(a1) 18911986Sandreas.sandberg@arm.com assert_references(a1, a2) 19011986Sandreas.sandberg@arm.com 19111986Sandreas.sandberg@arm.com a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F') 19211986Sandreas.sandberg@arm.com assert a1.flags.owndata and a1.base is None 19311986Sandreas.sandberg@arm.com a2 = wrap(a1) 19411986Sandreas.sandberg@arm.com assert_references(a1, a2) 19511986Sandreas.sandberg@arm.com 19611986Sandreas.sandberg@arm.com a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C') 19711986Sandreas.sandberg@arm.com a1.flags.writeable = False 19811986Sandreas.sandberg@arm.com a2 = wrap(a1) 19911986Sandreas.sandberg@arm.com assert_references(a1, a2) 20011986Sandreas.sandberg@arm.com 20111986Sandreas.sandberg@arm.com a1 = np.random.random((4, 4, 4)) 20211986Sandreas.sandberg@arm.com a2 = wrap(a1) 20311986Sandreas.sandberg@arm.com assert_references(a1, a2) 20411986Sandreas.sandberg@arm.com 20511986Sandreas.sandberg@arm.com a1 = a1.transpose() 20611986Sandreas.sandberg@arm.com a2 = wrap(a1) 20711986Sandreas.sandberg@arm.com assert_references(a1, a2) 20811986Sandreas.sandberg@arm.com 20911986Sandreas.sandberg@arm.com a1 = a1.diagonal() 21011986Sandreas.sandberg@arm.com a2 = wrap(a1) 21111986Sandreas.sandberg@arm.com assert_references(a1, a2) 21211986Sandreas.sandberg@arm.com 21311986Sandreas.sandberg@arm.com 21411986Sandreas.sandberg@arm.com@pytest.requires_numpy 21511986Sandreas.sandberg@arm.comdef test_numpy_view(capture): 21611986Sandreas.sandberg@arm.com from pybind11_tests.array import ArrayClass 21711986Sandreas.sandberg@arm.com with capture: 21811986Sandreas.sandberg@arm.com ac = ArrayClass() 21911986Sandreas.sandberg@arm.com ac_view_1 = ac.numpy_view() 22011986Sandreas.sandberg@arm.com ac_view_2 = ac.numpy_view() 22111986Sandreas.sandberg@arm.com assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32)) 22211986Sandreas.sandberg@arm.com del ac 22311986Sandreas.sandberg@arm.com gc.collect() 22411986Sandreas.sandberg@arm.com assert capture == """ 22511986Sandreas.sandberg@arm.com ArrayClass() 22611986Sandreas.sandberg@arm.com ArrayClass::numpy_view() 22711986Sandreas.sandberg@arm.com ArrayClass::numpy_view() 22811986Sandreas.sandberg@arm.com """ 22911986Sandreas.sandberg@arm.com ac_view_1[0] = 4 23011986Sandreas.sandberg@arm.com ac_view_1[1] = 3 23111986Sandreas.sandberg@arm.com assert ac_view_2[0] == 4 23211986Sandreas.sandberg@arm.com assert ac_view_2[1] == 3 23311986Sandreas.sandberg@arm.com with capture: 23411986Sandreas.sandberg@arm.com del ac_view_1 23511986Sandreas.sandberg@arm.com del ac_view_2 23611986Sandreas.sandberg@arm.com gc.collect() 23711986Sandreas.sandberg@arm.com assert capture == """ 23811986Sandreas.sandberg@arm.com ~ArrayClass() 23911986Sandreas.sandberg@arm.com """ 24011986Sandreas.sandberg@arm.com 24111986Sandreas.sandberg@arm.com 24211986Sandreas.sandberg@arm.com@pytest.requires_numpy 24311986Sandreas.sandberg@arm.comdef test_cast_numpy_int64_to_uint64(): 24411986Sandreas.sandberg@arm.com from pybind11_tests.array import function_taking_uint64 24511986Sandreas.sandberg@arm.com function_taking_uint64(123) 24611986Sandreas.sandberg@arm.com function_taking_uint64(np.uint64(123)) 24711986Sandreas.sandberg@arm.com 24811986Sandreas.sandberg@arm.com 24911986Sandreas.sandberg@arm.com@pytest.requires_numpy 25011986Sandreas.sandberg@arm.comdef test_isinstance(): 25111986Sandreas.sandberg@arm.com from pybind11_tests.array import isinstance_untyped, isinstance_typed 25211986Sandreas.sandberg@arm.com 25311986Sandreas.sandberg@arm.com assert isinstance_untyped(np.array([1, 2, 3]), "not an array") 25411986Sandreas.sandberg@arm.com assert isinstance_typed(np.array([1.0, 2.0, 3.0])) 25511986Sandreas.sandberg@arm.com 25611986Sandreas.sandberg@arm.com 25711986Sandreas.sandberg@arm.com@pytest.requires_numpy 25811986Sandreas.sandberg@arm.comdef test_constructors(): 25911986Sandreas.sandberg@arm.com from pybind11_tests.array import default_constructors, converting_constructors 26011986Sandreas.sandberg@arm.com 26111986Sandreas.sandberg@arm.com defaults = default_constructors() 26211986Sandreas.sandberg@arm.com for a in defaults.values(): 26311986Sandreas.sandberg@arm.com assert a.size == 0 26411986Sandreas.sandberg@arm.com assert defaults["array"].dtype == np.array([]).dtype 26511986Sandreas.sandberg@arm.com assert defaults["array_t<int32>"].dtype == np.int32 26611986Sandreas.sandberg@arm.com assert defaults["array_t<double>"].dtype == np.float64 26711986Sandreas.sandberg@arm.com 26811986Sandreas.sandberg@arm.com results = converting_constructors([1, 2, 3]) 26911986Sandreas.sandberg@arm.com for a in results.values(): 27011986Sandreas.sandberg@arm.com np.testing.assert_array_equal(a, [1, 2, 3]) 27111986Sandreas.sandberg@arm.com assert results["array"].dtype == np.int_ 27211986Sandreas.sandberg@arm.com assert results["array_t<int32>"].dtype == np.int32 27311986Sandreas.sandberg@arm.com assert results["array_t<double>"].dtype == np.float64 274