Skip to content

Commit

Permalink
op - ReferenceCopy for CeedOperatorFieldGet*
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Aug 23, 2024
1 parent 229d7ba commit b377dfe
Show file tree
Hide file tree
Showing 13 changed files with 612 additions and 229 deletions.
56 changes: 43 additions & 13 deletions backends/blocked/ceed-blocked-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
// Empty case - won't occur
break;
}
CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
CeedCallBackend(CeedElemRestrictionCreateVector(block_rstr[i + start_e], NULL, &e_vecs_full[i + start_e]));
}

Expand All @@ -122,6 +123,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
CeedCallBackend(CeedBasisGetNumNodes(basis, &P));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisDestroy(&basis));
e_size = (CeedSize)P * num_comp * block_size;
CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i]));
q_size = (CeedSize)Q * size * block_size;
Expand All @@ -132,6 +134,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
q_size = (CeedSize)Q * block_size;
CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
CeedCallBackend(CeedBasisDestroy(&basis));
break;
}
}
Expand All @@ -154,7 +157,11 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
skip_rstr[j] = true;
}
CeedCallBackend(CeedVectorDestroy(&vec_j));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
}
CeedCallBackend(CeedVectorDestroy(&vec_i));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
}
} else {
for (CeedInt i = num_fields - 1; i >= 0; i--) {
Expand All @@ -176,7 +183,11 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
apply_add_basis[i] = true;
e_data_out_indices[j] = i;
}
CeedCallBackend(CeedVectorDestroy(&vec_j));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
}
CeedCallBackend(CeedVectorDestroy(&vec_i));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
}
}
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -259,13 +270,15 @@ static inline int CeedOperatorSetupInputs_Blocked(CeedInt num_input_fields, Ceed
CeedVector in_vec, bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX],
CeedOperator_Blocked *impl, CeedRequest *request) {
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active;
uint64_t state;
CeedEvalMode eval_mode;
CeedVector vec;

// Get input vector
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
is_active = vec == CEED_VECTOR_ACTIVE;
if (is_active) {
if (skip_active) continue;
else vec = in_vec;
}
Expand All @@ -282,6 +295,7 @@ static inline int CeedOperatorSetupInputs_Blocked(CeedInt num_input_fields, Ceed
// Get evec
CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i]));
}
if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
}
return CEED_ERROR_SUCCESS;
}
Expand All @@ -300,15 +314,19 @@ static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunc

// Skip active input
if (skip_active) {
bool is_active;
CeedVector vec;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) continue;
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (is_active) continue;
}

// Get elem_size, eval_mode, size
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
// Basis action
Expand All @@ -324,6 +342,7 @@ static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunc
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][(CeedSize)e * elem_size * num_comp]));
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i]));
CeedCallBackend(CeedBasisDestroy(&basis));
break;
case CEED_EVAL_WEIGHT:
break; // No action
Expand All @@ -347,6 +366,7 @@ static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFun
// Get elem_size, eval_mode, size
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
// Basis action
switch (eval_mode) {
Expand All @@ -365,6 +385,7 @@ static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFun
} else {
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
}
CeedCallBackend(CeedBasisDestroy(&basis));
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand All @@ -386,10 +407,13 @@ static inline int CeedOperatorRestoreInputs_Blocked(CeedInt num_input_fields, Ce

// Skip active inputs
if (skip_active) {
bool is_active;
CeedVector vec;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) continue;
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (is_active) continue;
}
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
Expand Down Expand Up @@ -470,18 +494,21 @@ static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, Ceed

// Output restriction
for (CeedInt i = 0; i < num_output_fields; i++) {
bool is_active;
CeedVector vec;

if (impl->skip_rstr_out[i]) continue;
// Restore evec
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields]));
// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
// Active
if (vec == CEED_VECTOR_ACTIVE) vec = out_vec;
if (is_active) vec = out_vec;
// Restrict
CeedCallBackend(
CeedElemRestrictionApply(impl->block_rstr[i + impl->num_inputs], CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request));
if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
}

// Restore input arrays
Expand Down Expand Up @@ -533,14 +560,14 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
CeedInt field_size;
CeedVector vec;

// Get input vector
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
// Check if active input
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
qf_size_in += field_size;
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
CeedCheck(qf_size_in > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
impl->qf_size_in = qf_size_in;
Expand All @@ -552,13 +579,13 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
CeedInt field_size;
CeedVector vec;

// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
// Check if active output
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size));
qf_size_out += field_size;
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
impl->qf_size_out = qf_size_out;
Expand Down Expand Up @@ -601,13 +628,15 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o

// Assemble QFunction
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active;
CeedInt field_size;
CeedVector vec;

// Get input vector
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
// Check if active input
if (vec != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (!is_active) continue;
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
for (CeedInt field = 0; field < field_size; field++) {
// Set current portion of input to 1.0
Expand All @@ -633,6 +662,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &field_size));
l_vec_array += field_size * Q * block_size; // Advance the pointer by the size of the output
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
// Apply QFunction
CeedCallBackend(CeedQFunctionApply(qf, Q * block_size, impl->q_vecs_in, impl->q_vecs_out));
Expand Down Expand Up @@ -664,12 +694,12 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
for (CeedInt out = 0; out < num_output_fields; out++) {
CeedVector vec;

// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
// Check if active output
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL));
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
}

Expand Down
11 changes: 11 additions & 0 deletions backends/cuda-gen/ceed-cuda-gen-operator-build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
CeedCheck(*Q_1d == 0 || field_Q_1d == *Q_1d, ceed, CEED_ERROR_BACKEND, "Quadrature spaces must be compatible");
*Q_1d = field_Q_1d;
}
CeedCallBackend(CeedBasisDestroy(&basis));
}
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedBasis basis;
Expand All @@ -77,6 +78,7 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
CeedCheck(*Q_1d == 0 || field_Q_1d == *Q_1d, ceed, CEED_ERROR_BACKEND, "Quadrature spaces must be compatible");
*Q_1d = field_Q_1d;
}
CeedCallBackend(CeedBasisDestroy(&basis));
}

// Only use 3D collocated gradient parallelization strategy when gradient is computed
Expand All @@ -96,6 +98,7 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
CeedCallBackend(CeedBasisGetData(basis, &basis_data));
*use_3d_slices = basis_data->d_collo_grad_1d && (was_grad_found ? *use_3d_slices : true);
was_grad_found = true;
CeedCallBackend(CeedBasisDestroy(&basis));
}
}
for (CeedInt i = 0; i < num_output_fields; i++) {
Expand All @@ -110,6 +113,7 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
CeedCallBackend(CeedBasisGetData(basis, &basis_data));
*use_3d_slices = basis_data->d_collo_grad_1d && (was_grad_found ? *use_3d_slices : true);
was_grad_found = true;
CeedCallBackend(CeedBasisDestroy(&basis));
}
}
}
Expand Down Expand Up @@ -138,6 +142,7 @@ static int CeedOperatorBuildKernelFieldData_Cuda_gen(std::ostringstream &code, C
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
}
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedOperatorFieldGetBasis(op_field, &basis));
if (basis != CEED_BASIS_NONE) {
CeedCallBackend(CeedBasisGetData(basis, &basis_data));
Expand All @@ -150,6 +155,7 @@ static int CeedOperatorBuildKernelFieldData_Cuda_gen(std::ostringstream &code, C
code << " const CeedInt " << P_name << " = " << (basis == CEED_BASIS_NONE ? Q_1d : P_1d) << ";\n";
code << " const CeedInt num_comp" << var_suffix << " = " << num_comp << ";\n";
}
CeedCallBackend(CeedBasisDestroy(&basis));

// Load basis data
code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
Expand Down Expand Up @@ -224,6 +230,7 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
if (basis != CEED_BASIS_NONE) {
CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
}
CeedCallBackend(CeedBasisDestroy(&basis));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_field, &eval_mode));

// Restriction
Expand Down Expand Up @@ -291,6 +298,7 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
<< strides[2] << ">(data, elem, r_e" << var_suffix << ", d" << var_suffix << ");\n";
}
}
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
return CEED_ERROR_SUCCESS;
}

Expand All @@ -313,6 +321,7 @@ static int CeedOperatorBuildKernelBasis_Cuda_gen(std::ostringstream &code, CeedO
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
}
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedOperatorFieldGetBasis(op_field, &basis));
if (basis != CEED_BASIS_NONE) {
CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
Expand Down Expand Up @@ -392,6 +401,7 @@ static int CeedOperatorBuildKernelBasis_Cuda_gen(std::ostringstream &code, CeedO
// LCOV_EXCL_STOP
}
}
CeedCallBackend(CeedBasisDestroy(&basis));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -480,6 +490,7 @@ static int CeedOperatorBuildKernelQFunction_Cuda_gen(std::ostringstream &code, C
code << " readSliceQuadsOffset3d<num_comp" << var_suffix << ", " << comp_stride << ", " << Q_name << ">(data, l_size" << var_suffix
<< ", elem, q, indices.inputs[" << i << "], d" << var_suffix << ", r_s" << var_suffix << ");\n";
}
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
break;
case CEED_EVAL_INTERP:
code << " CeedScalar r_s" << var_suffix << "[num_comp" << var_suffix << "];\n";
Expand Down
Loading

0 comments on commit b377dfe

Please sign in to comment.