-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbenchmark.py
112 lines (89 loc) · 3.49 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import re
import json
import torch
import warnings
import argparse
import pandas as pd
from tqdm import tqdm
from typing import Optional
from vlms import load_model
from bias_eval_utils import Prompt
from bias_eval_utils import BiasPrompt
from utils.configs import dataset_configs
from utils.benchmark_utils import make_dataloader, encode_option_letter
# Silence warnings
warnings.filterwarnings("ignore")
def make_prompts(prompt_chunk_index: int, base_dir: str) -> list[BiasPrompt]:
print(prompt_chunk_index)
# Get prompts
with open(os.path.join(base_dir, f"prompts_{prompt_chunk_index}.json"), "r") as f:
prompts = json.load(f)
# Update the absolute image paths
for prompt in prompts:
path_to_image = prompt["image"]
image_name = re.split(r"images/", path_to_image, maxsplit=1)[-1]
prompt["image"] = os.path.join(dataset_configs["data_root"], prompt["dataset"], "images", image_name)
prompts = [BiasPrompt(**prompt) for prompt in prompts]
return prompts
def get_cmd_arguments() -> argparse.Namespace:
# Make argument parser
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--prompt-chunk-index", type=int, required=True)
parser.add_argument("--task", type=str, default=None)
return parser.parse_args()
def prompt_to_keys(prompt: Prompt) -> dict[str, str]:
return {
"gender": prompt.gender,
"value": prompt.value,
"task": prompt.task,
"dataset": prompt.dataset,
"image": prompt.image,
"unknown_option_letter": prompt.unknown_option_letter,
"yes_option_letter": prompt.yes_option_letter,
"no_option_letter": prompt.no_option_letter,
}
def get_results(prompt_chunk_index: int, task: Optional[str] = None) -> list[dict]:
# Load model
model = load_model(args.model)
# Make classification prompts
if task is not None:
base_dir = f"./data/prompts_by_task/{task}"
else:
base_dir = "./data/prompts"
prompts = make_prompts(prompt_chunk_index, base_dir=base_dir)
# Make dataloader
dataloader = make_dataloader(prompts, model, args.model)
# Iterate dataloader and get classification results
results = []
for prompt, metadata in tqdm(iter(dataloader)):
with torch.no_grad():
probs = model.get_next_token_probabilities(prompt)
# Get option letter indices
for i, prompt_metadata in enumerate(metadata):
probs_i = probs[i]
option_probs = dict()
for letter, option in prompt_metadata.letter_to_option.items():
letter_index = encode_option_letter(letter, model, args.model)
option_probs[option] = probs_i[letter_index].item()
keys_to_save = prompt_to_keys(prompt_metadata)
results.append(
{
**keys_to_save,
**option_probs,
}
)
return results
if __name__ == '__main__':
# Make argument parser
args = get_cmd_arguments()
# Get results
prompt_chunk_index = args.prompt_chunk_index
results = get_results(prompt_chunk_index, task=args.task)
# Convert results to pd dataframe
results_df = pd.DataFrame(results)
save_path = os.path.join("./results/benchmark/", args.model)
os.makedirs(save_path, exist_ok=True)
save_filename = f"{prompt_chunk_index}.csv"
results_df.to_csv(os.path.join(save_path, save_filename), index=False)