Skip to content

Commit

Permalink
[torch-quant] automatic mixed precision quantization (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaowan0322 authored Mar 17, 2023
1 parent d95b24a commit c85a5d5
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 37 deletions.
8 changes: 7 additions & 1 deletion tools/torch_quant/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ quantizer = Quantizer()
# create a proxy model and run forward to calibrate quantization params
quantizer.calib(model)(typical_data)

# [Optional] perform automatic mixed precision quantization
# create a proxy model and run forward to fallback few sensitive layers to float precision
amp_model = quantizer.amp(model)
amp_model(typical_data)
quantizer.fallback(amp_model, num=1)

# create a proxy model representing quantized model
quant_model = quantizer.quantize(model)

Expand All @@ -58,4 +64,4 @@ opt = torch_blade.optimize(quant_model)

*TBD*

A initial example can be found at [bert_ptq_demo.py](bert_ptq_demo.py)
A initial example can be found at [bert_ptq_demo.py](bert_ptq_demo.py)
72 changes: 72 additions & 0 deletions tools/torch_quant/tests/test_amp_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023 The BladeDISC Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from functools import partial
from typing import Callable

import torch
from torch.quantization import QConfig

from torch_quant.amp_module import AmpModule
from torch_quant.observed_module import Linear
from torch_quant.observer import (
BiasObserver,
MinMaxObserver,
Observer,
PerChannelMinMaxObserver,
toggle_observer,
)


class TestAmpModule(unittest.TestCase):
def test_basic(self):
model = torch.nn.Linear(2, 4)
dummy_input = torch.randn((4, 2))
original_output = model(dummy_input)

def _act_ob(data: torch.Tensor) -> None:
ob = MinMaxObserver()
ob.set_mode(observe=True, fake_quant=False)
ob(data)
return ob

act_ob = _act_ob(dummy_input)
out_ob = _act_ob(model(dummy_input))

def _w_ob(ctr: Callable[..., Observer], param: torch.nn.Parameter) -> Observer:
ob = ctr()
ob.set_mode(observe=True, fake_quant=False)
ob(param)
return ob

w_ob_ctr = PerChannelMinMaxObserver
w_ob = _w_ob(w_ob_ctr, model.weight)
bias_ob_ctr = partial(BiasObserver, w_ob, act_ob)
bias_ob = _w_ob(bias_ob_ctr, model.bias)
model.qconfig = QConfig(activation=None, weight=w_ob_ctr)
observed_model = Linear.from_float(model, w_ob, bias_ob)
amp = AmpModule(model, observed_model, act_ob, out_ob)

amp_output = amp(dummy_input)
self.assertTrue(torch.equal(original_output, amp_output))

w_ob.set_mode(observe=False, fake_quant=True)
quant_weight = w_ob(model.weight)
model.load_state_dict({'weight': quant_weight, 'bias': model.bias})
toggle_observer(model, observe=False, fake_quant=True)
quant_output = out_ob(model(act_ob(dummy_input)))
mse = torch.mean(torch.pow(original_output - quant_output, 2))
self.assertTrue(torch.equal(amp.noise, mse))


if __name__ == "__main__":
unittest.main()
53 changes: 53 additions & 0 deletions tools/torch_quant/tests/test_pass_quantizable_module_to_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 The BladeDISC Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from tests.models import SimpleModule
from torch_quant.amp_module import AmpModule
from torch_quant.graph import (
GraphModContext,
insert_act_observer,
quantizable_module_to_amp,
set_qconfig,
)
from torch_quant.observer import BiasObserver, MinMaxObserver, PerChannelMinMaxObserver


class QuantizableModuleToAmpTest(unittest.TestCase):
def test_base(self) -> None:
model = SimpleModule()
ctx = GraphModContext(
gm=torch.fx.symbolic_trace(model),
root=model,
act_ob_ctr=MinMaxObserver,
w_ob_ctr=PerChannelMinMaxObserver,
bias_ob_ctr=BiasObserver,
)
insert_act_observer(ctx)
ctx.gm = torch.fx.symbolic_trace(model)
set_qconfig(ctx)
quantizable_module_to_amp(ctx)
amp_modules = dict(ctx.gm.named_modules())
for name, mod in model.named_modules():
if type(mod) in [torch.nn.Conv2d, torch.nn.Linear]:
self.assertTrue(isinstance(amp_modules[name], AmpModule))

dummy_input = torch.randn((1, 2, 5, 5))
original_output = model(dummy_input)
amp_output = ctx.gm(dummy_input)
self.assertTrue(torch.equal(original_output, amp_output))


if __name__ == '__main__':
unittest.main()
75 changes: 47 additions & 28 deletions tools/torch_quant/tests/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,32 @@

import tempfile
import unittest
from typing import Optional
from typing import List, Optional

import torch
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.quantized._reference as nnqr
from parameterized import parameterized

from tests.models import SimpleModule, SubModule, UntraceableSimpleModule
from tests.models import LinearReLU, SimpleModule, SubModule, UntraceableSimpleModule
from torch_quant.amp_module import AmpModule
from torch_quant.module import ModuleFilter
from torch_quant.observer import toggle_observer
from torch_quant.quantizer import Backend, Quantizer


def parameterized_backend(backend):
def parameterized_with_backends(parameters: Optional[List] = None):
if parameters is None:
parameters = [(Backend.REFERENCE,), (Backend.FBGEMM,), (Backend.DISC,)]
# skip if fbgemm not available
if torch.backends.quantized.engine != 'fbgemm':
backend.remove((Backend.FBGEMM, ))
return parameterized.expand(backend)
parameters = [param for param in parameters if Backend.FBGEMM not in param]
return parameterized.expand(parameters)


class QuantizerTest(unittest.TestCase):
@parameterized_backend([
(Backend.REFERENCE, ),
(Backend.FBGEMM, ),
(Backend.DISC, ),
])
@parameterized_with_backends()
def test_calib_and_quantize(self, backend: Backend) -> None:
model = SimpleModule()
quantizer = Quantizer(backend=backend)
Expand All @@ -54,10 +54,7 @@ def test_calib_and_quantize(self, backend: Backend) -> None:
quant_output, original_output, rtol=0.1, atol=0.5)

