Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure SSL methods in Bolts #929

Open
Atharva-Phatak opened this issue Oct 31, 2022 · 8 comments
Open

Restructure SSL methods in Bolts #929

Atharva-Phatak opened this issue Oct 31, 2022 · 8 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@Atharva-Phatak
Copy link
Contributor

Atharva-Phatak commented Oct 31, 2022

🚀 Feature

'

Bolts has implementations of SSL methods where there are implementation of custom models, and each one has its own lightning module.

Here's what I and @senarvi have in mind for restructuring SSL methods in bolts

  • Single place for SSL projection heads (SSL models are usually same with a difference in projection heads)
  • Single place for SSL losses
  • Single place for transforms as well.
  • General PL module for training any custom SSL model :)
@Atharva-Phatak Atharva-Phatak added enhancement New feature or request help wanted Extra attention is needed labels Oct 31, 2022
@Atharva-Phatak Atharva-Phatak self-assigned this Oct 31, 2022
@senarvi
Copy link
Contributor

senarvi commented Nov 1, 2022

By "General PL module for training any custom SSL model" do you need a command line application?

@senarvi
Copy link
Contributor

senarvi commented Nov 1, 2022

Currently the SSL models take the name of the backbone, such as "resnet18", in the constructor. I would suggest that they also accept any network (nn.Module) that will be used directly as the encoder, like in my pull request.

In order to use the learned encoder in a downstream task, the user probably wants to load the weights to a classification or detection model for fine-tuning. What's your idea of how this is done? I have overridden the on_load_checkpoint() method in the model class to translate the variable names in the state_dict, so that I can load a checkpoint that was trained using SSL. If the same encoder is used in both models, only the root name needs to be translated - that is the name of the variable in the model class that stores the network, for example encoder_q in MoCo. I wonder if we can make this easier for the user. Should we at least always use the same variable name, i.e. rename encoder_q to encoder?

@Atharva-Phatak
Copy link
Contributor Author

Hi @senarvi I will quickly create a dirty example and let you know. Give me some time, I have few nice ideas :). Lightning cli runs may fail because there was an issue with jsonparse :)

@matsumotosan
Copy link
Contributor

Just a note on the SSL transforms. I've been working on moving all of the transforms to pl_bolts/transforms/self_supervised #904

@Atharva-Phatak
Copy link
Contributor Author

@senarvi Here is a simple idea

from lightning.bolts.ssl.heads import SomeSSLMethodHead
from lightning.bolts.ssl.criterion import SomeSSLCriterion
from torchvision.models import resnet50

class CustomLightningModule(pl.LightningModule):
             def __init__(self, backbone):
                   self.backbone = torch.nn.Sequential(*(list(backbone.children())[:-1]))
                   self.head = SomeSSLMethodHead
                   self.criterion = SomeSSLCriterion

          def forward(self, batch):
                 x = self.backbone(batch)
                 projection = self.head(x)
                 return projection

           def train_step(batch, batch_idx):
                 loss = self.criterion(projection, batch)  #Lossfunction
                 self.log(loss, .....)

Essentially here 3 things are happening

  • We have implemented custom heads in lightning bolts
  • We have implemented custom SSL losses in lightning bolts
  • User is using a pre-trained backbone from timm or somewhere else.
  • All we need to implement is heads and other tools required by SSL methods.

cc @Borda
Let me know what you think ?

@senarvi
Copy link
Contributor

senarvi commented Nov 24, 2022

@Atharva-Phatak I have some questions.

  1. Is the CustomLightningModule class something that the user would write, and Bolts would offer only the heads and criterions? Or do you suggest that Bolts would offer such modules?

  2. I'm trying to keep the terminology consistent in MoCo. Do you think backbone is a better name for the encoder? At least the SimCLR and MoCo papers seem to talk about an encoder.

  3. What's the purpose of list(backbone.children())[:-1]? I guess this won't work generally with any backbone.

  4. resnet50 is imported, but not used. Should it be?

  5. MoCo uses two different encoders. One of them is the running average of the parameters of the other one and it needs to be updated during the training step. If you're suggesting that we have a single module that supports both SimCLR and MoCo by just switching the criterion, I think it won't be flexible enough. But in any case we can have the heads and criterions in a central place.

@matsumotosan
Copy link
Contributor

@Atharva-Phatak @senarvi are you still working on this?

I'm about done moving SSL transforms to their own dedicated module. Just waiting on CI to be fixed to test and merge.

@Atharva-Phatak
Copy link
Contributor Author

That looks like a good start I am little busy with my last semester. If your PR gets fixed we can move to next stage.

@stale stale bot added the won't fix This will not be worked on label Mar 18, 2023
@Lightning-Universe Lightning-Universe deleted a comment from stale bot Mar 18, 2023
@stale stale bot removed the won't fix This will not be worked on label Mar 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants