From f12dfac0b03ad8efebdb26bed7b82365d150b7cd Mon Sep 17 00:00:00 2001 From: HokageM Date: Fri, 8 Dec 2023 10:58:41 +0100 Subject: [PATCH] Check if directory exists --- src/irlwpython/MaxEntropyDeepIRL.py | 4 +-- src/irlwpython/MaxEntropyDeepRL.py | 16 ++++++------ src/irlwpython/MaxEntropyIRL.py | 12 ++++----- .../{FigurePrinter.py => OutputHandler.py} | 25 ++++++++++++++++++- 4 files changed, 40 insertions(+), 17 deletions(-) rename src/irlwpython/{FigurePrinter.py => OutputHandler.py} (67%) diff --git a/src/irlwpython/MaxEntropyDeepIRL.py b/src/irlwpython/MaxEntropyDeepIRL.py index c4d9d28..8793966 100644 --- a/src/irlwpython/MaxEntropyDeepIRL.py +++ b/src/irlwpython/MaxEntropyDeepIRL.py @@ -5,7 +5,7 @@ import torch.optim as optim import torch.nn as nn -from irlwpython.FigurePrinter import FigurePrinter +from irlwpython.OutputHandler import OutputHandler class QNetwork(nn.Module): @@ -17,7 +17,7 @@ def __init__(self, input_size, output_size): self.relu2 = nn.ReLU() self.output_layer = nn.Linear(32, output_size) - self.printer = FigurePrinter() + self.output_hand = OutputHandler() def forward(self, state): x = self.fc1(state) diff --git a/src/irlwpython/MaxEntropyDeepRL.py b/src/irlwpython/MaxEntropyDeepRL.py index 9fbbb5e..0c8d9a0 100644 --- a/src/irlwpython/MaxEntropyDeepRL.py +++ b/src/irlwpython/MaxEntropyDeepRL.py @@ -5,7 +5,7 @@ import torch.optim as optim import torch.nn as nn -from irlwpython.FigurePrinter import FigurePrinter +from irlwpython.OutputHandler import OutputHandler class QNetwork(nn.Module): @@ -17,7 +17,7 @@ def __init__(self, input_size, output_size): self.relu2 = nn.ReLU() self.output_layer = nn.Linear(32, output_size) - self.printer = FigurePrinter() + self.output_hand = OutputHandler() def forward(self, state): x = self.fc1(state) @@ -42,7 +42,7 @@ def __init__(self, target, state_dim, action_size, feature_matrix, one_feature, self.gamma = gamma - self.printer = FigurePrinter() + self.output_hand = OutputHandler() def select_action(self, state, epsilon): """ @@ -150,16 +150,16 @@ def train(self, n_states, episodes=30000, max_steps=200, if (episode + 1) % 1000 == 0: score_avg = np.mean(scores) print('{} episode average score is {:.2f}'.format(episode, score_avg)) - self.printer.save_plot_as_png(episode_arr, scores, + self.output_hand.save_plot_as_png(episode_arr, scores, f"../learning_curves/maxent_{episodes}_{episode}_qnetwork_RL.png") - self.printer.save_heatmap_as_png(learner.reshape((20, 20)), f"../heatmap/learner_{episode}_deep_RL.png") - self.printer.save_heatmap_as_png(self.theta.reshape((20, 20)), + self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)), f"../heatmap/learner_{episode}_deep_RL.png") + self.output_hand.save_heatmap_as_png(self.theta.reshape((20, 20)), f"../heatmap/theta_{episode}_deep_RL.png") torch.save(self.q_network.state_dict(), f"../results/maxent_{episodes}_{episode}_network_main.pth") if episode == episodes - 1: - self.printer.save_plot_as_png(episode_arr, scores, + self.output_hand.save_plot_as_png(episode_arr, scores, f"../learning_curves/maxentdeep_{episodes}_qdeep_RL.png") torch.save(self.q_network.state_dict(), f"src/irlwpython/results/maxentdeep_{episodes}_q_network_RL.pth") @@ -192,6 +192,6 @@ def test(self, model_path, epsilon=0.01, repeats=100): if episode % 1 == 0: print('{} episode score is {:.2f}'.format(episode, score)) - self.printer.save_plot_as_png(episodes, scores, + self.output_hand.save_plot_as_png(episodes, scores, "src/irlwpython/learning_curves" "/test_maxentropydeep_best_model_RL_results.png") diff --git a/src/irlwpython/MaxEntropyIRL.py b/src/irlwpython/MaxEntropyIRL.py index 58d50d8..117f6fa 100644 --- a/src/irlwpython/MaxEntropyIRL.py +++ b/src/irlwpython/MaxEntropyIRL.py @@ -6,7 +6,7 @@ import numpy as np -from irlwpython.FigurePrinter import FigurePrinter +from irlwpython.OutputHandler import OutputHandler class MaxEntropyIRL: @@ -20,7 +20,7 @@ def __init__(self, target, feature_matrix, one_feature, q_table, q_learning_rate self.gamma = gamma self.n_states = n_states - self.printer = FigurePrinter() + self.output_hand = OutputHandler() def get_feature_matrix(self): """ @@ -133,12 +133,12 @@ def train(self, theta_learning_rate, episode_count=30000): if (episode + 1) % 1000 == 0: score_avg = np.mean(scores) print('{} episode score is {:.2f}'.format(episode, score_avg)) - self.printer.save_plot_as_png(episodes, scores, + self.output_hand.save_plot_as_png(episodes, scores, f"src/irlwpython/learning_curves/" f"maxent_{episode_count}_{episode}_qtable.png") - self.printer.save_heatmap_as_png(learner.reshape((20, 20)), + self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)), f"src/irlwpython/heatmap/learner_{episode}_flat.png") - self.printer.save_heatmap_as_png(self.theta.reshape((20, 20)), + self.output_hand.save_heatmap_as_png(self.theta.reshape((20, 20)), f"src/irlwpython/heatmap/theta_{episode}_flat.png") np.save(f"src/irlwpython/results/maxent_{episode}_qtable", arr=self.q_table) @@ -172,5 +172,5 @@ def test(self, repeats=100): if episode % 1 == 0: print('{} episode score is {:.2f}'.format(episode, score)) - self.printer.save_plot_as_png(episodes, scores, + self.output_hand.save_plot_as_png(episodes, scores, "src/irlwpython/learning_curves/test_maxentropy_flat.png") diff --git a/src/irlwpython/FigurePrinter.py b/src/irlwpython/OutputHandler.py similarity index 67% rename from src/irlwpython/FigurePrinter.py rename to src/irlwpython/OutputHandler.py index a491204..8557412 100644 --- a/src/irlwpython/FigurePrinter.py +++ b/src/irlwpython/OutputHandler.py @@ -1,7 +1,8 @@ import matplotlib.pyplot as plt +import os -class FigurePrinter: +class OutputHandler: def __int__(self): pass @@ -25,6 +26,11 @@ def save_heatmap_as_png(self, data, output_path, title=None, xlabel="Position", if title: plt.title(title) + target_dir = os.path.basename(output_path) + if not os.path.isdir(target_dir): + print(f"Creating directory {target_dir}") + os.mkdir(target_dir) + plt.savefig(output_path, format='png') plt.close(fig) @@ -48,5 +54,22 @@ def save_plot_as_png(self, x, y, output_path, title=None, xlabel="Episodes", yla if title: plt.title(title) + target_dir = os.path.basename(output_path) + if not os.path.isdir(target_dir): + print(f"Creating directory {target_dir}") + os.mkdir(target_dir) + plt.savefig(output_path, format='png') plt.close(fig) + + def save_network(self, network, output_path): + target_dir = os.path.basename(output_path) + if not os.path.isdir(target_dir): + print(f"Creating directory {target_dir}") + os.mkdir(target_dir) + + def save_qtable(self, qtable, output_path): + target_dir = os.path.basename(output_path) + if not os.path.isdir(target_dir): + print(f"Creating directory {target_dir}") + os.mkdir(target_dir)