Skip to content

Commit

Permalink
Add unit tests for get_rank
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Jan 14, 2025
1 parent ae7aa4b commit d6c03c6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/test/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ find_package(PFUNIT REQUIRED)

add_pfunit_ctest(test_constructors
TEST_SOURCES test_constructors.pf LINK_LIBRARIES FTorch::ftorch)
add_pfunit_ctest(test_tensor_interrogation
TEST_SOURCES test_tensor_interrogation.pf LINK_LIBRARIES FTorch::ftorch)
add_pfunit_ctest(test_operator_overloads
TEST_SOURCES test_tensor_operator_overloads.pf LINK_LIBRARIES FTorch::ftorch)
77 changes: 77 additions & 0 deletions src/test/unit/test_tensor_interrogation.pf
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
!| Unit tests for FTorch subroutines that interrogate tensors.
!
! * License
! FTorch is released under an MIT license.
! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE)
! file for details.
module test_tensor_interrogation
use funit
use ftorch, only: ftorch_int, get_rank, torch_kFloat32, torch_kCPU, torch_tensor, &
torch_tensor_delete, torch_tensor_empty
use ftorch_test_utils, only: assert_allclose
use iso_c_binding, only: c_int64_t

implicit none

public

! Parameters common across all test cases
integer, parameter :: device_type = torch_kCPU

contains

! Unit test for the get_rank function applied to a 1D tensor
@test
subroutine test_get_rank_1D()

implicit none

type(torch_tensor) :: tensor
integer, parameter :: ndims = 1
integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [6]
integer, parameter :: dtype = torch_kFloat32

! Create a tensor with uninitialised values and check get_rank can correctly identify its rank
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type)
@assertEqual(ndims, get_rank(tensor))
call torch_tensor_delete(tensor)

end subroutine test_get_rank_1D

! Unit test for the get_rank function applied to a 2D tensor
@test
subroutine test_get_rank_2D()

implicit none

type(torch_tensor) :: tensor
integer, parameter :: ndims = 2
integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [2,3]
integer, parameter :: dtype = torch_kFloat32

! Create a tensor with uninitialised values and check get_rank can correctly identify its rank
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type)
@assertEqual(ndims, get_rank(tensor))
call torch_tensor_delete(tensor)

end subroutine test_get_rank_2D

! Unit test for the get_rank function applied to a 3D tensor
@test
subroutine test_get_rank_3D()

implicit none

type(torch_tensor) :: tensor
integer, parameter :: ndims = 3
integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [1,2,3]
integer, parameter :: dtype = torch_kFloat32

! Create a tensor with uninitialised values and check get_rank can correctly identify its rank
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type)
@assertEqual(ndims, get_rank(tensor))
call torch_tensor_delete(tensor)

end subroutine test_get_rank_3D

end module test_tensor_interrogation

0 comments on commit d6c03c6

Please sign in to comment.