From 58bac0ac375732dce358cf85583ae7fe3632b8cf Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 12 Dec 2023 04:33:32 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 590154234 --- tests/perturbations_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/perturbations_test.py b/tests/perturbations_test.py index 5375ec34..66259e27 100644 --- a/tests/perturbations_test.py +++ b/tests/perturbations_test.py @@ -333,7 +333,7 @@ def test_rank_finite_diff(self): delta_num = (sq_loss_plus_h - sq_loss_minus_h) / (2 * eps) delta_lin = jnp.sum(gradient_square_rank * h) - self.assertArraysAllClose(delta_num, delta_lin, atol=5e-2) + self.assertArraysAllClose(delta_num, delta_lin, atol=1e-1, rtol=.5) class PerturbationsMaxTest(test_util.JaxoptTestCase): @@ -571,7 +571,7 @@ def test_noise_iid(self, control_variate): rngs_batch) self.assertArraysAllClose(pert_scalar_repeat[0], pert_scalar_repeat[1], - atol=2e-2) + atol=1e-1, rtol=1e-2) delta_noise = pert_scalar_repeat[0] - pert_scalar_repeat[1] self.assertNotAlmostEqual(jnp.linalg.norm(delta_noise), 0) @@ -752,7 +752,7 @@ def test_noise_iid(self): pert_repeat = jax.vmap(pert_fun)(theta_batch_repeat, rngs_batch) self.assertArraysAllClose(pert_repeat[0], pert_repeat[1], - atol=2e-2) + atol=5e-2, rtol=5e-2) delta_noise = pert_repeat[0] - pert_repeat[1] self.assertNotAlmostEqual(jnp.linalg.norm(delta_noise), 0)