diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index afc6d158..7392e666 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -406,13 +406,14 @@ class HuggingFaceURLChatModel(HFBaseChatModel): def __init__( self, model_name: str, + base_model_name: str, model_url: str, token: Optional[str] = None, temperature: Optional[int] = 1e-1, max_new_tokens: Optional[int] = 512, n_retry_server: Optional[int] = 4, ): - super().__init__(model_name, n_retry_server) + super().__init__(model_name, base_model_name, n_retry_server) if temperature < 1e-3: logging.warning("Models might behave weirdly when temperature is too low.") self.temperature = temperature diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 364221b5..32f12082 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -40,14 +40,17 @@ class HFBaseChatModel(AbstractChatModel): description="The number of times to retry the server if it fails to respond", ) - def __init__(self, model_name, n_retry_server): + def __init__(self, model_name, base_model_name, n_retry_server): super().__init__() self.n_retry_server = n_retry_server - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + if base_model_name is None: + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + else: + self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) if isinstance(self.tokenizer, GPT2TokenizerFast): logging.warning( - f"No chat template is defined for {model_name}. Resolving to the hard-coded templates." + f"No chat template is defined for {base_model_name}. Resolving to the hard-coded templates." ) self.tokenizer = None self.prompt_template = get_prompt_template(model_name)