Skip to content

Commit

Permalink
Added Adam optimizer implementation (#150)
Browse files Browse the repository at this point in the history
* Added Adam optimizer implementation

* called rms from optimizers module

* Minor fixes to improve adam performance

* Suggested Changes

* Remove dead code; format to <=80 columns

* Use associate instead of explicit allocation for m_hat and v_hat; formatting

* Added convergency test for Adam

* AdamW: Adam with decay weights modification

* AdamW Modifications

* Fixed failing test

* Add notes; clean up; make more internal parameters private

* AdamW changes

* flexible weight decay regularization

* Formatting

---------

Co-authored-by: milancurcic <caomaco@gmail.com>
  • Loading branch information
Spnetic-5 and milancurcic authored Jul 26, 2023
1 parent e9bfbd6 commit 6adc1c2
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 74 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
* Training and inference of dense (fully connected) and convolutional neural
networks
* Stochastic gradient descent optimizers: Classic, momentum, Nesterov momentum,
and RMSProp
RMSProp, Adam, AdamW
* More than a dozen activation functions and their derivatives
* Loading dense and convolutional models from Keras HDF5 (.h5) files
* Data-based parallelism
Expand Down Expand Up @@ -287,4 +287,4 @@ group.

Neural-fortran has been used successfully in over a dozen published studies.
See all papers that cite it
[here](https://scholar.google.com/scholar?cites=7315840714744905948).
[here](https://scholar.google.com/scholar?cites=7315840714744905948).
193 changes: 129 additions & 64 deletions example/quadratic.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ program quadratic_fit
! descent.
use nf, only: dense, input, network
use nf_dense_layer, only: dense_layer
use nf_optimizers, only: sgd
use nf_optimizers, only: sgd, rmsprop, adam

implicit none
type(network) :: net(6)
type(network) :: net(9)

! Training parameters
integer, parameter :: num_epochs = 1000
Expand All @@ -16,6 +16,9 @@ program quadratic_fit
integer, parameter :: batch_size = 100
real, parameter :: learning_rate = 0.01
real, parameter :: decay_rate = 0.9
real, parameter :: beta1 = 0.85
real, parameter :: beta2 = 0.95
real, parameter :: epsilon = 1e-8

! Input and output data
real, allocatable :: x(:), y(:) ! training data
Expand Down Expand Up @@ -51,19 +54,46 @@ program quadratic_fit
call sgd_optimizer(net(1), x, y, xtest, ytest, learning_rate, num_epochs)

! SGD, momentum
call sgd_optimizer(net(2), x, y, xtest, ytest, learning_rate, num_epochs, momentum=0.9)
call sgd_optimizer( &
net(2), x, y, xtest, ytest, learning_rate, num_epochs, momentum=0.9 &
)

! SGD, momentum with Nesterov
call sgd_optimizer(net(3), x, y, xtest, ytest, learning_rate, num_epochs, momentum=0.9, nesterov=.true.)
call sgd_optimizer( &
net(3), x, y, xtest, ytest, learning_rate, num_epochs, &
momentum=0.9, nesterov=.true. &
)

! Batch SGD optimizer
call batch_gd_optimizer(net(4), x, y, xtest, ytest, learning_rate, num_epochs)

! Mini-batch SGD optimizer
call minibatch_gd_optimizer(net(5), x, y, xtest, ytest, learning_rate, num_epochs, batch_size)
call minibatch_gd_optimizer( &
net(5), x, y, xtest, ytest, learning_rate, num_epochs, batch_size &
)

! RMSProp optimizer
call rmsprop_optimizer(net(6), x, y, xtest, ytest, learning_rate, num_epochs, decay_rate)
call rmsprop_optimizer( &
net(6), x, y, xtest, ytest, learning_rate, num_epochs, decay_rate &
)

! Adam optimizer
call adam_optimizer( &
net(7), x, y, xtest, ytest, learning_rate, num_epochs, &
beta1, beta2, epsilon &
)

! Adam optimizer with L2 regularization
call adam_optimizer( &
net(8), x, y, xtest, ytest, learning_rate, num_epochs, &
beta1, beta2, epsilon, weight_decay_l2=1e-4 &
)

! Adam optimizer with decoupled weight decay regularization
call adam_optimizer( &
net(9), x, y, xtest, ytest, learning_rate, num_epochs, &
beta1, beta2, epsilon, weight_decay_decoupled=1e-5 &
)

contains

Expand All @@ -73,7 +103,9 @@ real elemental function quadratic(x) result(y)
y = (x**2 / 2 + x / 2 + 1) / 2
end function quadratic

subroutine sgd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, momentum, nesterov)
subroutine sgd_optimizer( &
net, x, y, xtest, ytest, learning_rate, num_epochs, momentum, nesterov &
)
! In the stochastic gradient descent (SGD) optimizer, we run the forward
! and backward passes and update the weights for each training sample,
! one at a time.
Expand Down Expand Up @@ -109,12 +141,19 @@ subroutine sgd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, mom
do i = 1, size(x)
call net % forward([x(i)])
call net % backward([y(i)])
call net % update(sgd(learning_rate=learning_rate, momentum=momentum_value, nesterov=nesterov_value))
call net % update( &
sgd( &
learning_rate=learning_rate, &
momentum=momentum_value, &
nesterov=nesterov_value &
) &
)
end do

if (mod(n, num_epochs / 10) == 0) then
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', &
n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
end if

end do
Expand All @@ -123,7 +162,9 @@ subroutine sgd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, mom

end subroutine sgd_optimizer

subroutine batch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs)
subroutine batch_gd_optimizer( &
net, x, y, xtest, ytest, learning_rate, num_epochs &
)
! Like the stochastic gradient descent (SGD) optimizer, except that here we
! accumulate the weight gradients for all training samples and update the
! weights once per epoch.
Expand All @@ -147,7 +188,8 @@ subroutine batch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs

if (mod(n, num_epochs / 10) == 0) then
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', &
n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
end if

end do
Expand All @@ -156,7 +198,9 @@ subroutine batch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs

end subroutine batch_gd_optimizer

subroutine minibatch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, batch_size)
subroutine minibatch_gd_optimizer( &
net, x, y, xtest, ytest, learning_rate, num_epochs, batch_size &
)
! Like the batch SGD optimizer, except that here we accumulate the weight
! over a number of mini batches and update the weights once per mini batch.
!
Expand Down Expand Up @@ -203,7 +247,8 @@ subroutine minibatch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_ep

