diff --git a/sgpt/app.py b/sgpt/app.py index 9dc506db..83f48e92 100644 --- a/sgpt/app.py +++ b/sgpt/app.py @@ -87,7 +87,7 @@ def main( show_chat: str = typer.Option( None, help="Show all messages from provided chat id.", - callback=ChatHandler.show_messages, + callback=ChatHandler.show_messages_callback, rich_help_panel="Chat Options", ), list_chats: bool = typer.Option( diff --git a/sgpt/handlers/chat_handler.py b/sgpt/handlers/chat_handler.py index f17cdcd1..9919df56 100644 --- a/sgpt/handlers/chat_handler.py +++ b/sgpt/handlers/chat_handler.py @@ -6,7 +6,7 @@ from click import BadArgumentUsage from ..config import cfg -from ..role import SystemRole +from ..role import DefaultRoles, SystemRole from ..utils import option_callback from .handler import Handler @@ -102,15 +102,6 @@ def __init__(self, chat_id: str, role: SystemRole) -> None: self.validate() - @classmethod - def list_ids(cls, value: str) -> None: - if not value: - return - # Prints all existing chat IDs to the console. - for chat_id in cls.chat_session.list(): - typer.echo(chat_id) - raise typer.Exit() - @property def initiated(self) -> bool: return self.chat_session.exists(self.chat_id) @@ -127,21 +118,31 @@ def is_same_role(self) -> bool: @classmethod @option_callback + def list_ids(cls, value: str) -> None: + # Prints all existing chat IDs to the console. + for chat_id in cls.chat_session.list(): + typer.echo(chat_id) + + @classmethod def show_messages(cls, chat_id: str) -> None: # Prints all messages from a specified chat ID to the console. for index, message in enumerate(cls.chat_session.get_messages(chat_id)): color = "magenta" if index % 2 == 0 else "green" typer.secho(message, fg=color) + @classmethod + @option_callback + def show_messages_callback(cls, chat_id: str) -> None: + cls.show_messages(chat_id) + def validate(self) -> None: if self.initiated: - # print("initial message:", self.initial_message) chat_role_name = self.role.get_role_name(self.initial_message) if not chat_role_name: raise BadArgumentUsage( f'Could not determine chat role of "{self.chat_id}"' ) - if self.role.name == "default": + if self.role.name == DefaultRoles.DEFAULT.value: # If user didn't pass chat mode, we will use the one that was used to initiate the chat. self.role = SystemRole.get(chat_role_name) else: diff --git a/sgpt/role.py b/sgpt/role.py index dee6c031..e36dec16 100644 --- a/sgpt/role.py +++ b/sgpt/role.py @@ -56,7 +56,7 @@ def __init__( self.name = name if variables: role = role.format(**variables) - self.role = ROLE_TEMPLATE.format(name=name, role=role) + self.role = role @classmethod def create_defaults(cls) -> None: @@ -143,6 +143,8 @@ def _save(self) -> None: f'Role "{self.name}" already exists, overwrite it?', abort=True, ) + + self.role = ROLE_TEMPLATE.format(name=self.name, role=self.role) self._file_path.write_text(json.dumps(self.__dict__), encoding="utf-8") def delete(self) -> None: