-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_hyper_param_search.py
115 lines (88 loc) · 4.06 KB
/
main_hyper_param_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import glob
import os
import shutil
import time
import warnings
from datetime import datetime
from multiprocessing import Pool
import numpy as np
from tqdm import tqdm
from IO import config
from IO.result import DetectionResult
from core import detector
from core import freq_transform
from core import signal_generator
warnings.simplefilter(action='ignore', category=FutureWarning)
def run(usr_config):
np.random.seed(usr_config.seed)
# prepare input signal
input_signal_generator = signal_generator.InputSignalGenerator(usr_config.signal, usr_config.noise)
input_signal, labels = input_signal_generator.get()
# transform to frequency domain
input_signal_freq_sq_mag, input_fft = freq_transform.transform_all(
input_signal, usr_config.freq_transform_method, usr_config.signal
)
input_signal_freq_sq_mag_half = input_signal_freq_sq_mag[:, 0:int(usr_config.signal.block_size / 2) + 1]
freq_detector = detector.HarmonicEstimator(usr_config.detection, input_signal_generator)
roc, scores = freq_detector.get_roc(input_signal_freq_sq_mag_half, labels)
# save result
result = DetectionResult(roc=roc, usr_configs=usr_config, scores=scores)
result.save(usr_config.result_dir)
def experiment(usr_configs_template, hyper_param):
usr_configs_template.signal.block_size = hyper_param.block_size
usr_configs_template.signal.num_blocks_avg = hyper_param.num_blocks_avg
usr_configs_template.signal.phases = [hyper_param.phi]
usr_configs_template.signal.hop_size = usr_configs_template.signal.block_size * usr_configs_template.signal.num_blocks_avg
usr_configs_template.noise.init_args.top = hyper_param.noise_level
usr_configs_template.noise.init_args.steady_state = hyper_param.noise_level
usr_configs_template.freq_transform_method.name = hyper_param.freq_transform_method
usr_configs_template.detection.name = hyper_param.detection_method
usr_configs_template.result_dir = os.path.join(
'/home/hgeng4/THESIS/results',
'Fmethod_{}'.format(hyper_param.freq_transform_method),
'detection_{}'.format(hyper_param.detection_method),
'phi_{}'.format(hyper_param.phi),
'N_{}'.format(hyper_param.block_size),
'L_{}'.format(hyper_param.num_blocks_avg),
'inde_noise_level_{}'.format(hyper_param.noise_level),
)
N = usr_configs_template.signal.block_size
fs = usr_configs_template.signal.fs
test_k = np.linspace(3, 4, 5)
test_f = test_k / N * fs
if os.path.exists(usr_configs_template.result_dir):
number_tar = glob.glob(os.path.join(usr_configs_template.result_dir, '*tar'))
if len(number_tar) != 101:
print('[INFO]: Found incomplete directories from previous runs, deleting :',
usr_configs_template.result_dir)
shutil.rmtree(usr_configs_template.result_dir)
pass
else:
print('[INFO]: Directories completed from previous runs',
usr_configs_template.result_dir)
return
print('[INFO]: Simulating: ', usr_configs_template.result_dir)
for f in test_f:
usr_configs_template.signal.freqs = [f]
run(usr_configs_template)
def multi_run_wrapper(args):
return experiment(*args)
def run_single_machine(usr_configs_template, search_space, host_name=None):
pool = Pool(processes=4)
total_args = []
for s in search_space:
total_args.append([usr_configs_template, s])
s = time.time()
print('staring at', datetime.now())
for _ in tqdm(pool.map(multi_run_wrapper, total_args), total=len(total_args)):
pass
e = time.time()
print('end at', datetime.now())
print('total time in second:', e - s)
if __name__ == '__main__':
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
usr_configs_template = config.parse_config('./yaml/template.yaml')
search_space = config.parse_search_space('./yaml/search_space.yaml')
total_args = [[usr_configs_template, s] for s in search_space]
run_single_machine(usr_configs_template, search_space, host_name=None)