Skip to content

Commit

Permalink
Remove usage of jax2tf.PolyShape in haiku.
Browse files Browse the repository at this point in the history
jax2tf.PolyShape has been deprecated since January 2024.
Instead, we can use simple strings.

PiperOrigin-RevId: 707033476
  • Loading branch information
gnecula authored and copybara-github committed Dec 17, 2024
1 parent b3541fa commit fe400bb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion haiku/_src/reshape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def f(inputs):
# We convert `f` using `jax2tf` with undefined shape
converted_f = jax2tf.convert(
apply_fn,
polymorphic_shapes=[None, None, jax2tf.PolyShape("_", "T", ...)], # pytype: disable=wrong-arg-count
polymorphic_shapes=[None, None, "_, T, ..."], # pytype: disable=wrong-arg-count
with_gradient=True,
)

Expand Down

0 comments on commit fe400bb

Please sign in to comment.