-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerator.py
97 lines (78 loc) · 3.3 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import csv
import re
import numpy as np
import pandas as pd
import tensorflow as tf
from convert import Converter
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, df: pd.DataFrame,
batch_size: int = 16,
input_size: int = 128,
shuffle: bool = True):
self.df = df.copy()
self.batch_size = batch_size
self.input_size = input_size
self.shuffle = shuffle
self.nums = len(self.df)
self.converter = Converter(input_size)
def on_epoch_end(self):
if self.shuffle:
self.df = self.df.sample(frac=1).reset_index(drop=True)
def __getitem__(self, index):
X = []
y = []
batches = self.df[index * self.batch_size:(index + 1) * self.batch_size]
# print(f"Ask {index} from {self.df.shape} got batches {batches.shape}")
for _, row in batches.iterrows():
text = row["text"]
raw = self.converter.convert_text(text)
if raw is not None:
# print(raw.shape)
# print(f"Got '{text}' to {raw.shape}")
X.append(raw.tolist())
y.append((row["latitude"], row["longitude"]))
Xnumpy = np.array(X)
# if len(Xnumpy.shape) == 1:
# print(f"X : '{len(X)}' : '{len(X[0])}'")
# print(f"Ask {index} from {self.df.shape} got batches {batches.shape} converted to {Xnumpy.shape}")
return Xnumpy, np.array(y)
def __len__(self):
return self.nums // self.batch_size
@staticmethod
def get_data(input_size, base_dir) -> [pd.DataFrame, None]:
if not os.path.exists(base_dir):
print(f"Data dir '{base_dir}' does not exist")
return None
converter = Converter(input_size)
df = pd.DataFrame({'text': pd.Series(dtype='str'),
'latitude': pd.Series(dtype='float'),
'longitude': pd.Series(dtype='float')})
# i = 10
data_dir = os.path.join(base_dir, "data")
for fname in os.listdir(data_dir):
if fname.endswith('.csv'):
latitude, longitude = fname[:-4].split('_', 1)
latitude = float(latitude)
longitude = float(longitude)
fullname = os.path.join(data_dir, fname)
# fdata = pd.read_csv(fullname, sep=';') # FIXME just use cvs reader?
# for text in fdata['text']: # FIXME to do in one shot
with open(fullname, newline='', encoding="utf8") as csvfile:
reader = csv.reader(csvfile, delimiter=';')
next(reader)
for row in reader:
text = row[3]
if converter.convert_text(text) is not None:
df.loc[len(df.index)] = [text,
converter.convert_geo(latitude, 90),
converter.convert_geo(longitude, 180)
]
# i -= 1
# if i < 0:
# break
return df
if __name__ == "__main__":
gdf = DataGenerator.get_data(128, os.curdir)
print(gdf.shape)
print(gdf.head())