-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhf_infer.py
58 lines (52 loc) · 1.76 KB
/
hf_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import time
import torch
from transformers import (
AutoTokenizer,
BitsAndBytesConfig,
MistralForCausalLM,
GenerationConfig,
)
BNB_CONFIG = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage="bfloat16",
)
MODEL = MistralForCausalLM.from_pretrained(
"runs/merged",
quantization_config=BNB_CONFIG,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
TOKENIZER = AutoTokenizer.from_pretrained("runs/merged")
GENERATION_CONFIG = GenerationConfig.from_pretrained("runs/merged")
def infer(messages):
text = TOKENIZER.apply_chat_template(messages, return_tensors="pt", tokenize=False)
inputs = TOKENIZER(text, return_tensors="pt").input_ids.to(device="cuda:0", dtype=torch.long)
generated = MODEL.generate(
inputs,
max_new_tokens=100,
top_k=50,
top_p=0.95,
generation_config=GENERATION_CONFIG,
)
response = TOKENIZER.batch_decode(generated, skip_special_tokens=False)[0]
return response.split("<|im_start|>assistant")[-1][1:].split("<|im_end|>")[0]
if __name__ == "__main__":
N = 100
messages = [
{
"role": "system",
"content": "You are an ai image generator that takes user requests and interprets and converts them to optimized stable diffusion prompts. Always specify if the image is a photograph/painting/digital art/etc.",
},
{
"role": "user",
"content": "a photo of donald trump doing a handstand, 32k, HD, best quality",
},
]
start = time.time()
for i in range(N):
print(infer(messages))
print((time.time() - start) / N)