-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpopulate_cache.py
32 lines (24 loc) · 1023 Bytes
/
populate_cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from datasets import Dataset, DatasetDict, load_dataset
from tqdm import tqdm
import pandas as pd
import json
import wandb
import os
import hashlib
tqdm.pandas()
if __name__ == "__main__":
print("Pulling cache from HF")
hf_cache = load_dataset("Kyle1668/LLM-TTA-Cached-Rewrites")
if os.path.exists("cached_rewrites"):
print("Removing old cache")
os.system("rm -rf cached_rewrites")
os.mkdir("cached_rewrites")
for split_name in hf_cache:
local_file_name = split_name.replace("dot", ".").replace("equals", "=")
if "back_translate" in local_file_name:
local_file_name = local_file_name.replace("back_translate", "back-translate")
if "StableBeluga_" in local_file_name:
local_file_name = local_file_name.replace("StableBeluga_", "StableBeluga-")
print(f"Writing {local_file_name} to disk")
local_file_name += ".csv"
hf_cache[split_name].to_csv(f"cached_rewrites/{local_file_name}", index=False)