-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata_reader.py
64 lines (62 loc) · 1.9 KB
/
data_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import math
from caffe2.python import (
core, utils, workspace, schema, layer_model_helper
)
from caffe2.proto import caffe2_pb2
from caffe2.python.layers.tags import Tags
def write_db(db_type, db_name, data_lst):
''' The minidb datebase seems immutable.
Pre-condition: the first dimension is the num of example
'''
db = core.C.create_db(db_type, db_name, core.C.Mode.write)
transaction = db.new_transaction()
# for each example
for i in range(data_lst[0].shape[0]):
tensor = caffe2_pb2.TensorProtos()
temp_lst = []
for data in data_lst:
temp_lst.append(utils.NumpyArrayToCaffe2Tensor(data[i]))
tensor.protos.extend(temp_lst)
transaction.put(str(i), tensor.SerializeToString())
del transaction
del db
return db_name
def build_input_reader(
model,
db_name, db_type,
input_names_lst,
batch_size = 1,
data_type='train',
):
'''
Init the dbreader and build the network for reading the data,
however, the newwork is not connected to the computation network yet.
Therefore we can switch between different data sources.
'''
assert batch_size != 0, 'batch_size cannot be zero'
reader_init_net = core.Net('reader_init_net_'+data_type)
dbreader = reader_init_net.CreateDB(
[], 'dbreader_'+data_type, db=db_name, db_type=db_type)
# need to initialze dbreader ONLY ONCE
workspace.RunNetOnce(reader_init_net)
if data_type == 'train':
TAG = Tags.TRAIN_ONLY
elif data_type == 'eval':
TAG = Tags.EVAL_ONLY
else:
raise Exception('data type: {} not valid.'.format(data_type))
with Tags(TAG):
# the last one is the label
input_data_struct = model.TensorProtosDBInput(
[dbreader],
input_names_lst,
name = 'DBInput_' + data_type,
batch_size=batch_size
)
input_data_lst = [input_data for input_data in input_data_struct]
for i in range(len(input_data_lst)-1):
input_data_lst[i] = model.StopGradient(
input_data_lst[i], input_data_lst[i]
)
return input_data_lst