Skip to content

Commit

Permalink
Add unit test for torch_tensor_get_device_index on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Jan 15, 2025
1 parent 9623a90 commit 445a927
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/test/unit/test_tensor_interrogation.pf
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,32 @@ module test_tensor_interrogation

contains

! Unit test for the torch_tensor_get_device_index function applied to a tensor on the CPU
@test
subroutine test_torch_tensor_get_device_index_default()
use ftorch, only: torch_tensor_get_device_index

implicit none

type(torch_tensor) :: tensor
integer, parameter :: ndims = 1
integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [1]
integer, parameter :: dtype = torch_kFloat32
integer, parameter :: expected = -1

! Create an empty tensor on the CPU with the default device index
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type)

! Check that torch_tensor_get_device_index can get the device index
if (expected /= torch_tensor_get_device_index(tensor)) then
call torch_tensor_delete(tensor)
print *, "Error :: torch_tensor_get_device_index returned incorrect CPU device index"
stop 999
end if
call torch_tensor_delete(tensor)

end subroutine test_torch_tensor_get_device_index_default

! Unit test for the get_rank method of a 1D tensor
@test
subroutine test_get_rank_1D()
Expand Down

0 comments on commit 445a927

Please sign in to comment.