diff --git a/unimol_tools/README.md b/unimol_tools/README.md index 0dc3405..8186f1a 100644 --- a/unimol_tools/README.md +++ b/unimol_tools/README.md @@ -50,7 +50,16 @@ export HF_ENDPOINT=https://hf-mirror.com Setting the `HF_ENDPOINT` environment variable specifies the mirror address for the Hugging Face Hub to use when downloading models. +### Modify the default directory for weights + +Setting the `UNIMOL_WEIGHT_DIR` environment variable specifies the directory for pre-trained weights if the weights have been downloaded from another source. + +```bash +export UNIMOL_WEIGHT_DIR=/path/to/your/weights/dir/ +``` + ## News +- 2024-07-23: User experience improvements: Add `UNIMOL_WEIGHT_DIR`. - 2024-06-25: unimol_tools has been publish to pypi! Huggingface has been used to manage the pretrain models. - 2024-06-20: unimol_tools v0.1.0 released, we remove the dependency of Uni-Core. And we will publish to pypi soon. - 2024-03-20: unimol_tools documents is available at https://unimol.readthedocs.io/en/latest/ diff --git a/unimol_tools/setup.py b/unimol_tools/setup.py index 6d80b95..a51a833 100644 --- a/unimol_tools/setup.py +++ b/unimol_tools/setup.py @@ -5,7 +5,7 @@ setup( name="unimol_tools", - version="0.1.0.post1", + version="0.1.0.post2", description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."), author="DP Technology", author_email="unimol@dp.tech", diff --git a/unimol_tools/unimol_tools/weights/weighthub.py b/unimol_tools/unimol_tools/weights/weighthub.py index 5d6de89..e2088d2 100644 --- a/unimol_tools/unimol_tools/weights/weighthub.py +++ b/unimol_tools/unimol_tools/weights/weighthub.py @@ -9,7 +9,12 @@ def snapshot_download(*args, **kwargs): raise ImportError('huggingface_hub is not installed. If weights are not avaliable, please install it by running: pip install huggingface_hub. Otherwise, please download the weights manually from https://huggingface.co/dptech/Uni-Mol-Models') -WEIGHT_DIR = os.path.dirname(os.path.abspath(__file__)) +WEIGHT_DIR = os.environ.get('UNIMOL_WEIGHT_DIR', os.path.dirname(os.path.abspath(__file__))) + +if 'UNIMOL_WEIGHT_DIR' in os.environ: + logger.warning(f'Using custom weight directory from UNIMOL_WEIGHT_DIR: {WEIGHT_DIR}') +else: + logger.info(f'Weights will be downloaded to default directory: {WEIGHT_DIR}') os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" # use mirror to download weights