Skip to content

Commit

Permalink
Add save and load_from_path methods to VanillaCF
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Nov 8, 2023
1 parent 915252e commit 63db944
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 8 deletions.
9 changes: 9 additions & 0 deletions nbs/methods/00_base.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Base API"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -97,6 +104,8 @@
" \"set_compute_reg_loss_fn\",\n",
" \"apply_constraints\",\n",
" \"compute_reg_loss\",\n",
" \"save\",\n",
" \"load_from_path\",\n",
" \"before_generate_cf\",\n",
" \"generate_cf\"\n",
" ]"
Expand Down
46 changes: 42 additions & 4 deletions nbs/methods/01_vanilla.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
Expand Down Expand Up @@ -160,6 +160,14 @@
" name = \"VanillaCF\" if name is None else name\n",
" super().__init__(config, name=name)\n",
"\n",
" def save(self, path: str):\n",
" self.config.save(Path(path) / 'config.json')\n",
" \n",
" @classmethod\n",
" def load_from_path(cls, path: str):\n",
" config = VanillaCFConfig.load_from_json(Path(path) / 'config.json')\n",
" return cls(config=config)\n",
"\n",
" @auto_reshaping('x')\n",
" def generate_cf(\n",
" self,\n",
Expand Down Expand Up @@ -207,7 +215,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "79ad3ff2955b4d9ba79f2a56ac8fd390",
"model_id": "2fbb0bca45d843499a67a6645ff388d7",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -221,7 +229,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5c624088b27b474ebc26beaf14431901",
"model_id": "55b5780c4ffd4f12874997a82faf806f",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -262,7 +270,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e4ad8906413b49ca90178fcceb82a7d6",
"model_id": "e48714e145a8478ab5444c89a8f9201b",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -298,6 +306,36 @@
").mean())\n",
"assert (cfs >= 0).all() and (cfs <= 1).all()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6f72b5c4f24342feb1b2ba088180b634",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/100 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"vcf.save('tmp/vanillacf/')\n",
"vcf_1 = VanillaCF.load_from_path('tmp/vanillacf/')\n",
"vcf_1.set_apply_constraints_fn(apply_constraint_fn)\n",
"partial_gen_1 = ft.partial(vcf_1.generate_cf, pred_fn=model.pred_fn)\n",
"cfs_1 = jax.vmap(partial_gen_1)(xs_test)\n",
"\n",
"assert jnp.allclose(cfs, cfs_1)"
]
}
],
"metadata": {
Expand Down
4 changes: 4 additions & 0 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,10 @@
'relax/methods/vanilla.py'),
'relax.methods.vanilla.VanillaCF.generate_cf': ( 'methods/vanilla.html#vanillacf.generate_cf',
'relax/methods/vanilla.py'),
'relax.methods.vanilla.VanillaCF.load_from_path': ( 'methods/vanilla.html#vanillacf.load_from_path',
'relax/methods/vanilla.py'),
'relax.methods.vanilla.VanillaCF.save': ( 'methods/vanilla.html#vanillacf.save',
'relax/methods/vanilla.py'),
'relax.methods.vanilla.VanillaCFConfig': ( 'methods/vanilla.html#vanillacfconfig',
'relax/methods/vanilla.py'),
'relax.methods.vanilla._vanilla_cf': ( 'methods/vanilla.html#_vanilla_cf',
Expand Down
10 changes: 6 additions & 4 deletions relax/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
# %% auto 0
__all__ = ['CFModule', 'ParametricCFModule']

# %% ../../nbs/methods/00_base.ipynb 1
# %% ../../nbs/methods/00_base.ipynb 2
from ..import_essentials import *
from ..base import BaseConfig, BaseModule, PredFnMixedin, TrainableMixedin

# %% ../../nbs/methods/00_base.ipynb 2
# %% ../../nbs/methods/00_base.ipynb 3
def default_apply_constraints_fn(x, cf, hard, **kwargs):
return cf

def default_compute_reg_loss_fn(x, cf, **kwargs):
return 0.

# %% ../../nbs/methods/00_base.ipynb 3
# %% ../../nbs/methods/00_base.ipynb 4
class CFModule(BaseModule):
"""Base class for all counterfactual modules."""

Expand Down Expand Up @@ -71,11 +71,13 @@ def generate_cf(
"set_compute_reg_loss_fn",
"apply_constraints",
"compute_reg_loss",
"save",
"load_from_path",
"before_generate_cf",
"generate_cf"
]

# %% ../../nbs/methods/00_base.ipynb 4
# %% ../../nbs/methods/00_base.ipynb 5
class ParametricCFModule(CFModule, TrainableMixedin):
"""Base class for parametric counterfactual modules."""

Expand Down
8 changes: 8 additions & 0 deletions relax/methods/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def __init__(
name = "VanillaCF" if name is None else name
super().__init__(config, name=name)

def save(self, path: str):
self.config.save(Path(path) / 'config.json')

@classmethod
def load_from_path(cls, path: str):
config = VanillaCFConfig.load_from_json(Path(path) / 'config.json')
return cls(config=config)

@auto_reshaping('x')
def generate_cf(
self,
Expand Down

0 comments on commit 63db944

Please sign in to comment.