Skip to content

Commit

Permalink
[WC, PT] Store compression scale in f16 (#2596)
Browse files Browse the repository at this point in the history
### Changes

- Store compression scale if FP16
- Add type conversion to original data type after decompression

Below are the compression subgraphs for the first conv2d in mobilenet_v2
after conversion to OV, this is similar to the table presented in #2537
.

![image](https://github.com/openvinotoolkit/nncf/assets/23343961/740953d6-2615-4c8f-bbd3-6cfae5585dfd)
Compared to OV case, there is an additional Multiply node after the
scale Multiply node. It seems to come from Batch Norm applied to the
convolution. In case of PT weight compression it does not get merged
into the weight as it does in OV case.


### Reason for changes

Weight compression for PT backend fails when applied to model in half
precision. The reason is that the scale is always in FP32, and hence
decompression result is also in FP32, which conflicts with input type of
FP16.

### Related tickets

134063

### Tests

Added test for half/full precision cases. Also added cases for different
devices as it was thought that it may influence tracing in half
precision.
  • Loading branch information
nikita-savelyevv authored Mar 26, 2024
1 parent 3d3b797 commit c79111b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def transform_model(

# calculates compressed weights and decompression parameters
compressed_weight = compress_weight(Tensor(weight), wc_params.reduction_axes, compression_config)
compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16)

# pack compressed tensor
packed_tensor = compressed_weight.tensor.astype(TensorDataType.uint8)
Expand All @@ -217,7 +218,9 @@ def transform_model(
packed_zero_point = compressed_weight.zero_point.astype(TensorDataType.uint8)

# creates weight decompressor
decompressor = WeightsDecompressor(compressed_weight.scale.data, packed_zero_point.data)
decompressor = WeightsDecompressor(
compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.dtype
)

# registry weight decompression module in the model
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"
Expand Down
8 changes: 6 additions & 2 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,14 +1039,18 @@ class WeightsDecompressor(nn.Module):
Applies decompression of compressed weights in the forward pass
"""

def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor):
def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: torch.dtype = None):
"""
:param scale: A scale in quantization scheme
:param zero_point: A zero point in quantization scheme
:param result_dtype: (Optional) A data type that result should be cast to
"""
super().__init__()
self.register_buffer("_scale", scale)
self.register_buffer("_zero_point", zero_point)
self.result_dtype = result_dtype

def forward(self, x):
return decompress(x, self._scale, self._zero_point)
result = decompress(x, self._scale, self._zero_point)
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result
25 changes: 23 additions & 2 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, input_ids):
return res


class NestedMatMul(torch.nn.Module):
class MatMulModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.ones(size=(300, 300), dtype=torch.float32))
Expand All @@ -68,7 +68,7 @@ def __init__(self):
self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32))
self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 300, 300), dtype=torch.float32))
self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3)))
self.nested_matmul = NestedMatMul()
self.nested_matmul = MatMulModel()

def forward(self, input_):
x = input_.to(torch.float32)
Expand Down Expand Up @@ -241,3 +241,24 @@ def test_get_dtype_attribute_of_parameter():
assert compressed_model.weight.dtype == torch.uint8
compressed_model(dummy_input)
assert compressed_model.weight.dtype == torch.uint8


@pytest.mark.parametrize("device", ("cpu", "cuda"))
@pytest.mark.parametrize("dtype", ("float16", "float32"))
def test_model_devices_and_precisions(device, dtype):
device = torch.device(device)
dtype = torch.float16 if dtype == "float16" else torch.float32

model = MatMulModel().to(device)
if dtype == torch.float16:
model.half()

dummy_input = torch.rand((1, 300), dtype=dtype, device=device)
wrapped_model = wrap_model(model, example_input=dummy_input, trace_parameters=True)
compressed_model = compress_weights(wrapped_model)
result = compressed_model(dummy_input)

# Scale should always be in float16
assert compressed_model.state_dict()["_nncf.external_op.weights_decompressor_w._scale"].dtype == torch.float16
# Result should be in the precision of the model
assert result.dtype == dtype

0 comments on commit c79111b

Please sign in to comment.