Skip to content

Commit

Permalink
[BugFix, Feature] GNN position_key and velocity_key not in `obser…
Browse files Browse the repository at this point in the history
…vation_spec` (#125)

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend

* amend
  • Loading branch information
matteobettini authored Sep 3, 2024
1 parent d260eea commit aef8d40
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 55 deletions.
2 changes: 2 additions & 0 deletions benchmarl/conf/model/layers/gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ gnn_kwargs:
aggr: "add"

position_key: null
pos_features: 0
velocity_key: null
vel_features: 0

exclude_pos_from_node_features: False
edge_radius: null
1 change: 0 additions & 1 deletion benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,6 @@ def _get_excluded_keys(self, group: str):
for other_group in self.group_map.keys():
if other_group != group:
excluded_keys += [other_group, ("next", other_group)]
excluded_keys += ["info", (group, "info"), ("next", group, "info")]
return excluded_keys

def _optimizer_loop(self, group: str) -> TensorDictBase:
Expand Down
145 changes: 92 additions & 53 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import warnings
from dataclasses import dataclass, MISSING
from math import prod
from typing import Optional, Type
from typing import List, Optional, Type

import torch
from tensordict import TensorDictBase
from tensordict.utils import _unravel_key_to_tuple
from tensordict.utils import _unravel_key_to_tuple, NestedKey
from torch import nn, Tensor

from benchmarl.models.common import Model, ModelConfig
Expand Down Expand Up @@ -59,16 +59,21 @@ class Gnn(Model):
self_loops (str): Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
position_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent position. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(we suggest to use the "info" dict) representing the agent position. This key will be processed as a
node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features.
In particular, it will be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
one-dimensional distance for all neighbours in the graph.
pos_features (int, optional): Needed when position_key is specified.
It has to match to the last element of the shape the tensor under position_key.
exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided,
wether to use it just to compute edge features or also include it in node features.
velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent velocity. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours
in the graph.
velocity_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(we suggest to use the "info" dict) representing the agent velocity. This key will be processed as a node feature, and
it will be used to construct edge features. In particular, it will be used to compute relative velocities
(``vel_node_1 - vel_node_2``) for all neighbours in the graph.
vel_features (int, optional): Needed when velocity_key is specified.
It has to match to the last element of the shape the tensor under velocity_key.
edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph.
Agents within this radius distance will be neighnours.
Expand Down Expand Up @@ -120,6 +125,8 @@ def __init__(
exclude_pos_from_node_features: Optional[bool],
velocity_key: Optional[str],
edge_radius: Optional[float],
pos_features: Optional[int],
vel_features: Optional[int],
**kwargs,
):
self.topology = topology
Expand All @@ -128,34 +135,26 @@ def __init__(
self.velocity_key = velocity_key
self.exclude_pos_from_node_features = exclude_pos_from_node_features
self.edge_radius = edge_radius
self.pos_features = pos_features
self.vel_features = vel_features

super().__init__(**kwargs)

self.pos_features = sum(
[
spec.shape[-1]
for key, spec in self.input_spec.items(True, True)
if _unravel_key_to_tuple(key)[-1] == position_key
]
) # Input keys ending with `position_key`
if self.pos_features > 0:
self.pos_features += 1 # We will add also 1-dimensional distance
self.vel_features = sum(
[
spec.shape[-1]
for key, spec in self.input_spec.items(True, True)
if _unravel_key_to_tuple(key)[-1] == velocity_key
]
) # Input keys ending with `velocity_key`
self.edge_features = self.pos_features + self.vel_features
self.input_features = sum(
[
spec.shape[-1]
for key, spec in self.input_spec.items(True, True)
if _unravel_key_to_tuple(key)[-1]
not in ((position_key) if self.exclude_pos_from_node_features else ())
if _unravel_key_to_tuple(key)[-1] not in (position_key, velocity_key)
]
) # Input keys not ending with `velocity_key` and `position_key`
) # Input keys
if self.position_key is not None and not self.exclude_pos_from_node_features:
self.input_features += self.pos_features - 1
if self.velocity_key is not None:
self.input_features += self.vel_features

self.output_features = self.output_leaf_spec.shape[-1]

if gnn_kwargs is None:
Expand Down Expand Up @@ -191,6 +190,8 @@ def __init__(
device=self.device,
n_agents=self.n_agents,
)
self._full_position_key = None
self._full_velocity_key = None

def _perform_checks(self):
super()._perform_checks()
Expand All @@ -208,6 +209,22 @@ def _perform_checks(self):
raise ValueError(
"exclude_pos_from_node_features needs to be specified when position_key is provided"
)
if self.position_key is not None and self.pos_features <= 0:
raise ValueError(
f"Position key specified but pos_features is {self.pos_features}"
)
elif self.position_key is None and self.pos_features > 0:
raise ValueError(
f"If no position_key is given, pos_features needs to be 0, got: {self.pos_features}"
)
if self.velocity_key is not None and self.vel_features <= 0:
raise ValueError(
f"Velocity key specified but vel_features is {self.vel_features}"
)
elif self.velocity_key is None and self.vel_features > 0:
raise ValueError(
f"If no velocity_key is given, vel_features needs to be 0, got: {self.vel_features}"
)

if not self.input_has_agent_dim:
raise ValueError(
Expand Down Expand Up @@ -247,40 +264,51 @@ def _perform_checks(self):

def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather in_key
input = torch.cat(
[
tensordict.get(in_key)
for in_key in self.in_keys
if _unravel_key_to_tuple(in_key)[-1]
not in (
(self.position_key) if self.exclude_pos_from_node_features else ()
)
],
dim=-1,
)
input = [
tensordict.get(in_key)
for in_key in self.in_keys
if _unravel_key_to_tuple(in_key)[-1]
not in (self.position_key, self.velocity_key)
]

# Retrieve position
if self.position_key is not None:
pos = torch.cat(
[
tensordict.get(in_key)
for in_key in self.in_keys
if _unravel_key_to_tuple(in_key)[-1] == self.position_key
],
dim=-1,
)
if self._full_position_key is None: # Run once to find full key
self._full_position_key = self._get_key_terminating_with(
list(tensordict.keys(True, True)), self.position_key
)
pos = tensordict.get(self._full_position_key)
if pos.shape[-1] != self.pos_features - 1:
raise ValueError(
f"Position key in tensordict is {pos.shape[-1]}-dimensional, "
f"while model was configured with pos_features={self.pos_features-1}"
)
else:
pos = tensordict.get(self._full_position_key)
if not self.exclude_pos_from_node_features:
input.append(pos)
else:
pos = None

# Retrieve velocity
if self.velocity_key is not None:
vel = torch.cat(
[
tensordict.get(in_key)
for in_key in self.in_keys
if _unravel_key_to_tuple(in_key)[-1] == self.velocity_key
],
dim=-1,
)
if self._full_velocity_key is None: # Run once to find full key
self._full_velocity_key = self._get_key_terminating_with(
list(tensordict.keys(True, True)), self.velocity_key
)
vel = tensordict.get(self._full_velocity_key)
if vel.shape[-1] != self.vel_features:
raise ValueError(
f"Velocity key in tensordict is {vel.shape[-1]}-dimensional, "
f"while model was configured with vel_features={self.vel_features}"
)
else:
vel = tensordict.get(self._full_velocity_key)
input.append(vel)
else:
vel = None

input = torch.cat(input, dim=-1)
batch_size = input.shape[:-2]

graph = _batch_from_dense_to_ptg(
Expand Down Expand Up @@ -338,6 +366,15 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict.set(self.out_key, res)
return tensordict

def _get_key_terminating_with(self, keys: List[NestedKey], key: str) -> NestedKey:
for k in keys:
k_tuple = _unravel_key_to_tuple(k)
if k_tuple[-1] == key and self.agent_group in k_tuple:
return k
raise KeyError(
f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}"
)


def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str):
if topology == "full":
Expand Down Expand Up @@ -426,7 +463,9 @@ class GnnConfig(ModelConfig):
gnn_kwargs: Optional[dict] = None

position_key: Optional[str] = None
pos_features: Optional[int] = 0
velocity_key: Optional[str] = None
vel_features: Optional[int] = 0
exclude_pos_from_node_features: Optional[bool] = None
edge_radius: Optional[float] = None

Expand Down
4 changes: 3 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def test_gnn_edge_attrs(
shape=multi_agent_obs.shape[len(batch_size) :]
),
"pos": UnboundedContinuousTensorSpec(
shape=multi_agent_obs.shape[len(batch_size) :]
shape=multi_agent_pos.shape[len(batch_size) :]
),
},
shape=(n_agents,),
Expand Down Expand Up @@ -360,6 +360,7 @@ def test_gnn_edge_attrs(
gnn_kwargs=None,
position_key=position_key,
exclude_pos_from_node_features=False,
pos_features=pos_size if position_key is not None else 0,
).get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand Down Expand Up @@ -391,6 +392,7 @@ def test_gnn_edge_attrs(
gnn_kwargs=None,
position_key=position_key,
exclude_pos_from_node_features=False,
pos_features=pos_size if position_key is not None else 0,
).get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand Down

0 comments on commit aef8d40

Please sign in to comment.