From 445a92712acab423d64059cc3f338e849843bcfe Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Wed, 15 Jan 2025 11:06:46 +0000 Subject: [PATCH] Add unit test for torch_tensor_get_device_index on CPU --- src/test/unit/test_tensor_interrogation.pf | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/test/unit/test_tensor_interrogation.pf b/src/test/unit/test_tensor_interrogation.pf index 8bd53861..578cc77e 100644 --- a/src/test/unit/test_tensor_interrogation.pf +++ b/src/test/unit/test_tensor_interrogation.pf @@ -19,6 +19,32 @@ module test_tensor_interrogation contains + ! Unit test for the torch_tensor_get_device_index function applied to a tensor on the CPU + @test + subroutine test_torch_tensor_get_device_index_default() + use ftorch, only: torch_tensor_get_device_index + + implicit none + + type(torch_tensor) :: tensor + integer, parameter :: ndims = 1 + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [1] + integer, parameter :: dtype = torch_kFloat32 + integer, parameter :: expected = -1 + + ! Create an empty tensor on the CPU with the default device index + call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type) + + ! Check that torch_tensor_get_device_index can get the device index + if (expected /= torch_tensor_get_device_index(tensor)) then + call torch_tensor_delete(tensor) + print *, "Error :: torch_tensor_get_device_index returned incorrect CPU device index" + stop 999 + end if + call torch_tensor_delete(tensor) + + end subroutine test_torch_tensor_get_device_index_default + ! Unit test for the get_rank method of a 1D tensor @test subroutine test_get_rank_1D()