-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit test for torch_tensor_get_device_index on CUDA device
- Loading branch information
1 parent
445a927
commit 602d8f1
Showing
3 changed files
with
50 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
!| Unit tests for FTorch subroutines that interrogate tensors on a CUDA device. | ||
! | ||
! * License | ||
! FTorch is released under an MIT license. | ||
! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) | ||
! file for details. | ||
module test_tensor_interrogation_cuda | ||
|
||
implicit none | ||
|
||
public | ||
|
||
contains | ||
|
||
! Unit test for the torch_tensor_get_device_index function applied to a tensor on a CUDA device | ||
@test | ||
subroutine test_torch_tensor_get_device_index_default() | ||
use funit | ||
use ftorch, only: torch_kFloat32, torch_kCPU, torch_tensor, torch_tensor_delete & | ||
torch_tensor_empty, torch_tensor_get_device_index | ||
use iso_c_binding, only: c_int64_t | ||
|
||
implicit none | ||
|
||
type(torch_tensor) :: tensor | ||
integer, parameter :: ndims = 1 | ||
integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [1] | ||
integer, parameter :: dtype = torch_kFloat32 | ||
integer, parameter :: device_type = torch_kCUDA | ||
integer, parameter :: expected = 0 | ||
|
||
! Create an empty tensor on the CPU with the default device index | ||
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type) | ||
|
||
! Check that torch_tensor_get_device_index can get the device index | ||
if (expected /= torch_tensor_get_device_index(tensor)) then | ||
call torch_tensor_delete(tensor) | ||
print *, "Error :: torch_tensor_get_device_index returned incorrect CUDA device index" | ||
stop 999 | ||
end if | ||
call torch_tensor_delete(tensor) | ||
|
||
end subroutine test_torch_tensor_get_device_index_default | ||
|
||
end module test_tensor_interrogation_cuda |