Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QoL - Adding a function in the metric receiver to handle kill signals #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions RLGymPPO_CPP/python_scripts/metric_receiver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
import signal
import site
import sys
import json
import os

wandb_run = None

# Takes in the python executable path, the three wandb init strings, and optionally the current run ID
# Returns the ID of the run (either newly created or resumed)
def init(py_exec_path, project, group, name, id = None):
"""Takes in the python executable path, the three wandb init strings, and optionally the current run ID. Returns the ID of the run (either newly created or resumed)

Args:
py_exec_path (str): Python executable path, necessary to fix a bug where the wrong interpreter is used
project (str): Wandb project name
group (str): Wandb group name
name (str): Wandb run name
id (str, optional): Id of the wandb run, if None, a new run is created

Raises:
Exception: Failed to import wandb

Returns:
str: The id of the created or continued run
"""

global wandb_run

Expand Down Expand Up @@ -35,5 +50,24 @@ def init(py_exec_path, project, group, name, id = None):
return wandb_run.id

def add_metrics(metrics):
"""Logs metrics to the wandb run

Args:
metrics (Dict[str, Any]): The metrics to log
"""
global wandb_run
wandb_run.log(metrics)
wandb_run.log(metrics)


def end(_signal):
"""Runs post-mortem tasks

Args:
signal (int): Received signal
"""
print(f"Received signal {_signal}, running post-mortem tasks")

# SIGBREAK crashes wandb_run.finish on a WinError[10054].

if _signal != signal.Signals.SIGBREAK.value:
wandb_run.finish()
1 change: 1 addition & 0 deletions RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ void RLGPC::Learner::Learn() {
RG_LOG("Learner: Timestep limit of " << config.timestepLimit << " reached, stopping");
RG_LOG("\tStopping agents...");
agentMgr->StopAgents();
if(config.sendMetrics) this->metricSender->StopRun();
}

void RLGPC::Learner::AddNewExperience(GameTrajectory& gameTraj, Report& report) {
Expand Down
37 changes: 34 additions & 3 deletions RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "MetricSender.h"

#include "Timer.h"
#include <csignal>

namespace py = pybind11;
using namespace RLGPC;
Expand All @@ -12,12 +13,15 @@ RLGPC::MetricSender::MetricSender(std::string _projectName, std::string _groupNa

try {
pyMod = py::module::import("python_scripts.metric_receiver");
this->initMethod = pyMod.attr("init");
this->sendMetricsMethod = pyMod.attr("add_metrics");
this->onKillMethod = pyMod.attr("end");
} catch (std::exception& e) {
RG_ERR_CLOSE("MetricSender: Failed to import metrics receiver, exception: " << e.what());
}

try {
auto returedRunID = pyMod.attr("init")(PY_EXEC_PATH, projectName, groupName, runName, runID);
auto returedRunID = this->initMethod(PY_EXEC_PATH, projectName, groupName, runName, runID);
curRunID = returedRunID.cast<std::string>();
RG_LOG(" > " << (runID.empty() ? "Starting" : "Continuing") << " run with ID : \"" << curRunID << "\"...");

Expand All @@ -26,6 +30,10 @@ RLGPC::MetricSender::MetricSender(std::string _projectName, std::string _groupNa
}

RG_LOG(" > MetricSender initalized.");

std::signal(SIGINT, MetricSender::OnKillSignal);
std::signal(SIGTERM, MetricSender::OnKillSignal);
std::signal(SIGBREAK, MetricSender::OnKillSignal);
}

void RLGPC::MetricSender::Send(const Report& report) {
Expand All @@ -35,12 +43,35 @@ void RLGPC::MetricSender::Send(const Report& report) {
reportDict[pair.first.c_str()] = pair.second;

try {
pyMod.attr("add_metrics")(reportDict);
this->sendMetricsMethod(reportDict);
} catch (std::exception& e) {
RG_ERR_CLOSE("MetricSender: Failed to add metrics, exception: " << e.what());
}
}

RLGPC::MetricSender::~MetricSender() {
void RLGPC::MetricSender::StopRun() const
{
try {
this->onKillMethod(0);
}
catch (std::exception& e) {
RG_ERR_CLOSE("MetricSender: Failed to add metrics, exception: " << e.what());
}

}

void RLGPC::MetricSender::OnKillSignal(const int signal)
{
RG_LOG("Received end signal " << signal << ".");
try {
pybind11::module pyMod = py::module::import("python_scripts.metric_receiver");
pyMod.attr("end")(signal);
}
catch (std::exception& e) {
RG_ERR_CLOSE("MetricSender: Failed during end signal handling, exception: " << e.what());
}

}

RLGPC::MetricSender::~MetricSender() {
}
8 changes: 8 additions & 0 deletions RLGymPPO_CPP/src/public/RLGymPPO_CPP/Util/MetricSender.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@ namespace RLGPC {
std::string projectName, groupName, runName;
pybind11::module pyMod;

pybind11::object initMethod;
pybind11::object sendMetricsMethod;
pybind11::object onKillMethod;

MetricSender(std::string projectName = {}, std::string groupName = {}, std::string runName = {}, std::string runID = {});

RG_NO_COPY(MetricSender);

void Send(const Report& report);
void StopRun() const;

static void OnKillSignal(int sig);


~MetricSender();
};
Expand Down