-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata_to_pickle.py
55 lines (47 loc) · 1.6 KB
/
data_to_pickle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# based on ideas from https://github.com/voicy-ai/DialogStateTracking/blob/master/data/data_utils.py
import os
import pickle as pkl
import data_utils
DATA_DIR = './data/babi_dialog/'
P_DATA_DIR = './data/babi_dialog_pkl/'
if not os.path.exists(P_DATA_DIR):
os.makedirs(P_DATA_DIR)
def prepare_data(task_id, is_oov=False):
task_id = task_id
is_oov = is_oov
# get candidates (restaurants)
candidates, candid2idx, idx2candid = data_utils.load_candidates(task_id=task_id,
candidates_f=DATA_DIR + 'dialog-babi-candidates.txt')
# get data
train, test, val = data_utils.load_dialog_task(
data_dir=DATA_DIR,
task_id=task_id,
candid_dic=candid2idx,
isOOV=is_oov)
##
# get metadata
metadata = data_utils.build_vocab(train + test + val, candidates)
###
# write data to file
data_ = {
'candidates': candidates,
'train': train,
'test': test,
'val': val
}
if is_oov:
with open(P_DATA_DIR + str(task_id) + '_oov.data.pkl', 'wb') as f:
pkl.dump(data_, f)
else:
with open(P_DATA_DIR + str(task_id) + '.data.pkl', 'wb') as f:
pkl.dump(data_, f)
###
# save metadata to disk
metadata['candid2idx'] = candid2idx
metadata['idx2candid'] = idx2candid
if is_oov:
with open(P_DATA_DIR + str(task_id) + '_oov.metadata.pkl', 'wb') as f:
pkl.dump(metadata, f)
else:
with open(P_DATA_DIR + str(task_id) + '.metadata.pkl', 'wb') as f:
pkl.dump(metadata, f)