-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torch-quant] automatic mixed precision quantization (#1081)
- Loading branch information
1 parent
d95b24a
commit c85a5d5
Showing
8 changed files
with
335 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright 2023 The BladeDISC Authors. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
from functools import partial | ||
from typing import Callable | ||
|
||
import torch | ||
from torch.quantization import QConfig | ||
|
||
from torch_quant.amp_module import AmpModule | ||
from torch_quant.observed_module import Linear | ||
from torch_quant.observer import ( | ||
BiasObserver, | ||
MinMaxObserver, | ||
Observer, | ||
PerChannelMinMaxObserver, | ||
toggle_observer, | ||
) | ||
|
||
|
||
class TestAmpModule(unittest.TestCase): | ||
def test_basic(self): | ||
model = torch.nn.Linear(2, 4) | ||
dummy_input = torch.randn((4, 2)) | ||
original_output = model(dummy_input) | ||
|
||
def _act_ob(data: torch.Tensor) -> None: | ||
ob = MinMaxObserver() | ||
ob.set_mode(observe=True, fake_quant=False) | ||
ob(data) | ||
return ob | ||
|
||
act_ob = _act_ob(dummy_input) | ||
out_ob = _act_ob(model(dummy_input)) | ||
|
||
def _w_ob(ctr: Callable[..., Observer], param: torch.nn.Parameter) -> Observer: | ||
ob = ctr() | ||
ob.set_mode(observe=True, fake_quant=False) | ||
ob(param) | ||
return ob | ||
|
||
w_ob_ctr = PerChannelMinMaxObserver | ||
w_ob = _w_ob(w_ob_ctr, model.weight) | ||
bias_ob_ctr = partial(BiasObserver, w_ob, act_ob) | ||
bias_ob = _w_ob(bias_ob_ctr, model.bias) | ||
model.qconfig = QConfig(activation=None, weight=w_ob_ctr) | ||
observed_model = Linear.from_float(model, w_ob, bias_ob) | ||
amp = AmpModule(model, observed_model, act_ob, out_ob) | ||
|
||
amp_output = amp(dummy_input) | ||
self.assertTrue(torch.equal(original_output, amp_output)) | ||
|
||
w_ob.set_mode(observe=False, fake_quant=True) | ||
quant_weight = w_ob(model.weight) | ||
model.load_state_dict({'weight': quant_weight, 'bias': model.bias}) | ||
toggle_observer(model, observe=False, fake_quant=True) | ||
quant_output = out_ob(model(act_ob(dummy_input))) | ||
mse = torch.mean(torch.pow(original_output - quant_output, 2)) | ||
self.assertTrue(torch.equal(amp.noise, mse)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
53 changes: 53 additions & 0 deletions
53
tools/torch_quant/tests/test_pass_quantizable_module_to_amp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2023 The BladeDISC Authors. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import torch | ||
|
||
from tests.models import SimpleModule | ||
from torch_quant.amp_module import AmpModule | ||
from torch_quant.graph import ( | ||
GraphModContext, | ||
insert_act_observer, | ||
quantizable_module_to_amp, | ||
set_qconfig, | ||
) | ||
from torch_quant.observer import BiasObserver, MinMaxObserver, PerChannelMinMaxObserver | ||
|
||
|
||
class QuantizableModuleToAmpTest(unittest.TestCase): | ||
def test_base(self) -> None: | ||
model = SimpleModule() | ||
ctx = GraphModContext( | ||
gm=torch.fx.symbolic_trace(model), | ||
root=model, | ||
act_ob_ctr=MinMaxObserver, | ||
w_ob_ctr=PerChannelMinMaxObserver, | ||
bias_ob_ctr=BiasObserver, | ||
) | ||
insert_act_observer(ctx) | ||
ctx.gm = torch.fx.symbolic_trace(model) | ||
set_qconfig(ctx) | ||
quantizable_module_to_amp(ctx) | ||
amp_modules = dict(ctx.gm.named_modules()) | ||
for name, mod in model.named_modules(): | ||
if type(mod) in [torch.nn.Conv2d, torch.nn.Linear]: | ||
self.assertTrue(isinstance(amp_modules[name], AmpModule)) | ||
|
||
dummy_input = torch.randn((1, 2, 5, 5)) | ||
original_output = model(dummy_input) | ||
amp_output = ctx.gm(dummy_input) | ||
self.assertTrue(torch.equal(original_output, amp_output)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright 2023 The BladeDISC Authors. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from typing import List | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from torch_quant.observer import Observer, toggle_observer | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class AmpModule(nn.Module): | ||
""" | ||
This module includes original float op and fake quantized op (i.e. observed module). | ||
Mean square error is used to analyze the quantization precision. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
float_op: nn.Module, | ||
observed_op: nn.Module, | ||
act_ob: Observer, | ||
out_ob: Observer, | ||
) -> None: | ||
super(AmpModule, self).__init__() | ||
self.float_op = float_op | ||
self.observed_op = observed_op | ||
self.act_ob = act_ob | ||
self.out_ob = out_ob | ||
self.register_buffer('noise', torch.tensor(0.0)) | ||
toggle_observer(self, observe=False, fake_quant=True) | ||
|
||
def forward(self, x): | ||
y = self.float_op(x) | ||
quant_y = self.out_ob(self.observed_op(self.act_ob(x))) | ||
noise = torch.mean(torch.pow(y.detach() - quant_y.detach(), 2)) | ||
self.noise.copy_(self.noise + noise) | ||
return y | ||
|
||
|
||
def get_fallback_names(root: nn.Module, num: int) -> List[str]: | ||
modules = dict(root.named_modules()) | ||
candidates = [k for k, v in modules.items() if isinstance(v, AmpModule)] | ||
if len(candidates) < num: | ||
LOGGER.warning( | ||
f"No module be quantized. There are only {len(candidates)} " | ||
f"quantizable modules, but fallback number is {num}." | ||
) | ||
num = len(candidates) | ||
LOGGER.info(f"Fallback {num} modules to float precision.") | ||
noises = {name: modules[name].noise for name in candidates} | ||
sorted_noises = sorted(noises.items(), key=lambda x: x[1], reverse=True) | ||
fallback_names = [k[0] for k in sorted_noises[:num]] | ||
return fallback_names |
Oops, something went wrong.