From 1d49049a67f14b79b422dd9521f32be864f419eb Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Tue, 1 Sep 2020 15:49:49 +0530 Subject: [PATCH 01/16] Single actor critic shared params --- genrl/agents/deep/a2c/a2c.py | 30 ++++++++-- genrl/agents/deep/base/base.py | 2 + genrl/agents/deep/base/offpolicy.py | 3 +- genrl/agents/deep/ddpg/ddpg.py | 26 ++++++++- genrl/agents/deep/ppo1/ppo1.py | 29 +++++++++- genrl/core/actor_critic.py | 70 +++++++++++++++++++++++- tests/test_deep/test_agents/test_a2c.py | 8 +++ tests/test_deep/test_agents/test_ddpg.py | 26 +++++++++ tests/test_deep/test_agents/test_ppo1.py | 8 +++ 9 files changed, 188 insertions(+), 14 deletions(-) diff --git a/genrl/agents/deep/a2c/a2c.py b/genrl/agents/deep/a2c/a2c.py index 595b9f14..e990f531 100644 --- a/genrl/agents/deep/a2c/a2c.py +++ b/genrl/agents/deep/a2c/a2c.py @@ -66,7 +66,24 @@ def _create_model(self) -> None: state_dim, action_dim, discrete, action_lim = get_env_properties( self.env, self.network ) - if isinstance(self.network, str): + if isinstance(self.network, str) and self.shared_layers is not None: + self.ac = get_model("ac", self.network + "s")( + state_dim, + action_dim, + shared_layers=self.shared_layers, + policy_layers=self.policy_layers, + value_layers=self.value_layers, + val_type="V", + discrete=discrete, + action_lim=action_lim, + ).to(self.device) + actor_params = list(self.ac.shared.parameters()) + list( + self.ac.actor.parameters() + ) + critic_params = list(self.ac.shared.parameters()) + list( + self.ac.critic.parameters() + ) + elif isinstance(self.network, str) and self.shared_layers is None: self.ac = get_model("ac", self.network)( state_dim, action_dim, @@ -76,18 +93,21 @@ def _create_model(self) -> None: discrete=discrete, action_lim=action_lim, ).to(self.device) + actor_params = self.ac.actor.parameters() + critic_params = self.ac.critic.parameters() + else: self.ac = self.network.to(self.device) - - # action_dim = self.network.action_dim + actor_params = self.ac.actor.parameters() + critic_params = self.ac.critic.parameters() if self.noise is not None: self.noise = self.noise( np.zeros_like(action_dim), self.noise_std * np.ones_like(action_dim) ) - self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy) - self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value) + self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) + self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) def select_action( self, state: np.ndarray, deterministic: bool = False diff --git a/genrl/agents/deep/base/base.py b/genrl/agents/deep/base/base.py index f37907a9..c2067c91 100644 --- a/genrl/agents/deep/base/base.py +++ b/genrl/agents/deep/base/base.py @@ -34,6 +34,7 @@ def __init__( create_model: bool = True, batch_size: int = 64, gamma: float = 0.99, + shared_layers=None, policy_layers: Tuple = (64, 64), value_layers: Tuple = (64, 64), lr_policy: float = 0.0001, @@ -45,6 +46,7 @@ def __init__( self.create_model = create_model self.batch_size = batch_size self.gamma = gamma + self.shared_layers = shared_layers self.policy_layers = policy_layers self.value_layers = value_layers self.lr_policy = lr_policy diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index 916a60ec..f64f3dff 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -174,7 +174,7 @@ def select_action( # add noise to output from policy network if self.noise is not None: - action += self.noise() + action = action + self.noise() return np.clip( action, self.env.action_space.low[0], self.env.action_space.high[0] @@ -233,7 +233,6 @@ def get_target_q_values( next_q_target_values = self.ac_target.get_value( torch.cat([next_states, next_target_actions], dim=-1) ) - target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values return target_q_values diff --git a/genrl/agents/deep/ddpg/ddpg.py b/genrl/agents/deep/ddpg/ddpg.py index 24f004b6..9adfc27a 100644 --- a/genrl/agents/deep/ddpg/ddpg.py +++ b/genrl/agents/deep/ddpg/ddpg.py @@ -62,7 +62,23 @@ def _create_model(self) -> None: np.zeros_like(action_dim), self.noise_std * np.ones_like(action_dim) ) - if isinstance(self.network, str): + if isinstance(self.network, str) and self.shared_layers is not None: + self.ac = get_model("ac", self.network + "s")( + state_dim, + action_dim, + self.shared_layers, + self.policy_layers, + self.value_layers, + "Qsa", + False, + ).to(self.device) + actor_params = list(self.ac.actor.parameters()) + list( + self.ac.shared.parameters() + ) + critic_params = list(self.ac.critic.parameters()) + list( + self.ac.shared.parameters() + ) + elif isinstance(self.network, str) and self.shared_layers is None: self.ac = get_model("ac", self.network)( state_dim, action_dim, @@ -71,13 +87,17 @@ def _create_model(self) -> None: "Qsa", False, ).to(self.device) + actor_params = self.ac.actor.parameters() + critic_params = self.ac.critic.parameters() else: self.ac = self.network + actor_params = self.ac.actor.parameters() + critic_params = self.ac.critic.parameters() self.ac_target = deepcopy(self.ac).to(self.device) - self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy) - self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value) + self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) + self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) def update_params(self, update_interval: int) -> None: """Update parameters of the model diff --git a/genrl/agents/deep/ppo1/ppo1.py b/genrl/agents/deep/ppo1/ppo1.py index 0987d078..456aa7d1 100644 --- a/genrl/agents/deep/ppo1/ppo1.py +++ b/genrl/agents/deep/ppo1/ppo1.py @@ -66,10 +66,29 @@ def _create_model(self): state_dim, action_dim, discrete, action_lim = get_env_properties( self.env, self.network ) - if isinstance(self.network, str): + if isinstance(self.network, str) and self.shared_layers is not None: + self.ac = get_model("ac", self.network + "s")( + state_dim, + action_dim, + shared_layers=self.shared_layers, + policy_layers=self.policy_layers, + value_layers=self.value_layers, + val_typ="V", + discrete=discrete, + action_lim=action_lim, + activation=self.activation, + ).to(self.device) + actor_params = list(self.ac.shared.parameters()) + list( + self.ac.actor.parameters() + ) + critic_params = list(self.ac.shared.parameters()) + list( + self.ac.critic.parameters() + ) + elif isinstance(self.network, str) and self.shared_layers is None: self.ac = get_model("ac", self.network)( state_dim, action_dim, + shared_layers=self.shared_layers, policy_layers=self.policy_layers, value_layers=self.value_layers, val_typ="V", @@ -77,11 +96,15 @@ def _create_model(self): action_lim=action_lim, activation=self.activation, ).to(self.device) + actor_params = self.ac.actor.parameters() + critic_params = self.ac.critic.parameters() else: self.ac = self.network.to(self.device) + actor_params = self.ac.actor.parameters() + critic_params = self.ac.critic.parameters() - self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy) - self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value) + self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) + self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) def select_action( self, state: np.ndarray, deterministic: bool = False diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index 6214ec46..1ce61f72 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -9,7 +9,7 @@ from genrl.core.base import BaseActorCritic from genrl.core.policies import MlpPolicy from genrl.core.values import MlpValue -from genrl.utils.utils import cnn +from genrl.utils.utils import cnn, mlp class MlpActorCritic(BaseActorCritic): @@ -41,6 +41,73 @@ def __init__( self.critic = MlpValue(state_dim, action_dim, val_type, value_layers, **kwargs) +class MlpSharedActorCritic(BaseActorCritic): + """MLP Shared Actor Critic + + Attributes: + state_dim (int): State dimensions of the environment + action_dim (int): Action space dimensions of the environment + hidden (:obj:`list` or :obj:`tuple`): Hidden layers in the MLP + val_type (str): Value type of the critic network + discrete (bool): True if the action space is discrete, else False + sac (bool): True if a SAC-like network is needed, else False + activation (str): Activation function to be used. Can be either "tanh" or "relu" + """ + + def __init__( + self, + state_dim: spaces.Space, + action_dim: spaces.Space, + shared_layers: Tuple = (32, 32), + policy_layers: Tuple = (32, 32), + value_layers: Tuple = (32, 32), + val_type: str = "V", + discrete: bool = True, + **kwargs, + ): + super(MlpSharedActorCritic, self).__init__() + self.shared = mlp([state_dim] + list(shared_layers)) + self.actor = MlpPolicy( + shared_layers[-1], action_dim, policy_layers, discrete, **kwargs + ) + self.critic = MlpValue( + shared_layers[-1], action_dim, val_type, value_layers, **kwargs + ) + self.state_dim = state_dim + self.action_dim = action_dim + + def get_features(self, state: torch.Tensor): + features = self.shared(state) + return features + + def get_action(self, state: torch.Tensor, deterministic: bool = False): + state = torch.as_tensor(state).float() + features = self.get_features(state) + action_probs = self.actor(features) + action_probs = nn.Softmax(dim=-1)(action_probs) + + if deterministic: + action = torch.argmax(action_probs, dim=-1).unsqueeze(-1).float() + distribution = None + else: + distribution = Categorical(probs=action_probs) + action = distribution.sample() + + return action, distribution + + def get_value(self, state: torch.Tensor): + state = torch.as_tensor(state).float() + if self.critic.val_type == "Qsa": + features = self.shared(state[:, :, :-1]) + features = torch.cat([features, state[:, :, -1].unsqueeze(-1)], dim=-1) + print(f"features {features.shape}") + value = self.critic(features).float().squeeze(-1) + else: + features = self.shared(state) + value = self.critic(features) + return value + + class MlpSingleActorMultiCritic(BaseActorCritic): """MLP Actor Critic @@ -220,6 +287,7 @@ def get_value(self, inp: torch.Tensor) -> torch.Tensor: "mlp": MlpActorCritic, "cnn": CNNActorCritic, "mlp12": MlpSingleActorMultiCritic, + "mlps": MlpSharedActorCritic, } diff --git a/tests/test_deep/test_agents/test_a2c.py b/tests/test_deep/test_agents/test_a2c.py index 2b012069..f731f40f 100644 --- a/tests/test_deep/test_agents/test_a2c.py +++ b/tests/test_deep/test_agents/test_a2c.py @@ -19,3 +19,11 @@ def test_a2c_cnn(): trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") + + +def test_a2c_shared(): + env = VectorEnv("CartPole-v0", 1) + algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128) + trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) + trainer.train() + shutil.rmtree("./logs") diff --git a/tests/test_deep/test_agents/test_ddpg.py b/tests/test_deep/test_agents/test_ddpg.py index ab309518..94670cef 100644 --- a/tests/test_deep/test_agents/test_ddpg.py +++ b/tests/test_deep/test_agents/test_ddpg.py @@ -29,3 +29,29 @@ def test_ddpg(): ) trainer.train() shutil.rmtree("./logs") + + +def test_ddpg_shared(): + env = VectorEnv("Pendulum-v0", 2) + algo = DDPG( + "mlp", + env, + batch_size=5, + noise=NormalActionNoise, + shared_layers=[1, 1], + policy_layers=[1, 1], + value_layers=[1, 1], + ) + + trainer = OffPolicyTrainer( + algo, + env, + log_mode=["csv"], + logdir="./logs", + epochs=4, + max_ep_len=200, + warmup_steps=10, + start_update=10, + ) + trainer.train() + shutil.rmtree("./logs") diff --git a/tests/test_deep/test_agents/test_ppo1.py b/tests/test_deep/test_agents/test_ppo1.py index 3e9feaf2..1bb06a22 100644 --- a/tests/test_deep/test_agents/test_ppo1.py +++ b/tests/test_deep/test_agents/test_ppo1.py @@ -19,3 +19,11 @@ def test_ppo1_cnn(): trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") + + +def test_ppo1_shared(): + env = VectorEnv("CartPole-v0") + algo = PPO1("mlp", env, shared_layers=(32, 32), rollout_size=128) + trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) + trainer.train() + shutil.rmtree("./logs") From ef4a179a321ed5a4f306067a77712948d7b2e93b Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Wed, 2 Sep 2020 01:17:47 +0530 Subject: [PATCH 02/16] Shared layers for multi ACs --- genrl/agents/deep/a2c/a2c.py | 32 ++--- genrl/agents/deep/base/offpolicy.py | 2 +- genrl/agents/deep/ddpg/ddpg.py | 27 +--- genrl/agents/deep/ppo1/ppo1.py | 30 +--- genrl/agents/deep/sac/sac.py | 16 +-- genrl/agents/deep/td3/td3.py | 18 ++- genrl/core/actor_critic.py | 160 ++++++++++++++++++++- tests/test_deep/test_agents/test_custom.py | 7 + tests/test_deep/test_agents/test_ppo1.py | 8 -- tests/test_deep/test_agents/test_sac.py | 25 ++++ tests/test_deep/test_agents/test_td3.py | 26 ++++ 11 files changed, 253 insertions(+), 98 deletions(-) diff --git a/genrl/agents/deep/a2c/a2c.py b/genrl/agents/deep/a2c/a2c.py index e990f531..86cd84ad 100644 --- a/genrl/agents/deep/a2c/a2c.py +++ b/genrl/agents/deep/a2c/a2c.py @@ -66,8 +66,11 @@ def _create_model(self) -> None: state_dim, action_dim, discrete, action_lim = get_env_properties( self.env, self.network ) - if isinstance(self.network, str) and self.shared_layers is not None: - self.ac = get_model("ac", self.network + "s")( + if isinstance(self.network, str): + arch_type = self.network + if self.shared_layers is not None: + arch_type += "s" + self.ac = get_model("ac", arch_type)( state_dim, action_dim, shared_layers=self.shared_layers, @@ -77,37 +80,18 @@ def _create_model(self) -> None: discrete=discrete, action_lim=action_lim, ).to(self.device) - actor_params = list(self.ac.shared.parameters()) + list( - self.ac.actor.parameters() - ) - critic_params = list(self.ac.shared.parameters()) + list( - self.ac.critic.parameters() - ) - elif isinstance(self.network, str) and self.shared_layers is None: - self.ac = get_model("ac", self.network)( - state_dim, - action_dim, - policy_layers=self.policy_layers, - value_layers=self.value_layers, - val_type="V", - discrete=discrete, - action_lim=action_lim, - ).to(self.device) - actor_params = self.ac.actor.parameters() - critic_params = self.ac.critic.parameters() else: self.ac = self.network.to(self.device) - actor_params = self.ac.actor.parameters() - critic_params = self.ac.critic.parameters() if self.noise is not None: self.noise = self.noise( np.zeros_like(action_dim), self.noise_std * np.ones_like(action_dim) ) - self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) - self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) + actor_params, critic_params = self.ac.get_params() + self.optimizer_policy = opt.Adam(critic_params, lr=self.lr_policy) + self.optimizer_value = opt.Adam(actor_params, lr=self.lr_value) def select_action( self, state: np.ndarray, deterministic: bool = False diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index f64f3dff..656d7911 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -174,7 +174,7 @@ def select_action( # add noise to output from policy network if self.noise is not None: - action = action + self.noise() + action += self.noise() return np.clip( action, self.env.action_space.low[0], self.env.action_space.high[0] diff --git a/genrl/agents/deep/ddpg/ddpg.py b/genrl/agents/deep/ddpg/ddpg.py index 9adfc27a..0d09314b 100644 --- a/genrl/agents/deep/ddpg/ddpg.py +++ b/genrl/agents/deep/ddpg/ddpg.py @@ -62,8 +62,11 @@ def _create_model(self) -> None: np.zeros_like(action_dim), self.noise_std * np.ones_like(action_dim) ) - if isinstance(self.network, str) and self.shared_layers is not None: - self.ac = get_model("ac", self.network + "s")( + if isinstance(self.network, str): + arch_type = self.network + if self.shared_layers is not None: + arch_type += "s" + self.ac = get_model("ac", arch_type)( state_dim, action_dim, self.shared_layers, @@ -72,28 +75,10 @@ def _create_model(self) -> None: "Qsa", False, ).to(self.device) - actor_params = list(self.ac.actor.parameters()) + list( - self.ac.shared.parameters() - ) - critic_params = list(self.ac.critic.parameters()) + list( - self.ac.shared.parameters() - ) - elif isinstance(self.network, str) and self.shared_layers is None: - self.ac = get_model("ac", self.network)( - state_dim, - action_dim, - self.policy_layers, - self.value_layers, - "Qsa", - False, - ).to(self.device) - actor_params = self.ac.actor.parameters() - critic_params = self.ac.critic.parameters() else: self.ac = self.network - actor_params = self.ac.actor.parameters() - critic_params = self.ac.critic.parameters() + actor_params, critic_params = self.ac.get_params() self.ac_target = deepcopy(self.ac).to(self.device) self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) diff --git a/genrl/agents/deep/ppo1/ppo1.py b/genrl/agents/deep/ppo1/ppo1.py index 456aa7d1..7359e621 100644 --- a/genrl/agents/deep/ppo1/ppo1.py +++ b/genrl/agents/deep/ppo1/ppo1.py @@ -66,8 +66,11 @@ def _create_model(self): state_dim, action_dim, discrete, action_lim = get_env_properties( self.env, self.network ) - if isinstance(self.network, str) and self.shared_layers is not None: - self.ac = get_model("ac", self.network + "s")( + if isinstance(self.network, str): + arch = self.network + if self.shared_layers is not None: + arch += "s" + self.ac = get_model("ac", arch)( state_dim, action_dim, shared_layers=self.shared_layers, @@ -78,31 +81,10 @@ def _create_model(self): action_lim=action_lim, activation=self.activation, ).to(self.device) - actor_params = list(self.ac.shared.parameters()) + list( - self.ac.actor.parameters() - ) - critic_params = list(self.ac.shared.parameters()) + list( - self.ac.critic.parameters() - ) - elif isinstance(self.network, str) and self.shared_layers is None: - self.ac = get_model("ac", self.network)( - state_dim, - action_dim, - shared_layers=self.shared_layers, - policy_layers=self.policy_layers, - value_layers=self.value_layers, - val_typ="V", - discrete=discrete, - action_lim=action_lim, - activation=self.activation, - ).to(self.device) - actor_params = self.ac.actor.parameters() - critic_params = self.ac.critic.parameters() else: self.ac = self.network.to(self.device) - actor_params = self.ac.actor.parameters() - critic_params = self.ac.critic.parameters() + actor_params, critic_params = self.ac.get_params() self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) diff --git a/genrl/agents/deep/sac/sac.py b/genrl/agents/deep/sac/sac.py index b7a5572d..54c7f87b 100644 --- a/genrl/agents/deep/sac/sac.py +++ b/genrl/agents/deep/sac/sac.py @@ -76,8 +76,10 @@ def _create_model(self, **kwargs) -> None: state_dim, action_dim, discrete, _ = get_env_properties( self.env, self.network ) - - self.ac = get_model("ac", self.network + "12")( + arch = self.network + "12" + if self.shared_layers is not None: + arch += "s" + self.ac = get_model("ac", arch)( state_dim, action_dim, policy_layers=self.policy_layers, @@ -92,13 +94,9 @@ def _create_model(self, **kwargs) -> None: self.model = self.network self.ac_target = deepcopy(self.ac) - - self.critic_params = list(self.ac.critic1.parameters()) + list( - self.ac.critic2.parameters() - ) - - self.optimizer_value = opt.Adam(self.critic_params, self.lr_value) - self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), self.lr_policy) + actor_params, critic_params = self.ac.get_params() + self.optimizer_value = opt.Adam(critic_params, self.lr_value) + self.optimizer_policy = opt.Adam(actor_params, self.lr_policy) if self.entropy_tuning: self.target_entropy = -torch.prod( diff --git a/genrl/agents/deep/td3/td3.py b/genrl/agents/deep/td3/td3.py index a9687446..5a8e83d2 100644 --- a/genrl/agents/deep/td3/td3.py +++ b/genrl/agents/deep/td3/td3.py @@ -68,10 +68,13 @@ def _create_model(self) -> None: ) if isinstance(self.network, str): - # Below, the "12" corresponds to the Single Actor, Double Critic network architecture - self.ac = get_model("ac", self.network + "12")( + arch = self.network + "12" + if self.shared_layers is not None: + arch += "s" + self.ac = get_model("ac", arch)( state_dim, action_dim, + shared_layers=self.shared_layers, policy_layers=self.policy_layers, value_layers=self.value_layers, val_type="Qsa", @@ -86,14 +89,9 @@ def _create_model(self) -> None: ) self.ac_target = deepcopy(self.ac) - - self.critic_params = list(self.ac.critic1.parameters()) + list( - self.ac.critic2.parameters() - ) - self.optimizer_value = torch.optim.Adam(self.critic_params, lr=self.lr_value) - self.optimizer_policy = torch.optim.Adam( - self.ac.actor.parameters(), lr=self.lr_policy - ) + actor_params, critic_params = self.ac.get_params() + self.optimizer_value = torch.optim.Adam(critic_params, lr=self.lr_value) + self.optimizer_policy = torch.optim.Adam(actor_params, lr=self.lr_policy) def update_params(self, update_interval: int) -> None: """Update parameters of the model diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index 1ce61f72..1b135ebd 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -29,6 +29,7 @@ def __init__( self, state_dim: spaces.Space, action_dim: spaces.Space, + shared_layers: None, policy_layers: Tuple = (32, 32), value_layers: Tuple = (32, 32), val_type: str = "V", @@ -40,6 +41,11 @@ def __init__( self.actor = MlpPolicy(state_dim, action_dim, policy_layers, discrete, **kwargs) self.critic = MlpValue(state_dim, action_dim, val_type, value_layers, **kwargs) + def get_params(self): + actor_params = self.actor.parameters() + critic_params = self.critic.parameters() + return actor_params, critic_params + class MlpSharedActorCritic(BaseActorCritic): """MLP Shared Actor Critic @@ -76,7 +82,20 @@ def __init__( self.state_dim = state_dim self.action_dim = action_dim + def get_params(self): + actor_params = list(self.shared.parameters()) + list(self.actor.parameters()) + critic_params = list(self.shared.parameters()) + list(self.critic.parameters()) + return actor_params, critic_params + def get_features(self, state: torch.Tensor): + """Extract features from the state, which is then an input to get_action and get_value + + Args: + state (:obj:`torch.Tensor`): The state(s) being passed + + Returns: + features (:obj:`torch.Tensor`): The feature(s) extracted from the state + """ features = self.shared(state) return features @@ -100,7 +119,6 @@ def get_value(self, state: torch.Tensor): if self.critic.val_type == "Qsa": features = self.shared(state[:, :, :-1]) features = torch.cat([features, state[:, :, -1].unsqueeze(-1)], dim=-1) - print(f"features {features.shape}") value = self.critic(features).float().squeeze(-1) else: features = self.shared(state) @@ -144,6 +162,137 @@ def __init__( self.action_scale = kwargs["action_scale"] if "action_scale" in kwargs else 1 self.action_bias = kwargs["action_bias"] if "action_bias" in kwargs else 0 + def get_params(self): + actor_params = self.actor.parameters() + critic_params = list(self.critic1.parameters()) + list( + self.critic2.parameters() + ) + return actor_params, critic_params + + def forward(self, x): + q1_values = self.critic1(x).squeeze(-1) + q2_values = self.critic2(x).squeeze(-1) + return (q1_values, q2_values) + + def get_action(self, state: torch.Tensor, deterministic: bool = False): + state = torch.as_tensor(state).float() + + if self.actor.sac: + mean, log_std = self.actor(state) + std = log_std.exp() + distribution = Normal(mean, std) + + action_probs = distribution.rsample() + log_probs = distribution.log_prob(action_probs) + action_probs = torch.tanh(action_probs) + + action = action_probs * self.action_scale + self.action_bias + + # enforcing action bound (appendix of SAC paper) + log_probs -= torch.log( + self.action_scale * (1 - action_probs.pow(2)) + np.finfo(np.float32).eps + ) + log_probs = log_probs.sum(1, keepdim=True) + mean = torch.tanh(mean) * self.action_scale + self.action_bias + + action = (action.float(), log_probs, mean) + else: + action = self.actor.get_action(state, deterministic=deterministic) + + return action + + def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor: + """Get Values from the Critic + + Arg: + state (:obj:`torch.Tensor`): The state(s) being passed to the critics + mode (str): What values should be returned. Types: + "both" --> Both values will be returned + "min" --> The minimum of both values will be returned + "first" --> The value from the first critic only will be returned + + Returns: + values (:obj:`list`): List of values as estimated by each individual critic + """ + state = torch.as_tensor(state).float() + + if mode == "both": + values = self.forward(state) + elif mode == "min": + values = self.forward(state) + values = torch.min(*values).squeeze(-1) + elif mode == "first": + values = self.critic1(state) + else: + raise KeyError("Mode doesn't exist") + + return values + + +class MlpSharedSingleActorMultiCritic(BaseActorCritic): + """MLP Actor Critic + + Attributes: + state_dim (int): State dimensions of the environment + action_dim (int): Action space dimensions of the environment + hidden (:obj:`list` or :obj:`tuple`): Hidden layers in the MLP + val_type (str): Value type of the critic network + discrete (bool): True if the action space is discrete, else False + num_critics (int): Number of critics in the architecture + sac (bool): True if a SAC-like network is needed, else False + activation (str): Activation function to be used. Can be either "tanh" or "relu" + """ + + def __init__( + self, + state_dim: spaces.Space, + action_dim: spaces.Space, + shared_layers: Tuple = (32, 32), + policy_layers: Tuple = (32, 32), + value_layers: Tuple = (32, 32), + val_type: str = "V", + discrete: bool = True, + num_critics: int = 2, + **kwargs, + ): + super(MlpSharedSingleActorMultiCritic, self).__init__() + + self.num_critics = num_critics + self.shared = mlp([state_dim] + list(shared_layers)) + self.actor = MlpPolicy( + shared_layers[-1], action_dim, policy_layers, discrete, **kwargs + ) + self.critic1 = MlpValue( + shared_layers[-1], action_dim, "Qsa", value_layers, **kwargs + ) + self.critic2 = MlpValue( + shared_layers[-1], action_dim, "Qsa", value_layers, **kwargs + ) + + self.action_scale = kwargs["action_scale"] if "action_scale" in kwargs else 1 + self.action_bias = kwargs["action_bias"] if "action_bias" in kwargs else 0 + + def get_params(self): + actor_params = list(self.actor.parameters()) + list(self.shared.parameters()) + critic_params = ( + list(self.critic1.parameters()) + + list(self.critic2.parameters()) + + list(self.shared.parameters()) + ) + return actor_params, critic_params + + def get_features(self, state: torch.Tensor): + """Extract features from the state, which is then an input to get_action and get_value + + Args: + state (:obj:`torch.Tensor`): The state(s) being passed + + Returns: + features (:obj:`torch.Tensor`): The feature(s) extracted from the state + """ + features = self.shared(state) + return features + def forward(self, x): q1_values = self.critic1(x).squeeze(-1) q2_values = self.critic2(x).squeeze(-1) @@ -151,6 +300,7 @@ def forward(self, x): def get_action(self, state: torch.Tensor, deterministic: bool = False): state = torch.as_tensor(state).float() + state = self.get_features(state) if self.actor.sac: mean, log_std = self.actor(state) @@ -190,6 +340,8 @@ def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor: values (:obj:`list`): List of values as estimated by each individual critic """ state = torch.as_tensor(state).float() + x = self.get_features(state[:, :, :-1]) + state = torch.cat([x, state[:, :, -1].unsqueeze(-1)], dim=-1) if mode == "both": values = self.forward(state) @@ -240,6 +392,11 @@ def __init__( ) self.critic = MlpValue(output_size, action_dim, val_type, value_layers) + def get_params(self): + actor_params = list(self.feature.parameters()) + list(self.actor.parameters()) + critic_params = list(self.feature.parameters()) + list(self.critic.parameters()) + return actor_params, critic_params + def get_action( self, state: torch.Tensor, deterministic: bool = False ) -> torch.Tensor: @@ -288,6 +445,7 @@ def get_value(self, inp: torch.Tensor) -> torch.Tensor: "cnn": CNNActorCritic, "mlp12": MlpSingleActorMultiCritic, "mlps": MlpSharedActorCritic, + "mlp12s": MlpSharedSingleActorMultiCritic, } diff --git a/tests/test_deep/test_agents/test_custom.py b/tests/test_deep/test_agents/test_custom.py index c4614b70..a0c97063 100644 --- a/tests/test_deep/test_agents/test_custom.py +++ b/tests/test_deep/test_agents/test_custom.py @@ -24,6 +24,7 @@ def __init__( self, state_dim, action_dim, + shared_layers=None, policy_layers=(1, 1), value_layers=(1, 1), val_type="V", @@ -32,12 +33,18 @@ def __init__( super(custom_actorcritic, self).__init__( state_dim, action_dim, + shared_layers=shared_layers, policy_layers=policy_layers, value_layers=value_layers, val_type=val_type, **kwargs ) + def get_params(self): + actor_params = self.actor.parameters() + critic_params = self.critic.parameters() + return actor_params, critic_params + def test_custom_vpg(): env = VectorEnv("CartPole-v0", 1) diff --git a/tests/test_deep/test_agents/test_ppo1.py b/tests/test_deep/test_agents/test_ppo1.py index 1bb06a22..3e9feaf2 100644 --- a/tests/test_deep/test_agents/test_ppo1.py +++ b/tests/test_deep/test_agents/test_ppo1.py @@ -19,11 +19,3 @@ def test_ppo1_cnn(): trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") - - -def test_ppo1_shared(): - env = VectorEnv("CartPole-v0") - algo = PPO1("mlp", env, shared_layers=(32, 32), rollout_size=128) - trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) - trainer.train() - shutil.rmtree("./logs") diff --git a/tests/test_deep/test_agents/test_sac.py b/tests/test_deep/test_agents/test_sac.py index 8755c5c4..3ea1cfee 100644 --- a/tests/test_deep/test_agents/test_sac.py +++ b/tests/test_deep/test_agents/test_sac.py @@ -21,3 +21,28 @@ def test_sac(): ) trainer.train() shutil.rmtree("./logs") + + +def test_sac_shared(): + env = VectorEnv("Pendulum-v0", 2) + algo = SAC( + "mlp", + env, + batch_size=5, + shared_layers=[1, 1], + policy_layers=[1, 1], + value_layers=[1, 1], + ) + + trainer = OffPolicyTrainer( + algo, + env, + log_mode=["csv"], + logdir="./logs", + epochs=5, + max_ep_len=500, + warmup_steps=10, + start_update=10, + ) + trainer.train() + shutil.rmtree("./logs") diff --git a/tests/test_deep/test_agents/test_td3.py b/tests/test_deep/test_agents/test_td3.py index e3d59491..35def46f 100644 --- a/tests/test_deep/test_agents/test_td3.py +++ b/tests/test_deep/test_agents/test_td3.py @@ -29,3 +29,29 @@ def test_td3(): ) trainer.train() shutil.rmtree("./logs") + + +def test_td3_shared(): + env = VectorEnv("Pendulum-v0", 2) + algo = TD3( + "mlp", + env, + batch_size=5, + noise=OrnsteinUhlenbeckActionNoise, + shared_layers=[1, 1], + policy_layers=[1, 1], + value_layers=[1, 1], + ) + + trainer = OffPolicyTrainer( + algo, + env, + log_mode=["csv"], + logdir="./logs", + epochs=5, + max_ep_len=500, + warmup_steps=10, + start_update=10, + ) + trainer.train() + shutil.rmtree("./logs") From 53450a8399530e68ba179cbe59ddeaa8f354a503 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Wed, 2 Sep 2020 01:28:36 +0530 Subject: [PATCH 03/16] Fix lint errors (1) --- genrl/core/actor_critic.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index 1b135ebd..1660e132 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -90,12 +90,12 @@ def get_params(self): def get_features(self, state: torch.Tensor): """Extract features from the state, which is then an input to get_action and get_value - Args: - state (:obj:`torch.Tensor`): The state(s) being passed + Args: + state (:obj:`torch.Tensor`): The state(s) being passed - Returns: - features (:obj:`torch.Tensor`): The feature(s) extracted from the state - """ + Returns: + features (:obj:`torch.Tensor`): The feature(s) extracted from the state + """ features = self.shared(state) return features @@ -392,11 +392,6 @@ def __init__( ) self.critic = MlpValue(output_size, action_dim, val_type, value_layers) - def get_params(self): - actor_params = list(self.feature.parameters()) + list(self.actor.parameters()) - critic_params = list(self.feature.parameters()) + list(self.critic.parameters()) - return actor_params, critic_params - def get_action( self, state: torch.Tensor, deterministic: bool = False ) -> torch.Tensor: From 274aff98cb11915be2a984624f1e9c9bc22b2fa4 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Wed, 2 Sep 2020 02:14:03 +0530 Subject: [PATCH 04/16] Fixed tests --- genrl/core/actor_critic.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index 1660e132..b45222de 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -392,6 +392,11 @@ def __init__( ) self.critic = MlpValue(output_size, action_dim, val_type, value_layers) + def get_params(self): + actor_params = list(self.feature.parameters()) + list(self.actor.parameters()) + critic_params = list(self.feature.parameters()) + list(self.critic.parameters()) + return actor_params, critic_params + def get_action( self, state: torch.Tensor, deterministic: bool = False ) -> torch.Tensor: From 38f95f00ee2397844fb174cda61641134adc4a04 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Thu, 3 Sep 2020 01:38:45 +0530 Subject: [PATCH 05/16] Changes to dicstrings and classes --- genrl/core/actor_critic.py | 141 +++++++++++------------ tests/test_deep/test_agents/test_ppo1.py | 8 ++ 2 files changed, 76 insertions(+), 73 deletions(-) diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index b45222de..aa460be7 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -18,7 +18,8 @@ class MlpActorCritic(BaseActorCritic): Attributes: state_dim (int): State dimensions of the environment action_dim (int): Action space dimensions of the environment - hidden (:obj:`list` or :obj:`tuple`): Hidden layers in the MLP + policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP + value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP val_type (str): Value type of the critic network discrete (bool): True if the action space is discrete, else False sac (bool): True if a SAC-like network is needed, else False @@ -53,7 +54,9 @@ class MlpSharedActorCritic(BaseActorCritic): Attributes: state_dim (int): State dimensions of the environment action_dim (int): Action space dimensions of the environment - hidden (:obj:`list` or :obj:`tuple`): Hidden layers in the MLP + shared_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the shared MLP + policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP + value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP val_type (str): Value type of the critic network discrete (bool): True if the action space is discrete, else False sac (bool): True if a SAC-like network is needed, else False @@ -100,6 +103,18 @@ def get_features(self, state: torch.Tensor): return features def get_action(self, state: torch.Tensor, deterministic: bool = False): + """Get Actions from the actor + + Arg: + state (:obj:`torch.Tensor`): The state(s) being passed to the critics + deterministic (bool): True if the action space is deterministic, else False + + Returns: + action (:obj:`list`): List of actions as estimated by the critic + distribution (): The distribution from which the action was sampled + (None if determinist + """ + state = torch.as_tensor(state).float() features = self.get_features(state) action_probs = self.actor(features) @@ -115,6 +130,14 @@ def get_action(self, state: torch.Tensor, deterministic: bool = False): return action, distribution def get_value(self, state: torch.Tensor): + """Get Values from the Critic + + Arg: + state (:obj:`torch.Tensor`): The state(s) being passed to the critics + + Returns: + values (:obj:`list`): List of values as estimated by the critic + """ state = torch.as_tensor(state).float() if self.critic.val_type == "Qsa": features = self.shared(state[:, :, :-1]) @@ -132,7 +155,8 @@ class MlpSingleActorMultiCritic(BaseActorCritic): Attributes: state_dim (int): State dimensions of the environment action_dim (int): Action space dimensions of the environment - hidden (:obj:`list` or :obj:`tuple`): Hidden layers in the MLP + policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP + value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP val_type (str): Value type of the critic network discrete (bool): True if the action space is discrete, else False num_critics (int): Number of critics in the architecture @@ -175,6 +199,17 @@ def forward(self, x): return (q1_values, q2_values) def get_action(self, state: torch.Tensor, deterministic: bool = False): + """Get Actions from the actor + + Arg: + state (:obj:`torch.Tensor`): The state(s) being passed to the critics + deterministic (bool): True if the action space is deterministic, else False + + Returns: + action (:obj:`list`): List of actions as estimated by the critic + distribution (): The distribution from which the action was sampled + (None if determinist + """ state = torch.as_tensor(state).float() if self.actor.sac: @@ -229,13 +264,15 @@ def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor: return values -class MlpSharedSingleActorMultiCritic(BaseActorCritic): +class MlpSharedSingleActorMultiCritic(MlpSingleActorMultiCritic): """MLP Actor Critic Attributes: state_dim (int): State dimensions of the environment action_dim (int): Action space dimensions of the environment - hidden (:obj:`list` or :obj:`tuple`): Hidden layers in the MLP + shared_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the shared MLP + policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP + value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP val_type (str): Value type of the critic network discrete (bool): True if the action space is discrete, else False num_critics (int): Number of critics in the architecture @@ -250,36 +287,22 @@ def __init__( shared_layers: Tuple = (32, 32), policy_layers: Tuple = (32, 32), value_layers: Tuple = (32, 32), - val_type: str = "V", + val_type: str = "Qsa", discrete: bool = True, num_critics: int = 2, **kwargs, ): - super(MlpSharedSingleActorMultiCritic, self).__init__() - - self.num_critics = num_critics - self.shared = mlp([state_dim] + list(shared_layers)) - self.actor = MlpPolicy( - shared_layers[-1], action_dim, policy_layers, discrete, **kwargs - ) - self.critic1 = MlpValue( - shared_layers[-1], action_dim, "Qsa", value_layers, **kwargs + super(MlpSharedSingleActorMultiCritic, self).__init__( + shared_layers[-1], + action_dim, + policy_layers, + value_layers, + val_type, + discrete, + num_critics, + **kwargs, ) - self.critic2 = MlpValue( - shared_layers[-1], action_dim, "Qsa", value_layers, **kwargs - ) - - self.action_scale = kwargs["action_scale"] if "action_scale" in kwargs else 1 - self.action_bias = kwargs["action_bias"] if "action_bias" in kwargs else 0 - - def get_params(self): - actor_params = list(self.actor.parameters()) + list(self.shared.parameters()) - critic_params = ( - list(self.critic1.parameters()) - + list(self.critic2.parameters()) - + list(self.shared.parameters()) - ) - return actor_params, critic_params + self.shared = mlp([state_dim] + list(shared_layers)) def get_features(self, state: torch.Tensor): """Extract features from the state, which is then an input to get_action and get_value @@ -293,41 +316,24 @@ def get_features(self, state: torch.Tensor): features = self.shared(state) return features - def forward(self, x): - q1_values = self.critic1(x).squeeze(-1) - q2_values = self.critic2(x).squeeze(-1) - return (q1_values, q2_values) - def get_action(self, state: torch.Tensor, deterministic: bool = False): - state = torch.as_tensor(state).float() - state = self.get_features(state) + """Get Actions from the actor - if self.actor.sac: - mean, log_std = self.actor(state) - std = log_std.exp() - distribution = Normal(mean, std) - - action_probs = distribution.rsample() - log_probs = distribution.log_prob(action_probs) - action_probs = torch.tanh(action_probs) - - action = action_probs * self.action_scale + self.action_bias - - # enforcing action bound (appendix of SAC paper) - log_probs -= torch.log( - self.action_scale * (1 - action_probs.pow(2)) + np.finfo(np.float32).eps - ) - log_probs = log_probs.sum(1, keepdim=True) - mean = torch.tanh(mean) * self.action_scale + self.action_bias - - action = (action.float(), log_probs, mean) - else: - action = self.actor.get_action(state, deterministic=deterministic) + Arg: + state (:obj:`torch.Tensor`): The state(s) being passed to the critics + deterministic (bool): True if the action space is deterministic, else False - return action + Returns: + action (:obj:`list`): List of actions as estimated by the critic + distribution (): The distribution from which the action was sampled + (None if determinist + """ + return super(MlpSharedSingleActorMultiCritic, self).get_action( + self.get_features(state), deterministic=deterministic + ) - def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor: - """Get Values from the Critic + def get_value(self, state: torch.Tensor, mode="first"): + """Get Values from both the Critic Arg: state (:obj:`torch.Tensor`): The state(s) being passed to the critics @@ -342,18 +348,7 @@ def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor: state = torch.as_tensor(state).float() x = self.get_features(state[:, :, :-1]) state = torch.cat([x, state[:, :, -1].unsqueeze(-1)], dim=-1) - - if mode == "both": - values = self.forward(state) - elif mode == "min": - values = self.forward(state) - values = torch.min(*values).squeeze(-1) - elif mode == "first": - values = self.critic1(state) - else: - raise KeyError("Mode doesn't exist") - - return values + return super(MlpSharedSingleActorMultiCritic, self).get_value(state, mode) class CNNActorCritic(BaseActorCritic): diff --git a/tests/test_deep/test_agents/test_ppo1.py b/tests/test_deep/test_agents/test_ppo1.py index 3e9feaf2..97d40791 100644 --- a/tests/test_deep/test_agents/test_ppo1.py +++ b/tests/test_deep/test_agents/test_ppo1.py @@ -19,3 +19,11 @@ def test_ppo1_cnn(): trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") + + +def test_ppo1_shared(): + env = VectorEnv("CartPole-v0") + algo = PPO1("mlp", env, shared_layers=[32, 32], rollout_size=128) + trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) + trainer.train() + shutil.rmtree("./logs") From 835819e193413e59aa3b09f63cd7b385ea3a359a Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Fri, 4 Sep 2020 22:55:36 +0530 Subject: [PATCH 06/16] Renaming Multi -> Two and comments --- genrl/core/actor_critic.py | 82 +++++++++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 24 deletions(-) diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index aa460be7..a6fb7e49 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -75,7 +75,7 @@ def __init__( **kwargs, ): super(MlpSharedActorCritic, self).__init__() - self.shared = mlp([state_dim] + list(shared_layers)) + self.shared_network = mlp([state_dim] + list(shared_layers)) self.actor = MlpPolicy( shared_layers[-1], action_dim, policy_layers, discrete, **kwargs ) @@ -86,8 +86,12 @@ def __init__( self.action_dim = action_dim def get_params(self): - actor_params = list(self.shared.parameters()) + list(self.actor.parameters()) - critic_params = list(self.shared.parameters()) + list(self.critic.parameters()) + actor_params = list(self.shared_network.parameters()) + list( + self.actor.parameters() + ) + critic_params = list(self.shared_network.parameters()) + list( + self.critic.parameters() + ) return actor_params, critic_params def get_features(self, state: torch.Tensor): @@ -99,7 +103,7 @@ def get_features(self, state: torch.Tensor): Returns: features (:obj:`torch.Tensor`): The feature(s) extracted from the state """ - features = self.shared(state) + features = self.shared_network(state) return features def get_action(self, state: torch.Tensor, deterministic: bool = False): @@ -116,8 +120,8 @@ def get_action(self, state: torch.Tensor, deterministic: bool = False): """ state = torch.as_tensor(state).float() - features = self.get_features(state) - action_probs = self.actor(features) + shared_features = self.get_features(state) + action_probs = self.actor(shared_features) action_probs = nn.Softmax(dim=-1)(action_probs) if deterministic: @@ -139,17 +143,28 @@ def get_value(self, state: torch.Tensor): values (:obj:`list`): List of values as estimated by the critic """ state = torch.as_tensor(state).float() + if self.critic.val_type == "Qsa": - features = self.shared(state[:, :, :-1]) - features = torch.cat([features, state[:, :, -1].unsqueeze(-1)], dim=-1) - value = self.critic(features).float().squeeze(-1) + # state shape = [batch_size, number of vec envs, (state_dim + action_dim)] + + # extract shared_features from just the state + # state[:, :, :-action_dim] -> [batch_size, number of vec envs, state_dim] + shared_features = self.shared_network(state[:, :, : -self.action_dim]) + + # concatenate the actions to the extracted shared_features + # state[:, :, -action_dim:] -> [batch_size, number of vec envs, action_dim] + shared_features = torch.cat( + [shared_features, state[:, :, -self.action_dim :]], dim=-1 + ) + + value = self.critic(shared_features).float().squeeze(-1) else: - features = self.shared(state) - value = self.critic(features) + shared_features = self.shared_network(state) + value = self.critic(shared_features) return value -class MlpSingleActorMultiCritic(BaseActorCritic): +class MlpSingleActorTwoCritic(BaseActorCritic): """MLP Actor Critic Attributes: @@ -175,7 +190,7 @@ def __init__( num_critics: int = 2, **kwargs, ): - super(MlpSingleActorMultiCritic, self).__init__() + super(MlpSingleActorTwoCritic, self).__init__() self.num_critics = num_critics @@ -264,7 +279,7 @@ def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor: return values -class MlpSharedSingleActorMultiCritic(MlpSingleActorMultiCritic): +class MlpSharedSingleActorTwoCritic(MlpSingleActorTwoCritic): """MLP Actor Critic Attributes: @@ -292,7 +307,7 @@ def __init__( num_critics: int = 2, **kwargs, ): - super(MlpSharedSingleActorMultiCritic, self).__init__( + super(MlpSharedSingleActorTwoCritic, self).__init__( shared_layers[-1], action_dim, policy_layers, @@ -302,7 +317,19 @@ def __init__( num_critics, **kwargs, ) - self.shared = mlp([state_dim] + list(shared_layers)) + self.shared_network = mlp([state_dim] + list(shared_layers)) + self.action_dim = action_dim + + def get_params(self): + actor_params = list(self.shared_network.parameters()) + list( + self.actor.parameters() + ) + critic_params = ( + list(self.shared_network.parameters()) + + list(self.critic1.parameters()) + + list(self.critic2.parameters()) + ) + return actor_params, critic_params def get_features(self, state: torch.Tensor): """Extract features from the state, which is then an input to get_action and get_value @@ -313,7 +340,7 @@ def get_features(self, state: torch.Tensor): Returns: features (:obj:`torch.Tensor`): The feature(s) extracted from the state """ - features = self.shared(state) + features = self.shared_network(state) return features def get_action(self, state: torch.Tensor, deterministic: bool = False): @@ -326,9 +353,9 @@ def get_action(self, state: torch.Tensor, deterministic: bool = False): Returns: action (:obj:`list`): List of actions as estimated by the critic distribution (): The distribution from which the action was sampled - (None if determinist + (None if deterministic) """ - return super(MlpSharedSingleActorMultiCritic, self).get_action( + return super(MlpSharedSingleActorTwoCritic, self).get_action( self.get_features(state), deterministic=deterministic ) @@ -346,9 +373,16 @@ def get_value(self, state: torch.Tensor, mode="first"): values (:obj:`list`): List of values as estimated by each individual critic """ state = torch.as_tensor(state).float() - x = self.get_features(state[:, :, :-1]) - state = torch.cat([x, state[:, :, -1].unsqueeze(-1)], dim=-1) - return super(MlpSharedSingleActorMultiCritic, self).get_value(state, mode) + # state shape = [batch_size, number of vec envs, (state_dim + action_dim)] + + # extract shard features for just the state + # state[:, :, :-action_dim] -> [batch_size, number of vec envs, state_dim] + x = self.get_features(state[:, :, : -self.action_dim]) + + # concatenate the actions to the extracted shared features + # state[:, :, -action_dim:] -> [batch_size, number of vec envs, action_dim] + state = torch.cat([x, state[:, :, -self.action_dim :]], dim=-1) + return super(MlpSharedSingleActorTwoCritic, self).get_value(state, mode) class CNNActorCritic(BaseActorCritic): @@ -438,9 +472,9 @@ def get_value(self, inp: torch.Tensor) -> torch.Tensor: actor_critic_registry = { "mlp": MlpActorCritic, "cnn": CNNActorCritic, - "mlp12": MlpSingleActorMultiCritic, + "mlp12": MlpSingleActorTwoCritic, "mlps": MlpSharedActorCritic, - "mlp12s": MlpSharedSingleActorMultiCritic, + "mlp12s": MlpSharedSingleActorTwoCritic, } From bf71710a386b4e76ce2d26e83531f68daa5cbf55 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Sat, 12 Sep 2020 13:00:52 +0530 Subject: [PATCH 07/16] Adding tutorial --- .isort.cfg | 2 +- .pre-commit-config.yaml | 2 +- ...ared parameters in actor critic agents.rst | 70 +++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 docs/source/usage/tutorials/Using shared parameters in actor critic agents.rst diff --git a/.isort.cfg b/.isort.cfg index db6b8351..4b0feff5 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] -known_third_party = cv2,gym,matplotlib,numpy,pandas,pytest,scipy,setuptools,torch +known_third_party = cv2,gym,matplotlib,numpy,pandas,pytest,scipy,setuptools,toml,torch multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c68e57a..990fb2a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: args: [--exclude=^((examples|docs)/.*)$] - repo: https://github.com/timothycrosley/isort - rev: 4.3.2 + rev: 5.4.2 hooks: - id: isort diff --git a/docs/source/usage/tutorials/Using shared parameters in actor critic agents.rst b/docs/source/usage/tutorials/Using shared parameters in actor critic agents.rst new file mode 100644 index 00000000..70a7b764 --- /dev/null +++ b/docs/source/usage/tutorials/Using shared parameters in actor critic agents.rst @@ -0,0 +1,70 @@ +Using Shared Parameters in Actor Critic Agents in GenRL +======================================================= + +The Actor Critic Agents use two networks, an Actor network to select an action to be taken in the current state, and a +critic network, to estimate the value of the state the agent is currently in. There are two common ways to implement +this actor critic architecture. + +The first method - Indpendent Actor and critic networks - + +.. code-block:: none + + state + / \ + + / \ + action value + +And the second method - Using a set of shared parameters to extract a feature vector from the state. The actor and the +critic network act on this feature vector to select an action and estimate the value + +.. code-block:: none + + state + | + + / \ + + / \ + action value + +GenRL provides support to incorporte this decoder network in all of the actor critic agents through a ``shared_layers`` +parameter. ``shared_layers`` takes the sizes of the mlp layers o be used, and ``None`` if no decoder network is to be +used + +As an example - in A2C - +.. code-block:: python +# The imports +from genrl.agents import A2C +from genrl.environments import VectorEnv +from genrl.trainers import OnPolicyTrainer + +# Initializing the environment +env = VectorEnv("CartPole-v0", 1) + +# Initializing the agent to be used +algo = A2C( + "mlp", + env, + policy_layers=(128,), + value_layers=(128,), + shared_layers=(32, 64), + rollout_size=128, + ) + +# Finally initializing the trainer and trainer +trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) +trainer.train() + +The above example uses and mlp of layer sizes (32, 64) as the decoder, and can be visualised as follows - +.. code-block:: none + + state + | + <32> + | + <64> + / \ + <128> <128> + / \ + action value \ No newline at end of file From fc356b9aaa660f72e5566fc3e76cec29076920d8 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Sat, 12 Sep 2020 13:03:10 +0530 Subject: [PATCH 08/16] Small change --- .../Using shared parameters in actor critic agents.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/usage/tutorials/Using shared parameters in actor critic agents.rst b/docs/source/usage/tutorials/Using shared parameters in actor critic agents.rst index 70a7b764..82eb12b7 100644 --- a/docs/source/usage/tutorials/Using shared parameters in actor critic agents.rst +++ b/docs/source/usage/tutorials/Using shared parameters in actor critic agents.rst @@ -33,6 +33,7 @@ parameter. ``shared_layers`` takes the sizes of the mlp layers o be used, and `` used As an example - in A2C - + .. code-block:: python # The imports from genrl.agents import A2C @@ -57,6 +58,7 @@ trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1 trainer.train() The above example uses and mlp of layer sizes (32, 64) as the decoder, and can be visualised as follows - + .. code-block:: none state From 844c53da24fc4d65aa438cb8538f2bfec5498046 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Sun, 13 Sep 2020 11:52:26 +0530 Subject: [PATCH 09/16] Index --- docs/source/usage/tutorials/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/usage/tutorials/index.rst b/docs/source/usage/tutorials/index.rst index f0f4cb63..85875f44 100644 --- a/docs/source/usage/tutorials/index.rst +++ b/docs/source/usage/tutorials/index.rst @@ -9,5 +9,6 @@ Tutorials Deep/index Using Custom Policies Using A2C + Using shared parameters in actor critic agents using_vpg Saving and loading From a90e8d0e834cdbc6fe777dc0a50b01c651f9fc59 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Tue, 6 Oct 2020 03:20:36 +0530 Subject: [PATCH 10/16] CEM agent --- genrl/agents/__init__.py | 1 + genrl/agents/modelbased/__init__.py | 0 genrl/agents/modelbased/base.py | 66 +++++++++ genrl/agents/modelbased/cem/__init__.py | 0 genrl/agents/modelbased/cem/cem.py | 171 ++++++++++++++++++++++++ tests/test_deep/test_agents/test_cem.py | 12 ++ 6 files changed, 250 insertions(+) create mode 100644 genrl/agents/modelbased/__init__.py create mode 100644 genrl/agents/modelbased/base.py create mode 100644 genrl/agents/modelbased/cem/__init__.py create mode 100644 genrl/agents/modelbased/cem/cem.py create mode 100644 tests/test_deep/test_agents/test_cem.py diff --git a/genrl/agents/__init__.py b/genrl/agents/__init__.py index f1089ff9..48fdd51a 100644 --- a/genrl/agents/__init__.py +++ b/genrl/agents/__init__.py @@ -40,5 +40,6 @@ from genrl.agents.deep.sac.sac import SAC # noqa from genrl.agents.deep.td3.td3 import TD3 # noqa from genrl.agents.deep.vpg.vpg import VPG # noqa +from genrl.agents.modelbased.base import ModelBasedAgent from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa diff --git a/genrl/agents/modelbased/__init__.py b/genrl/agents/modelbased/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/modelbased/base.py b/genrl/agents/modelbased/base.py new file mode 100644 index 00000000..01f50018 --- /dev/null +++ b/genrl/agents/modelbased/base.py @@ -0,0 +1,66 @@ +from abc import ABC + +import numpy as np +import torch + + +class Planner: + def __init__(self, initial_state, dynamics_model=None): + if dynamics_model is not None: + self.dynamics_model = dynamics_model + self.initial_state = initial_state + + def _learn_dynamics_model(self, state): + raise NotImplementedError + + def plan(self): + raise NotImplementedError + + def execute_actions(self): + raise NotImplementedError + + +class ModelBasedAgent(ABC): + def __init__(self, env, planner=None, render=False, device="cpu"): + self.env = env + self.planner = planner + self.render = render + self.device = torch.device(device) + + def plan(self): + """ + To be used to plan out a sequence of actions + """ + if self.planner is not None: + raise ValueError("Provide a planner to plan for the environment") + self.planner.plan() + + def generate_data(self): + """ + To be used to generate synthetic data via a model (may be learnt or specified beforehand) + """ + raise NotImplementedError + + def value_equivalence(self, state_space): + """ + To be used for approximate value estimation methods e.g. Value Iteration Networks + """ + raise NotImplementedError + + def update_params(self): + """ + Update the parameters (Parameters of the learnt model and/or Parameters of the policy being used) + """ + raise NotImplementedError + + def get_hyperparans(self): + raise NotImplementedError + + def get_logging_params(self): + raise NotImplementedError + + def _load_weights(self, weights): + raise NotImplementedError + + def empty_logs(self): + raise NotImplementedError diff --git a/genrl/agents/modelbased/cem/__init__.py b/genrl/agents/modelbased/cem/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/modelbased/cem/cem.py b/genrl/agents/modelbased/cem/cem.py new file mode 100644 index 00000000..0b0dcf61 --- /dev/null +++ b/genrl/agents/modelbased/cem/cem.py @@ -0,0 +1,171 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +from genrl.agents import ModelBasedAgent +from genrl.core import RolloutBuffer +from genrl.utils import get_env_properties, get_model, safe_mean + + +class CEM(ModelBasedAgent): + def __init__( + self, + *args, + network: str = "mlp", + policy_layers: tuple = (100,), + percentile: int = 70, + **kwargs + ): + super(CEM, self).__init__(*args, **kwargs) + self.network = network + self.rollout_size = int(1e4) + self.rollout = RolloutBuffer(self.rollout_size, self.env) + self.policy_layers = policy_layers + self.percentile = percentile + + self._create_model() + self.empty_logs() + + def _create_model(self): + self.state_dim, self.action_dim, discrete, action_lim = get_env_properties( + self.env, self.network + ) + self.agent = get_model("p", self.network)( + self.state_dim, + self.action_dim, + self.policy_layers, + "V", + discrete, + action_lim, + ) + self.optim = torch.optim.Adam( + self.agent.parameters(), lr=1e-3 + ) # make this a hyperparam + + def plan(self, timesteps=1e4): + state = self.env.reset() + self.rollout.reset() + _, _ = self.collect_rollouts(state) + return ( + self.rollout.observations, + self.rollout.actions, + torch.sum(self.rollout.rewards).detach(), + ) + + def select_elites(self, states_batch, actions_batch, rewards_batch): + reward_threshold = np.percentile(rewards_batch, self.percentile) + print(reward_threshold) + elite_states = [ + s.unsqueeze(0) + for i in range(len(states_batch)) + if rewards_batch[i] >= reward_threshold + for s in states_batch[i] + ] + elite_actions = [ + a.unsqueeze(0) + for i in range(len(actions_batch)) + if rewards_batch[i] >= reward_threshold + for a in actions_batch[i] + ] + + return torch.cat(elite_states, dim=0), torch.cat(elite_actions, dim=0) + + def select_action(self, state): + state = torch.as_tensor(state).float() + action, dist = self.agent.get_action(state) + return action + + def update_params(self): + sess = [self.plan() for _ in range(100)] + batch_states, batch_actions, batch_rewards = zip(*sess) + elite_states, elite_actions = self.select_elites( + batch_states, batch_actions, batch_rewards + ) + print(elite_actions.shape) + action_probs = self.agent.forward(torch.as_tensor(elite_states).float()) + print(action_probs.shape) + print(self.action_dim) + loss = F.cross_entropy( + action_probs.view(-1, self.action_dim), + torch.as_tensor(elite_actions).long().view(-1), + ) + self.logs["crossentropy_loss"].append(loss.item()) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5) + self.optim.step() + + def get_traj_loss(self, values, dones): + # No need for this here + pass + + def collect_rollouts(self, state: torch.Tensor): + """Function to collect rollouts + + Collects rollouts by playing the env like a human agent and inputs information into + the rollout buffer. + + Args: + state (:obj:`torch.Tensor`): The starting state of the environment + + Returns: + values (:obj:`torch.Tensor`): Values of states encountered during the rollout + dones (:obj:`torch.Tensor`): Game over statuses of each environment + """ + for i in range(self.rollout_size): + action = self.select_action(state) + + next_state, reward, dones, _ = self.env.step(action) + + if self.render: + self.env.render() + + self.rollout.add( + state, + action.reshape(self.env.n_envs, 1), + reward, + dones, + torch.tensor(0), + torch.tensor(0), + ) + + state = next_state + + self.collect_rewards(dones, i) + + if dones: + break + + return torch.tensor(0), dones + + def collect_rewards(self, dones: torch.Tensor, timestep: int): + """Helper function to collect rewards + + Runs through all the envs and collects rewards accumulated during rollouts + + Args: + dones (:obj:`torch.Tensor`): Game over statuses of each environment + timestep (int): Timestep during rollout + """ + for i, done in enumerate(dones): + if done or timestep == self.rollout_size - 1: + self.rewards.append(self.env.episode_reward[i].detach().clone()) + self.env.reset_single_env(i) + + def get_hyperparams(self): + # return self.agent.get_hyperparams() + pass + + def get_logging_params(self): + logs = { + "crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]), + "mean_reward": safe_mean(self.rewards), + } + return logs + + def empty_logs(self): + # self.agent.empty_logs() + self.logs = {} + self.logs["crossentropy_loss"] = [] + self.rewards = [] diff --git a/tests/test_deep/test_agents/test_cem.py b/tests/test_deep/test_agents/test_cem.py new file mode 100644 index 00000000..51c3c9a4 --- /dev/null +++ b/tests/test_deep/test_agents/test_cem.py @@ -0,0 +1,12 @@ +import gym + +from genrl.agents.modelbased.cem.cem import CEM +from genrl.environments import VectorEnv +from genrl.trainers import OnPolicyTrainer + + +def test_CEM(): + env = VectorEnv("CartPole-v0", 1) + algo = CEM(env, percentile=70, policy_layers=[100]) + trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) + trainer.train() From 3b2067d59710c0865a01b9588d3be1e4177273ac Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Thu, 15 Oct 2020 16:16:17 +0530 Subject: [PATCH 11/16] Training CEM without rollouts --- genrl/agents/modelbased/cem/cem.py | 64 +++++++++++++------------ tests/test_deep/test_agents/test_cem.py | 2 +- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/genrl/agents/modelbased/cem/cem.py b/genrl/agents/modelbased/cem/cem.py index 0b0dcf61..31b0404f 100644 --- a/genrl/agents/modelbased/cem/cem.py +++ b/genrl/agents/modelbased/cem/cem.py @@ -5,7 +5,7 @@ import torch.optim as optim from genrl.agents import ModelBasedAgent -from genrl.core import RolloutBuffer +from genrl.core import MlpPolicy, RolloutBuffer from genrl.utils import get_env_properties, get_model, safe_mean @@ -15,14 +15,17 @@ def __init__( *args, network: str = "mlp", policy_layers: tuple = (100,), + lr_policy=1e-3, percentile: int = 70, + rollout_size, **kwargs ): super(CEM, self).__init__(*args, **kwargs) self.network = network - self.rollout_size = int(1e4) + self.rollout_size = rollout_size self.rollout = RolloutBuffer(self.rollout_size, self.env) self.policy_layers = policy_layers + self.lr_policy = lr_policy self.percentile = percentile self._create_model() @@ -40,31 +43,24 @@ def _create_model(self): discrete, action_lim, ) - self.optim = torch.optim.Adam( - self.agent.parameters(), lr=1e-3 - ) # make this a hyperparam + self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy) - def plan(self, timesteps=1e4): + def plan(self): state = self.env.reset() self.rollout.reset() - _, _ = self.collect_rollouts(state) - return ( - self.rollout.observations, - self.rollout.actions, - torch.sum(self.rollout.rewards).detach(), - ) + states, actions = self.collect_rollouts(state) + return (states, actions, self.rewards[-1]) def select_elites(self, states_batch, actions_batch, rewards_batch): reward_threshold = np.percentile(rewards_batch, self.percentile) - print(reward_threshold) elite_states = [ - s.unsqueeze(0) + s.unsqueeze(0).clone() for i in range(len(states_batch)) if rewards_batch[i] >= reward_threshold for s in states_batch[i] ] elite_actions = [ - a.unsqueeze(0) + a.unsqueeze(0).clone() for i in range(len(actions_batch)) if rewards_batch[i] >= reward_threshold for a in actions_batch[i] @@ -75,25 +71,24 @@ def select_elites(self, states_batch, actions_batch, rewards_batch): def select_action(self, state): state = torch.as_tensor(state).float() action, dist = self.agent.get_action(state) - return action + return action, torch.zeros((1, self.env.n_envs)), dist.log_prob(action).cpu() def update_params(self): sess = [self.plan() for _ in range(100)] - batch_states, batch_actions, batch_rewards = zip(*sess) + batch_states, batch_actions, batch_rewards = zip( + *sess + ) # map(np.array, zip(*sess)) elite_states, elite_actions = self.select_elites( batch_states, batch_actions, batch_rewards ) - print(elite_actions.shape) - action_probs = self.agent.forward(torch.as_tensor(elite_states).float()) - print(action_probs.shape) - print(self.action_dim) + action_probs = self.agent.forward(elite_states.float()) loss = F.cross_entropy( action_probs.view(-1, self.action_dim), - torch.as_tensor(elite_actions).long().view(-1), + elite_actions.long().view(-1), ) self.logs["crossentropy_loss"].append(loss.item()) loss.backward() - torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5) + # torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5) self.optim.step() def get_traj_loss(self, values, dones): @@ -113,8 +108,12 @@ def collect_rollouts(self, state: torch.Tensor): values (:obj:`torch.Tensor`): Values of states encountered during the rollout dones (:obj:`torch.Tensor`): Game over statuses of each environment """ + states = [] + actions = [] for i in range(self.rollout_size): - action = self.select_action(state) + action, value, log_probs = self.select_action(state) + states.append(state) + actions.append(action) next_state, reward, dones, _ = self.env.step(action) @@ -126,8 +125,8 @@ def collect_rollouts(self, state: torch.Tensor): action.reshape(self.env.n_envs, 1), reward, dones, - torch.tensor(0), - torch.tensor(0), + value, + log_probs.detach(), ) state = next_state @@ -137,7 +136,7 @@ def collect_rollouts(self, state: torch.Tensor): if dones: break - return torch.tensor(0), dones + return states, actions def collect_rewards(self, dones: torch.Tensor, timestep: int): """Helper function to collect rewards @@ -151,11 +150,15 @@ def collect_rewards(self, dones: torch.Tensor, timestep: int): for i, done in enumerate(dones): if done or timestep == self.rollout_size - 1: self.rewards.append(self.env.episode_reward[i].detach().clone()) - self.env.reset_single_env(i) + # self.env.reset_single_env(i) def get_hyperparams(self): - # return self.agent.get_hyperparams() - pass + hyperparams = { + "network": self.network, + "lr_policy": self.lr_policy, + "rollout_size": self.rollout_size, + } + return hyperparams def get_logging_params(self): logs = { @@ -165,7 +168,6 @@ def get_logging_params(self): return logs def empty_logs(self): - # self.agent.empty_logs() self.logs = {} self.logs["crossentropy_loss"] = [] self.rewards = [] diff --git a/tests/test_deep/test_agents/test_cem.py b/tests/test_deep/test_agents/test_cem.py index 51c3c9a4..6d1cf215 100644 --- a/tests/test_deep/test_agents/test_cem.py +++ b/tests/test_deep/test_agents/test_cem.py @@ -7,6 +7,6 @@ def test_CEM(): env = VectorEnv("CartPole-v0", 1) - algo = CEM(env, percentile=70, policy_layers=[100]) + algo = CEM(env, percentile=70, policy_layers=[100], rollout_size=100) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() From f86b0467f1ebd98038ef864d5b294169f1945433 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Fri, 16 Oct 2020 00:08:46 +0530 Subject: [PATCH 12/16] Fix Codacy (1) --- genrl/agents/__init__.py | 3 ++- genrl/agents/modelbased/base.py | 1 - genrl/agents/modelbased/cem/cem.py | 8 ++------ tests/test_deep/test_agents/test_cem.py | 4 +--- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/genrl/agents/__init__.py b/genrl/agents/__init__.py index 34be594f..f19f6432 100644 --- a/genrl/agents/__init__.py +++ b/genrl/agents/__init__.py @@ -41,6 +41,7 @@ from genrl.agents.deep.sac.sac import SAC # noqa from genrl.agents.deep.td3.td3 import TD3 # noqa from genrl.agents.deep.vpg.vpg import VPG # noqa -from genrl.agents.modelbased.base import ModelBasedAgent +from genrl.agents.modelbased.base import ModelBasedAgent # noqa +from genrl.agents.modelbased.cem.cem import CEM # noqa from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa diff --git a/genrl/agents/modelbased/base.py b/genrl/agents/modelbased/base.py index 01f50018..ad9949c9 100644 --- a/genrl/agents/modelbased/base.py +++ b/genrl/agents/modelbased/base.py @@ -1,6 +1,5 @@ from abc import ABC -import numpy as np import torch diff --git a/genrl/agents/modelbased/cem/cem.py b/genrl/agents/modelbased/cem/cem.py index 31b0404f..8944a9d0 100644 --- a/genrl/agents/modelbased/cem/cem.py +++ b/genrl/agents/modelbased/cem/cem.py @@ -1,11 +1,9 @@ import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F -import torch.optim as optim from genrl.agents import ModelBasedAgent -from genrl.core import MlpPolicy, RolloutBuffer +from genrl.core import RolloutBuffer from genrl.utils import get_env_properties, get_model, safe_mean @@ -75,9 +73,7 @@ def select_action(self, state): def update_params(self): sess = [self.plan() for _ in range(100)] - batch_states, batch_actions, batch_rewards = zip( - *sess - ) # map(np.array, zip(*sess)) + batch_states, batch_actions, batch_rewards = zip(*sess) elite_states, elite_actions = self.select_elites( batch_states, batch_actions, batch_rewards ) diff --git a/tests/test_deep/test_agents/test_cem.py b/tests/test_deep/test_agents/test_cem.py index 6d1cf215..f0d230f5 100644 --- a/tests/test_deep/test_agents/test_cem.py +++ b/tests/test_deep/test_agents/test_cem.py @@ -1,6 +1,4 @@ -import gym - -from genrl.agents.modelbased.cem.cem import CEM +from genrl.agents import CEM from genrl.environments import VectorEnv from genrl.trainers import OnPolicyTrainer From f5a189d6f8fd35dade60b1573983b692382ef0f8 Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Sun, 18 Oct 2020 02:10:44 +0530 Subject: [PATCH 13/16] Docstrings --- genrl/agents/modelbased/base.py | 28 +--- genrl/agents/modelbased/cem/cem.py | 142 +++++++++++++----- tests/test_agents/test_modelbased/__init__.py | 1 + tests/test_agents/test_modelbased/test_cem.py | 23 +++ tests/test_deep/test_agents/test_cem.py | 10 -- 5 files changed, 133 insertions(+), 71 deletions(-) create mode 100644 tests/test_agents/test_modelbased/__init__.py create mode 100644 tests/test_agents/test_modelbased/test_cem.py delete mode 100644 tests/test_deep/test_agents/test_cem.py diff --git a/genrl/agents/modelbased/base.py b/genrl/agents/modelbased/base.py index ad9949c9..cb11e020 100644 --- a/genrl/agents/modelbased/base.py +++ b/genrl/agents/modelbased/base.py @@ -2,6 +2,8 @@ import torch +from genrl.agents import BaseAgent + class Planner: def __init__(self, initial_state, dynamics_model=None): @@ -19,12 +21,10 @@ def execute_actions(self): raise NotImplementedError -class ModelBasedAgent(ABC): - def __init__(self, env, planner=None, render=False, device="cpu"): - self.env = env +class ModelBasedAgent(BaseAgent): + def __init__(self, *args, planner=None, **kwargs): + super(ModelBasedAgent, self).__init__(*args, **kwargs) self.planner = planner - self.render = render - self.device = torch.device(device) def plan(self): """ @@ -45,21 +45,3 @@ def value_equivalence(self, state_space): To be used for approximate value estimation methods e.g. Value Iteration Networks """ raise NotImplementedError - - def update_params(self): - """ - Update the parameters (Parameters of the learnt model and/or Parameters of the policy being used) - """ - raise NotImplementedError - - def get_hyperparans(self): - raise NotImplementedError - - def get_logging_params(self): - raise NotImplementedError - - def _load_weights(self, weights): - raise NotImplementedError - - def empty_logs(self): - raise NotImplementedError diff --git a/genrl/agents/modelbased/cem/cem.py b/genrl/agents/modelbased/cem/cem.py index 8944a9d0..a90af5e7 100644 --- a/genrl/agents/modelbased/cem/cem.py +++ b/genrl/agents/modelbased/cem/cem.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import numpy as np import torch import torch.nn.functional as F @@ -8,13 +10,28 @@ class CEM(ModelBasedAgent): + """Cross Entropy method algorithm (CEM) + + Attributes: + network (str): The type of network to be used + env (Environment): The environment the agent is supposed to act on + create_model (bool): Whether the model of the algo should be created when initialised + policy_layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network of the policy + lr_policy (float): learning rate of the policy + percentile (float): Top percentile of rewards to consider as elite + simulations_per_epoch (int): Number of simulations to perform before taking a gradient step + rollout_size (int): Capacity of the replay buffer + render (bool): Whether to render the environment or not + device (str): Hardware being used for training. Options: + ["cuda" -> GPU, "cpu" -> CPU] + """ + def __init__( self, *args, network: str = "mlp", - policy_layers: tuple = (100,), - lr_policy=1e-3, - percentile: int = 70, + percentile: float = 70, + simulations_per_epoch: int = 1000, rollout_size, **kwargs ): @@ -22,14 +39,17 @@ def __init__( self.network = network self.rollout_size = rollout_size self.rollout = RolloutBuffer(self.rollout_size, self.env) - self.policy_layers = policy_layers - self.lr_policy = lr_policy self.percentile = percentile + self.simulations_per_epoch = simulations_per_epoch self._create_model() self.empty_logs() def _create_model(self): + """Function to initialize the Policy + + This will create the Policy net for the CEM agent + """ self.state_dim, self.action_dim, discrete, action_lim = get_env_properties( self.env, self.network ) @@ -44,35 +64,74 @@ def _create_model(self): self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy) def plan(self): + """Function to plan out one episode + + Returns: + states (:obj:`list` of :obj:`torch.Tensor`): Batch of states the agent encountered in the episode + actions (:obj:`list` of :obj:`torch.Tensor`): Batch of actions the agent took in the episode + rewards (:obj:`torch.Tensor`): The episode reward obtained + """ state = self.env.reset() self.rollout.reset() states, actions = self.collect_rollouts(state) return (states, actions, self.rewards[-1]) def select_elites(self, states_batch, actions_batch, rewards_batch): + """Function to select the elite states and elite actions based on the episode reward + + Args: + states_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of states + actions_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of actions + rewards_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of rewards + + Returns: + elite_states (:obj:`torch.Tensor`): Elite batch of states based on episode reward + elite_actions (:obj:`torch.Tensor`): Actions the agent took during the elite batch of states + + """ reward_threshold = np.percentile(rewards_batch, self.percentile) - elite_states = [ - s.unsqueeze(0).clone() - for i in range(len(states_batch)) - if rewards_batch[i] >= reward_threshold - for s in states_batch[i] - ] - elite_actions = [ - a.unsqueeze(0).clone() - for i in range(len(actions_batch)) - if rewards_batch[i] >= reward_threshold - for a in actions_batch[i] - ] - - return torch.cat(elite_states, dim=0), torch.cat(elite_actions, dim=0) + elite_states = torch.cat( + [ + s.unsqueeze(0).clone() + for i in range(len(states_batch)) + if rewards_batch[i] >= reward_threshold + for s in states_batch[i] + ], + dim=0, + ) + elite_actions = torch.cat( + [ + a.unsqueeze(0).clone() + for i in range(len(actions_batch)) + if rewards_batch[i] >= reward_threshold + for a in actions_batch[i] + ], + dim=0, + ) + + return elite_states, elite_actions def select_action(self, state): + """Select action given state + + Action selection policy for the Cross Entropy agent + + Args: + state (:obj:`torch.Tensor`): Current state of the agent + + Returns: + action (:obj:`torch.Tensor`): Action taken by the agent + """ state = torch.as_tensor(state).float() action, dist = self.agent.get_action(state) - return action, torch.zeros((1, self.env.n_envs)), dist.log_prob(action).cpu() + return action def update_params(self): - sess = [self.plan() for _ in range(100)] + """Updates the the Policy network of the CEM agent + + Function to update the policy network + """ + sess = [self.plan() for _ in range(self.simulations_per_epoch)] batch_states, batch_actions, batch_rewards = zip(*sess) elite_states, elite_actions = self.select_elites( batch_states, batch_actions, batch_rewards @@ -101,13 +160,13 @@ def collect_rollouts(self, state: torch.Tensor): state (:obj:`torch.Tensor`): The starting state of the environment Returns: - values (:obj:`torch.Tensor`): Values of states encountered during the rollout - dones (:obj:`torch.Tensor`): Game over statuses of each environment + states (:obj:`list`): list of states the agent encountered during the episode + actions (:obj:`list`): list of actions the agent took in the corresponding states """ states = [] actions = [] for i in range(self.rollout_size): - action, value, log_probs = self.select_action(state) + action = self.select_action(state) states.append(state) actions.append(action) @@ -116,20 +175,11 @@ def collect_rollouts(self, state: torch.Tensor): if self.render: self.env.render() - self.rollout.add( - state, - action.reshape(self.env.n_envs, 1), - reward, - dones, - value, - log_probs.detach(), - ) - state = next_state self.collect_rewards(dones, i) - if dones: + if torch.any(dones.byte()): break return states, actions @@ -146,24 +196,40 @@ def collect_rewards(self, dones: torch.Tensor, timestep: int): for i, done in enumerate(dones): if done or timestep == self.rollout_size - 1: self.rewards.append(self.env.episode_reward[i].detach().clone()) - # self.env.reset_single_env(i) - def get_hyperparams(self): + def get_hyperparams(self) -> Dict[str, Any]: + """Get relevant hyperparameters to save + + Returns: + hyperparams (:obj:`dict`): Hyperparameters to be saved + weights (:obj:`torch.Tensor`): Neural network weights + """ hyperparams = { "network": self.network, "lr_policy": self.lr_policy, "rollout_size": self.rollout_size, } - return hyperparams + return hyperparams, self.agent.state_dict() + + def _load_weights(self, weights) -> None: + self.agent.load_state_dict(weights) + + def get_logging_params(self) -> Dict[str, Any]: + """Gets relevant parameters for logging - def get_logging_params(self): + Returns: + logs (:obj:`dict`): Logging parameters for monitoring training + """ logs = { "crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]), "mean_reward": safe_mean(self.rewards), } + + self.empty_logs() return logs def empty_logs(self): + """Empties logs""" self.logs = {} self.logs["crossentropy_loss"] = [] self.rewards = [] diff --git a/tests/test_agents/test_modelbased/__init__.py b/tests/test_agents/test_modelbased/__init__.py new file mode 100644 index 00000000..08c59ec8 --- /dev/null +++ b/tests/test_agents/test_modelbased/__init__.py @@ -0,0 +1 @@ +from tests.test_agents.test_modelbased.test_cem import TestCEM diff --git a/tests/test_agents/test_modelbased/test_cem.py b/tests/test_agents/test_modelbased/test_cem.py new file mode 100644 index 00000000..5cf76779 --- /dev/null +++ b/tests/test_agents/test_modelbased/test_cem.py @@ -0,0 +1,23 @@ +import shutil + +from genrl.agents import CEM +from genrl.environments import VectorEnv +from genrl.trainers import OnPolicyTrainer + + +class TestCEM: + def test_CEM(self): + env = VectorEnv("CartPole-v0", 1) + algo = CEM( + "mlp", + env, + percentile=70, + policy_layers=[100], + rollout_size=100, + simulations_per_epoch=100, + ) + trainer = OnPolicyTrainer( + algo, env, log_mode=["csv"], logdir="./logs", epochs=1 + ) + trainer.train() + shutil.rmtree("./logs") diff --git a/tests/test_deep/test_agents/test_cem.py b/tests/test_deep/test_agents/test_cem.py deleted file mode 100644 index f0d230f5..00000000 --- a/tests/test_deep/test_agents/test_cem.py +++ /dev/null @@ -1,10 +0,0 @@ -from genrl.agents import CEM -from genrl.environments import VectorEnv -from genrl.trainers import OnPolicyTrainer - - -def test_CEM(): - env = VectorEnv("CartPole-v0", 1) - algo = CEM(env, percentile=70, policy_layers=[100], rollout_size=100) - trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) - trainer.train() From 4b11c16fd8b353e115fe7cded8db7c71f0d2f0eb Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Wed, 21 Oct 2020 21:47:57 +0530 Subject: [PATCH 14/16] Adding device --- genrl/agents/modelbased/cem/cem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/genrl/agents/modelbased/cem/cem.py b/genrl/agents/modelbased/cem/cem.py index a90af5e7..5451bb96 100644 --- a/genrl/agents/modelbased/cem/cem.py +++ b/genrl/agents/modelbased/cem/cem.py @@ -60,7 +60,7 @@ def _create_model(self): "V", discrete, action_lim, - ) + ).to(self.device) self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy) def plan(self): @@ -136,7 +136,7 @@ def update_params(self): elite_states, elite_actions = self.select_elites( batch_states, batch_actions, batch_rewards ) - action_probs = self.agent.forward(elite_states.float()) + action_probs = self.agent.forward(elite_states.float().to(self.device)) loss = F.cross_entropy( action_probs.view(-1, self.action_dim), elite_actions.long().view(-1), From a93a094344e3caac84be7b4a6cbedbaa47b63b4b Mon Sep 17 00:00:00 2001 From: hades-rp2010 Date: Sat, 14 Nov 2020 16:15:24 +0530 Subject: [PATCH 15/16] [WIP] Modular structure for MCTS --- genrl/agents/deep/a2c/a2c.py | 4 +- genrl/agents/modelbased/mcts/__init__.py | 0 genrl/agents/modelbased/mcts/base.py | 117 ++++++++++++++++ genrl/agents/modelbased/mcts/mcts.py | 170 +++++++++++++++++++++++ genrl/agents/modelbased/mcts/uct.py | 5 + 5 files changed, 294 insertions(+), 2 deletions(-) create mode 100644 genrl/agents/modelbased/mcts/__init__.py create mode 100644 genrl/agents/modelbased/mcts/base.py create mode 100644 genrl/agents/modelbased/mcts/mcts.py create mode 100644 genrl/agents/modelbased/mcts/uct.py diff --git a/genrl/agents/deep/a2c/a2c.py b/genrl/agents/deep/a2c/a2c.py index 1f94992e..9baa2bab 100644 --- a/genrl/agents/deep/a2c/a2c.py +++ b/genrl/agents/deep/a2c/a2c.py @@ -95,8 +95,8 @@ def _create_model(self) -> None: ) actor_params, critic_params = self.ac.get_params() - self.optimizer_policy = opt.Adam(critic_params, lr=self.lr_policy) - self.optimizer_value = opt.Adam(actor_params, lr=self.lr_value) + self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) + self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) def select_action( self, state: torch.Tensor, deterministic: bool = False diff --git a/genrl/agents/modelbased/mcts/__init__.py b/genrl/agents/modelbased/mcts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/modelbased/mcts/base.py b/genrl/agents/modelbased/mcts/base.py new file mode 100644 index 00000000..28bfda43 --- /dev/null +++ b/genrl/agents/modelbased/mcts/base.py @@ -0,0 +1,117 @@ +import numpy as np + +from genrl.agents.modelbased.base import ModelBasedAgent, Planner + + +class TreePlanner(Planner): + def __init__(self): + self.root = None + self.observations = [] + self.horizon = horizon + self.reset() + + def plan(self, state, obs): + raise NotImplementedError() + + def get_plan(self): + actions = [] + node = self.root + while node.children: + action = node.selection_rule() + actions.append(action) + node = node.children[action] + return actions + + def step(self, state, action): + obs, reward, done, info = state.step(action) + self.observations.append(obs) + return obs, reward, done, info + + def step_tree(self, actions): + if self.strategy == "reset": + self.reset() + elif self.strategy == "subtree": + if actions: + self._step_by_subtree(actions[0]) + else: + self.reset() + else: + raise NotImplementedError + + def _step_by_subtree(self, action): + if action in self.root.children: + self.root = self.root.children[action] + self.root.parent = None + else: + self.reset() + + def get_visits(self): + visits = {} + for obs in self.observations: + if str(obs) not in visits.keys(): + visits[str(obs)] = 0 + visits[str(obs)] += 1 + + def reset(): + raise NotImplementedError + + +class Node: + def __init__(self, parent, planner): + self.parent = parent + self.planner = planner + self.children = {} + + self.visits = 0 + + def get_value(self): + raise NotImplementedError + + def expand(self, branch_factor): + self.children[a] = Node(self, planner) + + def selection_rule(self): + raise NotImplementedError + + def is_leaf(self): + return not self.children + + +class TreeSearchAgent(ModelBasedAgent): + def __init__(self, *args, horizon, **kwargs): + super(TreeSearchAgent, self).__init__(*args, **kwargs) + self.planner = self._make_planner() + self.prev_actions = [] + self.horizon = horizon + self.remaining_horizon = 0 + self.steps = 0 + + def _create_planner(self): + pass + + def plan(self, obs): + self.steps += 1 + replan = self._replan(self.prev_actions) + if replan: + env = self.env + actions = self.planner.plan(state=env, obs=obs) + else: + actions = self.prev_actions[1:] + + self.prev_actions = actions + return actions + + def _replan(self, actions): + replan = self.remaining_horizon == 0 or len(actions) <= 1 + if replan: + self.remaining_horizon = self.horizon + else: + self.remaining_horizon -= 1 + + self.planner.step_tree(actions) + return replan + + def reset(self): + self.planner.reset() + self.remaining_horizon = 0 + self.steps = 0 diff --git a/genrl/agents/modelbased/mcts/mcts.py b/genrl/agents/modelbased/mcts/mcts.py new file mode 100644 index 00000000..aba7544a --- /dev/null +++ b/genrl/agents/modelbased/mcts/mcts.py @@ -0,0 +1,170 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.optim as opt + +from genrl.agents.modelbased.mcts.base import TreePlanner, TreeSearchAgent, Node + + +class MCTSAgent(TreeSearchAgent): + def __init__(self, *args, **kwargs): + super(MCTSAgent, self).__init__(*args, **kwargs) + self.planner = self._create_planner() + + def _create_planner(self): + prior_policy = None + rollout_policy = None + return MCTSPlanner(prior_policy, rollout_policy) + + def _create_model(self): + if isinstance(self.network, str): + arch_type = self.network + if self.shared_layers is not None: + arch_type += "s" + self.ac = get_model("ac", arch_type)( + state_dim, + action_dim, + shared_layers=self.shared_layers, + policy_layers=self.policy_layers, + value_layers=self.value_layers, + val_type="V", + discrete=discrete, + action_lim=action_lim, + ).to(self.device) + + actor_params, critic_params = self.ac.get_params() + self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) + self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) + + def select_action(self, state) -> torch.Tensor: + # return action.detach(), value, dist.log_prob.cpu() + pass + + def get_traj_loss(self) -> None: + pass + + def evaluate_actions(self, states, actions): + # return values, dist.log_prob(actions).cpu(), dist.entropy.cpu() + pass + + def update_params(self) -> None: + # Blah blah + # policy_loss = something + # value_loss = something + pass + + +class MCTSNode(Node): + def __init__(self, *args, prior, **kwargs): + super(MCTSNode, self).__init__(*args, **kwargs) + self.value = 0 + self.prior = prior + + def selection_rule(self): + if not self.children: + return None + actions = list(self.children.keys()) + counts = np.argmax([self.children[a] for a in actions]) + return actions[np.max(counts, key=(lambda i: self.children[actions[i]].get_value()))] + + def sampling_rule(self): + if self.children: + actions = list(self.children.keys()) + idx = [self.children[a].selection_strategy(temp) for a in actions] + return actions[] + return None + + def expand(self, actions_dist): + actions, probs = actions_dist + for i in range(len(actions)): + if actions[i] not in self.children: + self.children[actions[i]] = MCTSNode(parent=self, planner=self.planner, prior=probs[i]) + + def _update(self, total_rew): + self.count += 1 + self.value += 1 / self.count * (total_rew - self.value) + + def update_branch(self, total_rew): + self._update(total_rew) + if self.parent: + self.parent.update_branch(total_rew) + + def get_child(self, action, obs=None): + child = self.children[action] + if obs is not None: + if str(obs) not in child.children: + child.children[str(obs)] = MCTSNode(parent=child, planner=self.planner, prior=0) + child = child.children[str(obs)] + return child + + def selection_strategy(self, temp=0): + if not self.parent: + return self.get_value() + return self.get_value + temp * len(self.parent.children) * self.prior / (self.count-1) + + def get_value(self): + return self.value + + def convert_visits_to_prior(self, reg=0.5): + self.count = 0 + total_count = np.sum([(child.count + 1) for child in self.children]) + for child in self.children.values(): + child.prior = reg * (child.count + 1) / total_counts + reg / len(self.children) + child.convert_visits_to_prior() + + +class MCTSPlanner(TreePlanner): + def __init__(self, *args, prior, rollout_policy, episodes, **kwargs): + super(MCTSPlanner, self).__init__(*args, **kwargs) + self.env = env + self.prior = prior + self.rollout_policy = rollout_policy + self.gamma = gamma + self.episodes = episodes + + def reset(self): + self.root = MCTSNode(parent=None, planner=self) + + def _mc_search(self, state, obs): + # Runs one iteration of mcts + node = self.root + total_rew = 0 + depth = 0 + terminal = False + while depth < self.horizon and node.children and not terminal: + action = node.sampling_rule() # Not so sure about this + obs, reward, terminal, _ = self.step(state, action) + total_rew += self.gamma ** depth * reward + node_obs = obs + node = node.get_child(action, node_obs) + depth += 1 + + if not terminal: + total_rew = self.eval(state, obs, total_rew, depth=depth) + node.update_branch(total_rew) + + def eval(self, state, obs, total_rew=0, depth=0): + # Run the rollout policy to yeild a sample for the value + for h in range(depth, self.horizon): + actions, probs = self.rollout_policy(state, obs) + action = None # rew Select an action + obs, reward, terminal, _ = self.step(state, action) + total_ += self.gamma ** h * reward + if np.all(terminal): + break + return total_rew + + def plan(self, obs): + for i in range(self.episodes): + self._mc_search(copy.deepcopy(state), obs) + return self.get_plan() + + def step_planner(seld, action): + if self.step_strategy == "prior": + self._step_by_prior(action) + else: + super().step_planner(action) + + def _step_by_prior(self, action): + self._step_by_subtree(action) + self.root.convert_visits_to_prior() diff --git a/genrl/agents/modelbased/mcts/uct.py b/genrl/agents/modelbased/mcts/uct.py new file mode 100644 index 00000000..46f8fb78 --- /dev/null +++ b/genrl/agents/modelbased/mcts/uct.py @@ -0,0 +1,5 @@ +from genrl.agents.modelbased.mcts.mcts import MCTSNode + + +class UCTNode(MCTSNode): + pass \ No newline at end of file From 1b1bc6102d1b57cc89c20c1eaa8f56d526579569 Mon Sep 17 00:00:00 2001 From: hades-rp2010 Date: Sat, 14 Nov 2020 16:31:46 +0530 Subject: [PATCH 16/16] Adding UCTNode --- genrl/agents/modelbased/mcts/mcts.py | 31 +++++++++++++++++++--------- genrl/agents/modelbased/mcts/uct.py | 13 +++++++++++- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/genrl/agents/modelbased/mcts/mcts.py b/genrl/agents/modelbased/mcts/mcts.py index aba7544a..a803694b 100644 --- a/genrl/agents/modelbased/mcts/mcts.py +++ b/genrl/agents/modelbased/mcts/mcts.py @@ -3,14 +3,14 @@ import torch.nn as nn import torch.optim as opt -from genrl.agents.modelbased.mcts.base import TreePlanner, TreeSearchAgent, Node +from genrl.agents.modelbased.mcts.base import Node, TreePlanner, TreeSearchAgent class MCTSAgent(TreeSearchAgent): def __init__(self, *args, **kwargs): super(MCTSAgent, self).__init__(*args, **kwargs) self.planner = self._create_planner() - + def _create_planner(self): prior_policy = None rollout_policy = None @@ -65,20 +65,25 @@ def selection_rule(self): return None actions = list(self.children.keys()) counts = np.argmax([self.children[a] for a in actions]) - return actions[np.max(counts, key=(lambda i: self.children[actions[i]].get_value()))] + return actions[ + np.max(counts, key=(lambda i: self.children[actions[i]].get_value())) + ] def sampling_rule(self): if self.children: actions = list(self.children.keys()) idx = [self.children[a].selection_strategy(temp) for a in actions] - return actions[] + random_idx = np.random.choice(np.argmax(idx)) + return actions[random_idx] return None def expand(self, actions_dist): actions, probs = actions_dist for i in range(len(actions)): if actions[i] not in self.children: - self.children[actions[i]] = MCTSNode(parent=self, planner=self.planner, prior=probs[i]) + self.children[actions[i]] = MCTSNode( + parent=self, planner=self.planner, prior=probs[i] + ) def _update(self, total_rew): self.count += 1 @@ -93,14 +98,18 @@ def get_child(self, action, obs=None): child = self.children[action] if obs is not None: if str(obs) not in child.children: - child.children[str(obs)] = MCTSNode(parent=child, planner=self.planner, prior=0) + child.children[str(obs)] = MCTSNode( + parent=child, planner=self.planner, prior=0 + ) child = child.children[str(obs)] return child - + def selection_strategy(self, temp=0): if not self.parent: return self.get_value() - return self.get_value + temp * len(self.parent.children) * self.prior / (self.count-1) + return self.get_value + temp * len(self.parent.children) * self.prior / ( + self.count - 1 + ) def get_value(self): return self.value @@ -109,7 +118,9 @@ def convert_visits_to_prior(self, reg=0.5): self.count = 0 total_count = np.sum([(child.count + 1) for child in self.children]) for child in self.children.values(): - child.prior = reg * (child.count + 1) / total_counts + reg / len(self.children) + child.prior = reg * (child.count + 1) / total_counts + reg / len( + self.children + ) child.convert_visits_to_prior() @@ -164,7 +175,7 @@ def step_planner(seld, action): self._step_by_prior(action) else: super().step_planner(action) - + def _step_by_prior(self, action): self._step_by_subtree(action) self.root.convert_visits_to_prior() diff --git a/genrl/agents/modelbased/mcts/uct.py b/genrl/agents/modelbased/mcts/uct.py index 46f8fb78..b6d9128d 100644 --- a/genrl/agents/modelbased/mcts/uct.py +++ b/genrl/agents/modelbased/mcts/uct.py @@ -1,5 +1,16 @@ +import numpy as np + from genrl.agents.modelbased.mcts.mcts import MCTSNode class UCTNode(MCTSNode): - pass \ No newline at end of file + def __init__(self, *args, disc_factor, **kwargs): + super(UCTNode, self).__init__(*args, **kwargs) + self.disc_factor = disc_factor + + def selection_strategy(self, temp=0): + if not self.parent: + return self.get_value() + return self.get_value() + temperature * self.prior * np.sqrt( + np.log(self.parent.count) / self.count + )