-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
72 lines (57 loc) · 2.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# import absl.logging
# absl.logging.set_verbosity(absl.logging.ERROR)
import warnings
warnings.filterwarnings('ignore')
import datetime, time
import tensorflow as tf
# tf.compat.v1.disable_eager_execution()
from utils.tools import Logger, get_args_and_cfg
from utils.train import Trainer
from utils.distiller import Distiller
from utils.hp_search import HPSearcher
def main():
start_time = time.time()
# define some variables and read cfg
args, cfg = get_args_and_cfg()
#select the working GPU
if cfg['NAME'] == 'test' or cfg['METHOD'] == 'ISW':
tf.config.run_functions_eagerly(True)
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[args.cuda], 'GPU')
devices = []
for g in [args.cuda]:
tf.config.experimental.set_memory_growth(gpus[g], True)
devices.append(f'GPU:{g}')
if len(cfg['GPU']) > 1:
strategy = tf.distribute.MirroredStrategy(devices=devices)
else:
strategy = None
logger = Logger(f"{cfg['LOG_PATH']}_{cfg['NAME']}_{start_time}.txt")
start_time = time.time()
if cfg['SEED']:
#seed_everything(cfg['SEED'])
tf.keras.utils.set_random_seed(cfg['SEED']) # sets seeds for base-python, numpy and tf
tf.config.experimental.enable_op_determinism()
if cfg['HP_SEARCH']:
searcher = HPSearcher(args=args, cfg=cfg, logger=logger, strategy=strategy, trial=None)
searcher.hp_search()
elif cfg['METHOD'] in ['KD']:
distiller = Distiller(cfg, logger, strategy)
if cfg['TEST']:
test_loss, test_metr = distiller.evaluate(distiller.ds_test, "test")
print(f"Test loss: {test_loss}, Test mIoU: {test_metr}")
else:
distiller.train()
else:
trainer = Trainer(cfg, logger, strategy)
if cfg['TEST']:
test_loss, test_metr = trainer.evaluate(trainer.ds_test, "test")
print(f"Test loss: {test_loss}, Test mIoU: {test_metr}")
else:
trainer.train()
print(f"--- {(time.time() - start_time):.1f} seconds ---")
return
if __name__ == "__main__":
main()