From c85a5d571ff383e5adf013270390fce2beb4f3d3 Mon Sep 17 00:00:00 2001 From: xiaowan0322 <35219371+xiaowan0322@users.noreply.github.com> Date: Fri, 17 Mar 2023 09:39:49 +0800 Subject: [PATCH] [torch-quant] automatic mixed precision quantization (#1081) --- tools/torch_quant/README.md | 8 +- tools/torch_quant/tests/test_amp_module.py | 72 ++++++++++++++++++ .../test_pass_quantizable_module_to_amp.py | 53 +++++++++++++ tools/torch_quant/tests/test_quantizer.py | 75 ++++++++++++------- tools/torch_quant/torch_quant/amp_module.py | 65 ++++++++++++++++ tools/torch_quant/torch_quant/graph.py | 40 +++++++++- tools/torch_quant/torch_quant/observer.py | 17 +++-- tools/torch_quant/torch_quant/quantizer.py | 42 ++++++++++- 8 files changed, 335 insertions(+), 37 deletions(-) create mode 100644 tools/torch_quant/tests/test_amp_module.py create mode 100644 tools/torch_quant/tests/test_pass_quantizable_module_to_amp.py create mode 100644 tools/torch_quant/torch_quant/amp_module.py diff --git a/tools/torch_quant/README.md b/tools/torch_quant/README.md index 6c8ee9d6ea6..c7232562693 100644 --- a/tools/torch_quant/README.md +++ b/tools/torch_quant/README.md @@ -39,6 +39,12 @@ quantizer = Quantizer() # create a proxy model and run forward to calibrate quantization params quantizer.calib(model)(typical_data) +# [Optional] perform automatic mixed precision quantization +# create a proxy model and run forward to fallback few sensitive layers to float precision +amp_model = quantizer.amp(model) +amp_model(typical_data) +quantizer.fallback(amp_model, num=1) + # create a proxy model representing quantized model quant_model = quantizer.quantize(model) @@ -58,4 +64,4 @@ opt = torch_blade.optimize(quant_model) *TBD* -A initial example can be found at [bert_ptq_demo.py](bert_ptq_demo.py) \ No newline at end of file +A initial example can be found at [bert_ptq_demo.py](bert_ptq_demo.py) diff --git a/tools/torch_quant/tests/test_amp_module.py b/tools/torch_quant/tests/test_amp_module.py new file mode 100644 index 00000000000..3eda3351dcd --- /dev/null +++ b/tools/torch_quant/tests/test_amp_module.py @@ -0,0 +1,72 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import Callable + +import torch +from torch.quantization import QConfig + +from torch_quant.amp_module import AmpModule +from torch_quant.observed_module import Linear +from torch_quant.observer import ( + BiasObserver, + MinMaxObserver, + Observer, + PerChannelMinMaxObserver, + toggle_observer, +) + + +class TestAmpModule(unittest.TestCase): + def test_basic(self): + model = torch.nn.Linear(2, 4) + dummy_input = torch.randn((4, 2)) + original_output = model(dummy_input) + + def _act_ob(data: torch.Tensor) -> None: + ob = MinMaxObserver() + ob.set_mode(observe=True, fake_quant=False) + ob(data) + return ob + + act_ob = _act_ob(dummy_input) + out_ob = _act_ob(model(dummy_input)) + + def _w_ob(ctr: Callable[..., Observer], param: torch.nn.Parameter) -> Observer: + ob = ctr() + ob.set_mode(observe=True, fake_quant=False) + ob(param) + return ob + + w_ob_ctr = PerChannelMinMaxObserver + w_ob = _w_ob(w_ob_ctr, model.weight) + bias_ob_ctr = partial(BiasObserver, w_ob, act_ob) + bias_ob = _w_ob(bias_ob_ctr, model.bias) + model.qconfig = QConfig(activation=None, weight=w_ob_ctr) + observed_model = Linear.from_float(model, w_ob, bias_ob) + amp = AmpModule(model, observed_model, act_ob, out_ob) + + amp_output = amp(dummy_input) + self.assertTrue(torch.equal(original_output, amp_output)) + + w_ob.set_mode(observe=False, fake_quant=True) + quant_weight = w_ob(model.weight) + model.load_state_dict({'weight': quant_weight, 'bias': model.bias}) + toggle_observer(model, observe=False, fake_quant=True) + quant_output = out_ob(model(act_ob(dummy_input))) + mse = torch.mean(torch.pow(original_output - quant_output, 2)) + self.assertTrue(torch.equal(amp.noise, mse)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/torch_quant/tests/test_pass_quantizable_module_to_amp.py b/tools/torch_quant/tests/test_pass_quantizable_module_to_amp.py new file mode 100644 index 00000000000..bd92cd38707 --- /dev/null +++ b/tools/torch_quant/tests/test_pass_quantizable_module_to_amp.py @@ -0,0 +1,53 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tests.models import SimpleModule +from torch_quant.amp_module import AmpModule +from torch_quant.graph import ( + GraphModContext, + insert_act_observer, + quantizable_module_to_amp, + set_qconfig, +) +from torch_quant.observer import BiasObserver, MinMaxObserver, PerChannelMinMaxObserver + + +class QuantizableModuleToAmpTest(unittest.TestCase): + def test_base(self) -> None: + model = SimpleModule() + ctx = GraphModContext( + gm=torch.fx.symbolic_trace(model), + root=model, + act_ob_ctr=MinMaxObserver, + w_ob_ctr=PerChannelMinMaxObserver, + bias_ob_ctr=BiasObserver, + ) + insert_act_observer(ctx) + ctx.gm = torch.fx.symbolic_trace(model) + set_qconfig(ctx) + quantizable_module_to_amp(ctx) + amp_modules = dict(ctx.gm.named_modules()) + for name, mod in model.named_modules(): + if type(mod) in [torch.nn.Conv2d, torch.nn.Linear]: + self.assertTrue(isinstance(amp_modules[name], AmpModule)) + + dummy_input = torch.randn((1, 2, 5, 5)) + original_output = model(dummy_input) + amp_output = ctx.gm(dummy_input) + self.assertTrue(torch.equal(original_output, amp_output)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/torch_quant/tests/test_quantizer.py b/tools/torch_quant/tests/test_quantizer.py index cf28a6119c9..f85f7d2e5e4 100644 --- a/tools/torch_quant/tests/test_quantizer.py +++ b/tools/torch_quant/tests/test_quantizer.py @@ -11,7 +11,7 @@ import tempfile import unittest -from typing import Optional +from typing import List, Optional import torch import torch.nn.intrinsic as nni @@ -19,24 +19,24 @@ import torch.nn.quantized._reference as nnqr from parameterized import parameterized -from tests.models import SimpleModule, SubModule, UntraceableSimpleModule +from tests.models import LinearReLU, SimpleModule, SubModule, UntraceableSimpleModule +from torch_quant.amp_module import AmpModule from torch_quant.module import ModuleFilter from torch_quant.observer import toggle_observer from torch_quant.quantizer import Backend, Quantizer -def parameterized_backend(backend): +def parameterized_with_backends(parameters: Optional[List] = None): + if parameters is None: + parameters = [(Backend.REFERENCE,), (Backend.FBGEMM,), (Backend.DISC,)] # skip if fbgemm not available if torch.backends.quantized.engine != 'fbgemm': - backend.remove((Backend.FBGEMM, )) - return parameterized.expand(backend) + parameters = [param for param in parameters if Backend.FBGEMM not in param] + return parameterized.expand(parameters) + class QuantizerTest(unittest.TestCase): - @parameterized_backend([ - (Backend.REFERENCE, ), - (Backend.FBGEMM, ), - (Backend.DISC, ), - ]) + @parameterized_with_backends() def test_calib_and_quantize(self, backend: Backend) -> None: model = SimpleModule() quantizer = Quantizer(backend=backend) @@ -54,10 +54,7 @@ def test_calib_and_quantize(self, backend: Backend) -> None: quant_output, original_output, rtol=0.1, atol=0.5) # TODO(litan.ls): QAT is more suitable for this case - @parameterized.expand([ - (Backend.REFERENCE, ), - (Backend.DISC, ), - ]) + @parameterized_with_backends() def test_load_from_state_dict(self, backend: Backend) -> None: model = SimpleModule() quantizer = Quantizer(backend=backend) @@ -73,10 +70,7 @@ def test_load_from_state_dict(self, backend: Backend) -> None: quant_output = quantizer.quantize(model)(dummy_input) self.assertTrue(torch.equal(loaded_quant_output, quant_output)) - @parameterized.expand([ - (Backend.REFERENCE, ), - (Backend.DISC, ), - ]) + @parameterized_with_backends() def test_save_and_load_quantized(self, backend: Backend) -> None: model = SimpleModule() quantizer = Quantizer(backend=backend) @@ -125,9 +119,39 @@ def test_calib_quantize_qat_quantize_state_equal(self): out3 = quant_model(dummy_input) self.assertTrue(torch.equal(out2, out3)) - @parameterized.expand( + @parameterized_with_backends() + def test_calib_amp_quantize(self, backend: Backend) -> None: + model = SimpleModule() + dummy_input = torch.randn((1, 2, 5, 5)) + quantizer = Quantizer(backend=backend) + original_output = model(dummy_input) + + calib_model = quantizer.calib(model) + calib_output = calib_model(dummy_input) + self.assertTrue(torch.equal(original_output, calib_output)) + + amp_model = quantizer.amp(model) + amp_modules = dict(amp_model.named_modules()) + for name, mod in model.named_modules(): + if type(mod) in [torch.nn.Conv2d, torch.nn.Linear]: + self.assertTrue(isinstance(amp_modules[name], AmpModule)) + amp_output = amp_model(dummy_input) + self.assertTrue(torch.equal(original_output, amp_output)) + quantizer.fallback(amp_model, num=2) + self.assertEqual(len(quantizer.module_filter.exclude_names), 2) + + quant_model = quantizer.quantize(model) + modules = dict(model.named_modules()) + quant_modules = dict(quant_model.named_modules()) + for name in quantizer.module_filter.exclude_names: + self.assertEqual(quant_modules[name], modules[name]) + quant_output = quant_model(dummy_input) + self.assertFalse(torch.equal(original_output, quant_output)) + torch.testing.assert_close(quant_output, original_output, rtol=0.1, atol=0.5) + + @parameterized_with_backends( [ - (Backend.REFERENCE, ModuleFilter(include_op_types=[torch.nn.Linear])), + (Backend.FBGEMM, ModuleFilter(include_op_types=[torch.nn.Linear])), (Backend.DISC, ModuleFilter(exclude_op_types=[torch.nn.Conv2d])), ] ) @@ -154,10 +178,10 @@ def test_calib_and_quantize_with_op_types_filter( self.assertFalse(torch.equal(original_output, quant_output)) torch.testing.assert_close(quant_output, original_output, rtol=0.1, atol=0.5) - @parameterized.expand( + @parameterized_with_backends( [ ( - Backend.REFERENCE, + Backend.FBGEMM, ModuleFilter( include_names=['traceable_sub'], exclude_names=['traceable_sub.sub.conv'], @@ -196,11 +220,7 @@ def test_calib_and_quantize_with_module_filter( quant_output = qmodel(dummy_input) torch.testing.assert_close(quant_output, original_output, rtol=0.1, atol=0.5) - - @parameterized.expand([ - (Backend.REFERENCE, ), - # (Backend.FBGEMM, ), - ]) + @parameterized_with_backends([(Backend.REFERENCE,), (Backend.FBGEMM,)]) def test_calib_and_quantize_with_module_fusion(self, backend): model = SimpleModule() quantizer = Quantizer(backend=backend) @@ -213,7 +233,6 @@ def test_calib_and_quantize_with_module_fusion(self, backend): self.assertTrue(isinstance(quant_model.sub.conv[0], nnqr.Conv2d)) elif backend == Backend.FBGEMM: self.assertTrue(isinstance(quant_model.sub.conv, nniq.ConvReLU2d)) - self.assertTrue(isinstance(quant_model.linear, nniq.LinearReLU)) if __name__ == '__main__': diff --git a/tools/torch_quant/torch_quant/amp_module.py b/tools/torch_quant/torch_quant/amp_module.py new file mode 100644 index 00000000000..9804274a421 --- /dev/null +++ b/tools/torch_quant/torch_quant/amp_module.py @@ -0,0 +1,65 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List + +import torch +import torch.nn as nn + +from torch_quant.observer import Observer, toggle_observer + +LOGGER = logging.getLogger(__name__) + + +class AmpModule(nn.Module): + """ + This module includes original float op and fake quantized op (i.e. observed module). + Mean square error is used to analyze the quantization precision. + """ + + def __init__( + self, + float_op: nn.Module, + observed_op: nn.Module, + act_ob: Observer, + out_ob: Observer, + ) -> None: + super(AmpModule, self).__init__() + self.float_op = float_op + self.observed_op = observed_op + self.act_ob = act_ob + self.out_ob = out_ob + self.register_buffer('noise', torch.tensor(0.0)) + toggle_observer(self, observe=False, fake_quant=True) + + def forward(self, x): + y = self.float_op(x) + quant_y = self.out_ob(self.observed_op(self.act_ob(x))) + noise = torch.mean(torch.pow(y.detach() - quant_y.detach(), 2)) + self.noise.copy_(self.noise + noise) + return y + + +def get_fallback_names(root: nn.Module, num: int) -> List[str]: + modules = dict(root.named_modules()) + candidates = [k for k, v in modules.items() if isinstance(v, AmpModule)] + if len(candidates) < num: + LOGGER.warning( + f"No module be quantized. There are only {len(candidates)} " + f"quantizable modules, but fallback number is {num}." + ) + num = len(candidates) + LOGGER.info(f"Fallback {num} modules to float precision.") + noises = {name: modules[name].noise for name in candidates} + sorted_noises = sorted(noises.items(), key=lambda x: x[1], reverse=True) + fallback_names = [k[0] for k in sorted_noises[:num]] + return fallback_names diff --git a/tools/torch_quant/torch_quant/graph.py b/tools/torch_quant/torch_quant/graph.py index fbe52041c35..10829cc44bb 100644 --- a/tools/torch_quant/torch_quant/graph.py +++ b/tools/torch_quant/torch_quant/graph.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import logging from collections import defaultdict from functools import partial @@ -22,6 +23,7 @@ import torch.nn.quantized._reference as nnqr from torch.fx import GraphModule, Node from torch.quantization import DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS, QConfig +from torch_quant.amp_module import AmpModule from torch_quant.module import ModuleFilter from torch_quant.observed_module import OB_MODULE_MAPPING @@ -34,8 +36,6 @@ nn.Conv2d, nn.Conv3d, nn.Linear, - nni.LinearReLU, - nni.ConvReLU2d, ) REF_TO_QUANT_MAP = { @@ -393,3 +393,39 @@ def fold_qdq(ctx: GraphModContext) -> None: ctx.gm.graph.erase_node(q) ctx.gm.graph.eliminate_dead_code() ctx.gm.recompile() + + +def quantizable_module_to_amp(ctx: GraphModContext) -> None: + for node in ctx.quantizable_nodes(): + src = ctx.modules[node.target] + fused_module = None + if isinstance(src, nn.intrinsic._FusedModule): + fused_module = src + src = fused_module[0] + dst_type = OB_MODULE_MAPPING.get(type(src)) + if dst_type is None: + raise ValueError(f'{type(src)} cannot be observed.') + act = node.args[0] + act_name = act.name if act.op == 'call_function' else act.target + act_ob = ctx.modules[f'{act_name}_ob'] + out_ob = ctx.modules[f'{node.target}_ob'] + w_ob = ctx.modules.get(f'{node.target}.w_ob') + if w_ob is None: + w_ob = ctx.w_ob_ctr() + w_ob.set_mode(observe=True, fake_quant=False) + w_ob(src.weight) + bias_ob = None + if getattr(src, 'bias', None) is not None: + bias_ob = ctx.modules.get(f'{node.target}.bias_ob') + if bias_ob is None and ctx.bias_ob_ctr: + bias_ob = ctx.bias_ob_ctr(w_ob, act_ob) + bias_ob.set_mode(observe=True, fake_quant=False) + bias_ob(src.bias) + dst = dst_type.from_float(src, w_ob, bias_ob) + if fused_module is not None: + copied = copy.deepcopy(fused_module) + copied[0] = dst + dst = copied + amp = AmpModule(ctx.modules[node.target], dst, act_ob, out_ob) + ctx.replace_module(node.target, amp) + ctx.gm.recompile() diff --git a/tools/torch_quant/torch_quant/observer.py b/tools/torch_quant/torch_quant/observer.py index b3437a3931c..3c7718fa22b 100644 --- a/tools/torch_quant/torch_quant/observer.py +++ b/tools/torch_quant/torch_quant/observer.py @@ -155,18 +155,25 @@ def _calculate_qparams(self, min_val, max_val) -> Tuple[torch.Tensor, torch.Tens zero_point = torch.clamp(zero_point, q_min, q_max) LOGGER.debug( - f'calc qparams: {self.min_val=}, {self.max_val=}, {self.q_min=}, {self.q_max=}, {self.bit=}, {self.signed=}, {scale=}, {zero_point=}') + f'calc qparams: min_val={self.min_val}, max_val={self.max_val}, ' + f'q_min={self.q_min}, q_max={self.q_max}, bit={self.bit}, ' + f'signed={self.signed}, scale={scale}, zero_point={zero_point}' + ) return scale, zero_point @classmethod def from_qparams(cls, qparams: QParams): raise RuntimeError(f"Instantiating a {type(cls)} from QParams is not implemented") + def set_mode(self, *, observe: bool, fake_quant: bool) -> None: + self.observe = observe + self.fake_quant = fake_quant + + def toggle_observer(root: nn.Module, *, observe: bool, fake_quant: bool) -> None: for m in root.modules(): if isinstance(m, Observer): - m.observe = observe - m.fake_quant = fake_quant + m.set_mode(observe=observe, fake_quant=fake_quant) DTYPE_TO_BIT_SIGN = { @@ -422,8 +429,8 @@ def __init__( self.register_buffer("histogram", torch.zeros(self.bins)) self.register_buffer("min_val", torch.tensor(float("inf"))) self.register_buffer("max_val", torch.tensor(float("-inf"))) - self.register_buffer("scale", torch.tensor([1.])) - self.register_buffer("zero_point", torch.tensor([0], dtype=torch.int32)) + self.register_buffer("scale", torch.tensor(1.)) + self.register_buffer("zero_point", torch.tensor(0, dtype=torch.int32)) self.dst_nbins = 2 ** self.bit self.upsample_rate = upsample_rate diff --git a/tools/torch_quant/torch_quant/quantizer.py b/tools/torch_quant/torch_quant/quantizer.py index 4b21153ce74..bf4e4e46412 100644 --- a/tools/torch_quant/torch_quant/quantizer.py +++ b/tools/torch_quant/torch_quant/quantizer.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn from torch.fx import GraphModule, Tracer +from torch_quant.amp_module import get_fallback_names from torch_quant.graph import ( GraphModContext, fold_qdq, @@ -23,6 +24,7 @@ insert_act_observer, observer_to_qdq, q_ref_dq_to_fbgemm, + quantizable_module_to_amp, quantizable_module_to_observed, quantizable_module_to_ref, set_qconfig @@ -102,7 +104,7 @@ def __init__(self, module_filter: Optional[ModuleFilter] = None, raise ValueError('fbgemm is not available, it only for x86_64') def calib_gm( - self, name: str, gm: GraphModule, root: nn.Module, ob_types: ObserverTypes, + self, name: str, gm: GraphModule, root: nn.Module, ob_types: ObserverTypes ) -> None: mf = submodule_filter(self.module_filter, name) if self.module_filter else None ctx = GraphModContext( @@ -138,6 +140,44 @@ def calib(self, model: nn.Module, self.calib_gm(name, traced.gm, traced.m, ob_types) return copy_and_replace(model, trace_mapping) + def amp_gm( + self, name: str, gm: GraphModule, root: nn.Module, ob_types: ObserverTypes + ) -> None: + mf = submodule_filter(self.module_filter, name) if self.module_filter else None + ctx = GraphModContext( + gm, root, mf, ob_types.act_ob_ctr, ob_types.w_ob_ctr, ob_types.bias_ob_ctr + ) + if self.backend == Backend.DISC: + ctx.modify_graph([set_qconfig, quantizable_module_to_amp]) + else: + ctx.modify_graph([set_qconfig, fuse_modules, quantizable_module_to_amp]) + toggle_observer(gm, observe=False, fake_quant=True) + + def amp( + self, + model: nn.Module, + act_ob_ctr: Optional[Callable[..., Observer]] = None, + w_ob_ctr: Optional[Callable[..., Observer]] = None, + bias_ob_ctr: Optional[Callable[..., Observer]] = None, + ) -> nn.Module: + ob_types = get_observer_types( + act_ob_ctr, + w_ob_ctr, + bias_ob_ctr, + DEFAULT_ACT_OB_CTR[self.backend], + DEFAULT_W_OB_CTR[self.backend], + DEFAULT_BIAS_OB_CTR, + ) + trace_mapping = fx_trace(model, self.module_filter, tracer=self.tracer) + for name, traced in trace_mapping.items(): + self.amp_gm(name, traced.gm, traced.m, ob_types) + return copy_and_replace(model, trace_mapping) + + def fallback(self, model: nn.Module, num: int) -> None: + self.module_filter = self.module_filter or ModuleFilter() + self.module_filter.exclude_names = self.module_filter.exclude_names or list() + self.module_filter.exclude_names.extend(get_fallback_names(model, num)) + def qat_gm( self, name: str, gm: GraphModule, root: nn.Module, ob_types: ObserverTypes ) -> None: