Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Dec 17, 2024
1 parent bf83e68 commit acdaa7c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions nncf/quantization/algorithms/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def run_step(

def run_from_step(
self,
model: ModelWrapper,
model_wrapper: ModelWrapper,
dataset: Dataset,
start_step_index: int = 0,
step_index_to_statistics: Optional[Dict[int, StatisticPointsContainer]] = None,
Expand All @@ -134,23 +134,23 @@ def run_from_step(
:return: The updated model after executing the pipeline from the specified pipeline
step to the end.
"""
pipeline_steps = self._remove_unsupported_algorithms(get_backend(model.model))
pipeline_steps = self._remove_unsupported_algorithms(model_wrapper.backend)
if step_index_to_statistics is None:
step_index_to_statistics = {}

# The `step_model` and `step_graph` entities are required to execute `step_index`-th pipeline step
step_model = model
step_model_wrapper = model_wrapper
for step_index in range(start_step_index, len(pipeline_steps)):
# Collect statistics required to run current pipeline step
step_statistics = step_index_to_statistics.get(step_index)
if step_statistics is None:
statistic_points = self.get_statistic_points_for_step(step_index, step_model)
step_statistics = collect_statistics(statistic_points, step_model, dataset)
statistic_points = self.get_statistic_points_for_step(step_index, step_model_wrapper)
step_statistics = collect_statistics(statistic_points, step_model_wrapper, dataset)

# Run current pipeline step
step_model = self.run_step(step_index, step_statistics, step_model)
step_model_wrapper = self.run_step(step_index, step_statistics, step_model_wrapper)

return step_model
return step_model_wrapper

def get_statistic_points_for_step(self, step_index: int, model_wrapper: ModelWrapper) -> StatisticPointsContainer:
"""
Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/post_training/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin

def apply(
self,
model: ModelWrapper,
model_wrapper: ModelWrapper,
*,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
Expand All @@ -109,4 +109,4 @@ def apply(
if statistic_points:
step_index_to_statistics = {0: statistic_points}

return self._pipeline.run_from_step(model, dataset, 0, step_index_to_statistics)
return self._pipeline.run_from_step(model_wrapper, dataset, 0, step_index_to_statistics)

0 comments on commit acdaa7c

Please sign in to comment.