From 41702922fea09b915358e5903705c366bc018bf0 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 14:45:38 +0300 Subject: [PATCH 01/11] replaced seed with key --- optax/_src/alias.py | 17 +++++++---- optax/_src/alias_test.py | 10 +++---- optax/_src/utils.py | 4 +-- optax/contrib/_privacy.py | 25 ++++++++++++---- optax/contrib/_privacy_test.py | 17 ++++++++--- .../stochastic_gradient_estimators.py | 15 +++++----- .../stochastic_gradient_estimators_test.py | 17 ++++++----- optax/perturbations/_make_pert.py | 8 ++--- optax/transforms/_adding.py | 30 +++++++++++++++++-- optax/transforms/_adding_test.py | 6 ++-- 10 files changed, 102 insertions(+), 47 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index b20d33e1c..5163f57eb 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -26,6 +26,7 @@ from optax._src import linesearch as _linesearch from optax._src import transform from optax._src import wrappers +import chex MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] @@ -1254,9 +1255,9 @@ def lamb( def noisy_sgd( learning_rate: base.ScalarOrSchedule, + key: Optional[chex.PRNGKey] = None, eta: float = 0.01, gamma: float = 0.55, - seed: int = 0, ) -> base.GradientTransformation: r"""A variant of SGD with added noise. @@ -1284,10 +1285,10 @@ def noisy_sgd( Args: learning_rate: A global scaling factor, either fixed or evolving along iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + key: a PRNG key used as the random key. eta: Initial variance for the Gaussian noise added to gradients. gamma: A parameter controlling the annealing of noise over time ``t``, the variance decays according to ``(1+t)**(-gamma)``. - seed: Seed for the pseudo-random generation process. Returns: The corresponding :class:`optax.GradientTransformation`. @@ -1297,7 +1298,8 @@ def noisy_sgd( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> solver = optax.noisy_sgd(learning_rate=0.003) + >>> key = jax.random.key(0) + >>> solver = optax.noisy_sgd(learning_rate=0.003, key) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 @@ -1317,8 +1319,13 @@ def noisy_sgd( Neelakantan et al, `Adding Gradient Noise Improves Learning for Very Deep Networks `_, 2015 """ + if key is None: + raise ValueError( + "noisy_sgd optimizer requires specifying key: " + "noisy_sgd(..., key=jax.random.key(0))" + ) return combine.chain( - transform.add_noise(eta, gamma, seed), + transform.add_noise(key, eta, gamma), transform.scale_by_learning_rate(learning_rate), ) @@ -2394,7 +2401,7 @@ def lbfgs( linesearch: Optional[ base.GradientTransformationExtraArgs ] = _linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=20, initial_guess_strategy='one' + max_linesearch_steps=20, initial_guess_strategy="one" ), ) -> base.GradientTransformationExtraArgs: r"""L-BFGS optimizer. diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 067cb1292..71a3c0426 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -66,7 +66,7 @@ {'opt_name': 'nadamw', 'opt_kwargs': {'learning_rate': 1e-2}}, { 'opt_name': 'noisy_sgd', - 'opt_kwargs': {'learning_rate': 1e-3, 'eta': 1e-4}, + 'opt_kwargs': {'learning_rate': 1e-3, 'key': jrd.key(0), 'eta': 1e-4}, }, {'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1e-3}}, { @@ -573,7 +573,7 @@ def zakharov(x, xnp): class LBFGSTest(chex.TestCase): def test_plain_preconditioning(self): - key = jrd.PRNGKey(0) + key = jrd.key(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 d = 3 @@ -592,7 +592,7 @@ def test_plain_preconditioning(self): @parameterized.product(idx=[0, 1, 2, 3]) def test_preconditioning_by_lbfgs_on_vectors(self, idx: int): - key = jrd.PRNGKey(0) + key = jrd.key(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 d = 3 @@ -619,7 +619,7 @@ def test_preconditioning_by_lbfgs_on_vectors(self, idx: int): @parameterized.product(idx=[0, 1, 2, 3]) def test_preconditioning_by_lbfgs_on_trees(self, idx: int): - key = jrd.PRNGKey(0) + key = jrd.key(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 shapes = ((3, 2), (5,)) @@ -721,7 +721,7 @@ def fun_(x): def fun(x): return otu.tree_sum(jax.tree.map(fun_, x)) - key = jrd.PRNGKey(0) + key = jrd.key(0) init_array = jrd.normal(key, (2, 4)) init_tree = (init_array[0], init_array[1]) diff --git a/optax/_src/utils.py b/optax/_src/utils.py index 75a0ebb79..64bb8414d 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -109,10 +109,10 @@ def __init__(self, loc: chex.Array, log_scale: chex.Array): self._mean.shape, self._scale.shape ) - def sample(self, shape: Sequence[int], seed: chex.PRNGKey) -> chex.Array: + def sample(self, shape: Sequence[int], key: chex.PRNGKey) -> chex.Array: sample_shape = tuple(shape) + self._param_shape return ( - jax.random.normal(seed, shape=sample_shape) * self._scale + self._mean + jax.random.normal(key, shape=sample_shape) * self._scale + self._mean ) def log_prob(self, x: chex.Array) -> chex.Array: diff --git a/optax/contrib/_privacy.py b/optax/contrib/_privacy.py index b36901014..a6edcae3e 100644 --- a/optax/contrib/_privacy.py +++ b/optax/contrib/_privacy.py @@ -17,6 +17,7 @@ from typing import Any, NamedTuple, Optional import jax +import chex from optax._src import base from optax._src import clipping from optax._src import combine @@ -33,14 +34,16 @@ class DifferentiallyPrivateAggregateState(NamedTuple): def differentially_private_aggregate( - l2_norm_clip: float, noise_multiplier: float, seed: int + l2_norm_clip: float, + noise_multiplier: float, + key: Optional[chex.PRNGKey] = None ) -> base.GradientTransformation: """Aggregates gradients based on the DPSGD algorithm. Args: l2_norm_clip: maximum L2 norm of the per-example gradients. noise_multiplier: ratio of standard deviation to the clipping norm. - seed: initial seed used for the jax.random.PRNGKey + key: a PRNG key used as the random key. Returns: A :class:`optax.GradientTransformation`. @@ -56,11 +59,16 @@ def differentially_private_aggregate( JAX using `jax.vmap`). It can still be composed with other transformations as long as it is the first in the chain. """ + if key is None: + raise ValueError( + "differentially_private_aggregate optimizer requires specifying key: " + "differentially_private_aggregate(..., key=jax.random.key(0))" + ) noise_std = l2_norm_clip * noise_multiplier def init_fn(params): del params - return DifferentiallyPrivateAggregateState(rng_key=jax.random.PRNGKey(seed)) + return DifferentiallyPrivateAggregateState(rng_key=key) def update_fn(updates, state, params=None): del params @@ -85,7 +93,7 @@ def dpsgd( learning_rate: base.ScalarOrSchedule, l2_norm_clip: float, noise_multiplier: float, - seed: int, + key: Optional[chex.PRNGKey] = None, momentum: Optional[float] = None, nesterov: bool = False, ) -> base.GradientTransformation: @@ -100,7 +108,7 @@ def dpsgd( learning_rate: A fixed global scaling factor. l2_norm_clip: Maximum L2 norm of the per-example gradients. noise_multiplier: Ratio of standard deviation to the clipping norm. - seed: Initial seed used for the jax.random.PRNGKey + key: a PRNG key used as the random key. momentum: Decay rate used by the momentum term, when it is set to `None`, then momentum is not used at all. nesterov: Whether Nesterov momentum is used. @@ -117,11 +125,16 @@ def dpsgd( batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using `jax.vmap`). """ + if key is None: + raise ValueError( + "dpsgd optimizer requires specifying key: " + "dpsgd(..., key=jax.random.key(0))" + ) return combine.chain( differentially_private_aggregate( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, - seed=seed, + key=key, ), ( transform.trace(decay=momentum, nesterov=nesterov) diff --git a/optax/contrib/_privacy_test.py b/optax/contrib/_privacy_test.py index 47e72f40e..f00b6eed3 100644 --- a/optax/contrib/_privacy_test.py +++ b/optax/contrib/_privacy_test.py @@ -19,6 +19,7 @@ import chex import jax import jax.numpy as jnp +import jax.random as jrd from optax.contrib import _privacy @@ -45,7 +46,9 @@ def setUp(self): def test_no_privacy(self): """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD.""" dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, seed=0 + l2_norm_clip=jnp.finfo(jnp.float32).max, + noise_multiplier=0.0, + key=jrd.key(0) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -59,7 +62,9 @@ def test_no_privacy(self): @parameterized.parameters(0.5, 10.0, 20.0, 40.0, 80.0) def test_clipping_norm(self, l2_norm_clip): dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, seed=42 + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + key=jrd.key(42) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -87,7 +92,9 @@ def test_clipping_norm(self, l2_norm_clip): def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): """Standard dev. of noise should be l2_norm_clip * noise_multiplier.""" dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, seed=1337 + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + key=jrd.key(1337) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -103,7 +110,9 @@ def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): def test_aggregated_updates_as_input_fails(self): """Expect per-example gradients as input to this transform.""" dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=0.1, noise_multiplier=1.1, seed=2021 + l2_norm_clip=0.1, + noise_multiplier=1.1, + key=jrd.key(2021) ) state = dp_agg.init(self.params) mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads) diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index 541b697a3..b766bd97a 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -34,6 +34,7 @@ import chex import jax +import jax.radom as jrd import jax.numpy as jnp import numpy as np from optax._src import base @@ -241,15 +242,15 @@ def measure_valued_estimation_mean( dist_samples = dist.sample((num_samples,), seed=rng) - pos_rng, neg_rng = jax.random.split(rng) - pos_sample = jax.random.weibull_min( + pos_rng, neg_rng = jrd.split(rng) + pos_sample = jrd.weibull_min( pos_rng, scale=math.sqrt(2.0), concentration=2.0, shape=dist_samples.shape ) if coupling: neg_sample = pos_sample else: - neg_sample = jax.random.weibull_min( + neg_sample = jrd.weibull_min( neg_rng, scale=math.sqrt(2.0), concentration=2.0, @@ -314,17 +315,17 @@ def measure_valued_estimation_std( dist_samples = dist.sample((num_samples,), seed=rng) - pos_rng, neg_rng = jax.random.split(rng) + pos_rng, neg_rng = jrd.split(rng) # The only difference between mean and std gradients is what we sample. - pos_sample = jax.random.double_sided_maxwell( + pos_sample = jrd.double_sided_maxwell( pos_rng, loc=0.0, scale=1.0, shape=dist_samples.shape ) if coupling: - unif_rvs = jax.random.uniform(neg_rng, dist_samples.shape) + unif_rvs = jrd.uniform(neg_rng, dist_samples.shape) neg_sample = unif_rvs * pos_sample else: - neg_sample = jax.random.normal(neg_rng, dist_samples.shape) + neg_sample = jrd.normal(neg_rng, dist_samples.shape) # Both need to be positive in the case of the scale. # N x D diff --git a/optax/monte_carlo/stochastic_gradient_estimators_test.py b/optax/monte_carlo/stochastic_gradient_estimators_test.py index a54501c6e..3ac06630a 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators_test.py +++ b/optax/monte_carlo/stochastic_gradient_estimators_test.py @@ -18,6 +18,7 @@ from absl.testing import parameterized import chex import jax +import jrd as jrd import jax.numpy as jnp import numpy as np from optax._src import utils @@ -99,7 +100,7 @@ def test_constant_function(self, estimator, constant): effective_log_scale = 0.0 log_scale = effective_log_scale * _ones(data_dims) - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) jacobians = _estimator_variant(self.variant, estimator)( lambda x: jnp.array(constant), @@ -144,7 +145,7 @@ def test_linear_function( ): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) @@ -185,7 +186,7 @@ def test_quadratic_function( ): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) @@ -233,7 +234,7 @@ def test_weighted_linear( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) mean = jnp.array(effective_mean) log_scale = jnp.array(effective_log_scale) @@ -280,7 +281,7 @@ def test_weighted_quadratic( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) @@ -342,8 +343,8 @@ def testNonPolynomialFunctionConsistencyWithPathwise( self, effective_mean, effective_log_scale, function, coupling ): num_samples = 10**5 - rng = jax.random.PRNGKey(1) - measure_rng, pathwise_rng = jax.random.split(rng) + rng = jrd.key(1) + measure_rng, pathwise_rng = jrd.split(rng) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) @@ -405,7 +406,7 @@ class MeasuredValuedEstimatorsTest(chex.TestCase): @parameterized.parameters([True, False]) def test_raises_error_for_non_gaussian(self, coupling): num_samples = 10**5 - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) function = lambda x: jnp.sum(x) ** 2 diff --git a/optax/perturbations/_make_pert.py b/optax/perturbations/_make_pert.py index ae711f5d2..434fa4d1d 100644 --- a/optax/perturbations/_make_pert.py +++ b/optax/perturbations/_make_pert.py @@ -35,11 +35,11 @@ class Normal: def sample( self, - seed: chex.PRNGKey, + key: chex.PRNGKey, sample_shape: Shape, dtype: chex.ArrayDType = float, ) -> jax.Array: - return jax.random.normal(seed, sample_shape, dtype) + return jax.random.normal(key, sample_shape, dtype) def log_prob(self, inputs: jax.Array) -> jax.Array: return -0.5 * inputs**2 @@ -50,11 +50,11 @@ class Gumbel: def sample( self, - seed: chex.PRNGKey, + key: chex.PRNGKey, sample_shape: Shape, dtype: chex.ArrayDType = float, ) -> jax.Array: - return jax.random.gumbel(seed, sample_shape, dtype) + return jax.random.gumbel(key, sample_shape, dtype) def log_prob(self, inputs: jax.Array) -> jax.Array: return -inputs - jnp.exp(-inputs) diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index 5fd759aab..5a7ccf649 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -71,18 +71,42 @@ class AddNoiseState(NamedTuple): def add_noise( - eta: float, gamma: float, seed: int + key: chex.PRNGKey, eta: float, gamma: float ) -> base.GradientTransformation: """Add gradient noise. Args: + key: a PRNG key used as the random key. eta: Base variance of the gaussian noise added to the gradient. gamma: Decay exponent for annealing of the variance. - seed: Seed for random number generation. Returns: A :class:`optax.GradientTransformation` object. + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function + >>> key = jax.random.key(0) + >>> noise = optax.add_noise(key=key, eta=0.01, gamma=0.55) + >>> sgd = optax.scale_by_learning_rate(learning_rate=0.003) + >>> solver = optax.chain(noise, sgd) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 1.38E+01 + Objective function: 1.37E+01 + Objective function: 1.35E+01 + Objective function: 1.33E+01 + Objective function: 1.32E+01 + References: Neelakantan et al, `Adding Gradient Noise Improves Learning for Very Deep Networks `_, 2015 @@ -91,7 +115,7 @@ def add_noise( def init_fn(params): del params return AddNoiseState( - count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed) + count=jnp.zeros([], jnp.int32), rng_key=key ) def update_fn(updates, state, params=None): diff --git a/optax/transforms/_adding_test.py b/optax/transforms/_adding_test.py index 61305d2be..fdb9a5613 100644 --- a/optax/transforms/_adding_test.py +++ b/optax/transforms/_adding_test.py @@ -73,9 +73,9 @@ def test_add_noise_has_correct_variance_scaling(self): # Prepare to compare noise with a rescaled unit-variance substitute. eta = 0.3 gamma = 0.55 - seed = 314 - noise = _adding.add_noise(eta, gamma, seed) - noise_unit = _adding.add_noise(1.0, 0.0, seed) + key = jax.random.key(314) + noise = _adding.add_noise(key, eta, gamma) + noise_unit = _adding.add_noise(key, 1.0, 0.0) params = self.init_params state = noise.init(params) From e38501c171e533f50eefcb2d845cd841f0d43fda Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 14:53:01 +0300 Subject: [PATCH 02/11] fixed spaces --- optax/_src/alias.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 5163f57eb..735644303 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1320,10 +1320,10 @@ def noisy_sgd( Networks `_, 2015 """ if key is None: - raise ValueError( - "noisy_sgd optimizer requires specifying key: " - "noisy_sgd(..., key=jax.random.key(0))" - ) + raise ValueError( + "noisy_sgd optimizer requires specifying key: " + "noisy_sgd(..., key=jax.random.key(0))" + ) return combine.chain( transform.add_noise(key, eta, gamma), transform.scale_by_learning_rate(learning_rate), From 8a3a909db5f67e6f7761ba6a3aa0bb94ab184999 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 14:55:40 +0300 Subject: [PATCH 03/11] returned jax.random instead of jrd --- .../stochastic_gradient_estimators_test.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/optax/monte_carlo/stochastic_gradient_estimators_test.py b/optax/monte_carlo/stochastic_gradient_estimators_test.py index 3ac06630a..ad85e0cd6 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators_test.py +++ b/optax/monte_carlo/stochastic_gradient_estimators_test.py @@ -18,7 +18,6 @@ from absl.testing import parameterized import chex import jax -import jrd as jrd import jax.numpy as jnp import numpy as np from optax._src import utils @@ -100,7 +99,7 @@ def test_constant_function(self, estimator, constant): effective_log_scale = 0.0 log_scale = effective_log_scale * _ones(data_dims) - rng = jrd.key(1) + rng = jax.random.key(1) jacobians = _estimator_variant(self.variant, estimator)( lambda x: jnp.array(constant), @@ -145,7 +144,7 @@ def test_linear_function( ): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] - rng = jrd.key(1) + rng = jax.random.key(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) @@ -186,7 +185,7 @@ def test_quadratic_function( ): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] - rng = jrd.key(1) + rng = jax.random.key(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) @@ -234,7 +233,7 @@ def test_weighted_linear( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jrd.key(1) + rng = jax.random.key(1) mean = jnp.array(effective_mean) log_scale = jnp.array(effective_log_scale) @@ -281,7 +280,7 @@ def test_weighted_quadratic( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jrd.key(1) + rng = jax.random.key(1) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) @@ -343,8 +342,8 @@ def testNonPolynomialFunctionConsistencyWithPathwise( self, effective_mean, effective_log_scale, function, coupling ): num_samples = 10**5 - rng = jrd.key(1) - measure_rng, pathwise_rng = jrd.split(rng) + rng = jax.random.key(1) + measure_rng, pathwise_rng = jax.random.split(rng) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) @@ -406,7 +405,7 @@ class MeasuredValuedEstimatorsTest(chex.TestCase): @parameterized.parameters([True, False]) def test_raises_error_for_non_gaussian(self, coupling): num_samples = 10**5 - rng = jrd.key(1) + rng = jax.random.key(1) function = lambda x: jnp.sum(x) ** 2 From 5499f4748981febf507c9373b35bcfa9ce722600 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 15:10:44 +0300 Subject: [PATCH 04/11] returned jax.random instead of jrd v2 --- .../monte_carlo/stochastic_gradient_estimators.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index b766bd97a..541b697a3 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -34,7 +34,6 @@ import chex import jax -import jax.radom as jrd import jax.numpy as jnp import numpy as np from optax._src import base @@ -242,15 +241,15 @@ def measure_valued_estimation_mean( dist_samples = dist.sample((num_samples,), seed=rng) - pos_rng, neg_rng = jrd.split(rng) - pos_sample = jrd.weibull_min( + pos_rng, neg_rng = jax.random.split(rng) + pos_sample = jax.random.weibull_min( pos_rng, scale=math.sqrt(2.0), concentration=2.0, shape=dist_samples.shape ) if coupling: neg_sample = pos_sample else: - neg_sample = jrd.weibull_min( + neg_sample = jax.random.weibull_min( neg_rng, scale=math.sqrt(2.0), concentration=2.0, @@ -315,17 +314,17 @@ def measure_valued_estimation_std( dist_samples = dist.sample((num_samples,), seed=rng) - pos_rng, neg_rng = jrd.split(rng) + pos_rng, neg_rng = jax.random.split(rng) # The only difference between mean and std gradients is what we sample. - pos_sample = jrd.double_sided_maxwell( + pos_sample = jax.random.double_sided_maxwell( pos_rng, loc=0.0, scale=1.0, shape=dist_samples.shape ) if coupling: - unif_rvs = jrd.uniform(neg_rng, dist_samples.shape) + unif_rvs = jax.random.uniform(neg_rng, dist_samples.shape) neg_sample = unif_rvs * pos_sample else: - neg_sample = jrd.normal(neg_rng, dist_samples.shape) + neg_sample = jax.random.normal(neg_rng, dist_samples.shape) # Both need to be positive in the case of the scale. # N x D From ea02fc7cb430b052dd8045591b60bbe9e6cd13ab Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 15:22:32 +0300 Subject: [PATCH 05/11] positional argument in docstring fixed --- optax/_src/alias.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 735644303..4e7bb1f56 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1298,8 +1298,7 @@ def noisy_sgd( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> key = jax.random.key(0) - >>> solver = optax.noisy_sgd(learning_rate=0.003, key) + >>> solver = optax.noisy_sgd(learning_rate=0.003, jax.random.key(0)) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 From 09eec47fe10081ae63e8c9c6bc5e93acf8ccd621 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 15:30:21 +0300 Subject: [PATCH 06/11] old-style PRNGKey returned --- optax/_src/alias.py | 4 ++-- optax/transforms/_adding.py | 2 +- optax/transforms/_adding_test.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 4e7bb1f56..b8b84634b 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1298,7 +1298,7 @@ def noisy_sgd( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> solver = optax.noisy_sgd(learning_rate=0.003, jax.random.key(0)) + >>> solver = optax.noisy_sgd(learning_rate=0.003, jax.random.PRNGKey(0)) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 @@ -1321,7 +1321,7 @@ def noisy_sgd( if key is None: raise ValueError( "noisy_sgd optimizer requires specifying key: " - "noisy_sgd(..., key=jax.random.key(0))" + "noisy_sgd(..., key=jax.random.PRNGKey(0))" ) return combine.chain( transform.add_noise(key, eta, gamma), diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index 5a7ccf649..7d463a9ba 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -88,7 +88,7 @@ def add_noise( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> key = jax.random.key(0) + >>> key = jax.random.PRNGKey(0) >>> noise = optax.add_noise(key=key, eta=0.01, gamma=0.55) >>> sgd = optax.scale_by_learning_rate(learning_rate=0.003) >>> solver = optax.chain(noise, sgd) diff --git a/optax/transforms/_adding_test.py b/optax/transforms/_adding_test.py index fdb9a5613..96e61bca0 100644 --- a/optax/transforms/_adding_test.py +++ b/optax/transforms/_adding_test.py @@ -73,7 +73,7 @@ def test_add_noise_has_correct_variance_scaling(self): # Prepare to compare noise with a rescaled unit-variance substitute. eta = 0.3 gamma = 0.55 - key = jax.random.key(314) + key = jax.random.PRNGKey(314) noise = _adding.add_noise(key, eta, gamma) noise_unit = _adding.add_noise(key, 1.0, 0.0) From 4562d2f87e548f6772609cd0099a1434672c5f91 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 15:40:48 +0300 Subject: [PATCH 07/11] replaced keys in several test files --- optax/_src/alias_test.py | 10 +++++----- optax/_src/float64_test.py | 4 ++-- optax/assignment/_hungarian_algorithm_test.py | 6 +++--- optax/contrib/_common_test.py | 6 +++++- optax/contrib/_privacy.py | 4 ++-- optax/contrib/_privacy_test.py | 8 ++++---- optax/monte_carlo/control_variates.py | 2 +- .../monte_carlo/stochastic_gradient_estimators.py | 8 ++++---- .../stochastic_gradient_estimators_test.py | 14 +++++++------- optax/tree_utils/_state_utils.py | 6 ++++-- optax/tree_utils/_state_utils_test.py | 9 +++++---- 11 files changed, 42 insertions(+), 35 deletions(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 71a3c0426..43087bb96 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -66,7 +66,7 @@ {'opt_name': 'nadamw', 'opt_kwargs': {'learning_rate': 1e-2}}, { 'opt_name': 'noisy_sgd', - 'opt_kwargs': {'learning_rate': 1e-3, 'key': jrd.key(0), 'eta': 1e-4}, + 'opt_kwargs': {'learning_rate': 1e-3, 'key': jrd.PRNGKey(0), 'eta': 1e-4}, }, {'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1e-3}}, { @@ -573,7 +573,7 @@ def zakharov(x, xnp): class LBFGSTest(chex.TestCase): def test_plain_preconditioning(self): - key = jrd.key(0) + key = jrd.PRNGKey(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 d = 3 @@ -592,7 +592,7 @@ def test_plain_preconditioning(self): @parameterized.product(idx=[0, 1, 2, 3]) def test_preconditioning_by_lbfgs_on_vectors(self, idx: int): - key = jrd.key(0) + key = jrd.PRNGKey(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 d = 3 @@ -619,7 +619,7 @@ def test_preconditioning_by_lbfgs_on_vectors(self, idx: int): @parameterized.product(idx=[0, 1, 2, 3]) def test_preconditioning_by_lbfgs_on_trees(self, idx: int): - key = jrd.key(0) + key = jrd.PRNGKey(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 shapes = ((3, 2), (5,)) @@ -721,7 +721,7 @@ def fun_(x): def fun(x): return otu.tree_sum(jax.tree.map(fun_, x)) - key = jrd.key(0) + key = jrd.PRNGKey(0) init_array = jrd.normal(key, (2, 4)) init_tree = (init_array[0], init_array[1]) diff --git a/optax/_src/float64_test.py b/optax/_src/float64_test.py index 1eadc1c09..4de38007d 100644 --- a/optax/_src/float64_test.py +++ b/optax/_src/float64_test.py @@ -48,14 +48,14 @@ {'step_size_fn': lambda x: x * 0.1}, ), ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), - ('add_noise', transform.add_noise, {'eta': 1.0, 'gamma': 0.1, 'seed': 42}), + ('add_noise', transform.add_noise, {'key': jax.random.PRNGKey(42), 'eta': 1.0, 'gamma': 0.1}), ('apply_every_k', transform.apply_every, {}), ('adagrad', alias.adagrad, {'learning_rate': 0.1}), ('adam', alias.adam, {'learning_rate': 0.1}), ('adamw', alias.adamw, {'learning_rate': 0.1}), ('fromage', alias.fromage, {'learning_rate': 0.1}), ('lamb', alias.lamb, {'learning_rate': 0.1}), - ('noisy_sgd', alias.noisy_sgd, {'learning_rate': 0.1}), + ('noisy_sgd', alias.noisy_sgd, {'learning_rate': 0.1, 'key': jax.random.PRNGKey(0)}), ('rmsprop', alias.rmsprop, {'learning_rate': 0.1}), ('sgd', alias.sgd, {'learning_rate': 0.1}), ('sign_sgd', alias.sgd, {'learning_rate': 0.1}), diff --git a/optax/assignment/_hungarian_algorithm_test.py b/optax/assignment/_hungarian_algorithm_test.py index 7cf3aeb59..a724d9412 100644 --- a/optax/assignment/_hungarian_algorithm_test.py +++ b/optax/assignment/_hungarian_algorithm_test.py @@ -30,7 +30,7 @@ class HungarianAlgorithmTest(parameterized.TestCase): m=[0, 1, 2, 4, 8, 16], ) def test_hungarian_algorithm(self, n, m): - key = jrd.key(0) + key = jrd.PRNGKey(0) costs = jrd.normal(key, (n, m)) i, j = _hungarian_algorithm.hungarian_algorithm(costs) @@ -91,7 +91,7 @@ def test_hungarian_algorithm(self, n, m): m=[0, 1, 2, 4], ) def test_hungarian_algorithm_vmap(self, k, n, m): - key = jrd.key(0) + key = jrd.PRNGKey(0) costs = jrd.normal(key, (k, n, m)) with self.subTest('works under vmap'): @@ -106,7 +106,7 @@ def test_hungarian_algorithm_vmap(self, k, n, m): assert j.shape == (k, r) def test_hungarian_algorithm_jit(self): - key = jrd.key(0) + key = jrd.PRNGKey(0) costs = jrd.normal(key, (20, 30)) with self.subTest('works under jit'): diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index 0c41eac53..4b7c6e413 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -101,7 +101,11 @@ {'opt_name': 'lion', 'opt_kwargs': {'learning_rate': 1.0, 'b1': 0.99}}, { 'opt_name': 'noisy_sgd', - 'opt_kwargs': {'learning_rate': 1.0, 'eta': 1e-4}, + 'opt_kwargs': { + 'learning_rate': 1.0, + 'key': jax.random.PRNGKey(0), + 'eta': 1e-4 + }, }, {'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1.0}}, { diff --git a/optax/contrib/_privacy.py b/optax/contrib/_privacy.py index a6edcae3e..50d8f746a 100644 --- a/optax/contrib/_privacy.py +++ b/optax/contrib/_privacy.py @@ -62,7 +62,7 @@ def differentially_private_aggregate( if key is None: raise ValueError( "differentially_private_aggregate optimizer requires specifying key: " - "differentially_private_aggregate(..., key=jax.random.key(0))" + "differentially_private_aggregate(..., key=jax.random.PRNGKey(0))" ) noise_std = l2_norm_clip * noise_multiplier @@ -128,7 +128,7 @@ def dpsgd( if key is None: raise ValueError( "dpsgd optimizer requires specifying key: " - "dpsgd(..., key=jax.random.key(0))" + "dpsgd(..., key=jax.random.PRNGKey(0))" ) return combine.chain( differentially_private_aggregate( diff --git a/optax/contrib/_privacy_test.py b/optax/contrib/_privacy_test.py index f00b6eed3..b2719ff04 100644 --- a/optax/contrib/_privacy_test.py +++ b/optax/contrib/_privacy_test.py @@ -48,7 +48,7 @@ def test_no_privacy(self): dp_agg = _privacy.differentially_private_aggregate( l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, - key=jrd.key(0) + key=jrd.PRNGKey(0) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -64,7 +64,7 @@ def test_clipping_norm(self, l2_norm_clip): dp_agg = _privacy.differentially_private_aggregate( l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, - key=jrd.key(42) + key=jrd.PRNGKey(42) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -94,7 +94,7 @@ def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): dp_agg = _privacy.differentially_private_aggregate( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, - key=jrd.key(1337) + key=jrd.PRNGKey(1337) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -112,7 +112,7 @@ def test_aggregated_updates_as_input_fails(self): dp_agg = _privacy.differentially_private_aggregate( l2_norm_clip=0.1, noise_multiplier=1.1, - key=jrd.key(2021) + key=jrd.PRNGKey(2021) ) state = dp_agg.init(self.params) mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads) diff --git a/optax/monte_carlo/control_variates.py b/optax/monte_carlo/control_variates.py index 9cfd9f649..6c6421257 100644 --- a/optax/monte_carlo/control_variates.py +++ b/optax/monte_carlo/control_variates.py @@ -310,7 +310,7 @@ def control_variates_jacobians( # gradient. # The rng has to be the same as passed to the grad_estimator above so that we # obtain the same samples. - samples = dist_builder(*params).sample((num_samples,), seed=rng) + samples = dist_builder(*params).sample((num_samples,), key=rng) # If the CV has state, update it. control_variate_state = update_state_cv( params, samples, control_variate_state diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index 541b697a3..50fd86ee1 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -85,7 +85,7 @@ def score_function_jacobians( def surrogate(params): dist = dist_builder(*params) one_sample_surrogate_fn = lambda x: function(x) * dist.log_prob(x) - samples = jax.lax.stop_gradient(dist.sample((num_samples,), seed=rng)) + samples = jax.lax.stop_gradient(dist.sample((num_samples,), key=rng)) # We vmap the function application over samples - this ensures that the # function we use does not have to be vectorized itself. return jax.vmap(one_sample_surrogate_fn)(samples) @@ -141,7 +141,7 @@ def surrogate(params): # We vmap the function application over samples - this ensures that the # function we use does not have to be vectorized itself. dist = dist_builder(*params) - return jax.vmap(function)(dist.sample((num_samples,), seed=rng)) + return jax.vmap(function)(dist.sample((num_samples,), key=rng)) return jax.jacfwd(surrogate)(params) @@ -239,7 +239,7 @@ def measure_valued_estimation_mean( mean, log_std = dist.params std = jnp.exp(log_std) - dist_samples = dist.sample((num_samples,), seed=rng) + dist_samples = dist.sample((num_samples,), key=rng) pos_rng, neg_rng = jax.random.split(rng) pos_sample = jax.random.weibull_min( @@ -312,7 +312,7 @@ def measure_valued_estimation_std( mean, log_std = dist.params std = jnp.exp(log_std) - dist_samples = dist.sample((num_samples,), seed=rng) + dist_samples = dist.sample((num_samples,), key=rng) pos_rng, neg_rng = jax.random.split(rng) diff --git a/optax/monte_carlo/stochastic_gradient_estimators_test.py b/optax/monte_carlo/stochastic_gradient_estimators_test.py index ad85e0cd6..a54501c6e 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators_test.py +++ b/optax/monte_carlo/stochastic_gradient_estimators_test.py @@ -99,7 +99,7 @@ def test_constant_function(self, estimator, constant): effective_log_scale = 0.0 log_scale = effective_log_scale * _ones(data_dims) - rng = jax.random.key(1) + rng = jax.random.PRNGKey(1) jacobians = _estimator_variant(self.variant, estimator)( lambda x: jnp.array(constant), @@ -144,7 +144,7 @@ def test_linear_function( ): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] - rng = jax.random.key(1) + rng = jax.random.PRNGKey(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) @@ -185,7 +185,7 @@ def test_quadratic_function( ): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] - rng = jax.random.key(1) + rng = jax.random.PRNGKey(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) @@ -233,7 +233,7 @@ def test_weighted_linear( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jax.random.key(1) + rng = jax.random.PRNGKey(1) mean = jnp.array(effective_mean) log_scale = jnp.array(effective_log_scale) @@ -280,7 +280,7 @@ def test_weighted_quadratic( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jax.random.key(1) + rng = jax.random.PRNGKey(1) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) @@ -342,7 +342,7 @@ def testNonPolynomialFunctionConsistencyWithPathwise( self, effective_mean, effective_log_scale, function, coupling ): num_samples = 10**5 - rng = jax.random.key(1) + rng = jax.random.PRNGKey(1) measure_rng, pathwise_rng = jax.random.split(rng) mean = jnp.array(effective_mean, dtype=jnp.float32) @@ -405,7 +405,7 @@ class MeasuredValuedEstimatorsTest(chex.TestCase): @parameterized.parameters([True, False]) def test_raises_error_for_non_gaussian(self, coupling): num_samples = 10**5 - rng = jax.random.key(1) + rng = jax.random.PRNGKey(1) function = lambda x: jnp.sum(x) ** 2 diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index 168e4520c..ad282b88d 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -341,7 +341,8 @@ def tree_get( >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( - ... optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam() + ... optax.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), + ... optax.scale_by_adam() ... ) >>> state = opt.init(params) >>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState') @@ -354,7 +355,8 @@ def tree_get( >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( - ... optax.add_noise(1.0, 0.9, 0), optax.scale_by_adam() + ... optax.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), + ... optax.scale_by_adam() ... ) >>> state = opt.init(params) >>> filtering = ( diff --git a/optax/tree_utils/_state_utils_test.py b/optax/tree_utils/_state_utils_test.py index 069ac5dc5..a7cd56980 100644 --- a/optax/tree_utils/_state_utils_test.py +++ b/optax/tree_utils/_state_utils_test.py @@ -417,7 +417,7 @@ def get_learning_rate(state): with self.subTest('Test filtering for specific state'): opt = combine.chain( - transform.add_noise(1.0, 0.9, 0), transform.scale_by_adam() + transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), transform.scale_by_adam() ) state = opt.init(params) @@ -432,7 +432,7 @@ def filtering(path, _): with self.subTest('Test extracting a state'): opt = combine.chain( - transform.add_noise(1.0, 0.9, 0), transform.scale_by_adam() + transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), transform.scale_by_adam() ) state = opt.init(params) noise_state = _state_utils.tree_get(state, 'AddNoiseState') @@ -534,7 +534,7 @@ def set_learning_rate(state, lr): with self.subTest('Test setting a specific state'): opt = combine.chain( - transform.add_noise(1.0, 0.9, 0), transform.scale_by_adam() + transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), transform.scale_by_adam() ) state = opt.init(params) @@ -560,7 +560,8 @@ def filtering(path, _): with self.subTest('Test setting a state'): opt = combine.chain( - transform.add_noise(1.0, 0.9, 0), transform.scale_by_adam() + transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), + transform.scale_by_adam() ) state = opt.init(params) new_noise_state = transform.AddNoiseState( From 7500da2f5a2bb9a6a1815631302111d1f375b940 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 15:47:17 +0300 Subject: [PATCH 08/11] changed dict layout --- optax/_src/float64_test.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/optax/_src/float64_test.py b/optax/_src/float64_test.py index 4de38007d..242ce2c9e 100644 --- a/optax/_src/float64_test.py +++ b/optax/_src/float64_test.py @@ -48,14 +48,22 @@ {'step_size_fn': lambda x: x * 0.1}, ), ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), - ('add_noise', transform.add_noise, {'key': jax.random.PRNGKey(42), 'eta': 1.0, 'gamma': 0.1}), + ( + 'add_noise', + transform.add_noise, + {'key': jax.random.PRNGKey(42), 'eta': 1.0, 'gamma': 0.1} + ), ('apply_every_k', transform.apply_every, {}), ('adagrad', alias.adagrad, {'learning_rate': 0.1}), ('adam', alias.adam, {'learning_rate': 0.1}), ('adamw', alias.adamw, {'learning_rate': 0.1}), ('fromage', alias.fromage, {'learning_rate': 0.1}), ('lamb', alias.lamb, {'learning_rate': 0.1}), - ('noisy_sgd', alias.noisy_sgd, {'learning_rate': 0.1, 'key': jax.random.PRNGKey(0)}), + ( + 'noisy_sgd', + alias.noisy_sgd, + {'learning_rate': 0.1, 'key': jax.random.PRNGKey(0)} + ), ('rmsprop', alias.rmsprop, {'learning_rate': 0.1}), ('sgd', alias.sgd, {'learning_rate': 0.1}), ('sign_sgd', alias.sgd, {'learning_rate': 0.1}), From 521a013d4f17653a4dd1b1854fea66e739ad9608 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 15:51:35 +0300 Subject: [PATCH 09/11] shorted lines --- optax/_src/alias_test.py | 6 +++++- optax/tree_utils/_state_utils_test.py | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 43087bb96..66162a9f4 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -66,7 +66,11 @@ {'opt_name': 'nadamw', 'opt_kwargs': {'learning_rate': 1e-2}}, { 'opt_name': 'noisy_sgd', - 'opt_kwargs': {'learning_rate': 1e-3, 'key': jrd.PRNGKey(0), 'eta': 1e-4}, + 'opt_kwargs': { + 'learning_rate': 1e-3, + 'key': jrd.PRNGKey(0), + 'eta': 1e-4 + }, }, {'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1e-3}}, { diff --git a/optax/tree_utils/_state_utils_test.py b/optax/tree_utils/_state_utils_test.py index a7cd56980..427d66104 100644 --- a/optax/tree_utils/_state_utils_test.py +++ b/optax/tree_utils/_state_utils_test.py @@ -417,7 +417,8 @@ def get_learning_rate(state): with self.subTest('Test filtering for specific state'): opt = combine.chain( - transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), transform.scale_by_adam() + transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), + transform.scale_by_adam() ) state = opt.init(params) @@ -432,7 +433,8 @@ def filtering(path, _): with self.subTest('Test extracting a state'): opt = combine.chain( - transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), transform.scale_by_adam() + transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), + transform.scale_by_adam() ) state = opt.init(params) noise_state = _state_utils.tree_get(state, 'AddNoiseState') From 4ef4820e7bd3019363bbe98d87b1a73bdedb50dc Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 15:53:11 +0300 Subject: [PATCH 10/11] shorted lines --- optax/tree_utils/_state_utils_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optax/tree_utils/_state_utils_test.py b/optax/tree_utils/_state_utils_test.py index 427d66104..8ae2e5dd4 100644 --- a/optax/tree_utils/_state_utils_test.py +++ b/optax/tree_utils/_state_utils_test.py @@ -536,7 +536,8 @@ def set_learning_rate(state, lr): with self.subTest('Test setting a specific state'): opt = combine.chain( - transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), transform.scale_by_adam() + transform.add_noise(jax.random.PRNGKey(0), 1.0, 0.9), + transform.scale_by_adam() ) state = opt.init(params) From f62e115b1439257d8e9a529bf810afae0238a141 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Mon, 6 Jan 2025 16:13:03 +0300 Subject: [PATCH 11/11] fixed key-word argument in noisy_sgd --- optax/_src/alias.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index b8b84634b..4f6ea9c66 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1298,7 +1298,9 @@ def noisy_sgd( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> solver = optax.noisy_sgd(learning_rate=0.003, jax.random.PRNGKey(0)) + >>> solver = optax.noisy_sgd( + ... learning_rate=0.003, + ... key=jax.random.PRNGKey(0)) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0