Skip to content

Commit

Permalink
Add logit bonus system
Browse files Browse the repository at this point in the history
  • Loading branch information
ZealanL committed Oct 26, 2024
1 parent 628f6a3 commit a8d28f0
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 2 deletions.
11 changes: 10 additions & 1 deletion RLGymPPO_CPP/src/private/RLGymPPO_CPP/PPO/DiscretePolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ namespace RLGPC {
int inputAmount;
int actionAmount;
IList layerSizes;

float temperature;
torch::Tensor logitBonuses;

// Min probability that an action will be taken
constexpr static float ACTION_MIN_PROB = 1e-11;
Expand All @@ -25,8 +27,15 @@ namespace RLGPC {
void CopyTo(DiscretePolicy& to);

torch::Tensor GetOutput(torch::Tensor input) {
auto baseOutput = seq->forward(input) / temperature;

if (logitBonuses.defined()) {
auto outputRange = baseOutput.max() - baseOutput.min();
baseOutput = baseOutput + (logitBonuses * outputRange);
}

return torch::nn::functional::softmax(
seq->forward(input) / temperature,
baseOutput,
torch::nn::functional::SoftmaxFuncOptions(-1)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,15 @@ void _RunFunc(ThreadAgent* ta) {
// Infer the policy to get actions for all our agents in all our games
Timer policyInferTimer = {};


if (blockConcurrentInfer)
mgr->inferMutex.lock();
auto actionResults = policy->GetAction(curObsTensorDevice, deterministic);
RLGPC::DiscretePolicy::ActionResult actionResults;
try {
actionResults = policy->GetAction(curObsTensorDevice, deterministic);
} catch (std::exception& e) {
RG_ERR_CLOSE("Exception during policy->GetAction(): " << e.what());
}
if (blockConcurrentInfer)
mgr->inferMutex.unlock();
if (halfPrec) {
Expand Down
8 changes: 8 additions & 0 deletions RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "Learner.h"
#include "Learner.h"

#include "../../private/RLGymPPO_CPP/Util/SkillTracker.h"

Expand Down Expand Up @@ -706,6 +707,13 @@ void RLGPC::Learner::UpdateLearningRates(float policyLR, float criticLR) {
ppo->UpdateLearningRates(policyLR, criticLR);
}

void RLGPC::Learner::SetLogitBonuses(RLGSC::FList bonuses) {
RG_ASSERT(bonuses.size() == actionAmount);
ppo->policy->logitBonuses = torch::tensor(bonuses, ppo->policy->device);
if (ppo->policyHalf)
ppo->policyHalf->logitBonuses = torch::tensor(bonuses, ppo->policy->device);
}

std::vector<RLGPC::Report> RLGPC::Learner::GetAllGameMetrics() {
std::vector<Report> reports = {};

Expand Down
2 changes: 2 additions & 0 deletions RLGymPPO_CPP/src/public/RLGymPPO_CPP/Learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ namespace RLGPC {

void UpdateLearningRates(float policyLR, float criticLR);

void SetLogitBonuses(RLGSC::FList bonuses);

std::vector<Report> GetAllGameMetrics();

void Save();
Expand Down

0 comments on commit a8d28f0

Please sign in to comment.