Skip to content

Commit

Permalink
jit pass of adding a fake-quant for weight-only quantizable op. (#998)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 authored Mar 17, 2023
1 parent 89730c1 commit d95b24a
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 46 deletions.
139 changes: 139 additions & 0 deletions pytorch_blade/pytorch_blade/quantization/pybind_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/script.h>
#include <cmath>
#include <unordered_set>

namespace torch {
namespace blade {
Expand Down Expand Up @@ -78,6 +79,143 @@ torch::jit::Value* insert_prim_constant(
return constant_node->output();
}

const static std::unordered_set<std::string> weight_only_list{"aten::linear"};

// Prepare graph for weight-only quantization.
// 1. extract the weight of each weight-only quantizable op
// 2. calculate the scale & zero_point of each weight. (currently use min/max
// observer)
// 3. add a torch_blade.fake_quant to each weight
// NOTE: This pass assumes that the graph is frozen.
// In other words, the weight tensor of aten op is from
// prim::Constant op instead of prim::GetAttr.
// TODO: Currently, there is not difference between the fake-quant used in
// static quantization and that used in weight-only quantization. This will
// make it impossible for use to distinguish between the two when doing mix
// type quantization. (e.g. use static and weight-only quantization
// simultaneously) Consider to add a new attribute like weight-only for
// torch_blade::fake_quant
void add_fake_quant_for_weight(Module& model) {
auto g = model.get_method("forward").graph();
Symbol sym = Symbol::fromQualString(
torch::blade::quantization::custom_fake_quant_name);

for (auto&& n : g->nodes()) {
std::string node_kind_str = n->kind().toQualString();
if (weight_only_list.find(node_kind_str) != weight_only_list.end()) {
// TODO: If support quantizing other types of layers, check whether index
// 1 meets the requirements.
Value* weight_val = n->inputs()[1];
auto weight_val_type = weight_val->type()->cast<c10::TensorType>();
if (!weight_val_type) {
// The probability of this condition being triggered is very small,
// however, for safety reasons, we still do this check.
continue;
}

Node* weight_node = weight_val->node();
if (weight_node->kind() != prim::Constant) {
// So the graph should be frozen
continue;
}
c10::optional<IValue> constant = weight_node->t(attr::value);
const at::Tensor& weight_t = constant->toIValue().toTensor();
at::Tensor weight_min_t, weight_max_t;

// TODO: determine dim according to the type of layer to be quantized
std::tie(weight_min_t, weight_max_t) = at::_aminmax(weight_t, 1);

// TODO: calculate the quantization info based on the backend
int32_t quant_min = -128;
int32_t quant_max = 127;
// the following process is same on it in
// UniformQuantizationObserverBase's _calculate_qparams
at::Tensor min_val_neg_t =
torch::min(weight_min_t, torch::zeros_like(weight_min_t));
at::Tensor max_val_pos_t =
torch::max(weight_max_t, torch::zeros_like(weight_max_t));
auto device = weight_val_type->device();
auto scale_option =
torch::TensorOptions().dtype(torch::kFloat32).device(device);
at::Tensor scale_t = torch::ones(min_val_neg_t.sizes(), scale_option);
#if PYTORCH_MAJOR_VERSION == 1 && PYTORCH_MINOR_VERSION >= 10
auto zero_point_option =
torch::TensorOptions().dtype(torch::kInt32).device(device);
#else
auto zero_point_option =
torch::TensorOptions().dtype(torch::kInt64).device(device);
#endif
at::Tensor zero_point_t =
torch::zeros(min_val_neg_t.sizes(), zero_point_option);
// for per_channel_symmetric
max_val_pos_t = torch::max(-min_val_neg_t, max_val_pos_t);
scale_t = max_val_pos_t / (float(quant_max - quant_min) / 2);
const static float epsilon = std::numeric_limits<float>::epsilon();
at::Tensor epsilon_t = torch::ones_like(scale_t) * epsilon;
scale_t = torch::max(scale_t, epsilon_t);

// Create torch_blade.fake_quant for the weight,
// and replace the origin weight input with the output
// of the new constructed fake_quant node
Node* fake_quant_node = g->insertNode(g->create(sym));
fake_quant_node->moveAfter(weight_node);
fake_quant_node->output()->setType(weight_node->outputs()[0]->type());
n->replaceInputWith(weight_val, fake_quant_node->outputs()[0]);

// Create needed inputs to the torch_blade.fake_quant
// 1. scale
Value* scale_val =
insert_prim_constant(g, fake_quant_node, false, scale_t);

// 2. zero_point
Value* zero_point_val =
insert_prim_constant(g, fake_quant_node, false, zero_point_t);

// 3. quant_min & quant_max
Value* quant_min_val =
insert_prim_constant<int>(g, fake_quant_node, false, quant_min);
Value* quant_max_val =
insert_prim_constant<int>(g, fake_quant_node, false, quant_max);

// 4. num_bits
// TODO: support more kinds of num_bits
Value* num_bits_val =
insert_prim_constant<int>(g, fake_quant_node, false, 8);

// 5. axis
Value* axis_val = insert_prim_constant<int>(g, fake_quant_node, false, 0);
Node* list_axis_node = g->insertNode(g->create(prim::ListConstruct));
list_axis_node->addInput(axis_val);
list_axis_node->moveAfter(axis_val->node());
list_axis_node->output()->setType(c10::ListType::ofInts());

// 6. boolean value
Value* use_signed_val =
insert_prim_constant<bool>(g, fake_quant_node, false, true);
Value* use_symmetric_val =
insert_prim_constant<bool>(g, fake_quant_node, false, true);
Value* use_dynamic_val =
insert_prim_constant<bool>(g, fake_quant_node, false, true);
Value* use_per_channel_val =
insert_prim_constant<bool>(g, fake_quant_node, false, true);

fake_quant_node->addInput(weight_val);
fake_quant_node->addInput(scale_val);
fake_quant_node->addInput(zero_point_val);
fake_quant_node->addInput(quant_min_val);
fake_quant_node->addInput(quant_max_val);
fake_quant_node->addInput(num_bits_val);
fake_quant_node->addInput(list_axis_node->output());
fake_quant_node->addInput(use_signed_val);
fake_quant_node->addInput(use_symmetric_val);
fake_quant_node->addInput(use_dynamic_val);
fake_quant_node->addInput(use_per_channel_val);
}
}
// some jit passes to clean the graph
EliminateDeadCode(g->block());
}

void replace_aten_fake_quant_with_custom_version(Module& model) {
auto g = model.get_method("forward").graph();
// the graph should be inlined first
Expand Down Expand Up @@ -213,6 +351,7 @@ void initQuantizationBindings(py::module& m) {
"_quantization", "torch_blade python bindings for quantization");
quantization.def(
"add_placeholder_for_fake_quant", &add_placeholder_for_fake_quant);
quantization.def("add_fake_quant_for_weight", &add_fake_quant_for_weight);
quantization.def("remove_placeholder", &remove_placeholder);
quantization.def(
"replace_aten_fake_quant_with_custom_version",
Expand Down
40 changes: 37 additions & 3 deletions pytorch_blade/tests/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
import torch.nn.functional as F
from torch import nn
Expand Down Expand Up @@ -47,7 +45,8 @@ def forward(self, x):
weight = torch.fake_quantize_per_channel_affine(
self.weight.data, self.weight_scale.data,
self.weight_zero_point.data.to(zero_point_dtype),
axis=self.weight_axis, quant_min=self.weight_quant_min, quant_max=self.weight_quant_max
axis=self.weight_axis, quant_min=self.weight_quant_min,
quant_max=self.weight_quant_max
)
y = F.conv2d(x, weight, bias=None)
return y
Expand Down Expand Up @@ -92,3 +91,38 @@ def setUp(self):
self.is_quantization_available = is_quantization_available()
if not is_quantization_available():
self.skipTest("Quantization support was not built")

def _test_fake_quant_params(self, fake_quant_node, target_val):
# The order of the constant nodes should not be fixed. So it
# is not easy to use the FileCheck system to check each attributes
# of the fake_quant node. We extract all attributes and compare
# them with the target value one-by-one.
input_list = fake_quant_node.input_list()
scale = input_list[1].node().t("value")
self.assertTrue(torch.equal(scale, target_val['scale']))

zero_point = input_list[2].node().t("value")
self.assertTrue(torch.equal(zero_point, target_val['zero_point']))

quant_min = input_list[3].node().i("value")
self.assertEqual(quant_min, target_val["quant_min"])

quant_max = input_list[4].node().i("value")
self.assertEqual(quant_max, target_val["quant_max"])

num_bits = input_list[5].node().i("value")
self.assertEqual(num_bits, target_val["num_bits"])

# TODO: find a way to check axis

use_signed = bool(input_list[7].node().i("value"))
self.assertEqual(use_signed, target_val["use_signed"])

use_symmetric = bool(input_list[8].node().i("value"))
self.assertEqual(use_symmetric, target_val["use_symmetric"])

use_dynamic = bool(input_list[9].node().i("value"))
self.assertEqual(use_dynamic, target_val["use_dynamic"])

use_per_channel = bool(input_list[10].node().i("value"))
self.assertEqual(use_per_channel, target_val["use_per_channel"])
132 changes: 94 additions & 38 deletions pytorch_blade/tests/quantization/test_graph_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,25 @@

import torch
from tests.quantization import (
TORCH_VERSION,
ModelWithFakeQuant,
PerChannelFakeQuant,
PerTensorFakeQuant,
QuantizationTestCase,
zero_point_dtype
)
from torch import nn

try:
from torch.ao.quantization.observer import PerChannelMinMaxObserver
except ModuleNotFoundError:
from torch.quantization.observer import PerChannelMinMaxObserver

from torch.nn import functional as F
from torch.testing import FileCheck
from torch_blade import tools
from torch_blade.quantization import (
_jit_add_fake_quant_for_weight,
_jit_pass_add_placeholder_for_fake_quant,
_jit_pass_remove_all_placeholder,
_jit_replace_aten_fake_quant_with_custom_version,
Expand Down Expand Up @@ -76,41 +86,6 @@ def test_insert_and_remove_fake_quant(self):


class TestReplaceAtenFakeQuant(QuantizationTestCase):

def _test_fake_quant_params(self, fake_quant_node, target_val):
# The order of the constant nodes should not be fixed. So it is not easy to
# use the FileCheck system to check each attributes of the fake_quant node.
# We extract all attributes and compare them with the target value one-by-one.
input_list = fake_quant_node.input_list()
scale = input_list[1].node().t("value")
self.assertTrue(torch.equal(scale, target_val['scale']))

zero_point = input_list[2].node().t("value")
self.assertTrue(torch.equal(zero_point, target_val['zero_point']))

quant_min = input_list[3].node().i("value")
self.assertEqual(quant_min, target_val["quant_min"])

quant_max = input_list[4].node().i("value")
self.assertEqual(quant_max, target_val["quant_max"])

num_bits = input_list[5].node().i("value")
self.assertEqual(num_bits, target_val["num_bits"])

# TODO: find a way to check axis

use_signed = bool(input_list[7].node().i("value"))
self.assertEqual(use_signed, target_val["use_signed"])

use_symmetric = bool(input_list[8].node().i("value"))
self.assertEqual(use_symmetric, target_val["use_symmetric"])

use_dynamic = bool(input_list[9].node().i("value"))
self.assertEqual(use_dynamic, target_val["use_dynamic"])

use_per_channel = bool(input_list[10].node().i("value"))
self.assertEqual(use_per_channel, target_val["use_per_channel"])

def _test_replace_aten_fake_quant(self, model, inp, all_target_val):
traced_model = torch.jit.trace(model, inp)
origin_output = traced_model(inp)
Expand Down Expand Up @@ -138,8 +113,8 @@ def test_per_tensor_symmetry(self):
model = PerTensorFakeQuant(
scale=0.1,
zero_point=0,
quant_min=-2**(bit-1),
quant_max=2**(bit-1)-1
quant_min=-2 ** (bit - 1),
quant_max=2 ** (bit - 1) - 1
).eval().to(self.device)

target_val = {
Expand Down Expand Up @@ -217,7 +192,7 @@ def test_per_channel_asymmetry(self):
scale=scale,
zero_point=zero_point,
quant_min=0,
quant_max=2**bit-1,
quant_max=2 ** bit - 1,
axis=1
).eval().to(self.device)
target_val = {
Expand Down Expand Up @@ -263,5 +238,86 @@ def test_dummy_model(self):
self._test_replace_aten_fake_quant(model, inp, target_val)


class TestAddFakeQuantForWeight(QuantizationTestCase):
def _test_add_fake_quant_for_weight(self, model, inp, target_quant_info, target_output, target_graph):
traced_model = torch.jit.trace(model, inp)
c_module = traced_model._c
c_module = tools.freeze_module(c_module, [], disableShapePeephole=False)
_jit_add_fake_quant_for_weight(c_module)
graph = c_module.forward.graph
fake_quant_nodes = get_fake_quant_node(graph)
self.assertEqual(len(fake_quant_nodes), len(target_quant_info))
for n, node_target_val in zip(fake_quant_nodes, target_quant_info):
self._test_fake_quant_params(n, node_target_val)

c_module.create_method_from_graph("_forward", graph)
now_output = c_module._forward(inp)
self.assertEqual(now_output, target_output)

FileCheck().run(target_graph, graph)

@unittest.skipIf(TORCH_VERSION < (1, 9), "Unsupported torch version")
def test_per_channel_symmetry(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 4, bias=True)

def forward(self, x):
x = self.linear(x)
return x

inp = torch.randn(1, 3)
model = Model().eval()
# use torch's observer to calculate the scale & zero point
obs = PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
)
observer = obs()
observer(model.linear.weight)
scale, zero_point = observer.calculate_qparams()
zero_point = zero_point.to(zero_point_dtype)
fake_quantized_weight = torch.fake_quantize_per_channel_affine(
model.linear.weight, scale, zero_point, 0, -128, 127)
target_output = F.linear(inp, fake_quantized_weight, model.linear.bias)

target_quant_info = [{
"scale": scale,
"zero_point": zero_point,
"quant_min": -128,
"quant_max": 127,
"num_bits": 8,
"use_signed": True,
"use_symmetric": True,
"use_dynamic": True,
"use_per_channel": True,
"target_output": target_output
}]

target_graph = """
graph(%self.1 : __torch__.___torch_mangle_2.Model,
%x : Float(1, 3, strides=[3, 1], requires_grad=0, device=cpu)):
%self.linear.weight : Float(4, 3, strides=[3, 1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
%11 : Float(4, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value=0.001 * 2.5694 2.9053 1.9990 2.3660 [ CPUFloatType{4} ]]()
%12 : Int(4, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value= 0 0 0 0 [ CPUIntType{4} ]]()
%13 : int = prim::Constant[value=-128]()
%14 : int = prim::Constant[value=127]()
%15 : int = prim::Constant[value=8]()
%16 : int = prim::Constant[value=1]()
%17 : int[] = prim::ListConstruct(%16)
%18 : bool = prim::Constant[value=1]()
%19 : bool = prim::Constant[value=1]()
%20 : bool = prim::Constant[value=1]()
%21 : bool = prim::Constant[value=1]()
# CHECK: torch_blade::fake_quant
%10 : Float(4, 3, strides=[3, 1], requires_grad=0, device=cpu) = torch_blade::fake_quant(%self.linear.weight, %11, %12, %13, %14, %15, %17, %18, %19, %20, %21)
%self.linear.bias : Float(4, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value= 0.1169 -0.2259 -0.2832 0.1494 [ CPUFloatType{4} ]]()
%6 : Float(1, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::linear(%x, %10, %self.linear.bias)
return (%6)
"""

self._test_add_fake_quant_for_weight(model, inp, target_quant_info, target_output, target_graph)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit d95b24a

Please sign in to comment.