-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtwit.py
224 lines (190 loc) · 7.86 KB
/
twit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import numpy as np
import json
import sys
import argparse
import requests
#import PIL import Image, ImageDraw, ImageFont
from io import BytesIO
import nltk as nt
from config import FLAGS
from collections import Counter
import pickle
import os
import random
import re
"""
트윗 데이터 전처리 파일
"""
class Twit():
MAIN_KEY = ['hashtags','text','hashtags']
def __init__(self):
self.hashtag_max = 0
self.data_path = FLAGS.twit_path
self.twits, self.test = self.load_data() # 80,10,10 for train,dev,test
self.voca = self.build_voca()
self.vocab_size = len(self.voca.keys())
self.ImgTag_data()
self.vec_generation() # 각 트윗의 tokens을 ids로 변환하여 저장(data embedding)
self.curr_tag = 0 # 1twit - 1tag matching을 위해, 현재 읽는 twit의 tag idx를 저장
self._idx_in_epoch = 0
def load_data(self):
"""
raw data에서 MAIN_KEY만을 추출한 뒤 가공한다.
text tokenize, VISION API(api not yet)
"""
if os.path.exists('data.pickle'):
with open('data.pickle','rb') as f:
print("이미 존재하는 정형 트윗 데이터 이용")
data,tests = pickle.load(f)
return data,tests
with open(self.data_path,encoding='utf8') as f:
raw_data = json.load(f)
data = []
print("트위터 새로~")
for i,t in enumerate(raw_data):
t['text'],tmp_url=self.extract_url(t['text'].lower())
t['text'] = re.sub(r'[^\w]',' ',t['text']).lower()
lowered_tag = [j.lower() for j in t['hashtags']]
if len(lowered_tag)==0 : continue
data.append({'raw_text':t['text'],'hashtags':lowered_tag,'image':t['media']})
pos = nt.pos_tag(nt.word_tokenize(t['text']))
data[-1]['tokens'] = [_pos[0] for _pos in pos] + tmp_url
for tok in tmp_url: data[-1]['raw_text'] += ' '+tok
for i in range(10):
print(data[i]['tokens'])
tests = data[int(len(data)*9/10):]
data = data[:int(len(data)*9/10)]
print('훈련 {}개, 테스트 {}개'.format(len(data),len(tests)))
print('섞섞')
random.shuffle(data)
random.shuffle(tests)
with open('data.pickle','wb') as f:
pickle.dump((data,tests),f)
return data,tests
def ImgTag_data(self):
'''
# TODO
Trash code!!! (duplicated with ''load_data()'' function).
Someday, merge this with load_data() with image tag management.
'''
with open(FLAGS.twit_img_path,encoding='utf8') as f:
raw_data = json.load(f)
data = []
for i,t in enumerate(raw_data):
t['text'], tmp_url = self.extract_url(t['text'].lower())
t['text'] = re.sub(r'[^\w]', ' ', t['text']).lower()
lowered_tag = [j.lower() for j in t['hashtags']]
if len(lowered_tag) == 0: continue
data.append({'raw_text': t['text'], 'hashtags': lowered_tag, 'image': t['media'],'image_tag':t['mediaTags']})
pos = nt.pos_tag(nt.word_tokenize(t['text']))
data[-1]['tokens'] = [_pos[0] for _pos in pos] + tmp_url + t['mediaTags']
for tok in tmp_url: data[-1]['raw_text'] += ' ' + tok
self.test = data[int(len(data)*9/10):]
self.twits = data[:int(len(data)*9/10)]
def extract_url(self,strs):
url = []
text = ''
for i,tok in enumerate(strs.split()):
if 'https' in tok:
url.append(tok)
else:
text += ' '+tok
if text[:3] ==' rt': #retweet mark remove
text = ' '.join(text.strip().split()[2:])
return text.strip(), url
def build_voca(self):
# TODO
# BOW 외의 word embedding 구현 시, 이곳에 추가할 것
voca = Counter()
for d in self.twits:
for w in d['tokens']:
voca[w]+=1
#작게 나온 단어들은 단어장에서 제외
pairs = sorted(voca.items(), key=lambda x: (-x[1],x[0]))
for i,qwer in enumerate(pairs):
if qwer[1]== FLAGS.minimum_cnt: break
pairs=pairs[:i-1]
words, _ = list(zip(*pairs))
word_id = dict(zip(words, range(len(words))))
self.voca_list = list(words) + ['_UNK_', '_BEG_', '_EOS_', '_PAD_']
self.UNK_KEY = word_id['_UNK_'] = len(words)
self.BEG_KEY = word_id['_BEG_'] = len(words)+1
self.EOS_KEY = word_id['_EOS_'] = len(words)+2
self.PAD_KEY = word_id['_PAD_'] = len(words)+3
self.DEFINED = [self.UNK_KEY,self.BEG_KEY,self.EOS_KEY,self.PAD_KEY]
return word_id
def vec_generation(self):
for idx,t in enumerate(self.twits):
self.twits[idx]['vec'] = self.tokens_to_id(t['tokens'])
self.twits[idx]['vec'].reverse()
self.twits[idx]['tag_vec'] = self.tokens_to_id(t['hashtags'])
for idx,t in enumerate(self.test):
self.test[idx]['vec'] = self.tokens_to_id(t['tokens'])
self.test[idx]['vec'].reverse()
self.test[idx]['tag_vec'] = self.tokens_to_id(t['hashtags'])
def next_batch(self, batch_size = 100, test=False):
start = self._idx_in_epoch
enc_input = []
dec_input = []
target = []
batch_set = []
if not test:
while len(batch_set)<batch_size:
if self._idx_in_epoch == len(self.twits)-1:
self._idx_in_epoch = 0
self.curr_tag = 0
if self.curr_tag == len(self.twits[self._idx_in_epoch]['tag_vec']):
self.curr_tag = 0
self._idx_in_epoch += 1
t=self.twits[self._idx_in_epoch].copy()
t['tag_vec'] = [t['tag_vec'][self.curr_tag]]
batch_set.append(t)
self.curr_tag+=1
if test:
batch_set = self.test
max_len_input,self.hashtag_max = self._max_len(batch_set)
if not test:
for t in batch_set:
for tag in t['tag_vec']:
a,b,c = self.transform(t['vec'],[tag],max_len_input)
enc_input.append(a)
dec_input.append(b)
target.append(c)
if test:
for t in batch_set:
a,b,c = self.transform(t['vec'],t['tag_vec'],max_len_input)
enc_input.append(a)
dec_input.append(b)
target.append(c)
return enc_input, dec_input, target
def _max_len(self, batch_set):
max_len_input = 0
max_len_output = 0
for i in range(0, len(batch_set)):
len_input = len(batch_set[i]['vec'])
len_output = len(batch_set[i]['tag_vec'])
if len_input > max_len_input:
max_len_input = len_input
if len_output > max_len_output:
max_len_output = len_output
return max_len_input, max_len_output + 1
def transform(self,input,output,input_max):
enc_input = input + [self.PAD_KEY]*max(0,input_max - len(input))
dec_input = [self.BEG_KEY] + output + [self.PAD_KEY]*max(0, self.hashtag_max- len(output)-1)
target = output + [self.EOS_KEY] + [self.PAD_KEY]*max(0, self.hashtag_max - len(output)-1)
enc_input = np.eye(self.vocab_size)[enc_input]
dec_input = np.eye(self.vocab_size)[dec_input]
return enc_input, dec_input, target
def tokens_to_id(self,tokens):
#String to vector
ids = [self.voca[i] if i in self.voca else self.UNK_KEY for i in tokens]
return ids
def decode(self,indices,string=False):
#vector to string
tok = [[self.voca_list[i] for i in dec]for dec in indices]
return tok
if __name__ == '__main__':
t = Twit()
print(len(t.voca_list))
for i in t.twits:
print(i)