Skip to content

Commit

Permalink
Implement configurable LLM/embeddings providers
Browse files Browse the repository at this point in the history
  • Loading branch information
Vidminas committed Feb 4, 2024
1 parent a25dfc0 commit 495ca50
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 138 deletions.
Empty file added chat_app/apis/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions chat_app/apis/base_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC
from langchain_core.messages import BaseMessage
from requests import Session

session = Session()


class BaseEmbeddingsAPI(ABC):
"""
Generic abstract class that contains endpoints and route handling for embeddings providers
"""

def __init__(self):
self.session = session

def get_embeddings(self) -> None:
pass


class BaseLLMAPI(ABC):
"""
Generic abstract class that contains endpoints and route handling for LLM providers
"""

def __init__(self):
self.session = session

def get_llm_models(self) -> list[str]:
pass

def chat_completion(self, selected_llm: str, messages: list[BaseMessage]) -> str:
pass
51 changes: 51 additions & 0 deletions chat_app/apis/demo_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from urllib.parse import urljoin
from langchain.schema import messages_to_dict
from langchain_core.messages import BaseMessage
from .base_api import BaseEmbeddingsAPI, BaseLLMAPI


class DemoEmbeddingsAPI(BaseEmbeddingsAPI):
def __init__(self, embeddings_provider_url: str):
super().__init__()
self.embeddings_provider_url = (
embeddings_provider_url
if embeddings_provider_url.endswith("/")
else f"{embeddings_provider_url}/"
)

def get_embeddings(self) -> None:
return super().get_embeddings()

def __str__(self):
return f"Demo embeddings provider: {self.embeddings_provider_url}"


class DemoLLMAPI(BaseLLMAPI):
def __init__(self, llm_provider_url: str):
super().__init__()
self.llm_provider_url = (
llm_provider_url
if llm_provider_url.endswith("/")
else f"{llm_provider_url}/"
)

def get_llm_models(self) -> list[str]:
response = self.session.get(urljoin(self.llm_provider_url, "models/"))
if not response.is_redirect:
response.raise_for_status()
return response.json()

def chat_completion(self, selected_llm: str, messages: list[BaseMessage]) -> str:
response = self.session.post(
urljoin(self.llm_provider_url, "completions/"),
json={
"model": selected_llm,
"messages": messages_to_dict(messages),
},
)
if not response.is_redirect:
response.raise_for_status()
return response.text

def __str__(self):
return f"Demo LLM provider: {self.llm_provider_url}"
Loading

0 comments on commit 495ca50

Please sign in to comment.