Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Ensemble models and algorithms (different chioces for different agent groups) #159

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarl/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#

from .common import Algorithm, AlgorithmConfig
from .ensemble import EnsembleAlgorithm, EnsembleAlgorithmConfig
from .iddpg import Iddpg, IddpgConfig
from .ippo import Ippo, IppoConfig
from .iql import Iql, IqlConfig
Expand Down
128 changes: 128 additions & 0 deletions benchmarl/algorithms/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple, Type

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule

from torchrl.objectives import LossModule

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig

from benchmarl.models.common import ModelConfig


class EnsembleAlgorithm(Algorithm):
def __init__(self, algorithms_map, **kwargs):
super().__init__(**kwargs)
self.algorithms_map = algorithms_map

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
return self.algorithms_map[group]._get_loss(group, policy_for_loss, continuous)

def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
return self.algorithms_map[group]._get_parameters(group, loss)

def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:
return self.algorithms_map[group]._get_policy_for_loss(
group, model_config, continuous
)

def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:
return self.algorithms_map[group]._get_policy_for_collection(
policy_for_loss, group, continuous
)

def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
return self.algorithms_map[group].process_batch(group, batch)

def process_loss_vals(
self, group: str, loss_vals: TensorDictBase
) -> TensorDictBase:
return self.algorithms_map[group].process_loss_vals(group, loss_vals)


@dataclass
class EnsembleAlgorithmConfig(AlgorithmConfig):

algorithm_configs_map: Dict[str, AlgorithmConfig]

def __post_init__(self):
algorithm_configs = list(self.algorithm_configs_map.values())
self._on_policy = algorithm_configs[0].on_policy()

for algorithm_config in algorithm_configs[1:]:
if algorithm_config.on_policy() != self._on_policy:
raise ValueError(
"Algorithms in EnsembleAlgorithmConfig must either be all on_policy or all off_policy"
)

if (
not self.supports_discrete_actions()
and not self.supports_continuous_actions()
):
raise ValueError(
"Ensemble algorithm does not support discrete actions nor continuous actions."
" Make sure that at least one type of action is supported across all the algorithms used."
)

def get_algorithm(self, experiment) -> Algorithm:
return self.associated_class()(
algorithms_map={
group: algorithm_config.get_algorithm(experiment)
for group, algorithm_config in self.algorithm_configs_map.items()
},
experiment=experiment,
)

@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
raise NotImplementedError

@staticmethod
def associated_class() -> Type[Algorithm]:
return EnsembleAlgorithm

def on_policy(self) -> bool:
return self._on_policy

def supports_continuous_actions(self) -> bool:
supports_continuous_actions = True
for algorithm_config in self.algorithm_configs_map.values():
supports_continuous_actions *= (
algorithm_config.supports_continuous_actions()
)
return supports_continuous_actions

def supports_discrete_actions(self) -> bool:
supports_discrete_actions = True
for algorithm_config in self.algorithm_configs_map.values():
supports_discrete_actions *= algorithm_config.supports_discrete_actions()
return supports_discrete_actions

def has_independent_critic(self) -> bool:
has_independent_critic = False
for algorithm_config in self.algorithm_configs_map.values():
has_independent_critic += algorithm_config.has_independent_critic()
return has_independent_critic

def has_centralized_critic(self) -> bool:
has_centralized_critic = False
for algorithm_config in self.algorithm_configs_map.values():
has_centralized_critic += algorithm_config.has_centralized_critic()
return has_centralized_critic

def has_critic(self) -> bool:
return self.has_centralized_critic() or self.has_independent_critic()
8 changes: 7 additions & 1 deletion benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
#

