From 71da6b5c62f19ca3dd22dd8606037d6ece3f40dd Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 13 Feb 2024 11:24:33 -0500 Subject: [PATCH] changes requested for PR --- src/gfn/containers/trajectories.py | 7 +++--- src/gfn/env.py | 6 ++---- src/gfn/gflownet/base.py | 26 ++++++++++++----------- src/gfn/gflownet/detailed_balance.py | 15 +++++++------ src/gfn/gym/discrete_ebm.py | 20 ----------------- src/gfn/samplers.py | 15 +++++++------ src/gfn/states.py | 8 +++---- src/gfn/utils/common.py | 9 +++++--- src/gfn/utils/modules.py | 2 +- testing/test_samplers_and_trajectories.py | 2 +- 10 files changed, 49 insertions(+), 61 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index e2e25f6f..5b0142e6 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -77,7 +77,7 @@ def __init__( self.env = env self.is_backward = is_backward self.states = ( - states.clone() # TODO: Do we need this clone? + states if states is not None else env.States.from_batch_shape(batch_shape=(0, 0)) ) @@ -169,8 +169,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: ) if is_tensor(self.estimator_outputs): - estimator_outputs = self.estimator_outputs[:, index] - estimator_outputs = estimator_outputs[:new_max_length] + estimator_outputs = self.estimator_outputs[..., index][:new_max_length] else: estimator_outputs = None @@ -261,7 +260,7 @@ def extend(self, other: Trajectories) -> None: other_shape = np.array(other.estimator_outputs.shape) required_first_dim = max(self_shape[0], other_shape[0]) - # TODO: This should be a single reused function. + # TODO: This should be a single reused function (#154) # The size of self needs to grow to match other along dim=0. if self_shape[0] < other_shape[0]: pad_dim = required_first_dim - self_shape[0] diff --git a/src/gfn/env.py b/src/gfn/env.py index c21f958b..bf2a3d3b 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -7,6 +7,7 @@ from gfn.actions import Actions from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, States +from gfn.utils.common import set_seed # Errors NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) @@ -79,7 +80,7 @@ def reset( assert not (random and sink) if random and seed is not None: - torch.manual_seed(seed) # TODO: Improve seeding here? + set_seed(seed, performance_mode=True) if batch_shape is None: batch_shape = (1,) @@ -150,9 +151,6 @@ def step( new_not_done_states_tensor = self.maskless_step( not_done_states, not_done_actions ) - # TODO: Why is this here? Should it be removed? - # if isinstance(new_states, DiscreteStates): - # new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 0656ba64..5e04151d 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,6 +1,6 @@ +import math from abc import ABC, abstractmethod from typing import Generic, Tuple, TypeVar, Union -import math import torch import torch.nn as nn @@ -26,12 +26,15 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): """ @abstractmethod - def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories: + def sample_trajectories( + self, env: Env, n_samples: int, sample_off_policy: bool + ) -> Trajectories: """Sample a specific number of complete trajectories. Args: env: the environment to sample trajectories from. n_samples: number of trajectories to be sampled. + sample_off_policy: whether to sample trajectories on / off policy. Returns: Trajectories: sampled trajectories object. """ @@ -48,12 +51,6 @@ def sample_terminating_states(self, env: Env, n_samples: int) -> States: trajectories = self.sample_trajectories(env, n_samples, sample_off_policy=False) return trajectories.last_states - def pf_pb_named_parameters(self): - return {k: v for k, v in self.named_parameters() if "pb" in k or "pf" in k} - - def pf_pb_parameters(self): - return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k] - def logz_named_parameters(self): return {"logZ": dict(self.named_parameters())["logZ"]} @@ -97,6 +94,12 @@ def sample_trajectories( return trajectories + def pf_pb_named_parameters(self): + return {k: v for k, v in self.named_parameters() if "pb" in k or "pf" in k} + + def pf_pb_parameters(self): + return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k] + class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]): def get_pfs_and_pbs( @@ -148,7 +151,7 @@ def get_pfs_and_pbs( if self.off_policy: # We re-use the values calculated in .sample_trajectories(). - if not isinstance(trajectories.estimator_outputs, type(None)): + if trajectories.estimator_outputs is not None: estimator_outputs = trajectories.estimator_outputs[ ~trajectories.actions.is_dummy ] @@ -211,9 +214,8 @@ def get_trajectories_scores( total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) log_rewards = trajectories.log_rewards - if math.isfinite(self.log_reward_clip_min) and not isinstance( - log_rewards, type(None) - ): + # TODO: log_reward_clip_min isn't defined in base (#155). + if math.isfinite(self.log_reward_clip_min) and log_rewards is not None: log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) if torch.any(torch.isinf(total_log_pf_trajectories)) or torch.any( diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 818a2d8a..4cb4e6e2 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -35,12 +35,12 @@ def __init__( logF: ScalarEstimator, off_policy: bool, forward_looking: bool = False, - log_reward_clamp_min: float = -float("inf"), + log_reward_clip_min: float = -float("inf"), ): super().__init__(pf, pb, off_policy=off_policy) self.logF = logF self.forward_looking = forward_looking - self.log_reward_clamp_min = log_reward_clamp_min + self.log_reward_clip_min = log_reward_clip_min def get_scores( self, env: Env, transitions: Transitions @@ -68,10 +68,13 @@ def get_scores( if states.batch_shape != tuple(actions.batch_shape): raise ValueError("Something wrong happening with log_pf evaluations") - if self.off_policy: + if not self.off_policy: valid_log_pf_actions = transitions.log_probs else: # Evaluate the log PF of the actions sampled off policy. + # I suppose the Transitions container should then have some + # estimator_outputs attribute as well, to avoid duplication here ? + # See (#156). module_output = self.pf(states) # TODO: Inefficient duplication. valid_log_pf_actions = self.pf.to_probability_distribution( states, module_output @@ -82,8 +85,8 @@ def get_scores( valid_log_F_s = self.logF(states).squeeze(-1) if self.forward_looking: log_rewards = env.log_reward(states) # TODO: RM unsqueeze(-1) ? - if math.isfinite(self.log_reward_clamp_min): - log_rewards = log_rewards.clamp_min(self.log_reward_clamp_min) + if math.isfinite(self.log_reward_clip_min): + log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) valid_log_F_s = valid_log_F_s + log_rewards preds = valid_log_pf_actions + valid_log_F_s @@ -163,7 +166,7 @@ def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.flo all_log_rewards = transitions.all_log_rewards[mask] module_output = self.pf(states) pf_dist = self.pf.to_probability_distribution(states, module_output) - if self.off_policy: + if not self.off_policy: valid_log_pf_actions = transitions[mask].log_probs else: # Evaluate the log PF of the actions sampled off policy. diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index ecd05eea..a4f82735 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -116,26 +116,6 @@ def make_random_states_tensor( device=env.device, ) - # TODO: Look into make masks - I don't think this is being called. - def make_masks( - self, - ) -> Tuple[ - TT["batch_shape", "n_actions", torch.bool], - TT["batch_shape", "n_actions - 1", torch.bool], - ]: - forward_masks = torch.zeros( - self.batch_shape + (env.n_actions,), - device=env.device, - dtype=torch.bool, - ) - backward_masks = torch.zeros( - self.batch_shape + (env.n_actions - 1,), - device=env.device, - dtype=torch.bool, - ) - - return forward_masks, backward_masks - def update_masks(self) -> None: self.set_default_typing() self.forward_masks[..., : env.ndim] = self.tensor == -1 diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 92781664..56cd83de 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -91,7 +91,7 @@ def sample_trajectories( off_policy: bool, states: Optional[States] = None, n_trajectories: Optional[int] = None, - test_mode: bool = False, + debug_mode: bool = False, **policy_kwargs, ) -> Trajectories: """Sample trajectories sequentially. @@ -110,7 +110,7 @@ def sample_trajectories( parameter, `epsilon`, and `sf_bias`. In the continuous case these kwargs will be user defined. This can be used to, for example, sample off-policy. - test_mode: if True, everything gets calculated. + debug_mode: if True, everything gets calculated. Returns: A Trajectories object representing the batch of sampled trajectories. @@ -118,8 +118,8 @@ def sample_trajectories( AssertionError: When both states and n_trajectories are specified. AssertionError: When states are not linear. """ - save_estimator_outputs = off_policy or test_mode - skip_logprob_calculaion = off_policy and not test_mode + save_estimator_outputs = off_policy or debug_mode + skip_logprob_calculaion = off_policy and not debug_mode if states is None: assert ( @@ -173,7 +173,7 @@ def sample_trajectories( calculate_logprobs=False if skip_logprob_calculaion else True, **policy_kwargs, ) - if not isinstance(estimator_outputs, type(None)): + if estimator_outputs is not None: # Place estimator outputs into a stackable tensor. Note that this # will be replaced with torch.nested.nested_tensor in the future. estimator_outputs_padded = torch.full( @@ -202,11 +202,14 @@ def sample_trajectories( # Increment the step, determine which trajectories are finisihed, and eval # rewards. step += 1 + # new_dones means those trajectories that just finished. Because we + # pad the sink state to every short trajectory, we need to make sure + # to filter out the already done ones. new_dones = ( new_states.is_initial_state if self.estimator.is_backward else sink_states_mask - ) & ~dones # TODO: why is ~dones used here and again later on? Is this intentional? + ) & ~dones trajectories_dones[new_dones & ~dones] = step try: trajectories_log_rewards[new_dones & ~dones] = env.log_reward( diff --git a/src/gfn/states.py b/src/gfn/states.py index f5d63a4e..e50b6aea 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -1,6 +1,7 @@ from __future__ import annotations # This allows to use the class name in type hints from abc import ABC, abstractmethod +from copy import deepcopy from math import prod from typing import ClassVar, Optional, Sequence, cast @@ -133,7 +134,7 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> States: """Access particular states of the batch.""" return self.__class__( self.tensor[index] - ) # TODO: Inefficient - this make a copy of the tensor! + ) # TODO: Inefficient - this might make a copy of the tensor! def __setitem__( self, index: int | Sequence[int] | Sequence[bool], states: States @@ -142,9 +143,8 @@ def __setitem__( self.tensor[index] = states.tensor def clone(self) -> States: - """Returns a clone of the current instance.""" - # TODO: Do we need to copy _log_rewards? - return self.__class__(self.tensor.detach().clone()) + """Returns a *detached* clone of the current instance using deepcopy.""" + return deepcopy(self) def flatten(self) -> States: """Flatten the batch dimension of the states. diff --git a/src/gfn/utils/common.py b/src/gfn/utils/common.py index a80890c5..75a9ffe8 100644 --- a/src/gfn/utils/common.py +++ b/src/gfn/utils/common.py @@ -72,11 +72,14 @@ def validate( return validation_info -def set_seed(seed: int) -> None: +def set_seed(seed: int, performance_mode: bool = False) -> None: """Used to control randomness.""" torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + + # These are only set when we care about reproducibility over performance. + if not performance_mode: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index f99aa22d..2ffbf54a 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -50,7 +50,7 @@ def __init__( arch.append(nn.Linear(hidden_dim, hidden_dim)) arch.append(activation()) self.torso = nn.Sequential(*arch) - self.torso.hidden_dim = hidden_dim # TODO: what is this? + self.torso.hidden_dim = hidden_dim else: self.torso = torso self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 77eca901..1ff865f5 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -80,7 +80,7 @@ def trajectory_sampling_with_return( sampler = Sampler(estimator=pf_estimator) # Test mode collects log_probs and estimator_ouputs, not encountered in the wild. - trajectories = sampler.sample_trajectories(env, off_policy=False, n_trajectories=5, test_mode=True) + trajectories = sampler.sample_trajectories(env, off_policy=False, n_trajectories=5, debug_mode=True) # trajectories = sampler.sample_trajectories(env, n_trajectories=10) # TODO - why is this duplicated? states = env.reset(batch_shape=5, random=True)