From a8d28f02432040c32d32dcfd82f7d67a92171987 Mon Sep 17 00:00:00 2001 From: ZealanL Date: Sat, 26 Oct 2024 00:10:26 -0700 Subject: [PATCH] Add logit bonus system --- .../src/private/RLGymPPO_CPP/PPO/DiscretePolicy.h | 11 ++++++++++- .../private/RLGymPPO_CPP/Threading/ThreadAgent.cpp | 8 +++++++- RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp | 8 ++++++++ RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.h | 2 ++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/RLGymPPO_CPP/src/private/RLGymPPO_CPP/PPO/DiscretePolicy.h b/RLGymPPO_CPP/src/private/RLGymPPO_CPP/PPO/DiscretePolicy.h index 009aeac..4265880 100644 --- a/RLGymPPO_CPP/src/private/RLGymPPO_CPP/PPO/DiscretePolicy.h +++ b/RLGymPPO_CPP/src/private/RLGymPPO_CPP/PPO/DiscretePolicy.h @@ -13,7 +13,9 @@ namespace RLGPC { int inputAmount; int actionAmount; IList layerSizes; + float temperature; + torch::Tensor logitBonuses; // Min probability that an action will be taken constexpr static float ACTION_MIN_PROB = 1e-11; @@ -25,8 +27,15 @@ namespace RLGPC { void CopyTo(DiscretePolicy& to); torch::Tensor GetOutput(torch::Tensor input) { + auto baseOutput = seq->forward(input) / temperature; + + if (logitBonuses.defined()) { + auto outputRange = baseOutput.max() - baseOutput.min(); + baseOutput = baseOutput + (logitBonuses * outputRange); + } + return torch::nn::functional::softmax( - seq->forward(input) / temperature, + baseOutput, torch::nn::functional::SoftmaxFuncOptions(-1) ); } diff --git a/RLGymPPO_CPP/src/private/RLGymPPO_CPP/Threading/ThreadAgent.cpp b/RLGymPPO_CPP/src/private/RLGymPPO_CPP/Threading/ThreadAgent.cpp index 72a0a58..478d2b9 100644 --- a/RLGymPPO_CPP/src/private/RLGymPPO_CPP/Threading/ThreadAgent.cpp +++ b/RLGymPPO_CPP/src/private/RLGymPPO_CPP/Threading/ThreadAgent.cpp @@ -79,9 +79,15 @@ void _RunFunc(ThreadAgent* ta) { // Infer the policy to get actions for all our agents in all our games Timer policyInferTimer = {}; + if (blockConcurrentInfer) mgr->inferMutex.lock(); - auto actionResults = policy->GetAction(curObsTensorDevice, deterministic); + RLGPC::DiscretePolicy::ActionResult actionResults; + try { + actionResults = policy->GetAction(curObsTensorDevice, deterministic); + } catch (std::exception& e) { + RG_ERR_CLOSE("Exception during policy->GetAction(): " << e.what()); + } if (blockConcurrentInfer) mgr->inferMutex.unlock(); if (halfPrec) { diff --git a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp index 9cecd95..b84ef3a 100644 --- a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp +++ b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp @@ -1,4 +1,5 @@ #include "Learner.h" +#include "Learner.h" #include "../../private/RLGymPPO_CPP/Util/SkillTracker.h" @@ -706,6 +707,13 @@ void RLGPC::Learner::UpdateLearningRates(float policyLR, float criticLR) { ppo->UpdateLearningRates(policyLR, criticLR); } +void RLGPC::Learner::SetLogitBonuses(RLGSC::FList bonuses) { + RG_ASSERT(bonuses.size() == actionAmount); + ppo->policy->logitBonuses = torch::tensor(bonuses, ppo->policy->device); + if (ppo->policyHalf) + ppo->policyHalf->logitBonuses = torch::tensor(bonuses, ppo->policy->device); +} + std::vector RLGPC::Learner::GetAllGameMetrics() { std::vector reports = {}; diff --git a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.h b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.h index 268830f..58fc899 100644 --- a/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.h +++ b/RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.h @@ -41,6 +41,8 @@ namespace RLGPC { void UpdateLearningRates(float policyLR, float criticLR); + void SetLogitBonuses(RLGSC::FList bonuses); + std::vector GetAllGameMetrics(); void Save();