diff --git a/examples/tensorflow/classification/configs/quantization/mobilenet_v2_imagenet_int8.json b/examples/tensorflow/classification/configs/quantization/mobilenet_v2_imagenet_int8.json index 3462c2f3663..577d96f2cbb 100644 --- a/examples/tensorflow/classification/configs/quantization/mobilenet_v2_imagenet_int8.json +++ b/examples/tensorflow/classification/configs/quantization/mobilenet_v2_imagenet_int8.json @@ -6,14 +6,14 @@ }, "batch_size": 256, - "epochs": 15, + "epochs": 9, "optimizer": { "type": "Adam", "schedule_type": "piecewise_constant", "schedule_params": { - "boundaries": [5, 10], - "values": [1e-4, 1e-5, 1e-6] + "boundaries": [3, 6], + "values": [1e-3, 1e-4, 1e-5] } }, @@ -23,10 +23,9 @@ "compression": { "algorithm": "quantization", "initializer": { - "batchnorm_adaptation": { - "num_bn_adaptation_samples": 2048, - "num_bn_forget_samples": 1024 - } - } + "range": { + "num_init_samples": 0 + } + } } } diff --git a/examples/tensorflow/classification/main.py b/examples/tensorflow/classification/main.py index 2e6c29f4c68..b63c2b0ab9c 100644 --- a/examples/tensorflow/classification/main.py +++ b/examples/tensorflow/classification/main.py @@ -307,6 +307,7 @@ def export(config): def main(argv): parser = get_argument_parser() config = get_config_from_argv(argv, parser) + #config['eager_mode'] = True serialize_config(config, config.log_dir) @@ -319,4 +320,8 @@ def main(argv): if __name__ == '__main__': + physical_devices = tf.config.list_physical_devices('GPU') + for device in physical_devices: + tf.config.experimental.set_memory_growth(device, True) + main(sys.argv[1:]) diff --git a/nncf/tensorflow/layers/wrapper.py b/nncf/tensorflow/layers/wrapper.py index 9ac77e86f49..9c61b2fbea0 100644 --- a/nncf/tensorflow/layers/wrapper.py +++ b/nncf/tensorflow/layers/wrapper.py @@ -21,6 +21,31 @@ from nncf.tensorflow.layers.operation import InputType +from tensorflow.python.distribute.values_util import get_current_replica_id_as_int +from tensorflow.python.framework import importer +from tensorflow.python.eager import wrap_function +from tensorflow.python.pywrap_tfe import TFE_Py_TapeSetShouldRecordBackprop as \ + check_tensor_in_tape +from tensorflow.python.ops.resource_variable_ops import variable_accessed as \ + add_resource_var_in_tape + +from tensorflow.python.framework import auto_control_deps +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest +from tensorflow.python.util import object_identity +from tensorflow.python.util import tf_decorator +from tensorflow.python.framework.func_graph import FuncGraph +from tensorflow.python.framework.func_graph import _get_defun_inputs_from_args +from tensorflow.python.framework.func_graph import _get_defun_inputs_from_kwargs +from tensorflow.python.framework.func_graph import convert_structure_to_signature +from tensorflow.python.framework.func_graph import flatten +from tensorflow.python.framework.func_graph import check_mutation + + @NNCF_CUSTOM_OBJECTS.register() class NNCFWrapper(tf.keras.layers.Wrapper): """ @@ -169,6 +194,15 @@ def layer_weights(self): def build(self, input_shape=None): super().build(input_shape) + + self.input_shape__ = input_shape + self.tf_f = tf.function(self.layer.call) + + concrete = self.tf_f.get_concrete_function(*[tf.TensorSpec(self.input_shape__, tf.float32)]) + with concrete.graph.as_default() as g: + tf.nn.softmax(g.outputs[0]) + + self.fn_train = concrete for weight_attr, ops in self.weights_attr_ops.items(): weight = self.get_layer_weight(weight_attr) for op_name, op in ops.items(): @@ -179,16 +213,25 @@ def build(self, input_shape=None): self._op_build = True def call(self, inputs, training=None): - training = self._get_training_value(training) - - self._apply_ops(training) + replica_context = tf.distribute.get_replica_context() + if replica_context is not None: + replica_id = get_current_replica_id_as_int() + new_variables = [] + new_captured = [] + for var, input_tensor in zip(self.layer.variables, self.fn_train.inputs[1:]): + new_variables.append(var._get_replica(replica_id)) + new_captured.append((var._get_replica(replica_id).handle, input_tensor)) - if self._layer_expects_training_arg: - outputs = self.layer.call(inputs, training=training) else: - outputs = self.layer.call(inputs) + new_variables = self.fn_train.graph.variables + new_captured = self.fn_train.graph.captures - return outputs + fn_train = make_new_func(self.fn_train.graph.as_graph_def(), + new_captured, + new_variables, + self.fn_train.inputs, + self.fn_train.outputs) + return fn_train(inputs) def _apply_ops(self, training): for weight_attr, ops in self.weights_attr_ops.items(): @@ -200,6 +243,7 @@ def _apply_ops(self, training): self.set_layer_weight(weight_attr, layer_weight) def registry_weight_operation(self, weights_attr, op): + return if weights_attr not in self.weights_attr_ops: self.weights_attr_ops[weights_attr] = OrderedDict() @@ -279,3 +323,61 @@ def from_config(cls, config, custom_objects=None): ) return wrapper + + +####### +# To make possible to get gradients out of concrete function +# their vars id and captured id should be equal +####### +def get_concrete_vars_id(concrete): + res = [] + for var in concrete._func_graph.variables: + res.append(var.handle._id) + return res + + +def get_concrete_captured_id(concrete): + res = [] + for var in concrete.captured_inputs: + res.append(var._id) + return res + + +def _add_concrete_fun_resource_vars_to_tape(concrete): + for v in concrete._func_graph.variables: + add_resource_var_in_tape(v) + + +def _check_concrete_fun_resource_vars_is_in_tape(concrete): + return check_tensor_in_tape(concrete.captured_inputs) + + +def make_new_func(output_graph_def, captures, variables, inputs, outputs): + new_input_names = [tensor.name for tensor in inputs] + inputs_map = { + tensor.name: tensor for tensor in inputs + } + new_output_names = [tensor.name for tensor in outputs] + new_func = my_function_from_graph_def(output_graph_def, + new_input_names, + new_output_names, + captures,) + for input in new_func.inputs: + input.set_shape(inputs_map[input.name].shape) + break + + new_func.graph.variables = variables + return new_func + + +def my_function_from_graph_def(graph_def, inputs, outputs, ref_captures): + def _imports_graph_def(): + importer.import_graph_def(graph_def, name="") + + wrapped_import = wrap_function.wrap_function(_imports_graph_def, []) + import_graph = wrapped_import.graph + wrapped_import.graph.reset_captures([(tensor, import_graph.get_tensor_by_name(placeholder.name)) + for tensor, placeholder in ref_captures]) + return wrapped_import.prune( + nest.map_structure(import_graph.as_graph_element, inputs), + nest.map_structure(import_graph.as_graph_element, outputs)) diff --git a/nncf/tensorflow/quantization/algorithm.py b/nncf/tensorflow/quantization/algorithm.py index 452e799013d..8ac4734c6c7 100644 --- a/nncf/tensorflow/quantization/algorithm.py +++ b/nncf/tensorflow/quantization/algorithm.py @@ -174,20 +174,20 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa callable_object=operation, priority=TransformationPriority.QUANTIZATION_PRIORITY)) - insertion_points = self._find_insertion_points(nncf_graph) - qconfig = self._get_default_qconfig(self.global_quantizer_constraints[QuantizerGroup.ACTIVATIONS]) - for original_node_name, instance_index in insertion_points: - fake_quantize_name = self._get_fake_quantize_name(original_node_name, instance_index) - fake_quantize_layer = FakeQuantize( - TFQuantizerSpec.from_config(qconfig, narrow_range=False, half_range=False), - name=fake_quantize_name) - self._op_names.append(fake_quantize_layer.op_name) - - transformations.register( - TFInsertionCommand( - target_point=TFAfterLayer(original_node_name, instance_index), - callable_object=fake_quantize_layer, - priority=TransformationPriority.QUANTIZATION_PRIORITY)) + #insertion_points = self._find_insertion_points(nncf_graph) + #qconfig = self._get_default_qconfig(self.global_quantizer_constraints[QuantizerGroup.ACTIVATIONS]) + #for original_node_name, instance_index in insertion_points: + # fake_quantize_name = self._get_fake_quantize_name(original_node_name, instance_index) + # fake_quantize_layer = FakeQuantize( + # TFQuantizerSpec.from_config(qconfig, narrow_range=False, half_range=False), + # name=fake_quantize_name) + # self._op_names.append(fake_quantize_layer.op_name) + + # transformations.register( + # TFInsertionCommand( + # target_point=TFAfterLayer(original_node_name, instance_index), + # callable_object=fake_quantize_layer, + # priority=TransformationPriority.QUANTIZATION_PRIORITY)) return transformations