Skip to content

Commit

Permalink
refactor(quanto): avoid qlinear composite gradients
Browse files Browse the repository at this point in the history
By explicitly defining the backward path for the quantized linear function,
we save calculations and prepare the introduction of fused operations for
which we won't necessarily have a gradient.
  • Loading branch information
dacorvo committed Apr 23, 2024
1 parent 544981d commit 1f0a2a5
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
50 changes: 46 additions & 4 deletions quanto/tensor/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,51 @@ def unsupported_op(func, *args, **kwargs):
return qfallback(func, *args, **kwargs)


class QTensorLinear(torch.autograd.Function):
"""Quantized linear function.
This is a quantized implementation of torch.nn.functional.linear.
It defines explicitly the backward pass instead of letting pytorch
build it by combining the gradients of the underlying quantized operations.
This has two main benefits:
- this saves computations,
- this allows to use operations that do not have a registered backward method,
such as quanto custom operations.
The drawback is that the extra tensors involved in the quantization graph, such as
the scales and zeropoint, cannot be trained.
This is however consistent with the quanto quantizers backward pass, that returns
a zero gradient for these tensors.
"""

@staticmethod
def forward(ctx, input, other, bias):
ctx.save_for_backward(input, other)
output = torch.matmul(input, other.t())
if bias is not None:
output = output + bias
return output

def backward(ctx, gO):
input_gO = other_gO = bias_gO = None
input, other = ctx.saved_tensors
out_features, in_features = other.shape
if ctx.needs_input_grad[0]:
# grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B
input_gO = torch.matmul(gO, other)
if ctx.needs_input_grad[1]:
# grad([email protected]()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A
other_gO = torch.matmul(gO.view(-1, out_features).t(), input.view(-1, in_features))
if ctx.needs_input_grad[2]:
# Bias gradient is the sum on all dimensions but the last one
dim = tuple(range(gO.ndim - 1))
bias_gO = gO.sum(dim)
return input_gO, other_gO, bias_gO


@register_qtensor_func([torch.nn.functional.linear])
def linear(func, input, other, bias=None):
output = torch.matmul(input, other.t())
if bias is not None:
output = output + bias
return output
return QTensorLinear.apply(input, other, bias)
37 changes: 29 additions & 8 deletions test/nn/test_qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,20 @@

import pytest
import torch
from helpers import assert_similar, random_qactivation

from quanto import Calibration, QBitsTensor, QTensor, qfloat8, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8
from helpers import assert_similar, random_qactivation, random_tensor

from quanto import (
Calibration,
QBitsTensor,
QTensor,
absmax_scale,
qfloat8,
qfloat8_e4m3fn,
qfloat8_e5m2,
qint4,
qint8,
quantize_activation,
)
from quanto.nn import QLinear


Expand Down Expand Up @@ -124,20 +135,30 @@ def test_qlinear_gradient(tokens, embeddings, activations, weights, device):
qlinear = QLinear.from_module(linear, weights=weights, activations=activations)
assert qlinear.weight.requires_grad is True
assert qlinear.bias.requires_grad is True
# Run an inference with identical inputs
qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device)
# Run an inference with quantized inputs
inputs = random_tensor((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device)
inputs.requires_grad = True
qinputs = quantize_activation(inputs, qtype=qint8, scale=absmax_scale(inputs, qint8))
qout = qlinear(qinputs)
out = linear(qinputs.dequantize())
# Run an equivalent inference with float inputs
dqinputs = qinputs.dequantize().clone().detach()
dqinputs.requires_grad = True
out = linear(dqinputs)
# Outputs are not identical because of the quantization
assert not torch.equal(qout, out)
# Compute gradients and compare
gradient = torch.randn(qout.size()).to(device)
qout.backward(gradient)
out.backward(gradient)
# Gradients are nearly identical because they depend only on the input
# Bias gradients are identical because they don't depend on inputs and weights
atol = 1e-6
assert_similar(qlinear.bias.grad, linear.bias.grad, atol=atol)
# Weights gradients are nearly identical, based on identical inputs through subtly different graphs
atol = 1e-5
assert_similar(qlinear.weight.grad, linear.weight.grad, atol=atol)
assert_similar(qlinear.bias.grad, linear.bias.grad, atol=atol)
# Inputs gradients are slightly different because they depend on the quantized weights
atol = {qint8: 1e-5, qint4: 5e-3}[weights]
assert_similar(inputs.grad, dqinputs.grad, atol=atol)


@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
Expand Down

0 comments on commit 1f0a2a5

Please sign in to comment.