diff --git a/download.py b/download.py index 52ec42f..3c04630 100755 --- a/download.py +++ b/download.py @@ -19,6 +19,7 @@ nltk_dir = "./nltk_data" model_name = os.getenv("MODEL_NAME", None) force_automodel = os.getenv("FORCE_AUTOMODEL", False) +trust_remote_code = os.getenv("TRUST_REMOTE_CODE", False) if not model_name: print("Fatal: MODEL_NAME is required") print( @@ -46,11 +47,13 @@ ) -def download_onnx_model(model_name: str, model_dir: str): +def download_onnx_model( + model_name: str, model_dir: str, trust_remote_code: bool = False +): # Download model and tokenizer onnx_path = Path(model_dir) ort_model = ORTModelForFeatureExtraction.from_pretrained( - model_name, from_transformers=True + model_name, from_transformers=True, trust_remote_code=trust_remote_code ) # Save model ort_model.save_pretrained(onnx_path) @@ -92,9 +95,11 @@ def quantization_config(onnx_cpu_arch: str): tokenizer.save_pretrained(onnx_path) -def download_model(model_name: str, model_dir: str): - print(f"Downloading model {model_name} from huggingface model hub") - config = AutoConfig.from_pretrained(model_name) +def download_model(model_name: str, model_dir: str, trust_remote_code: bool = False): + print( + f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})" + ) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) model_type = config.to_dict()["model_type"] if ( @@ -114,11 +119,17 @@ def download_model(model_name: str, model_dir: str): print( f"{config.architectures[0]} not found in transformers, fallback to AutoModel" ) - model = AutoModel.from_pretrained(model_name) + model = AutoModel.from_pretrained( + model_name, trust_remote_code=trust_remote_code + ) else: - model = AutoModel.from_pretrained(model_name) + model = AutoModel.from_pretrained( + model_name, trust_remote_code=trust_remote_code + ) - tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained( + model_name, trust_remote_code=trust_remote_code + ) model.save_pretrained(model_dir) tokenizer.save_pretrained(model_dir) @@ -128,6 +139,6 @@ def download_model(model_name: str, model_dir: str): if onnx_runtime == "true": - download_onnx_model(model_name, model_dir) + download_onnx_model(model_name, model_dir, trust_remote_code) else: - download_model(model_name, model_dir) + download_model(model_name, model_dir, trust_remote_code)