Skip to content

Commit

Permalink
sycl - fix regresions
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Oct 9, 2024
1 parent a014212 commit dbb6954
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 46 deletions.
8 changes: 4 additions & 4 deletions backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
// Get elem_size, eval_mode, num_comp
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));

// Set field constants
if (eval_mode != CEED_EVAL_WEIGHT) {
Expand Down Expand Up @@ -334,9 +334,9 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
// Get elem_size, eval_mode, num_comp
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));

// Set field constants
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
Expand Down Expand Up @@ -401,8 +401,8 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
// Get elem_size, eval_mode, num_comp
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));

// Restriction
if (eval_mode != CEED_EVAL_WEIGHT && !((eval_mode == CEED_EVAL_NONE) && use_collograd_parallelization)) {
Expand Down Expand Up @@ -677,8 +677,8 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
// Get elem_size, eval_mode, num_comp
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
// Basis action
code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
switch (eval_mode) {
Expand Down
98 changes: 56 additions & 42 deletions backends/sycl-ref/ceed-sycl-ref-operator.sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,13 @@ static int CeedOperatorDestroy_Sycl(CeedOperator op) {
CeedCallSycl(ceed, sycl::free(impl->diag->d_interp_out, sycl_data->sycl_context));
CeedCallSycl(ceed, sycl::free(impl->diag->d_grad_in, sycl_data->sycl_context));
CeedCallSycl(ceed, sycl::free(impl->diag->d_grad_out, sycl_data->sycl_context));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));

CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag));
CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));
CeedCallBackend(CeedBasisDestroy(&impl->diag->basis_in));
CeedCallBackend(CeedBasisDestroy(&impl->diag->basis_out));
}
CeedCallBackend(CeedFree(&impl->diag));

Expand All @@ -115,7 +118,7 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
Ceed ceed;
CeedSize q_size;
bool is_strided, skip_restriction;
CeedInt dim, size;
CeedInt size;
CeedOperatorField *op_fields;
CeedQFunctionField *qf_fields;

Expand All @@ -133,7 +136,6 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
CeedEvalMode eval_mode;
CeedVector vec;
CeedElemRestriction elem_rstr;
CeedBasis basis;

CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));

Expand Down Expand Up @@ -183,20 +185,21 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
break;
case CEED_EVAL_GRAD:
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
CeedCallBackend(CeedBasisGetDimension(basis, &dim));
CeedCallBackend(CeedBasisDestroy(&basis));
q_size = (CeedSize)num_elem * Q * size;
CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
break;
case CEED_EVAL_WEIGHT: // Only on input fields
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
case CEED_EVAL_WEIGHT: {
CeedBasis basis;

// Note: only on input fields
q_size = (CeedSize)num_elem * Q;
CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
CeedCallBackend(CeedBasisDestroy(&basis));
break;
}
case CEED_EVAL_DIV:
break; // TODO: Not implemented
case CEED_EVAL_CURL:
Expand Down Expand Up @@ -463,8 +466,8 @@ static int CeedOperatorApplyAdd_Sycl(CeedOperator op, CeedVector in_vec, CeedVec
// Restrict
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
if (is_active) vec = out_vec;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_e_in], vec, request));
if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
Expand Down Expand Up @@ -637,6 +640,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
Ceed_Sycl *sycl_data;
CeedInt num_input_fields, num_output_fields, num_eval_mode_in = 0, num_comp = 0, dim = 1, num_eval_mode_out = 0;
CeedEvalMode *eval_mode_in = NULL, *eval_mode_out = NULL;
CeedElemRestriction rstr_in = NULL, rstr_out = NULL;
CeedBasis basis_in = NULL, basis_out = NULL;
CeedQFunctionField *qf_fields;
CeedQFunction qf;
Expand All @@ -655,14 +659,19 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedEvalMode eval_mode;
CeedBasis basis;
CeedEvalMode eval_mode;
CeedElemRestriction elem_rstr;
CeedBasis basis;

CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr));
if (!rstr_in) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_in));
CeedCheck(rstr_in == elem_rstr, ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly");
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND,
"Backend does not implement operator diagonal assembly with multiple active bases");
if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
CeedCheck(basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator diagonal assembly with multiple active bases");
CeedCallBackend(CeedBasisDestroy(&basis));
CeedCallBackend(CeedBasisGetDimension(basis_in, &dim));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
switch (eval_mode) {
case CEED_EVAL_NONE:
Expand All @@ -684,6 +693,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in));

// Determine active output basis
CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
Expand All @@ -693,26 +703,30 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedEvalMode eval_mode;
CeedBasis basis;
CeedEvalMode eval_mode;
CeedElemRestriction elem_rstr;
CeedBasis basis;

CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr));
if (!rstr_out) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_out));
CeedCheck(rstr_out == elem_rstr, ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly");
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND,
"Backend does not implement operator diagonal assembly with multiple active bases");
if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
CeedCheck(basis_out == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator diagonal assembly with multiple active bases");
CeedCallBackend(CeedBasisDestroy(&basis));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
switch (eval_mode) {
case CEED_EVAL_NONE:
case CEED_EVAL_INTERP:
CeedCallBackend(CeedRealloc(num_eval_mode_in + 1, &eval_mode_in));
eval_mode_in[num_eval_mode_in] = eval_mode;
num_eval_mode_in += 1;
CeedCallBackend(CeedRealloc(num_eval_mode_out + 1, &eval_mode_out));
eval_mode_out[num_eval_mode_out] = eval_mode;
num_eval_mode_out += 1;
break;
case CEED_EVAL_GRAD:
CeedCallBackend(CeedRealloc(num_eval_mode_in + dim, &eval_mode_in));
for (CeedInt d = 0; d < dim; d++) eval_mode_in[num_eval_mode_in + d] = eval_mode;
num_eval_mode_in += dim;
CeedCallBackend(CeedRealloc(num_eval_mode_out + dim, &eval_mode_out));
for (CeedInt d = 0; d < dim; d++) eval_mode_out[num_eval_mode_out + d] = eval_mode;
num_eval_mode_out += dim;
break;
case CEED_EVAL_WEIGHT:
case CEED_EVAL_DIV:
Expand All @@ -729,8 +743,8 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
CeedCallBackend(CeedCalloc(1, &impl->diag));
CeedOperatorDiag_Sycl *diag = impl->diag;

diag->basis_in = basis_in;
diag->basis_out = basis_out;
CeedCallBackend(CeedBasisReferenceCopy(basis_in, &diag->basis_in));
CeedCallBackend(CeedBasisReferenceCopy(basis_out, &diag->basis_out));
diag->h_eval_mode_in = eval_mode_in;
diag->h_eval_mode_out = eval_mode_out;
diag->num_eval_mode_in = num_eval_mode_in;
Expand All @@ -740,6 +754,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
CeedInt num_nodes, num_qpts;
CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp));
diag->num_nodes = num_nodes;
diag->num_qpts = num_qpts;
diag->num_comp = num_comp;
Expand Down Expand Up @@ -801,13 +816,12 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
copy_events.push_back(eval_mode_out_copy);

// Restriction
{
CeedElemRestriction rstr_out;
CeedCallBackend(CeedElemRestrictionReferenceCopy(rstr_out, &diag->diag_rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));

CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, NULL, &rstr_out));
diag->diag_rstr = rstr_out;
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
}
// Cleanup
CeedCallBackend(CeedBasisDestroy(&basis_in));
CeedCallBackend(CeedBasisDestroy(&basis_out));

// Wait for all copies to complete and handle exceptions
CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events));
Expand Down Expand Up @@ -1020,16 +1034,16 @@ static int CeedSingleOperatorAssembleSetup_Sycl(CeedOperator op) {
CeedElemRestriction elem_rstr;
CeedBasis basis;

CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis));
CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases");
if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
CeedCallBackend(CeedBasisGetDimension(basis, &dim));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
CeedCallBackend(CeedBasisDestroy(&basis));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &elem_rstr));
if (!rstr_in) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_in));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size));
CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis));
if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
CeedCheck(basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases");
CeedCallBackend(CeedBasisDestroy(&basis));
CeedCallBackend(CeedBasisGetDimension(basis_in, &dim));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
if (eval_mode != CEED_EVAL_NONE) {
CeedCallBackend(CeedRealloc(num_B_in_mats_to_load + 1, &eval_mode_in));
Expand Down Expand Up @@ -1058,14 +1072,14 @@ static int CeedSingleOperatorAssembleSetup_Sycl(CeedOperator op) {
CeedElemRestriction elem_rstr;
CeedBasis basis;

CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis));
CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND,
"Backend does not implement operator assembly with multiple active bases");
if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
CeedCallBackend(CeedBasisDestroy(&basis));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &elem_rstr));
if (!rstr_out) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_out));
CeedCheck(rstr_out == rstr_in, ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator assembly");
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis));
if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
CeedCheck(basis_out == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases");
CeedCallBackend(CeedBasisDestroy(&basis));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
if (eval_mode != CEED_EVAL_NONE) {
CeedCallBackend(CeedRealloc(num_B_out_mats_to_load + 1, &eval_mode_out));
Expand Down

0 comments on commit dbb6954

Please sign in to comment.