eigen.h revision 11986:c12e4625ab56
1/*
2    pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
3
4    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
5
6    All rights reserved. Use of this source code is governed by a
7    BSD-style license that can be found in the LICENSE file.
8*/
9
10#pragma once
11
12#include "numpy.h"
13
14#if defined(__INTEL_COMPILER)
15#  pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
16#elif defined(__GNUG__) || defined(__clang__)
17#  pragma GCC diagnostic push
18#  pragma GCC diagnostic ignored "-Wconversion"
19#  pragma GCC diagnostic ignored "-Wdeprecated-declarations"
20#endif
21
22#include <Eigen/Core>
23#include <Eigen/SparseCore>
24
25#if defined(__GNUG__) || defined(__clang__)
26#  pragma GCC diagnostic pop
27#endif
28
29#if defined(_MSC_VER)
30#pragma warning(push)
31#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
32#endif
33
34NAMESPACE_BEGIN(pybind11)
35NAMESPACE_BEGIN(detail)
36
37template <typename T> using is_eigen_dense = is_template_base_of<Eigen::DenseBase, T>;
38template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
39template <typename T> using is_eigen_ref = is_template_base_of<Eigen::RefBase, T>;
40
41// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above.  This
42// basically covers anything that can be assigned to a dense matrix but that don't have a typical
43// matrix data layout that can be copied from their .data().  For example, DiagonalMatrix and
44// SelfAdjointView fall into this category.
45template <typename T> using is_eigen_base = bool_constant<
46    is_template_base_of<Eigen::EigenBase, T>::value
47    && !is_eigen_dense<T>::value && !is_eigen_sparse<T>::value
48>;
49
50template<typename Type>
51struct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>> {
52    typedef typename Type::Scalar Scalar;
53    static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
54    static constexpr bool isVector = Type::IsVectorAtCompileTime;
55
56    bool load(handle src, bool) {
57        auto buf = array_t<Scalar>::ensure(src);
58        if (!buf)
59            return false;
60
61        if (buf.ndim() == 1) {
62            typedef Eigen::InnerStride<> Strides;
63            if (!isVector &&
64                !(Type::RowsAtCompileTime == Eigen::Dynamic &&
65                  Type::ColsAtCompileTime == Eigen::Dynamic))
66                return false;
67
68            if (Type::SizeAtCompileTime != Eigen::Dynamic &&
69                buf.shape(0) != (size_t) Type::SizeAtCompileTime)
70                return false;
71
72            Strides::Index n_elts = (Strides::Index) buf.shape(0);
73            Strides::Index unity = 1;
74
75            value = Eigen::Map<Type, 0, Strides>(
76                buf.mutable_data(),
77                rowMajor ? unity : n_elts,
78                rowMajor ? n_elts : unity,
79                Strides(buf.strides(0) / sizeof(Scalar))
80            );
81        } else if (buf.ndim() == 2) {
82            typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
83
84            if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
85                (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
86                return false;
87
88            value = Eigen::Map<Type, 0, Strides>(
89                buf.mutable_data(),
90                typename Strides::Index(buf.shape(0)),
91                typename Strides::Index(buf.shape(1)),
92                Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
93                        buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
94            );
95        } else {
96            return false;
97        }
98        return true;
99    }
100
101    static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
102        if (isVector) {
103            return array(
104                { (size_t) src.size() },                                      // shape
105                { sizeof(Scalar) * static_cast<size_t>(src.innerStride()) },  // strides
106                src.data()                                                    // data
107            ).release();
108        } else {
109            return array(
110                { (size_t) src.rows(),                                        // shape
111                  (size_t) src.cols() },
112                { sizeof(Scalar) * static_cast<size_t>(src.rowStride()),      // strides
113                  sizeof(Scalar) * static_cast<size_t>(src.colStride()) },
114                src.data()                                                    // data
115            ).release();
116        }
117    }
118
119    PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
120            _("[") + rows() + _(", ") + cols() + _("]]"));
121
122protected:
123    template <typename T = Type, enable_if_t<T::RowsAtCompileTime == Eigen::Dynamic, int> = 0>
124    static PYBIND11_DESCR rows() { return _("m"); }
125    template <typename T = Type, enable_if_t<T::RowsAtCompileTime != Eigen::Dynamic, int> = 0>
126    static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); }
127    template <typename T = Type, enable_if_t<T::ColsAtCompileTime == Eigen::Dynamic, int> = 0>
128    static PYBIND11_DESCR cols() { return _("n"); }
129    template <typename T = Type, enable_if_t<T::ColsAtCompileTime != Eigen::Dynamic, int> = 0>
130    static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
131};
132
133// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructable, so it needs a special
134// type_caster to handle argument copying/forwarding.
135template <typename CVDerived, int Options, typename StrideType>
136struct type_caster<Eigen::Ref<CVDerived, Options, StrideType>> {
137protected:
138    using Type = Eigen::Ref<CVDerived, Options, StrideType>;
139    using Derived = typename std::remove_const<CVDerived>::type;
140    using DerivedCaster = type_caster<Derived>;
141    DerivedCaster derived_caster;
142    std::unique_ptr<Type> value;
143public:
144    bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
145    static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
146    static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }
147
148    static PYBIND11_DESCR name() { return DerivedCaster::name(); }
149
150    operator Type*() { return value.get(); }
151    operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
152    template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
153};
154
155// type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can
156// cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that.
157template <typename Type>
158struct type_caster<Type, enable_if_t<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>> {
159protected:
160    using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>;
161    using MatrixCaster = type_caster<Matrix>;
162public:
163    [[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); }
164    static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); }
165    static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); }
166
167    static PYBIND11_DESCR name() { return MatrixCaster::name(); }
168
169    [[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
170    [[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
171    template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
172};
173
174template<typename Type>
175struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
176    typedef typename Type::Scalar Scalar;
177    typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;
178    typedef typename Type::Index Index;
179    static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
180
181    bool load(handle src, bool) {
182        if (!src)
183            return false;
184
185        auto obj = reinterpret_borrow<object>(src);
186        object sparse_module = module::import("scipy.sparse");
187        object matrix_type = sparse_module.attr(
188            rowMajor ? "csr_matrix" : "csc_matrix");
189
190        if (obj.get_type() != matrix_type.ptr()) {
191            try {
192                obj = matrix_type(obj);
193            } catch (const error_already_set &) {
194                return false;
195            }
196        }
197
198        auto values = array_t<Scalar>((object) obj.attr("data"));
199        auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
200        auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
201        auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
202        auto nnz = obj.attr("nnz").cast<Index>();
203
204        if (!values || !innerIndices || !outerIndices)
205            return false;
206
207        value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
208            shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
209            outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
210
211        return true;
212    }
213
214    static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
215        const_cast<Type&>(src).makeCompressed();
216
217        object matrix_type = module::import("scipy.sparse").attr(
218            rowMajor ? "csr_matrix" : "csc_matrix");
219
220        array data((size_t) src.nonZeros(), src.valuePtr());
221        array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
222        array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr());
223
224        return matrix_type(
225            std::make_tuple(data, innerIndices, outerIndices),
226            std::make_pair(src.rows(), src.cols())
227        ).release();
228    }
229
230    PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
231            + npy_format_descriptor<Scalar>::name() + _("]"));
232};
233
234NAMESPACE_END(detail)
235NAMESPACE_END(pybind11)
236
237#if defined(_MSC_VER)
238#pragma warning(pop)
239#endif
240