eigen.h revision 11986
111986Sandreas.sandberg@arm.com/*
211986Sandreas.sandberg@arm.com    pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
311986Sandreas.sandberg@arm.com
411986Sandreas.sandberg@arm.com    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
511986Sandreas.sandberg@arm.com
611986Sandreas.sandberg@arm.com    All rights reserved. Use of this source code is governed by a
711986Sandreas.sandberg@arm.com    BSD-style license that can be found in the LICENSE file.
811986Sandreas.sandberg@arm.com*/
911986Sandreas.sandberg@arm.com
1011986Sandreas.sandberg@arm.com#pragma once
1111986Sandreas.sandberg@arm.com
1211986Sandreas.sandberg@arm.com#include "numpy.h"
1311986Sandreas.sandberg@arm.com
1411986Sandreas.sandberg@arm.com#if defined(__INTEL_COMPILER)
1511986Sandreas.sandberg@arm.com#  pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
1611986Sandreas.sandberg@arm.com#elif defined(__GNUG__) || defined(__clang__)
1711986Sandreas.sandberg@arm.com#  pragma GCC diagnostic push
1811986Sandreas.sandberg@arm.com#  pragma GCC diagnostic ignored "-Wconversion"
1911986Sandreas.sandberg@arm.com#  pragma GCC diagnostic ignored "-Wdeprecated-declarations"
2011986Sandreas.sandberg@arm.com#endif
2111986Sandreas.sandberg@arm.com
2211986Sandreas.sandberg@arm.com#include <Eigen/Core>
2311986Sandreas.sandberg@arm.com#include <Eigen/SparseCore>
2411986Sandreas.sandberg@arm.com
2511986Sandreas.sandberg@arm.com#if defined(__GNUG__) || defined(__clang__)
2611986Sandreas.sandberg@arm.com#  pragma GCC diagnostic pop
2711986Sandreas.sandberg@arm.com#endif
2811986Sandreas.sandberg@arm.com
2911986Sandreas.sandberg@arm.com#if defined(_MSC_VER)
3011986Sandreas.sandberg@arm.com#pragma warning(push)
3111986Sandreas.sandberg@arm.com#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
3211986Sandreas.sandberg@arm.com#endif
3311986Sandreas.sandberg@arm.com
3411986Sandreas.sandberg@arm.comNAMESPACE_BEGIN(pybind11)
3511986Sandreas.sandberg@arm.comNAMESPACE_BEGIN(detail)
3611986Sandreas.sandberg@arm.com
3711986Sandreas.sandberg@arm.comtemplate <typename T> using is_eigen_dense = is_template_base_of<Eigen::DenseBase, T>;
3811986Sandreas.sandberg@arm.comtemplate <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
3911986Sandreas.sandberg@arm.comtemplate <typename T> using is_eigen_ref = is_template_base_of<Eigen::RefBase, T>;
4011986Sandreas.sandberg@arm.com
4111986Sandreas.sandberg@arm.com// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above.  This
4211986Sandreas.sandberg@arm.com// basically covers anything that can be assigned to a dense matrix but that don't have a typical
4311986Sandreas.sandberg@arm.com// matrix data layout that can be copied from their .data().  For example, DiagonalMatrix and
4411986Sandreas.sandberg@arm.com// SelfAdjointView fall into this category.
4511986Sandreas.sandberg@arm.comtemplate <typename T> using is_eigen_base = bool_constant<
4611986Sandreas.sandberg@arm.com    is_template_base_of<Eigen::EigenBase, T>::value
4711986Sandreas.sandberg@arm.com    && !is_eigen_dense<T>::value && !is_eigen_sparse<T>::value
4811986Sandreas.sandberg@arm.com>;
4911986Sandreas.sandberg@arm.com
5011986Sandreas.sandberg@arm.comtemplate<typename Type>
5111986Sandreas.sandberg@arm.comstruct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>> {
5211986Sandreas.sandberg@arm.com    typedef typename Type::Scalar Scalar;
5311986Sandreas.sandberg@arm.com    static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
5411986Sandreas.sandberg@arm.com    static constexpr bool isVector = Type::IsVectorAtCompileTime;
5511986Sandreas.sandberg@arm.com
5611986Sandreas.sandberg@arm.com    bool load(handle src, bool) {
5711986Sandreas.sandberg@arm.com        auto buf = array_t<Scalar>::ensure(src);
5811986Sandreas.sandberg@arm.com        if (!buf)
5911986Sandreas.sandberg@arm.com            return false;
6011986Sandreas.sandberg@arm.com
6111986Sandreas.sandberg@arm.com        if (buf.ndim() == 1) {
6211986Sandreas.sandberg@arm.com            typedef Eigen::InnerStride<> Strides;
6311986Sandreas.sandberg@arm.com            if (!isVector &&
6411986Sandreas.sandberg@arm.com                !(Type::RowsAtCompileTime == Eigen::Dynamic &&
6511986Sandreas.sandberg@arm.com                  Type::ColsAtCompileTime == Eigen::Dynamic))
6611986Sandreas.sandberg@arm.com                return false;
6711986Sandreas.sandberg@arm.com
6811986Sandreas.sandberg@arm.com            if (Type::SizeAtCompileTime != Eigen::Dynamic &&
6911986Sandreas.sandberg@arm.com                buf.shape(0) != (size_t) Type::SizeAtCompileTime)
7011986Sandreas.sandberg@arm.com                return false;
7111986Sandreas.sandberg@arm.com
7211986Sandreas.sandberg@arm.com            Strides::Index n_elts = (Strides::Index) buf.shape(0);
7311986Sandreas.sandberg@arm.com            Strides::Index unity = 1;
7411986Sandreas.sandberg@arm.com
7511986Sandreas.sandberg@arm.com            value = Eigen::Map<Type, 0, Strides>(
7611986Sandreas.sandberg@arm.com                buf.mutable_data(),
7711986Sandreas.sandberg@arm.com                rowMajor ? unity : n_elts,
7811986Sandreas.sandberg@arm.com                rowMajor ? n_elts : unity,
7911986Sandreas.sandberg@arm.com                Strides(buf.strides(0) / sizeof(Scalar))
8011986Sandreas.sandberg@arm.com            );
8111986Sandreas.sandberg@arm.com        } else if (buf.ndim() == 2) {
8211986Sandreas.sandberg@arm.com            typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
8311986Sandreas.sandberg@arm.com
8411986Sandreas.sandberg@arm.com            if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
8511986Sandreas.sandberg@arm.com                (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
8611986Sandreas.sandberg@arm.com                return false;
8711986Sandreas.sandberg@arm.com
8811986Sandreas.sandberg@arm.com            value = Eigen::Map<Type, 0, Strides>(
8911986Sandreas.sandberg@arm.com                buf.mutable_data(),
9011986Sandreas.sandberg@arm.com                typename Strides::Index(buf.shape(0)),
9111986Sandreas.sandberg@arm.com                typename Strides::Index(buf.shape(1)),
9211986Sandreas.sandberg@arm.com                Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
9311986Sandreas.sandberg@arm.com                        buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
9411986Sandreas.sandberg@arm.com            );
9511986Sandreas.sandberg@arm.com        } else {
9611986Sandreas.sandberg@arm.com            return false;
9711986Sandreas.sandberg@arm.com        }
9811986Sandreas.sandberg@arm.com        return true;
9911986Sandreas.sandberg@arm.com    }
10011986Sandreas.sandberg@arm.com
10111986Sandreas.sandberg@arm.com    static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
10211986Sandreas.sandberg@arm.com        if (isVector) {
10311986Sandreas.sandberg@arm.com            return array(
10411986Sandreas.sandberg@arm.com                { (size_t) src.size() },                                      // shape
10511986Sandreas.sandberg@arm.com                { sizeof(Scalar) * static_cast<size_t>(src.innerStride()) },  // strides
10611986Sandreas.sandberg@arm.com                src.data()                                                    // data
10711986Sandreas.sandberg@arm.com            ).release();
10811986Sandreas.sandberg@arm.com        } else {
10911986Sandreas.sandberg@arm.com            return array(
11011986Sandreas.sandberg@arm.com                { (size_t) src.rows(),                                        // shape
11111986Sandreas.sandberg@arm.com                  (size_t) src.cols() },
11211986Sandreas.sandberg@arm.com                { sizeof(Scalar) * static_cast<size_t>(src.rowStride()),      // strides
11311986Sandreas.sandberg@arm.com                  sizeof(Scalar) * static_cast<size_t>(src.colStride()) },
11411986Sandreas.sandberg@arm.com                src.data()                                                    // data
11511986Sandreas.sandberg@arm.com            ).release();
11611986Sandreas.sandberg@arm.com        }
11711986Sandreas.sandberg@arm.com    }
11811986Sandreas.sandberg@arm.com
11911986Sandreas.sandberg@arm.com    PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
12011986Sandreas.sandberg@arm.com            _("[") + rows() + _(", ") + cols() + _("]]"));
12111986Sandreas.sandberg@arm.com
12211986Sandreas.sandberg@arm.comprotected:
12311986Sandreas.sandberg@arm.com    template <typename T = Type, enable_if_t<T::RowsAtCompileTime == Eigen::Dynamic, int> = 0>
12411986Sandreas.sandberg@arm.com    static PYBIND11_DESCR rows() { return _("m"); }
12511986Sandreas.sandberg@arm.com    template <typename T = Type, enable_if_t<T::RowsAtCompileTime != Eigen::Dynamic, int> = 0>
12611986Sandreas.sandberg@arm.com    static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); }
12711986Sandreas.sandberg@arm.com    template <typename T = Type, enable_if_t<T::ColsAtCompileTime == Eigen::Dynamic, int> = 0>
12811986Sandreas.sandberg@arm.com    static PYBIND11_DESCR cols() { return _("n"); }
12911986Sandreas.sandberg@arm.com    template <typename T = Type, enable_if_t<T::ColsAtCompileTime != Eigen::Dynamic, int> = 0>
13011986Sandreas.sandberg@arm.com    static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
13111986Sandreas.sandberg@arm.com};
13211986Sandreas.sandberg@arm.com
13311986Sandreas.sandberg@arm.com// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructable, so it needs a special
13411986Sandreas.sandberg@arm.com// type_caster to handle argument copying/forwarding.
13511986Sandreas.sandberg@arm.comtemplate <typename CVDerived, int Options, typename StrideType>
13611986Sandreas.sandberg@arm.comstruct type_caster<Eigen::Ref<CVDerived, Options, StrideType>> {
13711986Sandreas.sandberg@arm.comprotected:
13811986Sandreas.sandberg@arm.com    using Type = Eigen::Ref<CVDerived, Options, StrideType>;
13911986Sandreas.sandberg@arm.com    using Derived = typename std::remove_const<CVDerived>::type;
14011986Sandreas.sandberg@arm.com    using DerivedCaster = type_caster<Derived>;
14111986Sandreas.sandberg@arm.com    DerivedCaster derived_caster;
14211986Sandreas.sandberg@arm.com    std::unique_ptr<Type> value;
14311986Sandreas.sandberg@arm.compublic:
14411986Sandreas.sandberg@arm.com    bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
14511986Sandreas.sandberg@arm.com    static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
14611986Sandreas.sandberg@arm.com    static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }
14711986Sandreas.sandberg@arm.com
14811986Sandreas.sandberg@arm.com    static PYBIND11_DESCR name() { return DerivedCaster::name(); }
14911986Sandreas.sandberg@arm.com
15011986Sandreas.sandberg@arm.com    operator Type*() { return value.get(); }
15111986Sandreas.sandberg@arm.com    operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
15211986Sandreas.sandberg@arm.com    template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
15311986Sandreas.sandberg@arm.com};
15411986Sandreas.sandberg@arm.com
15511986Sandreas.sandberg@arm.com// type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can
15611986Sandreas.sandberg@arm.com// cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that.
15711986Sandreas.sandberg@arm.comtemplate <typename Type>
15811986Sandreas.sandberg@arm.comstruct type_caster<Type, enable_if_t<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>> {
15911986Sandreas.sandberg@arm.comprotected:
16011986Sandreas.sandberg@arm.com    using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>;
16111986Sandreas.sandberg@arm.com    using MatrixCaster = type_caster<Matrix>;
16211986Sandreas.sandberg@arm.compublic:
16311986Sandreas.sandberg@arm.com    [[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); }
16411986Sandreas.sandberg@arm.com    static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); }
16511986Sandreas.sandberg@arm.com    static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); }
16611986Sandreas.sandberg@arm.com
16711986Sandreas.sandberg@arm.com    static PYBIND11_DESCR name() { return MatrixCaster::name(); }
16811986Sandreas.sandberg@arm.com
16911986Sandreas.sandberg@arm.com    [[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
17011986Sandreas.sandberg@arm.com    [[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
17111986Sandreas.sandberg@arm.com    template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
17211986Sandreas.sandberg@arm.com};
17311986Sandreas.sandberg@arm.com
17411986Sandreas.sandberg@arm.comtemplate<typename Type>
17511986Sandreas.sandberg@arm.comstruct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
17611986Sandreas.sandberg@arm.com    typedef typename Type::Scalar Scalar;
17711986Sandreas.sandberg@arm.com    typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;
17811986Sandreas.sandberg@arm.com    typedef typename Type::Index Index;
17911986Sandreas.sandberg@arm.com    static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
18011986Sandreas.sandberg@arm.com
18111986Sandreas.sandberg@arm.com    bool load(handle src, bool) {
18211986Sandreas.sandberg@arm.com        if (!src)
18311986Sandreas.sandberg@arm.com            return false;
18411986Sandreas.sandberg@arm.com
18511986Sandreas.sandberg@arm.com        auto obj = reinterpret_borrow<object>(src);
18611986Sandreas.sandberg@arm.com        object sparse_module = module::import("scipy.sparse");
18711986Sandreas.sandberg@arm.com        object matrix_type = sparse_module.attr(
18811986Sandreas.sandberg@arm.com            rowMajor ? "csr_matrix" : "csc_matrix");
18911986Sandreas.sandberg@arm.com
19011986Sandreas.sandberg@arm.com        if (obj.get_type() != matrix_type.ptr()) {
19111986Sandreas.sandberg@arm.com            try {
19211986Sandreas.sandberg@arm.com                obj = matrix_type(obj);
19311986Sandreas.sandberg@arm.com            } catch (const error_already_set &) {
19411986Sandreas.sandberg@arm.com                return false;
19511986Sandreas.sandberg@arm.com            }
19611986Sandreas.sandberg@arm.com        }
19711986Sandreas.sandberg@arm.com
19811986Sandreas.sandberg@arm.com        auto values = array_t<Scalar>((object) obj.attr("data"));
19911986Sandreas.sandberg@arm.com        auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
20011986Sandreas.sandberg@arm.com        auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
20111986Sandreas.sandberg@arm.com        auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
20211986Sandreas.sandberg@arm.com        auto nnz = obj.attr("nnz").cast<Index>();
20311986Sandreas.sandberg@arm.com
20411986Sandreas.sandberg@arm.com        if (!values || !innerIndices || !outerIndices)
20511986Sandreas.sandberg@arm.com            return false;
20611986Sandreas.sandberg@arm.com
20711986Sandreas.sandberg@arm.com        value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
20811986Sandreas.sandberg@arm.com            shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
20911986Sandreas.sandberg@arm.com            outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
21011986Sandreas.sandberg@arm.com
21111986Sandreas.sandberg@arm.com        return true;
21211986Sandreas.sandberg@arm.com    }
21311986Sandreas.sandberg@arm.com
21411986Sandreas.sandberg@arm.com    static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
21511986Sandreas.sandberg@arm.com        const_cast<Type&>(src).makeCompressed();
21611986Sandreas.sandberg@arm.com
21711986Sandreas.sandberg@arm.com        object matrix_type = module::import("scipy.sparse").attr(
21811986Sandreas.sandberg@arm.com            rowMajor ? "csr_matrix" : "csc_matrix");
21911986Sandreas.sandberg@arm.com
22011986Sandreas.sandberg@arm.com        array data((size_t) src.nonZeros(), src.valuePtr());
22111986Sandreas.sandberg@arm.com        array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
22211986Sandreas.sandberg@arm.com        array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr());
22311986Sandreas.sandberg@arm.com
22411986Sandreas.sandberg@arm.com        return matrix_type(
22511986Sandreas.sandberg@arm.com            std::make_tuple(data, innerIndices, outerIndices),
22611986Sandreas.sandberg@arm.com            std::make_pair(src.rows(), src.cols())
22711986Sandreas.sandberg@arm.com        ).release();
22811986Sandreas.sandberg@arm.com    }
22911986Sandreas.sandberg@arm.com
23011986Sandreas.sandberg@arm.com    PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
23111986Sandreas.sandberg@arm.com            + npy_format_descriptor<Scalar>::name() + _("]"));
23211986Sandreas.sandberg@arm.com};
23311986Sandreas.sandberg@arm.com
23411986Sandreas.sandberg@arm.comNAMESPACE_END(detail)
23511986Sandreas.sandberg@arm.comNAMESPACE_END(pybind11)
23611986Sandreas.sandberg@arm.com
23711986Sandreas.sandberg@arm.com#if defined(_MSC_VER)
23811986Sandreas.sandberg@arm.com#pragma warning(pop)
23911986Sandreas.sandberg@arm.com#endif
240