Skip to content

Codes for An explainable knowledge distillation method with XGBoost for ICU mortality prediction.

Notifications You must be signed in to change notification settings

Liumucan/XGB-KD-for-ICU-Mortality-Prediction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Dependencies

For a straight-forward use of XGB-KD, you can install the required libraries from requirements.txt: pip install -r requirements.txt

Dataset

We made our experiments on MIMIC-III dataset, which is a publicly available dataset. You can access to the dataset through https://mimic.physionet.org/gettingstarted/access/ .

After downloading the MIMIC-III dataset, you can use the following Github repo for the pre-processing: YerevaNN/mimic3-benchmarks: Python suite to construct benchmark machine learning datasets from the MIMIC-III 💊 clinical database. (github.com)

That project provides preprocessed multivariate time series data and benchmark deep learning models. You can generate structured features from the multivariate time series and save them as a pickle file. We put the dataset file dataset_dict.pkl in ./X_knowledge_distill/load_data. The file store a python dictionary which has following form:

{
   'train':{'train_X': array([[            nan,             nan,             nan, ...,
         1.49999997e-02, -8.88443439e-14,  2.00000000e+00],
       [            nan,             nan,             nan, ...,
         1.24721909e-02,  3.81801784e-01,  3.00000000e+00],
       [            nan,             nan,             nan, ...,
         1.24721909e-02,  3.81801784e-01,  3.00000000e+00],
       ...,
       [            nan,             nan,             nan, ...,
                    nan,             nan,             nan],
       [            nan,             nan,             nan, ...,
                    nan,             nan,             nan],
       [            nan,             nan,             nan, ...,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00]]),
           'train_y':  [0, 1, 0, 0, 0, 0,...],
           'train_names': ['12797_episode1_timeseries.csv',
                           '9027_episode1_timeseries.csv',
                           '40386_episode1_timeseries.csv',
                           '48770_episode1_timeseries.csv',...]},
    'val':{'val_X': array([[        nan,         nan,         nan, ...,  0.08458041,
        -0.74828583, 13.        ],
       [        nan,         nan,         nan, ...,  0.        ,
         0.        ,  1.        ],
       [        nan,         nan,         nan, ...,  0.05117372,
         0.68207926,  4.        ],
       ...,
       [        nan,         nan,         nan, ...,  0.        ,
         0.        ,  2.        ],
       [        nan,         nan,         nan, ...,  0.        ,
         0.        ,  1.        ],
       [        nan,         nan,         nan, ...,  0.        ,
         0.        ,  2.        ]]),
          'val_y': [0, 0, 0, 0, 0, 1, ...],
          'val_names': ['76541_episode1_timeseries.csv',
                        '22933_episode1_timeseries.csv',
                        '24771_episode1_timeseries.csv',...]},
    'test':{'test_X': array([[            nan,             nan,             nan, ...,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00],
       [            nan,             nan,             nan, ...,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00],
       [            nan,             nan,             nan, ...,
                    nan,             nan,             nan],
       ...,
       [            nan,             nan,             nan, ...,
         2.81735696e-02, -1.02709770e+00,  8.00000000e+00],
       [            nan,             nan,             nan, ...,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00],
       [            nan,             nan,             nan, ...,
         1.99999996e-02,  6.65979655e-14,  4.00000000e+00]]),
            'test_y': [1, 0, 0, 0, 0, 1, ...],
            'test_names': ['10011_episode1_timeseries.csv',
                           '10026_episode1_timeseries.csv',
                           '10030_episode1_timeseries.csv',
                           '10042_episode1_timeseries.csv',...]
           }
}

After running the teacher models, you can generate the soft labels and save them in a y_train_soft_labels.pkl file. Then, you can put those files generated by deep learning teacher models with different structures in ./X_knowledge_distill/teachers_predictions.

Example Usage

For training XGB-KD, you can use the following statement to run the program:

python -um mimic3models.in_hospital_mortality.X_knowledge_distill.main --label_fusion label_fusion --y_proba_file y_proba_file --Temperature Temperature --alpha alpha

in which label_fusion identifies whether to use the soft labels from a single teacher model or ensemble teacher models; y_proba_file is the file of teacher models soft labels; Temperature and alpha are hyperparameters in the knowledge distillation process.

For convenience, you can directly run the ./X_knowledge_distill/run_with_shell.py program to conduct the experiment.

About

Codes for An explainable knowledge distillation method with XGBoost for ICU mortality prediction.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages