Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/intel/auto-round
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 committed May 29, 2024
2 parents bda3da9 + 4db22e1 commit a4de240
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
##from auto_round.auto_quantizer import AutoHfQuantizer ## uncomment it for models with quantized lm-head

##from auto_round.export import AutoHfQuantizer ## uncomment it for models with quantized lm-head
quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path, use_fast=True)
Expand Down
2 changes: 1 addition & 1 deletion auto_round/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .autoround import AutoRound, AutoAdamRound, AutoOPTRound
from .version import __version__
from .version import __version__
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from packaging import version
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import Conv1D
import transformers
from transformers.quantizers import AutoQuantizationConfig, HfQuantizer
from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING
from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod
Expand Down Expand Up @@ -159,7 +160,7 @@ def merge_quantization_configs(
if "auto-round" in quantization_config["quant_method"]:
quantization_config = AutoRoundConfig.from_dict(quantization_config)
else:
quantization_config = AutoQuantizationConfig.from_dict(quantization_config) # pylint: disable=E1101
quantization_config = AutoQuantizationConfig.from_dict(quantization_config) # pylint: disable=E1101

if isinstance(quantization_config, (GPTQConfig, AwqConfig)) and quantization_config_from_args is not None:
# special case for GPTQ / AWQ config collision
Expand Down Expand Up @@ -285,7 +286,7 @@ def convert_model(self, model: nn.Module):
model (`nn.Module`):
Model to be converted
"""
from .export_to_autoround import get_layer_names_in_block
from auto_round.utils import get_layer_names_in_block

layer_names = get_layer_names_in_block(model)
quantization_config = model.config.quantization_config
Expand Down Expand Up @@ -334,7 +335,7 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
data_type = config["data_type"]
if not (bits <= 8 and data_type == "int"):
continue
from .export_to_autoround import get_autogptq_backend_config
from auto_round.export.export_to_autoround.export_to_autoround import get_autogptq_backend_config

use_triton, disable_exllama, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
backend, bits
Expand Down Expand Up @@ -425,3 +426,14 @@ def is_trainable(self, model: Optional["PreTrainedModel"] = None):
@property
def is_serializable(self):
return True


import transformers

transformers_version = [int(item) for item in transformers.__version__.split('.')[:2]]
if transformers_version[0] == 4 and transformers_version[1] < 38:
logger.error("Please upgrade transformers>=4.38.0 to support lm-head quantization")

transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer
from transformers import AutoModelForCausalLM as AutoModelForCausalLM
2 changes: 1 addition & 1 deletion auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,7 @@ def save_quantized(self, output_dir=None, format="auto_gptq", inplace=True, **kw
)
return compressed_model

def get_layer_names_in_block(self):
def get_layer_names_in_block(self): ##TODO consolidate with utils
"""Retrieves the names of layers within each block of the model.
Returns:
Expand Down
3 changes: 2 additions & 1 deletion auto_round/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from .export_to_autogptq import save_quantized_as_autogptq
from .export_to_itrex import save_quantized_as_itrex, QuantConfig
from .export_to_autoround.export_to_autoround import save_quantized_as_autoround
from .export_to_autoround import AutoHfQuantizer


8 changes: 1 addition & 7 deletions auto_round/export/export_to_autoround/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import transformers
from .autoround_quantizer import AutoHfQuantizer

transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
transformers.quantizers.auto.AutoQuantizationConfig = AutoHfQuantizer
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer
from transformers import AutoModelForCausalLM as AutoRoundModelForCausalLM
from .export_to_autoround import save_quantized_as_autoround

24 changes: 1 addition & 23 deletions auto_round/export/export_to_autoround/export_to_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,9 @@
import transformers

from auto_round.export.register import register_format
from auto_round.utils import get_block_names, get_module, logger, set_module
from auto_round.utils import get_layer_names_in_block, get_block_names, get_module, logger, set_module


def get_layer_names_in_block(model, supported_types=[torch.nn.Linear, transformers.modeling_utils.Conv1D]):
"""Retrieves the names of layers within each block of the model.
Returns:
list: A list of strings, where each string is the name of a layer
within a block of the model.
"""
for n, m in model.named_modules():
if isinstance(m, tuple(supported_types)):
m.tmp_name = n
layers_in_block = []
block_names = get_block_names(model)
for block_name in block_names:
block = get_module(model, block_name)
for n, m in block.named_modules():
if hasattr(m, "tmp_name"):
layers_in_block.append(m.tmp_name)
for n, m in model.named_modules():
if hasattr(m, "tmp_name"):
delattr(m, "tmp_name")
return layers_in_block


def check_neq_config(config, data_type, bits, group_size, sym):
res = []
Expand Down
25 changes: 24 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
logger.addHandler(fh)

import importlib

import transformers

class LazyImport(object):
"""Lazy import python module till use."""
Expand Down Expand Up @@ -739,3 +739,26 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs):
bs = 1

