Skip to content

Commit

Permalink
sample trajectories bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 16, 2023
1 parent d9fa884 commit 2a16b1b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def sample_trajectories(
env: Env,
states: Optional[States] = None,
n_trajectories: Optional[int] = None,
**policy_kwargs: Optional[dict],
**policy_kwargs,
) -> Trajectories:
"""Sample trajectories sequentially.
Expand Down Expand Up @@ -131,8 +131,9 @@ def sample_trajectories(
# TODO: Retrieve module outputs here, and stack them along the trajectory
# length.
# TODO: Optionally submit module outputs to skip re-estimation.
actions[~dones] = valid_actions
valid_actions, actions_log_probs = self.sample_actions(env, states[~dones], **policy_kwargs)
actions[~dones] = valid_actions

log_probs[~dones] = actions_log_probs
trajectories_actions += [actions]
trajectories_logprobs += [log_probs]
Expand Down

0 comments on commit 2a16b1b

Please sign in to comment.