Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training time #3

Open
youweiliang opened this issue Nov 10, 2020 · 9 comments
Open

Training time #3

youweiliang opened this issue Nov 10, 2020 · 9 comments

Comments

@youweiliang
Copy link

Hi, thanks for the implementation. Could you provide the (approximate) training time to produce the result in your table in the Readme?

@guanfuchen
Copy link

same problem, it seems slow using the code for training BYOL, how can we optim it for fast training?

@yaox12
Copy link
Owner

yaox12 commented Nov 17, 2020

@guanfuchen
First, training BYOL is slow naturally, due to:

  1. With the same architecture, ResNet-50 for example, BYOL does more than twice forward calculations in comparison to supervised learning.
  2. BYOL, as well as other self-supervised methods such as SimCLR, requires more training epochs to converge. Most of the ablation study results in the paper are reported at 300 epochs while the best performance is reached at 1000 epochs.

Second, self-supervised methods usually use a large batch size, e.g., 4096, for better performance. And there are a lot of efficiency challenges in large-scale distributed training.

  1. BYOL occupies more than twice GPU memory than supervised training. That's to say, it needs twice the amount of GPU cards to ensure the same batch size. Unfortunately, communication costs or throughput will be a major bottleneck under this circumstance.
  2. I adopt synchronized batch normalization (syncBN) in my code. It does an extra synchronization among all the GPU cards during the forward calculation, which is very time-consuming. While the authors have not stressed using syncBN in BYOL, I found that it is hard to reproduce the reported results without syncBN. So I add it back referring to the implementation of SimCLR.

Besides, there are some points that could be optimized. The most important one is data loading. The dataloader assumes reading raw images and then doing data augmentations, which is rather slow. You can:

  1. Use better optimized data augmentation library, e.g., albumentations (improves a little). An example is here.
  2. Use newer torchvision versions (with pytorch>=1.7.0) which can apply augmentations to tensors (in batches and on GPUs) (I have not tried)
  3. Use DALI for reading (converting raw images to tfrecords accelerates a lot) and augmentation (however, some augmentation used by BYOL may not be supported by DALI).

@guanfuchen
Copy link

Thxs, great answer. DALI is not very flexiable, yes, syncbn is important. I am considering the MOCO, it seems more quick than BYOL, and it does not use SyncBN. @yaox12

@guanfuchen
Copy link

guanfuchen commented Dec 10, 2020

  1. another problem about transformation is slow, the cv2.GaussianBlur is very slow. It seems low linear evaluation accuracy using PIL GaussianBlur. This is OpenSelfSup Code about BYOL GaussianBlur transformation.

  2. Can we use byol_transform_a to get same result as byol_transform?

image

@yaox12
Copy link
Owner

yaox12 commented Dec 10, 2020

@guanfuchen

  1. The official implementation of MoCo also uses PIL GaussianBlur, however I haven't tried it because it accepts a single parameter sigma. I kept the implementation as close to the BYOL paper as possible.
  2. albumentations in byol_transform_a.py leads to a similar linear evaluation result but I remember it didn't accelerate much. The color adjustment in byol_transform_a.py may not be correct. You can refer to How to make an equivalent migration from torchvision.transforms.ColorJitter to albumentations albumentations-team/albumentations#672 and Is there a HueSaturationLightness (to recreate torchvision.transforms.ColorJitter)? albumentations-team/albumentations#698 for help. BTW, GaussianBlur has been added to torchvision now, I think it deserves a try.

@guanfuchen
Copy link

Thxs @yaox12 , you code is really excelent!!! I will try but the cost is high. I modify the weight decay to 0 using linear evaluation, then the acc is almost same as the origin cv2 implementaion.

@yaox12 yaox12 pinned this issue Dec 16, 2020
@yaox12 yaox12 reopened this Dec 16, 2020
@yaox12 yaox12 unpinned this issue Dec 16, 2020
@youweiliang
Copy link
Author

@yaox12 Hi, I notice you actually did not use mixed precision in the BYOL training since opt_level is "O0" in your train_config.yaml. Have you tried using opt_level: "O1"? Would opt_level: "O1" cause a significant accuracy drop?

@yaox12
Copy link
Owner

yaox12 commented Jan 21, 2021

@youweiliang See #7

@youweiliang
Copy link
Author

Thx. I have just tried using torch.cuda.amp (PyTorch 1.6) for the mixed prescision training and also observed similar computation time with/without amp. I tend to agree with you on the communication bottleneck.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants