-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_passage_embeddings.py
93 lines (76 loc) · 3.48 KB
/
build_passage_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from argparse import ArgumentParser, Namespace
import faiss
import numpy as np
from torch.nn.parallel import DataParallel
from tqdm import tqdm
from soseki.biencoder.modeling import BiencoderLightningModule
from soseki.passage_db.lmdb_passage_db import LMDBPassageDB
from soseki.utils.data_utils import batch_iter
def main(args: Namespace):
# load the biencoder
biencoder = BiencoderLightningModule.load_from_checkpoint(args.biencoder_file, map_location="cpu")
biencoder.freeze()
biencoder.question_encoder = None # to free up memory
biencoder.passage_encoder.eval()
if args.device_ids is not None:
device_ids = args.device_ids
biencoder.passage_encoder.to(device_ids[0])
if len(device_ids) > 1:
biencoder.passage_encoder = DataParallel(biencoder.passage_encoder, device_ids=device_ids)
else:
device_ids = []
# load the passage db
passage_db = LMDBPassageDB(args.passage_db_file)
index = None
is_binary = biencoder.hparams["binary"]
# iterate over passages in the passage db
with tqdm(total=len(passage_db)) as pbar:
for passages in batch_iter(passage_db, batch_size=args.batch_size):
# get embeddings of the passages
titles = [passage.title for passage in passages]
texts = [passage.text for passage in passages]
encoder_inputs = dict(
biencoder.tokenization.tokenize_passages(
titles,
texts,
padding=True,
truncation="only_second",
max_length=args.max_passage_length,
return_tensors="pt",
)
)
if device_ids:
encoder_inputs = {key: tensor.to(device_ids[0]) for key, tensor in encoder_inputs.items()}
embeddings = biencoder.passage_encoder(encoder_inputs).cpu().numpy()
if index is None:
# initialize the faiss index with the dimensionality of the embeddings
dim_size = embeddings.shape[1]
if is_binary:
index = faiss.IndexBinaryIDMap2(faiss.IndexBinaryFlat(dim_size))
else:
index = faiss.IndexIDMap2(faiss.IndexFlatIP(dim_size))
# format the embeddings for indexing
if is_binary:
embeddings = np.where(embeddings < 0, 0, 1)
embeddings = np.packbits(embeddings).reshape(embeddings.shape[0], -1)
else:
embeddings = embeddings.astype(np.float32)
# add the embeddings and the corresponding passage ids to the faiss index
passage_ids = np.array([passage.id for passage in passages], dtype=np.int64)
index.add_with_ids(embeddings, passage_ids)
pbar.update(len(passages))
# write the faiss index to file
if is_binary:
faiss.write_index_binary(index, args.output_file)
else:
faiss.write_index(index, args.output_file)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--biencoder_file", type=str, required=True)
parser.add_argument("--passage_db_file", type=str, required=True)
parser.add_argument("--output_file", type=str, required=True)
parser.add_argument("--max_passage_length", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--device_ids", type=int, nargs="+")
args = parser.parse_args()
main(args)