diff --git a/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu b/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu index a2a87a4b..bd685715 100644 --- a/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu +++ b/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu @@ -728,7 +728,7 @@ __global__ void Marlin( // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. const int THREADS = 256; const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = 164 * 1024; // max shared memory on compute capability 8.0 +const int SHARED_MEM = 128 * 1024; // max shared memory on compute capability 8.0 // ADDED: add scaled zero pointer #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ diff --git a/test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py b/test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py index 4048a3ea..185dc26b 100644 --- a/test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py +++ b/test/tensor/weights/optimized/test_marlin_int4_packed_tensor.py @@ -12,49 +12,139 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import numpy as np import pytest import torch -from helpers import device_eq - -from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4PackedTensor - +from helpers import device_eq, random_qweight +from tensor.weights.weight_helpers import check_weight_qtensor_linear -def get_uint4_tensor(shape, device, random=False): - qmax = 2**4 - if random: - t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) - else: - numel = np.prod(shape) - t = torch.tensor(range(numel), dtype=torch.int32) - t = (t % qmax).reshape(shape).to(torch.uint8).to(device) - return t +from optimum.quanto import qint4 +from optimum.quanto.library.extensions import is_extension_available +from optimum.quanto.tensor.weights import WeightQBitsTensor +from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4WeightQBitsTensor -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available" +) @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) -@pytest.mark.parametrize("random", [True, False]) -def test_pack_marlin_int4_tensor(in_features, out_features, random): +def test_marlin_int4_weight_qbits_tensor_from_qbits_tensor(in_features, out_features): + qtype = qint4 + group_size = 128 + dtype = torch.float16 shape = (out_features, in_features) device = torch.device("cuda") - t = get_uint4_tensor(shape, device, random) - packed = MarlinInt4PackedTensor.pack(t) - assert isinstance(packed, MarlinInt4PackedTensor) - assert device_eq(packed.device, device) - assert torch.equal(t, packed.unpack()) + qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device) + # Create a MarlinInt4WeightQBitsTensor from the WeightQBitsTensor members + marlinqbt = MarlinInt4WeightQBitsTensor( + qtype=qbt.qtype, + axis=qbt.axis, + group_size=qbt._group_size, + size=qbt.size(), + stride=qbt.stride(), + data=qbt._data.unpack(), + scale=qbt._scale, + shift=qbt._shift, + ) + assert marlinqbt.dtype == dtype + assert marlinqbt.qtype == qtype + assert marlinqbt.shape == shape + assert device_eq(marlinqbt.device, device) + # Verify the dequantized tensors are identical + assert torch.equal(marlinqbt.dequantize(), qbt.dequantize()) + # Now verify that we can reconstruct the WeightQBitsTensor + new_qbt = marlinqbt.weight_qbits_tensor() + assert type(new_qbt) is WeightQBitsTensor + assert new_qbt.dtype == dtype + assert new_qbt.qtype == qtype + assert new_qbt.shape == shape + assert torch.equal(new_qbt._data, qbt._data) + assert torch.equal(new_qbt._scale, qbt._scale) + assert torch.equal(new_qbt._shift, qbt._shift) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_move_marlin_int4_packed_tensor(device): - shape = (256, 256) +def test_marlin_int4_weight_qbits_tensor_move(device): + qtype = qint4 + group_size = 128 + dtype = torch.float16 + shape = (1024, 1024) device = torch.device("cuda") - t = get_uint4_tensor(shape, device) - packed = MarlinInt4PackedTensor.pack(t) - moved = packed.to("cuda") - assert isinstance(moved, MarlinInt4PackedTensor) - # Marlin int4 tensors are unpacked when moved out of CUDA device - moved = packed.to("cpu") - assert type(moved) is torch.Tensor - assert torch.equal(t, moved.to("cuda")) + # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA + qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device("cuda")) + marlinqbt = MarlinInt4WeightQBitsTensor( + qtype=qbt.qtype, + axis=qbt.axis, + group_size=qbt._group_size, + size=qbt.size(), + stride=qbt.stride(), + data=qbt._data.unpack(), + scale=qbt._scale, + shift=qbt._shift, + ) + # Move to device, dequantize and compare + moved_qbt = marlinqbt.to(device) + assert isinstance(moved_qbt, WeightQBitsTensor) + if device.type != "cuda": + assert type(moved_qbt) is not MarlinInt4WeightQBitsTensor + assert marlinqbt.dtype == moved_qbt.dtype + assert marlinqbt.qtype == moved_qbt.qtype + assert marlinqbt.shape == moved_qbt.shape + assert torch.equal(marlinqbt.dequantize().to(device), moved_qbt.dequantize()) + + +def _test_marlin_int4_weight_qbits_tensor_linear( + dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias +): + # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA + qbt = random_qweight( + (out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda") + ) + marlin_qweight = MarlinInt4WeightQBitsTensor( + qtype=qbt.qtype, + axis=qbt.axis, + group_size=qbt._group_size, + size=qbt.size(), + stride=qbt.stride(), + data=qbt._data.unpack(), + scale=qbt._scale, + shift=qbt._shift, + ) + check_weight_qtensor_linear(marlin_qweight, batch_size, tokens, use_bias) + + +@pytest.mark.skipif( + not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, + reason="CUDA >= sm80 not available", +) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("tokens", [16, 32]) +@pytest.mark.parametrize("in_features", [1024]) +@pytest.mark.parametrize("out_features", [1024, 2048, 4096]) +@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) +def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias): + dtype = torch.float16 + weight_qtype = qint4 + group_size = 128 + _test_marlin_int4_weight_qbits_tensor_linear( + dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias + ) + + +#Tests previous Marlin kernel bug: https://github.com/huggingface/optimum-quanto/issues/332 +@pytest.mark.skipif( + not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, + reason="CUDA >= sm80 not available", +) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("tokens", [48, 64]) +# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384]) +@pytest.mark.parametrize("in_features", [4096, 16384]) +@pytest.mark.parametrize("out_features", [2048, 4096]) +def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features): + dtype = torch.float16 + weight_qtype = qint4 + group_size = 128 + _test_marlin_int4_weight_qbits_tensor_linear( + dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias=False + ) \ No newline at end of file