if (mod(n, num_epochs / 10) == 0) then
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', &
n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
end if

end do
Expand All @@ -212,79 +257,36 @@ subroutine minibatch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_ep

end subroutine minibatch_gd_optimizer

subroutine rmsprop_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, decay_rate)
subroutine rmsprop_optimizer( &
net, x, y, xtest, ytest, learning_rate, num_epochs, decay_rate &
)
! RMSprop optimizer for updating weights using root mean square
type(network), intent(inout) :: net
real, intent(in) :: x(:), y(:)
real, intent(in) :: xtest(:), ytest(:)
real, intent(in) :: learning_rate, decay_rate
integer, intent(in) :: num_epochs
integer :: i, j, n
real, parameter :: epsilon = 1e-8 ! Small constant to avoid division by zero
real, allocatable :: ypred(:)

! Define a dedicated type to store the RMSprop gradients.
! This is needed because array sizes vary between layers and we need to
! keep track of gradients for each layer over time.
! For now this works only for dense layers.
! We will need to define a similar type for conv2d layers.
type :: rms_gradient_dense
real, allocatable :: dw(:,:)
real, allocatable :: db(:)
end type rms_gradient_dense

type(rms_gradient_dense), allocatable :: rms(:)

print '(a)', 'RMSProp optimizer'
print '(34("-"))'

! Here we allocate the array or RMS gradient derived types.
! We need one for each dense layer, however we will allocate it to the
! length of all layers as it will make housekeeping easier.
allocate(rms(size(net % layers)))

do n = 1, num_epochs

do i = 1, size(x)
call net % forward([x(i)])
call net % backward([y(i)])
end do

