-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate C++ kernels for 4-bit & 2-bit MatMul #113
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really really good ... I left a few comments.
I am wondering how we can pass the qlinear tests though since I thought the packed weights would be transposed when passed to mm
and their axis would be 1 which is not supported yet ...
@@ -41,3 +42,16 @@ def quantize_symmetric_cpp(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Te | |||
@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"]) | |||
def unpack_cpp(t: torch.Tensor, bits: int): | |||
return ext().unpack(t, bits) | |||
|
|||
|
|||
@torch.library.impl("quanto_ext::udqmm", ["CPU", "CUDA"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where we will need the other pull-request to support also MPS devices.
quanto/library/ext/cpp/udqmm.cpp
Outdated
|
||
torch::Tensor udqmm(torch::Tensor &input, torch::Tensor &weights, torch::Tensor& scale, torch::Tensor& zeropoint, int axis, int bits, torch::IntArrayRef orig_shape) { | ||
torch::Tensor unpacked_weights = unpack(weights, bits); | ||
torch::Tensor dq_output = (unpacked_weights.to(torch::kInt8) - zeropoint.to(torch::kInt8)) * scale; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok to cast the unpacked_weigths since they are uint8 (we could also change the unpack method to do the cast there, making it part of the contract). We should only assert and err on zeropoint though, using TORCH_CHECK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will probably be cleaner do the cast in the unpack. So that we have the same dtype for unpacked_weights
and zeropoint
! LMK what you think. If yes, should we do that in another PR ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about that a little bit and I am inclined to keep returning an unsigned int.
The unpacked weights really are unsigned integer in the range [0, 2**(bits -1)], and the
only reason for the cast is the subtraction of a signed int.
Now regarding the cast, I thought we could simply drop it, but I realized pytorch is doing an implicit type promotion to int16
because of the subtraction:
>>> a = torch.tensor(255, dtype=torch.uint8).to('mps')
>>> b = torch.tensor(1, dtype=torch.int8).to('mps')
>>> a - b
tensor(254, device='mps:0', dtype=torch.int16)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why pytorch is doing that is probably because the first term might become negative, and change the result:
>>> a.to(torch.int8) - b
tensor(-2, device='mps:0', dtype=torch.int8)
In our case we know the first term will stay positive during the cast, so it is safe, and avoids the materialization of a large int16
intermediate tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, thanks for explaining !
quanto/library/ext/cpp/udqmm.cpp
Outdated
|
||
torch::Tensor ungrouped_output; | ||
|
||
// Ungroup TODO : put on its own function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely. And this should be tested independently like unpack.
quanto/library/python/udqmm.py
Outdated
): | ||
unpacked_weights = torch.ops.quanto.unpack(weights, bits) | ||
shifted_weights = unpacked_weights.to(torch.int8) - zeropoint.to(torch.int8) | ||
scaled_weights = shifted_weights * scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, this will actually promote shifted_weigths to float16 then do the multiplication.
Since I had issues with these kind of implicit operation, maybe we could do that explicitly:
scaled_weights = shifted_weights.to(scale.dtype) * scale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to know !
bench/library/benchmark.py
Outdated
# torch.float32, torch.int8, False, device | ||
# ), | ||
# "unpack_2bit": lambda device: get_unpack_bench(2, device), | ||
# "unpack_4bit": lambda device: get_unpack_bench(4, device), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can pass just the bench you want as parameter to the script ... just sayin ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, indeed !
quanto/library/python/udqmm.py
Outdated
def udqmm( | ||
input: torch.Tensor, | ||
weights: torch.Tensor, | ||
scale: torch.Tensor, | ||
zeropoint: torch.Tensor, | ||
axis: int, | ||
bits: int, | ||
orig_shape: torch.Size, | ||
unpacked_shape: torch.Size, | ||
) -> torch.Tensor: | ||
# we transpose it back, so it is simpler to unpack since we have the pack + transposed weights | ||
weights = weights.transpose(0, 1) | ||
unpacked_weights = torch.ops.quanto.unpack(weights, bits) | ||
# TODO : we should proably add that in unpack with with arg unpacked_shape. | ||
# Depends if the weights have been transposed or not | ||
# if not transposed, we need to do unpacked_weights[: unpacked_shape[0]] | ||
unpacked_weights = unpacked_weights[: unpacked_shape[1]] | ||
# transpose back | ||
unpacked_weights = unpacked_weights.transpose(0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we call the quantized linear layer, the QBitsTensor
tensor gets transposed and we end up with transposed weights in the udqmm
function. This is quite troublesome for the unpacking
step since don't have an variable to indicate (e.g. axis for the ungroup
) when we should transpose again our weights to unpack them correctly. Do you have any solution to fix this @dacorvo ? We can maybe add an attribute in PackedTensor
that tell us if the data was transposed, so that we know how to unpack it.
This is the reason why the test_mm.py
are failing. For the test_qlinear.py
(where we have transposed weights), there is no issue because of the fix I did above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the weights are transposed, the axis is also modified, so all the information should be available to the caller. I am not sure to see what the issue is here, because the flow has been validated in python already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that the packing/unpacking do not depend on the axis
.
Before reviewing, it can you rebase your branch and cleanup the commits that are related to one another (like a fix to another commit). Rebase on main The last commit from
You will need to solve conflicts (because at some point you merged an obsolete Interactive rebase to cleanup redundant commits
This will open an editor with the following list of commits:
Edit the command on the left to choose between An obvious example based on your commits:
|
9aa908d
to
c579e4a
Compare
Hi @dacorvo, I've addressed the issue about the packing/unpacking and cleaned the commits ! |
409bd96
to
0611a76
Compare
That's a crappy kernel anyway. I will remove it. |
As discussed, the results are quite disappointing, especially on CUDA. I am closing this but keeping the branch. |
Hey there @SunMarc @dacorvo I'm seeing that this PR was closed and wanted to understand how it relates to #107. I must have missed why the results were disappointing? Should someone in a position like mine scrap #107 and pivot to a different issue in the repo? Thank you in advance 😊 I just want to be as helpful as possible! |
What does this PR do ?
This PR implements C++ udqmm kernel for 4-bit and 2-bit MatMul. In this kernel, we do the following steps: unpack + dequantize + mm where mm is the only parallelized operation for now. We also implement the c++ version of
ungroup
cc @SunMarc @dacorvo