From 6adc1c25efe36f142b9611570ad618f8d2c9429b Mon Sep 17 00:00:00 2001 From: Saurabh Suresh Powar <66636289+Spnetic-5@users.noreply.github.com> Date: Wed, 26 Jul 2023 10:14:06 -0400 Subject: [PATCH] Added Adam optimizer implementation (#150) * Added Adam optimizer implementation * called rms from optimizers module * Minor fixes to improve adam performance * Suggested Changes * Remove dead code; format to <=80 columns * Use associate instead of explicit allocation for m_hat and v_hat; formatting * Added convergency test for Adam * AdamW: Adam with decay weights modification * AdamW Modifications * Fixed failing test * Add notes; clean up; make more internal parameters private * AdamW changes * flexible weight decay regularization * Formatting --------- Co-authored-by: milancurcic --- README.md | 4 +- example/quadratic.f90 | 193 ++++++++++++++++++++++++++------------- fpm.toml | 2 +- src/nf.f90 | 2 +- src/nf/nf_optimizers.f90 | 78 +++++++++++++++- test/test_optimizers.f90 | 25 ++++- 6 files changed, 230 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index 6be58112..0f3fc465 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). * Training and inference of dense (fully connected) and convolutional neural networks * Stochastic gradient descent optimizers: Classic, momentum, Nesterov momentum, - and RMSProp + RMSProp, Adam, AdamW * More than a dozen activation functions and their derivatives * Loading dense and convolutional models from Keras HDF5 (.h5) files * Data-based parallelism @@ -287,4 +287,4 @@ group. Neural-fortran has been used successfully in over a dozen published studies. See all papers that cite it -[here](https://scholar.google.com/scholar?cites=7315840714744905948). \ No newline at end of file +[here](https://scholar.google.com/scholar?cites=7315840714744905948). diff --git a/example/quadratic.f90 b/example/quadratic.f90 index 168a6aed..90394f07 100644 --- a/example/quadratic.f90 +++ b/example/quadratic.f90 @@ -4,10 +4,10 @@ program quadratic_fit ! descent. use nf, only: dense, input, network use nf_dense_layer, only: dense_layer - use nf_optimizers, only: sgd + use nf_optimizers, only: sgd, rmsprop, adam implicit none - type(network) :: net(6) + type(network) :: net(9) ! Training parameters integer, parameter :: num_epochs = 1000 @@ -16,6 +16,9 @@ program quadratic_fit integer, parameter :: batch_size = 100 real, parameter :: learning_rate = 0.01 real, parameter :: decay_rate = 0.9 + real, parameter :: beta1 = 0.85 + real, parameter :: beta2 = 0.95 + real, parameter :: epsilon = 1e-8 ! Input and output data real, allocatable :: x(:), y(:) ! training data @@ -51,19 +54,46 @@ program quadratic_fit call sgd_optimizer(net(1), x, y, xtest, ytest, learning_rate, num_epochs) ! SGD, momentum - call sgd_optimizer(net(2), x, y, xtest, ytest, learning_rate, num_epochs, momentum=0.9) + call sgd_optimizer( & + net(2), x, y, xtest, ytest, learning_rate, num_epochs, momentum=0.9 & + ) ! SGD, momentum with Nesterov - call sgd_optimizer(net(3), x, y, xtest, ytest, learning_rate, num_epochs, momentum=0.9, nesterov=.true.) + call sgd_optimizer( & + net(3), x, y, xtest, ytest, learning_rate, num_epochs, & + momentum=0.9, nesterov=.true. & + ) ! Batch SGD optimizer call batch_gd_optimizer(net(4), x, y, xtest, ytest, learning_rate, num_epochs) ! Mini-batch SGD optimizer - call minibatch_gd_optimizer(net(5), x, y, xtest, ytest, learning_rate, num_epochs, batch_size) + call minibatch_gd_optimizer( & + net(5), x, y, xtest, ytest, learning_rate, num_epochs, batch_size & + ) ! RMSProp optimizer - call rmsprop_optimizer(net(6), x, y, xtest, ytest, learning_rate, num_epochs, decay_rate) + call rmsprop_optimizer( & + net(6), x, y, xtest, ytest, learning_rate, num_epochs, decay_rate & + ) + + ! Adam optimizer + call adam_optimizer( & + net(7), x, y, xtest, ytest, learning_rate, num_epochs, & + beta1, beta2, epsilon & + ) + + ! Adam optimizer with L2 regularization + call adam_optimizer( & + net(8), x, y, xtest, ytest, learning_rate, num_epochs, & + beta1, beta2, epsilon, weight_decay_l2=1e-4 & + ) + + ! Adam optimizer with decoupled weight decay regularization + call adam_optimizer( & + net(9), x, y, xtest, ytest, learning_rate, num_epochs, & + beta1, beta2, epsilon, weight_decay_decoupled=1e-5 & + ) contains @@ -73,7 +103,9 @@ real elemental function quadratic(x) result(y) y = (x**2 / 2 + x / 2 + 1) / 2 end function quadratic - subroutine sgd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, momentum, nesterov) + subroutine sgd_optimizer( & + net, x, y, xtest, ytest, learning_rate, num_epochs, momentum, nesterov & + ) ! In the stochastic gradient descent (SGD) optimizer, we run the forward ! and backward passes and update the weights for each training sample, ! one at a time. @@ -109,12 +141,19 @@ subroutine sgd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, mom do i = 1, size(x) call net % forward([x(i)]) call net % backward([y(i)]) - call net % update(sgd(learning_rate=learning_rate, momentum=momentum_value, nesterov=nesterov_value)) + call net % update( & + sgd( & + learning_rate=learning_rate, & + momentum=momentum_value, & + nesterov=nesterov_value & + ) & + ) end do if (mod(n, num_epochs / 10) == 0) then ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))] - print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest) + print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', & + n, num_epochs, sum((ypred - ytest)**2) / size(ytest) end if end do @@ -123,7 +162,9 @@ subroutine sgd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, mom end subroutine sgd_optimizer - subroutine batch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs) + subroutine batch_gd_optimizer( & + net, x, y, xtest, ytest, learning_rate, num_epochs & + ) ! Like the stochastic gradient descent (SGD) optimizer, except that here we ! accumulate the weight gradients for all training samples and update the ! weights once per epoch. @@ -147,7 +188,8 @@ subroutine batch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs if (mod(n, num_epochs / 10) == 0) then ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))] - print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest) + print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', & + n, num_epochs, sum((ypred - ytest)**2) / size(ytest) end if end do @@ -156,7 +198,9 @@ subroutine batch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs end subroutine batch_gd_optimizer - subroutine minibatch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, batch_size) + subroutine minibatch_gd_optimizer( & + net, x, y, xtest, ytest, learning_rate, num_epochs, batch_size & + ) ! Like the batch SGD optimizer, except that here we accumulate the weight ! over a number of mini batches and update the weights once per mini batch. ! @@ -203,7 +247,8 @@ subroutine minibatch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_ep if (mod(n, num_epochs / 10) == 0) then ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))] - print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest) + print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', & + n, num_epochs, sum((ypred - ytest)**2) / size(ytest) end if end do @@ -212,7 +257,9 @@ subroutine minibatch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_ep end subroutine minibatch_gd_optimizer - subroutine rmsprop_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, decay_rate) + subroutine rmsprop_optimizer( & + net, x, y, xtest, ytest, learning_rate, num_epochs, decay_rate & + ) ! RMSprop optimizer for updating weights using root mean square type(network), intent(inout) :: net real, intent(in) :: x(:), y(:) @@ -220,29 +267,11 @@ subroutine rmsprop_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, real, intent(in) :: learning_rate, decay_rate integer, intent(in) :: num_epochs integer :: i, j, n - real, parameter :: epsilon = 1e-8 ! Small constant to avoid division by zero real, allocatable :: ypred(:) - ! Define a dedicated type to store the RMSprop gradients. - ! This is needed because array sizes vary between layers and we need to - ! keep track of gradients for each layer over time. - ! For now this works only for dense layers. - ! We will need to define a similar type for conv2d layers. - type :: rms_gradient_dense - real, allocatable :: dw(:,:) - real, allocatable :: db(:) - end type rms_gradient_dense - - type(rms_gradient_dense), allocatable :: rms(:) - print '(a)', 'RMSProp optimizer' print '(34("-"))' - ! Here we allocate the array or RMS gradient derived types. - ! We need one for each dense layer, however we will allocate it to the - ! length of all layers as it will make housekeeping easier. - allocate(rms(size(net % layers))) - do n = 1, num_epochs do i = 1, size(x) @@ -250,41 +279,14 @@ subroutine rmsprop_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, call net % backward([y(i)]) end do - ! RMSprop update rule - do j = 1, size(net % layers) - select type (this_layer => net % layers(j) % p) - type is (dense_layer) - - ! If this is our first time here for this layer, allocate the - ! internal RMS gradient arrays and initialize them to zero. - if (.not. allocated(rms(j) % dw)) then - allocate(rms(j) % dw, mold=this_layer % dw) - allocate(rms(j) % db, mold=this_layer % db) - rms(j) % dw = 0 - rms(j) % db = 0 - end if - - ! Update the RMS gradients using the RMSprop moving average rule - rms(j) % dw = decay_rate * rms(j) % dw + (1 - decay_rate) * this_layer % dw**2 - rms(j) % db = decay_rate * rms(j) % db + (1 - decay_rate) * this_layer % db**2 - - ! Update weights and biases using the RMSprop update rule - this_layer % weights = this_layer % weights - learning_rate & - / sqrt(rms(j) % dw + epsilon) * this_layer % dw - this_layer % biases = this_layer % biases - learning_rate & - / sqrt(rms(j) % db + epsilon) * this_layer % db - - ! We have updated the weights and biases, so we need to reset the - ! gradients to zero for the next epoch. - this_layer % dw = 0 - this_layer % db = 0 - - end select - end do + call net % update( & + rmsprop(learning_rate=learning_rate, decay_rate=decay_rate) & + ) if (mod(n, num_epochs / 10) == 0) then ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))] - print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest) + print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', & + n, num_epochs, sum((ypred - ytest)**2) / size(ytest) end if end do @@ -293,6 +295,69 @@ subroutine rmsprop_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, end subroutine rmsprop_optimizer + subroutine adam_optimizer( & + net, x, y, xtest, ytest, learning_rate, num_epochs, beta1, beta2, epsilon, & + weight_decay_l2, weight_decay_decoupled & + ) + ! Adam optimizer + type(network), intent(inout) :: net + real, intent(in) :: x(:), y(:) + real, intent(in) :: xtest(:), ytest(:) + real, intent(in) :: learning_rate, beta1, beta2, epsilon + real, intent(in), optional :: weight_decay_l2 + real, intent(in), optional :: weight_decay_decoupled + integer, intent(in) :: num_epochs + real, allocatable :: ypred(:) + integer :: i, n + real :: weight_decay_l2_val + real :: weight_decay_decoupled_val + + ! Set default values for weight_decay_l2 + if (.not. present(weight_decay_l2)) then + weight_decay_l2_val = 0.0 + else + weight_decay_l2_val = weight_decay_l2 + end if + + ! Set default values for weight_decay_decoupled + if (.not. present(weight_decay_decoupled)) then + weight_decay_decoupled_val = 0.0 + else + weight_decay_decoupled_val = weight_decay_decoupled + end if + + print '(a)', 'Adam optimizer' + print '(34("-"))' + + do n = 1, num_epochs + do i = 1, size(x) + call net % forward([x(i)]) + call net % backward([y(i)]) + end do + + call net % update( & + adam( & + learning_rate=learning_rate, & + beta1=beta1, & + beta2=beta2, & + epsilon=epsilon, & + weight_decay_l2=weight_decay_l2_val, & + weight_decay_decoupled=weight_decay_decoupled_val & + ) & + ) + + if (mod(n, num_epochs / 10) == 0) then + ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))] + print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', & + n, num_epochs, sum((ypred - ytest)**2) / size(ytest) + end if + + end do + + print *, '' + + end subroutine adam_optimizer + subroutine shuffle(arr) ! Shuffle an array using the Fisher-Yates algorithm. integer, intent(inout) :: arr(:) diff --git a/fpm.toml b/fpm.toml index a02d33a5..36242ef4 100644 --- a/fpm.toml +++ b/fpm.toml @@ -1,5 +1,5 @@ name = "neural-fortran" -version = "0.13.0" +version = "0.14.0" license = "MIT" author = "Milan Curcic" maintainer = "milancurcic@hey.com" diff --git a/src/nf.f90 b/src/nf.f90 index 82fcb80f..f26e99ba 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -5,7 +5,7 @@ module nf use nf_layer_constructors, only: & conv2d, dense, flatten, input, maxpool2d, reshape use nf_network, only: network - use nf_optimizers, only: sgd, rmsprop + use nf_optimizers, only: sgd, rmsprop, adam use nf_activation, only: activation_function, elu, exponential, & gaussian, linear, relu, leaky_relu, & sigmoid, softmax, softplus, step, tanhf, & diff --git a/src/nf/nf_optimizers.f90 b/src/nf/nf_optimizers.f90 index 3b7dc012..43f9d3b6 100644 --- a/src/nf/nf_optimizers.f90 +++ b/src/nf/nf_optimizers.f90 @@ -13,7 +13,7 @@ module nf_optimizers implicit none private - public :: optimizer_base_type, sgd, rmsprop + public :: optimizer_base_type, sgd, rmsprop, adam type, abstract :: optimizer_base_type real :: learning_rate = 0.01 @@ -43,22 +43,50 @@ end subroutine minimize !! Stochastic Gradient Descent optimizer real :: momentum = 0 logical :: nesterov = .false. - real, allocatable :: velocity(:) + real, allocatable, private :: velocity(:) contains procedure :: init => init_sgd procedure :: minimize => minimize_sgd end type sgd type, extends(optimizer_base_type) :: rmsprop - !! RMSProp optimizer + !! RMSProp optimizer by Hinton et al. (2012) + !! + !! Hinton, G., Srivastava, N. and Swersky, K., 2012. Neural networks for + !! machine learning lecture 6a overview of mini-batch gradient descent. + !! Cited on 2023-07-19, 14(8), p.2. Available at: + !! http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf real :: decay_rate = 0.9 real :: epsilon = 1e-8 - real, allocatable :: rms_gradient(:) + real, allocatable, private :: rms_gradient(:) contains procedure :: init => init_rmsprop procedure :: minimize => minimize_rmsprop end type rmsprop + type, extends(optimizer_base_type) :: adam + !! Adam optimizer by Kingma and Ba (2014), with optional decoupled weight + !! decay regularization (AdamW) by Loshchilov and Hutter (2017). + !! + !! Kingma, D.P. and Ba, J., 2014. Adam: A method for stochastic + !! optimization. arXiv preprint arXiv:1412.6980. + !! https://arxiv.org/abs/1412.6980 + !! + !! Loshchilov, I. and Hutter, F., 2017. Decoupled weight decay + !! regularization. arXiv preprint arXiv:1711.05101. + !! https://arxiv.org/abs/1711.05101 + real :: beta1 = 0.9 + real :: beta2 = 0.999 + real :: epsilon = 1e-8 + real :: weight_decay_l2 = 0 ! L2 regularization (Adam) + real :: weight_decay_decoupled = 0 ! decoupled weight decay regularization (AdamW) + real, allocatable, private :: m(:), v(:) + integer, private :: t = 0 + contains + procedure :: init => init_adam + procedure :: minimize => minimize_adam + end type adam + contains impure elemental subroutine init_sgd(self, num_params) @@ -123,4 +151,46 @@ pure subroutine minimize_rmsprop(self, param, gradient) end subroutine minimize_rmsprop + + impure elemental subroutine init_adam(self, num_params) + class(adam), intent(inout) :: self + integer, intent(in) :: num_params + if (.not. allocated(self % m)) then + allocate(self % m(num_params), self % v(num_params)) + self % m = 0 + self % v = 0 + end if + end subroutine init_adam + + + pure subroutine minimize_adam(self, param, gradient) + !! Concrete implementation of an Adam optimizer update rule. + class(adam), intent(inout) :: self + real, intent(inout) :: param(:) + real, intent(in) :: gradient(:) + + self % t = self % t + 1 + + ! If weight_decay_l2 > 0, use L2 regularization; + ! otherwise, default to regular Adam. + associate(g => gradient + self % weight_decay_l2 * param) + self % m = self % beta1 * self % m + (1 - self % beta1) * g + self % v = self % beta2 * self % v + (1 - self % beta2) * g**2 + end associate + + ! Compute bias-corrected first and second moment estimates. + associate( & + m_hat => self % m / (1 - self % beta1**self % t), & + v_hat => self % v / (1 - self % beta2**self % t) & + ) + + ! Update parameters. + param = param & + - self % learning_rate * m_hat / (sqrt(v_hat) + self % epsilon) & + - self % weight_decay_decoupled * param + + end associate + + end subroutine minimize_adam + end module nf_optimizers diff --git a/test/test_optimizers.f90 b/test/test_optimizers.f90 index dc41f912..1c3e5c3c 100644 --- a/test/test_optimizers.f90 +++ b/test/test_optimizers.f90 @@ -1,10 +1,10 @@ program test_optimizers - use nf, only: dense, input, network, rmsprop, sgd + use nf, only: dense, input, network, rmsprop, sgd, adam use iso_fortran_env, only: stderr => error_unit implicit none - type(network) :: net(4) + type(network) :: net(5) real, allocatable :: x(:), y(:) real, allocatable :: ypred(:) integer, parameter :: num_iterations = 1000 @@ -96,6 +96,27 @@ program test_optimizers ok = .false. end if + ! Test Adam optimizer + converged = .false. + + do n = 0, num_iterations + + call net(5) % forward(x) + call net(5) % backward(y) + call net(5) % update(optimizer=adam(learning_rate=0.01, beta1=0.9, beta2=0.999)) + + ypred = net(5) % predict(x) + converged = check_convergence(y, ypred) + if (converged) exit + + end do + + if (.not. converged) then + write(stderr, '(a)') 'adam should converge in simple training.. failed' + ok = .false. + end if + + if (ok) then print '(a)', 'test_optimizers: All tests passed.' else