Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replaced seed with key #1167

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
18 changes: 13 additions & 5 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from optax._src import linesearch as _linesearch
from optax._src import transform
from optax._src import wrappers
import chex


MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]
Expand Down Expand Up @@ -1254,9 +1255,9 @@ def lamb(

def noisy_sgd(
learning_rate: base.ScalarOrSchedule,
key: Optional[chex.PRNGKey] = None,
eta: float = 0.01,
gamma: float = 0.55,
seed: int = 0,
) -> base.GradientTransformation:
r"""A variant of SGD with added noise.

Expand Down Expand Up @@ -1284,10 +1285,10 @@ def noisy_sgd(
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
key: a PRNG key used as the random key.
eta: Initial variance for the Gaussian noise added to gradients.
gamma: A parameter controlling the annealing of noise over time ``t``, the
variance decays according to ``(1+t)**(-gamma)``.
seed: Seed for the pseudo-random generation process.

Returns:
The corresponding :class:`optax.GradientTransformation`.
Expand All @@ -1297,7 +1298,9 @@ def noisy_sgd(
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.noisy_sgd(learning_rate=0.003)
>>> solver = optax.noisy_sgd(
... learning_rate=0.003,
... key=jax.random.PRNGKey(0))
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
Expand All @@ -1317,8 +1320,13 @@ def noisy_sgd(
Neelakantan et al, `Adding Gradient Noise Improves Learning for Very Deep
Networks <https://arxiv.org/abs/1511.06807>`_, 2015
"""
if key is None:
raise ValueError(
"noisy_sgd optimizer requires specifying key: "
"noisy_sgd(..., key=jax.random.PRNGKey(0))"
)
return combine.chain(
transform.add_noise(eta, gamma, seed),
transform.add_noise(key, eta, gamma),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down Expand Up @@ -2394,7 +2402,7 @@ def lbfgs(
linesearch: Optional[
base.GradientTransformationExtraArgs
] = _linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=20, initial_guess_strategy='one'
max_linesearch_steps=20, initial_guess_strategy="one"
),
) -> base.GradientTransformationExtraArgs:
r"""L-BFGS optimizer.
Expand Down
6 changes: 5 additions & 1 deletion optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
{'opt_name': 'nadamw', 'opt_kwargs': {'learning_rate': 1e-2}},
{
'opt_name': 'noisy_sgd',
'opt_kwargs': {'learning_rate': 1e-3, 'eta': 1e-4},
'opt_kwargs': {
'learning_rate': 1e-3,
'key': jrd.PRNGKey(0),
'eta': 1e-4
},
},
{'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1e-3}},
{
Expand Down
12 changes: 10 additions & 2 deletions optax/_src/float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@
{'step_size_fn': lambda x: x * 0.1},
),
('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}),
('add_noise', transform.add_noise, {'eta': 1.0, 'gamma': 0.1, 'seed': 42}),
(
'add_noise',
transform.add_noise,
{'key': jax.random.PRNGKey(42), 'eta': 1.0, 'gamma': 0.1}
),
('apply_every_k', transform.apply_every, {}),
('adagrad', alias.adagrad, {'learning_rate': 0.1}),
('adam', alias.adam, {'learning_rate': 0.1}),
('adamw', alias.adamw, {'learning_rate': 0.1}),
('fromage', alias.fromage, {'learning_rate': 0.1}),
('lamb', alias.lamb, {'learning_rate': 0.1}),
('noisy_sgd', alias.noisy_sgd, {'learning_rate': 0.1}),
(
'noisy_sgd',
alias.noisy_sgd,
{'learning_rate': 0.1, 'key': jax.random.PRNGKey(0)}
),
('rmsprop', alias.rmsprop, {'learning_rate': 0.1}),
('sgd', alias.sgd, {'learning_rate': 0.1}),
('sign_sgd', alias.sgd, {'learning_rate': 0.1}),
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def __init__(self, loc: chex.Array, log_scale: chex.Array):
self._mean.shape, self._scale.shape
)

def sample(self, shape: Sequence[int], seed: chex.PRNGKey) -> chex.Array:
def sample(self, shape: Sequence[int], key: chex.PRNGKey) -> chex.Array:
sample_shape = tuple(shape) + self._param_shape
return (
jax.random.normal(seed, shape=sample_shape) * self._scale + self._mean
jax.random.normal(key, shape=sample_shape) * self._scale + self._mean
)

def log_prob(self, x: chex.Array) -> chex.Array:
Expand Down
6 changes: 3 additions & 3 deletions optax/assignment/_hungarian_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class HungarianAlgorithmTest(parameterized.TestCase):
m=[0, 1, 2, 4, 8, 16],
)
def test_hungarian_algorithm(self, n, m):
key = jrd.key(0)
key = jrd.PRNGKey(0)
costs = jrd.normal(key, (n, m))

i, j = _hungarian_algorithm.hungarian_algorithm(costs)
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_hungarian_algorithm(self, n, m):
m=[0, 1, 2, 4],
)
def test_hungarian_algorithm_vmap(self, k, n, m):
key = jrd.key(0)
key = jrd.PRNGKey(0)
costs = jrd.normal(key, (k, n, m))

with self.subTest('works under vmap'):
Expand All @@ -106,7 +106,7 @@ def test_hungarian_algorithm_vmap(self, k, n, m):
assert j.shape == (k, r)

def test_hungarian_algorithm_jit(self):
key = jrd.key(0)
key = jrd.PRNGKey(0)
costs = jrd.normal(key, (20, 30))

with self.subTest('works under jit'):
Expand Down
6 changes: 5 additions & 1 deletion optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@
{'opt_name': 'lion', 'opt_kwargs': {'learning_rate': 1.0, 'b1': 0.99}},
{
'opt_name': 'noisy_sgd',
'opt_kwargs': {'learning_rate': 1.0, 'eta': 1e-4},
'opt_kwargs': {
'learning_rate': 1.0,
'key': jax.random.PRNGKey(0),
'eta': 1e-4
},
},
{'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1.0}},
{
Expand Down
25 changes: 19 additions & 6 deletions optax/contrib/_privacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, NamedTuple, Optional

import jax
import chex
from optax._src import base
from optax._src import clipping
from optax._src import combine
Expand All @@ -33,14 +34,16 @@ class DifferentiallyPrivateAggregateState(NamedTuple):


def differentially_private_aggregate(
l2_norm_clip: float, noise_multiplier: float, seed: int
l2_norm_clip: float,
noise_multiplier: float,
key: Optional[chex.PRNGKey] = None
) -> base.GradientTransformation:
"""Aggregates gradients based on the DPSGD algorithm.

Args:
l2_norm_clip: maximum L2 norm of the per-example gradients.
noise_multiplier: ratio of standard deviation to the clipping norm.
seed: initial seed used for the jax.random.PRNGKey
key: a PRNG key used as the random key.

Returns:
A :class:`optax.GradientTransformation`.
Expand All @@ -56,11 +59,16 @@ def differentially_private_aggregate(
JAX using `jax.vmap`). It can still be composed with other transformations
as long as it is the first in the chain.
"""
if key is None:
raise ValueError(
"differentially_private_aggregate optimizer requires specifying key: "
"differentially_private_aggregate(..., key=jax.random.PRNGKey(0))"
)
noise_std = l2_norm_clip * noise_multiplier

def init_fn(params):
del params
return DifferentiallyPrivateAggregateState(rng_key=jax.random.PRNGKey(seed))
return DifferentiallyPrivateAggregateState(rng_key=key)

def update_fn(updates, state, params=None):
del params
Expand All @@ -85,7 +93,7 @@ def dpsgd(
learning_rate: base.ScalarOrSchedule,
l2_norm_clip: float,
noise_multiplier: float,
seed: int,
key: Optional[chex.PRNGKey] = None,
momentum: Optional[float] = None,
nesterov: bool = False,
) -> base.GradientTransformation:
Expand All @@ -100,7 +108,7 @@ def dpsgd(
learning_rate: A fixed global scaling factor.
l2_norm_clip: Maximum L2 norm of the per-example gradients.
noise_multiplier: Ratio of standard deviation to the clipping norm.
seed: Initial seed used for the jax.random.PRNGKey
key: a PRNG key used as the random key.
momentum: Decay rate used by the momentum term, when it is set to `None`,
then momentum is not used at all.
nesterov: Whether Nesterov momentum is used.
Expand All @@ -117,11 +125,16 @@ def dpsgd(
batch dimension on the 0th axis. That is, this function expects per-example
gradients as input (which are easy to obtain in JAX using `jax.vmap`).
"""
if key is None:
raise ValueError(
"dpsgd optimizer requires specifying key: "
"dpsgd(..., key=jax.random.PRNGKey(0))"
)
return combine.chain(
differentially_private_aggregate(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
seed=seed,
key=key,
),
(
transform.trace(decay=momentum, nesterov=nesterov)
Expand Down
17 changes: 13 additions & 4 deletions optax/contrib/_privacy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import chex
import jax
import jax.numpy as jnp
import jax.random as jrd
from optax.contrib import _privacy


Expand All @@ -45,7 +46,9 @@ def setUp(self):
def test_no_privacy(self):
"""l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD."""
dp_agg = _privacy.differentially_private_aggregate(
l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, seed=0
l2_norm_clip=jnp.finfo(jnp.float32).max,
noise_multiplier=0.0,
key=jrd.PRNGKey(0)
)
state = dp_agg.init(self.params)
update_fn = self.variant(dp_agg.update)
Expand All @@ -59,7 +62,9 @@ def test_no_privacy(self):
@parameterized.parameters(0.5, 10.0, 20.0, 40.0, 80.0)
def test_clipping_norm(self, l2_norm_clip):
dp_agg = _privacy.differentially_private_aggregate(
l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, seed=42
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
key=jrd.PRNGKey(42)
)
state = dp_agg.init(self.params)
update_fn = self.variant(dp_agg.update)
Expand Down Expand Up @@ -87,7 +92,9 @@ def test_clipping_norm(self, l2_norm_clip):
def test_noise_multiplier(self, l2_norm_clip, noise_multiplier):
"""Standard dev. of noise should be l2_norm_clip * noise_multiplier."""
dp_agg = _privacy.differentially_private_aggregate(
l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, seed=1337
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
key=jrd.PRNGKey(1337)
)
state = dp_agg.init(self.params)
update_fn = self.variant(dp_agg.update)
Expand All @@ -103,7 +110,9 @@ def test_noise_multiplier(self, l2_norm_clip, noise_multiplier):
def test_aggregated_updates_as_input_fails(self):
"""Expect per-example gradients as input to this transform."""
dp_agg = _privacy.differentially_private_aggregate(
l2_norm_clip=0.1, noise_multiplier=1.1, seed=2021
l2_norm_clip=0.1,
noise_multiplier=1.1,
key=jrd.PRNGKey(2021)
)
state = dp_agg.init(self.params)
mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads)
Expand Down
2 changes: 1 addition & 1 deletion optax/monte_carlo/control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def control_variates_jacobians(
# gradient.
# The rng has to be the same as passed to the grad_estimator above so that we
# obtain the same samples.
samples = dist_builder(*params).sample((num_samples,), seed=rng)
samples = dist_builder(*params).sample((num_samples,), key=rng)
# If the CV has state, update it.
control_variate_state = update_state_cv(
params, samples, control_variate_state
Expand Down
8 changes: 4 additions & 4 deletions optax/monte_carlo/stochastic_gradient_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def score_function_jacobians(
def surrogate(params):
dist = dist_builder(*params)
one_sample_surrogate_fn = lambda x: function(x) * dist.log_prob(x)
samples = jax.lax.stop_gradient(dist.sample((num_samples,), seed=rng))
samples = jax.lax.stop_gradient(dist.sample((num_samples,), key=rng))
# We vmap the function application over samples - this ensures that the
# function we use does not have to be vectorized itself.
return jax.vmap(one_sample_surrogate_fn)(samples)
Expand Down Expand Up @@ -141,7 +141,7 @@ def surrogate(params):
# We vmap the function application over samples - this ensures that the
# function we use does not have to be vectorized itself.
dist = dist_builder(*params)
return jax.vmap(function)(dist.sample((num_samples,), seed=rng))
return jax.vmap(function)(dist.sample((num_samples,), key=rng))

return jax.jacfwd(surrogate)(params)

Expand Down Expand Up @@ -239,7 +239,7 @@ def measure_valued_estimation_mean(
mean, log_std = dist.params
std = jnp.exp(log_std)

dist_samples = dist.sample((num_samples,), seed=rng)
dist_samples = dist.sample((num_samples,), key=rng)

pos_rng, neg_rng = jax.random.split(rng)
pos_sample = jax.random.weibull_min(
Expand Down Expand Up @@ -312,7 +312,7 @@ def measure_valued_estimation_std(
mean, log_std = dist.params
std = jnp.exp(log_std)

dist_samples = dist.sample((num_samples,), seed=rng)
dist_samples = dist.sample((num_samples,), key=rng)

pos_rng, neg_rng = jax.random.split(rng)

Expand Down
8 changes: 4 additions & 4 deletions optax/perturbations/_make_pert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class Normal:

def sample(
self,
seed: chex.PRNGKey,
key: chex.PRNGKey,
sample_shape: Shape,
dtype: chex.ArrayDType = float,
) -> jax.Array:
return jax.random.normal(seed, sample_shape, dtype)
return jax.random.normal(key, sample_shape, dtype)

def log_prob(self, inputs: jax.Array) -> jax.Array:
return -0.5 * inputs**2
Expand All @@ -50,11 +50,11 @@ class Gumbel:

def sample(
self,
seed: chex.PRNGKey,
key: chex.PRNGKey,
sample_shape: Shape,
dtype: chex.ArrayDType = float,
) -> jax.Array:
return jax.random.gumbel(seed, sample_shape, dtype)
return jax.random.gumbel(key, sample_shape, dtype)

def log_prob(self, inputs: jax.Array) -> jax.Array:
return -inputs - jnp.exp(-inputs)
Expand Down
Loading
Loading