diff --git a/nbs/methods/00_base.ipynb b/nbs/methods/00_base.ipynb index a4c7c9d..3a2885b 100644 --- a/nbs/methods/00_base.ipynb +++ b/nbs/methods/00_base.ipynb @@ -1,5 +1,12 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Base API" + ] + }, { "cell_type": "code", "execution_count": null, @@ -97,6 +104,8 @@ " \"set_compute_reg_loss_fn\",\n", " \"apply_constraints\",\n", " \"compute_reg_loss\",\n", + " \"save\",\n", + " \"load_from_path\",\n", " \"before_generate_cf\",\n", " \"generate_cf\"\n", " ]" diff --git a/nbs/methods/01_vanilla.ipynb b/nbs/methods/01_vanilla.ipynb index 0e4186c..90423bd 100644 --- a/nbs/methods/01_vanilla.ipynb +++ b/nbs/methods/01_vanilla.ipynb @@ -44,7 +44,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" ] } ], @@ -160,6 +160,14 @@ " name = \"VanillaCF\" if name is None else name\n", " super().__init__(config, name=name)\n", "\n", + " def save(self, path: str):\n", + " self.config.save(Path(path) / 'config.json')\n", + " \n", + " @classmethod\n", + " def load_from_path(cls, path: str):\n", + " config = VanillaCFConfig.load_from_json(Path(path) / 'config.json')\n", + " return cls(config=config)\n", + "\n", " @auto_reshaping('x')\n", " def generate_cf(\n", " self,\n", @@ -207,7 +215,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "79ad3ff2955b4d9ba79f2a56ac8fd390", + "model_id": "2fbb0bca45d843499a67a6645ff388d7", "version_major": 2, "version_minor": 0 }, @@ -221,7 +229,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5c624088b27b474ebc26beaf14431901", + "model_id": "55b5780c4ffd4f12874997a82faf806f", "version_major": 2, "version_minor": 0 }, @@ -262,7 +270,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e4ad8906413b49ca90178fcceb82a7d6", + "model_id": "e48714e145a8478ab5444c89a8f9201b", "version_major": 2, "version_minor": 0 }, @@ -298,6 +306,36 @@ ").mean())\n", "assert (cfs >= 0).all() and (cfs <= 1).all()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6f72b5c4f24342feb1b2ba088180b634", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/100 [00:00