Skip to content

Commit

Permalink
trajectories and transitions now initalize log_rewards correctly, so …
Browse files Browse the repository at this point in the history
…extend always works, which means we no longer need the extra condition (but we are leaving in the sanity check for now).
  • Loading branch information
josephdviviano committed Apr 2, 2024
1 parent 184c5f5 commit 00cab17
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
11 changes: 6 additions & 5 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def __init__(
if when_is_done is not None
else torch.full(size=(0,), fill_value=-1, dtype=torch.long)
)
self._log_rewards = log_rewards
self._log_rewards = (
log_rewards
if log_rewards is not None
else torch.full(size=(0,), fill_value=0, dtype=torch.float)
)
self.log_probs = (
log_probs
if log_probs is not None
Expand Down Expand Up @@ -246,10 +250,7 @@ def extend(self, other: Trajectories) -> None:
(self._log_rewards, other._log_rewards),
dim=0,
)
# If the trajectories object does not yet have `log_rewards` assigned but the
# external trajectory has log_rewards, simply assign them over.
elif self._log_rewards is None and other._log_rewards is not None:
self._log_rewards = other._log_rewards
# Will not be None if object is initialized as empty.
else:
self._log_rewards = None

Expand Down
4 changes: 3 additions & 1 deletion src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
len(self.next_states.batch_shape) == 1
and self.states.batch_shape == self.next_states.batch_shape
)
self._log_rewards = log_rewards
self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0)
self.log_probs = log_probs if log_probs is not None else torch.zeros(0)

@property
Expand Down Expand Up @@ -208,6 +208,8 @@ def extend(self, other: Transitions) -> None:
self.actions.extend(other.actions)
self.is_done = torch.cat((self.is_done, other.is_done), dim=0)
self.next_states.extend(other.next_states)

# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards), dim=0
Expand Down

0 comments on commit 00cab17

Please sign in to comment.