From af779256beb13bcbbc683572dd0c67ab421ab857 Mon Sep 17 00:00:00 2001 From: Firdaus Choudhury Date: Fri, 10 Nov 2023 02:52:15 -0500 Subject: [PATCH 1/3] Implemented set_config and ran test cases --- nbs/00_utils.ipynb | 337 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 282 insertions(+), 55 deletions(-) diff --git a/nbs/00_utils.ipynb b/nbs/00_utils.ipynb index 86e64a3..0d71ba9 100644 --- a/nbs/00_utils.ipynb +++ b/nbs/00_utils.ipynb @@ -12,8 +12,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:01.508540797Z", + "start_time": "2023-11-10T06:38:01.467755314Z" + } + }, "outputs": [], "source": [ "#| default_exp utils" @@ -21,8 +26,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:02.115318558Z", + "start_time": "2023-11-10T06:38:02.085323051Z" + } + }, "outputs": [], "source": [ "#| hide\n", @@ -36,8 +46,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:07.238516713Z", + "start_time": "2023-11-10T06:38:02.798672829Z" + } + }, "outputs": [ { "name": "stdout", @@ -68,8 +83,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:09.186196715Z", + "start_time": "2023-11-10T06:38:09.144526819Z" + } + }, "outputs": [], "source": [ "#| export\n", @@ -104,8 +124,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:12.853148753Z", + "start_time": "2023-11-10T06:38:12.739172987Z" + } + }, "outputs": [], "source": [ "class LearningConfigs(BaseParser):\n", @@ -121,8 +146,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:14.348411924Z", + "start_time": "2023-11-10T06:38:14.303681198Z" + } + }, "outputs": [], "source": [ "configs_dict = dict(lr=0.01)" @@ -137,8 +167,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:15.305738543Z", + "start_time": "2023-11-10T06:38:15.264866659Z" + } + }, "outputs": [], "source": [ "configs = validate_configs(configs_dict, LearningConfigs)\n", @@ -148,8 +183,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:15.935419169Z", + "start_time": "2023-11-10T06:38:15.904612109Z" + } + }, "outputs": [], "source": [ "#| include: false\n", @@ -160,8 +200,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:16.506981806Z", + "start_time": "2023-11-10T06:38:16.472002790Z" + } + }, "outputs": [], "source": [ "#| hide\n", @@ -198,8 +243,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:17.649017551Z", + "start_time": "2023-11-10T06:38:17.592457500Z" + } + }, "outputs": [], "source": [ "#| export\n", @@ -231,8 +281,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:19.296940315Z", + "start_time": "2023-11-10T06:38:19.208847906Z" + } + }, "outputs": [], "source": [ "pytree = {\n", @@ -256,18 +311,23 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T06:38:20.586706908Z", + "start_time": "2023-11-10T06:38:20.537278805Z" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "data: [array([[-0.4150979 ],\n", - " [-0.59805975],\n", - " [-0.59252158],\n", - " [-0.88781678],\n", - " [ 0.08100867]]), 1, True, 'Hello', array(['a', 'b', 'c'], dtype=' Config: \n", " return main_config" ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T07:27:54.894610863Z", + "start_time": "2023-11-10T07:27:54.764399917Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "#| export\n", + "def set_config(\n", + " *,\n", + " rng_reserve_size: int=None,\n", + " global_seed: int=None,\n", + " **kwargs\n", + ") -> None:\n", + " \n", + " if not kwargs:\n", + " #set to default if no arguments are passed\n", + " #Can change to provide error\n", + " main_config.rng_reserve_size = Config.default().rng_reserve_size\n", + " main_config.global_seed = Config.default().global_seed\n", + " \n", + " if rng_reserve_size is not None:\n", + " if not isinstance(rng_reserve_size, int):\n", + " raise TypeError(f\"`rng_reserve_size` must be an integer, but got {type(rng_reserve_size).__name__}.\")\n", + " if rng_reserve_size < 0:\n", + " raise ValueError(f\"`rng_reserve_size` must be non-negative, but got {rng_reserve_size}.\")\n", + " main_config.rng_reserve_size = rng_reserve_size\n", + " \n", + " if global_seed is not None:\n", + " if not isinstance(global_seed, int):\n", + " raise TypeError(f\"`global_seed` must be an integer, but got {type(global_seed).__name__}.\")\n", + " if global_seed < 0:\n", + " raise ValueError(f\"`global_seed` must be non-negative, but got {global_seed}.\")\n", + " main_config.global_seed = global_seed\n", + " \n", + " for k, v in kwargs.items():\n", + " #check if the config name is valid\n", + " if not hasattr(main_config, k):\n", + " raise ValueError(f\"Invalid config name: {k}.\")\n", + " \n", + " if k == \"rng_reserve_size\":\n", + " if not isinstance(v, int):\n", + " raise ValueError(f\"`rng_reserve_size` must be an integer, but got {type(rng_reserve_size).__name__}.\")\n", + " if v < 0:\n", + " raise ValueError(f\"`rng_reserve_size` must be non-negative, but got {rng_reserve_size}.\")\n", + "\n", + " elif k == \"global_seed\":\n", + " if not isinstance(v, int):\n", + " raise ValueError(f\"`global_seed` must be an integer, but got {type(global_seed).__name__}.\")\n", + " if v < 0:\n", + " raise ValueError(f\"`global_seed` must be non-negative, but got {global_seed}.\")\n", + "\n", + " setattr(main_config, k, v)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-10T07:27:55.725263458Z", + "start_time": "2023-11-10T07:27:55.691327435Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# Generic Test cases\n", + "set_config(rng_reserve_size=100)\n", + "assert get_config().rng_reserve_size == 100\n", + "set_config(global_seed=1234)\n", + "assert get_config().global_seed == 1234\n", + "set_config(rng_reserve_size=2, global_seed=234)\n", + "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n", + "set_config()\n", + "assert get_config() == Config.default()\n", + "#Tests for invalid inputs\n", + "test_fail(set_config, kwargs={'rng_reserve_size': -1}, contains='must be non-negative')\n", + "test_fail(set_config, kwargs={'rng_reserve_size': 22.7}, contains='must be an integer')\n", + "test_fail(set_config, kwargs={'global_seed': -4}, contains='must be non-negative')\n", + "test_fail(set_config, kwargs={'global_seed': 3.14}, contains='must be an integer')\n", + "test_fail(set_config, kwargs={'random': 3}, contains='Invalid config name')\n" + ] } ], "metadata": { "kernelspec": { - "display_name": "python3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } From 981db427a5d88fe0fcbd0dce6ee30550e3b665d0 Mon Sep 17 00:00:00 2001 From: Firdaus Choudhury Date: Mon, 13 Nov 2023 15:10:11 -0500 Subject: [PATCH 2/3] Modified set_config and added arg_check --- nbs/00_utils.ipynb | 473 +++++++++++++++++++++++---------------------- 1 file changed, 247 insertions(+), 226 deletions(-) diff --git a/nbs/00_utils.ipynb b/nbs/00_utils.ipynb index 0d71ba9..9fc4360 100644 --- a/nbs/00_utils.ipynb +++ b/nbs/00_utils.ipynb @@ -12,11 +12,11 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2023-11-10T06:38:01.508540797Z", - "start_time": "2023-11-10T06:38:01.467755314Z" + "end_time": "2023-11-13T19:56:19.376220024Z", + "start_time": "2023-11-13T19:56:19.375715707Z" } }, "outputs": [], @@ -26,14 +26,18 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:02.115318558Z", - "start_time": "2023-11-10T06:38:02.085323051Z" + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] } - }, - "outputs": [], + ], "source": [ "#| hide\n", "%load_ext autoreload\n", @@ -46,19 +50,22 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:07.238516713Z", - "start_time": "2023-11-10T06:38:02.798672829Z" - } - }, + "execution_count": 7, + "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using JAX backend.\n" + "ename": "ModuleNotFoundError", + "evalue": "No module named 'matplotlib'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| export\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrelax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mimport_essentials\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnbdev\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfastcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbasics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AttrDict\n", + "File \u001b[0;32m~/UniversityFiles/RAISE_LAB/jax-relax/nbs/../relax/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m __version__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m0.2.0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata_module\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataModule, DataModuleConfig, load_data\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m(\n\u001b[1;32m 5\u001b[0m Feature, FeaturesList\n\u001b[1;32m 6\u001b[0m )\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mml_model\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MLModule, MLModuleConfig, load_ml_module\n", + "File \u001b[0;32m~/UniversityFiles/RAISE_LAB/jax-relax/nbs/../relax/data_module.py:5\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_data.ipynb.\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# %% ../nbs/01_data.ipynb 3\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_json, validate_configs, get_config, save_pytree, load_pytree, get_config\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbase\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n", + "File \u001b[0;32m~/UniversityFiles/RAISE_LAB/jax-relax/nbs/../relax/utils.py:5\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_utils.ipynb.\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# %% ../nbs/00_utils.ipynb 3\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mimport_essentials\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnbdev\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfastcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbasics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AttrDict\n", + "File \u001b[0;32m~/UniversityFiles/RAISE_LAB/jax-relax/nbs/../relax/import_essentials.py:4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Cell\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# https://github.com/fastai/fastai/blob/master/fastai/imports.py\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mscipy\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Union,Optional,Dict,List,Tuple,Sequence,Mapping,Callable,Iterable,Any,NamedTuple,Literal\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mio\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01moperator\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01msys\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mos\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mre\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mmimetypes\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mcsv\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mitertools\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mjson\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mshutil\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mglob\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mpickle\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mtarfile\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mcollections\u001b[39;00m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'matplotlib'" ] } ], @@ -83,13 +90,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:09.186196715Z", - "start_time": "2023-11-10T06:38:09.144526819Z" - } - }, + "execution_count": 8, + "metadata": {}, "outputs": [], "source": [ "#| export\n", @@ -124,14 +126,21 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:12.853148753Z", - "start_time": "2023-11-10T06:38:12.739172987Z" + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'BaseParser' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mLearningConfigs\u001b[39;00m(\u001b[43mBaseParser\u001b[49m):\n\u001b[1;32m 2\u001b[0m lr: \u001b[38;5;28mfloat\u001b[39m\n", + "\u001b[0;31mNameError\u001b[0m: name 'BaseParser' is not defined" + ] } - }, - "outputs": [], + ], "source": [ "class LearningConfigs(BaseParser):\n", " lr: float" @@ -146,13 +155,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:14.348411924Z", - "start_time": "2023-11-10T06:38:14.303681198Z" - } - }, + "execution_count": 10, + "metadata": {}, "outputs": [], "source": [ "configs_dict = dict(lr=0.01)" @@ -167,14 +171,21 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:15.305738543Z", - "start_time": "2023-11-10T06:38:15.264866659Z" + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'LearningConfigs' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m configs \u001b[38;5;241m=\u001b[39m validate_configs(configs_dict, \u001b[43mLearningConfigs\u001b[49m)\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(configs) \u001b[38;5;241m==\u001b[39m LearningConfigs\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m configs\u001b[38;5;241m.\u001b[39mlr \u001b[38;5;241m==\u001b[39m configs_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m'\u001b[39m]\n", + "\u001b[0;31mNameError\u001b[0m: name 'LearningConfigs' is not defined" + ] } - }, - "outputs": [], + ], "source": [ "configs = validate_configs(configs_dict, LearningConfigs)\n", "assert type(configs) == LearningConfigs\n", @@ -183,13 +194,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:15.935419169Z", - "start_time": "2023-11-10T06:38:15.904612109Z" - } - }, + "execution_count": 12, + "metadata": {}, "outputs": [], "source": [ "#| include: false\n", @@ -200,13 +206,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:16.506981806Z", - "start_time": "2023-11-10T06:38:16.472002790Z" - } - }, + "execution_count": 13, + "metadata": {}, "outputs": [], "source": [ "#| hide\n", @@ -243,13 +244,8 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:17.649017551Z", - "start_time": "2023-11-10T06:38:17.592457500Z" - } - }, + "execution_count": 14, + "metadata": {}, "outputs": [], "source": [ "#| export\n", @@ -281,14 +277,21 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:19.296940315Z", - "start_time": "2023-11-10T06:38:19.208847906Z" + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[15], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m pytree \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[43mnp\u001b[49m\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m5\u001b[39m, \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mc\u001b[39m\u001b[38;5;124m'\u001b[39m: {\n\u001b[1;32m 5\u001b[0m \n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124md\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124me\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray([\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mc\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m }\n\u001b[1;32m 10\u001b[0m }\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] } - }, - "outputs": [], + ], "source": [ "pytree = {\n", " 'a': np.random.randn(5, 1),\n", @@ -311,24 +314,18 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:20.586706908Z", - "start_time": "2023-11-10T06:38:20.537278805Z" - } - }, + "execution_count": 16, + "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "data: [array([[-1.56906349],\n", - " [-0.18791164],\n", - " [-0.08231128],\n", - " [ 1.2647431 ],\n", - " [-2.12119984]]), 1, True, 'Hello', array(['a', 'b', 'c'], dtype=' 2\u001b[0m data, pytreedef \u001b[38;5;241m=\u001b[39m \u001b[43mjax\u001b[49m\u001b[38;5;241m.\u001b[39mtree_util\u001b[38;5;241m.\u001b[39mtree_flatten(pytree)\n\u001b[1;32m 3\u001b[0m pytreedef \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mtree_util\u001b[38;5;241m.\u001b[39mtree_map(\u001b[38;5;28;01mlambda\u001b[39;00m x: _is_array(x), pytree)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata: \u001b[39m\u001b[38;5;124m'\u001b[39m, data)\n", + "\u001b[0;31mNameError\u001b[0m: name 'jax' is not defined" ] } ], @@ -342,13 +339,8 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:21.344706738Z", - "start_time": "2023-11-10T06:38:21.318678110Z" - } - }, + "execution_count": 17, + "metadata": {}, "outputs": [], "source": [ "#| export\n", @@ -368,14 +360,21 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:22.059530789Z", - "start_time": "2023-11-10T06:38:22.024062961Z" + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[18], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Store a dictionary to disk\u001b[39;00m\n\u001b[1;32m 2\u001b[0m pytree \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[43mnp\u001b[49m\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m100\u001b[39m, \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mc\u001b[39m\u001b[38;5;124m'\u001b[39m: {\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124md\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124me\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray([\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mc\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m }\n\u001b[1;32m 10\u001b[0m }\n\u001b[1;32m 11\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 12\u001b[0m save_pytree(pytree, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] } - }, - "outputs": [], + ], "source": [ "# Store a dictionary to disk\n", "pytree = {\n", @@ -400,14 +399,21 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:22.900176104Z", - "start_time": "2023-11-10T06:38:22.878383648Z" + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[19], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Store a list to disk\u001b[39;00m\n\u001b[1;32m 2\u001b[0m pytree \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m----> 3\u001b[0m \u001b[43mnp\u001b[49m\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m100\u001b[39m, \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 4\u001b[0m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m1\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray([\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m3\u001b[39m])},\n\u001b[1;32m 5\u001b[0m \u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 6\u001b[0m [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m3\u001b[39m],\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgood\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 8\u001b[0m ]\n\u001b[1;32m 9\u001b[0m save_pytree(pytree, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 10\u001b[0m pytree_loaded \u001b[38;5;241m=\u001b[39m load_pytree(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] } - }, - "outputs": [], + ], "source": [ "# Store a list to disk\n", "pytree = [\n", @@ -432,14 +438,21 @@ }, { "cell_type": "code", - "execution_count": 16, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:23.913173104Z", - "start_time": "2023-11-10T06:38:23.877605477Z" + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'shutil' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[20], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| hide\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mshutil\u001b[49m\u001b[38;5;241m.\u001b[39mrmtree(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'shutil' is not defined" + ] } - }, - "outputs": [], + ], "source": [ "#| hide\n", "shutil.rmtree('tmp')" @@ -455,13 +468,8 @@ }, { "cell_type": "code", - "execution_count": 17, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:24.996603960Z", - "start_time": "2023-11-10T06:38:24.955404319Z" - } - }, + "execution_count": 21, + "metadata": {}, "outputs": [], "source": [ "#| exporti\n", @@ -479,13 +487,8 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:26.011283260Z", - "start_time": "2023-11-10T06:38:25.975995184Z" - } - }, + "execution_count": 22, + "metadata": {}, "outputs": [], "source": [ "#| export\n", @@ -536,14 +539,21 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:27.141513274Z", - "start_time": "2023-11-10T06:38:27.024133904Z" + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'vmap' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[23], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;129m@auto_reshaping\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mf_vmap\u001b[39m(x): \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;241m*\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m10\u001b[39m,))\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[43mvmap\u001b[49m(f_vmap)(jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m10\u001b[39m, \u001b[38;5;241m10\u001b[39m)))\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m (\u001b[38;5;241m10\u001b[39m, \u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;129m@auto_reshaping\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m, reshape_output\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mf_vmap\u001b[39m(x): \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;241m*\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m10\u001b[39m,))\n", + "\u001b[0;31mNameError\u001b[0m: name 'vmap' is not defined" + ] } - }, - "outputs": [], + ], "source": [ "@auto_reshaping('x')\n", "def f_vmap(x): return x * jnp.ones((10,))\n", @@ -556,14 +566,21 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:27.600491863Z", - "start_time": "2023-11-10T06:38:27.494606555Z" + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'jnp' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[24], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m f_1(\u001b[43mjnp\u001b[49m\u001b[38;5;241m.\u001b[39mones(\u001b[38;5;241m10\u001b[39m))\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m (\u001b[38;5;241m10\u001b[39m,)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m f_1(jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m10\u001b[39m)))\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m (\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;129m@auto_reshaping\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 11\u001b[0m \u001b[38;5;129m@jit\u001b[39m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mf_2\u001b[39m(y, x):\n", + "\u001b[0;31mNameError\u001b[0m: name 'jnp' is not defined" + ] } - }, - "outputs": [], + ], "source": [ "#| hide\n", "@auto_reshaping('x')\n", @@ -608,13 +625,8 @@ }, { "cell_type": "code", - "execution_count": 21, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-10T06:38:28.729238117Z", - "start_time": "2023-11-10T06:38:28.679544633Z" - } - }, + "execution_count": 25, + "metadata": {}, "outputs": [], "source": [ "#| export\n", @@ -638,11 +650,11 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 26, "metadata": { "ExecuteTime": { - "end_time": "2023-11-10T06:38:29.890801775Z", - "start_time": "2023-11-10T06:38:29.846618254Z" + "end_time": "2023-11-13T19:47:57.495915554Z", + "start_time": "2023-11-13T19:47:57.450970065Z" } }, "outputs": [], @@ -663,14 +675,26 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 27, "metadata": { "ExecuteTime": { - "end_time": "2023-11-10T06:38:31.044953559Z", - "start_time": "2023-11-10T06:38:30.999553258Z" + "end_time": "2023-11-13T19:48:00.252414140Z", + "start_time": "2023-11-13T19:48:00.126479788Z" } }, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'dataclass' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[27], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| exporti\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mConfig\u001b[39;00m:\n\u001b[1;32m 4\u001b[0m rng_reserve_size: \u001b[38;5;28mint\u001b[39m\n\u001b[1;32m 5\u001b[0m global_seed: \u001b[38;5;28mint\u001b[39m\n", + "\u001b[0;31mNameError\u001b[0m: name 'dataclass' is not defined" + ] + } + ], "source": [ "#| exporti\n", "@dataclass\n", @@ -687,11 +711,11 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 28, "metadata": { "ExecuteTime": { - "end_time": "2023-11-10T06:38:32.771582697Z", - "start_time": "2023-11-10T06:38:32.695643857Z" + "end_time": "2023-11-13T19:48:01.238592167Z", + "start_time": "2023-11-13T19:48:01.225359977Z" } }, "outputs": [], @@ -703,83 +727,79 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 29, "metadata": { "ExecuteTime": { - "end_time": "2023-11-10T07:27:54.894610863Z", - "start_time": "2023-11-10T07:27:54.764399917Z" - }, - "collapsed": false, - "jupyter": { - "outputs_hidden": false + "end_time": "2023-11-13T19:48:01.732117487Z", + "start_time": "2023-11-13T19:48:01.703501116Z" } }, "outputs": [], "source": [ - "#| export\n", + "# | export\n", "def set_config(\n", - " *,\n", - " rng_reserve_size: int=None,\n", - " global_seed: int=None,\n", - " **kwargs\n", + " *,\n", + " rng_reserve_size: int = None,\n", + " global_seed: int = None,\n", + " **kwargs\n", ") -> None:\n", - " \n", - " if not kwargs:\n", - " #set to default if no arguments are passed\n", - " #Can change to provide error\n", - " main_config.rng_reserve_size = Config.default().rng_reserve_size\n", - " main_config.global_seed = Config.default().global_seed\n", - " \n", - " if rng_reserve_size is not None:\n", - " if not isinstance(rng_reserve_size, int):\n", - " raise TypeError(f\"`rng_reserve_size` must be an integer, but got {type(rng_reserve_size).__name__}.\")\n", - " if rng_reserve_size < 0:\n", - " raise ValueError(f\"`rng_reserve_size` must be non-negative, but got {rng_reserve_size}.\")\n", - " main_config.rng_reserve_size = rng_reserve_size\n", - " \n", - " if global_seed is not None:\n", - " if not isinstance(global_seed, int):\n", - " raise TypeError(f\"`global_seed` must be an integer, but got {type(global_seed).__name__}.\")\n", - " if global_seed < 0:\n", - " raise ValueError(f\"`global_seed` must be non-negative, but got {global_seed}.\")\n", - " main_config.global_seed = global_seed\n", - " \n", - " for k, v in kwargs.items():\n", - " #check if the config name is valid\n", - " if not hasattr(main_config, k):\n", - " raise ValueError(f\"Invalid config name: {k}.\")\n", - " \n", - " if k == \"rng_reserve_size\":\n", - " if not isinstance(v, int):\n", - " raise ValueError(f\"`rng_reserve_size` must be an integer, but got {type(rng_reserve_size).__name__}.\")\n", - " if v < 0:\n", - " raise ValueError(f\"`rng_reserve_size` must be non-negative, but got {rng_reserve_size}.\")\n", + " \"\"\"\n", + " set_config() sets the global configurations.\n", + " :param rng_reserve_size: set the number of random number generators to reserve.\n", + " :param global_seed: set the global seed for random number generators.\n", + " :param kwargs: A dictionary of keyword arguments, where the keys are the config keys to set and the values are the new values for those keys.\n", + " \"\"\"\n", "\n", - " elif k == \"global_seed\":\n", - " if not isinstance(v, int):\n", - " raise ValueError(f\"`global_seed` must be an integer, but got {type(global_seed).__name__}.\")\n", - " if v < 0:\n", - " raise ValueError(f\"`global_seed` must be non-negative, but got {global_seed}.\")\n", + " def arg_check(arg, arg_value, arg_min):\n", + " \"\"\"\n", + " arg_check() checks the validity of the argument and returns the argument value.\n", + " :param arg: The name of the argument.\n", + " :param arg_value: The value of the argument.\n", + " :param arg_min: The minimum value of the argument.\n", + " :return: The argument value.\n", + " \"\"\"\n", "\n", - " setattr(main_config, k, v)" + " if arg_value is not None:\n", + " if not isinstance(arg_value, int):\n", + " raise TypeError(f\"`{arg}` must be an integer, but got {type(arg_value).__name__}.\")\n", + " if arg_value < arg_min:\n", + " raise ValueError(f\"`{arg}` must be non-negative, but got {arg_value}.\")\n", + " return arg_value\n", + "\n", + " if arg_check('rng_reserve_size', rng_reserve_size, 0) is not None:\n", + " main_config.rng_reserve_size = rng_reserve_size\n", + "\n", + " if arg_check('global_seed', global_seed, 0) is not None:\n", + " main_config.global_seed = global_seed" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 30, "metadata": { "ExecuteTime": { - "end_time": "2023-11-10T07:27:55.725263458Z", - "start_time": "2023-11-10T07:27:55.691327435Z" - }, - "collapsed": false, - "jupyter": { - "outputs_hidden": false + "end_time": "2023-11-13T19:48:02.375119176Z", + "start_time": "2023-11-13T19:48:02.326026800Z" } }, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'main_config' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[30], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Generic Test cases\u001b[39;00m\n\u001b[1;32m 2\u001b[0m set_config()\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[43mget_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mrng_reserve_size \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m get_config()\u001b[38;5;241m.\u001b[39mglobal_seed \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m42\u001b[39m\n\u001b[1;32m 4\u001b[0m set_config(rng_reserve_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m get_config()\u001b[38;5;241m.\u001b[39mrng_reserve_size \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m100\u001b[39m\n", + "Cell \u001b[0;32mIn[28], line 3\u001b[0m, in \u001b[0;36mget_config\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_config\u001b[39m() \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Config: \n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmain_config\u001b[49m\n", + "\u001b[0;31mNameError\u001b[0m: name 'main_config' is not defined" + ] + } + ], "source": [ "# Generic Test cases\n", + "set_config()\n", + "assert get_config().rng_reserve_size == 1 and get_config().global_seed == 42\n", "set_config(rng_reserve_size=100)\n", "assert get_config().rng_reserve_size == 100\n", "set_config(global_seed=1234)\n", @@ -787,13 +807,14 @@ "set_config(rng_reserve_size=2, global_seed=234)\n", "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n", "set_config()\n", - "assert get_config() == Config.default()\n", + "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n", + "set_config(lol = 80)\n", + "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n", "#Tests for invalid inputs\n", "test_fail(set_config, kwargs={'rng_reserve_size': -1}, contains='must be non-negative')\n", "test_fail(set_config, kwargs={'rng_reserve_size': 22.7}, contains='must be an integer')\n", "test_fail(set_config, kwargs={'global_seed': -4}, contains='must be non-negative')\n", - "test_fail(set_config, kwargs={'global_seed': 3.14}, contains='must be an integer')\n", - "test_fail(set_config, kwargs={'random': 3}, contains='Invalid config name')\n" + "test_fail(set_config, kwargs={'global_seed': 3.14}, contains='must be an integer')" ] } ], From 4783dd67f71a652eb9d9da874115d9d9d55a3a8d Mon Sep 17 00:00:00 2001 From: Firdaus Choudhury Date: Mon, 13 Nov 2023 15:16:13 -0500 Subject: [PATCH 3/3] Removed redundant test cases for set_config --- nbs/00_utils.ipynb | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/nbs/00_utils.ipynb b/nbs/00_utils.ipynb index 9fc4360..fb976dc 100644 --- a/nbs/00_utils.ipynb +++ b/nbs/00_utils.ipynb @@ -809,12 +809,7 @@ "set_config()\n", "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n", "set_config(lol = 80)\n", - "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n", - "#Tests for invalid inputs\n", - "test_fail(set_config, kwargs={'rng_reserve_size': -1}, contains='must be non-negative')\n", - "test_fail(set_config, kwargs={'rng_reserve_size': 22.7}, contains='must be an integer')\n", - "test_fail(set_config, kwargs={'global_seed': -4}, contains='must be non-negative')\n", - "test_fail(set_config, kwargs={'global_seed': 3.14}, contains='must be an integer')" + "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234" ] } ],