Skip to content

Commit

Permalink
enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 committed Dec 30, 2024
1 parent 0f8486b commit 7946757
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 64 deletions.
6 changes: 3 additions & 3 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp
unwrapper_layer(self.model, wrapper_linear, layer_name, best_params)
mv_module_from_gpu(layer, self.low_cpu_mem_usage)
dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
logger.info(dump_info)
logger.debug(dump_info)

def register_act_max_hook(self, model):
def get_act_max_hook(module, input, output):
Expand Down Expand Up @@ -1045,7 +1045,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
f"layers in the block"
)
logger.info(dump_info)
logger.debug(dump_info)
return output, output

if self.lr_scheduler is None:
Expand Down Expand Up @@ -1136,7 +1136,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
)
logger.info(dump_info)
logger.debug(dump_info)
if len(unquantized_layer_names) != 0:
logger.info(f"{unquantized_layer_names} have not been quantized")
with torch.no_grad():
Expand Down
181 changes: 131 additions & 50 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,40 +45,47 @@ def register(dataset):
return register


def get_tokenizer_function(tokenizer, seqlen, apply_template=False):
def apply_chat_templte_to_samples(samples, tokenizer, seqlen):
from jinja2 import Template
chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \
else tokenizer.default_chat_template
template = Template(chat_template)
rendered_messages = []
for text in samples:
message = [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": text}]
rendered_message = template.render(messages=message, add_generation_prompt=True, \
bos_token=tokenizer.bos_token)
rendered_messages.append(rendered_message)
example = tokenizer(rendered_messages, truncation=True, max_length=seqlen)
return example


def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=False):
"""Returns a default tokenizer function.
Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.
apply_template: Whether to apply chat template in tokenization.
apply_chat_template: Whether to apply chat template in tokenization.
Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of
seqlen to the "text" field of examples.
"""

def default_tokenizer_function(examples, apply_template=apply_template):
if not apply_template:
def default_tokenizer_function(examples, apply_chat_template=apply_chat_template):
if not apply_chat_template:
example = tokenizer(examples["text"], truncation=True, max_length=seqlen)
else:
from jinja2 import Template # pylint: disable=E0401
chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \
else tokenizer.default_chat_template
template = Template(chat_template)
rendered_messages = []
for text in examples["text"]:
message = [{"role": "user", "content": text}]
rendered_message = template.render(messages=message, add_generation_prompt=True, \
bos_token=tokenizer.bos_token)
rendered_messages.append(rendered_message)
example = tokenizer(rendered_messages, truncation=True, max_length=seqlen)
apply_chat_templte_to_samples(examples["text"], tokenizer, seqlen)
return example

return default_tokenizer_function


@register_dataset("NeelNanda/pile-10k")
def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42, apply_template=False):
def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42,
apply_chat_template=False):
"""Returns a dataloader for the specified dataset and split.
Args:
Expand All @@ -87,15 +94,15 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split
data_name: The name of the dataset.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
apply_template: Whether to apply chat template in tokenization.
apply_chat_template: Whether to apply chat template in tokenization.
Returns:
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
"""
from datasets import load_dataset

split = "train"
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template)
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template)

calib_dataset = load_dataset(dataset_name, split=split)
calib_dataset = calib_dataset.shuffle(seed=seed)
Expand All @@ -105,7 +112,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split


@register_dataset("BAAI/CCI3-HQ")
def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_template=False):
def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False):
"""Returns a dataloader for the specified dataset and split.
Args:
Expand All @@ -114,14 +121,64 @@ def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=No
data_name: The name of the dataset.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
apply_template: Whether to apply chat template in tokenization.
apply_chat_template: Whether to apply chat template in tokenization.
Returns:
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
"""
from datasets import load_dataset

tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template)
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template)

calib_dataset = load_dataset(dataset_name, split='train', streaming=True)
calib_dataset = calib_dataset.take(10000)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)

return calib_dataset


@register_dataset("codeparrot/github-code-clean")
def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="codeparrot/github-code-clean", split=None, seed=42,
apply_chat_template=False):
"""Returns a dataloader for the specified dataset and split.
Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.
data_name: The name of the dataset.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
apply_chat_template: Whether to apply chat template in tokenization.
Returns:
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
"""

def get_default_tokenizer_function(tokenizer, seqlen, apply_chat_template=False):
"""Returns a default tokenizer function.
Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.
apply_chat_template: Whether to apply chat template in tokenization.
Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of
seqlen to the "text" field of examples.
"""

def default_tokenizer_function(examples, apply_chat_template=apply_chat_template):
if not apply_chat_template:
example = tokenizer(examples["code"], truncation=True, max_length=seqlen)
else:
example = apply_chat_templte_to_samples(examples["code"], tokenizer, seqlen)
return example

return default_tokenizer_function

from datasets import load_dataset

tokenizer_function = get_default_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template)

calib_dataset = load_dataset(dataset_name, split='train', streaming=True)
calib_dataset = calib_dataset.take(10000)
Expand All @@ -138,7 +195,7 @@ def get_new_chinese_title_dataset(
dataset_name="madao33/new-title-chinese",
split=None,
seed=42,
apply_template=False
apply_chat_template=False
):
"""Returns a dataloader for the specified dataset and split.
Expand All @@ -148,47 +205,37 @@ def get_new_chinese_title_dataset(
data_name: The name of the dataset.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
apply_template: Whether to apply chat template in tokenization.
apply_chat_template: Whether to apply chat template in tokenization.
Returns:
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
"""

