diff --git a/nbs/00_utils.ipynb b/nbs/00_utils.ipynb index 86e64a3..fb976dc 100644 --- a/nbs/00_utils.ipynb +++ b/nbs/00_utils.ipynb @@ -12,8 +12,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-13T19:56:19.376220024Z", + "start_time": "2023-11-13T19:56:19.375715707Z" + } + }, "outputs": [], "source": [ "#| default_exp utils" @@ -21,9 +26,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "#| hide\n", "%load_ext autoreload\n", @@ -36,14 +50,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using JAX backend.\n" + "ename": "ModuleNotFoundError", + "evalue": "No module named 'matplotlib'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| export\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrelax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mimport_essentials\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnbdev\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfastcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbasics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AttrDict\n", + "File \u001b[0;32m~/UniversityFiles/RAISE_LAB/jax-relax/nbs/../relax/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m __version__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m0.2.0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata_module\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataModule, DataModuleConfig, load_data\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m(\n\u001b[1;32m 5\u001b[0m Feature, FeaturesList\n\u001b[1;32m 6\u001b[0m )\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mml_model\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MLModule, MLModuleConfig, load_ml_module\n", + "File \u001b[0;32m~/UniversityFiles/RAISE_LAB/jax-relax/nbs/../relax/data_module.py:5\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_data.ipynb.\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# %% ../nbs/01_data.ipynb 3\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_json, validate_configs, get_config, save_pytree, load_pytree, get_config\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbase\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n", + "File \u001b[0;32m~/UniversityFiles/RAISE_LAB/jax-relax/nbs/../relax/utils.py:5\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_utils.ipynb.\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# %% ../nbs/00_utils.ipynb 3\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mimport_essentials\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnbdev\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfastcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbasics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AttrDict\n", + "File \u001b[0;32m~/UniversityFiles/RAISE_LAB/jax-relax/nbs/../relax/import_essentials.py:4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Cell\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# https://github.com/fastai/fastai/blob/master/fastai/imports.py\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mscipy\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Union,Optional,Dict,List,Tuple,Sequence,Mapping,Callable,Iterable,Any,NamedTuple,Literal\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mio\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01moperator\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01msys\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mos\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mre\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mmimetypes\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mcsv\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mitertools\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mjson\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mshutil\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mglob\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mpickle\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mtarfile\u001b[39;00m\u001b[38;5;241m,\u001b[39m\u001b[38;5;21;01mcollections\u001b[39;00m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'matplotlib'" ] } ], @@ -68,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -104,9 +126,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'BaseParser' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mLearningConfigs\u001b[39;00m(\u001b[43mBaseParser\u001b[49m):\n\u001b[1;32m 2\u001b[0m lr: \u001b[38;5;28mfloat\u001b[39m\n", + "\u001b[0;31mNameError\u001b[0m: name 'BaseParser' is not defined" + ] + } + ], "source": [ "class LearningConfigs(BaseParser):\n", " lr: float" @@ -121,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -137,9 +171,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'LearningConfigs' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m configs \u001b[38;5;241m=\u001b[39m validate_configs(configs_dict, \u001b[43mLearningConfigs\u001b[49m)\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(configs) \u001b[38;5;241m==\u001b[39m LearningConfigs\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m configs\u001b[38;5;241m.\u001b[39mlr \u001b[38;5;241m==\u001b[39m configs_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m'\u001b[39m]\n", + "\u001b[0;31mNameError\u001b[0m: name 'LearningConfigs' is not defined" + ] + } + ], "source": [ "configs = validate_configs(configs_dict, LearningConfigs)\n", "assert type(configs) == LearningConfigs\n", @@ -148,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -160,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -198,7 +244,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -231,9 +277,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[15], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m pytree \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[43mnp\u001b[49m\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m5\u001b[39m, \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mc\u001b[39m\u001b[38;5;124m'\u001b[39m: {\n\u001b[1;32m 5\u001b[0m \n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124md\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124me\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray([\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mc\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m }\n\u001b[1;32m 10\u001b[0m }\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] + } + ], "source": [ "pytree = {\n", " 'a': np.random.randn(5, 1),\n", @@ -256,19 +314,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "data: [array([[-0.4150979 ],\n", - " [-0.59805975],\n", - " [-0.59252158],\n", - " [-0.88781678],\n", - " [ 0.08100867]]), 1, True, 'Hello', array(['a', 'b', 'c'], dtype=' 2\u001b[0m data, pytreedef \u001b[38;5;241m=\u001b[39m \u001b[43mjax\u001b[49m\u001b[38;5;241m.\u001b[39mtree_util\u001b[38;5;241m.\u001b[39mtree_flatten(pytree)\n\u001b[1;32m 3\u001b[0m pytreedef \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mtree_util\u001b[38;5;241m.\u001b[39mtree_map(\u001b[38;5;28;01mlambda\u001b[39;00m x: _is_array(x), pytree)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata: \u001b[39m\u001b[38;5;124m'\u001b[39m, data)\n", + "\u001b[0;31mNameError\u001b[0m: name 'jax' is not defined" ] } ], @@ -282,7 +339,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -303,9 +360,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[18], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Store a dictionary to disk\u001b[39;00m\n\u001b[1;32m 2\u001b[0m pytree \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[43mnp\u001b[49m\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m100\u001b[39m, \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mc\u001b[39m\u001b[38;5;124m'\u001b[39m: {\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124md\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124me\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray([\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mc\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m }\n\u001b[1;32m 10\u001b[0m }\n\u001b[1;32m 11\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 12\u001b[0m save_pytree(pytree, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] + } + ], "source": [ "# Store a dictionary to disk\n", "pytree = {\n", @@ -330,9 +399,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[19], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Store a list to disk\u001b[39;00m\n\u001b[1;32m 2\u001b[0m pytree \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m----> 3\u001b[0m \u001b[43mnp\u001b[49m\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m100\u001b[39m, \u001b[38;5;241m1\u001b[39m),\n\u001b[1;32m 4\u001b[0m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m1\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray([\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m3\u001b[39m])},\n\u001b[1;32m 5\u001b[0m \u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 6\u001b[0m [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m3\u001b[39m],\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgood\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 8\u001b[0m ]\n\u001b[1;32m 9\u001b[0m save_pytree(pytree, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 10\u001b[0m pytree_loaded \u001b[38;5;241m=\u001b[39m load_pytree(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] + } + ], "source": [ "# Store a list to disk\n", "pytree = [\n", @@ -357,9 +438,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'shutil' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[20], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| hide\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mshutil\u001b[49m\u001b[38;5;241m.\u001b[39mrmtree(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtmp\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'shutil' is not defined" + ] + } + ], "source": [ "#| hide\n", "shutil.rmtree('tmp')" @@ -375,7 +468,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -394,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -446,9 +539,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'vmap' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[23], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;129m@auto_reshaping\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mf_vmap\u001b[39m(x): \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;241m*\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m10\u001b[39m,))\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[43mvmap\u001b[49m(f_vmap)(jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m10\u001b[39m, \u001b[38;5;241m10\u001b[39m)))\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m (\u001b[38;5;241m10\u001b[39m, \u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;129m@auto_reshaping\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m, reshape_output\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mf_vmap\u001b[39m(x): \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;241m*\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m10\u001b[39m,))\n", + "\u001b[0;31mNameError\u001b[0m: name 'vmap' is not defined" + ] + } + ], "source": [ "@auto_reshaping('x')\n", "def f_vmap(x): return x * jnp.ones((10,))\n", @@ -461,9 +566,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'jnp' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[24], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m f_1(\u001b[43mjnp\u001b[49m\u001b[38;5;241m.\u001b[39mones(\u001b[38;5;241m10\u001b[39m))\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m (\u001b[38;5;241m10\u001b[39m,)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m f_1(jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m10\u001b[39m)))\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m (\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;129m@auto_reshaping\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 11\u001b[0m \u001b[38;5;129m@jit\u001b[39m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mf_2\u001b[39m(y, x):\n", + "\u001b[0;31mNameError\u001b[0m: name 'jnp' is not defined" + ] + } + ], "source": [ "#| hide\n", "@auto_reshaping('x')\n", @@ -508,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -533,8 +650,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 26, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-13T19:47:57.495915554Z", + "start_time": "2023-11-13T19:47:57.450970065Z" + } + }, "outputs": [], "source": [ "#| export\n", @@ -553,9 +675,26 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 27, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-13T19:48:00.252414140Z", + "start_time": "2023-11-13T19:48:00.126479788Z" + } + }, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'dataclass' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[27], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| exporti\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mConfig\u001b[39;00m:\n\u001b[1;32m 4\u001b[0m rng_reserve_size: \u001b[38;5;28mint\u001b[39m\n\u001b[1;32m 5\u001b[0m global_seed: \u001b[38;5;28mint\u001b[39m\n", + "\u001b[0;31mNameError\u001b[0m: name 'dataclass' is not defined" + ] + } + ], "source": [ "#| exporti\n", "@dataclass\n", @@ -572,23 +711,127 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 28, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-13T19:48:01.238592167Z", + "start_time": "2023-11-13T19:48:01.225359977Z" + } + }, "outputs": [], "source": [ "#| export\n", "def get_config() -> Config: \n", " return main_config" ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-13T19:48:01.732117487Z", + "start_time": "2023-11-13T19:48:01.703501116Z" + } + }, + "outputs": [], + "source": [ + "# | export\n", + "def set_config(\n", + " *,\n", + " rng_reserve_size: int = None,\n", + " global_seed: int = None,\n", + " **kwargs\n", + ") -> None:\n", + " \"\"\"\n", + " set_config() sets the global configurations.\n", + " :param rng_reserve_size: set the number of random number generators to reserve.\n", + " :param global_seed: set the global seed for random number generators.\n", + " :param kwargs: A dictionary of keyword arguments, where the keys are the config keys to set and the values are the new values for those keys.\n", + " \"\"\"\n", + "\n", + " def arg_check(arg, arg_value, arg_min):\n", + " \"\"\"\n", + " arg_check() checks the validity of the argument and returns the argument value.\n", + " :param arg: The name of the argument.\n", + " :param arg_value: The value of the argument.\n", + " :param arg_min: The minimum value of the argument.\n", + " :return: The argument value.\n", + " \"\"\"\n", + "\n", + " if arg_value is not None:\n", + " if not isinstance(arg_value, int):\n", + " raise TypeError(f\"`{arg}` must be an integer, but got {type(arg_value).__name__}.\")\n", + " if arg_value < arg_min:\n", + " raise ValueError(f\"`{arg}` must be non-negative, but got {arg_value}.\")\n", + " return arg_value\n", + "\n", + " if arg_check('rng_reserve_size', rng_reserve_size, 0) is not None:\n", + " main_config.rng_reserve_size = rng_reserve_size\n", + "\n", + " if arg_check('global_seed', global_seed, 0) is not None:\n", + " main_config.global_seed = global_seed" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "ExecuteTime": { + "end_time": "2023-11-13T19:48:02.375119176Z", + "start_time": "2023-11-13T19:48:02.326026800Z" + } + }, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'main_config' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[30], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Generic Test cases\u001b[39;00m\n\u001b[1;32m 2\u001b[0m set_config()\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[43mget_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mrng_reserve_size \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m get_config()\u001b[38;5;241m.\u001b[39mglobal_seed \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m42\u001b[39m\n\u001b[1;32m 4\u001b[0m set_config(rng_reserve_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m get_config()\u001b[38;5;241m.\u001b[39mrng_reserve_size \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m100\u001b[39m\n", + "Cell \u001b[0;32mIn[28], line 3\u001b[0m, in \u001b[0;36mget_config\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_config\u001b[39m() \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Config: \n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmain_config\u001b[49m\n", + "\u001b[0;31mNameError\u001b[0m: name 'main_config' is not defined" + ] + } + ], + "source": [ + "# Generic Test cases\n", + "set_config()\n", + "assert get_config().rng_reserve_size == 1 and get_config().global_seed == 42\n", + "set_config(rng_reserve_size=100)\n", + "assert get_config().rng_reserve_size == 100\n", + "set_config(global_seed=1234)\n", + "assert get_config().global_seed == 1234\n", + "set_config(rng_reserve_size=2, global_seed=234)\n", + "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n", + "set_config()\n", + "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234\n", + "set_config(lol = 80)\n", + "assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234" + ] } ], "metadata": { "kernelspec": { - "display_name": "python3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }