Skip to content

Commit

Permalink
Merge pull request #51 from stratosphereips/mari_tui_prettify3
Browse files Browse the repository at this point in the history
Remove random move
  • Loading branch information
MariaRigaki authored Apr 3, 2024
2 parents a536dc0 + 64acca6 commit fb36b78
Showing 1 changed file with 80 additions and 74 deletions.
154 changes: 80 additions & 74 deletions agents/interactive_tui/interactive_tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import os
import logging
import ipaddress
from random import choice
import argparse

from assistant import LLMAssistant
Expand Down Expand Up @@ -161,6 +160,9 @@ def __init__(
)

def compose(self) -> ComposeResult:
"""
Creates the layout
"""
yield Vertical(TreeState(obs=self.current_obs), classes="box", id="tree")
yield Select(
[
Expand Down Expand Up @@ -224,7 +226,7 @@ def compose(self) -> ComposeResult:

@on(Select.Changed)
def select_changed(self, event: Select.Changed) -> None:
# Save the selections of the action parameters
"""Handles the selections of the drop down menus"""
match event._sender.id:
case "src_host":
self.src_host_input = event.value
Expand Down Expand Up @@ -334,9 +336,9 @@ def select_changed(self, event: Select.Changed) -> None:

@on(Input.Changed)
def handle_inputs(self, event: Input.Changed) -> None:
# log = self.query_one("RichLog")
# log.write(f"Input: {str(event.value)} from {event._sender.id}")

"""
Handles the manual inputs that are types by the user.
"""
if event._sender.id == "src_host":
self.src_host_input = event.value
elif event._sender.id == "network":
Expand All @@ -351,12 +353,11 @@ def handle_inputs(self, event: Input.Changed) -> None:
@on(Button.Pressed)
def do_something(self, event: Button.Pressed) -> None:
"""
Press the button to select a random action.
Right now there is only one button. If we add more we will need to distinguish them.
Handles the button events.
"""
log = self.query_one("RichLog")
if event.button.id == "act":
action = self._move(self.current_obs.state)
action = self.generate_action(self.current_obs.state)

self.update_state(action)

Expand Down Expand Up @@ -417,6 +418,9 @@ def do_something(self, event: Button.Pressed) -> None:
)

def update_state(self, action: Action) -> None:
"""
Take an action and receive the new state from the environment.
"""
# Get next observation of the environment
next_observation = self.agent.make_step(action)
if next_observation.state != self.current_obs.state:
Expand Down Expand Up @@ -493,66 +497,67 @@ def update_tree(self, tree: Widget) -> None:
for datum in new_state.known_data[host]:
node.add_leaf(f"{datum.owner} - {datum.id}")

def _generate_valid_actions(self, state: GameState) -> list:
# Generate the list of all valid actions in the current state
valid_actions = set()
for src_host in state.controlled_hosts:
# Network Scans
for network in state.known_networks:
valid_actions.add(
Action(
ActionType.ScanNetwork,
params={"target_network": network, "source_host": src_host},
)
)
# Service Scans
for host in state.known_hosts:
valid_actions.add(
Action(
ActionType.FindServices,
params={"target_host": host, "source_host": src_host},
)
)
# Service Exploits
for host, service_list in state.known_services.items():
for service in service_list:
valid_actions.add(
Action(
ActionType.ExploitService,
params={
"target_host": host,
"target_service": service,
"source_host": src_host,
},
)
)
# Data Scans
for host in state.controlled_hosts:
valid_actions.add(
Action(
ActionType.FindData,
params={"target_host": host, "source_host": src_host},
)
)

# Data Exfiltration
for src_host, data_list in state.known_data.items():
for data in data_list:
for trg_host in state.controlled_hosts:
if trg_host != src_host:
valid_actions.add(
Action(
ActionType.ExfiltrateData,
params={
"target_host": trg_host,
"source_host": src_host,
"data": data,
},
)
)
return list(valid_actions)

def _move(self, state: GameState) -> Action:
# def _generate_valid_actions(self, state: GameState) -> list:
# # Generate the list of all valid actions in the current state
# valid_actions = set()
# for src_host in state.controlled_hosts:
# # Network Scans
# for network in state.known_networks:
# valid_actions.add(
# Action(
# ActionType.ScanNetwork,
# params={"target_network": network, "source_host": src_host},
# )
# )
# # Service Scans
# for host in state.known_hosts:
# valid_actions.add(
# Action(
# ActionType.FindServices,
# params={"target_host": host, "source_host": src_host},
# )
# )
# # Service Exploits
# for host, service_list in state.known_services.items():
# for service in service_list:
# valid_actions.add(
# Action(
# ActionType.ExploitService,
# params={
# "target_host": host,
# "target_service": service,
# "source_host": src_host,
# },
# )
# )
# # Data Scans
# for host in state.controlled_hosts:
# valid_actions.add(
# Action(
# ActionType.FindData,
# params={"target_host": host, "source_host": src_host},
# )
# )

# # Data Exfiltration
# for src_host, data_list in state.known_data.items():
# for data in data_list:
# for trg_host in state.controlled_hosts:
# if trg_host != src_host:
# valid_actions.add(
# Action(
# ActionType.ExfiltrateData,
# params={
# "target_host": trg_host,
# "source_host": src_host,
# "data": data,
# },
# )
# )
# return list(valid_actions)

def generate_action(self, state: GameState) -> Action:
"""Generate a valid action from the user inputs"""
action = None
log = self.query_one("RichLog")
if self.next_action == ActionType.ScanNetwork:
Expand Down Expand Up @@ -605,8 +610,9 @@ def _move(self, state: GameState) -> Action:
)

if action is None:
action = self._random_move(state)
log.write(f"[bold yellow]Random action:[/bold yellow] {str(action)}")
log.write(
f"[bold red]Please select a valid action:[/bold bed] {str(action)}"
)
logger.info(f"Random action due to error: {str(action)}")

log.write(
Expand All @@ -616,10 +622,10 @@ def _move(self, state: GameState) -> Action:

return action

def _random_move(self, state: GameState) -> Action:
# Randomly choose from the available actions
valid_actions = self._generate_valid_actions(state)
return choice(valid_actions)
# def _random_move(self, state: GameState) -> Action:
# # Randomly choose from the available actions
# valid_actions = self._generate_valid_actions(state)
# return choice(valid_actions)

def _clear_state(self) -> None:
"""Reset the state and variables"""
Expand Down

0 comments on commit fb36b78

Please sign in to comment.