Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
HokageM committed Dec 31, 2023
1 parent f12dfac commit b0d290e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/irlwpython/MaxEntropyDeepIRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, target, state_dim, action_size, feature_matrix, one_feature,
self.theta_learning_rate = theta_learning_rate
self.theta = theta

self.printer = FigurePrinter()
self.printer = OutputHandler()

def select_action(self, state, epsilon):
"""
Expand Down
13 changes: 7 additions & 6 deletions src/irlwpython/MaxEntropyDeepRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,17 @@ def train(self, n_states, episodes=30000, max_steps=200,
score_avg = np.mean(scores)
print('{} episode average score is {:.2f}'.format(episode, score_avg))
self.output_hand.save_plot_as_png(episode_arr, scores,
f"../learning_curves/maxent_{episodes}_{episode}_qnetwork_RL.png")
self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)), f"../heatmap/learner_{episode}_deep_RL.png")
f"../learning_curves/maxent_{episodes}_{episode}_qnetwork_RL.png")
self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)),
f"../heatmap/learner_{episode}_deep_RL.png")
self.output_hand.save_heatmap_as_png(self.theta.reshape((20, 20)),
f"../heatmap/theta_{episode}_deep_RL.png")
f"../heatmap/theta_{episode}_deep_RL.png")

torch.save(self.q_network.state_dict(), f"../results/maxent_{episodes}_{episode}_network_main.pth")

if episode == episodes - 1:
self.output_hand.save_plot_as_png(episode_arr, scores,
f"../learning_curves/maxentdeep_{episodes}_qdeep_RL.png")
f"../learning_curves/maxentdeep_{episodes}_qdeep_RL.png")

torch.save(self.q_network.state_dict(), f"src/irlwpython/results/maxentdeep_{episodes}_q_network_RL.pth")

Expand Down Expand Up @@ -193,5 +194,5 @@ def test(self, model_path, epsilon=0.01, repeats=100):
print('{} episode score is {:.2f}'.format(episode, score))

self.output_hand.save_plot_as_png(episodes, scores,
"src/irlwpython/learning_curves"
"/test_maxentropydeep_best_model_RL_results.png")
"src/irlwpython/learning_curves"
"/test_maxentropydeep_best_model_RL_results.png")
10 changes: 5 additions & 5 deletions src/irlwpython/MaxEntropyIRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def train(self, theta_learning_rate, episode_count=30000):
score_avg = np.mean(scores)
print('{} episode score is {:.2f}'.format(episode, score_avg))
self.output_hand.save_plot_as_png(episodes, scores,
f"src/irlwpython/learning_curves/"
f"maxent_{episode_count}_{episode}_qtable.png")
f"src/irlwpython/learning_curves/"
f"maxent_{episode_count}_{episode}_qtable.png")
self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)),
f"src/irlwpython/heatmap/learner_{episode}_flat.png")
f"src/irlwpython/heatmap/learner_{episode}_flat.png")
self.output_hand.save_heatmap_as_png(self.theta.reshape((20, 20)),
f"src/irlwpython/heatmap/theta_{episode}_flat.png")
f"src/irlwpython/heatmap/theta_{episode}_flat.png")

np.save(f"src/irlwpython/results/maxent_{episode}_qtable", arr=self.q_table)

Expand Down Expand Up @@ -173,4 +173,4 @@ def test(self, repeats=100):
print('{} episode score is {:.2f}'.format(episode, score))

self.output_hand.save_plot_as_png(episodes, scores,
"src/irlwpython/learning_curves/test_maxentropy_flat.png")
"src/irlwpython/learning_curves/test_maxentropy_flat.png")

0 comments on commit b0d290e

Please sign in to comment.