generated from SparkJiao/pytorch-transformers-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllama_example.py
112 lines (94 loc) · 3.17 KB
/
llama_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
from pathlib import Path
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from meta_llama import ModelArgs, Transformer, Tokenizer, LLaMA
def setup_model_parallel() -> Tuple[int, int]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))
torch.distributed.init_process_group("nccl")
initialize_model_parallel(world_size)
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(1)
return local_rank, world_size
def load(
ckpt_dir: str,
tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_batch_size: int = 32,
):
local_rank, world_size = setup_model_parallel()
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
generator = load(
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)
prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
"Simply put, the theory of relativity states that ",
"Building a website can be done in 10 simple steps:\n",
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
"""Tweet: "I hate it when my phone battery dies."
Sentiment: Negative
###
Tweet: "My day has been 👍"
Sentiment: Positive
###
Tweet: "This is the link to the article"
Sentiment: Neutral
###
Tweet: "This new music video was incredibile"
Sentiment:""",
"""Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
)
for result in results:
print(result)
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)