From 8b2bdf4e2ec251bbf40ccbefe17d2230342aeb9b Mon Sep 17 00:00:00 2001 From: ago109 Date: Tue, 16 Jul 2024 16:35:03 -0400 Subject: [PATCH] mod to rpe/mstdpet --- .../neurons/graded/rewardErrorCell.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ngclearn/components/neurons/graded/rewardErrorCell.py b/ngclearn/components/neurons/graded/rewardErrorCell.py index 2499c2b26..5fab46dc9 100755 --- a/ngclearn/components/neurons/graded/rewardErrorCell.py +++ b/ngclearn/components/neurons/graded/rewardErrorCell.py @@ -57,12 +57,12 @@ def _advance_state(dt, use_online_predictor, alpha, mu, rpe, reward, ## compute/update RPE and predictor values accum_reward = accum_reward + reward m = (Ns > 0.) * 1. - _Ns = Ns * m + (1. - m) ## mask out Ns - rpe = reward - mu/_Ns #reward - mu + #_Ns = Ns * m + (1. - m) ## mask out Ns + rpe = reward - mu #/_Ns #reward - mu if use_online_predictor: - #mu = mu * (1. - alpha) + reward * alpha - mu = mu + reward - Ns = Ns + 1. + mu = mu * (1. - alpha) + reward * alpha + # mu = mu + reward + # Ns = Ns + 1. n_ep_steps = n_ep_steps + 1 return mu, rpe, n_ep_steps, accum_reward, Ns @@ -76,16 +76,18 @@ def advance_state(self, mu, rpe, n_ep_steps, accum_reward, Ns): @staticmethod def _evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu, - accum_reward): + accum_reward, rpe): if use_online_predictor is False: ## total episodic reward signal r = accum_reward/n_ep_steps mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r - return mu + rpe = r - mu + return mu, rpe @resolver(_evolve) - def evolve(self, mu): + def evolve(self, mu, rpe): self.mu.set(mu) + self.rpe.set(rpe) @staticmethod def _reset(batch_size, n_units):