Skip to content

Commit

Permalink
breakup value_type and scalar_type for Eigen types based on whether t…
Browse files Browse the repository at this point in the history
…hey have a Scalar type trait
  • Loading branch information
SteveBronder committed Nov 11, 2024
1 parent 3f065fa commit 62094ed
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
52 changes: 50 additions & 2 deletions stan/math/prim/meta/is_eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ template <typename T>
struct is_eigen
: bool_constant<is_base_pointer_convertible<Eigen::EigenBase, T>::value> {};

namespace internal {
// primary template handles types that have no nested ::type member:
template <class, class = void>
struct has_internal_trait : std::false_type {};

// specialization recognizes types that do have a nested ::type member:
template <class T>
struct has_internal_trait<T, std::void_t<Eigen::internal::traits<std::decay_t<T>>>>
: std::true_type {};

// primary template handles types that have no nested ::type member:
template <class, class = void>
struct has_scalar_trait : std::false_type {};

// specialization recognizes types that do have a nested ::type member:
template <class T>
struct has_scalar_trait<T, std::void_t<typename std::decay_t<T>::Scalar>>
: std::true_type {};

} // namespace internal


/**
* Template metaprogram defining the base scalar type of
* values stored in an Eigen matrix.
Expand All @@ -28,7 +50,7 @@ struct is_eigen
* @ingroup type_trait
*/
template <typename T>
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value && internal::has_scalar_trait<T>::value>> {
using type = scalar_type_t<typename std::decay_t<T>::Scalar>;
};

Expand All @@ -40,10 +62,36 @@ struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
* @ingroup type_trait
*/
template <typename T>
struct value_type<T, std::enable_if_t<is_eigen<T>::value>> {
struct value_type<T, std::enable_if_t<is_eigen<T>::value && internal::has_scalar_trait<T>::value>> {
using type = typename std::decay_t<T>::Scalar;
};


/**
* Template metaprogram defining the base scalar type of
* values stored in an Eigen matrix.
*
* @tparam T type to check.
* @ingroup type_trait
*/
template <typename T>
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value && !internal::has_scalar_trait<T>::value>> {
using type = scalar_type_t<typename Eigen::internal::traits<std::decay_t<T>>::Scalar>;
};

/**
* Template metaprogram defining the type of values stored in an
* Eigen matrix, vector, or row vector.
*
* @tparam T type to check
* @ingroup type_trait
*/
template <typename T>
struct value_type<T, std::enable_if_t<is_eigen<T>::value && !internal::has_scalar_trait<T>::value>> {
using type = typename Eigen::internal::traits<std::decay_t<T>>::Scalar;
};


/*! \ingroup require_eigens_types */
/*! \defgroup eigen_types eigen */
/*! \addtogroup eigen_types */
Expand Down
1 change: 1 addition & 0 deletions stan/math/rev/core/arena_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ namespace internal {
template <typename T>
struct traits<stan::math::arena_matrix<T>> {
using base = traits<Eigen::Map<T>>;
using Scalar = typename base::Scalar;
using XprKind = typename Eigen::internal::traits<std::decay_t<T>>::XprKind;
enum {
PlainObjectTypeInnerSize = base::PlainObjectTypeInnerSize,
Expand Down

0 comments on commit 62094ed

Please sign in to comment.