-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AtPoints for */gen #1715
Merged
Merged
AtPoints for */gen #1715
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
jeremylt
force-pushed
the
jeremy/at-points-gen
branch
3 times, most recently
from
December 11, 2024 20:44
f84f050
to
368f613
Compare
Giving wrong results, but the code is compiling: #define T_1D 5
#include <ceed/jit-source/cuda/cuda-jit.h>
// Tensor basis source
#include <ceed/jit-source/cuda/cuda-shared-basis-tensor-templates.h>
// AtPoints basis source
#include <ceed/jit-source/cuda/cuda-shared-basis-tensor-at-points-templates.h>
// CodeGen operator source
#include <ceed/jit-source/cuda/cuda-gen-templates.h>
#undef CEED_Q_VLA
#define CEED_Q_VLA 1
// User QFunction source
#include "/home/jeremy/Dev/libCEED/tests/t590-operator.h"
// -----------------------------------------------------------------------------
// Operator Kernel
//
// d_[in,out]_i: CeedVector device array
// r_[in,out]_e_i: Element vector register
// r_[in,out]_q_i: Quadrature space vector register
// r_[in,out]_c_i: AtPoints Chebyshev coefficents register
// r_[in,out]_s_i: Quadrature space slice vector register
//
// s_B_[in,out]_i: Interpolation matrix, shared memory
// s_G_[in,out]_i: Gradient matrix, shared memory
// -----------------------------------------------------------------------------
extern "C" __global__ void CeedKernelCudaGenOperator_mass(CeedInt num_elem, void* ctx, FieldsInt_Cuda indices, Fields_Cuda fields, Fields_Cuda B, Fields_Cuda G, CeedScalar *W, Points_Cuda points) {
const CeedScalar *d_in_0 = fields.inputs[0];
CeedScalar *d_out_0 = fields.outputs[0];
const CeedInt dim = 2;
const CeedInt Q_1d = 5;
const CeedInt max_num_points = 4;
extern __shared__ CeedScalar slice[];
SharedData_Cuda data;
data.t_id_x = threadIdx.x;
data.t_id_y = threadIdx.y;
data.t_id_z = threadIdx.z;
data.t_id = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
data.slice = slice + data.t_id_z*T_1D*T_1D;
// Input field constants and basis data
// -- Input field 0
const CeedInt P_1d_in_0 = 3;
const CeedInt num_comp_in_0 = 1;
// EvalMode: interpolation
__shared__ CeedScalar s_B_in_0[15];
LoadMatrix<P_1d_in_0, Q_1d>(data, B.inputs[0], s_B_in_0);
// Output field constants and basis data
// -- Output field 0
const CeedInt P_1d_out_0 = 3;
const CeedInt num_comp_out_0 = 1;
// EvalMode: interpolation
__shared__ CeedScalar s_B_out_0[15];
LoadMatrix<P_1d_out_0, Q_1d>(data, B.outputs[0], s_B_out_0);
// Element loop
__syncthreads();
for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {
// Scratch restriction buffer space
CeedScalar r_e_scratch[9];
// -- Input field restrictions and basis actions
// ---- Input field 0
CeedScalar *r_e_in_0 = r_e_scratch;
const CeedInt l_size_in_0 = 49;
// CompStride: 1
ReadLVecStandard2d<num_comp_in_0, 1, P_1d_in_0>(data, l_size_in_0, elem, indices.inputs[0], d_in_0, r_e_in_0);
// EvalMode: interpolation
CeedScalar r_c_in_0[num_comp_in_0*1];
InterpTensor2d<num_comp_in_0, P_1d_in_0, Q_1d>(data, r_e_in_0, s_B_in_0, r_c_in_0);
// -- Output field setup
// ---- Output field 0
CeedScalar r_c_out_0[num_comp_out_0*1];
for (CeedInt i = 0; i < num_comp_out_0*Q_1d; i++) {
r_c_out_0[i] = 0.0;
}
// Note: Using batches of points
const CeedInt point_loop_bound = (blockDim.x * blockDim.y) * ceil(1.0 * max_num_points / (blockDim.x * blockDim.y));
#pragma unroll
for (CeedInt i = threadIdx.x + threadIdx.y * blockDim.x; i < point_loop_bound; i += blockDim.x * blockDim.y) {
const CeedInt p = i % max_num_points;
// -- Coordinates
CeedScalar r_x[dim];
ReadPoint<dim, max_num_points>(data, elem, p, max_num_points, 1, num_elem * max_num_points, max_num_points, points.coords, r_x);
// -- Input fields
// ---- Input field 0
// EvalMode: interpolation
CeedScalar r_s_in_0[num_comp_in_0];
InterpAtPoints2d<num_comp_in_0, max_num_points, Q_1d>(data, i, r_c_in_0, r_x, r_s_in_0);
// -- Output fields
// ---- Output field 0
CeedScalar r_s_out_0[num_comp_out_0];
// -- QFunction inputs and outputs
// ---- Inputs
CeedScalar *inputs[1];
// ------ Input field 0
inputs[0] = r_s_in_0;
// ---- Outputs
CeedScalar *outputs[1];
// ------ Output field 0
outputs[0] = r_s_out_0;
// -- Apply QFunction
mass(ctx, 1, inputs, outputs);
// -- Output fields
// ---- Output field 0
// EvalMode: interpolation
if (i > points.num_per_elem[elem]) {
for (CeedInt j = 0; j < num_comp_out_0; j++) r_s_out_0[j] = 0.0;
}
InterpTransposeAtPoints2d<num_comp_out_0, max_num_points, Q_1d>(data, i, r_s_out_0, r_x, r_c_out_0);
}
__syncthreads();
// -- Output field basis action and restrictions
// ---- Output field 0
// EvalMode: interpolation
CeedScalar *r_e_out_0 = r_e_scratch;
InterpTransposeTensor2d<num_comp_out_0, P_1d_out_0, Q_1d>(data, r_c_out_0, s_B_out_0, r_e_out_0);
const CeedInt l_size_out_0 = 49;
// CompStride: 1
WriteLVecStandard2d<num_comp_out_0, 1, P_1d_out_0>(data, l_size_out_0, elem, indices.outputs[0], r_e_out_0, d_out_0);
}
}
// ----------------------------------------------------------------------------- |
jeremylt
force-pushed
the
jeremy/at-points-gen
branch
2 times, most recently
from
December 13, 2024 00:32
4924280
to
38bf749
Compare
Only diagonal assembly failing for the core tests - I'm suspicious of the delegation? Edit: Nope, assembly is correct but true values are incorrect for gen? But t590-t593 passes? |
jeremylt
force-pushed
the
jeremy/at-points-gen
branch
14 times, most recently
from
December 16, 2024 21:22
4f6a4a3
to
dbb7f4f
Compare
jeremylt
force-pushed
the
jeremy/at-points-gen
branch
7 times, most recently
from
December 18, 2024 17:44
867d284
to
9f6cf23
Compare
jeremylt
force-pushed
the
jeremy/at-points-gen
branch
from
December 18, 2024 18:03
9f6cf23
to
3a2968d
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
t590 works, now on to t591-t594
AtPoints
Operator