From 2781a6bc7efbafd3049ab26f7e3c4ad1950fb160 Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Tue, 8 Jun 2021 06:54:36 -0700 Subject: [PATCH] Add LARS optimiser. PiperOrigin-RevId: 378143672 --- optax/__init__.py | 2 ++ optax/_src/alias.py | 59 ++++++++++++++++++++++++++++++++-- optax/_src/equivalence_test.py | 13 ++++++-- optax/_src/transform.py | 8 +++-- 4 files changed, 76 insertions(+), 6 deletions(-) diff --git a/optax/__init__.py b/optax/__init__.py index d0724346c..1144b496a 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -21,6 +21,7 @@ from optax._src.alias import dpsgd from optax._src.alias import fromage from optax._src.alias import lamb +from optax._src.alias import lars from optax._src.alias import noisy_sgd from optax._src.alias import radam from optax._src.alias import rmsprop @@ -192,6 +193,7 @@ "InjectHyperparamsState", "join_schedules", "lamb", + "lars", "log_cosh", "lookahead", "LookaheadParams", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 626984b3a..c4aa28f10 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -22,9 +22,11 @@ from optax._src import combine from optax._src import privacy from optax._src import transform +from optax._src import wrappers ScalarOrSchedule = Union[float, base.Schedule] +MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] def _scale_by_learning_rate(learning_rate: ScalarOrSchedule): @@ -214,13 +216,62 @@ def fromage( ) +def lars( + learning_rate: ScalarOrSchedule, + weight_decay: float = 0., + weight_decay_mask: MaskOrFn = True, + trust_coefficient: float = 0.001, + eps: float = 0., + trust_ratio_mask: MaskOrFn = True, + momentum: float = 0.9, + nesterov: bool = False, +) -> base.GradientTransformation: + """The LARS optimiser. + + LAMB is a layer-wise adaptive optimiser introduced to help scale SGD to + larger batch sizes. LARS later inspired the LAMB optimiser. + + References: + You et al, 2017: https://arxiv.org/abs/1708.03888 + + Args: + learning_rate: this is a fixed global scaling factor. + weight_decay (default `0.`): strength of the weight decay regularization. + weight_decay_mask: a tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + trust_coefficient: a multiplier for the trust ratio. + eps: optional additive constant in the trust ratio denominator. + trust_ratio_mask: a tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + momentum: the decay rate for momentum. + nesterov: whether to use Nesterov momentum. + + Returns: + the corresponding `GradientTransformation`. + """ + return combine.chain( + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), + wrappers.masked( + inner=transform.scale_by_trust_ratio( + trust_coefficient=trust_coefficient, eps=eps), + mask=trust_ratio_mask), + _scale_by_learning_rate(learning_rate), + transform.trace(decay=momentum, nesterov=nesterov), + ) + + def lamb( learning_rate: ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-6, eps_root: float = 0.0, - weight_decay: float = 0. + weight_decay: float = 0., + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, ) -> base.GradientTransformation: """The LAMB optimiser. @@ -243,13 +294,17 @@ def lamb( the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam. weight_decay (default `0.`): strength of the weight decay regularization. + mask: a tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. Returns: the corresponding `GradientTransformation`. """ return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root), - transform.add_decayed_weights(weight_decay), + transform.add_decayed_weights(weight_decay=weight_decay, mask=mask), transform.scale_by_trust_ratio(), _scale_by_learning_rate(learning_rate), ) diff --git a/optax/_src/equivalence_test.py b/optax/_src/equivalence_test.py index eb2727a86..6632c8025 100644 --- a/optax/_src/equivalence_test.py +++ b/optax/_src/equivalence_test.py @@ -94,8 +94,10 @@ class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() - self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) - self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) + self.init_params = ( + jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.])) + self.per_step_updates = ( + jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', @@ -125,6 +127,13 @@ def setUp(self): ('lamb', alias.lamb(LR), optim.LAMB(LR)), + ('lars', + alias.lars( + LR, weight_decay=.5, trust_coefficient=0.003, + momentum=0.9, eps=1e-3), + optim.LARS( + LR, weight_decay=.5, trust_coefficient=0.003, + beta=0.9, eps=1e-3)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 6f353b6c0..8e9eab9dc 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -617,7 +617,9 @@ class ScaleByTrustRatioState(NamedTuple): def scale_by_trust_ratio( - min_norm: float = 0.0 + min_norm: float = 0.0, + trust_coefficient: float = 1., + eps: float = 0., ) -> base.GradientTransformation: """Scale updates by trust ratio`. @@ -626,6 +628,8 @@ def scale_by_trust_ratio( Args: min_norm: minimum norm for params and gradient norms; by default is zero. + trust_coefficient: a multiplier for the trust ratio. + eps: additive constant added to the denominator for numerical stability. Returns: An (init_fn, update_fn) tuple. @@ -643,7 +647,7 @@ def _scale_update(update, param): # Clip norms to minimum value, by default no clipping. param_norm = numerics.safe_norm(param, min_norm) update_norm = numerics.safe_norm(update, min_norm) - trust_ratio = param_norm / update_norm + trust_ratio = trust_coefficient * param_norm / (update_norm + eps) # If no minimum norm clipping is used # Set trust_ratio to 1 in case where parameters would never be updated.