Skip to content

Commit

Permalink
add mistralai
Browse files Browse the repository at this point in the history
  • Loading branch information
kharvd committed Mar 4, 2024
1 parent a957b4b commit 80655f1
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
3 changes: 3 additions & 0 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gptcli.completion import CompletionProvider, ModelOverrides, Message
from gptcli.google import GoogleCompletionProvider
from gptcli.llama import LLaMACompletionProvider
from gptcli.mistral import MistralCompletionProvider
from gptcli.openai import OpenAICompletionProvider
from gptcli.anthropic import AnthropicCompletionProvider

Expand Down Expand Up @@ -64,6 +65,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
return LLaMACompletionProvider()
elif model.startswith("chat-bison"):
return GoogleCompletionProvider()
elif model.startswith("mistral"):
return MistralCompletionProvider()
else:
raise ValueError(f"Unknown model: {model}")

Expand Down
1 change: 1 addition & 0 deletions gptcli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class GptCliConfig:
show_price: bool = True
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
mistral_api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY")
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
log_file: Optional[str] = None
Expand Down
4 changes: 4 additions & 0 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import google.generativeai as genai
import gptcli.anthropic
import gptcli.mistral
from gptcli.assistant import (
Assistant,
DEFAULT_ASSISTANTS,
Expand Down Expand Up @@ -178,6 +179,9 @@ def main():
)
sys.exit(1)

if config.mistral_api_key:
gptcli.mistral.api_key = config.mistral_api_key

if config.anthropic_api_key:
gptcli.anthropic.api_key = config.anthropic_api_key

Expand Down
47 changes: 47 additions & 0 deletions gptcli/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Iterator, List
import os
from gptcli.completion import CompletionProvider, Message
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

api_key = os.environ.get("MISTRAL_API_KEY")


class MistralCompletionProvider(CompletionProvider):
def __init__(self):
self.client = MistralClient(api_key=api_key)

def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[str]:
kwargs = {}
if "temperature" in args:
kwargs["temperature"] = args["temperature"]
if "top_p" in args:
kwargs["top_p"] = args["top_p"]

messages = [
ChatMessage(role=msg["role"], content=msg["content"])
for msg in messages
]

if stream:
response_iter = self.client.chat_stream(
model=args["model"],
messages=messages,
**kwargs,
)

for response in response_iter:
next_choice = response.choices[0]
if next_choice.finish_reason is None and next_choice.delta.content:
yield next_choice.delta.content
else:
response = self.client.chat(
model=args["model"],
messages=messages,
**kwargs,
)
next_choice = response.choices[0]
if next_choice.message.content:
yield next_choice.message.content

0 comments on commit 80655f1

Please sign in to comment.