Skip to content

Commit

Permalink
sync for debug
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 18, 2023
1 parent 740e16f commit 9d1ffd0
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def get_pfs_and_pbs(
log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice

# TODO: Optionally zero out S1->S0.

return log_pf_trajectories, log_pb_trajectories

def get_trajectories_scores(
Expand All @@ -173,7 +175,7 @@ def get_trajectories_scores(
total_log_pf_trajectories = log_pf_trajectories.sum(dim=0)
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)

log_rewards = trajectories.log_rewards.clamp_min(self.log_reward_clip_min) # type: ignore
log_rewards = trajectories.log_rewards # .clamp_min(self.log_reward_clip_min) # type: ignore
if torch.any(torch.isinf(total_log_pf_trajectories)) or torch.any(
torch.isinf(total_log_pb_trajectories)
):
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_scores(

valid_log_F_s = self.logF(states).squeeze(-1)
if self.forward_looking:
log_rewards = env.log_reward(states) # RM unsqueeze(-1)
log_rewards = env.log_reward(states) # TODO: RM unsqueeze(-1) ?
valid_log_F_s = valid_log_F_s + log_rewards

preds = valid_log_pf_actions + valid_log_F_s
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
"equal_within",
] = "geometric_within",
lamda: float = 0.9,
log_reward_clip_min: float = -12, # roughly log(1e-5)
log_reward_clip_min: float = -100, # roughly log(5e-44)
forward_looking: bool = False,
):
super().__init__(pf, pb, on_policy=on_policy)
Expand Down Expand Up @@ -151,7 +151,7 @@ def get_scores(
assert trajectories.log_rewards is not None
log_rewards = trajectories.log_rewards[
trajectories.when_is_done >= i
].clamp_min(self.log_reward_clip_min)
] # .clamp_min(self.log_reward_clip_min)
targets.T[is_terminal_mask[i - 1 :].T] = log_rewards

# For now, the targets contain the log-rewards of the ending sub trajectories
Expand Down
6 changes: 3 additions & 3 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
pb: GFNModule,
on_policy: bool = False,
init_logZ: float = 0.0,
log_reward_clip_min: float = -12, # roughly log(1e-5)
log_reward_clip_min: float = -100, # roughly log(5e-44)
):
super().__init__(pf, pb, on_policy=on_policy)

Expand Down Expand Up @@ -74,11 +74,11 @@ def __init__(
pf: GFNModule,
pb: GFNModule,
on_policy: bool = False,
log_reward_clip_min: float = -12,
log_reward_clip_min: float = -100, # Roughly roughly log(5e-44)
):
super().__init__(pf, pb, on_policy=on_policy)

self.log_reward_clip_min = log_reward_clip_min # -12 is roughly log(1e-5)
self.log_reward_clip_min = log_reward_clip_min

def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
"""Log Partition Variance loss.
Expand Down

0 comments on commit 9d1ffd0

Please sign in to comment.