Skip to content

Commit

Permalink
feat(bench): generation int4 weights
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Feb 18, 2024
1 parent 096650a commit 6302171
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions bench/generation/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig

from quanto import Calibration, freeze, qint8, quantize
from quanto import Calibration, freeze, qint4, qint8, quantize


CALIBRATION_PROMPT = "It was a bright cold day in April, and the clocks were striking thirteen."
Expand Down Expand Up @@ -129,7 +129,7 @@ def main():
"--quantization",
type=str,
default="none",
choices=["bnb_4bit", "bnb_8bit", "w8a16", "w8a8"],
choices=["bnb_4bit", "bnb_8bit", "w4a16", "w4a8", "w8a16", "w8a8"],
help="One of none, bnb_4bit, bnb_8bit, w8a16, w8a8.",
)
args = parser.parse_args()
Expand Down Expand Up @@ -162,10 +162,10 @@ def main():
else:
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype, low_cpu_mem_usage=True).to(device)

if args.quantization in ("w8a8", "w8a16"):
if args.quantization in ("w4a8", "w4a16", "w8a8", "w8a16"):
print("quantizing")
start = time.time()
weights = qint8
weights = qint8 if "w8" in args.quantization else qint4
activations = None if "a16" in args.quantization else qint8
quantize(model, weights=weights, activations=activations)
if activations is not None:
Expand Down

0 comments on commit 6302171

Please sign in to comment.