Skip to content

Commit

Permalink
Introduce switch fixture for operator overloads unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Jan 14, 2025
1 parent 9196917 commit 12d504a
Showing 1 changed file with 96 additions and 72 deletions.
168 changes: 96 additions & 72 deletions src/test/unit/test_tensor_operator_overloads.pf
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,62 @@ module test_tensor_operator_overloads

integer, parameter :: device_type = torch_kCPU

! Typedef holding a set of parameter values
@testParameter
type, extends(AbstractTestParameter) :: TestParametersType
logical :: switch
contains
procedure :: toString
end type TestParametersType

! Typedef for a test case with a particular set of parameters
@testCase(constructor=test_case_ctor)
type, extends (ParameterizedTestCase) :: TestCaseType
type(TestParametersType) :: param
end type TestCaseType

contains

@test
subroutine test_torch_tensor_assign()
! A fixture comprised of a full list of parameter sets
function get_parameters_full() result(params)
type(TestParametersType), allocatable :: params(:)
params = [ &
TestParametersType(.false.), &
TestParametersType(.true.) &
]
end function get_parameters_full

! A fixture comprised of a short list of parameter sets
function get_parameters_short() result(params)
type(TestParametersType), allocatable :: params(:)
params = [TestParametersType(.false.)]
end function get_parameters_short

! Constructor for the test case type
function test_case_ctor(param)
type(TestCaseType) :: test_case_ctor
type(TestParametersType) :: param
test_case_ctor%param = param
end function test_case_ctor

! Function for representing a parameter set as a string
function toString(this) result(string)
class(TestParametersType), intent(in) :: this
character(:), allocatable :: string
character(len=1) :: str
write(str,'(l1)') this%switch
string = str
end function toString

@test(testParameters={get_parameters_short()})
subroutine test_torch_tensor_assign(this)
use, intrinsic :: iso_fortran_env, only: sp => real32

implicit none
! Set working precision for reals
integer, parameter :: wp = sp

class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
Expand Down Expand Up @@ -78,8 +124,8 @@ contains

end subroutine test_torch_tensor_assign

@test
subroutine test_torch_tensor_add()
@test(testParameters={get_parameters_short()})
subroutine test_torch_tensor_add(this)
use ftorch, only: operator(+)
use, intrinsic :: iso_fortran_env, only: sp => real32

Expand All @@ -88,6 +134,7 @@ contains
! Set working precision for reals
integer, parameter :: wp = sp

class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2, tensor3
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
Expand Down Expand Up @@ -149,8 +196,8 @@ contains

end subroutine test_torch_tensor_add

@test
subroutine test_torch_tensor_subtract()
@test(testParameters={get_parameters_short()})
subroutine test_torch_tensor_subtract(this)
use ftorch, only: operator(-)
use, intrinsic :: iso_fortran_env, only: sp => real32

Expand All @@ -159,6 +206,7 @@ contains
! Set working precision for reals
integer, parameter :: wp = sp

class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2, tensor3
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
Expand Down Expand Up @@ -222,8 +270,8 @@ contains

end subroutine test_torch_tensor_subtract

@test
subroutine test_torch_tensor_multiply()
@test(testParameters={get_parameters_short()})
subroutine test_torch_tensor_multiply(this)
use ftorch, only: operator(*)
use, intrinsic :: iso_fortran_env, only: sp => real32

Expand All @@ -232,6 +280,7 @@ contains
! Set working precision for reals
integer, parameter :: wp = sp

class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2, tensor3
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
Expand Down Expand Up @@ -295,8 +344,8 @@ contains

end subroutine test_torch_tensor_multiply

@test
subroutine test_torch_tensor_scalar_multiply()
@test(testParameters={get_parameters_full()})
subroutine test_torch_tensor_scalar_multiply(this)
use ftorch, only: operator(*)
use, intrinsic :: iso_fortran_env, only: sp => real32

Expand All @@ -305,14 +354,15 @@ contains
! Set working precision for reals
integer, parameter :: wp = sp

type(torch_tensor) :: tensor1, tensor2, tensor3
class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3]
integer, parameter :: dtype = torch_kFloat32
real(wp), parameter :: scalar = 3.14
real(wp), dimension(2,3), target :: in_data
real(wp), dimension(:,:), pointer :: out_data2, out_data3
real(wp), dimension(:,:), pointer :: out_data
real(wp), dimension(2,3) :: expected
logical :: test_pass

Expand All @@ -325,60 +375,45 @@ contains
! Create another two empty tensors and assign them to the products of a scalar constant and the
! first tensor using the overloaded multiplication operator (in each order)
call torch_tensor_empty(tensor2, ndims, tensor_shape, dtype, device_type)
call torch_tensor_empty(tensor3, ndims, tensor_shape, dtype, device_type)
tensor2 = scalar * tensor1

