From 0b512d78eb290158de756b975973b6a8927e0eaa Mon Sep 17 00:00:00 2001
From: shcho1118 <chosanghoon1118@gmail.com>
Date: Sat, 29 Jun 2024 23:55:09 +0900
Subject: [PATCH] perf: add Marlin to w4a16 benchmark

---
 bench/kernels/benchmark_w4a16.py | 33 ++++++++++++++++++++++++++++++--
 1 file changed, 31 insertions(+), 2 deletions(-)

diff --git a/bench/kernels/benchmark_w4a16.py b/bench/kernels/benchmark_w4a16.py
index 36629c67..ac3b4c8f 100644
--- a/bench/kernels/benchmark_w4a16.py
+++ b/bench/kernels/benchmark_w4a16.py
@@ -5,6 +5,8 @@
 import torch
 
 from optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking
+from optimum.quanto.tensor.weights.marlin import marlin_permute
+from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4PackedTensor
 
 
 def benchmark(f, warmup=1, iter=10):
@@ -28,12 +30,15 @@ def get_problem(m, n, k, groupsize=128):
     A = torch.rand((m, k), dtype=torch.half, device=dev)
     B_4bit = torch.randint(0, 2**4, (n, k), dtype=torch.uint8, device=dev)
     B_awq = AWQPackedTensor.pack(B_4bit, packing=AWQPacking.V2)._data
+    B_marlin = MarlinInt4PackedTensor.pack(B_4bit)._data
     B_ref = torch.rand((k, n), dtype=torch.half, device=dev)
     s = torch.rand((k // groupsize, n), dtype=torch.half, device=dev) / 2**4
+    s_marlin = marlin_permute(s)
     z = torch.randint(-(2 ** (4 - 1)), 2 ** (4 - 1), (k // groupsize, n), dtype=torch.int8, device=dev)
     sz = -z * s
+    sz_marlin = marlin_permute(sz)
     torch.cuda.synchronize()
-    return A, B_ref, B_awq, s, sz
+    return A, B_ref, B_awq, B_marlin, s, s_marlin, sz, sz_marlin
 
 
 def benchmark_dense(A, B, m, n, k):
@@ -56,6 +61,16 @@ def benchmark_awq(A, B, s, sz, m, n, k):
     }
 
 
+def benchmark_marlin(A, B, s, sz, m, n, k):
+    workspace = torch.zeros(n // 128 * 16, dtype=torch.int, device=torch.device("cuda:0"))
+    res = benchmark(lambda: torch.ops.quanto.gemm_f16i4_marlin(A, B, s, sz, workspace))
+    return {
+        "s": res,
+        "TFLOP/s": 2 * (m * k) * n / res / 10**12,
+        "GB/s": (2 * A.numel() + 4 * B.numel() + 2 * (m * n) + 2 * s.numel() + 2 * sz.numel()) / res / 10**9,
+    }
+
+
 MODELS = {
     "Llama7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
     "Llama13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
@@ -79,9 +94,10 @@ def run_benchmark(model, tokens=None):
     print(model)
     for m in tokens:
         tot_awq = {"s": 0, "TFLOP/s": 0, "GB/s": 0, "speedup": 0}
+        tot_marlin = {"s": 0, "TFLOP/s": 0, "GB/s": 0, "speedup": 0}
         for layer in layers:
             k, n = layer
-            A, B_ref, B_awq, s, sz = get_problem(m, n, k, groupsize)
+            A, B_ref, B_awq, B_marlin, s, s_marlin, sz, sz_marlin = get_problem(m, n, k, groupsize)
             res_d = benchmark_dense(A, B_ref, m, n, k)
             res_awq = benchmark_awq(A, B_awq, s, sz, m, n, k)
             res_awq["speedup"] = res_d["s"] / res_awq["s"]
@@ -89,13 +105,26 @@ def run_benchmark(model, tokens=None):
             for key in tot_awq:
                 if key != "s":
                     tot_awq[key] += res_awq[key] * res_awq["s"]
+            res_marlin = benchmark_marlin(A, B_marlin, s_marlin, sz_marlin, m, n, k)
+            res_marlin["speedup"] = res_d["s"] / res_marlin["s"]
+            tot_marlin["s"] += res_marlin["s"]
+            for key in tot_marlin:
+                if key != "s":
+                    tot_marlin[key] += res_marlin[key] * res_marlin["s"]
         for key in tot_awq:
             if key != "s":
                 tot_awq[key] /= tot_awq["s"]
+        for key in tot_marlin:
+            if key != "s":
+                tot_marlin[key] /= tot_marlin["s"]
         print(
             "AWQ, tokens=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f"
             % (m, tot_awq["s"], tot_awq["TFLOP/s"], tot_awq["GB/s"], tot_awq["speedup"])
         )
+        print(
+            "Marlin, batch=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f"
+            % (m, tot_marlin["s"], tot_marlin["TFLOP/s"], tot_marlin["GB/s"], tot_marlin["speedup"])
+        )
 
 
 def main():