From 413bfb4678442355ec7a51c3beb9bccc1ab7f5e4 Mon Sep 17 00:00:00 2001 From: "Balos, Cody, J" Date: Thu, 30 Nov 2023 08:04:38 -0800 Subject: [PATCH] fix handling of SUN_COMM_NULL in Fortran --- include/sundials/sundials_types.h | 8 +++++++ .../fmod/fnvector_mpimanyvector_mod.c | 2 +- .../mpiplusx/fmod/fnvector_mpiplusx_mod.c | 2 +- .../parallel/fmod/fnvector_parallel_mod.c | 6 ++--- src/sundials/fmod/fsundials_context_mod.c | 2 +- src/sundials/fmod/fsundials_logger_mod.c | 4 ++-- .../fmod/fsundials_nonlinearsolver_mod.c | 6 ++--- src/sundials/fmod/fsundials_profiler_mod.c | 2 +- src/sundials/fmod/fsundials_types_mod.c | 5 ++-- src/sundials/fmod/fsundials_types_mod.f90 | 11 +++++++-- src/sundials/sundials_context.c | 1 + src/sundials/sundials_logger.c | 1 + swig/sundials/fsundials_types_mod.i | 23 ++++++++++++++++--- 13 files changed, 54 insertions(+), 19 deletions(-) diff --git a/include/sundials/sundials_types.h b/include/sundials/sundials_types.h index ba349ae917..0d7d19f893 100644 --- a/include/sundials/sundials_types.h +++ b/include/sundials/sundials_types.h @@ -214,11 +214,19 @@ typedef int (*SUNErrHandlerFn)(int line, const char* func, const char* file, *------------------------------------------------------------------ */ + /* We don't define SUN_COMM_NULL when SWIG is processing the header + because we manually insert the wrapper code for SUN_COMM_NULL + (and %ignoring it in the SWIG code doesn't seem to work). */ + #if SUNDIALS_MPI_ENABLED +#ifndef SWIG #define SUN_COMM_NULL MPI_COMM_NULL +#endif typedef MPI_Comm SUNComm; #else +#ifndef SWIG #define SUN_COMM_NULL 0 +#endif typedef int SUNComm; #endif diff --git a/src/nvector/manyvector/fmod/fnvector_mpimanyvector_mod.c b/src/nvector/manyvector/fmod/fnvector_mpimanyvector_mod.c index 6ae3b9460f..088f0eece4 100644 --- a/src/nvector/manyvector/fmod/fnvector_mpimanyvector_mod.c +++ b/src/nvector/manyvector/fmod/fnvector_mpimanyvector_mod.c @@ -222,7 +222,7 @@ SWIGEXPORT N_Vector _wrap_FN_VMake_MPIManyVector(int const *farg1, int64_t const if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; diff --git a/src/nvector/mpiplusx/fmod/fnvector_mpiplusx_mod.c b/src/nvector/mpiplusx/fmod/fnvector_mpiplusx_mod.c index 3ba2e1d120..a788718787 100644 --- a/src/nvector/mpiplusx/fmod/fnvector_mpiplusx_mod.c +++ b/src/nvector/mpiplusx/fmod/fnvector_mpiplusx_mod.c @@ -221,7 +221,7 @@ SWIGEXPORT N_Vector _wrap_FN_VMake_MPIPlusX(int const *farg1, N_Vector farg2, vo if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; diff --git a/src/nvector/parallel/fmod/fnvector_parallel_mod.c b/src/nvector/parallel/fmod/fnvector_parallel_mod.c index f6c93c82c7..f3133f4006 100644 --- a/src/nvector/parallel/fmod/fnvector_parallel_mod.c +++ b/src/nvector/parallel/fmod/fnvector_parallel_mod.c @@ -222,7 +222,7 @@ SWIGEXPORT N_Vector _wrap_FN_VNew_Parallel(int const *farg1, int64_t const *farg if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; @@ -250,7 +250,7 @@ SWIGEXPORT N_Vector _wrap_FN_VNewEmpty_Parallel(int const *farg1, int64_t const if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; @@ -279,7 +279,7 @@ SWIGEXPORT N_Vector _wrap_FN_VMake_Parallel(int const *farg1, int64_t const *far if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; diff --git a/src/sundials/fmod/fsundials_context_mod.c b/src/sundials/fmod/fsundials_context_mod.c index 762fa185da..be6efb29c0 100644 --- a/src/sundials/fmod/fsundials_context_mod.c +++ b/src/sundials/fmod/fsundials_context_mod.c @@ -218,7 +218,7 @@ SWIGEXPORT int _wrap_FSUNContext_Create(int const *farg1, void *farg2) { if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; diff --git a/src/sundials/fmod/fsundials_logger_mod.c b/src/sundials/fmod/fsundials_logger_mod.c index 6aace0705e..f15968862e 100644 --- a/src/sundials/fmod/fsundials_logger_mod.c +++ b/src/sundials/fmod/fsundials_logger_mod.c @@ -246,7 +246,7 @@ SWIGEXPORT int _wrap_FSUNLogger_Create(int const *farg1, int const *farg2, void if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; @@ -271,7 +271,7 @@ SWIGEXPORT int _wrap_FSUNLogger_CreateFromEnv(int const *farg1, void *farg2) { if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; diff --git a/src/sundials/fmod/fsundials_nonlinearsolver_mod.c b/src/sundials/fmod/fsundials_nonlinearsolver_mod.c index f4e7d34d2c..64f26c50d6 100644 --- a/src/sundials/fmod/fsundials_nonlinearsolver_mod.c +++ b/src/sundials/fmod/fsundials_nonlinearsolver_mod.c @@ -254,13 +254,13 @@ SWIGEXPORT int _wrap_FSUNNonlinSolSetup(SUNNonlinearSolver farg1, N_Vector farg2 SUNNonlinearSolver arg1 = (SUNNonlinearSolver) 0 ; N_Vector arg2 = (N_Vector) 0 ; void *arg3 = (void *) 0 ; - SUNErrCode result; + int result; arg1 = (SUNNonlinearSolver)(farg1); arg2 = (N_Vector)(farg2); arg3 = (void *)(farg3); - result = (SUNErrCode)SUNNonlinSolSetup(arg1,arg2,arg3); - fresult = (SUNErrCode)(result); + result = (int)SUNNonlinSolSetup(arg1,arg2,arg3); + fresult = (int)(result); return fresult; } diff --git a/src/sundials/fmod/fsundials_profiler_mod.c b/src/sundials/fmod/fsundials_profiler_mod.c index 1fd2f8961f..edace7d7e6 100644 --- a/src/sundials/fmod/fsundials_profiler_mod.c +++ b/src/sundials/fmod/fsundials_profiler_mod.c @@ -246,7 +246,7 @@ SWIGEXPORT int _wrap_FSUNProfiler_Create(int const *farg1, SwigArrayWrapper *far if(flag) { arg1 = MPI_Comm_f2c((MPI_Fint)(*farg1)); } else { - arg1 = 0; + arg1 = SUN_COMM_NULL; } #else arg1 = *farg1; diff --git a/src/sundials/fmod/fsundials_types_mod.c b/src/sundials/fmod/fsundials_types_mod.c index 5e15ad4735..d40aa77129 100644 --- a/src/sundials/fmod/fsundials_types_mod.c +++ b/src/sundials/fmod/fsundials_types_mod.c @@ -216,10 +216,11 @@ #error "The Fortran bindings are only targeted at 64-bit indices" #endif + +SWIGEXPORT SWIGEXTERN int const _wrap_SUN_COMM_NULL = (int)(0); + SWIGEXPORT SWIGEXTERN int const _wrap_SUNFALSE = (int)(0); SWIGEXPORT SWIGEXTERN int const _wrap_SUNTRUE = (int)(1); -SWIGEXPORT SWIGEXTERN int const _wrap_SUN_COMM_NULL = (int)(0); - diff --git a/src/sundials/fmod/fsundials_types_mod.f90 b/src/sundials/fmod/fsundials_types_mod.f90 index c2823b52ad..41366fd513 100644 --- a/src/sundials/fmod/fsundials_types_mod.f90 +++ b/src/sundials/fmod/fsundials_types_mod.f90 @@ -24,6 +24,15 @@ module fsundials_types_mod private ! DECLARATION CONSTRUCTS + +#if SUNDIALS_MPI_ENABLED + include "mpif.h" + integer(C_INT), protected, public :: SUN_COMM_NULL = MPI_COMM_NULL +#else + integer(C_INT), protected, public, & + bind(C, name="_wrap_SUN_COMM_NULL") :: SUN_COMM_NULL +#endif + integer(C_INT), protected, public, & bind(C, name="_wrap_SUNFALSE") :: SUNFALSE integer(C_INT), protected, public, & @@ -35,7 +44,5 @@ module fsundials_types_mod end enum integer, parameter, public :: SUNOutputFormat = kind(SUN_OUTPUTFORMAT_TABLE) public :: SUN_OUTPUTFORMAT_TABLE, SUN_OUTPUTFORMAT_CSV - integer(C_INT), protected, public, & - bind(C, name="_wrap_SUN_COMM_NULL") :: SUN_COMM_NULL end module diff --git a/src/sundials/sundials_context.c b/src/sundials/sundials_context.c index 884c398694..0298f411e1 100644 --- a/src/sundials/sundials_context.c +++ b/src/sundials/sundials_context.c @@ -60,6 +60,7 @@ SUNErrCode SUNContext_Create(SUNComm comm, SUNContext* sunctx_out) do { #if SUNDIALS_LOGGING_LEVEL > 0 #if SUNDIALS_MPI_ENABLED + printf("SUNContext_Create: comm=%p, MPI_COMM_NULL=%p\n", (void*)comm, (void*)MPI_COMM_NULL); err = SUNLogger_CreateFromEnv(comm, &logger); SUNCheckCallNoRet(err); if (err) { break; } #else diff --git a/src/sundials/sundials_logger.c b/src/sundials/sundials_logger.c index b3cacf7dee..d2be8627fe 100644 --- a/src/sundials/sundials_logger.c +++ b/src/sundials/sundials_logger.c @@ -162,6 +162,7 @@ SUNErrCode SUNLogger_Create(SUNComm comm, int output_rank, SUNLogger* logger_ptr /* Attach the comm, duplicating it if MPI is used. */ #if SUNDIALS_MPI_ENABLED logger->comm = SUN_COMM_NULL; + printf("SUNLogger_Create: comm=%p, MPI_COMM_NULL=%p\n", (void*)comm, (void*)MPI_COMM_NULL); if (comm != SUN_COMM_NULL) { MPI_Comm_dup(comm, &logger->comm); diff --git a/swig/sundials/fsundials_types_mod.i b/swig/sundials/fsundials_types_mod.i index d45154be2d..87aea766f5 100644 --- a/swig/sundials/fsundials_types_mod.i +++ b/swig/sundials/fsundials_types_mod.i @@ -68,7 +68,6 @@ %apply MPI_Comm { SUNComm }; - // Insert code into the C wrapper to check that the sizes match %{ #include "sundials/sundials_types.h" @@ -82,8 +81,23 @@ #endif %} -// Process and wrap functions in the following files -%include "sundials/sundials_types.h" +// We insert the binding code for SUN_COMM_NULL ourselves because +// (1) SWIG expands SUN_COMM_NULL to its value +// (2) We need it to be equivalent to MPI_COMM_NULL when MPI is enabled + +%insert("wrapper") %{ +SWIGEXPORT SWIGEXTERN int const _wrap_SUN_COMM_NULL = (int)(0); +%} + +%insert("fdecl") %{ +#if SUNDIALS_MPI_ENABLED + include "mpif.h" + integer(C_INT), protected, public :: SUN_COMM_NULL = MPI_COMM_NULL +#else + integer(C_INT), protected, public, & + bind(C, name="_wrap_SUN_COMM_NULL") :: SUN_COMM_NULL +#endif +%} // Insert SUNDIALS copyright into generated C files. %insert(begin) @@ -119,3 +133,6 @@ ! SUNDIALS Copyright End ! --------------------------------------------------------------- %} + +// Process and wrap functions in the following files +%include "sundials/sundials_types.h"