Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix formatting #207

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __init__(self, tensor: torch.Tensor):
Args:
tensor: tensors representing a batch of actions with shape (*batch_shape, *action_shape).
"""
assert tensor.shape[-len(self.action_shape):] == self.action_shape, (
f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}."
)
assert (
tensor.shape[-len(self.action_shape) :] == self.action_shape
), f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}."

self.tensor = tensor
self.batch_shape = tuple(self.tensor.shape)[:-len(self.action_shape)]
self.batch_shape = tuple(self.tensor.shape)[: -len(self.action_shape)]

@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> Actions:
Expand Down Expand Up @@ -137,13 +137,13 @@ def compare(self, other: torch.Tensor) -> torch.Tensor:

Args:
other: tensor of actions to compare, with shape (*batch_shape, *action_shape).

Returns: boolean tensor of shape batch_shape indicating whether the actions are
equal.
"""
assert other.shape == self.batch_shape + self.action_shape, (
f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}."
)
assert (
other.shape == self.batch_shape + self.action_shape
), f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}."
out = self.tensor == other
n_batch_dims = len(self.batch_shape)

Expand Down
2 changes: 1 addition & 1 deletion src/gfn/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from .replay_buffer import PrioritizedReplayBuffer, ReplayBuffer
from .trajectories import Trajectories
from .transitions import Transitions
2 changes: 1 addition & 1 deletion src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
def _add_objs(
self,
training_objects: Transitions | Trajectories | tuple[States],
terminating_states: States | None = None
terminating_states: States | None = None,
):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
Expand Down
28 changes: 17 additions & 11 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, Union, Tuple

from typing import TYPE_CHECKING, Sequence, Tuple, Union

if TYPE_CHECKING:
from gfn.actions import Actions
Expand Down Expand Up @@ -92,20 +91,29 @@ def __init__(
if when_is_done is not None
else torch.full(size=(0,), fill_value=-1, dtype=torch.long)
)
assert self.when_is_done.shape == (self.n_trajectories,) and self.when_is_done.dtype == torch.long
assert (
self.when_is_done.shape == (self.n_trajectories,)
and self.when_is_done.dtype == torch.long
)

self._log_rewards = (
log_rewards
if log_rewards is not None
else torch.full(size=(0,), fill_value=0, dtype=torch.float)
)
assert self._log_rewards.shape == (self.n_trajectories,) and self._log_rewards.dtype == torch.float
assert (
self._log_rewards.shape == (self.n_trajectories,)
and self._log_rewards.dtype == torch.float
)

if log_probs is not None:
assert log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
)
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs = log_probs
self.log_probs = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
Expand Down Expand Up @@ -207,15 +215,13 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
)

@staticmethod
def extend_log_probs(
log_probs: torch.Tensor, new_max_length: int
) -> torch.Tensor:
def extend_log_probs(log_probs: torch.Tensor, new_max_length: int) -> torch.Tensor:
"""Extend the log_probs matrix by adding 0 until the required length is reached.

Args:
log_probs: The log_probs tensor of shape (max_length, n_trajectories) to extend.
new_max_length: The new length of the log_probs tensor.

Returns: The extended log_probs tensor of shape (new_max_length, n_trajectories).

"""
Expand Down
22 changes: 17 additions & 5 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def __init__(
if is_done is not None
else torch.full(size=(0,), fill_value=False, dtype=torch.bool)
)
assert self.is_done.shape == (self.n_transitions,) and self.is_done.dtype == torch.bool
assert (
self.is_done.shape == (self.n_transitions,)
and self.is_done.dtype == torch.bool
)

self.next_states = (
next_states
Expand All @@ -96,9 +99,15 @@ def __init__(
and self.states.batch_shape == self.next_states.batch_shape
)
self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0)
assert self._log_rewards.shape == (self.n_transitions,) and self._log_rewards.dtype == torch.float
assert (
self._log_rewards.shape == (self.n_transitions,)
and self._log_rewards.dtype == torch.float
)
self.log_probs = log_probs if log_probs is not None else torch.zeros(0)
assert self.log_probs.shape == (self.n_transitions,) and self.log_probs.dtype == torch.float
assert (
self.log_probs.shape == (self.n_transitions,)
and self.log_probs.dtype == torch.float
)

@property
def n_transitions(self) -> int:
Expand Down Expand Up @@ -186,8 +195,11 @@ def all_log_rewards(self) -> torch.Tensor:
log_rewards[~is_sink_state, 1] = torch.log(
self.env.reward(self.next_states[~is_sink_state])
)

assert log_rewards.shape == (self.n_transitions, 2) and log_rewards.dtype == torch.float

assert (
log_rewards.shape == (self.n_transitions, 2)
and log_rewards.dtype == torch.float
)
return log_rewards

