Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate C++ kernels for 4-bit & 2-bit MatMul #113

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions bench/library/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,38 @@
import torch
from tqdm.auto import tqdm

from quanto import group
from quanto.library import disable_extensions
from quanto.tensor.packed import pack_weights


def get_udqmm_bench(input_dtype, device, bits):
input = torch.rand([128, 1024], dtype=input_dtype).to(device)
weight = torch.randint(-127, 127, [1024, 1024], dtype=torch.int8).to(device)

orig_shape = weight.shape
grouped_weights = group(weight, axis=0, group_size=int(orig_shape[-1] / 4))
scale = torch.ones((1, grouped_weights.shape[1]), dtype=input_dtype, device=device) * 0.5
zeropoint = torch.randint(
torch.iinfo(torch.int8).min, torch.iinfo(torch.int8).max, (1, grouped_weights.shape[1]), dtype=torch.int8
).to(device)

packed_weights = pack_weights(grouped_weights, bits)

def bench_fn():
return torch.ops.quanto.udqmm(
input,
packed_weights,
scale,
zeropoint,
axis=0,
bits=bits,
orig_shape=orig_shape,
unpacked_shape=grouped_weights.shape,
packed_axis=0,
)

return bench_fn


def get_dqmm_bench(input_dtype, device):
Expand All @@ -31,12 +62,28 @@ def bench_fn():
return bench_fn


def get_unpack_bench(bits, device):
def get_unpack_bench(bits, axis, device):
qmax = 2**bits
a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device)
packed_size = 10240
a = torch.randint(0, qmax, [packed_size, packed_size], dtype=torch.uint8).to(device)
n_packed = 8 // bits
actual_axis = 0 if axis == 1 else 1
orig_shape = [packed_size, packed_size]
orig_shape[actual_axis] = orig_shape[actual_axis] * n_packed

def bench_fn():
return torch.ops.quanto.unpack(a, bits, orig_shape, axis)

return bench_fn


def get_ungroup_bench(device):
qmax = 2**8
weights = torch.randint(0, qmax, (10240, 10240), dtype=torch.uint8).to(device)
grouped_weights = group(weights, axis=0, group_size=32)

def bench_fn():
return torch.ops.quanto.unpack(a, bits)
return torch.ops.quanto.ungroup(grouped_weights, axis=0, orig_shape=weights.shape)

return bench_fn

Expand Down Expand Up @@ -95,8 +142,13 @@ def elapsed_time(self, other):
"quantize_symmetric_fp32_int8_per_tensor": lambda device: get_quantize_symmetric_bench(
torch.float32, torch.int8, False, device
),
"unpack_2bit": lambda device: get_unpack_bench(2, device),
"unpack_4bit": lambda device: get_unpack_bench(4, device),
"unpack_2bit_axis_0": lambda device: get_unpack_bench(2, 0, device),
"unpack_2bit_axis_1": lambda device: get_unpack_bench(2, 1, device),
"unpack_4bit_axis_0": lambda device: get_unpack_bench(4, 0, device),
"unpack_4bit_axis_1": lambda device: get_unpack_bench(4, 1, device),
"ungroup": lambda device: get_ungroup_bench(device),
"udqmm_2bit": lambda device: get_udqmm_bench(torch.float16, device, 2),
"udqmm_4bit": lambda device: get_udqmm_bench(torch.float16, device, 4),
}


