Skip to content

Commit

Permalink
Update inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
heroding77 authored Aug 9, 2024
1 parent 4904f29 commit 3103781
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
import json
from tqdm import tqdm
import os

model_name = "/path/to/SEA-E/"
tokenizer = AutoTokenizer.from_pretrained(model_name)
chat_model = AutoModelForCausalLM.from_pretrained(model_name)
chat_model.to("cuda:0")


def read_txt_file(path):
with open(path, 'r') as f:
content = f.read()
Expand All @@ -22,12 +29,6 @@ def get_subfile(path):
subfiles = [d for d in os.listdir(path) if os.path.isfile(os.path.join(path, d))]
return subfiles


model_name = "/path/to/SEA-E/"
tokenizer = AutoTokenizer.from_pretrained(model_name)
chat_model = AutoModelForCausalLM.from_pretrained(model_name)
chat_model.to("cuda:0")

def infer_one(mmd_file_path):
system_prompt_dict = read_json_file(os.path.join(os.path.dirname(os.path.abspath(__file__)),"template.json"))
instruction = system_prompt_dict['instruction_e']
Expand All @@ -54,6 +55,8 @@ def run_review(mmd_file_path):
os.mkdir(infer_save_path)
res = infer_one(mmd_file_path)
return res


if __name__ == "__main__":
review = run_review("/path/to")
print(review)

0 comments on commit 3103781

Please sign in to comment.