-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing f8 dtypes #1670
Comments
@bionicles I'm not affiliated with mlx, but I think the easiest option would be some software emulation initially, as F8 can be treated like a quantization. Eventual hardware support would be super cool! |
Hi @bionicles ! We do support 3, 4, 6 and 8 bit quantization so you can use large models in MLX already. Regarding supporting FP8 variants. I think it is quite unlikely that it will be added anytime soon. It is not technically very hard, we are already emulating bf16 for older machines that don't support it natively but I am not sure that it provides any benefit for now. It will be significantly slower than fp32 and it will add several more types that will increase the footprint of the library without being seriously used. Let me know what you think. |
How do we know fp8 would be slower than fp32? Seems like it would be significantly less calculation involved if there were fewer bits to manipulate, so I don't understand how fp8 could be slower than fp32, can you please clarify the rationale? Additionally, the main reason to promote fp8 dtype would be to increase the maximum achievable model parameter count. Since the parameters could be smaller, can we fit more of them? I don't want to run quantized versions of models trained at higher precision, I want to train larger models at reduced precision, because then I can train them on local hardware instead of needing to depend on externals. |
Hi, the unified memory of Apple silicon devices is compelling for AI training, and often enables these to have significantly more memory for parameters and gradients than best consumer or even data center grade discrete GPUs.
However, inspecting the data types list, the smallest float dtype I saw in mlx today was 16 bits (f16 or bf16)
Adding 8 bit floats to mlx would effectively double the maximum possible model size.
Would this be possible with software or does it need to be a hardware update?
If it is possible with software, how could we make it happen?
Some options for sensible default f8 e/m split for an 8-bit dtype could be:
e5m2
e4m3
e3m4
Then the question becomes how we would rank these and decide which one is best for the most likely use cases?
The text was updated successfully, but these errors were encountered: