-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathmake_dataset.py
121 lines (88 loc) · 3.96 KB
/
make_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import os
import random
import joblib
def get_arguments():
def _str_to_bool(s):
"""Convert string to bool (in argparse context)."""
if s.lower() not in ['true', 'false']:
raise ValueError('Argument needs to be a '
'boolean, got {}'.format(s))
return {'true': True, 'false': False}[s.lower()]
parser = argparse.ArgumentParser(description='Find all wav, aiff, and mp3 files and create dataset file.')
parser.add_argument('--data_dir', type=str, nargs='*',
help='Root directory(s) in which to look for samples. Samples can be in nested directories.')
parser.add_argument('--dataset_name', type=str,
help='Root directory in which to look for samples. Samples can be in nested directories.')
parser.add_argument('--train_ratio', type=float, default=0.9,
help='Percentage of (randomly chosen) files to use for training. Remaining ones are validation.')
return parser.parse_args()
def get_data_subset(track_ids, full_data):
categories_full = full_data['categories']
audio_paths_full = full_data['audio_paths']
category_names_full = full_data['category_names']
categories = dict()
audio_paths = dict()
for track_id in track_ids:
audio_paths[track_id] = audio_paths_full[track_id]
for category_key in categories_full.keys():
categories[category_key] = dict()
for track_id in track_ids:
categories[category_key][track_id] = categories_full[category_key][track_id]
datasubset = {
'track_ids': track_ids,
'categories': categories,
'audio_paths': audio_paths,
'category_names': category_names_full
}
return datasubset
def main():
args = get_arguments()
# Get all paths of audio files
audio_files = []
for root_dir in args.data_dir:
for dirName, subdirList, fileList in os.walk(root_dir, topdown=False):
for fname in fileList:
if os.path.splitext(fname)[1] in ['.wav', '.WAV',
'.aiff', '.AIFF',
'.aif', '.AIF',
'.mp3', '.MP3',
'.aac', '.AAC']:
audio_files.append('%s/%s' % (dirName, fname))
print(f'Total number of samples found: {len(audio_files)}')
# Build dataset
track_ids = []
audio_paths = dict()
for sample_path in audio_files:
# Find unique ID for each sample. Try filename first, if already exists add extension
track_id = os.path.splitext(os.path.basename(sample_path))[0]
while track_id in track_ids:
track_id += 'x'
audio_paths[track_id] = sample_path
track_ids.append(track_id)
# TODO: Add proper support for category data later; for now just pass empty dicts for compatibility
categories = dict()
category_names = dict()
dataset = {
'track_ids': track_ids,
'categories': categories,
'audio_paths': audio_paths,
'category_names': category_names
}
# Train/valid split
split_index = int(args.train_ratio * len(track_ids))
# Randomize data
random.shuffle(track_ids)
track_ids_train = track_ids[:split_index]
track_ids_valid = track_ids[split_index:]
print(f'Splitting {len(track_ids)} samples into {len(track_ids_train)} training and {len(track_ids_valid)} validation samples.')
dataset_train = get_data_subset(track_ids_train, dataset)
dataset_valid = get_data_subset(track_ids_valid, dataset)
print('Saving dataset files.')
if not os.path.exists('datasets'):
os.makedirs('datasets')
joblib.dump(dataset_train, f'datasets/{args.dataset_name}_train.pkl')
joblib.dump(dataset_valid, f'datasets/{args.dataset_name}_valid.pkl')
print('Done.')
if __name__ == '__main__':
main()