Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable autoquant for CPU userbenchmark #2547

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
choices=["int8dynamic", "int8weightonly", "int4weightonly", "auto_quant"],
help="Apply quantization to the model before running it",
)
parser.add_argument(
Expand Down Expand Up @@ -183,25 +183,36 @@ def apply_torchdynamo_args(

if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
)
if model.device == "cuda":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep original behavior in here, suggest to change as

if model.device == "cpu":
    xxx
else:
    original path

from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
)

torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
elif model.device == "cpu" and model.test == "eval":
if args.quantization == "auto_quant":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add else statement to warining that the other quant modes are still not support for cpu device

module, example_inputs = model.get_module()
with torch.no_grad():
module=torchao.autoquant(torch.compile(module, mode='max-autotune'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need compile the module before autoquant? and why we need to hard code the compile mode as max-autotune? Can we follow similar style with original part?

if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
model.set_module(module)

if args.freeze_prepack_weights:
torch._inductor.config.freezing = True
Expand Down