Skip to content

Commit

Permalink
REPL conversation bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
TheR1D committed Dec 22, 2023
1 parent 7ac1f98 commit 6ffb6ba
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(
show_chat: str = typer.Option(
None,
help="Show all messages from provided chat id.",
callback=ChatHandler.show_messages,
callback=ChatHandler.show_messages_callback,
rich_help_panel="Chat Options",
),
list_chats: bool = typer.Option(
Expand Down
25 changes: 13 additions & 12 deletions sgpt/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from click import BadArgumentUsage

from ..config import cfg
from ..role import SystemRole
from ..role import DefaultRoles, SystemRole
from ..utils import option_callback
from .handler import Handler

Expand Down Expand Up @@ -102,15 +102,6 @@ def __init__(self, chat_id: str, role: SystemRole) -> None:

self.validate()

@classmethod
def list_ids(cls, value: str) -> None:
if not value:
return
# Prints all existing chat IDs to the console.
for chat_id in cls.chat_session.list():
typer.echo(chat_id)
raise typer.Exit()

@property
def initiated(self) -> bool:
return self.chat_session.exists(self.chat_id)
Expand All @@ -127,21 +118,31 @@ def is_same_role(self) -> bool:

@classmethod
@option_callback
def list_ids(cls, value: str) -> None:
# Prints all existing chat IDs to the console.
for chat_id in cls.chat_session.list():
typer.echo(chat_id)

@classmethod
def show_messages(cls, chat_id: str) -> None:
# Prints all messages from a specified chat ID to the console.
for index, message in enumerate(cls.chat_session.get_messages(chat_id)):
color = "magenta" if index % 2 == 0 else "green"
typer.secho(message, fg=color)

@classmethod
@option_callback
def show_messages_callback(cls, chat_id: str) -> None:
cls.show_messages(chat_id)

def validate(self) -> None:
if self.initiated:
# print("initial message:", self.initial_message)
chat_role_name = self.role.get_role_name(self.initial_message)
if not chat_role_name:
raise BadArgumentUsage(
f'Could not determine chat role of "{self.chat_id}"'
)
if self.role.name == "default":
if self.role.name == DefaultRoles.DEFAULT.value:
# If user didn't pass chat mode, we will use the one that was used to initiate the chat.
self.role = SystemRole.get(chat_role_name)
else:
Expand Down
4 changes: 3 additions & 1 deletion sgpt/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.name = name
if variables:
role = role.format(**variables)
self.role = ROLE_TEMPLATE.format(name=name, role=role)
self.role = role

@classmethod
def create_defaults(cls) -> None:
Expand Down Expand Up @@ -143,6 +143,8 @@ def _save(self) -> None:
f'Role "{self.name}" already exists, overwrite it?',
abort=True,
)

self.role = ROLE_TEMPLATE.format(name=self.name, role=self.role)
self._file_path.write_text(json.dumps(self.__dict__), encoding="utf-8")

def delete(self) -> None:
Expand Down

0 comments on commit 6ffb6ba

Please sign in to comment.