Skip to content

Commit

Permalink
Download Support for Custom HuggingFace Models (#27)
Browse files Browse the repository at this point in the history
* Download Custom HuggingFace Models and Progress Bar

* Minor linting fix

* minor repo_version change
  • Loading branch information
saileshd1402 authored Dec 4, 2023
1 parent e912408 commit f5bfb83
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 71 deletions.
111 changes: 62 additions & 49 deletions llm/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import uuid
from typing import List
import huggingface_hub as hfh
from huggingface_hub.utils import HfHubHTTPError
from utils.marsgen import get_mar_name, generate_mars
from utils.system_utils import (
check_if_path_exists,
Expand Down Expand Up @@ -176,65 +175,68 @@ def read_config_for_download(gen_model: GenerateDataModel) -> GenerateDataModel:
"repo_version"
]

# Make sure there is HF hub token for LLAMA(2)
if (
gen_model.repo_info.repo_id.startswith("meta-llama")
and gen_model.repo_info.hf_token is None
):
print(
"## Error: HuggingFace Hub token is required for llama download."
" Please specify it using --hf_token=<your token>. "
"Refer https://huggingface.co/docs/hub/security-tokens"
)
sys.exit(1)

# Validate downloaded files
hf_api = hfh.HfApi()
hf_api.list_repo_commits(
repo_id=gen_model.repo_info.repo_id,
revision=gen_model.repo_info.repo_version,
token=gen_model.repo_info.hf_token,
)

# Read handler file name
if not gen_model.mar_utils.handler_path:
gen_model.mar_utils.handler_path = os.path.join(
os.path.dirname(__file__),
models[gen_model.model_name]["handler"],
)

except (KeyError, HfHubHTTPError):
# Validate hf_token
gen_model.validate_hf_token()

# Validate repository info
gen_model.validate_commit_info()

except (KeyError, ValueError):
print(
"## Error: Please check either repo_id, repo_version"
" or HuggingFace ID is not correct\n"
"## There seems to be an error in the model_config.json file. "
"Please check the same."
)
sys.exit(1)

else: # Custom model case
if not gen_model.skip_download:
print(
"## Please check your model name,"
" it should be one of the following : "
)
print(list(models.keys()))
print(
"\n## If you want to use custom model files,"
" use the '--no_download' argument"
)
sys.exit(1)
gen_model.is_custom_model = True
if gen_model.skip_download:
if check_if_folder_empty(gen_model.mar_utils.model_path):
print("## Error: The given model path folder is empty\n")
sys.exit(1)

if check_if_folder_empty(gen_model.mar_utils.model_path):
print("## Error: The given model path folder is empty\n")
sys.exit(1)
if not gen_model.repo_info.repo_version:
gen_model.repo_info.repo_version = "1.0"

else:
if not gen_model.repo_info.repo_id:
print(
"## If you want to create a model archive file with the supported models, "
"make sure you're model name is present in the below : "
)
print(list(models.keys()))
print(
"\nIf you want to create a model archive file for"
" a custom model,there are two methods:\n"
"1. If you have already downloaded the custom model"
" files, please include"
" the --no_download flag and provide the model_path "
"directory which contains the model files.\n"
"2. If you need to download the model files, provide "
"the HuggingFace repository ID using 'repo_id'"
" along with an empty model_path driectory where the "
"model files will be downloaded.\n"
)
sys.exit(1)

# Validate hf_token
gen_model.validate_hf_token()

# Validate repository info
gen_model.validate_commit_info()

if not gen_model.mar_utils.handler_path:
gen_model.mar_utils.handler_path = os.path.join(
os.path.dirname(__file__), "handler.py"
)

if not gen_model.repo_info.repo_version:
gen_model.repo_info.repo_version = "1.0"

gen_model.is_custom_model = True
print(
f"\n## Generating MAR file for "
f"custom model files: {gen_model.model_name}"
Expand All @@ -260,7 +262,9 @@ def run_download(gen_model: GenerateDataModel) -> GenerateDataModel:
GenerateDataModel: An instance of the GenerateDataModel class.
"""
if not check_if_folder_empty(gen_model.mar_utils.model_path):
print("## Make sure the path provided to download model files is empty\n")
print(
"## Make sure the model_path provided to download model files through is empty\n"
)
sys.exit(1)

print(
Expand Down Expand Up @@ -290,7 +294,9 @@ def create_mar(gen_model: GenerateDataModel) -> None:
Args:
gen_model (GenerateDataModel): An instance of the GenerateDataModel dataclass
"""
if not gen_model.is_custom_model and not check_if_model_files_exist(gen_model):
if not (
gen_model.is_custom_model and gen_model.skip_download
) and not check_if_model_files_exist(gen_model):
print("## Model files do not match HuggingFace repository files")
sys.exit(1)

Expand Down Expand Up @@ -347,13 +353,20 @@ def run_script(params: argparse.Namespace) -> bool:
type=str,
default="",
required=True,
metavar="mn",
metavar="n",
help="Name of model",
)
parser.add_argument(
"--repo_id",
type=str,
default=None,
metavar="ri",
help="HuggingFace repository ID (In case of custom model download)",
)
parser.add_argument(
"--repo_version",
type=str,
default="",
default=None,
metavar="rv",
help="Commit ID of models repo from HuggingFace repository",
)
Expand All @@ -367,15 +380,15 @@ def run_script(params: argparse.Namespace) -> bool:
type=str,
default="",
required=True,
metavar="mp",
metavar="p",
help="Absolute path of model files (should be empty if downloading)",
)
parser.add_argument(
"--mar_output",
type=str,
default="",
required=True,
metavar="mx",
metavar="a",
help="Absolute path of exported MAR file (.mar)",
)
parser.add_argument(
Expand All @@ -389,7 +402,7 @@ def run_script(params: argparse.Namespace) -> bool:
"--hf_token",
type=str,
default=None,
metavar="hft",
metavar="ht",
help="HuggingFace Hub token to download LLAMA(2) models",
)
parser.add_argument("--debug", action="store_true", help="flag to debug")
Expand Down
73 changes: 67 additions & 6 deletions llm/tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def cleanup_folders() -> None:

def set_generate_args(
model_name: str = MODEL_NAME,
repo_version: str = "",
repo_version: str = None,
model_path: str = MODEL_PATH,
mar_output: str = MODEL_STORE,
handler_path: str = "",
Expand All @@ -82,6 +82,7 @@ def set_generate_args(
args.model_path = model_path
args.mar_output = mar_output
args.no_download = False
args.repo_id = None
args.repo_version = repo_version
args.handler_path = handler_path
args.debug = False
Expand Down Expand Up @@ -250,16 +251,20 @@ def test_skip_download_success() -> None:
assert result is True


def custom_model_setup() -> None:
def custom_model_setup(download_model: bool = True) -> None:
"""
This function is used to setup custom model case.
It runs download.py to download model files and
deletes the contents of 'model_config.json' after
making a backup.
Args:
download (bool): Set to download model files (defaults to True)
"""
download_setup()
args = set_generate_args()
download.run_script(args)
if download_model:
args = set_generate_args()
download.run_script(args)

# creating a backup of original model_config.json
copy_file(MODEL_CONFIG_PATH, MODEL_TEMP_CONFIG_PATH)
Expand All @@ -277,9 +282,9 @@ def custom_model_restore() -> None:
cleanup_folders()


def test_custom_model_success() -> None:
def test_custom_model_skip_download_success() -> None:
"""
This function tests the custom model case.
This function tests the no download custom model case.
This is done by clearing the 'model_config.json' and
generating the 'GPT2' MAR file.
Expected result: Success.
Expand All @@ -296,6 +301,62 @@ def test_custom_model_success() -> None:
custom_model_restore()


def test_custom_model_download_success() -> None:
"""
This function tests the download custom model case.
This is done by clearing the 'model_config.json' and
generating the 'GPT2' MAR file.
Expected result: Success.
"""
custom_model_setup(download_model=False)
args = set_generate_args()
args.repo_id = "gpt2"
try:
result = download.run_script(args)
except SystemExit:
assert False
else:
assert result is True
custom_model_restore()


def test_custom_model_download_wrong_repo_id_throw_error() -> None:
"""
This function tests the download custom model case and
passes a wrong repo_id.
Expected result: Failure.
"""
custom_model_setup(download_model=False)
args = set_generate_args()
args.repo_id = "wrong_repo_id"
try:
download.run_script(args)
except SystemExit as e:
assert e.code == 1
else:
assert False
custom_model_restore()


def test_custom_model_download_wrong_repo_version_throw_error() -> None:
"""
This function tests the download custom model case and
passes a correct repo_id but wrong repo_version.
Expected result: Failure.
"""
custom_model_setup(download_model=False)
args = set_generate_args()
args.repo_id = "gpt2"
args.repo_version = "wrong_repo_version"
try:
download.run_script(args)
except SystemExit as e:
assert e.code == 1
else:
assert False
custom_model_restore()


# Run the tests
if __name__ == "__main__":
pytest.main(["-v", __file__])
24 changes: 22 additions & 2 deletions llm/tests/test_torchserve_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def test_inference_json_file_success() -> None:
assert False


def test_custom_model_success() -> None:
def test_custom_model_skip_download_success() -> None:
"""
This function tests custom model with input folder.
This function tests custom model skipping download with input folder.
Expected result: Success.
"""
custom_model_setup()
Expand All @@ -221,6 +221,26 @@ def test_custom_model_success() -> None:
process = subprocess.run(["python3", "cleanup.py"], check=False)


def test_custom_model_download_success() -> None:
"""
This function tests download custom model input folder.
Expected result: Success.
"""
custom_model_setup(download_model=False)
args = set_generate_args()
args.repo_id = "gpt2"
try:
download.run_script(args)
except SystemExit:
assert False

process = subprocess.run(get_run_cmd(input_path=INPUT_PATH), check=False)
assert process.returncode == 0

custom_model_restore()
process = subprocess.run(["python3", "cleanup.py"], check=False)


# Run the tests
if __name__ == "__main__":
pytest.main(["-v", __file__])
Loading

0 comments on commit f5bfb83

Please sign in to comment.