Skip to content

Commit

Permalink
helper functions for HPO scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Willian-Girao committed May 3, 2024
1 parent f9153ed commit 40b7091
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
58 changes: 58 additions & 0 deletions tests/test_nonsequential/utils/train_test_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,64 @@ def training_loop(device, nb_time_steps, batch_size, feature_map_size, dataloade

return epochs_x, epochs_y, epochs_acc

def training_loop_no_tqdm(device, nb_time_steps, batch_size, feature_map_size, dataloader_train, model, loss_fn, optimizer, epochs, dataloader_test):
model.train()

for e in range(epochs):
for X, y in dataloader_train:
# reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
X = X.reshape(-1, feature_map_size[2], feature_map_size[0], feature_map_size[1]).to(dtype=torch.float, device=device)
y = y.to(dtype=torch.long, device=device)

# forward
pred = model(X)

# reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
pred = pred.reshape(batch_size, nb_time_steps, -1)

# accumulate all time-steps output for final prediction
pred = pred.sum(dim = 1)
loss = loss_fn(pred, y)

# gradient update
optimizer.zero_grad()
loss.backward()
optimizer.step()

# detach the neuron states and activations from current computation graph(necessary)
model.detach_neuron_states()

acc = test_no_tqdm(device, nb_time_steps, batch_size, feature_map_size, dataloader_test, model)

return acc

def test_no_tqdm(device, nb_time_steps, batch_size, feature_map_size, dataloader_test, model):
correct_predictions = []

with torch.no_grad():
for X, y in dataloader_test:
# reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
X = X.reshape(-1, feature_map_size[2], feature_map_size[0], feature_map_size[1]).to(dtype=torch.float, device=device)
y = y.to(dtype=torch.long, device=device)

# forward
output = model(X)

# reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
output = output.reshape(batch_size, nb_time_steps, -1)

# accumulate all time-steps output for final prediction
output = output.sum(dim=1)

# calculate accuracy
pred = output.argmax(dim=1, keepdim=True)

# compute the total correct predictions
correct_predictions.append(pred.eq(y.view_as(pred)))

correct_predictions = torch.cat(correct_predictions)
return correct_predictions.sum().item()/(len(correct_predictions))*100

def test(device, nb_time_steps, batch_size, feature_map_size, dataloader_test, model):
correct_predictions = []
with torch.no_grad():
Expand Down
26 changes: 25 additions & 1 deletion tests/test_nonsequential/utils/weight_initialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch.nn as nn
import numpy as np
import statistics

def rescale_method_1(conv_layer: nn.Conv2d, input_pool_kernel: list, lambda_: float = 1):
"""
Expand All @@ -20,6 +21,29 @@ def rescale_method_1(conv_layer: nn.Conv2d, input_pool_kernel: list, lambda_: fl

rescaling_factor = np.mean(rescaling_factors)*lambda_

print(f'recaling factor: {rescaling_factor} (computed using {len(input_pool_kernel)} kernels and lambda {lambda_})')
# print(f'method 1 - recaling factor: {rescaling_factor} (computed using {len(input_pool_kernel)} kernels and lambda {lambda_})')

conv_layer.weight.data /= rescaling_factor

def rescale_method_2(conv_layer: nn.Conv2d, input_pool_kernel: list, lambda_: float = 1):
"""
The `method 2` will use the harmonic mean of the computed rescaling factor for each pooling layer
feeding into `conv_layer` (if there are more than one) to rescale its weights.
Arguments
---------
input_pool_kernel (list): the kernels of all pooling layers feeding input to `conv_layer`.
lambda_ (float): scales the computed re-scaling factor. If the outputs of the pooling are too small
the rescaling might lead to vanishing gradients, so we can try to control that by scaling it by
lambda.
"""
rescaling_factors = []

for kernel in input_pool_kernel:
rescaling_factors.append(kernel[0]*kernel[1])

rescaling_factor = statistics.harmonic_mean(rescaling_factors)*lambda_

# print(f'method 2 - recaling factor: {rescaling_factor} (computed using {len(input_pool_kernel)} kernels and lambda {lambda_})')

conv_layer.weight.data /= rescaling_factor

0 comments on commit 40b7091

Please sign in to comment.