Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General question: How can finetuning flux in fp8 work with gradient in bf16? #6394

Open
robinchm opened this issue Jan 8, 2025 · 0 comments
Labels
User Support A user needs help with something, probably not a bug.

Comments

@robinchm
Copy link

robinchm commented Jan 8, 2025

Your question

This is a general question, which I failed to find the answer in code and google, so I am asking her for help to understand the following:

The flux model can be finetuned with its transformer (unet) weight in fp8, data in fp32 and autocast to bf16. While I verified this is indeed working, I cannot reproduce it with simple code.

For example:

import torch

a = torch.zeros((1), dtype=torch.float8_e4m3fn, device="cuda:0")
b = torch.zeros((1), dtype=torch.float32, device="cuda:0")

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    c = a * b

The above code failed with RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Float. But somehow this just works in finetuning flux. Where is the magic? Many thanks!

Logs

No response

Other

Using torch 2.5.1, cuda 12.4, card 4090.

@robinchm robinchm added the User Support A user needs help with something, probably not a bug. label Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
User Support A user needs help with something, probably not a bug.
Projects
None yet
Development

No branches or pull requests

1 participant