Skip to content

Commit

Permalink
fix build errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Nov 10, 2023
1 parent fd8bae8 commit c34036d
Showing 1 changed file with 34 additions and 94 deletions.
128 changes: 34 additions & 94 deletions backends/ref/ceed-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,8 @@ static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op, CeedVe
//------------------------------------------------------------------------------
// Setup Input/Output Fields
//------------------------------------------------------------------------------
static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op, bool is_input, CeedElemRestriction *block_rstr,
CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields,
CeedInt Q) {
static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op, bool is_input, CeedVector *e_vecs_full, CeedVector *e_vecs,
CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) {
Ceed ceed;
CeedSize e_size, q_size;
CeedInt max_num_points, num_comp, size, P;
Expand Down Expand Up @@ -588,71 +587,16 @@ static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op

CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
if (eval_mode != CEED_EVAL_WEIGHT) {
Ceed ceed_rstr;
CeedSize l_size;
CeedInt num_elem, elem_size, comp_stride;
CeedRestrictionType rstr_type;
CeedElemRestriction rstr;
CeedElemRestriction elem_rstr;

CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
if (rstr_type == CEED_RESTRICTION_POINTS) {
// -- Only vector needed if field is at points
CeedCallBackend(CeedElemRestrictionReferenceCopy(rstr, &block_rstr[i + start_e]));
CeedCallBackend(CeedVectorCreate(ceed, num_comp * max_num_points, &e_vecs_full[i + start_e]));
} else {
// -- FEM fields need blocked restriction
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr));
CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed_rstr));
CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size));
CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride));

switch (rstr_type) {
case CEED_RESTRICTION_STANDARD: {
const CeedInt *offsets = NULL;

CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
CeedCallBackend(CeedElemRestrictionCreateBlocked(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, CEED_MEM_HOST,
CEED_COPY_VALUES, offsets, &block_rstr[i + start_e]));
CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
} break;
case CEED_RESTRICTION_ORIENTED: {
const bool *orients = NULL;
const CeedInt *offsets = NULL;

CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
CeedCallBackend(CeedElemRestrictionGetOrientations(rstr, CEED_MEM_HOST, &orients));
CeedCallBackend(CeedElemRestrictionCreateBlockedOriented(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size,
CEED_MEM_HOST, CEED_COPY_VALUES, offsets, orients, &block_rstr[i + start_e]));
CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr, &orients));
} break;
case CEED_RESTRICTION_CURL_ORIENTED: {
const CeedInt8 *curl_orients = NULL;
const CeedInt *offsets = NULL;

CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr, CEED_MEM_HOST, &curl_orients));
CeedCallBackend(CeedElemRestrictionCreateBlockedCurlOriented(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size,
CEED_MEM_HOST, CEED_COPY_VALUES, offsets, curl_orients,
&block_rstr[i + start_e]));
CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr, &curl_orients));
} break;
case CEED_RESTRICTION_STRIDED: {
CeedInt strides[3];

CeedCallBackend(CeedElemRestrictionGetStrides(rstr, &strides));
CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed_rstr, num_elem, elem_size, block_size, num_comp, l_size, strides,
&block_rstr[i + start_e]));
} break;
case CEED_RESTRICTION_POINTS:
// Empty case - won't occur
break;
}
CeedCallBackend(CeedElemRestrictionCreateVector(block_rstr[i + start_e], NULL, &e_vecs_full[i + start_e]));
CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs_full[i + start_e]));
}
}

Expand Down Expand Up @@ -694,7 +638,6 @@ static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op
static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) {
bool is_setup_done;
Ceed ceed;
Ceed_Ref *ceed_impl;
CeedInt Q, num_input_fields, num_output_fields;
CeedQFunctionField *qf_input_fields, *qf_output_fields;
CeedQFunction qf;
Expand All @@ -705,7 +648,6 @@ static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) {
if (is_setup_done) return CEED_ERROR_SUCCESS;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedGetData(ceed, &ceed_impl));
CeedCallBackend(CeedOperatorGetData(op, &impl));
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
Expand All @@ -714,7 +656,6 @@ static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) {
CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));

// Allocate
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->block_rstr));
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full));

CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
Expand All @@ -728,11 +669,10 @@ static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) {

// Set up infield and outfield pointer arrays
// Infields
CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, true, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0,
num_input_fields, Q));
CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, true, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q));
// Outfields
CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, false, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out,
num_input_fields, num_output_fields, Q));
CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, false, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields,
num_output_fields, Q));

