From 0d9160f3bac7f3acfd27f1f23aa15b1b6fe714e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= <45119610+ireneisdoomed@users.noreply.github.com> Date: Mon, 24 Jun 2024 17:20:58 +0100 Subject: [PATCH] feat(l2g): better l2g training, evaluation, and integration (#576) * chore: checkpoint * chore: checkpoint * chore: deprecate spark evaluator * chore: checkpoint * chore: resolve conflicts with dev * chore: resolve conflicts with dev * chore(model): add parameters class property * feat: add module to export model to hub * refactor: make model agnostic of features list * chore: add wandb to gitignore * feat: download model from hub * chore(model): adapt predict method * feat(trainer): add hyperparameter tuning * chore: deprecate trainer tests * refactor: modularise step * feat: download model from hub by default * fix: convert omegaconfig defaults to python objects * fix: write serialised model to disk and then upload to gcs * fix(matrix): drop goldStandardSet when in predict mode * chore: pass token to access private model * chore: pass token to access private model * fix: pass right schema * chore: pre-commit auto fixes [...] * chore: fix mypy issues * build: remove xgboost * chore: merge * chore: pre-commit auto fixes [...] * chore: address comments --- .gitignore | 1 + config/datasets/ot_gcp.yaml | 2 +- config/step/ot_locus_to_gene_predict.yaml | 2 +- config/step/ot_locus_to_gene_train.yaml | 6 +- docs/python_api/methods/l2g/evaluator.md | 5 - poetry.lock | 592 +++++++++++--------- pyproject.toml | 2 +- src/gentropy/common/utils.py | 23 + src/gentropy/config.py | 36 +- src/gentropy/dataset/l2g_feature_matrix.py | 48 +- src/gentropy/dataset/l2g_prediction.py | 65 +-- src/gentropy/l2g.py | 301 ++++++---- src/gentropy/method/l2g/evaluator.py | 204 ------- src/gentropy/method/l2g/model.py | 424 ++++++-------- src/gentropy/method/l2g/trainer.py | 253 ++++++--- tests/gentropy/method/test_locus_to_gene.py | 57 +- 16 files changed, 961 insertions(+), 1060 deletions(-) delete mode 100644 docs/python_api/methods/l2g/evaluator.md delete mode 100644 src/gentropy/method/l2g/evaluator.py diff --git a/.gitignore b/.gitignore index 9aef6d065..f4c85d797 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ src/airflow/logs/* site/ .env .coverage* +wandb/ diff --git a/config/datasets/ot_gcp.yaml b/config/datasets/ot_gcp.yaml index 09cece996..c8f67d418 100644 --- a/config/datasets/ot_gcp.yaml +++ b/config/datasets/ot_gcp.yaml @@ -70,7 +70,7 @@ from_sumstats_pics: ${datasets.credible_set}/from_sumstats # ETL output datasets: l2g_gold_standard_curation: ${datasets.release_folder}/locus_to_gene_gold_standard.json -l2g_model: ${datasets.release_folder}/locus_to_gene_model +l2g_model: ${datasets.release_folder}/locus_to_gene_model/classifier.skops l2g_predictions: ${datasets.release_folder}/locus_to_gene_predictions l2g_feature_matrix: ${datasets.release_folder}/locus_to_gene_feature_matrix colocalisation: ${datasets.release_folder}/colocalisation diff --git a/config/step/ot_locus_to_gene_predict.yaml b/config/step/ot_locus_to_gene_predict.yaml index 2c2c3a092..a98e3cf2a 100644 --- a/config/step/ot_locus_to_gene_predict.yaml +++ b/config/step/ot_locus_to_gene_predict.yaml @@ -2,7 +2,7 @@ defaults: - locus_to_gene run_mode: predict -model_path: ${datasets.l2g_model} +model_path: null predictions_path: ${datasets.l2g_predictions} feature_matrix_path: ${datasets.l2g_feature_matrix} credible_set_path: ${datasets.credible_set} diff --git a/config/step/ot_locus_to_gene_train.yaml b/config/step/ot_locus_to_gene_train.yaml index d055621ca..25f3710c5 100644 --- a/config/step/ot_locus_to_gene_train.yaml +++ b/config/step/ot_locus_to_gene_train.yaml @@ -3,7 +3,7 @@ defaults: run_mode: train wandb_run_name: null -perform_cross_validation: false +hf_hub_repo_id: opentargets/locus_to_gene model_path: ${datasets.l2g_model} predictions_path: ${datasets.l2g_predictions} credible_set_path: ${datasets.credible_set} @@ -13,5 +13,7 @@ study_index_path: ${datasets.study_index} gold_standard_curation_path: ${datasets.l2g_gold_standard_curation} gene_interactions_path: ${datasets.gene_interactions} hyperparameters: + n_estimators: 100 max_depth: 5 - loss_function: binary:logistic + loss: log_loss +download_from_hub: true diff --git a/docs/python_api/methods/l2g/evaluator.md b/docs/python_api/methods/l2g/evaluator.md deleted file mode 100644 index 4b389e8c0..000000000 --- a/docs/python_api/methods/l2g/evaluator.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -title: W&B evaluator ---- - -::: gentropy.method.l2g.evaluator.WandbEvaluator diff --git a/poetry.lock b/poetry.lock index 4fa3027d2..17ef009f0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiodns" @@ -675,13 +675,13 @@ files = [ [[package]] name = "argcomplete" -version = "3.3.0" +version = "3.4.0" description = "Bash tab completion for argparse" optional = false python-versions = ">=3.8" files = [ - {file = "argcomplete-3.3.0-py3-none-any.whl", hash = "sha256:c168c3723482c031df3c207d4ba8fa702717ccb9fc0bfe4117166c1f537b4a54"}, - {file = "argcomplete-3.3.0.tar.gz", hash = "sha256:fd03ff4a5b9e6580569d34b273f741e85cd9e072f3feeeee3eba4891c70eda62"}, + {file = "argcomplete-3.4.0-py3-none-any.whl", hash = "sha256:69a79e083a716173e5532e0fa3bef45f793f4e61096cf52b5a42c0211c8b8aa5"}, + {file = "argcomplete-3.4.0.tar.gz", hash = "sha256:c2abcdfe1be8ace47ba777d4fce319eb13bf8ad9dace8d085dcad6eded88057f"}, ] [package.extras] @@ -798,13 +798,13 @@ aio = ["aiohttp (>=3.0)"] [[package]] name = "azure-identity" -version = "1.16.1" +version = "1.17.1" description = "Microsoft Azure Identity Library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "azure-identity-1.16.1.tar.gz", hash = "sha256:6d93f04468f240d59246d8afde3091494a5040d4f141cad0f49fc0c399d0d91e"}, - {file = "azure_identity-1.16.1-py3-none-any.whl", hash = "sha256:8fb07c25642cd4ac422559a8b50d3e77f73dcc2bbfaba419d06d6c9d7cff6726"}, + {file = "azure-identity-1.17.1.tar.gz", hash = "sha256:32ecc67cc73f4bd0595e4f64b1ca65cd05186f4fe6f98ed2ae9f1aa32646efea"}, + {file = "azure_identity-1.17.1-py3-none-any.whl", hash = "sha256:db8d59c183b680e763722bfe8ebc45930e6c57df510620985939f7f3191e0382"}, ] [package.dependencies] @@ -812,6 +812,7 @@ azure-core = ">=1.23.0" cryptography = ">=2.5" msal = ">=1.24.0" msal-extensions = ">=0.3.0" +typing-extensions = ">=4.0.0" [[package]] name = "azure-mgmt-core" @@ -944,17 +945,17 @@ xyzservices = ">=2021.09.1" [[package]] name = "boto3" -version = "1.34.125" +version = "1.34.131" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.125-py3-none-any.whl", hash = "sha256:116d9eb3c26cf313a2e1e44ef704d1f98f9eb18e7628695d07b01b44a8683544"}, - {file = "boto3-1.34.125.tar.gz", hash = "sha256:31c4a5e4d6f9e6116be61ff654b424ddbd1afcdefe0e8b870c4796f9108eb1c6"}, + {file = "boto3-1.34.131-py3-none-any.whl", hash = "sha256:05e388cb937e82be70bfd7eb0c84cf8011ff35cf582a593873ac21675268683b"}, + {file = "boto3-1.34.131.tar.gz", hash = "sha256:dab8f72a6c4e62b4fd70da09e08a6b2a65ea2115b27dd63737142005776ef216"}, ] [package.dependencies] -botocore = ">=1.34.125,<1.35.0" +botocore = ">=1.34.131,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -963,13 +964,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.125" +version = "1.34.131" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.125-py3-none-any.whl", hash = "sha256:71e97e7d2c088f1188ba6976441b5857a5425acd4aaa31b45d13119c9cb86424"}, - {file = "botocore-1.34.125.tar.gz", hash = "sha256:d2882be011ad5b16e7ab4a96360b5b66a0a7e175c1ea06dbf2de473c0a0a33d8"}, + {file = "botocore-1.34.131-py3-none-any.whl", hash = "sha256:13b011d7b206ce00727dcee26548fa3b550db9046d5a0e90ac25a6e6c8fde6ef"}, + {file = "botocore-1.34.131.tar.gz", hash = "sha256:502ddafe1d627fcf1e4c007c86454e5dd011dba7c58bd8e8a5368a79f3e387dc"}, ] [package.dependencies] @@ -1427,63 +1428,63 @@ test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] [[package]] name = "coverage" -version = "7.5.3" +version = "7.5.4" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" files = [ - {file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"}, - {file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"}, - {file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"}, - {file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"}, - {file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"}, - {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"}, - {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"}, - {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"}, - {file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"}, - {file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"}, - {file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"}, - {file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"}, - {file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"}, - {file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"}, - {file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"}, - {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"}, - {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"}, - {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"}, - {file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"}, - {file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"}, - {file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"}, - {file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"}, - {file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"}, - {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"}, - {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"}, - {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"}, - {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"}, - {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"}, - {file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"}, - {file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"}, - {file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"}, - {file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"}, - {file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"}, - {file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"}, - {file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"}, - {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"}, - {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"}, - {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"}, - {file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"}, - {file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"}, - {file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"}, - {file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"}, - {file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"}, - {file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"}, - {file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"}, - {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"}, - {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"}, - {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"}, - {file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"}, - {file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"}, - {file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"}, - {file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"}, + {file = "coverage-7.5.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6cfb5a4f556bb51aba274588200a46e4dd6b505fb1a5f8c5ae408222eb416f99"}, + {file = "coverage-7.5.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2174e7c23e0a454ffe12267a10732c273243b4f2d50d07544a91198f05c48f47"}, + {file = "coverage-7.5.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2214ee920787d85db1b6a0bd9da5f8503ccc8fcd5814d90796c2f2493a2f4d2e"}, + {file = "coverage-7.5.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1137f46adb28e3813dec8c01fefadcb8c614f33576f672962e323b5128d9a68d"}, + {file = "coverage-7.5.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b385d49609f8e9efc885790a5a0e89f2e3ae042cdf12958b6034cc442de428d3"}, + {file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b4a474f799456e0eb46d78ab07303286a84a3140e9700b9e154cfebc8f527016"}, + {file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:5cd64adedf3be66f8ccee418473c2916492d53cbafbfcff851cbec5a8454b136"}, + {file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e564c2cf45d2f44a9da56f4e3a26b2236504a496eb4cb0ca7221cd4cc7a9aca9"}, + {file = "coverage-7.5.4-cp310-cp310-win32.whl", hash = "sha256:7076b4b3a5f6d2b5d7f1185fde25b1e54eb66e647a1dfef0e2c2bfaf9b4c88c8"}, + {file = "coverage-7.5.4-cp310-cp310-win_amd64.whl", hash = "sha256:018a12985185038a5b2bcafab04ab833a9a0f2c59995b3cec07e10074c78635f"}, + {file = "coverage-7.5.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:db14f552ac38f10758ad14dd7b983dbab424e731588d300c7db25b6f89e335b5"}, + {file = "coverage-7.5.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3257fdd8e574805f27bb5342b77bc65578e98cbc004a92232106344053f319ba"}, + {file = "coverage-7.5.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a6612c99081d8d6134005b1354191e103ec9705d7ba2754e848211ac8cacc6b"}, + {file = "coverage-7.5.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d45d3cbd94159c468b9b8c5a556e3f6b81a8d1af2a92b77320e887c3e7a5d080"}, + {file = "coverage-7.5.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed550e7442f278af76d9d65af48069f1fb84c9f745ae249c1a183c1e9d1b025c"}, + {file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7a892be37ca35eb5019ec85402c3371b0f7cda5ab5056023a7f13da0961e60da"}, + {file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8192794d120167e2a64721d88dbd688584675e86e15d0569599257566dec9bf0"}, + {file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:820bc841faa502e727a48311948e0461132a9c8baa42f6b2b84a29ced24cc078"}, + {file = "coverage-7.5.4-cp311-cp311-win32.whl", hash = "sha256:6aae5cce399a0f065da65c7bb1e8abd5c7a3043da9dceb429ebe1b289bc07806"}, + {file = "coverage-7.5.4-cp311-cp311-win_amd64.whl", hash = "sha256:d2e344d6adc8ef81c5a233d3a57b3c7d5181f40e79e05e1c143da143ccb6377d"}, + {file = "coverage-7.5.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:54317c2b806354cbb2dc7ac27e2b93f97096912cc16b18289c5d4e44fc663233"}, + {file = "coverage-7.5.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:042183de01f8b6d531e10c197f7f0315a61e8d805ab29c5f7b51a01d62782747"}, + {file = "coverage-7.5.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6bb74ed465d5fb204b2ec41d79bcd28afccf817de721e8a807d5141c3426638"}, + {file = "coverage-7.5.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3d45ff86efb129c599a3b287ae2e44c1e281ae0f9a9bad0edc202179bcc3a2e"}, + {file = "coverage-7.5.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5013ed890dc917cef2c9f765c4c6a8ae9df983cd60dbb635df8ed9f4ebc9f555"}, + {file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1014fbf665fef86cdfd6cb5b7371496ce35e4d2a00cda501cf9f5b9e6fced69f"}, + {file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3684bc2ff328f935981847082ba4fdc950d58906a40eafa93510d1b54c08a66c"}, + {file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:581ea96f92bf71a5ec0974001f900db495488434a6928a2ca7f01eee20c23805"}, + {file = "coverage-7.5.4-cp312-cp312-win32.whl", hash = "sha256:73ca8fbc5bc622e54627314c1a6f1dfdd8db69788f3443e752c215f29fa87a0b"}, + {file = "coverage-7.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:cef4649ec906ea7ea5e9e796e68b987f83fa9a718514fe147f538cfeda76d7a7"}, + {file = "coverage-7.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cdd31315fc20868c194130de9ee6bfd99755cc9565edff98ecc12585b90be882"}, + {file = "coverage-7.5.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:02ff6e898197cc1e9fa375581382b72498eb2e6d5fc0b53f03e496cfee3fac6d"}, + {file = "coverage-7.5.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d05c16cf4b4c2fc880cb12ba4c9b526e9e5d5bb1d81313d4d732a5b9fe2b9d53"}, + {file = "coverage-7.5.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5986ee7ea0795a4095ac4d113cbb3448601efca7f158ec7f7087a6c705304e4"}, + {file = "coverage-7.5.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df54843b88901fdc2f598ac06737f03d71168fd1175728054c8f5a2739ac3e4"}, + {file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ab73b35e8d109bffbda9a3e91c64e29fe26e03e49addf5b43d85fc426dde11f9"}, + {file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:aea072a941b033813f5e4814541fc265a5c12ed9720daef11ca516aeacd3bd7f"}, + {file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:16852febd96acd953b0d55fc842ce2dac1710f26729b31c80b940b9afcd9896f"}, + {file = "coverage-7.5.4-cp38-cp38-win32.whl", hash = "sha256:8f894208794b164e6bd4bba61fc98bf6b06be4d390cf2daacfa6eca0a6d2bb4f"}, + {file = "coverage-7.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:e2afe743289273209c992075a5a4913e8d007d569a406ffed0bd080ea02b0633"}, + {file = "coverage-7.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b95c3a8cb0463ba9f77383d0fa8c9194cf91f64445a63fc26fb2327e1e1eb088"}, + {file = "coverage-7.5.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3d7564cc09dd91b5a6001754a5b3c6ecc4aba6323baf33a12bd751036c998be4"}, + {file = "coverage-7.5.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44da56a2589b684813f86d07597fdf8a9c6ce77f58976727329272f5a01f99f7"}, + {file = "coverage-7.5.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e16f3d6b491c48c5ae726308e6ab1e18ee830b4cdd6913f2d7f77354b33f91c8"}, + {file = "coverage-7.5.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbc5958cb471e5a5af41b0ddaea96a37e74ed289535e8deca404811f6cb0bc3d"}, + {file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a04e990a2a41740b02d6182b498ee9796cf60eefe40cf859b016650147908029"}, + {file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ddbd2f9713a79e8e7242d7c51f1929611e991d855f414ca9996c20e44a895f7c"}, + {file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b1ccf5e728ccf83acd313c89f07c22d70d6c375a9c6f339233dcf792094bcbf7"}, + {file = "coverage-7.5.4-cp39-cp39-win32.whl", hash = "sha256:56b4eafa21c6c175b3ede004ca12c653a88b6f922494b023aeb1e836df953ace"}, + {file = "coverage-7.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:65e528e2e921ba8fd67d9055e6b9f9e34b21ebd6768ae1c1723f4ea6ace1234d"}, + {file = "coverage-7.5.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:79b356f3dd5b26f3ad23b35c75dbdaf1f9e2450b6bcefc6d0825ea0aa3f86ca5"}, + {file = "coverage-7.5.4.tar.gz", hash = "sha256:a44963520b069e12789d0faea4e9fdb1e410cdc4aab89d94f7f55cbb7fef0353"}, ] [package.dependencies] @@ -1761,13 +1762,13 @@ files = [ [[package]] name = "docstring-parser-fork" -version = "0.0.5" +version = "0.0.8" description = "Parse Python docstrings in reST, Google and Numpydoc format" optional = false -python-versions = ">=3.6,<4.0" +python-versions = "<4.0,>=3.7" files = [ - {file = "docstring_parser_fork-0.0.5-py3-none-any.whl", hash = "sha256:d521dea9b9cc6c60ab5569fa0c1115e3b84a83e6413266fb111a7c81cb935997"}, - {file = "docstring_parser_fork-0.0.5.tar.gz", hash = "sha256:395ae8ee6a359e268670ebc4fe9a40dab917a94f6decd7cda8e86f9bea5c9456"}, + {file = "docstring_parser_fork-0.0.8-py3-none-any.whl", hash = "sha256:88098ae01b0909b241954ad2c50c0c29ec2292223366a540bfd68332be8fd595"}, + {file = "docstring_parser_fork-0.0.8.tar.gz", hash = "sha256:59d3b00d42ba9f4e229a7df7e1f6fc742845f88a1190973cc33ba336a5405425"}, ] [[package]] @@ -1783,13 +1784,13 @@ files = [ [[package]] name = "email-validator" -version = "2.1.1" +version = "2.2.0" description = "A robust email address syntax and deliverability validation library." optional = false python-versions = ">=3.8" files = [ - {file = "email_validator-2.1.1-py3-none-any.whl", hash = "sha256:97d882d174e2a65732fb43bfce81a3a834cbc1bde8bf419e30ef5ea976370a05"}, - {file = "email_validator-2.1.1.tar.gz", hash = "sha256:200a70680ba08904be6d1eef729205cc0d687634399a5924d842533efb824b84"}, + {file = "email_validator-2.2.0-py3-none-any.whl", hash = "sha256:561977c2d73ce3611850a06fa56b414621e0c8faa9d66f2611407d87465da631"}, + {file = "email_validator-2.2.0.tar.gz", hash = "sha256:cb690f344c617a714f22e66ae771445a1ceb46821152df8e165c5f9a364582b7"}, ] [package.dependencies] @@ -1840,34 +1841,34 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth [[package]] name = "filelock" -version = "3.15.1" +version = "3.15.4" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.1-py3-none-any.whl", hash = "sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac"}, - {file = "filelock-3.15.1.tar.gz", hash = "sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8"}, + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] name = "flake8" -version = "7.0.0" +version = "7.1.0" description = "the modular source code checker: pep8 pyflakes and co" optional = false python-versions = ">=3.8.1" files = [ - {file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"}, - {file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"}, + {file = "flake8-7.1.0-py2.py3-none-any.whl", hash = "sha256:2e416edcc62471a64cea09353f4e7bdba32aeb079b6e360554c659a122b1bc6a"}, + {file = "flake8-7.1.0.tar.gz", hash = "sha256:48a07b626b55236e0fb4784ee69a465fbf59d79eec1f5b4785c3d3bc57d17aa5"}, ] [package.dependencies] mccabe = ">=0.7.0,<0.8.0" -pycodestyle = ">=2.11.0,<2.12.0" +pycodestyle = ">=2.12.0,<2.13.0" pyflakes = ">=3.2.0,<3.3.0" [[package]] @@ -2403,13 +2404,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-api-python-client" -version = "2.133.0" +version = "2.134.0" description = "Google API Client Library for Python" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-python-client-2.133.0.tar.gz", hash = "sha256:293092905b66a046d3187a99ac454e12b00cc2c70444f26eb2f1f9c1a82720b4"}, - {file = "google_api_python_client-2.133.0-py2.py3-none-any.whl", hash = "sha256:396fe676ea0dfed066654dcf9f8dea77a1342f9d9bb23bb88e45b7b81e773926"}, + {file = "google-api-python-client-2.134.0.tar.gz", hash = "sha256:4a8f0bea651a212997cc83c0f271fc86f80ef93d1cee9d84de7dfaeef2a858b6"}, + {file = "google_api_python_client-2.134.0-py2.py3-none-any.whl", hash = "sha256:ba05d60f6239990b7994f6328f17bb154c602d31860fb553016dc9f8ce886945"}, ] [package.dependencies] @@ -2477,13 +2478,13 @@ tool = ["click (>=6.0.0)"] [[package]] name = "google-cloud-aiplatform" -version = "1.55.0" +version = "1.56.0" description = "Vertex AI API client library" optional = false python-versions = ">=3.8" files = [ - {file = "google-cloud-aiplatform-1.55.0.tar.gz", hash = "sha256:aa87cb6c49ae5fde87fb831ce8ad4a853c4656fe04babe505e9144c7a9e09c1a"}, - {file = "google_cloud_aiplatform-1.55.0-py2.py3-none-any.whl", hash = "sha256:c6cc76ca5537f4636a0c3f8c0288d2e0d2d86ef708e562d2654313e11d6ee46a"}, + {file = "google-cloud-aiplatform-1.56.0.tar.gz", hash = "sha256:d4cfb085427dac01142915f523949ac2955d6c7f148d95017d3286a77caf5d5e"}, + {file = "google_cloud_aiplatform-1.56.0-py2.py3-none-any.whl", hash = "sha256:ee1ab3bd115c3caebf8ddfd3e47eeb8396a3ec2fc5f5baf1a5c295c8d64333ab"}, ] [package.dependencies] @@ -2505,8 +2506,8 @@ cloud-profiler = ["tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow datasets = ["pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)"] endpoint = ["requests (>=2.28.1)"] full = ["cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.109.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nest-asyncio (>=1.0.0,<1.6.0)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)"] -langchain = ["langchain (>=0.1.16,<0.3)", "langchain-core (<0.2)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)"] -langchain-testing = ["absl-py", "cloudpickle (>=2.2.1,<4.0)", "langchain (>=0.1.16,<0.3)", "langchain-core (<0.2)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)", "pytest-xdist"] +langchain = ["langchain (>=0.1.16,<0.3)", "langchain-core (<0.2)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "tenacity (<=8.3)"] +langchain-testing = ["absl-py", "cloudpickle (>=3.0,<4.0)", "langchain (>=0.1.16,<0.3)", "langchain-core (<0.2)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)", "pytest-xdist", "tenacity (<=8.3)"] lit = ["explainable-ai-sdk (>=1.0.0)", "lit-nlp (==0.4.0)", "pandas (>=1.0.0)", "tensorflow (>=2.3.0,<3.0.0dev)"] metadata = ["numpy (>=1.15.0)", "pandas (>=1.0.0)"] pipelines = ["pyyaml (>=5.3.1,<7)"] @@ -2516,7 +2517,7 @@ private-endpoints = ["requests (>=2.28.1)", "urllib3 (>=1.21.1,<1.27)"] rapid-evaluation = ["nest-asyncio (>=1.0.0,<1.6.0)", "pandas (>=1.0.0,<2.2.0)"] ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "setuptools (<70.0.0)"] ray-testing = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pytest-xdist", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "ray[train] (==2.9.3)", "scikit-learn", "setuptools (<70.0.0)", "tensorflow", "torch (>=2.0.0,<2.1.0)", "xgboost", "xgboost-ray"] -reasoningengine = ["cloudpickle (>=2.2.1,<4.0)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)"] +reasoningengine = ["cloudpickle (>=3.0,<4.0)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)"] tensorboard = ["tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"] testing = ["bigframes", "cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.109.1)", "google-api-core (>=2.11,<3.0.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "ipython", "kfp (>=2.6.0,<3.0.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nest-asyncio (>=1.0.0,<1.6.0)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyfakefs", "pytest-asyncio", "pytest-xdist", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (==2.13.0)", "tensorflow (==2.16.1)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "torch (>=2.2.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost"] vizier = ["google-vizier (>=0.1.6)"] @@ -2595,13 +2596,13 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4 [[package]] name = "google-cloud-bigquery" -version = "3.24.0" +version = "3.25.0" description = "Google BigQuery API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-bigquery-3.24.0.tar.gz", hash = "sha256:e95e6f6e0aa32e6c453d44e2b3298931fdd7947c309ea329a31b6ff1f939e17e"}, - {file = "google_cloud_bigquery-3.24.0-py2.py3-none-any.whl", hash = "sha256:bc08323ce99dee4e811b7c3d0cde8929f5bf0b1aeaed6bcd75fc89796dd87652"}, + {file = "google-cloud-bigquery-3.25.0.tar.gz", hash = "sha256:5b2aff3205a854481117436836ae1403f11f2594e6810a98886afd57eda28509"}, + {file = "google_cloud_bigquery-3.25.0-py2.py3-none-any.whl", hash = "sha256:7f0c371bc74d2a7fb74dacbc00ac0f90c8c2bec2289b51dd6685a275873b1ce9"}, ] [package.dependencies] @@ -2787,13 +2788,13 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4 [[package]] name = "google-cloud-dataplex" -version = "2.0.0" +version = "2.0.1" description = "Google Cloud Dataplex API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-dataplex-2.0.0.tar.gz", hash = "sha256:b5140f77a694ef9d6f5c7f12693873f5acf97cc75661d46bb8cc725a94a9017c"}, - {file = "google_cloud_dataplex-2.0.0-py2.py3-none-any.whl", hash = "sha256:b914aaf9040fc96d06831d33d43c8ac554eeda301cadd5342ca8ce7201eb0d27"}, + {file = "google-cloud-dataplex-2.0.1.tar.gz", hash = "sha256:f4ccb1f76eb7b8a2ae01cdcb2041bb613045d262b57bd65d2c7522e766923c15"}, + {file = "google_cloud_dataplex-2.0.1-py2.py3-none-any.whl", hash = "sha256:ebe732dcf54b372c4af8ee40e2f43e3eff329b98d2060b1a803a0e9e656f5da6"}, ] [package.dependencies] @@ -2985,13 +2986,13 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4 [[package]] name = "google-cloud-pubsub" -version = "2.21.3" +version = "2.21.5" description = "Google Cloud Pub/Sub API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-pubsub-2.21.3.tar.gz", hash = "sha256:df44c79a4d7bfb5adfa675413e3b91bbbe06e91d2715a15f925140b05715bb7d"}, - {file = "google_cloud_pubsub-2.21.3-py2.py3-none-any.whl", hash = "sha256:a417d63d9db5e8b9ff51dee705c69ed02b0a8bc62cac2474a86f22aea288c0ec"}, + {file = "google-cloud-pubsub-2.21.5.tar.gz", hash = "sha256:4fa96e7f200359ccc49cf6657e31ac35f5e6e55d00fbb3cedfa672903cf75b24"}, + {file = "google_cloud_pubsub-2.21.5-py2.py3-none-any.whl", hash = "sha256:fbd6b00a1e28ea47609b2a5562aeecbaf31ad9cf4f7a83f91c3605e869c6447c"}, ] [package.dependencies] @@ -3001,7 +3002,7 @@ grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" grpcio = ">=1.51.3,<2.0dev" grpcio-status = ">=1.33.2" proto-plus = ">=1.22.0,<2.0.0dev" -protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" [package.extras] libcst = ["libcst (>=0.3.10)"] @@ -3572,13 +3573,13 @@ test = ["objgraph", "psutil"] [[package]] name = "griffe" -version = "0.45.3" +version = "0.47.0" description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." optional = false python-versions = ">=3.8" files = [ - {file = "griffe-0.45.3-py3-none-any.whl", hash = "sha256:ed1481a680ae3e28f91a06e0d8a51a5c9b97555aa2527abc2664447cc22337d6"}, - {file = "griffe-0.45.3.tar.gz", hash = "sha256:02ee71cc1a5035864b97bd0dbfff65c33f6f2c8854d3bd48a791905c2b8a44b9"}, + {file = "griffe-0.47.0-py3-none-any.whl", hash = "sha256:07a2fd6a8c3d21d0bbb0decf701d62042ccc8a576645c7f8799fe1f10de2b2de"}, + {file = "griffe-0.47.0.tar.gz", hash = "sha256:95119a440a3c932b13293538bdbc405bee4c36428547553dc6b327e7e7d35e5a"}, ] [package.dependencies] @@ -3843,6 +3844,40 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "huggingface-hub" +version = "0.23.4" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "huggingface_hub-0.23.4-py3-none-any.whl", hash = "sha256:3a0b957aa87150addf0cc7bd71b4d954b78e749850e1e7fb29ebbd2db64ca037"}, + {file = "huggingface_hub-0.23.4.tar.gz", hash = "sha256:35d99016433900e44ae7efe1c209164a5a81dbbcd53a52f99c281dcd7ce22431"}, +] + +[package.dependencies] +filelock = "*" +fsspec = ">=2023.5.0" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +hf-transfer = ["hf-transfer (>=0.1.4)"] +inference = ["aiohttp", "minijinja (>=1.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["safetensors", "torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] + [[package]] name = "humanize" version = "1.1.0" @@ -4310,13 +4345,13 @@ files = [ [[package]] name = "limits" -version = "3.12.0" +version = "3.13.0" description = "Rate limiting utilities" optional = false python-versions = ">=3.8" files = [ - {file = "limits-3.12.0-py3-none-any.whl", hash = "sha256:48d91e94a0888fb1251aa31423d716ae72ceff997231363f7968a5eaa51dc56d"}, - {file = "limits-3.12.0.tar.gz", hash = "sha256:95764065715a11b9fdcc82558cac2fb59a1febbb7aa2acd045f72ab0c16ec04f"}, + {file = "limits-3.13.0-py3-none-any.whl", hash = "sha256:9767f7233da4255e9904b79908a728e8ec0984c0b086058b4cbbd309aea553f6"}, + {file = "limits-3.13.0.tar.gz", hash = "sha256:6571b0c567bfa175a35fed9f8a954c0c92f1c3200804282f1b8f1de4ad98a953"}, ] [package.dependencies] @@ -4952,13 +4987,13 @@ pytz = "*" [[package]] name = "mkdocs-material" -version = "9.5.26" +version = "9.5.27" description = "Documentation that simply works" optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_material-9.5.26-py3-none-any.whl", hash = "sha256:5d01fb0aa1c7946a1e3ae8689aa2b11a030621ecb54894e35aabb74c21016312"}, - {file = "mkdocs_material-9.5.26.tar.gz", hash = "sha256:56aeb91d94cffa43b6296fa4fbf0eb7c840136e563eecfd12c2d9e92e50ba326"}, + {file = "mkdocs_material-9.5.27-py3-none-any.whl", hash = "sha256:af8cc263fafa98bb79e9e15a8c966204abf15164987569bd1175fd66a7705182"}, + {file = "mkdocs_material-9.5.27.tar.gz", hash = "sha256:a7d4a35f6d4a62b0c43a0cfe7e987da0980c13587b5bc3c26e690ad494427ec0"}, ] [package.dependencies] @@ -5032,17 +5067,17 @@ python-legacy = ["mkdocstrings-python-legacy (>=0.2.1)"] [[package]] name = "mkdocstrings-python" -version = "1.10.3" +version = "1.10.5" description = "A Python handler for mkdocstrings." optional = false python-versions = ">=3.8" files = [ - {file = "mkdocstrings_python-1.10.3-py3-none-any.whl", hash = "sha256:11ff6d21d3818fb03af82c3ea6225b1534837e17f790aa5f09626524171f949b"}, - {file = "mkdocstrings_python-1.10.3.tar.gz", hash = "sha256:321cf9c732907ab2b1fedaafa28765eaa089d89320f35f7206d00ea266889d03"}, + {file = "mkdocstrings_python-1.10.5-py3-none-any.whl", hash = "sha256:92e3c588ef1b41151f55281d075de7558dd8092e422cb07a65b18ee2b0863ebb"}, + {file = "mkdocstrings_python-1.10.5.tar.gz", hash = "sha256:acdc2a98cd9d46c7ece508193a16ca03ccabcb67520352b7449f84b57c162bdf"}, ] [package.dependencies] -griffe = ">=0.44" +griffe = ">=0.47" mkdocstrings = ">=0.25" [[package]] @@ -5058,13 +5093,13 @@ files = [ [[package]] name = "msal" -version = "1.28.1" +version = "1.29.0" description = "The Microsoft Authentication Library (MSAL) for Python library enables your app to access the Microsoft Cloud by supporting authentication of users with Microsoft Azure Active Directory accounts (AAD) and Microsoft Accounts (MSA) using industry standard OAuth2 and OpenID Connect." optional = false python-versions = ">=3.7" files = [ - {file = "msal-1.28.1-py3-none-any.whl", hash = "sha256:563c2d70de77a2ca9786aab84cb4e133a38a6897e6676774edc23d610bfc9e7b"}, - {file = "msal-1.28.1.tar.gz", hash = "sha256:d72bbfe2d5c2f2555f4bc6205be4450ddfd12976610dd9a16a9ab0f05c68b64d"}, + {file = "msal-1.29.0-py3-none-any.whl", hash = "sha256:6b301e63f967481f0cc1a3a3bac0cf322b276855bc1b0955468d9deb3f33d511"}, + {file = "msal-1.29.0.tar.gz", hash = "sha256:8f6725f099752553f9b2fe84125e2a5ebe47b49f92eacca33ebedd3a9ebaae25"}, ] [package.dependencies] @@ -5077,22 +5112,18 @@ broker = ["pymsalruntime (>=0.13.2,<0.17)"] [[package]] name = "msal-extensions" -version = "1.1.0" +version = "1.2.0" description = "Microsoft Authentication Library extensions (MSAL EX) provides a persistence API that can save your data on disk, encrypted on Windows, macOS and Linux. Concurrent data access will be coordinated by a file lock mechanism." optional = false python-versions = ">=3.7" files = [ - {file = "msal-extensions-1.1.0.tar.gz", hash = "sha256:6ab357867062db7b253d0bd2df6d411c7891a0ee7308d54d1e4317c1d1c54252"}, - {file = "msal_extensions-1.1.0-py3-none-any.whl", hash = "sha256:01be9711b4c0b1a151450068eeb2c4f0997df3bba085ac299de3a66f585e382f"}, + {file = "msal_extensions-1.2.0-py3-none-any.whl", hash = "sha256:cf5ba83a2113fa6dc011a254a72f1c223c88d7dfad74cc30617c4679a417704d"}, + {file = "msal_extensions-1.2.0.tar.gz", hash = "sha256:6f41b320bfd2933d631a215c91ca0dd3e67d84bd1a2f50ce917d5874ec646bef"}, ] [package.dependencies] -msal = ">=0.4.1,<2.0.0" -packaging = "*" -portalocker = [ - {version = ">=1.0,<3", markers = "platform_system != \"Windows\""}, - {version = ">=1.6,<3", markers = "platform_system == \"Windows\""}, -] +msal = ">=1.29,<2" +portalocker = ">=1.4,<3" [[package]] name = "msrest" @@ -5529,57 +5560,57 @@ dev = ["black", "mypy", "pytest"] [[package]] name = "orjson" -version = "3.10.4" +version = "3.10.5" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false python-versions = ">=3.8" files = [ - {file = "orjson-3.10.4-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:afca963f19ca60c7aedadea9979f769139127288dd58ccf3f7c5e8e6dc62cabf"}, - {file = "orjson-3.10.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b112eff36ba7ccc7a9d6b87e17b9d6bde4312d05e3ddf66bf5662481dee846"}, - {file = "orjson-3.10.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02b192eaba048b1039eca9a0cef67863bd5623042f5c441889a9957121d97e14"}, - {file = "orjson-3.10.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:827c3d0e4fc44242c82bfdb1a773235b8c0575afee99a9fa9a8ce920c14e440f"}, - {file = "orjson-3.10.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca8ec09724f10ec209244caeb1f9f428b6bb03f2eda9ed5e2c4dd7f2b7fabd44"}, - {file = "orjson-3.10.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8eaa5d531a8fde11993cbcb27e9acf7d9c457ba301adccb7fa3a021bfecab46c"}, - {file = "orjson-3.10.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e112aa7fc4ea67367ec5e86c39a6bb6c5719eddc8f999087b1759e765ddaf2d4"}, - {file = "orjson-3.10.4-cp310-none-win32.whl", hash = "sha256:1538844fb88446c42da3889f8c4ecce95a630b5a5ba18ecdfe5aea596f4dff21"}, - {file = "orjson-3.10.4-cp310-none-win_amd64.whl", hash = "sha256:de02811903a2e434127fba5389c3cc90f689542339a6e52e691ab7f693407b5a"}, - {file = "orjson-3.10.4-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:358afaec75de7237dfea08e6b1b25d226e33a1e3b6dc154fc99eb697f24a1ffa"}, - {file = "orjson-3.10.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb4e292c3198ab3d93e5f877301d2746be4ca0ba2d9c513da5e10eb90e19ff52"}, - {file = "orjson-3.10.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c39e57cf6323a39238490092985d5d198a7da4a3be013cc891a33fef13a536e"}, - {file = "orjson-3.10.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f86df433fc01361ff9270ad27455ce1ad43cd05e46de7152ca6adb405a16b2f6"}, - {file = "orjson-3.10.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c9966276a2c97e93e6cbe8286537f88b2a071827514f0d9d47a0aefa77db458"}, - {file = "orjson-3.10.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c499a14155a1f5a1e16e0cd31f6cf6f93965ac60a0822bc8340e7e2d3dac1108"}, - {file = "orjson-3.10.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3087023ce904a327c29487eb7e1f2c060070e8dbb9a3991b8e7952a9c6e62f38"}, - {file = "orjson-3.10.4-cp311-none-win32.whl", hash = "sha256:f965893244fe348b59e5ce560693e6dd03368d577ce26849b5d261ce31c70101"}, - {file = "orjson-3.10.4-cp311-none-win_amd64.whl", hash = "sha256:c212f06fad6aa6ce85d5665e91a83b866579f29441a47d3865c57329c0857357"}, - {file = "orjson-3.10.4-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d0965a8b0131959833ca8a65af60285995d57ced0de2fd8f16fc03235975d238"}, - {file = "orjson-3.10.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27b64695d9f2aef3ae15a0522e370ec95c946aaea7f2c97a1582a62b3bdd9169"}, - {file = "orjson-3.10.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:867d882ddee6a20be4c8b03ae3d2b0333894d53ad632d32bd9b8123649577171"}, - {file = "orjson-3.10.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a0667458f8a8ceb6dee5c08fec0b46195f92c474cbbec71dca2a6b7fd5b67b8d"}, - {file = "orjson-3.10.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3eac9befc4eaec1d1ff3bba6210576be4945332dde194525601c5ddb5c060d3"}, - {file = "orjson-3.10.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4343245443552eae240a33047a6d1bcac7a754ad4b1c57318173c54d7efb9aea"}, - {file = "orjson-3.10.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:30153e269eea43e98918d4d462a36a7065031d9246407dfff2579a4e457515c1"}, - {file = "orjson-3.10.4-cp312-none-win32.whl", hash = "sha256:1a7d092ee043abf3db19c2183115e80676495c9911843fdb3ebd48ca7b73079e"}, - {file = "orjson-3.10.4-cp312-none-win_amd64.whl", hash = "sha256:07a2adbeb8b9efe6d68fc557685954a1f19d9e33f5cc018ae1a89e96647c1b65"}, - {file = "orjson-3.10.4-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:f5a746f3d908bce1a1e347b9ca89864047533bdfab5a450066a0315f6566527b"}, - {file = "orjson-3.10.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:465b4a8a3e459f8d304c19071b4badaa9b267c59207a005a7dd9dfe13d3a423f"}, - {file = "orjson-3.10.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:35858d260728c434a3d91b60685ab32418318567e8902039837e1c2af2719e0b"}, - {file = "orjson-3.10.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8a5ba090d40c4460312dd69c232b38c2ff67a823185cfe667e841c9dd5c06841"}, - {file = "orjson-3.10.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5dde86755d064664e62e3612a166c28298aa8dfd35a991553faa58855ae739cc"}, - {file = "orjson-3.10.4-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:020a9e9001cfec85c156ef3b185ff758b62ef986cefdb8384c4579facd5ce126"}, - {file = "orjson-3.10.4-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:3bf8e6e3388a2e83a86466c912387e0f0a765494c65caa7e865f99969b76ba0d"}, - {file = "orjson-3.10.4-cp38-none-win32.whl", hash = "sha256:c5a1cca6a4a3129db3da68a25dc0a459a62ae58e284e363b35ab304202d9ba9e"}, - {file = "orjson-3.10.4-cp38-none-win_amd64.whl", hash = "sha256:ecd97d98d7bee3e3d51d0b51c92c457f05db4993329eea7c69764f9820e27eb3"}, - {file = "orjson-3.10.4-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:71362daa330a2fc85553a1469185ac448547392a8f83d34e67779f8df3a52743"}, - {file = "orjson-3.10.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d24b59d1fecb0fd080c177306118a143f7322335309640c55ed9580d2044e363"}, - {file = "orjson-3.10.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e906670aea5a605b083ebb58d575c35e88cf880fa372f7cedaac3d51e98ff164"}, - {file = "orjson-3.10.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7ce32ed4bc4d632268e4978e595fe5ea07e026b751482b4a0feec48f66a90abc"}, - {file = "orjson-3.10.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1dcd34286246e0c5edd0e230d1da2daab2c1b465fcb6bac85b8d44057229d40a"}, - {file = "orjson-3.10.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c45d4b8c403e50beedb1d006a8916d9910ed56bceaf2035dc253618b44d0a161"}, - {file = "orjson-3.10.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:aaed3253041b5002a4f5bfdf6f7b5cce657d974472b0699a469d439beba40381"}, - {file = "orjson-3.10.4-cp39-none-win32.whl", hash = "sha256:9a4f41b7dbf7896f8dbf559b9b43dcd99e31e0d49ac1b59d74f52ce51ab10eb9"}, - {file = "orjson-3.10.4-cp39-none-win_amd64.whl", hash = "sha256:6c4eb7d867ed91cb61e6514cb4f457aa01d7b0fd663089df60a69f3d38b69d4c"}, - {file = "orjson-3.10.4.tar.gz", hash = "sha256:c912ed25b787c73fe994a5decd81c3f3b256599b8a87d410d799d5d52013af2a"}, + {file = "orjson-3.10.5-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:545d493c1f560d5ccfc134803ceb8955a14c3fcb47bbb4b2fee0232646d0b932"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4324929c2dd917598212bfd554757feca3e5e0fa60da08be11b4aa8b90013c1"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c13ca5e2ddded0ce6a927ea5a9f27cae77eee4c75547b4297252cb20c4d30e6"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b6c8e30adfa52c025f042a87f450a6b9ea29649d828e0fec4858ed5e6caecf63"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:338fd4f071b242f26e9ca802f443edc588fa4ab60bfa81f38beaedf42eda226c"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6970ed7a3126cfed873c5d21ece1cd5d6f83ca6c9afb71bbae21a0b034588d96"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:235dadefb793ad12f7fa11e98a480db1f7c6469ff9e3da5e73c7809c700d746b"}, + {file = "orjson-3.10.5-cp310-none-win32.whl", hash = "sha256:be79e2393679eda6a590638abda16d167754393f5d0850dcbca2d0c3735cebe2"}, + {file = "orjson-3.10.5-cp310-none-win_amd64.whl", hash = "sha256:c4a65310ccb5c9910c47b078ba78e2787cb3878cdded1702ac3d0da71ddc5228"}, + {file = "orjson-3.10.5-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cdf7365063e80899ae3a697def1277c17a7df7ccfc979990a403dfe77bb54d40"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b68742c469745d0e6ca5724506858f75e2f1e5b59a4315861f9e2b1df77775a"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d10cc1b594951522e35a3463da19e899abe6ca95f3c84c69e9e901e0bd93d38"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcbe82b35d1ac43b0d84072408330fd3295c2896973112d495e7234f7e3da2e1"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c0eb7e0c75e1e486c7563fe231b40fdd658a035ae125c6ba651ca3b07936f5"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:53ed1c879b10de56f35daf06dbc4a0d9a5db98f6ee853c2dbd3ee9d13e6f302f"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:099e81a5975237fda3100f918839af95f42f981447ba8f47adb7b6a3cdb078fa"}, + {file = "orjson-3.10.5-cp311-none-win32.whl", hash = "sha256:1146bf85ea37ac421594107195db8bc77104f74bc83e8ee21a2e58596bfb2f04"}, + {file = "orjson-3.10.5-cp311-none-win_amd64.whl", hash = "sha256:36a10f43c5f3a55c2f680efe07aa93ef4a342d2960dd2b1b7ea2dd764fe4a37c"}, + {file = "orjson-3.10.5-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:68f85ecae7af14a585a563ac741b0547a3f291de81cd1e20903e79f25170458f"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28afa96f496474ce60d3340fe8d9a263aa93ea01201cd2bad844c45cd21f5268"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cd684927af3e11b6e754df80b9ffafd9fb6adcaa9d3e8fdd5891be5a5cad51e"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d21b9983da032505f7050795e98b5d9eee0df903258951566ecc358f6696969"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ad1de7fef79736dde8c3554e75361ec351158a906d747bd901a52a5c9c8d24b"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d97531cdfe9bdd76d492e69800afd97e5930cb0da6a825646667b2c6c6c0211"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d69858c32f09c3e1ce44b617b3ebba1aba030e777000ebdf72b0d8e365d0b2b3"}, + {file = "orjson-3.10.5-cp312-none-win32.whl", hash = "sha256:64c9cc089f127e5875901ac05e5c25aa13cfa5dbbbd9602bda51e5c611d6e3e2"}, + {file = "orjson-3.10.5-cp312-none-win_amd64.whl", hash = "sha256:b2efbd67feff8c1f7728937c0d7f6ca8c25ec81373dc8db4ef394c1d93d13dc5"}, + {file = "orjson-3.10.5-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:03b565c3b93f5d6e001db48b747d31ea3819b89abf041ee10ac6988886d18e01"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:584c902ec19ab7928fd5add1783c909094cc53f31ac7acfada817b0847975f26"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a35455cc0b0b3a1eaf67224035f5388591ec72b9b6136d66b49a553ce9eb1e6"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1670fe88b116c2745a3a30b0f099b699a02bb3482c2591514baf5433819e4f4d"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:185c394ef45b18b9a7d8e8f333606e2e8194a50c6e3c664215aae8cf42c5385e"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ca0b3a94ac8d3886c9581b9f9de3ce858263865fdaa383fbc31c310b9eac07c9"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:dfc91d4720d48e2a709e9c368d5125b4b5899dced34b5400c3837dadc7d6271b"}, + {file = "orjson-3.10.5-cp38-none-win32.whl", hash = "sha256:c05f16701ab2a4ca146d0bca950af254cb7c02f3c01fca8efbbad82d23b3d9d4"}, + {file = "orjson-3.10.5-cp38-none-win_amd64.whl", hash = "sha256:8a11d459338f96a9aa7f232ba95679fc0c7cedbd1b990d736467894210205c09"}, + {file = "orjson-3.10.5-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:85c89131d7b3218db1b24c4abecea92fd6c7f9fab87441cfc342d3acc725d807"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66215277a230c456f9038d5e2d84778141643207f85336ef8d2a9da26bd7ca"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:51bbcdea96cdefa4a9b4461e690c75ad4e33796530d182bdd5c38980202c134a"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbead71dbe65f959b7bd8cf91e0e11d5338033eba34c114f69078d59827ee139"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df58d206e78c40da118a8c14fc189207fffdcb1f21b3b4c9c0c18e839b5a214"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c4057c3b511bb8aef605616bd3f1f002a697c7e4da6adf095ca5b84c0fd43595"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b39e006b00c57125ab974362e740c14a0c6a66ff695bff44615dcf4a70ce2b86"}, + {file = "orjson-3.10.5-cp39-none-win32.whl", hash = "sha256:eded5138cc565a9d618e111c6d5c2547bbdd951114eb822f7f6309e04db0fb47"}, + {file = "orjson-3.10.5-cp39-none-win_amd64.whl", hash = "sha256:cc28e90a7cae7fcba2493953cff61da5a52950e78dc2dacfe931a317ee3d8de7"}, + {file = "orjson-3.10.5.tar.gz", hash = "sha256:7a5baef8a4284405d96c90c7c62b755e9ef1ada84c2406c24a9ebec86b89f46d"}, ] [[package]] @@ -6004,13 +6035,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "portalocker" -version = "2.8.2" +version = "2.10.0" description = "Wraps the portalocker recipe for easy usage" optional = false python-versions = ">=3.8" files = [ - {file = "portalocker-2.8.2-py3-none-any.whl", hash = "sha256:cfb86acc09b9aa7c3b43594e19be1345b9d16af3feb08bf92f23d4dce513a28e"}, - {file = "portalocker-2.8.2.tar.gz", hash = "sha256:2b035aa7828e46c58e9b31390ee1f169b98e1066ab10b9a6a861fe7e25ee4f33"}, + {file = "portalocker-2.10.0-py3-none-any.whl", hash = "sha256:48944147b2cd42520549bc1bb8fe44e220296e56f7c3d551bc6ecce69d9b0de1"}, + {file = "portalocker-2.10.0.tar.gz", hash = "sha256:49de8bc0a2f68ca98bf9e219c81a3e6b27097c7bf505a87c5a112ce1aaeb9b81"}, ] [package.dependencies] @@ -6083,20 +6114,20 @@ wcwidth = "*" [[package]] name = "proto-plus" -version = "1.23.0" +version = "1.24.0" description = "Beautiful, Pythonic protocol buffers." optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "proto-plus-1.23.0.tar.gz", hash = "sha256:89075171ef11988b3fa157f5dbd8b9cf09d65fffee97e29ce403cd8defba19d2"}, - {file = "proto_plus-1.23.0-py3-none-any.whl", hash = "sha256:a829c79e619e1cf632de091013a4173deed13a55f326ef84f05af6f50ff4c82c"}, + {file = "proto-plus-1.24.0.tar.gz", hash = "sha256:30b72a5ecafe4406b0d339db35b56c4059064e69227b8c3bda7462397f966445"}, + {file = "proto_plus-1.24.0-py3-none-any.whl", hash = "sha256:402576830425e5f6ce4c2a6702400ac79897dab0b4343821aa5188b0fab81a12"}, ] [package.dependencies] -protobuf = ">=3.19.0,<5.0.0dev" +protobuf = ">=3.19.0,<6.0.0dev" [package.extras] -testing = ["google-api-core[grpc] (>=1.31.5)"] +testing = ["google-api-core (>=1.31.5)"] [[package]] name = "protobuf" @@ -6131,27 +6162,28 @@ files = [ [[package]] name = "psutil" -version = "5.9.8" +version = "6.0.0" description = "Cross-platform lib for process and system monitoring in Python." optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "psutil-5.9.8-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8"}, - {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73"}, - {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7"}, - {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36"}, - {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d"}, - {file = "psutil-5.9.8-cp27-none-win32.whl", hash = "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e"}, - {file = "psutil-5.9.8-cp27-none-win_amd64.whl", hash = "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631"}, - {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, - {file = "psutil-5.9.8-cp36-cp36m-win32.whl", hash = "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee"}, - {file = "psutil-5.9.8-cp36-cp36m-win_amd64.whl", hash = "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2"}, - {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, - {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, - {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, - {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, + {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, + {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, + {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, + {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, ] [package.extras] @@ -6345,13 +6377,13 @@ idna = ["idna (>=2.1)"] [[package]] name = "pycodestyle" -version = "2.11.1" +version = "2.12.0" description = "Python style guide checker" optional = false python-versions = ">=3.8" files = [ - {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"}, - {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, + {file = "pycodestyle-2.12.0-py2.py3-none-any.whl", hash = "sha256:949a39f6b86c3e1515ba1787c2022131d165a8ad271b11370a8819aa070269e4"}, + {file = "pycodestyle-2.12.0.tar.gz", hash = "sha256:442f950141b4f43df752dd303511ffded3a04c2b6fb7f65980574f0c31e6e79c"}, ] [[package]] @@ -6849,6 +6881,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -7187,13 +7220,13 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] [[package]] name = "rich-argparse" -version = "1.5.1" +version = "1.5.2" description = "Rich help formatters for argparse and optparse" optional = false python-versions = ">=3.8" files = [ - {file = "rich_argparse-1.5.1-py3-none-any.whl", hash = "sha256:527ffd2afebc0fd6fae312da4720df0c8bfa858e27c35d5471c338bc4c193939"}, - {file = "rich_argparse-1.5.1.tar.gz", hash = "sha256:025ca081da4dbb013dd4ea5213a782553ef8653d4da33c5b0a96b0d28f73058e"}, + {file = "rich_argparse-1.5.2-py3-none-any.whl", hash = "sha256:7027503d5849e27fc7cc85fb58504363606f2ec1c8b3c27d9a8ad28788faf877"}, + {file = "rich_argparse-1.5.2.tar.gz", hash = "sha256:84d348d5b6dafe99fffe2c7ea1ca0afe14096c921693445b9eee65ee4fcbfd2c"}, ] [package.dependencies] @@ -7323,28 +7356,28 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.4.8" +version = "0.4.10" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.4.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7663a6d78f6adb0eab270fa9cf1ff2d28618ca3a652b60f2a234d92b9ec89066"}, - {file = "ruff-0.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eeceb78da8afb6de0ddada93112869852d04f1cd0f6b80fe464fd4e35c330913"}, - {file = "ruff-0.4.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aad360893e92486662ef3be0a339c5ca3c1b109e0134fcd37d534d4be9fb8de3"}, - {file = "ruff-0.4.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:284c2e3f3396fb05f5f803c9fffb53ebbe09a3ebe7dda2929ed8d73ded736deb"}, - {file = "ruff-0.4.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7354f921e3fbe04d2a62d46707e569f9315e1a613307f7311a935743c51a764"}, - {file = "ruff-0.4.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:72584676164e15a68a15778fd1b17c28a519e7a0622161eb2debdcdabdc71883"}, - {file = "ruff-0.4.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9678d5c9b43315f323af2233a04d747409d1e3aa6789620083a82d1066a35199"}, - {file = "ruff-0.4.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704977a658131651a22b5ebeb28b717ef42ac6ee3b11e91dc87b633b5d83142b"}, - {file = "ruff-0.4.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d05f8d6f0c3cce5026cecd83b7a143dcad503045857bc49662f736437380ad45"}, - {file = "ruff-0.4.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6ea874950daca5697309d976c9afba830d3bf0ed66887481d6bca1673fc5b66a"}, - {file = "ruff-0.4.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fc95aac2943ddf360376be9aa3107c8cf9640083940a8c5bd824be692d2216dc"}, - {file = "ruff-0.4.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:384154a1c3f4bf537bac69f33720957ee49ac8d484bfc91720cc94172026ceed"}, - {file = "ruff-0.4.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e9d5ce97cacc99878aa0d084c626a15cd21e6b3d53fd6f9112b7fc485918e1fa"}, - {file = "ruff-0.4.8-py3-none-win32.whl", hash = "sha256:6d795d7639212c2dfd01991259460101c22aabf420d9b943f153ab9d9706e6a9"}, - {file = "ruff-0.4.8-py3-none-win_amd64.whl", hash = "sha256:e14a3a095d07560a9d6769a72f781d73259655919d9b396c650fc98a8157555d"}, - {file = "ruff-0.4.8-py3-none-win_arm64.whl", hash = "sha256:14019a06dbe29b608f6b7cbcec300e3170a8d86efaddb7b23405cb7f7dcaf780"}, - {file = "ruff-0.4.8.tar.gz", hash = "sha256:16d717b1d57b2e2fd68bd0bf80fb43931b79d05a7131aa477d66fc40fbd86268"}, + {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"}, + {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"}, + {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"}, + {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"}, + {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"}, + {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, ] [[package]] @@ -7453,13 +7486,13 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo [[package]] name = "sentry-sdk" -version = "2.5.1" +version = "2.6.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = ">=3.6" files = [ - {file = "sentry_sdk-2.5.1-py2.py3-none-any.whl", hash = "sha256:1f87acdce4a43a523ae5aa21a3fc37522d73ebd9ec04b1dbf01aa3d173852def"}, - {file = "sentry_sdk-2.5.1.tar.gz", hash = "sha256:fbc40a78a8a9c6675133031116144f0d0940376fa6e4e1acd5624c90b0aaf58b"}, + {file = "sentry_sdk-2.6.0-py2.py3-none-any.whl", hash = "sha256:422b91cb49378b97e7e8d0e8d5a1069df23689d45262b86f54988a7db264e874"}, + {file = "sentry_sdk-2.6.0.tar.gz", hash = "sha256:65cc07e9c6995c5e316109f138570b32da3bd7ff8d0d0ee4aaf2628c3dd8127d"}, ] [package.dependencies] @@ -7696,6 +7729,28 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "skops" +version = "0.9.0" +description = "A set of tools to push scikit-learn based models to and pull from Hugging Face Hub" +optional = false +python-versions = ">=3.8" +files = [ + {file = "skops-0.9.0-py3-none-any.whl", hash = "sha256:05645199bf6976e1f6dbba4a0704799cd5d2fcef18a98b069b4c84744e1a80a1"}, + {file = "skops-0.9.0.tar.gz", hash = "sha256:3e39333d65f26d5863ad44db5001b4cfe6a29642274ac37af54fb834813aee3f"}, +] + +[package.dependencies] +huggingface-hub = ">=0.17.0" +packaging = ">=17.0" +scikit-learn = ">=0.24" +tabulate = ">=0.8.8" + +[package.extras] +docs = ["fairlearn (>=0.7.0)", "matplotlib (>=3.3)", "numpydoc (>=1.0.0)", "pandas (>=1)", "scikit-learn-intelex (>=2021.7.1)", "sphinx (>=3.2.0)", "sphinx-gallery (>=0.7.0)", "sphinx-issues (>=1.2.0)", "sphinx-prompt (>=1.3.0)", "sphinx-rtd-theme (>=1)"] +rich = ["rich (>=12)"] +tests = ["catboost (>=1.0)", "fairlearn (>=0.7.0)", "flake8 (>=3.8.2)", "flaky (>=3.7.0)", "lightgbm (>=3)", "matplotlib (>=3.3)", "pandas (>=1)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "quantile-forest (>=1.0.0)", "rich (>=12)", "types-requests (>=2.28.5)", "xgboost (>=1.6)"] + [[package]] name = "smmap" version = "5.0.1" @@ -7955,13 +8010,13 @@ widechars = ["wcwidth"] [[package]] name = "tenacity" -version = "8.3.0" +version = "8.4.1" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" files = [ - {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"}, - {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"}, + {file = "tenacity-8.4.1-py3-none-any.whl", hash = "sha256:28522e692eda3e1b8f5e99c51464efcc0b9fc86933da92415168bc1c4e2308fa"}, + {file = "tenacity-8.4.1.tar.gz", hash = "sha256:54b1412b878ddf7e1f1577cd49527bad8cdef32421bd599beac0c6c3f10582fd"}, ] [package.extras] @@ -8035,6 +8090,26 @@ files = [ {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"}, ] +[[package]] +name = "tqdm" +version = "4.66.4" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"}, + {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "traitlets" version = "5.14.3" @@ -8144,13 +8219,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] @@ -8205,13 +8280,13 @@ test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)" [[package]] name = "virtualenv" -version = "20.26.2" +version = "20.26.3" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.2-py3-none-any.whl", hash = "sha256:a624db5e94f01ad993d476b9ee5346fdf7b9de43ccaee0e0197012dc838a0e9b"}, - {file = "virtualenv-20.26.2.tar.gz", hash = "sha256:82bf0f4eebbb78d36ddaee0283d43fe5736b53880b8a8cdcd37390a07ac3741c"}, + {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, + {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, ] [package.dependencies] @@ -8225,18 +8300,18 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "wandb" -version = "0.17.1" +version = "0.17.2" description = "A CLI and library for interacting with the Weights & Biases API." optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.17.1-py3-none-any.whl", hash = "sha256:6d3eb0abb5cd189992ffe7892167aa54900cc9bfc78b9a1af032b930cc9b4bc5"}, - {file = "wandb-0.17.1-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:14f4e6751541c171e85ab25d208b2d04fe12d08e1d9bed9fb67c78b2e31617f6"}, - {file = "wandb-0.17.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:037aef76eddc0eef3b0fee604eade50273eae6558eda900b2a1c096acfa764f1"}, - {file = "wandb-0.17.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1098b1244fa448e0ad8326c518588bb27e3a78caf9c627f0bb74019446d392db"}, - {file = "wandb-0.17.1-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eeacf9b7ab075fe2c07cc38481d8be8ff50c8e7569f68633c9d885c118ef65b2"}, - {file = "wandb-0.17.1-py3-none-win32.whl", hash = "sha256:93baed41ce7842c1a2f6e7eb20d6ad2abc402c74fdd035b14cd46ab1ee354e21"}, - {file = "wandb-0.17.1-py3-none-win_amd64.whl", hash = "sha256:c6efd1f9c77815ed0b9b3df00d8bb46df811c0e3cfd290deb7c9c65ad9a0370b"}, + {file = "wandb-0.17.2-py3-none-any.whl", hash = "sha256:4bd351be28cea87730365856cfaa72f72ceb787accc21bad359dde5aa9c4356d"}, + {file = "wandb-0.17.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:638353a2d702caedd304a5f1e526ef93a291c984c109fcb444262a57aeaacec9"}, + {file = "wandb-0.17.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:824e33ca77af87f87a9cf1122acba164da5bf713adc9d67332bc686028921ec9"}, + {file = "wandb-0.17.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:032ca5939008643349af178a8b66b8047a1eefcb870c4c4a86e22acafde6470f"}, + {file = "wandb-0.17.2-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9558bab47a0c8ac4f22cfa2d43f91d1bc1f75d4255629286db674fe49fcd30e5"}, + {file = "wandb-0.17.2-py3-none-win32.whl", hash = "sha256:4bc176e3c81be216dc889fcd098341eb17a14b04e080d4343ce3f0b1740abfc1"}, + {file = "wandb-0.17.2-py3-none-win_amd64.whl", hash = "sha256:62cd707f38b5711971729dae80343b8c35f6003901e690166cc6d526187a9785"}, ] [package.dependencies] @@ -8464,33 +8539,6 @@ markupsafe = "*" [package.extras] email = ["email-validator"] -[[package]] -name = "xgboost" -version = "1.7.6" -description = "XGBoost Python Package" -optional = false -python-versions = ">=3.8" -files = [ - {file = "xgboost-1.7.6-py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64.whl", hash = "sha256:4c34675b4d2678c624ddde5d45361e7e16046923e362e4e609b88353e6b87124"}, - {file = "xgboost-1.7.6-py3-none-macosx_12_0_arm64.whl", hash = "sha256:59b4b366d2cafc7f645e87d897983a5b59be02876194b1d213bd8d8b811d8ce8"}, - {file = "xgboost-1.7.6-py3-none-manylinux2014_aarch64.whl", hash = "sha256:281c3c6f4fbed2d36bf95cd02a641afa95e72e9abde70064056da5e76233e8df"}, - {file = "xgboost-1.7.6-py3-none-manylinux2014_x86_64.whl", hash = "sha256:b1d5db49b199152d62bd9217c98760207d3de86d2b9d243260c573ffe638f80a"}, - {file = "xgboost-1.7.6-py3-none-win_amd64.whl", hash = "sha256:127cf1f5e2ec25cd41429394c6719b87af1456ce583e89f0bffd35d02ad18bcb"}, - {file = "xgboost-1.7.6.tar.gz", hash = "sha256:1c527554a400445e0c38186039ba1a00425dcdb4e40b37eed0e74cb39a159c47"}, -] - -[package.dependencies] -numpy = "*" -scipy = "*" - -[package.extras] -dask = ["dask", "distributed", "pandas"] -datatable = ["datatable"] -pandas = ["pandas"] -plotting = ["graphviz", "matplotlib"] -pyspark = ["cloudpickle", "pyspark", "scikit-learn"] -scikit-learn = ["scikit-learn"] - [[package]] name = "xyzservices" version = "2024.6.0" @@ -8641,4 +8689,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.10, <3.11" -content-hash = "8fbf0fce421fcf89ec73570e7ff33e562a87af3cf13cbf96c1e1add1f15ad742" +content-hash = "6c00c39b39dbea140d90c2dcf903b1dfb19090a43f95c8d02db91e939e8fff9b" diff --git a/pyproject.toml b/pyproject.toml index 8c9c19fc3..e2d973421 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ pyspark = "3.3.4" scipy = "^1.11.4" hydra-core = "^1.3.2" pyliftover = "^0.4" -xgboost = "^1.7.3" numpy = "^1.26.2" hail = "0.2.127" wandb = ">=0.16.2,<0.18.0" @@ -33,6 +32,7 @@ omegaconf = "^2.3.0" typing-extensions = "^4.9.0" scikit-learn = "^1.3.2" pandas = {extras = ["gcp", "parquet"], version = "^2.2.2"} +skops = "^0.9.0" google-cloud-secret-manager = "^2.20.0" [tool.poetry.dev-dependencies] diff --git a/src/gentropy/common/utils.py b/src/gentropy/common/utils.py index 52be4429f..81a2b4bfd 100644 --- a/src/gentropy/common/utils.py +++ b/src/gentropy/common/utils.py @@ -332,3 +332,26 @@ def access_gcp_secret(secret_id: str, project_id: str) -> str: name = f"projects/{project_id}/secrets/{secret_id}/versions/latest" response = client.access_secret_version(name=name) return response.payload.data.decode("UTF-8") + + +def copy_to_gcs(source_path: str, destination_blob: str) -> None: + """Copy a file to a Google Cloud Storage bucket. + + Args: + source_path (str): Path to the local file to copy + destination_blob (str): GS path to the destination blob in the GCS bucket + + Raises: + ValueError: If the path is a directory + """ + import os + from urllib.parse import urlparse + + from google.cloud import storage + + if os.path.isdir(source_path): + raise ValueError("Path should be a file, not a directory.") + client = storage.Client() + bucket = client.bucket(bucket_name=urlparse(destination_blob).hostname) + blob = bucket.blob(blob_name=urlparse(destination_blob).path.lstrip("/")) + blob.upload_from_filename(source_path) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 29ef5a48f..f8077722f 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -209,12 +209,12 @@ class LocusToGeneConfig(StepConfig): } ) run_mode: str = MISSING - model_path: str = MISSING predictions_path: str = MISSING credible_set_path: str = MISSING variant_gene_path: str = MISSING colocalisation_path: str = MISSING study_index_path: str = MISSING + model_path: str | None = None feature_matrix_path: str | None = None gold_standard_curation_path: str | None = None gene_interactions_path: str | None = None @@ -248,28 +248,34 @@ class LocusToGeneConfig(StepConfig): "tuqtlColocClppMaximum", # max clpp for each (study, locus, gene) aggregating over all tuQTLs "tuqtlColocClppMaximumNeighborhood", - # # max log-likelihood ratio value for each (study, locus, gene) aggregating over all eQTLs - # "eqtlColocLlrLocalMaximum", - # # max log-likelihood ratio value for each (study, locus) aggregating over all eQTLs - # "eqtlColocLlpMaximumNeighborhood", - # # max log-likelihood ratio value for each (study, locus, gene) aggregating over all pQTLs - # "pqtlColocLlrLocalMaximum", - # # max log-likelihood ratio value for each (study, locus) aggregating over all pQTLs - # "pqtlColocLlpMaximumNeighborhood", - # # max log-likelihood ratio value for each (study, locus, gene) aggregating over all sQTLs - # "sqtlColocLlrLocalMaximum", - # # max log-likelihood ratio value for each (study, locus) aggregating over all sQTLs - # "sqtlColocLlpMaximumNeighborhood", + # max log-likelihood ratio value for each (study, locus, gene) aggregating over all eQTLs + "eqtlColocLlrMaximum", + # max log-likelihood ratio value for each (study, locus) aggregating over all eQTLs + "eqtlColocLlrMaximumNeighborhood", + # max log-likelihood ratio value for each (study, locus, gene) aggregating over all pQTLs + "pqtlColocLlrMaximum", + # max log-likelihood ratio value for each (study, locus) aggregating over all pQTLs + "pqtlColocLlrMaximumNeighborhood", + # max log-likelihood ratio value for each (study, locus, gene) aggregating over all sQTLs + "sqtlColocLlrMaximum", + # max log-likelihood ratio value for each (study, locus) aggregating over all sQTLs + "sqtlColocLlrMaximumNeighborhood", + # max log-likelihood ratio value for each (study, locus, gene) aggregating over all tuQTLs + "tuqtlColocLlrMaximum", + # max log-likelihood ratio value for each (study, locus) aggregating over all tuQTLs + "tuqtlColocLlrMaximumNeighborhood", ] ) hyperparameters: dict[str, Any] = field( default_factory=lambda: { + "n_estimators": 100, "max_depth": 5, - "loss_function": "binary:logistic", + "loss": "log_loss", } ) wandb_run_name: str | None = None - perform_cross_validation: bool = False + hf_hub_repo_id: str | None = "opentargets/locus_to_gene" + download_from_hub: bool = True _target_: str = "gentropy.l2g.LocusToGeneStep" diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 78926b46b..4c611e3da 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import reduce from typing import TYPE_CHECKING, Type @@ -26,15 +26,26 @@ class L2GFeatureMatrix(Dataset): Attributes: features_list (list[str] | None): List of features to use. If None, all possible features are used. + fixed_cols (list[str]): Columns that should be kept fixed in the feature matrix, although not considered as features. + mode (str): Mode of the feature matrix. Defaults to "train". Can be either "train" or "predict". """ features_list: list[str] | None = None + fixed_cols: list[str] = field(default_factory=lambda: ["studyLocusId", "geneId"]) + mode: str = "train" def __post_init__(self: L2GFeatureMatrix) -> None: - """Post-initialisation to set the features list. If not provided, all columns except the fixed ones are used.""" - fixed_cols = ["studyLocusId", "geneId", "goldStandardSet"] + """Post-initialisation to set the features list. If not provided, all columns except the fixed ones are used. + + Raises: + ValueError: If the mode is neither 'train' nor 'predict'. + """ + if self.mode not in ["train", "predict"]: + raise ValueError("Mode should be either 'train' or 'predict'") + if self.mode == "train": + self.fixed_cols = self.fixed_cols + ["goldStandardSet"] self.features_list = self.features_list or [ - col for col in self._df.columns if col not in fixed_cols + col for col in self._df.columns if col not in self.fixed_cols ] self.validate_schema() @@ -138,7 +149,8 @@ def fill_na( return self def select_features( - self: L2GFeatureMatrix, features_list: list[str] | None + self: L2GFeatureMatrix, + features_list: list[str] | None, ) -> L2GFeatureMatrix: """Select a subset of features from the feature matrix. @@ -147,25 +159,11 @@ def select_features( Returns: L2GFeatureMatrix: L2G feature matrix dataset - """ - features_list = features_list or self.features_list - fixed_cols = ["studyLocusId", "geneId", "goldStandardSet"] - self.df = self._df.select(fixed_cols + features_list) # type: ignore - return self - def train_test_split( - self: L2GFeatureMatrix, fraction: float - ) -> tuple[L2GFeatureMatrix, L2GFeatureMatrix]: - """Split the dataset into training and test sets. - - Args: - fraction (float): Fraction of the dataset to use for training - - Returns: - tuple[L2GFeatureMatrix, L2GFeatureMatrix]: Training and test datasets + Raises: + ValueError: If no features have been selected. """ - train, test = self._df.randomSplit([fraction, 1 - fraction], seed=42) - return ( - L2GFeatureMatrix(_df=train, _schema=L2GFeatureMatrix.get_schema()), - L2GFeatureMatrix(_df=test, _schema=L2GFeatureMatrix.get_schema()), - ) + if features_list := features_list or self.features_list: + self.df = self._df.select(self.fixed_cols + features_list) + return self + raise ValueError("features_list cannot be None") diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 2a704fbc9..9895f55b7 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -5,10 +5,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Type -import pyspark.sql.functions as f -from pyspark.ml.functions import vector_to_array - from gentropy.common.schemas import parse_spark_schema +from gentropy.common.session import Session from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix @@ -42,26 +40,41 @@ def get_schema(cls: type[L2GPrediction]) -> StructType: @classmethod def from_credible_set( cls: Type[L2GPrediction], - model_path: str, features_list: list[str], credible_set: StudyLocus, study_index: StudyIndex, v2g: V2G, coloc: Colocalisation, + session: Session, + model_path: str | None, + hf_token: str | None = None, + download_from_hub: bool = True, ) -> tuple[L2GPrediction, L2GFeatureMatrix]: """Extract L2G predictions for a set of credible sets derived from GWAS. Args: - model_path (str): Path to the fitted model features_list (list[str]): List of features to use for the model credible_set (StudyLocus): Credible set dataset study_index (StudyIndex): Study index dataset v2g (V2G): Variant to gene dataset coloc (Colocalisation): Colocalisation dataset + session (Session): Session object that contains the Spark session + model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). + hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private. + download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. Returns: tuple[L2GPrediction, L2GFeatureMatrix]: L2G dataset and feature matrix limited to GWAS study only. """ + # Load the model + if download_from_hub: + # Model ID defaults to "opentargets/locus_to_gene" and it assumes the name of the classifier is "classifier.skops". + model_id = model_path or "opentargets/locus_to_gene" + l2g_model = LocusToGeneModel.load_from_hub(model_id, hf_token) + elif model_path: + l2g_model = LocusToGeneModel.load_from_disk(model_path) + + # Prepare data fm = L2GFeatureMatrix.generate_features( features_list=features_list, credible_set=credible_set, @@ -70,35 +83,23 @@ def from_credible_set( colocalisation=coloc, ).fill_na() - gwas_fm = L2GFeatureMatrix( - _df=( - fm.df.join( - credible_set.filter_by_study_type("gwas", study_index).df.select( - "studyLocusId" - ), - on="studyLocusId", - ) - ), - _schema=L2GFeatureMatrix.get_schema(), - ) - return ( - L2GPrediction( - # Load and apply fitted model + gwas_fm = ( + L2GFeatureMatrix( _df=( - LocusToGeneModel.load_from_disk( - model_path, - features_list=features_list, - ) - .predict(gwas_fm) - # the probability of the positive class is the second element inside the probability array - # - this is selected as the L2G probability - .select( - "studyLocusId", - "geneId", - vector_to_array(f.col("probability"))[1].alias("score"), + fm.df.join( + credible_set.filter_by_study_type( + "gwas", study_index + ).df.select("studyLocusId"), + on="studyLocusId", ) ), - _schema=cls.get_schema(), - ), + _schema=L2GFeatureMatrix.get_schema(), + mode="predict", + ) + .select_features(features_list) + .persist() + ) + return ( + l2g_model.predict(gwas_fm, session), gwas_fm, ) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 0f1871872..432e46f88 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -1,13 +1,16 @@ """Step to run Locus to Gene either for inference or for training.""" + from __future__ import annotations from typing import Any import pyspark.sql.functions as f -import sklearn -from xgboost.spark import SparkXGBClassifier +from sklearn.ensemble import GradientBoostingClassifier +from wandb import login as wandb_login from gentropy.common.session import Session +from gentropy.common.utils import access_gcp_secret +from gentropy.config import LocusToGeneConfig from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_gold_standard import L2GGoldStandard @@ -26,7 +29,6 @@ def __init__( self, session: Session, run_mode: str, - model_path: str, predictions_path: str, credible_set_path: str, variant_gene_path: str, @@ -36,153 +38,204 @@ def __init__( gene_interactions_path: str, features_list: list[str], hyperparameters: dict[str, Any], + download_from_hub: bool, + model_path: str | None, feature_matrix_path: str | None = None, wandb_run_name: str | None = None, - perform_cross_validation: bool = False, + hf_hub_repo_id: str | None = LocusToGeneConfig().hf_hub_repo_id, ) -> None: - """Run step. + """Initialise the step and run the logic based on mode. Args: - session (Session): Session object. - run_mode (str): One of "train" or "predict". - model_path (str): Path to save the model. - predictions_path (str): Path to save the predictions. - credible_set_path (str): Path to credible set Parquet files. - variant_gene_path (str): Path to variant to gene Parquet files. - colocalisation_path (str): Path to colocalisation Parquet files. - study_index_path (str): Path to study index Parquet files. - gold_standard_curation_path (str): Path to gold standard curation JSON files. - gene_interactions_path (str): Path to gene interactions Parquet files. - features_list (list[str]): List of features to use. - hyperparameters (dict[str, Any]): Hyperparameters for the model. - feature_matrix_path (str | None): Optional path where the raw feature matrix should be stored. - If None, the feature matrix is not published. The feature matrix is published only when `run_mode` is `predict`. - wandb_run_name (str | None): Name of the run to be tracked on W&B. - perform_cross_validation (bool): Whether to perform cross validation. + session (Session): Session object that contains the Spark session + run_mode (str): Run mode, either 'train' or 'predict' + predictions_path (str): Path to save the predictions + credible_set_path (str): Path to the credible set dataset + variant_gene_path (str): Path to the variant to gene dataset + colocalisation_path (str): Path to the colocalisation dataset + study_index_path (str): Path to the study index dataset + gold_standard_curation_path (str): Path to the gold standard curation dataset + gene_interactions_path (str): Path to the gene interactions dataset + features_list (list[str]): List of features to use for the model + hyperparameters (dict[str, Any]): Hyperparameters for the model + download_from_hub (bool): Whether to download the model from the Hugging Face Hub + model_path (str | None): Path to the fitted model + feature_matrix_path (str | None): Path to save the feature matrix. Defaults to None. + wandb_run_name (str | None): Name of the wandb run. Defaults to None. + hf_hub_repo_id (str | None): Hugging Face Hub repo id. Defaults to the one set in the step configuration. Raises: - ValueError: if run_mode is not one of "train" or "predict". + ValueError: If run_mode is not 'train' or 'predict' """ - print("Sci-kit learn version: ", sklearn.__version__) # noqa: T201 if run_mode not in ["train", "predict"]: raise ValueError( f"run_mode must be one of 'train' or 'predict', got {run_mode}" ) + + self.session = session + self.run_mode = run_mode + self.model_path = model_path + self.predictions_path = predictions_path + self.credible_set_path = credible_set_path + self.variant_gene_path = variant_gene_path + self.colocalisation_path = colocalisation_path + self.study_index_path = study_index_path + self.gold_standard_curation_path = gold_standard_curation_path + self.gene_interactions_path = gene_interactions_path + self.features_list = list(features_list) + self.hyperparameters = dict(hyperparameters) + self.feature_matrix_path = feature_matrix_path + self.wandb_run_name = wandb_run_name + self.hf_hub_repo_id = hf_hub_repo_id + self.download_from_hub = download_from_hub + # Load common inputs - credible_set = StudyLocus.from_parquet( + self.credible_set = StudyLocus.from_parquet( session, credible_set_path, recursiveFileLookup=True ) - studies = StudyIndex.from_parquet( + self.studies = StudyIndex.from_parquet( session, study_index_path, recursiveFileLookup=True ) - v2g = V2G.from_parquet(session, variant_gene_path) - coloc = Colocalisation.from_parquet( + self.v2g = V2G.from_parquet(session, variant_gene_path) + self.coloc = Colocalisation.from_parquet( session, colocalisation_path, recursiveFileLookup=True ) if run_mode == "predict": - if not model_path or not predictions_path: - raise ValueError( - "model_path and predictions_path must be set for predict mode." - ) - predictions, feature_matrix = L2GPrediction.from_credible_set( - model_path, list(features_list), credible_set, studies, v2g, coloc + self.run_predict() + elif run_mode == "train": + self.run_train() + + def run_predict(self) -> None: + """Run the prediction step. + + Raises: + ValueError: If predictions_path is not set. + """ + if not self.predictions_path: + raise ValueError("predictions_path must be set for predict mode.") + predictions, feature_matrix = L2GPrediction.from_credible_set( + self.features_list, + self.credible_set, + self.studies, + self.v2g, + self.coloc, + self.session, + model_path=self.model_path, + hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), + download_from_hub=self.download_from_hub, + ) + if self.feature_matrix_path: + feature_matrix.df.write.mode(self.session.write_mode).parquet( + self.feature_matrix_path ) - if feature_matrix_path: - feature_matrix.df.write.mode(session.write_mode).parquet( - feature_matrix_path - ) - predictions.df.write.mode(session.write_mode).parquet(predictions_path) - session.logger.info(predictions_path) - elif ( - run_mode == "train" - and gold_standard_curation_path - and gene_interactions_path + predictions.df.write.mode(self.session.write_mode).parquet( + self.predictions_path + ) + self.session.logger.info(self.predictions_path) + + def run_train(self) -> None: + """Run the training step. + + Raises: + ValueError: If gold_standard_curation_path, gene_interactions_path, or wandb_run_name are not set. + """ + if not ( + self.gold_standard_curation_path + and self.gene_interactions_path + and self.wandb_run_name + and self.model_path ): - # Process gold standard and L2G features - gs_curation = session.spark.read.json(gold_standard_curation_path) - interactions = session.spark.read.parquet(gene_interactions_path) - study_locus_overlap = StudyLocus( - # We just extract overlaps of associations in the gold standard. This parsing is a duplication of the one in the gold standard curation, - # but we need to do it here to be able to parse gold standards later - _df=credible_set.df.join( - f.broadcast( - gs_curation.select( - StudyLocus.assign_study_locus_id( - f.col("association_info.otg_id"), # studyId - f.concat_ws( # variantId - "_", - f.col("sentinel_variant.locus_GRCh38.chromosome"), - f.col("sentinel_variant.locus_GRCh38.position"), - f.col("sentinel_variant.alleles.reference"), - f.col("sentinel_variant.alleles.alternative"), - ), - ).alias("studyLocusId"), - ) - ), - "studyLocusId", - "inner", - ), - _schema=StudyLocus.get_schema(), - ).find_overlaps(studies) - - gold_standards = L2GGoldStandard.from_otg_curation( - gold_standard_curation=gs_curation, - v2g=v2g, - study_locus_overlap=study_locus_overlap, - interactions=interactions, + raise ValueError( + "gold_standard_curation_path, gene_interactions_path, and wandb_run_name, and a path to save the model must be set for train mode." ) - fm = L2GFeatureMatrix.generate_features( - features_list=features_list, - credible_set=credible_set, - study_index=studies, - variant_gene=v2g, - colocalisation=coloc, - ) + wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") + # Process gold standard and L2G features + data = self._generate_feature_matrix().persist() - data = ( - # Annotate gold standards with features - L2GFeatureMatrix( - _df=fm.df.join( - f.broadcast( - gold_standards.df.drop("variantId", "studyId", "sources") - ), - on=["studyLocusId", "geneId"], - how="inner", - ), - _schema=L2GFeatureMatrix.get_schema(), + # Instantiate classifier and train model + l2g_model = LocusToGeneModel( + model=GradientBoostingClassifier(random_state=42), + hyperparameters=self.hyperparameters, + ) + wandb_login(key=wandb_key) + trained_model = LocusToGeneTrainer(model=l2g_model, feature_matrix=data).train( + self.wandb_run_name + ) + if trained_model.training_data and trained_model.model: + trained_model.save(self.model_path) + if self.hf_hub_repo_id: + hf_hub_token = access_gcp_secret( + "hfhub-key", "open-targets-genetics-dev" + ) + trained_model.export_to_hugging_face_hub( + # we upload the model in the filesystem + self.model_path.split("/")[-1], + hf_hub_token, + data=trained_model.training_data.df.drop( + "goldStandardSet", "geneId" + ).toPandas(), + repo_id=self.hf_hub_repo_id, + commit_message="chore: update model", ) - .fill_na() - .select_features(list(features_list)) - ) - # Instantiate classifier - estimator = SparkXGBClassifier( - eval_metric="logloss", - features_col="features", - label_col="label", - max_depth=5, - ) - l2g_model = LocusToGeneModel( - features_list=list(features_list), estimator=estimator + def _generate_feature_matrix(self) -> L2GFeatureMatrix: + """Generate the feature matrix for training. + + Returns: + L2GFeatureMatrix: Feature matrix with gold standards annotated with features. + """ + gs_curation = self.session.spark.read.json(self.gold_standard_curation_path) + interactions = self.session.spark.read.parquet(self.gene_interactions_path) + study_locus_overlap = StudyLocus( + _df=self.credible_set.df.join( + f.broadcast( + gs_curation.select( + StudyLocus.assign_study_locus_id( + f.col("association_info.otg_id"), # studyId + f.concat_ws( # variantId + "_", + f.col("sentinel_variant.locus_GRCh38.chromosome"), + f.col("sentinel_variant.locus_GRCh38.position"), + f.col("sentinel_variant.alleles.reference"), + f.col("sentinel_variant.alleles.alternative"), + ), + ).alias("studyLocusId"), + ) + ), + "studyLocusId", + "inner", + ), + _schema=StudyLocus.get_schema(), + ).find_overlaps(self.studies) + + gold_standards = L2GGoldStandard.from_otg_curation( + gold_standard_curation=gs_curation, + v2g=self.v2g, + study_locus_overlap=study_locus_overlap, + interactions=interactions, + ) + + fm = L2GFeatureMatrix.generate_features( + features_list=self.features_list, + credible_set=self.credible_set, + study_index=self.studies, + variant_gene=self.v2g, + colocalisation=self.coloc, + ) + + return ( + L2GFeatureMatrix( + _df=fm.df.join( + f.broadcast( + gold_standards.df.drop("variantId", "studyId", "sources") + ), + on=["studyLocusId", "geneId"], + how="inner", + ), + _schema=L2GFeatureMatrix.get_schema(), ) - if perform_cross_validation: - # Perform cross validation to extract what are the best hyperparameters - cv_folds = hyperparameters.get("cross_validation_folds", 5) - LocusToGeneTrainer.cross_validate( - l2g_model=l2g_model, - data=data, - num_folds=cv_folds, - ) - else: - # Train model - LocusToGeneTrainer.train( - gold_standard_data=data, - l2g_model=l2g_model, - model_path=model_path, - evaluate=True, - wandb_run_name=wandb_run_name, - **hyperparameters, - ) - session.logger.info(model_path) + .fill_na() + .select_features(self.features_list) + ) diff --git a/src/gentropy/method/l2g/evaluator.py b/src/gentropy/method/l2g/evaluator.py deleted file mode 100644 index f41b1d45e..000000000 --- a/src/gentropy/method/l2g/evaluator.py +++ /dev/null @@ -1,204 +0,0 @@ -"""Module that integrates Spark ML Evaluators with W&B for experiment tracking.""" -from __future__ import annotations - -import itertools -from typing import TYPE_CHECKING, Any, Dict - -from pyspark import keyword_only -from pyspark.ml.evaluation import ( - BinaryClassificationEvaluator, - Evaluator, - MulticlassClassificationEvaluator, -) -from pyspark.ml.param import Param, Params, TypeConverters -from wandb.sdk.wandb_run import Run - -if TYPE_CHECKING: - from pyspark.sql import DataFrame - - -class WandbEvaluator(Evaluator): - """Wrapper for pyspark Evaluators. It is expected that the user will provide an Evaluators, and this wrapper will log metrics from said evaluator to W&B.""" - - spark_ml_evaluator: Param[Evaluator] = Param( - Params._dummy(), "spark_ml_evaluator", "evaluator from pyspark.ml.evaluation" - ) - - wandb_run: Param[Run] = Param( - Params._dummy(), - "wandb_run", - "wandb run. Expects an already initialized run. You should set this, or wandb_run_kwargs, NOT BOTH", - ) - - wandb_run_kwargs: Param[Any] = Param( - Params._dummy(), - "wandb_run_kwargs", - "kwargs to be passed to wandb.init. You should set this, or wandb_runId, NOT BOTH. Setting this is useful when using with WandbCrossValdidator", - ) - - wandb_runId: Param[str] = Param( # noqa: N815 - Params._dummy(), - "wandb_runId", - "wandb run id. if not providing an intialized run to wandb_run, a run with id wandb_runId will be resumed", - ) - - wandb_project_name: Param[str] = Param( - Params._dummy(), - "wandb_project_name", - "name of W&B project", - typeConverter=TypeConverters.toString, - ) - - label_values: Param[list[str]] = Param( - Params._dummy(), - "label_values", - "for classification and multiclass classification, this is a list of values the label can assume\nIf provided Multiclass or Multilabel evaluator without label_values, we'll figure it out from dataset passed through to evaluate.", - ) - - _input_kwargs: Dict[str, Any] - - @keyword_only - def __init__( - self: WandbEvaluator, - label_values: list[str] | None = None, - **kwargs: BinaryClassificationEvaluator - | MulticlassClassificationEvaluator - | Run, - ) -> None: - """Initialize a WandbEvaluator. - - Args: - label_values (list[str] | None): List of label values. - **kwargs (BinaryClassificationEvaluator | MulticlassClassificationEvaluator | Run): Keyword arguments. - """ - if label_values is None: - label_values = [] - super(Evaluator, self).__init__() - - self.metrics = { - MulticlassClassificationEvaluator: [ - "f1", - "accuracy", - "weightedPrecision", - "weightedRecall", - "weightedTruePositiveRate", - "weightedFalsePositiveRate", - "weightedFMeasure", - "truePositiveRateByLabel", - "falsePositiveRateByLabel", - "precisionByLabel", - "recallByLabel", - "fMeasureByLabel", - "logLoss", - "hammingLoss", - ], - BinaryClassificationEvaluator: ["areaUnderROC", "areaUnderPR"], - } - - self._setDefault(label_values=[]) - kwargs = self._input_kwargs - self._set(**kwargs) - - def setspark_ml_evaluator(self: WandbEvaluator, value: Evaluator) -> None: - """Set the spark_ml_evaluator parameter. - - Args: - value (Evaluator): Spark ML evaluator. - """ - self._set(spark_ml_evaluator=value) - - def setlabel_values(self: WandbEvaluator, value: list[str]) -> None: - """Set the label_values parameter. - - Args: - value (list[str]): List of label values. - """ - self._set(label_values=value) - - def getspark_ml_evaluator(self: WandbEvaluator) -> Evaluator: - """Get the spark_ml_evaluator parameter. - - Returns: - Evaluator: Spark ML evaluator. - """ - return self.getOrDefault(self.spark_ml_evaluator) - - def getwandb_run(self: WandbEvaluator) -> Run: - """Get the wandb_run parameter. - - Returns: - Run: Wandb run object. - """ - return self.getOrDefault(self.wandb_run) - - def getwandb_project_name(self: WandbEvaluator) -> Any: - """Get the wandb_project_name parameter. - - Returns: - Any: Name of the W&B project. - """ - return self.getOrDefault(self.wandb_project_name) - - def getlabel_values(self: WandbEvaluator) -> list[str]: - """Get the label_values parameter. - - Returns: - list[str]: List of label values. - """ - return self.getOrDefault(self.label_values) - - def _evaluate(self: WandbEvaluator, dataset: DataFrame) -> float: - """Evaluate the model on the given dataset. - - Args: - dataset (DataFrame): Dataset to evaluate the model on. - - Returns: - float: Metric value. - """ - dataset.persist() - metric_values: list[tuple[str, Any]] = [] - label_values = self.getlabel_values() - spark_ml_evaluator: BinaryClassificationEvaluator | MulticlassClassificationEvaluator = ( - self.getspark_ml_evaluator() # type: ignore[assignment, unused-ignore] - ) - run = self.getwandb_run() - evaluator_type = type(spark_ml_evaluator) - for metric in self.metrics[evaluator_type]: - if "ByLabel" in metric and label_values == []: - print( # noqa: T201 - "no label_values for the target have been provided and will be determined by the dataset. This could take some time" - ) - label_values = [ - r[spark_ml_evaluator.getLabelCol()] - for r in dataset.select(spark_ml_evaluator.getLabelCol()) - .distinct() - .collect() - ] - if isinstance(label_values[0], list): - merged = list(itertools.chain(*label_values)) - label_values = list(dict.fromkeys(merged).keys()) - self.setlabel_values(label_values) - for label in label_values: - out = spark_ml_evaluator.evaluate( - dataset, - { - spark_ml_evaluator.metricLabel: label, # type: ignore[assignment, unused-ignore] - spark_ml_evaluator.metricName: metric, - }, - ) - metric_values.append((f"{metric}:{label}", out)) - out = spark_ml_evaluator.evaluate( - dataset, {spark_ml_evaluator.metricName: metric} - ) - metric_values.append((f"{metric}", out)) - run.log(dict(metric_values)) - config = [ - (f"{k.parent.split('_')[0]}.{k.name}", v) - for k, v in spark_ml_evaluator.extractParamMap().items() - if "metric" not in k.name - ] - run.config.update(dict(config)) - return_metric = spark_ml_evaluator.evaluate(dataset) - dataset.unpersist() - return return_metric diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index f8d892e07..1eab0792c 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -2,307 +2,243 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING, Any, Type -from pyspark.ml import Pipeline, PipelineModel -from pyspark.ml.evaluation import ( - BinaryClassificationEvaluator, - MulticlassClassificationEvaluator, -) -from pyspark.ml.feature import StringIndexer, VectorAssembler -from pyspark.ml.tuning import ParamGridBuilder -from wandb.data_types import Table -from wandb.sdk import init as wandb_init -from wandb.wandb_run import Run -from xgboost.spark.core import SparkXGBClassifierModel - -from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix -from gentropy.method.l2g.evaluator import WandbEvaluator +import skops.io as sio +from pandas import DataFrame as pd_dataframe +from pandas import to_numeric as pd_to_numeric +from sklearn.ensemble import GradientBoostingClassifier +from skops import hub_utils + +from gentropy.common.session import Session +from gentropy.common.utils import copy_to_gcs if TYPE_CHECKING: - from pyspark.ml import Transformer - from pyspark.sql import DataFrame + from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix + from gentropy.dataset.l2g_prediction import L2GPrediction @dataclass class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" - features_list: list[str] - estimator: Any = None - pipeline: Pipeline = Pipeline(stages=[]) - model: PipelineModel | None = None - wandb_l2g_project_name: str = "otg_l2g" + model: Any = GradientBoostingClassifier(random_state=42) + hyperparameters: dict[str, Any] | None = None + training_data: L2GFeatureMatrix | None = None + label_encoder: dict[str, int] = field( + default_factory=lambda: { + "negative": 0, + "positive": 1, + } + ) def __post_init__(self: LocusToGeneModel) -> None: - """Post init that adds the model to the ML pipeline.""" - label_indexer = StringIndexer( - inputCol="goldStandardSet", outputCol="label", handleInvalid="keep" - ) - vector_assembler = LocusToGeneModel.features_vector_assembler( - self.features_list - ) - - self.pipeline = Pipeline( - stages=[ - label_indexer, - vector_assembler, - ] - ) + """Post-initialisation to fit the estimator with the provided params.""" + if self.hyperparameters: + self.model.set_params(**self.hyperparameters_dict) - def save(self: LocusToGeneModel, path: str) -> None: - """Saves fitted pipeline model to disk. + @classmethod + def load_from_disk( + cls: Type[LocusToGeneModel], path: str | Path + ) -> LocusToGeneModel: + """Load a fitted model from disk. Args: - path (str): Path to save the model to + path (str | Path): Path to the model + + Returns: + LocusToGeneModel: L2G model loaded from disk Raises: ValueError: If the model has not been fitted yet """ - if self.model is None: + loaded_model = sio.load(path, trusted=True) + if not loaded_model._is_fitted(): raise ValueError("Model has not been fitted yet.") - self.model.write().overwrite().save(path) + return cls(model=loaded_model) - @property - def classifier(self: LocusToGeneModel) -> Any: - """Return the model. - - Returns: - Any: An estimator object from Spark ML - """ - return self.estimator - - @staticmethod - def features_vector_assembler(features_cols: list[str]) -> VectorAssembler: - """Spark transformer to assemble the feature columns into a vector. + @classmethod + def load_from_hub( + cls: Type[LocusToGeneModel], + model_id: str, + hf_token: str | None = None, + model_name: str = "classifier.skops", + ) -> LocusToGeneModel: + """Load a model from the Hugging Face Hub. This will download the model from the hub and load it from disk. Args: - features_cols (list[str]): List of feature columns to assemble + model_id (str): Model ID on the Hugging Face Hub + hf_token (str | None): Hugging Face Hub token to download the model (only required if private) + model_name (str): Name of the persisted model to load. Defaults to "classifier.skops". Returns: - VectorAssembler: Spark transformer to assemble the feature columns into a vector - - Examples: - >>> from pyspark.ml.feature import VectorAssembler - >>> df = spark.createDataFrame([(5.2, 3.5)], schema="feature_1 FLOAT, feature_2 FLOAT") - >>> assembler = LocusToGeneModel.features_vector_assembler(["feature_1", "feature_2"]) - >>> assembler.transform(df).show() - +---------+---------+--------------------+ - |feature_1|feature_2| features| - +---------+---------+--------------------+ - | 5.2| 3.5|[5.19999980926513...| - +---------+---------+--------------------+ - - """ - return ( - VectorAssembler(handleInvalid="error") - .setInputCols(features_cols) - .setOutputCol("features") - ) - - def log_to_wandb( - self: LocusToGeneModel, - results: DataFrame, - training_data: L2GFeatureMatrix, - evaluators: list[ - BinaryClassificationEvaluator | MulticlassClassificationEvaluator - ], - wandb_run: Run, - ) -> None: - """Log evaluation results and feature importance to W&B. - - Args: - results (DataFrame): Dataframe containing the predictions - training_data (L2GFeatureMatrix): Training data used for the model. If provided, the table and the number of positive and negative labels will be logged to W&B - evaluators (list[BinaryClassificationEvaluator | MulticlassClassificationEvaluator]): List of Spark ML evaluators to use for evaluation - wandb_run (Run): W&B run to log the results to + LocusToGeneModel: L2G model loaded from the Hugging Face Hub """ - ## Track evaluation metrics - for evaluator in evaluators: - wandb_evaluator = WandbEvaluator( - spark_ml_evaluator=evaluator, wandb_run=wandb_run - ) - wandb_evaluator.evaluate(results) - ## Track feature importance - wandb_run.log({"importances": self.get_feature_importance()}) - ## Track training set - training_table = Table(dataframe=training_data.df.toPandas()) - wandb_run.log({"trainingSet": training_table}) - # Count number of positive and negative labels - gs_counts_dict = { - "goldStandard" + row["goldStandardSet"].capitalize(): row["count"] - for row in training_data.df.groupBy("goldStandardSet").count().collect() - } - wandb_run.log(gs_counts_dict) - # Missingness rates - wandb_run.log( - {"missingnessRates": training_data.calculate_feature_missingness_rate()} - ) - - @classmethod - def load_from_disk( - cls: Type[LocusToGeneModel], path: str, features_list: list[str] - ) -> LocusToGeneModel: - """Load a fitted pipeline model from disk. + local_path = Path(model_id) + hub_utils.download(repo_id=model_id, dst=local_path, token=hf_token) + return cls.load_from_disk(Path(local_path) / model_name) - Args: - path (str): Path to the model - features_list (list[str]): List of features used for the model + @property + def hyperparameters_dict(self) -> dict[str, Any]: + """Return hyperparameters as a dictionary. Returns: - LocusToGeneModel: L2G model loaded from disk + dict[str, Any]: Hyperparameters + + Raises: + ValueError: If hyperparameters have not been set """ - return cls(model=PipelineModel.load(path), features_list=features_list) + if not self.hyperparameters: + raise ValueError("Hyperparameters have not been set.") + elif isinstance(self.hyperparameters, dict): + return self.hyperparameters + return self.hyperparameters.default_factory() - @classifier.setter # type: ignore - def classifier(self: LocusToGeneModel, new_estimator: Any) -> None: - """Set the model. + def predict( + self: LocusToGeneModel, + feature_matrix: L2GFeatureMatrix, + session: Session, + ) -> L2GPrediction: + """Apply the model to a given feature matrix dataframe. The feature matrix needs to be preprocessed first. Args: - new_estimator (Any): An estimator object from Spark ML - """ - self.estimator = new_estimator - - def get_param_grid(self: LocusToGeneModel) -> list[Any]: - """Return the parameter grid for the model. + feature_matrix (L2GFeatureMatrix): Feature matrix to apply the model to. + session (Session): Session object to convert data to Spark Returns: - list[Any]: List of parameter maps to use for cross validation + L2GPrediction: Dataset containing credible sets and their L2G scores """ - return ( - ParamGridBuilder() - .addGrid(self.estimator.max_depth, [3, 5, 7]) - .addGrid(self.estimator.learning_rate, [0.01, 0.1, 1.0]) - .build() + from gentropy.dataset.l2g_prediction import L2GPrediction + + pd_dataframe.iteritems = pd_dataframe.items + + feature_matrix_pdf = feature_matrix.df.toPandas() + # L2G score is the probability the classifier assigns to the positive class (the second element in the probability array) + feature_matrix_pdf["score"] = self.model.predict_proba( + # We drop the fixed columns to only pass the feature values to the classifier + feature_matrix_pdf.drop(feature_matrix.fixed_cols, axis=1) + .apply(pd_to_numeric) + .values + )[:, 1] + output_cols = [field.name for field in L2GPrediction.get_schema().fields] + return L2GPrediction( + _df=session.spark.createDataFrame(feature_matrix_pdf.filter(output_cols)), + _schema=L2GPrediction.get_schema(), ) - def add_pipeline_stage( - self: LocusToGeneModel, transformer: Transformer - ) -> LocusToGeneModel: - """Adds a stage to the L2G pipeline. + def save(self: LocusToGeneModel, path: str) -> None: + """Saves fitted model to disk using the skops persistence format. Args: - transformer (Transformer): Spark transformer to add to the pipeline + path (str): Path to save the persisted model. Should end with .skops - Returns: - LocusToGeneModel: L2G model with the new transformer - - Examples: - >>> from pyspark.ml.regression import LinearRegression - >>> estimator = LinearRegression() - >>> test_model = LocusToGeneModel(features_list=["a", "b"]) - >>> print(len(test_model.pipeline.getStages())) - 2 - >>> print(len(test_model.add_pipeline_stage(estimator).pipeline.getStages())) - 3 + Raises: + ValueError: If the model has not been fitted yet + ValueError: If the path does not end with .skops """ - pipeline_stages = self.pipeline.getStages() - new_stages = pipeline_stages + [transformer] - self.pipeline = Pipeline(stages=new_stages) - return self - - def evaluate( + if self.model is None: + raise ValueError("Model has not been fitted yet.") + if not path.endswith(".skops"): + raise ValueError("Path should end with .skops") + if path.startswith("gs://"): + local_path = path.split("/")[-1] + sio.dump(self.model, local_path) + copy_to_gcs(local_path, path) + else: + sio.dump(self.model, path) + + def _create_hugging_face_model_card( self: LocusToGeneModel, - results: DataFrame, - hyperparameters: dict[str, Any], - wandb_run_name: str | None, - gold_standard_data: L2GFeatureMatrix | None = None, + local_repo: str, ) -> None: - """Perform evaluation of the model predictions for the test set and track the results with W&B. + """Create a model card to document the model in the hub. The model card is saved in the local repo before pushing it to the hub. Args: - results (DataFrame): Dataframe containing the predictions - hyperparameters (dict[str, Any]): Hyperparameters used for the model - wandb_run_name (str | None): Descriptive name for the run to be tracked with W&B - gold_standard_data (L2GFeatureMatrix | None): Feature matrix for the associations in the gold standard. If provided, the ratio of positive to negative labels will be logged to W&B + local_repo (str): Path to the folder where the README file will be saved to be pushed to the Hugging Face Hub """ - binary_evaluator = BinaryClassificationEvaluator( - rawPredictionCol="rawPrediction", labelCol="label" - ) - multi_evaluator = MulticlassClassificationEvaluator( - labelCol="label", predictionCol="prediction" - ) - - if wandb_run_name and gold_standard_data: - run = wandb_init( - project=self.wandb_l2g_project_name, - config=hyperparameters, - name=wandb_run_name, - ) - if isinstance(run, Run): - self.log_to_wandb( - results, - gold_standard_data, - [binary_evaluator, multi_evaluator], - run, - ) - run.finish() + from skops import card - @property - def feature_name_map(self: LocusToGeneModel) -> dict[str, str]: - """Return a dictionary mapping encoded feature names to the original names. + # Define card metadata + description = """The locus-to-gene (L2G) model derives features to prioritise likely causal genes at each GWAS locus based on genetic and functional genomics features. The main categories of predictive features are: - Returns: - dict[str, str]: Feature name map of the model + - Distance: (from credible set variants to gene) + - Molecular QTL Colocalization + - Chromatin Interaction: (e.g., promoter-capture Hi-C) + - Variant Pathogenicity: (from VEP) - Raises: - ValueError: If the model has not been fitted yet + More information at: https://opentargets.github.io/gentropy/python_api/methods/l2g/_l2g/ """ - if not self.model: - raise ValueError("Model not fitted yet. `fit()` has to be called first.") - elif isinstance(self.model.stages[1], VectorAssembler): - feature_names = self.model.stages[1].getInputCols() - return {f"f{i}": feature_name for i, feature_name in enumerate(feature_names)} - - def get_feature_importance(self: LocusToGeneModel) -> dict[str, float]: - """Return dictionary with relative importances of every feature in the model. Feature names are encoded and have to be mapped back to their original names. + how_to = """To use the model, you can load it using the `LocusToGeneModel.load_from_hub` method. This will return a `LocusToGeneModel` object that can be used to make predictions on a feature matrix. + The model can then be used to make predictions using the `predict` method. - Returns: - dict[str, float]: Dictionary mapping feature names to their importance - - Raises: - ValueError: If the model has not been fitted yet or is not an XGBoost model + More information can be found at: https://opentargets.github.io/gentropy/python_api/methods/l2g/model/ """ - if not self.model or not isinstance( - self.model.stages[-1], SparkXGBClassifierModel - ): - raise ValueError( - f"Model type {type(self.model)} not supported for feature importance." - ) - importance_map = self.model.stages[-1].get_feature_importances() - return {self.feature_name_map[k]: v for k, v in importance_map.items()} - - def fit( - self: LocusToGeneModel, - feature_matrix: L2GFeatureMatrix, - ) -> LocusToGeneModel: - """Fit the pipeline to the feature matrix dataframe. - - Args: - feature_matrix (L2GFeatureMatrix): Feature matrix dataframe to fit the model to - - Returns: - LocusToGeneModel: Fitted model - """ - self.model = self.pipeline.fit(feature_matrix.df) - return self + model_card = card.Card( + self.model, + metadata=card.metadata_from_config(Path(local_repo)), + ) + model_card.add( + **{ + "Model description": description, + "Model description/Training Procedure": "Gradient Boosting Classifier", + "How to Get Started with the Model": how_to, + "Model Card Authors": "Open Targets", + "License": "MIT", + "Citation": "https://doi.org/10.1038/s41588-021-00945-5", + } + ) + model_card.delete("Model description/Training Procedure/Model Plot") + model_card.delete("Model description/Evaluation Results") + model_card.delete("Model Card Authors") + model_card.delete("Model Card Contact") + model_card.save(Path(local_repo) / "README.md") - def predict( + def export_to_hugging_face_hub( self: LocusToGeneModel, - feature_matrix: L2GFeatureMatrix, - ) -> DataFrame: - """Apply the model to a given feature matrix dataframe. The feature matrix needs to be preprocessed first. + model_path: str, + hf_hub_token: str, + data: pd_dataframe, + commit_message: str, + repo_id: str = "opentargets/locus_to_gene", + local_repo: str = "locus_to_gene", + ) -> None: + """Share the model on Hugging Face Hub. Args: - feature_matrix (L2GFeatureMatrix): Feature matrix dataframe to apply the model to - - Returns: - DataFrame: Dataframe with predictions + model_path (str): The path to the L2G model file. + hf_hub_token (str): Hugging Face Hub token + data (pd_dataframe): Data used to train the model. This is used to have an example input for the model and to store the column order. + commit_message (str): Commit message for the push + repo_id (str): The Hugging Face Hub repo id where the model will be stored. + local_repo (str): Path to the folder where the contents of the model repo + the documentation are located. This is used to push the model to the Hugging Face Hub. Raises: - ValueError: If the model has not been fitted yet + Exception: If the push to the Hugging Face Hub fails """ - if not self.model: - raise ValueError("Model not fitted yet. `fit()` has to be called first.") - return self.model.transform(feature_matrix.df) + from sklearn import __version__ as sklearn_version + + try: + hub_utils.init( + model=model_path, + requirements=[f"scikit-learn={sklearn_version}"], + dst=local_repo, + task="tabular-classification", + data=data, + ) + self._create_hugging_face_model_card(local_repo) + hub_utils.push( + repo_id=repo_id, + source=local_repo, + token=hf_hub_token, + commit_message=commit_message, + create_remote=True, + ) + except Exception as e: + # remove the local repo if the push fails + if Path(local_repo).exists(): + for p in Path(local_repo).glob("*"): + p.unlink() + Path(local_repo).rmdir() + raise e diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index 1638a0417..85fedc45b 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -3,10 +3,23 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Optional - -from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.ml.tuning import CrossValidator +from functools import partial +from typing import Any + +import pandas as pd +from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, + roc_auc_score, +) +from sklearn.model_selection import train_test_split +from wandb.data_types import Table +from wandb.sdk.wandb_init import init as wandb_init +from wandb.sdk.wandb_sweep import sweep as wandb_sweep +from wandb.sklearn import plot_classifier +from wandb.wandb_agent import agent as wandb_agent from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.method.l2g.model import LocusToGeneModel @@ -16,95 +29,177 @@ class LocusToGeneTrainer: """Modelling of what is the most likely causal gene associated with a given locus.""" - _model: LocusToGeneModel - train_set: L2GFeatureMatrix + model: LocusToGeneModel + feature_matrix: L2GFeatureMatrix - @classmethod - def train( - cls: type[LocusToGeneTrainer], - gold_standard_data: L2GFeatureMatrix, - l2g_model: LocusToGeneModel, - evaluate: bool, - wandb_run_name: str | None = None, - model_path: str | None = None, - **hyperparams: dict[str, Any], - ) -> LocusToGeneModel: - """Train the Locus to Gene model. + # Initialise vars + features_list: list[str] | None = None + target_labels: list[str] | None = None + x_train: pd.DataFrame | None = None + y_train: pd.Series | None = None + x_test: pd.DataFrame | None = None + y_test: pd.Series | None = None + wandb_l2g_project_name: str = "gentropy-locus-to-gene" - Args: - gold_standard_data (L2GFeatureMatrix): Feature matrix for the associations in the gold standard - l2g_model (LocusToGeneModel): Model to fit to the data on - evaluate (bool): Whether to evaluate the model on a test set - wandb_run_name (str | None): Descriptive name for the run to be tracked with W&B - model_path (str | None): Path to save the model to - **hyperparams (dict[str, Any]): Hyperparameters to use for the model + def fit( + self: LocusToGeneTrainer, + ) -> LocusToGeneModel: + """Fit the pipeline to the feature matrix dataframe. Returns: - LocusToGeneModel: Trained model + LocusToGeneModel: Fitted model + + Raises: + ValueError: Train data not set, nothing to fit. """ - train, test = gold_standard_data.train_test_split(fraction=0.8) + if self.x_train is not None and self.y_train is not None: + assert ( + not self.x_train.empty and not self.y_train.empty + ), "Train data not set, nothing to fit." + fitted_model = self.model.model.fit(X=self.x_train.values, y=self.y_train) + self.model = LocusToGeneModel( + model=fitted_model, + hyperparameters=fitted_model.get_params(), + training_data=self.feature_matrix, + ) + return self.model + raise ValueError("Train data not set, nothing to fit.") - model = l2g_model.add_pipeline_stage(l2g_model.estimator).fit(train) + def log_to_wandb( + self: LocusToGeneTrainer, + wandb_run_name: str, + ) -> None: + """Log evaluation results and feature importance to W&B to compare between different L2G runs. - if evaluate: - l2g_model.evaluate( - results=model.predict(test), - hyperparameters=hyperparams, - wandb_run_name=wandb_run_name, - gold_standard_data=gold_standard_data, + Dashboard is available at https://wandb.ai/open-targets/gentropy-locus-to-gene?nw=nwuseropentargets + Credentials to access W&B are available at the OT central login sheet. + + Args: + wandb_run_name (str): Name of the W&B run + """ + if ( + self.x_train is not None + and self.x_test is not None + and self.y_train is not None + and self.y_test is not None + ): + assert ( + not self.x_train.empty and not self.y_train.empty + ), "Train data not set, nothing to evaluate." + fitted_classifier = self.model.model + y_predicted = fitted_classifier.predict(self.x_test.values) + y_probas = fitted_classifier.predict_proba(self.x_test.values) + run = wandb_init( + project=self.wandb_l2g_project_name, + name=wandb_run_name, + config=fitted_classifier.get_params(), + ) + # Track classification plots + plot_classifier( + self.model.model, + self.x_train.values, + self.x_test.values, + self.y_train, + self.y_test, + y_predicted, + y_probas, + labels=list(self.model.label_encoder.values()), + model_name="L2G-classifier", + feature_names=self.features_list, + is_binary=True, + ) + # Track evaluation metrics + run.log( + { + "areaUnderROC": roc_auc_score( + self.y_test, y_probas[:, 1], average="weighted" + ) + } + ) + run.log({"accuracy": accuracy_score(self.y_test, y_predicted)}) + run.log( + { + "weightedPrecision": precision_score( + self.y_test, y_predicted, average="weighted" + ) + } + ) + run.log( + { + "weightedRecall": recall_score( + self.y_test, y_predicted, average="weighted" + ) + } + ) + run.log({"f1": f1_score(self.y_test, y_predicted, average="weighted")}) + # Track gold standards and their features + run.log( + {"featureMatrix": Table(dataframe=self.feature_matrix.df.toPandas())} + ) + # Log feature missingness + run.log( + { + "missingnessRates": self.feature_matrix.calculate_feature_missingness_rate() + } ) - if model_path: - l2g_model.save(model_path) - return l2g_model - - @classmethod - def cross_validate( - cls: type[LocusToGeneTrainer], - l2g_model: LocusToGeneModel, - data: L2GFeatureMatrix, - num_folds: int, - param_grid: Optional[list] = None, # type: ignore - ) -> LocusToGeneModel: - """Perform k-fold cross validation on the model. - By providing a model with a parameter grid, this method will perform k-fold cross validation on the model for each - combination of parameters and return the best model. + def train( + self: LocusToGeneTrainer, + wandb_run_name: str, + ) -> LocusToGeneModel: + """Train the Locus to Gene model. Args: - l2g_model (LocusToGeneModel): Model to fit to the data on - data (L2GFeatureMatrix): Data to perform cross validation on - num_folds (int): Number of folds to use for cross validation - param_grid (Optional[list]): List of parameter maps to use for cross validation + wandb_run_name (str): Name of the W&B run. Unless this is provided, the model will not be logged to W&B. Returns: - LocusToGeneModel: Trained model fitted with the best hyperparameters - - Raises: - ValueError: Parameter grid is empty. Cannot perform cross-validation. - ValueError: Unable to retrieve the best model. + LocusToGeneModel: Fitted model """ - evaluator = MulticlassClassificationEvaluator() - params_grid = param_grid or l2g_model.get_param_grid() - if not param_grid: - raise ValueError( - "Parameter grid is empty. Cannot perform cross-validation." - ) - cv = CrossValidator( - numFolds=num_folds, - estimator=l2g_model.estimator, - estimatorParamMaps=params_grid, - evaluator=evaluator, - parallelism=2, - collectSubModels=False, - seed=42, + data_df = self.feature_matrix.df.drop("geneId").toPandas() + + # Encode labels in `goldStandardSet` to a numeric value + data_df["goldStandardSet"] = data_df["goldStandardSet"].map( + self.model.label_encoder ) - l2g_model.add_pipeline_stage(cv) # type: ignore[assignment, unused-ignore] + # Convert all columns to numeric and split + data_df = data_df.apply(pd.to_numeric) + self.feature_cols = [ + col + for col in data_df.columns + if col not in ["studyLocusId", "goldStandardSet"] + ] + label_col = "goldStandardSet" + X = data_df[self.feature_cols].copy() + y = data_df[label_col].copy() + self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) - # Integrate the best model from the last stage of the pipeline - if (full_pipeline_model := l2g_model.fit(data).model) is None or not hasattr( - full_pipeline_model, "stages" - ): - raise ValueError("Unable to retrieve the best model.") - l2g_model.model = full_pipeline_model.stages[-1].bestModel # type: ignore[assignment, unused-ignore] - return l2g_model + # Train + model = self.fit() + + # Evaluate + self.log_to_wandb( + wandb_run_name=wandb_run_name, + ) + + return model + + def hyperparameter_tuning( + self: LocusToGeneTrainer, wandb_run_name: str, parameter_grid: dict[str, Any] + ) -> None: + """Perform hyperparameter tuning on the model with W&B Sweeps. Metrics for every combination of hyperparameters will be logged to W&B for comparison. + + Args: + wandb_run_name (str): Name of the W&B run + parameter_grid (dict[str, Any]): Dictionary containing the hyperparameters to sweep over. The keys are the hyperparameter names, and the values are dictionaries containing the values to sweep over. + """ + sweep_config = { + "method": "grid", + "metric": {"name": "roc", "goal": "maximize"}, + "parameters": parameter_grid, + } + sweep_id = wandb_sweep(sweep_config, project=self.wandb_l2g_project_name) + + wandb_agent(sweep_id, partial(self.train, wandb_run_name=wandb_run_name)) diff --git a/tests/gentropy/method/test_locus_to_gene.py b/tests/gentropy/method/test_locus_to_gene.py index 898252f9f..35f736d1f 100644 --- a/tests/gentropy/method/test_locus_to_gene.py +++ b/tests/gentropy/method/test_locus_to_gene.py @@ -7,15 +7,11 @@ import pytest from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature import L2GFeature -from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus from gentropy.method.l2g.feature_factory import ColocalisationFactory, StudyLocusFactory from gentropy.method.l2g.model import LocusToGeneModel -from gentropy.method.l2g.trainer import LocusToGeneTrainer -from pyspark.ml import PipelineModel -from pyspark.ml.tuning import ParamGridBuilder -from xgboost.spark import SparkXGBClassifier +from sklearn.ensemble import RandomForestClassifier if TYPE_CHECKING: from gentropy.dataset.v2g import V2G @@ -25,56 +21,7 @@ @pytest.fixture(scope="module") def model() -> LocusToGeneModel: """Creates an instance of the LocusToGene class.""" - estimator = SparkXGBClassifier( - eval_metric="logloss", - features_col="features", - label_col="label", - max_depth=5, - ) - return LocusToGeneModel(estimator=estimator, features_list=["distanceTssMean"]) - - -class TestLocusToGeneTrainer: - """Test the L2GTrainer methods using a logistic regression model as estimation algorithm.""" - - def test_cross_validate( - self: TestLocusToGeneTrainer, - mock_l2g_feature_matrix: L2GFeatureMatrix, - model: LocusToGeneModel, - ) -> None: - """Test the k-fold cross-validation function.""" - param_grid = ( - ParamGridBuilder() - .addGrid(model.estimator.learning_rate, [0.1, 0.01]) - .build() - ) - best_model = LocusToGeneTrainer.cross_validate( - model, mock_l2g_feature_matrix.fill_na(), num_folds=2, param_grid=param_grid - ) - assert isinstance( - best_model, LocusToGeneModel - ), "Unexpected model type returned from cross_validate" - # Check that the best model's hyperparameters are among those in the param_grid - assert best_model.model.getOrDefault("learning_rate") in [ # type: ignore - 0.1, - 0.01, - ], "Unexpected learning rate in the best model" - - def test_train( - self: TestLocusToGeneTrainer, - mock_l2g_feature_matrix: L2GFeatureMatrix, - model: LocusToGeneModel, - ) -> None: - """Test the training function.""" - trained_model = LocusToGeneTrainer.train( - mock_l2g_feature_matrix.fill_na().select_features(["distanceTssMean"]), - model, - evaluate=False, - ) - # Check that `model` is a PipelineModel object and not None - assert isinstance( - trained_model.model, PipelineModel - ), "Model is not a PipelineModel object." + return LocusToGeneModel(model=RandomForestClassifier()) class TestColocalisationFactory: