From 18d93ab9a653b0e3683179e5dadc122062cdd1d8 Mon Sep 17 00:00:00 2001 From: Jeremy L Thompson Date: Fri, 9 Aug 2024 13:15:36 -0600 Subject: [PATCH] basis - add MAGMA ApplyAdd --- backends/magma/ceed-magma-basis.c | 59 ++++++++-- backends/magma/ceed-magma.h | 4 + .../jit-source/magma/magma-basis-grad-1d.h | 45 ++++++++ .../jit-source/magma/magma-basis-grad-2d.h | 51 +++++++++ .../jit-source/magma/magma-basis-grad-3d.h | 58 ++++++++++ .../jit-source/magma/magma-basis-interp-1d.h | 45 ++++++++ .../jit-source/magma/magma-basis-interp-2d.h | 41 +++++++ .../jit-source/magma/magma-basis-interp-3d.h | 41 +++++++ .../magma-basis-interp-deriv-nontensor.h | 106 ++++++++++++++++++ .../jit-source/magma/magma-common-nontensor.h | 19 ++++ .../jit-source/magma/magma-common-tensor.h | 46 ++++++++ 11 files changed, 505 insertions(+), 10 deletions(-) diff --git a/backends/magma/ceed-magma-basis.c b/backends/magma/ceed-magma-basis.c index de18a1a2fc..71a86d5b8d 100644 --- a/backends/magma/ceed-magma-basis.c +++ b/backends/magma/ceed-magma-basis.c @@ -26,7 +26,8 @@ //------------------------------------------------------------------------------ // Basis apply - tensor //------------------------------------------------------------------------------ -static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) { +static int CeedBasisApplyCore_Magma(CeedBasis basis, bool apply_add, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, + CeedVector v) { Ceed ceed; Ceed_Magma *data; CeedInt dim, num_comp, num_nodes, P_1d, Q_1d, P, Q; @@ -52,7 +53,8 @@ static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTranspose // Read vectors if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); - CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); + if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); + else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); // Apply basis operation switch (e_mode) { @@ -115,7 +117,8 @@ static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTranspose void *args[] = {&impl->d_interp_1d, &d_u, &u_elem_stride, &u_comp_stride, &d_v, &v_elem_stride, &v_comp_stride, &num_elem}; if (t_mode == CEED_TRANSPOSE) { - CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->InterpTranspose, grid, num_threads, num_t_col, 1, shared_mem, args)); + CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, apply_add ? impl->InterpTransposeAdd : impl->InterpTranspose, grid, num_threads, num_t_col, + 1, shared_mem, args)); } else { CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Interp, grid, num_threads, num_t_col, 1, shared_mem, args)); } @@ -192,7 +195,8 @@ static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTranspose &v_elem_stride, &v_comp_stride, &v_dim_stride, &num_elem}; if (t_mode == CEED_TRANSPOSE) { - CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->GradTranspose, grid, num_threads, num_t_col, 1, shared_mem, args)); + CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, apply_add ? impl->GradTransposeAdd : impl->GradTranspose, grid, num_threads, num_t_col, 1, + shared_mem, args)); } else { CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Grad, grid, num_threads, num_t_col, 1, shared_mem, args)); } @@ -248,6 +252,16 @@ static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTranspose return CEED_ERROR_SUCCESS; } +static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) { + CeedCallBackend(CeedBasisApplyCore_Magma(basis, false, num_elem, t_mode, e_mode, u, v)); + return CEED_ERROR_SUCCESS; +} + +static int CeedBasisApplyAdd_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) { + CeedCallBackend(CeedBasisApplyCore_Magma(basis, true, num_elem, t_mode, e_mode, u, v)); + return CEED_ERROR_SUCCESS; +} + //------------------------------------------------------------------------------ // Basis apply - tensor AtPoints //------------------------------------------------------------------------------ @@ -259,8 +273,8 @@ int CeedBasisApplyAtPoints_Magma(CeedBasis basis, const CeedInt num_elem, const //------------------------------------------------------------------------------ // Basis apply - non-tensor //------------------------------------------------------------------------------ -static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, - CeedVector v) { +static int CeedBasisApplyNonTensorCore_Magma(CeedBasis basis, bool apply_add, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, + CeedVector u, CeedVector v) { Ceed ceed; Ceed_Magma *data; CeedInt num_comp, num_nodes, num_qpts, P, Q, N; @@ -281,7 +295,8 @@ static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, Ceed // Read vectors if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); - CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); + if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); + else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); // Compile kernels for N as needed CeedInt iN = 0; @@ -344,8 +359,10 @@ static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, Ceed impl->NB_deriv_t[iN])); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_n", &impl->Interp[iN])); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_t", &impl->InterpTranspose[iN])); + CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_ta", &impl->InterpTransposeAdd[iN])); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_n", &impl->Deriv[iN])); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_t", &impl->DerivTranspose[iN])); + CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_ta", &impl->DerivTransposeAdd[iN])); if (!impl->Weight) { CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_weight_nontensor", &impl->Weight)); CeedCallBackend(CeedFree(&weight_kernel_path)); @@ -388,7 +405,7 @@ static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, Ceed if (P <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) { if (e_mode == CEED_EVAL_INTERP) { if (t_mode == CEED_TRANSPOSE) { - Kernel = impl->InterpTranspose[iN]; + Kernel = apply_add ? impl->InterpTransposeAdd[iN] : impl->InterpTranspose[iN]; NB = impl->NB_interp_t[iN]; } else { Kernel = impl->Interp[iN]; @@ -396,7 +413,7 @@ static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, Ceed } } else { if (t_mode == CEED_TRANSPOSE) { - Kernel = impl->DerivTranspose[iN]; + Kernel = apply_add ? impl->DerivTransposeAdd[iN] : impl->DerivTranspose[iN]; NB = impl->NB_deriv_t[iN]; } else { Kernel = impl->Deriv[iN]; @@ -414,7 +431,7 @@ static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, Ceed } else { for (CeedInt d = 0; d < q_comp; d++) { if (t_mode == CEED_TRANSPOSE) { - const CeedScalar beta = (d > 0) ? 1.0 : 0.0; + const CeedScalar beta = (apply_add || (d > 0)) ? 1.0 : 0.0; magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, N, Q, 1.0, d_b + d * P * Q, P, d_u + d * N * Q, Q, beta, d_v, P, data->queue); } else { magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, N, P, 1.0, d_b + d * P * Q, P, d_u, P, 0.0, d_v + d * N * Q, Q, data->queue); @@ -443,6 +460,18 @@ static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, Ceed return CEED_ERROR_SUCCESS; } +static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, + CeedVector v) { + CeedCallBackend(CeedBasisApplyNonTensorCore_Magma(basis, false, num_elem, t_mode, e_mode, u, v)); + return CEED_ERROR_SUCCESS; +} + +static int CeedBasisApplyAddNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, + CeedVector v) { + CeedCallBackend(CeedBasisApplyNonTensorCore_Magma(basis, true, num_elem, t_mode, e_mode, u, v)); + return CEED_ERROR_SUCCESS; +} + //------------------------------------------------------------------------------ // Destroy tensor basis //------------------------------------------------------------------------------ @@ -559,22 +588,28 @@ int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const case 1: CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->Interp)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->InterpTranspose)); + CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_1d_kernel", &impl->InterpTransposeAdd)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->Grad)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->GradTranspose)); + CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_1d_kernel", &impl->GradTransposeAdd)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->Weight)); break; case 2: CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->Interp)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->InterpTranspose)); + CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_2d_kernel", &impl->InterpTransposeAdd)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->Grad)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->GradTranspose)); + CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_2d_kernel", &impl->GradTransposeAdd)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->Weight)); break; case 3: CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->Interp)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->InterpTranspose)); + CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_3d_kernel", &impl->InterpTransposeAdd)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->Grad)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->GradTranspose)); + CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_3d_kernel", &impl->GradTransposeAdd)); CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->Weight)); break; } @@ -588,6 +623,7 @@ int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedCallBackend(CeedBasisSetData(basis, impl)); CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma)); + CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAdd_Magma)); CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAtPoints", CeedBasisApplyAtPoints_Magma)); CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma)); return CEED_ERROR_SUCCESS; @@ -650,6 +686,7 @@ int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_node // Register backend functions CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma)); + CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma)); CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma)); return CEED_ERROR_SUCCESS; } @@ -711,6 +748,7 @@ int CeedBasisCreateHdiv_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_no // Register backend functions CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma)); + CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma)); CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma)); return CEED_ERROR_SUCCESS; } @@ -772,6 +810,7 @@ int CeedBasisCreateHcurl_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_n // Register backend functions CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma)); + CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma)); CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma)); return CEED_ERROR_SUCCESS; } diff --git a/backends/magma/ceed-magma.h b/backends/magma/ceed-magma.h index aa60b37b40..22dd4264b6 100644 --- a/backends/magma/ceed-magma.h +++ b/backends/magma/ceed-magma.h @@ -47,8 +47,10 @@ typedef struct { CeedMagmaModule module; CeedMagmaFunction Interp; CeedMagmaFunction InterpTranspose; + CeedMagmaFunction InterpTransposeAdd; CeedMagmaFunction Grad; CeedMagmaFunction GradTranspose; + CeedMagmaFunction GradTransposeAdd; CeedMagmaFunction Weight; CeedScalar *d_interp_1d; CeedScalar *d_grad_1d; @@ -59,8 +61,10 @@ typedef struct { CeedMagmaModule module[MAGMA_NONTENSOR_KERNEL_INSTANCES]; CeedMagmaFunction Interp[MAGMA_NONTENSOR_KERNEL_INSTANCES]; CeedMagmaFunction InterpTranspose[MAGMA_NONTENSOR_KERNEL_INSTANCES]; + CeedMagmaFunction InterpTransposeAdd[MAGMA_NONTENSOR_KERNEL_INSTANCES]; CeedMagmaFunction Deriv[MAGMA_NONTENSOR_KERNEL_INSTANCES]; CeedMagmaFunction DerivTranspose[MAGMA_NONTENSOR_KERNEL_INSTANCES]; + CeedMagmaFunction DerivTransposeAdd[MAGMA_NONTENSOR_KERNEL_INSTANCES]; CeedMagmaFunction Weight; CeedInt NB_interp[MAGMA_NONTENSOR_KERNEL_INSTANCES], NB_interp_t[MAGMA_NONTENSOR_KERNEL_INSTANCES]; CeedInt NB_deriv[MAGMA_NONTENSOR_KERNEL_INSTANCES], NB_deriv_t[MAGMA_NONTENSOR_KERNEL_INSTANCES]; diff --git a/include/ceed/jit-source/magma/magma-basis-grad-1d.h b/include/ceed/jit-source/magma/magma-basis-grad-1d.h index dd21682225..cd6f8548fb 100644 --- a/include/ceed/jit-source/magma/magma-basis-grad-1d.h +++ b/include/ceed/jit-source/magma/magma-basis-grad-1d.h @@ -126,3 +126,48 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_ // write V write_1d(sV, dV, cstrdV, tx); } + +//////////////////////////////////////////////////////////////////////////////// +extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_1D)) __global__ + void magma_gradta_1d_kernel(const CeedScalar *dTinterp, const CeedScalar *dTgrad, const CeedScalar *dU, const int estrdU, const int cstrdU, + const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) { + MAGMA_DEVICE_SHARED(CeedScalar, shared_data) + + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int elem_id = (blockIdx.x * blockDim.y) + ty; + + if (elem_id >= nelem) return; + + CeedScalar *sU[BASIS_NUM_COMP]; + CeedScalar *sV[BASIS_NUM_COMP]; + + // shift global memory pointers by elem stride + dU += elem_id * estrdU; + dV += elem_id * estrdV; + + // assign shared memory pointers + CeedScalar *sT = (CeedScalar *)shared_data; + CeedScalar *sW = sT + BASIS_Q * BASIS_P; + sU[0] = sW + ty * BASIS_NUM_COMP * (BASIS_Q + BASIS_P); + sV[0] = sU[0] + (BASIS_NUM_COMP * 1 * BASIS_Q); + for (int comp = 1; comp < BASIS_NUM_COMP; comp++) { + sU[comp] = sU[comp - 1] + (1 * BASIS_Q); + sV[comp] = sV[comp - 1] + (1 * BASIS_P); + } + + // read T + if (ty == 0) { + read_T_trans_gm2sm(tx, dTgrad, sT); + } + + // read U + read_1d(dU, cstrdU, sU, tx); + + __syncthreads(); + magma_grad_1d_device(sT, sU, sV, tx); + __syncthreads(); + + // sum into V + sum_1d(sV, dV, cstrdV, tx); +} diff --git a/include/ceed/jit-source/magma/magma-basis-grad-2d.h b/include/ceed/jit-source/magma/magma-basis-grad-2d.h index 23559716dc..b4e7e2981a 100644 --- a/include/ceed/jit-source/magma/magma-basis-grad-2d.h +++ b/include/ceed/jit-source/magma/magma-basis-grad-2d.h @@ -188,3 +188,54 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_ // write V write_V_2d(dV + (0 * dstrdV), cstrdV, rV, tx); } + +//////////////////////////////////////////////////////////////////////////////// +extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_2D)) __global__ + void magma_gradta_2d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU, + const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) { + MAGMA_DEVICE_SHARED(CeedScalar, shared_data) + + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int elem_id = (blockIdx.x * blockDim.y) + ty; + + if (elem_id >= nelem) return; + + CeedScalar rU[1][BASIS_NUM_COMP][BASIS_Q] = {0.0}; // here DIM_U = 1, but might be different for a fused operator + CeedScalar rV[1][BASIS_NUM_COMP][BASIS_P] = {0.0}; // here DIM_V = 1, but might be different for a fused operator + CeedScalar rTmp = 0.0; + + // shift global memory pointers by elem stride + dU += elem_id * estrdU; + dV += elem_id * estrdV; + + // assign shared memory pointers + CeedScalar *sTinterp = (CeedScalar *)shared_data; + CeedScalar *sTgrad = sTinterp + BASIS_Q * BASIS_P; + CeedScalar *sTmp = sTgrad + BASIS_Q * BASIS_P; + sTmp += ty * (BASIS_Q * BASIS_MAX_P_Q); + + // read T + if (ty == 0) { + read_T_trans_gm2sm(tx, dinterp1d, sTinterp); + read_T_trans_gm2sm(tx, dgrad1d, sTgrad); + } + __syncthreads(); + + /* read U (idim = 0 for dU, i_DIM = 0 for rU) -- + there is a sync at the end of this function */ + read_U_2d(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx); + /* first call (i_DIM = 0, i_DIM_U = 0, i_DIM_V = 0) */ + magma_grad_2d_device(sTinterp, sTgrad, rU, rV, tx, rTmp, sTmp); + /* there is a sync at the end of magma_grad_2d_device */ + + /* read U (idim = 1 for dU, i_DIM = 0 for rU) -- + there is a sync at the end of this function */ + read_U_2d(dU + (1 * dstrdU), cstrdU, rU, sTmp, tx); + /* second call (i_DIM = 1, i_DIM_U = 0, i_DIM_V = 0) */ + magma_grad_2d_device(sTinterp, sTgrad, rU, rV, tx, rTmp, sTmp); + /* there is a sync at the end of magma_grad_2d_device */ + + // sum into V + sum_V_2d(dV + (0 * dstrdV), cstrdV, rV, tx); +} diff --git a/include/ceed/jit-source/magma/magma-basis-grad-3d.h b/include/ceed/jit-source/magma/magma-basis-grad-3d.h index c030f8e9e5..c8028be756 100644 --- a/include/ceed/jit-source/magma/magma-basis-grad-3d.h +++ b/include/ceed/jit-source/magma/magma-basis-grad-3d.h @@ -225,3 +225,61 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q *BASIS_MAX_P_Q, MA // write V write_V_3d(dV + (0 * dstrdV), cstrdV, rV, tx); } + +//////////////////////////////////////////////////////////////////////////////// +extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q *BASIS_MAX_P_Q, MAGMA_MAXTHREADS_3D)) __global__ + void magma_gradta_3d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU, + const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) { + MAGMA_DEVICE_SHARED(CeedScalar, shared_data) + + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int elem_id = (blockIdx.x * blockDim.y) + ty; + + if (elem_id >= nelem) return; + + CeedScalar rU[1][BASIS_NUM_COMP][BASIS_Q] = {0.0}; // here DIM_U = 1, but might be different for a fused operator + CeedScalar rV[1][BASIS_NUM_COMP][BASIS_P] = {0.0}; // here DIM_V = 1, but might be different for a fused operator + CeedScalar rTmp = 0.0; + + // shift global memory pointers by elem stride + dU += elem_id * estrdU; + dV += elem_id * estrdV; + + // assign shared memory pointers + CeedScalar *sTinterp = (CeedScalar *)shared_data; + CeedScalar *sTgrad = sTinterp + BASIS_Q * BASIS_P; + CeedScalar *sTmp = sTgrad + BASIS_Q * BASIS_P; + sTmp += ty * (max(BASIS_Q * BASIS_Q * BASIS_Q, (BASIS_Q * BASIS_Q * BASIS_P) + (BASIS_Q * BASIS_P * BASIS_P))); + + // read T + if (ty == 0) { + read_T_trans_gm2sm(tx, dinterp1d, sTinterp); + read_T_trans_gm2sm(tx, dgrad1d, sTgrad); + } + __syncthreads(); + + /* read U (idim = 0 for dU, i_DIM = 0 for rU) -- + there is a sync at the end of this function */ + read_U_3d(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx); + /* then first call (i_DIM = 0, i_DIM_U = 0, i_DIM_V = 0) */ + magma_grad_3d_device(sTinterp, sTgrad, rU, rV, tx, rTmp, sTmp); + /* there is a sync at the end of magma_grad_3d_device */ + + /* read U (idim = 1 for dU, i_DIM = 0 for rU) -- + there is a sync at the end of this function */ + read_U_3d(dU + (1 * dstrdU), cstrdU, rU, sTmp, tx); + /* then second call (i_DIM = 1, i_DIM_U = 0, i_DIM_V = 0) */ + magma_grad_3d_device(sTinterp, sTgrad, rU, rV, tx, rTmp, sTmp); + /* there is a sync at the end of magma_grad_3d_device */ + + /* read U (idim = 2 for dU, i_DIM = 0 for rU) -- + there is a sync at the end of this function */ + read_U_3d(dU + (2 * dstrdU), cstrdU, rU, sTmp, tx); + /* then third call (i_DIM = 2, i_DIM_U = 0, i_DIM_V = 0) */ + magma_grad_3d_device(sTinterp, sTgrad, rU, rV, tx, rTmp, sTmp); + /* there is a sync at the end of magma_grad_3d_device */ + + // sum into V + sum_V_3d(dV + (0 * dstrdV), cstrdV, rV, tx); +} diff --git a/include/ceed/jit-source/magma/magma-basis-interp-1d.h b/include/ceed/jit-source/magma/magma-basis-interp-1d.h index ae8d082653..02f894ecce 100644 --- a/include/ceed/jit-source/magma/magma-basis-interp-1d.h +++ b/include/ceed/jit-source/magma/magma-basis-interp-1d.h @@ -126,3 +126,48 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_ // write V write_1d(sV, dV, cstrdV, tx); } + +//////////////////////////////////////////////////////////////////////////////// +extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_1D)) __global__ + void magma_interpta_1d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, + const int cstrdV, const int nelem) { + MAGMA_DEVICE_SHARED(CeedScalar, shared_data) + + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int elem_id = (blockIdx.x * blockDim.y) + ty; + + if (elem_id >= nelem) return; + + CeedScalar *sU[BASIS_NUM_COMP]; + CeedScalar *sV[BASIS_NUM_COMP]; + + // shift global memory pointers by elem stride + dU += elem_id * estrdU; + dV += elem_id * estrdV; + + // assign shared memory pointers + CeedScalar *sT = (CeedScalar *)shared_data; + CeedScalar *sW = sT + BASIS_Q * BASIS_P; + sU[0] = sW + ty * BASIS_NUM_COMP * (BASIS_Q + BASIS_P); + sV[0] = sU[0] + (BASIS_NUM_COMP * 1 * BASIS_Q); + for (int comp = 1; comp < BASIS_NUM_COMP; comp++) { + sU[comp] = sU[comp - 1] + (1 * BASIS_Q); + sV[comp] = sV[comp - 1] + (1 * BASIS_P); + } + + // read T + if (ty == 0) { + read_T_trans_gm2sm(tx, dT, sT); + } + + // read U + read_1d(dU, cstrdU, sU, tx); + + __syncthreads(); + magma_interp_1d_device(sT, sU, sV, tx); + __syncthreads(); + + // sum into V + sum_1d(sV, dV, cstrdV, tx); +} diff --git a/include/ceed/jit-source/magma/magma-basis-interp-2d.h b/include/ceed/jit-source/magma/magma-basis-interp-2d.h index a2a41a25ae..56c8081c83 100644 --- a/include/ceed/jit-source/magma/magma-basis-interp-2d.h +++ b/include/ceed/jit-source/magma/magma-basis-interp-2d.h @@ -144,3 +144,44 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_ // write V write_V_2d(dV, cstrdV, rV, tx); } + +//////////////////////////////////////////////////////////////////////////////// +extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_2D)) __global__ + void magma_interpta_2d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, + const int cstrdV, const int nelem) { + MAGMA_DEVICE_SHARED(CeedScalar, shared_data) + + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int elem_id = (blockIdx.x * blockDim.y) + ty; + + if (elem_id >= nelem) return; + + CeedScalar rU[1][BASIS_NUM_COMP][BASIS_Q] = {0.0}; // for a non-fused operator BASIS_DIM is always 1 + CeedScalar rV[1][BASIS_NUM_COMP][BASIS_P] = {0.0}; // for a non-fused operator BASIS_DIM is always 1 + CeedScalar rTmp = 0.0; + + // shift global memory pointers by elem stride + dU += elem_id * estrdU; + dV += elem_id * estrdV; + + // assign shared memory pointers + CeedScalar *sT = (CeedScalar *)shared_data; + CeedScalar *sTmp = sT + BASIS_Q * BASIS_P; + sTmp += ty * (BASIS_Q * BASIS_MAX_P_Q); + + // read T + if (ty == 0) { + read_T_trans_gm2sm(tx, dT, sT); + } + + // read U -- there is a sync at the end of this function + read_U_2d(dU, cstrdU, rU, sTmp, tx); + + // no sync needed here -- read_U_2d already syncs at the end + magma_interp_2d_device(sT, rU, rV, tx, rTmp, sTmp); + __syncthreads(); + + // sum into V + sum_V_2d(dV, cstrdV, rV, tx); +} diff --git a/include/ceed/jit-source/magma/magma-basis-interp-3d.h b/include/ceed/jit-source/magma/magma-basis-interp-3d.h index 50c7e4df4a..ac11e3f8df 100644 --- a/include/ceed/jit-source/magma/magma-basis-interp-3d.h +++ b/include/ceed/jit-source/magma/magma-basis-interp-3d.h @@ -172,3 +172,44 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q *BASIS_MAX_P_Q, MA // write V write_V_3d(dV, cstrdV, rV, tx); } + +//////////////////////////////////////////////////////////////////////////////// +extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q *BASIS_MAX_P_Q, MAGMA_MAXTHREADS_3D)) __global__ + void magma_interpta_3d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, + const int cstrdV, const int nelem) { + MAGMA_DEVICE_SHARED(CeedScalar, shared_data) + + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int elem_id = (blockIdx.x * blockDim.y) + ty; + + if (elem_id >= nelem) return; + + CeedScalar rU[1][BASIS_NUM_COMP][BASIS_Q] = {0.0}; // for a non-fused operator BASIS_DIM is always 1 + CeedScalar rV[1][BASIS_NUM_COMP][BASIS_P] = {0.0}; // for a non-fused operator BASIS_DIM is always 1 + CeedScalar rTmp[BASIS_P] = {0.0}; + + // shift global memory pointers by elem stride + dU += elem_id * estrdU; + dV += elem_id * estrdV; + + // assign shared memory pointers + CeedScalar *sT = (CeedScalar *)shared_data; + CeedScalar *sTmp = sT + BASIS_Q * BASIS_P; + sTmp += ty * (max(BASIS_Q * BASIS_Q * BASIS_MAX_P_Q, BASIS_Q * BASIS_P * BASIS_P)); + + // read T + if (ty == 0) { + read_T_trans_gm2sm(tx, dT, sT); + } + + // read U (idim = 0 for dU, i_DIM = 0 for rU, u_dimstride is always 0) + read_U_3d(dU, cstrdU, rU, sTmp, tx); + // there is a sync at the end of this function + + magma_interp_3d_device(sT, rU, rV, tx, rTmp, sTmp); + __syncthreads(); + + // sum into V + sum_V_3d(dV, cstrdV, rV, tx); +} diff --git a/include/ceed/jit-source/magma/magma-basis-interp-deriv-nontensor.h b/include/ceed/jit-source/magma/magma-basis-interp-deriv-nontensor.h index f5e2df1e90..00f7212cc0 100644 --- a/include/ceed/jit-source/magma/magma-basis-interp-deriv-nontensor.h +++ b/include/ceed/jit-source/magma/magma-basis-interp-deriv-nontensor.h @@ -99,6 +99,52 @@ static __device__ __inline__ void magma_basis_nontensor_device_t(const int n, Ce } } +//////////////////////////////////////////////////////////////////////////////// +template +static __device__ __inline__ void magma_basis_nontensor_device_ta(const int n, CeedScalar const *dA, CeedScalar const *dB, CeedScalar *dC, + CeedScalar *shared_data) { + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int id = blockIdx.x * blockDim.y + ty; + const int nblocks = (n + NB - 1) / NB; + const int myn = min(NB, n - id * NB); + + dB += id * Q * NB; + dC += id * P * NB; + + // A is P x Q + CeedScalar *sA = shared_data; + CeedScalar *sB = shared_data + ty * Q * NB; + + CeedScalar rC[NB] = {0.0}; + + // unrolling this loop yields dramatic performance drop using hipcc, so let the compiler decide (no pragma unroll) + for (int d = 0; d < Q_COMP; d++) { + // read A using all threads + CeedScalar rA[Q]; + read_A_notrans_g2r_1D_nosync(tx, ty, dA, sA, rA); + __syncthreads(); + + // read B + if (id < nblocks) { + read_B_g2s_1D_nosync(tx, myn, dB, sB); + } + __syncthreads(); + + addmul_rAsBrC_1D_nosync(rA, sB, rC); + + dA += P * Q; + dB += Q * n; + + __syncthreads(); + } + + // sum into C + if (id < nblocks) { + sum_C_r2g_1D_nosync(tx, myn, rC, dC); + } +} + //////////////////////////////////////////////////////////////////////////////// template static __device__ __inline__ void magma_basis_nontensor_device_n1(const int n, CeedScalar const *dA, CeedScalar const *dB, CeedScalar *dC, @@ -171,6 +217,42 @@ static __device__ __inline__ void magma_basis_nontensor_device_t1(const int n, C write_C_r2g_1D_nosync(tx, myn, rC, dC); } +//////////////////////////////////////////////////////////////////////////////// +template +static __device__ __inline__ void magma_basis_nontensor_device_ta1(const int n, CeedScalar const *dA, CeedScalar const *dB, CeedScalar *dC, + CeedScalar *shared_data) { + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int id = blockIdx.x * blockDim.y + ty; + const int nblocks = (n + NB - 1) / NB; + const int myn = min(NB, n - id * NB); + + dB += id * Q * NB; + dC += id * P * NB; + + // A is P x Q + CeedScalar *sA = shared_data; + CeedScalar *sB = shared_data + ty * Q * NB; + + // read A using all threads + CeedScalar rA[Q]; + read_A_notrans_g2r_1D_nosync(tx, ty, dA, sA, rA); + __syncthreads(); + + // terminate threads with no work + if (id >= nblocks) return; + + // read B + read_B_g2s_1D_nosync(tx, myn, dB, sB); + __syncthreads(); + + CeedScalar rC[NB]; + mul_rAsBrC_1D_nosync(rA, sB, rC); + + // sum into C + sum_C_r2g_1D_nosync(tx, myn, rC, dC); +} + //////////////////////////////////////////////////////////////////////////////// extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_Q, MAGMA_MAXTHREADS_1D)) __global__ void magma_interp_nontensor_n(const int n, CeedScalar const *__restrict__ dA, CeedScalar const *__restrict__ dB, CeedScalar *__restrict__ dC) { @@ -195,6 +277,18 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_P, MAGMA_MAXTHREADS_1D)) _ #endif } +//////////////////////////////////////////////////////////////////////////////// +extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_P, MAGMA_MAXTHREADS_1D)) __global__ + void magma_interp_nontensor_ta(const int n, CeedScalar const *__restrict__ dA, CeedScalar const *__restrict__ dB, CeedScalar *__restrict__ dC) { + MAGMA_DEVICE_SHARED(CeedScalar, shared_data); + +#if BASIS_Q_COMP_INTERP == 1 + magma_basis_nontensor_device_ta1(n, dA, dB, dC, (CeedScalar *)shared_data); +#else + magma_basis_nontensor_device_ta(n, dA, dB, dC, (CeedScalar *)shared_data); +#endif +} + //////////////////////////////////////////////////////////////////////////////// extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_Q, MAGMA_MAXTHREADS_1D)) __global__ void magma_deriv_nontensor_n(const int n, CeedScalar const *__restrict__ dA, CeedScalar const *__restrict__ dB, CeedScalar *__restrict__ dC) { @@ -218,3 +312,15 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_P, MAGMA_MAXTHREADS_1D)) _ magma_basis_nontensor_device_t(n, dA, dB, dC, (CeedScalar *)shared_data); #endif } + +//////////////////////////////////////////////////////////////////////////////// +extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_P, MAGMA_MAXTHREADS_1D)) __global__ + void magma_deriv_nontensor_ta(const int n, CeedScalar const *__restrict__ dA, CeedScalar const *__restrict__ dB, CeedScalar *__restrict__ dC) { + MAGMA_DEVICE_SHARED(CeedScalar, shared_data); + +#if BASIS_Q_COMP_DERIV == 1 + magma_basis_nontensor_device_ta1(n, dA, dB, dC, (CeedScalar *)shared_data); +#else + magma_basis_nontensor_device_ta(n, dA, dB, dC, (CeedScalar *)shared_data); +#endif +} diff --git a/include/ceed/jit-source/magma/magma-common-nontensor.h b/include/ceed/jit-source/magma/magma-common-nontensor.h index 730acc6419..945227d145 100644 --- a/include/ceed/jit-source/magma/magma-common-nontensor.h +++ b/include/ceed/jit-source/magma/magma-common-nontensor.h @@ -104,6 +104,25 @@ static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int } } +//////////////////////////////////////////////////////////////////////////////// +// sum into C from reg. to global +// C is (P x NB) +// 1D thread config. with (P x 1) threads +// no sync at the end of the function +template +static __device__ __inline__ void sum_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) { + if (n != NB) { + for (int i = 0; i < n; i++) { + dC[i * P + tx] += rC[i]; + } + } else { +#pragma unroll + for (int i = 0; i < NB; i++) { + dC[i * P + tx] += rC[i]; + } + } +} + //////////////////////////////////////////////////////////////////////////////// // multiply C = A x B using 1D threads in P x 1 config // A (P x Q) in reg., one row per thread diff --git a/include/ceed/jit-source/magma/magma-common-tensor.h b/include/ceed/jit-source/magma/magma-common-tensor.h index 6c483abd9d..494afacd87 100644 --- a/include/ceed/jit-source/magma/magma-common-tensor.h +++ b/include/ceed/jit-source/magma/magma-common-tensor.h @@ -36,6 +36,18 @@ static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, cons } } +//////////////////////////////////////////////////////////////////////////////// +// sum into V of a 1D element into global memory from sV[][] -- for all components +// the devptr is assumed to point directly to the element +template +static __device__ __inline__ void sum_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) { + if (tx < LENGTH) { + for (int comp = 0; comp < NUM_COMP; comp++) { + devptr[comp * compstride + tx] += sBuffer[comp][tx]; + } + } +} + //////////////////////////////////////////////////////////////////////////////// // read U of a 2D element into registers rU[][][] -- for all components of a single dim // dU is assumed to be offset by elem-stride and dim-stride @@ -107,6 +119,23 @@ static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[D } } +//////////////////////////////////////////////////////////////////////////////// +// sum into V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim +// dV is assumed to be offset by elem-stride and dim-stride +// register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] +// i_DIM specifies which dimension is being written to in dV +// rV_SIZE can be different from P (e.g. max(P, Q)) +template +static __device__ __inline__ void sum_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { + if (tx < Q) { + for (int comp = 0; comp < NUM_COMP; comp++) { + for (int j = 0; j < Q; j++) { + dV[comp * compstride + j * Q + tx] += rV[i_DIM][comp][j]; + } + } + } +} + //////////////////////////////////////////////////////////////////////////////// // read U of a 3D element into registers rU[][][] -- for all components of a single dim // dU is assumed to be offset by elem-stride and dim-stride @@ -178,6 +207,23 @@ static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[D } } +//////////////////////////////////////////////////////////////////////////////// +// sum into V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim +// dV is assumed to point directly to the element (i.e. already offset by elem-stride) +// register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] +// i_DIM specifies which dimension is being written to in dV +// rV_SIZE can be different from P (e.g. max(P, Q)) +template +static __device__ __inline__ void sum_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { + if (tx < (Q * Q)) { + for (int comp = 0; comp < NUM_COMP; comp++) { + for (int j = 0; j < Q; j++) { + dV[comp * compstride + j * (Q * Q) + tx] += rV[i_DIM][comp][j]; + } + } + } +} + //////////////////////////////////////////////////////////////////////////////// // reads T (no-trans) into shared memory // T is B x J