Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFX] INT8 Weights Compression Support #2891

Merged
merged 92 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
297fdb4
weights compression init
anzr299 Aug 14, 2024
534e294
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 16, 2024
06ca5a3
compression complete
anzr299 Aug 16, 2024
b4b2603
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 19, 2024
c770d2c
Modify graph builder to include support for embedding op
anzr299 Aug 19, 2024
70b00f9
modify function to set new node meta for new module insertion to fx g…
anzr299 Aug 19, 2024
c7fa7f2
Add weights compression support for torch fx
anzr299 Aug 19, 2024
667b8a5
Add test for torch fx weights compression
anzr299 Aug 19, 2024
dca2374
reorder comments
anzr299 Aug 19, 2024
6f693c9
variable names fix
anzr299 Aug 19, 2024
159a615
Fix messages, use transformation for updating weight
anzr299 Aug 19, 2024
7a896d6
Minor mypy fix
anzr299 Aug 19, 2024
0de1d9b
fix set_weight
anzr299 Aug 19, 2024
f9e5d7c
Update torch_fx_backend.py
anzr299 Aug 20, 2024
443dce7
Add embedding metatype for torch fx as a subtype
anzr299 Aug 20, 2024
03d16f8
replace embedding metatype with torch fx subtype in torch fx graph bu…
anzr299 Aug 20, 2024
5226934
1. Adjust the torch fx weights compression backend to use fx embeddin…
anzr299 Aug 20, 2024
3cdb7b3
Update test for weight compression. Include test to see if
anzr299 Aug 20, 2024
28f7053
Fix FX metatype mapping
anzr299 Aug 20, 2024
cb0bf6b
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 20, 2024
8b3c6e2
Add metatypes registry for torch fx specific embedding metatype and c…
anzr299 Aug 20, 2024
79ec939
Add copyright to new torch fx operator_metatypes file
anzr299 Aug 20, 2024
7accaf2
Add weights compression graph test
anzr299 Aug 26, 2024
5b11455
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 26, 2024
1cb55c2
Merge branch 'develop' of https://github.com/anzr299/nncf into fx_com…
anzr299 Aug 26, 2024
71c50ff
pre-commit fix
anzr299 Aug 26, 2024
9f68831
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 28, 2024
2cb0a41
Handle Lora correction in torch fx weights compression
anzr299 Aug 28, 2024
a9c3d57
Add graph test for compressed models in test_models
anzr299 Aug 28, 2024
0172ad1
pre commit fix
anzr299 Aug 28, 2024
f590200
1. Moved Embedding FX metatype from `experimental/torch/fx` to torch …
anzr299 Aug 29, 2024
0c7be62
shared weights support in torch fx graph builder and constant update …
anzr299 Aug 29, 2024
0a1157d
Update tests for more description
anzr299 Aug 29, 2024
0eff5cb
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
c7b9093
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
93ecc4e
add torch fx in supported backends
anzr299 Aug 30, 2024
b6ad458
Remove Compressed reference graphs
anzr299 Aug 30, 2024
e7097bd
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
64b9ba7
add test for shared weights
anzr299 Sep 2, 2024
2665666
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 2, 2024
fb74267
Merge branch 'develop' of https://github.com/anzr299/nncf into fx_com…
anzr299 Sep 2, 2024
287cb2c
pre-commit fix
anzr299 Sep 2, 2024
449f767
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 3, 2024
c79dfc2
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 4, 2024
a10cb68
Add test for shared node decompressor call
anzr299 Sep 4, 2024
1c144a5
update backend supported in docs
anzr299 Sep 4, 2024
c5291b7
pre-commit fix
anzr299 Sep 4, 2024
174fb32
remove todo
anzr299 Sep 4, 2024
45a5274
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 10, 2024
b46d00e
add get_dtype and get_shape methods to torch fx weights compression b…
anzr299 Sep 10, 2024
32f5098
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 12, 2024
1819241
get the updated constant name from graph
anzr299 Sep 16, 2024
8a6b6d5
updated constant name from graph
anzr299 Sep 16, 2024
502c6c3
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 16, 2024
3503674
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 17, 2024
71901c5
update shared constants transformation
anzr299 Sep 20, 2024
bd5ff1f
pre commit fix
anzr299 Sep 20, 2024
b6a29ab
update docs
anzr299 Sep 20, 2024
7dd9782
refactor get weight name and port ids
anzr299 Sep 20, 2024
bbfeff0
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 20, 2024
48848be
update docs from X to Torch FX
anzr299 Sep 20, 2024
20544fd
fix shared weights attribute
anzr299 Sep 20, 2024
60ef615
Merge branch 'fx_compress_weights' of https://github.com/anzr299/nncf…
anzr299 Sep 20, 2024
fb89a4d
Fix Suggestions
anzr299 Sep 20, 2024
002758b
pre commit fix
anzr299 Sep 20, 2024
fe4d390
update is_shared attribute
anzr299 Sep 20, 2024
2ca11f8
Add tests for cosntant update transformation
anzr299 Sep 20, 2024
2be2487
pre commit fix
anzr299 Sep 20, 2024
fc543c9
Add test for edge shape
anzr299 Sep 20, 2024
02861e9
make decompressor name more readible
anzr299 Sep 20, 2024
33afddb
fix model_devices and precision test
anzr299 Sep 20, 2024
15bfeb0
Update is_shared attribute using a one liner
anzr299 Sep 20, 2024
04ed994
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 23, 2024
7683b5d
add test for nncf node is_shared attribute before applying transforma…
anzr299 Sep 23, 2024
fa56e7e
Change code to include _capture_model function for torch FX graph cap…
anzr299 Sep 23, 2024
fd9498a
pre-commit fix
anzr299 Sep 23, 2024
782b509
Fix is_shared attribute test
anzr299 Sep 23, 2024
48d050b
pre- commit fix
anzr299 Sep 23, 2024
3477d7c
add reference for checking shared constant unification transformation
anzr299 Sep 23, 2024
cbc2106
Add synthetic model with embedding to test models and include create …
anzr299 Sep 23, 2024
229517c
add reference graphs
anzr299 Sep 23, 2024
fde56b7
Include assert in shared attribute test
anzr299 Sep 24, 2024
30ff3d2
Fix reference graphs structure
anzr299 Sep 24, 2024
f26a7a0
pre-commit fix
anzr299 Sep 24, 2024
1d0a866
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 24, 2024
49d3dec
Change FXEmbedding metatype to PTAtenEmbeddingMetatype
anzr299 Sep 24, 2024
2e7e639
Move shared constants unification transformation to `apply_quantizati…
anzr299 Sep 24, 2024
817c233
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 25, 2024
26a4ff4
Corrections, comments and refactoring
anzr299 Sep 26, 2024
065bacb
Add seperate error message for dataset attribute
anzr299 Sep 26, 2024
3942d45
fix comments
anzr299 Sep 26, 2024
14096b7
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def module_insertion_transformation(model: torch.fx.GraphModule):
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
_set_new_node_meta(new_node, target_node, module_to_insert, model)
with graph.inserting_after(target_node):
for user in list(target_node.users.keys()):
for user in list(target_node.users):
if user is new_node:
continue
user.replace_input_with(target_node, new_node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value

import nncf
import nncf.errors
Expand All @@ -30,6 +29,7 @@
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder
from nncf.experimental.torch.fx.transformations import module_insertion_transformation_builder
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
Expand Down Expand Up @@ -180,23 +180,13 @@ def set_weight(
graph: NNCFGraph,
weight: Tensor,
) -> torch.fx.Node:
weight_node = graph.get_previous_nodes(node_with_weight)[weight_port_id]
graph_node = get_graph_node_by_name(model.graph, weight_node.node_name)
if len(graph_node.users) != 1:
raise nncf.InternalError(f"Weight Node has {len(graph_node.users)} users, 1 expected.")

node_with_weight_graph = next(iter(graph_node.users))
with model.graph.inserting_before(node_with_weight_graph):
new_weight_node = create_getattr_from_value(
model, model.graph, node_with_weight.node_name + "_compressed_weight", weight.data
)

args = list(node_with_weight_graph.args)
args[weight_port_id] = new_weight_node
node_with_weight_graph.args = tuple(args)
model.graph.eliminate_dead_code()

return new_weight_node
weight_update_command = FXApplyTransformationCommand(
constant_update_transformation_builder(node_with_weight, weight.data)
)
layout = TransformationLayout()
layout.register(weight_update_command)
model = FXModelTransformer(model).transform(layout)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

def transform_model(
self,
Expand Down Expand Up @@ -240,12 +230,10 @@ def transform_model(
dtype = TensorDataType.uint8
packed_tensor = compressed_weight.tensor.astype(dtype)

new_weight = self.set_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph, packed_tensor
)
self.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, packed_tensor)

if len(consumer_nodes) > 1:
raise nncf.InternalError("Shared weights not supported in compression for Torch Fx models")
raise nncf.InternalError("Shared weights not supported in compression for TorchFX models")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for the restriction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far using capture_pre_autograd_graph() adds an extra node for the shared weights which point to the original weight under the hood. But new nodes are created for them in the graph. I had added this for a scenario where an edge case is ever hit or we change our method of creating torch fx graph where this can be extended.

Copy link
Contributor Author

@anzr299 anzr299 Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, Added support for shared constants in the update constant transformation.

Copy link
Contributor

@alexsu52 alexsu52 Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you a test on it?

Copy link
Contributor Author

@anzr299 anzr299 Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the test cases test_compress_weights_shared and test_compress_weights_model_size_conv in tests/torch/fx/test_compress_weights.py cover this case for compression of shared weights.


# creates weight decompressor
if compression_config.mode == CompressWeightsMode.INT8_SYM:
Expand All @@ -261,15 +249,16 @@ def transform_model(
decompressor_type = "asymmetric"

# registry weight decompression module in the model
compressed_weight_name = wc_params.node_with_weight.node_name
# TODO: Find a more efficient way to access updated constant name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to solve this issue in this PR and remove implicit connection between weight assigment and get weight name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright

compressed_weight_name = wc_params.node_with_weight.node_name + "_updated_constant0"
decompressor_name = f"{decompressor_type}_weights_decompressor_{compressed_weight_name.replace('.', '_')}"

# inserts the weight decompressor into the model as the post hook on the model weight
transformation_layout.register(
FXApplyTransformationCommand(
module_insertion_transformation_builder(
decompressor,
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=new_weight.name)],
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=compressed_weight_name)],
decompressor_name,
)
)
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,13 +474,13 @@ def compress_weights(

if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]:
raise AttributeError(
"Torch backend supports only INT8_ASYM, INT8_SYM modes for weight compression, "
"TorchFX backend supports only INT8_ASYM, INT8_SYM modes for weight compression, "
f"but given {mode.value} mode."
)

if True in [awq, scale_estimation, gptq]:
if any((awq, scale_estimation, gptq)):
raise AttributeError(
"Torch backend doesn`t supports scale estimation and AWQ algorithm, "
"TorchFX backend doesn`t supports scale estimation and AWQ algorithm, "
"but awq=True or scale_estimation=True or gptq=True is specified."
)
compression_weights_impl = fx_compression_weights_impl
Expand Down
9 changes: 6 additions & 3 deletions tests/torch/fx/test_compress_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def test_get_dtype_attribute_of_parameter():
dummy_input = torch.randint(0, 10, [3, 3])
exported_model = capture_pre_autograd_graph(model, args=(dummy_input,))
compressed_model = compress_weights(exported_model)
assert compressed_model.matmul_compressed_weight0.dtype == torch.uint8
assert compressed_model.matmul_updated_constant0.dtype == torch.uint8
compressed_model(dummy_input)
assert compressed_model.matmul_compressed_weight0.dtype == torch.uint8
assert compressed_model.matmul_updated_constant0.dtype == torch.uint8


@pytest.mark.parametrize("dtype", ("float16", "float32"))
Expand All @@ -183,6 +183,9 @@ def test_model_devices_and_precisions(use_cuda, dtype):
compressed_model = compress_weights(exported_model)
result = compressed_model(dummy_input)
# Scale should always be in float16
assert compressed_model.state_dict()["asymmetric_weights_decompressor_matmul._scale"].dtype == torch.float16
assert (
compressed_model.state_dict()["asymmetric_weights_decompressor_matmul_updated_constant0._scale"].dtype
== torch.float16
)
# Result should be in the precision of the model
assert result.dtype == dtype
Loading