Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Extract shell command from backticks #645

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ python -m venv env && source ./env/bin/activate
Install the necessary dependencies, including development and test dependencies:

```shell
pip install -e ."[dev,test]"
pip install -e ."[dev,test,litellm]"
```

### Start Coding
Expand All @@ -35,4 +35,4 @@ Before creating a pull request, run `scripts/lint.sh` and `scripts/tests.sh` to
### Code Review
After submitting your pull request, be patient and receptive to feedback from reviewers. Address any concerns they raise and collaborate to refine the code. Together, we can enhance the ShellGPT project.

Thank you once again for your contribution! We're excited to have you join us.
Thank you once again for your contribution! We're excited to have you join us.
53 changes: 53 additions & 0 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional

Expand Down Expand Up @@ -37,6 +38,7 @@ class Handler:

def __init__(self, role: SystemRole, markdown: bool) -> None:
self.role = role
self.is_shell = role.name == DefaultRoles.SHELL.value

api_base_url = cfg.get("API_BASE_URL")
self.base_url = None if api_base_url == "default" else api_base_url
Expand All @@ -45,6 +47,13 @@ def __init__(self, role: SystemRole, markdown: bool) -> None:
self.markdown = "APPLY MARKDOWN" in self.role.role and markdown
self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")

self.backticks_start = re.compile(r"(^|[\r\n]+)```\w*[\r\n]+")
end_regex_parts = [r"[\r\n]+", "`", "`", "`", r"([\r\n]+|$)"]
self.backticks_end_prefixes = [
re.compile("".join(end_regex_parts[: i + 1]))
for i in range(len(end_regex_parts))
]

@property
def printer(self) -> Printer:
return (
Expand Down Expand Up @@ -82,6 +91,48 @@ def handle_function_call(
yield f"```text\n{result}\n```\n"
messages.append({"role": "function", "content": result, "name": name})

def _matches_end_at(self, text: str) -> tuple[bool, int]:
end_of_match = 0
for _i, regex in enumerate(self.backticks_end_prefixes):
m = regex.search(text)
if m:
end_of_match = m.end()
else:
return False, end_of_match
return True, m.start()

def _filter_chunks(
self, chunks: Generator[str, None, None]
) -> Generator[str, None, None]:
buffer = ""
inside_backticks = False
end_of_beginning = 0

for chunk in chunks:
buffer += chunk
if not inside_backticks:
m = self.backticks_start.search(buffer)
if not m:
continue
new_end_of_beginning = m.end()
if new_end_of_beginning > end_of_beginning:
end_of_beginning = new_end_of_beginning
continue
inside_backticks = True
buffer = buffer[end_of_beginning:]
if inside_backticks:
matches_end, index = self._matches_end_at(buffer)
if matches_end:
yield buffer[:index]
return
if index == len(buffer):
continue
else:
yield buffer
buffer = ""
if buffer:
yield buffer

@cache
def get_completion(
self,
Expand Down Expand Up @@ -163,4 +214,6 @@ def handle(
caching=caching,
**kwargs,
)
if self.role.name == DefaultRoles.SHELL.value:
generator = self._filter_chunks(generator)
return self.printer(generator, not disable_stream)
50 changes: 50 additions & 0 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
from unittest.mock import patch

import pytest

from sgpt.config import cfg
from sgpt.role import DefaultRoles, SystemRole

Expand All @@ -22,6 +24,54 @@ def test_shell(completion):
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout


@patch("sgpt.handlers.handler.completion")
@pytest.mark.parametrize(
"prefix,suffix",
[
("", ""),
("some text before\n```powershell\n", "\n```" ""),
("```powershell\n", "\n```\nsome text after" ""),
("some text before\n```powershell\n", "\n```\nsome text after" ""),
(
"some text with ``` before\n```powershell\n",
"\n```\nsome text with ``` after" "",
),
("```powershell\n", "\n```" ""),
("```\n", "\n```" ""),
("```powershell\r\n", "\r\n```" ""),
("```\r\n", "\r\n```" ""),
("```powershell\r", "\r```" ""),
("```\r", "\r```" ""),
],
)
@pytest.mark.parametrize("group_by_size", range(10))
def test_shell_no_backticks(completion, prefix: str, suffix: str, group_by_size: int):
expected_output = "Get-Process | \nWhere-Object { $_.Port -eq 9000 }\r\n | Select-Object Id | Text \r\nwith '```' inside"
produced_output = prefix + expected_output + suffix
if group_by_size == 0:
produced_tokens = list(produced_output)
else:
produced_tokens = [
produced_output[i : i + group_by_size]
for i in range(0, len(produced_output), group_by_size)
]
assert produced_output == "".join(produced_tokens)

role = SystemRole.get(DefaultRoles.SHELL.value)
completion.return_value = mock_comp(produced_tokens)

args = {"prompt": "find pid by port 9000", "--shell": True}
result = runner.invoke(app, cmd_args(**args))

completion.assert_called_once_with(**comp_args(role, args["prompt"]))
index = result.stdout.find(expected_output)
assert index >= 0
rest = result.stdout[index + len(expected_output) :].strip()
assert "`" not in rest
assert result.exit_code == 0
assert "[E]xecute, [D]escribe, [A]bort:" == rest


@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("sgpt.handlers.handler.completion")
Expand Down