-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_llm_locally.py
41 lines (32 loc) · 1.91 KB
/
run_llm_locally.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from lib.llm import start_language_models
def summarise_articles(args):
import json
from lib.summariser import Summariser
summariser_model = start_language_models(
lang_model_path=args.language_model,
emb_model_path=args.embedding_model,
device_map=args.device_map,
model_type=args.model_type,
chat_type=args.chat_type,
n_ctx=args.n_ctx,
n_gpu_layers=args.n_gpu_layers,
)
summariser = Summariser(summariser_model, texts=args.articles_json, query=args.query)
summarised_articles = summariser.process()
with open(args.output_file, "w") as f:
json.dump(summarised_articles, f)
if __name__ == "__main__":
import argparse
argparse = argparse.ArgumentParser()
argparse.add_argument("--language_model", type=str,required=True, help="Main LLM model for summarisation")
argparse.add_argument("--embedding_model", type=str, default=None, required=False, help="Sentence embedding model to calculate similarity")
argparse.add_argument("--device_map", type=str, default='cuda', required=False, help="Device map for LLM model")
argparse.add_argument("--n_gpu_layers", type=int, default=0, required=False, help="number of layers for LLM model")
argparse.add_argument("--model_type", type=str, default='gguf', required=False, help="LLM model type")
argparse.add_argument("--chat_type", type=str, default="llama", required=False, help="LLM chat type")
argparse.add_argument("--n_ctx", type=int, default=8000, required=False, help="Number of context for LLM model")
argparse.add_argument("--articles_json", type=str, required=True, help="JSON file containing articles")
argparse.add_argument("--query", type=str, required=True, help="Query string")
argparse.add_argument("--output_file", type=str, required=True, help="Output file containing summarised articles")
args = argparse.parse_args()
summarise_articles(args)