diff --git a/RLGymPPO_CPP/python_scripts/metric_receiver.py b/RLGymPPO_CPP/python_scripts/metric_receiver.py index cb78f89..5355042 100644 --- a/RLGymPPO_CPP/python_scripts/metric_receiver.py +++ b/RLGymPPO_CPP/python_scripts/metric_receiver.py @@ -1,6 +1,6 @@ +import signal import site import sys -import json import os wandb_run = None @@ -8,6 +8,21 @@ # Takes in the python executable path, the three wandb init strings, and optionally the current run ID # Returns the ID of the run (either newly created or resumed) def init(py_exec_path, project, group, name, id = None): + """Takes in the python executable path, the three wandb init strings, and optionally the current run ID. Returns the ID of the run (either newly created or resumed) + + Args: + py_exec_path (str): Python executable path, necessary to fix a bug where the wrong interpreter is used + project (str): Wandb project name + group (str): Wandb group name + name (str): Wandb run name + id (str, optional): Id of the wandb run, if None, a new run is created + + Raises: + Exception: Failed to import wandb + + Returns: + str: The id of the created or continued run + """ global wandb_run @@ -35,5 +50,24 @@ def init(py_exec_path, project, group, name, id = None): return wandb_run.id def add_metrics(metrics): + """Logs metrics to the wandb run + + Args: + metrics (Dict[str, Any]): The metrics to log + """ global wandb_run - wandb_run.log(metrics) \ No newline at end of file + wandb_run.log(metrics) + + +def end(_signal): + """Runs post-mortem tasks + + Args: + signal (int): Received signal + """ + print(f"Received signal {_signal}, running post-mortem tasks") + + # SIGBREAK crashes wandb_run.finish on a WinError[10054]. + + if _signal != signal.Signals.SIGBREAK.value: + wandb_run.finish() \ No newline at end of file diff --git a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp index 9cecd95..629d464 100644 --- a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp +++ b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp @@ -603,6 +603,7 @@ void RLGPC::Learner::Learn() { RG_LOG("Learner: Timestep limit of " << config.timestepLimit << " reached, stopping"); RG_LOG("\tStopping agents..."); agentMgr->StopAgents(); + if(config.sendMetrics) this->metricSender->StopRun(); } void RLGPC::Learner::AddNewExperience(GameTrajectory& gameTraj, Report& report) { diff --git a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.cpp b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.cpp index 5050160..7adc8d9 100644 --- a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.cpp +++ b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.cpp @@ -1,6 +1,7 @@ #include "MetricSender.h" #include "Timer.h" +#include namespace py = pybind11; using namespace RLGPC; @@ -12,12 +13,15 @@ RLGPC::MetricSender::MetricSender(std::string _projectName, std::string _groupNa try { pyMod = py::module::import("python_scripts.metric_receiver"); + this->initMethod = pyMod.attr("init"); + this->sendMetricsMethod = pyMod.attr("add_metrics"); + this->onKillMethod = pyMod.attr("end"); } catch (std::exception& e) { RG_ERR_CLOSE("MetricSender: Failed to import metrics receiver, exception: " << e.what()); } try { - auto returedRunID = pyMod.attr("init")(PY_EXEC_PATH, projectName, groupName, runName, runID); + auto returedRunID = this->initMethod(PY_EXEC_PATH, projectName, groupName, runName, runID); curRunID = returedRunID.cast(); RG_LOG(" > " << (runID.empty() ? "Starting" : "Continuing") << " run with ID : \"" << curRunID << "\"..."); @@ -26,6 +30,10 @@ RLGPC::MetricSender::MetricSender(std::string _projectName, std::string _groupNa } RG_LOG(" > MetricSender initalized."); + + std::signal(SIGINT, MetricSender::OnKillSignal); + std::signal(SIGTERM, MetricSender::OnKillSignal); + std::signal(SIGBREAK, MetricSender::OnKillSignal); } void RLGPC::MetricSender::Send(const Report& report) { @@ -35,12 +43,35 @@ void RLGPC::MetricSender::Send(const Report& report) { reportDict[pair.first.c_str()] = pair.second; try { - pyMod.attr("add_metrics")(reportDict); + this->sendMetricsMethod(reportDict); } catch (std::exception& e) { RG_ERR_CLOSE("MetricSender: Failed to add metrics, exception: " << e.what()); } } -RLGPC::MetricSender::~MetricSender() { +void RLGPC::MetricSender::StopRun() const +{ + try { + this->onKillMethod(0); + } + catch (std::exception& e) { + RG_ERR_CLOSE("MetricSender: Failed to add metrics, exception: " << e.what()); + } +} + +void RLGPC::MetricSender::OnKillSignal(const int signal) +{ + RG_LOG("Received end signal " << signal << "."); + try { + pybind11::module pyMod = py::module::import("python_scripts.metric_receiver"); + pyMod.attr("end")(signal); + } + catch (std::exception& e) { + RG_ERR_CLOSE("MetricSender: Failed during end signal handling, exception: " << e.what()); + } + +} + +RLGPC::MetricSender::~MetricSender() { } \ No newline at end of file diff --git a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.h b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.h index b167fd7..21498dc 100644 --- a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.h +++ b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.h @@ -8,11 +8,19 @@ namespace RLGPC { std::string projectName, groupName, runName; pybind11::module pyMod; + pybind11::object initMethod; + pybind11::object sendMetricsMethod; + pybind11::object onKillMethod; + MetricSender(std::string projectName = {}, std::string groupName = {}, std::string runName = {}, std::string runID = {}); RG_NO_COPY(MetricSender); void Send(const Report& report); + void StopRun() const; + + static void OnKillSignal(int sig); + ~MetricSender(); };