Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Common] Unified Scales for SDPA #3205

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f9e5d7c
Update torch_fx_backend.py
anzr299 Aug 20, 2024
5b11455
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 26, 2024
0eff5cb
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
c7b9093
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
e7097bd
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
2665666
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 2, 2024
1b4a926
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 10, 2024
74d8f4c
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 12, 2024
415a222
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 18, 2024
939a560
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 24, 2024
cc544ff
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 24, 2024
9a359ab
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 26, 2024
f3047c9
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 27, 2024
85ec57e
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Oct 3, 2024
7ad1586
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Oct 24, 2024
231ea70
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Nov 12, 2024
5512d0c
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Dec 31, 2024
bd863cc
init
anzr299 Jan 21, 2025
fe406d4
init add SDPA in scales unification map for MinMax Backends
anzr299 Jan 22, 2025
c4bd12d
Add tests for unified scales with concat and SDPA block
anzr299 Jan 22, 2025
04d6e59
pre commit fixes
anzr299 Jan 22, 2025
85bed29
Merge branch 'openvinotoolkit:develop' into common/SDPA_unified_scale
anzr299 Jan 22, 2025
c6000f2
adjust for latest nncf changes
anzr299 Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.OVConcatMetatype: self.overflow_fix_metatypes}
return {om.OVConcatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.PTCatMetatype: self.overflow_fix_metatypes}
return {om.PTCatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.PTCatMetatype: self.overflow_fix_metatypes}
return {om.PTCatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down
51 changes: 51 additions & 0 deletions tests/cross_fw/test_templates/test_unified_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import List, TypeVar

import pytest
import torch

from nncf.common.factory import NNCFGraphFactory
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from tests.torch.test_models.synthetic import ConcatSDPABlock

TModel = TypeVar("TModel")


class TemplateTestUnifiedScales:
@property
@abstractmethod
def get_backend_specific_model(self, model: TModel) -> TModel:
"""
Convert and return backend specific Model

:param model: Model (for example in PT) to be converted to backend specific model
:return: Backend specific Model
"""

@pytest.mark.parametrize(
"model_cls, unified_group, unified_group_nncf_network",
((ConcatSDPABlock, [["x", "y"]], [["/nncf_model_input_0", "/nncf_model_input_1"]]),),
)
def test_unified_groups(
self, model_cls: TModel, unified_group: List[List[str]], unified_group_nncf_network: List[List[str]]
):
backend_model = self.get_backend_specific_model(model_cls())
if isinstance(backend_model, torch.nn.Module) and not isinstance(backend_model, torch.fx.GraphModule):
unified_group = unified_group_nncf_network

nncf_graph = NNCFGraphFactory.create(backend_model)
algo = MinMaxQuantization()
algo._set_backend_entity(backend_model)
_, groups = algo._get_quantization_target_points(backend_model, nncf_graph)
assert [[target.target_node_name for target in groups] for groups in groups] == unified_group
29 changes: 29 additions & 0 deletions tests/openvino/native/test_unified_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import openvino as ov
import torch

from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales


class TestUnifiedScales(TemplateTestUnifiedScales):
def get_backend_specific_model(self, model: torch.nn.Module) -> ov.Model:
input_shape = model.INPUT_SHAPE
backend_model = ov.convert_model(
model,
example_input=(
torch.randn(input_shape),
torch.randn(input_shape),
),
)

return backend_model
30 changes: 30 additions & 0 deletions tests/torch/fx/test_unified_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from nncf.torch.nncf_network import NNCFNetwork
from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales
from tests.torch.fx.helpers import get_torch_fx_model_q_transformed


class TestUnifiedScales(TemplateTestUnifiedScales):
def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork:
input_shape = model.INPUT_SHAPE
backend_model = get_torch_fx_model_q_transformed(
model,
(
torch.randn(input_shape),
torch.randn(input_shape),
),
)

return backend_model
18 changes: 18 additions & 0 deletions tests/torch/quantization/test_unified_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from nncf.common.quantization.structs import NonWeightQuantizerId
from nncf.torch.dynamic_graph.operation_address import OperationAddress
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_creation import wrap_model
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.layers import AsymmetricQuantizer
from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales
from tests.torch.helpers import create_compressed_model_and_algo_for_test
from tests.torch.helpers import get_nodes_by_type
from tests.torch.helpers import register_bn_adaptation_init_args
Expand Down Expand Up @@ -711,3 +714,18 @@ def test_unified_scales_with_shared_nodes():

assert len(compression_ctrl.weight_quantizers) == 1 # The two embedding nodes point to a single shared layer
assert len(compression_ctrl.non_weight_quantizers) == 0 # The "add" operation has its inputs already quantized


class TestUnifiedScales(TemplateTestUnifiedScales):
def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork:
input_shape = model.INPUT_SHAPE
backend_model = wrap_model(
model,
(
torch.randn(input_shape),
torch.randn(input_shape),
),
trace_parameters=True,
)

return backend_model
16 changes: 16 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,19 @@ def forward(self, x):
kq /= 2**-2
kq = torch.softmax(kq, -1)
return torch.matmul(torch.transpose(kq, 1, 2), v)


class ConcatSDPABlock(torch.nn.Module):
INPUT_SHAPE = (2, 10, 6)

def __init__(self):
super().__init__()

def forward(self, x, y):
concatenated_input = torch.cat((x, y), dim=-1)
query = concatenated_input
key = concatenated_input
value = concatenated_input
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, dropout_p=0.2)

return attn_output