Skip to content

Commit

Permalink
Merge pull request #49 from stratosphereips/mari_tui_prettify
Browse files Browse the repository at this point in the history
Prepare TUI for BSides
  • Loading branch information
MariaRigaki authored Apr 2, 2024
2 parents 214a937 + a339c20 commit 5918684
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 30 deletions.
12 changes: 8 additions & 4 deletions agents/interactive_tui/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ACTION_MAPPER = {
"ScanNetwork": ActionType.ScanNetwork,
"ScanServices": ActionType.FindServices,
"FindServices": ActionType.FindServices,
"FindData": ActionType.FindData,
"ExfiltrateData": ActionType.ExfiltrateData,
"ExploitService": ActionType.ExploitService,
Expand Down Expand Up @@ -119,8 +120,11 @@ def create_mem_prompt(self, memory_list: list) -> str:
"""Summarize a list of memories into a few sentences."""
prompt = ""
if len(memory_list) > 0:
for memory in memory_list:
prompt += f"You have taken action {str(memory)} in the past.\n"
for memory, goodness in memory_list:
if goodness:
prompt += f"You have taken action {str(memory)} in the past. This action was helpful.\n"
else:
prompt += f"You have taken action {str(memory)} in the past. This action was not helpful.\n"
return prompt

def parse_response(self, llm_response: str, state: Observation.state):
Expand All @@ -136,8 +140,8 @@ def parse_response(self, llm_response: str, state: Observation.state):
# self.memories.append((action_str, action_params))

_, action = create_action_from_response(response, state)
if action_str == "ScanServices":
action_str = "FindServices"
# if action_str == "ScanServices":
# action_str = "FindServices"
action_output = (
f"You can take action {action_str} with parameters {action_params}"
)
Expand Down
63 changes: 40 additions & 23 deletions agents/interactive_tui/interactive_tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Author: Maria Rigaki - [email protected]
#
from textual.app import App, ComposeResult, Widget
from textual.widgets import Tree, Button, Log, Select, Input
from textual.widgets import Tree, Button, RichLog, Select, Input
from textual.containers import Vertical, VerticalScroll, Horizontal
from textual.validation import Function
from textual import on
Expand Down Expand Up @@ -210,10 +210,10 @@ def compose(self) -> ComposeResult:
classes="box",
)
)
yield Log(classes="box", id="textarea")
yield RichLog(classes="box", id="textarea", highlight=True, markup=True)
yield Horizontal(
Button("Take Action", variant="primary", id="act"),
Button("Hack the Planet", variant="warning", id="help"),
Button("Hack the Future", variant="warning", id="help"),
# Button("Hack the Planet", variant="primary", id="hack"),
)
# yield Footer()
Expand Down Expand Up @@ -271,11 +271,11 @@ def select_changed(self, event: Select.Changed) -> None:
case ActionType.ExploitService:
target_input.set_options(known_hosts)

services = []
services = set()
for host in state.known_services:
if host not in state.controlled_hosts:
for serv in state.known_services[host]:
services.append((serv.name, serv.name))
services.add((serv.name, serv.name))
service_input.set_options(services)

case ActionType.FindData:
Expand All @@ -284,10 +284,10 @@ def select_changed(self, event: Select.Changed) -> None:
case ActionType.ExfiltrateData:
target_input.set_options(contr_hosts)

data = []
data = set()
for host in state.known_data:
for d in state.known_data[host]:
data.append((d.id, d.id))
data.add((d.id, d.id))
data_input.set_options(data)

else:
Expand Down Expand Up @@ -330,8 +330,8 @@ def select_changed(self, event: Select.Changed) -> None:

@on(Input.Changed)
def handle_inputs(self, event: Input.Changed) -> None:
# log = self.query_one("Log")
# log.write_line(f"Input: {str(event.value)} from {event._sender.id}")
# log = self.query_one("RichLog")
# log.write(f"Input: {str(event.value)} from {event._sender.id}")

