-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata.py
205 lines (164 loc) · 6.91 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import re
import numpy as np
import selfies as sf
class selfies_vocabulary(object):
def __init__(self, vocab_path):
self.alphabet = set()
with open(vocab_path, 'r') as f:
chars = f.read().split()
for char in chars:
self.alphabet.add(char)
self.special_tokens = ['[EOS]', '[BOS]', '[PAD]', '[UNK]']
self.alphabet_list = list(self.alphabet)
self.alphabet_list.sort()
self.alphabet_list = self.alphabet_list + self.special_tokens
self.alphabet_length = len(self.alphabet_list)
self.alphabet_to_idx = {s: i for i, s in enumerate(self.alphabet_list)}
self.idx_to_alphabet = {s: i for i, s in self.alphabet_to_idx.items()}
self.action_list = self.alphabet_list[:-3]
self.action_length = len(self.action_list)
self.special_tokens_idx = [self.eos, self.bos, self.pad, self.unk]
def tokenize(self, selfies, add_bos=False, add_eos=False):
"""Takes a SMILES and return a list of characters/tokens"""
char_list = sf.split_selfies(selfies)
tokenized = []
for char in char_list:
if char.startswith('['):
tokenized.append(char)
else:
chars = [unit for unit in char]
[tokenized.append(unit) for unit in chars]
for char in char_list:
tokenized.append(char)
if add_bos:
tokenized.insert(0, "[BOS]")
if add_eos:
tokenized.append('[EOS]')
return tokenized
def encode(self, selfies, add_bos=False, add_eos=False):
"""Takes a list of SELFIES and encodes to array of indices"""
char_list = self.tokenize(selfies, add_bos, add_eos)
encoded_selfies = np.zeros(len(char_list), dtype=np.uint8)
for i, char in enumerate(char_list):
encoded_selfies[i] = self.alphabet_to_idx[char]
return encoded_selfies
def decode(self, encoded_selfies, rem_bos=True, rem_eos=True):
"""Takes an list of indices and returns the corresponding SELFIES"""
if rem_bos and encoded_selfies[0] == self.bos:
encoded_selfies = encoded_selfies[1:]
if rem_eos and encoded_selfies[-1] == self.eos:
encoded_selfies = encoded_selfies[:-1]
chars = []
for i in encoded_selfies:
chars.append(self.idx_to_alphabet[i])
selfies = "".join(chars)
smiles = sf.decoder(selfies)
return smiles
def decode_padded(self, encoded_selfies, rem_bos=True):
"""Takes a padded array of indices which might contain special tokens and returns the corresponding SMILES"""
if rem_bos and encoded_selfies[0] == self.bos:
encoded_selfies = encoded_selfies[1:]
chars = []
for i in encoded_selfies:
if i == self.eos: break
if i not in self.special_tokens_idx: chars.append(self.idx_to_alphabet[i])
selfies = "".join(chars)
smiles = sf.decoder(selfies)
return smiles
def __len__(self):
return len(self.alphabet_to_idx)
@property
def bos(self):
return self.alphabet_to_idx['[BOS]']
@property
def eos(self):
return self.alphabet_to_idx['[EOS]']
@property
def pad(self):
return self.alphabet_to_idx['[PAD]']
@property
def unk(self):
return self.alphabet_to_idx['[UNK]']
class smiles_vocabulary(object):
def __init__(self, vocab_path):
self.alphabet = set()
with open(vocab_path, 'r') as f:
chars = f.read().split()
for char in chars:
self.alphabet.add(char)
self.special_tokens = ['EOS', 'BOS', 'PAD', 'UNK']
self.alphabet_list = list(self.alphabet)
self.alphabet_list.sort()
self.alphabet_list = self.alphabet_list + self.special_tokens
self.alphabet_length = len(self.alphabet_list)
self.alphabet_to_idx = {s: i for i, s in enumerate(self.alphabet_list)}
self.idx_to_alphabet = {s: i for i, s in self.alphabet_to_idx.items()}
self.special_tokens_idx = [self.eos, self.bos, self.pad, self.unk]
def tokenize(self, smiles, add_bos=False, add_eos=False):
"""Takes a SMILES and return a list of characters/tokens"""
regex = '(\[[^\[\]]{1,6}\])'
smiles = replace_halogen(smiles)
char_list = re.split(regex, smiles)
tokenized = []
for char in char_list:
if char.startswith('['):
tokenized.append(char)
else:
chars = [unit for unit in char]
[tokenized.append(unit) for unit in chars]
if add_bos:
tokenized.insert(0, "BOS")
if add_eos:
tokenized.append('EOS')
return tokenized
def encode(self, smiles, add_bos=False, add_eos=False):
"""Takes a list of SMILES and encodes to array of indices"""
char_list = self.tokenize(smiles, add_bos, add_eos)
encoded_smiles = np.zeros(len(char_list), dtype=np.uint8)
for i, char in enumerate(char_list):
encoded_smiles[i] = self.alphabet_to_idx[char]
return encoded_smiles
def decode(self, encoded_smiles, rem_bos=True, rem_eos=True):
"""Takes an list of indices and returns the corresponding SMILES"""
if rem_bos and encoded_smiles[0] == self.bos:
encoded_smiles = encoded_smiles[1:]
if rem_eos and encoded_smiles[-1] == self.eos:
encoded_smiles = encoded_smiles[:-1]
chars = []
for i in encoded_smiles:
chars.append(self.idx_to_alphabet[i])
smiles = "".join(chars)
smiles = smiles.replace("L", "Cl").replace("R", "Br")
return smiles
def decode_padded(self, encoded_smiles, rem_bos=True):
"""Takes a padded array of indices and returns the corresponding SMILES"""
if rem_bos and encoded_smiles[0] == self.bos:
encoded_smiles = encoded_smiles[1:]
chars = []
for i in encoded_smiles:
if i == self.eos: break
if i not in self.special_tokens_idx: chars.append(self.idx_to_alphabet[i])
smiles = "".join(chars)
smiles = smiles.replace("L", "Cl").replace("R", "Br")
return smiles
def __len__(self):
return len(self.alphabet_to_idx)
@property
def bos(self):
return self.alphabet_to_idx['BOS']
@property
def eos(self):
return self.alphabet_to_idx['EOS']
@property
def pad(self):
return self.alphabet_to_idx['PAD']
@property
def unk(self):
return self.alphabet_to_idx['UNK']
def replace_halogen(string):
"""Regex to replace Br and Cl with single letters"""
br = re.compile('Br')
cl = re.compile('Cl')
string = br.sub('R', string)
string = cl.sub('L', string)
return string