Skip to content

Commit

Permalink
Merge pull request #234 from GFNOrg/newprecommit
Browse files Browse the repository at this point in the history
Update pre-commit configuration: enforce pyright and optimize local testing
  • Loading branch information
josephdviviano authored Jan 24, 2025
2 parents 59a1efa + 09ff20c commit 5db4162
Show file tree
Hide file tree
Showing 20 changed files with 283 additions and 216 deletions.
35 changes: 18 additions & 17 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
---
default_language_version:
node: 15.1.0
repos:
- repo: https://github.com/python/black
rev: 23.7.0
hooks:
- id: black
language_version: python3.10
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.0.0
hooks:
- id: flake8
- repo: https://github.com/pycqa/autoflake
rev: v2.2.0
rev: v2.3.1
hooks:
- id: autoflake
name: autoflake
entry: autoflake
language: python
"types": [python]
types: [python]
require_serial: true
args:
- "--in-place"
- "--expand-star-imports"
- "--remove-duplicate-keys"
- "--remove-unused-variables"
- "--remove-all-unused-imports"
- "--ignore-init-module-imports"
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
files: "\\.(py)$"
files: \\.py$
args: [--settings-path=pyproject.toml]
- repo: https://github.com/python/black
rev: 24.10.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.320
rev: v1.1.392.post0
hooks:
- id: pyright
name: pyright
Expand All @@ -44,7 +44,8 @@ repos:
- id: pytest-check
name: pytest-check
entry: pytest
args: [testing/]
language: python
files: testing/
pass_filenames: false
types: [python]
always_run: true
8 changes: 5 additions & 3 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def __init__(self, tensor: torch.Tensor):
self.batch_shape = tuple(self.tensor.shape)[: -len(self.action_shape)]

@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> Actions:
def make_dummy_actions(cls, batch_shape: tuple[int, ...]) -> Actions:
"""Creates an Actions object of dummy actions with the given batch shape."""
action_ndim = len(cls.action_shape)
tensor = cls.dummy_action.repeat(*batch_shape, *((1,) * action_ndim))
return cls(tensor)

@classmethod
def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions:
def make_exit_actions(cls, batch_shape: tuple[int, ...]) -> Actions:
"""Creates an Actions object of exit actions with the given batch shape."""
action_ndim = len(cls.action_shape)
tensor = cls.exit_action.repeat(*batch_shape, *((1,) * action_ndim))
Expand All @@ -64,7 +64,9 @@ def __repr__(self):
def device(self) -> torch.device:
return self.tensor.device

def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> Actions:
def __getitem__(
self, index: int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor
) -> Actions:
actions = self.tensor[index]
return self.__class__(actions)

Expand Down
83 changes: 54 additions & 29 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,40 +65,52 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects
training_objects, terminating_states = training_objects # pyright: ignore

to_add = len(training_objects)

self._is_full |= len(self) + to_add >= self.capacity

self.training_objects.extend(training_objects)
self.training_objects = self.training_objects[-self.capacity :]
self.training_objects.extend(training_objects) # pyright: ignore
self.training_objects = self.training_objects[
-self.capacity :
] # pyright: ignore

if self.terminating_states is not None:
assert terminating_states is not None
self.terminating_states.extend(terminating_states)
self.terminating_states.extend(terminating_states) # pyright: ignore
self.terminating_states = self.terminating_states[-self.capacity :]

def sample(self, n_trajectories: int) -> Transitions | Trajectories | tuple[States]:
"""Samples `n_trajectories` training objects from the buffer."""
if self.terminating_states is not None:
return (
self.training_objects.sample(n_trajectories),
self.terminating_states.sample(n_trajectories),
self.training_objects.sample(n_trajectories), # pyright: ignore
self.terminating_states.sample(n_trajectories), # pyright: ignore
)
return self.training_objects.sample(n_trajectories)
return self.training_objects.sample(n_trajectories) # pyright: ignore

def save(self, directory: str):
"""Saves the buffer to disk."""
self.training_objects.save(os.path.join(directory, "training_objects"))
if self.objects_type == "states":
raise ValueError("States cannot be saved")
self.training_objects.save( # pyright: ignore
os.path.join(directory, "training_objects")
)
if self.terminating_states is not None:
self.terminating_states.save(os.path.join(directory, "terminating_states"))
self.terminating_states.save( # pyright: ignore
os.path.join(directory, "terminating_states")
)

def load(self, directory: str):
"""Loads the buffer from disk."""
self.training_objects.load(os.path.join(directory, "training_objects"))
self.training_objects.load( # pyright: ignore
os.path.join(directory, "training_objects")
)
if self.terminating_states is not None:
self.terminating_states.load(os.path.join(directory, "terminating_states"))
self.terminating_states.load( # pyright: ignore
os.path.join(directory, "terminating_states")
)


class PrioritizedReplayBuffer(ReplayBuffer):
Expand Down Expand Up @@ -148,12 +160,14 @@ def _add_objs(
):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
self.training_objects.extend(training_objects)
self.training_objects.extend(training_objects) # pyright: ignore

# Sort elements by logreward, capping the size at the defined capacity.
ix = torch.argsort(self.training_objects.log_rewards)
self.training_objects = self.training_objects[ix]
self.training_objects = self.training_objects[-self.capacity :]
ix = torch.argsort(self.training_objects.log_rewards) # pyright: ignore
self.training_objects = self.training_objects[ix] # pyright: ignore
self.training_objects = self.training_objects[
-self.capacity :
] # pyright: ignore

# Add the terminating states to the buffer.
if self.terminating_states is not None:
Expand All @@ -169,7 +183,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects
training_objects, terminating_states = training_objects # pyright: ignore

to_add = len(training_objects)
self._is_full |= len(self) + to_add >= self.capacity
Expand All @@ -182,18 +196,22 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
else:
if (
self.training_objects.log_rewards is None
or training_objects.log_rewards is None
or training_objects.log_rewards is None # pyright: ignore
):
raise ValueError("log_rewards must be defined for prioritized replay.")

# Sort the incoming elements by their logrewards.
ix = torch.argsort(training_objects.log_rewards, descending=True)
training_objects = training_objects[ix]
ix = torch.argsort(
training_objects.log_rewards, descending=True # pyright: ignore
) # pyright: ignore
training_objects = training_objects[ix] # pyright: ignore

# Filter all batch logrewards lower than the smallest logreward in buffer.
min_reward_in_buffer = self.training_objects.log_rewards.min() # type: ignore # FIXME
idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer
training_objects = training_objects[idx_bigger_rewards]
idx_bigger_rewards = (
training_objects.log_rewards >= min_reward_in_buffer # pyright: ignore
) # pyright: ignore
training_objects = training_objects[idx_bigger_rewards] # pyright: ignore

# TODO: Concatenate input with final state for conditional GFN.
# if self.is_conditional:
Expand All @@ -212,8 +230,10 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):

if self.cutoff_distance >= 0:
# Filter the batch for diverse final_states with high reward.
batch = training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
batch = training_objects.last_states.tensor.float() # pyright: ignore
batch_dim = training_objects.last_states.batch_shape[ # pyright: ignore
0
] # pyright: ignore
batch_batch_dist = torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
batch.view(batch_dim, -1).unsqueeze(0),
Expand All @@ -225,13 +245,18 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max
batch_batch_dist = batch_batch_dist.min(-1)[0]
idx_batch_batch = batch_batch_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_batch]
training_objects = training_objects[idx_batch_batch] # pyright: ignore

# Compute all pairwise distances between the remaining batch & buffer.
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]
batch = training_objects.last_states.tensor.float() # pyright: ignore
buffer = (
self.training_objects.last_states.tensor.float() # pyright: ignore
) # pyright: ignore
batch_dim = training_objects.last_states.batch_shape[ # pyright: ignore
0
] # pyright: ignore
tmp = self.training_objects.last_states # pyright: ignore
buffer_dim = tmp.batch_shape[0] # pyright: ignore
batch_buffer_dist = (
torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
Expand All @@ -244,7 +269,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):

# Filter the batch for diverse final_states w.r.t the buffer.
idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_buffer]
training_objects = training_objects[idx_batch_buffer] # pyright: ignore

