-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
84 lines (62 loc) · 2.76 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python
import os
import sys
from argparse import ArgumentParser
import logging
import torch
from torch.utils.data import DataLoader
from openfold.utils.rigid_utils import Rigid
from diffusion.optimizer import DiffusionModelOptimizer
from diffusion.model import Model
from diffusion.data import MhcpDataset
from diffusion.tools.pdb import save
_log = logging.getLogger(__name__)
arg_parser = ArgumentParser()
arg_parser.add_argument("model", help="model parameters file")
arg_parser.add_argument("test_hdf5", help="test data")
arg_parser.add_argument("--debug", "-d", action="store_const", const=True, default=False, help="run in debug mode")
arg_parser.add_argument("-T", type=int, default=1000, help="number of noise steps")
arg_parser.add_argument("--batch-size", "-b", type=int, help="data batch size", default=64)
arg_parser.add_argument("--num-workers", "-w", type=int, help="number of batch loading workers", default=4)
if __name__ == "__main__":
args = arg_parser.parse_args()
# init logger
log_level = logging.INFO
if args.debug:
log_level = logging.DEBUG
logging.basicConfig(stream=sys.stdout, level=log_level)
# select device
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
# init model & optimizer
_log.debug(f"initializing model")
model = Model(16, 22, args.T).to(device=device)
if os.path.isfile(args.model):
model.load_state_dict(torch.load(args.model, map_location=device))
_log.debug(f"initializing diffusion model optimizer")
dm = DiffusionModelOptimizer(args.T, model, 0.0)
# load model state from input file
model.load_state_dict(torch.load(args.model, map_location=device))
# open dataset
test_dataset = MhcpDataset(args.test_hdf5, device)
# get output directory
output_path = os.path.splitext(args.test_hdf5)[0] + "-sampled"
if not os.path.isdir(output_path):
os.mkdir(output_path)
with torch.no_grad():
for true_batch in DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers):
# get entry names
names = list(true_batch['name'][0])
# noisify
noise = dm.gen_noise(true_batch["frames"].shape[:-1], device=device)
input_batch = {k: true_batch[k] for k in true_batch}
input_batch["frames"] = noise["frames"].to_tensor_7()
input_batch["torsions"] = noise["torsions"]
# denoisify
pred_batch = dm.sample(input_batch)
# add all protein residues
pred_batch.update(test_dataset.get_protein_positions(names))
# save denoisified data
for i, name in enumerate(names):
save(pred_batch, i, f"{output_path}/{name}.pdb")