-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
129 lines (96 loc) · 4.62 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
parser = argparse.ArgumentParser(description="Download and transform EEG dataset")
parser.add_argument("action",
type=str,
choices=["download", "transform", "compress"],
help="Action can be download/transform/compress the dataset")
parser.add_argument("dataset",
type=str,
help="If the action is download, specifies the dataset size (small, large, full), otherwise specifies dataset path (.tar.gz)")
parser.add_argument("output-path",
type=str,
help="Output dataset path. The resulting file will be in .tar.gz if the action is download and .csv or parquet format if compression is enabled")
# early parsing arguments
args = None
if __name__ == "__main__":
# parse arguments
args = vars(parser.parse_args())
# imports
import os
import csv
import tempfile
from tqdm import tqdm
from termcolor import colored
from sibyl import dataset
from sibyl import transformer
from sibyl.util import filesystem as fs
def download_dataset(dataset_type: str, output_path: str):
print(colored("Starting to download {} dataset".format(dataset_type), "cyan"))
url = dataset.get_download_url(dataset_type)
print(colored("Downloading {}".format(url), "cyan"))
dataset.async_download(url, output_path)
print(colored("Download finished.", "cyan"))
def process_dataset(dataset_path: str, output_path: str):
print(colored("Transforming dataset...", "cyan"))
temp_dataset_path = tempfile.mkdtemp(prefix="sibyl_eeg_temp")
# if the passed path is a file, extract it first
print(colored("Decompressing main dataset archive...", "cyan"))
fs.decompress_tar(dataset_path, temp_dataset_path)
# decompress all tar files
for data_file in tqdm(fs.find_files(temp_dataset_path, ".tar.gz"), desc="Decompressing files (step 1)"):
fs.decompress_tar(data_file, temp_dataset_path)
fs.delete_file(data_file)
# decompress all gz files
for data_file in tqdm(fs.find_files(temp_dataset_path, ".gz"), desc="Decompressing files (step 2)"):
sample_extract_path = os.path.join(temp_dataset_path, os.path.basename(data_file) + ".txt")
fs.decompress_gz(data_file, sample_extract_path)
fs.delete_file(data_file)
# process all files
with open(output_path, 'w', newline='', encoding='utf-8') as file_stream:
csv_writer = csv.writer(file_stream)
transformer.write_csv_header(csv_writer)
for record_file in tqdm(fs.find_files(temp_dataset_path, ".txt"), desc="Parsing dataset"):
rows = transformer.parse_file(record_file)
if rows is None:
continue
for row in rows:
csv_writer.writerow(row)
fs.delete_file(record_file)
# delete temporary directory
print(colored("Deleting temporary files...", "cyan"))
fs.delete_dir(temp_dataset_path)
print(colored("Transform complete!", "cyan"))
def compress_dataset(dataset_path: str, output_path: str):
print(colored("Saving as parquet file with gzip compression, might take a while!", "cyan"))
print(colored("Compressing...", "cyan"))
transformer.save_as_parquet(dataset_path, output_path)
print(colored("Saved as {}".format(output_path), "cyan"))
# main app entry point
if __name__ == "__main__":
# download dataset from UCI server (.tar.gz)
if args["action"] == "download":
if (args["dataset"] not in ["small", "large", "full"]):
print(colored("Unknown dataset type, valid values are: small, large, full", "red"))
exit()
download_dataset(args["dataset"], args["output-path"])
# transform dataset (.tar.gz to .csv)
elif args["action"] == "transform":
if not fs.is_file_exists(args["dataset"]):
print(colored("Dataset file does not exists", "red"))
exit()
if not fs.is_file_extension(args["dataset"], [".tar.gz", ".tar"]):
print(colored("Dataset is not in .tar or .tar.gz extension", "red"))
exit()
process_dataset(args["dataset"], args["output-path"])
# compress dataset (.csv to .parquet)
elif args["action"] == "compress":
if not fs.is_file_exists(args["dataset"]):
print(colored("Dataset file does not exists", "red"))
exit()
if not fs.is_file_extension(args["dataset"], [".csv"]):
print(colored("Dataset is not in .csv extension", "red"))
exit()
compress_dataset(args["dataset"], args["output-path"])
# out of range
else:
print("Unknown action, valid values are: download, transform")