Skip to content

Commit

Permalink
Add DataModule support and save/load functionality
Browse files Browse the repository at this point in the history
to GrowingSphere class
  • Loading branch information
BirkhoffG committed Nov 8, 2023
1 parent c79e07e commit dca3e3e
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 9 deletions.
68 changes: 62 additions & 6 deletions nbs/methods/05_sphere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@
"#| export\n",
"from __future__ import annotations\n",
"from relax.import_essentials import *\n",
"from relax.methods.base import CFModule\n",
"from relax.methods.base import CFModule, BaseConfig\n",
"from relax.utils import auto_reshaping, grad_update, validate_configs\n",
"from relax.data_utils import Feature, FeaturesList"
"from relax.data_utils import Feature, FeaturesList\n",
"from relax.data_module import DataModule"
]
},
{
Expand Down Expand Up @@ -342,7 +343,7 @@
"outputs": [],
"source": [
"#| export\n",
"class GSConfig(BaseParser):\n",
"class GSConfig(BaseConfig):\n",
" n_steps: int = 100\n",
" n_samples: int = 300\n",
" step_size: float = 0.05\n",
Expand All @@ -365,9 +366,26 @@
" self.perturb_fn = perturb_fn\n",
" super().__init__(config, name=name)\n",
"\n",
" def has_data_module(self):\n",
" return hasattr(self, 'data_module') and self.data_module is not None\n",
" \n",
" def save(self, path: str, *, save_data_module: bool = True):\n",
" self.config.save(Path(path) / 'config.json')\n",
" if self.has_data_module() and save_data_module:\n",
" self.data_module.save(Path(path) / 'data_module')\n",
" \n",
" @classmethod\n",
" def load_from_path(cls, path: str):\n",
" config = GSConfig.load_from_json(Path(path) / 'config.json')\n",
" gs = cls(config=config)\n",
" if (Path(path) / 'data_module').exists():\n",
" dm = DataModule.load_from_path(Path(path) / 'data_module')\n",
" gs.set_data_module(dm)\n",
" return gs\n",
"\n",
" def before_generate_cf(self, *args, **kwargs):\n",
" if self.perturb_fn is None:\n",
" if hasattr(self, 'data_module'):\n",
" if self.has_data_module():\n",
" feats_info, perturb_fn = features_to_infos_and_perturb_fn(self.data_module.features)\n",
" self.perturb_fn = ft.partial(\n",
" perturb_function_with_features, \n",
Expand Down Expand Up @@ -429,7 +447,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "92ff72c0cd264e5ea278c75310a93588",
"model_id": "54ebe22c890f47e5b50654e5a9461831",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -443,7 +461,9 @@
],
"source": [
"gs = GrowingSphere()\n",
"assert not gs.has_data_module()\n",
"gs.set_data_module(dm)\n",
"assert gs.has_data_module()\n",
"gs.set_apply_constraints_fn(dm.apply_constraints)\n",
"gs.before_generate_cf()\n",
"\n",
Expand All @@ -458,7 +478,43 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1291d10cbc9c4e63adc7fd1cae9bfcdd",
"model_id": "900d02459f67488293174606b72b25a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/100 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"gs.save('tmp/gs/')\n",
"gs_1 = GrowingSphere.load_from_path('tmp/gs/')\n",
"assert gs_1.has_data_module()\n",
"gs_1.set_apply_constraints_fn(dm.apply_constraints)\n",
"gs_1.before_generate_cf()\n",
"\n",
"cf_1 = gs_1.generate_cf(xs_test[0], pred_fn=model.pred_fn, rng_key=jax.random.PRNGKey(0))\n",
"assert jnp.allclose(cf, cf_1)\n",
"\n",
"shutil.rmtree('tmp/gs/')\n",
"gs.save('tmp/gs/', save_data_module=False)\n",
"gs_2 = GrowingSphere.load_from_path('tmp/gs/')\n",
"assert not gs_2.has_data_module()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f3c3eac1cfaa458c9feedba3d0f98783",
"version_major": 2,
"version_minor": 0
},
Expand Down
6 changes: 6 additions & 0 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,12 @@
'relax/methods/sphere.py'),
'relax.methods.sphere.GrowingSphere.generate_cf': ( 'methods/sphere.html#growingsphere.generate_cf',
'relax/methods/sphere.py'),
'relax.methods.sphere.GrowingSphere.has_data_module': ( 'methods/sphere.html#growingsphere.has_data_module',
'relax/methods/sphere.py'),
'relax.methods.sphere.GrowingSphere.load_from_path': ( 'methods/sphere.html#growingsphere.load_from_path',
'relax/methods/sphere.py'),
'relax.methods.sphere.GrowingSphere.save': ( 'methods/sphere.html#growingsphere.save',
'relax/methods/sphere.py'),
'relax.methods.sphere._growing_spheres': ( 'methods/sphere.html#_growing_spheres',
'relax/methods/sphere.py'),
'relax.methods.sphere.cat_perturb_fn': ( 'methods/sphere.html#cat_perturb_fn',
Expand Down
24 changes: 21 additions & 3 deletions relax/methods/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# %% ../../nbs/methods/05_sphere.ipynb 3
from __future__ import annotations
from ..import_essentials import *
from .base import CFModule
from .base import CFModule, BaseConfig
from ..utils import auto_reshaping, grad_update, validate_configs
from ..data_utils import Feature, FeaturesList
from ..data_module import DataModule

# %% auto 0
__all__ = ['hyper_sphere_coordindates', 'sample_categorical', 'default_perturb_function', 'perturb_function_with_features',
Expand Down Expand Up @@ -215,7 +216,7 @@ def step(i, state):
return candidate_cf

# %% ../../nbs/methods/05_sphere.ipynb 12
class GSConfig(BaseParser):
class GSConfig(BaseConfig):
n_steps: int = 100
n_samples: int = 300
step_size: float = 0.05
Expand All @@ -232,9 +233,26 @@ def __init__(self, config: dict | GSConfig = None, *, name: str = None, perturb_
self.perturb_fn = perturb_fn
super().__init__(config, name=name)

def has_data_module(self):
return hasattr(self, 'data_module') and self.data_module is not None

def save(self, path: str, *, save_data_module: bool = True):
self.config.save(Path(path) / 'config.json')
if self.has_data_module() and save_data_module:
self.data_module.save(Path(path) / 'data_module')

@classmethod
def load_from_path(cls, path: str):
config = GSConfig.load_from_json(Path(path) / 'config.json')
gs = cls(config=config)
if (Path(path) / 'data_module').exists():
dm = DataModule.load_from_path(Path(path) / 'data_module')
gs.set_data_module(dm)
return gs

def before_generate_cf(self, *args, **kwargs):
if self.perturb_fn is None:
if hasattr(self, 'data_module'):
if self.has_data_module():
feats_info, perturb_fn = features_to_infos_and_perturb_fn(self.data_module.features)
self.perturb_fn = ft.partial(
perturb_function_with_features,
Expand Down

0 comments on commit dca3e3e

Please sign in to comment.