Skip to content

Commit

Permalink
Add support for TRUST_REMOTE_CODE in download ONNX model method
Browse files Browse the repository at this point in the history
  • Loading branch information
antas-marcin committed Oct 4, 2024
1 parent c405a7e commit d7d8312
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import sys
import nltk
from config import TRUST_REMOTE_CODE
from transformers import (
AutoModel,
AutoTokenizer,
Expand All @@ -20,6 +19,7 @@
nltk_dir = "./nltk_data"
model_name = os.getenv("MODEL_NAME", None)
force_automodel = os.getenv("FORCE_AUTOMODEL", False)
trust_remote_code = os.getenv("TRUST_REMOTE_CODE", False)
if not model_name:
print("Fatal: MODEL_NAME is required")
print(
Expand Down Expand Up @@ -47,11 +47,13 @@
)


def download_onnx_model(model_name: str, model_dir: str):
def download_onnx_model(
model_name: str, model_dir: str, trust_remote_code: bool = False
):
# Download model and tokenizer
onnx_path = Path(model_dir)
ort_model = ORTModelForFeatureExtraction.from_pretrained(
model_name, from_transformers=True
model_name, from_transformers=True, trust_remote_code=trust_remote_code
)
# Save model
ort_model.save_pretrained(onnx_path)
Expand Down Expand Up @@ -92,10 +94,13 @@ def quantization_config(onnx_cpu_arch: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(onnx_path)


def download_model(model_name: str, model_dir: str, trust_remote_code: bool = False):
print(f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})")
print(
f"Downloading model {model_name} from huggingface model hub ({trust_remote_code=})"
)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model_type = config.to_dict()['model_type']
model_type = config.to_dict()["model_type"]

if (
model_type is not None and model_type == "t5"
Expand All @@ -111,12 +116,20 @@ def download_model(model_name: str, model_dir: str, trust_remote_code: bool = Fa
klass_architecture = getattr(mod, config.architectures[0])
model = klass_architecture.from_pretrained(model_name)
except AttributeError:
print(f"{config.architectures[0]} not found in transformers, fallback to AutoModel")
model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)
print(
f"{config.architectures[0]} not found in transformers, fallback to AutoModel"
)
model = AutoModel.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)
else:
model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model = AutoModel.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)

model.save_pretrained(model_dir)
tokenizer.save_pretrained(model_dir)
Expand All @@ -126,6 +139,6 @@ def download_model(model_name: str, model_dir: str, trust_remote_code: bool = Fa


if onnx_runtime == "true":
download_onnx_model(model_name, model_dir)
download_onnx_model(model_name, model_dir, trust_remote_code)
else:
download_model(model_name, model_dir, TRUST_REMOTE_CODE)
download_model(model_name, model_dir, trust_remote_code)

0 comments on commit d7d8312

Please sign in to comment.