-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathproof_steps_by_file.py
executable file
·99 lines (85 loc) · 3.35 KB
/
proof_steps_by_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
import argparse
import glob
import os.path
import os
import pickle
import multiprocessing
import json
import functools
from pathlib import Path
from tqdm import tqdm
import fcntl
import time
from lark import Tree
from typing import Any, Dict, List
from hashlib import sha256
class LarkEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Tree):
return {"__is_tree__": True,
"data": obj.data,
"children": [self.default(child) for child in obj.children]}
return json.JSONEncoder.default(self, obj)
def lark_decoder(dct):
if "__is_tree__" in dct:
return Tree(dct["data"], dct["children"])
return dct
class FileLock:
def __init__(self, file_handle):
self.file_handle = file_handle
def __enter__(self):
while True:
try:
fcntl.flock(self.file_handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
break
except OSError:
time.sleep(0.01)
return self
def __exit__(self, type, value, traceback):
fcntl.flock(self.file_handle, fcntl.LOCK_UN)
def get_local_path(filename):
parts = Path(filename).parts
i = parts.index('data')
return os.path.join(*parts[i + 1:])
def get_proj(filename):
parts = Path(filename).parts
i = parts.index('data')
return parts[i+1]
def to_file(args: argparse.Namespace, filename: str) -> None:
with open(filename, 'rb') as f:
step = pickle.load(f)
local_path = get_local_path(step['file'])
outpath = os.path.join(args.outdir, local_path)
os.makedirs(os.path.dirname(outpath), exist_ok=True)
with open(outpath, 'a') as f, FileLock(f):
print(json.dumps(step, cls=LarkEncoder), file=f)
def to_hashes(args: argparse.Namespace, projs_split: Dict[str, List[str]], filename: str) -> None:
with open(filename, 'r') as f:
for line in f:
step = json.loads(line, object_hook=lark_decoder)
step_hash = sha256(repr(step).encode('utf-8')).hexdigest()
if get_proj(step["file"]) in projs_split["projs_train"]:
outpath = os.path.join(args.outdir, "train", f"{stephash}.pickle")
else:
assert get_proj(step["file"]) in projs_split("projs_valid"), get_proj(step["file"])
outpath = os.path.join(args.outdir, "valid", f"{stephash}.pickle")
os.makedirs(os.path.dirname(outpath), exist_ok=True)
with open(outpath, 'w') as f:
pickle.dump(step, outpath)
parser = argparse.ArgumentParser()
parser.add_argument("indir")
parser.add_argument("outdir")
parser.add_argument("-j", "--num-threads", default=4)
parser.add_argument("--to-hashes", action="store_true")
parser.add_argument("--projs_split", default="projs_split.json")
args = parser.parse_args()
with multiprocessing.Pool(args.num_threads) as pool:
if args.to_hashes:
paths = glob.glob(os.path.join(args.indir, "**/*.json"), recursive=True)
projs_split = json.load(open(args.projs_split))
res = list(tqdm(pool.imap(functools.partial(to_hashes, args, projs_split), paths),
total=len(paths)))
else:
paths = glob.glob(os.path.join(args.indir, "**/*.pickle"), recursive=True)
res = list(tqdm(pool.imap(functools.partial(to_file, args), paths), total=len(paths)))