Skip to content

Commit

Permalink
Merge pull request #253 from stratosphereips/send-config-file-id-to-a…
Browse files Browse the repository at this point in the history
…gent

Send config file id to agent
  • Loading branch information
ondrej-lukas authored Nov 5, 2024
2 parents 0b24242 + eb7ae3d commit 79f7e26
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
17 changes: 11 additions & 6 deletions coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from env.worlds.network_security_game_real_world import NetworkSecurityEnvironmentRealWorld
from env.worlds.aidojo_world import AIDojoWorld
from env.game_components import Action, Observation, ActionType, GameStatus, GameState
from utils.utils import observation_as_dict, get_logging_level
from utils.utils import observation_as_dict, get_logging_level, get_file_hash
from pathlib import Path
import os
import signal
Expand Down Expand Up @@ -174,6 +174,7 @@ def __init__(self, actions_queue, answers_queue, net_sec_config, allowed_roles,
self._answers_queue = answers_queue
self.ALLOWED_ROLES = allowed_roles
self.logger = logging.getLogger("AIDojo-Coordinator")

# world definition
match world_type:
case "netsecenv":
Expand All @@ -183,9 +184,7 @@ def __init__(self, actions_queue, answers_queue, net_sec_config, allowed_roles,
case _:
self._world = AIDojoWorld(net_sec_config)
self.world_type = world_type



self._CONFIG_FILE_HASH = get_file_hash(net_sec_config)
self._starting_positions_per_role = self._get_starting_position_per_role()
self._win_conditions_per_role = self._get_win_condition_per_role()
self._goal_description_per_role = self._get_goal_description_per_role()
Expand Down Expand Up @@ -215,6 +214,10 @@ def episode_end(self)->bool:
self.logger.debug(f"End evaluation: {self._agent_episode_ends.values()}")
return all(self._agent_episode_ends.values())

@property
def config_file_hash(self):
return self._CONFIG_FILE_HASH

def convert_msg_dict_to_json(self, msg_dict)->str:
try:
# Convert message into string representation
Expand Down Expand Up @@ -383,15 +386,17 @@ def _process_join_game_action(self, agent_addr: tuple, action: Action) -> dict:
agent_role = action.parameters["agent_info"].role
if agent_role in self.ALLOWED_ROLES:
initial_observation = self._initialize_new_player(agent_addr, agent_name, agent_role)
max_steps = self._world._max_steps if agent_role == "Attacker" else None
output_message_dict = {
"to_agent": agent_addr,
"status": str(GameStatus.CREATED),
"observation": observation_as_dict(initial_observation),
"message": {
"message": f"Welcome {agent_name}, registred as {agent_role}",
"max_steps": self._world._max_steps,
"max_steps": max_steps,
"goal_description": self._goal_description_per_role[agent_role],
"num_actions": self._world.num_actions
"num_actions": self._world.num_actions,
"configuration_hash": self._CONFIG_FILE_HASH
},
}
else:
Expand Down
1 change: 1 addition & 0 deletions tests/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_join(self, coordinator_init):
assert "max_steps" in result["message"].keys()
assert "goal_description" in result["message"].keys()
assert not result["observation"]["end"]
assert "configuration_hash" in result["message"].keys()

# def test_reset(self, coordinator_registered_player):
# coord, _ = coordinator_registered_player
Expand Down
14 changes: 14 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Utility functions for then env and for the agents
# Author: Sebastian Garcia. [email protected]
# Author: Ondrej Lukas, [email protected]
#import configparser
import yaml
import sys
Expand All @@ -16,6 +17,19 @@
import csv
from random import randint
import json
import hashlib

def get_file_hash(filepath, hash_func='sha256', chunk_size=4096):
"""
Computes hash of a given file.
"""
hash_algorithm = hashlib.new(hash_func)
with open(filepath, 'rb') as file:
chunk = file.read(chunk_size)
while chunk:
hash_algorithm.update(chunk)
chunk = file.read(chunk_size)
return hash_algorithm.hexdigest()

def read_replay_buffer_from_csv(csvfile:str)->list:
"""
Expand Down

0 comments on commit 79f7e26

Please sign in to comment.