Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multiple precisions in CeedVector #948

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
73 changes: 43 additions & 30 deletions backends/cuda-ref/ceed-cuda-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
// Sync arrays
//------------------------------------------------------------------------------
static inline int CeedVectorSync_Cuda(const CeedVector vec,
CeedScalarType prec,
CeedMemType mem_type) {
switch (mem_type) {
case CEED_MEM_HOST: return CeedVectorSyncD2H_Cuda(vec);
Expand Down Expand Up @@ -150,7 +151,8 @@ static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec,
// Check if has borrowed array of given type
//------------------------------------------------------------------------------
static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec,
CeedMemType mem_type, bool *has_borrowed_array_of_type) {
CeedMemType mem_type, CeedScalarType prec,
bool *has_borrowed_array_of_type) {
int ierr;
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
Expand Down Expand Up @@ -194,7 +196,7 @@ static inline int CeedVectorNeedSync_Cuda(const CeedVector vec,
// Set array from host
//------------------------------------------------------------------------------
static int CeedVectorSetArrayHost_Cuda(const CeedVector vec,
const CeedCopyMode copy_mode, CeedScalar *array) {
const CeedCopyMode copy_mode, void *array) {
int ierr;
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
Expand Down Expand Up @@ -236,7 +238,7 @@ static int CeedVectorSetArrayHost_Cuda(const CeedVector vec,
// Set array from device
//------------------------------------------------------------------------------
static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec,
const CeedCopyMode copy_mode, CeedScalar *array) {
const CeedCopyMode copy_mode, void *array) {
int ierr;
Ceed ceed;
ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
Expand Down Expand Up @@ -280,9 +282,11 @@ static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec,
// Set the array used by a vector,
// freeing any previously allocated array if applicable
//------------------------------------------------------------------------------
static int CeedVectorSetArray_Cuda(const CeedVector vec,
const CeedMemType mem_type,
const CeedCopyMode copy_mode, CeedScalar *array) {
static int CeedVectorSetArrayGeneric_Cuda(const CeedVector vec,
const CeedMemType mem_type,
const CeedScalarType prec,
const CeedCopyMode copy_mode,
void *array) {
int ierr;
Ceed ceed;
ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
Expand Down Expand Up @@ -359,8 +363,8 @@ static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
//------------------------------------------------------------------------------
// Vector Take Array
//------------------------------------------------------------------------------
static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type,
CeedScalar **array) {
static int CeedVectorTakeArrayGeneric_Cuda(CeedVector vec, CeedMemType mem_type,
CeedScalarType prec, void **array) {
int ierr;
Ceed ceed;
ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
Expand Down Expand Up @@ -396,7 +400,9 @@ static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type,
// If a different memory type is most up to date, this will perform a copy
//------------------------------------------------------------------------------
static int CeedVectorGetArrayCore_Cuda(const CeedVector vec,
const CeedMemType mem_type, CeedScalar **array) {
const CeedMemType mem_type,
const CeedScalarType prec,
void **array) {
int ierr;
Ceed ceed;
ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
Expand Down Expand Up @@ -427,21 +433,26 @@ static int CeedVectorGetArrayCore_Cuda(const CeedVector vec,
//------------------------------------------------------------------------------
// Get read-only access to a vector via the specified mem_type
//------------------------------------------------------------------------------
static int CeedVectorGetArrayRead_Cuda(const CeedVector vec,
const CeedMemType mem_type, const CeedScalar **array) {
return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
static int CeedVectorGetArrayReadGeneric_Cuda(const CeedVector vec,
const CeedMemType mem_type,
const CeedScalarType prec,
const void **array) {
return CeedVectorGetArrayCore_Cuda(vec, mem_type, prec, (void **)array);
}

//------------------------------------------------------------------------------
// Get read/write access to a vector via the specified mem_type
// Get read/write access to a vector via the specified mem_type and precision
//------------------------------------------------------------------------------
static int CeedVectorGetArray_Cuda(const CeedVector vec,
const CeedMemType mem_type, CeedScalar **array) {
const CeedMemType mem_type,
const CeedScalarType prec,
void **array) {
int ierr;
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

ierr = CeedVectorGetArrayCore_Cuda(vec, mem_type, array); CeedChkBackend(ierr);
ierr = CeedVectorGetArrayCore_Cuda(vec, mem_type, prec, array);
CeedChkBackend(ierr);

ierr = CeedVectorSetAllInvalid_Cuda(vec); CeedChkBackend(ierr);
switch (mem_type) {
Expand All @@ -457,10 +468,12 @@ static int CeedVectorGetArray_Cuda(const CeedVector vec,
}

//------------------------------------------------------------------------------
// Get write access to a vector via the specified mem_type
// Get write access to a vector via the specified mem_type and precision
//------------------------------------------------------------------------------
static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec,
const CeedMemType mem_type, CeedScalar **array) {
static int CeedVectorGetArrayWriteGeneric_Cuda(const CeedVector vec,
const CeedMemType mem_type,
const CeedScalarType prec,
void **array) {
int ierr;
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
Expand All @@ -470,7 +483,7 @@ static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec,
CeedChkBackend(ierr);
if (!has_array_of_type) {
// Allocate if array is not yet allocated
ierr = CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL);
ierr = CeedVectorSetArrayGeneric(vec, mem_type, prec, CEED_COPY_VALUES, NULL);
CeedChkBackend(ierr);
} else {
// Select dirty array
Expand All @@ -489,7 +502,7 @@ static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec,
}
}

return CeedVectorGetArray_Cuda(vec, mem_type, array);
return CeedVectorGetArrayGeneric_Cuda(vec, mem_type, prec, array);
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -756,19 +769,19 @@ int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType",
CeedVectorHasBorrowedArrayOfType_Cuda);
CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetArray",
CeedVectorSetArray_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray",
CeedVectorTakeArray_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetArrayGeneric",
CeedVectorSetArrayGeneric_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "TakeArrayGeneric",
CeedVectorTakeArrayGeneric_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
(int (*)())(CeedVectorSetValue_Cuda));
CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
CeedVectorGetArray_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
CeedVectorGetArrayRead_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite",
CeedVectorGetArrayWrite_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayGeneric",
CeedVectorGetArrayGeneric_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayReadGeneric",
CeedVectorGetArrayReadGeneric_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWriteGeneric",
CeedVectorGetArrayWriteGeneric_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Norm",
CeedVectorNorm_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal",
Expand Down
Loading