Skip to content

Commit

Permalink
Merge pull request #251 from stratosphereips/role-dependent-timeout
Browse files Browse the repository at this point in the history
Role dependent timeout
  • Loading branch information
ondrej-lukas authored Nov 5, 2024
2 parents 4602d2f + b02a4cb commit 0b24242
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
17 changes: 15 additions & 2 deletions coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def _process_generic_action(self, agent_addr: tuple, action: Action) -> dict:
self._agent_episode_ends[agent_addr] = True
end_reason = "goal_reached"
obs_info = {'end_reason': "goal_reached"}
elif self._agent_steps[agent_addr] >= self._steps_limit:
elif self._timeout_reached(agent_addr):
self._agent_episode_ends[agent_addr] = True
obs_info = {"end_reason": "max_steps"}
end_reason = "max_steps"
Expand Down Expand Up @@ -532,7 +532,7 @@ def _generate_episode_end_message(self, agent_addr:tuple)->dict:
end_reason = ""
if self._agent_goal_reached[agent_addr]:
end_reason = "goal_reached"
elif self._agent_steps[agent_addr] >= self._world.timeout:
elif self._timeout_reached(agent_addr):
end_reason = "max_steps"
else:
end_reason = "game_lost"
Expand Down Expand Up @@ -609,6 +609,19 @@ def _check_detection(self, agent_addr:tuple, last_action:Action)->bool:
else:
self.logger.info("\tNot detected!")
return detection

def _timeout_reached(self, agent_addr:tuple) ->bool:
"""
Checks if the agent reached the max allowed steps. Only applies to role 'Attacker'
"""
self.logger.debug(f"Checking timout for {self.agents[agent_addr]}")
if self.agents[agent_addr][1] == "Attacker":
if self._agent_steps[agent_addr] >= self._steps_limit:
self.logger.info("Timeout reached by {self.agents[agent_addr]}!")
return True
else:
return False

__version__ = "v0.2.2"


Expand Down
2 changes: 1 addition & 1 deletion tests/manual/three_nets/three_net_testing_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ env:
# random_seed: 42
scenario: 'three_nets'
use_global_defender: False
max_steps: 50
max_steps: 15
use_dynamic_addresses: False
use_firewall: True
save_trajectories: False
Expand Down
4 changes: 4 additions & 0 deletions tests/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace
python3 -m pytest tests/test_components.py -p no:warnings -vvvv -s --full-trace
python3 -m pytest tests/test_coordinator.py -p no:warnings -vvvv -s --full-trace

# run ruff check as well
echo "Running RUFF check: in ${PWD}"
ruff check --output-format=github --select=E9,F4,F6,F7,F8,N8 --ignore=F405 --target-version=py310 --line-length=120 .

24 changes: 21 additions & 3 deletions tests/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_class_init(self):
assert coord._agent_states == {}
assert coord._agent_goal_reached == {}
assert coord._agent_episode_ends == {}
assert type(coord._actions_queue) == queue.Queue
assert type(coord._answers_queue) == queue.Queue
assert type(coord._actions_queue) is queue.Queue
assert type(coord._answers_queue) is queue.Queue

def test_initialize_new_player(self, coordinator_init):
coord = coordinator_init
Expand Down Expand Up @@ -201,4 +201,22 @@ def test_check_goal_empty(self, coordinator_init):
"known_data":{},
"known_blocks":{}
}
assert coordinator_init._check_goal(game_state, win_conditions) is True
assert coordinator_init._check_goal(game_state, win_conditions) is True

def test_timeout(self, coordinator_registered_player):
coord, init_result = coordinator_registered_player
action = Action(
ActionType.ScanNetwork,
params={
"source_host": IP("192.168.2.2"),
"target_network": Network("192.168.1.0", 24),
},
)
result = init_result
for _ in range(15):
result = coord._process_generic_action(("192.168.1.1", "3300"), action)
assert result["to_agent"] == ("192.168.1.1", "3300")
assert result["status"] == "GameStatus.OK"
assert init_result["observation"]["state"] != result["observation"]["state"]
assert result["observation"]["end"]
assert result["observation"]["info"]["end_reason"] == "max_steps"

0 comments on commit 0b24242

Please sign in to comment.