-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace hardcoded class index with logit_class_dim
argument
#177
base: main
Are you sure you want to change the base?
Conversation
Still WIP test cases |
Ready to review! Discussion points:
|
Merged with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I think @f-dangel mentioned that BackPack doesn't support multiple output dims, so we have to rearrange to the equivalent 2d output/label. ASDL assumes the class dim is
-1
and flattens everything, so again, we have to rearrange to the equivalent 2d output. Same probably holds for asdfghjkl. With other words, we can't just flatten the output, but have to transpose correctly (probably most readable witheinops.rearrange
). See here, for how we handle this in curvlinops with assumption thatlogit_class_dim=1
for CE loss andlogit_class_dim=-1
for MSE loss. - For functorch, it should be possible to also rearrange the logits/labels. For curvlinops, the issue is that we don't have access to the logits, so to make it work here we would have to wrap the forward pass of the model to fix the logit class dim at the end of the forward pass. So maybe we should just raise an informative error, telling the user how to modify the forward pass of the model to be compatible with curvlinops.
To make sure we get this right, we have to test for equivalence of computation, see my comment. Also, if we only want logit_class_dim
to affect classification tasks, we should maybe verify at initialization that likelihood="classification"
if logit_class_dim
is not the default value and raise a ValueError
otherwise.
dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1)) | ||
dummy | ||
if self.loss_type == LOSS_MSE | ||
else dummy.view(-1, dummy.size(self.logit_class_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think flattening will have the intended effect; same for the other changes below. See my main comment.
) | ||
@pytest.mark.parametrize("method", ["full", "kron", "diag"]) | ||
@pytest.mark.parametrize("logit_class_dim", [-1, 1000]) | ||
def test_logit_class_dim_class(backend_cls, method, logit_class_dim, model, class_Xy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this has to be tested for actual equivalence of the computation, i.e. pytest.mark.parametrize
the tests of the backends with logit_class_dim
.
It seems more complicated than anticipated. This PR is useful for models with image outputs like diffusion models. Considering v0.2 is all about LLMs, let's defer this to v0.3! |
Ok, for now maybe we can add a note in the README and the docstring that clearly states how multi dim outputs are handled? |
So what do you have in mind regarding the wording? Something like this in README.md? ## Caveats
- Currently, this library always assumes that the model has an
output tensor of shape `(batch_size, ..., n_classes)`, so in
the case of image outputs, you need to rearrange from NCHW to NHWC. |
Yes, this is exactly what I was thinking! |
Closes #163
Please wait until #144 is merged.