Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 9, 2023
1 parent 84b0a53 commit 3c0e659
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 9 deletions.
12 changes: 10 additions & 2 deletions nncf/torch/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,23 @@ class GraphBuilder:
def __init__(self, custom_forward_fn: Callable[[torch.nn.Module], Any]):
self.custom_forward_fn = custom_forward_fn

def build_dynamic_graph(
self,
model: torch.nn.Module,
context_to_use: Optional[TracingContext] = None,
as_eval: bool = False,
) -> DynamicGraph:
tracer = GraphTracer(self.custom_forward_fn)
return tracer.trace_graph(model, context_to_use, as_eval)

def build_graph(
self,
model: torch.nn.Module,
context_to_use: Optional[TracingContext] = None,
as_eval: bool = False,
input_infos: List[ModelInputInfo] = None,
) -> PTNNCFGraph:
tracer = GraphTracer(self.custom_forward_fn)
dynamic_graph = tracer.trace_graph(model, context_to_use, as_eval)
dynamic_graph = self.build_dynamic_graph(model=model, context_to_use=context_to_use, as_eval=as_eval)
return GraphConverter.convert(dynamic_graph, input_infos)


Expand Down
2 changes: 2 additions & 0 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P
for transformation_command in transformations:
target_point: PTTargetPoint = transformation_command.target_point
target_node_name = target_point.target_node_name
if target_node_name not in node_to_op_address_mapping:
breakpoint()
pt_ip = PTInsertionPoint(
target_type=target_point.target_type,
op_address=node_to_op_address_mapping[target_node_name],
Expand Down
35 changes: 29 additions & 6 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def __init__(
)
self._original_graph = GraphConverter.convert(self._original_dynamic_graph, input_infos=self._input_infos)
self._compressed_graph: PTNNCFGraph = None
self._compressed_traced_graph: DynamicGraph = None

self._compressed_context = TracingContext()

Expand Down Expand Up @@ -358,6 +359,25 @@ def reset_nncf_modules(self):
module = self.get_module_by_scope(some_scope)
module.reset()

def get_shallow_copy(self) -> "NNCFNetwork":
from nncf.torch.utils import load_module_state
from nncf.torch.utils import save_module_state

saved_state = save_module_state(self._model_ref)
new_interface = NNCFNetworkInterface(
self._model_ref,
self._input_infos,
self._user_dummy_forward_fn,
self._wrap_inputs_fn,
self._scopes_without_shape_matching,
self._ignored_scopes,
self._target_scopes,
wrap_outputs_fn=self._wrap_outputs_fn,
)
self._model_ref._nncf = new_interface
load_module_state(self._model_ref, saved_state)
return self._model_ref

def get_clean_shallow_copy(self) -> "NNCFNetwork":
# WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions
# and load_nncf_module_additions to preserve these, or temporary_clean_view().
Expand Down Expand Up @@ -500,8 +520,9 @@ def rebuild_graph(self, *input_args):
builder = GraphBuilder(dummy_forward_fn)

with training_mode_switcher(self._model_ref, is_training=False):
self._compressed_graph = builder.build_graph(
self._model_ref, self._compressed_context, input_infos=self._input_infos
self._compressed_traced_graph = builder.build_dynamic_graph(self._model_ref, self._compressed_context)
self._compressed_graph = GraphConverter.convert(
self._compressed_traced_graph, input_infos=self._input_infos
)

def is_scope_in_nncf_module_scope(self, scope: Scope) -> bool:
Expand Down Expand Up @@ -736,13 +757,15 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable)
return result

def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]:
# The IDs of corresponding nodes of the original dynamic graph and original NNCF graph
# must be equal for this to work.
retval = {}
for node in self._original_dynamic_graph.get_all_nodes():
dynamic_graph = (
self._original_dynamic_graph if self._compressed_traced_graph is None else self._compressed_traced_graph
)
nncf_graph = self._original_graph if self._compressed_graph is None else self._compressed_graph
for node in dynamic_graph.get_all_nodes():
node_id = node.node_id
op_address = node.op_exec_context.op_address
nncf_node = self._original_graph.get_node_by_id(node_id)
nncf_node = nncf_graph.get_node_by_id(node_id)
retval[nncf_node.node_name] = op_address
return retval

Expand Down
15 changes: 14 additions & 1 deletion nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict

import numpy as np
Expand All @@ -26,10 +27,22 @@
from nncf.torch.tensor_statistics.algo import create_register_input_hook


class ModelView:
def __init__(self, model: NNCFNetwork):
self.model = model

def __enter__(self):
self.nncf_interface = deepcopy(self.model.nncf)
return self.model

def __exit__(self, exc_type, exc_val, exc_tb):
self.model._nncf = self.nncf_interface


class PTStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
with torch.no_grad():
with model.nncf.temporary_clean_view() as intermediate_model:
with ModelView(model) as intermediate_model:
super().collect_statistics(intermediate_model, graph)

def _register_statistics(
Expand Down

0 comments on commit 3c0e659

Please sign in to comment.