Skip to content

Commit

Permalink
System role optimisations for GPT4+ models
Browse files Browse the repository at this point in the history
  • Loading branch information
TheR1D committed Dec 19, 2023
1 parent 1c58566 commit 3b9b0bd
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 157 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,17 +341,13 @@ REQUEST_TIMEOUT=60
DEFAULT_MODEL=gpt-3.5-turbo
# Default color for OpenAI completions.
DEFAULT_COLOR=magenta
# Force use system role messages (not recommended).
SYSTEM_ROLES=false
# When in --shell mode, default to "Y" for no input.
DEFAULT_EXECUTE_SHELL_CMD=false
# Disable streaming of responses
DISABLE_STREAMING=false
```
Possible options for `DEFAULT_COLOR`: black, red, green, yellow, blue, magenta, cyan, white, bright_black, bright_red, bright_green, bright_yellow, bright_blue, bright_magenta, bright_cyan, bright_white.
Switch `SYSTEM_ROLES` to force use [system roles](https://help.openai.com/en/articles/7042661-chatgpt-api-transition-guide) messages, this is not recommended, since it doesn't perform well with current GPT models.
### Full list of arguments
```text
╭─ Arguments ─────────────────────────────────────────────────────────────────────────────────────────────────╮
Expand Down
4 changes: 2 additions & 2 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def main(
help="Large language model to use.",
),
temperature: float = typer.Option(
0.1,
0.0,
min=0.0,
max=2.0,
help="Randomness of generated output.",
),
top_probability: float = typer.Option(
1.0,
min=0.1,
min=0.0,
max=1.0,
help="Limits highest probable tokens (words).",
),
Expand Down
3 changes: 1 addition & 2 deletions sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
"CHAT_CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
"CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
"REQUEST_TIMEOUT": int(os.getenv("REQUEST_TIMEOUT", "60")),
"DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-3.5-turbo"),
"DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-4-1106-preview"),
"OPENAI_API_HOST": os.getenv("OPENAI_API_HOST", "https://api.openai.com"),
"DEFAULT_COLOR": os.getenv("DEFAULT_COLOR", "magenta"),
"ROLE_STORAGE_PATH": os.getenv("ROLE_STORAGE_PATH", str(ROLE_STORAGE_PATH)),
"SYSTEM_ROLES": os.getenv("SYSTEM_ROLES", "false"),
"DEFAULT_EXECUTE_SHELL_CMD": os.getenv("DEFAULT_EXECUTE_SHELL_CMD", "false"),
"DISABLE_STREAMING": os.getenv("DISABLE_STREAMING", "false")
# New features might add their own config variables here.
Expand Down
17 changes: 4 additions & 13 deletions sgpt/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..config import cfg
from ..role import SystemRole
from ..utils import option_callback
from .handler import Handler

CHAT_CACHE_LENGTH = int(cfg.get("CHAT_CACHE_LENGTH"))
Expand Down Expand Up @@ -117,28 +118,22 @@ def initiated(self) -> bool:
@property
def initial_message(self) -> str:
chat_history = self.chat_session.get_messages(self.chat_id)
index = 1 if cfg.get("SYSTEM_ROLES") == "true" else 0
return chat_history[index] if chat_history else ""
return chat_history[0] if chat_history else ""

@property
def is_same_role(self) -> bool:
# TODO: Should be optimized for REPL mode.
return self.role.same_role(self.initial_message)

@classmethod
@option_callback
def show_messages_callback(cls, chat_id: str) -> None:
if not chat_id:
return
cls.show_messages(chat_id)
raise typer.Exit()

@classmethod
def show_messages(cls, chat_id: str) -> None:
# Prints all messages from a specified chat ID to the console.
for index, message in enumerate(cls.chat_session.get_messages(chat_id)):
# Remove output type from the message, e.g. "text\nCommand:" -> "text"
if message.startswith("user:"):
message = "\n".join(message.splitlines()[:-1])
color = "magenta" if index % 2 == 0 else "green"
typer.secho(message, fg=color)

Expand All @@ -160,13 +155,9 @@ def validate(self) -> None:
f'since it was initiated as "{chat_role_name}" chat.'
)

def make_prompt(self, prompt: str) -> str:
prompt = prompt.strip()
return self.role.make_prompt(prompt, not self.initiated)

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
messages = []
if not self.initiated and cfg.get("SYSTEM_ROLES") == "true":
if not self.initiated:
messages.append({"role": "system", "content": self.role.role})
messages.append({"role": "user", "content": prompt})
return messages
Expand Down
12 changes: 4 additions & 8 deletions sgpt/handlers/default_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ def __init__(self, role: SystemRole) -> None:
super().__init__(role)
self.role = role

