test_eigen.py revision 11986
111986Sandreas.sandberg@arm.comimport pytest 211986Sandreas.sandberg@arm.com 311986Sandreas.sandberg@arm.comwith pytest.suppress(ImportError): 411986Sandreas.sandberg@arm.com import numpy as np 511986Sandreas.sandberg@arm.com 611986Sandreas.sandberg@arm.com ref = np.array([[ 0, 3, 0, 0, 0, 11], 711986Sandreas.sandberg@arm.com [22, 0, 0, 0, 17, 11], 811986Sandreas.sandberg@arm.com [ 7, 5, 0, 1, 0, 11], 911986Sandreas.sandberg@arm.com [ 0, 0, 0, 0, 0, 11], 1011986Sandreas.sandberg@arm.com [ 0, 0, 14, 0, 8, 11]]) 1111986Sandreas.sandberg@arm.com 1211986Sandreas.sandberg@arm.com 1311986Sandreas.sandberg@arm.comdef assert_equal_ref(mat): 1411986Sandreas.sandberg@arm.com np.testing.assert_array_equal(mat, ref) 1511986Sandreas.sandberg@arm.com 1611986Sandreas.sandberg@arm.com 1711986Sandreas.sandberg@arm.comdef assert_sparse_equal_ref(sparse_mat): 1811986Sandreas.sandberg@arm.com assert_equal_ref(sparse_mat.todense()) 1911986Sandreas.sandberg@arm.com 2011986Sandreas.sandberg@arm.com 2111986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy 2211986Sandreas.sandberg@arm.comdef test_fixed(): 2311986Sandreas.sandberg@arm.com from pybind11_tests import fixed_r, fixed_c, fixed_passthrough_r, fixed_passthrough_c 2411986Sandreas.sandberg@arm.com 2511986Sandreas.sandberg@arm.com assert_equal_ref(fixed_c()) 2611986Sandreas.sandberg@arm.com assert_equal_ref(fixed_r()) 2711986Sandreas.sandberg@arm.com assert_equal_ref(fixed_passthrough_r(fixed_r())) 2811986Sandreas.sandberg@arm.com assert_equal_ref(fixed_passthrough_c(fixed_c())) 2911986Sandreas.sandberg@arm.com assert_equal_ref(fixed_passthrough_r(fixed_c())) 3011986Sandreas.sandberg@arm.com assert_equal_ref(fixed_passthrough_c(fixed_r())) 3111986Sandreas.sandberg@arm.com 3211986Sandreas.sandberg@arm.com 3311986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy 3411986Sandreas.sandberg@arm.comdef test_dense(): 3511986Sandreas.sandberg@arm.com from pybind11_tests import dense_r, dense_c, dense_passthrough_r, dense_passthrough_c 3611986Sandreas.sandberg@arm.com 3711986Sandreas.sandberg@arm.com assert_equal_ref(dense_r()) 3811986Sandreas.sandberg@arm.com assert_equal_ref(dense_c()) 3911986Sandreas.sandberg@arm.com assert_equal_ref(dense_passthrough_r(dense_r())) 4011986Sandreas.sandberg@arm.com assert_equal_ref(dense_passthrough_c(dense_c())) 4111986Sandreas.sandberg@arm.com assert_equal_ref(dense_passthrough_r(dense_c())) 4211986Sandreas.sandberg@arm.com assert_equal_ref(dense_passthrough_c(dense_r())) 4311986Sandreas.sandberg@arm.com 4411986Sandreas.sandberg@arm.com 4511986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy 4611986Sandreas.sandberg@arm.comdef test_nonunit_stride_from_python(): 4711986Sandreas.sandberg@arm.com from pybind11_tests import double_row, double_col, double_mat_cm, double_mat_rm 4811986Sandreas.sandberg@arm.com 4911986Sandreas.sandberg@arm.com counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3)) 5011986Sandreas.sandberg@arm.com first_row = counting_mat[0, :] 5111986Sandreas.sandberg@arm.com first_col = counting_mat[:, 0] 5211986Sandreas.sandberg@arm.com assert np.array_equal(double_row(first_row), 2.0 * first_row) 5311986Sandreas.sandberg@arm.com assert np.array_equal(double_col(first_row), 2.0 * first_row) 5411986Sandreas.sandberg@arm.com assert np.array_equal(double_row(first_col), 2.0 * first_col) 5511986Sandreas.sandberg@arm.com assert np.array_equal(double_col(first_col), 2.0 * first_col) 5611986Sandreas.sandberg@arm.com 5711986Sandreas.sandberg@arm.com counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3)) 5811986Sandreas.sandberg@arm.com slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]] 5911986Sandreas.sandberg@arm.com for slice_idx, ref_mat in enumerate(slices): 6011986Sandreas.sandberg@arm.com assert np.array_equal(double_mat_cm(ref_mat), 2.0 * ref_mat) 6111986Sandreas.sandberg@arm.com assert np.array_equal(double_mat_rm(ref_mat), 2.0 * ref_mat) 6211986Sandreas.sandberg@arm.com 6311986Sandreas.sandberg@arm.com 6411986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy 6511986Sandreas.sandberg@arm.comdef test_nonunit_stride_to_python(): 6611986Sandreas.sandberg@arm.com from pybind11_tests import diagonal, diagonal_1, diagonal_n, block 6711986Sandreas.sandberg@arm.com 6811986Sandreas.sandberg@arm.com assert np.all(diagonal(ref) == ref.diagonal()) 6911986Sandreas.sandberg@arm.com assert np.all(diagonal_1(ref) == ref.diagonal(1)) 7011986Sandreas.sandberg@arm.com for i in range(-5, 7): 7111986Sandreas.sandberg@arm.com assert np.all(diagonal_n(ref, i) == ref.diagonal(i)), "diagonal_n({})".format(i) 7211986Sandreas.sandberg@arm.com 7311986Sandreas.sandberg@arm.com assert np.all(block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]) 7411986Sandreas.sandberg@arm.com assert np.all(block(ref, 1, 4, 4, 2) == ref[1:, 4:]) 7511986Sandreas.sandberg@arm.com assert np.all(block(ref, 1, 4, 3, 2) == ref[1:4, 4:]) 7611986Sandreas.sandberg@arm.com 7711986Sandreas.sandberg@arm.com 7811986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy 7911986Sandreas.sandberg@arm.comdef test_eigen_ref_to_python(): 8011986Sandreas.sandberg@arm.com from pybind11_tests import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6 8111986Sandreas.sandberg@arm.com 8211986Sandreas.sandberg@arm.com chols = [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6] 8311986Sandreas.sandberg@arm.com for i, chol in enumerate(chols, start=1): 8411986Sandreas.sandberg@arm.com mymat = chol(np.array([[1, 2, 4], [2, 13, 23], [4, 23, 77]])) 8511986Sandreas.sandberg@arm.com assert np.all(mymat == np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])), "cholesky{}".format(i) 8611986Sandreas.sandberg@arm.com 8711986Sandreas.sandberg@arm.com 8811986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy 8911986Sandreas.sandberg@arm.comdef test_special_matrix_objects(): 9011986Sandreas.sandberg@arm.com from pybind11_tests import incr_diag, symmetric_upper, symmetric_lower 9111986Sandreas.sandberg@arm.com 9211986Sandreas.sandberg@arm.com assert np.all(incr_diag(7) == np.diag([1, 2, 3, 4, 5, 6, 7])) 9311986Sandreas.sandberg@arm.com 9411986Sandreas.sandberg@arm.com asymm = np.array([[ 1, 2, 3, 4], 9511986Sandreas.sandberg@arm.com [ 5, 6, 7, 8], 9611986Sandreas.sandberg@arm.com [ 9, 10, 11, 12], 9711986Sandreas.sandberg@arm.com [13, 14, 15, 16]]) 9811986Sandreas.sandberg@arm.com symm_lower = np.array(asymm) 9911986Sandreas.sandberg@arm.com symm_upper = np.array(asymm) 10011986Sandreas.sandberg@arm.com for i in range(4): 10111986Sandreas.sandberg@arm.com for j in range(i + 1, 4): 10211986Sandreas.sandberg@arm.com symm_lower[i, j] = symm_lower[j, i] 10311986Sandreas.sandberg@arm.com symm_upper[j, i] = symm_upper[i, j] 10411986Sandreas.sandberg@arm.com 10511986Sandreas.sandberg@arm.com assert np.all(symmetric_lower(asymm) == symm_lower) 10611986Sandreas.sandberg@arm.com assert np.all(symmetric_upper(asymm) == symm_upper) 10711986Sandreas.sandberg@arm.com 10811986Sandreas.sandberg@arm.com 10911986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_numpy 11011986Sandreas.sandberg@arm.comdef test_dense_signature(doc): 11111986Sandreas.sandberg@arm.com from pybind11_tests import double_col, double_row, double_mat_rm 11211986Sandreas.sandberg@arm.com 11311986Sandreas.sandberg@arm.com assert doc(double_col) == """ 11411986Sandreas.sandberg@arm.com double_col(arg0: numpy.ndarray[float32[m, 1]]) -> numpy.ndarray[float32[m, 1]] 11511986Sandreas.sandberg@arm.com """ 11611986Sandreas.sandberg@arm.com assert doc(double_row) == """ 11711986Sandreas.sandberg@arm.com double_row(arg0: numpy.ndarray[float32[1, n]]) -> numpy.ndarray[float32[1, n]] 11811986Sandreas.sandberg@arm.com """ 11911986Sandreas.sandberg@arm.com assert doc(double_mat_rm) == """ 12011986Sandreas.sandberg@arm.com double_mat_rm(arg0: numpy.ndarray[float32[m, n]]) -> numpy.ndarray[float32[m, n]] 12111986Sandreas.sandberg@arm.com """ 12211986Sandreas.sandberg@arm.com 12311986Sandreas.sandberg@arm.com 12411986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_scipy 12511986Sandreas.sandberg@arm.comdef test_sparse(): 12611986Sandreas.sandberg@arm.com from pybind11_tests import sparse_r, sparse_c, sparse_passthrough_r, sparse_passthrough_c 12711986Sandreas.sandberg@arm.com 12811986Sandreas.sandberg@arm.com assert_sparse_equal_ref(sparse_r()) 12911986Sandreas.sandberg@arm.com assert_sparse_equal_ref(sparse_c()) 13011986Sandreas.sandberg@arm.com assert_sparse_equal_ref(sparse_passthrough_r(sparse_r())) 13111986Sandreas.sandberg@arm.com assert_sparse_equal_ref(sparse_passthrough_c(sparse_c())) 13211986Sandreas.sandberg@arm.com assert_sparse_equal_ref(sparse_passthrough_r(sparse_c())) 13311986Sandreas.sandberg@arm.com assert_sparse_equal_ref(sparse_passthrough_c(sparse_r())) 13411986Sandreas.sandberg@arm.com 13511986Sandreas.sandberg@arm.com 13611986Sandreas.sandberg@arm.com@pytest.requires_eigen_and_scipy 13711986Sandreas.sandberg@arm.comdef test_sparse_signature(doc): 13811986Sandreas.sandberg@arm.com from pybind11_tests import sparse_passthrough_r, sparse_passthrough_c 13911986Sandreas.sandberg@arm.com 14011986Sandreas.sandberg@arm.com assert doc(sparse_passthrough_r) == """ 14111986Sandreas.sandberg@arm.com sparse_passthrough_r(arg0: scipy.sparse.csr_matrix[float32]) -> scipy.sparse.csr_matrix[float32] 14211986Sandreas.sandberg@arm.com """ # noqa: E501 line too long 14311986Sandreas.sandberg@arm.com assert doc(sparse_passthrough_c) == """ 14411986Sandreas.sandberg@arm.com sparse_passthrough_c(arg0: scipy.sparse.csc_matrix[float32]) -> scipy.sparse.csc_matrix[float32] 14511986Sandreas.sandberg@arm.com """ # noqa: E501 line too long 146