def get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template):
def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template):
"""Returns a default tokenizer function.
Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.
apply_template: Whether to apply chat template in tokenization.
apply_chat_template: Whether to apply chat template in tokenization.
Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length
of seqlen to the "text" field of examples.
"""

def default_tokenizer_function(examples, apply_template=apply_template):
if not apply_template:
def default_tokenizer_function(examples, apply_chat_template=apply_chat_template):
if not apply_chat_template:
example = tokenizer(examples["content"], truncation=True, max_length=seqlen)
else:
from jinja2 import Template
chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \
else tokenizer.default_chat_template
template = Template(chat_template)
rendered_messages = []
for text in examples["text"]:
message = [{"role": "user", "content": text}]
rendered_message = template.render(messages=message, add_generation_prompt=True, \
bos_token=tokenizer.bos_token)
rendered_messages.append(rendered_message)
example = tokenizer(rendered_messages, truncation=True, max_length=seqlen)
apply_chat_templte_to_samples(examples["content"], tokenizer, seqlen)
return example

return default_tokenizer_function

split = "train"
from datasets import load_dataset

tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template)
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template)

calib_dataset = load_dataset(dataset_name, split=split)
calib_dataset = calib_dataset.shuffle(seed=seed)
Expand All @@ -198,7 +245,7 @@ def default_tokenizer_function(examples, apply_template=apply_template):


@register_dataset("mbpp")
def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42, apply_template=False):
def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42, apply_chat_template=False):
"""Returns a dataloader for the specified dataset and split.
Args:
Expand All @@ -207,14 +254,14 @@ def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42
data_name: The name of the dataset.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
apply_template: Whether to apply chat template in tokenization.
apply_chat_template: Whether to apply chat template in tokenization.
Returns:
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
"""
from datasets import load_dataset

tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template)
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template)

samples = []
splits = split
Expand All @@ -237,7 +284,7 @@ def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42


@register_dataset("local")
def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None, seed=42, apply_template=False):
def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None, seed=42, apply_chat_template=False):
"""Returns a dataloader for a custom dataset and split.
We allow the input of a json or text file containing a processed text sample each line.
Expand All @@ -247,12 +294,12 @@ def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None,
data_name: The name or path of the dataset, which is a json or jsonl file.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
apply_template: Whether to apply chat template in tokenization.
apply_chat_template: Whether to apply chat template in tokenization.
Returns:
A dataloader for a custom dataset and split, using the provided tokenizer and sequence length.
"""
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template)
tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template)

def load_local_data(data_path):
if data_path.endswith(".json"):
Expand Down Expand Up @@ -300,6 +347,17 @@ def load_local_data(data_path):


def get_dataset_len(dataset):
"""Calculates the length of a dataset.
Args:
dataset: The dataset object, which can be any iterable or collection.
Returns:
int: The length of the dataset.
Raises:
If the dataset does not support `len()`, iterates through it to count the number of elements.
"""
try:
dataset_len = len(dataset)
return dataset_len
Expand All @@ -311,6 +369,19 @@ def get_dataset_len(dataset):


def select(dataset, indices):
"""Selects specific elements from a dataset based on given indices.
Args:
dataset: The dataset object to iterate over.
indices: An iterable of integers specifying the indices to select.
Yields:
Elements of the dataset corresponding to the specified indices.
Notes:
Stops iterating once the highest index in `indices` has been processed
to optimize performance.
"""
indices = set(indices)
for idx, sample in enumerate(dataset):
if idx in indices:
Expand All @@ -320,6 +391,16 @@ def select(dataset, indices):


def select_dataset(dataset, indices):
"""Selects elements from a dataset using its native `select` method, if available.
Args:
dataset: The dataset object, which may have a `select` method.
indices: An iterable of integers specifying the indices to select.
Returns:
A subset of the dataset, either using the dataset's `select` method or the
`select` function defined above as a fallback.
"""
try:
return dataset.select(indices)
except:
Expand All @@ -346,7 +427,7 @@ def get_dataloader(
seed (int, optional): The random seed for reproducibility. Defaults to 42.
bs (int, optional): The batch size. Defaults to 4.
nsamples (int, optional): The total number of samples to include. Defaults to 512.
apply_template: Whether to apply chat template in tokenization.
apply_chat_template: Whether to apply chat template in tokenization.
Returns:
DataLoader: The DataLoader for the calibrated dataset.
Expand Down Expand Up @@ -412,7 +493,7 @@ def concat_dataset_element(dataset):
for name in dataset_names:
split = None
do_concat = False
apply_template = False
apply_chat_template = False
if ":" in name:
split_list = name.split(":")
name, split_list = name.split(":")[0], name.split(":")[1:]
Expand All @@ -424,8 +505,8 @@ def concat_dataset_element(dataset):
data_lens[name] = int(values[0])
if key == "concat":
do_concat = False if (len(values) > 0 and values[0].lower() == 'false') else True
if key == "apply_template":
apply_template = False if (len(values) > 0 and values[0].lower() == 'false') else True
if key == "apply_chat_template":
apply_chat_template = False if (len(values) > 0 and values[0].lower() == 'false') else True
if is_local_path(name):
get_dataset = CALIB_DATASETS.get("local")
else:
Expand All @@ -443,7 +524,7 @@ def concat_dataset_element(dataset):
seed=seed,
split=split,
dataset_name=name,
apply_template=apply_template,
apply_chat_template=apply_chat_template,
)
if not isinstance(dataset, IterableDataset):
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
Expand Down
Loading

0 comments on commit 7946757

Please sign in to comment.