-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
122 lines (100 loc) · 3.27 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import *
import requests
import json
import datetime
import random
import socket
import numpy as np
try:
with open("slack_webhook.txt", "r") as f:
slack_url = f.readline().strip()
except Exception as e:
pass
def bfs_words(alphabet: str, limit_depth: Optional[int] = None, limit_num: Optional[int] = None) -> Iterable[str]:
num = 0
queue = [""]
while True:
word = queue.pop(0)
if limit_depth is not None and len(word) > limit_depth:
break
num += 1
yield word
if limit_num is not None and num >= limit_num:
break
queue += [word + c for c in alphabet]
def notify_slack(message: str) -> bool:
try:
requests.post(slack_url, data=json.dumps({
'text': message, # 投稿するテキスト
'username': u'rnn2wfa', # 投稿のユーザー名
'icon_emoji': u':ghost:', # 投稿のプロフィール画像に入れる絵文字
'link_names': 1, # メンションを有効にする
}))
return True
except Exception as e:
return False
def get_time_hash():
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "-" \
+ socket.gethostbyname(socket.gethostname()).replace(".", "-")
T = TypeVar("T")
def weighted_choice(x: Dict[T, int]) -> T:
"""
Equivalent to random.choice(concat([[k]*v for k, v in x.items()]))
:param x:
:return:
"""
n = sum(x.values())
val = random.randint(0, n - 1)
for k, v in x.items():
if val < v:
return k
val -= v
assert False
def argmax_dict(d: Dict[T, float]) -> T:
assert len(d) > 0
maxk, maxv = None, None
for k, v in d.items():
if maxk is None or maxv < v:
maxk = k
maxv = v
return maxk
def sample_length_from_all_words(n_alphabets: int,
max_length: int) -> int:
return weighted_choice({i: i ** n_alphabets for i in range(max_length + 1)})
def sample_length_from_all_lengths(n_alphabets: int,
max_length: int) -> int:
return random.randint(1, max_length)
def make_words(alphabets: str,
max_length: int,
n_samples: int,
length_sampling: Callable[[int, int], int],
exclude_list: Optional[List[str]] = None,
random_seed: Optional[int] = None) -> List[str]:
if exclude_list is None:
exclude_list = []
if random_seed is not None:
rstate = random.getstate()
random.seed(random_seed)
n_alphabets = len(alphabets)
abort_counter = 0
words = set()
while len(words) < n_samples:
abort_counter += 1
if abort_counter > n_samples * 100:
break
length = length_sampling(n_alphabets, max_length)
word = ''
for _ in range(length):
word += random.choice(alphabets)
assert len(word) <= max_length
if word in exclude_list:
continue
words.add(word)
if random_seed is not None:
random.setstate(rstate)
return list(words)
def dist_f(d: Callable[[np.ndarray, np.ndarray], float],
f: Callable[[T], np.ndarray],
x: T,
y: T) -> float:
return d(f(x), f(y))