diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index 141471fa977..ee78b7c5899 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -197,6 +197,7 @@ def transform_model( # calculates compressed weights and decompression parameters compressed_weight = compress_weight(Tensor(weight), wc_params.reduction_axes, compression_config) + compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16) # pack compressed tensor packed_tensor = compressed_weight.tensor.astype(TensorDataType.uint8) @@ -217,7 +218,9 @@ def transform_model( packed_zero_point = compressed_weight.zero_point.astype(TensorDataType.uint8) # creates weight decompressor - decompressor = WeightsDecompressor(compressed_weight.scale.data, packed_zero_point.data) + decompressor = WeightsDecompressor( + compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.dtype + ) # registry weight decompression module in the model decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 937d156c8fe..cb8906b1ff0 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -1039,14 +1039,18 @@ class WeightsDecompressor(nn.Module): Applies decompression of compressed weights in the forward pass """ - def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor): + def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: torch.dtype = None): """ :param scale: A scale in quantization scheme :param zero_point: A zero point in quantization scheme + :param result_dtype: (Optional) A data type that result should be cast to """ super().__init__() self.register_buffer("_scale", scale) self.register_buffer("_zero_point", zero_point) + self.result_dtype = result_dtype def forward(self, x): - return decompress(x, self._scale, self._zero_point) + result = decompress(x, self._scale, self._zero_point) + result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result + return result diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index c357005a3bd..a0056d09e5f 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -53,7 +53,7 @@ def forward(self, input_ids): return res -class NestedMatMul(torch.nn.Module): +class MatMulModel(torch.nn.Module): def __init__(self): super().__init__() self.w = torch.nn.Parameter(torch.ones(size=(300, 300), dtype=torch.float32)) @@ -68,7 +68,7 @@ def __init__(self): self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32)) self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 300, 300), dtype=torch.float32)) self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3))) - self.nested_matmul = NestedMatMul() + self.nested_matmul = MatMulModel() def forward(self, input_): x = input_.to(torch.float32) @@ -241,3 +241,24 @@ def test_get_dtype_attribute_of_parameter(): assert compressed_model.weight.dtype == torch.uint8 compressed_model(dummy_input) assert compressed_model.weight.dtype == torch.uint8 + + +@pytest.mark.parametrize("device", ("cpu", "cuda")) +@pytest.mark.parametrize("dtype", ("float16", "float32")) +def test_model_devices_and_precisions(device, dtype): + device = torch.device(device) + dtype = torch.float16 if dtype == "float16" else torch.float32 + + model = MatMulModel().to(device) + if dtype == torch.float16: + model.half() + + dummy_input = torch.rand((1, 300), dtype=dtype, device=device) + wrapped_model = wrap_model(model, example_input=dummy_input, trace_parameters=True) + compressed_model = compress_weights(wrapped_model) + result = compressed_model(dummy_input) + + # Scale should always be in float16 + assert compressed_model.state_dict()["_nncf.external_op.weights_decompressor_w._scale"].dtype == torch.float16 + # Result should be in the precision of the model + assert result.dtype == dtype