Skip to content

Commit

Permalink
Refactor HuggingFace model initialization to include base model name …
Browse files Browse the repository at this point in the history
…and update tokenizer logic (#190)
  • Loading branch information
jardinetsouffleton authored Dec 20, 2024
1 parent 64c8bc9 commit 0e060fc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,14 @@ class HuggingFaceURLChatModel(HFBaseChatModel):
def __init__(
self,
model_name: str,
base_model_name: str,
model_url: str,
token: Optional[str] = None,
temperature: Optional[int] = 1e-1,
max_new_tokens: Optional[int] = 512,
n_retry_server: Optional[int] = 4,
):
super().__init__(model_name, n_retry_server)
super().__init__(model_name, base_model_name, n_retry_server)
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")
self.temperature = temperature
Expand Down
9 changes: 6 additions & 3 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ class HFBaseChatModel(AbstractChatModel):
description="The number of times to retry the server if it fails to respond",
)

def __init__(self, model_name, n_retry_server):
def __init__(self, model_name, base_model_name, n_retry_server):
super().__init__()
self.n_retry_server = n_retry_server

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if base_model_name is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if isinstance(self.tokenizer, GPT2TokenizerFast):
logging.warning(
f"No chat template is defined for {model_name}. Resolving to the hard-coded templates."
f"No chat template is defined for {base_model_name}. Resolving to the hard-coded templates."
)
self.tokenizer = None
self.prompt_template = get_prompt_template(model_name)
Expand Down

0 comments on commit 0e060fc

Please sign in to comment.