def __getitem__(self, index: int | Sequence[int]) -> Transitions:
Expand Down
49 changes: 21 additions & 28 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ def __init__(

def states_from_tensor(self, tensor: torch.Tensor):
"""Wraps the supplied Tensor in a States instance.

Args:
tensor: The tensor of shape "state_shape" representing the states.

Returns:
States: An instance of States.
"""
return self.States(tensor)

def states_from_batch_shape(self, batch_shape: Tuple):
"""Returns a batch of s0 states with a given batch_shape.

Args:
batch_shape: Tuple representing the shape of the batch of states.

Expand All @@ -98,38 +98,36 @@ def states_from_batch_shape(self, batch_shape: Tuple):

def actions_from_tensor(self, tensor: torch.Tensor):
"""Wraps the supplied Tensor an an Actions instance.

Args:
tensor: The tensor of shape "action_shape" representing the actions.

Returns:
Actions: An instance of Actions.
"""
return self.Actions(tensor)

def actions_from_batch_shape(self, batch_shape: Tuple):
"""Returns a batch of dummy actions with the supplied batch_shape.

Args:
batch_shape: Tuple representing the shape of the batch of actions.

Returns:
Actions: A batch of dummy actions.
"""
return self.Actions.make_dummy_actions(batch_shape)

# To be implemented by the User.
@abstractmethod
def step(
self, states: States, actions: Actions
) -> torch.Tensor:
def step(self, states: States, actions: Actions) -> torch.Tensor:
"""Function that takes a batch of states and actions and returns a batch of next
states. Does not need to check whether the actions are valid or the states are sink states.

Args:
states: A batch of states.
actions: A batch of actions.

Returns:
torch.Tensor: A batch of next states.
"""
Expand All @@ -140,11 +138,11 @@ def backward_step( # TODO: rename to backward_step, other method becomes _backw
) -> torch.Tensor:
"""Function that takes a batch of states and actions and returns a batch of previous
states. Does not need to check whether the actions are valid or the states are sink states.

Args:
states: A batch of states.
actions: A batch of actions.

Returns:
torch.Tensor: A batch of previous states.
"""
Expand Down Expand Up @@ -312,7 +310,7 @@ def reward(self, final_states: States) -> torch.Tensor:

Args:
final_states: A batch of final states.

Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the rewards.
"""
Expand All @@ -321,10 +319,10 @@ def reward(self, final_states: States) -> torch.Tensor:
def log_reward(self, final_states: States) -> torch.Tensor:
"""Calculates the log reward.
This or reward must be implemented.

Args:
final_states: A batch of final states.

Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the log rewards.
"""
Expand Down Expand Up @@ -386,7 +384,6 @@ def __init__(
assert dummy_action.shape == action_shape
assert exit_action.shape == action_shape


self.n_actions = n_actions # Before init, for compatibility with States.
super().__init__(
s0,
Expand All @@ -403,10 +400,10 @@ def __init__(

def states_from_tensor(self, tensor: torch.Tensor):
"""Wraps the supplied Tensor in a States instance & updates masks.

Args:
tensor: The tensor of shape "state_shape" representing the states.

Returns:
States: An instance of States.
"""
Expand Down Expand Up @@ -489,29 +486,25 @@ def _step(self, states: DiscreteStates, actions: Actions) -> States:
) # TODO: update_masks is owned by the env, not the states!!
return new_states

def get_states_indices(
self, states: DiscreteStates
) -> torch.Tensor:
def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Returns the indices of the states in the environment.

Args:
states: The batch of states.

Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the indices of the states.
"""
return NotImplementedError(
"The environment does not support enumeration of states"
)

def get_terminating_states_indices(
self, states: DiscreteStates
) -> torch.Tensor:
def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Returns the indices of the terminating states in the environment.

Args:
states: The batch of states.

Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the indices of the terminating states.
"""
Expand Down
21 changes: 13 additions & 8 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from abc import ABC, abstractmethod
from typing import Generic, Tuple, TypeVar, Union, Any
from typing import Any, Generic, Tuple, TypeVar, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -211,7 +211,6 @@ def get_pfs_and_pbs(
# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
if trajectories.conditioning is not None:

# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
Expand Down Expand Up @@ -242,8 +241,14 @@ def get_pfs_and_pbs(
log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice

assert log_pf_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories)
assert log_pb_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories)
assert log_pf_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
assert log_pb_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
return log_pf_trajectories, log_pb_trajectories

def get_trajectories_scores(
Expand All @@ -252,15 +257,15 @@ def get_trajectories_scores(
recalculate_all_logprobs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Given a batch of trajectories, calculate forward & backward policy scores.

Args:
trajectories: Trajectories to evaluate.
recalculate_all_logprobs: Whether to re-evaluate all logprobs.

Returns: A tuple of float tensors of shape (n_trajectories,)
containing the total log_pf, total log_pb, and the total
log-likelihood of the trajectories.

"""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
Expand All @@ -279,7 +284,7 @@ def get_trajectories_scores(
torch.isinf(total_log_pb_trajectories)
):
raise ValueError("Infinite logprobs found")

assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,)
assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,)
return (
Expand Down
Loading