Skip to content

Commit

Permalink
bugfix: fix for windows and arm long long int
Browse files Browse the repository at this point in the history
fixes [#183](#183)

There is an issue when building on mac (arm_64) or windows. The version
of `libtorch` exposes a torch tensors shape (`t->sizes().data()`) as a
`const long long int*` instead of just a `const long int*` like on linux
and mac (x86).

This commit adds preprocessor macro to switch between implementations
automatically detecting the correct version at CMake build stage.
  • Loading branch information
TomMelt committed Dec 2, 2024
1 parent 4d66327 commit a005817
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 7 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/fypp_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ jobs:

- name: Check ftorch.fypp matches ftorch.f90
run: |
fypp src/ftorch.fypp src/temp.f90_temp
if ! diff -q src/ftorch.f90 src/temp.f90_temp; then
echo "Error: The code in ftorch.f90 does not match that expected from ftorch.fypp."
fypp src/ftorch.fypp src/temp.F90_temp
if ! diff -q src/ftorch.F90 src/temp.F90_temp; then
echo "Error: The code in ftorch.F90 does not match that expected from ftorch.fypp."
echo "Please re-run fypp on ftorch.fypp to ensure consistency and re-commit."
exit 1
else
Expand Down
2 changes: 1 addition & 1 deletion examples/n_c_and_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ set(CMAKE_INSTALL_RPATH $ORIGIN/${relDir})
find_package(Torch REQUIRED)

# Library with C and Fortran bindings
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90)
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.F90)
add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME})
set_target_properties(${LIB_NAME} PROPERTIES
PUBLIC_HEADER "ctorch.h"
Expand Down
10 changes: 9 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,15 @@ set(CMAKE_INSTALL_RPATH $ORIGIN/${relDir})
find_package(Torch REQUIRED)

# Library with C and Fortran bindings
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90 ftorch_test_utils.f90)
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.F90 ftorch_test_utils.f90)

if(UNIX)
message(STATUS "CMAKE_SYSTEM_PROCESSOR = ${CMAKE_SYSTEM_PROCESSOR}")
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
target_compile_definitions(${LIB_NAME} PRIVATE UNIX)
endif()
endif()

# Add an alias FTorch::ftorch for the library
add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME})
set_target_properties(${LIB_NAME} PROPERTIES
Expand Down
7 changes: 7 additions & 0 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,17 @@ int torch_tensor_get_rank(const torch_tensor_t tensor) {
return t->sizes().size();
}

#ifdef UNIX
const long int *torch_tensor_get_sizes(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
return t->sizes().data();
}
#else
const long long int *torch_tensor_get_sizes(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
return t->sizes().data();
}
#endif

void torch_tensor_delete(torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
Expand Down
5 changes: 5 additions & 0 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@ EXPORT_C int torch_tensor_get_rank(const torch_tensor_t tensor);
* @param Torch Tensor to determine the rank of
* @return pointer to the sizes array of the Torch Tensor
*/
#ifdef UNIX
EXPORT_C const long int *torch_tensor_get_sizes(const torch_tensor_t tensor);
#else
EXPORT_C const long long int *
torch_tensor_get_sizes(const torch_tensor_t tensor);
#endif

/**
* Function to delete a Torch Tensor to clean up
Expand Down
6 changes: 5 additions & 1 deletion src/ftorch.f90 → src/ftorch.F90
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,13 @@ end function get_rank

!> Determines the shape of a tensor.
function get_shape(self) result(sizes)
use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr
use, intrinsic :: iso_c_binding, only : c_int, c_long, c_long_long, c_ptr
class(torch_tensor), intent(in) :: self
#ifdef UNIX
integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data
#else
integer(kind=c_long_long), pointer :: sizes(:) !! Pointer to tensor data
#endif
integer(kind=int32) :: ndims(1)
type(c_ptr) :: cptr

Expand Down
6 changes: 5 additions & 1 deletion src/ftorch.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,13 @@ contains

!> Determines the shape of a tensor.
function get_shape(self) result(sizes)
use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr
use, intrinsic :: iso_c_binding, only : c_int, c_long, c_long_long, c_ptr
class(torch_tensor), intent(in) :: self
#ifdef UNIX
integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data
#else
integer(kind=c_long_long), pointer :: sizes(:) !! Pointer to tensor data
#endif
integer(kind=int32) :: ndims(1)
type(c_ptr) :: cptr

Expand Down

0 comments on commit a005817

Please sign in to comment.