From 7946757e7dd70ffa61916306d98ef1ce4dac8fcf Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Mon, 30 Dec 2024 10:16:10 +0800 Subject: [PATCH] enhancement --- auto_round/autoround.py | 6 +- auto_round/calib_dataset.py | 181 ++++++++++++++++++++++++++---------- auto_round/utils.py | 15 +++ docs/step_by_step.md | 25 ++--- test/test_calib_dataset.py | 11 +++ 5 files changed, 174 insertions(+), 64 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index bd9ef47b..68ff8b1f 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -965,7 +965,7 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp unwrapper_layer(self.model, wrapper_linear, layer_name, best_params) mv_module_from_gpu(layer, self.low_cpu_mem_usage) dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" - logger.info(dump_info) + logger.debug(dump_info) def register_act_max_hook(self, model): def get_act_max_hook(module, input, output): @@ -1045,7 +1045,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " f"layers in the block" ) - logger.info(dump_info) + logger.debug(dump_info) return output, output if self.lr_scheduler is None: @@ -1136,7 +1136,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" ) - logger.info(dump_info) + logger.debug(dump_info) if len(unquantized_layer_names) != 0: logger.info(f"{unquantized_layer_names} have not been quantized") with torch.no_grad(): diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index 62051cf4..dfb1176b 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -45,40 +45,47 @@ def register(dataset): return register -def get_tokenizer_function(tokenizer, seqlen, apply_template=False): +def apply_chat_templte_to_samples(samples, tokenizer, seqlen): + from jinja2 import Template + chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \ + else tokenizer.default_chat_template + template = Template(chat_template) + rendered_messages = [] + for text in samples: + message = [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": text}] + rendered_message = template.render(messages=message, add_generation_prompt=True, \ + bos_token=tokenizer.bos_token) + rendered_messages.append(rendered_message) + example = tokenizer(rendered_messages, truncation=True, max_length=seqlen) + return example + + +def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=False): """Returns a default tokenizer function. Args: tokenizer: The tokenizer to be used for tokenization. seqlen: The maximum sequence length. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of seqlen to the "text" field of examples. """ - def default_tokenizer_function(examples, apply_template=apply_template): - if not apply_template: + def default_tokenizer_function(examples, apply_chat_template=apply_chat_template): + if not apply_chat_template: example = tokenizer(examples["text"], truncation=True, max_length=seqlen) else: - from jinja2 import Template # pylint: disable=E0401 - chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \ - else tokenizer.default_chat_template - template = Template(chat_template) - rendered_messages = [] - for text in examples["text"]: - message = [{"role": "user", "content": text}] - rendered_message = template.render(messages=message, add_generation_prompt=True, \ - bos_token=tokenizer.bos_token) - rendered_messages.append(rendered_message) - example = tokenizer(rendered_messages, truncation=True, max_length=seqlen) + apply_chat_templte_to_samples(examples["text"], tokenizer, seqlen) return example return default_tokenizer_function @register_dataset("NeelNanda/pile-10k") -def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42, apply_template=False): +def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42, + apply_chat_template=False): """Returns a dataloader for the specified dataset and split. Args: @@ -87,7 +94,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split data_name: The name of the dataset. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. @@ -95,7 +102,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split from datasets import load_dataset split = "train" - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) calib_dataset = load_dataset(dataset_name, split=split) calib_dataset = calib_dataset.shuffle(seed=seed) @@ -105,7 +112,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split @register_dataset("BAAI/CCI3-HQ") -def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_template=False): +def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False): """Returns a dataloader for the specified dataset and split. Args: @@ -114,14 +121,64 @@ def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=No data_name: The name of the dataset. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. """ from datasets import load_dataset - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) + + calib_dataset = load_dataset(dataset_name, split='train', streaming=True) + calib_dataset = calib_dataset.take(10000) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + +@register_dataset("codeparrot/github-code-clean") +def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="codeparrot/github-code-clean", split=None, seed=42, + apply_chat_template=False): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + + def get_default_tokenizer_function(tokenizer, seqlen, apply_chat_template=False): + """Returns a default tokenizer function. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of + seqlen to the "text" field of examples. + """ + + def default_tokenizer_function(examples, apply_chat_template=apply_chat_template): + if not apply_chat_template: + example = tokenizer(examples["code"], truncation=True, max_length=seqlen) + else: + example = apply_chat_templte_to_samples(examples["code"], tokenizer, seqlen) + return example + + return default_tokenizer_function + + from datasets import load_dataset + + tokenizer_function = get_default_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) calib_dataset = load_dataset(dataset_name, split='train', streaming=True) calib_dataset = calib_dataset.take(10000) @@ -138,7 +195,7 @@ def get_new_chinese_title_dataset( dataset_name="madao33/new-title-chinese", split=None, seed=42, - apply_template=False + apply_chat_template=False ): """Returns a dataloader for the specified dataset and split. @@ -148,39 +205,29 @@ def get_new_chinese_title_dataset( data_name: The name of the dataset. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. """ - def get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template): + def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template): """Returns a default tokenizer function. Args: tokenizer: The tokenizer to be used for tokenization. seqlen: The maximum sequence length. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of seqlen to the "text" field of examples. """ - def default_tokenizer_function(examples, apply_template=apply_template): - if not apply_template: + def default_tokenizer_function(examples, apply_chat_template=apply_chat_template): + if not apply_chat_template: example = tokenizer(examples["content"], truncation=True, max_length=seqlen) else: - from jinja2 import Template - chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \ - else tokenizer.default_chat_template - template = Template(chat_template) - rendered_messages = [] - for text in examples["text"]: - message = [{"role": "user", "content": text}] - rendered_message = template.render(messages=message, add_generation_prompt=True, \ - bos_token=tokenizer.bos_token) - rendered_messages.append(rendered_message) - example = tokenizer(rendered_messages, truncation=True, max_length=seqlen) + apply_chat_templte_to_samples(examples["content"], tokenizer, seqlen) return example return default_tokenizer_function @@ -188,7 +235,7 @@ def default_tokenizer_function(examples, apply_template=apply_template): split = "train" from datasets import load_dataset - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) calib_dataset = load_dataset(dataset_name, split=split) calib_dataset = calib_dataset.shuffle(seed=seed) @@ -198,7 +245,7 @@ def default_tokenizer_function(examples, apply_template=apply_template): @register_dataset("mbpp") -def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42, apply_template=False): +def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42, apply_chat_template=False): """Returns a dataloader for the specified dataset and split. Args: @@ -207,14 +254,14 @@ def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42 data_name: The name of the dataset. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. """ from datasets import load_dataset - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) samples = [] splits = split @@ -237,7 +284,7 @@ def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42 @register_dataset("local") -def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None, seed=42, apply_template=False): +def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None, seed=42, apply_chat_template=False): """Returns a dataloader for a custom dataset and split. We allow the input of a json or text file containing a processed text sample each line. @@ -247,12 +294,12 @@ def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None, data_name: The name or path of the dataset, which is a json or jsonl file. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for a custom dataset and split, using the provided tokenizer and sequence length. """ - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) def load_local_data(data_path): if data_path.endswith(".json"): @@ -300,6 +347,17 @@ def load_local_data(data_path): def get_dataset_len(dataset): + """Calculates the length of a dataset. + + Args: + dataset: The dataset object, which can be any iterable or collection. + + Returns: + int: The length of the dataset. + + Raises: + If the dataset does not support `len()`, iterates through it to count the number of elements. + """ try: dataset_len = len(dataset) return dataset_len @@ -311,6 +369,19 @@ def get_dataset_len(dataset): def select(dataset, indices): + """Selects specific elements from a dataset based on given indices. + + Args: + dataset: The dataset object to iterate over. + indices: An iterable of integers specifying the indices to select. + + Yields: + Elements of the dataset corresponding to the specified indices. + + Notes: + Stops iterating once the highest index in `indices` has been processed + to optimize performance. + """ indices = set(indices) for idx, sample in enumerate(dataset): if idx in indices: @@ -320,6 +391,16 @@ def select(dataset, indices): def select_dataset(dataset, indices): + """Selects elements from a dataset using its native `select` method, if available. + + Args: + dataset: The dataset object, which may have a `select` method. + indices: An iterable of integers specifying the indices to select. + + Returns: + A subset of the dataset, either using the dataset's `select` method or the + `select` function defined above as a fallback. + """ try: return dataset.select(indices) except: @@ -346,7 +427,7 @@ def get_dataloader( seed (int, optional): The random seed for reproducibility. Defaults to 42. bs (int, optional): The batch size. Defaults to 4. nsamples (int, optional): The total number of samples to include. Defaults to 512. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: DataLoader: The DataLoader for the calibrated dataset. @@ -412,7 +493,7 @@ def concat_dataset_element(dataset): for name in dataset_names: split = None do_concat = False - apply_template = False + apply_chat_template = False if ":" in name: split_list = name.split(":") name, split_list = name.split(":")[0], name.split(":")[1:] @@ -424,8 +505,8 @@ def concat_dataset_element(dataset): data_lens[name] = int(values[0]) if key == "concat": do_concat = False if (len(values) > 0 and values[0].lower() == 'false') else True - if key == "apply_template": - apply_template = False if (len(values) > 0 and values[0].lower() == 'false') else True + if key == "apply_chat_template": + apply_chat_template = False if (len(values) > 0 and values[0].lower() == 'false') else True if is_local_path(name): get_dataset = CALIB_DATASETS.get("local") else: @@ -443,7 +524,7 @@ def concat_dataset_element(dataset): seed=seed, split=split, dataset_name=name, - apply_template=apply_template, + apply_chat_template=apply_chat_template, ) if not isinstance(dataset, IterableDataset): dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) diff --git a/auto_round/utils.py b/auto_round/utils.py index 4db95899..94e73b28 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -1119,6 +1119,21 @@ def get_fp_layer_names(model, fp_layers): def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=None): + """Checks if a model is compatible with the AutoAWQ GEMM kernel. + + Args: + model: The model object to evaluate, typically a PyTorch model. + bits (int): The number of bits for quantization (must be 4 for compatibility). + group_size (int): The group size for quantization. + sym (bool): Whether symmetric quantization is used (not utilized in the current function logic). + layer_configs (dict, optional): A dictionary mapping layer names to configurations, where each + configuration can specify a custom number of bits for the layer. + + Returns: + tuple: A tuple containing: + - bool: `True` if the model is compatible, `False` otherwise. + - str: An error message describing why the model is incompatible, or an empty string if compatible. + """ if bits != 4: return False, f"AutoAWQ GEMM kernel only supports 4 bits" for n, m in model.named_modules(): diff --git a/docs/step_by_step.md b/docs/step_by_step.md index e66a1ffb..1d8fbd87 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -45,16 +45,19 @@ See more about loading [huggingface dataset](https://huggingface.co/docs/dataset tokens.append(token) return tokens ~~~ - -We support combination of different datasets and parametrization of calibration datasets by using "--dataset ./tmp.json: -concat,NeelNanda/pile-10k:split=train+val:num=256,mbpp:concat=True:num=128:apply_template". Both local calibration file -and huggingface dataset are supported. Through parametrization, users could specify splits of a dataset by setting " -split=split1+split2". A concatenation option could enable users to merge calibration samples, a process commonly used to -enhance calibration reliability. An 'apply_template' option would enable users to apply chat_template to calibration -data before tokenization and is widely used by instruct-models in generation. Please note that samples shorter than -args.seqlen will be dropped when concatenation option is not enabled. -Please use ',' to split datasets, ':' to split parameters of a dataset and '+' to add values for one targeted parameter. - + **Dataset combination**:We support combination of different datasets and parametrization of calibration datasets by using "--dataset ./tmp.json: + concat,NeelNanda/pile-10k:split=train+val:num=256,mbpp:concat=True:num=128:apply_chat_template". Both local calibration file + and huggingface dataset are supported. Through parametrization, users could specify splits of a dataset by setting " + split=split1+split2". + + **Samples concatenation**: A concatenation option could enable users to merge calibration samples. '--dataset NeelNanda/pile-10k:concat=True' + + **Apply chat template**: '--dataset NeelNanda/pile-10k:apply_chat_template' would enable users to apply chat_template to calibration + data before tokenization and is widely used by instruct-models in generation. Please note that samples shorter than + args.seqlen will be dropped when concatenation option is not enabled. + + Please use ',' to split datasets, ':' to split parameters of a dataset and '+' to add values for one targeted parameter. +
@@ -128,7 +131,7 @@ Please use ',' to split datasets, ':' to split parameters of a dataset and '+' t - To leverage auto-gptq marlin kernel, you need to install auto-gptq from source and export the model without sharding. ```bash - auto-round --model facebook/opt-125m --sym --bits 4 --group_size 128 --format "gptq:marlin" + auto-round --model facebook/opt-125m --sym --bits 4 --group_size 128 --format "auto_gptq:marlin" ``` - **Utilize the AdamW Optimizer:** diff --git a/test/test_calib_dataset.py b/test/test_calib_dataset.py index 7c5f3ac3..291d535c 100644 --- a/test/test_calib_dataset.py +++ b/test/test_calib_dataset.py @@ -70,6 +70,17 @@ def test_jsonl(self): ) autoround.quantize() + def test_apply_chat_template(self): + model_name = "Qwen/Qwen2.5-0.5B-Instruct" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + dataset = "NeelNanda/pile-10k:apply_chat_template" + bits, group_size, sym = 4, 128, True + autoround = AutoRound( + model, tokenizer, bits=bits, group_size=group_size, sym=sym, iters=2, seqlen=128, dataset=dataset + ) + autoround.quantize() + def test_combine_dataset(self): dataset = "NeelNanda/pile-10k" + "," + "madao33/new-title-chinese" + "," + "mbpp" bits, group_size, sym = 4, 128, True