Skip to content

Commit

Permalink
Merge branch 'develop' into dl/challen_alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 27, 2023
2 parents 6a22667 + a8af396 commit 2274fb3
Show file tree
Hide file tree
Showing 91 changed files with 5,225 additions and 474 deletions.
4 changes: 3 additions & 1 deletion nncf/common/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.config.extractors import extract_algo_specific_config
from nncf.config.extractors import extract_bn_adaptation_init_params
from nncf.config.extractors import has_bn_section
from nncf.config.schemata.defaults import VALIDATE_SCOPES

TModel = TypeVar("TModel")

Expand Down Expand Up @@ -201,8 +202,9 @@ def __init__(self, config: NNCFConfig, should_init: bool = True):
self.should_init = should_init
self._algo_config = self._get_algo_specific_config_section()

self.ignored_scopes = self.config.get("ignored_scopes")
self.validate_scopes = self._algo_config.get("validate_scopes", VALIDATE_SCOPES)

self.ignored_scopes = self.config.get("ignored_scopes")
if "ignored_scopes" in self._algo_config:
algo_ignored_scopes = self._algo_config["ignored_scopes"]
if self.ignored_scopes is not None:
Expand Down
10 changes: 9 additions & 1 deletion nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.
from collections import defaultdict
from copy import deepcopy
from typing import Any, Callable, Dict, Generator, KeysView, List, Tuple, Type, ValuesView
from typing import Any, Callable, Dict, Generator, KeysView, List, Optional, Tuple, Type, ValuesView

