Skip to content

Commit

Permalink
hip - AtPoints for hip/gen
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Dec 18, 2024
1 parent 688b547 commit 9f6cf23
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 84 deletions.
2 changes: 1 addition & 1 deletion backends/cuda-gen/ceed-cuda-gen-operator-build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op) {
code << "}\n";
code << "// -----------------------------------------------------------------------------\n\n";

// View kernel for debugging
// Compile
CeedCallBackend(CeedCompile_Cuda(ceed, code.str().c_str(), &data->module, 1, "T_1D", CeedIntMax(Q_1d, data->max_P_1d)));
CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, operator_name.c_str(), &data->op));
CeedCallBackend(CeedOperatorSetSetupDone(op));
Expand Down
422 changes: 339 additions & 83 deletions backends/hip-gen/ceed-hip-gen-operator-build.cpp

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions backends/hip-gen/ceed-hip-gen-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ceed/backend.h>
#include <ceed/jit-source/hip/hip-types.h>
#include <stddef.h>
#include <hip/hiprtc.h>

#include "../hip/ceed-hip-common.h"
#include "../hip/ceed-hip-compile.h"
Expand All @@ -19,17 +20,22 @@
// Destroy operator
//------------------------------------------------------------------------------
static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
Ceed ceed;
CeedOperator_Hip_gen *impl;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedOperatorGetData(op, &impl));
if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem));
CeedCallBackend(CeedFree(&impl));
CeedCallBackend(CeedDestroy(&ceed));
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Apply and add to output
//------------------------------------------------------------------------------
static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
bool is_at_points;
Ceed ceed;
CeedInt num_elem, num_input_fields, num_output_fields;
CeedEvalMode eval_mode;
Expand Down Expand Up @@ -110,6 +116,39 @@ static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, C
}
}

// Point coordinates, if needed
CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points));
if (is_at_points) {
// Coords
CeedVector vec;

CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
CeedCallBackend(CeedVectorDestroy(&vec));

// Points per elem
if (num_elem != data->points.num_elem) {
CeedInt *points_per_elem;
const CeedInt num_bytes = num_elem * sizeof(CeedInt);
CeedElemRestriction rstr_points = NULL;

data->points.num_elem = num_elem;
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
for (CeedInt e = 0; e < num_elem; e++) {
CeedInt num_points_elem;

CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
points_per_elem[e] = num_points_elem;
}
if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
CeedCallBackend(CeedFree(&points_per_elem));
}
}

// Get context data
CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));

Expand Down Expand Up @@ -163,6 +202,7 @@ static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, C
if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
// Check for multiple output modes
CeedInt index = -1;

for (CeedInt j = 0; j < i; j++) {
if (vec == output_vecs[j]) {
index = j;
Expand All @@ -175,6 +215,15 @@ static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, C
}
}

// Restore point coordinates, if needed
if (is_at_points) {
CeedVector vec;

CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
CeedCallBackend(CeedVectorDestroy(&vec));
}

// Restore context data
CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
CeedCallBackend(CeedDestroy(&ceed));
Expand Down
1 change: 1 addition & 0 deletions backends/hip-gen/ceed-hip-gen.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ static int CeedInit_Hip_gen(const char *resource, Ceed ceed) {

CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "QFunctionCreate", CeedQFunctionCreate_Hip_gen));
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "OperatorCreate", CeedOperatorCreate_Hip_gen));
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "OperatorCreateAtPoints", CeedOperatorCreate_Hip_gen));
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "Destroy", CeedDestroy_Hip));
return CEED_ERROR_SUCCESS;
}
Expand Down
1 change: 1 addition & 0 deletions backends/hip-gen/ceed-hip-gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ typedef struct {
Fields_Hip B;
Fields_Hip G;
CeedScalar *W;
Points_Hip points;
} CeedOperator_Hip_gen;

typedef struct {
Expand Down
32 changes: 32 additions & 0 deletions include/ceed/jit-source/hip/hip-gen-templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,38 @@ inline __device__ void LoadMatrix(SharedData_Hip &data, const CeedScalar *__rest
for (CeedInt i = data.t_id; i < P * Q; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i];
}

//------------------------------------------------------------------------------
// AtPoints
//------------------------------------------------------------------------------

//------------------------------------------------------------------------------
// L-vector -> single point
//------------------------------------------------------------------------------
template <int NUM_COMP, int COMP_STRIDE, int NUM_PTS>
inline __device__ void ReadPoint(SharedData_Hip &data, const CeedInt elem, const CeedInt p, const CeedInt points_in_elem,
const CeedInt *__restrict__ indices, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
const CeedInt ind = indices[p + elem * NUM_PTS];

for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
r_u[comp] = d_u[ind + comp * COMP_STRIDE];
}
}

//------------------------------------------------------------------------------
// Single point -> L-vector
//------------------------------------------------------------------------------
template <int NUM_COMP, int COMP_STRIDE, int NUM_PTS>
inline __device__ void WritePoint(SharedData_Hip &data, const CeedInt elem, const CeedInt p, const CeedInt points_in_elem,
const CeedInt *__restrict__ indices, const CeedScalar *__restrict__ r_u, CeedScalar *d_u) {
if (p < points_in_elem) {
const CeedInt ind = indices[p + elem * NUM_PTS];

for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
d_u[ind + comp * COMP_STRIDE] += r_u[comp];
}
}
}

//------------------------------------------------------------------------------
// 1D
//------------------------------------------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions include/ceed/jit-source/hip/hip-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ typedef struct {
CeedInt *outputs[CEED_HIP_NUMBER_FIELDS];
} FieldsInt_Hip;

typedef struct {
CeedInt num_elem;
const CeedInt *num_per_elem;
const CeedInt *indices;
const CeedScalar *coords;
} Points_Hip;

typedef struct {
CeedInt t_id_x;
CeedInt t_id_y;
Expand Down

0 comments on commit 9f6cf23

Please sign in to comment.