Skip to content

Commit

Permalink
explained types of recourse methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Praneyg committed Nov 8, 2023
1 parent 13762fc commit a2c426b
Showing 1 changed file with 107 additions and 9 deletions.
116 changes: 107 additions & 9 deletions nbs/tutorials/methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"* You can use all of the recourse methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n",
"\n",
"In this tutorial, we uncover the possibility of the second option by using recourse methods under `relax.methods` \n",
"for debugging, diagnoising, interpreting your JAX models.\n"
"for debugging, diagnosing, interpreting your JAX models.\n"
]
},
{
Expand All @@ -27,11 +27,18 @@
"source": [
"## Types of Recourse Methods\n",
"\n",
"TODO: Describe the difference between non-parametric, semi-parametric, and parametric methods. \n",
"What it means conceptually (text and formula), and what it means in terms of code \n",
"(e.g., parametric methods inherites `ParametricCFModule`). \n",
"1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include VanillaCF and GrowingSpheres. These methods inherit from NonParametricCFModule.\n",
"\n",
"TODO: Include a table to describe the difference."
"2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include DiverseCF and ProtoCF. These methods inherit from SemiParametricCFModule.\n",
"\n",
"3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include CounterNet, CCHVAE, VAECF and CLUE. These methods inherit from ParametricCFModule.\n",
"\n",
"\n",
"|Method Type | Learned Parameters | Training Required | Example Methods | \n",
"|-----|:-----|:---:|:-----:|\n",
"|Non-parametric |None |No |VanillaCF, GrowingSpheres |\n",
"|Semi-parametric|Some (θ)|Modest amount |DiverseCF, ProtoCF |\n",
"|Parametric|Full generator model (φ)|Substantial amount|CounterNet, CCHVAE, VAECF, CLUE|"
]
},
{
Expand Down Expand Up @@ -59,9 +66,39 @@
"vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n",
"# xs is a batched data. Shape: `(N, K)`\n",
"cfs = jax.vmap(vcf_gen_fn)(xs)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Example of using ReLax for parametric methods (using CCHVAE)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"from relax.methods import CCHVAE\n",
"\n",
"cchvae = CCHVAE()\n",
"# x is one data point. Shape: `(K)` or `(1, K)`\n",
"cf = vcf.generate_cf(x, pred_fn=pred_fn)\n",
"```\n",
"\n",
"TODO: Also show examples of using parametric methods (using `CCHVAE` as an example)."
"Or generate a batch of recourse explanation via the `jax.vmap` primitive:\n",
"\n",
"```python\n",
"...\n",
"import functools as ft\n",
"\n",
"cchvae_gen_fn = ft.partial(cchvae.generate_cf, pred_fn=pred_fn)\n",
"cfs = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactuals\n",
"\n",
"```"
]
},
{
Expand All @@ -70,8 +107,69 @@
"source": [
"## Config Recourse Methods\n",
"\n",
"TODO: Refer to [this link](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html#load-dataset-with-datamodule)\n",
"on how config works in `ReLax`. It is similar on how to config recourse methods here."
"Each recourse method in ReLax has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.\n",
"\n",
"For example, to configure VanillaCF:\n",
"\n",
"```Python\n",
"from relax.methods import VanillaCF \n",
"from relax.methods.vanilla import VanillaCFConfig\n",
"\n",
"config = VanillaCFConfig(\n",
" n_steps=100,\n",
" lr=0.1,\n",
" lambda_=0.1\n",
")\n",
"\n",
"vcf = VanillaCF(config)\n",
"\n",
"```\n",
"Each Config class inherits from a BaseConfig that defines common options like num_cfs. Method-specific parameters are defined on the individual Config classes.\n",
"\n",
"See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from relax.methods.[method_name]."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, we can also specify this config via a dictionary.\n",
"\n",
"```Python\n",
"from relax.methods import VanillaCF\n",
"\n",
"config = {\n",
" \"num_cfs\": 10, \n",
" \"epsilon\": 0.01,\n",
" \"lr\": 0.1 \n",
"}\n",
"\n",
"vcf = VanillaCF(config)\n",
"```\n",
"\n",
"This config dictionary is passed to VanillaCF's __init__ method, which will set the specified parameters. Now our VanillaCF instance is configured to:\n",
"\n",
" * Generate 10 counterfactuals per input (num_cfs=10)\n",
" * Use a maximum perturbation of 0.01 (epsilon=0.01)\n",
" * Use a learning rate of 0.1 for optimization (lr=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Modifying at Runtime:\n",
"\n",
"The configuration can also be updated after constructing the recourse method:\n",
"\n",
"```Python\n",
"vcf = VanillaCF() \n",
"\n",
"# Later, modify config\n",
"vcf.config[\"lr\"] = 0.5\n",
"\n",
"```\n",
"This allows dynamically adjusting the configuration as needed."
]
}
],
Expand All @@ -83,5 +181,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

0 comments on commit a2c426b

Please sign in to comment.