Skip to content

Commit

Permalink
4042 Hugging Face Hub integration (Project-MONAI#6454)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#4042 .

### Description
- [x] Add functionality to download a model bundle from the Hugging Face
Hub.
- [x] Add functionality to upload a model bundle to the Hugging Face
Hub.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: katielink <[email protected]>
  • Loading branch information
katielink authored Oct 19, 2023
1 parent c2e2a96 commit 6e5fdc0
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 34 deletions.
1 change: 1 addition & 0 deletions docs/source/bundle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ Model Bundle
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
.. autofunction:: init_bundle
.. autofunction:: push_to_hf_hub
.. autofunction:: update_kwargs
5 changes: 2 additions & 3 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml]
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub]
```

which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips` and `nvidia-ml-py` respectively.

`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, and `huggingface_hub` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
1 change: 1 addition & 0 deletions monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
init_bundle,
load,
onnx_export,
push_to_hf_hub,
run,
run_workflow,
trt_export,
Expand Down
124 changes: 113 additions & 11 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
requests_get, has_requests = optional_import("requests", name="get")
onnx, _ = optional_import("onnx")
huggingface_hub, _ = optional_import("huggingface_hub")

logger = get_logger(module_name=__name__)

Expand Down Expand Up @@ -244,6 +245,14 @@ def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, l
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
elif source == "huggingface_hub":
refs = huggingface_hub.list_repo_refs(repo_id=repo)
if len(refs.tags) > 0:
all_versions = [t.name for t in refs.tags] # git tags, not to be confused with `tag`
latest_version = ["latest_version" if "latest_version" in all_versions else all_versions[-1]][0]
else:
latest_version = [b.name for b in refs.branches][0] # use the branch that was last updated
return latest_version
else:
raise ValueError(
f"To get the latest bundle version, source should be 'github', 'monaihosting' or 'ngc', got {source}."
Expand Down Expand Up @@ -293,6 +302,9 @@ def download(
# Execute this module as a CLI entry, and download bundle from monaihosting with latest version:
python -m monai.bundle download --name <bundle_name> --source "monaihosting" --bundle_dir "./"
# Execute this module as a CLI entry, and download bundle from Hugging Face Hub:
python -m monai.bundle download --name "bundle_name" --source "huggingface_hub" --repo "repo_owner/repo_name"
# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name <bundle_name> --url <url>
Expand All @@ -311,14 +323,15 @@ def download(
"monai_brats_mri_segmentation" in ngc:
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
the latest version (or the last commit to the `main` branch in the case of Hugging Face Hub).
bundle_dir: target directory to store the downloaded data.
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc", "monaihosting" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
it should be "ngc", "monaihosting", "github", or "huggingface_hub".
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
Expand Down Expand Up @@ -351,9 +364,10 @@ def download(
bundle_dir_ = _process_bundle_dir(bundle_dir_)
if repo_ is None:
repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
if len(repo_.split("/")) != 3:
if len(repo_.split("/")) != 3 and source_ != "huggingface_hub":
raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.")

elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub":
raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name`")
if url_ is not None:
if name_ is not None:
filepath = bundle_dir_ / f"{name_}.zip"
Expand All @@ -380,9 +394,12 @@ def download(
remove_prefix=remove_prefix_,
progress=progress_,
)
elif source_ == "huggingface_hub":
extract_path = os.path.join(bundle_dir_, name_)
huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)
else:
raise NotImplementedError(
"Currently only download from `url`, source 'github', 'monaihosting' or 'ngc' are implemented,"
"Currently only download from `url`, source 'github', 'monaihosting', 'huggingface_hub' or 'ngc' are implemented,"
f"got source: {source_}."
)

