eigen.h revision 11986:c12e4625ab56
1/* 2 pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices 3 4 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> 5 6 All rights reserved. Use of this source code is governed by a 7 BSD-style license that can be found in the LICENSE file. 8*/ 9 10#pragma once 11 12#include "numpy.h" 13 14#if defined(__INTEL_COMPILER) 15# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) 16#elif defined(__GNUG__) || defined(__clang__) 17# pragma GCC diagnostic push 18# pragma GCC diagnostic ignored "-Wconversion" 19# pragma GCC diagnostic ignored "-Wdeprecated-declarations" 20#endif 21 22#include <Eigen/Core> 23#include <Eigen/SparseCore> 24 25#if defined(__GNUG__) || defined(__clang__) 26# pragma GCC diagnostic pop 27#endif 28 29#if defined(_MSC_VER) 30#pragma warning(push) 31#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 32#endif 33 34NAMESPACE_BEGIN(pybind11) 35NAMESPACE_BEGIN(detail) 36 37template <typename T> using is_eigen_dense = is_template_base_of<Eigen::DenseBase, T>; 38template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>; 39template <typename T> using is_eigen_ref = is_template_base_of<Eigen::RefBase, T>; 40 41// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This 42// basically covers anything that can be assigned to a dense matrix but that don't have a typical 43// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and 44// SelfAdjointView fall into this category. 45template <typename T> using is_eigen_base = bool_constant< 46 is_template_base_of<Eigen::EigenBase, T>::value 47 && !is_eigen_dense<T>::value && !is_eigen_sparse<T>::value 48>; 49 50template<typename Type> 51struct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>> { 52 typedef typename Type::Scalar Scalar; 53 static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; 54 static constexpr bool isVector = Type::IsVectorAtCompileTime; 55 56 bool load(handle src, bool) { 57 auto buf = array_t<Scalar>::ensure(src); 58 if (!buf) 59 return false; 60 61 if (buf.ndim() == 1) { 62 typedef Eigen::InnerStride<> Strides; 63 if (!isVector && 64 !(Type::RowsAtCompileTime == Eigen::Dynamic && 65 Type::ColsAtCompileTime == Eigen::Dynamic)) 66 return false; 67 68 if (Type::SizeAtCompileTime != Eigen::Dynamic && 69 buf.shape(0) != (size_t) Type::SizeAtCompileTime) 70 return false; 71 72 Strides::Index n_elts = (Strides::Index) buf.shape(0); 73 Strides::Index unity = 1; 74 75 value = Eigen::Map<Type, 0, Strides>( 76 buf.mutable_data(), 77 rowMajor ? unity : n_elts, 78 rowMajor ? n_elts : unity, 79 Strides(buf.strides(0) / sizeof(Scalar)) 80 ); 81 } else if (buf.ndim() == 2) { 82 typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides; 83 84 if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) || 85 (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime)) 86 return false; 87 88 value = Eigen::Map<Type, 0, Strides>( 89 buf.mutable_data(), 90 typename Strides::Index(buf.shape(0)), 91 typename Strides::Index(buf.shape(1)), 92 Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar), 93 buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar)) 94 ); 95 } else { 96 return false; 97 } 98 return true; 99 } 100 101 static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { 102 if (isVector) { 103 return array( 104 { (size_t) src.size() }, // shape 105 { sizeof(Scalar) * static_cast<size_t>(src.innerStride()) }, // strides 106 src.data() // data 107 ).release(); 108 } else { 109 return array( 110 { (size_t) src.rows(), // shape 111 (size_t) src.cols() }, 112 { sizeof(Scalar) * static_cast<size_t>(src.rowStride()), // strides 113 sizeof(Scalar) * static_cast<size_t>(src.colStride()) }, 114 src.data() // data 115 ).release(); 116 } 117 } 118 119 PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() + 120 _("[") + rows() + _(", ") + cols() + _("]]")); 121 122protected: 123 template <typename T = Type, enable_if_t<T::RowsAtCompileTime == Eigen::Dynamic, int> = 0> 124 static PYBIND11_DESCR rows() { return _("m"); } 125 template <typename T = Type, enable_if_t<T::RowsAtCompileTime != Eigen::Dynamic, int> = 0> 126 static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); } 127 template <typename T = Type, enable_if_t<T::ColsAtCompileTime == Eigen::Dynamic, int> = 0> 128 static PYBIND11_DESCR cols() { return _("n"); } 129 template <typename T = Type, enable_if_t<T::ColsAtCompileTime != Eigen::Dynamic, int> = 0> 130 static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); } 131}; 132 133// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructable, so it needs a special 134// type_caster to handle argument copying/forwarding. 135template <typename CVDerived, int Options, typename StrideType> 136struct type_caster<Eigen::Ref<CVDerived, Options, StrideType>> { 137protected: 138 using Type = Eigen::Ref<CVDerived, Options, StrideType>; 139 using Derived = typename std::remove_const<CVDerived>::type; 140 using DerivedCaster = type_caster<Derived>; 141 DerivedCaster derived_caster; 142 std::unique_ptr<Type> value; 143public: 144 bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; } 145 static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); } 146 static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); } 147 148 static PYBIND11_DESCR name() { return DerivedCaster::name(); } 149 150 operator Type*() { return value.get(); } 151 operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; } 152 template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>; 153}; 154 155// type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can 156// cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that. 157template <typename Type> 158struct type_caster<Type, enable_if_t<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>> { 159protected: 160 using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>; 161 using MatrixCaster = type_caster<Matrix>; 162public: 163 [[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); } 164 static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); } 165 static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); } 166 167 static PYBIND11_DESCR name() { return MatrixCaster::name(); } 168 169 [[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); } 170 [[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); } 171 template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>; 172}; 173 174template<typename Type> 175struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> { 176 typedef typename Type::Scalar Scalar; 177 typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex; 178 typedef typename Type::Index Index; 179 static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; 180 181 bool load(handle src, bool) { 182 if (!src) 183 return false; 184 185 auto obj = reinterpret_borrow<object>(src); 186 object sparse_module = module::import("scipy.sparse"); 187 object matrix_type = sparse_module.attr( 188 rowMajor ? "csr_matrix" : "csc_matrix"); 189 190 if (obj.get_type() != matrix_type.ptr()) { 191 try { 192 obj = matrix_type(obj); 193 } catch (const error_already_set &) { 194 return false; 195 } 196 } 197 198 auto values = array_t<Scalar>((object) obj.attr("data")); 199 auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices")); 200 auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr")); 201 auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); 202 auto nnz = obj.attr("nnz").cast<Index>(); 203 204 if (!values || !innerIndices || !outerIndices) 205 return false; 206 207 value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>( 208 shape[0].cast<Index>(), shape[1].cast<Index>(), nnz, 209 outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data()); 210 211 return true; 212 } 213 214 static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { 215 const_cast<Type&>(src).makeCompressed(); 216 217 object matrix_type = module::import("scipy.sparse").attr( 218 rowMajor ? "csr_matrix" : "csc_matrix"); 219 220 array data((size_t) src.nonZeros(), src.valuePtr()); 221 array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr()); 222 array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr()); 223 224 return matrix_type( 225 std::make_tuple(data, innerIndices, outerIndices), 226 std::make_pair(src.rows(), src.cols()) 227 ).release(); 228 } 229 230 PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") 231 + npy_format_descriptor<Scalar>::name() + _("]")); 232}; 233 234NAMESPACE_END(detail) 235NAMESPACE_END(pybind11) 236 237#if defined(_MSC_VER) 238#pragma warning(pop) 239#endif 240