Skip to content

Commit

Permalink
Implemented a2d_tuple, a2d_get, a2d_forward (#122)
Browse files Browse the repository at this point in the history
* use custom implementation of std::forward and call it a2d_forward

* drop std::forward, std::tuple, std::get and implemented a2d_{forward,tuple,get} to facilitate migrating to CUDA

* add A2D_FUNCTION to tuple/forward/get functions
  • Loading branch information
aaronyicongfu authored Nov 22, 2024
1 parent 391daf0 commit b046455
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 62 deletions.
124 changes: 124 additions & 0 deletions include/a2dtuple.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#ifndef A2D_TUPLE_H
#define A2D_TUPLE_H

#include <type_traits>
#include <utility>

#include "a2ddefs.h"

namespace A2D {

template <class _Tp>
A2D_FUNCTION inline constexpr _Tp &&a2d_forward(
typename std::remove_reference<_Tp>::type &__t) noexcept {
return static_cast<_Tp &&>(__t);
}

/*
* The implementation of the tuple below is adopted from the following article:
* https://medium.com/@mortificador/implementing-std-tuple-in-c-17-3cc5c6da7277
*
* This custom tuple is useful when using a2d on CUDA becasue std::tuple is not
* supported
* */
// Actual implementation for a type
template <std::size_t _index, typename T>
class _tuple_impl {
using value_type =
typename std::remove_const<typename std::remove_reference<T>::type>::type;

public:
A2D_FUNCTION _tuple_impl(value_type &v) : val(v) {}

A2D_FUNCTION _tuple_impl(value_type &&v) : val(std::move(v)) {}

A2D_FUNCTION _tuple_impl() : val(T{}) {}

A2D_FUNCTION value_type &a2d_get() { return val; }
A2D_FUNCTION const value_type &a2d_get() const { return val; }

private:
T val;
};

// general template, will be used only when there is no arguments
template <std::size_t _index, typename... types>
class _tuple_recurr_base {};

// This is a partial specialization, so as long as there is at least one
// argument this specialization is preferred to the
// _tuple_recurr_base<std::size_t, typename ...types>
template <std::size_t _index, typename L, typename... types>
class _tuple_recurr_base<_index, L, types...>
// : public _tuple_impl<_index, typename std::remove_reference<L>::type>,
: public _tuple_impl<_index, L>,
public _tuple_recurr_base<_index + 1, types...> {
public:
// Default Constructor that takes in no objects
A2D_FUNCTION _tuple_recurr_base()
// : _tuple_impl<_index, typename std::remove_reference<L>::type>(),
: _tuple_impl<_index, L>(), _tuple_recurr_base<_index + 1, types...>() {}

template <typename CL, typename... CArgs>
A2D_FUNCTION _tuple_recurr_base(CL &&arg, CArgs &&...args)
// : _tuple_impl<_index, typename std::remove_reference<CL>::type>(
: _tuple_impl<_index,
typename std::conditional<
std::is_reference<L>::value, CL,
typename std::remove_reference<CL>::type>::type>(arg),
_tuple_recurr_base<_index + 1, types...>(args...) {}
};

template <typename L, typename... types>
class a2d_tuple : public _tuple_recurr_base<0, L, types...> {
public:
// Default Constructor that takes in no objects
A2D_FUNCTION a2d_tuple() : _tuple_recurr_base<0, L, types...>() {}

// The constructor uses the same recursion as the inheritance
template <typename... CArgs>
A2D_FUNCTION a2d_tuple(CArgs &&...args)
: _tuple_recurr_base<0, L, types...>(a2d_forward<CArgs>(args)...) {}
//
};

// template deduction guideline
template <typename... CArgs>
a2d_tuple(CArgs... args) -> a2d_tuple<CArgs...>;

// extract_type_at is a class that, given a list of types and an index, defines
// a type member with the type of the index given from the list (zero based
// index). E.g. extract<1, int, double, float>::type == double For this we
// define ::type recursively, until we hit index zero, at that point there is a
// specialization that defines the member ::type, and stops the recursion
template <std::size_t index, typename L, typename... Args>
struct extract_type_at {
using type = typename extract_type_at<index - 1, Args...>::type;
};

// This is the stop type. If the index is zero, we define the member type to be
// the correspondent type
template <typename L, typename... Args>
struct extract_type_at<0, L, Args...> {
using type = L;
};

// Method to get the value of a tuple, given an index
// We cast the tuple to the base class that corresponds to the index
// and type for that index
template <std::size_t index, typename... Args>
A2D_FUNCTION auto &a2d_get(a2d_tuple<Args...> &t) {
return (static_cast<_tuple_impl<
index, typename extract_type_at<index, Args...>::type> &>(t))
.a2d_get();
}
template <std::size_t index, typename... Args>
A2D_FUNCTION const auto &a2d_get(const a2d_tuple<Args...> &t) {
return (static_cast<const _tuple_impl<
index, typename extract_type_at<index, Args...>::type> &>(t))
.a2d_get();
}

} // namespace A2D

#endif // A2D_TUPLE_H
8 changes: 4 additions & 4 deletions include/ad/a2dscalarops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ template <class Expr, class T>
class EvalExpr {
public:
A2D_FUNCTION EvalExpr(Expr&& expr, ADObj<T>& out)
: expr(std::forward<Expr>(expr)), out(out) {}
: expr(a2d_forward<Expr>(expr)), out(out) {}

A2D_FUNCTION void eval() {
expr.eval();
Expand Down Expand Up @@ -41,14 +41,14 @@ class EvalExpr {

template <class Expr, class T>
auto Eval(Expr&& expr, ADObj<T>& out) {
return EvalExpr<Expr, T>(std::forward<Expr>(expr), out);
return EvalExpr<Expr, T>(a2d_forward<Expr>(expr), out);
}

template <class Expr, class T>
class EvalExpr2 {
public:
A2D_FUNCTION EvalExpr2(Expr&& expr, A2DObj<T>& out)
: expr(std::forward<Expr>(expr)), out(out) {}
: expr(a2d_forward<Expr>(expr)), out(out) {}

A2D_FUNCTION void eval() {
expr.eval();
Expand Down Expand Up @@ -85,7 +85,7 @@ class EvalExpr2 {

template <class Expr, class T>
auto Eval(Expr&& expr, A2DObj<T>& out) {
return EvalExpr2<Expr, T>(std::forward<Expr>(expr), out);
return EvalExpr2<Expr, T>(a2d_forward<Expr>(expr), out);
}

namespace Test {
Expand Down
25 changes: 12 additions & 13 deletions include/ad/a2dstack.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
#define A2D_STACK_H

#include "../a2ddefs.h"
#include "../a2dtuple.h"
#include "a2dobj.h"
#include "a2dtuple.h"

namespace A2D {

template <class... Operations>
class OperationStack {
public:
using StackTuple = std::tuple<Operations...>;
using StackTuple = a2d_tuple<Operations...>;
static constexpr index_t num_ops = sizeof...(Operations);

A2D_FUNCTION OperationStack(Operations &&...s)
: stack(std::forward<Operations>(s)...) {
// printf("in stack constructor\n");
: stack(a2d_forward<Operations>(s)...) {
eval_<0>();
}

Expand Down Expand Up @@ -67,56 +68,55 @@ class OperationStack {

template <index_t index>
A2D_FUNCTION void eval_() {
// printf("evaluating the stack\n");
std::get<index>(stack).eval();
a2d_get<index>(stack).eval();
if constexpr (index < num_ops - 1) {
eval_<index + 1>();
}
}

template <index_t index>
A2D_FUNCTION void bzero_() {
std::get<index>(stack).bzero();
a2d_get<index>(stack).bzero();
if constexpr (index < num_ops - 1) {
bzero_<index + 1>();
}
}

template <index_t index>
A2D_FUNCTION void forward_() {
std::get<index>(stack).template forward<ADorder::FIRST>();
a2d_get<index>(stack).template forward<ADorder::FIRST>();
if constexpr (index < num_ops - 1) {
forward_<index + 1>();
}
}

template <index_t index>
A2D_FUNCTION void reverse_() {
std::get<index>(stack).reverse();
a2d_get<index>(stack).reverse();
if constexpr (index) {
reverse_<index - 1>();
}
}

template <index_t index>
A2D_FUNCTION void hzero_() {
std::get<index>(stack).hzero();
a2d_get<index>(stack).hzero();
if constexpr (index < num_ops - 1) {
hzero_<index + 1>();
}
}

template <index_t index>
A2D_FUNCTION void hforward_() {
std::get<index>(stack).template forward<ADorder::SECOND>();
a2d_get<index>(stack).template forward<ADorder::SECOND>();
if constexpr (index < num_ops - 1) {
hforward_<index + 1>();
}
}

template <index_t index>
A2D_FUNCTION void hreverse_() {
std::get<index>(stack).hreverse();
a2d_get<index>(stack).hreverse();
if constexpr (index) {
hreverse_<index - 1>();
}
Expand All @@ -134,8 +134,7 @@ class OperationStack {
*/
template <class... Operations>
A2D_FUNCTION auto MakeStack(Operations &&...s) {
// printf("in make stack\n");
return OperationStack<Operations...>(std::forward<Operations>(s)...);
return OperationStack<Operations...>(a2d_forward<Operations>(s)...);
}

/**
Expand Down
46 changes: 27 additions & 19 deletions include/ad/a2dvartuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <type_traits>

#include "../a2ddefs.h"
#include "../a2dtuple.h"
#include "a2dobj.h"

namespace A2D {

Expand Down Expand Up @@ -59,19 +61,19 @@ class VarTupleBase {
A2D_FUNCTION T& get_value_(TupleObj& var, const I comp) {
if constexpr (__is_scalar_type<First>::value) {
if (comp == 0) {
return std::get<index>(var);
return a2d_get<index>(var);
} else if constexpr (sizeof...(Remain) == 0) {
return std::get<index>(var);
return a2d_get<index>(var);
} else {
return get_value_<I, index + 1, TupleObj, Remain...>(var, comp - 1);
}
} else {
if constexpr (sizeof...(Remain) == 0) {
return std::get<index>(var)[comp];
return a2d_get<index>(var)[comp];
} else {
if constexpr (First::ncomp > 0) {
if (comp < First::ncomp) {
return std::get<index>(var)[comp];
return a2d_get<index>(var)[comp];
} else {
return get_value_<I, index + 1, TupleObj, Remain...>(
var, comp - First::ncomp);
Expand All @@ -89,20 +91,20 @@ class VarTupleBase {
const I comp) const {
if constexpr (__is_scalar_type<First>::value) {
if (comp == 0) {
return std::get<index>(var);
return a2d_get<index>(var);
} else if constexpr (sizeof...(Remain) == 0) {
return std::get<index>(var);
return a2d_get<index>(var);
} else {
return get_value_const_<I, index + 1, TupleObj, Remain...>(var,
comp - 1);
}
} else {
if constexpr (sizeof...(Remain) == 0) {
return std::get<index>(var)[comp];
return a2d_get<index>(var)[comp];
} else {
if constexpr (First::ncomp > 0) {
if (comp < First::ncomp) {
return std::get<index>(var)[comp];
return a2d_get<index>(var)[comp];
} else {
return get_value_const_<I, index + 1, TupleObj, Remain...>(
var, comp - First::ncomp);
Expand All @@ -118,9 +120,9 @@ class VarTupleBase {
A2D_FUNCTION void set_values_(TupleObj& var, const First& f,
const Remain&... r) {
if constexpr (__is_scalar_type<First>::value) {
std::get<index>(var) = f;
a2d_get<index>(var) = f;
} else if constexpr (First::ncomp > 0) {
First& val = std::get<index>(var);
First& val = a2d_get<index>(var);
for (index_t i = 0; i < First::ncomp; i++) {
val[i] = f[i];
}
Expand All @@ -134,9 +136,9 @@ class VarTupleBase {
A2D_FUNCTION void get_values_(const TupleObj& var, First& f,
Remain&... r) const {
if constexpr (__is_scalar_type<First>::value) {
f = std::get<index>(var);
f = a2d_get<index>(var);
} else if constexpr (First::ncomp > 0) {
const First& val = std::get<index>(var);
const First& val = a2d_get<index>(var);
for (index_t i = 0; i < First::ncomp; i++) {
f[i] = val[i];
}
Expand All @@ -149,9 +151,9 @@ class VarTupleBase {
template <index_t index, class TupleObj, class First, class... Remain>
A2D_FUNCTION void zero_(TupleObj& var) {
if constexpr (__is_scalar_type<First>::value) {
std::get<index>(var) = T(0.0);
a2d_get<index>(var) = T(0.0);
} else if constexpr (First::ncomp > 0) {
std::get<index>(var).zero();
a2d_get<index>(var).zero();
}
if constexpr (sizeof...(Remain) > 0) {
zero_<index + 1, TupleObj, Remain...>(var);
Expand All @@ -161,10 +163,10 @@ class VarTupleBase {
template <index_t index, class TupleObj, class First, class... Remain>
A2D_FUNCTION void set_rand_(TupleObj& var, const T low, const T high) {
if constexpr (__is_scalar_type<First>::value) {
std::get<index>(var) =
a2d_get<index>(var) =
low + (high - low) * (static_cast<double>(rand()) / RAND_MAX);
} else if constexpr (First::ncomp > 0) {
First& val = std::get<index>(var);
First& val = a2d_get<index>(var);
for (index_t i = 0; i < First::ncomp; i++) {
val[i] = low + (high - low) * (static_cast<double>(rand()) / RAND_MAX);
}
Expand All @@ -178,9 +180,12 @@ class VarTupleBase {
template <typename T, class... Vars>
class VarTuple : public VarTupleBase<T, Vars...> {
public:
using VarTupleObj = std::tuple<Vars...>;
using VarTupleObj = a2d_tuple<Vars...>;

// Default constructor that takes in no arguments
A2D_FUNCTION VarTuple() {}

// A2D_FUNCTION VarTuple() {}
A2D_FUNCTION VarTuple(const Vars&... s) {
this->template set_values_<0, VarTupleObj, Vars...>(var, s...);
}
Expand Down Expand Up @@ -238,7 +243,10 @@ A2D_FUNCTION auto MakeVarTuple(Vars&... s) {
template <typename T, class... Vars>
class TieTuple : public VarTupleBase<T, Vars...> {
public:
using VarTupleObj = std::tuple<Vars&...>;
using VarTupleObj = a2d_tuple<Vars&...>;

// Default constructor that takes in no arguments
A2D_FUNCTION TieTuple() {}

A2D_FUNCTION TieTuple(Vars&... s) : var(s...) {}

Expand Down Expand Up @@ -299,4 +307,4 @@ A2D_FUNCTION auto MakeTieTuple(Vars&... s) {

} // namespace A2D

#endif // A2D_VAR_TUPLE_H
#endif // A2D_VAR_TUPLE_H
Loading

0 comments on commit b046455

Please sign in to comment.