Skip to content

Commit

Permalink
Merge pull request #284 from artivis/fix/cast
Browse files Browse the repository at this point in the history
Fix casting float->double
  • Loading branch information
artivis authored Mar 9, 2024
2 parents 66408f7 + afe43aa commit 974a241
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 14 deletions.
35 changes: 35 additions & 0 deletions include/manif/impl/cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef _MANIF_MANIF_IMPL_CAST_H_
#define _MANIF_MANIF_IMPL_CAST_H_

namespace manif {
namespace internal {

template <typename Derived, typename NewScalar>
struct CastEvaluatorImpl {
template <typename T>
static auto run(const T& o) -> typename T::template LieGroupTemplate<NewScalar> {
return typename T::template LieGroupTemplate<NewScalar>(
o.coeffs().template cast<NewScalar>()
);
}
};

template <typename Derived, typename NewScalar>
struct CastEvaluator : CastEvaluatorImpl<Derived, NewScalar> {
using Base = CastEvaluatorImpl<Derived, NewScalar>;

CastEvaluator(const Derived& xptr) : xptr_(xptr) {}

auto run() const -> typename Derived::template LieGroupTemplate<NewScalar> {
return Base::run(xptr_);
}

protected:

const Derived& xptr_;
};

} // namespace internal
} // namespace manif

#endif // _MANIF_MANIF_IMPL_CAST_H_
5 changes: 4 additions & 1 deletion include/manif/impl/lie_group_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "manif/impl/eigen.h"
#include "manif/impl/tangent_base.h"
#include "manif/impl/assignment_assert.h"
#include "manif/impl/cast.h"

#include "manif/constants.h"

Expand Down Expand Up @@ -415,7 +416,9 @@ template <class _NewScalar>
typename LieGroupBase<_Derived>::template LieGroupTemplate<_NewScalar>
LieGroupBase<_Derived>::cast() const
{
return LieGroupTemplate<_NewScalar>(coeffs().template cast<_NewScalar>());
return internal::CastEvaluator<
typename internal::traits<_Derived>::Base, _NewScalar
>(derived()).run();
}

template <typename _Derived>
Expand Down
11 changes: 11 additions & 0 deletions include/manif/impl/se2/SE2_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,17 @@ struct AssignmentEvaluatorImpl<SE2Base<Derived>>
}
};

//! @brief Cast specialization for SE2Base objects.
template <typename Derived, typename NewScalar>
struct CastEvaluatorImpl<SE2Base<Derived>, NewScalar> {
template <typename T>
static auto run(const T& o) -> typename Derived::template LieGroupTemplate<NewScalar> {
return typename Derived::template LieGroupTemplate<NewScalar>(
NewScalar(o.x()), NewScalar(o.y()), NewScalar(o.angle())
);
}
};

} /* namespace internal */
} /* namespace manif */

Expand Down
16 changes: 15 additions & 1 deletion include/manif/impl/se3/SE3_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ struct RandomEvaluatorImpl<SE3Base<Derived>>
}
};

//! @brief Assignment assert specialization for SE2Base objects
//! @brief Assignment assert specialization for SE3Base objects
template <typename Derived>
struct AssignmentEvaluatorImpl<SE3Base<Derived>>
{
Expand All @@ -461,6 +461,20 @@ struct AssignmentEvaluatorImpl<SE3Base<Derived>>
}
};

//! @brief Cast specialization for SE3Base objects.
template <typename Derived, typename NewScalar>
struct CastEvaluatorImpl<SE3Base<Derived>, NewScalar> {
template <typename T>
static auto run(const T& o) -> typename Derived::template LieGroupTemplate<NewScalar> {
const typename SE3Base<Derived>::QuaternionDataType q = o.quat();
const typename SE3Base<Derived>::Translation t = o.translation();

return typename Derived::template LieGroupTemplate<NewScalar>(
t.template cast<NewScalar>(), q.template cast<NewScalar>().normalized()
);
}
};

} /* namespace internal */
} /* namespace manif */

Expand Down
17 changes: 17 additions & 0 deletions include/manif/impl/se_2_3/SE_2_3_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,23 @@ struct AssignmentEvaluatorImpl<SE_2_3Base<Derived>>
}
};

//! @brief Cast specialization for SE_2_3Base objects.
template <typename Derived, typename NewScalar>
struct CastEvaluatorImpl<SE_2_3Base<Derived>, NewScalar> {
template <typename T>
static auto run(const T& o) -> typename Derived::template LieGroupTemplate<NewScalar> {
const typename SE_2_3Base<Derived>::QuaternionDataType q = o.quat();
const typename SE_2_3Base<Derived>::Translation t = o.translation();
const typename SE_2_3Base<Derived>::LinearVelocity v = o.linearVelocity();

return typename Derived::template LieGroupTemplate<NewScalar>(
t.template cast<NewScalar>(),
q.template cast<NewScalar>().normalized(),
v.template cast<NewScalar>()
);
}
};

} /* namespace internal */
} /* namespace manif */

