Skip to content

Commit

Permalink
test(hyperparameters): add unit test & finish type
Browse files Browse the repository at this point in the history
  • Loading branch information
rouson committed Oct 31, 2023
1 parent fd2cc99 commit 0c95d20
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 111 deletions.
20 changes: 18 additions & 2 deletions src/inference_engine/hyperparameters_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,23 @@ module hyperparameters_m
character(len=:), allocatable :: optimizer_
contains
procedure :: to_json
procedure :: equals
generic :: operator(==) => equals
end type

interface hyperparameters_t

pure module function from_json(file_) result(hyperparameters)
pure module function from_json(lines) result(hyperparameters)
implicit none
type(file_t), intent(in) :: file_
type(string_t), intent(in) :: lines(:)
type(hyperparameters_t) hyperparameters
end function

pure module function from_components(mini_batches, learning_rate, optimizer) result(hyperparameters)
implicit none
integer, intent(in) :: mini_batches
real, intent(in) :: learning_rate
character(len=*), intent(in) :: optimizer
type(hyperparameters_t) hyperparameters
end function

Expand All @@ -32,6 +42,12 @@ pure module function to_json(self) result(lines)
type(string_t), allocatable :: lines(:)
end function

elemental module function equals(lhs, rhs) result(lhs_equals_rhs)
implicit none
class(hyperparameters_t), intent(in) :: lhs, rhs
logical lhs_equals_rhs
end function

end interface

end module
34 changes: 25 additions & 9 deletions src/inference_engine/hyperparameters_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,31 @@

contains

module procedure from_components
hyperparameters%mini_batches_ = mini_batches
hyperparameters%learning_rate_ = learning_rate
hyperparameters%optimizer_ = optimizer
end procedure

module procedure equals

real, parameter :: tolerance = 1.E-08

call assert(allocated(lhs%optimizer_) .and. allocated(rhs%optimizer_), "hyperparameters_s(equals): allocated optimizers")

lhs_equals_rhs = &
lhs%mini_batches_ == rhs%mini_batches_ .and. &
lhs%optimizer_ == rhs%optimizer_ .and. &
abs(lhs%learning_rate_ - rhs%learning_rate_) <= tolerance

end procedure

module procedure from_json
type(string_t), allocatable :: lines(:)
integer l
logical hyperparameters_key_found

lines = file_%lines()
hyperparameters_key_found = .false.

loop_through_file: &
do l=1,size(lines)
if (lines(l)%get_json_key() == "hyperparameters") then
hyperparameters_key_found = .true.
Expand All @@ -25,24 +41,24 @@
hyperparameters%optimizer_ = lines(l+3)%get_json_value(string_t(optimizer_key), mold=string_t(""))
return
end if
end do loop_through_file
end do

call assert(hyperparameters_key_found, "hyperparameters_s(from_json): hyperparameters_found")
end procedure

module procedure to_json
character(len=*), parameter :: indent = repeat(" ",ncopies=4)
integer, parameter :: max_digits = 12
character(len=max_digits) mini_batches_string, learning_rate_string
integer, parameter :: max_width= 18
character(len=max_width) mini_batches_string, learning_rate_string

write(mini_batches_string,*) self%mini_batches_
write(learning_rate_string,*) self%learning_rate_

lines = [ &
string_t(indent // '"hyperparameters": {'), &
string_t(indent // indent // '"' // mini_batches_key // '": ' // mini_batches_string ), &
string_t(indent // indent // '"' // learning_rate_key // '": ' // learning_rate_string ), &
string_t(indent // indent // '"' // optimizer_key // '": "' // self%optimizer_ // '"'), &
string_t(indent // indent // '"' // mini_batches_key // '" : ' // mini_batches_string ), &
string_t(indent // indent // '"' // learning_rate_key // '" : ' // learning_rate_string ), &
string_t(indent // indent // '"' // optimizer_key // '" : "' // self%optimizer_ // '"'), &
string_t(indent // '}') &
]
end procedure
Expand Down
39 changes: 0 additions & 39 deletions src/inference_engine/network_configuration_m.f90

This file was deleted.

54 changes: 0 additions & 54 deletions src/inference_engine/network_configuration_s.f90

This file was deleted.

1 change: 1 addition & 0 deletions src/inference_engine_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ module inference_engine_m
use swish_m, only : swish_t
use tensor_m, only : tensor_t
use trainable_engine_m, only : trainable_engine_t
use hyperparameters_m, only : hyperparameters_t
implicit none
end module
21 changes: 14 additions & 7 deletions test/hyperparameters_test_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ module hyperparameters_test_m

! External dependencies
use assert_m, only : assert
use sourcery_m, only : string_t, test_t, test_result_t
use sourcery_m, only : string_t, test_t, test_result_t, file_t
use inference_engine_m, only : hyperparameters_t

! Internal dependencies
use hyperparameters_m, only : hyperparameters_t
Expand All @@ -32,15 +33,15 @@ function results() result(test_results)
type(test_result_t), allocatable :: test_results(:)

character(len=*), parameter :: longest_description = &
"writing and then reading gives input matching output for perturbed identity network"
"component-wise construction followed by conversion to and from JSON"

associate( &
descriptions => &
[ character(len=len(longest_description)) :: &
"writing and then reading gives input matching output for perturbed identity network" &
"component-wise construction followed by conversion to and from JSON" &
], &
outcomes => &
[ write_then_read_perturbed_identity() &
[ write_then_read_hyperparameters() &
] &
)
call assert(size(descriptions) == size(outcomes),"hyperparameters_test_m(results): size(descriptions) == size(outcomes)")
Expand All @@ -49,9 +50,15 @@ function results() result(test_results)

end function

function write_then_read_perturbed_identity() result(test_passes)
logical, allocatable :: test_passes(:)
test_passes = [.true.]
function write_then_read_hyperparameters() result(test_passes)
logical test_passes

associate(hyperparameters => hyperparameters_t(mini_batches=5, learning_rate=1., optimizer = "stochastic gradient descent"))
associate(from_json => hyperparameters_t(hyperparameters%to_json()))
test_passes = hyperparameters == from_json
end associate
end associate

end function

end module hyperparameters_test_m

0 comments on commit 0c95d20

Please sign in to comment.