From 9cc31259718cb1c57ec5d13175485c1f6d9af223 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 2 Jan 2025 11:21:03 +0100 Subject: [PATCH] axis should be int or None --- optax/losses/_classification.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 087e1ae4f..694a84f40 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -273,7 +273,7 @@ def softmax_cross_entropy( def softmax_cross_entropy_with_integer_labels( logits: chex.Array, labels: chex.Array, - axis: Union[int, tuple[int, ...], None] = -1, + axis: Union[int, None] = -1, where: Union[chex.Array, None] = None, ) -> chex.Array: r"""Computes softmax cross entropy between the logits and integer labels. @@ -297,7 +297,7 @@ def softmax_cross_entropy_with_integer_labels( labels: Integers specifying the correct class for each input, with shape ``[batch_size]``. Class labels are assumed to be between 0 and ``num_classes - 1`` inclusive. - axis: Axis or axes along which to compute. + axis: Axis along which to compute. where: Elements to include in the computation. Returns: @@ -329,6 +329,9 @@ def softmax_cross_entropy_with_integer_labels( """ chex.assert_type([logits], float) chex.assert_type([labels], int) + if axis is not None and not isinstance(axis, int): + raise ValueError(f'axis = {axis} is unsupported. Provide an int or None.') + # This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that # we avoid subtracting the normalizer from all values, just from the values # for the correct labels.