// Identity QFunctions
if (impl->is_identity_qf) {
Expand Down Expand Up @@ -761,20 +701,22 @@ static inline int CeedOperatorSetupInputsAtPoints_Ref(CeedInt num_input_fields,
CeedOperatorField *op_input_fields, CeedVector in_vec, CeedScalar *e_data[2 * CEED_FIELD_MAX],
CeedOperator_Ref *impl, CeedRequest *request) {
for (CeedInt i = 0; i < num_input_fields; i++) {
uint64_t state;
CeedEvalMode eval_mode;
CeedVector vec;
uint64_t state;
CeedEvalMode eval_mode;
CeedVector vec;
CeedElemRestriction elem_rstr;

CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
} else {
// Get input vector
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
if (vec != CEED_VECTOR_ACTIVE) {
// Restrict
CeedCallBackend(CeedVectorGetState(vec, &state));
if (state != impl->input_states[i]) {
CeedCallBackend(CeedElemRestrictionApply(impl->block_rstr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request));
CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request));
impl->input_states[i] = state;
}
// Get evec
Expand Down Expand Up @@ -802,9 +744,9 @@ static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_poin
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active_input = false;
CeedInt elem_size, size, num_comp;
CeedRestrictionType rstr_type;
CeedEvalMode eval_mode;
CeedVector vec;
CeedRestrictionType rstr_type;
CeedElemRestriction elem_rstr;
CeedBasis basis;

Expand All @@ -815,15 +757,15 @@ static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_poin

// Get elem_size, eval_mode, size
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetType(impl->block_rstr[i], &rstr_type));
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
// Restrict block active input
if (is_active_input) {
if (elem_type == CEED_RESTRICTION_POINTS) {
CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(impl->block_rstr[i], e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request))
if (rstr_type == CEED_RESTRICTION_POINTS) {
CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request));
} else {
CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[i], e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request));
CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request));
}
}
// Basis action
Expand Down Expand Up @@ -862,9 +804,9 @@ static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_poi
CeedOperator op, CeedVector out_vec, CeedVector point_coords_elem, CeedOperator_Ref *impl,
CeedRequest *request) {
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedRestrictionType rstr_type;
CeedEvalMode eval_mode;
CeedVector vec;
CeedRestrictionType rstr_type;
CeedElemRestriction elem_rstr;
CeedBasis basis;

Expand Down Expand Up @@ -893,15 +835,14 @@ static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_poi
}
// Restrict output block
// Get output vector
CeedCallBackend(CeedElemRestrictionGetType(impl->block_rstr[i + impl->num_inputs], &rstr_type));
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) vec = out_vec;
// Restrict
if (elem_type == CEED_RESTRICTION_POINTS) {
CeedCallBackend(
CeedElemRestrictionApplyAtPointsInElement(impl->block_rstr[i + impl->num_inputs], e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request))
if (rstr_type == CEED_RESTRICTION_POINTS) {
CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request));
} else {
CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[i + impl->num_inputs], e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request));
CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request));
}
}
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -931,7 +872,6 @@ static inline int CeedOperatorRestoreInputsAtPoints_Ref(CeedInt num_input_fields
//------------------------------------------------------------------------------
static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) {
Ceed ceed;
Ceed_Ref *ceed_impl;
CeedInt num_points_offset = 0, num_input_fields, num_output_fields, num_elem;
CeedEvalMode eval_mode;
CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0};
Expand All @@ -943,10 +883,8 @@ static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec,
CeedOperator_Ref *impl;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedGetData(ceed, &ceed_impl));
CeedCallBackend(CeedOperatorGetData(op, &impl));
CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
Expand All @@ -956,9 +894,10 @@ static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec,

// Restriction only operator
if (impl->is_identity_rstr_op) {
for (CeedInt b = 0; b < num_blocks; b++) {
CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[0], b, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[0], request));
CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[1], b, CEED_TRANSPOSE, impl->e_vecs_in[0], out_vec, request));
// TODO: Fix this up
for (CeedInt e = 0; e < num_elem; e++) {
// CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[0], e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[0], request));
// CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[1], e, CEED_TRANSPOSE, impl->e_vecs_in[0], out_vec, request));
}
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -991,7 +930,7 @@ static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec,

// Input basis apply
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec,
false, e_data, impl, request));
impl->point_coords_elem, false, e_data, impl, request));

// Q function
if (!impl->is_identity_qf) {
Expand All @@ -1000,7 +939,7 @@ static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec,

// Output basis apply and restriction
CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields,
num_output_fields, op, out_vec, impl, request));
num_output_fields, op, out_vec, impl->point_coords_elem, impl, request));

num_points_offset += num_points;
}
Expand All @@ -1010,7 +949,7 @@ static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec,

// Cleanup point coordinates
CeedCallBackend(CeedVectorDestroy(&point_coords));
CeedCallBackend(CeedElemRestriction(&rstr_points));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -1076,6 +1015,7 @@ int CeedOperatorCreateAtPoints_Ref(CeedOperator op) {
Ceed ceed;
CeedOperator_Ref *impl;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedCalloc(1, &impl));
CeedCallBackend(CeedOperatorSetData(op, impl));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Ref));
Expand Down

0 comments on commit c34036d

Please sign in to comment.