Skip to content

Commit

Permalink
Added support for Anthropic API (Claude)
Browse files Browse the repository at this point in the history
  • Loading branch information
Arceuid731 committed Jun 26, 2024
1 parent a1e800c commit 92258e8
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 3 deletions.
113 changes: 112 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import google.generativeai as genai
import httpx
import openai
import anthropic
from fastapi import FastAPI, Query, Request, Response
from fastapi.responses import StreamingResponse
from google.generativeai import GenerativeModel
Expand Down Expand Up @@ -342,7 +343,6 @@ def get_models(self):
]
return {"default_models": default_models, "models": models}


class GeminiChatBot(ChatBotAbc):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -446,6 +446,110 @@ def get_models(self):
]
return {"default_models": default_models, "models": models}

class AnthropicChatBot(ChatBotAbc):
@classmethod
def is_start_available(cls):
return os.environ.get("ANTHROPIC_API_KEY")

def __init__(self) -> None:
super().__init__()
logger.info("Using Anthropic API")
self.anthropic_client = anthropic.AsyncAnthropic(
api_key=os.environ.get("ANTHROPIC_API_KEY")
)

async def chat_completions(self, raycast_data: dict):
messages = self.__build_anthropic_messages(raycast_data)
model = raycast_data["model"]
temperature = os.environ.get("TEMPERATURE", 0.5)

try:
response = await self.anthropic_client.messages.create(
model=model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=temperature,
stream=True
)
async for chunk in response:
if chunk.type == "content_block_delta":
yield f'data: {json_dumps({"text": chunk.delta.text})}\n\n'
elif chunk.type == "message_stop":
yield f'data: {json_dumps({"text": "", "finish_reason": "stop"})}\n\n'
except Exception as e:
logger.error(f"Anthropic error: {e}")
yield f'data: {json_dumps({"text": str(e), "finish_reason": "error"})}\n\n'

def __build_anthropic_messages(self, raycast_data: dict):
anthropic_messages = []
for msg in raycast_data["messages"]:
content = {}
if "system_instructions" in msg["content"]:
anthropic_messages.append({"role": "system", "content": msg["content"]["system_instructions"]})
if "command_instructions" in msg["content"]:
anthropic_messages.append({"role": "system", "content": msg["content"]["command_instructions"]})
if "text" in msg["content"]:
anthropic_messages.append({"role": msg["author"], "content": msg["content"]["text"]})
return anthropic_messages

async def translate_completions(self, raycast_data: dict):
messages = [
{"role": "system", "content": f"Translate the following text to {raycast_data['target']}:"},
{"role": "user", "content": raycast_data["q"]},
]
model = os.environ.get("ANTHROPIC_TRANSLATE_MODEL", "claude-3-opus-20240229")

try:
response = await self.anthropic_client.messages.create(
model=model,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=0.8,
stream=True
)
async for chunk in response:
if chunk.type == "content_block_delta":
yield chunk.delta.text
except Exception as e:
logger.error(f"Anthropic translation error: {e}")
yield f"Error: {str(e)}"

def get_models(self):
default_models = _get_default_model_dict("claude-3-5-sonnet-20240620")
models = [
{
"id": "claude-3-5-sonnet-20240620",
"model": "claude-3-5-sonnet-20240620",
"name": "Claude 3.5 Sonnet",
"provider": "anthropic",
"provider_name": "Anthropic",
"provider_brand": "anthropic",
"context": 32,
**_get_model_extra_info("claude-3-5-sonnet-20240620"),
},
{
"id": "claude-3-opus-20240229",
"model": "claude-3-opus-20240229",
"name": "Claude 3 Opus",
"provider": "anthropic",
"provider_name": "Anthropic",
"provider_brand": "anthropic",
"context": 32,
**_get_model_extra_info("claude-3-opus-20240229"),
},
{
"id": "claude-3-sonnet-20240229",
"model": "claude-3-sonnet-20240229",
"name": "Claude 3 Sonnet",
"provider": "anthropic",
"provider_name": "Anthropic",
"provider_brand": "anthropic",
"context": 16,
**_get_model_extra_info("claude-3-sonnet-20240229"),
},
]
return {"default_models": default_models, "models": models}


MODELS_DICT = {}
MODELS_AVAILABLE = []
Expand All @@ -464,6 +568,13 @@ def get_models(self):
MODELS_AVAILABLE.extend(_models["models"])
DEFAULT_MODELS.update(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})
if AnthropicChatBot.is_start_available():
logger.info("Anthropic API is available")
_bot = AnthropicChatBot()
_models = _bot.get_models()
MODELS_AVAILABLE.extend(_models["models"])
DEFAULT_MODELS.update(_models["default_models"])
MODELS_DICT.update({model["model"]: _bot for model in _models["models"]})


def _get_bot(model_id):
Expand Down
1 change: 1 addition & 0 deletions local_docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ docker build -t raycast .
docker run --rm -it \
$([[ -n $OPENAI_API_KEY ]] && echo -n "-e OPENAI_API_KEY=$OPENAI_API_KEY") \
$([[ -n $GOOGLE_API_KEY ]] && echo -n "-e GOOGLE_API_KEY=$GOOGLE_API_KEY") \
$([[ -n $ANTHROPIC_API_KEY ]] && echo -n "-e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY") \
$([[ -f .env ]] && echo -n "--env-file .env") \
-p 443:443 \
--dns 1.1.1.1 \
Expand Down
Loading

0 comments on commit 92258e8

Please sign in to comment.