diff --git a/src/providers/anthropic.py b/src/providers/anthropic.py index 5bacde58..b5c8d165 100644 --- a/src/providers/anthropic.py +++ b/src/providers/anthropic.py @@ -2,7 +2,8 @@ from .model import ModelProvider -from anthropic import AsyncAnthropic, Anthropic +from anthropic import Anthropic as AnthropicModel +from anthropic import AsyncAnthropic from typing import Optional class Anthropic(ModelProvider): @@ -22,7 +23,7 @@ def __init__(self, model_name: str = "claude", api_key: str = None): self.api_key = api_key or os.getenv('ANTHROPIC_API_KEY') self.model = AsyncAnthropic(api_key=self.api_key) - self.enc = Anthropic().get_tokenizer() + self.tokenizer = AnthropicModel().get_tokenizer() # Generate the prompt structure for the Anthropic model # Replace the following file with the appropriate prompt structure @@ -43,8 +44,8 @@ def generate_prompt(self, context: str, retrieval_question: str) -> str | list[d context=context) def encode_text_to_tokens(self, text: str) -> list[int]: - return self.enc.encode(text).ids + return self.tokenizer.encode(text).ids def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str: # Assuming you have a different decoder for Anthropic - return self.enc.decode(tokens[:context_length]) \ No newline at end of file + return self.tokenizer.decode(tokens[:context_length]) \ No newline at end of file diff --git a/src/providers/openai.py b/src/providers/openai.py index 494dcdbe..82cd9026 100644 --- a/src/providers/openai.py +++ b/src/providers/openai.py @@ -20,7 +20,7 @@ def __init__(self, model_name: str = "gpt-3.5-turbo-0125", api_key: str = None): self.api_key = api_key or os.getenv('OPENAI_API_KEY') self.model = AsyncOpenAI(api_key=self.api_key) - self.enc = tiktoken.encoding_for_model(self.model_name) + self.tokenizer = tiktoken.encoding_for_model(self.model_name) async def evaluate_model(self, prompt: str) -> str: response = await self.model.chat.completions.create( @@ -46,7 +46,7 @@ def generate_prompt(self, context: str, retrieval_question: str) -> str | list[d }] def encode_text_to_tokens(self, text: str) -> list[int]: - return self.enc.encode(text) + return self.tokenizer.encode(text) def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str: - return self.enc.decode(tokens[:context_length]) \ No newline at end of file + return self.tokenizer.decode(tokens[:context_length]) \ No newline at end of file