Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Graphs as States #210

Open
wants to merge 62 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
9ae28b2
including Graphs as States for torchgfn
alip67 Nov 6, 2024
de6ab1c
add GraphEnv
younik Nov 7, 2024
24e23e8
add deps and reformat
younik Nov 7, 2024
1f7b220
add test, fix errors, add valid action check
younik Nov 8, 2024
63e4f1c
fix formatting
younik Nov 8, 2024
8034fb2
add GraphAction
younik Nov 14, 2024
d179671
fix batching mechanism
younik Nov 14, 2024
e018f4e
Merge branch 'GFNOrg:master' into graph-states
alip67 Nov 15, 2024
7ff96d5
add support for EXIT action
younik Nov 16, 2024
cf482da
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik Nov 16, 2024
dacbbf7
add GraphActionPolicyEstimator
younik Nov 19, 2024
98ea448
Merge branch 'GFNOrg:master' into graph-states
alip67 Nov 19, 2024
e74e500
Sampler integration work
younik Nov 22, 2024
a862bb4
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik Nov 22, 2024
5e64c84
use TensorDict
younik Nov 26, 2024
81f8b71
solve some errors
younik Nov 28, 2024
34781ef
use tensordict in actions
younik Nov 28, 2024
3e584f2
handle sf
younik Dec 2, 2024
d5e438f
remove Data
younik Dec 3, 2024
fba5d50
categorical action type
younik Dec 6, 2024
478bd14
change batching
younik Dec 10, 2024
dd80f28
fix stacking
younik Dec 11, 2024
616551c
fix graph stacking
younik Dec 11, 2024
77611d4
fix test graph env
younik Dec 12, 2024
5874ff6
add ring example
younik Dec 19, 2024
9d42332
remove check edge_features
younik Dec 20, 2024
2d44242
fix GraphStates set
younik Dec 20, 2024
173d4fb
remove debug
younik Dec 20, 2024
7265857
fix add_edge action
younik Dec 20, 2024
2b3208f
fix edge_index after get
younik Dec 20, 2024
b84246f
push updated code
younik Dec 22, 2024
fa0d22a
add rendering
younik Dec 27, 2024
27d192a
fix gradient propagation
younik Jan 6, 2025
5d99739
Merge remote-tracking branch 'origin/master' into graph-states
younik Jan 12, 2025
f4fc3ab
fix formatting
younik Jan 12, 2025
8f1c62c
address comments
younik Jan 12, 2025
6482834
fix test
younik Jan 12, 2025
6db601d
fix test
younik Jan 13, 2025
c7f8243
fix pre-commit
younik Jan 13, 2025
c3df427
Merge remote-tracking branch 'origin/master' into graph-states
younik Jan 13, 2025
78b729a
fix merging issues
younik Jan 13, 2025
38dd2b0
fix toml
younik Jan 13, 2025
12c49b7
add dep & address issue
younik Jan 13, 2025
fe237ed
fix toml
younik Jan 13, 2025
9bbc48d
fix pyproject.toml
younik Jan 13, 2025
5e4fc4e
address comments
younik Jan 14, 2025
705b4cc
add tests for action
younik Jan 15, 2025
d765330
fix test after added dummy action
younik Jan 15, 2025
4ee6987
add GraphPreprocessor
younik Jan 15, 2025
fe9713c
added TODO
younik Jan 15, 2025
1425eb6
add complete masks
younik Jan 19, 2025
36c42ec
pre-commit hook
younik Jan 19, 2025
5747e97
adress comments
younik Jan 19, 2025
e9f9951
pre-commit
younik Jan 19, 2025
406cfca
address comments
younik Jan 20, 2025
08e519b
fix ring example
younik Jan 24, 2025
17e07ad
make edge_index global
younik Jan 29, 2025
46e3698
make edge_index global
younik Jan 29, 2025
1c33d98
Merge remote-tracking branch 'origin/graph-states-fix' into graph-states
younik Jan 29, 2025
e6d909b
fix test_env
younik Jan 29, 2025
da66adb
add global edge + pair programming session
younik Feb 4, 2025
7130281
pair programming session
younik Feb 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
saleml marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ einops = ">=0.6.1"
numpy = ">=1.21.2"
younik marked this conversation as resolved.
Show resolved Hide resolved
python = "^3.10"
torch = ">=1.9.0"
tensordict = ">=0.6.1"

# dev dependencies.
black = { version = "24.3", optional = true }
Expand All @@ -44,6 +45,7 @@ wandb = { version = "*", optional = true }
scikit-learn = {version = "*", optional = true }
scipy = { version = "*", optional = true }
matplotlib = { version = "*", optional = true }
torch_geometric = { version = ">=2.6.1", optional = true }