from .cnn import Cnn, CnnConfig
from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .common import (
EnsembleModelConfig,
Model,
ModelConfig,
SequenceModel,
SequenceModelConfig,
)
from .deepsets import Deepsets, DeepsetsConfig
from .gnn import Gnn, GnnConfig
from .gru import Gru, GruConfig
Expand Down
64 changes: 64 additions & 0 deletions benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ def get_model_state_spec(self, model_index: int = 0) -> Composite:
"""
return Composite()

def _get_model_state_spec_inner(
self, model_index: int = 0, group: str = None
) -> Composite:
return self.get_model_state_spec(model_index)

@staticmethod
def _load_from_yaml(name: str) -> Dict[str, Any]:
yaml_path = (
Expand Down Expand Up @@ -421,6 +426,13 @@ class SequenceModelConfig(ModelConfig):
model_configs: Sequence[ModelConfig]
intermediate_sizes: Sequence[int]

def __post_init__(self):
for model_config in self.model_configs:
if isinstance(model_config, EnsembleModelConfig):
raise TypeError(
"SequenceModelConfig cannot contain EnsembleModelConfig layers, but the opposite can be done."
)

def get_model(
self,
input_spec: Composite,
Expand Down Expand Up @@ -522,3 +534,55 @@ def is_rnn(self) -> bool:
@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
raise NotImplementedError


@dataclass
class EnsembleModelConfig(ModelConfig):

model_configs_map: Dict[str, ModelConfig]

def get_model(self, agent_group: str, **kwargs) -> Model:
if agent_group not in self.model_configs_map.keys():
raise ValueError(
f"Environment contains agent group '{agent_group}' not present in the EnsembleModelConfig configuration."
)
return self.model_configs_map[agent_group].get_model(
**kwargs, agent_group=agent_group
)

@staticmethod
def associated_class():
class EnsembleModel(Model):
pass

return EnsembleModel

@property
def is_critic(self):
if not hasattr(self, "_is_critic"):
self._is_critic = False
return self._is_critic

@is_critic.setter
def is_critic(self, value):
self._is_critic = value
for model_config in self.model_configs_map.values():
model_config.is_critic = value

def _get_model_state_spec_inner(
self, model_index: int = 0, group: str = None
) -> Composite:
return self.model_configs_map[group].get_model_state_spec(
model_index=model_index
)

@property
def is_rnn(self) -> bool:
is_rnn = False
for model_config in self.model_configs_map.values():
is_rnn += model_config.is_rnn
return is_rnn

@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
raise NotImplementedError
6 changes: 4 additions & 2 deletions benchmarl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ def _add_rnn_transforms(

def model_fun():
env = env_fun()
spec_actor = model_config.get_model_state_spec()
spec_actor = Composite(
{
group: Composite(
spec_actor.expand(len(agents), *spec_actor.shape),
model_config._get_model_state_spec_inner(group=group).expand(
len(agents),
*model_config._get_model_state_spec_inner(group=group).shape
),
shape=(len(agents),),
)
for group, agents in group_map.items()
Expand Down
38 changes: 38 additions & 0 deletions examples/ensemble/ensemble_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from benchmarl.algorithms import EnsembleAlgorithmConfig, IsacConfig, MaddpgConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import MlpConfig


if __name__ == "__main__":

# Loads from "benchmarl/conf/experiment/base_experiment.yaml"
experiment_config = ExperimentConfig.get_from_yaml()

# Loads from "benchmarl/conf/task/vmas/simple_tag.yaml"
task = VmasTask.SIMPLE_TAG.get_from_yaml()

# Loads from "benchmarl/conf/model/layers/mlp.yaml"
model_config = MlpConfig.get_from_yaml()
critic_model_config = MlpConfig.get_from_yaml()

algorithm_config = EnsembleAlgorithmConfig(
{"agent": MaddpgConfig.get_from_yaml(), "adversary": IsacConfig.get_from_yaml()}
)

experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
)
experiment.run()
44 changes: 44 additions & 0 deletions examples/ensemble/ensemble_algorithm_and_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from benchmarl.algorithms import EnsembleAlgorithmConfig, IppoConfig, MappoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import MlpConfig
from models import DeepsetsConfig, EnsembleModelConfig, GnnConfig

if __name__ == "__main__":

# Loads from "benchmarl/conf/experiment/base_experiment.yaml"
experiment_config = ExperimentConfig.get_from_yaml()

# Loads from "benchmarl/conf/task/vmas/simple_tag.yaml"
task = VmasTask.SIMPLE_TAG.get_from_yaml()

algorithm_config = EnsembleAlgorithmConfig(
{"agent": MappoConfig.get_from_yaml(), "adversary": IppoConfig.get_from_yaml()}
)

model_config = EnsembleModelConfig(
{"agent": MlpConfig.get_from_yaml(), "adversary": GnnConfig.get_from_yaml()}
)
critic_model_config = EnsembleModelConfig(
{
"agent": DeepsetsConfig.get_from_yaml(),
"adversary": MlpConfig.get_from_yaml(),
}
)

experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
)
experiment.run()
40 changes: 40 additions & 0 deletions examples/ensemble/ensemble_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from benchmarl.algorithms import MappoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import EnsembleModelConfig, GnnConfig, MlpConfig


if __name__ == "__main__":

# Loads from "benchmarl/conf/experiment/base_experiment.yaml"
experiment_config = ExperimentConfig.get_from_yaml()

# Loads from "benchmarl/conf/task/vmas/simple_tag.yaml"
task = VmasTask.SIMPLE_TAG.get_from_yaml()

# Loads from "benchmarl/conf/algorithm/mappo.yaml"
algorithm_config = MappoConfig.get_from_yaml()

# Loads from "benchmarl/conf/model/layers/mlp.yaml"
critic_model_config = MlpConfig.get_from_yaml()

model_config = EnsembleModelConfig(
{"agent": MlpConfig.get_from_yaml(), "adversary": GnnConfig.get_from_yaml()}
)

experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
)
experiment.run()
Loading