-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(quanto): avoid qlinear composite gradients
By explicitly defining the backward path for the quantized linear function, we save calculations and prepare the introduction of fused operations for which we won't necessarily have a gradient.
- Loading branch information
Showing
2 changed files
with
75 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,9 +66,51 @@ def unsupported_op(func, *args, **kwargs): | |
return qfallback(func, *args, **kwargs) | ||
|
||
|
||
class QTensorLinear(torch.autograd.Function): | ||
"""Quantized linear function. | ||
This is a quantized implementation of torch.nn.functional.linear. | ||
It defines explicitly the backward pass instead of letting pytorch | ||
build it by combining the gradients of the underlying quantized operations. | ||
This has two main benefits: | ||
- this saves computations, | ||
- this allows to use operations that do not have a registered backward method, | ||
such as quanto custom operations. | ||
The drawback is that the extra tensors involved in the quantization graph, such as | ||
the scales and zeropoint, cannot be trained. | ||
This is however consistent with the quanto quantizers backward pass, that returns | ||
a zero gradient for these tensors. | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, input, other, bias): | ||
ctx.save_for_backward(input, other) | ||
output = torch.matmul(input, other.t()) | ||
if bias is not None: | ||
output = output + bias | ||
return output | ||
|
||
def backward(ctx, gO): | ||
input_gO = other_gO = bias_gO = None | ||
input, other = ctx.saved_tensors | ||
out_features, in_features = other.shape | ||
if ctx.needs_input_grad[0]: | ||
# grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B | ||
input_gO = torch.matmul(gO, other) | ||
if ctx.needs_input_grad[1]: | ||
# grad([email protected]()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A | ||
other_gO = torch.matmul(gO.view(-1, out_features).t(), input.view(-1, in_features)) | ||
if ctx.needs_input_grad[2]: | ||
# Bias gradient is the sum on all dimensions but the last one | ||
dim = tuple(range(gO.ndim - 1)) | ||
bias_gO = gO.sum(dim) | ||
return input_gO, other_gO, bias_gO | ||
|
||
|
||
@register_qtensor_func([torch.nn.functional.linear]) | ||
def linear(func, input, other, bias=None): | ||
output = torch.matmul(input, other.t()) | ||
if bias is not None: | ||
output = output + bias | ||
return output | ||
return QTensorLinear.apply(input, other, bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters