forked from ShoumikSaha/DRSM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_custom_malconv_by_ablation_on_advmal.py
111 lines (92 loc) · 4.56 KB
/
evaluate_custom_malconv_by_ablation_on_advmal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import magic
from secml.array import CArray
import numpy as np
import torch
import sys
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import argparse
# import pickle
import dill as pickle
from secml_malware.models.malconv import MalConv
from secml_malware.models.c_classifier_end2end_malware import CClassifierEnd2EndMalware
from secml_malware.smoothed_malconv import get_dataset, create_smoothed_malconv, modify_dataset_for_smoothed_malconv, \
pad_ablated_input, train_model, model_predict, get_majority_voting, get_majority_voting_without_padding
from secml_malware.custom_malconv import Custom_MalConv
from secml_malware.models.my_dataloader_csv import MyDataSet
from torch.utils.data import Dataset, DataLoader, random_split, default_collate
inp_len = 2 ** 21
def main(root_dir, train_path, dir_path, dataset_size, total_ablations, batch_size=16, perturb_size=20000):
"""
net = Custom_MalConv(max_input_size=int(inp_len / total_ablations), unfreeze=True)
net = CClassifierEnd2EndMalware(net, batch_size=batch_size)
net._n_features = int(inp_len / total_ablations)
"""
nets = []
ablation_idxs = []
for i, f in enumerate(os.listdir(dir_path)):
if ".h5" not in f:
continue
ablation_idx = int(f.split('_')[-1].split('.')[0])
model_path = os.path.join(dir_path, f)
print(model_path)
print("Loading the model from path")
print(ablation_idx)
# net.load_model(model_path)
net_model = torch.load(model_path)
nets.append(net_model)
ablation_idxs.append(ablation_idx)
test_preds_all_models = []
train_preds_all_models = []
val_preds_all_models = []
for i, net in enumerate(nets):
ablation_idx = ablation_idxs[i]
#generator1 = torch.Generator().manual_seed(42)
trainset = MyDataSet(root_dir, train_path, inp_len, ablation_idx, total_ablations, dataset_size)
#trainset, validset, testset = random_split(dataset, [0.7, 0.15, 0.15], generator=generator1)
train_loader = DataLoader(trainset, shuffle=False, batch_size=batch_size)
train_preds, lengths_all_train = get_predicts(net, train_loader)
train_preds_all_models.append(train_preds)
votes, certified_votes = get_majority_voting_without_padding(np.asarray(train_preds_all_models), len(train_preds_all_models[0]),
lengths_all_train,
int(inp_len / total_ablations), perturb_size)
train_acc = get_acc(votes, train_loader)
print("Train Accuracy (Standard): ", train_acc)
cert_train_acc = get_acc(certified_votes, train_loader)
print("Train Accuracy (Certified): ", cert_train_acc)
def get_predicts(net_model, data_generator):
net_model.eval()
preds_all_samples = []
lengths_all = []
for local_batch, local_labels, local_lengths in data_generator:
# Transfer to GPU
#local_batch, local_labels = local_batch.to(net._device), local_labels.to(net._device)
preds = net_model(local_batch).cpu()
#print(preds)
preds = preds.round().detach().numpy()
preds_all_samples.extend(preds)
# print(preds)
lengths_all.extend(local_lengths.detach().numpy())
preds_all_samples = np.asarray(preds_all_samples)
lengths_all = np.asarray(lengths_all)
return preds_all_samples.flatten(), lengths_all.flatten()
def get_acc(votes, data_generator):
labels = []
for local_batch, local_labels, local_lengths in data_generator:
#print(local_labels)
labels.extend(local_labels.numpy())
labels = np.asarray(labels)
acc = accuracy_score(votes, labels)
return acc
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Evaluate the custom model")
parser.add_argument('--root_dir', metavar='path', required=True)
parser.add_argument('--adv_csv_path', metavar='path', required=True)
parser.add_argument('--dir_path', metavar='path', required=True)
parser.add_argument('--dataset_size', type=int, metavar='dataset_size', required=False, default=-2)
parser.add_argument('--ablations', type=int, metavar='total_ablations', required=True)
parser.add_argument('--batch_size', type=int, metavar='batch_size', required=True)
parser.add_argument('--perturb_size', type=int, metavar='length of perturbation', required=False, default=20000)
args = parser.parse_args()
main(args.root_dir, args.adv_csv_path, args.dir_path, args.dataset_size, args.ablations, args.batch_size, args.perturb_size)