Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CeedBasisApplyAdd #1644

Merged
merged 9 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions backends/cuda-ref/ceed-cuda-ref-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
//------------------------------------------------------------------------------
// Basis apply - tensor
//------------------------------------------------------------------------------
int CeedBasisApply_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u, CeedVector v) {
static int CeedBasisApplyCore_Cuda(CeedBasis basis, bool apply_add, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode,
CeedVector u, CeedVector v) {
Ceed ceed;
CeedInt Q_1d, dim;
const CeedInt is_transpose = t_mode == CEED_TRANSPOSE;
Expand All @@ -33,10 +34,11 @@ int CeedBasisApply_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMo
// Get read/write access to u, v
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose) {
if (is_transpose && !apply_add) {
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
Expand Down Expand Up @@ -83,11 +85,23 @@ int CeedBasisApply_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMo
return CEED_ERROR_SUCCESS;
}

static int CeedBasisApply_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
CeedVector v) {
CeedCallBackend(CeedBasisApplyCore_Cuda(basis, false, num_elem, t_mode, eval_mode, u, v));
return CEED_ERROR_SUCCESS;
}

static int CeedBasisApplyAdd_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
CeedVector v) {
CeedCallBackend(CeedBasisApplyCore_Cuda(basis, true, num_elem, t_mode, eval_mode, u, v));
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Basis apply - tensor AtPoints
//------------------------------------------------------------------------------
int CeedBasisApplyAtPoints_Cuda(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode, CeedEvalMode eval_mode,
CeedVector x_ref, CeedVector u, CeedVector v) {
static int CeedBasisApplyAtPointsCore_Cuda(CeedBasis basis, bool apply_add, const CeedInt num_elem, const CeedInt *num_points,
CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
Ceed ceed;
CeedInt Q_1d, dim, max_num_points = num_points[0];
const CeedInt is_transpose = t_mode == CEED_TRANSPOSE;
Expand Down Expand Up @@ -158,10 +172,11 @@ int CeedBasisApplyAtPoints_Cuda(CeedBasis basis, const CeedInt num_elem, const C
CeedCallBackend(CeedVectorGetArrayRead(x_ref, CEED_MEM_DEVICE, &d_x));
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose) {
if (is_transpose && !apply_add) {
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
Expand Down Expand Up @@ -200,11 +215,23 @@ int CeedBasisApplyAtPoints_Cuda(CeedBasis basis, const CeedInt num_elem, const C
return CEED_ERROR_SUCCESS;
}

static int CeedBasisApplyAtPoints_Cuda(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode,
CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
CeedCallBackend(CeedBasisApplyAtPointsCore_Cuda(basis, false, num_elem, num_points, t_mode, eval_mode, x_ref, u, v));
return CEED_ERROR_SUCCESS;
}

static int CeedBasisApplyAddAtPoints_Cuda(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode,
CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
CeedCallBackend(CeedBasisApplyAtPointsCore_Cuda(basis, true, num_elem, num_points, t_mode, eval_mode, x_ref, u, v));
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Basis apply - non-tensor
//------------------------------------------------------------------------------
int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
CeedVector v) {
static int CeedBasisApplyNonTensorCore_Cuda(CeedBasis basis, bool apply_add, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode,
CeedVector u, CeedVector v) {
Ceed ceed;
CeedInt num_nodes, num_qpts;
const CeedInt is_transpose = t_mode == CEED_TRANSPOSE;
Expand All @@ -222,10 +249,11 @@ int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTr
// Get read/write access to u, v
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose) {
if (is_transpose && !apply_add) {
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
Expand Down Expand Up @@ -291,6 +319,18 @@ int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTr
return CEED_ERROR_SUCCESS;
}

static int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
CeedVector v) {
CeedCallBackend(CeedBasisApplyNonTensorCore_Cuda(basis, false, num_elem, t_mode, eval_mode, u, v));
return CEED_ERROR_SUCCESS;
}

static int CeedBasisApplyAddNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
CeedVector v) {
CeedCallBackend(CeedBasisApplyNonTensorCore_Cuda(basis, true, num_elem, t_mode, eval_mode, u, v));
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Destroy tensor basis
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -374,7 +414,9 @@ int CeedBasisCreateTensorH1_Cuda(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const

// Register backend functions
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAdd_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAtPoints", CeedBasisApplyAtPoints_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAddAtPoints", CeedBasisApplyAddAtPoints_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Cuda));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -434,6 +476,7 @@ int CeedBasisCreateH1_Cuda(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes

// Register backend functions
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Cuda));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -493,6 +536,7 @@ int CeedBasisCreateHdiv_Cuda(CeedElemTopology topo, CeedInt dim, CeedInt num_nod

// Register backend functions
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Cuda));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -552,6 +596,7 @@ int CeedBasisCreateHcurl_Cuda(CeedElemTopology topo, CeedInt dim, CeedInt num_no

// Register backend functions
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Cuda));
return CEED_ERROR_SUCCESS;
}
Expand Down
Loading