112037Sandreas.sandberg@arm.comimport pytest
212037Sandreas.sandberg@arm.comimport sys
312391Sjason@lowepower.comfrom pybind11_tests import stl_binders as m
412037Sandreas.sandberg@arm.com
512037Sandreas.sandberg@arm.comwith pytest.suppress(ImportError):
612037Sandreas.sandberg@arm.com    import numpy as np
712037Sandreas.sandberg@arm.com
812037Sandreas.sandberg@arm.com
911986Sandreas.sandberg@arm.comdef test_vector_int():
1012391Sjason@lowepower.com    v_int = m.VectorInt([0, 0])
1111986Sandreas.sandberg@arm.com    assert len(v_int) == 2
1211986Sandreas.sandberg@arm.com    assert bool(v_int) is True
1311986Sandreas.sandberg@arm.com
1414299Sbbruce@ucdavis.edu    # test construction from a generator
1514299Sbbruce@ucdavis.edu    v_int1 = m.VectorInt(x for x in range(5))
1614299Sbbruce@ucdavis.edu    assert v_int1 == m.VectorInt([0, 1, 2, 3, 4])
1714299Sbbruce@ucdavis.edu
1812391Sjason@lowepower.com    v_int2 = m.VectorInt([0, 0])
1911986Sandreas.sandberg@arm.com    assert v_int == v_int2
2011986Sandreas.sandberg@arm.com    v_int2[1] = 1
2111986Sandreas.sandberg@arm.com    assert v_int != v_int2
2211986Sandreas.sandberg@arm.com
2311986Sandreas.sandberg@arm.com    v_int2.append(2)
2411986Sandreas.sandberg@arm.com    v_int2.insert(0, 1)
2511986Sandreas.sandberg@arm.com    v_int2.insert(0, 2)
2611986Sandreas.sandberg@arm.com    v_int2.insert(0, 3)
2712391Sjason@lowepower.com    v_int2.insert(6, 3)
2811986Sandreas.sandberg@arm.com    assert str(v_int2) == "VectorInt[3, 2, 1, 0, 1, 2, 3]"
2912391Sjason@lowepower.com    with pytest.raises(IndexError):
3012391Sjason@lowepower.com        v_int2.insert(8, 4)
3111986Sandreas.sandberg@arm.com
3211986Sandreas.sandberg@arm.com    v_int.append(99)
3311986Sandreas.sandberg@arm.com    v_int2[2:-2] = v_int
3412391Sjason@lowepower.com    assert v_int2 == m.VectorInt([3, 2, 0, 0, 99, 2, 3])
3511986Sandreas.sandberg@arm.com    del v_int2[1:3]
3612391Sjason@lowepower.com    assert v_int2 == m.VectorInt([3, 0, 99, 2, 3])
3711986Sandreas.sandberg@arm.com    del v_int2[0]
3812391Sjason@lowepower.com    assert v_int2 == m.VectorInt([0, 99, 2, 3])
3912391Sjason@lowepower.com
4014299Sbbruce@ucdavis.edu    v_int2.extend(m.VectorInt([4, 5]))
4114299Sbbruce@ucdavis.edu    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5])
4214299Sbbruce@ucdavis.edu
4314299Sbbruce@ucdavis.edu    v_int2.extend([6, 7])
4414299Sbbruce@ucdavis.edu    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7])
4514299Sbbruce@ucdavis.edu
4614299Sbbruce@ucdavis.edu    # test error handling, and that the vector is unchanged
4714299Sbbruce@ucdavis.edu    with pytest.raises(RuntimeError):
4814299Sbbruce@ucdavis.edu        v_int2.extend([8, 'a'])
4914299Sbbruce@ucdavis.edu
5014299Sbbruce@ucdavis.edu    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7])
5114299Sbbruce@ucdavis.edu
5214299Sbbruce@ucdavis.edu    # test extending from a generator
5314299Sbbruce@ucdavis.edu    v_int2.extend(x for x in range(5))
5414299Sbbruce@ucdavis.edu    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4])
5514299Sbbruce@ucdavis.edu
5614299Sbbruce@ucdavis.edu    # test negative indexing
5714299Sbbruce@ucdavis.edu    assert v_int2[-1] == 4
5814299Sbbruce@ucdavis.edu
5914299Sbbruce@ucdavis.edu    # insert with negative index
6014299Sbbruce@ucdavis.edu    v_int2.insert(-1, 88)
6114299Sbbruce@ucdavis.edu    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88, 4])
6214299Sbbruce@ucdavis.edu
6314299Sbbruce@ucdavis.edu    # delete negative index
6414299Sbbruce@ucdavis.edu    del v_int2[-1]
6514299Sbbruce@ucdavis.edu    assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88])
6612391Sjason@lowepower.com
6712391Sjason@lowepower.com# related to the PyPy's buffer protocol.
6812391Sjason@lowepower.com@pytest.unsupported_on_pypy
6912391Sjason@lowepower.comdef test_vector_buffer():
7012391Sjason@lowepower.com    b = bytearray([1, 2, 3, 4])
7112391Sjason@lowepower.com    v = m.VectorUChar(b)
7212391Sjason@lowepower.com    assert v[1] == 2
7312391Sjason@lowepower.com    v[2] = 5
7412391Sjason@lowepower.com    mv = memoryview(v)  # We expose the buffer interface
7512391Sjason@lowepower.com    if sys.version_info.major > 2:
7612391Sjason@lowepower.com        assert mv[2] == 5
7712391Sjason@lowepower.com        mv[2] = 6
7812391Sjason@lowepower.com    else:
7912391Sjason@lowepower.com        assert mv[2] == '\x05'
8012391Sjason@lowepower.com        mv[2] = '\x06'
8112391Sjason@lowepower.com    assert v[2] == 6
8212391Sjason@lowepower.com
8312391Sjason@lowepower.com    with pytest.raises(RuntimeError) as excinfo:
8412391Sjason@lowepower.com        m.create_undeclstruct()  # Undeclared struct contents, no buffer interface
8512391Sjason@lowepower.com    assert "NumPy type info missing for " in str(excinfo.value)
8611986Sandreas.sandberg@arm.com
8711986Sandreas.sandberg@arm.com
8812037Sandreas.sandberg@arm.com@pytest.unsupported_on_pypy
8912037Sandreas.sandberg@arm.com@pytest.requires_numpy
9012037Sandreas.sandberg@arm.comdef test_vector_buffer_numpy():
9112037Sandreas.sandberg@arm.com    a = np.array([1, 2, 3, 4], dtype=np.int32)
9212037Sandreas.sandberg@arm.com    with pytest.raises(TypeError):
9312391Sjason@lowepower.com        m.VectorInt(a)
9412037Sandreas.sandberg@arm.com
9512037Sandreas.sandberg@arm.com    a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uintc)
9612391Sjason@lowepower.com    v = m.VectorInt(a[0, :])
9712037Sandreas.sandberg@arm.com    assert len(v) == 4
9812037Sandreas.sandberg@arm.com    assert v[2] == 3
9912391Sjason@lowepower.com    ma = np.asarray(v)
10012391Sjason@lowepower.com    ma[2] = 5
10112037Sandreas.sandberg@arm.com    assert v[2] == 5
10212037Sandreas.sandberg@arm.com
10312391Sjason@lowepower.com    v = m.VectorInt(a[:, 1])
10412037Sandreas.sandberg@arm.com    assert len(v) == 3
10512037Sandreas.sandberg@arm.com    assert v[2] == 10
10612037Sandreas.sandberg@arm.com
10712391Sjason@lowepower.com    v = m.get_vectorstruct()
10812037Sandreas.sandberg@arm.com    assert v[0].x == 5
10912391Sjason@lowepower.com    ma = np.asarray(v)
11012391Sjason@lowepower.com    ma[1]['x'] = 99
11112037Sandreas.sandberg@arm.com    assert v[1].x == 99
11212037Sandreas.sandberg@arm.com
11312391Sjason@lowepower.com    v = m.VectorStruct(np.zeros(3, dtype=np.dtype([('w', 'bool'), ('x', 'I'),
11412391Sjason@lowepower.com                                                   ('y', 'float64'), ('z', 'bool')], align=True)))
11512037Sandreas.sandberg@arm.com    assert len(v) == 3
11612037Sandreas.sandberg@arm.com
11712037Sandreas.sandberg@arm.com
11812391Sjason@lowepower.comdef test_vector_bool():
11912391Sjason@lowepower.com    import pybind11_cross_module_tests as cm
12011986Sandreas.sandberg@arm.com
12112391Sjason@lowepower.com    vv_c = cm.VectorBool()
12211986Sandreas.sandberg@arm.com    for i in range(10):
12311986Sandreas.sandberg@arm.com        vv_c.append(i % 2 == 0)
12411986Sandreas.sandberg@arm.com    for i in range(10):
12511986Sandreas.sandberg@arm.com        assert vv_c[i] == (i % 2 == 0)
12611986Sandreas.sandberg@arm.com    assert str(vv_c) == "VectorBool[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]"
12711986Sandreas.sandberg@arm.com
12811986Sandreas.sandberg@arm.com
12912391Sjason@lowepower.comdef test_vector_custom():
13012391Sjason@lowepower.com    v_a = m.VectorEl()
13112391Sjason@lowepower.com    v_a.append(m.El(1))
13212391Sjason@lowepower.com    v_a.append(m.El(2))
13312391Sjason@lowepower.com    assert str(v_a) == "VectorEl[El{1}, El{2}]"
13412391Sjason@lowepower.com
13512391Sjason@lowepower.com    vv_a = m.VectorVectorEl()
13612391Sjason@lowepower.com    vv_a.append(v_a)
13712391Sjason@lowepower.com    vv_b = vv_a[0]
13812391Sjason@lowepower.com    assert str(vv_b) == "VectorEl[El{1}, El{2}]"
13912391Sjason@lowepower.com
14012391Sjason@lowepower.com
14111986Sandreas.sandberg@arm.comdef test_map_string_double():
14212391Sjason@lowepower.com    mm = m.MapStringDouble()
14312391Sjason@lowepower.com    mm['a'] = 1
14412391Sjason@lowepower.com    mm['b'] = 2.5
14511986Sandreas.sandberg@arm.com
14612391Sjason@lowepower.com    assert list(mm) == ['a', 'b']
14712391Sjason@lowepower.com    assert list(mm.items()) == [('a', 1), ('b', 2.5)]
14812391Sjason@lowepower.com    assert str(mm) == "MapStringDouble{a: 1, b: 2.5}"
14911986Sandreas.sandberg@arm.com
15012391Sjason@lowepower.com    um = m.UnorderedMapStringDouble()
15111986Sandreas.sandberg@arm.com    um['ua'] = 1.1
15211986Sandreas.sandberg@arm.com    um['ub'] = 2.6
15311986Sandreas.sandberg@arm.com
15411986Sandreas.sandberg@arm.com    assert sorted(list(um)) == ['ua', 'ub']
15511986Sandreas.sandberg@arm.com    assert sorted(list(um.items())) == [('ua', 1.1), ('ub', 2.6)]
15611986Sandreas.sandberg@arm.com    assert "UnorderedMapStringDouble" in str(um)
15711986Sandreas.sandberg@arm.com
15811986Sandreas.sandberg@arm.com
15911986Sandreas.sandberg@arm.comdef test_map_string_double_const():
16012391Sjason@lowepower.com    mc = m.MapStringDoubleConst()
16111986Sandreas.sandberg@arm.com    mc['a'] = 10
16211986Sandreas.sandberg@arm.com    mc['b'] = 20.5
16311986Sandreas.sandberg@arm.com    assert str(mc) == "MapStringDoubleConst{a: 10, b: 20.5}"
16411986Sandreas.sandberg@arm.com
16512391Sjason@lowepower.com    umc = m.UnorderedMapStringDoubleConst()
16611986Sandreas.sandberg@arm.com    umc['a'] = 11
16711986Sandreas.sandberg@arm.com    umc['b'] = 21.5
16811986Sandreas.sandberg@arm.com
16911986Sandreas.sandberg@arm.com    str(umc)
17011986Sandreas.sandberg@arm.com
17111986Sandreas.sandberg@arm.com
17212391Sjason@lowepower.comdef test_noncopyable_containers():
17312391Sjason@lowepower.com    # std::vector
17412391Sjason@lowepower.com    vnc = m.get_vnc(5)
17511986Sandreas.sandberg@arm.com    for i in range(0, 5):
17611986Sandreas.sandberg@arm.com        assert vnc[i].value == i + 1
17711986Sandreas.sandberg@arm.com
17811986Sandreas.sandberg@arm.com    for i, j in enumerate(vnc, start=1):
17911986Sandreas.sandberg@arm.com        assert j.value == i
18011986Sandreas.sandberg@arm.com
18112391Sjason@lowepower.com    # std::deque
18212391Sjason@lowepower.com    dnc = m.get_dnc(5)
18311986Sandreas.sandberg@arm.com    for i in range(0, 5):
18411986Sandreas.sandberg@arm.com        assert dnc[i].value == i + 1
18511986Sandreas.sandberg@arm.com
18611986Sandreas.sandberg@arm.com    i = 1
18711986Sandreas.sandberg@arm.com    for j in dnc:
18811986Sandreas.sandberg@arm.com        assert(j.value == i)
18911986Sandreas.sandberg@arm.com        i += 1
19011986Sandreas.sandberg@arm.com
19112391Sjason@lowepower.com    # std::map
19212391Sjason@lowepower.com    mnc = m.get_mnc(5)
19311986Sandreas.sandberg@arm.com    for i in range(1, 6):
19411986Sandreas.sandberg@arm.com        assert mnc[i].value == 10 * i
19511986Sandreas.sandberg@arm.com
19611986Sandreas.sandberg@arm.com    vsum = 0
19711986Sandreas.sandberg@arm.com    for k, v in mnc.items():
19811986Sandreas.sandberg@arm.com        assert v.value == 10 * k
19911986Sandreas.sandberg@arm.com        vsum += v.value
20011986Sandreas.sandberg@arm.com
20111986Sandreas.sandberg@arm.com    assert vsum == 150
20211986Sandreas.sandberg@arm.com
20312391Sjason@lowepower.com    # std::unordered_map
20412391Sjason@lowepower.com    mnc = m.get_umnc(5)
20511986Sandreas.sandberg@arm.com    for i in range(1, 6):
20611986Sandreas.sandberg@arm.com        assert mnc[i].value == 10 * i
20711986Sandreas.sandberg@arm.com
20811986Sandreas.sandberg@arm.com    vsum = 0
20911986Sandreas.sandberg@arm.com    for k, v in mnc.items():
21011986Sandreas.sandberg@arm.com        assert v.value == 10 * k
21111986Sandreas.sandberg@arm.com        vsum += v.value
21211986Sandreas.sandberg@arm.com
21311986Sandreas.sandberg@arm.com    assert vsum == 150
21414299Sbbruce@ucdavis.edu
21514299Sbbruce@ucdavis.edu
21614299Sbbruce@ucdavis.edudef test_map_delitem():
21714299Sbbruce@ucdavis.edu    mm = m.MapStringDouble()
21814299Sbbruce@ucdavis.edu    mm['a'] = 1
21914299Sbbruce@ucdavis.edu    mm['b'] = 2.5
22014299Sbbruce@ucdavis.edu
22114299Sbbruce@ucdavis.edu    assert list(mm) == ['a', 'b']
22214299Sbbruce@ucdavis.edu    assert list(mm.items()) == [('a', 1), ('b', 2.5)]
22314299Sbbruce@ucdavis.edu    del mm['a']
22414299Sbbruce@ucdavis.edu    assert list(mm) == ['b']
22514299Sbbruce@ucdavis.edu    assert list(mm.items()) == [('b', 2.5)]
22614299Sbbruce@ucdavis.edu
22714299Sbbruce@ucdavis.edu    um = m.UnorderedMapStringDouble()
22814299Sbbruce@ucdavis.edu    um['ua'] = 1.1
22914299Sbbruce@ucdavis.edu    um['ub'] = 2.6
23014299Sbbruce@ucdavis.edu
23114299Sbbruce@ucdavis.edu    assert sorted(list(um)) == ['ua', 'ub']
23214299Sbbruce@ucdavis.edu    assert sorted(list(um.items())) == [('ua', 1.1), ('ub', 2.6)]
23314299Sbbruce@ucdavis.edu    del um['ua']
23414299Sbbruce@ucdavis.edu    assert sorted(list(um)) == ['ub']
23514299Sbbruce@ucdavis.edu    assert sorted(list(um.items())) == [('ub', 2.6)]
236