-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathrecognize.py
64 lines (56 loc) · 2.23 KB
/
recognize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) 2024 Binbin Zhang([email protected])
import sys
from dataclasses import dataclass, field
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import transformers
from transformers import AutoTokenizer
from accelerate import Accelerator
from dataset import SpeechDataset, DataArguments
from speech_llm import init_model, ModelArguments
@dataclass
class DecodeArguments:
llm_type: str = 'qwen2'
max_new_tokens: int = 50
num_beams: int = 1
batch_size: int = 1
result_path: str = field(default=None, metadata={"help": "Path to result"})
def main():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, DecodeArguments))
model_args, data_args, decode_args = parser.parse_args_into_dataclasses()
model = init_model(model_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path)
if decode_args.llm_type == 'qwen2':
eos_token_id = tokenizer.convert_tokens_to_ids(
['<|endoftext|>', '<|im_end|>'])
else:
tokenizer.pad_token = '<|finetune_right_pad_id|>'
eos_token_id = tokenizer.convert_tokens_to_ids(
['<|end_of_text|>', '<|eot_id|>'])
print('eos_token_id', eos_token_id)
test_dataset = SpeechDataset(data_args.data_path,
tokenizer=tokenizer,
inference=True)
data_loader = DataLoader(test_dataset, batch_size=decode_args.batch_size)
if torch.cuda.is_available():
model = model.cuda()
accelerator = Accelerator()
model, data_loader = accelerator.prepare(model, data_loader)
model.eval()
fid = open(decode_args.result_path, 'w', encoding='utf8')
with torch.no_grad():
for item in tqdm(data_loader):
generated_ids = model.generate(**item,
eos_token_id=eos_token_id,
decode_config=decode_args)
text = tokenizer.batch_decode(generated_ids,
skip_special_tokens=True)
print(text)
for t in text:
fid.write(t + '\n')
sys.stdout.flush()
fid.close()
if __name__ == "__main__":
main()