You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
According to the docstring of softmax_cross_entropy_with_integer_labels, axis admits None, int, or tuple[int, ...]. However, snippet below demonstrates that axis of tuple[int, int] causes an exception.
xs=jnp.ones((1, 2, 3, 4))
ys=jnp.zeros(xs.shape[:-2], dtype=jnp.int32)
mask=jnp.ones_like(xs, dtype=jnp.bool)
optax.softmax_cross_entropy_with_integer_labels(xs, ys, (-2, -1), mask) # FAIL# TypeError: 'tuple' object cannot be interpreted as an integer
The reason is that implementation exploits take_along_axis which admits scalar axis but not tuple.
Parameter axis of softmax_cross_entropy_with_integer_labels should be either None or int.
Fix a bug and implement "vector" axis parameter.
However, semantic of the labels becomes unclear.
Specifically, if axis is (-2, -1) as in example above than what shape of labels and how to specify a label of 2-dimensional slice of logits? Obvious solution is append another dimension to labels with len(axis) elements (i.e. (1, 2, 2) in example above). Another solution is to assume that elements labels are flat indices in 2-dimensional slices of logits in this example.
The text was updated successfully, but these errors were encountered:
Great catch! I think solution 2. can be rather ambiguous like you're pointing out. Given the nature of the cross-entropy loss it's probably not too much to ask the user to reshape the label axes to a single axis.
I'm working on a fix (using your solution 1.) here: #1164
I'm also leaning towards 1. (restricting the argument to int or None) because one_hot doesn't support axis tuples and this function simulates an explicit one_hot application.
According to the docstring of
softmax_cross_entropy_with_integer_labels
,axis
admitsNone
,int
, ortuple[int, ...]
. However, snippet below demonstrates thataxis
oftuple[int, int]
causes an exception.The reason is that implementation exploits
take_along_axis
which admits scalaraxis
but nottuple
.optax/optax/losses/_classification.py
Lines 335 to 343 in 1e08bcc
There are two options from my perspective.
axis
ofsoftmax_cross_entropy_with_integer_labels
should be eitherNone
orint
.axis
parameter.However, semantic of the
labels
becomes unclear.Specifically, if
axis
is(-2, -1)
as in example above than what shape oflabels
and how to specify a label of 2-dimensional slice oflogits
? Obvious solution is append another dimension tolabels
withlen(axis)
elements (i.e.(1, 2, 2)
in example above). Another solution is to assume that elementslabels
are flat indices in 2-dimensional slices oflogits
in this example.The text was updated successfully, but these errors were encountered: