diff --git a/omiclearn/gui.py b/omiclearn/gui.py index e81a2cd..ba3f8bb 100644 --- a/omiclearn/gui.py +++ b/omiclearn/gui.py @@ -32,6 +32,7 @@ def run(): file_path, "--global.developmentMode=false", "--browser.gatherUsageStats=False", + "--logger.level=error", ] sys.argv = args diff --git a/omiclearn/utils/ml_helper.py b/omiclearn/utils/ml_helper.py index f8311bb..962d1f1 100644 --- a/omiclearn/utils/ml_helper.py +++ b/omiclearn/utils/ml_helper.py @@ -283,7 +283,9 @@ def perform_cross_validation(state, cohort_column=None): for metric_name, metric_fct in scorer_dict.items(): _cv_results[metric_name] = [] + _cv_results[metric_name + "_train"] = [] _cv_results["pr_auc"] = [] # ADD pr_auc manually + _cv_results["pr_auc_train"] = [] # ADD pr_auc manually X = state.X y = state.y @@ -367,10 +369,20 @@ def perform_cross_validation(state, cohort_column=None): calibrated_clf = CalibratedClassifierCV(clf, cv=cv_generator) calibrated_clf.fit(X_train, y_train) + + # Train + y_train_pred = calibrated_clf.predict(X_train) + y_train_pred_proba = calibrated_clf.predict_proba(X_train) + # Validation y_pred = calibrated_clf.predict(X_test) y_pred_proba = calibrated_clf.predict_proba(X_test) else: clf.fit(X_train, y_train) + + # Train + y_train_pred = clf.predict(X_train) + y_train_pred_proba = clf.predict_proba(X_train) + # Validation y_pred = clf.predict(X_test) y_pred_proba = clf.predict_proba(X_test) @@ -395,21 +407,42 @@ def perform_cross_validation(state, cohort_column=None): feature_importance = None # ROC CURVE + # Validation fpr, tpr, cutoffs = roc_curve(y_test, y_pred_proba[:, 1]) # PR CURVE + # Train + precision_train, recall_train, _train = precision_recall_curve( + y_train, y_train_pred_proba[:, 1] + ) + # Validation precision, recall, _ = precision_recall_curve(y_test, y_pred_proba[:, 1]) for metric_name, metric_fct in scorer_dict.items(): if metric_name == "roc_auc": + # Train + _cv_results[metric_name + "_train"].append( + metric_fct(y_train, y_train_pred_proba[:, 1]) + ) + # Validation _cv_results[metric_name].append( metric_fct(y_test, y_pred_proba[:, 1]) ) elif metric_name in ["precision", "recall", "f1"]: + # Train + _cv_results[metric_name + "_train"].append( + metric_fct(y_train, y_train_pred, zero_division=0) + ) + # Validation _cv_results[metric_name].append( metric_fct(y_test, y_pred, zero_division=0) ) else: + # Train + _cv_results[metric_name + "_train"].append( + metric_fct(y_train, y_train_pred) + ) + # Validation _cv_results[metric_name].append(metric_fct(y_test, y_pred)) # Results of Cross Validation @@ -423,12 +456,10 @@ def perform_cross_validation(state, cohort_column=None): _cv_results["n_class_0_test"].append(np.sum(y_test)) _cv_results["n_class_1_test"].append(np.sum(~y_test)) _cv_results["class_ratio_test"].append(np.sum(y_test) / len(y_test)) - _cv_results["pr_auc"].append( - auc(recall, precision) - ) # ADD PR Curve AUC Score - _cv_curves["pr_auc"].append( - auc(recall, precision) - ) # ADD PR Curve AUC Score + # Train PR Curve AUC Score + _cv_results["pr_auc_train"].append(auc(recall_train, precision_train)) + # Validation PR Curve AUC Score + _cv_results["pr_auc"].append(auc(recall, precision)) _cv_curves["roc_curves_"].append((fpr, tpr, cutoffs)) _cv_curves["pr_curves_"].append((precision, recall, _)) _cv_curves["y_hats_"].append((y_test.values, y_pred)) diff --git a/tests/test_helper.py b/tests/test_helper.py index 90a6890..77cedfc 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -38,16 +38,13 @@ def test_load_data(): # csv df = pd.DataFrame({"A": [1, 1], "B": [0, 0]}) csv_data, warnings = load_data("test_csv_c.csv", "Comma (,)") - print(csv_data) pd.testing.assert_frame_equal(csv_data, df) csv_data, warnings = load_data("test_csv_sc.csv", "Semicolon (;)") - print(csv_data) pd.testing.assert_frame_equal(csv_data, df) # TSV tsv_data, warnings = load_data("test_tsv.tsv", "Tab (\\t) for TSV") - print(tsv_data) pd.testing.assert_frame_equal(tsv_data, df) @@ -153,10 +150,10 @@ def test_integration(): test_state["cv_repeats"] = 2 test_state["bar"] = st.progress(0) test_state["features"] = ["AAA", "BBB", "CCC", "_study"] + # Generate X and y main_analysis_run(test_state) - # print("\n", test_state, "\n") _cv_results, _cv_curves = perform_cross_validation(test_state, cohort_column=None) assert _cv_results == expected_cv_results, "Error in CV Results" assert str(_cv_curves) == str(expected_cv_curves_str), "Error in CV Curves" diff --git a/tests/test_results.py b/tests/test_results.py index 28a0bdc..4a92208 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -31,21 +31,35 @@ 0.8333333333333334, 0.8571428571428571, 0.8333333333333334, - 1, + 1.0, ], - "roc_auc": [0.875, 0.8333333333333333, 1, 0.875, 0.8333333333333333, 1], - "precision": [0.75, 1, 0.75, 0.75, 1, 1], - "recall": [1, 0.6666666666666666, 1, 1, 0.6666666666666666, 1], - "f1": [0.8571428571428571, 0.8, 0.8571428571428571, 0.8571428571428571, 0.8, 1], + "accuracy_train": [1.0, 1.0, 0.9230769230769231, 1.0, 1.0, 0.9230769230769231], + "roc_auc": [0.875, 0.8333333333333333, 1.0, 0.875, 0.8333333333333333, 1.0], + "roc_auc_train": [1.0, 1.0, 0.9761904761904763, 1.0, 1.0, 0.9761904761904763], + "precision": [0.75, 1.0, 0.75, 0.75, 1.0, 1.0], + "precision_train": [1.0, 1.0, 0.8571428571428571, 1.0, 1.0, 0.8571428571428571], + "recall": [1.0, 0.6666666666666666, 1.0, 1.0, 0.6666666666666666, 1.0], + "recall_train": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "f1": [0.8571428571428571, 0.8, 0.8571428571428571, 0.8571428571428571, 0.8, 1.0], + "f1_train": [1.0, 1.0, 0.923076923076923, 1.0, 1.0, 0.923076923076923], "balanced_accuracy": [ 0.875, 0.8333333333333333, 0.8333333333333333, 0.875, 0.8333333333333333, - 1, + 1.0, ], - "pr_auc": [0.875, 0.9166666666666666, 1, 0.875, 0.9166666666666666, 1], + "balanced_accuracy_train": [ + 1.0, + 1.0, + 0.9285714285714286, + 1.0, + 1.0, + 0.9285714285714286, + ], + "pr_auc": [0.875, 0.9166666666666666, 1.0, 0.875, 0.9166666666666666, 1.0], + "pr_auc_train": [1.0, 1.0, 0.9742063492063492, 1.0, 1.0, 0.9742063492063492], } -expected_cv_curves_str = """{'pr_auc': [0.875, 0.9166666666666666, 1.0, 0.875, 0.9166666666666666, 1.0], 'roc_curves_': [(array([0. , 0.25, 1. ]), array([0., 1., 1.]), array([1.7956569 , 0.7956569 , 0.20434304], dtype=float32)), (array([0., 0., 1.]), array([0. , 0.66666667, 1. ]), array([1.8162205 , 0.8162206 , 0.15752529], dtype=float32)), (array([0. , 0. , 0.33333333, 1. ]), array([0., 1., 1., 1.]), array([1.8069754, 0.8069754, 0.502567 , 0.1422766], dtype=float32)), (array([0. , 0.25, 1. ]), array([0., 1., 1.]), array([1.7956569 , 0.7956569 , 0.20434304], dtype=float32)), (array([0., 0., 1.]), array([0. , 0.66666667, 1. ]), array([1.8162205 , 0.8162206 , 0.15752529], dtype=float32)), (array([0., 0., 1.]), array([0., 1., 1.]), array([1.8069754, 0.8069754, 0.1422766], dtype=float32))], 'pr_curves_': [(array([0.42857143, 0.75 , 1. ]), array([1., 1., 0.]), array([0.20434304, 0.7956569 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1. , 0.66666667, 0. ]), array([0.15752529, 0.8162206 ], dtype=float32)), (array([0.5 , 0.75, 1. , 1. ]), array([1., 1., 1., 0.]), array([0.1422766, 0.502567 , 0.8069754], dtype=float32)), (array([0.42857143, 0.75 , 1. ]), array([1., 1., 0.]), array([0.20434304, 0.7956569 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1. , 0.66666667, 0. ]), array([0.15752529, 0.8162206 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1., 1., 0.]), array([0.1422766, 0.8069754], dtype=float32))], 'y_hats_': [(array([ True, True, True, False, False, False, False]), array([1, 1, 1, 1, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 0, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 1, 1, 0, 0])), (array([ True, True, True, False, False, False, False]), array([1, 1, 1, 1, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 0, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 1, 0, 0, 0]))], 'feature_importances_': [{'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'AAA': 0.0717181, 'CCC': 0.9282819, 'BBB': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'AAA': 0.0717181, 'CCC': 0.9282819, 'BBB': 0.0}], 'features_': []}""" +expected_cv_curves_str = """{'pr_auc': [], 'roc_curves_': [(array([0. , 0.25, 1. ]), array([0., 1., 1.]), array([1.7956569 , 0.7956569 , 0.20434304], dtype=float32)), (array([0., 0., 1.]), array([0. , 0.66666667, 1. ]), array([1.8162205 , 0.8162206 , 0.15752529], dtype=float32)), (array([0. , 0. , 0.33333333, 1. ]), array([0., 1., 1., 1.]), array([1.8069754, 0.8069754, 0.502567 , 0.1422766], dtype=float32)), (array([0. , 0.25, 1. ]), array([0., 1., 1.]), array([1.7956569 , 0.7956569 , 0.20434304], dtype=float32)), (array([0., 0., 1.]), array([0. , 0.66666667, 1. ]), array([1.8162205 , 0.8162206 , 0.15752529], dtype=float32)), (array([0., 0., 1.]), array([0., 1., 1.]), array([1.8069754, 0.8069754, 0.1422766], dtype=float32))], 'pr_curves_': [(array([0.42857143, 0.75 , 1. ]), array([1., 1., 0.]), array([0.20434304, 0.7956569 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1. , 0.66666667, 0. ]), array([0.15752529, 0.8162206 ], dtype=float32)), (array([0.5 , 0.75, 1. , 1. ]), array([1., 1., 1., 0.]), array([0.1422766, 0.502567 , 0.8069754], dtype=float32)), (array([0.42857143, 0.75 , 1. ]), array([1., 1., 0.]), array([0.20434304, 0.7956569 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1. , 0.66666667, 0. ]), array([0.15752529, 0.8162206 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1., 1., 0.]), array([0.1422766, 0.8069754], dtype=float32))], 'y_hats_': [(array([ True, True, True, False, False, False, False]), array([1, 1, 1, 1, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 0, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 1, 1, 0, 0])), (array([ True, True, True, False, False, False, False]), array([1, 1, 1, 1, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 0, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 1, 0, 0, 0]))], 'feature_importances_': [{'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'AAA': 0.0717181, 'CCC': 0.9282819, 'BBB': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'AAA': 0.0717181, 'CCC': 0.9282819, 'BBB': 0.0}], 'features_': []}"""