Skip to content

Commit

Permalink
changes requested for PR
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Feb 13, 2024
1 parent a6601d7 commit 71da6b5
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 61 deletions.
7 changes: 3 additions & 4 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self.env = env
self.is_backward = is_backward
self.states = (
states.clone() # TODO: Do we need this clone?
states
if states is not None
else env.States.from_batch_shape(batch_shape=(0, 0))
)
Expand Down Expand Up @@ -169,8 +169,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
)

if is_tensor(self.estimator_outputs):
estimator_outputs = self.estimator_outputs[:, index]
estimator_outputs = estimator_outputs[:new_max_length]
estimator_outputs = self.estimator_outputs[..., index][:new_max_length]
else:
estimator_outputs = None

Expand Down Expand Up @@ -261,7 +260,7 @@ def extend(self, other: Trajectories) -> None:
other_shape = np.array(other.estimator_outputs.shape)
required_first_dim = max(self_shape[0], other_shape[0])

# TODO: This should be a single reused function.
# TODO: This should be a single reused function (#154)
# The size of self needs to grow to match other along dim=0.
if self_shape[0] < other_shape[0]:
pad_dim = required_first_dim - self_shape[0]
Expand Down
6 changes: 2 additions & 4 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gfn.actions import Actions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
from gfn.states import DiscreteStates, States
from gfn.utils.common import set_seed

# Errors
NonValidActionsError = type("NonValidActionsError", (ValueError,), {})
Expand Down Expand Up @@ -79,7 +80,7 @@ def reset(
assert not (random and sink)

if random and seed is not None:
torch.manual_seed(seed) # TODO: Improve seeding here?
set_seed(seed, performance_mode=True)

if batch_shape is None:
batch_shape = (1,)
Expand Down Expand Up @@ -150,9 +151,6 @@ def step(
new_not_done_states_tensor = self.maskless_step(
not_done_states, not_done_actions
)
# TODO: Why is this here? Should it be removed?
# if isinstance(new_states, DiscreteStates):
# new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions)

new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor

Expand Down
26 changes: 14 additions & 12 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from abc import ABC, abstractmethod
from typing import Generic, Tuple, TypeVar, Union
import math

import torch
import torch.nn as nn
Expand All @@ -26,12 +26,15 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):
"""

@abstractmethod
def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories:
def sample_trajectories(
self, env: Env, n_samples: int, sample_off_policy: bool
) -> Trajectories:
"""Sample a specific number of complete trajectories.
Args:
env: the environment to sample trajectories from.
n_samples: number of trajectories to be sampled.
sample_off_policy: whether to sample trajectories on / off policy.
Returns:
Trajectories: sampled trajectories object.
"""
Expand All @@ -48,12 +51,6 @@ def sample_terminating_states(self, env: Env, n_samples: int) -> States:
trajectories = self.sample_trajectories(env, n_samples, sample_off_policy=False)
return trajectories.last_states

def pf_pb_named_parameters(self):
return {k: v for k, v in self.named_parameters() if "pb" in k or "pf" in k}

def pf_pb_parameters(self):
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]

def logz_named_parameters(self):
return {"logZ": dict(self.named_parameters())["logZ"]}

Expand Down Expand Up @@ -97,6 +94,12 @@ def sample_trajectories(

return trajectories

def pf_pb_named_parameters(self):
return {k: v for k, v in self.named_parameters() if "pb" in k or "pf" in k}

def pf_pb_parameters(self):
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]


class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
def get_pfs_and_pbs(
Expand Down Expand Up @@ -148,7 +151,7 @@ def get_pfs_and_pbs(

if self.off_policy:
# We re-use the values calculated in .sample_trajectories().
if not isinstance(trajectories.estimator_outputs, type(None)):
if trajectories.estimator_outputs is not None:
estimator_outputs = trajectories.estimator_outputs[
~trajectories.actions.is_dummy
]
Expand Down Expand Up @@ -211,9 +214,8 @@ def get_trajectories_scores(
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)

log_rewards = trajectories.log_rewards
if math.isfinite(self.log_reward_clip_min) and not isinstance(
log_rewards, type(None)
):
# TODO: log_reward_clip_min isn't defined in base (#155).
if math.isfinite(self.log_reward_clip_min) and log_rewards is not None:
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)

if torch.any(torch.isinf(total_log_pf_trajectories)) or torch.any(
Expand Down
15 changes: 9 additions & 6 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def __init__(
logF: ScalarEstimator,
off_policy: bool,
forward_looking: bool = False,
log_reward_clamp_min: float = -float("inf"),
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb, off_policy=off_policy)
self.logF = logF
self.forward_looking = forward_looking
self.log_reward_clamp_min = log_reward_clamp_min
self.log_reward_clip_min = log_reward_clip_min

def get_scores(
self, env: Env, transitions: Transitions
Expand Down Expand Up @@ -68,10 +68,13 @@ def get_scores(

if states.batch_shape != tuple(actions.batch_shape):
raise ValueError("Something wrong happening with log_pf evaluations")
if self.off_policy:
if not self.off_policy:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
# I suppose the Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
# See (#156).
module_output = self.pf(states) # TODO: Inefficient duplication.
valid_log_pf_actions = self.pf.to_probability_distribution(
states, module_output
Expand All @@ -82,8 +85,8 @@ def get_scores(
valid_log_F_s = self.logF(states).squeeze(-1)
if self.forward_looking:
log_rewards = env.log_reward(states) # TODO: RM unsqueeze(-1) ?
if math.isfinite(self.log_reward_clamp_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clamp_min)
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
valid_log_F_s = valid_log_F_s + log_rewards

preds = valid_log_pf_actions + valid_log_F_s
Expand Down Expand Up @@ -163,7 +166,7 @@ def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.flo
all_log_rewards = transitions.all_log_rewards[mask]
module_output = self.pf(states)
pf_dist = self.pf.to_probability_distribution(states, module_output)
if self.off_policy:
if not self.off_policy:
valid_log_pf_actions = transitions[mask].log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
Expand Down
20 changes: 0 additions & 20 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,6 @@ def make_random_states_tensor(
device=env.device,
)

# TODO: Look into make masks - I don't think this is being called.
def make_masks(
self,
) -> Tuple[
TT["batch_shape", "n_actions", torch.bool],
TT["batch_shape", "n_actions - 1", torch.bool],
]:
forward_masks = torch.zeros(
self.batch_shape + (env.n_actions,),
device=env.device,
dtype=torch.bool,
)
backward_masks = torch.zeros(
self.batch_shape + (env.n_actions - 1,),
device=env.device,
dtype=torch.bool,
)

return forward_masks, backward_masks

def update_masks(self) -> None:
self.set_default_typing()
self.forward_masks[..., : env.ndim] = self.tensor == -1
Expand Down
15 changes: 9 additions & 6 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def sample_trajectories(
off_policy: bool,
states: Optional[States] = None,
n_trajectories: Optional[int] = None,
test_mode: bool = False,
debug_mode: bool = False,
**policy_kwargs,
) -> Trajectories:
"""Sample trajectories sequentially.
Expand All @@ -110,16 +110,16 @@ def sample_trajectories(
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
test_mode: if True, everything gets calculated.
debug_mode: if True, everything gets calculated.
Returns: A Trajectories object representing the batch of sampled trajectories.
Raises:
AssertionError: When both states and n_trajectories are specified.
AssertionError: When states are not linear.
"""
save_estimator_outputs = off_policy or test_mode
skip_logprob_calculaion = off_policy and not test_mode
save_estimator_outputs = off_policy or debug_mode
skip_logprob_calculaion = off_policy and not debug_mode

if states is None:
assert (
Expand Down Expand Up @@ -173,7 +173,7 @@ def sample_trajectories(
calculate_logprobs=False if skip_logprob_calculaion else True,
**policy_kwargs,
)
if not isinstance(estimator_outputs, type(None)):
if estimator_outputs is not None:
# Place estimator outputs into a stackable tensor. Note that this
# will be replaced with torch.nested.nested_tensor in the future.
estimator_outputs_padded = torch.full(
Expand Down Expand Up @@ -202,11 +202,14 @@ def sample_trajectories(
# Increment the step, determine which trajectories are finisihed, and eval
# rewards.
step += 1
# new_dones means those trajectories that just finished. Because we
# pad the sink state to every short trajectory, we need to make sure
# to filter out the already done ones.
new_dones = (
new_states.is_initial_state
if self.estimator.is_backward
else sink_states_mask
) & ~dones # TODO: why is ~dones used here and again later on? Is this intentional?
) & ~dones
trajectories_dones[new_dones & ~dones] = step
try:
trajectories_log_rewards[new_dones & ~dones] = env.log_reward(
Expand Down
8 changes: 4 additions & 4 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations # This allows to use the class name in type hints

from abc import ABC, abstractmethod
from copy import deepcopy
from math import prod
from typing import ClassVar, Optional, Sequence, cast

Expand Down Expand Up @@ -133,7 +134,7 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> States:
"""Access particular states of the batch."""
return self.__class__(
self.tensor[index]
) # TODO: Inefficient - this make a copy of the tensor!
) # TODO: Inefficient - this might make a copy of the tensor!

def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], states: States
Expand All @@ -142,9 +143,8 @@ def __setitem__(
self.tensor[index] = states.tensor

def clone(self) -> States:
"""Returns a clone of the current instance."""
# TODO: Do we need to copy _log_rewards?
return self.__class__(self.tensor.detach().clone())
"""Returns a *detached* clone of the current instance using deepcopy."""
return deepcopy(self)

def flatten(self) -> States:
"""Flatten the batch dimension of the states.
Expand Down
9 changes: 6 additions & 3 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ def validate(
return validation_info


def set_seed(seed: int) -> None:
def set_seed(seed: int, performance_mode: bool = False) -> None:
"""Used to control randomness."""
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# These are only set when we care about reproducibility over performance.
if not performance_mode:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
2 changes: 1 addition & 1 deletion src/gfn/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
arch.append(nn.Linear(hidden_dim, hidden_dim))
arch.append(activation())
self.torso = nn.Sequential(*arch)
self.torso.hidden_dim = hidden_dim # TODO: what is this?
self.torso.hidden_dim = hidden_dim
else:
self.torso = torso
self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim)
Expand Down
2 changes: 1 addition & 1 deletion testing/test_samplers_and_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def trajectory_sampling_with_return(

sampler = Sampler(estimator=pf_estimator)
# Test mode collects log_probs and estimator_ouputs, not encountered in the wild.
trajectories = sampler.sample_trajectories(env, off_policy=False, n_trajectories=5, test_mode=True)
trajectories = sampler.sample_trajectories(env, off_policy=False, n_trajectories=5, debug_mode=True)
# trajectories = sampler.sample_trajectories(env, n_trajectories=10) # TODO - why is this duplicated?

states = env.reset(batch_shape=5, random=True)
Expand Down

0 comments on commit 71da6b5

Please sign in to comment.