-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ae7aa4b
commit d6c03c6
Showing
2 changed files
with
79 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
!| Unit tests for FTorch subroutines that interrogate tensors. | ||
! | ||
! * License | ||
! FTorch is released under an MIT license. | ||
! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) | ||
! file for details. | ||
module test_tensor_interrogation | ||
use funit | ||
use ftorch, only: ftorch_int, get_rank, torch_kFloat32, torch_kCPU, torch_tensor, & | ||
torch_tensor_delete, torch_tensor_empty | ||
use ftorch_test_utils, only: assert_allclose | ||
use iso_c_binding, only: c_int64_t | ||
|
||
implicit none | ||
|
||
public | ||
|
||
! Parameters common across all test cases | ||
integer, parameter :: device_type = torch_kCPU | ||
|
||
contains | ||
|
||
! Unit test for the get_rank function applied to a 1D tensor | ||
@test | ||
subroutine test_get_rank_1D() | ||
|
||
implicit none | ||
|
||
type(torch_tensor) :: tensor | ||
integer, parameter :: ndims = 1 | ||
integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [6] | ||
integer, parameter :: dtype = torch_kFloat32 | ||
|
||
! Create a tensor with uninitialised values and check get_rank can correctly identify its rank | ||
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type) | ||
@assertEqual(ndims, get_rank(tensor)) | ||
call torch_tensor_delete(tensor) | ||
|
||
end subroutine test_get_rank_1D | ||
|
||
! Unit test for the get_rank function applied to a 2D tensor | ||
@test | ||
subroutine test_get_rank_2D() | ||
|
||
implicit none | ||
|
||
type(torch_tensor) :: tensor | ||
integer, parameter :: ndims = 2 | ||
integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [2,3] | ||
integer, parameter :: dtype = torch_kFloat32 | ||
|
||
! Create a tensor with uninitialised values and check get_rank can correctly identify its rank | ||
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type) | ||
@assertEqual(ndims, get_rank(tensor)) | ||
call torch_tensor_delete(tensor) | ||
|
||
end subroutine test_get_rank_2D | ||
|
||
! Unit test for the get_rank function applied to a 3D tensor | ||
@test | ||
subroutine test_get_rank_3D() | ||
|
||
implicit none | ||
|
||
type(torch_tensor) :: tensor | ||
integer, parameter :: ndims = 3 | ||
integer(kind=c_int64_t), parameter :: tensor_shape(ndims) = [1,2,3] | ||
integer, parameter :: dtype = torch_kFloat32 | ||
|
||
! Create a tensor with uninitialised values and check get_rank can correctly identify its rank | ||
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type) | ||
@assertEqual(ndims, get_rank(tensor)) | ||
call torch_tensor_delete(tensor) | ||
|
||
end subroutine test_get_rank_3D | ||
|
||
end module test_tensor_interrogation |