if event._sender.id == "src_host":
self.src_host_input = event.value
Expand All @@ -350,12 +350,12 @@ def do_something(self, event: Button.Pressed) -> None:
Press the button to select a random action.
Right now there is only one button. If we add more we will need to distinguish them.
"""
log = self.query_one("Log")
log = self.query_one("RichLog")
if event.button.id == "act":
action = self._move(self.current_obs.state)

self.update_state(action)
self.memory_buf.append(action)

# Take the first node of TreeState which contains the tree
tree_state = self.query_one(TreeState)
tree = tree_state.children[0]
Expand All @@ -367,19 +367,19 @@ def do_something(self, event: Button.Pressed) -> None:
)
if action is not None:
# To remove the discrepancy between scan and find services
msg = f"Assistant proposes: {str(action)}"
msg = f"[bold yellow]:robot: Assistant proposes:[/bold yellow] {str(action)}"
# if event.button.id == "hack":
self.update_state(action)
self.memory_buf.append(action)
# self.memory_buf.append(action)

tree_state = self.query_one(TreeState)
tree = tree_state.children[0]
self.update_tree(tree)
else:
msg = f"Assistant proposes (invalid): {act_str}"
log.write_line(msg)
msg = f"[bold red]:robot: Assistant proposes (invalid):[/bold red] {act_str}"
log.write(msg)
self.notify(
message=msg, title="LLM Action", timeout=20, severity="warning"
message=msg, title="LLM Action", timeout=15, severity="warning"
)

# Redo if hack the planet
Expand All @@ -391,20 +391,35 @@ def do_something(self, event: Button.Pressed) -> None:
# else:
# self.repetitions = 0
else:
log.write_line("No assistant is available at the moment.")
log.write(
"[bold red]No assistant is available at the moment.[/bold red]"
)

def update_state(self, action: Action) -> None:
# Get next observation of the environment
next_observation = self.agent.make_step(action)
if next_observation.state != self.current_obs.state:
good_action = True
else:
good_action = False

# Collect reward
self.returns += next_observation.reward
# Move to next state
self.current_obs = next_observation
self.memory_buf.append((action, good_action))

if next_observation.end:
log = self.query_one("RichLog")
if next_observation.info["end_reason"] == "goal_reached":
self.notify(f"You won! Total return: {self.returns}", timeout=10)
log.write(
f"[bold green]:tada: :fireworks: :trophy: You won! Total return: {self.returns}[/bold green]",
)
self.notify(f"You won! Total return: {self.returns}", timeout=20)
else:
log.write(
f"[bold red]:x: :sob: You lost! Total return: {self.returns}[/bold red]"
)
self.notify(
f"You lost! Total return: {self.returns}",
severity="error",
Expand Down Expand Up @@ -518,7 +533,7 @@ def _generate_valid_actions(self, state: GameState) -> list:

def _move(self, state: GameState) -> Action:
action = None
log = self.query_one("Log")
log = self.query_one("RichLog")
if self.next_action == ActionType.ScanNetwork:
parameters = {
"source_host": IP(self.src_host_input),
Expand Down Expand Up @@ -548,7 +563,7 @@ def _move(self, state: GameState) -> Action:
)
elif self.next_action == ActionType.ExfiltrateData:
for host, data_items in state.known_data.items():
log.write_line(f"{str(state.known_data.items())}")
# log.write(f"{str(state.known_data.items())}")
if IP(self.src_host_input) == host:
for datum in data_items:
if self.data_input == datum.id:
Expand All @@ -561,17 +576,19 @@ def _move(self, state: GameState) -> Action:
action_type=self.next_action, params=parameters
)
else:
log.write_line(f"Invalid input: {self.next_action} with {parameters}")
log.write(
f"[bold red]Invalid input: {self.next_action} with {parameters}[/bold red]"
)
logger.info(
f"Invalid input from user: {self.next_action} with {parameters}"
)

if action is None:
action = self._random_move(state)
log.write_line(f"Random action: {str(action)}")
log.write(f"[bold yellow]Random action:[/bold yellow] {str(action)}")
logger.info(f"Random action due to error: {str(action)}")

log.write_line(f"Action selected: {str(action)}")
log.write(f"[bold blue]:woman: Action selected:[/bold blue] {str(action)}")
logger.info(f"User selected action: {str(action)}")

return action
Expand Down
6 changes: 3 additions & 3 deletions agents/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def validate_action_in_state(llm_response: dict, state: GameState) -> bool:
case "ScanNetwork":
if action_params["target_network"] in known_nets:
valid = True
case "ScanServices":
case "ScanServices" | "FindServices":
if (
action_params["target_host"] in known_hosts
or action_params["target_host"] in contr_hosts
Expand Down Expand Up @@ -139,7 +139,7 @@ def create_action_from_response(llm_response: dict, state: GameState) -> tuple:
"source_host": IP(src_host),
},
)
case "ScanServices":
case "ScanServices" | "FindServices":
src_host = action_params["source_host"]
action = Action(
ActionType.FindServices,
Expand All @@ -154,7 +154,7 @@ def create_action_from_response(llm_response: dict, state: GameState) -> tuple:
src_host = action_params["source_host"]
if len(list(state.known_services[IP(target_ip)])) > 0:
for serv in state.known_services[IP(target_ip)]:
if serv.name == target_service:
if serv.name == target_service.lower():
parameters = {
"target_host": IP(target_ip),
"target_service": Service(
Expand Down

0 comments on commit 5918684

Please sign in to comment.