From d6c03c6896fecf63571a385e24b7f7286c51901b Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 14 Jan 2025 16:55:08 +0000 Subject: [PATCH] Add unit tests for get_rank --- src/test/unit/CMakeLists.txt | 2 + src/test/unit/test_tensor_interrogation.pf | 77 ++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 src/test/unit/test_tensor_interrogation.pf diff --git a/src/test/unit/CMakeLists.txt b/src/test/unit/CMakeLists.txt index dee34a44..4aa9dc15 100644 --- a/src/test/unit/CMakeLists.txt +++ b/src/test/unit/CMakeLists.txt @@ -10,5 +10,7 @@ find_package(PFUNIT REQUIRED) add_pfunit_ctest(test_constructors TEST_SOURCES test_constructors.pf LINK_LIBRARIES FTorch::ftorch) +add_pfunit_ctest(test_tensor_interrogation + TEST_SOURCES test_tensor_interrogation.pf LINK_LIBRARIES FTorch::ftorch) add_pfunit_ctest(test_operator_overloads TEST_SOURCES test_tensor_operator_overloads.pf LINK_LIBRARIES FTorch::ftorch) diff --git a/src/test/unit/test_tensor_interrogation.pf b/src/test/unit/test_tensor_interrogation.pf new file mode 100644 index 00000000..c6f128fc --- /dev/null +++ b/src/test/unit/test_tensor_interrogation.pf @@ -0,0 +1,77 @@ +!| Unit tests for FTorch subroutines that interrogate tensors. +! +! * License +! FTorch is released under an MIT license. +! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) +! file for details. +module test_tensor_interrogation + use funit + use ftorch, only: ftorch_int, get_rank, torch_kFloat32, torch_kCPU, torch_tensor, & + torch_tensor_delete, torch_tensor_empty + use ftorch_test_utils, only: assert_allclose + use iso_c_binding, only: c_int64_t + + implicit none + + public + + ! Parameters common across all test cases + integer, parameter :: device_type = torch_kCPU + +contains + + ! Unit test for the get_rank function applied to a 1D tensor + @test + subroutine test_get_rank_1D() + + implicit none + + type(torch_tensor) :: tensor + integer, parameter :: ndims = 1 + integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [6] + integer, parameter :: dtype = torch_kFloat32 + + ! Create a tensor with uninitialised values and check get_rank can correctly identify its rank + call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type) + @assertEqual(ndims, get_rank(tensor)) + call torch_tensor_delete(tensor) + + end subroutine test_get_rank_1D + + ! Unit test for the get_rank function applied to a 2D tensor + @test + subroutine test_get_rank_2D() + + implicit none + + type(torch_tensor) :: tensor + integer, parameter :: ndims = 2 + integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [2,3] + integer, parameter :: dtype = torch_kFloat32 + + ! Create a tensor with uninitialised values and check get_rank can correctly identify its rank + call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type) + @assertEqual(ndims, get_rank(tensor)) + call torch_tensor_delete(tensor) + + end subroutine test_get_rank_2D + + ! Unit test for the get_rank function applied to a 3D tensor + @test + subroutine test_get_rank_3D() + + implicit none + + type(torch_tensor) :: tensor + integer, parameter :: ndims = 3 + integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [1,2,3] + integer, parameter :: dtype = torch_kFloat32 + + ! Create a tensor with uninitialised values and check get_rank can correctly identify its rank + call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type) + @assertEqual(ndims, get_rank(tensor)) + call torch_tensor_delete(tensor) + + end subroutine test_get_rank_3D + +end module test_tensor_interrogation