-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustomize_get_random_skeletons.py
93 lines (84 loc) · 3.03 KB
/
customize_get_random_skeletons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import json
import random
f = open("data/merge_test_skeletons.json", "r")
all_data = json.load(f)
random_test_skes_json = "sketch_pred_results/random_test_skeletons_predictions.json"
all_pred = []
g = open(random_test_skes_json, "w")
random.seed(42)
for item in all_data:
pre = item['premise']
con = item['raw_condition']
c_con = item['counterfactual_condition']
ske = item['raw_skeletons_endings'][0]
c_ske = item['counterfactual_skeletons_endings'][0]
end = item['ending']
c_end = item['c_ending']
end_words = end.strip().split()
end_words_cf = c_end.strip().split()
words_raw = ske.strip().split()
labels_raw = item['label_raw'][1][1:]
words_cf = c_ske.strip().split()
labels_cf = item['label_cf'][1][1:]
len_labels_raw_0 = 0
len_labels_raw_1 = 0
for x in labels_raw:
if x == 0:
len_labels_raw_0 += 1
else:
len_labels_raw_1 += 1
len_labels_raw = len_labels_raw_0 + len_labels_raw_1
ids = [i for i in range(len_labels_raw)]
sampled_id = random.sample(ids, len_labels_raw_0)
random_skes_raw = [1 for i in range(len_labels_raw)]
for x in sampled_id:
random_skes_raw[x] = 0
assert len(labels_raw) == len(random_skes_raw)
pred_ske_raw = ""
for word, pred in zip(end_words, random_skes_raw):
if pred == 1:
pred_ske_raw = (pred_ske_raw + " " + word)
else:
if not pred_ske_raw.endswith(" __ "):
pred_ske_raw = pred_ske_raw + " __ "
pred_ske_raw = pred_ske_raw.strip()
words_raw = " ".join([w for w in words_raw])
labels_raw = " ".join([str(w) for w in labels_raw])
preds_raw = " ".join([str(w) for w in random_skes_raw])
len_labels_cf_0 = 0
len_labels_cf_1 = 0
for x in labels_cf:
if x == 0:
len_labels_cf_0 += 1
else:
len_labels_cf_1 += 1
len_labels_cf = len_labels_cf_0 + len_labels_cf_1
ids = [i for i in range(len_labels_cf)]
sampled_id = random.sample(ids, len_labels_cf_0)
random_skes_cf = [1 for i in range(len_labels_cf)]
for x in sampled_id:
random_skes_cf[x] = 0
assert len(labels_cf) == len(random_skes_cf)
pred_ske_cf = ""
for word, pred in zip(end_words_cf, random_skes_cf):
if pred == 1:
pred_ske_cf = (pred_ske_cf + " " + word)
else:
if not pred_ske_cf.endswith(" __ "):
pred_ske_cf = pred_ske_cf + " __ "
pred_ske_cf = pred_ske_cf.strip()
words_cf = " ".join([w for w in words_cf])
labels_cf = " ".join([str(w) for w in labels_cf])
preds_cf = " ".join([str(w) for w in random_skes_cf])
res = {}
res['premise'] = pre
res['raw_condition'] = con
res['ending'] = end
res['gt_raw_skeletons_ending'] = ske
res['raw_skeletons_endings'] = [pred_ske_cf]
res['counterfactual_condition'] = c_con
res['c_ending'] = c_end
res['gt_counterfactual_skeletons_ending'] = c_ske
res['counterfactual_skeletons_endings'] = [pred_ske_raw]
all_pred.append(res)
json.dump(all_pred, g)