Skip to content

Commit

Permalink
Merge branch 'master' into hyeok9855/minor-refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Oct 28, 2024
2 parents b6e042c + d2d959e commit 316a0c4
Show file tree
Hide file tree
Showing 26 changed files with 867 additions and 159 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ module_PF = MLP(
module_PB = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators.
Expand Down Expand Up @@ -136,7 +136,7 @@ module_PF = MLP(
module_PB = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)
module_logF = MLP(
input_dim=env.preprocessor.output_dim,
Expand Down
15 changes: 10 additions & 5 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,11 @@ def __init__(
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance

def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]):
def _add_objs(
self,
training_objects: Transitions | Trajectories | tuple[States],
terminating_states: States | None = None,
):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
self.training_objects.extend(training_objects)
Expand All @@ -153,15 +157,16 @@ def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert self.terminating_states is not None
self.terminating_states.extend(self.terminating_states)
assert terminating_states is not None
self.terminating_states.extend(terminating_states)

# Sort terminating states by logreward as well.
self.terminating_states = self.terminating_states[ix]
self.terminating_states = self.terminating_states[-self.capacity :]

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects
Expand All @@ -171,7 +176,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):

# The buffer isn't full yet.
if len(self.training_objects) < self.capacity:
self._add_objs(training_objects)
self._add_objs(training_objects, terminating_states)

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
Expand All @@ -180,7 +185,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
training_objects = training_objects[ix]

# Filter all batch logrewards lower than the smallest logreward in buffer.
min_reward_in_buffer = self.training_objects.log_rewards.min()
min_reward_in_buffer = self.training_objects.log_rewards.min() # type: ignore # FIXME
idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer
training_objects = training_objects[idx_bigger_rewards]

Expand Down
41 changes: 37 additions & 4 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Sequence, Union, Tuple


if TYPE_CHECKING:
from gfn.actions import Actions
from gfn.env import Env
from gfn.states import States
from gfn.states import States, DiscreteStates

