Skip to content

Commit

Permalink
Remove debugging-related stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
ahadnagy committed Jan 6, 2025
1 parent 02127ca commit d2eba7f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 129 deletions.
3 changes: 0 additions & 3 deletions optimum/quanto/library/extensions/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def get_max_cuda_arch():
extra_cuda_cflags = [
"--expt-extended-lambda",
"--use_fast_math",
"-lineinfo",
'-O0'
]
# We need to know the minimum CUDA Arch to select only the relevant kernels
# but we cannot rely on __CUDA_ARCH__ as it is not set in host code (only on device code)
Expand Down Expand Up @@ -189,7 +187,6 @@ def gemm_f16i4_marlin(
dtype=input.dtype,
device=input.device,
)
print(f"input shapes: {input.reshape((-1, input.shape[-1])).shape}, in2: {other.shape}, out: {output.reshape((-1, output.shape[-1])).shape}")
ext.lib.marlin_gemm_f16i4(
input.reshape((-1, input.shape[-1])),
other,
Expand Down
156 changes: 33 additions & 123 deletions test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,139 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import pytest
import torch
from helpers import device_eq, random_qweight
from tensor.weights.weight_helpers import check_weight_qtensor_linear
from helpers import device_eq

from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4PackedTensor


from optimum.quanto import qint4
from optimum.quanto.library.extensions import is_extension_available
from optimum.quanto.tensor.weights import WeightQBitsTensor
from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4WeightQBitsTensor
def get_uint4_tensor(shape, device, random=False):
qmax = 2**4
if random:
t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)
else:
numel = np.prod(shape)
t = torch.tensor(range(numel), dtype=torch.int32)
t = (t % qmax).reshape(shape).to(torch.uint8).to(device)
return t


@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("in_features", [128, 256, 512, 1024])
@pytest.mark.parametrize("out_features", [128, 256, 512, 1024])
def test_marlin_int4_weight_qbits_tensor_from_qbits_tensor(in_features, out_features):
qtype = qint4
group_size = 128
dtype = torch.float16
@pytest.mark.parametrize("random", [True, False])
def test_pack_marlin_int4_tensor(in_features, out_features, random):
shape = (out_features, in_features)
device = torch.device("cuda")
qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device)
# Create a MarlinInt4WeightQBitsTensor from the WeightQBitsTensor members
marlinqbt = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
group_size=qbt._group_size,
size=qbt.size(),
stride=qbt.stride(),
data=qbt._data.unpack(),
scale=qbt._scale,
shift=qbt._shift,
)
assert marlinqbt.dtype == dtype
assert marlinqbt.qtype == qtype
assert marlinqbt.shape == shape
assert device_eq(marlinqbt.device, device)
# Verify the dequantized tensors are identical
assert torch.equal(marlinqbt.dequantize(), qbt.dequantize())
# Now verify that we can reconstruct the WeightQBitsTensor
new_qbt = marlinqbt.weight_qbits_tensor()
assert type(new_qbt) is WeightQBitsTensor
assert new_qbt.dtype == dtype
assert new_qbt.qtype == qtype
assert new_qbt.shape == shape
assert torch.equal(new_qbt._data, qbt._data)
assert torch.equal(new_qbt._scale, qbt._scale)
assert torch.equal(new_qbt._shift, qbt._shift)
t = get_uint4_tensor(shape, device, random)
packed = MarlinInt4PackedTensor.pack(t)
assert isinstance(packed, MarlinInt4PackedTensor)
assert device_eq(packed.device, device)
assert torch.equal(t, packed.unpack())


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_marlin_int4_weight_qbits_tensor_move(device):
qtype = qint4
group_size = 128
dtype = torch.float16
shape = (1024, 1024)
def test_move_marlin_int4_packed_tensor(device):
shape = (256, 256)
device = torch.device("cuda")
# Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device("cuda"))
marlinqbt = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
group_size=qbt._group_size,
size=qbt.size(),
stride=qbt.stride(),
data=qbt._data.unpack(),
scale=qbt._scale,
shift=qbt._shift,
)
# Move to device, dequantize and compare
moved_qbt = marlinqbt.to(device)
assert isinstance(moved_qbt, WeightQBitsTensor)
if device.type != "cuda":
assert type(moved_qbt) is not MarlinInt4WeightQBitsTensor
assert marlinqbt.dtype == moved_qbt.dtype
assert marlinqbt.qtype == moved_qbt.qtype
assert marlinqbt.shape == moved_qbt.shape
assert torch.equal(marlinqbt.dequantize().to(device), moved_qbt.dequantize())


def _test_marlin_int4_weight_qbits_tensor_linear(
dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias
):
# Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
qbt = random_qweight(
(out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda")
)
marlin_qweight = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
group_size=qbt._group_size,
size=qbt.size(),
stride=qbt.stride(),
data=qbt._data.unpack(),
scale=qbt._scale,
shift=qbt._shift,
)
check_weight_qtensor_linear(marlin_qweight, batch_size, tokens, use_bias)


@pytest.mark.skipif(
not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8,
reason="CUDA >= sm80 not available",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [16, 32])
@pytest.mark.parametrize("in_features", [1024])
@pytest.mark.parametrize("out_features", [1024, 2048, 4096])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias):
dtype = torch.float16
weight_qtype = qint4
group_size = 128
_test_marlin_int4_weight_qbits_tensor_linear(
dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias
)


#Tests previous Marlin kernel bug: https://github.com/huggingface/optimum-quanto/issues/332
@pytest.mark.skipif(
not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8,
reason="CUDA >= sm80 not available",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [48, 64])
# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384])
@pytest.mark.parametrize("in_features", [4096, 16384])
@pytest.mark.parametrize("out_features", [2048, 4096])
def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features):
dtype = torch.float16
weight_qtype = qint4
group_size = 128
_test_marlin_int4_weight_qbits_tensor_linear(
dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias=False
)
t = get_uint4_tensor(shape, device)
packed = MarlinInt4PackedTensor.pack(t)
moved = packed.to("cuda")
assert isinstance(moved, MarlinInt4PackedTensor)
# Marlin int4 tensors are unpacked when moved out of CUDA device
moved = packed.to("cpu")
assert type(moved) is torch.Tensor
assert torch.equal(t, moved.to("cuda"))
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,14 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features,
)


@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False)
#Tests previous Marlin kernel bug: https://github.com/huggingface/optimum-quanto/issues/332
@pytest.mark.skipif(
not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8,
reason="CUDA >= sm80 not available",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [48, 64])
# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384])
@pytest.mark.parametrize("in_features", [4096, 16384])
@pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384])
@pytest.mark.parametrize("out_features", [2048, 4096])
def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features):
dtype = torch.float16
Expand Down

0 comments on commit d2eba7f

Please sign in to comment.