-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdense_retrieval.py
98 lines (85 loc) · 3.97 KB
/
dense_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from simcse import SimCSE
from typing import List, Dict, Tuple, Type, Union
import numpy as np
from typing import List, Dict, Tuple, Type, Union
import torch
from torch import Tensor, device
from numpy import ndarray
from tqdm import tqdm
class MySimCSE(SimCSE):
def encode(self, sentence: Union[str, List[str]],
device: str = None,
return_numpy: bool = False,
normalize_to_unit: bool = True,
keepdim: bool = False,
batch_size: int = 64,
max_length: int = 128) -> Union[ndarray, Tensor]:
target_device = self.device if device is None else device
self.model = self.model.to(target_device)
single_sentence = False
if isinstance(sentence, str):
sentence = [sentence]
single_sentence = True
embedding_list = []
with torch.no_grad():
total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
for batch_id in range(total_batch):
inputs = self.tokenizer(
sentence[batch_id * batch_size:(batch_id + 1) * batch_size],
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt"
)
inputs = {k: v.to(target_device) for k, v in inputs.items()}
outputs = self.model(**inputs, return_dict=True)
if self.pooler == "cls":
embeddings = outputs.pooler_output
elif self.pooler == "cls_before_pooler":
embeddings = outputs.last_hidden_state[:, 0]
elif self.pooler == 'all_token_pooling':
embeddings = outputs.last_hidden_state[:, 1:].mean(1)
else:
raise NotImplementedError
if normalize_to_unit:
embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
embedding_list.append(embeddings.cpu())
embeddings = torch.cat(embedding_list, 0)
if single_sentence and not keepdim:
embeddings = embeddings[0]
if return_numpy and not isinstance(embeddings, ndarray):
return embeddings.numpy()
return embeddings
def search(self, queries: Union[str, List[str]],
device: str = None,
threshold: float = 0.6,
top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
if not self.is_faiss_index:
if isinstance(queries, list):
combined_results = []
for query in queries:
results = self.search(query, device)
combined_results.append(results)
return combined_results
similarities = self.similarity(queries, self.index["index"]).tolist()
id_and_score = []
for i, s in enumerate(similarities):
if s >= threshold:
id_and_score.append((i, s))
id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k]
results = [(idx, score) for idx, score in id_and_score]
return results
else:
query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True)
distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k)
def pack_single_result(dist, idx):
results = [(i, s) for i, s in zip(idx, dist) if s >= threshold]
return results
if isinstance(queries, list):
combined_results = []
for i in range(len(queries)):
results = pack_single_result(distance[i], idx[i])
combined_results.append(results)
return combined_results
else:
return pack_single_result(distance[0], idx[0])