diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 27ff0fbf..9fc952f1 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -59,15 +59,25 @@ } }, "containerEnv": { - "SCRATCH": "/home/vscode/scratch" + "SCRATCH": "/home/vscode/scratch", + "SLURM_TMPDIR": "/tmp" }, - // Mount a "$SCRATCH" directory in the host to ~/scratch in the container. - // Mount /network to use this to mount a "$SCRATCH" directory in the host to ~/scratch in the container. "mounts": [ // https://code.visualstudio.com/remote/advancedcontainers/add-local-file-mount - "source=${localEnv:HOME}/.cache/pdm,target=/home/vscode/.pdm_install_cache,type=bind,consistency=cached", + // Mount a directory which will contain the pdm installation cache (shared with the host machine). + // This will use $SCRATCH/.cache/pdm, otherwise + // Mount a "$SCRATCH" directory in the host to ~/scratch in the container. "source=${localEnv:SCRATCH},target=/home/vscode/scratch,type=bind,consistency=cached", - "source=${localEnv:NETWORK_DIR:/network},target=/network,type=bind,readonly" + "source=${localEnv:SCRATCH}/.cache/pdm,target=/home/vscode/.pdm_install_cache,type=bind,consistency=cached", + // Mount a /network to match the /network directory on the host. + // FIXME: This assumes that either the NETWORK_DIR environment variable is set on the host, or + // that the /network directory exists. + "source=${localEnv:NETWORK_DIR:/network},target=/network,type=bind,readonly", + // Mount a /tmp on the host machine to /tmp/slurm_tmpdir in the container. + // note: there's also a SLURM_TMPDIR env variable set to /tmp/slurm_tmpdir in the container. + // NOTE: this assumes that either $SLURM_TMPDIR is set on the host machine (e.g. a compute node) + // or that `/tmp/slurm_tmpdir` exists on the host machine. + "source=${localEnv:SLURM_TMPDIR:/tmp/slurm_tmpdir},target=/tmp,type=bind,consistency=cached" ], "runArgs": [ "--gpus", @@ -76,7 +86,10 @@ ], // create the pdm cache dir on the host machine if it doesn exist yet so the mount above // doesn't fail. - "initializeCommand": "mkdir -p ~/.cache/pdm", + "initializeCommand": { + "create pdm install cache": "mkdir -p ${SCRATCH?need the SCRATCH environment variable to be set.}/.cache/pdm", // todo: put this on $SCRATCH on the host (e.g. compute node) + "create fake SLURM_TMPDIR": "mkdir -p ${SLURM_TMPDIR?need the SLURM_TMPDIR environment variable to be set.}" // this is fine on compute nodes + }, // NOTE: Getting some permission issues with the .cache dir if mounting .cache/pdm to // .cache/pdm in the container. Therefore, here I'm making a symlink from ~/.cache/pdm to // ~/.pdm_install_cache so the ~/.cache directory is writeable by the container. diff --git a/pdm.lock b/pdm.lock index 12a8b637..1eb350cd 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:805e3f5f1a98de3530f8ec547141537d7d30b7d7d7ca5a3b5f9477809327ecdd" +content_hash = "sha256:2e4f9911bfebcfc3a32ec5f7c9257db49d5c51fde8087009c4b966873563c79c" [[package]] name = "absl-py" @@ -65,14 +65,14 @@ files = [ ] [[package]] -name = "ansicon" -version = "1.89.0" -summary = "Python wrapper for loading Jason Hood's ANSICON" +name = "annotated-types" +version = "0.7.0" +requires_python = ">=3.8" +summary = "Reusable constraint types to use with typing.Annotated" groups = ["default"] -marker = "platform_system == \"Windows\"" files = [ - {file = "ansicon-1.89.0-py2.py3-none-any.whl", hash = "sha256:f1def52d17f65c2c9682cf8370c03f541f410c1752d6a14029f97318e4b9dfec"}, - {file = "ansicon-1.89.0.tar.gz", hash = "sha256:e4d039def5768a47e4afec8e89e83ec3ae5a26bf00ad851f914d1240b444d2b1"}, + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] [[package]] @@ -84,36 +84,6 @@ files = [ {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"}, ] -[[package]] -name = "anyio" -version = "4.4.0" -requires_python = ">=3.8" -summary = "High level compatibility layer for multiple asynchronous event loop implementations" -groups = ["default"] -dependencies = [ - "idna>=2.8", - "sniffio>=1.1", -] -files = [ - {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, - {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, -] - -[[package]] -name = "arrow" -version = "1.3.0" -requires_python = ">=3.8" -summary = "Better dates & times for Python" -groups = ["default"] -dependencies = [ - "python-dateutil>=2.7.0", - "types-python-dateutil>=2.8.10", -] -files = [ - {file = "arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80"}, - {file = "arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85"}, -] - [[package]] name = "attrs" version = "23.2.0" @@ -139,22 +109,6 @@ files = [ {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, ] -[[package]] -name = "blessed" -version = "1.20.0" -requires_python = ">=2.7" -summary = "Easy, practical library for making terminal apps, by providing an elegant, well-documented interface to Colors, Keyboard input, and screen Positioning capabilities." -groups = ["default"] -dependencies = [ - "jinxed>=1.1.0; platform_system == \"Windows\"", - "six>=1.9.0", - "wcwidth>=0.1.4", -] -files = [ - {file = "blessed-1.20.0-py2.py3-none-any.whl", hash = "sha256:0c542922586a265e699188e52d5f5ac5ec0dd517e5a1041d90d2bbf23f906058"}, - {file = "blessed-1.20.0.tar.gz", hash = "sha256:2cdd67f8746e048f00df47a2880f4d6acbcdb399031b604e34ba8f71d5787680"}, -] - [[package]] name = "blinker" version = "1.8.2" @@ -166,41 +120,9 @@ files = [ {file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"}, ] -[[package]] -name = "boto3" -version = "1.34.116" -requires_python = ">=3.8" -summary = "The AWS SDK for Python" -groups = ["default"] -dependencies = [ - "botocore<1.35.0,>=1.34.116", - "jmespath<2.0.0,>=0.7.1", - "s3transfer<0.11.0,>=0.10.0", -] -files = [ - {file = "boto3-1.34.116-py3-none-any.whl", hash = "sha256:e7f5ab2d1f1b90971a2b9369760c2c6bae49dae98c084a5c3f5c78e3968ace15"}, - {file = "boto3-1.34.116.tar.gz", hash = "sha256:53cb8aeb405afa1cd2b25421e27a951aeb568026675dec020587861fac96ac87"}, -] - -[[package]] -name = "botocore" -version = "1.34.116" -requires_python = ">=3.8" -summary = "Low-level, data-driven core of boto 3." -groups = ["default"] -dependencies = [ - "jmespath<2.0.0,>=0.7.1", - "python-dateutil<3.0.0,>=2.1", - "urllib3!=2.2.0,<3,>=1.25.4; python_version >= \"3.10\"", -] -files = [ - {file = "botocore-1.34.116-py3-none-any.whl", hash = "sha256:ec4d42c816e9b2d87a2439ad277e7dda16a4a614ef6839cf66f4c1a58afa547c"}, - {file = "botocore-1.34.116.tar.gz", hash = "sha256:269cae7ba99081519a9f87d7298e238d9e68ba94eb4f8ddfa906224c34cb8b6c"}, -] - [[package]] name = "brax" -version = "0.10.4" +version = "0.10.5" summary = "A differentiable physics engine written in JAX." groups = ["default"] dependencies = [ @@ -222,6 +144,7 @@ dependencies = [ "mujoco-mjx", "numpy", "optax", + "orbax-checkpoint", "pytinyrenderer", "scipy", "tensorboardX", @@ -229,19 +152,19 @@ dependencies = [ "typing-extensions", ] files = [ - {file = "brax-0.10.4-py3-none-any.whl", hash = "sha256:c47affa423ed0b2a987baef2553eeb84e701d52bfaa72695421d8b4ed9a826a5"}, - {file = "brax-0.10.4.tar.gz", hash = "sha256:6646bb5e280d3de2301f4908f236a14333817bdba5c7ec7faf38d4e8a627aec8"}, + {file = "brax-0.10.5-py3-none-any.whl", hash = "sha256:304fe6e5e266e42a18f197f2b7b6a9bb03a87bd97928e385c51c874b56f95866"}, + {file = "brax-0.10.5.tar.gz", hash = "sha256:e7563130c2b08bf0c9453d87602126732f20afc4624cb8574b3577fa62fdbcec"}, ] [[package]] name = "certifi" -version = "2024.2.2" +version = "2024.6.2" requires_python = ">=3.6" summary = "Python package for providing Mozilla's CA Bundle." groups = ["default"] files = [ - {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, - {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, + {file = "certifi-2024.6.2-py3-none-any.whl", hash = "sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56"}, + {file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"}, ] [[package]] @@ -321,7 +244,7 @@ version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." groups = ["default", "dev"] -marker = "platform_system == \"Windows\" or sys_platform == \"win32\"" +marker = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -424,20 +347,6 @@ files = [ {file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"}, ] -[[package]] -name = "croniter" -version = "1.3.15" -requires_python = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -summary = "croniter provides iteration for datetime object with cron like format" -groups = ["default"] -dependencies = [ - "python-dateutil", -] -files = [ - {file = "croniter-1.3.15-py2.py3-none-any.whl", hash = "sha256:f17f877be1d93b9e3191151584a19d8b367b017ab0febc8c5472b9300da61c4c"}, - {file = "croniter-1.3.15.tar.gz", hash = "sha256:924a38fda88f675ec6835667e1d32ac37ff0d65509c2152729d16ff205e32a65"}, -] - [[package]] name = "cycler" version = "0.12.1" @@ -449,20 +358,6 @@ files = [ {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, ] -[[package]] -name = "dateutils" -version = "0.6.12" -summary = "Various utilities for working with date and datetime objects" -groups = ["default"] -dependencies = [ - "python-dateutil", - "pytz", -] -files = [ - {file = "dateutils-0.6.12-py2.py3-none-any.whl", hash = "sha256:f33b6ab430fa4166e7e9cb8b21ee9f6c9843c48df1a964466f52c79b2a8d53b3"}, - {file = "dateutils-0.6.12.tar.gz", hash = "sha256:03dd90bcb21541bd4eb4b013637e4f1b5f944881c46cc6e4b67a6059e370e3f1"}, -] - [[package]] name = "decorator" version = "4.4.2" @@ -474,20 +369,6 @@ files = [ {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, ] -[[package]] -name = "deepdiff" -version = "7.0.1" -requires_python = ">=3.8" -summary = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." -groups = ["default"] -dependencies = [ - "ordered-set<4.2.0,>=4.1.0", -] -files = [ - {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"}, - {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"}, -] - [[package]] name = "dm-env" version = "1.6" @@ -534,40 +415,36 @@ files = [ ] [[package]] -name = "editor" -version = "1.6.6" -requires_python = ">=3.8" -summary = "🖋 Open the default text editor 🖋" +name = "docstring-parser" +version = "0.16" +requires_python = ">=3.6,<4.0" +summary = "Parse Python docstrings in reST, Google and Numpydoc format" groups = ["default"] -dependencies = [ - "runs", - "xmod", -] files = [ - {file = "editor-1.6.6-py3-none-any.whl", hash = "sha256:e818e6913f26c2a81eadef503a2741d7cca7f235d20e217274a009ecd5a74abf"}, - {file = "editor-1.6.6.tar.gz", hash = "sha256:bb6989e872638cd119db9a4fce284cd8e13c553886a1c044c6b8d8a160c871f8"}, + {file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"}, + {file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"}, ] [[package]] name = "etils" -version = "1.9.0" +version = "1.9.2" requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] files = [ - {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, - {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, + {file = "etils-1.9.2-py3-none-any.whl", hash = "sha256:ecd79de1fbfea9b0d6924756cfa922b05ed3360c45cf2170767da4bee0001d20"}, + {file = "etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379"}, ] [[package]] name = "etils" -version = "1.9.0" +version = "1.9.2" extras = ["epath", "epy"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.9.0", + "etils==1.9.2", "etils[epy]", "fsspec", "importlib-resources", @@ -576,19 +453,19 @@ dependencies = [ "zipp", ] files = [ - {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, - {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, + {file = "etils-1.9.2-py3-none-any.whl", hash = "sha256:ecd79de1fbfea9b0d6924756cfa922b05ed3360c45cf2170767da4bee0001d20"}, + {file = "etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379"}, ] [[package]] name = "etils" -version = "1.9.0" +version = "1.9.2" extras = ["epath"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.9.0", + "etils==1.9.2", "etils[epy]", "fsspec", "importlib-resources", @@ -596,24 +473,24 @@ dependencies = [ "zipp", ] files = [ - {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, - {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, + {file = "etils-1.9.2-py3-none-any.whl", hash = "sha256:ecd79de1fbfea9b0d6924756cfa922b05ed3360c45cf2170767da4bee0001d20"}, + {file = "etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379"}, ] [[package]] name = "etils" -version = "1.9.0" +version = "1.9.2" extras = ["epy"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.9.0", + "etils==1.9.2", "typing-extensions", ] files = [ - {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, - {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, + {file = "etils-1.9.2-py3-none-any.whl", hash = "sha256:ecd79de1fbfea9b0d6924756cfa922b05ed3360c45cf2170767da4bee0001d20"}, + {file = "etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379"}, ] [[package]] @@ -637,30 +514,15 @@ files = [ {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, ] -[[package]] -name = "fastapi" -version = "0.88.0" -requires_python = ">=3.7" -summary = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" -groups = ["default"] -dependencies = [ - "pydantic!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0,>=1.6.2", - "starlette==0.22.0", -] -files = [ - {file = "fastapi-0.88.0-py3-none-any.whl", hash = "sha256:263b718bb384422fe3d042ffc9a0c8dece5e034ab6586ff034f6b4b1667c3eee"}, - {file = "fastapi-0.88.0.tar.gz", hash = "sha256:915bf304180a0e7c5605ec81097b7d4cd8826ff87a02bb198e336fb9f3b5ff02"}, -] - [[package]] name = "filelock" -version = "3.14.0" +version = "3.15.3" requires_python = ">=3.8" summary = "A platform independent file lock." groups = ["default", "dev"] files = [ - {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, - {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, + {file = "filelock-3.15.3-py3-none-any.whl", hash = "sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103"}, + {file = "filelock-3.15.3.tar.gz", hash = "sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635"}, ] [[package]] @@ -855,21 +717,21 @@ files = [ [[package]] name = "grpcio" -version = "1.64.0" +version = "1.64.1" requires_python = ">=3.8" summary = "HTTP/2-based RPC framework" groups = ["default"] files = [ - {file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"}, - {file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"}, - {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"}, - {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"}, - {file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"}, - {file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"}, - {file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"}, + {file = "grpcio-1.64.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:4657d24c8063e6095f850b68f2d1ba3b39f2b287a38242dcabc166453e950c59"}, + {file = "grpcio-1.64.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:62b4e6eb7bf901719fce0ca83e3ed474ae5022bb3827b0a501e056458c51c0a1"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:198908f9b22e2672a998870355e226a725aeab327ac4e6ff3a1399792ece4762"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b9d0acaa8d835a6566c640f48b50054f422d03e77e49716d4c4e8e279665a1"}, + {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5e42634a989c3aa6049f132266faf6b949ec2a6f7d302dbb5c15395b77d757eb"}, + {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b1a82e0b9b3022799c336e1fc0f6210adc019ae84efb7321d668129d28ee1efb"}, + {file = "grpcio-1.64.1-cp312-cp312-win32.whl", hash = "sha256:55260032b95c49bee69a423c2f5365baa9369d2f7d233e933564d8a47b893027"}, + {file = "grpcio-1.64.1-cp312-cp312-win_amd64.whl", hash = "sha256:c1a786ac592b47573a5bb7e35665c08064a5d77ab88a076eec11f8ae86b3e3f6"}, + {file = "grpcio-1.64.1.tar.gz", hash = "sha256:8d51dd1c59d5fa0f34266b80a3805ec29a1f26425c2a54736133f6d87fc4968a"}, ] [[package]] @@ -936,17 +798,6 @@ files = [ {file = "gymnax-0.0.8.tar.gz", hash = "sha256:81defc17f52a30a84338b3daa574d7a3bb112f2656f45c783a71efe31eea68ff"}, ] -[[package]] -name = "h11" -version = "0.14.0" -requires_python = ">=3.7" -summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -groups = ["default"] -files = [ - {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, - {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, -] - [[package]] name = "hydra-colorlog" version = "1.2.0" @@ -1035,7 +886,7 @@ files = [ [[package]] name = "imageio-ffmpeg" -version = "0.5.0" +version = "0.5.1" requires_python = ">=3.5" summary = "FFMPEG wrapper for Python" groups = ["default"] @@ -1044,12 +895,12 @@ dependencies = [ "setuptools", ] files = [ - {file = "imageio-ffmpeg-0.5.0.tar.gz", hash = "sha256:75c9c45079510cfeb4849a17fcd3edd4f14062ea6b69c5b62695fb2075295c87"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl", hash = "sha256:e9aba9cdd01164a50a4cfb1b825fc8769151a0d3b5b5a7d5d50ff9fcda7eee9c"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:ba55f392ee5db9eb0a6d7699e0060a2edcaa7dbc740ca29671bdc8dbb763ca3b"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9c813be7d6a24236bb68aeab249ea67f5a7fdf7d86988855578247694c42e94a"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-win32.whl", hash = "sha256:c4a3b32fc38d4a26c15582bf12246ddae060932889da5c9da487cc675740039b"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-win_amd64.whl", hash = "sha256:8135f4d146094b62b31721ca53fe943f4134e3578e22015468e3df595217c24b"}, + {file = "imageio-ffmpeg-0.5.1.tar.gz", hash = "sha256:0ed7a9b31f560b0c9d929c5291cd430edeb9bed3ce9a497480e536dd4326484c"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl", hash = "sha256:1460e84712b9d06910c1f7bb524096b0341d4b7844cea6c20e099d0a24e795b1"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-manylinux2010_x86_64.whl", hash = "sha256:5289f75c7f755b499653f3209fea4efd1430cba0e39831c381aad2d458f7a316"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7fa9132a291d5eb28c44553550deb40cbdab831f2a614e55360301a6582eb205"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-win32.whl", hash = "sha256:89efe2c79979d8174ba8476deb7f74d74c331caee3fb2b65ba2883bec0737625"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-win_amd64.whl", hash = "sha256:1521e79e253bedbdd36a547e0cbd94a025ba0b558e17f08fea687d805a0e4698"}, ] [[package]] @@ -1074,22 +925,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "inquirer" -version = "3.2.4" -requires_python = ">=3.8.1" -summary = "Collection of common interactive command line user interfaces, based on Inquirer.js" -groups = ["default"] -dependencies = [ - "blessed>=1.19.0", - "editor>=1.6.0", - "readchar>=3.0.6", -] -files = [ - {file = "inquirer-3.2.4-py3-none-any.whl", hash = "sha256:273a4e4a4345ac1afdb17408d40fc8dccf3485db68203357919468561035a763"}, - {file = "inquirer-3.2.4.tar.gz", hash = "sha256:33b09efc1b742b9d687b540296a8b6a3f773399673321fcc2ab0eb4c109bf9b5"}, -] - [[package]] name = "intel-openmp" version = "2021.4.0" @@ -1188,7 +1023,7 @@ files = [ [[package]] name = "jaxlib" -version = "0.4.28" +version = "0.4.28+cuda12.cudnn89" requires_python = ">=3.9" summary = "XLA library for JAX" groups = ["default"] @@ -1238,31 +1073,6 @@ files = [ {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, ] -[[package]] -name = "jinxed" -version = "1.2.1" -summary = "Jinxed Terminal Library" -groups = ["default"] -marker = "platform_system == \"Windows\"" -dependencies = [ - "ansicon; platform_system == \"Windows\"", -] -files = [ - {file = "jinxed-1.2.1-py2.py3-none-any.whl", hash = "sha256:37422659c4925969c66148c5e64979f553386a4226b9484d910d3094ced37d30"}, - {file = "jinxed-1.2.1.tar.gz", hash = "sha256:30c3f861b73279fea1ed928cfd4dfb1f273e16cd62c8a32acfac362da0f78f3f"}, -] - -[[package]] -name = "jmespath" -version = "1.0.1" -requires_python = ">=3.7" -summary = "JSON Matching Expressions" -groups = ["default"] -files = [ - {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, - {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, -] - [[package]] name = "kiwisolver" version = "1.4.5" @@ -1305,73 +1115,25 @@ files = [ [[package]] name = "lightning" -version = "1.9.0" -requires_python = ">=3.7" -summary = "Use Lightning Apps to build everything from production-ready, multi-cloud ML systems to simple research demos." +version = "2.3.0" +requires_python = ">=3.8" +summary = "The Deep Learning framework to train, deploy, and ship AI products Lightning fast." groups = ["default"] dependencies = [ - "Jinja2<5.0", - "PyYAML<8.0", "PyYAML<8.0,>=5.4", - "arrow<3.0,>=1.2.0", - "beautifulsoup4<6.0,>=4.8.0", - "click<10.0", - "croniter<1.4.0,>=1.3.0", - "dateutils<2.0", - "deepdiff<8.0,>=5.7.0", - "fastapi<0.89.0", - "fsspec<2024.0,>=2022.5.0", - "fsspec[http]<2024.0,>2021.06.0", - "inquirer<5.0,>=2.10.0", - "lightning-cloud<2.0,>=0.5.12", - "lightning-utilities<2.0,>=0.4.2", + "fsspec[http]<2026.0,>=2022.5.0", + "lightning-utilities<2.0,>=0.8.0", "numpy<3.0,>=1.17.2", - "packaging", - "packaging<23.0,>=17.1", - "psutil<7.0", - "pydantic<3.0", - "requests<4.0", - "rich<15.0", - "starlette<2.0", - "starsessions<2.0,>=1.2.1", - "torch<3.0,>=1.10.0", - "torchmetrics<2.0,>=0.7.0", + "packaging<25.0,>=20.0", + "pytorch-lightning", + "torch<4.0,>=2.0.0", + "torchmetrics<3.0,>=0.7.0", "tqdm<6.0,>=4.57.0", - "traitlets<7.0,>=5.3.0", - "typing-extensions<6.0,>=4.0.0", - "urllib3<3.0", - "uvicorn<2.0", - "websocket-client<3.0", - "websockets<12.0", + "typing-extensions<6.0,>=4.4.0", ] files = [ - {file = "lightning-1.9.0-py3-none-any.whl", hash = "sha256:27db661b37c3581fb3467016cbfcaf39aee77e52a12c32344b0c8da1d1e9e311"}, - {file = "lightning-1.9.0.tar.gz", hash = "sha256:d002270e2cd6bdf239d6605f8ec7f6f79bd2ec4eb5e7758b38ca36c57d4d1fdf"}, -] - -[[package]] -name = "lightning-cloud" -version = "0.5.69" -requires_python = ">=3.7.0" -summary = "Lightning Cloud" -groups = ["default"] -dependencies = [ - "boto3", - "click", - "fastapi", - "protobuf", - "pyjwt", - "python-multipart", - "requests", - "rich", - "six", - "urllib3", - "uvicorn", - "websocket-client", -] -files = [ - {file = "lightning_cloud-0.5.69-py3-none-any.whl", hash = "sha256:8e26b534c3970ea939d37c284e9de5d0c880339a49d18c9b9181c0e093f95fd1"}, - {file = "lightning_cloud-0.5.69.tar.gz", hash = "sha256:0baeef05c06a6d89c482abea1826cc3e3bec48901d10cc2749f39b344e6f1dc3"}, + {file = "lightning-2.3.0-py3-none-any.whl", hash = "sha256:ed66c2053be1295c8452b996b719badf5a26a0652607c121103dfdd5d2dccfae"}, + {file = "lightning-2.3.0.tar.gz", hash = "sha256:4bb4d6e3650d2d5f544ad60853a22efc4e164aa71b9596d13f0454b29df05130"}, ] [[package]] @@ -1583,7 +1345,7 @@ files = [ [[package]] name = "mujoco" -version = "3.1.5" +version = "3.1.6" requires_python = ">=3.8" summary = "MuJoCo Physics Simulator" groups = ["default"] @@ -1595,17 +1357,17 @@ dependencies = [ "pyopengl", ] files = [ - {file = "mujoco-3.1.5-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:0a78079b07e63d04f2985684ccd3a9937badba4cf51432662ff818b092442dbc"}, - {file = "mujoco-3.1.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4145c6277a1e71000a54c0bfef337c885a57452c5f0aa7cddf4b41932b639f41"}, - {file = "mujoco-3.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20bb70bfee28e026efc71f6872871c689fa2eaecc54d019ae1a21362453619cd"}, - {file = "mujoco-3.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f93bf770c3c963efe03c27b34ca59015e27ae70cdd4272a8312e583f52dbf40"}, - {file = "mujoco-3.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:8b139b1950ad52924e8666561414dd8f4f3f69f89364f1d0304371839be9264e"}, - {file = "mujoco-3.1.5.tar.gz", hash = "sha256:9099ba6001341cc9e38b7b94b8ef7a67346c7638fa3e94f520743a357891f296"}, + {file = "mujoco-3.1.6-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:dc0ab85bcda35b2d87df91b7a13152e970a7108d87ef811f28dc32b2dbfb6754"}, + {file = "mujoco-3.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:37a41c5558bd8823da8b2822d2dd941a4c57ee11bf56be5e77ee157c0e5552a1"}, + {file = "mujoco-3.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92e839b3a3758a0010673ec954a1728ce076be923f868d37739040b029489544"}, + {file = "mujoco-3.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07d3b8c270ba9ae5c87e8e37061277ccc0d46767959b68f2a5c5c1e065213021"}, + {file = "mujoco-3.1.6-cp312-cp312-win_amd64.whl", hash = "sha256:49a6b3f88446686aebd345b12d1ec38259701215c7db355725499be9c0e53ef0"}, + {file = "mujoco-3.1.6.tar.gz", hash = "sha256:7cf8887526f071e7411dc02ce1cd665e39b4b6083fdff49fe1348a82d2314651"}, ] [[package]] name = "mujoco-mjx" -version = "3.1.5" +version = "3.1.6" requires_python = ">=3.8" summary = "MuJoCo XLA (MJX)" groups = ["default"] @@ -1614,13 +1376,13 @@ dependencies = [ "etils[epath]", "jax", "jaxlib", - "mujoco>=3.1.5.dev0", + "mujoco>=3.1.6.dev0", "scipy", "trimesh", ] files = [ - {file = "mujoco_mjx-3.1.5-py3-none-any.whl", hash = "sha256:4fc54e10c0cb811fd97584222a00ce9fa433f79d7ce46a8d7b22c8a054c35238"}, - {file = "mujoco_mjx-3.1.5.tar.gz", hash = "sha256:ee6b409d694a0a34ab93803089e3c1297ed91ae6a9461661cd1d80a9f0565880"}, + {file = "mujoco_mjx-3.1.6-py3-none-any.whl", hash = "sha256:0392975c610a8cbd8ad71ba7d7f524fccdb28bacf041998fea34370dd83d46a3"}, + {file = "mujoco_mjx-3.1.6.tar.gz", hash = "sha256:22f70227c3b7ee94b9e89a706c7a9387ba6f34219ecce5a2dadffac225c6637a"}, ] [[package]] @@ -1721,6 +1483,7 @@ requires_python = ">=3" summary = "CUDA nvcc" groups = ["default"] files = [ + {file = "nvidia_cuda_nvcc_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:dcea4f7fa223ac32ad40503499cec117e5543a1b34bb91a886049821bfa75304"}, {file = "nvidia_cuda_nvcc_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8347e2458c99eb9db3c392035c1781798f2593d495554106cf45502eeabc1a10"}, {file = "nvidia_cuda_nvcc_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:616cd3280a05657d1e40d4985058bbd4c88384b92c88a7c30228643abe7465f2"}, ] @@ -1832,6 +1595,7 @@ requires_python = ">=3" summary = "Nvidia JIT LTO Library" groups = ["default", "dev"] files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] @@ -1897,14 +1661,14 @@ files = [ [[package]] name = "orbax-checkpoint" -version = "0.5.15" +version = "0.5.17" requires_python = ">=3.9" summary = "Orbax Checkpoint" groups = ["default"] dependencies = [ "absl-py", "etils[epath,epy]", - "jax>=0.4.9", + "jax>=0.4.25", "jaxlib", "msgpack", "nest-asyncio", @@ -1915,30 +1679,19 @@ dependencies = [ "typing-extensions", ] files = [ - {file = "orbax_checkpoint-0.5.15-py3-none-any.whl", hash = "sha256:658dd89bc925cecc584d89eaa19af9a7e16e3371377907eb713fbd59b85262e4"}, - {file = "orbax_checkpoint-0.5.15.tar.gz", hash = "sha256:15195e8d1b381b56f23a62a25599a3644f5d08655fa64f60bb1b938b8ffe7ef3"}, -] - -[[package]] -name = "ordered-set" -version = "4.1.0" -requires_python = ">=3.7" -summary = "An OrderedSet is a custom MutableSet that remembers its order, so that every" -groups = ["default"] -files = [ - {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"}, - {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"}, + {file = "orbax_checkpoint-0.5.17-py3-none-any.whl", hash = "sha256:212a29bd43c368ba4b62b0c12d565b56bf50ce6df904aedb6ab379a8a3206fd9"}, + {file = "orbax_checkpoint-0.5.17.tar.gz", hash = "sha256:705574a0b41d935b17312fe36988e72da1f33d57d97732937a77a84c02793f94"}, ] [[package]] name = "packaging" -version = "22.0" -requires_python = ">=3.7" +version = "24.1" +requires_python = ">=3.8" summary = "Core utilities for Python packages" groups = ["default", "dev"] files = [ - {file = "packaging-22.0-py3-none-any.whl", hash = "sha256:957e2148ba0e1a3b282772e791ef1d8083648bc131c8ab0c1feba110ce1146c3"}, - {file = "packaging-22.0.tar.gz", hash = "sha256:2198ec20bd4c017b8f9717e00f0c8714076fc2fd93816750ab48e2c41de2cfd3"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -2052,18 +1805,19 @@ files = [ [[package]] name = "psutil" -version = "5.9.8" -requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +version = "6.0.0" +requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" summary = "Cross-platform lib for process and system monitoring in Python." groups = ["default"] files = [ - {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, - {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, - {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, - {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, - {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, ] [[package]] @@ -2078,16 +1832,60 @@ files = [ [[package]] name = "pydantic" -version = "1.10.15" -requires_python = ">=3.7" -summary = "Data validation and settings management using python type hints" +version = "2.7.4" +requires_python = ">=3.8" +summary = "Data validation using Python type hints" groups = ["default"] dependencies = [ - "typing-extensions>=4.2.0", + "annotated-types>=0.4.0", + "pydantic-core==2.18.4", + "typing-extensions>=4.6.1", ] files = [ - {file = "pydantic-1.10.15-py3-none-any.whl", hash = "sha256:28e552a060ba2740d0d2aabe35162652c1459a0b9069fe0db7f4ee0e18e74d58"}, - {file = "pydantic-1.10.15.tar.gz", hash = "sha256:ca832e124eda231a60a041da4f013e3ff24949d94a01154b137fc2f2a43c3ffb"}, + {file = "pydantic-2.7.4-py3-none-any.whl", hash = "sha256:ee8538d41ccb9c0a9ad3e0e5f07bf15ed8015b481ced539a1759d8cc89ae90d0"}, + {file = "pydantic-2.7.4.tar.gz", hash = "sha256:0c84efd9548d545f63ac0060c1e4d39bb9b14db8b3c0652338aecc07b5adec52"}, +] + +[[package]] +name = "pydantic-core" +version = "2.18.4" +requires_python = ">=3.8" +summary = "Core functionality for Pydantic validation and serialization" +groups = ["default"] +dependencies = [ + "typing-extensions!=4.7.0,>=4.6.0", +] +files = [ + {file = "pydantic_core-2.18.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6f5c4d41b2771c730ea1c34e458e781b18cc668d194958e0112455fff4e402b2"}, + {file = "pydantic_core-2.18.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2fdf2156aa3d017fddf8aea5adfba9f777db1d6022d392b682d2a8329e087cef"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4748321b5078216070b151d5271ef3e7cc905ab170bbfd27d5c83ee3ec436695"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:847a35c4d58721c5dc3dba599878ebbdfd96784f3fb8bb2c356e123bdcd73f34"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c40d4eaad41f78e3bbda31b89edc46a3f3dc6e171bf0ecf097ff7a0ffff7cb1"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21a5e440dbe315ab9825fcd459b8814bb92b27c974cbc23c3e8baa2b76890077"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01dd777215e2aa86dfd664daed5957704b769e726626393438f9c87690ce78c3"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4b06beb3b3f1479d32befd1f3079cc47b34fa2da62457cdf6c963393340b56e9"}, + {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:564d7922e4b13a16b98772441879fcdcbe82ff50daa622d681dd682175ea918c"}, + {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0eb2a4f660fcd8e2b1c90ad566db2b98d7f3f4717c64fe0a83e0adb39766d5b8"}, + {file = "pydantic_core-2.18.4-cp312-none-win32.whl", hash = "sha256:8b8bab4c97248095ae0c4455b5a1cd1cdd96e4e4769306ab19dda135ea4cdb07"}, + {file = "pydantic_core-2.18.4-cp312-none-win_amd64.whl", hash = "sha256:14601cdb733d741b8958224030e2bfe21a4a881fb3dd6fbb21f071cabd48fa0a"}, + {file = "pydantic_core-2.18.4-cp312-none-win_arm64.whl", hash = "sha256:c1322d7dd74713dcc157a2b7898a564ab091ca6c58302d5c7b4c07296e3fd00f"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:574d92eac874f7f4db0ca653514d823a0d22e2354359d0759e3f6a406db5d55d"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1f4d26ceb5eb9eed4af91bebeae4b06c3fb28966ca3a8fb765208cf6b51102ab"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77450e6d20016ec41f43ca4a6c63e9fdde03f0ae3fe90e7c27bdbeaece8b1ed4"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d323a01da91851a4f17bf592faf46149c9169d68430b3146dcba2bb5e5719abc"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43d447dd2ae072a0065389092a231283f62d960030ecd27565672bd40746c507"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:578e24f761f3b425834f297b9935e1ce2e30f51400964ce4801002435a1b41ef"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:81b5efb2f126454586d0f40c4d834010979cb80785173d1586df845a632e4e6d"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ab86ce7c8f9bea87b9d12c7f0af71102acbf5ecbc66c17796cff45dae54ef9a5"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:90afc12421df2b1b4dcc975f814e21bc1754640d502a2fbcc6d41e77af5ec312"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:51991a89639a912c17bef4b45c87bd83593aee0437d8102556af4885811d59f5"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:293afe532740370aba8c060882f7d26cfd00c94cae32fd2e212a3a6e3b7bc15e"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b48ece5bde2e768197a2d0f6e925f9d7e3e826f0ad2271120f8144a9db18d5c8"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eae237477a873ab46e8dd748e515c72c0c804fb380fbe6c85533c7de51f23a8f"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:834b5230b5dfc0c1ec37b2fda433b271cbbc0e507560b5d1588e2cc1148cf1ce"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e858ac0a25074ba4bce653f9b5d0a85b7456eaddadc0ce82d3878c22489fa4ee"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2fd41f6eff4c20778d717af1cc50eca52f5afe7805ee530a4fbd0bae284f16e9"}, + {file = "pydantic_core-2.18.4.tar.gz", hash = "sha256:ec3beeada09ff865c344ff3bc2f427f5e6c26401cc6113d77e372c3fdac73864"}, ] [[package]] @@ -2127,17 +1925,6 @@ files = [ {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, ] -[[package]] -name = "pyjwt" -version = "2.8.0" -requires_python = ">=3.7" -summary = "JSON Web Token implementation in Python" -groups = ["default"] -files = [ - {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, - {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, -] - [[package]] name = "pyopengl" version = "3.1.7" @@ -2172,7 +1959,7 @@ files = [ [[package]] name = "pytest" -version = "8.2.1" +version = "8.2.2" requires_python = ">=3.8" summary = "pytest: simple powerful testing with Python" groups = ["default", "dev"] @@ -2183,8 +1970,8 @@ dependencies = [ "pluggy<2.0,>=1.5", ] files = [ - {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"}, - {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"}, + {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, + {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, ] [[package]] @@ -2275,6 +2062,21 @@ files = [ {file = "pytest_skip_slow-0.0.5-py3-none-any.whl", hash = "sha256:e2f6401d6ed0db3be1402622a7b24f7df14f61ebd26feda808a0d45433d4d474"}, ] +[[package]] +name = "pytest-testmon" +version = "2.1.1" +requires_python = ">=3.8" +summary = "selects tests affected by changed files and methods" +groups = ["dev"] +dependencies = [ + "coverage<8,>=6", + "pytest<9,>=5", +] +files = [ + {file = "pytest-testmon-2.1.1.tar.gz", hash = "sha256:8ebe2c3de42d99306ee54cd4536fed0fc48346a954420da904b18e8d59b5da98"}, + {file = "pytest_testmon-2.1.1-py3-none-any.whl", hash = "sha256:8271ca47bc8c80760c4fc7fd7895ea786b111bbb31f13eeea879a6fd11fe2226"}, +] + [[package]] name = "pytest-timeout" version = "2.3.1" @@ -2318,17 +2120,6 @@ files = [ {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] -[[package]] -name = "python-multipart" -version = "0.0.9" -requires_python = ">=3.8" -summary = "A streaming multipart parser for Python" -groups = ["default"] -files = [ - {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"}, - {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"}, -] - [[package]] name = "pytinyrenderer" version = "0.0.14" @@ -2341,6 +2132,28 @@ files = [ {file = "pytinyrenderer-0.0.14.tar.gz", hash = "sha256:5fedb4798509cb911a03a3bc9e8de8d4d5aa36b1de52eb878efef104b95a3d15"}, ] +[[package]] +name = "pytorch-lightning" +version = "2.3.0" +requires_python = ">=3.8" +summary = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." +groups = ["default"] +dependencies = [ + "PyYAML>=5.4", + "fsspec[http]>=2022.5.0", + "lightning-utilities>=0.8.0", + "numpy>=1.17.2", + "packaging>=20.0", + "torch>=2.0.0", + "torchmetrics>=0.7.0", + "tqdm>=4.57.0", + "typing-extensions>=4.4.0", +] +files = [ + {file = "pytorch-lightning-2.3.0.tar.gz", hash = "sha256:89caf90e3543b314508493f26e0eca8d5e10e43e3d9e6c143acd8ddceb584ce2"}, + {file = "pytorch_lightning-2.3.0-py3-none-any.whl", hash = "sha256:b8eec361f4342ca628d0d8e6985511c9515435e4db62c5e982bb1c53a5a5140a"}, +] + [[package]] name = "pytorch2jax" version = "0.1.0" @@ -2384,17 +2197,6 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] -[[package]] -name = "readchar" -version = "4.1.0" -requires_python = ">=3.8" -summary = "Library to easily read single chars and key strokes" -groups = ["default"] -files = [ - {file = "readchar-4.1.0-py3-none-any.whl", hash = "sha256:d163680656b34f263fb5074023db44b999c68ff31ab394445ebfd1a2a41fe9a2"}, - {file = "readchar-4.1.0.tar.gz", hash = "sha256:6f44d1b5f0fd93bd93236eac7da39609f15df647ab9cea39f5bc7478b3344b99"}, -] - [[package]] name = "requests" version = "2.32.3" @@ -2445,56 +2247,28 @@ files = [ [[package]] name = "ruff" -version = "0.4.6" +version = "0.4.9" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["dev"] files = [ - {file = "ruff-0.4.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ef995583a038cd4a7edf1422c9e19118e2511b8ba0b015861b4abd26ec5367c5"}, - {file = "ruff-0.4.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:602ebd7ad909eab6e7da65d3c091547781bb06f5f826974a53dbe563d357e53c"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f9ced5cbb7510fd7525448eeb204e0a22cabb6e99a3cb160272262817d49786"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04a80acfc862e0e1630c8b738e70dcca03f350bad9e106968a8108379e12b31f"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be47700ecb004dfa3fd4dcdddf7322d4e632de3c06cd05329d69c45c0280e618"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1ff930d6e05f444090a0139e4e13e1e2e1f02bd51bb4547734823c760c621e79"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f13410aabd3b5776f9c5699f42b37a3a348d65498c4310589bc6e5c548dc8a2f"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0cf5cc02d3ae52dfb0c8a946eb7a1d6ffe4d91846ffc8ce388baa8f627e3bd50"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea3424793c29906407e3cf417f28fc33f689dacbbadfb52b7e9a809dd535dcef"}, - {file = "ruff-0.4.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1fa8561489fadf483ffbb091ea94b9c39a00ed63efacd426aae2f197a45e67fc"}, - {file = "ruff-0.4.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4d5b914818d8047270308fe3e85d9d7f4a31ec86c6475c9f418fbd1624d198e0"}, - {file = "ruff-0.4.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:4f02284335c766678778475e7698b7ab83abaf2f9ff0554a07b6f28df3b5c259"}, - {file = "ruff-0.4.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3a6a0a4f4b5f54fff7c860010ab3dd81425445e37d35701a965c0248819dde7a"}, - {file = "ruff-0.4.6-py3-none-win32.whl", hash = "sha256:9018bf59b3aa8ad4fba2b1dc0299a6e4e60a4c3bc62bbeaea222679865453062"}, - {file = "ruff-0.4.6-py3-none-win_amd64.whl", hash = "sha256:a769ae07ac74ff1a019d6bd529426427c3e30d75bdf1e08bb3d46ac8f417326a"}, - {file = "ruff-0.4.6-py3-none-win_arm64.whl", hash = "sha256:735a16407a1a8f58e4c5b913ad6102722e80b562dd17acb88887685ff6f20cf6"}, - {file = "ruff-0.4.6.tar.gz", hash = "sha256:a797a87da50603f71e6d0765282098245aca6e3b94b7c17473115167d8dfb0b7"}, -] - -[[package]] -name = "runs" -version = "1.2.2" -requires_python = ">=3.8" -summary = "🏃 Run a block of text as a subprocess 🏃" -groups = ["default"] -dependencies = [ - "xmod", -] -files = [ - {file = "runs-1.2.2-py3-none-any.whl", hash = "sha256:0980dcbc25aba1505f307ac4f0e9e92cbd0be2a15a1e983ee86c24c87b839dfd"}, - {file = "runs-1.2.2.tar.gz", hash = "sha256:9dc1815e2895cfb3a48317b173b9f1eac9ba5549b36a847b5cc60c3bf82ecef1"}, -] - -[[package]] -name = "s3transfer" -version = "0.10.1" -requires_python = ">= 3.8" -summary = "An Amazon S3 Transfer Manager" -groups = ["default"] -dependencies = [ - "botocore<2.0a.0,>=1.33.2", -] -files = [ - {file = "s3transfer-0.10.1-py3-none-any.whl", hash = "sha256:ceb252b11bcf87080fb7850a224fb6e05c8a776bab8f2b64b7f25b969464839d"}, - {file = "s3transfer-0.10.1.tar.gz", hash = "sha256:5683916b4c724f799e600f41dd9e10a9ff19871bf87623cc8f491cb4f5fa0a19"}, + {file = "ruff-0.4.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b262ed08d036ebe162123170b35703aaf9daffecb698cd367a8d585157732991"}, + {file = "ruff-0.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:98ec2775fd2d856dc405635e5ee4ff177920f2141b8e2d9eb5bd6efd50e80317"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4555056049d46d8a381f746680db1c46e67ac3b00d714606304077682832998e"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e91175fbe48f8a2174c9aad70438fe9cb0a5732c4159b2a10a3565fea2d94cde"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e8e7b95673f22e0efd3571fb5b0cf71a5eaaa3cc8a776584f3b2cc878e46bff"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:2d45ddc6d82e1190ea737341326ecbc9a61447ba331b0a8962869fcada758505"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:78de3fdb95c4af084087628132336772b1c5044f6e710739d440fc0bccf4d321"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:06b60f91bfa5514bb689b500a25ba48e897d18fea14dce14b48a0c40d1635893"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88bffe9c6a454bf8529f9ab9091c99490578a593cc9f9822b7fc065ee0712a06"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:673bddb893f21ab47a8334c8e0ea7fd6598ecc8e698da75bcd12a7b9d0a3206e"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8c1aff58c31948cc66d0b22951aa19edb5af0a3af40c936340cd32a8b1ab7438"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:784d3ec9bd6493c3b720a0b76f741e6c2d7d44f6b2be87f5eef1ae8cc1d54c84"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:732dd550bfa5d85af8c3c6cbc47ba5b67c6aed8a89e2f011b908fc88f87649db"}, + {file = "ruff-0.4.9-py3-none-win32.whl", hash = "sha256:8064590fd1a50dcf4909c268b0e7c2498253273309ad3d97e4a752bb9df4f521"}, + {file = "ruff-0.4.9-py3-none-win_amd64.whl", hash = "sha256:e0a22c4157e53d006530c902107c7f550b9233e9706313ab57b892d7197d8e52"}, + {file = "ruff-0.4.9-py3-none-win_arm64.whl", hash = "sha256:5d5460f789ccf4efd43f265a58538a2c24dbce15dbf560676e430375f20a8198"}, + {file = "ruff-0.4.9.tar.gz", hash = "sha256:f1cb0828ac9533ba0135d148d214e284711ede33640465e706772645483427e3"}, ] [[package]] @@ -2534,7 +2308,7 @@ files = [ [[package]] name = "sentry-sdk" -version = "2.3.1" +version = "2.6.0" requires_python = ">=3.6" summary = "Python client for Sentry (https://sentry.io)" groups = ["default"] @@ -2543,8 +2317,8 @@ dependencies = [ "urllib3>=1.26.11", ] files = [ - {file = "sentry_sdk-2.3.1-py2.py3-none-any.whl", hash = "sha256:c5aeb095ba226391d337dd42a6f9470d86c9fc236ecc71cfc7cd1942b45010c6"}, - {file = "sentry_sdk-2.3.1.tar.gz", hash = "sha256:139a71a19f5e9eb5d3623942491ce03cf8ebc14ea2e39ba3e6fe79560d8a5b1f"}, + {file = "sentry_sdk-2.6.0-py2.py3-none-any.whl", hash = "sha256:422b91cb49378b97e7e8d0e8d5a1069df23689d45262b86f54988a7db264e874"}, + {file = "sentry_sdk-2.6.0.tar.gz", hash = "sha256:65cc07e9c6995c5e316109f138570b32da3bd7ff8d0d0ee4aaf2628c3dd8127d"}, ] [[package]] @@ -2587,13 +2361,28 @@ files = [ [[package]] name = "setuptools" -version = "70.0.0" +version = "70.1.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default"] files = [ - {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, - {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, + {file = "setuptools-70.1.0-py3-none-any.whl", hash = "sha256:d9b8b771455a97c8a9f3ab3448ebe0b29b5e105f1228bba41028be116985a267"}, + {file = "setuptools-70.1.0.tar.gz", hash = "sha256:01a1e793faa5bd89abc851fa15d0a0db26f160890c7102cd8dce643e886b47f5"}, +] + +[[package]] +name = "simple-parsing" +version = "0.1.5" +requires_python = ">=3.7" +summary = "A small utility for simplifying and cleaning up argument parsing scripts." +groups = ["default"] +dependencies = [ + "docstring-parser~=0.15", + "typing-extensions>=4.5.0", +] +files = [ + {file = "simple_parsing-0.1.5-py3-none-any.whl", hash = "sha256:46f35ed7002f9bb25dca3a49eac491cc78d2140e4adcbe156225ae643c2874ea"}, + {file = "simple_parsing-0.1.5.tar.gz", hash = "sha256:d26ac15be5173cf28174e171a68153c11e462ad2cb3c23d3ad8634b00719d1fc"}, ] [[package]] @@ -2618,17 +2407,6 @@ files = [ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, ] -[[package]] -name = "sniffio" -version = "1.3.1" -requires_python = ">=3.7" -summary = "Sniff out which async library your code is running under" -groups = ["default"] -files = [ - {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, - {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, -] - [[package]] name = "soupsieve" version = "2.5" @@ -2640,35 +2418,6 @@ files = [ {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, ] -[[package]] -name = "starlette" -version = "0.22.0" -requires_python = ">=3.7" -summary = "The little ASGI library that shines." -groups = ["default"] -dependencies = [ - "anyio<5,>=3.4.0", -] -files = [ - {file = "starlette-0.22.0-py3-none-any.whl", hash = "sha256:b5eda991ad5f0ee5d8ce4c4540202a573bb6691ecd0c712262d0bc85cf8f2c50"}, - {file = "starlette-0.22.0.tar.gz", hash = "sha256:b092cbc365bea34dd6840b42861bdabb2f507f8671e642e8272d2442e08ea4ff"}, -] - -[[package]] -name = "starsessions" -version = "1.3.0" -requires_python = ">=3.6.2,<4.0.0" -summary = "Pluggable session support for Starlette." -groups = ["default"] -dependencies = [ - "itsdangerous<3.0.0,>=2.0.1", - "starlette<1,>=0", -] -files = [ - {file = "starsessions-1.3.0-py3-none-any.whl", hash = "sha256:c0758f2a1a2438ec7ba88b232e82008f2261a75584f01179c787b3636fae6040"}, - {file = "starsessions-1.3.0.tar.gz", hash = "sha256:8d3b509d4e6d235655f7dd495fcf0afc1bd86da84de3a8d434e6f82137ebcde8"}, -] - [[package]] name = "submitit" version = "1.5.1" @@ -2716,18 +2465,18 @@ name = "tensor-regression" version = "0.0.2.post3.dev0" requires_python = "<4.0,>=3.11" git = "https://www.github.com/lebrice/tensor_regression" -revision = "2b15f9312fe8891f0c617b5cbce1ba757d514a0a" +revision = "7b3a07ae924eaeacde6ebeade2efcd7f8ce526d5" summary = "A small wrapper around pytest_regressions for Tensors" groups = ["default", "dev"] dependencies = [ "numpy<2.0.0,>=1.26.4", "pytest-regressions<3.0.0,>=2.5.0", - "torch<3.0.0,>=2.3.1", + "torch<3.0.0,>=2.0.0", ] [[package]] name = "tensorboard" -version = "2.16.2" +version = "2.17.0" requires_python = ">=3.9" summary = "TensorBoard lets you watch Tensors Flow" groups = ["default"] @@ -2736,14 +2485,14 @@ dependencies = [ "grpcio>=1.48.2", "markdown>=2.6.8", "numpy>=1.12.0", - "protobuf!=4.24.0,>=3.19.6", + "protobuf!=4.24.0,<5.0.0,>=3.19.6", "setuptools>=41.0.0", "six>1.9", "tensorboard-data-server<0.8.0,>=0.7.0", "werkzeug>=1.0.1", ] files = [ - {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, + {file = "tensorboard-2.17.0-py3-none-any.whl", hash = "sha256:859a499a9b1fb68a058858964486627100b71fcb21646861c61d31846a6478fb"}, ] [[package]] @@ -2775,20 +2524,20 @@ files = [ [[package]] name = "tensorstore" -version = "0.1.60" +version = "0.1.62" requires_python = ">=3.9" summary = "Read and write large, multi-dimensional arrays" groups = ["default"] dependencies = [ "ml-dtypes>=0.3.1", - "numpy>=1.16.0", + "numpy>=1.22.0", ] files = [ - {file = "tensorstore-0.1.60-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:65677e21304fcf272557f195c597704f4ccf55b75314e68ece17bb1784cb59f7"}, - {file = "tensorstore-0.1.60-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:725d1f70c17838815704805d2853c636bb2d680424e81f91677a7defea68373b"}, - {file = "tensorstore-0.1.60-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c477a0e6948326c414ed1bcdab2949e975f0b4e7e449cce39e0fec14b273e1b2"}, - {file = "tensorstore-0.1.60-cp312-cp312-win_amd64.whl", hash = "sha256:32cba3cf0ae6dd03d504162b8ea387f140050e279cf23e7eced68d3c845693da"}, - {file = "tensorstore-0.1.60.tar.gz", hash = "sha256:88da8f1978982101b8dbb144fd29ee362e4e8c97fc595c4992d555f80ce62a79"}, + {file = "tensorstore-0.1.62-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:616cd5d55ff6e2979d6f4578ad76c1d12dfdb361d43edfd90728b558857f33b9"}, + {file = "tensorstore-0.1.62-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6824d3f49fc2c75c7a5da1b77e840014565660852bff2544c38ccafbe63ed5a7"}, + {file = "tensorstore-0.1.62-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78e081786b293bf3a4acf2ae54d62d25c82a21ad9503c0986ba6fcf03c6c9fbb"}, + {file = "tensorstore-0.1.62-cp312-cp312-win_amd64.whl", hash = "sha256:446e46dfd149ab516fdf47598684fc472b206afcd5e365e0e3e55c7f280cc288"}, + {file = "tensorstore-0.1.62.tar.gz", hash = "sha256:d0e88dae5d983e500700f9f1636eaa742f9e673b4a230d7126f1380e021f373f"}, ] [[package]] @@ -2838,16 +2587,16 @@ files = [ [[package]] name = "torch-jax-interop" version = "0.0.4.post7.dev0" -requires_python = "<4.0,>=3.11" +requires_python = ">=3.11,<4.0" git = "https://www.github.com/lebrice/torch_jax_interop" -revision = "7f0c72fe19d8bd4bd957f20dd90d77acd8178bd4" +revision = "3a4261f949d739cfe684280203137114b169e70e" summary = "Utility to convert Tensors from Jax to Torch and vice-versa" groups = ["default"] dependencies = [ - "flax<1.0.0,>=0.8.4", - "jax[cuda12]<1.0.0,>=0.4.28", - "pytorch2jax<1.0.0,>=0.1.0", - "torch<3.0.0,>=2.3.0", + "flax<0.9.0,>=0.8.4", + "jax[cuda12]<0.5.0,>=0.4.28", + "pytorch2jax<0.2.0,>=0.1.0", + "torch<3.0.0,>=2.0.0", ] [[package]] @@ -2899,20 +2648,9 @@ files = [ {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, ] -[[package]] -name = "traitlets" -version = "5.14.3" -requires_python = ">=3.8" -summary = "Traitlets Python configuration system" -groups = ["default"] -files = [ - {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, - {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, -] - [[package]] name = "trimesh" -version = "4.4.0" +version = "4.4.1" requires_python = ">=3.7" summary = "Import, export, process, analyze and view triangular meshes." groups = ["default"] @@ -2920,30 +2658,19 @@ dependencies = [ "numpy>=1.20", ] files = [ - {file = "trimesh-4.4.0-py3-none-any.whl", hash = "sha256:e192458da391c1b0a850df0b713c59234a6582e641569b004b588ada337b05c0"}, - {file = "trimesh-4.4.0.tar.gz", hash = "sha256:daf6e56715de2e93dd905e926f9bb10d23dc4157f9724aa7caab5d0e28963e56"}, -] - -[[package]] -name = "types-python-dateutil" -version = "2.9.0.20240316" -requires_python = ">=3.8" -summary = "Typing stubs for python-dateutil" -groups = ["default"] -files = [ - {file = "types-python-dateutil-2.9.0.20240316.tar.gz", hash = "sha256:5d2f2e240b86905e40944dd787db6da9263f0deabef1076ddaed797351ec0202"}, - {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, + {file = "trimesh-4.4.1-py3-none-any.whl", hash = "sha256:dc00e293f4efed692b57e95ff9dafd5b62f2126439fb377d2a6b048d7d086933"}, + {file = "trimesh-4.4.1.tar.gz", hash = "sha256:767fe3c866ba74e6d9a9d216c34ecc1cfe2fbf3f129a6c11d59871705a591aba"}, ] [[package]] name = "typing-extensions" -version = "4.12.0" +version = "4.12.2" requires_python = ">=3.8" summary = "Backported and Experimental Type Hints for Python 3.8+" groups = ["default", "dev"] files = [ - {file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"}, - {file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -2959,33 +2686,18 @@ files = [ [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" requires_python = ">=3.8" summary = "HTTP library with thread-safe connection pooling, file post, and more." groups = ["default"] files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, -] - -[[package]] -name = "uvicorn" -version = "0.30.0" -requires_python = ">=3.8" -summary = "The lightning-fast ASGI server." -groups = ["default"] -dependencies = [ - "click>=7.0", - "h11>=0.8", -] -files = [ - {file = "uvicorn-0.30.0-py3-none-any.whl", hash = "sha256:78fa0b5f56abb8562024a59041caeb555c86e48d0efdd23c3fe7de7a4075bdab"}, - {file = "uvicorn-0.30.0.tar.gz", hash = "sha256:f678dec4fa3a39706bbf49b9ec5fc40049d42418716cea52b53f07828a60aa37"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [[package]] name = "wandb" -version = "0.17.0" +version = "0.17.2" requires_python = ">=3.7" summary = "A CLI and library for interacting with the Weights & Biases API." groups = ["default"] @@ -2994,8 +2706,8 @@ dependencies = [ "docker-pycreds>=0.4.0", "gitpython!=3.1.29,>=1.0.0", "platformdirs", - "protobuf!=4.21.0,<5,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"", - "protobuf!=4.21.0,<5,>=3.19.0; sys_platform != \"linux\"", + "protobuf!=4.21.0,<6,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"", + "protobuf!=4.21.0,<6,>=3.19.0; sys_platform != \"linux\"", "psutil>=5.0.0", "pyyaml", "requests<3,>=2.0.0", @@ -3004,60 +2716,13 @@ dependencies = [ "setuptools", ] files = [ - {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"}, - {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"}, - {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"}, - {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"}, - {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"}, - {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"}, - {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"}, -] - -[[package]] -name = "wcwidth" -version = "0.2.13" -summary = "Measures the displayed width of unicode strings in a terminal" -groups = ["default"] -files = [ - {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, - {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, -] - -[[package]] -name = "websocket-client" -version = "1.8.0" -requires_python = ">=3.8" -summary = "WebSocket client for Python with low level API options" -groups = ["default"] -files = [ - {file = "websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526"}, - {file = "websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da"}, -] - -[[package]] -name = "websockets" -version = "11.0.3" -requires_python = ">=3.7" -summary = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -groups = ["default"] -files = [ - {file = "websockets-11.0.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e58f2c36cc52d41f2659e4c0cbf7353e28c8c9e63e30d8c6d3494dc9fdedcf"}, - {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de36fe9c02995c7e6ae6efe2e205816f5f00c22fd1fbf343d4d18c3d5ceac2f5"}, - {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0ac56b661e60edd453585f4bd68eb6a29ae25b5184fd5ba51e97652580458998"}, - {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e052b8467dd07d4943936009f46ae5ce7b908ddcac3fda581656b1b19c083d9b"}, - {file = "websockets-11.0.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:42cc5452a54a8e46a032521d7365da775823e21bfba2895fb7b77633cce031bb"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e6316827e3e79b7b8e7d8e3b08f4e331af91a48e794d5d8b099928b6f0b85f20"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8531fdcad636d82c517b26a448dcfe62f720e1922b33c81ce695d0edb91eb931"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c114e8da9b475739dde229fd3bc6b05a6537a88a578358bc8eb29b4030fac9c9"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e063b1865974611313a3849d43f2c3f5368093691349cf3c7c8f8f75ad7cb280"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:92b2065d642bf8c0a82d59e59053dd2fdde64d4ed44efe4870fa816c1232647b"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0ee68fe502f9031f19d495dae2c268830df2760c0524cbac5d759921ba8c8e82"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcacf2c7a6c3a84e720d1bb2b543c675bf6c40e460300b628bab1b1efc7c034c"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b67c6f5e5a401fc56394f191f00f9b3811fe843ee93f4a70df3c389d1adf857d"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d5023a4b6a5b183dc838808087033ec5df77580485fc533e7dab2567851b0a4"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:ed058398f55163a79bb9f06a90ef9ccc063b204bb346c4de78efc5d15abfe602"}, - {file = "websockets-11.0.3-py3-none-any.whl", hash = "sha256:6681ba9e7f8f3b19440921e99efbb40fc89f26cd71bf539e45d8c8a25c976dc6"}, - {file = "websockets-11.0.3.tar.gz", hash = "sha256:88fc51d9a26b10fc331be344f1781224a375b78488fc343620184e95a4b27016"}, + {file = "wandb-0.17.2-py3-none-any.whl", hash = "sha256:4bd351be28cea87730365856cfaa72f72ceb787accc21bad359dde5aa9c4356d"}, + {file = "wandb-0.17.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:638353a2d702caedd304a5f1e526ef93a291c984c109fcb444262a57aeaacec9"}, + {file = "wandb-0.17.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:824e33ca77af87f87a9cf1122acba164da5bf713adc9d67332bc686028921ec9"}, + {file = "wandb-0.17.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:032ca5939008643349af178a8b66b8047a1eefcb870c4c4a86e22acafde6470f"}, + {file = "wandb-0.17.2-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9558bab47a0c8ac4f22cfa2d43f91d1bc1f75d4255629286db674fe49fcd30e5"}, + {file = "wandb-0.17.2-py3-none-win32.whl", hash = "sha256:4bc176e3c81be216dc889fcd098341eb17a14b04e080d4343ce3f0b1740abfc1"}, + {file = "wandb-0.17.2-py3-none-win_amd64.whl", hash = "sha256:62cd707f38b5711971729dae80343b8c35f6003901e690166cc6d526187a9785"}, ] [[package]] @@ -3074,17 +2739,6 @@ files = [ {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, ] -[[package]] -name = "xmod" -version = "1.8.1" -requires_python = ">=3.8" -summary = "🌱 Turn any object into a module 🌱" -groups = ["default"] -files = [ - {file = "xmod-1.8.1-py3-none-any.whl", hash = "sha256:a24e9458a4853489042522bdca9e50ee2eac5ab75c809a91150a8a7f40670d48"}, - {file = "xmod-1.8.1.tar.gz", hash = "sha256:38c76486b9d672c546d57d8035df0beb7f4a9b088bc3fb2de5431ae821444377"}, -] - [[package]] name = "yarl" version = "1.9.4" @@ -3117,11 +2771,11 @@ files = [ [[package]] name = "zipp" -version = "3.19.1" +version = "3.19.2" requires_python = ">=3.8" summary = "Backport of pathlib-compatible object wrapper for zip files" groups = ["default"] files = [ - {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, - {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, + {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, + {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, ] diff --git a/project/algorithms/bases/algorithm_test.py b/project/algorithms/bases/algorithm_test.py index 5df5ee95..8d9a69ef 100644 --- a/project/algorithms/bases/algorithm_test.py +++ b/project/algorithms/bases/algorithm_test.py @@ -1,10 +1,8 @@ from __future__ import annotations -import contextlib import copy import inspect import operator -import random import sys import typing from collections.abc import Callable, Sequence @@ -12,7 +10,6 @@ from pathlib import Path from typing import Any, ClassVar, Generic, Literal, TypeVar -import numpy as np import pytest import torch from lightning import Callback, LightningDataModule, LightningModule, Trainer @@ -28,7 +25,7 @@ from project.datamodules.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.experiment import ( instantiate_datamodule, instantiate_network, @@ -36,7 +33,9 @@ from project.main import main from project.utils.hydra_utils import resolve_dictconfig from project.utils.testutils import ( + default_marks_for_config_combinations, default_marks_for_config_name, + fork_rng, get_all_datamodule_names_params, get_all_network_names, get_type_for_config_name, @@ -146,8 +145,11 @@ def get_testing_callbacks(self) -> list[TestingCallback]: AllParamsShouldHaveGradients(), ] + # todo: make this much faster to run! + # Also, some combinations don't work, e.g. `imagenet + fcnet`, there are nans in the network. + @pytest.mark.slow - @pytest.mark.timeout(10) # todo: make this much faster to run! + # @pytest.mark.timeout(10) def test_overfit_training_batch( self, algorithm: AlgorithmType, @@ -264,15 +266,6 @@ def test_experiment_reproducible_given_seed( overrides_1 = all_overrides + [f"++trainer.default_root_dir={tmp_path_1}"] overrides_2 = all_overrides + [f"++trainer.default_root_dir={tmp_path_2}"] - @contextlib.contextmanager - def fork_rng(): - with torch.random.fork_rng(): - random_state = random.getstate() - np_random_state = np.random.get_state() - yield - np.random.set_state(np_random_state) - random.setstate(random_state) - with ( fork_rng(), setup_hydra_for_tests_and_compose(overrides_1, tmp_path=tmp_path_1) as config_1, @@ -297,11 +290,9 @@ def fork_rng(): def datamodule_name(self, request: pytest.FixtureRequest): """Fixture that gives the name of a datamodule to use.""" datamodule_name = request.param - if datamodule_name in default_marks_for_config_name: for marker in default_marks_for_config_name[datamodule_name]: request.applymarker(marker) - self._skip_if_unsupported("datamodule", datamodule_name, skip_or_xfail=SKIP_OR_XFAIL) return datamodule_name @@ -325,7 +316,11 @@ def network_name(self, request: pytest.FixtureRequest): @pytest.fixture(scope="class") def _hydra_config( - self, datamodule_name: str, network_name: str, tmp_path_factory: pytest.TempPathFactory + self, + datamodule_name: str, + network_name: str, + tmp_path_factory: pytest.TempPathFactory, + request: pytest.FixtureRequest, ) -> DictConfig: """Fixture that gives the Hydra configuration for an experiment that uses this algorithm, datamodule, and network. @@ -338,6 +333,16 @@ def _hydra_config( # todo: Get the name of the algorithm from the hydra config? algorithm_name = self.algorithm_name + + combination = set([datamodule_name, network_name, algorithm_name]) + for configs, marks in default_marks_for_config_combinations.items(): + configs = set(configs) + if combination >= configs: + logger.debug(f"Applying markers because {combination} contains {configs}") + # There is a combination of potentially unsupported configs here. + for mark in marks: + request.applymarker(mark) + with setup_hydra_for_tests_and_compose( all_overrides=[ f"algorithm={algorithm_name}", @@ -654,7 +659,8 @@ def on_train_batch_end( parameters_with_nans = [ name for name, param in pl_module.named_parameters() if param.isnan().any() ] - assert not parameters_with_nans + if parameters_with_nans: + raise RuntimeError(f"Parameters {parameters_with_nans} contain NaNs!") parameters_with_nans_in_grad = [ name diff --git a/project/configs/__init__.py b/project/configs/__init__.py index bb734082..13e6a8d7 100644 --- a/project/configs/__init__.py +++ b/project/configs/__init__.py @@ -2,11 +2,9 @@ from hydra.core.config_store import ConfigStore +from ..utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR from .config import Config from .datamodule import ( - REPO_ROOTDIR, - SLURM_JOB_ID, - SLURM_TMPDIR, datamodule_store, ) from .network import network_store diff --git a/project/configs/config.py b/project/configs/config.py index 5ff41808..6a39e912 100644 --- a/project/configs/config.py +++ b/project/configs/config.py @@ -3,6 +3,12 @@ from logging import getLogger as get_logger from typing import Any, Literal +from omegaconf import OmegaConf + +from project.utils.env_vars import get_constant + +OmegaConf.register_new_resolver("constant", get_constant) + logger = get_logger(__name__) LogLevel = Literal["debug", "info", "warning", "error", "critical"] diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py index 03154b43..b8bd32a5 100644 --- a/project/configs/datamodule/__init__.py +++ b/project/configs/datamodule/__init__.py @@ -1,155 +1,49 @@ -import os -from collections.abc import Callable -from dataclasses import dataclass, field from logging import getLogger as get_logger from pathlib import Path -import torch -from hydra_zen import hydrated_dataclass, instantiate, store -from torch import Tensor +from hydra_zen import store -from project.datamodules import ( - CIFAR10DataModule, - FashionMNISTDataModule, - ImageNet32DataModule, - MNISTDataModule, - VisionDataModule, -) -from project.datamodules.image_classification.cifar10 import cifar10_train_transforms -from project.datamodules.image_classification.imagenet32 import imagenet32_train_transforms -from project.datamodules.image_classification.inaturalist import ( - INaturalistDataModule, - TargetType, - Version, -) -from project.datamodules.image_classification.mnist import mnist_train_transforms +from project.utils.env_vars import NETWORK_DIR -FILE = Path(__file__) -REPO_ROOTDIR = FILE.parent -for level in range(5): - if "README.md" in list(p.name for p in REPO_ROOTDIR.iterdir()): - break - REPO_ROOTDIR = REPO_ROOTDIR.parent - - -SLURM_TMPDIR: Path | None = ( - Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None -) -SLURM_JOB_ID: int | None = ( - int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None -) - -logger = get_logger(__name__) - - -TORCHVISION_DIR: Path | None = None - -_torchvision_dir = Path("/network/datasets/torchvision") -if _torchvision_dir.exists() and _torchvision_dir.is_dir(): - TORCHVISION_DIR = _torchvision_dir - - -if not SLURM_TMPDIR and SLURM_JOB_ID is not None: - # This can happens when running the integrated VSCode terminal with `mila code`! - _slurm_tmpdir = Path(f"/Tmp/slurm.{SLURM_JOB_ID}.0") - if _slurm_tmpdir.exists(): - SLURM_TMPDIR = _slurm_tmpdir -SCRATCH = Path(os.environ["SCRATCH"]) if "SCRATCH" in os.environ else None -DATA_DIR = Path(os.environ.get("DATA_DIR", (SLURM_TMPDIR or SCRATCH or REPO_ROOTDIR) / "data")) - -NUM_WORKERS = int( - os.environ.get( - "SLURM_CPUS_PER_TASK", - os.environ.get( - "SLURM_CPUS_ON_NODE", - len(os.sched_getaffinity(0)) - if hasattr(os, "sched_getaffinity") - else torch.multiprocessing.cpu_count(), - ), - ) -) logger = get_logger(__name__) - -Transform = Callable[[Tensor], Tensor] - - -@dataclass -class DataModuleConfig: ... +torchvision_dir: Path | None = None +"""Network directory with torchvision datasets.""" +if ( + NETWORK_DIR + and (_torchvision_dir := NETWORK_DIR / "datasets/torchvision").exists() + and _torchvision_dir.is_dir() +): + torchvision_dir = _torchvision_dir +# TODO: Make it possible to extend a structured base via yaml files as well as adding new fields +# (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the +# config). datamodule_store = store(group="datamodule") -@hydrated_dataclass(target=VisionDataModule, populate_full_signature=True) -class VisionDataModuleConfig(DataModuleConfig): - data_dir: str | None = str(TORCHVISION_DIR or DATA_DIR) - val_split: int | float = 0.1 # NOTE: reduced from default of 0.2 - num_workers: int = NUM_WORKERS - normalize: bool = True # NOTE: Set to True by default instead of False - batch_size: int = 32 - seed: int = 42 - shuffle: bool = True # NOTE: Set to True by default instead of False. - pin_memory: bool = True # NOTE: Set to True by default instead of False. - drop_last: bool = False - - __call__ = instantiate - - -# todo: look into this to avoid having to make dataclasses with no fields just to call a function.. -from hydra_zen import store, zen # noqa - - -# FIXME: This is dumb! -@hydrated_dataclass(target=mnist_train_transforms) -class MNISTTrainTransforms: ... - - -@hydrated_dataclass(target=MNISTDataModule, populate_full_signature=True) -class MNISTDataModuleConfig(VisionDataModuleConfig): - normalize: bool = True - batch_size: int = 128 - train_transforms: MNISTTrainTransforms = field(default_factory=MNISTTrainTransforms) - - -@hydrated_dataclass(target=FashionMNISTDataModule, populate_full_signature=True) -class FashionMNISTDataModuleConfig(MNISTDataModuleConfig): ... - - -@hydrated_dataclass(target=cifar10_train_transforms) -class Cifar10TrainTransforms: ... - - -@hydrated_dataclass(target=CIFAR10DataModule, populate_full_signature=True) -class CIFAR10DataModuleConfig(VisionDataModuleConfig): - train_transforms: Cifar10TrainTransforms = field(default_factory=Cifar10TrainTransforms) - # Overwriting this one: - batch_size: int = 128 - - -@hydrated_dataclass(target=imagenet32_train_transforms) -class ImageNet32TrainTransforms: ... - - -@hydrated_dataclass(target=ImageNet32DataModule, populate_full_signature=True) -class ImageNet32DataModuleConfig(VisionDataModuleConfig): - data_dir: Path = ((SCRATCH / "data") if SCRATCH else DATA_DIR) / "imagenet32" - - val_split: int | float = -1 - num_images_per_val_class: int = 50 # Slightly different. - normalize: bool = True - train_transforms: ImageNet32TrainTransforms = field(default_factory=ImageNet32TrainTransforms) +# @hydrated_dataclass(target=VisionDataModule, populate_full_signature=True) +# class VisionDataModuleConfig: +# data_dir: str | None = str(torchvision_dir or DATA_DIR) +# val_split: int | float = 0.1 # NOTE: reduced from default of 0.2 +# num_workers: int = NUM_WORKERS +# normalize: bool = True # NOTE: Set to True by default instead of False +# batch_size: int = 32 +# seed: int = 42 +# shuffle: bool = True # NOTE: Set to True by default instead of False. +# pin_memory: bool = True # NOTE: Set to True by default instead of False. +# drop_last: bool = False +# __call__ = instantiate -@hydrated_dataclass(target=INaturalistDataModule, populate_full_signature=True) -class INaturalistDataModuleConfig(VisionDataModuleConfig): - data_dir: Path | None = None - version: Version = "2021_train" - target_type: TargetType | list[TargetType] = "full" +# datamodule_store(VisionDataModuleConfig, name="vision") -datamodule_store(CIFAR10DataModuleConfig, name="cifar10") -datamodule_store(MNISTDataModuleConfig, name="mnist") -datamodule_store(FashionMNISTDataModuleConfig, name="fashion_mnist") -datamodule_store(ImageNet32DataModuleConfig, name="imagenet32") -datamodule_store(INaturalistDataModuleConfig, name="inaturalist") +# inaturalist_config = hydra_zen.builds( +# INaturalistDataModule, +# builds_bases=(VisionDataModuleConfig,), +# populate_full_signature=True, +# dataclass_name=f"{INaturalistDataModule.__name__}Config", +# ) +# datamodule_store(inaturalist_config, name="inaturalist") diff --git a/project/configs/datamodule/cifar10.yaml b/project/configs/datamodule/cifar10.yaml new file mode 100644 index 00000000..e8d3fb78 --- /dev/null +++ b/project/configs/datamodule/cifar10.yaml @@ -0,0 +1,6 @@ +defaults: +- vision +_target_: project.datamodules.CIFAR10DataModule +batch_size: 128 +train_transforms: + _target_: project.datamodules.image_classification.cifar10.cifar10_train_transforms diff --git a/project/configs/datamodule/fashion_mnist.yaml b/project/configs/datamodule/fashion_mnist.yaml new file mode 100644 index 00000000..f2038d2d --- /dev/null +++ b/project/configs/datamodule/fashion_mnist.yaml @@ -0,0 +1,3 @@ +defaults: +- mnist +_target_: project.datamodules.FashionMNISTDataModule diff --git a/project/configs/datamodule/imagenet.yaml b/project/configs/datamodule/imagenet.yaml new file mode 100644 index 00000000..3e82c78b --- /dev/null +++ b/project/configs/datamodule/imagenet.yaml @@ -0,0 +1,4 @@ +defaults: + - vision +_target_: project.datamodules.ImageNetDataModule +# todo: add good configuration options here. diff --git a/project/configs/datamodule/imagenet32.yaml b/project/configs/datamodule/imagenet32.yaml new file mode 100644 index 00000000..208119be --- /dev/null +++ b/project/configs/datamodule/imagenet32.yaml @@ -0,0 +1,9 @@ +defaults: +- vision +_target_: project.datamodules.ImageNet32DataModule +data_dir: ${constant:SCRATCH} +val_split: -1 +num_images_per_val_class: 50 +normalize: True +train_transforms: + _target_: project.datamodules.image_classification.imagenet32.imagenet32_train_transforms diff --git a/project/configs/datamodule/inaturalist.yaml b/project/configs/datamodule/inaturalist.yaml new file mode 100644 index 00000000..5670be6b --- /dev/null +++ b/project/configs/datamodule/inaturalist.yaml @@ -0,0 +1,5 @@ +defaults: +- vision +_target_: project.datamodules.INaturalistDataModule +version: "2021_train" +target_type: "full" diff --git a/project/configs/datamodule/mnist.yaml b/project/configs/datamodule/mnist.yaml new file mode 100644 index 00000000..c9a16639 --- /dev/null +++ b/project/configs/datamodule/mnist.yaml @@ -0,0 +1,7 @@ +defaults: +- vision +_target_: project.datamodules.MNISTDataModule +normalize: True +batch_size: 128 +train_transforms: + _target_: project.datamodules.image_classification.mnist.mnist_train_transforms diff --git a/project/configs/datamodule/vision.yaml b/project/configs/datamodule/vision.yaml new file mode 100644 index 00000000..e3f10b79 --- /dev/null +++ b/project/configs/datamodule/vision.yaml @@ -0,0 +1,9 @@ +_target_: project.datamodules.VisionDataModule +data_dir: ${constant:DATA_DIR} +num_workers: ${constant:NUM_WORKERS} +val_split: 0.1 # NOTE: reduced from default of 0.2 +normalize: True # NOTE: Set to True by default instead of False +shuffle: True # NOTE: Set to True by default instead of False. +pin_memory: True # NOTE: Set to True by default instead of False. +seed: 42 +batch_size: 64 diff --git a/project/configs/trainer/callbacks/no_checkpoints.yaml b/project/configs/trainer/callbacks/no_checkpoints.yaml index 49dbccf0..ee01c442 100644 --- a/project/configs/trainer/callbacks/no_checkpoints.yaml +++ b/project/configs/trainer/callbacks/no_checkpoints.yaml @@ -1,6 +1,11 @@ -model_summary: - _target_: lightning.pytorch.callbacks.RichModelSummary - max_depth: 1 +defaults: + - default -rich_progress_bar: - _target_: lightning.pytorch.callbacks.RichProgressBar +model_checkpoint: null + +# model_summary: +# _target_: lightning.pytorch.callbacks.RichModelSummary +# max_depth: 1 + +# rich_progress_bar: +# _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/project/conftest.py b/project/conftest.py index 0a69188d..5cf5a914 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -21,11 +21,10 @@ from torch.utils.data import DataLoader from project.configs.config import Config -from project.configs.datamodule import DATA_DIR from project.datamodules.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.experiment import ( instantiate_algorithm, instantiate_datamodule, @@ -36,7 +35,10 @@ setup_logging, ) from project.utils.hydra_utils import resolve_dictconfig -from project.utils.testutils import default_marks_for_config_name +from project.utils.testutils import ( + default_marks_for_config_combinations, + default_marks_for_config_name, +) from project.utils.types import is_sequence_of from project.utils.types.protocols import DataModule @@ -135,11 +137,6 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[Function]): items.pop(index) -@pytest.fixture(scope="session") -def data_dir() -> Path: - return DATA_DIR - - @pytest.fixture(autouse=True) def seed(request: pytest.FixtureRequest): """Fixture that seeds everything for reproducibility and yields the random seed used.""" @@ -362,9 +359,19 @@ def experiment_dictconfig( datamodule_name: str | None, network_name: str | None, overrides: tuple[str, ...], + request: pytest.FixtureRequest, ) -> Generator[DictConfig, None, None]: tmp_path = tmp_path_factory.mktemp("experiment_testing") + combination = set([datamodule_name, network_name, algorithm_name]) + for configs, marks in default_marks_for_config_combinations.items(): + configs = set(configs) + if combination >= configs: + logger.debug(f"Applying markers because {combination} contains {configs}") + # There is a combination of potentially unsupported configs here. + for mark in marks: + request.applymarker(mark) + default_overrides = [ # NOTE: if we were to run the test in a slurm job, this wouldn't make sense. "seed=42", diff --git a/project/datamodules/__init__.py b/project/datamodules/__init__.py index 5e298407..40eb5928 100644 --- a/project/datamodules/__init__.py +++ b/project/datamodules/__init__.py @@ -1,17 +1,21 @@ from .image_classification import ImageClassificationDataModule from .image_classification.cifar10 import CIFAR10DataModule, cifar10_normalization from .image_classification.fashion_mnist import FashionMNISTDataModule +from .image_classification.imagenet import ImageNetDataModule from .image_classification.imagenet32 import ImageNet32DataModule, imagenet32_normalization +from .image_classification.inaturalist import INaturalistDataModule from .image_classification.mnist import MNISTDataModule -from .vision.base import VisionDataModule +from .vision import VisionDataModule __all__ = [ "cifar10_normalization", "CIFAR10DataModule", "FashionMNISTDataModule", + "INaturalistDataModule", "ImageClassificationDataModule", "imagenet32_normalization", "ImageNet32DataModule", + "ImageNetDataModule", "MNISTDataModule", "VisionDataModule", ] diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index 432a4f7f..e0e9666c 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -3,7 +3,10 @@ import matplotlib.pyplot as plt import pytest -from tensor_regression.fixture import TensorRegressionFixture, get_test_source_and_temp_file_paths +from tensor_regression.fixture import ( + TensorRegressionFixture, + get_test_source_and_temp_file_paths, +) from torch import Tensor from project.utils.testutils import run_for_all_datamodules @@ -12,7 +15,8 @@ from ..utils.types.protocols import DataModule -@pytest.mark.timeout(25, func_only=True) +# @pytest.mark.timeout(25, func_only=True) +@pytest.mark.slow @run_for_all_datamodules() def test_first_batch( datamodule: DataModule, @@ -21,6 +25,7 @@ def test_first_batch( original_datadir: Path, datadir: Path, ): + # todo: skip this test if the dataset isn't already downloaded (for example on the GitHub CI). datamodule.prepare_data() datamodule.setup("fit") @@ -67,7 +72,10 @@ def test_first_batch( fig.suptitle(f"First batch of datamodule {type(datamodule).__name__}") figure_path, _ = get_test_source_and_temp_file_paths( - extension=".png", request=request, original_datadir=original_datadir, datadir=datadir + extension=".png", + request=request, + original_datadir=original_datadir, + datadir=datadir, ) figure_path.parent.mkdir(exist_ok=True, parents=True) fig.savefig(figure_path) diff --git a/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml b/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml index e027265f..f798ebc1 100644 --- a/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml +++ b/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml @@ -9,7 +9,7 @@ - 3 - 32 - 32 - sum: -2919.015 + sum: -2919.016 '1': device: cpu hash: 3692171093056153318 diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml new file mode 100644 index 00000000..6c2baa19 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: 3674008927974037273 + max: 2.64 + mean: -0.084 + min: -2.118 + shape: + - 64 + - 3 + - 224 + - 224 + sum: -809988.0 +'1': + device: cpu + hash: 3360823606619711831 + max: 988 + mean: 518.219 + min: 0 + shape: + - 64 + sum: 33166 diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml new file mode 100644 index 00000000..2540dc74 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: -8533209956811673698 + max: 2.64 + mean: 0.014 + min: -2.118 + shape: + - 64 + - 3 + - 32 + - 32 + sum: 2763.33 +'1': + device: cpu + hash: -8357971836707848708 + max: 993 + mean: 487.125 + min: 1 + shape: + - 64 + sum: 31176 diff --git a/project/datamodules/image_classification/base.py b/project/datamodules/image_classification/base.py index 0741e021..331cfbe6 100644 --- a/project/datamodules/image_classification/base.py +++ b/project/datamodules/image_classification/base.py @@ -2,12 +2,14 @@ from torch import Tensor -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.utils.types import C, H, W +# todo: decide if this should be a protocol or an actual base class (currently a base class). + class ImageClassificationDataModule[BatchType: tuple[Tensor, Tensor]](VisionDataModule[BatchType]): - """Protocol that describes lightning data modules for image classification.""" + """Lightning data modules for image classification.""" num_classes: int """Number of classes in the dataset.""" diff --git a/project/datamodules/image_classification/fashion_mnist.py b/project/datamodules/image_classification/fashion_mnist.py index df42a784..8b8c080d 100644 --- a/project/datamodules/image_classification/fashion_mnist.py +++ b/project/datamodules/image_classification/fashion_mnist.py @@ -1,17 +1,11 @@ from __future__ import annotations -from collections.abc import Callable -from typing import Any - -import torch from torchvision.datasets import FashionMNIST -from torchvision.transforms import v2 as transform_lib -from project.datamodules.image_classification.base import ImageClassificationDataModule -from project.utils.types import C, H, W +from project.datamodules.image_classification.mnist import MNISTDataModule -class FashionMNISTDataModule(ImageClassificationDataModule): +class FashionMNISTDataModule(MNISTDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png @@ -42,63 +36,3 @@ class FashionMNISTDataModule(ImageClassificationDataModule): name = "fashion_mnist" dataset_cls = FashionMNIST - dims = (C(1), H(28), W(28)) - num_classes = 10 - - def __init__( - self, - data_dir: str | None = None, - val_split: int | float = 0.2, - num_workers: int | None = 0, - normalize: bool = False, - batch_size: int = 32, - seed: int = 42, - shuffle: bool = True, - pin_memory: bool = True, - drop_last: bool = False, - *args: Any, - **kwargs: Any, - ) -> None: - """ - - Args: - data_dir: Root directory of dataset. - val_split: Percent (float) or number (int) of samples to use for the validation split. - num_workers: Number of workers to use for loading data. - normalize: If ``True``, applies image normalization. - batch_size: Number of samples per batch to load. - seed: Random seed to be used for train/val/test splits. - shuffle: If ``True``, shuffles the train data every epoch. - pin_memory: If ``True``, the data loader will copy Tensors into CUDA pinned memory \ - before returning them. - drop_last: If ``True``, drops the last incomplete batch. - """ - super().__init__( - data_dir=data_dir, - val_split=val_split, - num_workers=num_workers, - normalize=normalize, - batch_size=batch_size, - seed=seed, - shuffle=shuffle, - pin_memory=pin_memory, - drop_last=drop_last, - *args, - **kwargs, - ) - self.prepare_data() - self.setup("fit") - - def default_transforms(self) -> Callable: - if self.normalize: - mnist_transforms = transform_lib.Compose( - [ - transform_lib.ToImage(), - transform_lib.ToDtype(torch.float32, scale=True), - transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ] - ) - else: - mnist_transforms = transform_lib.Compose([transform_lib.ToImage()]) - - return mnist_transforms diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py new file mode 100644 index 00000000..60554038 --- /dev/null +++ b/project/datamodules/image_classification/imagenet.py @@ -0,0 +1,427 @@ +from __future__ import annotations + +import logging +import math +import os +import shutil +import tarfile +import time +from collections import defaultdict +from collections.abc import Callable +from logging import getLogger as get_logger +from pathlib import Path +from typing import ClassVar, Literal + +import rich +import rich.logging +import torch +import torch.utils.data +import tqdm +from torchvision.datasets import ImageNet +from torchvision.models.resnet import ResNet152_Weights +from torchvision.transforms import v2 as transform_lib + +from project.datamodules.vision import VisionDataModule +from project.utils.env_vars import DATA_DIR, NUM_WORKERS +from project.utils.types import C, H, StageStr, W +from project.utils.types.protocols import Module + +logger = get_logger(__name__) + + +def imagenet_normalization(): + return transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + +type ClassIndex = int +type ImageIndex = int + + +class ImageNetDataModule(VisionDataModule): + """ImageNet datamodule. + + Extracted from https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/datamodules/imagenet_datamodule.py + - Made this a subclass of VisionDataModule + + Notes: + - train_dataloader uses the train split of imagenet2012 and puts away a portion of it for the validation split. + - val_dataloader uses the part of the train split of imagenet2012 that was not used for training via + `num_imgs_per_val_class` + - TODO: needs to pass split='val' to UnlabeledImagenet. + - test_dataloader uses the validation split of imagenet2012 for testing. + - TODO: need to pass num_imgs_per_class=-1 for test dataset and split="test". + """ + + name: ClassVar[str] = "imagenet" + """Dataset name.""" + + dataset_cls: ClassVar[type[ImageNet]] = ImageNet + """Dataset class to use.""" + + dims: tuple[C, H, W] = (C(3), H(224), W(224)) + """A tuple describing the shape of the data.""" + + num_classes: ClassVar[int] = 1000 + + def __init__( + self, + data_dir: str | Path = DATA_DIR, + *, + val_split: int | float = 0.01, + num_workers: int = NUM_WORKERS, + normalize: bool = False, + image_size: int = 224, + batch_size: int = 32, + seed: int = 42, + shuffle: bool = True, + pin_memory: bool = True, + drop_last: bool = False, + train_transforms: Callable | None = None, + val_transforms: Callable | None = None, + test_transforms: Callable | None = None, + **kwargs, + ): + """Creates an ImageNet datamodule (doesn't load or prepare the dataset yet). + + Parameters + ---------- + data_dir: path to the imagenet dataset file + val_split: save `val_split`% of the training data *of each class* for validation. + image_size: final image size + num_workers: how many data workers + batch_size: batch_size + shuffle: If true shuffles the data every epoch + pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before \ + returning them + drop_last: If true drops the last incomplete batch + """ + self.image_size = image_size + super().__init__( + data_dir, + num_workers=num_workers, + val_split=val_split, + shuffle=shuffle, + pin_memory=pin_memory, + normalize=normalize, + seed=seed, + batch_size=batch_size, + drop_last=drop_last, + train_transforms=train_transforms or self.train_transform(), + val_transforms=val_transforms or self.val_transform(), + test_transforms=test_transforms, + **kwargs, + ) + self.dims = (C(3), H(self.image_size), W(self.image_size)) + self.train_kwargs = self.train_kwargs | {"split": "train"} + self.valid_kwargs = self.valid_kwargs | {"split": "train"} + self.test_kwargs = self.test_kwargs | {"split": "val"} + # self.test_dataset_cls = UnlabeledImagenet + + def prepare_data(self) -> None: + network_imagenet_dir = Path("/network/datasets/imagenet") + logger.debug(f"Preparing ImageNet train split in {self.data_dir}...") + prepare_imagenet( + self.data_dir, + network_imagenet_dir=network_imagenet_dir, + split="train", + ) + logger.debug(f"Preparing ImageNet val (test) split in {self.data_dir}...") + prepare_imagenet( + self.data_dir, + network_imagenet_dir=network_imagenet_dir, + split="val", + ) + + super().prepare_data() + + def setup(self, stage: StageStr | None = None) -> None: + logger.debug(f"Setup ImageNet datamodule for {stage=}") + super().setup(stage) + + def _split_dataset(self, dataset: ImageNet, train: bool = True) -> torch.utils.data.Dataset: + class_item_indices: dict[ClassIndex, list[ImageIndex]] = defaultdict(list) + for dataset_index, y in enumerate(dataset.targets): + class_item_indices[y].append(dataset_index) + + train_val_split_seed = self.seed + gen = torch.Generator().manual_seed(train_val_split_seed) + + train_class_indices: dict[ClassIndex, list[ImageIndex]] = {} + valid_class_indices: dict[ClassIndex, list[ImageIndex]] = {} + + for label, dataset_indices in class_item_indices.items(): + num_images_in_class = len(dataset_indices) + num_valid = math.ceil(self.val_split * num_images_in_class) + num_train = num_images_in_class - num_valid + + permutation = torch.randperm(len(dataset_indices), generator=gen) + dataset_indices = torch.tensor(dataset_indices)[permutation].tolist() + + train_indices = dataset_indices[:num_train] + valid_indices = dataset_indices[num_train:] + + train_class_indices[label] = train_indices + valid_class_indices[label] = valid_indices + + all_train_indices = sum(train_class_indices.values(), []) + all_valid_indices = sum(valid_class_indices.values(), []) + train_dataset = torch.utils.data.Subset(dataset, all_train_indices) + valid_dataset = torch.utils.data.Subset(dataset, all_valid_indices) + if train: + return train_dataset + return valid_dataset + + def _verify_splits(self, data_dir: str | Path, split: str) -> None: + dirs = os.listdir(data_dir) + if split not in dirs: + raise FileNotFoundError( + f"a {split} Imagenet split was not found in {data_dir}," + f" make sure the folder contains a subfolder named {split}" + ) + + def default_transforms(self) -> Module[[torch.Tensor], torch.Tensor]: + return ResNet152_Weights.IMAGENET1K_V1.transforms + + def train_transform(self) -> Module[[torch.Tensor], torch.Tensor]: + """The standard imagenet transforms. + + .. code-block:: python + + transform_lib.Compose([ + transform_lib.RandomResizedCrop(self.image_size), + transform_lib.RandomHorizontalFlip(), + transform_lib.ToTensor(), + transform_lib.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + """ + return transform_lib.Compose( + [ + transform_lib.RandomResizedCrop(self.image_size), + transform_lib.RandomHorizontalFlip(), + transform_lib.ToImage(), + transform_lib.ToDtype(torch.float32, scale=True), + imagenet_normalization(), + ] + ) + + def val_transform(self) -> Callable: + """The standard imagenet transforms for validation. + + .. code-block:: python + + transform_lib.Compose([ + transform_lib.Resize(self.image_size + 32), + transform_lib.CenterCrop(self.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + """ + + return transform_lib.Compose( + [ + transform_lib.Resize(self.image_size + 32), + transform_lib.CenterCrop(self.image_size), + transform_lib.ToImage(), + transform_lib.ToDtype(torch.float32, scale=True), + imagenet_normalization(), + ] + ) + + +def prepare_imagenet( + root: Path, + split: Literal["train", "val"] = "train", + network_imagenet_dir: Path = Path("/network/datasets/imagenet"), +) -> None: + """Custom preparation function for ImageNet, using @obilaniu's tar magic in Python form. + + The core of this is equivalent to these bash commands: + + ```bash + mkdir -p $SLURM_TMPDIR/imagenet/val + cd $SLURM_TMPDIR/imagenet/val + tar -xf /network/scratch/b/bilaniuo/ILSVRC2012_img_val.tar + mkdir -p $SLURM_TMPDIR/imagenet/train + cd $SLURM_TMPDIR/imagenet/train + tar -xf /network/datasets/imagenet/ILSVRC2012_img_train.tar \ + --to-command='mkdir ${TAR_REALNAME%.tar}; tar -xC ${TAR_REALNAME%.tar}' + ``` + """ + if not network_imagenet_dir.exists(): + raise NotImplementedError( + f"Assuming that we're running on a cluster where {network_imagenet_dir} exists for now." + ) + val_archive_file_name = "ILSVRC2012_img_val.tar" + train_archive_file_name = "ILSVRC2012_img_train.tar" + devkit_file_name = "ILSVRC2012_devkit_t12.tar.gz" + md5sums_file_name = "md5sums" + if not root.exists(): + root.mkdir(parents=True) + + def _symlink_if_needed(filename: str, network_imagenet_dir: Path): + if not (symlink := root / filename).exists(): + symlink.symlink_to(network_imagenet_dir / filename) + + # Create a symlink to the archive in $SLURM_TMPDIR, because torchvision expects it to be + # there. + _symlink_if_needed(train_archive_file_name, network_imagenet_dir) + _symlink_if_needed(val_archive_file_name, network_imagenet_dir) + _symlink_if_needed(devkit_file_name, network_imagenet_dir) + # TODO: COPY the file, not symlink it! (otherwise we get some "Read-only filesystem" errors + # when calling tvd.ImageNet(...). (Probably because the constructor tries to open the file) + # _symlink_if_needed(md5sums_file_name, network_imagenet_dir) + md5sums_file = root / md5sums_file_name + if not md5sums_file.exists(): + shutil.copyfile(network_imagenet_dir / md5sums_file_name, md5sums_file) + md5sums_file.chmod(0o755) + + if split == "train": + train_dir = root / "train" + train_dir.mkdir(exist_ok=True, parents=True) + train_archive = network_imagenet_dir / train_archive_file_name + previously_extracted_dirs_file = train_dir / ".previously_extracted_dirs.txt" + _extract_train_archive( + train_archive=train_archive, + train_dir=train_dir, + previously_extracted_dirs_file=previously_extracted_dirs_file, + ) + if previously_extracted_dirs_file.exists(): + previously_extracted_dirs_file.unlink() + + # OR: could just reuse the equivalent-ish from torchvision, but which doesn't support + # resuming after an interrupt. + # from torchvision.datasets.imagenet import parse_train_archive + # parse_train_archive(root, file=train_archive_file_name, folder="train") + else: + from torchvision.datasets.imagenet import ( + load_meta_file, + parse_devkit_archive, + parse_val_archive, + ) + + parse_devkit_archive(root, file=devkit_file_name) + wnids = load_meta_file(root)[1] + val_dir = root / "val" + if not val_dir.exists(): + logger.debug(f"Extracting ImageNet test set to {val_dir}") + parse_val_archive(root, file=val_archive_file_name, wnids=wnids) + return + + logger.debug(f"listing the contents of {val_dir}") + children = list(val_dir.iterdir()) + + if not children: + logger.debug(f"Extracting ImageNet test set to {val_dir}") + parse_val_archive(root, file=val_archive_file_name, wnids=wnids) + return + + if all(child.is_dir() for child in children): + logger.info("Validation split already extracted. Skipping.") + return + + logger.warning( + f"Incomplete extraction of the ImageNet test set in {val_dir}, deleting it and extracting again." + ) + shutil.rmtree(root / "val", ignore_errors=False) + parse_val_archive(root, file=val_archive_file_name, wnids=wnids) + + # val_dir = root / "val" + # val_dir.mkdir(exist_ok=True, parents=True) + # with tarfile.open(network_imagenet_dir / val_archive_file_name) as val_tarfile: + # val_tarfile.extractall(val_dir) + + +def _extract_train_archive( + *, train_archive: Path, train_dir: Path, previously_extracted_dirs_file: Path +) -> None: + # The ImageNet train archive is a tarfile of tarfiles (one for each class). + logger.debug("Extracting the ImageNet train archive using Olexa's tar magic in python form...") + train_dir.mkdir(exist_ok=True, parents=True) + + # Save a small text file or something that tells us which subdirs are + # done extracting so we can just skip ahead to the right directory? + previously_extracted_dirs: set[str] = set() + + if previously_extracted_dirs_file.exists(): + previously_extracted_dirs = set( + stripped_line + for line in previously_extracted_dirs_file.read_text().splitlines() + if (stripped_line := line.strip()) + ) + if len(previously_extracted_dirs) == 1000: + logger.info("Train archive already fully extracted. Skipping.") + return + logger.debug( + f"{len(previously_extracted_dirs)} directories have already been fully extracted." + ) + previously_extracted_dirs_file.write_text( + "\n".join(sorted(previously_extracted_dirs)) + "\n" + ) + + elif len(list(train_dir.iterdir())) == 1000: + logger.info("Train archive already fully extracted. Skipping.") + return + + with tarfile.open(train_archive, mode="r") as train_tarfile: + for member in tqdm.tqdm( + train_tarfile, + total=1000, # hard-coded here, since we know there are 1000 folders. + desc="Extracting train archive", + unit="Directories", + position=0, + ): + if member.name in previously_extracted_dirs: + continue + + buffer = train_tarfile.extractfile(member) + assert buffer is not None + + class_subdir = train_dir / member.name.replace(".tar", "") + class_subdir_existed = class_subdir.exists() + if class_subdir_existed: + # Remove all the (potentially partially constructed) files in the directory. + logger.debug(f"Removing partially-constructed dir {class_subdir}") + shutil.rmtree(class_subdir, ignore_errors=False) + else: + class_subdir.mkdir(parents=True, exist_ok=True) + + with tarfile.open(fileobj=buffer, mode="r|*") as class_tarfile: + class_tarfile.extractall(class_subdir, filter="data") + + # Alternative: .extractall with a list of members to extract: + # members = sub_tarfile.getmembers() # note: loads the full archive. + # if not files_in_subdir: + # members_to_extract = members + # else: + # members_to_extract = [m for m in members if m.name not in files_in_subdir] + # if members_to_extract: + # sub_tarfile.extractall(subdir, members=members_to_extract, filter="data") + + assert member.name not in previously_extracted_dirs + previously_extracted_dirs.add(member.name) + with previously_extracted_dirs_file.open("a") as f: + f.write(f"{member.name}\n") + + +def main(): + logging.basicConfig( + level=logging.DEBUG, format="%(message)s", handlers=[rich.logging.RichHandler()] + ) + datamodule = ImageNetDataModule() + start = time.time() + datamodule.prepare_data() + datamodule.setup("fit") + dl = datamodule.train_dataloader() + _batch = next(iter(dl)) + end = time.time() + print(f"Prepared imagenet in {end-start:.2f}s.") + + +if __name__ == "__main__": + main() diff --git a/project/datamodules/image_classification/imagenet32.py b/project/datamodules/image_classification/imagenet32.py index 825ba493..c66eb0f3 100644 --- a/project/datamodules/image_classification/imagenet32.py +++ b/project/datamodules/image_classification/imagenet32.py @@ -11,15 +11,16 @@ import gdown import numpy as np +import torch from PIL import Image from torch.utils.data import DataLoader, Dataset, Subset from torchvision.datasets import VisionDataset from torchvision.transforms import v2 as transforms +from project.datamodules.vision import VisionDataModule +from project.utils.env_vars import DATA_DIR, SCRATCH from project.utils.types import C, H, StageStr, W -from ..vision.base import VisionDataModule - logger = getLogger(__name__) @@ -176,11 +177,11 @@ class ImageNet32DataModule(VisionDataModule): def __init__( self, - data_dir: str | Path, - readonly_datasets_dir: str | Path | None = None, + data_dir: Path = DATA_DIR, + readonly_datasets_dir: str | Path | None = SCRATCH, val_split: int | float = -1, num_images_per_val_class: int | None = 50, - num_workers: int | None = 0, + num_workers: int = 0, normalize: bool = False, batch_size: int = 32, seed: int = 42, @@ -193,7 +194,7 @@ def __init__( ) -> None: Path(data_dir).mkdir(parents=True, exist_ok=True) super().__init__( - data_dir=str(data_dir), + data_dir=data_dir, val_split=val_split, num_workers=num_workers, normalize=normalize, @@ -205,9 +206,10 @@ def __init__( train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, + # extra kwargs + readonly_datasets_dir=readonly_datasets_dir, ) self.num_images_per_val_class = num_images_per_val_class - if self.val_split == -1 and self.num_images_per_val_class is None: raise ValueError( "Can't have both `val_split` and `num_images_per_val_class` set to `None`!" @@ -219,9 +221,6 @@ def __init__( ) self.num_images_per_val_class = None - # ImageNetDataModule uses num_imgs_per_val_class: int = 50, which makes sense! Here - # however we're using probably more than that for validation. - self.EXTRA_ARGS["readonly_datasets_dir"] = readonly_datasets_dir self.dataset_train: ImageNet32Dataset | Subset self.dataset_val: ImageNet32Dataset | Subset self.dataset_test: ImageNet32Dataset | Subset @@ -232,9 +231,7 @@ def num_samples(self) -> int: def prepare_data(self) -> None: """Saves files to data_dir.""" - # NOTE: In our case, the download gives us both. No need to do it twice. - self.dataset_cls(self.data_dir, train=True, download=True, **self.EXTRA_ARGS) - self.dataset_cls(self.data_dir, train=False, download=True, **self.EXTRA_ARGS) + super().prepare_data() def setup(self, stage: StageStr | None = None) -> None: """Creates train, val, and test dataset.""" @@ -246,32 +243,17 @@ def setup(self, stage: StageStr | None = None) -> None: else: logger.debug("Setting up for all stages") - if stage in ["fit", "val", None]: - train_transforms = ( - self.default_transforms() - if self.train_transforms is None - else self.train_transforms - ) - val_transforms = ( - self.default_transforms() if self.val_transforms is None else self.val_transforms - ) - # Create the entire dataset twice. This is only needed because they have different - # transforms... - base_dataset = self.dataset_cls( - self.data_dir, - train=True, - transform=transforms.ToTensor(), - **self.EXTRA_ARGS, - ) - # Make sure they both use the same underlying data. (so we don't use twice as much - # memory, like the base-class does! + if stage in ["fit", "validate", None]: + base_dataset = self.dataset_cls(self.data_dir, **self.train_kwargs) + assert len(base_dataset) == 1_281_159 + base_dataset_train = copy.deepcopy(base_dataset) - base_dataset_train.transform = train_transforms + base_dataset_train.transform = self.train_transforms base_dataset_train.data = base_dataset.data base_dataset_train.targets = base_dataset.targets base_dataset_valid = copy.deepcopy(base_dataset) - base_dataset_valid.transform = val_transforms + base_dataset_valid.transform = self.val_transforms base_dataset_valid.data = base_dataset.data base_dataset_valid.targets = base_dataset.targets @@ -288,22 +270,20 @@ def setup(self, stage: StageStr | None = None) -> None: self.dataset_val = self._split_dataset(base_dataset_valid, train=False) if stage in ["test", None]: - test_transforms = ( - self.default_transforms() if self.test_transforms is None else self.test_transforms - ) + test_transforms = self.test_transforms or self.default_transforms() self.dataset_test = self.dataset_cls( self.data_dir, train=False, transform=test_transforms, **self.EXTRA_ARGS ) def default_transforms(self) -> Callable: """Default transform for the dataset.""" - if self.normalize: - in32_transforms = transforms.Compose( - [transforms.ToTensor(), imagenet32_normalization()] - ) - else: - in32_transforms = transforms.Compose([transforms.ToTensor()]) - return in32_transforms + return transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + ] + + ([imagenet32_normalization()] if self.normalize else []) + ) def train_dataloader(self) -> DataLoader: """The train dataloader.""" @@ -328,6 +308,7 @@ def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: ) def _split_dataset(self, dataset: ImageNet32Dataset, train: bool = True) -> Subset: + assert self.val_split >= 0 split_dataset = super()._split_dataset(dataset, train=train) assert isinstance(split_dataset, Subset) return split_dataset @@ -363,6 +344,7 @@ def imagenet32_train_transforms(): return transforms.Compose( [ transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomCrop(size=32, padding=4, padding_mode="edge"), imagenet32_normalization(), diff --git a/project/datamodules/image_classification/imagenet32_test.py b/project/datamodules/image_classification/imagenet32_test.py index 771978cc..537c91ce 100644 --- a/project/datamodules/image_classification/imagenet32_test.py +++ b/project/datamodules/image_classification/imagenet32_test.py @@ -1,18 +1,19 @@ import itertools -from pathlib import Path import pytest -from project.configs.datamodule import SCRATCH +from project.utils.env_vars import DATA_DIR, SCRATCH +from project.utils.testutils import IN_GITHUB_CI from .imagenet32 import ImageNet32DataModule +@pytest.mark.skipif(IN_GITHUB_CI, reason="Can't run on the GitHub CI.") @pytest.mark.slow -def test_dataset_download_works(data_dir: Path): +def test_dataset_download_works(): batch_size = 16 datamodule = ImageNet32DataModule( - data_dir=data_dir, + data_dir=DATA_DIR, readonly_datasets_dir=SCRATCH, batch_size=batch_size, num_images_per_val_class=10, @@ -21,8 +22,8 @@ def test_dataset_download_works(data_dir: Path): assert datamodule.val_split == -1 datamodule.prepare_data() datamodule.setup(None) + expected_total = 1_281_159 - expected_total = 1281159 assert ( datamodule.num_samples == expected_total - datamodule.num_classes * datamodule.num_images_per_val_class diff --git a/project/datamodules/image_classification/inaturalist.py b/project/datamodules/image_classification/inaturalist.py index fffff33c..a1090a72 100644 --- a/project/datamodules/image_classification/inaturalist.py +++ b/project/datamodules/image_classification/inaturalist.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import warnings from collections.abc import Callable from logging import getLogger as get_logger @@ -11,6 +10,7 @@ from torchvision.datasets import INaturalist from project.datamodules.image_classification.base import ImageClassificationDataModule +from project.utils.env_vars import DATA_DIR, NUM_WORKERS, SLURM_TMPDIR from project.utils.types import C, H, W logger = get_logger(__name__) @@ -25,23 +25,6 @@ Version = Version2017_2019 | Version2021 -def get_slurm_tmpdir() -> Path: - if "SLURM_TMPDIR" in os.environ: - return Path(os.environ["SLURM_TMPDIR"]) - if "SLURM_JOB_ID" not in os.environ: - raise RuntimeError( - "SLURM_JOBID environment variable isn't set. Are you running this from a SLURM " - "cluster?" - ) - slurm_tmpdir = Path(f"/Tmp/slurm.{os.environ['SLURM_JOB_ID']}.0") - if not slurm_tmpdir.is_dir(): - raise NotImplementedError( - f"TODO: You appear to be running this outside the Mila cluster, since SLURM_TMPDIR " - f"isn't located at {slurm_tmpdir}." - ) - return slurm_tmpdir - - def inat_dataset_dir() -> Path: network_dir = Path("/network/datasets/inat") if not network_dir.exists(): @@ -61,9 +44,9 @@ class INaturalistDataModule(ImageClassificationDataModule): def __init__( self, - data_dir: str | Path | None = None, + data_dir: str | Path = DATA_DIR, val_split: int | float = 0.1, - num_workers: int | None = None, + num_workers: int = NUM_WORKERS, normalize: bool = False, batch_size: int = 32, seed: int = 42, @@ -79,7 +62,8 @@ def __init__( ) -> None: # assuming that we're on the Mila cluster atm. self.network_dir = inat_dataset_dir() - slurm_tmpdir = get_slurm_tmpdir() + assert SLURM_TMPDIR, "assuming that we're on a compute node." + slurm_tmpdir = SLURM_TMPDIR default_data_dir = slurm_tmpdir / "data" if data_dir is None: data_dir = default_data_dir @@ -121,7 +105,8 @@ def __init__( if not isinstance(target_type, list): self.num_classes = None - if version == "2021_train_mini" and target_type == "full": + # todo: double-check that the 2021_train split also has 10_000 classes. + if version in ["2021_train_mini", "2021_train"] and target_type == "full": self.num_classes = 10_000 if isinstance(train_transforms, T.Compose): channels = 3 diff --git a/project/datamodules/image_classification/mnist.py b/project/datamodules/image_classification/mnist.py index dc268aa2..905aa261 100644 --- a/project/datamodules/image_classification/mnist.py +++ b/project/datamodules/image_classification/mnist.py @@ -27,11 +27,7 @@ def mnist_train_transforms(): def mnist_normalization(): # NOTE: Taken from https://stackoverflow.com/a/67233938/6388696 # return transforms.Normalize(mean=0.5, std=0.5) - return transforms.Compose( - [ - transforms.Normalize(mean=[0.1307], std=[0.3081]), - ] - ) + return transforms.Normalize(mean=[0.1307], std=[0.3081]) def mnist_unnormalization(x: Tensor) -> Tensor: diff --git a/project/datamodules/image_classification/transforms.py b/project/datamodules/image_classification/transforms.py deleted file mode 100644 index 283541d4..00000000 --- a/project/datamodules/image_classification/transforms.py +++ /dev/null @@ -1,31 +0,0 @@ -from collections.abc import Callable - -from torchvision import transforms - - -def imagenet_normalization() -> Callable: - return transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) - - -def cifar10_normalization() -> Callable: - return transforms.Normalize( - mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], - std=[x / 255.0 for x in [63.0, 62.1, 66.7]], - ) - - -def stl10_normalization() -> Callable: - return transforms.Normalize(mean=(0.43, 0.42, 0.39), std=(0.27, 0.26, 0.27)) - - -def emnist_normalization(split: str) -> Callable: - # `stats` contains mean and std for each `split`. - stats = { - "balanced": (0.175, 0.333), - "byclass": (0.174, 0.332), - "bymerge": (0.174, 0.332), - "digits": (0.173, 0.332), - "letters": (0.172, 0.331), - "mnist": (0.173, 0.332), - } - return transforms.Normalize(mean=stats[split][0], std=stats[split][1]) diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision.py similarity index 71% rename from project/datamodules/vision/base.py rename to project/datamodules/vision.py index 9b54144e..c2d0f6b1 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision.py @@ -6,27 +6,18 @@ from collections.abc import Callable from logging import getLogger as get_logger from pathlib import Path -from typing import Any, ClassVar, Concatenate +from typing import ClassVar, Concatenate import torch from lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split from torchvision.datasets import VisionDataset -from typing_extensions import ParamSpec +from torchvision.transforms import v2 as transforms +from project.utils.env_vars import DATA_DIR, NUM_WORKERS from project.utils.types import C, H, StageStr, W +from project.utils.types.protocols import DataModule -from ...utils.types.protocols import DataModule - -P = ParamSpec("P") - -SLURM_TMPDIR: Path | None = ( - Path(os.environ["SLURM_TMPDIR"]) - if "SLURM_TMPDIR" in os.environ - else tmp - if "SLURM_JOB_ID" in os.environ and (tmp := Path("/tmp")).exists() - else None -) logger = get_logger(__name__) @@ -47,9 +38,9 @@ class VisionDataModule[BatchType_co](LightningDataModule, DataModule[BatchType_c def __init__( self, - data_dir: str | Path | None = None, + data_dir: str | Path = DATA_DIR, val_split: int | float = 0.2, - num_workers: int | None = None, + num_workers: int = NUM_WORKERS, normalize: bool = False, batch_size: int = 32, seed: int = 42, @@ -79,13 +70,8 @@ def __init__( """ super().__init__() - from project.configs.datamodule import DATA_DIR - - self.data_dir = data_dir if data_dir is not None else DATA_DIR + self.data_dir: Path = Path(data_dir or DATA_DIR) self.val_split = val_split - if num_workers is None: - num_workers = num_cpus_on_node() - logger.debug(f"Setting the number of dataloader workers to {num_workers}.") self.num_workers = num_workers self.normalize = normalize self.batch_size = batch_size @@ -93,51 +79,35 @@ def __init__( self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last - self._train_transforms = train_transforms - self._val_transforms = val_transforms - self._test_transforms = test_transforms - self.EXTRA_ARGS = kwargs - - self.train_kwargs = self.EXTRA_ARGS.copy() - self.test_kwargs = self.EXTRA_ARGS.copy() - if _has_constructor_argument(self.dataset_cls, "train"): - self.train_kwargs["train"] = True - self.test_kwargs["train"] = False + self.train_transforms = train_transforms or self.default_transforms() + self.val_transforms = val_transforms or transforms.Compose( + [transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)] + ) + self.test_transforms = test_transforms or transforms.Compose( + [transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)] + ) + # todo: what about the shuffling at each epoch? _rng = torch.Generator(device="cpu").manual_seed(self.seed) self.train_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) self.val_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) self.test_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) - self.dataset_test: VisionDataset | None = None - - @property - def train_transforms(self) -> Callable[..., Any] | None: - """Optional transforms (or collection of transforms) you can apply to train dataset.""" - return self._train_transforms - - @train_transforms.setter - def train_transforms(self, t: Callable) -> None: - self._train_transforms = t + self.test_dataset_cls = self.dataset_cls - @property - def val_transforms(self) -> Callable[..., Any] | None: - """Optional transforms (or collection of transforms) you can apply to validation - dataset.""" - return self._val_transforms - - @val_transforms.setter - def val_transforms(self, t: Callable) -> None: - self._val_transforms = t + self.dataset_train: Dataset | None = None + self.dataset_val: Dataset | None = None + self.dataset_test: VisionDataset | None = None - @property - def test_transforms(self) -> Callable[..., Any] | None: - """Optional transforms (or collection of transforms) you can apply to test dataset.""" - return self._test_transforms + self.EXTRA_ARGS = kwargs + self.train_kwargs = self.EXTRA_ARGS | {"transform": self.train_transforms} + self.valid_kwargs = self.EXTRA_ARGS | {"transform": self.val_transforms} + self.test_kwargs = self.EXTRA_ARGS | {"transform": self.test_transforms} - @test_transforms.setter - def test_transforms(self, t: Callable) -> None: - self._test_transforms = t + if _has_constructor_argument(self.dataset_cls, "train"): + self.train_kwargs["train"] = True + self.valid_kwargs["train"] = True + self.test_kwargs["train"] = False def prepare_data(self) -> None: """Saves files to data_dir.""" @@ -148,54 +118,38 @@ def prepare_data(self) -> None: if _has_constructor_argument(self.dataset_cls, "download"): train_kwargs["download"] = True test_kwargs["download"] = True - logger.info( + logger.debug( f"Preparing {self.name} dataset training split in {self.data_dir} with {train_kwargs}" ) self.dataset_cls(str(self.data_dir), **train_kwargs) if test_kwargs != train_kwargs: - logger.info( + logger.debug( f"Preparing {self.name} dataset test spit in {self.data_dir} with {test_kwargs=}" ) - self.dataset_cls(str(self.data_dir), **test_kwargs) + self.test_dataset_cls(str(self.data_dir), **test_kwargs) def setup(self, stage: StageStr | None = None) -> None: """Creates train, val, and test dataset.""" if stage in ["fit", "validate"] or stage is None: - train_transforms = ( - self.default_transforms() - if self.train_transforms is None - else self.train_transforms - ) - val_transforms = ( - self.default_transforms() if self.val_transforms is None else self.val_transforms - ) - + logger.debug(f"creating training dataset with kwargs {self.train_kwargs}") dataset_train = self.dataset_cls( str(self.data_dir), - transform=train_transforms, **self.train_kwargs, ) - # dataset_train = wrap_dataset_for_transforms_v2(dataset_train) + logger.debug(f"creating validation dataset with kwargs {self.valid_kwargs}") dataset_val = self.dataset_cls( str(self.data_dir), - transform=val_transforms, - **self.train_kwargs, # todo: Assuming those are the same for now. + **self.valid_kwargs, ) - # dataset_val = wrap_dataset_for_transforms_v2(dataset_val) - - # Split + # Train/validation split. + # NOTE: the dataset is created twice (with the right transforms) and split in the same + # way, such that there is no overlap in indices between train and validation sets. self.dataset_train = self._split_dataset(dataset_train, train=True) self.dataset_val = self._split_dataset(dataset_val, train=False) if stage == "test" or stage is None: - test_transforms = ( - self.default_transforms() if self.test_transforms is None else self.test_transforms - ) - dataset_test = self.dataset_cls( - str(self.data_dir), transform=test_transforms, **self.test_kwargs - ) - # dataset_test = wrap_dataset_for_transforms_v2(dataset_test) - self.dataset_test = dataset_test + logger.debug(f"creating test dataset with kwargs {self.train_kwargs}") + self.dataset_test = self.test_dataset_cls(str(self.data_dir), **self.test_kwargs) def _split_dataset(self, dataset: VisionDataset, train: bool = True) -> Dataset: """Splits the dataset into train and validation set.""" @@ -227,7 +181,7 @@ def _get_splits(self, len_dataset: int) -> list[int]: def default_transforms(self) -> Callable: """Default transform for the dataset.""" - def train_dataloader( + def train_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -247,7 +201,7 @@ def train_dataloader( ), ) - def val_dataloader( + def val_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -262,7 +216,7 @@ def val_dataloader( **(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs), ) - def test_dataloader( + def test_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -279,7 +233,7 @@ def test_dataloader( **(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs), ) - def _data_loader( + def _data_loader[**P]( self, dataset: Dataset, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, @@ -292,7 +246,7 @@ def _data_loader( num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, - persistent_workers=True if self.num_workers > 0 else False, + persistent_workers=(self.num_workers or 0) > 0, ) | dataloader_kwargs ) diff --git a/project/datamodules/vision/__init__.py b/project/datamodules/vision/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/project/experiment.py b/project/experiment.py index f4112356..d437a9b1 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -112,9 +112,9 @@ def instantiate_trainer(experiment_config: Config) -> Trainer: # fields have the right type. # instantiate all the callbacks - callbacks: dict[str, Callback] | None = hydra_zen.instantiate( - experiment_config.trainer.pop("callbacks", {}) - ) + callback_configs = experiment_config.trainer.pop("callbacks", {}) + callback_configs = {k: v for k, v in callback_configs.items() if v is not None} + callbacks: dict[str, Callback] | None = hydra_zen.instantiate(callback_configs) # Create the loggers, if any. loggers: dict[str, Any] | None = instantiate(experiment_config.trainer.pop("logger", {})) # Create the Trainer. diff --git a/project/main_test.py b/project/main_test.py index a855ba57..32d16e4b 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -9,7 +9,6 @@ from project.algorithms import Algorithm, ExampleAlgorithm from project.configs.config import Config -from project.configs.datamodule import CIFAR10DataModuleConfig from project.conftest import setup_hydra_for_tests_and_compose, use_overrides from project.datamodules.image_classification.cifar10 import CIFAR10DataModule from project.networks.fcnet import FcNet @@ -42,7 +41,10 @@ def set_testing_hydra_dir(): @use_overrides([""]) def test_defaults(experiment_config: Config) -> None: assert isinstance(experiment_config.algorithm, ExampleAlgorithm.HParams) - assert isinstance(experiment_config.datamodule, CIFAR10DataModuleConfig | CIFAR10DataModule) + assert ( + isinstance(experiment_config.datamodule, CIFAR10DataModule) + or hydra_zen.get_target(experiment_config.datamodule) is CIFAR10DataModule + ) def _ids(v): diff --git a/project/utils/env_vars.py b/project/utils/env_vars.py new file mode 100644 index 00000000..cc9d663d --- /dev/null +++ b/project/utils/env_vars.py @@ -0,0 +1,52 @@ +import os +from pathlib import Path + +import torch + +SLURM_TMPDIR: Path | None = ( + Path(os.environ["SLURM_TMPDIR"]) + if "SLURM_TMPDIR" in os.environ + else tmp + if "SLURM_JOB_ID" in os.environ and (tmp := Path("/tmp")).exists() + else None +) +SLURM_JOB_ID: int | None = ( + int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None +) + +NETWORK_DIR = ( + Path(os.environ["NETWORK_DIR"]) + if "NETWORK_DIR" in os.environ + else _network_dir + if (_network_dir := Path("/network")).exists() + else None +) + +REPO_ROOTDIR = Path(__file__).parent +for level in range(5): + if "README.md" in list(p.name for p in REPO_ROOTDIR.iterdir()): + break + REPO_ROOTDIR = REPO_ROOTDIR.parent + +SCRATCH = Path(os.environ["SCRATCH"]) if "SCRATCH" in os.environ else None +"""SCRATCH directory where logs / checkpoints / custom datasets should be saved.""" + +DATA_DIR = Path(os.environ.get("DATA_DIR", (SLURM_TMPDIR or SCRATCH or REPO_ROOTDIR) / "data")) +"""Directory where datasets should be extracted.""" + + +def get_constant(name: str): + return globals()[name] + + +NUM_WORKERS = int( + os.environ.get( + "SLURM_CPUS_PER_TASK", + os.environ.get( + "SLURM_CPUS_ON_NODE", + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else torch.multiprocessing.cpu_count(), + ), + ) +) diff --git a/project/utils/testutils.py b/project/utils/testutils.py index bfe9a203..184a0605 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -2,18 +2,21 @@ from __future__ import annotations +import contextlib import copy import dataclasses import hashlib import importlib +import os +import random from collections.abc import Mapping, Sequence from contextlib import contextmanager from logging import getLogger as get_logger from pathlib import Path from typing import Any, TypeVar -import hydra.errors import hydra_zen +import numpy as np import pytest import torch import yaml @@ -23,47 +26,66 @@ from torch import Tensor, nn from torch.optim import Optimizer -from project.configs import Config, cs -from project.configs.datamodule import DATA_DIR, SLURM_JOB_ID +from project.configs import Config from project.datamodules.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.experiment import instantiate_trainer +from project.utils.env_vars import NETWORK_DIR from project.utils.hydra_utils import get_attr, get_outer_class from project.utils.types import PhaseStr from project.utils.types.protocols import DataModule from project.utils.utils import get_device -SLOW_DATAMODULES = ["inaturalist", "imagenet32"] +logger = get_logger(__name__) + +IN_GITHUB_CI = "GITHUB_ACTIONS" in os.environ +IN_SELF_HOSTED_GITHUB_CI = IN_GITHUB_CI and "self-hosted" in os.environ.get("RUNNER_LABELS", "") + default_marks_for_config_name: dict[str, list[pytest.MarkDecorator]] = { "imagenet32": [pytest.mark.slow], "inaturalist": [ pytest.mark.slow, - pytest.mark.xfail( - not Path("/network/datasets/inat").exists(), - strict=True, - raises=hydra.errors.InstantiationException, + pytest.mark.skipif( + not (NETWORK_DIR and (NETWORK_DIR / "datasets/inat").exists()), + # strict=True, + # raises=hydra.errors.InstantiationException, reason="Expects to be run on the Mila cluster for now", ), ], - "rl": [ - pytest.mark.xfail( - strict=False, - raises=AssertionError, - # match="Shapes are not the same." - reason="Isn't entirely deterministic yet.", + "imagenet": [ + pytest.mark.slow, + pytest.mark.skipif( + not (NETWORK_DIR and (NETWORK_DIR / "datasets/imagenet").exists()), + # strict=True, + # raises=hydra.errors.InstantiationException, + reason="Expects to be run on a cluster with the ImageNet dataset.", ), ], - "moving_mnist": [ - (pytest.mark.slow if not (DATA_DIR / "MovingMNIST").exists() else pytest.mark.timeout(5)) - ], + "vision": [pytest.mark.skip(reason="Base class, shouldn't be instantiated.")], } """Dict with some default marks for some configs name.""" - -logger = get_logger(__name__) +default_marks_for_config_combinations: dict[tuple[str, ...], list[pytest.MarkDecorator]] = { + ("imagenet", "fcnet"): [ + pytest.mark.xfail( + reason="FcNet shouldn't be applied to the ImageNet datamodule. It can lead to nans in the parameters." + ) + ], + ("imagenet", "jax_fcnet"): [ + pytest.mark.xfail( + reason="FcNet shouldn't be applied to the ImageNet datamodule. It can lead to nans in the parameters." + ) + ], + ("imagenet", "jax_cnn"): [ + pytest.mark.xfail( + reason="todo: parameters contain nans when overfitting on one batch? Maybe we're " + "using too many iterations?" + ) + ], +} def parametrized_fixture(name: str, values: Sequence, ids=None, **kwargs): @@ -155,11 +177,15 @@ def get_all_algorithm_names() -> list[str]: return get_all_configs_in_group("algorithm") -def get_type_for_config_name(config_group: str, config_name: str, _cs: ConfigStore = cs) -> type: +def get_type_for_config_name( + config_group: str, config_name: str, _cs: ConfigStore | None = None +) -> type: """Returns the class that is to be instantiated by the given config name. In the case of inner dataclasses (e.g. Model.HParams), this returns the outer class (Model). """ + if _cs is None: + from project.configs import cs as _cs config_loader = get_config_loader() _, caching_repo = config_loader._parse_overrides_and_create_caching_repo( @@ -274,7 +300,11 @@ def test_network_output_is_reproducible(network: nn.Module, x: Tensor): def get_all_datamodule_names() -> list[str]: """Retrieves the names of all the datamodules that are saved in the ConfigStore of Hydra.""" - return get_all_configs_in_group("datamodule") + datamodules = get_all_configs_in_group("datamodule") + # todo: automatically detect which ones are configs for ABCs and remove them? + if "vision" in datamodules: + datamodules.remove("vision") + return datamodules def get_all_datamodule_names_params(): @@ -290,18 +320,7 @@ def get_all_datamodule_names_params(): marks=[ pytest.mark.xdist_group(name=dm_name), ] - + ([pytest.mark.slow] if dm_name in SLOW_DATAMODULES else []) - + ( - [ - pytest.mark.xfail( - SLURM_JOB_ID is None, - raises=NotImplementedError, - reason="Needs to be run on the Mila cluster atm.", - ) - ] - if dm_name == "inaturalist" - else [] - ), + + default_marks_for_config_name.get(dm_name, []), ) for dm_name in dm_names ] @@ -603,8 +622,17 @@ def assert_no_nans_in_params_or_grads(module: nn.Module): assert not torch.isnan(param.grad).any(), name +@contextlib.contextmanager +def fork_rng(): + with torch.random.fork_rng(): + random_state = random.getstate() + np_random_state = np.random.get_state() + yield + np.random.set_state(np_random_state) + random.setstate(random_state) + + @contextmanager def seeded(seed: int = 42): - with torch.random.fork_rng(): - torch.random.manual_seed(seed) + with fork_rng(): yield diff --git a/pyproject.toml b/pyproject.toml index 1443b0ff..6f7025a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "tqdm>=4.66.2", "hydra-zen>=0.12.1", "gym==0.26.2", - "lightning==1.9.0", + "lightning>=2.3.0", "gdown>=5.1.0", "hydra-submitit-launcher>=1.2.0", "wandb>=0.16.4", @@ -27,11 +27,16 @@ dependencies = [ "gymnax>=0.0.8", "torch-jax-interop @ git+https://www.github.com/lebrice/torch_jax_interop", "tensor-regression @ git+https://www.github.com/lebrice/tensor_regression", + "simple-parsing>=0.1.5", + "pydantic==2.7.4", ] requires-python = ">=3.12" readme = "README.md" license = {text = "MIT"} +[project.scripts] +project = "project:main.main" + [tool.setuptools] packages = ["project"] @@ -50,6 +55,7 @@ dev = [ "pytest-benchmark>=4.0.0", "pytest-cov>=5.0.0", "tensor-regression>=0.0.2.post3.dev0", + "pytest-testmon>=2.1.1", ] [[tool.pdm.source]]