-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconsensus_scoring.py
83 lines (65 loc) · 2.85 KB
/
consensus_scoring.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json
import os
from typing import Dict, List
import torch
from omegaconf import OmegaConf
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm import tqdm
def consensus_scoring(all_candidates: List[str]) -> List[float]:
"""
Compute the Conensus score (CIDEr score) for a list of candidate captions using the consensus scoring method.
Parameters:
- all_candidates (List[str]): A list of candidate captions.
Returns:
- List[float]: The Conensus score (CIDEr score) scores for the candidate captions.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vectorizer = TfidfVectorizer(ngram_range=(1, 2))
tf_idf_matrix = vectorizer.fit_transform(all_candidates)
tf_idf_tensor = torch.tensor(tf_idf_matrix.todense(), dtype=torch.float64).to(device)
similarity_scores = torch.mm(tf_idf_tensor, tf_idf_tensor.transpose(0, 1))
concensus_scores = (
((similarity_scores.sum(dim=-1) - 1) / (similarity_scores.size(1) - 1)).cpu().detach().numpy().tolist()
)
return concensus_scores
def bad_format_filter(caption_list: List[str]):
"""
Filter captions based on length and punctuation.
Parameters:
- caption_list (List[str]): List of captions to filter.
"""
filtered_captions = [
caption
for caption in caption_list
if len(caption.split()) > 5 and caption.count(",") < 3 and caption.count(".") < 2
]
return filtered_captions if filtered_captions else caption_list
def itm_filter(scores_dict: Dict[str, List[str]]):
"""
Filter captions based on ITM scores.
Parameters:
- scores_dict (Dict[str, List[str]]): Dictionary with keys "captions" and "scores".
"""
captions = scores_dict["captions"]
scores = scores_dict["scores"]
# Get indexes of 50% of the captions with the highest scores
top_50_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[: len(scores) // 2]
top_50_captions = [captions[i] for i in top_50_idx]
return top_50_captions
if __name__ == "__main__":
# Load Configs
cfg = OmegaConf.load("configs.yaml")
print("Loading blip2 ITM scores...")
with open("/workspace/data/scores/blip2_itm_scores.json", "r") as f:
blip2_itm_scores = json.load(f)
print("Loaded blip2 ITM scores.")
result_dict = {}
for file in tqdm(blip2_itm_scores.keys()):
filtered_captions = itm_filter(blip2_itm_scores[file])
filtered_captions = bad_format_filter(filtered_captions)
consensus_score = consensus_scoring(filtered_captions)
result_dict[file] = {"captions": filtered_captions, "scores": consensus_score}
score_file_path = os.path.join(cfg.DIR.Score, "itm_filtered_consensus.json")
with open(score_file_path, "w") as f:
json.dump(result_dict, f)
print("Scoring completed. Saving scores to", score_file_path)