-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.py
214 lines (178 loc) · 7.36 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from process import load_data
from cae import CVAutoencoder
from train import encoder_path, decoder_path, lstm_path
startTime = 100.067
dt = 0.077 # time interval between snapshots
RE = 160
file_path = "data/40-data-100.375" # file to test CAE reconstruction
sequence_path = [ # specify what files to input to test LSTM prediction (can be any #)
"validate/180-data-100.067",
"validate/180-data-100.144",
"validate/180-data-100.221",
"validate/180-data-100.298",
"validate/180-data-100.375",
"validate/180-data-100.452",
"validate/180-data-100.529",
"validate/180-data-100.606",
"validate/180-data-100.683",
"validate/180-data-100.760",
"validate/180-data-100.837",
"validate/180-data-100.914",
"validate/180-data-100.991",
"validate/180-data-101.068",
"validate/180-data-101.145",
"validate/180-data-101.222",
"validate/180-data-101.299",
"validate/180-data-101.376",
"validate/180-data-101.453",
"validate/180-data-101.530",
"validate/180-data-101.607"
]
re_mean = 120 # precalculated based on even number of files for Re={40,80,120,160,200}
re_std = 56.5685
re_norm = (RE - re_mean) / re_std
def decode_and_plot_comparison(file_path, encoder_path, decoder_path, device="cuda"):
# Load the pretrained autoencoder
device = torch.device(device)
ae = CVAutoencoder().to(device)
ae.encoder.load_state_dict(torch.load(encoder_path))
ae.decoder.load_state_dict(torch.load(decoder_path))
ae.encoder.eval()
ae.decoder.eval()
# Load the snapshot
snapshot = load_data(file_path).unsqueeze(0).to(device) # Add batch dimension, shape: [1, 3, height, width]
# Encode and decode the snapshot
with torch.no_grad():
latent_vector = ae.encoder(snapshot) # Shape: [1, latent_dim]
decoded_snapshot = ae.decoder(latent_vector) # Shape: [1, 3, height, width]
# Remove batch dimension for plotting
input_snapshot = snapshot.squeeze(0).cpu().numpy() # Shape: [3, height, width]
output_snapshot = decoded_snapshot.squeeze(0).cpu().numpy() # Shape: [3, height, width]
# Compute the absolute error
error_snapshot = np.abs(input_snapshot - output_snapshot)
components = ['u', 'v', 'p']
for i, component in enumerate(components):
plt.figure(figsize=(15, 5))
# Input
plt.subplot(1, 3, 1)
plt.imshow(input_snapshot[i], cmap="viridis")
plt.colorbar()
plt.title(f"Input: {component}")
# Output
plt.subplot(1, 3, 2)
plt.imshow(output_snapshot[i], cmap="viridis")
plt.colorbar()
plt.title(f"Decoded Output: {component}")
# Error
plt.subplot(1, 3, 3)
plt.imshow(error_snapshot[i], cmap="inferno")
plt.colorbar()
plt.title(f"Absolute Error: {component}")
plt.tight_layout()
plt.show()
plt.save(f"cae_output_{component}")
def recursive_validation_with_plots(
encoder_path,
decoder_path,
lstm_path,
file_paths,
re_value,
num_predictions,
ground_truth_dir,
output_dir="predicted_images",
device="cuda",
plot_after=5, # Number of recursive predictions to wait before plotting
start_time=startTime
):
# Load pretrained models
device = torch.device(device)
ae = CVAutoencoder().to(device)
ae.encoder.load_state_dict(torch.load(encoder_path))
ae.decoder.load_state_dict(torch.load(decoder_path))
ae.encoder.eval()
ae.decoder.eval()
lstm = ae.lstm
lstm.load_state_dict(torch.load(lstm_path))
lstm.eval()
# Prepare initial inputs (first snapshots in sequence)
initial_snapshots = torch.stack([load_data(fp) for fp in file_paths]).to(device) # Shape: [5, 3, height, width]
# Encode the initial snapshots
initial_latents = torch.stack(
[ae.encoder(snapshot.unsqueeze(0)).squeeze(0) for snapshot in initial_snapshots]
) # Shape: [5, latent_dim]
# Reynolds number tensor
re_value_tensor = torch.tensor(
[[[re_value]] * initial_latents.size(0)], device=device, dtype=torch.float32
) # Shape: [1, seq_length, 1]
for step in range(num_predictions):
# Prepare input for LSTM
input_sequence = initial_latents.unsqueeze(0) # Shape: [1, seq_length, latent_dim]
# Predict the next latent vector
with torch.no_grad():
next_latent = lstm(input_sequence, re_value_tensor) # Shape: [1, 1, latent_dim]
predicted_last_timestep = next_latent[:, -1, :] # Shape: [batch_size, latent_dim]
predicted_field = ae.decoder(predicted_last_timestep) # Decode to high-dimensional flow field
# Save the predicted field
predicted_file = os.path.join(output_dir, f"predicted_{step + 1}.npz")
predicted_field_np = predicted_field.detach().cpu().numpy()
np.savez(
predicted_file,
u_x=predicted_field_np[0, 0],
u_y=predicted_field_np[0, 1],
p=predicted_field_np[0, 2],
)
# Load the ground truth field
ground_truth_file = os.path.join(
ground_truth_dir, f"{re_value}-data-{startTime + dt * (step + 1):.3f}"
)
ground_truth_field = load_data(ground_truth_file).detach().cpu().numpy()
# Calculate and log MSE
mse = np.mean((predicted_field_np - ground_truth_field) ** 2)
print(f"Step {step + 1}: MSE with ground truth: {mse}")
# Plot after the specified number of steps
if (step + 1) % plot_after == 0 or (step + 1) == num_predictions:
components = ["u", "v", "p"]
for i, component in enumerate(components):
plt.figure(figsize=(15, 5))
# Ground Truth
plt.subplot(1, 3, 1)
plt.imshow(ground_truth_field[i], cmap="viridis")
plt.colorbar()
plt.title(f"Ground Truth: {component} (Step {step + 1})")
# Predicted Field
plt.subplot(1, 3, 2)
plt.imshow(predicted_field_np[0, i], cmap="viridis")
plt.colorbar()
plt.title(f"Predicted: {component} (Step {step + 1})")
# Error
plt.subplot(1, 3, 3)
error_snapshot = np.abs(ground_truth_field[i] - predicted_field_np[0, i])
plt.imshow(error_snapshot, cmap="inferno")
plt.colorbar()
plt.title(f"Absolute Error: {component} (Step {step + 1})")
plt.tight_layout()
plt.show()
# Update the sequence for the next prediction
initial_latents = torch.cat((initial_latents[1:], predicted_last_timestep), dim=0)
###### Uncomment to test CAE given a file (file_path) ######
# decode_and_plot_comparison(file_path, encoder_path, decoder_path, device="cuda")
##### Uncomment to test CAE w/LSTM given a sequence (sequence_path) ######
'''
recursive_validation_with_plots(
encoder_path,
decoder_path,
lstm_path,
file_paths=sequence_path, # Initial snapshots
re_value=re_norm, # Normalized Reynolds number
num_predictions=100, # Number of recursive predictions
ground_truth_dir="data/",
output_dir="predicted_snapshots",
device="cuda",
plot_after=10, # Generate plots after every 10 predictions
start_time=startTime
)
'''