Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate C++ kernels for 4-bit & 2-bit MatMul #113

Closed
wants to merge 6 commits into from

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Mar 7, 2024

What does this PR do ?

This PR implements C++ udqmm kernel for 4-bit and 2-bit MatMul. In this kernel, we do the following steps: unpack + dequantize + mm where mm is the only parallelized operation for now. We also implement the c++ version of ungroup

cc @SunMarc @dacorvo

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really really good ... I left a few comments.
I am wondering how we can pass the qlinear tests though since I thought the packed weights would be transposed when passed to mm and their axis would be 1 which is not supported yet ...

@@ -41,3 +42,16 @@ def quantize_symmetric_cpp(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Te
@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"])
def unpack_cpp(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)


@torch.library.impl("quanto_ext::udqmm", ["CPU", "CUDA"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we will need the other pull-request to support also MPS devices.


torch::Tensor udqmm(torch::Tensor &input, torch::Tensor &weights, torch::Tensor& scale, torch::Tensor& zeropoint, int axis, int bits, torch::IntArrayRef orig_shape) {
torch::Tensor unpacked_weights = unpack(weights, bits);
torch::Tensor dq_output = (unpacked_weights.to(torch::kInt8) - zeropoint.to(torch::kInt8)) * scale;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok to cast the unpacked_weigths since they are uint8 (we could also change the unpack method to do the cast there, making it part of the contract). We should only assert and err on zeropoint though, using TORCH_CHECK.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will probably be cleaner do the cast in the unpack. So that we have the same dtype for unpacked_weights and zeropoint ! LMK what you think. If yes, should we do that in another PR ?

Copy link
Collaborator

@dacorvo dacorvo Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that a little bit and I am inclined to keep returning an unsigned int.
The unpacked weights really are unsigned integer in the range [0, 2**(bits -1)], and the
only reason for the cast is the subtraction of a signed int.
Now regarding the cast, I thought we could simply drop it, but I realized pytorch is doing an implicit type promotion to int16 because of the subtraction:

>>> a = torch.tensor(255, dtype=torch.uint8).to('mps')
>>> b = torch.tensor(1, dtype=torch.int8).to('mps')
>>> a - b
tensor(254, device='mps:0', dtype=torch.int16)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why pytorch is doing that is probably because the first term might become negative, and change the result:

>>> a.to(torch.int8) - b
tensor(-2, device='mps:0', dtype=torch.int8)

In our case we know the first term will stay positive during the cast, so it is safe, and avoids the materialization of a large int16 intermediate tensor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks for explaining !


torch::Tensor ungrouped_output;

// Ungroup TODO : put on its own function
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. And this should be tested independently like unpack.

):
unpacked_weights = torch.ops.quanto.unpack(weights, bits)
shifted_weights = unpacked_weights.to(torch.int8) - zeropoint.to(torch.int8)
scaled_weights = shifted_weights * scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this will actually promote shifted_weigths to float16 then do the multiplication.
Since I had issues with these kind of implicit operation, maybe we could do that explicitly:

scaled_weights = shifted_weights.to(scale.dtype) * scale

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know !

test/library/test_mm.py Outdated Show resolved Hide resolved
# torch.float32, torch.int8, False, device
# ),
# "unpack_2bit": lambda device: get_unpack_bench(2, device),
# "unpack_4bit": lambda device: get_unpack_bench(4, device),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can pass just the bench you want as parameter to the script ... just sayin ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, indeed !

@SunMarc SunMarc marked this pull request as ready for review March 12, 2024 20:50
@SunMarc SunMarc changed the title [DRAFT] Integrate C++ kernels for 4-bit & 2-bit MatMul Integrate C++ kernels for 4-bit & 2-bit MatMul Mar 12, 2024
Comment on lines 5 to 23
def udqmm(
input: torch.Tensor,
weights: torch.Tensor,
scale: torch.Tensor,
zeropoint: torch.Tensor,
axis: int,
bits: int,
orig_shape: torch.Size,
unpacked_shape: torch.Size,
) -> torch.Tensor:
# we transpose it back, so it is simpler to unpack since we have the pack + transposed weights
weights = weights.transpose(0, 1)
unpacked_weights = torch.ops.quanto.unpack(weights, bits)
# TODO : we should proably add that in unpack with with arg unpacked_shape.
# Depends if the weights have been transposed or not
# if not transposed, we need to do unpacked_weights[: unpacked_shape[0]]
unpacked_weights = unpacked_weights[: unpacked_shape[1]]
# transpose back
unpacked_weights = unpacked_weights.transpose(0, 1)
Copy link
Member

@SunMarc SunMarc Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we call the quantized linear layer, the QBitsTensor tensor gets transposed and we end up with transposed weights in the udqmm function. This is quite troublesome for the unpacking step since don't have an variable to indicate (e.g. axis for the ungroup ) when we should transpose again our weights to unpack them correctly. Do you have any solution to fix this @dacorvo ? We can maybe add an attribute in PackedTensor that tell us if the data was transposed, so that we know how to unpack it.

This is the reason why the test_mm.py are failing. For the test_qlinear.py (where we have transposed weights), there is no issue because of the fix I did above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the weights are transposed, the axis is also modified, so all the information should be available to the caller. I am not sure to see what the issue is here, because the flow has been validated in python already.

Copy link
Member

@SunMarc SunMarc Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that the packing/unpacking do not depend on the axis.

@dacorvo
Copy link
Collaborator

dacorvo commented Mar 13, 2024

Before reviewing, it can you rebase your branch and cleanup the commits that are related to one another (like a fix to another commit).

Rebase on main

The last commit from main you are using being 3bddd0a:

$ git rebase 3bddd0a --onto origin/main

You will need to solve conflicts (because at some point you merged an obsolete main, which makes things more difficult here).

Interactive rebase to cleanup redundant commits

$ git rebase -i origin/main

This will open an editor with the following list of commits:

pick 7671993 v1 - does not support group wise?
pick 5ea42cb add grouped weights support for axis=0
pick c34443b refactor everything
pick 1c16b33 add python impl
pick 61b49e9 add benchmark script and fix python impl
pick f0bfa31 add dispatch to `QBitsTensor` and update benchmark
pick 2a8846e fix zero-point and orig_shape
pick cbd1922 add slicing
pick 0c485c7 uncomment
pick 81dc58a feat(cpp): check tensor type in udqmm with scalar_type
pick bc33442 feat(ungroup): add cpp and python implementation
pick 5bf6e18 feat(ungroup): add missing python impl
pick 02481a3 feat(PackedTensor): implement torch.ops.aten.t
pick 15734ad feat(QBitsTensor): implement torch.ops.aten.t
pick d3d280f feat(udqmm): fix issue with unpacking with transposed tensor
pick d1046df feat(udqmm): fix naming

Edit the command on the left to choose between pick, squash or fixup for each commit (you can also reorder them).
The difference between squash and fixup is that when squashing you keep the message.
When done, save and let git do its magic.

An obvious example based on your commits:

pick bc33442 feat(ungroup): add cpp and python implementation
fixup 5bf6e18 feat(ungroup): add missing python impl
...

@SunMarc SunMarc force-pushed the try-4bit-mm branch 3 times, most recently from 9aa908d to c579e4a Compare March 13, 2024 16:34
@SunMarc
Copy link
Member

SunMarc commented Mar 13, 2024

Hi @dacorvo, I've addressed the issue about the packing/unpacking and cleaned the commits !

@SunMarc SunMarc requested a review from dacorvo March 13, 2024 16:35
@SunMarc SunMarc force-pushed the try-4bit-mm branch 2 times, most recently from 409bd96 to 0611a76 Compare March 13, 2024 16:56
@SunMarc
Copy link
Member

SunMarc commented Mar 14, 2024

I had the issue multiple time but it seems that the following test is flaky test_quantize_symmetric[cpu-per-axis-float8-fp16-matrix]. Sometimes, it doesn't pass with python 3.11.

That's a crappy kernel anyway. I will remove it.

@dacorvo
Copy link
Collaborator

dacorvo commented Mar 19, 2024

As discussed, the results are quite disappointing, especially on CUDA. I am closing this but keeping the branch.

@dacorvo dacorvo closed this Mar 19, 2024
@alejandroarmas
Copy link

Hey there @SunMarc @dacorvo I'm seeing that this PR was closed and wanted to understand how it relates to #107. I must have missed why the results were disappointing? Should someone in a position like mine scrap #107 and pivot to a different issue in the repo? Thank you in advance 😊 I just want to be as helpful as possible!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants