From 80655f1e93bcd16e102e5230bfa42c1f3d046782 Mon Sep 17 00:00:00 2001 From: Val Kharitonov Date: Thu, 14 Dec 2023 11:24:23 -0500 Subject: [PATCH] add mistralai --- gptcli/assistant.py | 3 +++ gptcli/config.py | 1 + gptcli/gpt.py | 4 ++++ gptcli/mistral.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+) create mode 100644 gptcli/mistral.py diff --git a/gptcli/assistant.py b/gptcli/assistant.py index 07de4b1..6ed228c 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -7,6 +7,7 @@ from gptcli.completion import CompletionProvider, ModelOverrides, Message from gptcli.google import GoogleCompletionProvider from gptcli.llama import LLaMACompletionProvider +from gptcli.mistral import MistralCompletionProvider from gptcli.openai import OpenAICompletionProvider from gptcli.anthropic import AnthropicCompletionProvider @@ -64,6 +65,8 @@ def get_completion_provider(model: str) -> CompletionProvider: return LLaMACompletionProvider() elif model.startswith("chat-bison"): return GoogleCompletionProvider() + elif model.startswith("mistral"): + return MistralCompletionProvider() else: raise ValueError(f"Unknown model: {model}") diff --git a/gptcli/config.py b/gptcli/config.py index 3cb9070..472facf 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -20,6 +20,7 @@ class GptCliConfig: show_price: bool = True api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") + mistral_api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY") anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY") google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") log_file: Optional[str] = None diff --git a/gptcli/gpt.py b/gptcli/gpt.py index e9634e9..2d3aad7 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -15,6 +15,7 @@ import datetime import google.generativeai as genai import gptcli.anthropic +import gptcli.mistral from gptcli.assistant import ( Assistant, DEFAULT_ASSISTANTS, @@ -178,6 +179,9 @@ def main(): ) sys.exit(1) + if config.mistral_api_key: + gptcli.mistral.api_key = config.mistral_api_key + if config.anthropic_api_key: gptcli.anthropic.api_key = config.anthropic_api_key diff --git a/gptcli/mistral.py b/gptcli/mistral.py new file mode 100644 index 0000000..be3a303 --- /dev/null +++ b/gptcli/mistral.py @@ -0,0 +1,47 @@ +from typing import Iterator, List +import os +from gptcli.completion import CompletionProvider, Message +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage + +api_key = os.environ.get("MISTRAL_API_KEY") + + +class MistralCompletionProvider(CompletionProvider): + def __init__(self): + self.client = MistralClient(api_key=api_key) + + def complete( + self, messages: List[Message], args: dict, stream: bool = False + ) -> Iterator[str]: + kwargs = {} + if "temperature" in args: + kwargs["temperature"] = args["temperature"] + if "top_p" in args: + kwargs["top_p"] = args["top_p"] + + messages = [ + ChatMessage(role=msg["role"], content=msg["content"]) + for msg in messages + ] + + if stream: + response_iter = self.client.chat_stream( + model=args["model"], + messages=messages, + **kwargs, + ) + + for response in response_iter: + next_choice = response.choices[0] + if next_choice.finish_reason is None and next_choice.delta.content: + yield next_choice.delta.content + else: + response = self.client.chat( + model=args["model"], + messages=messages, + **kwargs, + ) + next_choice = response.choices[0] + if next_choice.message.content: + yield next_choice.message.content