diff --git a/nbs/methods/05_sphere.ipynb b/nbs/methods/05_sphere.ipynb index 8375bec..9a3b092 100644 --- a/nbs/methods/05_sphere.ipynb +++ b/nbs/methods/05_sphere.ipynb @@ -57,9 +57,10 @@ "#| export\n", "from __future__ import annotations\n", "from relax.import_essentials import *\n", - "from relax.methods.base import CFModule\n", + "from relax.methods.base import CFModule, BaseConfig\n", "from relax.utils import auto_reshaping, grad_update, validate_configs\n", - "from relax.data_utils import Feature, FeaturesList" + "from relax.data_utils import Feature, FeaturesList\n", + "from relax.data_module import DataModule" ] }, { @@ -342,7 +343,7 @@ "outputs": [], "source": [ "#| export\n", - "class GSConfig(BaseParser):\n", + "class GSConfig(BaseConfig):\n", " n_steps: int = 100\n", " n_samples: int = 300\n", " step_size: float = 0.05\n", @@ -365,9 +366,26 @@ " self.perturb_fn = perturb_fn\n", " super().__init__(config, name=name)\n", "\n", + " def has_data_module(self):\n", + " return hasattr(self, 'data_module') and self.data_module is not None\n", + " \n", + " def save(self, path: str, *, save_data_module: bool = True):\n", + " self.config.save(Path(path) / 'config.json')\n", + " if self.has_data_module() and save_data_module:\n", + " self.data_module.save(Path(path) / 'data_module')\n", + " \n", + " @classmethod\n", + " def load_from_path(cls, path: str):\n", + " config = GSConfig.load_from_json(Path(path) / 'config.json')\n", + " gs = cls(config=config)\n", + " if (Path(path) / 'data_module').exists():\n", + " dm = DataModule.load_from_path(Path(path) / 'data_module')\n", + " gs.set_data_module(dm)\n", + " return gs\n", + "\n", " def before_generate_cf(self, *args, **kwargs):\n", " if self.perturb_fn is None:\n", - " if hasattr(self, 'data_module'):\n", + " if self.has_data_module():\n", " feats_info, perturb_fn = features_to_infos_and_perturb_fn(self.data_module.features)\n", " self.perturb_fn = ft.partial(\n", " perturb_function_with_features, \n", @@ -429,7 +447,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "92ff72c0cd264e5ea278c75310a93588", + "model_id": "54ebe22c890f47e5b50654e5a9461831", "version_major": 2, "version_minor": 0 }, @@ -443,7 +461,9 @@ ], "source": [ "gs = GrowingSphere()\n", + "assert not gs.has_data_module()\n", "gs.set_data_module(dm)\n", + "assert gs.has_data_module()\n", "gs.set_apply_constraints_fn(dm.apply_constraints)\n", "gs.before_generate_cf()\n", "\n", @@ -458,7 +478,43 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1291d10cbc9c4e63adc7fd1cae9bfcdd", + "model_id": "900d02459f67488293174606b72b25a9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/100 [00:00