diff --git a/download.py b/download.py index 74efda9..3c04630 100755 --- a/download.py +++ b/download.py @@ -3,7 +3,6 @@ import os import sys import nltk -from config import TRUST_REMOTE_CODE from transformers import ( AutoModel, AutoTokenizer, @@ -20,6 +19,7 @@ nltk_dir = "./nltk_data" model_name = os.getenv("MODEL_NAME", None) force_automodel = os.getenv("FORCE_AUTOMODEL", False) +trust_remote_code = os.getenv("TRUST_REMOTE_CODE", False) if not model_name: print("Fatal: MODEL_NAME is required") print( @@ -47,11 +47,13 @@ ) -def download_onnx_model(model_name: str, model_dir: str): +def download_onnx_model( + model_name: str, model_dir: str, trust_remote_code: bool = False +): # Download model and tokenizer onnx_path = Path(model_dir) ort_model = ORTModelForFeatureExtraction.from_pretrained( - model_name, from_transformers=True + model_name, from_transformers=True, trust_remote_code=trust_remote_code ) # Save model ort_model.save_pretrained(onnx_path) @@ -92,10 +94,13 @@ def quantization_config(onnx_cpu_arch: str): tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.save_pretrained(onnx_path) + def download_model(model_name: str, model_dir: str, trust_remote_code: bool = False): - print(f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})") + print( + f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})" + ) config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) - model_type = config.to_dict()['model_type'] + model_type = config.to_dict()["model_type"] if ( model_type is not None and model_type == "t5" @@ -111,12 +116,20 @@ def download_model(model_name: str, model_dir: str, trust_remote_code: bool = Fa klass_architecture = getattr(mod, config.architectures[0]) model = klass_architecture.from_pretrained(model_name) except AttributeError: - print(f"{config.architectures[0]} not found in transformers, fallback to AutoModel") - model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code) + print( + f"{config.architectures[0]} not found in transformers, fallback to AutoModel" + ) + model = AutoModel.from_pretrained( + model_name, trust_remote_code=trust_remote_code + ) else: - model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code) + model = AutoModel.from_pretrained( + model_name, trust_remote_code=trust_remote_code + ) - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code) + tokenizer = AutoTokenizer.from_pretrained( + model_name, trust_remote_code=trust_remote_code + ) model.save_pretrained(model_dir) tokenizer.save_pretrained(model_dir) @@ -126,6 +139,6 @@ def download_model(model_name: str, model_dir: str, trust_remote_code: bool = Fa if onnx_runtime == "true": - download_onnx_model(model_name, model_dir) + download_onnx_model(model_name, model_dir, trust_remote_code) else: - download_model(model_name, model_dir, TRUST_REMOTE_CODE) + download_model(model_name, model_dir, trust_remote_code)