[tool.poetry.extras]
dev = [
Expand All @@ -59,6 +61,7 @@ dev = [
"sphinx",
"tox",
"flake8",
"torch_geometric",
]

scripts = ["tqdm", "wandb", "scikit-learn", "scipy", "matplotlib"]
Expand All @@ -80,6 +83,7 @@ all = [
"tox",
"tqdm",
"wandb",
"torch_geometric",
]

[tool.poetry.urls]
Expand Down
155 changes: 155 additions & 0 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations # This allows to use the class name in type hints

import enum
from abc import ABC
from math import prod
from typing import ClassVar, Sequence

import torch
from tensordict import TensorDict


class Actions(ABC):
Expand Down Expand Up @@ -168,3 +170,156 @@ def is_exit(self) -> torch.Tensor:
*self.batch_shape, *((1,) * len(self.__class__.action_shape))
)
return self.compare(exit_actions_tensor)


class GraphActionType(enum.IntEnum):
ADD_NODE = 0
ADD_EDGE = 1
EXIT = 2
DUMMY = 3


class GraphActions(Actions):
"""Actions for graph-based environments.

Each action is one of:
- ADD_NODE: Add a node with given features
- ADD_EDGE: Add an edge between two nodes with given features
- EXIT: Terminate the trajectory

Attributes:
features_dim: Dimension of node/edge features
tensor: TensorDict containing:
- action_type: Type of action (GraphActionType)
- features: Features for nodes/edges
- edge_index: Source/target nodes for edges
"""

features_dim: ClassVar[int]

def __init__(self, tensor: TensorDict):
"""Initializes a GraphAction object.
younik marked this conversation as resolved.
Show resolved Hide resolved

Args:
action: a GraphActionType indicating the type of action.
features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type.
In case of EXIT action, this can be None.
edge_index: an tensor of shape (batch_shape, 2) representing the edge to add.
This must defined if and only if the action type is GraphActionType.AddEdge.
"""
self.batch_shape = tensor["action_type"].shape
features = tensor.get("features", None)
if features is None:
assert torch.all(
torch.logical_or(
tensor["action_type"] == GraphActionType.EXIT,
tensor["action_type"] == GraphActionType.DUMMY,
)
)
features = torch.zeros((*self.batch_shape, self.features_dim))
edge_index = tensor.get("edge_index", None)
if edge_index is None:
assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE)
edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long)

self.tensor = TensorDict(
{
"action_type": tensor["action_type"],
"features": features,
"edge_index": edge_index,
},
batch_size=self.batch_shape,
)

def __repr__(self):
return f"""GraphAction object with {self.batch_shape} actions."""

@property
def device(self) -> torch.device:
"""Returns the device of the features tensor."""
return self.tensor.device

def __len__(self) -> int:
"""Returns the number of actions in the batch."""
return prod(self.batch_shape)

def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions:
"""Get particular actions of the batch."""
return GraphActions(self.tensor[index])

def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], action: GraphActions
) -> None:
"""Set particular actions of the batch."""
self.tensor[index] = action.tensor
josephdviviano marked this conversation as resolved.
Show resolved Hide resolved

def compare(self, other: GraphActions) -> torch.Tensor:
"""Compares the actions to another GraphAction object.

Args:
other: GraphAction object to compare.

Returns: boolean tensor of shape batch_shape indicating whether the actions are equal.
"""
compare = torch.all(self.tensor == other.tensor, dim=-1)
saleml marked this conversation as resolved.
Show resolved Hide resolved
younik marked this conversation as resolved.
Show resolved Hide resolved
return (
compare["action_type"]
& (compare["action_type"] == GraphActionType.EXIT | compare["features"])
& (
compare["action_type"]
!= GraphActionType.ADD_EDGE | compare["edge_index"]
)
)

@property
def is_exit(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions."""
return self.action_type == GraphActionType.EXIT

@property
def is_dummy(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions."""
return self.action_type == GraphActionType.DUMMY

@property
def action_type(self) -> torch.Tensor:
"""Returns the action type tensor."""
return self.tensor["action_type"]

@property
def features(self) -> torch.Tensor:
"""Returns the features tensor."""
return self.tensor["features"]

@property
def edge_index(self) -> torch.Tensor:
"""Returns the edge index tensor."""
return self.tensor["edge_index"]

younik marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions:
"""Creates a GraphActions object of dummy actions with the given batch shape."""
return cls(
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.DUMMY
),
},
batch_size=batch_shape,
)
)

@classmethod
def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions:
"""Creates an GraphActions object of exit actions with the given batch shape."""
return cls(
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.EXIT
),
},
batch_size=batch_shape,
)
)
2 changes: 1 addition & 1 deletion src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
)
), f"log_probs.shape={log_probs.shape}, self.max_length={self.max_length}, self.n_trajectories={self.n_trajectories}"
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs: torch.Tensor = log_probs
Expand Down
83 changes: 76 additions & 7 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Optional, Tuple, Union

import torch
from tensordict import TensorDict
younik marked this conversation as resolved.
Show resolved Hide resolved

from gfn.actions import Actions
from gfn.actions import Actions, GraphActions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
from gfn.states import DiscreteStates, States
from gfn.states import DiscreteStates, GraphStates, States
from gfn.utils.common import set_seed

# Errors
Expand Down Expand Up @@ -260,22 +261,24 @@ def _step(
"Some actions are not valid in the given states. See `is_action_valid`."
)

# Set to the sink state when the action is exit.
new_sink_states_idx = actions.is_exit
new_states.tensor[new_sink_states_idx] = self.sf
sf_tensor = self.States.make_sink_states_tensor((new_sink_states_idx.sum(),))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about the reason for this change? Is it specific to GraphStates?

Copy link
Collaborator

@younik younik Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is because of how graphs are represented in the tensor, i.e:

tensor = TensorDict({
 'node_features': shape (N, F1)
 'edge_features': shape (M, F2)
'edge_index': shape (2, M)
})

Notice that tensor[some_index] doesn't make sense, and doesn't work. There is a more complex behavior defined in GraphStates.__setitem__, to do it correctly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a comment would be worth adding here.

new_states[new_sink_states_idx] = self.States(sf_tensor)
new_sink_states_idx = ~valid_states_idx | new_sink_states_idx
assert new_sink_states_idx.shape == states.batch_shape

not_done_states = new_states[~new_sink_states_idx]
not_done_actions = actions[~new_sink_states_idx]

new_not_done_states_tensor = self.step(not_done_states, not_done_actions)
if not isinstance(new_not_done_states_tensor, torch.Tensor):

if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)):
raise Exception(
"User implemented env.step function *must* return a torch.Tensor!"
)

new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor

new_states[~new_sink_states_idx] = self.States(new_not_done_states_tensor)
return new_states

def _backward_step(
Expand Down Expand Up @@ -303,7 +306,7 @@ def _backward_step(

# Calculate the backward step, and update only the states which are not Done.
new_not_done_states_tensor = self.backward_step(valid_states, valid_actions)
new_states.tensor[valid_states_idx] = new_not_done_states_tensor
new_states[valid_states_idx] = self.States(new_not_done_states_tensor)
younik marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(new_states, DiscreteStates):
self.update_masks(new_states)
Expand Down Expand Up @@ -565,3 +568,69 @@ def terminating_states(self) -> DiscreteStates:
raise NotImplementedError(
"The environment does not support enumeration of states"
)


class GraphEnv(Env):
"""Base class for graph-based environments."""

def __init__(
self,
s0: TensorDict,
sf: Optional[TensorDict] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
):
"""Initializes a graph-based environment.

Args:
s0: The initial graph state.
sf: The final graph state.
device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is
inferred from s0.
preprocessor: a Preprocessor object that converts raw graph states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
"""
self.s0 = s0.to(device_str)
self.features_dim = s0["node_feature"].shape[-1]
self.sf = sf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens when sf is initialized to None?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method GraphStates.is_sink_state doesn't work, as it checks the values in the (expected) tensordict.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to support this case, I suggest doing it in another PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could have a special NoneTensorDict GraphState which acts like None but passes the relevant checks?


self.States = self.make_states_class()
self.Actions = self.make_actions_class()

self.preprocessor = preprocessor
self.is_discrete = False

def make_states_class(self) -> type[GraphStates]:
env = self

class GraphEnvStates(GraphStates):
s0 = env.s0
sf = env.sf
make_random_states_graph = env.make_random_states_tensor

return GraphEnvStates

def make_actions_class(self) -> type[GraphActions]:
"""The default Actions class factory for all Environments.

Returns a class that inherits from Actions and implements assumed methods.
The make_actions_class method should be overwritten to achieve more
environment-specific Actions functionality.
"""
env = self

class DefaultGraphAction(GraphActions):
features_dim = env.features_dim

return DefaultGraphAction

@abstractmethod
def step(self, states: GraphStates, actions: Actions) -> torch.Tensor:
"""Function that takes a batch of graph states and actions and returns a batch of next
graph states."""

@abstractmethod
def backward_step(self, states: GraphStates, actions: Actions) -> torch.Tensor:
"""Function that takes a batch of graph states and actions and returns a batch of previous
graph states."""
8 changes: 4 additions & 4 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]):
def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0):
super().__init__()

assert isinstance( # TODO: need a more flexible type check.
logF,
DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator,
), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator"
# assert isinstance( # TODO: need a more flexible type check.
# logF,
# DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator,
# ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a TODO here for what to replace this check with?

I think it would be helpful for the users to know if they're submitted an appropriate estimator for logF.

self.logF = logF
self.alpha = alpha

Expand Down
1 change: 1 addition & 0 deletions src/gfn/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from gfn.gym.box import Box
from gfn.gym.discrete_ebm import DiscreteEBM
from gfn.gym.graph_building import GraphBuilding
from gfn.gym.hypergrid import HyperGrid
Loading
Loading