From 9f3323be336d8ac633d9a7c61d61828ab3d42025 Mon Sep 17 00:00:00 2001 From: tommelt Date: Fri, 15 Nov 2024 15:26:17 +0000 Subject: [PATCH] bugfix: fix for windows and arm long long int fixes [#183](https://github.com/Cambridge-ICCS/FTorch/issues/183) There is an issue when building on mac (arm_64) or windows. The version of `libtorch` exposes a torch tensors shape (`t->sizes().data()`) as a `const long long int*` instead of just a `const long int*` like on linux and mac (x86). This commit adds preprocessor macro to switch between implementations automatically detecting the correct version at CMake build stage. --- .github/workflows/fypp_checks.yml | 6 +++--- examples/n_c_and_cpp/CMakeLists.txt | 2 +- src/CMakeLists.txt | 9 ++++++++- src/ctorch.cpp | 7 +++++++ src/ctorch.h | 5 +++++ src/{ftorch.f90 => ftorch.F90} | 6 +++++- src/ftorch.fypp | 6 +++++- 7 files changed, 34 insertions(+), 7 deletions(-) rename src/{ftorch.f90 => ftorch.F90} (99%) diff --git a/.github/workflows/fypp_checks.yml b/.github/workflows/fypp_checks.yml index e6e1e9ff..90dc7a1a 100644 --- a/.github/workflows/fypp_checks.yml +++ b/.github/workflows/fypp_checks.yml @@ -30,9 +30,9 @@ jobs: - name: Check ftorch.fypp matches ftorch.f90 run: | - fypp src/ftorch.fypp src/temp.f90_temp - if ! diff -q src/ftorch.f90 src/temp.f90_temp; then - echo "Error: The code in ftorch.f90 does not match that expected from ftorch.fypp." + fypp src/ftorch.fypp src/temp.F90_temp + if ! diff -q src/ftorch.F90 src/temp.F90_temp; then + echo "Error: The code in ftorch.F90 does not match that expected from ftorch.fypp." echo "Please re-run fypp on ftorch.fypp to ensure consistency and re-commit." exit 1 else diff --git a/examples/n_c_and_cpp/CMakeLists.txt b/examples/n_c_and_cpp/CMakeLists.txt index 69a45c64..cdb07f93 100644 --- a/examples/n_c_and_cpp/CMakeLists.txt +++ b/examples/n_c_and_cpp/CMakeLists.txt @@ -35,7 +35,7 @@ set(CMAKE_INSTALL_RPATH $ORIGIN/${relDir}) find_package(Torch REQUIRED) # Library with C and Fortran bindings -add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90) +add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.F90) add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME}) set_target_properties(${LIB_NAME} PROPERTIES PUBLIC_HEADER "ctorch.h" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index abfbb33a..1043ac6e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -52,7 +52,14 @@ set(CMAKE_INSTALL_RPATH $ORIGIN/${relDir}) find_package(Torch REQUIRED) # Library with C and Fortran bindings -add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90 ftorch_test_utils.f90) +add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.F90 ftorch_test_utils.f90) + +if(UNIX) + if(NOT APPLE) # only add definition for linux (not apple which is also unix) + target_compile_definitions(${LIB_NAME} PRIVATE UNIX) + endif() +endif() + # Add an alias FTorch::ftorch for the library add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME}) set_target_properties(${LIB_NAME} PROPERTIES diff --git a/src/ctorch.cpp b/src/ctorch.cpp index efdc2f0b..9f513b37 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -224,10 +224,17 @@ int torch_tensor_get_rank(const torch_tensor_t tensor) { return t->sizes().size(); } +#ifdef UNIX const long int *torch_tensor_get_sizes(const torch_tensor_t tensor) { auto t = reinterpret_cast(tensor); return t->sizes().data(); } +#else +const long long int *torch_tensor_get_sizes(const torch_tensor_t tensor) { + auto t = reinterpret_cast(tensor); + return t->sizes().data(); +} +#endif void torch_tensor_delete(torch_tensor_t tensor) { auto t = reinterpret_cast(tensor); diff --git a/src/ctorch.h b/src/ctorch.h index 9b45102d..d7d39490 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -128,7 +128,12 @@ EXPORT_C int torch_tensor_get_rank(const torch_tensor_t tensor); * @param Torch Tensor to determine the rank of * @return pointer to the sizes array of the Torch Tensor */ +#ifdef UNIX EXPORT_C const long int *torch_tensor_get_sizes(const torch_tensor_t tensor); +#else +EXPORT_C const long long int * +torch_tensor_get_sizes(const torch_tensor_t tensor); +#endif /** * Function to delete a Torch Tensor to clean up diff --git a/src/ftorch.f90 b/src/ftorch.F90 similarity index 99% rename from src/ftorch.f90 rename to src/ftorch.F90 index 2bbe1391..7300f737 100644 --- a/src/ftorch.f90 +++ b/src/ftorch.F90 @@ -367,9 +367,13 @@ end function get_rank !> Determines the shape of a tensor. function get_shape(self) result(sizes) - use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr + use, intrinsic :: iso_c_binding, only : c_int, c_long, c_long_long, c_ptr class(torch_tensor), intent(in) :: self +#ifdef UNIX integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data +#else + integer(kind=c_long_long), pointer :: sizes(:) !! Pointer to tensor data +#endif integer(kind=int32) :: ndims(1) type(c_ptr) :: cptr diff --git a/src/ftorch.fypp b/src/ftorch.fypp index d6dbf4a4..cb00bff6 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -334,9 +334,13 @@ contains !> Determines the shape of a tensor. function get_shape(self) result(sizes) - use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr + use, intrinsic :: iso_c_binding, only : c_int, c_long, c_long_long, c_ptr class(torch_tensor), intent(in) :: self +#ifdef UNIX integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data +#else + integer(kind=c_long_long), pointer :: sizes(:) !! Pointer to tensor data +#endif integer(kind=int32) :: ndims(1) type(c_ptr) :: cptr