def make_prompt(self, prompt: str) -> str:
prompt = prompt.strip()
return self.role.make_prompt(prompt, initial=True)

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
messages = []
if cfg.get("SYSTEM_ROLES") == "true":
messages.append({"role": "system", "content": self.role.role})
messages.append({"role": "user", "content": prompt})
messages = [
{"role": "system", "content": self.role.role},
{"role": "user", "content": prompt},
]
return messages
5 changes: 1 addition & 4 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@ def __init__(self, role: SystemRole) -> None:
self.role = role
self.color = cfg.get("DEFAULT_COLOR")

def make_prompt(self, prompt: str) -> str:
raise NotImplementedError

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
raise NotImplementedError

def get_completion(self, **kwargs: Any) -> Generator[str, None, None]:
yield from self.client.get_completion(**kwargs)

def handle(self, prompt: str, **kwargs: Any) -> str:
messages = self.make_messages(self.make_prompt(prompt))
messages = self.make_messages(prompt.strip())
full_completion = ""
stream = cfg.get("DISABLE_STREAMING") == "false"
if not stream:
Expand Down
158 changes: 64 additions & 94 deletions sgpt/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,29 @@
SHELL_ROLE = """Provide only {shell} commands for {os} without any description.
If there is a lack of details, provide most logical solution.
Ensure the output is a valid shell command.
If multiple steps required try to combine them together."""
If multiple steps required try to combine them together.
Provide only plain text without Markdown formatting.
Do not provide markdown formatting such as ```.
"""

DESCRIBE_SHELL_ROLE = """Provide a terse, single sentence description
of the given shell command. Provide only plain text without Markdown formatting.
Do not show any warnings or information regarding your capabilities.
If you need to store any data, assume it will be stored in the chat."""
DESCRIBE_SHELL_ROLE = """Provide a terse, single sentence description of the given shell command.
Describe each argument and option of the command, in plain text (no markdown).
Provide only plain text without Markdown formatting.
Do not include symbols such as `."""

CODE_ROLE = """Provide only code as output without any description.
IMPORTANT: Provide only plain text without Markdown formatting.
IMPORTANT: Do not include markdown formatting such as ```.
Provide only code in plain text format without Markdown formatting.
Do not include symbols such as ``` or ```python.
If there is a lack of details, provide most logical solution.
You are not allowed to ask for more details.
Ignore any potential risk of errors or confusion."""
For example if the prompt is "Hello world Python", you should return "print('Hello world')"."""

DEFAULT_ROLE = """You are Command Line App ShellGPT, a programming and system administration assistant.
DEFAULT_ROLE = """You are programming and system administration assistant.
You are managing {os} operating system with {shell} shell.
Provide only plain text without Markdown formatting.
Do not show any warnings or information regarding your capabilities.
If you need to store any data, assume it will be stored in the chat."""


PROMPT_TEMPLATE = """###
Role name: {name}
{role}
If you need to store any data, assume it will be stored in the conversation."""

Request: {request}
###
{expecting}:"""
ROLE_TEMPLATE = "You are {name}\n{role}"


class SystemRole:
Expand All @@ -53,58 +48,26 @@ def __init__(
self,
name: str,
role: str,
expecting: str,
variables: Optional[Dict[str, str]] = None,
) -> None:
self.storage.mkdir(parents=True, exist_ok=True)
self.name = name
self.expecting = expecting
self.variables = variables
if variables:
# Variables are for internal use only.
role = role.format(**variables)
self.role = role
self.role = ROLE_TEMPLATE.format(name=name, role=role)

@classmethod
def create_defaults(cls) -> None:
cls.storage.parent.mkdir(parents=True, exist_ok=True)
variables = {"shell": cls.shell_name(), "os": cls.os_name()}
variables = {"shell": cls._shell_name(), "os": cls._os_name()}
for default_role in (
SystemRole("default", DEFAULT_ROLE, "Answer", variables),
SystemRole("shell", SHELL_ROLE, "Command", variables),
SystemRole("describe_shell", DESCRIBE_SHELL_ROLE, "Description", variables),
SystemRole("code", CODE_ROLE, "Code"),
SystemRole("ShellGPT", DEFAULT_ROLE, variables),
SystemRole("Shell Command Generator", SHELL_ROLE, variables),
SystemRole("Shell Command Describer", DESCRIBE_SHELL_ROLE, variables),
SystemRole("Code Generator", CODE_ROLE),
):
if not default_role.exists:
default_role.save()

@classmethod
def os_name(cls) -> str:
current_platform = platform.system()
if current_platform == "Linux":
return "Linux/" + distro_name(pretty=True)
if current_platform == "Windows":
return "Windows " + platform.release()
if current_platform == "Darwin":
return "Darwin/MacOS " + platform.mac_ver()[0]
return current_platform

@classmethod
def shell_name(cls) -> str:
current_platform = platform.system()
if current_platform in ("Windows", "nt"):
is_powershell = len(getenv("PSModulePath", "").split(pathsep)) >= 3
return "powershell.exe" if is_powershell else "cmd.exe"
return basename(getenv("SHELL", "/bin/sh"))

@classmethod
def get_role_name(cls, initial_message: str) -> Optional[str]:
if not initial_message:
return None
message_lines = initial_message.splitlines()
if "###" in message_lines[0]:
return message_lines[1].split("Role name: ")[1].strip()
return None
if not default_role._exists:
default_role._save()

@classmethod
def get(cls, name: str) -> "SystemRole":
Expand All @@ -117,12 +80,8 @@ def get(cls, name: str) -> "SystemRole":
@option_callback
def create(cls, name: str) -> None:
role = typer.prompt("Enter role description")
expecting = typer.prompt(
"Enter expecting result, e.g. answer, code, \
shell command, command description, etc."
)
role = cls(name, role, expecting)
role.save()
role = cls(name, role)
role._save()

@classmethod
@option_callback
Expand All @@ -140,58 +99,69 @@ def list(cls, _value: str) -> None:
def show(cls, name: str) -> None:
typer.echo(cls.get(name).role)

@property
def exists(self) -> bool:
return self.file_path.exists()
@classmethod
def get_role_name(cls, initial_message: str) -> Optional[str]:
if not initial_message:
return None
message_lines = initial_message.splitlines()
if "You are" in message_lines[0]:
return message_lines[0].split("You are ")[1].strip()
return None

@classmethod
def _os_name(cls) -> str:
current_platform = platform.system()
if current_platform == "Linux":
return "Linux/" + distro_name(pretty=True)
if current_platform == "Windows":
return "Windows " + platform.release()
if current_platform == "Darwin":
return "Darwin/MacOS " + platform.mac_ver()[0]
return current_platform

@classmethod
def _shell_name(cls) -> str:
current_platform = platform.system()
if current_platform in ("Windows", "nt"):
is_powershell = len(getenv("PSModulePath", "").split(pathsep)) >= 3
return "powershell.exe" if is_powershell else "cmd.exe"
return basename(getenv("SHELL", "/bin/sh"))

@property
def system_message(self) -> Dict[str, str]:
return {"role": "system", "content": self.role}
def _exists(self) -> bool:
return self._file_path.exists()

@property
def file_path(self) -> Path:
def _file_path(self) -> Path:
return self.storage / f"{self.name}.json"

def save(self) -> None:
if self.exists:
def _save(self) -> None:
if self._exists:
typer.confirm(
f'Role "{self.name}" already exists, overwrite it?',
abort=True,
)
self.file_path.write_text(json.dumps(self.__dict__), encoding="utf-8")
self._file_path.write_text(json.dumps(self.__dict__), encoding="utf-8")

def delete(self) -> None:
if self.exists:
if self._exists:
typer.confirm(
f'Role "{self.name}" exist, delete it?',
abort=True,
)
self.file_path.unlink()

def make_prompt(self, request: str, initial: bool) -> str:
if initial:
prompt = PROMPT_TEMPLATE.format(
name=self.name,
role=self.role,
request=request,
expecting=self.expecting,
)
else:
prompt = f"{request}\n{self.expecting}:"

return prompt
self._file_path.unlink()

def same_role(self, initial_message: str) -> bool:
if not initial_message:
return False
return True if f"Role name: {self.name}" in initial_message else False
return True if f"You are {self.name}" in initial_message else False


class DefaultRoles(Enum):
DEFAULT = "default"
SHELL = "shell"
DESCRIBE_SHELL = "describe_shell"
CODE = "code"
DEFAULT = "ShellGPT"
SHELL = "Shell Command Generator"
DESCRIBE_SHELL = "Shell Command Describer"
CODE = "Code Generator"

@classmethod
def check_get(cls, shell: bool, describe_shell: bool, code: bool) -> SystemRole:
Expand Down
Loading

0 comments on commit 3b9b0bd

Please sign in to comment.