diff --git a/bench/generation/benchmark.py b/bench/generation/benchmark.py index b91cd421..3aa34e60 100644 --- a/bench/generation/benchmark.py +++ b/bench/generation/benchmark.py @@ -173,8 +173,9 @@ def main(): # Very simple calibration to avoid completely off results with Calibration(): generate(model, tokenizer, device, prompt=CALIBRATION_PROMPT) + print("Freezing") freeze(model) - print(f"Finished: {time.time()-start}") + print(f"Finished: {time.time()-start:.2f}") memory = get_device_memory(device) if memory is not None: diff --git a/quanto/nn/qlinear.py b/quanto/nn/qlinear.py index b32878cf..bd83f3c1 100644 --- a/quanto/nn/qlinear.py +++ b/quanto/nn/qlinear.py @@ -24,6 +24,7 @@ def from_module(cls, module, weights=torch.int8, activations: Optional[torch.dty dtype=module.weight.dtype, weights=weights, activations=activations, + device=module.weight.device, ) with torch.no_grad(): qmodule.weight.copy_(module.weight)