! Check input array is unchanged by pre-multiplication
expected(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
test_pass = assert_allclose(in_data, expected, test_name="test_torch_tensor_premultiply", &
rtol=1e-5)
if (.not. test_pass) then
call clean_up()
print *, "Error :: input array was changed during scalar pre-multiplication"
stop 999
if (this%param%switch) then
tensor2 = scalar * tensor1
else
tensor2 = tensor1 * scalar
end if

tensor3 = tensor1 * scalar

! Check input array is unchanged by post-multiplication
! Check input array is unchanged by scalar multiplication
expected(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
test_pass = assert_allclose(in_data, expected, test_name="test_torch_tensor_postmultiply", &
test_pass = assert_allclose(in_data, expected, test_name="test_torch_tensor_scalar_multiply", &
rtol=1e-5)
if (.not. test_pass) then
call clean_up()
print *, "Error :: input array was changed during scalar post-multiplication"
print *, "Error :: input array was changed during scalar multiplication"
stop 999
end if

! Extract Fortran arrays from the assigned tensors and compare the data in the tensors to the
! scaled input arrays
call torch_tensor_to_array(tensor2, out_data2, shape(in_data))
call torch_tensor_to_array(tensor3, out_data3, shape(in_data))
call torch_tensor_to_array(tensor2, out_data, shape(in_data))
expected(:,:) = scalar * in_data
test_pass = assert_allclose(out_data2, expected, test_name="test_torch_tensor_premultiply")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data2))
test_pass = assert_allclose(out_data3, expected, test_name="test_torch_tensor_postmultiply")
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_scalar_multiply")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data3))
@assertEqual(shape(expected), shape(out_data))

call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data2)
nullify(out_data3)
nullify(out_data)
call torch_tensor_delete(tensor1)
call torch_tensor_delete(tensor2)
call torch_tensor_delete(tensor3)
end subroutine clean_up

end subroutine test_torch_tensor_scalar_multiply

@test
subroutine test_torch_tensor_divide()
@test(testParameters={get_parameters_short()})
subroutine test_torch_tensor_divide(this)
use ftorch, only: operator(/)
use, intrinsic :: iso_fortran_env, only: sp => real32

Expand All @@ -387,6 +422,7 @@ contains
! Set working precision for reals
integer, parameter :: wp = sp

class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2, tensor3
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
Expand Down Expand Up @@ -448,8 +484,8 @@ contains

end subroutine test_torch_tensor_divide

@test
subroutine test_torch_tensor_scalar_divide()
@test(testParameters={get_parameters_short()})
subroutine test_torch_tensor_scalar_divide(this)
use ftorch, only: operator(/)
use, intrinsic :: iso_fortran_env, only: sp => real32

Expand All @@ -458,6 +494,7 @@ contains
! Set working precision for reals
integer, parameter :: wp = sp

class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
Expand Down Expand Up @@ -511,8 +548,8 @@ contains

end subroutine test_torch_tensor_scalar_divide

@test
subroutine test_torch_tensor_square()
@test(testParameters={get_parameters_short()})
subroutine test_torch_tensor_square(this)
use ftorch, only: operator(**)
use, intrinsic :: iso_fortran_env, only: sp => real32

Expand All @@ -521,13 +558,14 @@ contains
! Set working precision for reals
integer, parameter :: wp = sp

type(torch_tensor) :: tensor1, tensor2, tensor3
class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3]
integer, parameter :: dtype = torch_kFloat32
real(wp), dimension(2,3), target :: in_data
real(wp), dimension(:,:), pointer :: out_data2, out_data3
real(wp), dimension(:,:), pointer :: out_data
real(wp), dimension(2,3) :: expected
logical :: test_pass

Expand All @@ -541,8 +579,11 @@ contains
! integer exponent and float exponent, respectively, using the overloaded exponentiation
! operator
call torch_tensor_empty(tensor2, ndims, tensor_shape, dtype, device_type)
call torch_tensor_empty(tensor3, ndims, tensor_shape, dtype, device_type)
tensor2 = tensor1 ** 2
if (this%param%switch) then
tensor2 = tensor1 ** 2
else
tensor2 = tensor1 ** 2.0
end if

! Check input array is unchanged by pre-multiplication
expected(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
Expand All @@ -554,47 +595,29 @@ contains
stop 999
end if

tensor3 = tensor1 ** 2.0

! Check input array is unchanged by pre-multiplication
expected(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
test_pass = assert_allclose(in_data, expected, test_name="test_torch_tensor_square_float", &
rtol=1e-5)
if (.not. test_pass) then
call clean_up()
print *, "Error :: input array was changed during floating point exponentation"
stop 999
end if

! Extract Fortran arrays from the assigned tensors and compare the data in the tensors to the
! squared input array
call torch_tensor_to_array(tensor2, out_data2, shape(in_data))
call torch_tensor_to_array(tensor3, out_data3, shape(in_data))
call torch_tensor_to_array(tensor2, out_data, shape(in_data))
expected(:,:) = in_data ** 2
test_pass = assert_allclose(out_data2, expected, test_name="test_torch_tensor_square_int")
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_square_int")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data2))
test_pass = assert_allclose(out_data3, expected, test_name="test_torch_tensor_square_float")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data3))
@assertEqual(shape(expected), shape(out_data))

call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data2)
nullify(out_data3)
nullify(out_data)
call torch_tensor_delete(tensor1)
call torch_tensor_delete(tensor2)
call torch_tensor_delete(tensor3)
end subroutine clean_up

end subroutine test_torch_tensor_square

@test
subroutine test_torch_tensor_sqrt()
@test(testParameters={get_parameters_short()})
subroutine test_torch_tensor_sqrt(this)
use ftorch, only: operator(**)
use, intrinsic :: iso_fortran_env, only: sp => real32
use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t
Expand All @@ -604,6 +627,7 @@ contains
! Set working precision for reals
integer, parameter :: wp = sp

class(TestCaseType), intent(inout) :: this
type(torch_tensor) :: tensor1, tensor2
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
Expand Down

0 comments on commit 12d504a

Please sign in to comment.