-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch.py
168 lines (149 loc) · 7.59 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import argparse
from collections import Counter, OrderedDict
import torch
import db2 as data
import model_utils as mutils
import se_search
from cky import movesfromtree
def get_thresh_stuff(mindist, threshes):
for d, vals in threshes.items():
if mindist <= d:
return vals
assert False
def batched_pred(db, model, device, args, restrict_fn=None):
if restrict_fn is not None:
for ii in range(len(db.val_neidxs)):
restricted_nes = restrict_fn(db.val_neidxs[ii], db, ii)
if len(restricted_nes) > 0:
db.val_neidxs[ii] = restricted_nes[:args.restrict_nes]
if len(db.val_neidxs[ii]) < args.restrict_nes:
print("not enough restricted neighbors for", ii, len(db.val_neidxs[ii]))
nelists, nedists, nprotes = db.val_neidxs, db.val_ne_dists, len(db.protes)
model.eval()
num_preds = len(db.val_srcs)
print("num_preds", num_preds, len(nelists))
print("nprotes", nprotes)
#assert len(nelists) == num_preds
if args.debug >= 0:
start, end = args.debug, args.debug+1
elif args.startend:
start, end = args.startend
else:
start, end = 0, min(args.max_preds, num_preds)
search = se_search0.se_search if args.nusearch else se_search.se_search
if args.ne_thresh:
threshes = OrderedDict()
for i in range(0, len(args.ne_thresh), 4):
threshes[args.ne_thresh[i]] = (int(args.ne_thresh[i+1]), int(args.ne_thresh[i+2]),
int(args.ne_thresh[i+3])) # nmoves, nne, K
with open(args.out_fi, "w+") as f:
for ii in range(start, end, args.bsz):
batchidxs = list(range(ii, min(ii+args.bsz, end)))
nelist, neoffs = [], [0] # neoffs[b] contains idx of first neighbor for example b
for idx in batchidxs:
if args.ne_thresh:
nmoves, threshnne, K = get_thresh_stuff(nedists[idx], threshes)
#threshnne = get_nne(nedists[idx], threshes)
nelists[idx] = nelists[idx][:nprotes+threshnne]
else:
nmoves, K = args.max_moves, args.K
nelist.extend(nelists[idx])
neoffs.append(neoffs[-1] + len(nelists[idx]))
if (ii+1) % args.log_interval == 0:
print("predicting line", ii+1)
batch_fin_hyps = search(batchidxs, nelist, neoffs, K, model, db, device,
max_moves=nmoves, min_len=args.min_len,
max_canvlen=args.max_canvlen, len_avg=(not args.no_len_avg),
leftright=args.leftright, only_copy=args.only_copy)
for b in range(len(batchidxs)):
fin_hyps = batch_fin_hyps[b]
if len(fin_hyps) == 0:
print("wtf2", batchidxs[b])
assert False
argmax = torch.Tensor([hyp.score for hyp in fin_hyps]).argmax()
if db.tokenizer is None:
pred = " ".join(fin_hyps[argmax.item()].canvas)
else:
pred = ''.join(tok.replace('</w>', ' ') # should be equiv to tok's decode thing
for tok in fin_hyps[argmax.item()].canvas).strip()
f.write(pred)
f.write("|||")
f.write("{:.6f}".format(fin_hyps[argmax.item()].score))
if args.get_trace:
trace = fin_hyps[argmax.item()].get_moves(db.val_srcs[batchidxs[b]], nelist, db)
tracestr = " ".join(["=>(%s) %s" % (srcc, " ".join(thing))
for srcc, thing in trace])
f.write("|||")
f.write(tracestr)
f.write("\n")
parser = argparse.ArgumentParser(description='')
parser.add_argument('-data', type=str, default="data/wb", help='datadir')
parser.add_argument("-tokfi",
default=None, type=str, help="")
parser.add_argument("-split_dashes", action='store_true', help="")
parser.add_argument('-val_src_fi', type=str, default=None, help='if diff than in data/')
parser.add_argument('-get_trace', action='store_true', help='')
parser.add_argument('-leftright', action='store_true', help='')
parser.add_argument('-nne', type=int, default=200, help='only used to load stuff')
parser.add_argument("-prote_fi", default="", type=str, help="")
parser.add_argument('-bsz', type=int, default=4, help='batch size')
parser.add_argument('-seed', type=int, default=1111, help='random seed')
parser.add_argument('-cuda', action='store_true', help='use CUDA')
parser.add_argument('-log_interval', type=int, default=200, help='report interval')
parser.add_argument('-train_from', type=str, default='', help='')
parser.add_argument('-K', type=int, default=1, help='beam size')
parser.add_argument('-max_moves', type=int, default=23, help='') # 98% of wb < 23; 98% of iw < 32
parser.add_argument('-max_canvlen', type=int, default=200, help='')
parser.add_argument('-out_fi', type=str, default="preds.out", help='')
parser.add_argument('-val_nefi', type=str, default="val-nes.txt", help='')
parser.add_argument('-mtgt_fi', type=str, default="masked-train-tgt.txt", help='')
parser.add_argument('-no_len_avg', action='store_true', help='')
parser.add_argument('-min_len', type=int, default=0, help='INCLUDES <tgt>, </tgt> (so maybe add 2)')
parser.add_argument('-ne_thresh', nargs='+', type=float, default=None,
help='[min_score, nmoves, nne, K]')
parser.add_argument('-max_preds', type=int, default=4000, help='')
parser.add_argument('-only_copy', action='store_true', help='')
parser.add_argument('-restrict_mode', default=None, choices=['yes', 'no', None], type=str)
parser.add_argument('-restrict_nes', type=int, default=5, help='gross but whatever')
parser.add_argument('-debug', type=int, default=-1, help='')
parser.add_argument('-startend', nargs='+', type=int, default=None, help='')
parser.add_argument('-nusearch', action='store_true', help='')
if __name__ == "__main__":
args = parser.parse_args()
args.arbl = False
args.min_nes = 1 # not used, just for backward compatibility
print(args)
if torch.cuda.is_available():
if not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with -cuda")
device = torch.device("cuda" if args.cuda else "cpu")
torch.manual_seed(args.seed)
assert args.train_from
saved_stuff = torch.load(args.train_from)
saved_args = saved_stuff["opt"]
args.enclose, args.sel_firstlast_idxing = saved_args.enclose, saved_args.sel_firstlast_idxing
args.vocopts = saved_args.vocopts
print("enclose", args.enclose, "firstlast", args.sel_firstlast_idxing)
db = data.ValDB(args)
assert db.sel_firstlast_idxing
mod_ctor = mutils.BartThing
if not hasattr(saved_args, "leftright"):
saved_args.leftright = args.leftright
model = mod_ctor(len(db.d), db.d.gen_voc_size, saved_args)
model.load_state_dict(saved_stuff["sd"])
model = model.to(device)
if args.restrict_mode:
import eval_utils as eutils
yes = args.restrict_mode == "yes"
assert "e2e" in args.data
restrictor = eutils.E2ERestrictor(args, multi_sentence=True, yes=yes)
restrict_fn = restrictor.restrict
prote_keepers = restrict_fn(db.protes, db, 0)
print("nprote_keepers", len(prote_keepers))
assert prote_keepers or not args.prote_fi
else:
restrict_fn = None
if args.debug != -2:
with torch.no_grad():
batched_pred(db, model, device, args, restrict_fn=restrict_fn)