Expand Down
26 changes: 24 additions & 2 deletions quanto/library/ext/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ def ext():
name="quanto_cpp",
sources=[
f"{module_path}/mm.cpp",
f"{module_path}/udqmm.cpp",
f"{module_path}/quantize.cpp",
f"{module_path}/unpack.cpp",
f"{module_path}/ungroup.cpp",
f"{module_path}/pybind_module.cpp",
],
extra_cflags=["-O3"],
Expand All @@ -39,5 +41,25 @@ def quantize_symmetric_cpp(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Te


@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"])
def unpack_cpp(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)
def unpack_cpp(t: torch.Tensor, bits: int, orig_shape: torch.Size, axis: int):
return ext().unpack(t, bits, orig_shape, axis)


@torch.library.impl("quanto_ext::ungroup", ["CPU", "CUDA", "MPS"])
def ungroup_cpp(grouped: torch.Tensor, axis: int, orig_shape: torch.Size):
return ext().ungroup(grouped, axis, orig_shape)


@torch.library.impl("quanto_ext::udqmm", ["CPU", "CUDA"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we will need the other pull-request to support also MPS devices.

def udqmm_cpp(
input: torch.Tensor,
weights: torch.Tensor,
scale: torch.Tensor,
zeropoint: torch.Tensor,
axis: int,
bits: int,
orig_shape: torch.Size,
unpacked_shape: torch.Size,
packed_axis: int,
):
return ext().udqmm(input, weights, scale, zeropoint, axis, bits, orig_shape, unpacked_shape, packed_axis)
4 changes: 4 additions & 0 deletions quanto/library/ext/cpp/pybind_module.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <torch/extension.h>
#include "mm.h"
#include "quantize.h"
#include "udqmm.h"
#include "unpack.h"
#include "ungroup.h"

// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,
// and need to be explicitly converted using dedicated helpers before calling a C++ method.
Expand All @@ -12,11 +14,13 @@

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dqmm", &dqmm, "dqmm");
m.def("udqmm", &udqmm, "udqmm");
m.def("quantize_symmetric",
[](const torch::Tensor& t, const torch::Tensor& scale, py::object dtype) {
return quantize_symmetric(t,
scale,
torch::python::detail::py_object_to_dtype(dtype));
}, "quantize_symmetric");
m.def("unpack", &unpack, "unpack");
m.def("ungroup", &ungroup, "ungroup");
}
19 changes: 19 additions & 0 deletions quanto/library/ext/cpp/udqmm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "udqmm.h"
#include "unpack.h"
#include "ungroup.h"

#include <iostream>
#include <torch/extension.h>

using namespace std;

torch::Tensor udqmm(torch::Tensor &input, torch::Tensor &weights, torch::Tensor &scale, torch::Tensor &zeropoint, int axis, int bits, torch::IntArrayRef orig_shape, torch::IntArrayRef unpacked_shape, int packed_axis) {
TORCH_CHECK(zeropoint.scalar_type() == torch::kInt8, "zeropoint must have scalar type: torch.int8");
torch::Tensor unpacked_weights = unpack(weights, bits, unpacked_shape, packed_axis);

torch::Tensor dq_output = (unpacked_weights.to(torch::kInt8) - zeropoint).to(scale.dtype()) * scale;

torch::Tensor ungrouped_output = ungroup(dq_output, axis, orig_shape);

return torch::mm(input, ungrouped_output);
}
3 changes: 3 additions & 0 deletions quanto/library/ext/cpp/udqmm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include <torch/extension.h>

torch::Tensor udqmm(torch::Tensor &input, torch::Tensor &weights, torch::Tensor &scale, torch::Tensor &zeropoint, int axis, int bits, torch::IntArrayRef orig_shape, torch::IntArrayRef unpacked_shape, int packed_axis);
22 changes: 22 additions & 0 deletions quanto/library/ext/cpp/ungroup.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "ungroup.h"
#include <torch/extension.h>

torch::Tensor ungroup(torch::Tensor &grouped, int axis, torch::IntArrayRef orig_shape){
if (grouped.sizes() == orig_shape){
return grouped;
}
if (axis == 0) {
return torch::reshape(grouped, orig_shape);
}
int64_t group_size = (axis == -1) ? grouped.size(0) : grouped.size(-1);
int64_t axis_dim = (axis == -1) ? orig_shape.back() : orig_shape[axis];
// Calculate the number of groups per axis
int64_t groups_per_axis = grouped.numel() / axis_dim / group_size;

torch::Tensor ungrouped = grouped.reshape({group_size, axis_dim, groups_per_axis});
ungrouped = ungrouped.transpose(1, 2);
ungrouped = ungrouped.transpose(0, 1);

// Reshape to the original shape
return ungrouped.reshape(orig_shape);
}
3 changes: 3 additions & 0 deletions quanto/library/ext/cpp/ungroup.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include <torch/extension.h>

torch::Tensor ungroup(torch::Tensor &grouped, int axis, torch::IntArrayRef orig_shape);
26 changes: 17 additions & 9 deletions quanto/library/ext/cpp/unpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,39 @@
#include <torch/extension.h>


static torch::Tensor unpack_4bit(torch::Tensor &t) {
static torch::Tensor unpack_4bit(torch::Tensor &t, int axis) {
return torch::cat({
(t & 0x0F),
(t & 0xF0).__rshift__(4)
},
0);
axis);
}

static torch::Tensor unpack_2bit(torch::Tensor &t) {
static torch::Tensor unpack_2bit(torch::Tensor &t, int axis) {
return torch::cat({
(t & 0x03),
(t & 0x0C).__rshift__(2),
(t & 0x30).__rshift__(4),
(t & 0xC0).__rshift__(6)
},
0);
axis);
}

torch::Tensor unpack(torch::Tensor &t, int bits) {
static torch::Tensor slice_along_axis(torch::Tensor& t, torch::IntArrayRef orig_shape, int axis) {
return t.slice(axis, 0, orig_shape[axis]);
}

torch::Tensor unpack(torch::Tensor &t, int bits, torch::IntArrayRef orig_shape, int axis) {
TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type());
switch(bits) {
case 4:
return unpack_4bit(t);
case 2:
return unpack_2bit(t);
case 4: {
auto output = unpack_4bit(t, axis);
return slice_along_axis(output, orig_shape, axis);
}
case 2: {
auto output = unpack_2bit(t, axis);
return slice_along_axis(output, orig_shape, axis);
}
default:
throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors.");
}
Expand Down
2 changes: 1 addition & 1 deletion quanto/library/ext/cpp/unpack.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#include <torch/extension.h>

torch::Tensor unpack(torch::Tensor &t, int bits);
torch::Tensor unpack(torch::Tensor &t, int bits, torch::IntArrayRef orig_shape, int axis);
4 changes: 2 additions & 2 deletions quanto/library/ext/mps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ def ext():


@impl("quanto_ext::unpack", "MPS")
def unpack_mps(t: torch.Tensor, bits: int):
return ext().unpack(t, bits)
def unpack_mps(t: torch.Tensor, bits: int, orig_shape: torch.Size, axis: int):
return ext().unpack(t, bits, orig_shape, axis)
2 changes: 1 addition & 1 deletion quanto/library/ext/mps/unpack.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#include <torch/extension.h>

torch::Tensor unpack(const torch::Tensor &input, int bits);
torch::Tensor unpack(const torch::Tensor &input, int bits, torch::IntArrayRef orig_shape, int axis);
26 changes: 17 additions & 9 deletions quanto/library/ext/mps/unpack.mm
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ kernel void mask_and_rshift(constant uint8_t* input [[buffer(0)]],
return output;
}

torch::Tensor unpack_4bit(const torch::Tensor &input) {
torch::Tensor unpack_4bit(const torch::Tensor &input, int axis) {

torch::Tensor output = torch::empty_like(input);
mask_and_shift(input, output, 0x0F, 0);
torch::Tensor output1 = torch::empty_like(input);
mask_and_shift(input, output1, 0xF0, 4);
return torch::cat({output, output1}, 0);
return torch::cat({output, output1}, axis);
}

torch::Tensor unpack_2bit(const torch::Tensor &input) {
torch::Tensor unpack_2bit(const torch::Tensor &input, int axis) {

torch::Tensor output = torch::empty_like(input);
mask_and_shift(input, output, 0x03, 0);
Expand All @@ -112,11 +112,15 @@ kernel void mask_and_rshift(constant uint8_t* input [[buffer(0)]],
mask_and_shift(input, output2, 0x30, 4);
torch::Tensor output3 = torch::empty_like(input);
mask_and_shift(input, output3, 0xC0, 6);
return torch::cat({output, output1, output2, output3}, 0);
return torch::cat({output, output1, output2, output3}, axis);
}

static torch::Tensor slice_along_axis(torch::Tensor& t, torch::IntArrayRef orig_shape, int axis) {
return t.slice(axis, 0, orig_shape[axis]);
}

// C++ op dispatching the Metal unpack operation.
torch::Tensor unpack(const torch::Tensor &input, int bits) {
torch::Tensor unpack(const torch::Tensor &input, int bits, torch::IntArrayRef orig_shape, int axis) {
// Check whether the input tensor resides on the MPS device and whether it's contiguous.
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
Expand All @@ -125,10 +129,14 @@ kernel void mask_and_rshift(constant uint8_t* input [[buffer(0)]],
TORCH_CHECK(input.scalar_type() == torch::kUInt8, "Unsupported data type: ", input.scalar_type());

switch(bits) {
case 4:
return unpack_4bit(input);
case 2:
return unpack_2bit(input);
case 4: {
auto output = unpack_4bit(input, axis);
return slice_along_axis(output, orig_shape, axis);
}
case 2: {
auto output = unpack_2bit(input, axis);
return slice_along_axis(output, orig_shape, axis);
}
default:
throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors.");
}
Expand Down
7 changes: 6 additions & 1 deletion quanto/library/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,9 @@ def impl(*args, **kwargs):

define("dqmm", "(Tensor input, Tensor other, Tensor other_scale) -> Tensor")
define("quantize_symmetric", "(Tensor self, Tensor scale, ScalarType dtype) -> Tensor")
define("unpack", "(Tensor self, int bits) -> Tensor")
define("unpack", "(Tensor self, int bits, Any orig_shape, int axis) -> Tensor")
define("ungroup", "(Tensor grouped, int axis, Any orig_shape) -> Tensor")
define(
"udqmm",
"(Tensor input, Tensor weight, Tensor scales, Tensor zeropoint, int axis, int bits, Any orig_shape, Any unpacked_shape, int packed_axis) -> Tensor",
)
2 changes: 2 additions & 0 deletions quanto/library/python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .mm import *
from .quantize import *
from .udqmm import *
from .ungroup import *
from .unpack import *
20 changes: 20 additions & 0 deletions quanto/library/python/udqmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch


@torch.library.impl("quanto_py::udqmm", "default")
def udqmm(
input: torch.Tensor,
weights: torch.Tensor,
scale: torch.Tensor,
zeropoint: torch.Tensor,
axis: int,
bits: int,
orig_shape: torch.Size,
unpacked_shape: torch.Size,
packed_axis: int,
) -> torch.Tensor:
unpacked_weights = torch.ops.quanto.unpack(weights, bits, unpacked_shape, packed_axis)
shifted_weights = unpacked_weights.to(torch.int8) - zeropoint
scaled_weights = shifted_weights.to(scale.dtype) * scale
ungrouped_weights = torch.ops.quanto.ungroup(scaled_weights, axis, orig_shape)
return torch.ops.aten.mm(input, ungrouped_weights)
18 changes: 18 additions & 0 deletions quanto/library/python/ungroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch


@torch.library.impl("quanto_py::ungroup", "default")
def ungroup(grouped: torch.Tensor, axis: int, orig_shape: torch.Size) -> torch.Tensor:
if grouped.shape == orig_shape:
return grouped
if axis == 0:
# No transposition required, just reshape
return grouped.reshape(orig_shape)
group_size = grouped.shape[0] if axis == -1 else grouped.shape[-1]
axis_dim = orig_shape[axis]
groups_per_axis = grouped.numel() // axis_dim // group_size
ungrouped = grouped.reshape(group_size, axis_dim, groups_per_axis)
# A dual tranposition is required to reorder to (groups_per_axis, group_size, axis_dim)
ungrouped = ungrouped.transpose(1, 2)
ungrouped = ungrouped.transpose(0, 1)
return ungrouped.reshape(orig_shape)
Loading
Loading