# TODO(litan.ls): QAT is more suitable for this case
@parameterized.expand([
(Backend.REFERENCE, ),
(Backend.DISC, ),
])
@parameterized_with_backends()
def test_load_from_state_dict(self, backend: Backend) -> None:
model = SimpleModule()
quantizer = Quantizer(backend=backend)
Expand All @@ -73,10 +70,7 @@ def test_load_from_state_dict(self, backend: Backend) -> None:
quant_output = quantizer.quantize(model)(dummy_input)
self.assertTrue(torch.equal(loaded_quant_output, quant_output))

@parameterized.expand([
(Backend.REFERENCE, ),
(Backend.DISC, ),
])
@parameterized_with_backends()
def test_save_and_load_quantized(self, backend: Backend) -> None:
model = SimpleModule()
quantizer = Quantizer(backend=backend)
Expand Down Expand Up @@ -125,9 +119,39 @@ def test_calib_quantize_qat_quantize_state_equal(self):
out3 = quant_model(dummy_input)
self.assertTrue(torch.equal(out2, out3))

@parameterized.expand(
@parameterized_with_backends()
def test_calib_amp_quantize(self, backend: Backend) -> None:
model = SimpleModule()
dummy_input = torch.randn((1, 2, 5, 5))
quantizer = Quantizer(backend=backend)
original_output = model(dummy_input)

calib_model = quantizer.calib(model)
calib_output = calib_model(dummy_input)
self.assertTrue(torch.equal(original_output, calib_output))

amp_model = quantizer.amp(model)
amp_modules = dict(amp_model.named_modules())
for name, mod in model.named_modules():
if type(mod) in [torch.nn.Conv2d, torch.nn.Linear]:
self.assertTrue(isinstance(amp_modules[name], AmpModule))
amp_output = amp_model(dummy_input)
self.assertTrue(torch.equal(original_output, amp_output))
quantizer.fallback(amp_model, num=2)
self.assertEqual(len(quantizer.module_filter.exclude_names), 2)

quant_model = quantizer.quantize(model)
modules = dict(model.named_modules())
quant_modules = dict(quant_model.named_modules())
for name in quantizer.module_filter.exclude_names:
self.assertEqual(quant_modules[name], modules[name])
quant_output = quant_model(dummy_input)
self.assertFalse(torch.equal(original_output, quant_output))
torch.testing.assert_close(quant_output, original_output, rtol=0.1, atol=0.5)

@parameterized_with_backends(
[
(Backend.REFERENCE, ModuleFilter(include_op_types=[torch.nn.Linear])),
(Backend.FBGEMM, ModuleFilter(include_op_types=[torch.nn.Linear])),
(Backend.DISC, ModuleFilter(exclude_op_types=[torch.nn.Conv2d])),
]
)
Expand All @@ -154,10 +178,10 @@ def test_calib_and_quantize_with_op_types_filter(
self.assertFalse(torch.equal(original_output, quant_output))
torch.testing.assert_close(quant_output, original_output, rtol=0.1, atol=0.5)

@parameterized.expand(
@parameterized_with_backends(
[
(
Backend.REFERENCE,
Backend.FBGEMM,
ModuleFilter(
include_names=['traceable_sub'],
exclude_names=['traceable_sub.sub.conv'],
Expand Down Expand Up @@ -196,11 +220,7 @@ def test_calib_and_quantize_with_module_filter(
quant_output = qmodel(dummy_input)
torch.testing.assert_close(quant_output, original_output, rtol=0.1, atol=0.5)


@parameterized.expand([
(Backend.REFERENCE, ),
# (Backend.FBGEMM, ),
])
@parameterized_with_backends([(Backend.REFERENCE,), (Backend.FBGEMM,)])
def test_calib_and_quantize_with_module_fusion(self, backend):
model = SimpleModule()
quantizer = Quantizer(backend=backend)
Expand All @@ -213,7 +233,6 @@ def test_calib_and_quantize_with_module_fusion(self, backend):
self.assertTrue(isinstance(quant_model.sub.conv[0], nnqr.Conv2d))
elif backend == Backend.FBGEMM:
self.assertTrue(isinstance(quant_model.sub.conv, nniq.ConvReLU2d))
self.assertTrue(isinstance(quant_model.linear, nniq.LinearReLU))


if __name__ == '__main__':
Expand Down
65 changes: 65 additions & 0 deletions tools/torch_quant/torch_quant/amp_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2023 The BladeDISC Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List

import torch
import torch.nn as nn

from torch_quant.observer import Observer, toggle_observer

LOGGER = logging.getLogger(__name__)


class AmpModule(nn.Module):
"""
This module includes original float op and fake quantized op (i.e. observed module).
Mean square error is used to analyze the quantization precision.
"""

def __init__(
self,
float_op: nn.Module,
observed_op: nn.Module,
act_ob: Observer,
out_ob: Observer,
) -> None:
super(AmpModule, self).__init__()
self.float_op = float_op
self.observed_op = observed_op
self.act_ob = act_ob
self.out_ob = out_ob
self.register_buffer('noise', torch.tensor(0.0))
toggle_observer(self, observe=False, fake_quant=True)

def forward(self, x):
y = self.float_op(x)
quant_y = self.out_ob(self.observed_op(self.act_ob(x)))
noise = torch.mean(torch.pow(y.detach() - quant_y.detach(), 2))
self.noise.copy_(self.noise + noise)
return y


def get_fallback_names(root: nn.Module, num: int) -> List[str]:
modules = dict(root.named_modules())
candidates = [k for k, v in modules.items() if isinstance(v, AmpModule)]
if len(candidates) < num:
LOGGER.warning(
f"No module be quantized. There are only {len(candidates)} "
f"quantizable modules, but fallback number is {num}."
)
num = len(candidates)
LOGGER.info(f"Fallback {num} modules to float precision.")
noises = {name: modules[name].noise for name in candidates}
sorted_noises = sorted(noises.items(), key=lambda x: x[1], reverse=True)
fallback_names = [k[0] for k in sorted_noises[:num]]
return fallback_names
Loading

0 comments on commit c85a5d5

Please sign in to comment.