forked from JDACS4C-IMPROVE/GraphDRP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrm.py
171 lines (150 loc) · 5.01 KB
/
frm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from pathlib import Path
import torch
import candle
import candle_improve_utils as improve_utils
file_path = Path(__file__).resolve().parent
additional_definitions = [
{"name": "pred_col_name_suffix",
"type": str,
"default": "_pred",
"help": "Tag to add to predictions when storing the data frame."},
{"name": "y_col_name",
"type": str,
"default": "auc",
"help": "Drug sensitivity score to use as the target variable (e.g., IC50, AUC)."},
{"name": "model_arch",
"default": "GINConvNet",
"choices": ["GINConvNet", "GATNet", "GAT_GCN", "GCNNet"],
"type": str,
"help": "Model architecture to run.", },
# Preprocessing
{"name": "download",
"type": candle.str2bool,
"default": False,
"help": "Flag to indicate if downloading from FTP site.",},
{"name": "set",
"default": "mixed",
"choices": ["mixed", "cell", "drug"],
"type": str,
"help": "Validation scheme (data splitting strategy).", },
{"name": "train_split",
"nargs": "+",
"type": str,
"help": "path to the file that contains the split ids (e.g., 'split_0_tr_id', 'split_0_vl_id').", },
# Training / Inference
{"name": "log_interval",
"action": "store",
"type": int,
"help": "Interval for saving o/p", },
{"name": "cuda_name",
"action": "store",
"type": str,
"help": "Cuda device (e.g.: cuda:0, cuda:1."},
{"name": "train_ml_data_dir",
"action": "store",
"type": str,
"help": "Datadir where train data is stored."},
{"name": "val_ml_data_dir",
"action": "store",
"type": str,
"help": "Datadir where val data is stored."},
{"name": "test_ml_data_dir",
"action": "store",
"type": str,
"help": "Datadir where test data is stored."},
{"name": "model_outdir",
"action": "store",
"type": str,
"help": "Datadir to store trained model."},
{"name": "model_params",
"type": str,
"default": "model.pt",
"help": "Filename to store trained model."},
{"name": "pred_fname",
"type": str,
"default": "test_preds.csv",
"help": "Name of file to store inference results."},
{"name": "response_data",
"type": str,
"default": "test_response.csv",
"help": "Name of file to store inference results."},
{"name": "out_json",
"type": str,
"default": "test_scores.json",
"help": "Name of file to store scores."},
]
required = [
"train_data",
"val_data",
"test_data",
# "train_split",
]
class BenchmarkFRM(candle.Benchmark):
""" Benchmark for FRM. """
def set_locals(self):
""" Set parameters for the benchmark.
Parameters
----------
required: set of required parameters for the benchmark.
additional_definitions: list of dictionaries describing the additional parameters for the
benchmark.
"""
improve_definitions = improve_utils.parser_from_json("candle_improve.json")
if required is not None:
self.required = set(required)
if additional_definitions is not None:
self.additional_definitions = additional_definitions + improve_definitions
def initialize_parameters(default_model="frm_default_model.txt"):
"""Parse execution parameters from file or command line.
Parameters
----------
default_model : string
File containing the default parameter definition.
Returns
-------
gParameters: python dictionary
A dictionary of Candle keywords and parsed values.
"""
# Build benchmark object
frm = BenchmarkFRM(
file_path,
default_model,
"python",
prog="frm",
desc="frm functionality",
)
# Initialize parameters
gParameters = candle.finalize_parameters(frm)
gParameters = improve_utils.build_improve_paths(gParameters)
return gParameters
def predicting(model, device, loader):
""" Method to run predictions/inference.
The same method is in frm_train.py
TODO: put this in some utils script. --> graphdrp?
Parameters
----------
model : pytorch model
Model to evaluate.
device : string
Identifier for hardware that will be used to evaluate model.
loader : pytorch data loader.
Object to load data to evaluate.
Returns
-------
total_labels: numpy array
Array with ground truth.
total_preds: numpy array
Array with inferred outputs.
"""
model.eval()
total_preds = torch.Tensor()
total_labels = torch.Tensor()
print("Make prediction for {} samples...".format(len(loader.dataset)))
with torch.no_grad():
for data in loader:
data = data.to(device)
output, _ = model(data)
# Is this computationally efficient?
total_preds = torch.cat((total_preds, output.cpu()), 0) # preds to tensor
total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0) # labels to tensor
return total_labels.numpy().flatten(), total_preds.numpy().flatten()