Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WC] Align compression subgraphs for both weight input data types #2537

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def transform_model(
const_attributes = wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]
const_node_name = const_attributes["name"]
const_node = self.name_to_node_mapping[const_node_name]
const_dtype = const_node.output(0).get_element_type().to_dtype()
const_dtype = const_node.output(0).get_element_type()

weight = Tensor(get_const_value(const_node))
original_shape = weight.shape
Expand All @@ -153,19 +153,22 @@ def transform_model(
compressed_const = opset.constant(
compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name
)
converted_const = opset.convert(compressed_const, const_dtype)
converted_const = opset.convert(compressed_const, ov.Type.f32)
nikita-savelyevv marked this conversation as resolved.
Show resolved Hide resolved
if compressed_weight.zero_point is not None:
zero_point_const = opset.constant(
compressed_weight.zero_point.data,
dtype=compression_dtype,
name=f"{const_node_name}/zero_point",
)
converted_zero_point = opset.convert(zero_point_const, const_dtype)
converted_const = opset.subtract(converted_const, converted_zero_point)
converted_zero_point = opset.convert(zero_point_const, ov.Type.f32)
converted_const = opset.subtract(
converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract"
)

scale_const = opset.constant(compressed_weight.scale.data, dtype="float16", name=f"{const_node_name}/scale")
if const_dtype != "float16":
scale_const = opset.convert(scale_const, const_dtype, name=f"{const_node_name}/scale_convert")
scale_const = opset.constant(
compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale"
)
scale_const = opset.convert(scale_const, ov.Type.f32, name=f"{const_node_name}/scale_convert")
mul = opset.multiply(
converted_const,
scale_const,
Expand All @@ -175,8 +178,13 @@ def transform_model(
if compression_config.group_size != -1:
mul = opset.reshape(mul, output_shape=original_shape, special_zero=False)

const_node_output = const_node.output(0)
if const_dtype == ov.Type.f16:
# Bypass fp16 -> fp32 convert node
const_node_output = next(iter(const_node_output.get_target_inputs())).get_node().output(0)

mul_output = mul.output(0)
for target_input in const_node.output(0).get_target_inputs():
for target_input in const_node_output.get_target_inputs():
target_input.replace_source_output(mul_output)

# reset name_to_node_mapping
Expand Down
46 changes: 22 additions & 24 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ def check_int8_node(op: ov.Node, mode: CompressWeightsMode = CompressWeightsMode

mul_node = get_next_node(sub_node)
assert mul_node.get_type_name() == "Multiply"
scale_node = mul_node.input_value(1).get_node()
if scale_node.get_type_name() == "Convert":
scale_node = scale_node.input_value(0).get_node()
convert_node = mul_node.input_value(1).get_node()
scale_node = convert_node.input_value(0).get_node()
scale = get_const_value(scale_node)

return {
Expand Down Expand Up @@ -133,9 +132,8 @@ def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int =

mul_node = get_next_node(sub_node)
assert mul_node.get_type_name() == "Multiply"
scale_node = mul_node.input_value(1).get_node()
if scale_node.get_type_name() == "Convert":
scale_node = scale_node.input_value(0).get_node()
convert_node = mul_node.input_value(1).get_node()
scale_node = convert_node.input_value(0).get_node()
assert list(scale_node.shape) == reduced_weight_shape

reshape_node = get_next_node(mul_node)
Expand All @@ -159,9 +157,8 @@ def check_nf4_grouped(op: ov.Node, group_size: int = 7):

mul_node = get_next_node(convert_node)
assert mul_node.get_type_name() == "Multiply"
scale_node = mul_node.input_value(1).get_node()
if scale_node.get_type_name() == "Convert":
scale_node = scale_node.input_value(0).get_node()
convert_node = mul_node.input_value(1).get_node()
scale_node = convert_node.input_value(0).get_node()
assert list(scale_node.shape) == reduced_weight_shape

reshape_node = get_next_node(mul_node)
Expand Down Expand Up @@ -698,21 +695,22 @@ def test_data_type_for_num_weights(mocker):


def test_weight_scale_datatype():
# When model weight is in fp32, there will be an extra convert node for weight scale f16 > f32
model_fp32 = IdentityMatmul(weights_dtype=np.float32).ov_model
compressed_model_fp32 = compress_weights(model_fp32)
name_to_node_map = {op.get_friendly_name(): op for op in compressed_model_fp32.get_ops()}
assert "weights/scale_convert" in name_to_node_map
scale_multiply_node = name_to_node_map["weights/fq_weights_1"]
assert scale_multiply_node.input_value(1).get_node().get_element_type() == ov.Type.f32

# When model weight is in fp16, there will be no extra convert node for weight scale
model_fp16 = IdentityMatmul(weights_dtype=np.float16).ov_model
compressed_model_fp16 = compress_weights(model_fp16)
name_to_node_map = {op.get_friendly_name(): op for op in compressed_model_fp16.get_ops()}
assert "weights/scale_convert" not in name_to_node_map
scale_multiply_node = name_to_node_map["weights/fq_weights_1"]
assert scale_multiply_node.input_value(1).get_node().get_element_type() == ov.Type.f16
for weight_dtype in [np.float32, np.float16]:
model_fp32 = IdentityMatmul(weights_dtype=weight_dtype).ov_model
compressed_model_fp32 = compress_weights(model_fp32)
name_to_node_map = {op.get_friendly_name(): op for op in compressed_model_fp32.get_ops()}

# Scale should always be converted from f16 to f32
assert "weights/scale_convert" in name_to_node_map
scale_multiply_node = name_to_node_map["weights/fq_weights_1"]
convert_node = scale_multiply_node.input_value(1).get_node()
scale_node = convert_node.input_value(0).get_node()
assert scale_node.get_element_type() == ov.Type.f16
assert convert_node.get_element_type() == ov.Type.f32

# There should be no Convert node after scale multiply
matmul_node = get_next_node(scale_multiply_node)
assert matmul_node.get_type_name() == "MatMul"


DATASET_SIZE = 129
Expand Down
Loading