import numpy as np
import torch
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
self,
env: Env,
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
when_is_done: TT["n_trajectories", torch.long] | None = None,
is_backward: bool = False,
Expand All @@ -76,6 +78,7 @@ def __init__(
is used to compute the rewards, at each call of self.log_rewards
"""
self.env = env
self.conditioning = conditioning
self.is_backward = is_backward
self.states = (
states if states is not None else env.states_from_batch_shape((0, 0))
Expand Down Expand Up @@ -315,6 +318,15 @@ def extend(self, other: Trajectories) -> None:

def to_transitions(self) -> Transitions:
"""Returns a `Transitions` object from the trajectories."""
if self.conditioning is not None:
traj_len = self.actions.batch_shape[0]
expand_dims = (traj_len,) + tuple(self.conditioning.shape)
conditioning = self.conditioning.unsqueeze(0).expand(expand_dims)[
~self.actions.is_dummy
]
else:
conditioning = None

states = self.states[:-1][~self.actions.is_dummy]
next_states = self.states[1:][~self.actions.is_dummy]
actions = self.actions[~self.actions.is_dummy]
Expand Down Expand Up @@ -348,6 +360,7 @@ def to_transitions(self) -> Transitions:
return Transitions(
env=self.env,
states=states,
conditioning=conditioning,
actions=actions,
is_done=is_done,
next_states=next_states,
Expand All @@ -363,7 +376,10 @@ def to_states(self) -> States:

def to_non_initial_intermediary_and_terminating_states(
self,
) -> tuple[States, States]:
) -> Union[
Tuple[States, States, torch.Tensor, torch.Tensor],
Tuple[States, States, None, None],
]:
"""Returns all intermediate and terminating `States` from the trajectories.
This is useful for the flow matching loss, that requires its inputs to be distinguished.
Expand All @@ -373,10 +389,27 @@ def to_non_initial_intermediary_and_terminating_states(
are not s0.
"""
states = self.states

if self.conditioning is not None:
traj_len = self.states.batch_shape[0]
expand_dims = (traj_len,) + tuple(self.conditioning.shape)
intermediary_conditioning = self.conditioning.unsqueeze(0).expand(
expand_dims
)[~states.is_sink_state & ~states.is_initial_state]
conditioning = self.conditioning # n_final_states == n_trajectories.
else:
intermediary_conditioning = None
conditioning = None

intermediary_states = states[~states.is_sink_state & ~states.is_initial_state]
terminating_states = self.last_states
terminating_states.log_rewards = self.log_rewards
return intermediary_states, terminating_states
return (
intermediary_states,
terminating_states,
intermediary_conditioning,
conditioning,
)


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
self,
env: Env,
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
is_done: TT["n_transitions", torch.bool] | None = None,
next_states: States | None = None,
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
`batch_shapes`.
"""
self.env = env
self.conditioning = conditioning
self.is_backward = is_backward
self.states = (
states
Expand Down
66 changes: 50 additions & 16 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import math
from abc import ABC, abstractmethod
from typing import Generic, Tuple, TypeVar, Union
from typing import Generic, Tuple, TypeVar, Union, Any

import torch
import torch.nn as nn
from torch import Tensor
from torchtyping import TensorType as TT

from gfn.containers import Trajectories
Expand All @@ -14,6 +13,10 @@
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand All @@ -32,48 +35,48 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):
def sample_trajectories(
self,
env: Env,
n_samples: int,
n: int,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
) -> Trajectories:
"""Sample a specific number of complete trajectories.
Args:
env: the environment to sample trajectories from.
n_samples: number of trajectories to be sampled.
n: number of trajectories to be sampled.
save_logprobs: whether to save the logprobs of the actions - useful for on-policy learning.
save_estimator_outputs: whether to save the estimator outputs - useful for off-policy learning
with tempered policy
Returns:
Trajectories: sampled trajectories object.
"""

def sample_terminating_states(self, env: Env, n_samples: int) -> States:
def sample_terminating_states(self, env: Env, n: int) -> States:
"""Rolls out the parametrization's policy and returns the terminating states.
Args:
env: the environment to sample terminating states from.
n_samples: number of terminating states to be sampled.
n: number of terminating states to be sampled.
Returns:
States: sampled terminating states object.
"""
trajectories = self.sample_trajectories(
env, n_samples, save_estimator_outputs=False, save_logprobs=False
env, n, save_estimator_outputs=False, save_logprobs=False
)
return trajectories.last_states

def logz_named_parameters(self):
return {"logZ": dict(self.named_parameters())["logZ"]}
return {k: v for k, v in dict(self.named_parameters()).items() if "logZ" in k}

def logz_parameters(self):
return [dict(self.named_parameters())["logZ"]]
return [v for k, v in dict(self.named_parameters()).items() if "logZ" in k]

@abstractmethod
def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects):
def loss(self, env: Env, training_objects: Any):
"""Computes the loss given the training objects."""


Expand All @@ -93,18 +96,20 @@ def __init__(self, pf: GFNModule, pb: GFNModule):
def sample_trajectories(
self,
env: Env,
n_samples: int,
n: int,
conditioning: torch.Tensor | None = None,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
**policy_kwargs,
**policy_kwargs: Any,
) -> Trajectories:
"""Samples trajectories, optionally with specified policy kwargs."""
sampler = Sampler(estimator=self.pf)
trajectories = sampler.sample_trajectories(
env,
n_trajectories=n_samples,
save_estimator_outputs=save_estimator_outputs,
n=n,
conditioning=conditioning,
save_logprobs=save_logprobs,
save_estimator_outputs=save_estimator_outputs,
**policy_kwargs,
)

Expand Down Expand Up @@ -176,7 +181,20 @@ def get_pfs_and_pbs(
~trajectories.actions.is_dummy
]
else:
estimator_outputs = self.pf(valid_states)
if trajectories.conditioning is not None:
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
with has_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states, masked_cond)
else:
# Here, we pass all valid states, i.e., non-sink states.
with no_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.pf.to_probability_distribution(
Expand All @@ -196,7 +214,23 @@ def get_pfs_and_pbs(

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
estimator_outputs = self.pb(non_initial_valid_states)
if trajectories.conditioning is not None:

# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state][~valid_states.is_initial_state]

# Pass all valid states, i.e., non-sink states, except the initial state.
with has_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states, masked_cond)
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
with no_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)
Expand Down
Loading

0 comments on commit 316a0c4

Please sign in to comment.