Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: yintong-lu <yintong.lu@intel.com>
  • Loading branch information
yintong-lu committed Jun 5, 2024
1 parent 4513346 commit 38e195d
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def concat_dataset_element(dataset):
split = None
do_concat = False
if ":" in name:
# name, split = name.split(":")
split_list = name.split(":")
name, split_list = name.split(":")[0], name.split(":")[1:]
do_concat = 'concat' in split_list
Expand All @@ -297,13 +296,8 @@ def concat_dataset_element(dataset):
dataset_name=name,
)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
import copy
tensor_compare = copy.deepcopy(dataset[0]['input_ids'])
logger.info(f"lyt_debug dataset1:{len(dataset)} {dataset[0].keys()} {[type(dataset[0][key]) for key in dataset[0].keys()]}, {dataset[0]['input_ids'].shape}, {dataset[0]['input_ids'][0]}")
if do_concat:
dataset = concat_dataset_element(dataset)
logger.info(f"lyt_debug dataset2:{len(dataset)} {dataset[0].keys()} {[type(dataset[0][key]) for key in dataset[0].keys()]} {len(dataset[0]['input_ids'])}, {dataset[0]['input_ids'][0]}")
logger.info(f"lyt_debug compare: {torch.all(tensor_compare == torch.tensor(dataset[0]['input_ids']))}")
dataset = dataset.filter(filter_func)

datasets.append(dataset)
Expand Down

0 comments on commit 38e195d

Please sign in to comment.