-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixed external training dataset parameter. Now updates orig_params in…
…stead of params
- Loading branch information
1 parent
65ec4da
commit c5481a0
Showing
3 changed files
with
91 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
atomsci/ddm/test/integrative/external_dataset/H1_graphconv.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"system": "LC", | ||
"datastore": "False", | ||
"save_results": "False", | ||
"data_owner": "username", | ||
"prediction_type": "regression", | ||
"dataset_key": "../../test_datasets/H1_std.csv", | ||
"id_col": "compound_id", | ||
"smiles_col": "base_rdkit_smiles", | ||
"response_cols": "pKi_mean", | ||
"split_uuid": "002251a2-83f8-4511-acf5-e8bbc5f86677", | ||
"previously_split": "True", | ||
"uncertainty": "True", | ||
"verbose": "True", | ||
"transformers": "True", | ||
"model_type": "NN", | ||
"featurizer": "graphconv", | ||
"result_dir": "./output" | ||
} |
71 changes: 71 additions & 0 deletions
71
atomsci/ddm/test/integrative/external_dataset/test_external_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python | ||
|
||
import json | ||
import pandas as pd | ||
import os | ||
import sys | ||
|
||
import atomsci.ddm.pipeline.parameter_parser as parse | ||
from atomsci.ddm.pipeline import model_pipeline as mp | ||
from atomsci.ddm.pipeline import predict_from_model as pfm | ||
import pytest | ||
|
||
def clean(): | ||
"""Clean test files""" | ||
if "output" not in os.listdir(): | ||
os.mkdir("output") | ||
for f in os.listdir("./output"): | ||
if os.path.isfile("./output/"+f): | ||
os.remove("./output/"+f) | ||
|
||
def test(): | ||
"""Test AD index calculation: Curate data, fit model, and predict property for new compounds for each feature set""" | ||
|
||
# Clean | ||
# ----- | ||
clean() | ||
|
||
# Run HyperOpt | ||
# ------------ | ||
with open("H1_graphconv.json", "r") as f: | ||
hp_params = json.load(f) | ||
|
||
script_dir = parse.__file__.strip("parameter_parser.py").replace("/pipeline/", "") | ||
python_path = sys.executable | ||
hp_params["script_dir"] = script_dir | ||
hp_params["python_path"] = python_path | ||
|
||
params = parse.wrapper(hp_params) | ||
if not os.path.isfile(params.dataset_key): | ||
params.dataset_key = os.path.join(params.script_dir, params.dataset_key) | ||
|
||
train_df = pd.read_csv(params.dataset_key) | ||
|
||
pl = mp.ModelPipeline(params) | ||
pl.train_model() | ||
|
||
# this should raise an exception | ||
with pytest.raises(Exception, match='Dataset file file_does_not_exist.csv does not exist'): | ||
pred_df_file = pfm.predict_from_model_file(model_path=pl.params.model_tarball_path, | ||
input_df=train_df[:10], | ||
id_col="compound_id", | ||
smiles_col="base_rdkit_smiles", | ||
response_col="pKi_mean", | ||
dont_standardize=True, | ||
AD_method="z_score", | ||
external_training_data='file_does_not_exist.csv') | ||
|
||
# this should work | ||
pred_df_file = pfm.predict_from_model_file(model_path=pl.params.model_tarball_path, | ||
input_df=train_df[:10], | ||
id_col="compound_id", | ||
smiles_col="base_rdkit_smiles", | ||
response_col="pKi_mean", | ||
dont_standardize=True, | ||
AD_method="z_score") | ||
assert("AD_index" in pred_df_file.columns.values), 'Error: No AD_index column in pred_df_file' | ||
|
||
clean() | ||
|
||
if __name__ == '__main__': | ||
test() |