diff --git a/tf2jax/_src/tf2jax.py b/tf2jax/_src/tf2jax.py index 34f60fc..6db45f5 100644 --- a/tf2jax/_src/tf2jax.py +++ b/tf2jax/_src/tf2jax.py @@ -1405,7 +1405,7 @@ def _convert_gradient_function( library[grad_fn_name] = None return - @tf.function + @tf.function(autograph=False) def tf_grad_fn(*grad_args, **grad_kwargs): fn = tf_ops.gradient_registry.lookup(grad_fn_name) return fn(None, *grad_args, **grad_kwargs)