Expand Down Expand Up @@ -427,7 +444,7 @@ def load(
https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/mednist_gan/versions/0.2.0/files/mednist_gan_v0.2.0.zip
model: a pytorch module to be updated. Default to None, using the "network_def" in the bundle.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
the latest version. If `source` is "huggingface_hub", this argument is a Git revision id.
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
Expand All @@ -440,9 +457,10 @@ def load(
source: storage location name. This argument is used when `model_file` is not existing locally and need to be
downloaded first.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc", "monaihosting" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
it should be "ngc", "monaihosting", "github", or "huggingface_hub".
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
maintain the consistency between these three sources, remove prefix is necessary.
Expand Down Expand Up @@ -1597,6 +1615,90 @@ def init_bundle(
save_state(network, str(models_dir / "model.pt"))


def _add_model_card_metadata(new_modelcard_path):
# Extract license from LICENSE file
license_name = "unknown"
license_path = os.path.join(os.path.dirname(new_modelcard_path), "LICENSE")
if os.path.exists(license_path):
with open(license_path) as file:
content = file.read()
if "Apache License" in content and "Version 2.0" in content:
license_name = "apache-2.0"
elif "MIT License" in content:
license_name = "mit"
# Add relevant tags
tags = "- monai\n- medical\nlibrary_name: monai\n"
# Create tag section
tag_content = f"---\ntags:\n{tags}license: {license_name}\n---"

# Update model card
with open(new_modelcard_path) as file:
content = file.read()
new_content = tag_content + "\n" + content
with open(new_modelcard_path, "w") as file:
file.write(new_content)


def push_to_hf_hub(
repo: str,
name: str,
bundle_dir: str,
token: str | None = None,
private: bool | None = True,
version: str | None = None,
tag_as_latest_version: bool | None = False,
**upload_folder_kwargs: Any,
) -> Any:
"""
Push a MONAI bundle to the Hugging Face Hub.
Typical usage examples:
.. code-block:: bash
python -m monai.bundle push_to_hf_hub --repo <HF repository id> --name <bundle name> \
--bundle_dir <bundle directory> --version <version> ...
Args:
repo: namespace (user or organization) and a repo name separated by a /, e.g. `hf_username/bundle_name`
bundle_name: name of the bundle directory to push.
bundle_dir: path to the bundle directory.
token: Hugging Face authentication token. Default is `None` (will default to the stored token).
private: Private visibility of the repository on Hugging Face. Default is `True`.
version_name: Name of the version tag to create. Default is `None` (no version tag is created).
tag_as_latest_version: Whether to tag the commit as `latest_version`.
This version will downloaded by default when using `bundle.download()`. Default is `False`.
upload_folder_kwargs: Keyword arguments to pass to `HfApi.upload_folder`.
Returns:
repo_url: URL of the Hugging Face repo
"""
# Connect to API and create repo
hf_api = huggingface_hub.HfApi(token=token)
hf_api.create_repo(repo_id=repo, private=private, exist_ok=True)

# Create model card in bundle directory
new_modelcard_path = os.path.join(bundle_dir, name, "README.md")
modelcard_path = os.path.join(bundle_dir, name, "docs", "README.md")
if os.path.exists(modelcard_path):
# Copy README from old path if it exists
copyfile(modelcard_path, new_modelcard_path)
_add_model_card_metadata(new_modelcard_path)

# Upload bundle folder to repo
repo_url = hf_api.upload_folder(repo_id=repo, folder_path=os.path.join(bundle_dir, name), **upload_folder_kwargs)

# Create version tag if specified
if version is not None:
hf_api.create_tag(repo_id=repo, tag=version, exist_ok=True)

# Optionally tag as `latest_version`
if tag_as_latest_version:
hf_api.create_tag(repo_id=repo, tag="latest_version", exist_ok=True)

return repo_url


def create_workflow(
workflow_name: str | BundleWorkflow | None = None,
config_file: str | Sequence[str] | None = None,
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ pynvml =
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
# MetricsReloaded =
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
huggingface_hub =
huggingface_hub

[flake8]
select = B,C,E,F,N,P,T4,W,B9
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def run_testsuit():
"test_auto3dseg",
"test_bundle_onnx_export",
"test_bundle_trt_export",
"test_bundle_push_to_hf_hub",
"test_cachedataset",
"test_cachedataset_parallel",
"test_cachedataset_persistent_workers",
Expand Down
73 changes: 53 additions & 20 deletions tests/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import tempfile
import unittest
from unittest.case import skipUnless

import numpy as np
import torch
Expand All @@ -23,6 +24,7 @@
import monai.networks.nets as nets
from monai.apps import check_hash
from monai.bundle import ConfigParser, create_workflow, load
from monai.utils import optional_import
from tests.utils import (
SkipIfBeforePyTorchVersion,
assert_allclose,
Expand All @@ -32,6 +34,8 @@
skip_if_quick,
)

_, has_huggingface_hub = optional_import("huggingface_hub")

TEST_CASE_1 = ["test_bundle", None]

TEST_CASE_2 = ["test_bundle", "0.1.1"]
Expand All @@ -46,35 +50,41 @@
TEST_CASE_4 = [
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
"test_bundle",
"Project-MONAI/MONAI-extra-test-data/0.8.1",
"cuda" if torch.cuda.is_available() else "cpu",
"model.pt",
"monai-test/test_bundle",
]

TEST_CASE_5 = [
["test_output.pt", "test_input.pt"],
"test_bundle",
"0.1.1",
"Project-MONAI/MONAI-extra-test-data/0.8.1",
"cuda" if torch.cuda.is_available() else "cpu",
"model.ts",
]

TEST_CASE_6 = [
["models/model.pt", "models/model.ts", "configs/train.json"],
"brats_mri_segmentation",
"https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.3.9/files/brats_mri_segmentation_v0.3.9.zip",
]

TEST_CASE_6 = [["models/model.pt", "configs/train.json"], "renalStructures_CECT_segmentation", "0.1.0"]

TEST_CASE_7 = [
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
"test_bundle",
"Project-MONAI/MONAI-extra-test-data/0.8.1",
"cuda" if torch.cuda.is_available() else "cpu",
"model.pt",
]

TEST_CASE_8 = [
"spleen_ct_segmentation",
"cuda" if torch.cuda.is_available() else "cpu",
{"spatial_dims": 3, "out_channels": 5},
]

TEST_CASE_8 = [["models/model.pt", "configs/train.json"], "renalStructures_CECT_segmentation", "0.1.0"]

TEST_CASE_9 = [
["test_output.pt", "test_input.pt"],
"test_bundle",
"0.1.1",
"Project-MONAI/MONAI-extra-test-data/0.8.1",
"cuda" if torch.cuda.is_available() else "cpu",
"model.ts",
]

TEST_CASE_10 = [
["network.json", "test_output.pt", "test_input.pt", "large_files.yaml"],
"test_bundle",
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle_v0.1.2.zip",
Expand Down Expand Up @@ -122,7 +132,30 @@ def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val):
if file == "network.json":
self.assertTrue(check_hash(filepath=file_path, val=hash_val))

@parameterized.expand([TEST_CASE_6])
@parameterized.expand([TEST_CASE_4])
@skip_if_quick
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
def test_hf_hub_download_bundle(self, bundle_files, bundle_name, repo):
with skip_if_downloading_fails():
with tempfile.TemporaryDirectory() as tempdir:
cmd = [
"coverage",
"run",
"-m",
"monai.bundle",
"download",
"--name",
bundle_name,
"--source",
"huggingface_hub",
]
cmd += ["--bundle_dir", tempdir, "--repo", repo, "--progress", "False"]
command_line_tests(cmd)
for file in bundle_files:
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))

@parameterized.expand([TEST_CASE_5])
@skip_if_quick
def test_monaihosting_url_download_bundle(self, bundle_files, bundle_name, url):
with skip_if_downloading_fails():
Expand All @@ -139,7 +172,7 @@ def test_monaihosting_url_download_bundle(self, bundle_files, bundle_name, url):
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))

