Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace hardcoded class index with logit_class_dim argument #177

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

wiseodd
Copy link
Collaborator

@wiseodd wiseodd commented Apr 26, 2024

Closes #163

Please wait until #144 is merged.

@wiseodd wiseodd added the enhancement New feature or request label Apr 26, 2024
@wiseodd wiseodd added this to the 0.2 milestone Apr 26, 2024
@wiseodd wiseodd self-assigned this Apr 26, 2024
@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 26, 2024

Still WIP test cases

@wiseodd wiseodd marked this pull request as ready for review April 27, 2024 01:53
@wiseodd wiseodd requested review from runame and aleximmer April 27, 2024 01:54
@wiseodd wiseodd changed the base branch from main to mc-subset2 April 27, 2024 17:25
Base automatically changed from mc-subset2 to main April 27, 2024 18:53
@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 27, 2024

Ready to review!

Discussion points:

  1. Do BackPACK, ASDL, Asdfghjkl even support multiple output dims? That is, if we flatten logits = logits.view(-1, logits.size(logit_class_dim)), do they even compute the correct quantities?
  2. Curvlinops and torch.func interfaces always assume logit_class_dim = 1. Do we want to make them respect logit_class_dim? I don't think flattening as above is the correct approach, right?

@wiseodd
Copy link
Collaborator Author

wiseodd commented Jun 17, 2024

Merged with main and ready for review!

Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I think @f-dangel mentioned that BackPack doesn't support multiple output dims, so we have to rearrange to the equivalent 2d output/label. ASDL assumes the class dim is -1 and flattens everything, so again, we have to rearrange to the equivalent 2d output. Same probably holds for asdfghjkl. With other words, we can't just flatten the output, but have to transpose correctly (probably most readable with einops.rearrange). See here, for how we handle this in curvlinops with assumption that logit_class_dim=1 for CE loss and logit_class_dim=-1 for MSE loss.
  2. For functorch, it should be possible to also rearrange the logits/labels. For curvlinops, the issue is that we don't have access to the logits, so to make it work here we would have to wrap the forward pass of the model to fix the logit class dim at the end of the forward pass. So maybe we should just raise an informative error, telling the user how to modify the forward pass of the model to be compatible with curvlinops.

To make sure we get this right, we have to test for equivalence of computation, see my comment. Also, if we only want logit_class_dim to affect classification tasks, we should maybe verify at initialization that likelihood="classification" if logit_class_dim is not the default value and raise a ValueError otherwise.

dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1))
dummy
if self.loss_type == LOSS_MSE
else dummy.view(-1, dummy.size(self.logit_class_dim))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think flattening will have the intended effect; same for the other changes below. See my main comment.

)
@pytest.mark.parametrize("method", ["full", "kron", "diag"])
@pytest.mark.parametrize("logit_class_dim", [-1, 1000])
def test_logit_class_dim_class(backend_cls, method, logit_class_dim, model, class_Xy):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this has to be tested for actual equivalence of the computation, i.e. pytest.mark.parametrize the tests of the backends with logit_class_dim.

@wiseodd
Copy link
Collaborator Author

wiseodd commented Jul 4, 2024

It seems more complicated than anticipated. This PR is useful for models with image outputs like diffusion models.

Considering v0.2 is all about LLMs, let's defer this to v0.3!

@wiseodd wiseodd removed this from the 0.2 milestone Jul 4, 2024
@runame
Copy link
Collaborator

runame commented Jul 4, 2024

It seems more complicated than anticipated. This PR is useful for models with image outputs like diffusion models.

Considering v0.2 is all about LLMs, let's defer this to v0.3!

Ok, for now maybe we can add a note in the README and the docstring that clearly states how multi dim outputs are handled?

@wiseodd
Copy link
Collaborator Author

wiseodd commented Jul 4, 2024

So what do you have in mind regarding the wording? Something like this in README.md?

## Caveats

- Currently, this library always assumes that the model has an 
  output tensor of shape `(batch_size, ..., n_classes)`, so in 
  the case of image outputs, you need to rearrange from NCHW to NHWC.

@runame
Copy link
Collaborator

runame commented Jul 4, 2024

So what do you have in mind regarding the wording? Something like this in README.md?

Yes, this is exactly what I was thinking!

@wiseodd wiseodd added this to the 0.3 milestone Jul 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Conform to PyTorch convention in the loss
2 participants