-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_orientnet.py
92 lines (78 loc) · 2.8 KB
/
main_orientnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
This module defines the training loop for Orientnet
"""
import os
import numpy as np
import tensorflow as tf
import datetime
from model import keypoint_model, orientation_model
from utils import post_process_orient, Transformer, post_process_kp
from utils import pose_loss, variance_loss, separation_loss, silhouette_loss, mvc_loss, pose_loss
from utils import Transformer
from data_generation import create_data_generator
def orientation_loss(orient, mv):
"""
Args:
orient: (batch_size,2,2)
mv: model view matrix (batch_size, 4, 4)
Returns:
orient_loss: (batch_size, 2)
"""
xp_axis = tf.tile(
tf.constant([[[1.0, 0, 0, 1], [-1.0, 0, 0, 1]]]), [tf.shape(orient)[0], 1, 1]
)
xp = tf.matmul(xp_axis, mv)
xp = t.project(xp)
orient_loss = tf.keras.losses.MSE(orient, xp[..., :2])
return orient_loss
def orient_net_train_step(rgb, mv):
with tf.GradientTape() as tape:
orient = orient_net(rgb)
post_orient, _ = post_process_orient(orient)
loss = orientation_loss(post_orient, mv)
grads = tape.gradient(loss, orient_net.trainable_variables)
optim.apply_gradients(zip(grads, orient_net.trainable_variables))
train_orient_loss(loss)
return orient
if __name__ == '__main__':
parser = argparse.ArgumentParser("./main_orientnet.py")
parser.add_argument(
'--dataset_dir', '-d',
type=str,
required=True,
help='Dataset to train with. No Default',
)
parser.add_argument(
'--batch_size', '-bs',
type=int,
default=5,
help='Batch size',
)
parser.add_argument(
'--num_epochs', '-n',
type=int,
required=True,
help='Batch size',
)
FLAGS, unparsed = parser.parse_known_args()
dataset_dir = FLAGS.dataset_dir
batch_size = FLAGS.batch_size
num_epochs = FLAGS.num_epochs
vw, vh = 128, 128
t = Transformer(vw, vh, dataset_dir)
# remove the files other tf record from here
filenames = [dataset_dir + val for val in os.listdir(dataset_dir) if val.endswith('tfrecord') ]
dataset = create_data_generator(filenames, batch_size=batch_size)
orient_net = orientation_model()
optim = tf.keras.optimizers.Adam(lr=1e-3)
train_orient_loss = tf.keras.metrics.Mean('train_orient_loss', dtype=tf.float32)
for epoch in range(num_epochs):
for idx, data in enumerate(dataset):
for i in range(2):
rgb = data[f"img{i}"][..., :3]
mv = data[f"mv{i}"]
orient = orient_net_train_step(rgb, mv)
if idx % 100000 == 0:
print('loss_orient', train_orient_loss.result())
train_orient_loss.reset_states()
orient_net.save_weights('orientation_network.h5')