Expand Down
9 changes: 9 additions & 0 deletions include/manif/impl/so2/SO2_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,15 @@ struct AssignmentEvaluatorImpl<SO2Base<Derived>>
}
};

//! @brief Cast specialization for SO2Base objects.
template <typename Derived, typename NewScalar>
struct CastEvaluatorImpl<SO2Base<Derived>, NewScalar> {
template <typename T>
static auto run(const T& o) -> typename Derived::template LieGroupTemplate<NewScalar> {
return typename Derived::template LieGroupTemplate<NewScalar>(NewScalar(o.angle()));
}
};

} /* namespace internal */
} /* namespace manif */

Expand Down
15 changes: 14 additions & 1 deletion include/manif/impl/so3/SO3_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ struct RandomEvaluatorImpl<SO3Base<Derived>>
}
};

//! @brief Assignment assert specialization for SE2Base objects
//! @brief Assignment assert specialization for SO3Base objects
template <typename Derived>
struct AssignmentEvaluatorImpl<SO3Base<Derived>>
{
Expand All @@ -421,6 +421,19 @@ struct AssignmentEvaluatorImpl<SO3Base<Derived>>
}
};

//! @brief Cast specialization for SO3Base objects.
template <typename Derived, typename NewScalar>
struct CastEvaluatorImpl<SO3Base<Derived>, NewScalar> {
template <typename T>
static auto run(const T& o) -> typename Derived::template LieGroupTemplate<NewScalar> {
const typename SO3Base<Derived>::QuaternionDataType q = o.quat();

return typename Derived::template LieGroupTemplate<NewScalar>(
q.template cast<NewScalar>().normalized()
);
}
};

} /* namespace internal */
} /* namespace manif */

Expand Down
25 changes: 22 additions & 3 deletions test/common_tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@
TEST_P(TEST_##manifold##_TESTER, TEST_##manifold##_SMALL_ADJ) \
{ evalSmallAdj(); } \
TEST_F(TEST_##manifold##_TESTER, TEST_##manifold##_IDENTITY_ACT_POINT) \
{ evalIdentityActPoint(); }
{ evalIdentityActPoint(); } \
TEST_P(TEST_##manifold##_TESTER, TEST_##manifold##_CAST) \
{ evalCast(); }

#define MANIF_TEST_JACOBIANS(manifold) \
using manifold##JacobiansTester = JacobianTester<manifold>; \
Expand Down Expand Up @@ -703,6 +705,23 @@ class CommonTester
EXPECT_EIGEN_NEAR(pin, pout);
}

void evalCast() {
using NewScalar = typename std::conditional<
std::is_same<Scalar, float>::value, double, float
>::type;

EXPECT_NO_THROW(
auto state = getState().template cast<NewScalar>();
);

int i=0;
EXPECT_NO_THROW(
for (; i < 10000; ++i) {
auto state = LieGroup::Random().template cast<NewScalar>();
}
) << "+= failed at iteration " << i;
}

protected:

// relax eps for float type
Expand Down Expand Up @@ -1044,8 +1063,8 @@ class JacobianTester
Jrinv = tan.rjacinv();
Jlinv = tan.ljacinv();

EXPECT_EIGEN_NEAR(Jacobian::Identity(), Jr*Jrinv);
EXPECT_EIGEN_NEAR(Jacobian::Identity(), Jl*Jlinv);
EXPECT_EIGEN_NEAR(Jacobian::Identity(), Jr*Jrinv, tol_);
EXPECT_EIGEN_NEAR(Jacobian::Identity(), Jl*Jlinv, tol_);
}

void evalActJac()
Expand Down
4 changes: 2 additions & 2 deletions test/se2/gtest_se2_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ TEST(TEST_SE2, TEST_SE2_MAP_CAST)

EXPECT_DOUBLE_EQ(4, se2d.x());
EXPECT_DOUBLE_EQ(2, se2d.y());
EXPECT_DOUBLE_EQ(MANIF_PI, se2d.angle());
EXPECT_DOUBLE_EQ(MANIF_PI, std::abs(se2d.angle()));

SE2f se2f = se2d.cast<float>();

EXPECT_FLOAT_EQ(4, se2f.x());
EXPECT_FLOAT_EQ(2, se2f.y());
EXPECT_FLOAT_EQ(MANIF_PI, se2f.angle());
EXPECT_FLOAT_EQ(MANIF_PI, std::abs(se2f.angle()));
}

TEST(TEST_SE2, TEST_SE2_MAP_IDENTITY)
Expand Down
6 changes: 0 additions & 6 deletions test/so3/gtest_so3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,6 @@ TEST(TEST_SO3, TEST_SO3_NORMALIZE)

#endif

MANIF_TEST(SO3f);

MANIF_TEST_MAP(SO3f);

MANIF_TEST_JACOBIANS(SO3f);

MANIF_TEST(SO3d);

MANIF_TEST_MAP(SO3d);
Expand Down

0 comments on commit 974a241

Please sign in to comment.