diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 042cf29c1..0331c4721 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -33,7 +33,7 @@ jobs: matrix: python_version: ["3.8", "3.9", "3.10", "3.11"] use_gpu: ["y", "n"] - job_type: ["server", "training"] + job_type: ["server"] steps: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v1 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b50e4865d..405e0a0cf 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -36,7 +36,7 @@ jobs: matrix: python_version: ["3.8", "3.9", "3.10", "3.11"] use_gpu: ["y", "n"] - job_type: ["server", "training"] + job_type: ["server"] steps: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v1 diff --git a/bin/generate_base_images.py b/bin/generate_base_images.py index 42b11609e..bc82cbbfd 100755 --- a/bin/generate_base_images.py +++ b/bin/generate_base_images.py @@ -104,13 +104,18 @@ def _build( "docker", "buildx", "build", - "--platform=linux/amd64", + "--platform=linux/arm64,linux/amd64", ".", "-t", image_with_tag, ] if push: cmd.append("--push") + + # Needed to support multi-arch build. + subprocess.run( + ["docker", "buildx", "create", "--use"], cwd=build_ctx_path, check=True + ) subprocess.run(cmd, cwd=build_ctx_path, check=True) diff --git a/docker/base_images/base_image.Dockerfile.jinja b/docker/base_images/base_image.Dockerfile.jinja index 18479e7f3..37bf4c6cb 100644 --- a/docker/base_images/base_image.Dockerfile.jinja +++ b/docker/base_images/base_image.Dockerfile.jinja @@ -1,19 +1,19 @@ {% if use_gpu %} -FROM nvidia/cuda:11.2.1-base-ubuntu20.04 -ENV CUDNN_VERSION=8.1.0.77 -ENV CUDA=11.2 +FROM nvidia/cuda:12.2.2-base-ubuntu20.04 +ENV CUDNN_VERSION=8.9.5.29 +ENV CUDA=12.2 ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \ apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ - cuda-command-line-tools-11-2 \ - libcublas-11-2 \ - libcublas-dev-11-2 \ - libcufft-11-2 \ - libcurand-11-2 \ - libcusolver-11-2 \ - libcusparse-11-2 \ + cuda-command-line-tools-12-2 \ + libcublas-12-2 \ + libcublas-dev-12-2 \ + libcufft-12-2 \ + libcurand-12-2 \ + libcusolver-12-2 \ + libcusparse-12-2 \ libcudnn8=${CUDNN_VERSION}-1+cuda${CUDA} \ libgomp1 \ && \ diff --git a/docs/_snippets/config-params.mdx b/docs/_snippets/config-params.mdx index a908cb9cb..c0001eee7 100644 --- a/docs/_snippets/config-params.mdx +++ b/docs/_snippets/config-params.mdx @@ -231,36 +231,36 @@ Either `VLLM` for vLLM, or `TGI` for TGI. The arguments for the model server. This includes information such as which model you intend to load, and which endpoin from the server you'd like to use. -### `hf_cache` +### `model_cache` -The `hf_cache` section is used for caching model weights at build-time. This is one of the biggest levers +The `model_cache` section is used for caching model weights at build-time. This is one of the biggest levers for decreasing cold start times, as downloading weights can be one of the lengthiest parts of starting a new model instance. Using this section ensures that model weights are cached at _build_ time. See the [model cache guide](guides/model-cache) for the full details on how to use this field. - Despite the fact that this field is called the `hf_cache`, there are multiple backends supported, not just Hugging Face. You can + Despite the fact that this field is called the `model_cache`, there are multiple backends supported, not just Hugging Face. You can also cache weights stored on GCS, for instance. -#### `hf_cache..repo_id` +#### `model_cache..repo_id` The endpoint for your cloud bucket. Currently, we support Hugging Face and Google Cloud Storage. Example: `madebyollin/sdxl-vae-fp16-fix` for a Hugging Face repo, or `gcs://path-to-my-bucket` for a GCS bucket. -#### `hf_cache..revision` +#### `model_cache..revision` Points to your revision. This is only relevant if you are pulling By default, it refers to `main`. -#### `hf_cache..allow_patterns` +#### `model_cache..allow_patterns` Only cache files that match specified patterns. Utilize Unix shell-style wildcards to denote these patterns. By default, all paths are included. -#### `hf_cache..ignore_patterns` +#### `model_cache..ignore_patterns` Conversely, you can also denote file patterns to ignore, hence streamlining the caching process. By default, nothing is ignored. diff --git a/docs/examples/04-image-generation.mdx b/docs/examples/04-image-generation.mdx index 8060ea4b6..8ae54c4bc 100644 --- a/docs/examples/04-image-generation.mdx +++ b/docs/examples/04-image-generation.mdx @@ -215,7 +215,7 @@ subsequently. To enable caching, add the following to the config: ```yaml -hf_cache: +model_cache: - repo_id: madebyollin/sdxl-vae-fp16-fix allow_patterns: - config.json diff --git a/docs/examples/06-high-performance-cached-weights.mdx b/docs/examples/06-high-performance-cached-weights.mdx index c2616573b..c3905a8f2 100644 --- a/docs/examples/06-high-performance-cached-weights.mdx +++ b/docs/examples/06-high-performance-cached-weights.mdx @@ -89,9 +89,9 @@ requirements: - sentencepiece==0.1.99 - protobuf==4.24.4 ``` -# Configuring the hf_cache +# Configuring the model_cache -To cache model weights, set the `hf_cache` key. +To cache model weights, set the `model_cache` key. The `repo_id` field allows you to specify a Huggingface repo to pull down and cache at build-time, and the `ignore_patterns` field allows you to specify files to ignore. If this is specified, then @@ -100,7 +100,7 @@ this repo won't have to be pulled during runtime. Check out the [guide](https://truss.baseten.co/guides/model-cache) for more info. ```yaml config.yaml -hf_cache: +model_cache: - repo_id: "NousResearch/Llama-2-7b-chat-hf" ignore_patterns: - "*.bin" @@ -197,7 +197,7 @@ requirements: - transformers==4.34.0 - sentencepiece==0.1.99 - protobuf==4.24.4 -hf_cache: +model_cache: - repo_id: "NousResearch/Llama-2-7b-chat-hf" ignore_patterns: - "*.bin" diff --git a/docs/examples/performance/cached-weights.mdx b/docs/examples/performance/cached-weights.mdx index 03f19edad..8c65fda93 100644 --- a/docs/examples/performance/cached-weights.mdx +++ b/docs/examples/performance/cached-weights.mdx @@ -3,7 +3,7 @@ title: Deploy Llama 2 with Caching description: "Enable fast cold starts for a model with private Hugging Face weights" --- -In this example, we will cover how you can use the `hf_cache` key in your Truss's `config.yml` to automatically bundle model weights from a private Hugging Face repo. +In this example, we will cover how you can use the `model_cache` key in your Truss's `config.yml` to automatically bundle model weights from a private Hugging Face repo. Bundling model weights can significantly reduce cold start times because your instance won't waste time downloading the model weights from Hugging Face's servers. @@ -116,10 +116,10 @@ Always pin exact versions for your Python dependencies. The ML/AI space moves fa ### Step 3: Configure Hugging Face caching -Finally, we can configure Hugging Face caching in `config.yaml` by adding the `hf_cache` key. When building the image for your Llama 2 deployment, the Llama 2 model weights will be downloaded and cached for future use. +Finally, we can configure Hugging Face caching in `config.yaml` by adding the `model_cache` key. When building the image for your Llama 2 deployment, the Llama 2 model weights will be downloaded and cached for future use. ```yaml config.yaml -hf_cache: +model_cache: - repo_id: "meta-llama/Llama-2-7b-chat-hf" ignore_patterns: - "*.bin" @@ -163,7 +163,7 @@ requirements: - safetensors==0.3.2 - torch==2.0.1 - transformers==4.30.2 -hf_cache: +model_cache: - repo_id: "NousResearch/Llama-2-7b-chat-hf" ignore_patterns: - "*.bin" diff --git a/docs/guides/model-cache.mdx b/docs/guides/model-cache.mdx index ccc5d1da8..855cca50a 100644 --- a/docs/guides/model-cache.mdx +++ b/docs/guides/model-cache.mdx @@ -18,17 +18,20 @@ In practice, this reduces the cold start for large models to just a few seconds. ## Enabling Caching for a Model -To enable caching, simply add `hf_cache` to your `config.yml` with a valid `repo_id`. The `hf_cache` has a few key configurations: +To enable caching, simply add `model_cache` to your `config.yml` with a valid `repo_id`. The `model_cache` has a few key configurations: - `repo_id` (required): The endpoint for your cloud bucket. Currently, we support Hugging Face and Google Cloud Storage. - `revision`: Points to your revision. This is only relevant if you are pulling By default, it refers to `main`. - `allow_patterns`: Only cache files that match specified patterns. Utilize Unix shell-style wildcards to denote these patterns. - `ignore_patterns`: Conversely, you can also denote file patterns to ignore, hence streamlining the caching process. -Here is an example of a well written `hf_cache` for Stable Diffusion XL. Note how it only pulls the model weights that it needs using `allow_patterns`. +We recently renamed `hf_cache` to `model_cache`, but don't worry! If you're using `hf_cache` in any of your projects, it will automatically be aliased to `model_cache`. + + +Here is an example of a well written `model_cache` for Stable Diffusion XL. Note how it only pulls the model weights that it needs using `allow_patterns`. ```yaml config.yml ... -hf_cache: +model_cache: - repo_id: madebyollin/sdxl-vae-fp16-fix allow_patterns: - config.json @@ -51,7 +54,7 @@ Many Hugging Face repos have model weights in different formats (`.bin`, `.safet There are also some additional steps depending on the cloud bucket you want to query. ### Hugging Face 🤗 -For any public Hugging Face repo, you don't need to do anything else. Adding the `hf_cache` key with an appropriate `repo_id` should be enough. +For any public Hugging Face repo, you don't need to do anything else. Adding the `model_cache` key with an appropriate `repo_id` should be enough. However, if you want to deploy a model from a gated repo like [Llama 2](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to Baseten, there's a few steps you need to take: @@ -86,15 +89,17 @@ Weights will be cached in the default Hugging Face cache directory, `~/.cache/hu ### Google Cloud Storage Google Cloud Storage is a great alternative to Hugging Face when you have a custom model or fine-tune you want to gate, especially if you are already using GCP and care about security and compliance. -Your `hf_cache` should look something like this: +Your `model_cache` should look something like this: ```yaml config.yml ... -hf_cache: +model_cache: repo_id: gcs://path-to-my-bucket ... ``` +If you are accessing a public GCS bucket, you can ignore the following steps, but make sure you set appropriate permissions on your bucket. Users should be able to list and view all files. Otherwise, the model build will fail. + For a private GCS bucket, first export your service account key. Rename it to be `service_account.json` and add it to the `data` directory of your Truss. Your file structure should look something like this: @@ -111,9 +116,80 @@ your-truss If you are using version control, like git, for your Truss, make sure to add `service_account.json` to your `.gitignore` file. You don't want to accidentally expose your service account key. -Weights will be cached at `/app/hf_cache/{your_bucket_name}`. +Weights will be cached at `/app/model_cache/{your_bucket_name}`. + + +### Amazon Web Services S3 + +Another popular cloud storage option for hosting model weights is AWS S3, especially if you're already using AWS services. + +Your `model_cache` should look something like this: + +```yaml config.yml +... +model_cache: + repo_id: s3://path-to-my-bucket +... +``` + +If you are accessing a public GCS bucket, you can ignore the subsequent steps, but make sure you set an appropriate appropriate policy on your bucket. Users should be able to list and view all files. Otherwise, the model build will fail. + +However, for a private S3 bucket, you need to first find your `aws_access_key_id`, `aws_secret_access_key`, and `aws_region` in your AWS dashboard. Create a file named `s3_credentials.json`. Inside this file, add the credentials that you identified earlier as shown below. Place this file into the `data` directory of your Truss. +The key `aws_session_token` can be included, but is optional. + +Here is an example of how your `s3_credentials.json` file should look: + +```json +{ + "aws_access_key_id": "YOUR-ACCESS-KEY", + "aws_secret_access_key": "YOUR-SECRET-ACCESS-KEY", + "aws_region": "YOUR-REGION" +} +``` + +Your overall file structure should now look something like this: + +``` +your-truss +|--model +| └── model.py +|--data +|. └── s3_credentials.json +``` + +When you are generating credentials, make sure that the resulting keys have at minimum the following IAM policy: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Action": [ + "s3:GetObject", + "s3:ListObjects", + ], + "Effect": "Allow", + "Resource": ["arn:aws:s3:::S3_BUCKET/PATH_TO_MODEL/*"] + }, + { + "Action": [ + "s3:ListBucket", + ], + "Effect": "Allow", + "Resource": ["arn:aws:s3:::S3_BUCKET"] + } + ] + } +``` + + + +If you are using version control, like git, for your Truss, make sure to add `s3_credentials.json` to your `.gitignore` file. You don't want to accidentally expose your service account key. + + +Weights will be cached at `/app/model_cache/{your_bucket_name}`. ### Other Buckets -We're currently workign on adding support for more bucket types, including AWS S3. If you have any suggestions, please [leave an issue](https://github.com/basetenlabs/truss/issues) on our GitHub repo. +We can work with you to support additional bucket types if needed. If you have any suggestions, please [leave an issue](https://github.com/basetenlabs/truss/issues) on our GitHub repo. diff --git a/pyproject.toml b/pyproject.toml index a62c1f452..8058d8688 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.7.14" +version = "0.7.15" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/cli/cli.py b/truss/cli/cli.py index aac1c0e6b..87900222a 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -43,6 +43,8 @@ ] } +console = rich.console.Console() + def echo_output(f: Callable[..., object]): @wraps(f) @@ -298,19 +300,25 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]): is_flag=True, required=False, default=False, - help="Invoked the published model version.", + help="Call the published model deployment.", ) @click.option( "--model-version", type=str, required=False, - help="ID of model version to invoke", + help="[DEPRECATED] Use --model-deployment instead, this will be removed in future release. ID of model deployment", +) +@click.option( + "--model-deployment", + type=str, + required=False, + help="ID of model deployment to call", ) @click.option( "--model", type=str, required=False, - help="ID of model to invoke", + help="ID of model to call", ) @echo_output def predict( @@ -320,10 +328,11 @@ def predict( file: Optional[Path], published: Optional[bool], model_version: Optional[str], + model_deployment: Optional[str], model: Optional[str], ): """ - Invokes the packaged model + Calls the packaged model TARGET_DIRECTORY: A Truss directory. If none, use current directory. @@ -336,10 +345,17 @@ def predict( remote_provider = RemoteFactory.create(remote=remote) + if model_version: + console.print( + "[DEPRECATED] --model-version is deprecated, use --model-deployment instead.", + style="yellow", + ) + model_deployment = model_version + model_identifier = _extract_and_validate_model_identifier( target_directory, model_id=model, - model_version_id=model_version, + model_version_id=model_deployment, published=published, ) @@ -419,10 +435,10 @@ def push( if service.is_draft: draft_model_text = """ |---------------------------------------------------------------------------------------| -| Your model has been deployed as a draft. Draft models allow you to | +| Your model has been deployed as a development model. Development models allow you to | | iterate quickly during the deployment process. | | | -| When you are ready to publish your deployed model as a new version, | +| When you are ready to publish your deployed model as a new deployment, | | pass `--publish` to the `truss push` command. To monitor changes to your model and | | rapidly iterate, run the `truss watch` command. | | | diff --git a/truss/contexts/image_builder/cache_warmer.py b/truss/contexts/image_builder/cache_warmer.py index 12236f840..ea8615d70 100644 --- a/truss/contexts/image_builder/cache_warmer.py +++ b/truss/contexts/image_builder/cache_warmer.py @@ -3,17 +3,24 @@ import os import subprocess import sys +from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import Optional, Type import boto3 +from botocore import UNSIGNED from botocore.client import Config +from botocore.exceptions import ClientError, NoCredentialsError from google.cloud import storage from huggingface_hub import hf_hub_download os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" B10CP_PATH_TRUSS_ENV_VAR_NAME = "B10CP_PATH_TRUSS" +GCS_CREDENTIALS = "/app/data/service_account.json" +S3_CREDENTIALS = "/app/data/s3_credentials.json" + def _b10cp_path() -> Optional[str]: return os.environ.get(B10CP_PATH_TRUSS_ENV_VAR_NAME) @@ -35,6 +42,38 @@ def _download_from_url_using_b10cp( ) +@dataclass +class AWSCredentials: + access_key_id: str + secret_access_key: str + region: str + session_token: Optional[str] + + +def parse_s3_credentials_file(key_file_path: str) -> AWSCredentials: + # open the json file + with open(key_file_path, "r") as f: + data = json.load(f) + + # validate the data + if ( + "aws_access_key_id" not in data + or "aws_secret_access_key" not in data + or "aws_region" not in data + ): + raise ValueError("Invalid AWS credentials file") + + # create an AWS Service Account object + aws_sa = AWSCredentials( + access_key_id=data["aws_access_key_id"], + secret_access_key=data["aws_secret_access_key"], + region=data["aws_region"], + session_token=data.get("aws_session_token", None), + ) + + return aws_sa + + def split_path(path, prefix="gs://"): # Remove the 'gs://' prefix path = path.replace(prefix, "") @@ -48,107 +87,149 @@ def split_path(path, prefix="gs://"): return bucket_name, path -def parse_s3_service_account_file(file_path): - # open the json file - with open(file_path, "r") as f: - data = json.load(f) +class RepositoryFile(ABC): + def __init__(self, repo_name, file_name, revision_name): + self.repo_name = repo_name + self.file_name = file_name + self.revision_name = revision_name + + @staticmethod + def from_file( + new_repo_name: str, new_file_name: str, new_revision_name: str + ) -> "RepositoryFile": + repository_class: Type["RepositoryFile"] + if new_repo_name.startswith("gs://"): + repository_class = GCSFile + elif new_repo_name.startswith("s3://"): + repository_class = S3File + else: + repository_class = HuggingFaceFile + return repository_class(new_repo_name, new_file_name, new_revision_name) + + @abstractmethod + def download_to_cache(self): + pass + + +class HuggingFaceFile(RepositoryFile): + def download_to_cache(self): + secret_path = Path("/etc/secrets/hf-access-token") + secret = secret_path.read_text().strip() if secret_path.exists() else None + try: + hf_hub_download( + self.repo_name, + self.file_name, + revision=self.revision_name, + token=secret, + ) + except FileNotFoundError: + raise RuntimeError( + "Hugging Face repository not found (and no valid secret found for possibly private repository)." + ) - # validate the data - if "aws_access_key_id" not in data or "aws_secret_access_key" not in data: - raise ValueError("Invalid AWS credentials file") - # parse the data - aws_access_key_id = data["aws_access_key_id"] - aws_secret_access_key = data["aws_secret_access_key"] - aws_region = data["aws_region"] +class GCSFile(RepositoryFile): + def download_to_cache(self): + # Create GCS Client + bucket_name, _ = split_path(repo_name, prefix="gs://") - return aws_access_key_id, aws_secret_access_key, aws_region + is_private = os.path.exists(GCS_CREDENTIALS) + print(is_private) + if is_private: + print("loading...") + client = storage.Client.from_service_account_json(GCS_CREDENTIALS) + else: + client = storage.Client.create_anonymous_client() + bucket = client.bucket(bucket_name) -def download_file( - repo_name, file_name, revision_name=None, key_file="/app/data/service_account.json" -): - # Check if repo_name starts with "gs://" - if repo_name.startswith(("gs://", "s3://")): - prefix = repo_name[:5] - - # Create directory if not exist - bucket_name, _ = split_path(repo_name, prefix=prefix) - repo_name = repo_name.replace(prefix, "") - cache_dir = Path(f"/app/hf_cache/{bucket_name}") + # Cache file + cache_dir = Path(f"/app/model_cache/{bucket_name}") cache_dir.mkdir(parents=True, exist_ok=True) - if prefix == "gs://": - # Connect to GCS storage - storage_client = storage.Client.from_service_account_json(key_file) - bucket = storage_client.bucket(bucket_name) - blob = bucket.blob(file_name) + dst_file = cache_dir / self.file_name + if not dst_file.parent.exists(): + dst_file.parent.mkdir(parents=True) - dst_file = Path(f"{cache_dir}/{file_name}") - if not dst_file.parent.exists(): - dst_file.parent.mkdir(parents=True) + blob = bucket.blob(self.file_name) - if not blob.exists(storage_client): - raise RuntimeError(f"File not found on GCS bucket: {blob.name}") + if not blob.exists(client): + raise RuntimeError(f"File not found on GCS bucket: {blob.name}") + if is_private: url = blob.generate_signed_url( version="v4", expiration=datetime.timedelta(minutes=15), method="GET", ) - try: - proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file) - proc.wait() - except Exception as e: - raise RuntimeError(f"Failure downloading file from GCS: {e}") - elif prefix == "s3://": - ( - AWS_ACCESS_KEY_ID, - AWS_SECRET_ACCESS_KEY, - AWS_REGION, - ) = parse_s3_service_account_file(key_file) + else: + base_url = "https://storage.googleapis.com" + url = f"{base_url}/{bucket_name}/{blob.name}" + + download_file_using_b10cp(url, dst_file, self.file_name) + + +class S3File(RepositoryFile): + def download_to_cache(self): + # Create S3 Client + bucket_name, _ = split_path(repo_name, prefix="s3://") + + if os.path.exists(S3_CREDENTIALS): + s3_credentials = parse_s3_credentials_file(S3_CREDENTIALS) client = boto3.client( "s3", - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_REGION, + aws_access_key_id=s3_credentials.access_key_id, + aws_secret_access_key=s3_credentials.secret_access_key, + region_name=s3_credentials.region, + aws_session_token=s3_credentials.session_token, config=Config(signature_version="s3v4"), ) - bucket_name, _ = split_path(bucket_name, prefix="s3://") - - dst_file = Path(f"{cache_dir}/{file_name}") - if not dst_file.parent.exists(): - dst_file.parent.mkdir(parents=True) - - try: - url = client.generate_presigned_url( - "get_object", - Params={"Bucket": bucket_name, "Key": file_name}, - ExpiresIn=3600, - ) - except Exception: - raise RuntimeError(f"File not found on S3 bucket: {file_name}") - - try: - proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file) - proc.wait() - except Exception as e: - raise RuntimeError(f"Failure downloading file from S3: {e}") - else: - secret_path = Path("/etc/secrets/hf-access-token") - secret = secret_path.read_text().strip() if secret_path.exists() else None + else: + client = boto3.client("s3", config=Config(signature_version=UNSIGNED)) + + # Cache file + cache_dir = Path(f"/app/model_cache/{bucket_name}") + cache_dir.mkdir(parents=True, exist_ok=True) + + dst_file = cache_dir / self.file_name + if not dst_file.parent.exists(): + dst_file.parent.mkdir(parents=True) + try: - hf_hub_download( - repo_name, - file_name, - revision=revision_name, - token=secret, + url = client.generate_presigned_url( + "get_object", + Params={"Bucket": bucket_name, "Key": file_name}, + ExpiresIn=3600, ) - except FileNotFoundError: + except NoCredentialsError as nce: raise RuntimeError( - "Hugging Face repository not found (and no valid secret found for possibly private repository)." + f"No AWS credentials found\nOriginal exception: {str(nce)}" + ) + except ClientError as ce: + raise RuntimeError( + f"Client error when accessing the S3 bucket (check your credentials): {str(ce)}" + ) + except Exception as exc: + raise RuntimeError( + f"File not found on S3 bucket: {file_name}\nOriginal exception: {str(exc)}" ) + download_file_using_b10cp(url, dst_file, self.file_name) + + +def download_file_using_b10cp(url, dst_file, file_name): + try: + proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file) + proc.wait() + + except FileNotFoundError as file_error: + raise RuntimeError(f"Failure due to file ({file_name}) not found: {file_error}") + + +def download_file(repo_name, file_name, revision_name=None): + file = RepositoryFile.from_file(repo_name, file_name, revision_name) + file.download_to_cache() + if __name__ == "__main__": file_path = Path.home() / ".cache/huggingface/hub/version.txt" diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index 0cdc48ba0..5a38aacaf 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -1,9 +1,14 @@ -import json +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type import boto3 import yaml +from botocore import UNSIGNED +from botocore.client import Config from google.cloud import storage from huggingface_hub import get_hf_file_metadata, hf_hub_url, list_repo_files from huggingface_hub.utils import filter_repo_objects @@ -21,6 +26,10 @@ TEMPLATES_DIR, TRITON_SERVER_CODE_DIR, ) +from truss.contexts.image_builder.cache_warmer import ( + AWSCredentials, + parse_s3_credentials_file, +) from truss.contexts.image_builder.image_builder import ImageBuilder from truss.contexts.image_builder.util import ( TRUSS_BASE_IMAGE_VERSION_TAG, @@ -31,7 +40,7 @@ ) from truss.contexts.truss_context import TrussContext from truss.patch.hash import directory_content_hash -from truss.truss_config import Build, HuggingFaceModel, ModelServer, TrussConfig +from truss.truss_config import Build, ModelRepo, ModelServer, TrussConfig from truss.truss_spec import TrussSpec from truss.util.download import download_external_data from truss.util.jinja import read_template_from_fs @@ -45,10 +54,174 @@ BUILD_CONTROL_SERVER_DIR_NAME = "control" CONFIG_FILE = "config.yaml" +GCS_CREDENTIALS = "service_account.json" +S3_CREDENTIALS = "s3_credentials.json" HF_ACCESS_TOKEN_SECRET_NAME = "hf_access_token" HF_ACCESS_TOKEN_FILE_NAME = "hf-access-token" +CLOUD_BUCKET_CACHE = Path("/app/model_cache/") +HF_SOURCE_DIR = Path("./root/.cache/huggingface/hub/") +HF_CACHE_DIR = Path("/root/.cache/huggingface/hub/") + + +class RemoteCache(ABC): + def __init__(self, repo_name, data_dir, revision=None): + self.repo_name = repo_name + self.data_dir = data_dir + self.revision = revision + + @staticmethod + def from_repo(repo_name: str, data_dir: Path) -> "RemoteCache": + repository_class: Type["RemoteCache"] + if repo_name.startswith("gs://"): + repository_class = GCSCache + elif repo_name.startswith("s3://"): + repository_class = S3Cache + else: + repository_class = HuggingFaceCache + return repository_class(repo_name, data_dir) + + def filter(self, allow_patterns, ignore_patterns): + return list( + filter_repo_objects( + items=self.list_files(revision=self.revision), + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + ) + + @abstractmethod + def list_files(self, revision=None): + pass + + @abstractmethod + def prepare_for_cache(self, filenames): + pass + + +class GCSCache(RemoteCache): + def list_files(self, revision=None): + gcs_credentials_file = self.data_dir / GCS_CREDENTIALS + + if gcs_credentials_file.exists(): + storage_client = storage.Client.from_service_account_json( + gcs_credentials_file + ) + else: + storage_client = storage.Client.create_anonymous_client() + + bucket_name, prefix = split_path(self.repo_name, prefix="gs://") + blobs = storage_client.list_blobs(bucket_name, prefix=prefix) + + all_objects = [] + for blob in blobs: + # leave out folders + if blob.name[-1] == "/": + continue + all_objects.append(blob.name) + + return all_objects + + def prepare_for_cache(self, filenames): + bucket_name, _ = split_path(self.repo_name, prefix="gs://") + + files_to_cache = [] + for filename in filenames: + file_location = str(CLOUD_BUCKET_CACHE / bucket_name / filename) + files_to_cache.append(CachedFile(source=file_location, dst=file_location)) + + return files_to_cache + + +class S3Cache(RemoteCache): + def list_files(self, revision=None): + s3_credentials_file = self.data_dir / S3_CREDENTIALS + + if s3_credentials_file.exists(): + s3_credentials: AWSCredentials = parse_s3_credentials_file( + self.data_dir / S3_CREDENTIALS + ) + session = boto3.Session( + aws_access_key_id=s3_credentials.access_key_id, + aws_secret_access_key=s3_credentials.secret_access_key, + aws_session_token=s3_credentials.session_token, + region_name=s3_credentials.region, + ) + s3 = session.resource("s3") + else: + s3 = boto3.resource("s3", config=Config(signature_version=UNSIGNED)) + + # path may be like "folderA/folderB" + bucket_name, path = split_path(self.repo_name, prefix="s3://") + bucket = s3.Bucket(bucket_name) + + all_objects = [] + for blob in bucket.objects.filter(Prefix=path): + all_objects.append(blob.key) + return all_objects + + def prepare_for_cache(self, filenames): + bucket_name, _ = split_path(self.repo_name, prefix="s3://") + + files_to_cache = [] + for filename in filenames: + file_location = str(CLOUD_BUCKET_CACHE / bucket_name / filename) + files_to_cache.append(CachedFile(source=file_location, dst=file_location)) + + return files_to_cache + + +def hf_cache_file_from_location(path: str): + src_file_location = str(HF_SOURCE_DIR / path) + dst_file_location = str(HF_CACHE_DIR / path) + cache_file = CachedFile(source=src_file_location, dst=dst_file_location) + return cache_file + + +class HuggingFaceCache(RemoteCache): + def list_files(self, revision=None): + return list_repo_files(self.repo_name, revision=revision) + + def prepare_for_cache(self, filenames): + files_to_cache = [] + repo_folder_name = f"models--{self.repo_name.replace('/', '--')}" + for filename in filenames: + hf_url = hf_hub_url(self.repo_name, filename) + hf_file_metadata = get_hf_file_metadata(hf_url) + + files_to_cache.append( + hf_cache_file_from_location( + f"{repo_folder_name}/blobs/{hf_file_metadata.etag}" + ) + ) + + # snapshots is just a set of folders with symlinks -- we can copy the entire thing separately + files_to_cache.append( + hf_cache_file_from_location(f"{repo_folder_name}/snapshots/") + ) + + # refs just has files with revision commit hashes + files_to_cache.append(hf_cache_file_from_location(f"{repo_folder_name}/refs/")) + + files_to_cache.append(hf_cache_file_from_location("version.txt")) + + return files_to_cache + + +def get_credentials_to_cache(data_dir: Path) -> List[str]: + gcs_credentials_file = data_dir / GCS_CREDENTIALS + s3_credentials_file = data_dir / S3_CREDENTIALS + credentials = [gcs_credentials_file, s3_credentials_file] + + credentials_to_cache = [] + for file in credentials: + if file.exists(): + build_path = Path(*file.parts[-2:]) + credentials_to_cache.append(str(build_path)) + + return credentials_to_cache + def create_triton_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path): _spec = TrussSpec(truss_dir) @@ -103,79 +276,10 @@ def split_path(path, prefix="gs://"): return bucket_name, path -def list_gcs_bucket_files( - bucket_name, - data_dir, - is_trusted=False, -): - if is_trusted: - storage_client = storage.Client.from_service_account_json( - data_dir / "service_account.json" - ) - else: - storage_client = storage.Client() - bucket_name, prefix = split_path(bucket_name) - blobs = storage_client.list_blobs(bucket_name, prefix=prefix) - - all_objects = [] - for blob in blobs: - # leave out folders - if blob.name[-1] == "/": - continue - all_objects.append(blob.name) - - return all_objects - - -def parse_s3_service_account_file(file_path): - # open the json file - with open(file_path, "r") as f: - data = json.load(f) - - # validate the data - if "aws_access_key_id" not in data or "aws_secret_access_key" not in data: - raise ValueError("Invalid AWS credentials file") - - # parse the data - aws_access_key_id = data["aws_access_key_id"] - aws_secret_access_key = data["aws_secret_access_key"] - - return aws_access_key_id, aws_secret_access_key - - -def list_s3_bucket_files(bucket_name, data_dir, is_trusted=False): - if is_trusted: - AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY = parse_s3_service_account_file( - data_dir / "service_account.json" - ) - session = boto3.Session(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) - s3 = session.resource("s3") - else: - s3 = boto3.client("s3") - - bucket_name, _ = split_path(bucket_name, prefix="s3://") - bucket = s3.Bucket(bucket_name) - - all_objects = [] - for blob in bucket.objects.all(): - all_objects.append(blob.key) - - return all_objects - - -def list_files(repo_id, data_dir, revision=None): - credentials_file = data_dir / "service_account.json" - if repo_id.startswith("gs://"): - return list_gcs_bucket_files( - repo_id, data_dir, is_trusted=credentials_file.exists() - ) - elif repo_id.startswith("s3://"): - return list_s3_bucket_files( - repo_id, data_dir, is_trusted=credentials_file.exists() - ) - else: - # we assume it's a HF bucket - return list_repo_files(repo_id, revision=revision) +@dataclass +class CachedFile: + source: str + dst: str def update_model_key(config: TrussConfig) -> str: @@ -198,14 +302,14 @@ def update_model_name(config: TrussConfig, model_key: str) -> str: "Key for model missing in config or incorrect key used. Use `model` for VLLM and `model_id` for TGI." ) model_name = config.build.arguments[model_key] - if "gs://" in model_name: + if "gs://" in model_name or "s3://" in model_name: # if we are pulling from a gs bucket, we want to alias it as a part of the cache - model_to_cache = HuggingFaceModel(model_name) - config.hf_cache.models.append(model_to_cache) + model_to_cache = ModelRepo(model_name) + config.model_cache.models.append(model_to_cache) + + prefix_removed = model_name[4:] # removes "gs://" or "s3://" - config.build.arguments[ - model_key - ] = f"/app/hf_cache/{model_name.replace('gs://', '')}" + config.build.arguments[model_key] = f"/app/model_cache/{prefix_removed}" return model_name @@ -213,70 +317,31 @@ def get_files_to_cache(config: TrussConfig, truss_dir: Path, build_dir: Path): def copy_into_build_dir(from_path: Path, path_in_build_dir: str): copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator] - model_files = {} - cached_files: List[str] = [] - if config.hf_cache: + remote_model_files = {} + local_files_to_cache: List[CachedFile] = [] + if config.model_cache: curr_dir = Path(__file__).parent.resolve() copy_into_build_dir(curr_dir / "cache_warmer.py", "cache_warmer.py") - for model in config.hf_cache.models: + for model in config.model_cache.models: repo_id = model.repo_id revision = model.revision allow_patterns = model.allow_patterns ignore_patterns = model.ignore_patterns - filtered_repo_files = list( - filter_repo_objects( - items=list_files( - repo_id, truss_dir / config.data_dir, revision=revision - ), - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - ) - ) - - cached_files = fetch_files_to_cache( - cached_files, repo_id, filtered_repo_files - ) + model_cache = RemoteCache.from_repo(repo_id, truss_dir / config.data_dir) + remote_filtered_files = model_cache.filter(allow_patterns, ignore_patterns) + local_files_to_cache += model_cache.prepare_for_cache(remote_filtered_files) - model_files[repo_id] = {"files": filtered_repo_files, "revision": revision} + remote_model_files[repo_id] = { + "files": remote_filtered_files, + "revision": revision, + } copy_into_build_dir( TEMPLATES_DIR / "cache_requirements.txt", "cache_requirements.txt" ) - return model_files, cached_files - - -def fetch_files_to_cache(cached_files: list, repo_id: str, filtered_repo_files: list): - if repo_id.startswith("gs://"): - bucket_name, _ = split_path(repo_id) - repo_id = f"gs://{bucket_name}" - - for filename in filtered_repo_files: - cached_files.append(f"/app/hf_cache/{bucket_name}/{filename}") - elif repo_id.startswith("s3://"): - bucket_name, _ = split_path(repo_id, prefix="s3://") - repo_id = f"s3://{bucket_name}" - - for filename in filtered_repo_files: - cached_files.append(f"/app/hf_cache/{bucket_name}/{filename}") - else: - repo_folder_name = f"models--{repo_id.replace('/', '--')}" - for filename in filtered_repo_files: - hf_url = hf_hub_url(repo_id, filename) - hf_file_metadata = get_hf_file_metadata(hf_url) - - cached_files.append(f"{repo_folder_name}/blobs/{hf_file_metadata.etag}") - - # snapshots is just a set of folders with symlinks -- we can copy the entire thing separately - cached_files.append(f"{repo_folder_name}/snapshots/") - - # refs just has files with revision commit hashes - cached_files.append(f"{repo_folder_name}/refs/") - - cached_files.append("version.txt") - - return cached_files + return remote_model_files, local_files_to_cache def update_config_and_gather_files( @@ -306,14 +371,13 @@ def create_tgi_build_dir( ) data_dir = build_dir / "data" - credentials_file = data_dir / "service_account.json" dockerfile_content = dockerfile_template.render( config=config, hf_access_token=hf_access_token, models=model_files, - hf_cache=config.hf_cache, + model_cache=config.model_cache, data_dir_exists=data_dir.exists(), - credentials_exists=credentials_file.exists(), + credentials_to_cache=get_credentials_to_cache(data_dir), cached_files=cached_file_paths, use_hf_secret=use_hf_secret, hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME, @@ -364,15 +428,15 @@ def create_vllm_build_dir( nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja") data_dir = build_dir / "data" - credentials_file = data_dir / "service_account.json" + dockerfile_content = dockerfile_template.render( config=config, hf_access_token=hf_access_token, models=model_files, should_install_server_requirements=True, - hf_cache=config.hf_cache, + model_cache=config.model_cache, data_dir_exists=data_dir.exists(), - credentials_exists=credentials_file.exists(), + credentials_to_cache=get_credentials_to_cache(data_dir), cached_files=cached_file_paths, use_hf_secret=use_hf_secret, hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME, @@ -523,7 +587,7 @@ def _render_dockerfile( config = self._spec.config data_dir = build_dir / config.data_dir bundled_packages_dir = build_dir / config.bundled_packages_dir - credentials_file = data_dir / "service_account.json" + dockerfile_template = read_template_from_fs( TEMPLATES_DIR, SERVER_DOCKERFILE_TEMPLATE_NAME ) @@ -562,8 +626,8 @@ def _render_dockerfile( models=model_files, use_hf_secret=use_hf_secret, cached_files=cached_files, - credentials_exists=credentials_file.exists(), - hf_cache=len(config.hf_cache.models) > 0, + credentials_to_cache=get_credentials_to_cache(data_dir), + model_cache=len(config.model_cache.models) > 0, hf_access_token=hf_access_token, hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME, ) diff --git a/truss/contexts/image_builder/util.py b/truss/contexts/image_builder/util.py index d27351d8a..d50f60944 100644 --- a/truss/contexts/image_builder/util.py +++ b/truss/contexts/image_builder/util.py @@ -11,7 +11,7 @@ # [IMPORTANT] Make sure all images for this version are published to dockerhub # before change to this value lands. This value is used to look for base images # when building docker image for a truss. -TRUSS_BASE_IMAGE_VERSION_TAG = "v0.4.9" +TRUSS_BASE_IMAGE_VERSION_TAG = "v0.7.15" def file_is_empty(path: Path, ignore_hash_style_comments: bool = True) -> bool: diff --git a/truss/templates/cache.Dockerfile.jinja b/truss/templates/cache.Dockerfile.jinja index edc44c61f..5ed3dd169 100644 --- a/truss/templates/cache.Dockerfile.jinja +++ b/truss/templates/cache.Dockerfile.jinja @@ -1,20 +1,22 @@ FROM python:3.11-slim as cache_warmer -RUN mkdir -p /app/hf_cache +RUN mkdir -p /app/model_cache WORKDIR /app {% if hf_access_token %} ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}} {% endif %} -{%- if credentials_exists %} -COPY ./data/service_account.json /app/data/service_account.json -{%- endif %} RUN apt-get -y update; apt-get -y install curl; curl -s https://baseten-public.s3.us-west-2.amazonaws.com/bin/b10cp-5fe8dc7da-linux-amd64 -o /app/b10cp; chmod +x /app/b10cp ENV B10CP_PATH_TRUSS /app/b10cp COPY ./cache_requirements.txt /app/cache_requirements.txt RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip COPY ./cache_warmer.py /cache_warmer.py + +{% for credential in credentials_to_cache %} +COPY ./{{credential}} /app/{{credential}} +{% endfor %} + {% for repo, hf_dir in models.items() %} {% for file in hf_dir.files %} {{ "RUN --mount=type=secret,id=" + hf_access_token_file_name + ",dst=/etc/secrets/" + hf_access_token_file_name if use_hf_secret else "RUN" }} python3 /cache_warmer.py {{file}} {{repo}} {% if hf_dir.revision != None %}{{hf_dir.revision}}{% endif %} diff --git a/truss/templates/copy_cache_files.Dockerfile.jinja b/truss/templates/copy_cache_files.Dockerfile.jinja index 98f089268..3772acf44 100644 --- a/truss/templates/copy_cache_files.Dockerfile.jinja +++ b/truss/templates/copy_cache_files.Dockerfile.jinja @@ -1,7 +1,3 @@ {% for file in cached_files %} - {%- if credentials_exists %} -COPY --from=cache_warmer {{file}} {{file}} - {%- else %} -COPY --from=cache_warmer ./root/.cache/huggingface/hub/{{file}} {{hf_dst_directory}}{{file}} - {%- endif %} +COPY --from=cache_warmer {{file.source}} {{file.dst}} {% endfor %} diff --git a/truss/templates/server.Dockerfile.jinja b/truss/templates/server.Dockerfile.jinja index c0386f80b..0dec1fefe 100644 --- a/truss/templates/server.Dockerfile.jinja +++ b/truss/templates/server.Dockerfile.jinja @@ -1,4 +1,4 @@ -{%- if hf_cache %} +{%- if model_cache %} {%- include "cache.Dockerfile.jinja" %} {%- endif %} @@ -65,8 +65,7 @@ RUN python3 -m venv /control/.env \ && /control/.env/bin/pip3 install -r /control/requirements.txt {%- endif %} -{%- if hf_cache %} - {%- set hf_dst_directory="/root/.cache/huggingface/hub/"%} +{%- if model_cache %} {%- include "copy_cache_files.Dockerfile.jinja"%} {%- endif %} diff --git a/truss/templates/tgi/tgi.Dockerfile.jinja b/truss/templates/tgi/tgi.Dockerfile.jinja index a27d65daa..67496b4ca 100644 --- a/truss/templates/tgi/tgi.Dockerfile.jinja +++ b/truss/templates/tgi/tgi.Dockerfile.jinja @@ -1,4 +1,4 @@ -{%- if hf_cache %} +{%- if model_cache %} {%- include "cache.Dockerfile.jinja" %} {%- endif %} @@ -22,7 +22,7 @@ ENV {{ env_var_name }} {{ env_var_value }} ENV SERVER_START_CMD /usr/bin/supervisord -{%- if hf_cache %} +{%- if model_cache %} {%- set hf_dst_directory="/data/"%} {%- include "copy_cache_files.Dockerfile.jinja"%} {%- endif %} diff --git a/truss/templates/vllm/vllm.Dockerfile.jinja b/truss/templates/vllm/vllm.Dockerfile.jinja index 5b452f49d..8d69f51c7 100644 --- a/truss/templates/vllm/vllm.Dockerfile.jinja +++ b/truss/templates/vllm/vllm.Dockerfile.jinja @@ -1,4 +1,4 @@ -{%- if hf_cache %} +{%- if model_cache %} {%- include "cache.Dockerfile.jinja" %} {%- endif %} @@ -22,7 +22,7 @@ ENV {{ env_var_name }} {{ env_var_value }} ENV SERVER_START_CMD /usr/bin/supervisord -{%- if hf_cache %} +{%- if model_cache %} {%- set hf_dst_directory="./root/.cache/huggingface/hub/"%} {%- include "copy_cache_files.Dockerfile.jinja"%} {%- endif %} diff --git a/truss/test_data/gcs_fix/config.yaml b/truss/test_data/gcs_fix/config.yaml new file mode 100644 index 000000000..6bbfbe9ab --- /dev/null +++ b/truss/test_data/gcs_fix/config.yaml @@ -0,0 +1,13 @@ +environment_variables: {} +external_package_dirs: [] +model_metadata: {} +model_name: gcs fix +python_version: py39 +requirements: [] +resources: + accelerator: null + cpu: '1' + memory: 2Gi + use_gpu: false +secrets: {} +system_packages: [] diff --git a/truss/test_data/gcs_fix/model/__init__.py b/truss/test_data/gcs_fix/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/truss/test_data/gcs_fix/model/model.py b/truss/test_data/gcs_fix/model/model.py new file mode 100644 index 000000000..9216b50d7 --- /dev/null +++ b/truss/test_data/gcs_fix/model/model.py @@ -0,0 +1,32 @@ +""" +The `Model` class is an interface between the ML model that you're packaging and the model +server that you're running it on. + +The main methods to implement here are: +* `load`: runs exactly once when the model server is spun up or patched and loads the + model onto the model server. Include any logic for initializing your model, such + as downloading model weights and loading the model into memory. +* `predict`: runs every time the model server is called. Include any logic for model + inference and return the model output. + +See https://truss.baseten.co/quickstart for more. +""" + + +class Model: + def __init__(self, **kwargs): + # Uncomment the following to get access + # to various parts of the Truss config. + + # self._data_dir = kwargs["data_dir"] + # self._config = kwargs["config"] + # self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + pass + + def predict(self, model_input): + # Run model inference here + return model_input diff --git a/truss/test_data/test_basic_truss/config.yaml b/truss/test_data/test_basic_truss/config.yaml new file mode 100644 index 000000000..8120b0ad3 --- /dev/null +++ b/truss/test_data/test_basic_truss/config.yaml @@ -0,0 +1,13 @@ +environment_variables: {} +external_package_dirs: [] +model_metadata: {} +model_name: basic truss +python_version: py39 +requirements: [] +resources: + accelerator: null + cpu: '1' + memory: 2Gi + use_gpu: false +secrets: {} +system_packages: [] diff --git a/truss/test_data/test_basic_truss/model/__init__.py b/truss/test_data/test_basic_truss/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/truss/test_data/test_basic_truss/model/model.py b/truss/test_data/test_basic_truss/model/model.py new file mode 100644 index 000000000..9216b50d7 --- /dev/null +++ b/truss/test_data/test_basic_truss/model/model.py @@ -0,0 +1,32 @@ +""" +The `Model` class is an interface between the ML model that you're packaging and the model +server that you're running it on. + +The main methods to implement here are: +* `load`: runs exactly once when the model server is spun up or patched and loads the + model onto the model server. Include any logic for initializing your model, such + as downloading model weights and loading the model into memory. +* `predict`: runs every time the model server is called. Include any logic for model + inference and return the model output. + +See https://truss.baseten.co/quickstart for more. +""" + + +class Model: + def __init__(self, **kwargs): + # Uncomment the following to get access + # to various parts of the Truss config. + + # self._data_dir = kwargs["data_dir"] + # self._config = kwargs["config"] + # self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + pass + + def predict(self, model_input): + # Run model inference here + return model_input diff --git a/truss/test_data/test_tgi_truss/config.yaml b/truss/test_data/test_tgi_truss/config.yaml index 33a0d08f1..986542030 100644 --- a/truss/test_data/test_tgi_truss/config.yaml +++ b/truss/test_data/test_tgi_truss/config.yaml @@ -9,7 +9,7 @@ model_metadata: {} model_name: Test TGI python_version: py39 requirements: [] -hf_cache: +model_cache: - repo_id: facebook/opt-125m ignore_patterns: - "*.h5" diff --git a/truss/test_data/test_truss_server_caching_truss/config.yaml b/truss/test_data/test_truss_server_caching_truss/config.yaml index 692a7d9ef..672d4e39b 100644 --- a/truss/test_data/test_truss_server_caching_truss/config.yaml +++ b/truss/test_data/test_truss_server_caching_truss/config.yaml @@ -6,7 +6,7 @@ python_version: py39 requirements: - transformers - torch -hf_cache: +model_cache: - repo_id: julien-c/EsperBERTo-small ignore_patterns: - "*.bin" diff --git a/truss/test_data/test_vllm_truss/config.yaml b/truss/test_data/test_vllm_truss/config.yaml index b9ad01d9b..70e9f5243 100644 --- a/truss/test_data/test_vllm_truss/config.yaml +++ b/truss/test_data/test_vllm_truss/config.yaml @@ -9,7 +9,7 @@ model_metadata: {} model_name: Test vLLM python_version: py39 requirements: [] -hf_cache: +model_cache: - repo_id: facebook/opt-125m ignore_patterns: - "*.h5" diff --git a/truss/tests/contexts/image_builder/test_serving_image_builder.py b/truss/tests/contexts/image_builder/test_serving_image_builder.py index 7e829b794..c8c158a47 100644 --- a/truss/tests/contexts/image_builder/test_serving_image_builder.py +++ b/truss/tests/contexts/image_builder/test_serving_image_builder.py @@ -12,13 +12,7 @@ update_model_name, ) from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all -from truss.truss_config import ( - Build, - HuggingFaceCache, - HuggingFaceModel, - ModelServer, - TrussConfig, -) +from truss.truss_config import Build, ModelCache, ModelRepo, ModelServer, TrussConfig from truss.truss_handle import TrussHandle BASE_DIR = Path(__file__).parent @@ -79,9 +73,9 @@ def test_overrides_model_id_vllm(): update_model_name(config, model_key) # Assert model overridden in config - assert Path(config.build.arguments["model"]) == Path("/app/hf_cache/llama-2-7b") - assert config.hf_cache == HuggingFaceCache( - models=[HuggingFaceModel(repo_id="gs://llama-2-7b/")] + assert Path(config.build.arguments["model"]) == Path("/app/model_cache/llama-2-7b") + assert config.model_cache == ModelCache( + models=[ModelRepo(repo_id="gs://llama-2-7b/")] ) @@ -98,17 +92,23 @@ def test_overrides_model_id_tgi(): update_model_name(config, model_key) # Assert model overridden in config - assert Path(config.build.arguments["model_id"]) == Path("/app/hf_cache/llama-2-7b") - assert config.hf_cache == HuggingFaceCache( - models=[HuggingFaceModel(repo_id="gs://llama-2-7b/")] + assert Path(config.build.arguments["model_id"]) == Path( + "/app/model_cache/llama-2-7b" + ) + assert config.model_cache == ModelCache( + models=[ModelRepo(repo_id="gs://llama-2-7b/")] ) +def flatten_cached_files(local_cache_files): + return [file.source for file in local_cache_files] + + def test_correct_hf_files_accessed_for_caching(): model = "openai/whisper-small" config = TrussConfig( python_version="py39", - hf_cache=HuggingFaceCache(models=[HuggingFaceModel(repo_id=model)]), + model_cache=ModelCache(models=[ModelRepo(repo_id=model)]), ) with TemporaryDirectory() as tmp_dir: @@ -116,12 +116,18 @@ def test_correct_hf_files_accessed_for_caching(): build_path = truss_path / "build" build_path.mkdir(parents=True, exist_ok=True) + hf_path = Path("root/.cache/huggingface/hub") + model_files, files_to_cache = get_files_to_cache(config, truss_path, build_path) - assert "version.txt" in files_to_cache + files_to_cache = flatten_cached_files(files_to_cache) + assert str(hf_path / "version.txt") in files_to_cache # It's unlikely the repo will change assert ( - "models--openai--whisper-small/blobs/1d7734884874f1a1513ed9aa760a4f8e97aaa02fd6d93a3a85d27b2ae9ca596b" + str( + hf_path + / "models--openai--whisper-small/blobs/59ef8a839f271fa2183c6a4c302669d097e43b6d" + ) in files_to_cache ) @@ -131,7 +137,7 @@ def test_correct_hf_files_accessed_for_caching(): assert "tokenizer_config.json" in files -@patch("truss.contexts.image_builder.serving_image_builder.list_gcs_bucket_files") +@patch("truss.contexts.image_builder.serving_image_builder.GCSCache.list_files") def test_correct_gcs_files_accessed_for_caching(mock_list_bucket_files): mock_list_bucket_files.return_value = [ "fake_model-001-of-002.bin", @@ -141,7 +147,7 @@ def test_correct_gcs_files_accessed_for_caching(mock_list_bucket_files): config = TrussConfig( python_version="py39", - hf_cache=HuggingFaceCache(models=[HuggingFaceModel(repo_id=model)]), + model_cache=ModelCache(models=[ModelRepo(repo_id=model)]), ) with TemporaryDirectory() as tmp_dir: @@ -150,13 +156,14 @@ def test_correct_gcs_files_accessed_for_caching(mock_list_bucket_files): build_path.mkdir(parents=True, exist_ok=True) model_files, files_to_cache = get_files_to_cache(config, truss_path, build_path) + files_to_cache = flatten_cached_files(files_to_cache) assert ( - "/app/hf_cache/crazy-good-new-model-7b/fake_model-001-of-002.bin" + "/app/model_cache/crazy-good-new-model-7b/fake_model-001-of-002.bin" in files_to_cache ) assert ( - "/app/hf_cache/crazy-good-new-model-7b/fake_model-002-of-002.bin" + "/app/model_cache/crazy-good-new-model-7b/fake_model-002-of-002.bin" in files_to_cache ) @@ -164,7 +171,7 @@ def test_correct_gcs_files_accessed_for_caching(mock_list_bucket_files): assert "fake_model-001-of-002.bin" in model_files[model]["files"] -@patch("truss.contexts.image_builder.serving_image_builder.list_s3_bucket_files") +@patch("truss.contexts.image_builder.serving_image_builder.S3Cache.list_files") def test_correct_s3_files_accessed_for_caching(mock_list_bucket_files): mock_list_bucket_files.return_value = [ "fake_model-001-of-002.bin", @@ -174,7 +181,7 @@ def test_correct_s3_files_accessed_for_caching(mock_list_bucket_files): config = TrussConfig( python_version="py39", - hf_cache=HuggingFaceCache(models=[HuggingFaceModel(repo_id=model)]), + model_cache=ModelCache(models=[ModelRepo(repo_id=model)]), ) with TemporaryDirectory() as tmp_dir: @@ -183,13 +190,14 @@ def test_correct_s3_files_accessed_for_caching(mock_list_bucket_files): build_path.mkdir(parents=True, exist_ok=True) model_files, files_to_cache = get_files_to_cache(config, truss_path, build_path) + files_to_cache = flatten_cached_files(files_to_cache) assert ( - "/app/hf_cache/crazy-good-new-model-7b/fake_model-001-of-002.bin" + "/app/model_cache/crazy-good-new-model-7b/fake_model-001-of-002.bin" in files_to_cache ) assert ( - "/app/hf_cache/crazy-good-new-model-7b/fake_model-002-of-002.bin" + "/app/model_cache/crazy-good-new-model-7b/fake_model-002-of-002.bin" in files_to_cache ) @@ -197,7 +205,7 @@ def test_correct_s3_files_accessed_for_caching(mock_list_bucket_files): assert "fake_model-001-of-002.bin" in model_files[model]["files"] -@patch("truss.contexts.image_builder.serving_image_builder.list_gcs_bucket_files") +@patch("truss.contexts.image_builder.serving_image_builder.GCSCache.list_files") def test_correct_nested_gcs_files_accessed_for_caching(mock_list_bucket_files): mock_list_bucket_files.return_value = [ "folder_a/folder_b/fake_model-001-of-002.bin", @@ -207,7 +215,7 @@ def test_correct_nested_gcs_files_accessed_for_caching(mock_list_bucket_files): config = TrussConfig( python_version="py39", - hf_cache=HuggingFaceCache(models=[HuggingFaceModel(repo_id=model)]), + model_cache=ModelCache(models=[ModelRepo(repo_id=model)]), ) with TemporaryDirectory() as tmp_dir: @@ -216,14 +224,14 @@ def test_correct_nested_gcs_files_accessed_for_caching(mock_list_bucket_files): build_path.mkdir(parents=True, exist_ok=True) model_files, files_to_cache = get_files_to_cache(config, truss_path, build_path) - print(files_to_cache) + files_to_cache = flatten_cached_files(files_to_cache) assert ( - "/app/hf_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-001-of-002.bin" + "/app/model_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-001-of-002.bin" in files_to_cache ) assert ( - "/app/hf_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-002-of-002.bin" + "/app/model_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-002-of-002.bin" in files_to_cache ) @@ -235,7 +243,7 @@ def test_correct_nested_gcs_files_accessed_for_caching(mock_list_bucket_files): ) -@patch("truss.contexts.image_builder.serving_image_builder.list_s3_bucket_files") +@patch("truss.contexts.image_builder.serving_image_builder.S3Cache.list_files") def test_correct_nested_s3_files_accessed_for_caching(mock_list_bucket_files): mock_list_bucket_files.return_value = [ "folder_a/folder_b/fake_model-001-of-002.bin", @@ -245,7 +253,7 @@ def test_correct_nested_s3_files_accessed_for_caching(mock_list_bucket_files): config = TrussConfig( python_version="py39", - hf_cache=HuggingFaceCache(models=[HuggingFaceModel(repo_id=model)]), + model_cache=ModelCache(models=[ModelRepo(repo_id=model)]), ) with TemporaryDirectory() as tmp_dir: @@ -254,14 +262,14 @@ def test_correct_nested_s3_files_accessed_for_caching(mock_list_bucket_files): build_path.mkdir(parents=True, exist_ok=True) model_files, files_to_cache = get_files_to_cache(config, truss_path, build_path) - print(files_to_cache) + files_to_cache = flatten_cached_files(files_to_cache) assert ( - "/app/hf_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-001-of-002.bin" + "/app/model_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-001-of-002.bin" in files_to_cache ) assert ( - "/app/hf_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-002-of-002.bin" + "/app/model_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-002-of-002.bin" in files_to_cache ) @@ -317,7 +325,7 @@ def test_truss_server_caching_truss(): assert "Downloading model.safetensors:" not in container.logs() -def test_hf_cache_dockerfile(): +def test_model_cache_dockerfile(): truss_root = Path(__file__).parent.parent.parent.parent.parent.resolve() / "truss" truss_dir = truss_root / "test_data" / "test_truss_server_caching_truss" tr = TrussHandle(truss_dir) diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index a9ecd4154..8a3366b55 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -10,8 +10,8 @@ Accelerator, AcceleratorSpec, BaseImage, - HuggingFaceCache, - HuggingFaceModel, + ModelCache, + ModelRepo, Resources, Train, TrussConfig, @@ -206,56 +206,64 @@ def test_non_default_train(): assert new_config == config.to_dict(verbose=False) +def test_null_model_cache_key(): + config_yaml_dict = {"model_cache": None} + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: + yaml.safe_dump(config_yaml_dict, tmp_file) + config = TrussConfig.from_yaml(Path(tmp_file.name)) + assert config.model_cache == ModelCache.from_list([]) + + def test_null_hf_cache_key(): config_yaml_dict = {"hf_cache": None} with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: yaml.safe_dump(config_yaml_dict, tmp_file) config = TrussConfig.from_yaml(Path(tmp_file.name)) - assert config.hf_cache == HuggingFaceCache.from_list([]) + assert config.model_cache == ModelCache.from_list([]) def test_huggingface_cache_single_model_default_revision(): config = TrussConfig( python_version="py39", requirements=[], - hf_cache=HuggingFaceCache(models=[HuggingFaceModel("test/model")]), + model_cache=ModelCache(models=[ModelRepo("test/model")]), ) new_config = generate_default_config() - new_config["hf_cache"] = [ + new_config["model_cache"] = [ { "repo_id": "test/model", } ] assert new_config == config.to_dict(verbose=False) - assert config.to_dict(verbose=True)["hf_cache"][0].get("revision") is None + assert config.to_dict(verbose=True)["model_cache"][0].get("revision") is None def test_huggingface_cache_single_model_non_default_revision(): config = TrussConfig( python_version="py39", requirements=[], - hf_cache=HuggingFaceCache(models=[HuggingFaceModel("test/model", "not-main")]), + model_cache=ModelCache(models=[ModelRepo("test/model", "not-main")]), ) - assert config.to_dict(verbose=False)["hf_cache"][0].get("revision") == "not-main" + assert config.to_dict(verbose=False)["model_cache"][0].get("revision") == "not-main" def test_huggingface_cache_multiple_models_default_revision(): config = TrussConfig( python_version="py39", requirements=[], - hf_cache=HuggingFaceCache( + model_cache=ModelCache( models=[ - HuggingFaceModel("test/model1", "main"), - HuggingFaceModel("test/model2"), + ModelRepo("test/model1", "main"), + ModelRepo("test/model2"), ] ), ) new_config = generate_default_config() - new_config["hf_cache"] = [ + new_config["model_cache"] = [ {"repo_id": "test/model1", "revision": "main"}, { "repo_id": "test/model2", @@ -263,24 +271,24 @@ def test_huggingface_cache_multiple_models_default_revision(): ] assert new_config == config.to_dict(verbose=False) - assert config.to_dict(verbose=True)["hf_cache"][0].get("revision") == "main" - assert config.to_dict(verbose=True)["hf_cache"][1].get("revision") is None + assert config.to_dict(verbose=True)["model_cache"][0].get("revision") == "main" + assert config.to_dict(verbose=True)["model_cache"][1].get("revision") is None def test_huggingface_cache_multiple_models_mixed_revision(): config = TrussConfig( python_version="py39", requirements=[], - hf_cache=HuggingFaceCache( + model_cache=ModelCache( models=[ - HuggingFaceModel("test/model1"), - HuggingFaceModel("test/model2", "not-main2"), + ModelRepo("test/model1"), + ModelRepo("test/model2", "not-main2"), ] ), ) new_config = generate_default_config() - new_config["hf_cache"] = [ + new_config["model_cache"] = [ { "repo_id": "test/model1", }, @@ -288,8 +296,8 @@ def test_huggingface_cache_multiple_models_mixed_revision(): ] assert new_config == config.to_dict(verbose=False) - assert config.to_dict(verbose=True)["hf_cache"][0].get("revision") is None - assert config.to_dict(verbose=True)["hf_cache"][1].get("revision") == "not-main2" + assert config.to_dict(verbose=True)["model_cache"][0].get("revision") is None + assert config.to_dict(verbose=True)["model_cache"][1].get("revision") == "not-main2" def test_empty_config(): diff --git a/truss/tests/test_truss_handle.py b/truss/tests/test_truss_handle.py index f18ec2705..706a905a6 100644 --- a/truss/tests/test_truss_handle.py +++ b/truss/tests/test_truss_handle.py @@ -121,6 +121,7 @@ def test_build_serving_docker_image_from_user_base_image_live_reload( assert "It returned with code 1" in str(exc) +@pytest.mark.skip(reason="Training Integration tests not supported") @pytest.mark.integration def test_build_training_docker_image_from_user_base_image(custom_model_truss_dir): th = TrussHandle(custom_model_truss_dir) @@ -235,6 +236,7 @@ def test_docker_predict_model_with_external_packages( assert result == [1, 1] +@pytest.mark.skip(reason="Training Integration tests not supported") @pytest.mark.integration def test_docker_train(variables_to_artifacts_training_truss): th = TrussHandle(variables_to_artifacts_training_truss) diff --git a/truss/truss_config.py b/truss/truss_config.py index 4b921b9d9..ff5cd0b44 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -1,3 +1,4 @@ +import logging from dataclasses import _MISSING_TYPE, dataclass, field, fields from enum import Enum from pathlib import Path @@ -39,6 +40,10 @@ DEFAULT_BLOB_BACKEND = HTTP_PUBLIC_BLOB_BACKEND +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + class Accelerator(Enum): T4 = "T4" @@ -78,7 +83,7 @@ def from_str(acc_spec: Optional[str]): @dataclass -class HuggingFaceModel: +class ModelRepo: repo_id: str = "" revision: Optional[str] = None allow_patterns: Optional[List[str]] = None @@ -94,7 +99,7 @@ def from_dict(d): allow_patterns = d.get("allow_patterns", None) ignore_pattenrs = d.get("ignore_patterns", None) - return HuggingFaceModel( + return ModelRepo( repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, @@ -117,12 +122,12 @@ def to_dict(self, verbose=False): @dataclass -class HuggingFaceCache: - models: List[HuggingFaceModel] = field(default_factory=list) +class ModelCache: + models: List[ModelRepo] = field(default_factory=list) @staticmethod - def from_list(items: List[Dict[str, str]]) -> "HuggingFaceCache": - return HuggingFaceCache([HuggingFaceModel.from_dict(item) for item in items]) + def from_list(items: List[Dict[str, str]]) -> "ModelCache": + return ModelCache([ModelRepo.from_dict(item) for item in items]) def to_list(self, verbose=False) -> List[Dict[str, str]]: return [model.to_dict(verbose=verbose) for model in self.models] @@ -442,7 +447,7 @@ class TrussConfig: train: Train = field(default_factory=Train) base_image: Optional[BaseImage] = None - hf_cache: HuggingFaceCache = field(default_factory=HuggingFaceCache) + model_cache: ModelCache = field(default_factory=ModelCache) @property def canonical_python_version(self) -> str: @@ -490,9 +495,9 @@ def from_dict(d): d.get("external_data"), ExternalData.from_list ), base_image=transform_optional(d.get("base_image"), BaseImage.from_dict), - hf_cache=transform_optional( - d.get("hf_cache") or [], - HuggingFaceCache.from_list, + model_cache=transform_optional( + d.get("model_cache") or d.get("hf_cache") or [], + ModelCache.from_list, ), ) config.validate() @@ -502,6 +507,12 @@ def from_dict(d): def from_yaml(yaml_path: Path): with yaml_path.open() as yaml_file: raw_data = yaml.safe_load(yaml_file) or {} + if "hf_cache" in raw_data: + logger.warning( + """Warning: `hf_cache` is deprecated in favor of `model_cache`. + Everything will run as before, but if you are pulling weights from S3 or GCS, they will be + stored at /app/model_cache instead of /app/hf_cache as before.""" + ) return TrussConfig.from_dict(raw_data) def write_to_yaml_file(self, path: Path, verbose: bool = True): @@ -574,8 +585,8 @@ def obj_to_dict(obj, verbose: bool = False): d["external_data"] = transform_optional( field_curr_value, lambda data: data.to_list() ) - elif isinstance(field_curr_value, HuggingFaceCache): - d["hf_cache"] = transform_optional( + elif isinstance(field_curr_value, ModelCache): + d["model_cache"] = transform_optional( field_curr_value, lambda data: data.to_list(verbose=verbose) ) else: