Skip to content

Commit

Permalink
add GraphState class with State containing Timesteps as Batch
Browse files Browse the repository at this point in the history
  • Loading branch information
ashdtu committed Aug 20, 2024
1 parent 0764313 commit 3fbde19
Showing 1 changed file with 140 additions and 1 deletion.
141 changes: 140 additions & 1 deletion src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from abc import ABC
from copy import deepcopy
from math import prod
from typing import Callable, ClassVar, List, Optional, Sequence, cast
from typing import Callable, ClassVar, List, Optional, Sequence, cast, Tuple

import torch
from torchtyping import TensorType as TT
from torch_geometric.data import Batch, Data


class States(ABC):
Expand Down Expand Up @@ -492,3 +493,141 @@ def stack_states(states: List[States]):
) + state_example.batch_shape

return stacked_states


class GraphStates(ABC):
"""
Base class for Graph as a state representation. The `GraphStates` object is a batched collection of
multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of
graph objects as states.
"""


s0: ClassVar[Data]
sf: ClassVar[Data]
node_feature_dim: ClassVar[int]
edge_feature_dim: ClassVar[int]
make_random_states_graph: Callable = lambda x: (_ for _ in ()).throw(
NotImplementedError(
"The environment does not support initialization of random Graph states."
)
)

def __init__(self, graphs: Batch):
self.data: Batch = graphs
self.batch_shape: int = self.data.num_graphs
self._log_rewards: float = None

@classmethod
def from_batch_shape(cls, batch_shape: int, random: bool = False, sink: bool=False) -> GraphStates:
if random and sink:
raise ValueError("Only one of `random` and `sink` should be True.")
if random:
data = cls.make_random_states_graph(batch_shape)
elif sink:
data = cls.make_sink_states_graph(batch_shape)
else:
data = cls.make_initial_states_graph(batch_shape)
return cls(data)

@classmethod
def make_initial_states_graph(cls, batch_shape: int) -> Batch:
data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)])
return data

@classmethod
def make_sink_states_graph(cls, batch_shape: int) -> Batch:
data = Batch.from_data_list([cls.sf for _ in range(batch_shape)])
return data

@classmethod
def make_random_states_graph(cls, batch_shape: int) -> Batch:
data = Batch.from_data_list([cls.make_random_states_graph() for _ in range(batch_shape)])
return data

def __len__(self):
return self.data.batch_size

def __repr__(self):
return (f"{self.__class__.__name__} object of batch shape {self.batch_shape} and "
f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}")

def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates:
if isinstance(index, int):
out = self.__class__(Batch.from_data_list([self.data[index]]))
elif isinstance(index, (Sequence, slice)):
out = self.__class__(Batch.from_data_list(self.data.index_select(index)))
else:
raise NotImplementedError("Indexing with type {} is not implemented".format(type(index)))

if self._log_rewards is not None:
out._log_rewards = self._log_rewards[index]

return out

def __setitem__(self, index: int | Sequence[int], graph: GraphStates):
"""
Set particular states of the Batch
"""
data_list = self.data.to_data_list()
if isinstance(index, int):
assert len(graph) == 1, "GraphStates must have a batch size of 1 for single index assignment"
data_list[index] = graph.data[0]
self.data = Batch.from_data_list(data_list)
elif isinstance(index, Sequence):
assert len(index) == len(graph), "Index and GraphState must have the same length"
for i, idx in enumerate(index):
data_list[idx] = graph.data[i]
self.data = Batch.from_data_list(data_list)
elif isinstance(index, slice):
assert index.stop - index.start == len(graph), "Index slice and GraphStates must have the same length"
data_list[index] = graph.data.to_data_list()
self.data = Batch.from_data_list(data_list)
else:
raise NotImplementedError("Setters with type {} is not implemented".format(type(index)))

@property
def device(self) -> torch.device:
return self.data.get_example(0).x.device

def to(self, device: torch.device) -> GraphStates:
"""
Moves and/or casts the graph states to the specified device
"""
if self.device != device:
self.data = self.data.to(device)
return self

def clone(self) -> States:
"""Returns a *detached* clone of the current instance using deepcopy."""
return deepcopy(self)

def extend(self, other: GraphStates):
"""Concatenates to another GraphStates object along the batch dimension"""
self.data = Batch.from_data_list(self.data.to_data_list() + other.data.to_data_list())
if self._log_rewards is not None:
assert other._log_rewards is not None
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards), dim=0
)


@property
def log_rewards(self) -> TT["batch_shape", torch.float]:
return self._log_rewards

@log_rewards.setter
def log_rewards(self, log_rewards: TT["batch_shape", torch.float]) -> None:
self._log_rewards = log_rewards












0 comments on commit 3fbde19

Please sign in to comment.