# If any training object remain after filtering, add them.
if len(training_objects):
Expand Down
40 changes: 20 additions & 20 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,13 @@ def reverse_backward_trajectories(

# Initialize new actions and states
new_actions = trajectories.env.dummy_action.repeat(
max_len + 1, len(trajectories), 1
max_len + 1, len(trajectories), 1 # pyright: ignore
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
new_states = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to(
new_states = trajectories.env.sf.repeat(
max_len + 2, len(trajectories), 1 # pyright: ignore
).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)

Expand Down Expand Up @@ -494,9 +496,9 @@ def reverse_backward_trajectories(

# Assign reversed actions to new_actions
new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]]
new_actions[
torch.arange(len(trajectories)), seq_lengths
] = trajectories.env.exit_action
new_actions[torch.arange(len(trajectories)), seq_lengths] = (
trajectories.env.exit_action
)

# Assign reversed states to new_states
assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0"
Expand Down Expand Up @@ -530,32 +532,30 @@ def reverse_backward_trajectories(
# vectorized approach's results (above) to the for-loop results (below).
if debug:
_new_actions = trajectories.env.dummy_action.repeat(
max_len + 1, len(trajectories), 1
max_len + 1, len(trajectories), 1 # pyright: ignore
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
_new_states = trajectories.env.sf.repeat(
max_len + 2, len(trajectories), 1
max_len + 2, len(trajectories), 1 # pyright: ignore
).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)

for i in range(len(trajectories)):
_new_actions[
trajectories.when_is_done[i], i
] = trajectories.env.exit_action
_new_actions[
: trajectories.when_is_done[i], i
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(
0
_new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.exit_action
)
_new_actions[: trajectories.when_is_done[i], i] = (
trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(
0
)
)

_new_states[
: trajectories.when_is_done[i] + 1, i
] = trajectories.states.tensor[
: trajectories.when_is_done[i] + 1, i
].flip(
0
_new_states[: trajectories.when_is_done[i] + 1, i] = (
trajectories.states.tensor[
: trajectories.when_is_done[i] + 1, i
].flip(0)
)

assert torch.all(new_actions == _new_actions)
Expand Down
23 changes: 15 additions & 8 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def reset(
batch_shape: Optional[Union[int, Tuple[int]]] = None,
random: bool = False,
sink: bool = False,
seed: int = None,
seed: int = None, # pyright: ignore
) -> States:
"""
Instantiates a batch of initial states. random and sink cannot be both True.
Expand Down Expand Up @@ -306,7 +306,7 @@ def _backward_step(
new_states.tensor[valid_states_idx] = new_not_done_states_tensor

if isinstance(new_states, DiscreteStates):
self.update_masks(new_states)
self.update_masks(new_states) # pyright: ignore

return new_states

Expand Down Expand Up @@ -364,8 +364,8 @@ def __init__(
s0: torch.Tensor,
state_shape: Tuple,
action_shape: Tuple = (1,),
dummy_action: Optional[torch.Tensor] = None,
exit_action: Optional[torch.Tensor] = None,
dummy_action: Optional[torch.Tensor] = None, # pyright: ignore
exit_action: Optional[torch.Tensor] = None, # pyright: ignore
sf: Optional[torch.Tensor] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
Expand Down Expand Up @@ -430,7 +430,7 @@ def reset(
batch_shape: Optional[Union[int, Tuple[int]]] = None,
random: bool = False,
sink: bool = False,
seed: int = None,
seed: int = None, # pyright: ignore
) -> States:
"""Instantiates a batch of initial states.
Expand Down Expand Up @@ -487,9 +487,16 @@ class DiscreteEnvActions(Actions):
def is_action_valid(
self, states: States, actions: Actions, backward: bool = False
) -> bool:
assert states.forward_masks is not None and states.backward_masks is not None
masks_tensor = states.backward_masks if backward else states.forward_masks
return torch.gather(masks_tensor, 1, actions.tensor).all()
assert (
states.forward_masks is not None # pyright: ignore
and states.backward_masks is not None # pyright: ignore
)
masks_tensor = (
states.backward_masks # pyright: ignore
if backward
else states.forward_masks # pyright: ignore
)
return torch.gather(masks_tensor, 1, actions.tensor).all() # pyright: ignore

def _step(self, states: DiscreteStates, actions: Actions) -> States:
"""Calls the core self._step method of the parent class, and updates masks."""
Expand Down
Loading

0 comments on commit 5db4162

Please sign in to comment.