Skip to content

Latest commit

 

History

History
91 lines (65 loc) · 4.04 KB

README.md

File metadata and controls

91 lines (65 loc) · 4.04 KB

TABCF: Counterfactual Explanations for Tabular Data Using a Transformer-Based VAE

This paper has been presented at the 5th ACM International Conference on AI in Finance (ICAIF '24), November 14-17, 2024, Brooklyn, NY, USA. For a detailed explanation, you can read our full paper here: TABCF Paper.

Figure 1

Abstract:

In the field of Explainable AI (XAI), counterfactual (CF) explanations are one prominent method to interpret a black-box model by suggesting changes to the input that would alter a prediction. In real-world applications, the input is predominantly in tabular form and comprised of mixed data types and complex feature interdependencies. These unique data characteristics are difficult to model, and we empirically show that they lead to bias towards specific feature types when generating CFs. To overcome this issue, we introduce TABCF, a CF explanation method that leverages a transformer-based Variational Autoencoder (VAE) tailored for modeling tabular data. Our approach uses transformers to learn a continuous latent space and a novel Gumbel-Softmax detokenizer that enables precise categorical reconstruction while preserving end-to-end differentiability. Extensive quantitative evaluation on five financial datasets demonstrates that TABCF does not exhibit bias toward specific feature types, and outperforms existing methods in producing effective CFs that align with common CF desiderata.

Setup

Create a conda environment

conda create -n tabcf python=3.10
conda activate tabcf

Install dependencies

pip install -r requirements.txt

Download and process datasets

python download_dataset.py
python process_dataset.py

Install local DiCE optimization framework and local CARLA framework

cd baselines/dice/DiCE-main
pip install -e .
cd baselines/CARLA
pip install -e .

Train VAE model

python main.py --dataname [NAME_OF_DATASET] --method vae --mode train

Usage

To generate counterfactuals with TABCF run the following command:

python main.py --dataname [NAME_OF_DATASET] --method tabcf --mode sample --num_samples [NUMBER_OF_SAMPLES]

The same command can be used to generate counterfactuals using a competitor method, e.g.:

python main.py --dataname [NAME_OF_DATASET] --method wachter --mode sample --num_samples [NUMBER_OF_SAMPLES]

The resulting counterfactuals are saved as .csv files under the directory /counterfactual_results.

For calculation of all metrics (evaluation), given the generated csv files, run the following command:

python main.py --dataname [NAME_OF_DATASET] --method wachter --mode evaluate --num_samples [NUMBER_OF_SAMPLES]

General code structure

Acknowledgments

In our code we use and alter versions of the following repositories.

  • TABSYN (VAE structure)
  • DiCE (under /baselines/dice/DiCE-main), SGD optimization for counterfactual generation.
  • CARLA (under /baselines), the CARLA framework is used for evaluation against baseline methods.