Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kshpv committed Apr 12, 2024
1 parent 01db1cb commit 2021777
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def print_results(optimized_model: ov.Model, similarity: float) -> None:
else:
print(best_params_info)
footprint = Path(MODEL_PATH).with_suffix(".bin").stat().st_size
print(f"Memory footprint: {footprint / 2**20:.2f} MB")
print(f"Memory footprint: {footprint / 2**20 :.2f} MB")
print(f"Similarity: {similarity:.2f}")


Expand Down
4 changes: 3 additions & 1 deletion examples/torch/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def is_staged_quantization(config):
if isinstance(compression_config, list):
compression_config = compression_config[0]
algo_type = compression_config.get("algorithm")
return bool(algo_type is not None and algo_type == "quantization" and compression_config.get("params", {}))
if algo_type is not None and algo_type == "quantization" and compression_config.get("params", {}):
return True
return False


def is_pretrained_model_requested(config: SampleConfig) -> bool:
Expand Down
20 changes: 15 additions & 5 deletions nncf/common/pruning/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def mask_propagation(
class ConvolutionPruningOp(BasePruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode) -> bool:
return not (is_grouped_conv(node) and not is_prunable_depthwise_conv(node))
if is_grouped_conv(node) and not is_prunable_depthwise_conv(node):
return False
return True

@classmethod
def mask_propagation(
Expand All @@ -124,7 +126,9 @@ def mask_propagation(
class TransposeConvolutionPruningOp(BasePruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode) -> bool:
return not (is_grouped_conv(node) and not is_prunable_depthwise_conv(node))
if is_grouped_conv(node) and not is_prunable_depthwise_conv(node):
return False
return True

@classmethod
def mask_propagation(
Expand Down Expand Up @@ -274,7 +278,9 @@ def match_multiple_output_masks(

@classmethod
def accept_pruned_input(cls, node: NNCFNode):
return node.layer_attributes is not None
if node.layer_attributes is not None:
return True
return False

@classmethod
def generate_output_masks(
Expand Down Expand Up @@ -329,7 +335,9 @@ class PadPruningOp(IdentityMaskForwardPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode) -> bool:
mode, value = node.layer_attributes.mode, node.layer_attributes.value
return not (mode == "constant" and value != 0)
if mode == "constant" and value != 0:
return False
return True


class ElementwisePruningOp(BasePruningOp):
Expand Down Expand Up @@ -386,7 +394,9 @@ def mask_propagation(
class FlattenPruningOp(BasePruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode) -> bool:
return node.layer_attributes is not None
if node.layer_attributes is not None:
return True
return False

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph, tensor_processor: Type[NNCFPruningBaseTensorProcessor]):
Expand Down
2 changes: 1 addition & 1 deletion nncf/onnx/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _is_node_with_bias(node: onnx.NodeProto, model: onnx.ModelProto) -> bool:
"""
metatype = get_metatype(model, node)
bias_tensor_port_id = get_bias_tensor_port_id(metatype)
return bool(bias_tensor_port_id is not None and len(node.input) > bias_tensor_port_id)
return bias_tensor_port_id is not None and len(node.input) > bias_tensor_port_id


def _get_gemm_attrs(node: onnx.NodeProto) -> Dict[str, int]:
Expand Down
4 changes: 3 additions & 1 deletion nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def _check_consumer_conv_node(self, conv_node: NNCFNode) -> bool:
if any(elem != 1 for elem in attrs.stride):
return False
# Check Node has valid dilation
return not any(elem != 1 for elem in attrs.dilations)
if any(elem != 1 for elem in attrs.dilations):
return False
return True

def _check_producer_conv_node(self, conv_node: NNCFNode):
return conv_node.layer_attributes is not None
Expand Down
2 changes: 1 addition & 1 deletion nncf/tensorflow/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def commands(self) -> List[TFTransformationCommand]:
return self._commands

def check_insertion_command(self, command: TFTransformationCommand) -> bool:
return bool(
return (
isinstance(command, TFTransformationCommand)
and command.type == TransformationType.INSERT
and self.check_target_points_fn(self.target_point, command.target_point)
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/accuracy_aware_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
return bool(dist.is_initialized())
return dist.is_initialized()


def get_rank():
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def matches(
) -> bool:
if not isinstance(layer_attributes, ConvolutionLayerAttributes):
return False
return bool(layer_attributes.groups == layer_attributes.in_channels and layer_attributes.in_channels > 1)
return layer_attributes.groups == layer_attributes.in_channels and layer_attributes.in_channels > 1


@PT_OPERATOR_METATYPES.register()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def match_fn(obj):
def _is_loss(obj):
if not isinstance(obj, torch.Tensor):
return False
return bool(obj.requires_grad)
return obj.requires_grad

def forward(self) -> torch.Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_flat_tensor_contents_string(input_tensor):
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
return bool(dist.is_initialized())
return dist.is_initialized()


def get_rank():
Expand Down
Loading

0 comments on commit 2021777

Please sign in to comment.