import networkx as nx
import networkx.algorithms.isomorphism as iso
Expand Down Expand Up @@ -113,6 +113,7 @@ def __init__(
output_port_id: int,
tensor_shape: List[int],
dtype: Dtype,
parallel_input_port_ids: List[int],
):
"""
:param from_node: An NNCFNode that sources the directed edge.
Expand All @@ -128,6 +129,7 @@ def __init__(
self.output_port_id = output_port_id
self.tensor_shape = tensor_shape
self.dtype = dtype
self.parallel_input_port_ids = parallel_input_port_ids

def __str__(self):
return str(self.from_node) + " -> " + str(self.tensor_shape) + " -> " + str(self.to_node)
Expand Down Expand Up @@ -175,6 +177,7 @@ class NNCFGraph:
IS_INTEGER_INPUT_NODE_ATTR = "is_integer_input"
DTYPE_EDGE_ATTR = "dtype"
IS_SHARED_ATTR = "is_shared"
PARALLEL_INPUT_PORT_IDS_ATTR = "parallel_input_ports"

def __init__(self):
self._nx_graph = nx.DiGraph()
Expand Down Expand Up @@ -470,6 +473,7 @@ def add_edge_between_nncf_nodes(
input_port_id: int,
output_port_id: int,
dtype: Dtype,
parallel_input_port_ids: Optional[List[int]] = None,
):
"""
Adds a directed edge between two `NNCFNode`s that are already present in the graph.
Expand All @@ -482,6 +486,7 @@ def add_edge_between_nncf_nodes(
:param output_port_id: Specifies the index among the possible outputs of the `from_node_id` node' that this
tensor should correspond to.
:param dtype: The data type of the tensor.
:param parallel_input_port_ids: Input ports for parallel edges, if any should be present for this edge.
"""
from_node_key = self._node_id_to_key_dict[from_node_id]
to_node_key = self._node_id_to_key_dict[to_node_id]
Expand All @@ -505,6 +510,7 @@ def add_edge_between_nncf_nodes(
NNCFGraph.INPUT_PORT_ID_EDGE_ATTR: input_port_id,
NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR: output_port_id,
NNCFGraph.DTYPE_EDGE_ATTR: dtype,
NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR: [] if parallel_input_port_ids is None else parallel_input_port_ids,
}
self._nx_graph.add_edge(from_node_key, to_node_key, **attrs)

Expand Down Expand Up @@ -649,6 +655,7 @@ def get_nncf_graph_pattern_io(self, match: List[str]) -> NNCFGraphPatternIO:
output_port_id=data[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR],
tensor_shape=data[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR],
dtype=data[NNCFGraph.DTYPE_EDGE_ATTR],
parallel_input_port_ids=data[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR],
)
if from_node_key in match:
output_nncf_edges.append(nncf_edge)
Expand Down Expand Up @@ -683,6 +690,7 @@ def get_edge(self, from_node: NNCFNode, to_node: NNCFNode) -> NNCFGraphEdge:
data[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR],
data[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR],
data[NNCFGraph.DTYPE_EDGE_ATTR],
data[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR],
)

def get_all_edges(self) -> Generator[NNCFGraphEdge, None, None]:
Expand Down
1 change: 1 addition & 0 deletions nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class HWFusedPatternNames(Enum):
LINEAR_SCALE_SHIFT_ACTIVATIONS = PatternDesc("linear_scale_shift_activations")
LINEAR_CONST_MULTIPLY = PatternDesc("linear_const_multiply")
LINEAR_SQUEEZE_ACTIVATIONS = PatternDesc("linear_squeeze_activations")
LINEAR_ACTIVATIONS_UNSQUEEZE_BN_SQUEEZE = PatternDesc("linear_activations_unsqueeze_bn_squeeze")
SCALE_SHIFT_ACTIVATIONS = PatternDesc("scale_shift_activations")
MVN_SCALE_SHIFT_ACTIVATIONS = PatternDesc("mvn_scale_shift_activations")

Expand Down
22 changes: 19 additions & 3 deletions nncf/common/insertion_point_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,14 @@ def __init__(
for edge in self._base_nx_graph.edges:
input_port_id = self._base_nx_graph.edges[edge][NNCFGraph.INPUT_PORT_ID_EDGE_ATTR]
dtype = self._base_nx_graph.edges[edge][NNCFGraph.DTYPE_EDGE_ATTR]
parallel_input_port_ids = self._base_nx_graph.edges[edge][NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]
from_node, to_node = edge
attrs = {INPUT_PORT_ID: input_port_id, self.IS_INTEGER_PATH_EDGE_ATTR: dtype is Dtype.INTEGER}

attrs = {
INPUT_PORT_ID: input_port_id,
self.IS_INTEGER_PATH_EDGE_ATTR: dtype is Dtype.INTEGER,
NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR: parallel_input_port_ids,
}
self.add_edge(from_node, to_node, **attrs)

node_keys_working_set = [deepcopy(node_key) for node_key in nx.lexicographical_topological_sort(self)]
Expand All @@ -148,7 +154,14 @@ def __init__(
pre_hook_ips = list(target_node_name_vs_pre_hook_ips[original_node.node_name])
pre_hook_ips = sorted(pre_hook_ips, key=lambda x: x.input_port_id)
in_edges = list(self.in_edges(operator_node_key))
input_port_id_vs_edge = {self.edges[edge][INPUT_PORT_ID]: edge for edge in in_edges}
input_port_id_vs_edge = {}
for edge in in_edges:
input_port_id = self.edges[edge][INPUT_PORT_ID]
input_port_id_vs_edge[input_port_id] = edge
for parallel_input_port_id in self.edges[edge][NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]:
input_port_id_vs_edge[parallel_input_port_id] = edge

encountered_input_edges = set()
for pre_hook_point in pre_hook_ips:
edge = input_port_id_vs_edge[pre_hook_point.input_port_id]
original_edge_attrs = self.edges[edge]
Expand All @@ -162,11 +175,14 @@ def __init__(

self.add_node(ip_node_key, **pre_hook_ip_attrs)

self.remove_edge(from_node_key, to_node_key)
encountered_input_edges.add(edge)
self.add_edge(from_node_key, ip_node_key, **original_edge_attrs)
self.add_edge(ip_node_key, operator_node_key, **original_edge_attrs)
operator_node[InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR].add(ip_node_key)

for edge in encountered_input_edges:
self.remove_edge(*edge)

if original_node.node_name in target_node_name_vs_post_hook_ips:
post_hook_ips = target_node_name_vs_post_hook_ips[original_node.node_name]
assert len(post_hook_ips) == 1, "Multiple post-hooks for a single NNCFGraph node are not supported!"
Expand Down
81 changes: 70 additions & 11 deletions nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
from collections import deque
from copy import copy
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

import networkx as nx

from nncf import nncf_logger
from nncf.common.graph import INPUT_NOOP_METATYPES
from nncf.common.graph import OUTPUT_NOOP_METATYPES
from nncf.common.graph import NNCFNode
from nncf.common.graph import NNCFNodeName
from nncf.common.graph import OperatorMetatype
Expand All @@ -26,7 +29,6 @@
from nncf.common.insertion_point_graph import InsertionPointGraphNodeType
from nncf.common.insertion_point_graph import PostHookInsertionPoint
from nncf.common.insertion_point_graph import PreHookInsertionPoint
from nncf.common.logging import nncf_logger
from nncf.common.quantization.quantizer_propagation.grouping import UnifiedScalePropagatingQuantizerGroupManager
from nncf.common.quantization.quantizer_propagation.structs import IgnoreReason
from nncf.common.quantization.quantizer_propagation.structs import PropagatingQuantizer
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(

self._unified_scale_group_manager = UnifiedScalePropagatingQuantizerGroupManager()
self._input_node_keys_vs_nncf_nodes = {} # type: Dict[str, NNCFNode]
self._output_node_keys_vs_nncf_nodes = {} # type: Dict[str, NNCFNode]
self._pqs_after_weight_dependent_output_quantized_nodes = {} # type: Dict[PropagatingQuantizer, str]
self.op_node_keys_to_underlying_nodes_mapping = {} # type: Dict[str, List[NNCFNode]]

Expand Down Expand Up @@ -139,6 +142,8 @@ def __init__(

if nncf_node_ref.metatype in INPUT_NOOP_METATYPES:
self._input_node_keys_vs_nncf_nodes[node_key] = nncf_node_ref
if nncf_node_ref.metatype in OUTPUT_NOOP_METATYPES:
self._output_node_keys_vs_nncf_nodes[node_key] = nncf_node_ref

if nncf_node_ref.is_in_iteration_scope():
iteration_scope_node_keys.append(node_key)
Expand All @@ -153,6 +158,7 @@ def __init__(

for barred_node_key in list(self.ignored_node_keys.keys()) + iteration_scope_node_keys:
self._add_barrier_after_node(barred_node_key)
self._branch_nodes_directly_dominating_outputs = None

def get_node_keys_by_metatype(self, metatype: Type[OperatorMetatype]) -> List[str]:
"""
Expand All @@ -167,8 +173,9 @@ def get_node_keys_by_metatype(self, metatype: Type[OperatorMetatype]) -> List[st
output.append(node)
return output

@staticmethod
def _insertion_point_to_quant_insertion_point(
self, ip: Union[PreHookInsertionPoint, PostHookInsertionPoint]
ip: Union[PreHookInsertionPoint, PostHookInsertionPoint]
) -> QuantizationInsertionPointBase:
if isinstance(ip, PreHookInsertionPoint):
return ActivationQuantizationInsertionPoint(ip.target_node_name, input_port_id=ip.input_port_id)
Expand Down Expand Up @@ -233,8 +240,8 @@ def mark_act_quantizer_as_dependent_on_weights(self, pq: PropagatingQuantizer, o
and self._pqs_after_weight_dependent_output_quantized_nodes[pq] != operator_node_key
):
raise RuntimeError(
"Propagating quantizer {} is already marked as depending on node {} weight "
"quantization!".format(pq.id, operator_node_key)
f"Propagating quantizer {pq.id} is already marked as depending on node "
f"{operator_node_key} weight quantization!"
)
self._pqs_after_weight_dependent_output_quantized_nodes[pq] = operator_node_key

Expand Down Expand Up @@ -847,6 +854,55 @@ def traverse_fn(
self.traverse_graph(node_key, traverse_fn, retval)
return retval

def _build_branch_direct_output_dominators_info(self) -> Set[str]:
"""
Traverses the graph backwards starting from outputs. If there is a path from an output to a branching node
that only passes through quantization-agnostic ops, then this branching node is directly dominating an output.
:return: The set of node names that directly dominate at least one output.
"""

@dataclass
class LocalState:
global_result_ref: Set[str]
encountered_quantizer_aware_ops: bool = False

def traverse_fn(curr_node_key: str, local_state: LocalState) -> Tuple[bool, LocalState]:
curr_node = self.nodes[curr_node_key]
if len(list(self.successors(curr_node_key))) > 1:
if not local_state.encountered_quantizer_aware_ops:
local_state.global_result_ref.add(curr_node_key)
return True, local_state

curr_node_type = curr_node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
if curr_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
node_trait = curr_node[QuantizerPropagationStateGraph.QUANTIZATION_TRAIT_NODE_ATTR]
op_meta = curr_node[QuantizerPropagationStateGraph.OPERATOR_METATYPE_NODE_ATTR]
if op_meta not in OUTPUT_NOOP_METATYPES and node_trait in [
QuantizationTrait.INPUTS_QUANTIZABLE,
QuantizationTrait.OUTPUT_QUANTIZATION_AS_WEIGHTS,
QuantizationTrait.NON_QUANTIZABLE,
]:
local_state.encountered_quantizer_aware_ops = True
return False, local_state

visited_node_keys = set()
result = set()
for output_node_key in self._output_node_keys_vs_nncf_nodes:
output_state = LocalState(result)
self._traverse_graph_recursive_helper(
output_node_key, visited_node_keys, traverse_fn, output_state, traverse_backward=True, visit_once=False
)
return result

def is_branching_node_dominating_outputs(self, from_node_key: str) -> bool:
"""
Checks that all branches outgoing from the branching node can be quantized
(They do not contain an output that should not be quantized).
"""
if self._branch_nodes_directly_dominating_outputs is None:
self._branch_nodes_directly_dominating_outputs = self._build_branch_direct_output_dominators_info()
return from_node_key in self._branch_nodes_directly_dominating_outputs

def get_visualized_graph(self):
out_graph = nx.DiGraph()
unified_scale_group_vs_pq_node_id_dict = {} # type: Dict[int, List[str]]
Expand Down Expand Up @@ -967,18 +1023,21 @@ def _traverse_graph_recursive_helper(
visited_node_keys: Set[str],
traverse_function: Callable[[str, Any], Tuple[bool, Any]],
output: Any,
traverse_forward: bool,
traverse_backward: bool = False,
visit_once: bool = True,
):
"""This is DFS, and may fail with 'maximum recursion depth exceeded' for complex graphs."""
is_finished, output = traverse_function(curr_node_key, output)
visited_node_keys.add(curr_node_key)
next_node_keys_indexer = self.succ if traverse_forward else self.pred
if visit_once:
visited_node_keys.add(curr_node_key)
next_node_keys_indexer = self.pred if traverse_backward else self.succ
if not is_finished:
for node_key in next_node_keys_indexer[curr_node_key]:
if node_key not in visited_node_keys:
self._traverse_graph_recursive_helper(
node_key, visited_node_keys, traverse_function, output, traverse_forward
)
if visit_once and node_key in visited_node_keys:
continue
self._traverse_graph_recursive_helper(
node_key, visited_node_keys, traverse_function, output, traverse_backward, visit_once
)
return output

def _get_next_prop_quantizer_id(self):
Expand Down
6 changes: 5 additions & 1 deletion nncf/common/quantization/quantizer_propagation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def propagation_step(
# pylint:disable=too-many-branches
# pylint:disable=too-many-statements
curr_node_key = curr_prop_quantizer.current_location_node_key
curr_node = quant_prop_graph.nodes[curr_prop_quantizer.current_location_node_key]
curr_node = quant_prop_graph.nodes[curr_node_key]
curr_node_type = curr_node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
assert QuantizerPropagationStateGraph.is_insertion_point(curr_node_type)

Expand Down Expand Up @@ -1220,6 +1220,10 @@ def check_branching_transition(
that branches downwards.
:return: The TransitionStatus indicating in which fashion the transition should occur.
"""
is_dominating_outputs = quant_prop_graph.is_branching_node_dominating_outputs(branching_node_key)
if is_dominating_outputs and not self._quantize_outputs:
return TransitionStatus.SHOULD_NOT_TRANSITION

dom_op_node_keys = quant_prop_graph.get_non_quant_agnostic_op_nodes_immediately_dominated_by_node(
branching_node_key
)
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/quantization/quantizer_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def revert_operations_to_floating_point_precision(
command_creator.create_command_to_update_bias(node, original_bias, quantized_model_graph)
)

if node.layer_attributes is not None:
if node.layer_attributes and node.layer_attributes.const_attrs is not None:
weight_port_ids = node.layer_attributes.get_const_port_ids()
for port_id in weight_port_ids:
original_weight = node.data.get(f"original_weight.{port_id}", None)
Expand Down
8 changes: 7 additions & 1 deletion nncf/common/scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph import NNCFNodeName
from nncf.common.logging import nncf_logger
from nncf.common.quantization.structs import QuantizerId
from nncf.scopes import IgnoredScope
from nncf.scopes import convert_ignored_scope_to_list
Expand Down Expand Up @@ -103,6 +104,7 @@ def check_scopes_in_graph(
graph: NNCFGraph,
ignored_scopes: Union[IgnoredScope, List[str]],
target_scopes: Optional[List[str]] = None,
validate_scopes: bool = True,
) -> None:
"""
Raise RuntimeError in case if ignored/target scope names do not match model graph.
Expand All @@ -111,6 +113,8 @@ def check_scopes_in_graph(
:param ignored_scopes: The instance of IgnoredScope or a list of strings specifying a denylist
for the serializable_id.
:param target_scopes: A list of strings specifying an allowlist for the serializable_id.
:param validate_scopes: If set to True, then a RuntimeError will be raised if the names of the
ignored/target scopes do not match the names of the scopes in the model graph.
"""
node_list = graph.get_all_nodes()
not_matched_ignored_scopes = get_not_matched_scopes(ignored_scopes, node_list)
Expand All @@ -132,4 +136,6 @@ def check_scopes_in_graph(
"scopes in terms of the names there."
)

raise RuntimeError(err_message)
if validate_scopes:
raise RuntimeError(err_message)
nncf_logger.info(err_message)
7 changes: 3 additions & 4 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,9 @@ def _shape(self):
return self._all_shapes[0]


class BatchStatisticCollector(OfflineTensorStatisticCollector):
class RawStatisticCollector(OfflineTensorStatisticCollector):
"""
Collects tensor samples, where each tensor is averaged along the batch axis (and only that axis).
Collects tensor samples, where each tensor represented in raw format.
Each sample stays available for usage in further stages of the algorithm.
"""

Expand All @@ -498,7 +498,6 @@ def __init__(self, num_samples: Optional[int] = None) -> None:
the number of samples that will be processed.
"""
super().__init__(num_samples=num_samples)
self._tensor_processor = self._get_processor()
self._all_values = []

@staticmethod
Expand All @@ -507,7 +506,7 @@ def _get_processor():
pass

def _register_input_common(self, x: NNCFTensor):
self._all_values.append(self._tensor_processor.batch_mean(x).tensor)
self._all_values.append(x.tensor)

def _reset(self):
self._all_values.clear()
Expand Down
Loading

0 comments on commit 2274fb3

Please sign in to comment.