Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WC, PT] Store compression scale in f16 #2596

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def transform_model(

# calculates compressed weights and decompression parameters
compressed_weight = compress_weight(Tensor(weight), wc_params.reduction_axes, compression_config)
compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16)

# pack compressed tensor
packed_tensor = compressed_weight.tensor.astype(TensorDataType.uint8)
Expand All @@ -217,7 +218,9 @@ def transform_model(
packed_zero_point = compressed_weight.zero_point.astype(TensorDataType.uint8)

# creates weight decompressor
decompressor = WeightsDecompressor(compressed_weight.scale.data, packed_zero_point.data)
decompressor = WeightsDecompressor(
compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.dtype
)

# registry weight decompression module in the model
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"
Expand Down
8 changes: 6 additions & 2 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,14 +1039,18 @@ class WeightsDecompressor(nn.Module):
Applies decompression of compressed weights in the forward pass
"""

def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor):
def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: torch.dtype = None):
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
"""
:param scale: A scale in quantization scheme
:param zero_point: A zero point in quantization scheme
:param result_dtype: (Optional) A data type that result should be cast to
"""
super().__init__()
self.register_buffer("_scale", scale)
self.register_buffer("_zero_point", zero_point)
self.result_dtype = result_dtype

def forward(self, x):
return decompress(x, self._scale, self._zero_point)
result = decompress(x, self._scale, self._zero_point)
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result
25 changes: 23 additions & 2 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, input_ids):
return res


class NestedMatMul(torch.nn.Module):
class MatMulModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.ones(size=(300, 300), dtype=torch.float32))
Expand All @@ -68,7 +68,7 @@ def __init__(self):
self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32))
self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 300, 300), dtype=torch.float32))
self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3)))
self.nested_matmul = NestedMatMul()
self.nested_matmul = MatMulModel()

def forward(self, input_):
x = input_.to(torch.float32)
Expand Down Expand Up @@ -241,3 +241,24 @@ def test_get_dtype_attribute_of_parameter():
assert compressed_model.weight.dtype == torch.uint8
compressed_model(dummy_input)
assert compressed_model.weight.dtype == torch.uint8


@pytest.mark.parametrize("device", ("cpu", "cuda"))
@pytest.mark.parametrize("dtype", ("float16", "float32"))
def test_model_devices_and_precisions(device, dtype):
device = torch.device(device)
dtype = torch.float16 if dtype == "float16" else torch.float32

model = MatMulModel().to(device)
if dtype == torch.float16:
model.half()

dummy_input = torch.rand((1, 300), dtype=dtype, device=device)
wrapped_model = wrap_model(model, example_input=dummy_input, trace_parameters=True)
compressed_model = compress_weights(wrapped_model)
result = compressed_model(dummy_input)

# Scale should always be in float16
assert compressed_model.state_dict()["_nncf.external_op.weights_decompressor_w._scale"].dtype == torch.float16
# Result should be in the precision of the model
assert result.dtype == dtype
Loading