From 14a983df16e616aea95b0b9d47c9f1f44465ac7a Mon Sep 17 00:00:00 2001 From: Damian Rouson Date: Wed, 4 Oct 2023 09:46:14 -0700 Subject: [PATCH] feat(example): train ICAR saturated mixing ratio --- example/learn-saturated-mixing-ratio.f90 | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example/learn-saturated-mixing-ratio.f90 b/example/learn-saturated-mixing-ratio.f90 index 79c5d71f7..bd91a6fd4 100644 --- a/example/learn-saturated-mixing-ratio.f90 +++ b/example/learn-saturated-mixing-ratio.f90 @@ -24,7 +24,7 @@ program train_saturated_mixture_ratio call system_clock(counter_start, clock_rate) block - integer, parameter :: num_epochs = 10000000, num_mini_batches = 6 + integer, parameter :: max_num_epochs = 10000000, num_mini_batches = 6 integer num_pairs ! number of input/output pairs type(mini_batch_t), allocatable :: mini_batches(:) @@ -35,7 +35,7 @@ program train_saturated_mixture_ratio real, allocatable :: cost(:), random_numbers(:) integer io_status, network_unit, plot_unit integer, parameter :: io_success=0, diagnostics_print_interval = 1000, network_save_interval = 10000 - integer, parameter :: nodes_per_layer(*) = [2, 31, 31, 1] + integer, parameter :: nodes_per_layer(*) = [2, 72, 1] real, parameter :: cost_tolerance = 1.E-08 call random_init(image_distinct=.true., repeatable=.true.) @@ -81,7 +81,7 @@ program train_saturated_mixture_ratio print *, " Epoch | Cost Function| System_Clock | Nodes per Layer" allocate(random_numbers(2:size(input_output_pairs))) - do e = previous_epoch + 1, previous_epoch + num_epochs + do e = previous_epoch + 1, previous_epoch + max_num_epochs call random_number(random_numbers) call shuffle(input_output_pairs, random_numbers) mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))] @@ -91,7 +91,7 @@ program train_saturated_mixture_ratio associate( & cost_avg => sum(cost)/size(cost), & cumulative_clock_time => previous_clock_time + real(counter_end - counter_start) / real(clock_rate), & - loop_ending => e == previous_epoch + num_epochs & + loop_ending => e == previous_epoch + max_num_epochs & ) write_and_exit_if_converged: & if (cost_avg < cost_tolerance) then