Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Oct 22, 2024
1 parent e036be4 commit 0418c16
Show file tree
Hide file tree
Showing 59 changed files with 520 additions and 420 deletions.
8 changes: 5 additions & 3 deletions backends/blocked/ceed-blocked-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
if (ceed_parent) ceed = ceed_parent;
if (ceed_parent) CeedCall(CeedReferenceCopy(ceed_parent, &ceed));
}
if (is_input) {
CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
Expand Down Expand Up @@ -105,6 +105,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
// Empty case - won't occur
break;
}
CeedCallBackend(CeedDestroy(&ceed_rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
CeedCallBackend(CeedElemRestrictionCreateVector(block_rstr[i + start_e], NULL, &e_vecs_full[i + start_e]));
}
Expand Down Expand Up @@ -190,6 +191,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
}
}
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand All @@ -198,7 +200,6 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
//------------------------------------------------------------------------------
static int CeedOperatorSetup_Blocked(CeedOperator op) {
bool is_setup_done;
Ceed ceed;
CeedInt Q, num_input_fields, num_output_fields;
const CeedInt block_size = 8;
CeedQFunctionField *qf_input_fields, *qf_output_fields;
Expand All @@ -209,7 +210,6 @@ static int CeedOperatorSetup_Blocked(CeedOperator op) {
CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
if (is_setup_done) return CEED_ERROR_SUCCESS;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedOperatorGetData(op, &impl));
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
Expand Down Expand Up @@ -707,6 +707,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
CeedCallBackend(CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl));

// Output blocked restriction
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedVectorRestoreArray(l_vec, &l_vec_array));
CeedCallBackend(CeedVectorSetValue(*assembled, 0.0));
CeedCallBackend(CeedElemRestrictionApply(block_rstr, CEED_TRANSPOSE, l_vec, *assembled, request));
Expand Down Expand Up @@ -783,6 +784,7 @@ int CeedOperatorCreate_Blocked(CeedOperator op) {
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Blocked));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Blocked));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Blocked));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down
2 changes: 1 addition & 1 deletion backends/cuda-gen/ceed-cuda-gen-operator-build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,8 +901,8 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op) {

CeedCallBackend(CeedCompile_Cuda(ceed, code.str().c_str(), &data->module, 1, "T_1D", CeedIntMax(Q_1d, data->max_P_1d)));
CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, operator_name.c_str(), &data->op));

CeedCallBackend(CeedOperatorSetSetupDone(op));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down
2 changes: 2 additions & 0 deletions backends/cuda-gen/ceed-cuda-gen-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ static int CeedOperatorApplyAdd_Cuda_gen(CeedOperator op, CeedVector input_vec,

// Restore context data
CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand All @@ -266,6 +267,7 @@ int CeedOperatorCreate_Cuda_gen(CeedOperator op) {
CeedCallBackend(CeedOperatorSetData(op, impl));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda_gen));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda_gen));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down
2 changes: 1 addition & 1 deletion backends/cuda-gen/ceed-cuda-gen-qfunction.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ int CeedQFunctionCreate_Cuda_gen(CeedQFunction qf) {
CeedCallBackend(CeedCalloc(1, &data));
CeedCallBackend(CeedQFunctionSetData(qf, data));

// Read QFunction source
CeedCallBackend(CeedQFunctionGetKernelName(qf, &data->qfunction_name));

CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Cuda_gen));
CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Cuda_gen));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down
9 changes: 9 additions & 0 deletions backends/cuda-ref/ceed-cuda-ref-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ static int CeedBasisApplyCore_Cuda(CeedBasis basis, bool apply_add, const CeedIn
CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -244,6 +245,7 @@ static int CeedBasisApplyAtPointsCore_Cuda(CeedBasis basis, bool apply_add, cons
CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -351,6 +353,7 @@ static int CeedBasisApplyNonTensorCore_Cuda(CeedBasis basis, bool apply_add, con
CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -384,6 +387,7 @@ static int CeedBasisDestroy_Cuda(CeedBasis basis) {
CeedCallCuda(ceed, cudaFree(data->d_grad_1d));
CeedCallCuda(ceed, cudaFree(data->d_chebyshev_interp_1d));
CeedCallBackend(CeedFree(&data));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand All @@ -403,6 +407,7 @@ static int CeedBasisDestroyNonTensor_Cuda(CeedBasis basis) {
CeedCallCuda(ceed, cudaFree(data->d_div));
CeedCallCuda(ceed, cudaFree(data->d_curl));
CeedCallBackend(CeedFree(&data));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -449,6 +454,7 @@ int CeedBasisCreateTensorH1_Cuda(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAtPoints", CeedBasisApplyAtPoints_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAddAtPoints", CeedBasisApplyAddAtPoints_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Cuda));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -503,6 +509,7 @@ int CeedBasisCreateH1_Cuda(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Cuda));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -557,6 +564,7 @@ int CeedBasisCreateHdiv_Cuda(CeedElemTopology topo, CeedInt dim, CeedInt num_nod
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Cuda));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -611,6 +619,7 @@ int CeedBasisCreateHcurl_Cuda(CeedElemTopology topo, CeedInt dim, CeedInt num_no
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Cuda));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down
22 changes: 16 additions & 6 deletions backends/cuda-ref/ceed-cuda-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) {
CeedCallCuda(ceed, cudaFree(impl->diag->d_div_out));
CeedCallCuda(ceed, cudaFree(impl->diag->d_curl_in));
CeedCallCuda(ceed, cudaFree(impl->diag->d_curl_out));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag));
CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));
}
CeedCallBackend(CeedFree(&impl->diag));

