test_eigen.py revision 11986
111986Sandreas.sandberg@arm.comimport pytest
211986Sandreas.sandberg@arm.com
311986Sandreas.sandberg@arm.comwith pytest.suppress(ImportError):
411986Sandreas.sandberg@arm.com    import numpy as np
511986Sandreas.sandberg@arm.com
611986Sandreas.sandberg@arm.com    ref = np.array([[ 0,  3,  0,  0,  0, 11],
711986Sandreas.sandberg@arm.com                    [22,  0,  0,  0, 17, 11],
811986Sandreas.sandberg@arm.com                    [ 7,  5,  0,  1,  0, 11],
911986Sandreas.sandberg@arm.com                    [ 0,  0,  0,  0,  0, 11],
1011986Sandreas.sandberg@arm.com                    [ 0,  0, 14,  0,  8, 11]])
1111986Sandreas.sandberg@arm.com
1211986Sandreas.sandberg@arm.com
1311986Sandreas.sandberg@arm.comdef assert_equal_ref(mat):
1411986Sandreas.sandberg@arm.com    np.testing.assert_array_equal(mat, ref)
1511986Sandreas.sandberg@arm.com
1611986Sandreas.sandberg@arm.com
1711986Sandreas.sandberg@arm.comdef assert_sparse_equal_ref(sparse_mat):
1811986Sandreas.sandberg@arm.com    assert_equal_ref(sparse_mat.todense())
1911986Sandreas.sandberg@arm.com
2011986Sandreas.sandberg@arm.com
2111986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy
2211986Sandreas.sandberg@arm.comdef test_fixed():
2311986Sandreas.sandberg@arm.com    from pybind11_tests import fixed_r, fixed_c, fixed_passthrough_r, fixed_passthrough_c
2411986Sandreas.sandberg@arm.com
2511986Sandreas.sandberg@arm.com    assert_equal_ref(fixed_c())
2611986Sandreas.sandberg@arm.com    assert_equal_ref(fixed_r())
2711986Sandreas.sandberg@arm.com    assert_equal_ref(fixed_passthrough_r(fixed_r()))
2811986Sandreas.sandberg@arm.com    assert_equal_ref(fixed_passthrough_c(fixed_c()))
2911986Sandreas.sandberg@arm.com    assert_equal_ref(fixed_passthrough_r(fixed_c()))
3011986Sandreas.sandberg@arm.com    assert_equal_ref(fixed_passthrough_c(fixed_r()))
3111986Sandreas.sandberg@arm.com
3211986Sandreas.sandberg@arm.com
3311986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy
3411986Sandreas.sandberg@arm.comdef test_dense():
3511986Sandreas.sandberg@arm.com    from pybind11_tests import dense_r, dense_c, dense_passthrough_r, dense_passthrough_c
3611986Sandreas.sandberg@arm.com
3711986Sandreas.sandberg@arm.com    assert_equal_ref(dense_r())
3811986Sandreas.sandberg@arm.com    assert_equal_ref(dense_c())
3911986Sandreas.sandberg@arm.com    assert_equal_ref(dense_passthrough_r(dense_r()))
4011986Sandreas.sandberg@arm.com    assert_equal_ref(dense_passthrough_c(dense_c()))
4111986Sandreas.sandberg@arm.com    assert_equal_ref(dense_passthrough_r(dense_c()))
4211986Sandreas.sandberg@arm.com    assert_equal_ref(dense_passthrough_c(dense_r()))
4311986Sandreas.sandberg@arm.com
4411986Sandreas.sandberg@arm.com
4511986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy
4611986Sandreas.sandberg@arm.comdef test_nonunit_stride_from_python():
4711986Sandreas.sandberg@arm.com    from pybind11_tests import double_row, double_col, double_mat_cm, double_mat_rm
4811986Sandreas.sandberg@arm.com
4911986Sandreas.sandberg@arm.com    counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3))
5011986Sandreas.sandberg@arm.com    first_row = counting_mat[0, :]
5111986Sandreas.sandberg@arm.com    first_col = counting_mat[:, 0]
5211986Sandreas.sandberg@arm.com    assert np.array_equal(double_row(first_row), 2.0 * first_row)
5311986Sandreas.sandberg@arm.com    assert np.array_equal(double_col(first_row), 2.0 * first_row)
5411986Sandreas.sandberg@arm.com    assert np.array_equal(double_row(first_col), 2.0 * first_col)
5511986Sandreas.sandberg@arm.com    assert np.array_equal(double_col(first_col), 2.0 * first_col)
5611986Sandreas.sandberg@arm.com
5711986Sandreas.sandberg@arm.com    counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3))
5811986Sandreas.sandberg@arm.com    slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]]
5911986Sandreas.sandberg@arm.com    for slice_idx, ref_mat in enumerate(slices):
6011986Sandreas.sandberg@arm.com        assert np.array_equal(double_mat_cm(ref_mat), 2.0 * ref_mat)
6111986Sandreas.sandberg@arm.com        assert np.array_equal(double_mat_rm(ref_mat), 2.0 * ref_mat)
6211986Sandreas.sandberg@arm.com
6311986Sandreas.sandberg@arm.com
6411986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy
6511986Sandreas.sandberg@arm.comdef test_nonunit_stride_to_python():
6611986Sandreas.sandberg@arm.com    from pybind11_tests import diagonal, diagonal_1, diagonal_n, block
6711986Sandreas.sandberg@arm.com
6811986Sandreas.sandberg@arm.com    assert np.all(diagonal(ref) == ref.diagonal())
6911986Sandreas.sandberg@arm.com    assert np.all(diagonal_1(ref) == ref.diagonal(1))
7011986Sandreas.sandberg@arm.com    for i in range(-5, 7):
7111986Sandreas.sandberg@arm.com        assert np.all(diagonal_n(ref, i) == ref.diagonal(i)), "diagonal_n({})".format(i)
7211986Sandreas.sandberg@arm.com
7311986Sandreas.sandberg@arm.com    assert np.all(block(ref, 2, 1, 3, 3) == ref[2:5, 1:4])
7411986Sandreas.sandberg@arm.com    assert np.all(block(ref, 1, 4, 4, 2) == ref[1:, 4:])
7511986Sandreas.sandberg@arm.com    assert np.all(block(ref, 1, 4, 3, 2) == ref[1:4, 4:])
7611986Sandreas.sandberg@arm.com
7711986Sandreas.sandberg@arm.com
7811986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy
7911986Sandreas.sandberg@arm.comdef test_eigen_ref_to_python():
8011986Sandreas.sandberg@arm.com    from pybind11_tests import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
8111986Sandreas.sandberg@arm.com
8211986Sandreas.sandberg@arm.com    chols = [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]
8311986Sandreas.sandberg@arm.com    for i, chol in enumerate(chols, start=1):
8411986Sandreas.sandberg@arm.com        mymat = chol(np.array([[1, 2, 4], [2, 13, 23], [4, 23, 77]]))
8511986Sandreas.sandberg@arm.com        assert np.all(mymat == np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])), "cholesky{}".format(i)
8611986Sandreas.sandberg@arm.com
8711986Sandreas.sandberg@arm.com
8811986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy
8911986Sandreas.sandberg@arm.comdef test_special_matrix_objects():
9011986Sandreas.sandberg@arm.com    from pybind11_tests import incr_diag, symmetric_upper, symmetric_lower
9111986Sandreas.sandberg@arm.com
9211986Sandreas.sandberg@arm.com    assert np.all(incr_diag(7) == np.diag([1, 2, 3, 4, 5, 6, 7]))
9311986Sandreas.sandberg@arm.com
9411986Sandreas.sandberg@arm.com    asymm = np.array([[ 1,  2,  3,  4],
9511986Sandreas.sandberg@arm.com                      [ 5,  6,  7,  8],
9611986Sandreas.sandberg@arm.com                      [ 9, 10, 11, 12],
9711986Sandreas.sandberg@arm.com                      [13, 14, 15, 16]])
9811986Sandreas.sandberg@arm.com    symm_lower = np.array(asymm)
9911986Sandreas.sandberg@arm.com    symm_upper = np.array(asymm)
10011986Sandreas.sandberg@arm.com    for i in range(4):
10111986Sandreas.sandberg@arm.com        for j in range(i + 1, 4):
10211986Sandreas.sandberg@arm.com            symm_lower[i, j] = symm_lower[j, i]
10311986Sandreas.sandberg@arm.com            symm_upper[j, i] = symm_upper[i, j]
10411986Sandreas.sandberg@arm.com
10511986Sandreas.sandberg@arm.com    assert np.all(symmetric_lower(asymm) == symm_lower)
10611986Sandreas.sandberg@arm.com    assert np.all(symmetric_upper(asymm) == symm_upper)
10711986Sandreas.sandberg@arm.com
10811986Sandreas.sandberg@arm.com
10911986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy
11011986Sandreas.sandberg@arm.comdef test_dense_signature(doc):
11111986Sandreas.sandberg@arm.com    from pybind11_tests import double_col, double_row, double_mat_rm
11211986Sandreas.sandberg@arm.com
11311986Sandreas.sandberg@arm.com    assert doc(double_col) == """
11411986Sandreas.sandberg@arm.com        double_col(arg0: numpy.ndarray[float32[m, 1]]) -> numpy.ndarray[float32[m, 1]]
11511986Sandreas.sandberg@arm.com    """
11611986Sandreas.sandberg@arm.com    assert doc(double_row) == """
11711986Sandreas.sandberg@arm.com        double_row(arg0: numpy.ndarray[float32[1, n]]) -> numpy.ndarray[float32[1, n]]
11811986Sandreas.sandberg@arm.com    """
11911986Sandreas.sandberg@arm.com    assert doc(double_mat_rm) == """
12011986Sandreas.sandberg@arm.com        double_mat_rm(arg0: numpy.ndarray[float32[m, n]]) -> numpy.ndarray[float32[m, n]]
12111986Sandreas.sandberg@arm.com    """
12211986Sandreas.sandberg@arm.com
12311986Sandreas.sandberg@arm.com
12411986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_scipy
12511986Sandreas.sandberg@arm.comdef test_sparse():
12611986Sandreas.sandberg@arm.com    from pybind11_tests import sparse_r, sparse_c, sparse_passthrough_r, sparse_passthrough_c
12711986Sandreas.sandberg@arm.com
12811986Sandreas.sandberg@arm.com    assert_sparse_equal_ref(sparse_r())
12911986Sandreas.sandberg@arm.com    assert_sparse_equal_ref(sparse_c())
13011986Sandreas.sandberg@arm.com    assert_sparse_equal_ref(sparse_passthrough_r(sparse_r()))
13111986Sandreas.sandberg@arm.com    assert_sparse_equal_ref(sparse_passthrough_c(sparse_c()))
13211986Sandreas.sandberg@arm.com    assert_sparse_equal_ref(sparse_passthrough_r(sparse_c()))
13311986Sandreas.sandberg@arm.com    assert_sparse_equal_ref(sparse_passthrough_c(sparse_r()))
13411986Sandreas.sandberg@arm.com
13511986Sandreas.sandberg@arm.com
13611986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_scipy
13711986Sandreas.sandberg@arm.comdef test_sparse_signature(doc):
13811986Sandreas.sandberg@arm.com    from pybind11_tests import sparse_passthrough_r, sparse_passthrough_c
13911986Sandreas.sandberg@arm.com
14011986Sandreas.sandberg@arm.com    assert doc(sparse_passthrough_r) == """
14111986Sandreas.sandberg@arm.com        sparse_passthrough_r(arg0: scipy.sparse.csr_matrix[float32]) -> scipy.sparse.csr_matrix[float32]
14211986Sandreas.sandberg@arm.com    """  # noqa: E501 line too long
14311986Sandreas.sandberg@arm.com    assert doc(sparse_passthrough_c) == """
14411986Sandreas.sandberg@arm.com        sparse_passthrough_c(arg0: scipy.sparse.csc_matrix[float32]) -> scipy.sparse.csc_matrix[float32]
14511986Sandreas.sandberg@arm.com    """  # noqa: E501 line too long
146