From fe400bb2bf4ba1dbe6ea4353fd0002500eec243c Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 17 Dec 2024 03:58:40 -0800 Subject: [PATCH] Remove usage of jax2tf.PolyShape in haiku. jax2tf.PolyShape has been deprecated since January 2024. Instead, we can use simple strings. PiperOrigin-RevId: 707033476 --- haiku/_src/reshape_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haiku/_src/reshape_test.py b/haiku/_src/reshape_test.py index d71a9017a..a91c01aaf 100644 --- a/haiku/_src/reshape_test.py +++ b/haiku/_src/reshape_test.py @@ -87,7 +87,7 @@ def f(inputs): # We convert `f` using `jax2tf` with undefined shape converted_f = jax2tf.convert( apply_fn, - polymorphic_shapes=[None, None, jax2tf.PolyShape("_", "T", ...)], # pytype: disable=wrong-arg-count + polymorphic_shapes=[None, None, "_, T, ..."], # pytype: disable=wrong-arg-count with_gradient=True, )