diff --git a/benchmarl/conf/model/layers/gnn.yaml b/benchmarl/conf/model/layers/gnn.yaml index 6d6b1921..4e9776ce 100644 --- a/benchmarl/conf/model/layers/gnn.yaml +++ b/benchmarl/conf/model/layers/gnn.yaml @@ -8,7 +8,9 @@ gnn_kwargs: aggr: "add" position_key: null +pos_features: 0 velocity_key: null +vel_features: 0 exclude_pos_from_node_features: False edge_radius: null diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 12fa8859..48154e61 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -715,7 +715,6 @@ def _get_excluded_keys(self, group: str): for other_group in self.group_map.keys(): if other_group != group: excluded_keys += [other_group, ("next", other_group)] - excluded_keys += ["info", (group, "info"), ("next", group, "info")] return excluded_keys def _optimizer_loop(self, group: str) -> TensorDictBase: diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 22556ff1..5cb0f00e 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -11,11 +11,11 @@ import warnings from dataclasses import dataclass, MISSING from math import prod -from typing import Optional, Type +from typing import List, Optional, Type import torch from tensordict import TensorDictBase -from tensordict.utils import _unravel_key_to_tuple +from tensordict.utils import _unravel_key_to_tuple, NestedKey from torch import nn, Tensor from benchmarl.models.common import Model, ModelConfig @@ -59,16 +59,21 @@ class Gnn(Model): self_loops (str): Whether the resulting adjacency matrix will have self loops. gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class - position_key (str, optional): if provided, it will need to match a leaf key in the env observation spec - representing the agent position. This key will not be processed as a node feature, but it will used to construct - edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a + position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env + (we suggest to use the "info" dict) representing the agent position. This key will be processed as a + node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features. + In particular, it will be used to compute relative positions (``pos_node_1 - pos_node_2``) and a one-dimensional distance for all neighbours in the graph. + pos_features (int, optional): Needed when position_key is specified. + It has to match to the last element of the shape the tensor under position_key. exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided, wether to use it just to compute edge features or also include it in node features. - velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec - representing the agent velocity. This key will not be processed as a node feature, but it will used to construct - edge features. In particular it be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours - in the graph. + velocity_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env + (we suggest to use the "info" dict) representing the agent velocity. This key will be processed as a node feature, and + it will be used to construct edge features. In particular, it will be used to compute relative velocities + (``vel_node_1 - vel_node_2``) for all neighbours in the graph. + vel_features (int, optional): Needed when velocity_key is specified. + It has to match to the last element of the shape the tensor under velocity_key. edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph. Agents within this radius distance will be neighnours. @@ -120,6 +125,8 @@ def __init__( exclude_pos_from_node_features: Optional[bool], velocity_key: Optional[str], edge_radius: Optional[float], + pos_features: Optional[int], + vel_features: Optional[int], **kwargs, ): self.topology = topology @@ -128,34 +135,26 @@ def __init__( self.velocity_key = velocity_key self.exclude_pos_from_node_features = exclude_pos_from_node_features self.edge_radius = edge_radius + self.pos_features = pos_features + self.vel_features = vel_features super().__init__(**kwargs) - self.pos_features = sum( - [ - spec.shape[-1] - for key, spec in self.input_spec.items(True, True) - if _unravel_key_to_tuple(key)[-1] == position_key - ] - ) # Input keys ending with `position_key` if self.pos_features > 0: self.pos_features += 1 # We will add also 1-dimensional distance - self.vel_features = sum( - [ - spec.shape[-1] - for key, spec in self.input_spec.items(True, True) - if _unravel_key_to_tuple(key)[-1] == velocity_key - ] - ) # Input keys ending with `velocity_key` self.edge_features = self.pos_features + self.vel_features self.input_features = sum( [ spec.shape[-1] for key, spec in self.input_spec.items(True, True) - if _unravel_key_to_tuple(key)[-1] - not in ((position_key) if self.exclude_pos_from_node_features else ()) + if _unravel_key_to_tuple(key)[-1] not in (position_key, velocity_key) ] - ) # Input keys not ending with `velocity_key` and `position_key` + ) # Input keys + if self.position_key is not None and not self.exclude_pos_from_node_features: + self.input_features += self.pos_features - 1 + if self.velocity_key is not None: + self.input_features += self.vel_features + self.output_features = self.output_leaf_spec.shape[-1] if gnn_kwargs is None: @@ -191,6 +190,8 @@ def __init__( device=self.device, n_agents=self.n_agents, ) + self._full_position_key = None + self._full_velocity_key = None def _perform_checks(self): super()._perform_checks() @@ -208,6 +209,22 @@ def _perform_checks(self): raise ValueError( "exclude_pos_from_node_features needs to be specified when position_key is provided" ) + if self.position_key is not None and self.pos_features <= 0: + raise ValueError( + f"Position key specified but pos_features is {self.pos_features}" + ) + elif self.position_key is None and self.pos_features > 0: + raise ValueError( + f"If no position_key is given, pos_features needs to be 0, got: {self.pos_features}" + ) + if self.velocity_key is not None and self.vel_features <= 0: + raise ValueError( + f"Velocity key specified but vel_features is {self.vel_features}" + ) + elif self.velocity_key is None and self.vel_features > 0: + raise ValueError( + f"If no velocity_key is given, vel_features needs to be 0, got: {self.vel_features}" + ) if not self.input_has_agent_dim: raise ValueError( @@ -247,40 +264,51 @@ def _perform_checks(self): def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key - input = torch.cat( - [ - tensordict.get(in_key) - for in_key in self.in_keys - if _unravel_key_to_tuple(in_key)[-1] - not in ( - (self.position_key) if self.exclude_pos_from_node_features else () - ) - ], - dim=-1, - ) + input = [ + tensordict.get(in_key) + for in_key in self.in_keys + if _unravel_key_to_tuple(in_key)[-1] + not in (self.position_key, self.velocity_key) + ] + + # Retrieve position if self.position_key is not None: - pos = torch.cat( - [ - tensordict.get(in_key) - for in_key in self.in_keys - if _unravel_key_to_tuple(in_key)[-1] == self.position_key - ], - dim=-1, - ) + if self._full_position_key is None: # Run once to find full key + self._full_position_key = self._get_key_terminating_with( + list(tensordict.keys(True, True)), self.position_key + ) + pos = tensordict.get(self._full_position_key) + if pos.shape[-1] != self.pos_features - 1: + raise ValueError( + f"Position key in tensordict is {pos.shape[-1]}-dimensional, " + f"while model was configured with pos_features={self.pos_features-1}" + ) + else: + pos = tensordict.get(self._full_position_key) + if not self.exclude_pos_from_node_features: + input.append(pos) else: pos = None + + # Retrieve velocity if self.velocity_key is not None: - vel = torch.cat( - [ - tensordict.get(in_key) - for in_key in self.in_keys - if _unravel_key_to_tuple(in_key)[-1] == self.velocity_key - ], - dim=-1, - ) + if self._full_velocity_key is None: # Run once to find full key + self._full_velocity_key = self._get_key_terminating_with( + list(tensordict.keys(True, True)), self.velocity_key + ) + vel = tensordict.get(self._full_velocity_key) + if vel.shape[-1] != self.vel_features: + raise ValueError( + f"Velocity key in tensordict is {vel.shape[-1]}-dimensional, " + f"while model was configured with vel_features={self.vel_features}" + ) + else: + vel = tensordict.get(self._full_velocity_key) + input.append(vel) else: vel = None + input = torch.cat(input, dim=-1) batch_size = input.shape[:-2] graph = _batch_from_dense_to_ptg( @@ -338,6 +366,15 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(self.out_key, res) return tensordict + def _get_key_terminating_with(self, keys: List[NestedKey], key: str) -> NestedKey: + for k in keys: + k_tuple = _unravel_key_to_tuple(k) + if k_tuple[-1] == key and self.agent_group in k_tuple: + return k + raise KeyError( + f"Key terminating with {key} and containing {self.agent_group} not found in keys: {keys}" + ) + def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str): if topology == "full": @@ -426,7 +463,9 @@ class GnnConfig(ModelConfig): gnn_kwargs: Optional[dict] = None position_key: Optional[str] = None + pos_features: Optional[int] = 0 velocity_key: Optional[str] = None + vel_features: Optional[int] = 0 exclude_pos_from_node_features: Optional[bool] = None edge_radius: Optional[float] = None diff --git a/test/test_models.py b/test/test_models.py index 53dfae0c..48598f2d 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -331,7 +331,7 @@ def test_gnn_edge_attrs( shape=multi_agent_obs.shape[len(batch_size) :] ), "pos": UnboundedContinuousTensorSpec( - shape=multi_agent_obs.shape[len(batch_size) :] + shape=multi_agent_pos.shape[len(batch_size) :] ), }, shape=(n_agents,), @@ -360,6 +360,7 @@ def test_gnn_edge_attrs( gnn_kwargs=None, position_key=position_key, exclude_pos_from_node_features=False, + pos_features=pos_size if position_key is not None else 0, ).get_model( input_spec=input_spec, output_spec=output_spec, @@ -391,6 +392,7 @@ def test_gnn_edge_attrs( gnn_kwargs=None, position_key=position_key, exclude_pos_from_node_features=False, + pos_features=pos_size if position_key is not None else 0, ).get_model( input_spec=input_spec, output_spec=output_spec,