-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathutils.py
18 lines (16 loc) · 806 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tensorflow as tf
def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True,
*args, **kwargs):
"""Custom variable getter that forces trainable variables to be stored in
float32 precision and then casts them to the training precision.
"""
storage_dtype = tf.float32 if trainable else dtype
variable = getter(name, shape, dtype=storage_dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
if trainable and dtype != tf.float32:
variable = tf.cast(variable, dtype)
return variable