diff --git a/haiku/_src/reshape_test.py b/haiku/_src/reshape_test.py index d71a9017a..a91c01aaf 100644 --- a/haiku/_src/reshape_test.py +++ b/haiku/_src/reshape_test.py @@ -87,7 +87,7 @@ def f(inputs): # We convert `f` using `jax2tf` with undefined shape converted_f = jax2tf.convert( apply_fn, - polymorphic_shapes=[None, None, jax2tf.PolyShape("_", "T", ...)], # pytype: disable=wrong-arg-count + polymorphic_shapes=[None, None, "_, T, ..."], # pytype: disable=wrong-arg-count with_gradient=True, )