From 268b7dc2a1a570f596c6708a89dab48d9baea810 Mon Sep 17 00:00:00 2001 From: Shaobo Hou Date: Mon, 4 Dec 2023 09:46:19 -0800 Subject: [PATCH] Enforce autograph=False on gradient function. PiperOrigin-RevId: 587754205 --- tf2jax/_src/tf2jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2jax/_src/tf2jax.py b/tf2jax/_src/tf2jax.py index 34f60fc..6db45f5 100644 --- a/tf2jax/_src/tf2jax.py +++ b/tf2jax/_src/tf2jax.py @@ -1405,7 +1405,7 @@ def _convert_gradient_function( library[grad_fn_name] = None return - @tf.function + @tf.function(autograph=False) def tf_grad_fn(*grad_args, **grad_kwargs): fn = tf_ops.gradient_registry.lookup(grad_fn_name) return fn(None, *grad_args, **grad_kwargs)