Skip to content

Commit

Permalink
mod to rpe/mstdpet
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 16, 2024
1 parent 91b7e38 commit 8b2bdf4
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions ngclearn/components/neurons/graded/rewardErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def _advance_state(dt, use_online_predictor, alpha, mu, rpe, reward,
## compute/update RPE and predictor values
accum_reward = accum_reward + reward
m = (Ns > 0.) * 1.
_Ns = Ns * m + (1. - m) ## mask out Ns
rpe = reward - mu/_Ns #reward - mu
#_Ns = Ns * m + (1. - m) ## mask out Ns
rpe = reward - mu #/_Ns #reward - mu
if use_online_predictor:
#mu = mu * (1. - alpha) + reward * alpha
mu = mu + reward
Ns = Ns + 1.
mu = mu * (1. - alpha) + reward * alpha
# mu = mu + reward
# Ns = Ns + 1.
n_ep_steps = n_ep_steps + 1
return mu, rpe, n_ep_steps, accum_reward, Ns

Expand All @@ -76,16 +76,18 @@ def advance_state(self, mu, rpe, n_ep_steps, accum_reward, Ns):

@staticmethod
def _evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu,
accum_reward):
accum_reward, rpe):
if use_online_predictor is False:
## total episodic reward signal
r = accum_reward/n_ep_steps
mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r
return mu
rpe = r - mu
return mu, rpe

@resolver(_evolve)
def evolve(self, mu):
def evolve(self, mu, rpe):
self.mu.set(mu)
self.rpe.set(rpe)

@staticmethod
def _reset(batch_size, n_units):
Expand Down

0 comments on commit 8b2bdf4

Please sign in to comment.