Skip to content

Commit

Permalink
Merge pull request #16 from LazaroHurtado/bug_fixes
Browse files Browse the repository at this point in the history
using class alias for anthropic provider
  • Loading branch information
kedarchandrayan authored Mar 6, 2024
2 parents f2eb3de + 110a68b commit 2fb5f04
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions src/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from .model import ModelProvider

from anthropic import AsyncAnthropic, Anthropic
from anthropic import Anthropic as AnthropicModel
from anthropic import AsyncAnthropic
from typing import Optional

class Anthropic(ModelProvider):
Expand All @@ -22,7 +23,7 @@ def __init__(self, model_name: str = "claude", api_key: str = None):
self.api_key = api_key or os.getenv('ANTHROPIC_API_KEY')

self.model = AsyncAnthropic(api_key=self.api_key)
self.enc = Anthropic().get_tokenizer()
self.tokenizer = AnthropicModel().get_tokenizer()

# Generate the prompt structure for the Anthropic model
# Replace the following file with the appropriate prompt structure
Expand All @@ -43,8 +44,8 @@ def generate_prompt(self, context: str, retrieval_question: str) -> str | list[d
context=context)

def encode_text_to_tokens(self, text: str) -> list[int]:
return self.enc.encode(text).ids
return self.tokenizer.encode(text).ids

def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
# Assuming you have a different decoder for Anthropic
return self.enc.decode(tokens[:context_length])
return self.tokenizer.decode(tokens[:context_length])
6 changes: 3 additions & 3 deletions src/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, model_name: str = "gpt-3.5-turbo-0125", api_key: str = None):
self.api_key = api_key or os.getenv('OPENAI_API_KEY')

self.model = AsyncOpenAI(api_key=self.api_key)
self.enc = tiktoken.encoding_for_model(self.model_name)
self.tokenizer = tiktoken.encoding_for_model(self.model_name)

async def evaluate_model(self, prompt: str) -> str:
response = await self.model.chat.completions.create(
Expand All @@ -46,7 +46,7 @@ def generate_prompt(self, context: str, retrieval_question: str) -> str | list[d
}]

def encode_text_to_tokens(self, text: str) -> list[int]:
return self.enc.encode(text)
return self.tokenizer.encode(text)

def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
return self.enc.decode(tokens[:context_length])
return self.tokenizer.decode(tokens[:context_length])

0 comments on commit 2fb5f04

Please sign in to comment.