This script demonstrates a deep learning workflow for medical image classification using the MedMNIST dataset. It integrates PyTorch for model training and evaluation and SHAP for explainability. Below is a concise overview:
- Dataset and Preprocessing**
- Uses the MedMNIST API to load datasets (e.g.,
pathmnist
) with preprocessing (tensor conversion and normalization). - Data is organized into train and test loaders for efficient batching.
- CNN Model**
- Defines a convolutional neural network (
Net
) with:- Convolutional Layers: Feature extraction using ReLU, BatchNorm, and MaxPooling.
- Fully Connected Layers: Classification into dataset-specific classes.
- Loads a pre-trained model (
cnnmodel.pt
) for evaluation.
- SHAP Explainability**
- Creates multiple SHAP explainers (
DeepExplainer
) using subsets of the training data. - Computes Shapley values to quantify feature (pixel) importance for predictions.
- Saves results (image indices, labels, Shapley values) into CSV files for analysis.
Key Features**
- Parameters: Configurable batch size, learning rate, and epochs.
- Output: Model predictions and interpretable Shapley values for each image.
This workflow combines accurate classification with explainability, making it useful for medical AI applications where interpretability is essential.