eigen.h revision 11986
111986Sandreas.sandberg@arm.com/* 211986Sandreas.sandberg@arm.com pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices 311986Sandreas.sandberg@arm.com 411986Sandreas.sandberg@arm.com Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> 511986Sandreas.sandberg@arm.com 611986Sandreas.sandberg@arm.com All rights reserved. Use of this source code is governed by a 711986Sandreas.sandberg@arm.com BSD-style license that can be found in the LICENSE file. 811986Sandreas.sandberg@arm.com*/ 911986Sandreas.sandberg@arm.com 1011986Sandreas.sandberg@arm.com#pragma once 1111986Sandreas.sandberg@arm.com 1211986Sandreas.sandberg@arm.com#include "numpy.h" 1311986Sandreas.sandberg@arm.com 1411986Sandreas.sandberg@arm.com#if defined(__INTEL_COMPILER) 1511986Sandreas.sandberg@arm.com# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) 1611986Sandreas.sandberg@arm.com#elif defined(__GNUG__) || defined(__clang__) 1711986Sandreas.sandberg@arm.com# pragma GCC diagnostic push 1811986Sandreas.sandberg@arm.com# pragma GCC diagnostic ignored "-Wconversion" 1911986Sandreas.sandberg@arm.com# pragma GCC diagnostic ignored "-Wdeprecated-declarations" 2011986Sandreas.sandberg@arm.com#endif 2111986Sandreas.sandberg@arm.com 2211986Sandreas.sandberg@arm.com#include <Eigen/Core> 2311986Sandreas.sandberg@arm.com#include <Eigen/SparseCore> 2411986Sandreas.sandberg@arm.com 2511986Sandreas.sandberg@arm.com#if defined(__GNUG__) || defined(__clang__) 2611986Sandreas.sandberg@arm.com# pragma GCC diagnostic pop 2711986Sandreas.sandberg@arm.com#endif 2811986Sandreas.sandberg@arm.com 2911986Sandreas.sandberg@arm.com#if defined(_MSC_VER) 3011986Sandreas.sandberg@arm.com#pragma warning(push) 3111986Sandreas.sandberg@arm.com#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 3211986Sandreas.sandberg@arm.com#endif 3311986Sandreas.sandberg@arm.com 3411986Sandreas.sandberg@arm.comNAMESPACE_BEGIN(pybind11) 3511986Sandreas.sandberg@arm.comNAMESPACE_BEGIN(detail) 3611986Sandreas.sandberg@arm.com 3711986Sandreas.sandberg@arm.comtemplate <typename T> using is_eigen_dense = is_template_base_of<Eigen::DenseBase, T>; 3811986Sandreas.sandberg@arm.comtemplate <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>; 3911986Sandreas.sandberg@arm.comtemplate <typename T> using is_eigen_ref = is_template_base_of<Eigen::RefBase, T>; 4011986Sandreas.sandberg@arm.com 4111986Sandreas.sandberg@arm.com// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This 4211986Sandreas.sandberg@arm.com// basically covers anything that can be assigned to a dense matrix but that don't have a typical 4311986Sandreas.sandberg@arm.com// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and 4411986Sandreas.sandberg@arm.com// SelfAdjointView fall into this category. 4511986Sandreas.sandberg@arm.comtemplate <typename T> using is_eigen_base = bool_constant< 4611986Sandreas.sandberg@arm.com is_template_base_of<Eigen::EigenBase, T>::value 4711986Sandreas.sandberg@arm.com && !is_eigen_dense<T>::value && !is_eigen_sparse<T>::value 4811986Sandreas.sandberg@arm.com>; 4911986Sandreas.sandberg@arm.com 5011986Sandreas.sandberg@arm.comtemplate<typename Type> 5111986Sandreas.sandberg@arm.comstruct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>> { 5211986Sandreas.sandberg@arm.com typedef typename Type::Scalar Scalar; 5311986Sandreas.sandberg@arm.com static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; 5411986Sandreas.sandberg@arm.com static constexpr bool isVector = Type::IsVectorAtCompileTime; 5511986Sandreas.sandberg@arm.com 5611986Sandreas.sandberg@arm.com bool load(handle src, bool) { 5711986Sandreas.sandberg@arm.com auto buf = array_t<Scalar>::ensure(src); 5811986Sandreas.sandberg@arm.com if (!buf) 5911986Sandreas.sandberg@arm.com return false; 6011986Sandreas.sandberg@arm.com 6111986Sandreas.sandberg@arm.com if (buf.ndim() == 1) { 6211986Sandreas.sandberg@arm.com typedef Eigen::InnerStride<> Strides; 6311986Sandreas.sandberg@arm.com if (!isVector && 6411986Sandreas.sandberg@arm.com !(Type::RowsAtCompileTime == Eigen::Dynamic && 6511986Sandreas.sandberg@arm.com Type::ColsAtCompileTime == Eigen::Dynamic)) 6611986Sandreas.sandberg@arm.com return false; 6711986Sandreas.sandberg@arm.com 6811986Sandreas.sandberg@arm.com if (Type::SizeAtCompileTime != Eigen::Dynamic && 6911986Sandreas.sandberg@arm.com buf.shape(0) != (size_t) Type::SizeAtCompileTime) 7011986Sandreas.sandberg@arm.com return false; 7111986Sandreas.sandberg@arm.com 7211986Sandreas.sandberg@arm.com Strides::Index n_elts = (Strides::Index) buf.shape(0); 7311986Sandreas.sandberg@arm.com Strides::Index unity = 1; 7411986Sandreas.sandberg@arm.com 7511986Sandreas.sandberg@arm.com value = Eigen::Map<Type, 0, Strides>( 7611986Sandreas.sandberg@arm.com buf.mutable_data(), 7711986Sandreas.sandberg@arm.com rowMajor ? unity : n_elts, 7811986Sandreas.sandberg@arm.com rowMajor ? n_elts : unity, 7911986Sandreas.sandberg@arm.com Strides(buf.strides(0) / sizeof(Scalar)) 8011986Sandreas.sandberg@arm.com ); 8111986Sandreas.sandberg@arm.com } else if (buf.ndim() == 2) { 8211986Sandreas.sandberg@arm.com typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides; 8311986Sandreas.sandberg@arm.com 8411986Sandreas.sandberg@arm.com if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) || 8511986Sandreas.sandberg@arm.com (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime)) 8611986Sandreas.sandberg@arm.com return false; 8711986Sandreas.sandberg@arm.com 8811986Sandreas.sandberg@arm.com value = Eigen::Map<Type, 0, Strides>( 8911986Sandreas.sandberg@arm.com buf.mutable_data(), 9011986Sandreas.sandberg@arm.com typename Strides::Index(buf.shape(0)), 9111986Sandreas.sandberg@arm.com typename Strides::Index(buf.shape(1)), 9211986Sandreas.sandberg@arm.com Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar), 9311986Sandreas.sandberg@arm.com buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar)) 9411986Sandreas.sandberg@arm.com ); 9511986Sandreas.sandberg@arm.com } else { 9611986Sandreas.sandberg@arm.com return false; 9711986Sandreas.sandberg@arm.com } 9811986Sandreas.sandberg@arm.com return true; 9911986Sandreas.sandberg@arm.com } 10011986Sandreas.sandberg@arm.com 10111986Sandreas.sandberg@arm.com static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { 10211986Sandreas.sandberg@arm.com if (isVector) { 10311986Sandreas.sandberg@arm.com return array( 10411986Sandreas.sandberg@arm.com { (size_t) src.size() }, // shape 10511986Sandreas.sandberg@arm.com { sizeof(Scalar) * static_cast<size_t>(src.innerStride()) }, // strides 10611986Sandreas.sandberg@arm.com src.data() // data 10711986Sandreas.sandberg@arm.com ).release(); 10811986Sandreas.sandberg@arm.com } else { 10911986Sandreas.sandberg@arm.com return array( 11011986Sandreas.sandberg@arm.com { (size_t) src.rows(), // shape 11111986Sandreas.sandberg@arm.com (size_t) src.cols() }, 11211986Sandreas.sandberg@arm.com { sizeof(Scalar) * static_cast<size_t>(src.rowStride()), // strides 11311986Sandreas.sandberg@arm.com sizeof(Scalar) * static_cast<size_t>(src.colStride()) }, 11411986Sandreas.sandberg@arm.com src.data() // data 11511986Sandreas.sandberg@arm.com ).release(); 11611986Sandreas.sandberg@arm.com } 11711986Sandreas.sandberg@arm.com } 11811986Sandreas.sandberg@arm.com 11911986Sandreas.sandberg@arm.com PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() + 12011986Sandreas.sandberg@arm.com _("[") + rows() + _(", ") + cols() + _("]]")); 12111986Sandreas.sandberg@arm.com 12211986Sandreas.sandberg@arm.comprotected: 12311986Sandreas.sandberg@arm.com template <typename T = Type, enable_if_t<T::RowsAtCompileTime == Eigen::Dynamic, int> = 0> 12411986Sandreas.sandberg@arm.com static PYBIND11_DESCR rows() { return _("m"); } 12511986Sandreas.sandberg@arm.com template <typename T = Type, enable_if_t<T::RowsAtCompileTime != Eigen::Dynamic, int> = 0> 12611986Sandreas.sandberg@arm.com static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); } 12711986Sandreas.sandberg@arm.com template <typename T = Type, enable_if_t<T::ColsAtCompileTime == Eigen::Dynamic, int> = 0> 12811986Sandreas.sandberg@arm.com static PYBIND11_DESCR cols() { return _("n"); } 12911986Sandreas.sandberg@arm.com template <typename T = Type, enable_if_t<T::ColsAtCompileTime != Eigen::Dynamic, int> = 0> 13011986Sandreas.sandberg@arm.com static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); } 13111986Sandreas.sandberg@arm.com}; 13211986Sandreas.sandberg@arm.com 13311986Sandreas.sandberg@arm.com// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructable, so it needs a special 13411986Sandreas.sandberg@arm.com// type_caster to handle argument copying/forwarding. 13511986Sandreas.sandberg@arm.comtemplate <typename CVDerived, int Options, typename StrideType> 13611986Sandreas.sandberg@arm.comstruct type_caster<Eigen::Ref<CVDerived, Options, StrideType>> { 13711986Sandreas.sandberg@arm.comprotected: 13811986Sandreas.sandberg@arm.com using Type = Eigen::Ref<CVDerived, Options, StrideType>; 13911986Sandreas.sandberg@arm.com using Derived = typename std::remove_const<CVDerived>::type; 14011986Sandreas.sandberg@arm.com using DerivedCaster = type_caster<Derived>; 14111986Sandreas.sandberg@arm.com DerivedCaster derived_caster; 14211986Sandreas.sandberg@arm.com std::unique_ptr<Type> value; 14311986Sandreas.sandberg@arm.compublic: 14411986Sandreas.sandberg@arm.com bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; } 14511986Sandreas.sandberg@arm.com static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); } 14611986Sandreas.sandberg@arm.com static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); } 14711986Sandreas.sandberg@arm.com 14811986Sandreas.sandberg@arm.com static PYBIND11_DESCR name() { return DerivedCaster::name(); } 14911986Sandreas.sandberg@arm.com 15011986Sandreas.sandberg@arm.com operator Type*() { return value.get(); } 15111986Sandreas.sandberg@arm.com operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; } 15211986Sandreas.sandberg@arm.com template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>; 15311986Sandreas.sandberg@arm.com}; 15411986Sandreas.sandberg@arm.com 15511986Sandreas.sandberg@arm.com// type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can 15611986Sandreas.sandberg@arm.com// cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that. 15711986Sandreas.sandberg@arm.comtemplate <typename Type> 15811986Sandreas.sandberg@arm.comstruct type_caster<Type, enable_if_t<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>> { 15911986Sandreas.sandberg@arm.comprotected: 16011986Sandreas.sandberg@arm.com using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>; 16111986Sandreas.sandberg@arm.com using MatrixCaster = type_caster<Matrix>; 16211986Sandreas.sandberg@arm.compublic: 16311986Sandreas.sandberg@arm.com [[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); } 16411986Sandreas.sandberg@arm.com static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); } 16511986Sandreas.sandberg@arm.com static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); } 16611986Sandreas.sandberg@arm.com 16711986Sandreas.sandberg@arm.com static PYBIND11_DESCR name() { return MatrixCaster::name(); } 16811986Sandreas.sandberg@arm.com 16911986Sandreas.sandberg@arm.com [[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); } 17011986Sandreas.sandberg@arm.com [[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); } 17111986Sandreas.sandberg@arm.com template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>; 17211986Sandreas.sandberg@arm.com}; 17311986Sandreas.sandberg@arm.com 17411986Sandreas.sandberg@arm.comtemplate<typename Type> 17511986Sandreas.sandberg@arm.comstruct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> { 17611986Sandreas.sandberg@arm.com typedef typename Type::Scalar Scalar; 17711986Sandreas.sandberg@arm.com typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex; 17811986Sandreas.sandberg@arm.com typedef typename Type::Index Index; 17911986Sandreas.sandberg@arm.com static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; 18011986Sandreas.sandberg@arm.com 18111986Sandreas.sandberg@arm.com bool load(handle src, bool) { 18211986Sandreas.sandberg@arm.com if (!src) 18311986Sandreas.sandberg@arm.com return false; 18411986Sandreas.sandberg@arm.com 18511986Sandreas.sandberg@arm.com auto obj = reinterpret_borrow<object>(src); 18611986Sandreas.sandberg@arm.com object sparse_module = module::import("scipy.sparse"); 18711986Sandreas.sandberg@arm.com object matrix_type = sparse_module.attr( 18811986Sandreas.sandberg@arm.com rowMajor ? "csr_matrix" : "csc_matrix"); 18911986Sandreas.sandberg@arm.com 19011986Sandreas.sandberg@arm.com if (obj.get_type() != matrix_type.ptr()) { 19111986Sandreas.sandberg@arm.com try { 19211986Sandreas.sandberg@arm.com obj = matrix_type(obj); 19311986Sandreas.sandberg@arm.com } catch (const error_already_set &) { 19411986Sandreas.sandberg@arm.com return false; 19511986Sandreas.sandberg@arm.com } 19611986Sandreas.sandberg@arm.com } 19711986Sandreas.sandberg@arm.com 19811986Sandreas.sandberg@arm.com auto values = array_t<Scalar>((object) obj.attr("data")); 19911986Sandreas.sandberg@arm.com auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices")); 20011986Sandreas.sandberg@arm.com auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr")); 20111986Sandreas.sandberg@arm.com auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); 20211986Sandreas.sandberg@arm.com auto nnz = obj.attr("nnz").cast<Index>(); 20311986Sandreas.sandberg@arm.com 20411986Sandreas.sandberg@arm.com if (!values || !innerIndices || !outerIndices) 20511986Sandreas.sandberg@arm.com return false; 20611986Sandreas.sandberg@arm.com 20711986Sandreas.sandberg@arm.com value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>( 20811986Sandreas.sandberg@arm.com shape[0].cast<Index>(), shape[1].cast<Index>(), nnz, 20911986Sandreas.sandberg@arm.com outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data()); 21011986Sandreas.sandberg@arm.com 21111986Sandreas.sandberg@arm.com return true; 21211986Sandreas.sandberg@arm.com } 21311986Sandreas.sandberg@arm.com 21411986Sandreas.sandberg@arm.com static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { 21511986Sandreas.sandberg@arm.com const_cast<Type&>(src).makeCompressed(); 21611986Sandreas.sandberg@arm.com 21711986Sandreas.sandberg@arm.com object matrix_type = module::import("scipy.sparse").attr( 21811986Sandreas.sandberg@arm.com rowMajor ? "csr_matrix" : "csc_matrix"); 21911986Sandreas.sandberg@arm.com 22011986Sandreas.sandberg@arm.com array data((size_t) src.nonZeros(), src.valuePtr()); 22111986Sandreas.sandberg@arm.com array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr()); 22211986Sandreas.sandberg@arm.com array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr()); 22311986Sandreas.sandberg@arm.com 22411986Sandreas.sandberg@arm.com return matrix_type( 22511986Sandreas.sandberg@arm.com std::make_tuple(data, innerIndices, outerIndices), 22611986Sandreas.sandberg@arm.com std::make_pair(src.rows(), src.cols()) 22711986Sandreas.sandberg@arm.com ).release(); 22811986Sandreas.sandberg@arm.com } 22911986Sandreas.sandberg@arm.com 23011986Sandreas.sandberg@arm.com PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") 23111986Sandreas.sandberg@arm.com + npy_format_descriptor<Scalar>::name() + _("]")); 23211986Sandreas.sandberg@arm.com}; 23311986Sandreas.sandberg@arm.com 23411986Sandreas.sandberg@arm.comNAMESPACE_END(detail) 23511986Sandreas.sandberg@arm.comNAMESPACE_END(pybind11) 23611986Sandreas.sandberg@arm.com 23711986Sandreas.sandberg@arm.com#if defined(_MSC_VER) 23811986Sandreas.sandberg@arm.com#pragma warning(pop) 23911986Sandreas.sandberg@arm.com#endif 240