Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove in-chat model overrides #100

Merged
merged 1 commit into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@ optional arguments:
Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C to clear the conversation, `:r` or Ctrl-R to re-generate the last response.
To enter multi-line mode, enter a backslash `\` followed by a new line. Exit the multi-line mode by pressing ESC and then Enter.

You can override the model parameters using `--model`, `--temperature` and `--top_p` arguments at the end of your prompt. For example:

```
> What is the meaning of life? --model gpt-4 --temperature 2.0
The meaning of life is subjective and can be different for diverse human beings and unique-phil ethics.org/cultuties-/ it that reson/bdstals89im3_jrf334;mvs-bread99ef=g22me
```

The `dev` assistant is instructed to be an expert in software development and provide short responses.

```bash
Expand Down Expand Up @@ -197,7 +190,6 @@ assistants:
- { role: system, content: !include "pirate.txt" }
```


### Customize OpenAI API URL

If you are using other models compatible with the OpenAI Python SDK, you can configure them by modifying the `openai_base_url` setting in the config file or using the `OPENAI_BASE_URL` environment variable .
Expand Down
23 changes: 7 additions & 16 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from gptcli.completion import (
CompletionEvent,
CompletionProvider,
ModelOverrides,
Message,
)
from gptcli.providers.google import GoogleCompletionProvider
Expand Down Expand Up @@ -107,28 +106,20 @@ def from_config(cls, name: str, config: AssistantConfig):
def init_messages(self) -> List[Message]:
return self.config.get("messages", [])[:]

def supported_overrides(self) -> List[str]:
return ["model", "temperature", "top_p"]

def _param(self, param: str, override_params: ModelOverrides) -> Any:
# If the param is in the override_params, use that value
# Otherwise, use the value from the config
def _param(self, param: str) -> Any:
# Use the value from the config if exists
# Otherwise, use the default value
return override_params.get(
param, self.config.get(param, CONFIG_DEFAULTS[param])
)
return self.config.get(param, CONFIG_DEFAULTS[param])

def complete_chat(
self, messages, override_params: ModelOverrides = {}, stream: bool = True
) -> Iterator[CompletionEvent]:
model = self._param("model", override_params)
def complete_chat(self, messages, stream: bool = True) -> Iterator[CompletionEvent]:
model = self._param("model")
completion_provider = get_completion_provider(model)
return completion_provider.complete(
messages,
{
"model": model,
"temperature": float(self._param("temperature", override_params)),
"top_p": float(self._param("top_p", override_params)),
"temperature": float(self._param("temperature")),
"top_p": float(self._param("top_p")),
},
stream,
)
Expand Down
62 changes: 13 additions & 49 deletions gptcli/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any, Dict, Optional, Tuple
from typing import Optional

from openai import BadRequestError, OpenAIError
from prompt_toolkit import PromptSession
Expand All @@ -11,9 +10,16 @@
from rich.markdown import Markdown
from rich.text import Text

from gptcli.session import (ALL_COMMANDS, COMMAND_CLEAR, COMMAND_QUIT,
COMMAND_RERUN, ChatListener, InvalidArgumentError,
ResponseStreamer, UserInputProvider)
from gptcli.session import (
ALL_COMMANDS,
COMMAND_CLEAR,
COMMAND_QUIT,
COMMAND_RERUN,
ChatListener,
InvalidArgumentError,
ResponseStreamer,
UserInputProvider,
)

TERMINAL_WELCOME = """
Hi! I'm here to help. Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C and Enter to clear
Expand Down Expand Up @@ -113,43 +119,6 @@ def response_streamer(self) -> ResponseStreamer:
return CLIResponseStreamer(self.console, self.markdown)


def parse_args(input: str) -> Tuple[str, Dict[str, Any]]:
# Extract parts enclosed in specific delimiters (triple backticks, triple quotes, single backticks)
extracted_parts = []
delimiters = ['```', '"""', '`']

def replacer(match):
for i, delimiter in enumerate(delimiters):
part = match.group(i + 1)
if part is not None:
extracted_parts.append((part, delimiter))
break
return f"__EXTRACTED_PART_{len(extracted_parts) - 1}__"

# Construct the regex pattern dynamically from the delimiters list
pattern_fragments = [re.escape(d) + '(.*?)' + re.escape(d) for d in delimiters]
pattern = re.compile('|'.join(pattern_fragments), re.DOTALL)

input = pattern.sub(replacer, input)

# Parse the remaining string for arguments
args = {}
regex = r'--(\w+)(?:=(\S+)|\s+(\S+))?'
matches = re.findall(regex, input)

if matches:
for key, value1, value2 in matches:
value = value1 if value1 else value2 if value2 else ''
args[key] = value.strip("\"'")
input = re.sub(regex, "", input).strip()

# Add back the extracted parts, with enclosing backticks or quotes
for i, (part, delimiter) in enumerate(extracted_parts):
input = input.replace(f"__EXTRACTED_PART_{i}__", f"{delimiter}{part.strip()}{delimiter}")

return input, args


class CLIFileHistory(FileHistory):
def append_string(self, string: str) -> None:
if string in ALL_COMMANDS:
Expand All @@ -163,12 +132,11 @@ def __init__(self, history_filename) -> None:
history=CLIFileHistory(history_filename)
)

def get_user_input(self) -> Tuple[str, Dict[str, Any]]:
def get_user_input(self) -> str:
while (next_user_input := self._request_input()) == "":
pass

user_input, args = self._parse_input(next_user_input)
return user_input, args
return next_user_input

def prompt(self, multiline=False):
bindings = KeyBindings()
Expand Down Expand Up @@ -219,7 +187,3 @@ def _request_input(self):
return line

return self.prompt(multiline=True)

def _parse_input(self, input: str) -> Tuple[str, Dict[str, Any]]:
input, args = parse_args(input)
return input, args
6 changes: 0 additions & 6 deletions gptcli/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ class Message(TypedDict):
content: str


class ModelOverrides(TypedDict, total=False):
model: str
temperature: float
top_p: float


class Pricing(TypedDict):
prompt: float
response: float
Expand Down
5 changes: 2 additions & 3 deletions gptcli/composite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gptcli.completion import Message, ModelOverrides, UsageEvent
from gptcli.completion import Message, UsageEvent
from gptcli.session import ChatListener, ResponseStreamer


Expand Down Expand Up @@ -56,8 +56,7 @@ def on_chat_response(
self,
messages: List[Message],
response: Message,
overrides: ModelOverrides,
usage: Optional[UsageEvent],
):
for listener in self.listeners:
listener.on_chat_response(messages, response, overrides, usage)
listener.on_chat_response(messages, response, usage)
5 changes: 2 additions & 3 deletions gptcli/cost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gptcli.assistant import Assistant
from gptcli.completion import Message, ModelOverrides, UsageEvent
from gptcli.completion import Message, UsageEvent
from gptcli.session import ChatListener

from rich.console import Console
Expand All @@ -22,13 +22,12 @@ def on_chat_response(
self,
messages: List[Message],
response: Message,
args: ModelOverrides,
usage: Optional[UsageEvent] = None,
):
if usage is None:
return

model = self.assistant._param("model", args)
model = self.assistant._param("model")
num_tokens = usage.total_tokens
cost = usage.cost

Expand Down
45 changes: 13 additions & 32 deletions gptcli/session.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from abc import abstractmethod
from typing_extensions import TypeGuard
from gptcli.assistant import Assistant
from gptcli.completion import (
Message,
ModelOverrides,
CompletionError,
BadRequestError,
UsageEvent,
)
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional


class ResponseStreamer:
Expand Down Expand Up @@ -45,15 +43,14 @@ def on_chat_response(
self,
messages: List[Message],
response: Message,
overrides: ModelOverrides,
usage: Optional[UsageEvent] = None,
):
pass


class UserInputProvider:
@abstractmethod
def get_user_input(self) -> Tuple[str, Dict[str, Any]]:
def get_user_input(self) -> str:
pass


Expand Down Expand Up @@ -85,7 +82,7 @@ def __init__(
):
self.assistant = assistant
self.messages: List[Message] = assistant.init_messages()
self.user_prompts: List[Tuple[Message, ModelOverrides]] = []
self.user_prompts: List[Message] = []
self.listener = listener
self.stream = stream

Expand All @@ -103,18 +100,17 @@ def _rerun(self):
self.messages = self.messages[:-1]

self.listener.on_chat_rerun(True)
_, args = self.user_prompts[-1]
self._respond(args)
self._respond()

def _respond(self, overrides: ModelOverrides) -> bool:
def _respond(self) -> bool:
"""
Respond to the user's input and return whether the assistant's response was saved.
"""
next_response: str = ""
usage: Optional[UsageEvent] = None
try:
completion_iter = self.assistant.complete_chat(
self.messages, override_params=overrides, stream=self.stream
self.messages, stream=self.stream
)

with self.listener.response_streamer() as stream:
Expand All @@ -137,28 +133,16 @@ def _respond(self, overrides: ModelOverrides) -> bool:

next_message: Message = {"role": "assistant", "content": next_response}
self.listener.on_chat_message(next_message)
self.listener.on_chat_response(self.messages, next_message, overrides, usage)
self.listener.on_chat_response(self.messages, next_message, usage)

self.messages = self.messages + [next_message]
return True

def _validate_args(self, args: Dict[str, Any]) -> TypeGuard[ModelOverrides]:
for key in args:
supported_overrides = self.assistant.supported_overrides()
if key not in supported_overrides:
self.listener.on_error(
InvalidArgumentError(
f"Invalid argument: {key}. Allowed arguments: {supported_overrides}"
)
)
return False
return True

def _add_user_message(self, user_input: str, args: ModelOverrides):
def _add_user_message(self, user_input: str):
user_message: Message = {"role": "user", "content": user_input}
self.messages = self.messages + [user_message]
self.listener.on_chat_message(user_message)
self.user_prompts.append((user_message, args))
self.user_prompts.append(user_message)

def _rollback_user_message(self):
self.messages = self.messages[:-1]
Expand All @@ -168,13 +152,10 @@ def _print_help(self):
with self.listener.response_streamer() as stream:
stream.on_next_token(COMMANDS_HELP)

def process_input(self, user_input: str, args: Dict[str, Any]):
def process_input(self, user_input: str):
"""
Process the user's input and return whether the session should continue.
"""
if not self._validate_args(args):
return True

if user_input in COMMAND_QUIT:
return False
elif user_input in COMMAND_CLEAR:
Expand All @@ -187,14 +168,14 @@ def process_input(self, user_input: str, args: Dict[str, Any]):
self._print_help()
return True

self._add_user_message(user_input, args)
response_saved = self._respond(args)
self._add_user_message(user_input)
response_saved = self._respond()
if not response_saved:
self._rollback_user_message()

return True

def loop(self, input_provider: UserInputProvider):
self.listener.on_chat_start()
while self.process_input(*input_provider.get_user_input()):
while self.process_input(input_provider.get_user_input()):
pass
Loading
Loading