This repository has been archived by the owner on Dec 17, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsplit_data.py
72 lines (56 loc) · 2.13 KB
/
split_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import json
from collections import Counter
from pathlib import Path
import fire
import yaml
from loguru import logger
from sklearn.model_selection import train_test_split
def load_texts_labels(input_folder):
texts = []
labels = []
logger.info(f"Loading texts and labels from {input_folder} ...")
for text_file in Path(input_folder).rglob("*.txt"):
texts.append(Path(text_file).read_text())
labels.append(text_file.parent.name)
logger.success("Done!")
logger.info(f"texts: {len(texts)}")
logger.info(f"labels: {len(labels)}")
return texts, labels
@logger.catch(reraise=True)
def split(input_folder, output_folder, test_size, metrics_file):
with open("params.yaml") as f:
params = yaml.safe_load(f)
label_names = params["data"]["labels"]
texts, labels = load_texts_labels(input_folder)
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels, test_size=test_size, stratify=labels)
Path(output_folder).mkdir(exist_ok=True, parents=True)
train_json = [
{"text": text, "label": label_names.index(label)}
for text, label in zip(train_texts, train_labels)
]
val_json = [
{"text": text, "label": label_names.index(label)}
for text, label in zip(val_texts, val_labels)
]
logger.info(f"Writing outputs to {output_folder} ...")
with open(f"{Path(output_folder) / 'train.json'}", "w") as f:
json.dump({"data": train_json}, f, ensure_ascii=False)
with open(f"{Path(output_folder) / 'val.json'}", "w") as f:
json.dump({"data": val_json}, f, ensure_ascii=False)
logger.success("Done!")
logger.info("Computing metrics ...")
metrics = {
"train_texts": len(train_texts),
"val_texts": len(val_texts),
"train_labels": Counter(train_labels),
"val_labels": Counter(val_labels)
}
logger.success("Done!")
logger.info(json.dumps(metrics, indent=4))
logger.info("Writing metrics ...")
with open(metrics_file, "w") as f:
json.dump(metrics, f)
logger.success("Done!")
if __name__ == "__main__":
fire.Fire(split)