Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Jan 20, 2025
1 parent cb4d2f5 commit e5e540a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 35 deletions.
24 changes: 13 additions & 11 deletions nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,19 @@ def quantize_impl(
ignored_scope=ignored_scope,
advanced_parameters=advanced_parameters,
)

# To make it easier for bias correction algorithms.
apply_quantization_transformations(copied_model)

nncf_graph = NNCFGraphFactory.create(copied_model)
quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)
with torch.no_grad():
nncf_graph = NNCFGraphFactory.create(copied_model)
quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)

if is_weight_compression_needed(advanced_parameters):
compress_post_quantize_transformation(quantized_model)
else:
fq_weights_transformation(quantized_model)
if is_weight_compression_needed(advanced_parameters):
compress_post_quantize_transformation(quantized_model)
else:
fq_weights_transformation(quantized_model)

# Magic. Without this call compiled model
# is not preformant
# Magic. Without this call compiled model is not performant
quantized_model = GraphModule(quantized_model, quantized_model.graph)

quantized_model = _fold_conv_bn_qat(quantized_model)
Expand Down Expand Up @@ -151,8 +150,11 @@ def compress_weights_impl(
backup_mode,
advanced_parameters,
)
graph = NNCFGraphFactory.create(model)
compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)

with torch.no_grad():
graph = NNCFGraphFactory.create(model)
compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)

compressed_model = GraphModule(compressed_model, compressed_model.graph)
compressed_model = _disallow_eval_train(compressed_model)

Expand Down
11 changes: 5 additions & 6 deletions nncf/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def infer(
:param input_data: Inputs for the model.
:return: Model outputs.
"""
with torch.no_grad():
if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
return self._model(*input_data)
return self._model(input_data)
if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
return self._model(*input_data)
return self._model(input_data)
38 changes: 20 additions & 18 deletions nncf/torch/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,26 @@ def quantize_impl(
if mode is not None:
raise ValueError(f"mode={mode} is not supported")

copied_model = deepcopy(model)
with torch.no_grad():
copied_model = deepcopy(model)

example_input = next(iter(calibration_dataset.get_inference_data()))
nncf_network = wrap_model(copied_model.eval(), example_input, trace_parameters=True)
example_input = next(iter(calibration_dataset.get_inference_data()))
nncf_network = wrap_model(copied_model.eval(), example_input, trace_parameters=True)

quantization_algorithm = PostTrainingQuantization(
preset=preset,
target_device=target_device,
subset_size=subset_size,
fast_bias_correction=fast_bias_correction,
model_type=model_type,
ignored_scope=ignored_scope,
advanced_parameters=advanced_parameters,
)
graph = nncf_network.nncf.get_graph()
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
quantized_model = quantization_algorithm.apply(nncf_network, graph, dataset=calibration_dataset)
quantization_algorithm = PostTrainingQuantization(
preset=preset,
target_device=target_device,
subset_size=subset_size,
fast_bias_correction=fast_bias_correction,
model_type=model_type,
ignored_scope=ignored_scope,
advanced_parameters=advanced_parameters,
)
graph = nncf_network.nncf.get_graph()
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
quantized_model = quantization_algorithm.apply(nncf_network, graph, dataset=calibration_dataset)

quantized_model.nncf.disable_dynamic_graph_building()
quantized_model.nncf.disable_dynamic_graph_building()

return quantized_model

Expand Down Expand Up @@ -117,5 +118,6 @@ def compress_weights_impl(
backup_mode,
advanced_parameters,
)
graph = NNCFGraphFactory.create(model)
return compression_algorithm.apply(model, graph, dataset=dataset)
with torch.no_grad():
graph = NNCFGraphFactory.create(model)
return compression_algorithm.apply(model, graph, dataset=dataset)

0 comments on commit e5e540a

Please sign in to comment.