-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathutils.py
100 lines (78 loc) · 2.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import string
import pickle
class Index(object):
def __init__(self):
self.key2idx = {}
self.idx2key = []
def add(self, key):
if key not in self.key2idx:
self.key2idx[key] = len(self.idx2key)
self.idx2key.append(key)
return self.key2idx[key]
def __getitem__(self, key):
if isinstance(key, str):
return self.key2idx[key]
if isinstance(key, int):
return self.idx2key[key]
def __len__(self):
return len(self.idx2key)
def save(self, f):
with open(f, 'wt', encoding='utf-8') as fout:
for index, key in enumerate(self.idx2key):
fout.write(key + '\t' + str(index) + '\n')
def load(self, f):
with open(f, 'rt', encoding='utf-8') as fin:
for line in fin:
line = line.strip()
if not line:
continue
key = line.split()[0]
self.add(key)
class Charset(Index):
def __init__(self):
super().__init__()
for char in string.printable[0:-6]:#所有的字母加符号
self.add(char)
self.add("<pad>")
self.add("<unk>")
@staticmethod
def type(char):
if char in string.digits:
return "Digits"
if char in string.ascii_lowercase:
return "Lower Case"
if char in string.ascii_uppercase:
return "Upper Case"
if char in string.punctuation:
return "Punctuation"
return "Other"
def __getitem__(self, key):
if isinstance(key, str) and key not in self.key2idx:
return self.key2idx["<unk>"]
return super().__getitem__(key)
class Vocabulary(Index):
def __init__(self):
super().__init__()
self.add("<pad>")
self.add("<unk>")
def __getitem__(self, key):
if isinstance(key, str) and key not in self.key2idx:
return self.key2idx["<unk>"]
return super().__getitem__(key)
def prepare_sequence(seq, to_idx):
return [to_idx[key] for key in seq]
def save(obj, path):
with open(path, 'wb') as f:
pickle.dump(obj, f)
def load(path):
with open(path, 'rb') as f:
return pickle.load(f)
def time_display(s):
d = s // (3600*24)
s -= d * (3600*24)
h = s // 3600
s -= h * 3600
m = s // 60
s -= m * 60
str_time = "{:1d}d ".format(int(d)) if d else " "
return str_time + "{:0>2d}:{:0>2d}:{:0>2d}".format(int(h), int(m), int(s))