Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
noahfarr committed Nov 17, 2024
1 parent 30d8fb3 commit a7bb287
Showing 1 changed file with 16 additions and 47 deletions.
63 changes: 16 additions & 47 deletions cleanrl/crossq_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,9 @@ def thunk():
class SoftQNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.bn1 = BatchRenorm1d(
np.array(env.single_observation_space.shape).prod()
+ np.prod(env.single_action_space.shape)
)
self.bn1 = BatchRenorm1d(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape))
self.fc1 = nn.Linear(
np.array(env.single_observation_space.shape).prod()
+ np.prod(env.single_action_space.shape),
np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape),
2048,
)
self.bn2 = BatchRenorm1d(2048)
Expand Down Expand Up @@ -252,9 +248,7 @@ def forward(self, x):
mean = self.fc_mean(x)
log_std = self.fc_logstd(x)
log_std = torch.tanh(log_std)
log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (
log_std + 1
) # From SpinUp / Denis Yarats
log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats
return mean, log_std

def get_action(self, x, train=False):
Expand Down Expand Up @@ -303,8 +297,7 @@ def get_action(self, x, train=False):
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s"
% ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
Expand All @@ -316,12 +309,8 @@ def get_action(self, x, train=False):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]
)
assert isinstance(
envs.single_action_space, gym.spaces.Box
), "only continuous action space is supported"
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

max_action = float(envs.single_action_space.high[0])

Expand All @@ -337,9 +326,7 @@ def get_action(self, x, train=False):

# Automatic entropy tuning
if args.autotune:
target_entropy = -torch.prod(
torch.Tensor(envs.single_action_space.shape).to(device)
).item()
target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha = log_alpha.exp().item()
a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
Expand All @@ -361,9 +348,7 @@ def get_action(self, x, train=False):
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = np.array(
[envs.single_action_space.sample() for _ in range(envs.num_envs)]
)
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device), train=False)
actions = actions.detach().cpu().numpy()
Expand All @@ -374,15 +359,9 @@ def get_action(self, x, train=False):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
print(
f"global_step={global_step}, episodic_return={info['episode']['r']}"
)
writer.add_scalar(
"charts/episodic_return", info["episode"]["r"], global_step
)
writer.add_scalar(
"charts/episodic_length", info["episode"]["l"], global_step
)
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
Expand All @@ -399,9 +378,7 @@ def get_action(self, x, train=False):
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)
with torch.no_grad():
next_state_actions, next_state_log_pi, _ = actor.get_action(
data.next_observations, train=False
)
next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations, train=False)

cat_obs = torch.cat((data.observations, data.next_observations), dim=0)
cat_actions = torch.cat((data.actions, next_state_actions), dim=0)
Expand All @@ -419,9 +396,7 @@ def get_action(self, x, train=False):
qf1_next = qf1_next.detach()
qf2_next = qf2_next.detach()
min_qf_next = torch.min(qf1_next, qf2_next) - alpha * next_state_log_pi
next_q_value = data.rewards.flatten() + (
1 - data.dones.flatten()
) * args.gamma * (min_qf_next).view(-1)
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next).view(-1)

qf1_loss = F.mse_loss(qf1_value, next_q_value)
qf2_loss = F.mse_loss(qf2_value, next_q_value)
Expand Down Expand Up @@ -456,12 +431,8 @@ def get_action(self, x, train=False):
alpha = log_alpha.exp().item()

if global_step % 100 == 0:
writer.add_scalar(
"losses/qf1_values", qf1_value.mean().item(), global_step
)
writer.add_scalar(
"losses/qf2_values", qf2_value.mean().item(), global_step
)
writer.add_scalar("losses/qf1_values", qf1_value.mean().item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_value.mean().item(), global_step)
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
Expand All @@ -474,9 +445,7 @@ def get_action(self, x, train=False):
global_step,
)
if args.autotune:
writer.add_scalar(
"losses/alpha_loss", alpha_loss.item(), global_step
)
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

envs.close()
writer.close()

0 comments on commit a7bb287

Please sign in to comment.