From 0c95d20f4525a73f487f7eae7fc5a6012b9fafc4 Mon Sep 17 00:00:00 2001 From: Damian Rouson Date: Tue, 31 Oct 2023 07:58:16 -0700 Subject: [PATCH] test(hyperparameters): add unit test & finish type --- src/inference_engine/hyperparameters_m.f90 | 20 ++++++- src/inference_engine/hyperparameters_s.f90 | 34 ++++++++---- .../network_configuration_m.f90 | 39 -------------- .../network_configuration_s.f90 | 54 ------------------- src/inference_engine_m.f90 | 1 + test/hyperparameters_test_m.f90 | 21 +++++--- 6 files changed, 58 insertions(+), 111 deletions(-) delete mode 100644 src/inference_engine/network_configuration_m.f90 delete mode 100644 src/inference_engine/network_configuration_s.f90 diff --git a/src/inference_engine/hyperparameters_m.f90 b/src/inference_engine/hyperparameters_m.f90 index dea9e6580..174545008 100644 --- a/src/inference_engine/hyperparameters_m.f90 +++ b/src/inference_engine/hyperparameters_m.f90 @@ -12,13 +12,23 @@ module hyperparameters_m character(len=:), allocatable :: optimizer_ contains procedure :: to_json + procedure :: equals + generic :: operator(==) => equals end type interface hyperparameters_t - pure module function from_json(file_) result(hyperparameters) + pure module function from_json(lines) result(hyperparameters) implicit none - type(file_t), intent(in) :: file_ + type(string_t), intent(in) :: lines(:) + type(hyperparameters_t) hyperparameters + end function + + pure module function from_components(mini_batches, learning_rate, optimizer) result(hyperparameters) + implicit none + integer, intent(in) :: mini_batches + real, intent(in) :: learning_rate + character(len=*), intent(in) :: optimizer type(hyperparameters_t) hyperparameters end function @@ -32,6 +42,12 @@ pure module function to_json(self) result(lines) type(string_t), allocatable :: lines(:) end function + elemental module function equals(lhs, rhs) result(lhs_equals_rhs) + implicit none + class(hyperparameters_t), intent(in) :: lhs, rhs + logical lhs_equals_rhs + end function + end interface end module diff --git a/src/inference_engine/hyperparameters_s.f90 b/src/inference_engine/hyperparameters_s.f90 index c1f9aae94..aadab89ed 100644 --- a/src/inference_engine/hyperparameters_s.f90 +++ b/src/inference_engine/hyperparameters_s.f90 @@ -8,15 +8,31 @@ contains + module procedure from_components + hyperparameters%mini_batches_ = mini_batches + hyperparameters%learning_rate_ = learning_rate + hyperparameters%optimizer_ = optimizer + end procedure + + module procedure equals + + real, parameter :: tolerance = 1.E-08 + + call assert(allocated(lhs%optimizer_) .and. allocated(rhs%optimizer_), "hyperparameters_s(equals): allocated optimizers") + + lhs_equals_rhs = & + lhs%mini_batches_ == rhs%mini_batches_ .and. & + lhs%optimizer_ == rhs%optimizer_ .and. & + abs(lhs%learning_rate_ - rhs%learning_rate_) <= tolerance + + end procedure + module procedure from_json - type(string_t), allocatable :: lines(:) integer l logical hyperparameters_key_found - lines = file_%lines() hyperparameters_key_found = .false. - loop_through_file: & do l=1,size(lines) if (lines(l)%get_json_key() == "hyperparameters") then hyperparameters_key_found = .true. @@ -25,24 +41,24 @@ hyperparameters%optimizer_ = lines(l+3)%get_json_value(string_t(optimizer_key), mold=string_t("")) return end if - end do loop_through_file + end do call assert(hyperparameters_key_found, "hyperparameters_s(from_json): hyperparameters_found") end procedure module procedure to_json character(len=*), parameter :: indent = repeat(" ",ncopies=4) - integer, parameter :: max_digits = 12 - character(len=max_digits) mini_batches_string, learning_rate_string + integer, parameter :: max_width= 18 + character(len=max_width) mini_batches_string, learning_rate_string write(mini_batches_string,*) self%mini_batches_ write(learning_rate_string,*) self%learning_rate_ lines = [ & string_t(indent // '"hyperparameters": {'), & - string_t(indent // indent // '"' // mini_batches_key // '": ' // mini_batches_string ), & - string_t(indent // indent // '"' // learning_rate_key // '": ' // learning_rate_string ), & - string_t(indent // indent // '"' // optimizer_key // '": "' // self%optimizer_ // '"'), & + string_t(indent // indent // '"' // mini_batches_key // '" : ' // mini_batches_string ), & + string_t(indent // indent // '"' // learning_rate_key // '" : ' // learning_rate_string ), & + string_t(indent // indent // '"' // optimizer_key // '" : "' // self%optimizer_ // '"'), & string_t(indent // '}') & ] end procedure diff --git a/src/inference_engine/network_configuration_m.f90 b/src/inference_engine/network_configuration_m.f90 deleted file mode 100644 index 6f057e7a6..000000000 --- a/src/inference_engine/network_configuration_m.f90 +++ /dev/null @@ -1,39 +0,0 @@ -module network_configuration_m - use sourcery_m, only : file_t - implicit none - - private - public :: network_configuration_t - - type network_configuration_t - private - type(string_t) activation_function_ - integer, allocatable :: nodes_per_layer_(:) - logical skip_connections_ - contains - procedure :: to_json - end type - - end type - - interface network_configuration_t - - elemental module function from_json(file_) result(network_configuration) - implicit none - type(file_t), intent(in) :: file_ - type(network_configuration_t) network_configuration - end function - - end interface - - interface - - elemental module function to_json(self) result(json_file) - implicit none - class(network_configuration_t), intent(in) :: self - type(file_t) json_file - end function - - end interface - -end module diff --git a/src/inference_engine/network_configuration_s.f90 b/src/inference_engine/network_configuration_s.f90 deleted file mode 100644 index 43249fcc4..000000000 --- a/src/inference_engine/network_configuration_s.f90 +++ /dev/null @@ -1,54 +0,0 @@ -submodule(network_configuration_m) network_configuration_s - use assert_m, only : assert - use sourcery_m, only : string_t - implicit none - - character(len=*), parameter :: activation_function_key = "activation function" - character(len=*), parameter :: nodes_per_layer_key = "nodes per layer" - character(len=*), parameter :: skip_connections_key = "skip connections" - -contains - - module procedure from_json - type(string_t), allocatable :: lines(:) - integer l - logical network configuration_key_found - - lines = file_%lines() - network configuration_key_found = .false. - - loop_through_file: & - do l=1,size(lines) - if (line(l)%get_key() == "network configuration") then - network configuration_key_found = .true. - self%activation_function_ = line(l+1)%get_json_value(activation_function_key, mold=string("")) - self%nodes_per_layer_ = line(l+2)%get_json_value(nodes_per_layer_key , mold=[integer::]) - self%skip_connections_ = line(l+2)%get_json_value(skip_connetions_key , mold=.true.) - return - end if - end do loop_through_file - - call assert(network configuration_found, "network configuration_s(from_json): network configuration_found") - end procedure - - module procedure to_json - character(len=:), parameter :: indent = repeat(" ",ncopies=4) - integer, parameter :: max_digits = 12, max_length=5 - character(len=max_digits) activation_function_string, nodes_per_layer_string, skip_connections_string - character(len=max_length) skip_connections_string - - - write(activation_function_string,*) self%activation_function_ - write(nodes_per_layer_string ,*) self%nodes_per_layer_ - write(skip_connections_string ,*) merge("true ","false", self%skip_connections_) - - lines = [ & - string_t(indent // '"network configuration": {'), & - string_t(indent // indent // '"' // activation_function_key //'": ' // activation_function_string ), & - string_t(indent // indent // '"' // nodes_per_layer_key //'": ' // nodes_per_layer_string ), & - string_t(indent // indent // '"' // skip_connections_key //'": "' // skip_connections_string // '"'), & - string_t(indent // '}') & - ] - end procedure - -end submodule network_configuration_s diff --git a/src/inference_engine_m.f90 b/src/inference_engine_m.f90 index 81e3614f9..fd3bcdfc9 100644 --- a/src/inference_engine_m.f90 +++ b/src/inference_engine_m.f90 @@ -15,5 +15,6 @@ module inference_engine_m use swish_m, only : swish_t use tensor_m, only : tensor_t use trainable_engine_m, only : trainable_engine_t + use hyperparameters_m, only : hyperparameters_t implicit none end module diff --git a/test/hyperparameters_test_m.f90 b/test/hyperparameters_test_m.f90 index ee7440741..f2d1f65b4 100644 --- a/test/hyperparameters_test_m.f90 +++ b/test/hyperparameters_test_m.f90 @@ -5,7 +5,8 @@ module hyperparameters_test_m ! External dependencies use assert_m, only : assert - use sourcery_m, only : string_t, test_t, test_result_t + use sourcery_m, only : string_t, test_t, test_result_t, file_t + use inference_engine_m, only : hyperparameters_t ! Internal dependencies use hyperparameters_m, only : hyperparameters_t @@ -32,15 +33,15 @@ function results() result(test_results) type(test_result_t), allocatable :: test_results(:) character(len=*), parameter :: longest_description = & - "writing and then reading gives input matching output for perturbed identity network" + "component-wise construction followed by conversion to and from JSON" associate( & descriptions => & [ character(len=len(longest_description)) :: & - "writing and then reading gives input matching output for perturbed identity network" & + "component-wise construction followed by conversion to and from JSON" & ], & outcomes => & - [ write_then_read_perturbed_identity() & + [ write_then_read_hyperparameters() & ] & ) call assert(size(descriptions) == size(outcomes),"hyperparameters_test_m(results): size(descriptions) == size(outcomes)") @@ -49,9 +50,15 @@ function results() result(test_results) end function - function write_then_read_perturbed_identity() result(test_passes) - logical, allocatable :: test_passes(:) - test_passes = [.true.] + function write_then_read_hyperparameters() result(test_passes) + logical test_passes + + associate(hyperparameters => hyperparameters_t(mini_batches=5, learning_rate=1., optimizer = "stochastic gradient descent")) + associate(from_json => hyperparameters_t(hyperparameters%to_json())) + test_passes = hyperparameters == from_json + end associate + end associate + end function end module hyperparameters_test_m \ No newline at end of file