Releases: pytorch/ao
v0.8.0
Highlights
We are excited to announce the 0.8.0 release of torchao! In this release we’ve shipped the first CUTLASS kernel in torchAO which adds support for W4A8 linear operator. In addition to this, we’ve also added TTFT benchmarks to torchAO and compared different quantization + sparsity speedups for prefill / decoding.
W4A8 based on CUTLASS
A new W4A8 linear operator is implemented, that corresponds to int8_dynamic_activation_int4_weight quantization where two 4-bit weights get packed into a single 8-bit integer value; also, CUTLASS is made a sub-module of torchao repo, in order to be able to utilize more of its functionality to implement new kernels.
Benchmarks on A100
-q parameter |
Average tokens/sec | Average Bandwidth in GB/s | Peak Memory Usage in GB | Model Size in GB |
---|---|---|---|---|
95.24 | 258.55 | 13.90 | 13.21 | |
-q int8wo |
155.31 | 1028.37 | 8.97 | 6.62 |
-q int4wo-32 |
186.70 | 774.98 | 5.31 | 4.15 |
-q int4wo-hqq |
186.47 | 774.01 | 5.04 | 4.15 |
-q int8dq |
49.64 | 328.72 | 9.44 | 6.62 |
-q w4a8-cutlass (tuned) |
119.31 | 394.86 | 4.52 | 3.31 |
Prefill performance benchmarks
We’ve added TTFT benchmarks to torchAO and compared different quantization + sparsity speedups for prefill / decoding. During prefill, we are compute bound and find that dynamic quantization offers greater speedups over weight-only quantization, which is faster for prefill. We’ve also added an option for int8 dynamic quantization that will selectively use prefill during LLM decoding.
BC Breaking
Delete the float8-all-gather-only functionality from float8 training (#1451)
The use_fp8_all_gather_only
was an experimental flag, off by default, which was not marketed and not used by anyone as far as we know. We are removing it to simplify the code.
Before
config = Float8LinearConfig(
...,
# the option below is being removed
use_fp8_all_gather_only = True,
)
convert_to_float8_training(model, config=config, ...)
After
The use_fp8_all_gather_only
option is no longer supported.
New Features
- Add TTFT benchmarks + update sparsity benchmarks (#1140)
- Gemlite integration in torchao (#1034)
- W4A8 based on CUTLASS (#880)
Improvement
quantize_
- Expose zero_point_domain as arguments (#1401)
- Add convert path for quantize_ QAT API (#1540)
- Int8 dynamic prefill weight only decode (#1436)
autoquant
- Make int8 dynamic quant in autoquant serializable (#1484)
- Additional fixes for autoquant serialization (#1486)
- Add exhaustive config option to intmm kernel (#1392)
float8 training
- [float8] Allow specifying arbitrary dtype for each tensor, enabling recipes with e4m3 in both the forward and the backward (#1378)
experimental
- Remove temp build files from torchao (#1551)
other
- Torchao setup.py with cmake (#1490)
Bug Fixes
- Fix bfloat16/float16/float32 options (#1369)
- Fix a bug in LinearActivationQuantizedTensor (#1400)
- Fix error message in float8 FSDP utils (#1423)
- Fixes observer attachment to model based on config for wanda sparsifier (#1265)
- [resubmit] Gemlite fix (#1435)
- 🐛 Fix: Memory leak in image processing endpoint (#1513)
Performance
- [float8] Re-enable slow-accum in the bwd of axis-wise scaling schemes (#1377)
Documentation
- Update api_ref_quantization.rst (#1408)
- Update index.rst (#1409)
- Update QAT READMEs using new APIs (#1541)
Developers
New Contributors
- @sanchitintel made their first contribution in #1375
- @philipbutler made their first contribution in #1337
- @airMeng made their first contribution in #1401
- @DerekLiu35 made their first contribution in #1299
- @agrawal-aka made their first contribution in #1265
- @gmagogsfm made their first contribution in #1443
- @dongxiaolong made their first contribution in #1513
Full Changelog: v0.7.0...v0.8.0-rc2
v0.7.0
Highlights
We are excited to announce the 0.7.0 release of torchao! This release moves QAT out of prototype with improved LoRA support and more flexible APIs, and adds support for new experimental kernels such as Marlin QQQ (for CUDA), int8_dynamic_activation_intx_weight
(for ARM CPU), and more!
QAT moved out of prototype, LoRA integration, new flexible APIs (#1020, #1085, #1152, #1037, #1152)
QAT has been moved out of prototype to torchao/quantization/qat
to provide better API stability guarantees moving forward. In addition to the existing *QATQuantizer
classes, we now also support the more flexible FakeQuantizedLinear
and FakeQuantizedEmbedding
modules for users to configure the exact quantization settings they wish to use during QAT.
from torchao.quantization.qat.api import FakeQuantizeConfig
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
# Specify quantization schemes to use during QAT
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=8)
# Replace nn.Linear and nn.Embedding with these in your model
fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config)
fq_embedding = FakeQuantizedEmbedding(16, 32, weight_config=weight_config)
We also leveraged the new flexible APIs to build a new QAT + LoRA fine-tuning flow in torchtune. Try it out today!
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora
Marlin QQQ for CUDA (#1113)
Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. For more details about Marlin QQQ, please refer to paper.
from torchao.dtypes import MarlinQQQLayout
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=128,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
Benchmarking results can be found in https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#marlin-qqq.
This is a prototype feature - feel free to try out!
int8_dynamic_activation_intx_weight Quantization for ARM CPU (#995, #1027, #1254, #1353)
We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon).
from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires fp32 precision"
# Build kernels in temp location, and load them in torch
# This requires an ARM CPU
from torchao.experimental.temp_build import temp_build_and_load_torchao_ops
temp_build_and_load_torchao_ops(cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + "/../../experimental")
# Quantize model
nbit = 4
assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8"
group_size = 128
has_weight_zeros = False
quantize_(
model,
int8_dynamic_activation_intx_weight(
group_size=group_size,
nbit=nbit,
has_weight_zeros=has_weight_zeros,
),
)
Benchmarking results can be found in https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#int8_dynamic_activation_intx_weight-quantization
We are still trying to figure out how to ship the ARM CPU kernels, so the exact API is subject to change.
BC Breaking
Rename AQT#2 LayoutType -> Layout (#1049)
Before:
from torchao.dtypes import (
BlockSparseLayoutType,
Int4CPULayoutType,
MarlinQQQLayoutType,
MarlinSparseLayoutType,
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
UintxLayoutType,
Float8LayoutType,
LayoutType,
PlainLayoutType,
)
After:
from torchao.dtypes import (
BlockSparseLayout,
Int4CPULayout,
MarlinQQQLayout,
MarlinSparseLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
Float8Layout,
Layout,
PlainLayout,
)
QAT imports after move out of prototype (#1091)
Before:
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer
Int8DynActInt4WeightQATQuantizer,
Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.prototype.qat.api import (
FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
FakeQuantizer,
)
After:
from torchao.quantization.qat import (
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.api import (
FakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
)
New Features
- Add BF16 stochastic rounding option for optimizers (#1124)
- Add quantize_() API support for NF4 (#1216)
- Support W4A8 Marlin kernel (#1113)
Improvements
quantize_
- Add default filtering to remove mis-alinged weights (#1194)
- Add tensor parallelism support for int4_weight_only quantization (#1120)
- Add support for asymmetric act quant for int8 dynamic quant (#1131)
- Add support for groupwise quantization for int8 weight only quantization (#1121)
- Add AQT tensor parallel for float8_dynamic_quant (#1078)
- Int8wo Embedding Quant (#1167)
- Making sure int4 weight only supports cpu as well (#1203)
- BF16 support for Quant-LLM kernel (#1147)
- Add hardware check to fp8 quant (#1314)
- Add support for quantize_() with Float8Linear module (#1344)
autoquant
- Added support for Per Tensor Scaling for Float8 Dynamic Autoquant (#1175)
- Add floating point options for autoquant and add accuracy measurement (#1355)
benchmarks
- Adding batchsize support for torchao llama benchmarks (#1182)
- Add capability of benchmarking arbitrary binary (#1107)
experimental
- Add embedding ops aten (#1129)
- Add embedding ops executorch (#1137)
- Add quantized embedding kernels to torchao (#1018)
- Allow deprecated declarations what using Parallel ExecuTorch (#1031)
- Introduce lowbit quantized linear MPS kernels (#954)
- Enable 6-bit kernel (#1027)
- Kleidi 4b blockwise gemv prototype (#997)
- Experimental 6-bit quantization for Llama in torchchat (#1094)
- Introduce 7-bit quantization for Llama in torchchat. (#1139)
- Executorch Subclass API (#966) (#995)
- 8-bit packing support (#1248)
- Experimental Enable 8-bit (#1254)
- Experimental Benchmarking (#1353)
optimizer
- [low-bit optim] Upcast everything to FP32 for internal calculations (#1068)
- [Low-bit optim] Support for dcp.save() and dcp.load() (#1217)
- Enable CPU Offload for Intel GPU (#1324)
SAM2
- SAM2.1 copy (#1172)
- SAM2 AMG server side request batching (#1197)
- More SAM2-fast server improvements (#1285)
- SAM2 Fast AMG: memory profiling and more compile (#1296)
- SAM2 AMG cli and other QoL improvements (#1336)
- SAM2 AMG cli.py on modal (#1349)
- Reduce SAM2 AMG cli startup by using deploy (#1350)
- Reduce startup time for SAM2 AMG by using torch.export (#1358)
- More batching and improved furious accuracy/performance (#1253)
- SAM2.1 and example README (#1048)
- SAM2 AMG example mIoU, perf numbers and more SAM2 model annotations (#1196)
other
- Add SpinQuant to generate.py (#1069)
- SpinQuant (#983)
- SmoothQuant using tensor subclassing (#1030)
- Expose FakeQuantizeConfigs in QAT quantizers (#1214)
- Add module-swap UX for INT8 mixed-precision training (https://github.com/pytorch/...
v0.6.1
Highlights
We are excited to announce the 0.6.1 release of torchao! This release adds support for Auto-Round support, Float8 Axiswise scaled training, a BitNet training recipe, an implementation of AWQ and much more!
Auto-Round Support (#581)
Auto-Round is a new weight-only quantization algorithm, it has as achieved superior accuracy compared to GPTQ, AWQ, and OmniQuant across 11 tasks, particularly excelling in low-bit quantization (e.g., 2-bits and 3-bits). Auto-Round supports quantization from 2 to 8 bits, involves low tuning costs, and imposes no additional overhead during inference. Key results are summarized below, with detailed information available in our paper, GitHub repository, and Hugging Face low-bit quantization leaderboard.
from torchao.prototype.autoround.core import prepare_model_for_applying_auto_round_
from torchao.prototype.autoround.core import apply_auto_round
prepare_model_for_applying_auto_round_(
model,
is_target_module=is_target_module,
bits=4,
group_size=128,
iters=200,
device=device,
)
input_ids_lst = []
for data in dataloader:
input_ids_lst.append(data["input_ids"].to(model_device))
multi_t_input_ids = MultiTensor(input_ids_lst)
out = model(multi_t_input_ids)
quantize_(model, apply_auto_round(), is_target_module)
Added float8 training axiswise scaling support with per-gemm-argument configuration (#940)
We added experimental support for rowwise scaled float8 gemm to torchao.float8
, with per-gemm-input configurability to enable exploration of various recipes. Here is how a user can configure all-axiswise scaling
# all-axiswise scaling
config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
m = torchao.float8.convert_to_float8_training(config)
# or, a custom recipe by @lw where grad_weight is left in bfloat16
config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
m = torchao.float8.convert_to_float8_training(config)
Early performance benchmarks show all-axiswise scaling achieve a 1.13x speedup vs bf16 on torchtitan / LLaMa 3 8B / 8 H100 GPUs (compared to 1.17x from all-tensorwise scaling in the same setup), and loss curves which match to bf16 and all-tensorwise scaling. Further performance and accuracy benchmarks will follow in future releases.
Introduced BitNet b1.58 training recipe (#930)
Adds recipe for doing BitNet b1.58](https://arxiv.org/abs/2402.17764) ternary weights clamping.
from torchao.prototype.quantized_training import bitnet_training
from torchao import quantize_
model = ...
quantize_(model, bitnet_training())
Notably: Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases.
[Prototype] Implemented Activation Aware Weight Quantization AWQ (#743)
Perplexity and performance measured on A100 GPU:
Model | Quantization | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) |
---|---|---|---|---|---|
Llama-2-7b-chat-hf | bfloat16 | 107.38 | 1418.93 | 13.88 | 13.21 |
awq-hqq-int4 | 196.6 | 761.2 | 5.05 | 3.87 | |
awq-uint4 | 43.59 | 194.93 | 7.31 | 4.47 | |
int4wo-hqq | 209.19 | 804.32 | 4.89 | 3.84 | |
int4wo-64 | 201.14 | 751.42 | 4.87 | 3.74 |
Usage:
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
quant_dtype = torch.uint4
group_size = 64
calibration_limit = 10
calibration_seq_length = 1024
model=model.to(device)
insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
with torch.no_grad():
for batch in calibration_data:
model(batch.to(device))
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear)
New Features
- [Prototype] Added Float8 support for AQT tensor parallel (#1003)
- Added composable QAT quantizer (#938)
- Introduced torchchat quantizer (#897)
- Added INT8 mixed-precision training (#748)
- Implemented sparse marlin AQT layout (#621)
- Added a PerTensor static quant api (#787)
- Introduced uintx quant to generate and eval (#811)
- Added Float8 Weight Only and FP8 weight + dynamic activation (#740)
- Implemented Auto-Round support (#581)
- Added 2, 3, 4, 5 bit custom ops (#828)
- Introduced symmetric quantization with no clipping error in the tensor subclass based API (#845)
- Added int4 weight-only embedding QAT (#947)
- Added support for 1-bit and 6-bit quantization for Llama in torchchat (#910, #1007)
- Added a linear_observer class for doing static activation calibration (#807)
- Exposed hqq through uintx_weight_only API (#786)
- Added RowWise scaling option for Float8 dynamic activation quantization (#819)
- Added Float8 weight only to autoquant api (#866)
Improvements
- Enhanced Auto-Round functionality (#870)
- Improved FSDP support for low-bit optimizers (#538)
- Added support for using AffineQuantizedTensor with
weights_only=True
for torch.load (#630) - Optimized 3-bit packing (#1029)
- Added more evaluation metrics to llama/eval.sh (#934)
- Improved eager numerics for dynamic scales in float8 (#904)
Bug fixes
- Fixed inference_mode issues (#885)
- Fixed failing FP6 benchmark (#931)
- Resolved various issues with float8 support (#918, #923)
- Fixed load state dict when device is different for low-bit optim (#1021)
Performance
- Added SM75 (Turing) support for FP6 kernel (#942)
- Implemented int8 dynamic quant + bsr support (#821)
- Added workaround to recover the perf for quantized vit in torch.compile (#926)
INT8 Mixed-Precision Training
On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision.
from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig
from torchao.quantization import quantize_
model = ...
# apply INT8 matmul to all 3 matmuls
quantize_(model, int8_mixed_precision_training())
# customize which matmul is left in original precision.
config = Int8MixedPrecisionTrainingConfig(
output=True,
grad_input=True,
grad_weight=False,
)
quantize_(model, int8_mixed_precision_training(config))
End2end speed benchmark using benchmarks/quantized_training/pretrain_llama2.py
Model & GPU | bs x seq_len | Config | Tok/s | Peak mem (GB) |
---|---|---|---|---|
Llama2-7B, A100 | 8 x 2048 | BF16 (baseline) | ~4400 | 59.69 |
Llama2-7B, A100 | 8 x 2048 | INT8 mixed-precision | ~6100 (+39%) | 58.28 |
Llama2-1B, 4090 | 16 x 2048 | BF16 (baseline) | ~17,900 | 18.23 |
Llama2-1B, 4090 | 16 x 2048 | INT8 mixed-precision | ~30,700 (+72%) | 18.34 |
Docs
- Updated README with more current float8 speedup information (#816)
- Added tutorial for trainable tensor subclass (#908)
- Improved documentation for float8 unification and inference (#895, #896)
Devs
- Added compile tests to test suite (#906)
- Improved CI setup and build processes (#887)
- Added M1 wheel support (#822)
- Added more benchmarking and profiling tools (#1017)
- Renamed
fpx
tofloatx
(#877) - Removed torchao_nightly package (#661)
- Added more lint fixes (#827)
- Added better subclass testing support (#839)
- Added CI to catch syntax errors (#861)
- Added tutorial on composing quantized subclass w/ Dtensor based TP (#785)
Security
No significant security updates in this release.
Untopiced
- Added basic SAM2 AutomaticMaskGeneration example server (#1039)
New Contributors
New Contributors
- @iseeyuan made their first contribution in #805
- @YihengBrianWu made their first contribution in #860
- @kshitij12345 made their first contribution in #863
- @ZainRizvi made their first contribution in #887
- @alexsamardzic made their first contribution in #899
- @vaishnavi17 made their first contribution in #911
- @tobiasvanderwerff made their first contribution in #931
- @kwen2501 made their first contribution in #937
- @y-sq made their first contribution in #912
- @jimexist made their first contribution in #969
- @danielpatrickhug made their first contribution in #914
- @ramreddymounica made their first contribution in #1007
- @yushangdi made their first contribution in h...
v0.5.0
Highlights
We are excited to announce the 0.5 release of torchao! This release adds support for memory efficient inference, float8 training and inference, int8 quantized training, HQQ, automatic mixed-precision quantization through bayesian optimization, sparse marlin, and integrations with HuggingFace, SGLang, and diffusers.
Memory Efficient Inference Support #738
We've added support for Llama 3.1 to the llama benchmarks in TorchAO and added new features and improvements as a proof of concept for memory efficient inference. These additions allow us to to do 130k context length inference with Llama 3.1-8B with only 18.91 GB memory if we combine with kv cache quantization, int4 weight only quantization and linear causal mask.
General savings depend on technique and context length as can be seen in the following graph:
Float8 Training #551
torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
With torch.compile
on, current results show throughput speedups of up to 1.5x on 128 H100 GPU LLaMa 3 70B pretraining jobs (details)
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m, module_filter_fn=...)
And for an end-to-minimal training recipe of pretraining with float8, you can check out torchtitan.
Float8 Inference #740 #819
We have introduced two new quantization APIs for Float8 inference:
-
Float8 Weight-Only Quantization: A new quant_api float8_weight_only() has been added to apply float8 weight-only symmetric per-channel quantization to linear layers.
-
Float8 Dynamic Activation and Weight Quantization: A new quant_api float8_dynamic_activation_float8_weight() has been introduced to apply float8 dynamic symmetric quantization to both activations and weights of linear layers. By default PerTensor scaling. We have also added an option to do PerRow scaling of both activations and weights. By computing scales at a finer granularity, it can potentially reduce the overall quantization error and increase performance by reducing dynamic quantization overhead.
Example usage:
import torch
from torchao.quantization import quantize_, float8_weight_only, float8_dynamic_activation_float8_weight, PerRow
# Create a model
model = YourModel()
# Apply float8 weight-only quantization
quantize_(model, float8_weight_only())
# Apply float8 dynamic activation and weight quantization
quantize_(model, float8_dynamic_activation_float8_weight())
# Apply PerRow scaling to weight and activations
quantize_(linear_module, float8_dynamic_activation_float8_weight(granularity=PerRow()))
Notes:
- These new APIs are designed to work with PyTorch 2.5 and later versions.
float8_dynamic_activation_float8_weight
requires CUDA devices with compute capability 8.9 or higher for hardware acceleration.
Int8 quantized training #644 #748
@gau-nernst introduced 2 experimental works on training using INT8.
- INT8 quantized training (#644): weight is quantized to INT8 during the whole duration of training to save memory. Compute remains in high precision. To train the model effectively with only quantized weights, we use stochastic rounding for weight update. Right now, memory saving is not too competitive compared to compiled BF16 baseline.
- INT8 mixed-precision training (#748): weight is kept in the original high precision, but weight and activation are dynamically quantized to INT8 during training to utilize INT8 tensor cores. We observe up to 70% speedup for Llama2 pre-training on 4090, and 20% speedup for Llama3 pre-training on 8x A100 with FSDP2.
from torchao.quantization import quantize_
from torchao.prototype.quantized_training import int8_weight_only_quantized_training, int8_mixed_precision_training
model = YourModel()
# apply INT8 quantized training
quantize_(model, int8_weight_only_quantized_training())
# apply INT8 mixed-precision training
quantize_(model, int8_mixed_precision_training())
For more information and benchmark results, see README and the respective PR (#644 and #748)
HQQ Integration in torchao #605 #786
hqq is added to existing torchao APIs, it gives improvements on model accuracy and leverages the existing efficient kernels in torchao. We enabled hqq for int4_weight_only
API:
quantize_(model, int4_weight_only(group_size, use_hqq=True)
We also added this to the uintx api for accuracy experiments (current uintx kernels are slow):
quantize_(model, uintx_weight_only(torch.uint2, group_size, use_hqq=True)
Automatic Mixed-Precision Quantization through Bayesian Optimization #592, #694
We provided a Bayesian Optimization (BO) tool leveraging Ax to auto search mixed-precision weight-only quantization configuration, i.e., bit width and group size of intN_weight_only(bit_width, group_size)
for each layer. It also includes a sensitivity analysis tool to calculate layer-wise average Hessian trace and average fisher information matrix trace, which is an optional step to customize and improve BO search.
To optimize for model accuracy under a model size constraint (GB):
python --BO_acc_modelsize.py --checkpoint=/tmp/Meta-Llama-3-8B --model_size_constraint=6.0
To optimize for inference throughput under a model perplexity constraint:
python --BO_acc_throughput.py --checkpoint=/tmp/Meta-Llama-3-8B --ppl_constraint=7.5
For more detailed usage, please refer to this README. The mixed-precision quantization searched by this tool reduces 20.1% model size with 2.8% perplexity reduction, and improves 15.1% inference throughput with 3.2% perplexity reduction on the Llama3-8B model compared to int8 uniform quantization.
Sparse Marlin #621, #733
@Diogo-V added sparse-marlin, a W4AFP16 2:4 sparse kernel, support to TorchAO.
On Meta LLama3, we observe a 25% tok/s increase (180 -> 226) compared to our existing int4-wo implementation.
from torchao.quantization.quant_api import quantize_, int4_weight_only
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
---|---|---|---|---|---|
Llama-3-8B | Base (bfloat16) | 95.64 | 1435.54 | 16.43 | 15.01 |
int8dq | 8.61 | 64.75 | 9.24 | 7.52 | |
int8wo | 153.03 | 1150.80 | 10.42 | 7.52 | |
int4wo-64 | 180.80 | 763.33 | 6.88 | 4.22 | |
int4wo-64-sparse-marlin | 226.02 | 689.20 | 5.32 | 3.05 |
HuggingFace Integration
torchao is integrated into huggingface: https://huggingface.co/docs/transformers/main/en/quantization/torchao now you can use int4_weight_only
, int8_weight_only
and int8_dynamic_activation_int8_weight
through TorchAoConfig
in huggingface. Currently available in huggingface main branch only.
SGLang Integration
torchao is also integrated into sglang (sgl-project/sglang#1341) for llama3 model, you can try out with:
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128
Supported configurations are ["int4wo-<group_size>", "int8wo", "int8dq", "fp8wo" (only available in torchao 0.5+)]
diffusers Integration
diffusers-torchao provides end-to-end inference and experimental training recipes to use torchao with diffusers in this repo. We demonstrate 53.88% speedup on Flux.1-Dev* and 27.33% speedup on CogVideoX-5b when comparing compiled quantized models against their standard bf16 counterparts.
BC Breaking
Add layout option to woq int4 api #670
# torchao 0.4.0
from torchao.quantization import quantize_, int4_weight_only
quantize_(my_model, int4_weight_only(inner_k_tiles=8))
# torchao 0.5.0
from torchao.quantization import quantize, int4_weight_only
quant...
v0.4.0
v0.4.0
Highlights
We are excited to announce the 0.4 release of torchao! This release adds support for KV cache quantization, quantization aware training (QAT), low bit optimizer support, composing quantization and sparsity, and more!
KV cache quantization (#532)
We've added support for KV cache quantization, showing a peak memory reduction from 19.7 -> 19.2 GB on Llama3-8B at an 8192 context length. We plan to investigate Llama3.1 next.
Quantization-Aware Training (QAT) (#383, #555)
We now support two QAT schemes for linear layers: Int8 per token dynamic activations + int4 per group weights, and int4 per group weights (using the efficient tinygemm int4 kernel after training). Users can access this feature by transforming their models before and after training using the appropriate quantizer, for example:
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)
# Convert fake quantize to actual quantize operations
model = qat_quantizer.convert(model)
Initial evaluation results indicate that QAT in torchao can recover up to 96% of quantized accuracy degradation on hellaswag and up to 68% of quantized perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the README and this blog post.
Composing quantization and sparsity (#457, #473)
We've added support for composing int8 dynamic quantization with 2:4 sparsity, using the quantize_
API. We also added SAM benchmarks that show a 7% speedup over standalone sparsity / int8 dynamic quantization here.
from torchao.quantization import quantize_, int8_dynamic_activation_int8_semi_sparse_weight
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())
Community Contributions
low-bit optimizer support (#478, #463, #482, #484, #538)
@gau-nernst added implementations for 4-bit, 8-bit, and FP8 Adam with FSDP2/FSDP support. Our API is a drop-in replacement for torch.optim.Adam
and can be used as follows:
from torchao.prototype.low_bit_optim import Adam8bit, Adam4bit, AdamFp8
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
model = ...
optim = Adam8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
For more information about low bit optimizer support please refer to our README.
Improvements to 4-bit quantization (#517, #552, #544, #479 )
@bdhirsh @jeromeku @yanbing-j @manuelcandales @larryliu0820 added torch.compile support for NF4 Tensor, custom CUDA int4 tinygemm unpacking ops, and several bugfixes to torchao
BC breaking
quantize
has been renamed toquantize_
#467
# for torchao 0.4
from torchao.quantization import quantize_, int8_weight_only
quantize_(model, int8_weight_only())
# for torchao 0.3
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())
apply_sparse_semi_structured
has been deprecated in favor ofsparsify_
which matches thequantize_
API #473
# for torchao 0.4
from torchao.sparsity import _sparsify, semi_sparse_weight
sparsify_(model, semi_sparse_weight())
# for torchao 0.3
from torchao.sparsity import apply_sparse_semi_structured
apply_sparse_semi_structured(model)
Deprecations
New Features
- Added kv_cache quantization #532
- Migrated float8_experimental to
torchao.float8
, enabling float8 training support #551 #529 - Added FP5 E2M2 #399
- Added 4-bit, 8-bit, and FP8 ADAM support #478 #463 #482
- Added FSDP2 support for low-bit optimizers #484
- [prototype] mixed-precision quantization and eval framework #531
- Added int4 weight-only QAT support #555, #383
- Added custom CUDA
tinygemm
unpacking ops #415
Improvements
- Composing quantization and sparsity now uses the unified AQT Layout #498
- Added default inductor config settings #423
- Better dtype and device handling for
Int8DynActInt4WeightQuantizer
andInt4WeightOnlyQuantizer
#475 #479 - Enable
model.to
for int4/int8 weight only quantized models #486 #522 - Added more logging to
TensorCoreTiledAQTLayout
#520 - Added general
fake_quantize_affine op
with mask support #492 #500 - QAT now uses the shared
fake_quantize_affine
primitive #527 - Improve FSDP support for low-bit optimizers #538
- Custom op and inductor decomp registration now uses a decorator #434
- Updated torch version to no longer require
unwrap_tensor_subclass
#595
Bug fixes
- Fixed import for
TORCH_VERSION_AFTER_*
#433 - Fixed crash when PYTORCH_VERSION is not defined #455
- Added
torch.compile
support forNF4Tensor
#544 - Added fbcode check to fix torchtune in Genie #480
- Fixed
int4pack_mm
error #517 - Fixed cuda device check #536
- Weight shuffling now runs on CPU for int4 quantization due to a MPS memory issue #552
- Scale and input now are the same dtype for int8 weight only quantization #534
- Fixed FP6-LLM API #595
Performance
- Added
segment-anything-fast
benchmarks for composed quantization + sparsity #457 - Updated low-bit Adam benchmark #481
Docs
- Updated README.md #583 #438 #445 #460
- Updated installation instructions #447 #459
- Added more docs for int4_weight_only API #469
- Added developer guide notebook #588
- Added optimized model serialization/deserialization doc #524 #525
- Added new float8 feature tracker #557
- Added static quantization tutorial for calibration-based techniques #487
Devs
- Fix numpy version in CI #537
- trymerge now uploads merge records to s3 #448
- Updated python version to 3.9 #488
torchao
no long depends ontorch
#449benchmark_model
now accepts args and kwargs and supportscpu
andmps
backends #586 #406- Add git version suffix to package name #547
- Added validations to torchao #453 #454
- Parallel test support with pytest-xdist #518
Quantizer
now useslogging
instead ofprint
#472
Not user facing
- Refactored
_replace_linear_8da4w
#451 - Remove unused code from AQT implementation #476 #440 #441 #471
- Improved error message for lm_eval script #444
- Updated HF_TOKEN env variable #427
- Fixed typo in Quant-LLM in #450
- Add a test for map_location="cpu" in #497
- Removed sparse test collection warning #489
- Refactored layout imple...
v0.3.1
v0.3.1
Highlights
We are excited to announce the 0.3 release of torchao! This release adds support for a new quantize API, MX format, FP6 dtype and bitpacking, 2:4 sparse accelerated training and benchmarking infra for llama2/llama3 models.
quantize
API (#256)
We added a tensor subclass based quantization API, see docs and README for details on usage, this is planned to replace all existing quantization APIs in torchao for torch 2.4 and later.
Accelerated training with 2:4 sparsity (#184)
You can now accelerate training with 2:4 sparsity, using the runtime pruning + compression kernels written by xFormers. These kernels process a 4x4 sub-tile to be 2:4 sparse in both directions, to handle both the forward and backward pass when training. We see a 1.3x speedup for the MLP layers of ViT-L across a forward and backwards pass.
MX support (#264)
We added prototype support for MX format for training and inference with a reference native PyTorch implementation of training and inference primitives for using MX accelerated matrix multiplications. The MX numerical formats are new low precision formats with recent acceptance into the OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
Benchmarking (#276, #374)
We added a stable way to benchmark llama2 and llama3 models that includes perf/accuracy comparisons. See torchao/_models/llama/benchmarks.sh for more details.
🌟 💥 Community Contributions 🌟 💥
FP6 support (#279, #283, #358)
@gau-nernst Added support for FP6 dtype and mixed matmul FP16 x FP6 kernel with support for torch.compile. Benchmark results show a 2.3x speedup over BF16 baseline for meta-llama/Llama-2-7b-chat-hf
Bitpacking (#307, #282)
@vayuda, @melvinebenezer @CoffeeVampir3 @andreaskoepf Added support for packing/unpacking lower bit dtypes leveraging torch.compile to generate the kernels for this and added UInt2 and Bitnet tensor based on this approach.
FP8 split-gemm kernel #263
Added the kernel written by @AdnanHoque to torchao with speedups compared to the cuBLAS kernel for batch size <=16
BC Breaking
Deprecations
- Deprecate top level quantization APIs #344
1. int8 weight only quantization
apply_weight_only_int8_quant(model)
or change_linear_weights_to_int8_woqtensors(model)
-->
# for torch 2.4+
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())
# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(model)
2. int8 dynamic quantization
apply_dynamic_quant(model)
or change_linear_weights_to_int8_dqtensors(model)
-->
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
torch._inductor.config.force_fuse_int_mm_with_mul = True
# for torch 2.4+
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
quantize(model, int8_dynamic_activation_int8_weight())
# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(model)
3. int4 weight only quantization
change_linear_weights_to_int4_wotensors(model)
-->
# for torch 2.4+
from torchao.quantization import quantize, int4_weight_only
quantize(model, int4_weight_only())
# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(model)
New Features
- Add
quantize
#256 - Add a prototype of MX format training and inference #264
- [FP6-LLM] Port splitK map from DeepSpeed #283
- Improve FP6-LLM 2+4bit weight splitting + user API #279
- Bitpacking #291
- training acceleration via runtime semi-structured sparsity #184
- Bitpackingv2 #307
- Add FP6-LLM doc and move FP6-LLM to prototype #358
- Added first bits of Uint2Tensor and BitnetTensor #282
Improvements
- Improve primitives for FP6 quant #248
- Extract eval code from GPTQ for more general usage #275
- Factor out the specific configurations to helper functions #286
- Add support for
AQTLayout
,PlainAQTLayout
andTensorCoreTiledAQTLayout
#278 - Graceful handling of cpp extensions #296
- Refactor int8 dynamic quantization with call to
quantize
#294 - [NF4][FSDP] return contiguous
quantization_factor
#298 - Refactor int4 and int8 weight only quantization to use
quantize
#301 - Adding a quick way for users to test model eval for hf models #328
- Wrap torch.ops.quantized_decomposed to improve import errors #310
- [NF4Tensor] Switch to save for backward since are now a tensor input #323
- Refactor rest of tinygemm quant primitive ops #321
- Move some util functions from quantization.utils to torchao.utils #337
- Clean up FP6-LLM #304
- Move quant ops to utils.py #331
- FP6-LLM clean up (again) #339
- Improving hf_eval.py #342
- Generalize Model Size Code #364
- Minor upgrades to bit pack #347
- Factor out dispatch and layout registration table #360
- Add
register_apply_tensor_subclass
#366 - Refactor custom FPx cast #363
- Remove all dependencies except torch #369
- Enable a test for loading state_dict with tensor subclasses #389
- 073 scripts for benchmarks #372
- Add WOQ int8 test with Inductor Freeze #362
- Benchmarking updates for semi-structured sparse training #398
- add FSDP QLoRA test and revert failing PR #403
- Refactor the API for quant method argument for quantize function #400
- eval script fixes #414
Bug Fixes
- Fixed the HQQ import skip #262
- fixing autoquant bug #265
- Fix eval import after #275 #290
- Fixed f-string printing of
NF4Tensor
s #297 - Check and fix dequantize_affine is idempotent #309
- Update old pretrained TorchVision API in ao tutorials (#313) #314
- Fix dimension issues for int4 weight only quant path #330
- Fix compile in
hf_eval.py
#341 - task_list to tasks in hf_eval #343
- fixing peak memory stats for benchmark #353
- Fix inductor config BC change #382
- fixing scripts #395
Performance
- FP8 splitgemm user defined triton kernel #263
- sparse benchmarking numbers #303
- Fix FP6-LLM benchmark #312
- Adding Llama to TorchAO #276
- Generalize Model Size Code #364
- eval script for llama #374
- 077 autoquant gpt fast #361
Docs
- add static folder for images + fix links #271
- Fix Readme and remove unused kernel #270
- Kernel docs #274
- Quantization Docstrings #273
- Add
AffineQuantizedTensor
based workflow doc and examples #277 - Add
AUTOQUANT_CACHE
docs for reusing the same quantization plan #329 - Update nightly build instructions #334
- add link to benchmarking script #355
- New README #392
- Minor README updates #401
- Add
quantize
to ...
v0.2.0
What's Changed
Highlights
Custom CPU/CUDA extension to ship CPU/CUDA binaries.
PyTorch core has recently shipped a new custom op registration mechanism with torch.library with the benefit being that custom ops will compose with as many PyTorch subsystems as possible most notably NOT graph breaking with torch.compile()
We'd added some documentation for how you could register your own custom ops https://github.com/pytorch/ao/tree/main/torchao/csrc and if you learn better via example you can follow this PR #135 to add your own custom ops to torchao
.
Most notably these instructions were leveraged by @gau-nernst to integrate some new custom ops for fp6
support #223
One key benefit of integrating your kernels in torchao
directly is we thanks to our manylinux
GPU support can ensure that CPU/CUDA kernels that you've added will work on as many devices and cuda versions as possible #176
A lot of prototype and community contributions
@jeromeku was our community champion merging support for
- GaLore our first pretraining kernel that allows you to finetune llama 7b on a single 4090 card with up to 70% speedups relative to eager PyTorch
- DoRA which has been shown to yield superior fine-tuning accuracy results than QLoRA. This is an area where the community can help us benchmark more thoroughly https://github.com/pytorch/ao/tree/main/torchao/prototype/dora
- Fused int4/fp16 quantized matmul which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512 https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq
@gau-nernst merged fp6 support showing up to 8x speedups on an fp16 baseline for small batch size inference #223
NF4 support for upcoming FSDP2
@weifengpy merged support for composing FSDP2 with NF4 which makes it easy to implement algorithms like QLoRA + FSDP without writing any CUDA or C++ code. This work also provides a blueprint for how to compose smaller dtypes with FSDP #150 most notably by implementing torch.chunk()
. We hope the broader community uses this work to experiment more heavily at the intersection of distributed and quantization research and inspires many more studies such as the ones done by Answer.ai https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html
BC breaking
Deprecations
New Features
- Match autoquant API with torch.compile (#109, #162, #175)
- [Prototype] 8da4w QAT (#138, #199, #198, #211, #154, #157, #229)
- [Prototype] GaLore (#95)
- [Prototype] DoRA (#216)
- [Prototype] HQQ (#153, #185)
- [Prototype] 2:4 sparse + int8 sparse subclass (#36)
- [Prototype] Unified quantization primitives (#159, #201, #193, #220, #227, #173, #210)
- [Prototype] Pruning primitives (#148, #194)
- [Prototype] AffineQuantizedTensor subclass (#214, #230, #243, #247, #251)
- [Prototype] Add
Int4WeightOnlyQuantizer
(#119) - Custom CUDA extensions (#135, #186, #232)
- [Prototype] Add FP6 Linear (#223)
Improvements
- FSDP2 support for NF4Tensor (#118, #150, #207)
- Add save/load of int8 weight only quantized model (#122)
- Add int_scaled_mm on CPU (#121)
- Add cpu and gpu in int4wo and int4wo-gptq quantizer (#131)
- Add torch.export support to int8_dq, int8_wo, int4_wo subclasses (#146, #226, #213)
- Remove
is_gpt_fast
specialization from GTPQ (#172) - Common benchmark and profile utils (#238)
Bug fixes
- Fix padding in GPTQ (#119, #120)
- Fix
Int8DynActInt4WeightLinear
module swap (#151) - Fix
NF4Tensor.to
to use device kwarg (#158) - Fix
quantize_activation_per_token_absmax
perf regression (#253)
Performance
Docs
- Update READMEs (#140, #142, #169, #155, #179, #187, #188, #200, #217, #245)
- Add https://pytorch.org/ao (#136, #145, #163, #164, #165, #168, #177, #195, #224)
CI
- Add A10G support in CI (#176)
- General CI improvements (#161, #171, #178, #180, #183, #107, #215, #244, #257, #235, #242)
- Add expecttest to requirements.txt (#225)
- Push button binary support (#241, #240, #250)
Not user facing
Security
Untopiced
New Contributors
- @Xia-Weiwen made their first contribution in #121
- @jeromeku made their first contribution in #95
- @weifengpy made their first contribution in #118
- @aakashapoorv made their first contribution in #179
- @UsingtcNower made their first contribution in #194
- @Jokeren made their first contribution in #217
- @gau-nernst made their first contribution in #223
- @janeyx99 made their first contribution in #245
- @huydhn made their first contribution in #250
- @lancerts made their first contribution in #238
Full Changelog: v0.2.0...v0.2.1
We were able to close about half of tasks for 0.2.0, which will now spill over into upcoming releases. We will post a list for 0.3.0 next, which we aim to release at the end of May 2024. We want to follow a monthly release cadence until further notice.
TorchAO 0.1.0: First Release
Highlights
We’re excited to announce the release of TorchAO v0.1.0! TorchAO is a repository to host architecture optimization techniques such as quantization and sparsity and performance kernels on different backends such as CUDA and CPU. In this release, we added support for a few quantization techniques like int4 weight only GPTQ quantization, added nf4 dtype support for QLoRA and sparsity features like WandaSparsifier, we also added autotuner that can tune triton integer matrix multiplication kernels on cuda.
Note: TorchAO is currently in a pre-release state and under extensive development. The public APIs should not be considered stable. But we welcome you to try out our APIs and offerings and provide any feedback on your experience.
torchao 0.1.0 will be compatible with PyTorch 2.2.2 and 2.3.0, ExecuTorch 0.2.0 and TorchTune 0.1.0.
New Features
Quantization
- Added tensor subclass based quantization APIs:
change_linear_weights_to_int8_dqtensors
,change_linear_weights_to_int8_woqtensors
andchange_linear_weights_to_int4_woqtensors
(#1) - Added module based quantization APIs for int8 dynamic and weight only quantization
apply_weight_only_int8_quant
andapply_dynamic_quant
(#1) - Added module swap version of int4 weight only quantization
Int4WeightOnlyQuantizer
andInt4WeightOnlyGPTQQuantizer
used in TorchTune (#119, #116) - Added int8 dynamic activation and int4 weight quantization
Int8DynActInt4WeightQuantizer
andInt8DynActInt4WeightGPTQQuantizer
, used in ExecuTorch (#74) (available after torch 2.3.0 and later)
Sparsity
- Added
WandaSparsifier
that prunes both weights and activations (#22)
Kernels
- Added
autotuner
for int mm Triton kernels (#41)
dtypes
Improvements
- Setup github workflow for regression testing (#50)
- Setup github workflow for
torchao-nightly
release (#54)
Documentation
- Added tutorials for quantizing vision transformer model (#60)
- Added tutorials for how to add an op for
nf4
tensor (#54)
Notes
- we are still debugging the accuracy problem for
Int8DynActInt4WeightGPTQQuantizer
- Save and load does not work well for tensor subclass based APIs yet
- We will consolidate tensor subclass and module swap based quantization APIs later
uint4
tensor subclass is going to be merged into pytorch core in the future- Quantization ops in
quant_primitives.py
will be deduplicated with similar quantize/dequantize ops in PyTorch later