diff --git a/backends/cuda-ref/ceed-cuda-ref-operator.c b/backends/cuda-ref/ceed-cuda-ref-operator.c index 9f6d3d14b0..748237dab8 100644 --- a/backends/cuda-ref/ceed-cuda-ref-operator.c +++ b/backends/cuda-ref/ceed-cuda-ref-operator.c @@ -28,6 +28,8 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) { // Apply data CeedCallBackend(CeedFree(&impl->skip_rstr_in)); + CeedCallBackend(CeedFree(&impl->skip_rstr_out)); + CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i])); } @@ -97,8 +99,8 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) { //------------------------------------------------------------------------------ // Setup infields or outfields //------------------------------------------------------------------------------ -static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, CeedVector *e_vecs, - CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { +static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis, + CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { Ceed ceed; CeedQFunctionField *qf_fields; CeedOperatorField *op_fields; @@ -184,7 +186,7 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool break; } } - // Drop duplicate input restrictions + // Drop duplicate restrictions if (is_input) { for (CeedInt i = 0; i < num_fields; i++) { CeedVector vec_i; @@ -199,11 +201,31 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { - CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e])); skip_rstr[j] = true; } } } + } else { + for (CeedInt i = num_fields - 1; i >= 0; i--) { + CeedVector vec_i; + CeedElemRestriction rstr_i; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); + for (CeedInt j = i - 1; j >= 0; j--) { + CeedVector vec_j; + CeedElemRestriction rstr_j; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); + if (vec_i == vec_j && rstr_i == rstr_j) { + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e])); + skip_rstr[j] = true; + apply_add_basis[i] = true; + } + } + } } return CEED_ERROR_SUCCESS; } @@ -234,6 +256,8 @@ static int CeedOperatorSetup_Cuda(CeedOperator op) { // Allocate CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); @@ -243,10 +267,10 @@ static int CeedOperatorSetup_Cuda(CeedOperator op) { // Set up infield and outfield e_vecs and q_vecs // Infields CeedCallBackend( - CeedOperatorSetupFields_Cuda(qf, op, true, false, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); + CeedOperatorSetupFields_Cuda(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); // Outfields - CeedCallBackend( - CeedOperatorSetupFields_Cuda(qf, op, false, false, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem)); + CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, + num_input_fields, num_output_fields, Q, num_elem)); CeedCallBackend(CeedOperatorSetSetupDone(op)); return CEED_ERROR_SUCCESS; @@ -431,7 +455,11 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); - CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); + if (impl->apply_add_basis_out[i]) { + CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); + } else { + CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); + } break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -452,6 +480,7 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec if (eval_mode == CEED_EVAL_NONE) { CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); } + if (impl->skip_rstr_out[i]) continue; // Get output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); // Restrict @@ -499,6 +528,8 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { // Allocate CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); @@ -507,11 +538,11 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { // Set up infield and outfield e_vecs and q_vecs // Infields - CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, true, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, + CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, max_num_points, num_elem)); // Outfields - CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, - max_num_points, num_elem)); + CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, + num_input_fields, num_output_fields, max_num_points, num_elem)); CeedCallBackend(CeedOperatorSetSetupDone(op)); return CEED_ERROR_SUCCESS; @@ -635,8 +666,13 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); - CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], - impl->e_vecs[i + impl->num_inputs])); + if (impl->apply_add_basis_out[i]) { + CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, + impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); + } else { + CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], + impl->e_vecs[i + impl->num_inputs])); + } break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -657,6 +693,7 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, if (eval_mode == CEED_EVAL_NONE) { CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); } + if (impl->skip_rstr_out[i]) continue; // Get output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); // Restrict diff --git a/backends/cuda-ref/ceed-cuda-ref.h b/backends/cuda-ref/ceed-cuda-ref.h index f8430a1b12..ff0bbaf349 100644 --- a/backends/cuda-ref/ceed-cuda-ref.h +++ b/backends/cuda-ref/ceed-cuda-ref.h @@ -128,7 +128,7 @@ typedef struct { } CeedOperatorAssemble_Cuda; typedef struct { - bool *skip_rstr_in; + bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out; uint64_t *input_states; // State tracking for passive inputs CeedVector *e_vecs; // E-vectors, inputs followed by outputs CeedVector *q_vecs_in; // Input Q-vectors needed to apply operator