Skip to content

Commit

Permalink
Add unit test for torch_tensor_get_device_index on CUDA device
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Jan 15, 2025
1 parent 445a927 commit 602d8f1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/test/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ add_pfunit_ctest(test_tensor_interrogation
TEST_SOURCES test_tensor_interrogation.pf LINK_LIBRARIES FTorch::ftorch)
add_pfunit_ctest(test_operator_overloads
TEST_SOURCES test_tensor_operator_overloads.pf LINK_LIBRARIES FTorch::ftorch)
if(ENABLE_CUDA)
add_pfunit_ctest(test_operator_overloads_cuda
TEST_SOURCES test_tensor_operator_overloads_cuda.pf LINK_LIBRARIES FTorch::ftorch)
endif()
2 changes: 1 addition & 1 deletion src/test/unit/test_tensor_interrogation.pf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
!| Unit tests for FTorch subroutines that interrogate tensors.
!| Unit tests for FTorch subroutines that interrogate tensors on the CPU.
!
! * License
! FTorch is released under an MIT license.
Expand Down
45 changes: 45 additions & 0 deletions src/test/unit/test_tensor_interrogation_cuda.pf
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
!| Unit tests for FTorch subroutines that interrogate tensors on a CUDA device.
!
! * License
! FTorch is released under an MIT license.
! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE)
! file for details.
module test_tensor_interrogation_cuda

implicit none

public

contains

! Unit test for the torch_tensor_get_device_index function applied to a tensor on a CUDA device
@test
subroutine test_torch_tensor_get_device_index_default()
use funit
use ftorch, only: torch_kFloat32, torch_kCPU, torch_tensor, torch_tensor_delete &
torch_tensor_empty, torch_tensor_get_device_index
use iso_c_binding, only: c_int64_t

implicit none

type(torch_tensor) :: tensor
integer, parameter :: ndims = 1
integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [1]
integer, parameter :: dtype = torch_kFloat32
integer, parameter :: device_type = torch_kCUDA
integer, parameter :: expected = 0

! Create an empty tensor on the CPU with the default device index
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type)

! Check that torch_tensor_get_device_index can get the device index
if (expected /= torch_tensor_get_device_index(tensor)) then
call torch_tensor_delete(tensor)
print *, "Error :: torch_tensor_get_device_index returned incorrect CUDA device index"
stop 999
end if
call torch_tensor_delete(tensor)

end subroutine test_torch_tensor_get_device_index_default

end module test_tensor_interrogation_cuda

0 comments on commit 602d8f1

Please sign in to comment.