From 1f0a2a59adb1ee8ff2ad3070db0270479a26e56b Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 22 Apr 2024 16:34:30 +0200 Subject: [PATCH] refactor(quanto): avoid qlinear composite gradients By explicitly defining the backward path for the quantized linear function, we save calculations and prepare the introduction of fused operations for which we won't necessarily have a gradient. --- quanto/tensor/func.py | 50 +++++++++++++++++++++++++++++++++++++---- test/nn/test_qlinear.py | 37 +++++++++++++++++++++++------- 2 files changed, 75 insertions(+), 12 deletions(-) diff --git a/quanto/tensor/func.py b/quanto/tensor/func.py index b5d7a3e2..ba1f434e 100644 --- a/quanto/tensor/func.py +++ b/quanto/tensor/func.py @@ -66,9 +66,51 @@ def unsupported_op(func, *args, **kwargs): return qfallback(func, *args, **kwargs) +class QTensorLinear(torch.autograd.Function): + """Quantized linear function. + + This is a quantized implementation of torch.nn.functional.linear. + + It defines explicitly the backward pass instead of letting pytorch + build it by combining the gradients of the underlying quantized operations. + + This has two main benefits: + + - this saves computations, + - this allows to use operations that do not have a registered backward method, + such as quanto custom operations. + + The drawback is that the extra tensors involved in the quantization graph, such as + the scales and zeropoint, cannot be trained. + This is however consistent with the quanto quantizers backward pass, that returns + a zero gradient for these tensors. + """ + + @staticmethod + def forward(ctx, input, other, bias): + ctx.save_for_backward(input, other) + output = torch.matmul(input, other.t()) + if bias is not None: + output = output + bias + return output + + def backward(ctx, gO): + input_gO = other_gO = bias_gO = None + input, other = ctx.saved_tensors + out_features, in_features = other.shape + if ctx.needs_input_grad[0]: + # grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B + input_gO = torch.matmul(gO, other) + if ctx.needs_input_grad[1]: + # grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A + other_gO = torch.matmul(gO.view(-1, out_features).t(), input.view(-1, in_features)) + if ctx.needs_input_grad[2]: + # Bias gradient is the sum on all dimensions but the last one + dim = tuple(range(gO.ndim - 1)) + bias_gO = gO.sum(dim) + return input_gO, other_gO, bias_gO + + @register_qtensor_func([torch.nn.functional.linear]) def linear(func, input, other, bias=None): - output = torch.matmul(input, other.t()) - if bias is not None: - output = output + bias - return output + return QTensorLinear.apply(input, other, bias) diff --git a/test/nn/test_qlinear.py b/test/nn/test_qlinear.py index c580fca4..746651d8 100644 --- a/test/nn/test_qlinear.py +++ b/test/nn/test_qlinear.py @@ -17,9 +17,20 @@ import pytest import torch -from helpers import assert_similar, random_qactivation - -from quanto import Calibration, QBitsTensor, QTensor, qfloat8, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8 +from helpers import assert_similar, random_qactivation, random_tensor + +from quanto import ( + Calibration, + QBitsTensor, + QTensor, + absmax_scale, + qfloat8, + qfloat8_e4m3fn, + qfloat8_e5m2, + qint4, + qint8, + quantize_activation, +) from quanto.nn import QLinear @@ -124,20 +135,30 @@ def test_qlinear_gradient(tokens, embeddings, activations, weights, device): qlinear = QLinear.from_module(linear, weights=weights, activations=activations) assert qlinear.weight.requires_grad is True assert qlinear.bias.requires_grad is True - # Run an inference with identical inputs - qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) + # Run an inference with quantized inputs + inputs = random_tensor((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) + inputs.requires_grad = True + qinputs = quantize_activation(inputs, qtype=qint8, scale=absmax_scale(inputs, qint8)) qout = qlinear(qinputs) - out = linear(qinputs.dequantize()) + # Run an equivalent inference with float inputs + dqinputs = qinputs.dequantize().clone().detach() + dqinputs.requires_grad = True + out = linear(dqinputs) # Outputs are not identical because of the quantization assert not torch.equal(qout, out) # Compute gradients and compare gradient = torch.randn(qout.size()).to(device) qout.backward(gradient) out.backward(gradient) - # Gradients are nearly identical because they depend only on the input + # Bias gradients are identical because they don't depend on inputs and weights + atol = 1e-6 + assert_similar(qlinear.bias.grad, linear.bias.grad, atol=atol) + # Weights gradients are nearly identical, based on identical inputs through subtly different graphs atol = 1e-5 assert_similar(qlinear.weight.grad, linear.weight.grad, atol=atol) - assert_similar(qlinear.bias.grad, linear.bias.grad, atol=atol) + # Inputs gradients are slightly different because they depend on the quantized weights + atol = {qint8: 1e-5, qint4: 5e-3}[weights] + assert_similar(inputs.grad, dqinputs.grad, atol=atol) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])