Skip to content

Commit

Permalink
feat(network_configuration): add type & JSON I/O
Browse files Browse the repository at this point in the history
  • Loading branch information
rouson committed Oct 31, 2023
1 parent 0c95d20 commit 829b967
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/inference_engine/network_configuration_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module network_configuration_m
use sourcery_m, only : string_t, file_t
implicit none

private
public :: network_configuration_t

type network_configuration_t
private
logical :: skip_connections_ = .false.
integer, allocatable :: nodes_per_layer_(:)
character(len=:), allocatable :: activation_function_
contains
procedure :: to_json
procedure :: equals
generic :: operator(==) => equals
end type

interface network_configuration_t

pure module function from_json(lines) result(network_configuration)
implicit none
type(string_t), intent(in) :: lines(:)
type(network_configuration_t) network_configuration
end function

pure module function from_components(skip_connections, nodes_per_layer, activation_function) result(network_configuration)
implicit none
logical, intent(in) :: skip_connections
integer, intent(in) :: nodes_per_layer(:)
character(len=*), intent(in) :: activation_function
type(network_configuration_t) network_configuration
end function

end interface

interface

pure module function to_json(self) result(lines)
implicit none
class(network_configuration_t), intent(in) :: self
type(string_t), allocatable :: lines(:)
end function

elemental module function equals(lhs, rhs) result(lhs_equals_rhs)
implicit none
class(network_configuration_t), intent(in) :: lhs, rhs
logical lhs_equals_rhs
end function

end interface

end module
68 changes: 68 additions & 0 deletions src/inference_engine/network_configuration_s.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
submodule(network_configuration_m) network_configuration_s
use assert_m, only : assert
use sourcery_m, only : csv
implicit none

character(len=*), parameter :: skip_connections_key = "skip connections"
character(len=*), parameter :: nodes_per_layer_key = "nodes per layer"
character(len=*), parameter :: activation_function_key = "activation function"

contains

module procedure from_components
network_configuration%skip_connections_ = skip_connections
network_configuration%nodes_per_layer_ = nodes_per_layer
network_configuration%activation_function_ = activation_function
end procedure

module procedure equals

call assert(allocated(lhs%activation_function_) .and. allocated(rhs%activation_function_), "network_configuration_s(equals): allocated activation_functions")

lhs_equals_rhs = &
lhs%skip_connections_ .eqv. rhs%skip_connections_ .and. &
lhs%activation_function_ == rhs%activation_function_ .and. &
all(lhs%nodes_per_layer_ == rhs%nodes_per_layer_)

end procedure

module procedure from_json
integer l
logical network_configuration_key_found

network_configuration_key_found = .false.

do l=1,size(lines)
if (lines(l)%get_json_key() == "network configuration") then
network_configuration_key_found = .true.
network_configuration%skip_connections_ = lines(l+1)%get_json_value(string_t(skip_connections_key), mold=.true.)
network_configuration%nodes_per_layer_ = lines(l+2)%get_json_integer_array(string_t(nodes_per_layer_key), mold=[0,1])
network_configuration%activation_function_ = lines(l+3)%get_json_value(string_t(activation_function_key), mold=string_t(""))
return
end if
end do

call assert(network_configuration_key_found, "network_configuration_s(from_json): network_configuration_found")
end procedure

module procedure to_json
character(len=*), parameter :: indent = repeat(" ",ncopies=4)
integer, parameter :: max_logical_width= 6, char_per_elem = 10, brackets = 2
character(len=max_logical_width) skip_connections_string
character(len=:), allocatable :: nodes_per_layer_string

allocate(character(len=size(self%nodes_per_layer_)*char_per_elem + brackets) :: nodes_per_layer_string)

write(skip_connections_string,*) trim(merge("true ","false",self%skip_connections_))
write(nodes_per_layer_string, csv) self%nodes_per_layer_

lines = [ &
string_t(indent // '"network configuration": {'), &
string_t(indent // indent // '"' // skip_connections_key // '" : ' // skip_connections_string ), &
string_t(indent // indent // '"' // nodes_per_layer_key // '" : [' // trim(nodes_per_layer_string) // ']' ), &
string_t(indent // indent // '"' // activation_function_key // '" : "' // self%activation_function_ // '"'), &
string_t(indent // '}') &
]
end procedure

end submodule network_configuration_s
1 change: 1 addition & 0 deletions src/inference_engine_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ module inference_engine_m
use tensor_m, only : tensor_t
use trainable_engine_m, only : trainable_engine_t
use hyperparameters_m, only : hyperparameters_t
use network_configuration_m, only : network_configuration_t
implicit none
end module
3 changes: 3 additions & 0 deletions test/main.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ program main
use asymmetric_engine_test_m, only : asymmetric_engine_test_t
use trainable_engine_test_m, only : trainable_engine_test_t
use hyperparameters_test_m, only : hyperparameters_test_t
use network_configuration_test_m, only : network_configuration_test_t
implicit none

type(inference_engine_test_t) inference_engine_test
type(asymmetric_engine_test_t) asymmetric_engine_test
type(trainable_engine_test_t) trainable_engine_test
type(hyperparameters_test_t) hyperparameters_test
type(network_configuration_test_t) network_configuration_test
real t_start, t_finish

integer :: passes=0, tests=0
Expand All @@ -21,6 +23,7 @@ program main
call asymmetric_engine_test%report(passes, tests)
call trainable_engine_test%report(passes, tests)
call hyperparameters_test%report(passes, tests)
call network_configuration_test%report(passes, tests)
#ifndef __INTEL_FORTRAN
block
use netCDF_file_test_m, only : netCDF_file_test_t
Expand Down
65 changes: 65 additions & 0 deletions test/network_configuration_test_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
! Copyright (c), The Regents of the University of California
! Terms of use are as specified in LICENSE.txt
module network_configuration_test_m
!! Test network_configuration_t object I/O and construction

! External dependencies
use assert_m, only : assert
use sourcery_m, only : string_t, test_t, test_result_t, file_t
use inference_engine_m, only : network_configuration_t

! Internal dependencies
use network_configuration_m, only : network_configuration_t

implicit none

private
public :: network_configuration_test_t

type, extends(test_t) :: network_configuration_test_t
contains
procedure, nopass :: subject
procedure, nopass :: results
end type

contains

pure function subject() result(specimen)
character(len=:), allocatable :: specimen
specimen = "A network_configuration_t object"
end function

function results() result(test_results)
type(test_result_t), allocatable :: test_results(:)

character(len=*), parameter :: longest_description = &
"component-wise construction followed by conversion to and from JSON"

associate( &
descriptions => &
[ character(len=len(longest_description)) :: &
"component-wise construction followed by conversion to and from JSON" &
], &
outcomes => &
[ write_then_read_network_configuration() &
] &
)
call assert(size(descriptions) == size(outcomes),"network_configuration_test_m(results): size(descriptions) == size(outcomes)")
test_results = test_result_t(descriptions, outcomes)
end associate

end function

function write_then_read_network_configuration() result(test_passes)
logical test_passes

associate(constructed_from_components=> &
network_configuration_t(skip_connections=.false., nodes_per_layer=[2,72,2], activation_function="sigmoid"))
associate(constructed_from_json => network_configuration_t(constructed_from_components%to_json()))
test_passes = constructed_from_components == constructed_from_json
end associate
end associate

end function

end module network_configuration_test_m

0 comments on commit 829b967

Please sign in to comment.