generated from Lightning-AI/deep-learning-project-template
-
Notifications
You must be signed in to change notification settings - Fork 117
/
Copy pathpreprocessing_utils.py
61 lines (53 loc) · 1.92 KB
/
preprocessing_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import logging
from pathlib import Path
import numpy as np
import pandas as pd
logger = logging.getLogger("preprocessing_utils")
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
def update_test(test_csv_file):
"""Combines disjointed test and labels csv files into one file."""
test_csv_file = Path(test_csv_file)
test_set = pd.read_csv(test_csv_file)
data_labels = pd.read_csv(str(test_csv_file)[:-4] + "_labels.csv")
for category in data_labels.columns[1:]:
test_set[category] = data_labels[category]
if "content" in test_set.columns:
test_set.rename(columns={"content": "comment_text"}, inplace=True)
output_file = test_csv_file.parent / f"{test_csv_file.stem}_updated.csv"
test_set.to_csv(output_file)
logger.info("Updated test set saved to %s", output_file)
return test_set
def create_val_set(csv_file, val_fraction):
"""Takes in a csv file path and creates a validation set
out of it specified by val_fraction.
"""
csv_file = Path(csv_file)
dataset = pd.read_csv(csv_file)
np.random.seed(0)
dataset_mod = dataset[dataset.toxic != -1]
indices = np.random.rand(len(dataset_mod)) > val_fraction
val_set = dataset_mod[~indices]
output_file = csv_file.parent / "val.csv"
logger.info("Validation set saved to %s", output_file)
val_set.to_csv(output_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--test_csv", type=str)
parser.add_argument("--val_csv", type=str)
parser.add_argument(
"--update_test",
action="store_true",
)
parser.add_argument(
"--create_val_set",
action="store_true",
)
args = parser.parse_args()
if args.update_test:
test_set = update_test(args.test_csv)
if args.create_val_set:
create_val_set(args.val_csv, val_fraction=0.1)