Skip to content

Commit

Permalink
Merge branch 'main' into update_1227
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Dec 30, 2024
2 parents 664de12 + 42a6eb9 commit fd87065
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 13 deletions.
42 changes: 30 additions & 12 deletions auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Optional, Union
from tqdm import tqdm
from copy import deepcopy

import torch

Expand All @@ -24,28 +25,45 @@
to_dtype,
get_multimodal_block_names,
find_matching_blocks,
extract_block_names_to_str
extract_block_names_to_str,
clear_memory
)
from ..autoround import AutoRound
from .template import get_template, Template
from auto_round.special_model_handler import SUPPORT_ONLY_TEXT_MODELS
from .mllm_dataset import get_mllm_dataloader
from ..low_cpu_mem.utils import get_layers_before_block


def _only_text_test(model, tokenizer, device):
def _only_text_test(model, tokenizer, device, model_type):
"""Test if the model whether can use text-only datasets."""

if model_type in SUPPORT_ONLY_TEXT_MODELS: # save time
return True

new_tokenizer = deepcopy(tokenizer)
device = detect_device(device)
text = ["only text", "test"]
new_tokenizer.padding_side = 'left'
if new_tokenizer.pad_token is None:
new_tokenizer.pad_token = new_tokenizer.eos_token
inputs = new_tokenizer(text, return_tensors="pt", padding=True, truncation=True)

try:
device = detect_device(device)
text = ["only text", "test"]
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if device.split(':')[0] != model.device.type:
model = model.to(device)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
inputs = inputs.to(device)
model = model.to(device)
model(**inputs)
return True
except:
except RuntimeError as e:
if "CUDA out of memory" in str(e):
model = model.to("cpu")
inputs = inputs.to("cpu")
try:
model(**inputs)
except:
return False
return False
except Exception as e:
return False


Expand Down Expand Up @@ -165,7 +183,7 @@ def __init__(
if isinstance(dataset, str):
if quant_nontext_module or \
(dataset in CALIB_DATASETS.keys() and not \
_only_text_test(model, tokenizer, device)):
_only_text_test(model, tokenizer, device, self.template.model_type)):
if quant_nontext_module:
logger.warning(f"Text only dataset cannot be used for calibrating non-text modules,"
"switching to liuhaotian/llava_conv_58k")
Expand Down
3 changes: 2 additions & 1 deletion auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def tune(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.device = ",".join(map(str, range(len(devices))))
devices = args.device.replace(" ", "").split(',')
use_auto_mapping = True
if len(devices) > 1:
use_auto_mapping = True ##for 70B model on single card, use auto will cause some layer offload to cpu
elif args.device == "auto":
use_auto_mapping == True

Expand Down
10 changes: 10 additions & 0 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size
skippable_cache_keys = ("past_key_value",)

SUPPORT_ONLY_TEXT_MODELS = [
"phi3_v",
"cogvlm2",
"llava",
"qwen2_vl",
"deepseek_vl_v2",
"chatglm",
"idefics3"
]

def to_device(input, device=torch.device("cpu")):
"""Moves input data to the specified device.
Expand Down
9 changes: 9 additions & 0 deletions test_cuda/test_support_vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ def test_cogvlm(self):
response = response.split("<|end_of_text|>")[0]
print(response)
shutil.rmtree(quantized_model_path, ignore_errors=True)

def test_72b(self):
model_path = "/data5/models/Qwen2-VL-72B-Instruct/"
res = os.system(
f"cd .. && {self.python_path} -m auto_round --mllm "
f"--model {model_path} --iter 1 --nsamples 1 --bs 1 --output_dir {self.save_dir} --device {self.device}"
)
self.assertFalse(res > 0 or res == -1, msg="qwen2-72b tuning fail")
shutil.rmtree(self.save_dir, ignore_errors=True)

if __name__ == "__main__":
unittest.main()

0 comments on commit fd87065

Please sign in to comment.