test_numpy_array.py revision 11986
13560SN/Aimport pytest
23560SN/Aimport gc
33560SN/A
43560SN/Awith pytest.suppress(ImportError):
53560SN/A    import numpy as np
63560SN/A
73560SN/A
83560SN/A@pytest.fixture(scope='function')
93560SN/Adef arr():
103560SN/A    return np.array([[1, 2, 3], [4, 5, 6]], '<u2')
113560SN/A
123560SN/A
133560SN/A@pytest.requires_numpy
143560SN/Adef test_array_attributes():
153560SN/A    from pybind11_tests.array import (
163560SN/A        ndim, shape, strides, writeable, size, itemsize, nbytes, owndata
173560SN/A    )
183560SN/A
193560SN/A    a = np.array(0, 'f8')
203560SN/A    assert ndim(a) == 0
213560SN/A    assert all(shape(a) == [])
223560SN/A    assert all(strides(a) == [])
233560SN/A    with pytest.raises(IndexError) as excinfo:
243560SN/A        shape(a, 0)
253560SN/A    assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
263560SN/A    with pytest.raises(IndexError) as excinfo:
273560SN/A        strides(a, 0)
283560SN/A    assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
293560SN/A    assert writeable(a)
303560SN/A    assert size(a) == 1
313560SN/A    assert itemsize(a) == 8
3211793Sbrandon.potter@amd.com    assert nbytes(a) == 8
3311793Sbrandon.potter@amd.com    assert owndata(a)
343560SN/A
353560SN/A    a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
363560SN/A    a.flags.writeable = False
373560SN/A    assert ndim(a) == 2
3811793Sbrandon.potter@amd.com    assert all(shape(a) == [2, 3])
399329Sdam.sunwoo@arm.com    assert shape(a, 0) == 2
403560SN/A    assert shape(a, 1) == 3
413560SN/A    assert all(strides(a) == [6, 2])
428232Snate@binkert.org    assert strides(a, 0) == 6
433560SN/A    assert strides(a, 1) == 2
443560SN/A    with pytest.raises(IndexError) as excinfo:
453560SN/A        shape(a, 2)
463560SN/A    assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
473560SN/A    with pytest.raises(IndexError) as excinfo:
483560SN/A        strides(a, 2)
493560SN/A    assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
503560SN/A    assert not writeable(a)
513560SN/A    assert size(a) == 6
523560SN/A    assert itemsize(a) == 2
5312181Sgabeblack@google.com    assert nbytes(a) == 12
5412181Sgabeblack@google.com    assert not owndata(a)
553560SN/A
563560SN/A
573560SN/A@pytest.requires_numpy
583560SN/A@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)])
593560SN/Adef test_index_offset(arr, args, ret):
603560SN/A    from pybind11_tests.array import index_at, index_at_t, offset_at, offset_at_t
613560SN/A    assert index_at(arr, *args) == ret
623560SN/A    assert index_at_t(arr, *args) == ret
633560SN/A    assert offset_at(arr, *args) == ret * arr.dtype.itemsize
643560SN/A    assert offset_at_t(arr, *args) == ret * arr.dtype.itemsize
653560SN/A
663560SN/A
673560SN/A@pytest.requires_numpy
683560SN/Adef test_dim_check_fail(arr):
693560SN/A    from pybind11_tests.array import (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
703560SN/A                                      mutate_data, mutate_data_t)
713560SN/A    for func in (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
723560SN/A                 mutate_data, mutate_data_t):
733560SN/A        with pytest.raises(IndexError) as excinfo:
743560SN/A            func(arr, 1, 2, 3)
753560SN/A        assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)'
763560SN/A
773560SN/A
783560SN/A@pytest.requires_numpy
793560SN/A@pytest.mark.parametrize('args, ret',
803560SN/A                         [([], [1, 2, 3, 4, 5, 6]),
813560SN/A                          ([1], [4, 5, 6]),
823560SN/A                          ([0, 1], [2, 3, 4, 5, 6]),
833560SN/A                          ([1, 2], [6])])
843560SN/Adef test_data(arr, args, ret):
853560SN/A    from pybind11_tests.array import data, data_t
863560SN/A    assert all(data_t(arr, *args) == ret)
873560SN/A    assert all(data(arr, *args)[::2] == ret)
883560SN/A    assert all(data(arr, *args)[1::2] == 0)
893560SN/A
903560SN/A
913560SN/A@pytest.requires_numpy
923560SN/Adef test_mutate_readonly(arr):
933560SN/A    from pybind11_tests.array import mutate_data, mutate_data_t, mutate_at_t
943560SN/A    arr.flags.writeable = False
953560SN/A    for func, args in (mutate_data, ()), (mutate_data_t, ()), (mutate_at_t, (0, 0)):
963560SN/A        with pytest.raises(RuntimeError) as excinfo:
973560SN/A            func(arr, *args)
983560SN/A        assert str(excinfo.value) == 'array is not writeable'
993560SN/A
1003560SN/A
1013560SN/A@pytest.requires_numpy
1023560SN/A@pytest.mark.parametrize('dim', [0, 1, 3])
1033560SN/Adef test_at_fail(arr, dim):
1043560SN/A    from pybind11_tests.array import at_t, mutate_at_t
1053560SN/A    for func in at_t, mutate_at_t:
1063560SN/A        with pytest.raises(IndexError) as excinfo:
1073560SN/A            func(arr, *([0] * dim))
1083560SN/A        assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim)
1093560SN/A
1103560SN/A
1113560SN/A@pytest.requires_numpy
1123560SN/Adef test_at(arr):
1133560SN/A    from pybind11_tests.array import at_t, mutate_at_t
1143560SN/A
1153560SN/A    assert at_t(arr, 0, 2) == 3
1163560SN/A    assert at_t(arr, 1, 0) == 4
1173560SN/A
1183560SN/A    assert all(mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
1193560SN/A    assert all(mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
1203560SN/A
1213560SN/A
1223560SN/A@pytest.requires_numpy
1233560SN/Adef test_mutate_data(arr):
1243560SN/A    from pybind11_tests.array import mutate_data, mutate_data_t
1253560SN/A
1263560SN/A    assert all(mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12])
1273560SN/A    assert all(mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24])
1283560SN/A    assert all(mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48])
1293560SN/A    assert all(mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96])
1303560SN/A    assert all(mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192])
1313560SN/A
1323560SN/A    assert all(mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193])
1333560SN/A    assert all(mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194])
1343560SN/A    assert all(mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195])
1353560SN/A    assert all(mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196])
1363560SN/A    assert all(mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
1373560SN/A
1383560SN/A
1393560SN/A@pytest.requires_numpy
1403560SN/Adef test_bounds_check(arr):
1413560SN/A    from pybind11_tests.array import (index_at, index_at_t, data, data_t,
1425191Ssaidi@eecs.umich.edu                                      mutate_data, mutate_data_t, at_t, mutate_at_t)
1435191Ssaidi@eecs.umich.edu    funcs = (index_at, index_at_t, data, data_t,
1445191Ssaidi@eecs.umich.edu             mutate_data, mutate_data_t, at_t, mutate_at_t)
1453560SN/A    for func in funcs:
1463560SN/A        with pytest.raises(IndexError) as excinfo:
1477823Ssteve.reinhardt@amd.com            func(arr, 2, 0)
1483560SN/A        assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2'
1497823Ssteve.reinhardt@amd.com        with pytest.raises(IndexError) as excinfo:
1503560SN/A            func(arr, 0, 4)
1513560SN/A        assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
1523560SN/A
1533560SN/A
1543560SN/A@pytest.requires_numpy
1553560SN/Adef test_make_c_f_array():
1565568Snate@binkert.org    from pybind11_tests.array import (
1573560SN/A        make_c_array, make_f_array
1583560SN/A    )
1593560SN/A    assert make_c_array().flags.c_contiguous
1603560SN/A    assert not make_c_array().flags.f_contiguous
1613560SN/A    assert make_f_array().flags.f_contiguous
1623560SN/A    assert not make_f_array().flags.c_contiguous
1633560SN/A
1643560SN/A
1653560SN/A@pytest.requires_numpy
1663560SN/Adef test_wrap():
1673560SN/A    from pybind11_tests.array import wrap
1683560SN/A
1693560SN/A    def assert_references(a, b):
1703560SN/A        assert a is not b
1715191Ssaidi@eecs.umich.edu        assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0]
1725191Ssaidi@eecs.umich.edu        assert a.shape == b.shape
1735191Ssaidi@eecs.umich.edu        assert a.strides == b.strides
1745191Ssaidi@eecs.umich.edu        assert a.flags.c_contiguous == b.flags.c_contiguous
1753560SN/A        assert a.flags.f_contiguous == b.flags.f_contiguous
1763560SN/A        assert a.flags.writeable == b.flags.writeable
1773560SN/A        assert a.flags.aligned == b.flags.aligned
1783560SN/A        assert a.flags.updateifcopy == b.flags.updateifcopy
1793560SN/A        assert np.all(a == b)
1803560SN/A        assert not b.flags.owndata
1813560SN/A        assert b.base is a
1823560SN/A        if a.flags.writeable and a.ndim == 2:
1833560SN/A            a[0, 0] = 1234
1843560SN/A            assert b[0, 0] == 1234
1853560SN/A
1863560SN/A    a1 = np.array([1, 2], dtype=np.int16)
18710905Sandreas.sandberg@arm.com    assert a1.flags.owndata and a1.base is None
1883560SN/A    a2 = wrap(a1)
18910905Sandreas.sandberg@arm.com    assert_references(a1, a2)
1903560SN/A
1913560SN/A    a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F')
1923560SN/A    assert a1.flags.owndata and a1.base is None
1933560SN/A    a2 = wrap(a1)
1943560SN/A    assert_references(a1, a2)
1953560SN/A
1963560SN/A    a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C')
19710905Sandreas.sandberg@arm.com    a1.flags.writeable = False
1983560SN/A    a2 = wrap(a1)
19910905Sandreas.sandberg@arm.com    assert_references(a1, a2)
2003560SN/A
2013560SN/A    a1 = np.random.random((4, 4, 4))
2023560SN/A    a2 = wrap(a1)
2033560SN/A    assert_references(a1, a2)
2043560SN/A
2053560SN/A    a1 = a1.transpose()
2063560SN/A    a2 = wrap(a1)
2075568Snate@binkert.org    assert_references(a1, a2)
2085568Snate@binkert.org
209    a1 = a1.diagonal()
210    a2 = wrap(a1)
211    assert_references(a1, a2)
212
213
214@pytest.requires_numpy
215def test_numpy_view(capture):
216    from pybind11_tests.array import ArrayClass
217    with capture:
218        ac = ArrayClass()
219        ac_view_1 = ac.numpy_view()
220        ac_view_2 = ac.numpy_view()
221        assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
222        del ac
223        gc.collect()
224    assert capture == """
225        ArrayClass()
226        ArrayClass::numpy_view()
227        ArrayClass::numpy_view()
228    """
229    ac_view_1[0] = 4
230    ac_view_1[1] = 3
231    assert ac_view_2[0] == 4
232    assert ac_view_2[1] == 3
233    with capture:
234        del ac_view_1
235        del ac_view_2
236        gc.collect()
237    assert capture == """
238        ~ArrayClass()
239    """
240
241
242@pytest.requires_numpy
243def test_cast_numpy_int64_to_uint64():
244    from pybind11_tests.array import function_taking_uint64
245    function_taking_uint64(123)
246    function_taking_uint64(np.uint64(123))
247
248
249@pytest.requires_numpy
250def test_isinstance():
251    from pybind11_tests.array import isinstance_untyped, isinstance_typed
252
253    assert isinstance_untyped(np.array([1, 2, 3]), "not an array")
254    assert isinstance_typed(np.array([1.0, 2.0, 3.0]))
255
256
257@pytest.requires_numpy
258def test_constructors():
259    from pybind11_tests.array import default_constructors, converting_constructors
260
261    defaults = default_constructors()
262    for a in defaults.values():
263        assert a.size == 0
264    assert defaults["array"].dtype == np.array([]).dtype
265    assert defaults["array_t<int32>"].dtype == np.int32
266    assert defaults["array_t<double>"].dtype == np.float64
267
268    results = converting_constructors([1, 2, 3])
269    for a in results.values():
270        np.testing.assert_array_equal(a, [1, 2, 3])
271    assert results["array"].dtype == np.int_
272    assert results["array_t<int32>"].dtype == np.int32
273    assert results["array_t<double>"].dtype == np.float64
274