-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathcifar10_train.py
34 lines (27 loc) · 1.17 KB
/
cifar10_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from keras_audio.library.cifar10 import Cifar10AudioClassifier
from keras_audio.library.utility.gtzan_loader import download_gtzan_genres_if_not_found
def load_audio_path_label_pairs(max_allowed_pairs=None):
download_gtzan_genres_if_not_found('./very_large_data/gtzan')
audio_paths = []
with open('./data/lists/test_songs_gtzan_list.txt', 'rt') as file:
for line in file:
audio_path = './very_large_data/' + line.strip()
audio_paths.append(audio_path)
pairs = []
with open('./data/lists/test_gt_gtzan_list.txt', 'rt') as file:
for line in file:
label = int(line)
if max_allowed_pairs is None or len(pairs) < max_allowed_pairs:
pairs.append((audio_paths[len(pairs)], label))
else:
break
return pairs
def main():
audio_path_label_pairs = load_audio_path_label_pairs()
print('loaded: ', len(audio_path_label_pairs))
classifier = Cifar10AudioClassifier()
batch_size = 8
epochs = 100
history = classifier.fit(audio_path_label_pairs, model_dir_path='./models', batch_size=batch_size, epochs=epochs)
if __name__ == '__main__':
main()