@parameterized.expand([TEST_CASE_8])
@parameterized.expand([TEST_CASE_6])
@skip_if_quick
def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, version):
with skip_if_downloading_fails():
Expand All @@ -159,7 +192,7 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve

@skip_if_no_cuda
class TestLoad(unittest.TestCase):
@parameterized.expand([TEST_CASE_4])
@parameterized.expand([TEST_CASE_7])
@skip_if_quick
def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):
with skip_if_downloading_fails():
Expand Down Expand Up @@ -225,7 +258,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
output_3 = model_3.forward(input_tensor)
assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand([TEST_CASE_7])
@parameterized.expand([TEST_CASE_8])
@skip_if_quick
def test_load_weights_with_net_override(self, bundle_name, device, net_override):
with skip_if_downloading_fails():
Expand Down Expand Up @@ -270,7 +303,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
expected_shape = (1, 5, 96, 96, 96)
np.testing.assert_equal(output.shape, expected_shape)

@parameterized.expand([TEST_CASE_5])
@parameterized.expand([TEST_CASE_9])
@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 7, 1))
def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, model_file):
Expand Down Expand Up @@ -303,7 +336,7 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device,


class TestDownloadLargefiles(unittest.TestCase):
@parameterized.expand([TEST_CASE_9])
@parameterized.expand([TEST_CASE_10])
@skip_if_quick
def test_url_download_large_files(self, bundle_files, bundle_name, url, hash_val):
with skip_if_downloading_fails():
Expand Down
Loading

0 comments on commit 6e5fdc0

Please sign in to comment.