From 5fb08873c1910f0eff7e6b0424e7ce63a365fcde Mon Sep 17 00:00:00 2001 From: tommelt Date: Fri, 15 Nov 2024 15:26:17 +0000 Subject: [PATCH] bugfix: fix for windows and arm long long int fixes [#183](https://github.com/Cambridge-ICCS/FTorch/issues/183) There is an issue when building on mac (arm_64) or windows. The version of `libtorch` exposes a torch tensors shape (`t->sizes().data()`) as a `const long long int*` instead of just a `const long int*` like on linux and mac (x86). This commit adds preprocessor macro to switch between implementations automatically detecting the correct version at CMake build stage. --- .github/workflows/fypp.yml | 6 +++--- examples/n_c_and_cpp/CMakeLists.txt | 2 +- src/CMakeLists.txt | 10 +++++++++- src/ctorch.cpp | 8 ++++++++ src/ctorch.h | 4 ++++ src/{ftorch.f90 => ftorch.F90} | 6 +++++- src/ftorch.fypp | 6 +++++- 7 files changed, 35 insertions(+), 7 deletions(-) rename src/{ftorch.f90 => ftorch.F90} (99%) diff --git a/.github/workflows/fypp.yml b/.github/workflows/fypp.yml index fdb96047..565ad41c 100644 --- a/.github/workflows/fypp.yml +++ b/.github/workflows/fypp.yml @@ -24,9 +24,9 @@ jobs: - name: Check ftorch.fypp matches ftorch.f90 run: | - fypp src/ftorch.fypp src/temp.f90_temp - if ! diff -q src/ftorch.f90 src/temp.f90_temp; then - echo "Error: The code in ftorch.f90 does not match that expected from ftorch.fypp." + fypp src/ftorch.fypp src/temp.F90_temp + if ! diff -q src/ftorch.F90 src/temp.F90_temp; then + echo "Error: The code in ftorch.F90 does not match that expected from ftorch.fypp." echo "Please re-run fypp on ftorch.fypp to ensure consistency and re-commit." exit 1 else diff --git a/examples/n_c_and_cpp/CMakeLists.txt b/examples/n_c_and_cpp/CMakeLists.txt index 69a45c64..cdb07f93 100644 --- a/examples/n_c_and_cpp/CMakeLists.txt +++ b/examples/n_c_and_cpp/CMakeLists.txt @@ -35,7 +35,7 @@ set(CMAKE_INSTALL_RPATH $ORIGIN/${relDir}) find_package(Torch REQUIRED) # Library with C and Fortran bindings -add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90) +add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.F90) add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME}) set_target_properties(${LIB_NAME} PROPERTIES PUBLIC_HEADER "ctorch.h" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index abfbb33a..5f08471b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -52,7 +52,15 @@ set(CMAKE_INSTALL_RPATH $ORIGIN/${relDir}) find_package(Torch REQUIRED) # Library with C and Fortran bindings -add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90 ftorch_test_utils.f90) +add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.F90 ftorch_test_utils.f90) + +if(UNIX) + message(STATUS "CMAKE_SYSTEM_PROCESSOR = ${CMAKE_SYSTEM_PROCESSOR}") + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + target_compile_definitions(${LIB_NAME} PRIVATE UNIX) + endif() +endif() + # Add an alias FTorch::ftorch for the library add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME}) set_target_properties(${LIB_NAME} PROPERTIES diff --git a/src/ctorch.cpp b/src/ctorch.cpp index 59757c0c..9280653d 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -234,11 +234,19 @@ int torch_tensor_get_rank(const torch_tensor_t tensor) return t->sizes().size(); } +#ifdef UNIX const long int* torch_tensor_get_sizes(const torch_tensor_t tensor) { auto t = reinterpret_cast(tensor); return t->sizes().data(); } +#else +const long long int* torch_tensor_get_sizes(const torch_tensor_t tensor) +{ + auto t = reinterpret_cast(tensor); + return t->sizes().data(); +} +#endif void torch_tensor_delete(torch_tensor_t tensor) { diff --git a/src/ctorch.h b/src/ctorch.h index 0b25bcf2..c4f20aff 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -125,7 +125,11 @@ EXPORT_C int torch_tensor_get_rank(const torch_tensor_t tensor); * @param Torch Tensor to determine the rank of * @return pointer to the sizes array of the Torch Tensor */ +#ifdef UNIX EXPORT_C const long int* torch_tensor_get_sizes(const torch_tensor_t tensor); +#else +EXPORT_C const long long int* torch_tensor_get_sizes(const torch_tensor_t tensor); +#endif /** * Function to delete a Torch Tensor to clean up diff --git a/src/ftorch.f90 b/src/ftorch.F90 similarity index 99% rename from src/ftorch.f90 rename to src/ftorch.F90 index f0d34abf..7e07d734 100644 --- a/src/ftorch.f90 +++ b/src/ftorch.F90 @@ -351,9 +351,13 @@ end function get_rank !> Determines the shape of a tensor. function get_shape(self) result(sizes) - use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr + use, intrinsic :: iso_c_binding, only : c_int, c_long, c_long_long, c_ptr class(torch_tensor), intent(in) :: self +#ifdef UNIX integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data +#else + integer(kind=c_long_long), pointer :: sizes(:) !! Pointer to tensor data +#endif integer(kind=int32) :: ndims(1) type(c_ptr) :: cptr diff --git a/src/ftorch.fypp b/src/ftorch.fypp index bca8d3e9..974bcc3e 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -318,9 +318,13 @@ contains !> Determines the shape of a tensor. function get_shape(self) result(sizes) - use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr + use, intrinsic :: iso_c_binding, only : c_int, c_long, c_long_long, c_ptr class(torch_tensor), intent(in) :: self +#ifdef UNIX integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data +#else + integer(kind=c_long_long), pointer :: sizes(:) !! Pointer to tensor data +#endif integer(kind=int32) :: ndims(1) type(c_ptr) :: cptr