-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy patheval.py
100 lines (81 loc) · 2.94 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
import torch
from tqdm import tqdm
import cerebras.pytorch as cstorch
from configuration import parse_args
from data import get_dataloader
from model import GPTModel
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
def main(model_config, config, cs_config):
backend = cstorch.backend(config.backend, cluster_config=cs_config)
out_dir = Path(config.out_dir)
state_dict = cstorch.load(config.checkpoint_path)
if not backend.is_cpu:
cstorch.amp.set_half_dtype("bfloat16")
with backend.device:
model = GPTModel(model_config)
compiled_model = cstorch.compile(model, backend)
logger.info(f"Loading checkpoint from {config.checkpoint_path}")
state_dict = cstorch.load(config.checkpoint_path)
model.load_state_dict(state_dict["model"])
global_step = state_dict.get("global_step", 0)
@cstorch.trace
@torch.no_grad()
def eval_step(batch):
input_ids, labels = batch
loss = compiled_model(input_ids, labels)
return loss
total_loss = 0
total_steps = 0
@cstorch.step_closure
def post_eval_step(loss, step):
nonlocal total_loss
nonlocal total_steps
total_loss += loss
total_steps += 1
from cerebras.pytorch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(
log_dir=out_dir.joinpath("eval")
)
data_path = os.path.join(config.dataset, "val.bin")
dataloader = cstorch.utils.data.DataLoader(
get_dataloader,
data_path,
config.sequence_length,
config.batch_size,
)
num_steps = len(dataloader)
executor = cstorch.utils.data.DataExecutor(
dataloader, num_steps=num_steps,
)
logger.info(f"Total eval steps: {num_steps}")
for step, batch in tqdm(enumerate(executor, start=1), total=num_steps):
if step > num_steps:
break
loss = eval_step(batch)
post_eval_step(loss, step)
avg_loss = total_loss / total_steps
writer.add_scalar("loss", avg_loss, global_step)
logger.info(f"Average eval loss: {avg_loss}")
if __name__ == "__main__":
model_config, run_config, cs_config = parse_args()
if run_config.checkpoint_path is None:
raise ValueError(f"You must specify a checkpoint path for model eval")
main(model_config, run_config, cs_config)