Expand All @@ -92,6 +93,7 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) {
CeedCallCuda(ceed, cuModuleUnload(impl->asmb->module));
CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_in));
CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_out));
CeedCallBackend(CeedDestroy(&ceed));
}
CeedCallBackend(CeedFree(&impl->asmb));

Expand Down Expand Up @@ -227,14 +229,14 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
}
}
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction.
//------------------------------------------------------------------------------
static int CeedOperatorSetup_Cuda(CeedOperator op) {
Ceed ceed;
bool is_setup_done;
CeedInt Q, num_elem, num_input_fields, num_output_fields;
CeedQFunctionField *qf_input_fields, *qf_output_fields;
Expand All @@ -245,7 +247,6 @@ static int CeedOperatorSetup_Cuda(CeedOperator op) {
CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
if (is_setup_done) return CEED_ERROR_SUCCESS;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedOperatorGetData(op, &impl));
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
Expand Down Expand Up @@ -603,14 +604,14 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec

// Return work vector
CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction.
//------------------------------------------------------------------------------
static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
Ceed ceed;
bool is_setup_done;
CeedInt max_num_points = -1, num_elem, num_input_fields, num_output_fields;
CeedQFunctionField *qf_input_fields, *qf_output_fields;
Expand All @@ -621,7 +622,6 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
if (is_setup_done) return CEED_ERROR_SUCCESS;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedOperatorGetData(op, &impl));
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
Expand Down Expand Up @@ -934,6 +934,7 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,

// Restore work vector
CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -1075,6 +1076,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op,
}

// Restore output
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -1276,6 +1278,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op) {
CeedCallCuda(ceed, cudaMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, cudaMemcpyHostToDevice));
CeedCallBackend(CeedFree(&eval_modes_in));
CeedCallBackend(CeedFree(&eval_modes_out));
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedBasisDestroy(&basis_in));
CeedCallBackend(CeedBasisDestroy(&basis_out));
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -1361,6 +1364,7 @@ static inline int CeedOperatorAssembleDiagonalSetupCompile_Cuda(CeedOperator op,
num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE",
use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block));
CeedCallCuda(ceed, CeedGetKernel_Cuda(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal));
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedBasisDestroy(&basis_in));
CeedCallBackend(CeedBasisDestroy(&basis_out));
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -1449,6 +1453,7 @@ static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVec
CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request));

// Cleanup
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedVectorDestroy(&assembled_qf));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -1661,6 +1666,7 @@ static int CeedSingleOperatorAssembleSetup_Cuda(CeedOperator op, CeedInt use_cee
CeedCallBackend(CeedFree(&identity));
}
CeedCallBackend(CeedFree(&eval_modes_out));
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
CeedCallBackend(CeedBasisDestroy(&basis_in));
Expand Down Expand Up @@ -1769,6 +1775,7 @@ static int CeedSingleOperatorAssemble_Cuda(CeedOperator op, CeedInt offset, Ceed
CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out));
}
}
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -2040,6 +2047,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
// Restore work vector
CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec_in));
CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec_out));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand All @@ -2062,6 +2070,7 @@ int CeedOperatorCreate_Cuda(CeedOperator op) {
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand All @@ -2080,6 +2089,7 @@ int CeedOperatorCreateAtPoints_Cuda(CeedOperator op) {
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down
1 change: 1 addition & 0 deletions backends/cuda-ref/ceed-cuda-ref-qfunction-load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ extern "C" int CeedQFunctionBuildKernel_Cuda_ref(CeedQFunction qf) {
// Compile kernel
CeedCallBackend(CeedCompile_Cuda(ceed, code.str().c_str(), &data->module, 0));
CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, kernel_name.c_str(), &data->QFunction));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down
3 changes: 2 additions & 1 deletion backends/cuda-ref/ceed-cuda-ref-qfunction.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ static int CeedQFunctionApply_Cuda(CeedQFunction qf, CeedInt Q, CeedVector *U, C

// Restore context
CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &data->d_c));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -95,13 +96,13 @@ int CeedQFunctionCreate_Cuda(CeedQFunction qf) {
CeedCallBackend(CeedCalloc(1, &data));
CeedCallBackend(CeedQFunctionSetData(qf, data));

// Read QFunction name
CeedCallBackend(CeedQFunctionGetKernelName(qf, &data->qfunction_name));

// Register backend functions
CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "SetCUDAUserFunction", CeedQFunctionSetCUDAUserFunction_Cuda));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down
4 changes: 4 additions & 0 deletions backends/cuda-ref/ceed-cuda-ref-qfunctioncontext.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ct
impl->d_data = impl->d_data_owned;
}
CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctx_size, cudaMemcpyHostToDevice));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -64,6 +65,7 @@ static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ct
impl->h_data = impl->h_data_owned;
}
CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctx_size, cudaMemcpyDeviceToHost));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -205,6 +207,7 @@ static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx
impl->d_data = data;
break;
}
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -335,6 +338,7 @@ int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda));
CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda));
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedCalloc(1, &impl));
CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
return CEED_ERROR_SUCCESS;
Expand Down
Loading

0 comments on commit 0418c16

Please sign in to comment.