Skip to content

Commit

Permalink
edit to rpe
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 16, 2024
1 parent 6cc4841 commit c4ed156
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions ngclearn/components/neurons/graded/rewardErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,19 @@ def _reset(batch_size, n_units):
restVals = jnp.zeros((batch_size, n_units))
mu = restVals
rpe = restVals
accum_reward = restVals
n_ep_steps = jnp.zeros((batch_size, 1))
Ns = jnp.zeros((batch_size, 1))
return mu, rpe, accum_reward, n_ep_steps, Ns
resetMask = jnp.zeros((batch_size, 1))
accum_reward = resetMask
reward = resetMask
n_ep_steps = resetMask
Ns = resetMask
return mu, rpe, accum_reward, reward, n_ep_steps, Ns

@resolver(_reset)
def reset(self, mu, rpe, accum_reward, n_ep_steps, Ns):
def reset(self, mu, rpe, accum_reward, reward, n_ep_steps, Ns):
self.mu.set(mu)
self.rpe.set(rpe)
self.accum_reward.set(accum_reward)
self.reward.set(reward)
self.n_ep_steps.set(n_ep_steps)
self.Ns.set(Ns)

Expand Down

0 comments on commit c4ed156

Please sign in to comment.