Skip to content

Commit

Permalink
Merge pull request #256 from stratosphereips/assign-reward-at-episode…
Browse files Browse the repository at this point in the history
…-end

Assign reward at episode end
  • Loading branch information
ondrej-lukas authored Nov 18, 2024
2 parents 00ce5be + f23692c commit 696c8fe
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 63 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ env:
## Task configuration
The task configuration part (section `coordinator[agents]`) defines the starting and goal position of the attacker and the type of defender that is used.

### Attacker configuration (`attackers`)
Configuration of the attacking agents. Consists of two parts:
### Attacker configuration (`Attacker`)
Configuration of the attacking agents. Consists of three parts:
1. Goal definition (`goal`) which describes the `GameState` properties that must be fulfilled to award `goal_reward` to the attacker:
- `known_networks:`(set)
- `known_hosts`(set)
Expand All @@ -154,11 +154,14 @@ Configuration of the attacking agents. Consists of two parts:
- `known_data`(dict)

The initial network configuration must assign at least **one** controlled host to the attacker in the network. Any item in `controlled_hosts` is copied to `known_hosts`, so there is no need to include these in both sets. `known_networks` is also extended with a set of **all** networks accessible from the `controlled_hosts`
3. Definition of maximum allowed amount of steps:
- `max_steps:`(int)

Example attacker configuration:
```YAML
agents:
Attacker:
max_steps: 100
goal:
randomize_goal_every_episode: False
known_networks: []
Expand All @@ -179,7 +182,7 @@ agents:
known_data: {}
known_blocks: {}
```
### Defender configuration (`defenders`)
### Defender configuration (`Defender`)
Currently, the defender **is** a separate agent.

If you want a defender in the game, you must connect a defender agent. For playing without a defender, leave the section empty.
Expand Down
155 changes: 111 additions & 44 deletions coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(self, actions_queue, answers_queue, net_sec_config, allowed_roles,
self._starting_positions_per_role = self._get_starting_position_per_role()
self._win_conditions_per_role = self._get_win_condition_per_role()
self._goal_description_per_role = self._get_goal_description_per_role()
self._steps_limit = self._world.task_config.get_max_steps()
self._steps_limit_per_role = self._get_max_steps_per_role()
self._use_global_defender = self._world.task_config.get_use_global_defender()
# player information
self.agents = {}
Expand All @@ -201,18 +201,19 @@ def __init__(self, actions_queue, answers_queue, net_sec_config, allowed_roles,
self._agent_starting_position = {}
# current state per agent_addr (GameState)
self._agent_states = {}
# goal reach status per agent_addr (bool)
self._agent_goal_reached = {}
self._agent_episode_ends = {}
self._agent_detected = {}
# agent status dict {agent_addr: string}
self._agent_statuses = {}
# agent status dict {agent_addr: int}
self._agent_rewards = {}
# trajectories per agent_addr
self._agent_trajectories = {}

@property
def episode_end(self)->bool:
# Terminate episode if at least one player wins or reaches the timeout
self.logger.debug(f"End evaluation: {self._agent_episode_ends.values()}")
return all(self._agent_episode_ends.values())
# Episode ends ONLY IF all agents with defined max_steps reached the end fo the episode
exists_active_player = any(status == "playing_active" for status in self._agent_statuses.values())
self.logger.debug(f"End evaluation: {self._agent_statuses.items()} - Episode end:{not exists_active_player}")
return not exists_active_player

@property
def config_file_hash(self):
Expand Down Expand Up @@ -273,8 +274,13 @@ async def run(self):
self._reset_requests[agent] = False
self._agent_steps[agent] = 0
self._agent_states[agent] = self._world.create_state_from_view(self._agent_starting_position[agent])
self._agent_goal_reached[agent] = self._goal_reached(agent)
self._agent_episode_ends[agent] = False
self._agent_rewards.pop(agent, None)
if self._steps_limit_per_role[self.agents[agent][1]]:
# This agent can force episode end (has timeout and goal defined)
self._agent_statuses[agent] = "playing_active"
else:
# This agent can NOT force episode end (does NOT timeout or goal defined)
self._agent_statuses[agent] = "playing"
output_message_dict = self._create_response_to_reset_game_action(agent)
msg_json = self.convert_msg_dict_to_json(output_message_dict)
# Send to anwer_queue
Expand Down Expand Up @@ -307,9 +313,13 @@ def _initialize_new_player(self, agent_addr:tuple, agent_name:str, agent_role:st
self._reset_requests[agent_addr] = False
self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role]
self._agent_states[agent_addr] = self._world.create_state_from_view(self._agent_starting_position[agent_addr])
self._agent_goal_reached[agent_addr] = self._goal_reached(agent_addr)
self._agent_detected[agent_addr] = self._check_detection(agent_addr, None)
self._agent_episode_ends[agent_addr] = False

if self._steps_limit_per_role[agent_role]:
# This agent can force episode end (has timeout and goal defined)
self._agent_statuses[agent_addr] = "playing_active"
else:
# This agent can NOT force episode end (does NOT timeout or goal defined)
self._agent_statuses[agent_addr] = "playing"
if self._world.task_config.get_store_trajectories() or self._use_global_defender:
self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}")
Expand All @@ -323,10 +333,10 @@ def _remove_player(self, agent_addr:tuple)->dict:
agent_info = {}
if agent_addr in self.agents:
agent_info["state"] = self._agent_states.pop(agent_addr)
agent_info["goal_reached"] = self._agent_goal_reached.pop(agent_addr)
agent_info["status"] = self._agent_statuses.pop(agent_addr)
agent_info["num_steps"] = self._agent_steps.pop(agent_addr)
agent_info["reset_request"] = self._reset_requests.pop(agent_addr)
agent_info["episode_end"] = self._agent_episode_ends.pop(agent_addr)
agent_info["end_reward"] = self._agent_rewards.pop(agent_addr, None)
agent_info["agent_info"] = self.agents.pop(agent_addr)
self.logger.debug(f"\t{agent_info}")
else:
Expand Down Expand Up @@ -376,6 +386,19 @@ def _get_goal_description_per_role(self)->dict:
self.logger.info(f"Goal description for role '{agent_role}': {goal_descriptions[agent_role]}")
return goal_descriptions

def _get_max_steps_per_role(self)->dict:
"""
Method for finding max amount of steps in 1 episode for each agent role in the game.
"""
max_steps = {}
for agent_role in self.ALLOWED_ROLES:
try:
max_steps[agent_role] = self._world.task_config.get_max_steps(agent_role)
except KeyError:
max_steps[agent_role] = None
self.logger.info(f"Max steps in episode for '{agent_role}': {max_steps[agent_role]}")
return max_steps

def _process_join_game_action(self, agent_addr: tuple, action: Action) -> dict:
""" "
Method for processing Action of type ActionType.JoinGame
Expand All @@ -386,14 +409,13 @@ def _process_join_game_action(self, agent_addr: tuple, action: Action) -> dict:
agent_role = action.parameters["agent_info"].role
if agent_role in self.ALLOWED_ROLES:
initial_observation = self._initialize_new_player(agent_addr, agent_name, agent_role)
max_steps = self._world._max_steps if agent_role == "Attacker" else None
output_message_dict = {
"to_agent": agent_addr,
"status": str(GameStatus.CREATED),
"observation": observation_as_dict(initial_observation),
"message": {
"message": f"Welcome {agent_name}, registred as {agent_role}",
"max_steps": max_steps,
"max_steps": self._steps_limit_per_role[agent_role],
"goal_description": self._goal_description_per_role[agent_role],
"num_actions": self._world.num_actions,
"configuration_hash": self._CONFIG_FILE_HASH
Expand Down Expand Up @@ -436,8 +458,9 @@ def _create_response_to_reset_game_action(self, agent_addr: tuple) -> dict:
"observation": observation_as_dict(new_observation),
"message": {
"message": "Resetting Game and starting again.",
"max_steps": self._world._max_steps,
"goal_description": self._goal_description_per_role[self.agents[agent_addr][1]]
"max_steps": self._steps_limit_per_role[self.agents[agent_addr][1]],
"goal_description": self._goal_description_per_role[self.agents[agent_addr][1]],
"configuration_hash": self._CONFIG_FILE_HASH
},
}
return output_message_dict
Expand Down Expand Up @@ -491,24 +514,34 @@ def _process_generic_action(self, agent_addr: tuple, action: Action) -> dict:
current_state = self._agent_states[agent_addr]
# Build new Observation for the agent
self._agent_states[agent_addr] = self._world.step(current_state, action, agent_addr)
self._agent_goal_reached[agent_addr] = self._goal_reached(agent_addr)

self._agent_detected[agent_addr] = self._check_detection(agent_addr, action)

# check timout
if self._max_steps_reached(agent_addr):
self._agent_statuses[agent_addr] = "max_steps"
# check detection
if self._check_detection(agent_addr, action):
self._agent_statuses[agent_addr] = "blocked"
self._agent_detected[agent_addr] = True
# check goal
if self._goal_reached(agent_addr):
self._agent_statuses[agent_addr] = "goal_reached"
# add reward for taking a step
reward = self._world._rewards["step"]

obs_info = {}
end_reason = None
if self._agent_goal_reached[agent_addr]:
reward += self._world._rewards["goal"]
self._agent_episode_ends[agent_addr] = True
if self._agent_statuses[agent_addr] == "goal_reached":
self._assign_end_rewards()
reward += self._agent_rewards[agent_addr]
end_reason = "goal_reached"
obs_info = {'end_reason': "goal_reached"}
elif self._timeout_reached(agent_addr):
self._agent_episode_ends[agent_addr] = True
elif self._agent_statuses[agent_addr] == "max_steps":
self._assign_end_rewards()
reward += self._agent_rewards[agent_addr]
obs_info = {"end_reason": "max_steps"}
end_reason = "max_steps"
elif self._agent_detected[agent_addr]:
reward += self._world._rewards["detection"]
elif self._agent_statuses[agent_addr] == "blocked":
self._assign_end_rewards()
reward += self._agent_rewards[agent_addr]
self._agent_episode_ends[agent_addr] = True
obs_info = {"end_reason": "max_steps"}

Expand All @@ -524,6 +557,7 @@ def _process_generic_action(self, agent_addr: tuple, action: Action) -> dict:
"status": str(GameStatus.OK),
}
else:
self._assign_end_rewards()
self.logger.error(f"{self.episode_end}, {self._agent_episode_ends}")
output_message_dict = self._generate_episode_end_message(agent_addr)
return output_message_dict
Expand All @@ -533,15 +567,8 @@ def _generate_episode_end_message(self, agent_addr:tuple)->dict:
Method for generating response when agent attemps to make a step after episode ended.
"""
current_observation = self._agent_observations[agent_addr]
reward = 0 # TODO
end_reason = ""
if self._agent_goal_reached[agent_addr]:
end_reason = "goal_reached"
elif self._timeout_reached(agent_addr):
end_reason = "max_steps"
else:
end_reason = "game_lost"
reward += self._world._rewards["detection"]
reward = self._agent_rewards[agent_addr]
end_reason = self._agent_statuses[agent_addr]
new_observation = Observation(
current_observation.state,
reward=reward,
Expand Down Expand Up @@ -586,7 +613,7 @@ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
if len(matching_keys) == len(goal_dict.keys()):
return True
except KeyError:
#some keys are missing in the known_dict
# some keys are missing in the known_dict
return False
return False

Expand Down Expand Up @@ -615,18 +642,58 @@ def _check_detection(self, agent_addr:tuple, last_action:Action)->bool:
self.logger.info("\tNot detected!")
return detection

def _timeout_reached(self, agent_addr:tuple) ->bool:
def _max_steps_reached(self, agent_addr:tuple) ->bool:
"""
Checks if the agent reached the max allowed steps. Only applies to role 'Attacker'
"""
self.logger.debug(f"Checking timout for {self.agents[agent_addr]}")
if self.agents[agent_addr][1] == "Attacker":
if self._agent_steps[agent_addr] >= self._steps_limit:
agent_role = self.agents[agent_addr][1]
if self._steps_limit_per_role[agent_role]:
if self._agent_steps[agent_addr] >= self._steps_limit_per_role[agent_role]:
self.logger.info("Timeout reached by {self.agents[agent_addr]}!")
return True
else:
self.logger.debug(f"No max steps defined for role {agent_role}")
return False

def _assign_end_rewards(self)->None:
"""
Method which assings rewards to each agent which has finished playing
"""
is_episode_over = self.episode_end
for agent, status in self._agent_statuses.items():
if agent not in self._agent_rewards.keys(): # reward has not been assigned yet
agent_name, agent_role = self.agents[agent]
if agent_role == "Attacker":
match status:
case "goal_reached":
self._agent_rewards[agent] = self._world._rewards["goal"]
case "max_steps":
self._agent_rewards[agent] = 0
case "blocked":
self._agent_rewards[agent] = self._world._rewards["detection"]
self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")
elif agent_role == "Defender":
if self._agent_statuses[agent] == "max_steps": #defender was responsible for the end
raise NotImplementedError
self._agent_rewards[agent] = 0
else:
if is_episode_over: #only assign defender's reward when episode ends
sucessful_attacks = list(self._agent_statuses.values).count("goal_reached")
if sucessful_attacks > 0:
self._agent_rewards[agent] = sucessful_attacks*self._world._rewards["detection"]
self._agent_statuses[agent] = "game_lost"
else: #no successful attacker
self._agent_rewards[agent] = self._world._rewards["goal"]
self._agent_statuses[agent] = "goal_reached"
self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")
else:
if is_episode_over:
self._agent_rewards[agent] = 0
self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")



__version__ = "v0.2.2"


Expand Down Expand Up @@ -668,7 +735,7 @@ def _timeout_reached(self, agent_addr:tuple) ->bool:
action="store",
required=False,
type=str,
default="INFO",
default="WARNING",
)

args = parser.parse_args()
Expand Down
Loading

0 comments on commit 696c8fe

Please sign in to comment.