! RMSprop update rule
do j = 1, size(net % layers)
select type (this_layer => net % layers(j) % p)
type is (dense_layer)

! If this is our first time here for this layer, allocate the
! internal RMS gradient arrays and initialize them to zero.
if (.not. allocated(rms(j) % dw)) then
allocate(rms(j) % dw, mold=this_layer % dw)
allocate(rms(j) % db, mold=this_layer % db)
rms(j) % dw = 0
rms(j) % db = 0
end if

! Update the RMS gradients using the RMSprop moving average rule
rms(j) % dw = decay_rate * rms(j) % dw + (1 - decay_rate) * this_layer % dw**2
rms(j) % db = decay_rate * rms(j) % db + (1 - decay_rate) * this_layer % db**2

! Update weights and biases using the RMSprop update rule
this_layer % weights = this_layer % weights - learning_rate &
/ sqrt(rms(j) % dw + epsilon) * this_layer % dw
this_layer % biases = this_layer % biases - learning_rate &
/ sqrt(rms(j) % db + epsilon) * this_layer % db

! We have updated the weights and biases, so we need to reset the
! gradients to zero for the next epoch.
this_layer % dw = 0
this_layer % db = 0

end select
end do
call net % update( &
rmsprop(learning_rate=learning_rate, decay_rate=decay_rate) &
)

if (mod(n, num_epochs / 10) == 0) then
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', &
n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
end if

end do
Expand All @@ -293,6 +295,69 @@ subroutine rmsprop_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs,

end subroutine rmsprop_optimizer

subroutine adam_optimizer( &
net, x, y, xtest, ytest, learning_rate, num_epochs, beta1, beta2, epsilon, &
weight_decay_l2, weight_decay_decoupled &
)
! Adam optimizer
type(network), intent(inout) :: net
real, intent(in) :: x(:), y(:)
real, intent(in) :: xtest(:), ytest(:)
real, intent(in) :: learning_rate, beta1, beta2, epsilon
real, intent(in), optional :: weight_decay_l2
real, intent(in), optional :: weight_decay_decoupled
integer, intent(in) :: num_epochs
real, allocatable :: ypred(:)
integer :: i, n
real :: weight_decay_l2_val
real :: weight_decay_decoupled_val

! Set default values for weight_decay_l2
if (.not. present(weight_decay_l2)) then
weight_decay_l2_val = 0.0
else
weight_decay_l2_val = weight_decay_l2
end if

! Set default values for weight_decay_decoupled
if (.not. present(weight_decay_decoupled)) then
weight_decay_decoupled_val = 0.0
else
weight_decay_decoupled_val = weight_decay_decoupled
end if

print '(a)', 'Adam optimizer'
print '(34("-"))'

do n = 1, num_epochs
do i = 1, size(x)
call net % forward([x(i)])
call net % backward([y(i)])
end do

call net % update( &
adam( &
learning_rate=learning_rate, &
beta1=beta1, &
beta2=beta2, &
epsilon=epsilon, &
weight_decay_l2=weight_decay_l2_val, &
weight_decay_decoupled=weight_decay_decoupled_val &
) &
)

if (mod(n, num_epochs / 10) == 0) then
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', &
n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
end if

end do

print *, ''

end subroutine adam_optimizer

subroutine shuffle(arr)
! Shuffle an array using the Fisher-Yates algorithm.
integer, intent(inout) :: arr(:)
Expand Down
2 changes: 1 addition & 1 deletion fpm.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name = "neural-fortran"
version = "0.13.0"
version = "0.14.0"
license = "MIT"
author = "Milan Curcic"
maintainer = "milancurcic@hey.com"
Expand Down
2 changes: 1 addition & 1 deletion src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module nf
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
use nf_network, only: network
use nf_optimizers, only: sgd, rmsprop
use nf_optimizers, only: sgd, rmsprop, adam
use nf_activation, only: activation_function, elu, exponential, &
gaussian, linear, relu, leaky_relu, &
sigmoid, softmax, softplus, step, tanhf, &
Expand Down
Loading

0 comments on commit 6adc1c2

Please sign in to comment.