Skip to content

Learning stabilization with adaptive learning rate clipping

License

Notifications You must be signed in to change notification settings

Jeffrey-Ede/ALRC

Repository files navigation

Adaptive Learning Rate Clipping (ALRC)

DOI

Repository for the preprint|paper "Adatpive Learning Rate Clipping Stabilizes Learning".

This repository contains source code for CIFAR-10 supersampling experiments with squared and quartic errors, and stable and unstably high learning rates. An implementation of the ALRC algorithm is in alrc.py. Source code for partial-STEM is here.

Example learning curves for stable and unstably high learning rates. ALRC stabilizes learning by preventing loss spikes and otherwise has little effect. Learning curves are 500 iteration boxcar averaged. Results are similar for low and high order loss functions, different batch sizes and different optimizers.

Description

ALRC is a simple, computationally inexpensive algorithm that stabilizes learning by limiting loss spikes. It can be applied to any neural network trained with gradient descent. In practice, it improves the training of neural networks where learning is destabilized by loss spikes and otherwise has little effect.

Example

ALRC can be applied like any other neural network layer and is robust to hyperparamer choices. The only hyperparameters that need to be provided are estimates for the mean and mean squared loss function at the start of training. Any sensible overestimates are fine: even if they are an order of magnitude too high, the ALRC algorithm will decay them to the correct values.

#Roughly estimate the first two raw moments of your loss function
mu1_start_estimate = ... #Your estimate
mu2_start_estimate = ... #Your estimate

#It's fine to overestimate
overestimate_factor = 3 
mu1_start_estimate *= overestimate_factor
mu2_start_estimate *= overestimate_factor**2

loss = my_loss_fn( ... ) #Apply neural network and infer loss
loss = alrc(loss, mu1_start=mu1_start_estimate, mu2_start=mu2_start_estimate) #Apply ALRC

Note that mu2_start should be larger than mu1_start**2.

When Should I Use ALRC?

If learning is destabilized by loss spikes. This is common for training at small batch sizes, unstably high learning rates or high order loss functions. It might also help if your dataset contains unusual or mislabelled examples that cause loss spikes.

ALRC can also be used to safeguard against potential loss spikes. Anecdoteally, this was the situation in our partial STEM experiments. Large loss spikes would sometimes occur partway through training, which made results difficult to compare. ALRC prevented loss spikes, making training more consistent so that different experiments could be compared.

Training data

Our training dataset containing 161069 crops from STEM images is available here.

Contact

Jeffrey Ede: j.m.ede@warwick.ac.uk
Richard Beanland: r.beanland@warwick.ac.uk