-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutil_caching.py
144 lines (110 loc) · 6.58 KB
/
util_caching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from itertools import chain
import pandas as pd
import hashlib
import time
import ast
import os
from util_modeling import is_language_model
cache_frame = {}
def distributed_cache_write(rank, world_size, model_name, dataset_name, icl_method, eval_set, temperature, inference_logs, adaptive_model, entry, seed):
distributed_rewrites_cache = None
cache_write_steps = 5
is_cache_write_step = len(inference_logs) % cache_write_steps == 0
if not is_cache_write_step:
print(f"Skipping cache write because it is not a cache write step: {len(inference_logs)} - {cache_write_steps}")
return
if dist.is_initialized():
if rank == 0:
distributed_rewrites_cache = [[] for i in range(world_size)]
if is_cache_write_step:
print(f"Halting to cache rewrites across ranks - Rank {rank}")
dist.barrier()
print(f"Gathering cached rewrites across ranks")
# dist.gather_object(entry, distributed_rewrites_cache)
dist.gather_object(inference_logs, distributed_rewrites_cache)
if rank == 0:
prev_length = len(inference_logs)
distributed_rewrites_cache = list(chain(*distributed_rewrites_cache))
new_length = len(distributed_rewrites_cache)
print(f"Chained distributed rewrites from {prev_length} to {new_length}")
if not is_cache_write_step:
print(f"Rank 0: Skipping cache write because it is not a cache write step: {len(inference_logs)} - {cache_write_steps}")
return
writable_entries = [write_entry for write_entry in distributed_rewrites_cache]
if len(writable_entries) == 0:
print("Skipping cache writes because all entries were cache hits")
return
# description = f"Writing {len(writable_entries)} rewrites for {dataset_name}-{eval_set} with {model_name} using {icl_method}"
# print(description)
distributed_cache_write_steps = cache_write_steps * world_size
write_cached_rewrites(dataset_name, adaptive_model, temperature, distributed_rewrites_cache, seed, distributed_cache_write_steps)
# cache_style_prompts = [write_entry["style_prompt"] for write_entry in writable_entries]
# cache_texts = [write_entry["text"] for write_entry in writable_entries]
# write_cached_rewrites(dataset_name, adaptive_model, temperature, cache_style_prompts, cache_texts, seed)
else:
write_cached_rewrites(dataset_name, adaptive_model, temperature, inference_logs, seed, cache_write_steps)
def flush_local_cache():
global cache_frame
cache_frame = {}
def get_cached_rewrites(dataset_name, rewrite_model, temperature, input_prompt, seed):
global cache_frame
# set stopwatch for cache read
start_time = time.perf_counter()
if dist.get_rank() == 0:
print()
try:
if not os.path.exists("cached_rewrites"):
os.mkdir("cached_rewrites")
cache_path = f"cached_rewrites/seed={seed}_{dataset_name}_{rewrite_model.name_or_path.replace('/', '_')}.csv"
if is_language_model(rewrite_model.name_or_path):
cache_path = cache_path.replace(".csv", f"_temp={temperature}.csv")
if os.path.exists(cache_path) and cache_path not in cache_frame:
cache_frame[cache_path] = pd.read_csv(cache_path, on_bad_lines="warn", engine="python")
if cache_path in cache_frame and cache_frame[cache_path] is not None:
hashed_prompt = hashlib.sha256(input_prompt.encode()).hexdigest()
read_frame_start = time.perf_counter()
cached_inference = cache_frame[cache_path][cache_frame[cache_path]["prompt_hash"] == hashed_prompt]
end_time = time.perf_counter()
if len(cached_inference) > 0:
print(f"Found cached rewrites for {rewrite_model.name_or_path}. Overall Latency = {round(end_time - start_time, 2)} seconds & Search Latency = {round(end_time - read_frame_start, 2)} seconds")
return ast.literal_eval(cached_inference.iloc[0]["rewrites"])
else:
print()
except Exception as e:
end_time = time.perf_counter()
print(f"Error reading cached rewrites with Latency = {round(end_time - start_time, 2)}: {e}")
return None
def write_cached_rewrites(dataset_name, rewrite_model, temperature, inference_logs, seed, cache_write_steps):
# Track how many MS it takes to write to cache
start_time = time.perf_counter()
try:
cache_path = f"cached_rewrites/seed={seed}_{dataset_name}_{rewrite_model.name_or_path.replace('/', '_')}.csv"
if is_language_model(rewrite_model.name_or_path):
cache_path = cache_path.replace(".csv", f"_temp={temperature}.csv")
print(f"Inference Logs: {len(inference_logs)}")
# logs_to_write = inference_logs[-cache_write_steps:]
logs_to_write = inference_logs
cache_miss_entries = [{
"prompt_hash": hashlib.sha256(log["style prompt"].encode()).hexdigest(),
"prompt": log["style prompt"],
"rewrites": log["input"]
} for log in logs_to_write]
cache_miss_frame = pd.DataFrame(cache_miss_entries)
# cache_miss_frame = cache_miss_frame[~cache_miss_frame["prompt_hash"].isin(cache_frame["prompt_hash"])] if cache_frame is not None else cache_miss_frame
cache_miss_frame = cache_miss_frame[~cache_miss_frame["prompt_hash"].isin(cache_frame[cache_path]["prompt_hash"])] if cache_frame.get(cache_path) is not None else cache_miss_frame
if len(cache_miss_frame) == 0:
print(f"Skipping cache write because all entries were cache hits")
return
fresh_cache_frame = pd.read_csv(cache_path, on_bad_lines="warn", engine="python") if os.path.exists(cache_path) else None
updated_cache_frame = cache_miss_frame if fresh_cache_frame is None else pd.concat([fresh_cache_frame, cache_miss_frame])
# Dedup and write the new frame to local storage
previous_length = len(updated_cache_frame)
updated_cache_frame = updated_cache_frame.drop_duplicates(["prompt_hash"])
print(f"Writing {len(updated_cache_frame) - (previous_length - len(updated_cache_frame))} rewrites to cache")
updated_cache_frame.to_csv(cache_path, index=False)
except Exception as e:
print(f"Error writing cached rewrites: {e}")
end_time = time.perf_counter()
print(f"Cache write took {round(end_time - start_time, 5)} seconds")