Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <heng.guo@intel.com>
  • Loading branch information
n1ck-guo committed Dec 27, 2024
1 parent 1562f39 commit c296777
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
9 changes: 4 additions & 5 deletions auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Optional, Union
from tqdm import tqdm
from copy import deepcopy

import torch

Expand Down Expand Up @@ -42,13 +43,11 @@ def _only_text_test(model, tokenizer, device, model_type):

device = detect_device(device)
text = ["only text", "test"]
ori_padding_size = tokenizer.padding_side
tokenizer.padding_side = 'left'
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
tokenizer.padding_size = ori_padding_size

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

try:
inputs = inputs.to(device)
model = model.to(device)
Expand Down Expand Up @@ -183,7 +182,7 @@ def __init__(
if isinstance(dataset, str):
if quant_nontext_module or \
(dataset in CALIB_DATASETS.keys() and not \
_only_text_test(model, tokenizer, device, self.template.model_type)):
_only_text_test(model, deepcopy(tokenizer), device, self.template.model_type)):
if quant_nontext_module:
logger.warning(f"Text only dataset cannot be used for calibrating non-text modules,"
"switching to liuhaotian/llava_conv_58k")
Expand Down
2 changes: 1 addition & 1 deletion test_cuda/test_support_vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_72b(self):
model_path = "/data5/models/Qwen2-VL-72B-Instruct/"
res = os.system(
f"cd .. && {self.python_path} -m auto_round --mllm "
f"--model {model_path} --iter 1 --nsamples 1 --output_dir {self.save_dir} --device {self.device}"
f"--model {model_path} --iter 1 --nsamples 1 --bs 1 --output_dir {self.save_dir} --device {self.device}"
)
self.assertFalse(res > 0 or res == -1, msg="qwen2-72b tuning fail")
shutil.rmtree(quantized_model_path, ignore_errors=True)
Expand Down

0 comments on commit c296777

Please sign in to comment.