-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtest_2_Config.py
90 lines (81 loc) · 4.25 KB
/
test_2_Config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from sdmdl.sdmdl.config import Config
from sdmdl.sdmdl.occurrences import Occurrences
from sdmdl.sdmdl.gis import GIS
import yaml
import unittest
import os
class ConfigTestCase(unittest.TestCase):
"""Test cases for Config Handler class."""
def setUp(self):
self.root = (os.path.abspath(os.path.join(os.path.dirname(__file__))) + '/test_data').replace('\\', '/')
self.oh = Occurrences(self.root + '/root')
self.oh.validate_occurrences()
self.oh.species_dictionary()
self.gh = GIS(self.root + '/root')
self.gh.validate_gis()
self.gh.validate_tif()
self.gh.define_output()
self.ch = Config(self.root + '/root', self.oh, self.gh)
def test__init__(self):
self.assertEqual(self.ch.oh, self.oh)
self.assertEqual(self.ch.gh, self.gh)
self.assertEqual(self.ch.root, self.root + '/root')
self.assertEqual(self.ch.config, [])
self.assertEqual(self.ch.yml_names, ['data_path', 'occurrence_path', 'result_path', 'occurrences', 'layers',
'random_seed', 'pseudo_freq', 'batchsize', 'epoch', 'model_layers',
'model_dropout', 'verbose'])
self.assertEqual(self.ch.data_path, None)
self.assertEqual(self.ch.occ_path, None)
self.assertEqual(self.ch.result_path, None)
self.assertEqual(self.ch.yml, None)
self.assertEqual(self.ch.random_seed, 0)
self.assertEqual(self.ch.pseudo_freq, 0)
self.assertEqual(self.ch.batchsize, 0)
self.assertEqual(self.ch.epoch, 0)
self.assertEqual(self.ch.model_layers, [])
self.assertEqual(self.ch.model_dropout, [])
self.assertEqual(self.ch.verbose, None)
def test_search_config(self):
self.ch.search_config()
self.assertEqual(self.ch.config, self.root + '/root/config.yml')
with self.assertRaises(IOError):
self.ch = Config(self.root + '/config', self.oh, self.gh)
self.ch.search_config()
def test_create_yaml(self):
self.ch.search_config()
self.ch.config = self.root + '/root/test_config.yml'
self.ch.create_yaml()
with open(self.ch.config, 'r') as stream:
yml = yaml.safe_load(stream)
self.assertEqual(yml[list(yml.keys())[0]], self.root + '/root')
self.assertEqual(yml[list(yml.keys())[1]], self.root + '/root/occurrences')
self.assertEqual(yml[list(yml.keys())[2]], self.root + '/root/results')
self.assertEqual(yml[list(yml.keys())[3]], dict(zip(self.oh.name, self.oh.path)))
self.assertEqual(yml[list(yml.keys())[4]], dict(zip(self.gh.names, self.gh.variables)))
self.assertEqual(yml[list(yml.keys())[5]], 42)
self.assertEqual(yml[list(yml.keys())[6]], 2000)
self.assertEqual(yml[list(yml.keys())[7]], 75)
self.assertEqual(yml[list(yml.keys())[8]], 150)
self.assertEqual(yml[list(yml.keys())[9]], [250, 200, 150, 100])
self.assertEqual(yml[list(yml.keys())[10]], [0.3, 0.5, 0.3, 0.5])
self.assertEqual(yml[list(yml.keys())[11]], True)
os.remove(self.root + '/root/test_config.yml')
def test_read_yaml(self):
self.ch.search_config()
self.ch.read_yaml()
self.assertEqual(self.ch.data_path, self.root + '/root')
self.assertEqual(self.ch.occ_path, self.root + '/root/occurrences')
self.assertEqual(self.ch.result_path, self.root + '/root/results')
self.assertEqual(self.ch.oh.name, list(dict(zip(self.oh.name, self.oh.path)).keys()))
self.assertEqual(self.ch.oh.path, list(dict(zip(self.oh.name, self.oh.path)).values()))
self.assertEqual(self.ch.gh.names, list(dict(zip(self.gh.names, self.gh.variables)).keys()))
self.assertEqual(self.ch.gh.variables, list(dict(zip(self.gh.names, self.gh.variables)).values()))
self.assertEqual(self.ch.random_seed, 42)
self.assertEqual(self.ch.pseudo_freq, 2000)
self.assertEqual(self.ch.batchsize, 75)
self.assertEqual(self.ch.epoch, 150)
self.assertEqual(self.ch.model_layers, [250, 200, 150, 100])
self.assertEqual(self.ch.model_dropout, [0.3, 0.5, 0.3, 0.5])
self.assertEqual(self.ch.verbose, True)
if __name__ == '__main__':
unittest.main()