diff --git a/examples/lm1b/train.py b/examples/lm1b/train.py index e2e72df7d7..96b5d0c54c 100644 --- a/examples/lm1b/train.py +++ b/examples/lm1b/train.py @@ -582,7 +582,7 @@ def encode_strings(strs, max_len): ) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary["perplexity"] = jnp.clip( - jnp.exp(summary["loss"]), a_max=1.0e4 + jnp.exp(summary["loss"]), max=1.0e4 ) summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) @@ -598,7 +598,7 @@ def encode_strings(strs, max_len): ) # (clipped) perplexity after averaging log-perplexitie eval_results["perplexity"] = jnp.clip( - jnp.exp(eval_results["loss"]), a_max=1.0e4 + jnp.exp(eval_results["loss"]), max=1.0e4 ) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()} diff --git a/flax/experimental/nnx/examples/lm1b/train.py b/flax/experimental/nnx/examples/lm1b/train.py index 5bd289ac31..ed3f5986d0 100644 --- a/flax/experimental/nnx/examples/lm1b/train.py +++ b/flax/experimental/nnx/examples/lm1b/train.py @@ -606,7 +606,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): ) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary['perplexity'] = jnp.clip( - jnp.exp(summary['loss']), a_max=1.0e4 + jnp.exp(summary['loss']), max=1.0e4 ) summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) @@ -621,7 +621,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): ) # (clipped) perplexity after averaging log-perplexitie eval_results['perplexity'] = jnp.clip( - jnp.exp(eval_results['loss']), a_max=1.0e4 + jnp.exp(eval_results['loss']), max=1.0e4 ) writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}