return False, seqlen, bs


def get_layer_names_in_block(model, supported_types=[torch.nn.Linear, transformers.modeling_utils.Conv1D]):
"""Retrieves the names of layers within each block of the model.
Returns:
list: A list of strings, where each string is the name of a layer
within a block of the model.
"""
for n, m in model.named_modules():
if isinstance(m, tuple(supported_types)):
m.tmp_name = n
layers_in_block = []
block_names = get_block_names(model)
for block_name in block_names:
block = get_module(model, block_name)
for n, m in block.named_modules():
if hasattr(m, "tmp_name"):
layers_in_block.append(m.tmp_name)
for n, m in model.named_modules():
if hasattr(m, "tmp_name"):
delattr(m, "tmp_name")
return layers_in_block
2 changes: 1 addition & 1 deletion examples/language-modeling/eval_042/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def evaluate(
if hasattr(config, "quantization_config"):
quantization_config = config.quantization_config
if "quant_method" in quantization_config and "auto-round" in quantization_config["quant_method"]:
from auto_round.export.export_to_autoround import AutoHfQuantizer
from auto_round.auto_quantizer import AutoHfQuantizer

test_tasks = args.tasks
if isinstance(test_tasks, str):
Expand Down
8 changes: 6 additions & 2 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
help="enable_minmax_tuning is deprecated")

parser.add_argument("--deployment_device", default='fake', type=str,
help="targeted inference acceleration platform,The options are 'fake', 'cpu' and 'gpu'."
help="targeted inference acceleration platform,The options are 'fake', 'cpu', 'gpu' and 'xpu'."
"default to 'fake', indicating that it only performs fake quantization and won't be exported to any device.")

parser.add_argument("--scale_dtype", default='fp16',
Expand Down Expand Up @@ -151,7 +151,6 @@ def get_library_version(library_name):
except subprocess.CalledProcessError:
return "Library not found"


res = get_library_version("lm-eval")
if res == "0.3.0":
use_eval_legacy = True
Expand Down Expand Up @@ -291,6 +290,10 @@ def get_library_version(library_name):
break
if args.quant_lm_head:
weight_config[lm_head_layer_name] = {"data_type": "int"}
transformers_version = [int(item) for item in transformers.__version__.split('.')[:2]]
if transformers_version[0] == 4 and transformers_version[1] < 38:
error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization."
raise EnvironmentError(error_message)

if args.quant_lm_head and not args.disable_low_gpu_mem_usage:
print(f"warning, disable_low_gpu_mem_usage is strongly recommended if the whole model could be loaded to "
Expand Down Expand Up @@ -340,3 +343,4 @@ def get_library_version(library_name):
eval_model(model_path=output_dir, tasks=tasks, dtype=dtype, limit=None,
eval_bs=args.eval_bs, use_accelerate=not args.disable_low_gpu_mem_usage,
device=torch_device, excel_file=excel_name)

0 comments on commit a4de240

Please sign in to comment.