Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing f8 dtypes #1670

Open
bionicles opened this issue Dec 7, 2024 · 3 comments
Open

Missing f8 dtypes #1670

bionicles opened this issue Dec 7, 2024 · 3 comments

Comments

@bionicles
Copy link

Hi, the unified memory of Apple silicon devices is compelling for AI training, and often enables these to have significantly more memory for parameters and gradients than best consumer or even data center grade discrete GPUs.

However, inspecting the data types list, the smallest float dtype I saw in mlx today was 16 bits (f16 or bf16)

Adding 8 bit floats to mlx would effectively double the maximum possible model size.

Would this be possible with software or does it need to be a hardware update?

If it is possible with software, how could we make it happen?

Some options for sensible default f8 e/m split for an 8-bit dtype could be:

e5m2
e4m3
e3m4

Then the question becomes how we would rank these and decide which one is best for the most likely use cases?

@EricLBuehler
Copy link

@bionicles I'm not affiliated with mlx, but I think the easiest option would be some software emulation initially, as F8 can be treated like a quantization.

Eventual hardware support would be super cool!

@angeloskath
Copy link
Member

Hi @bionicles ! We do support 3, 4, 6 and 8 bit quantization so you can use large models in MLX already.

Regarding supporting FP8 variants. I think it is quite unlikely that it will be added anytime soon. It is not technically very hard, we are already emulating bf16 for older machines that don't support it natively but I am not sure that it provides any benefit for now. It will be significantly slower than fp32 and it will add several more types that will increase the footprint of the library without being seriously used.

Let me know what you think.

@bionicles
Copy link
Author

bionicles commented Jan 12, 2025

It will be significantly slower than fp32

How do we know fp8 would be slower than fp32? Seems like it would be significantly less calculation involved if there were fewer bits to manipulate, so I don't understand how fp8 could be slower than fp32, can you please clarify the rationale?

Additionally, the main reason to promote fp8 dtype would be to increase the maximum achievable model parameter count. Since the parameters could be smaller, can we fit more of them?

I don't want to run quantized versions of models trained at higher precision, I want to train larger models at reduced precision, because then I can train them on local hardware instead of needing to depend on externals.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants