From 1c6eb2ea161aa642e89cd5d65222c598e53ba8fb Mon Sep 17 00:00:00 2001 From: jerryzhuang Date: Wed, 13 Nov 2024 05:53:56 +1100 Subject: [PATCH] feat: support adaptive `max_model_len` (#657) **Reason for Change**: - upgrade base image to python-3.12 - added new model: phi-3.5-mini - support adaptive `max_model_len` --------- Signed-off-by: jerryzhuang Co-authored-by: Fei Guo --- docker/presets/models/tfs/Dockerfile | 2 +- pkg/utils/test/testModel.go | 2 + .../inference/preset-inferences_test.go | 6 +- presets/inference/vllm/inference_api.py | 62 +++++++++++++++++-- presets/models/phi3/model.go | 45 ++++++++++++++ presets/models/supported_models.yaml | 2 +- .../falcon-40b-instruct.yaml | 6 +- .../test/manifests/falcon-40b/falcon-40b.yaml | 6 +- .../falcon-7b-adapter/falcon-7b-adapter.yaml | 6 +- .../falcon-7b-instruct.yaml | 6 +- .../falcon-7b-with-adapter/falcon-7b.yaml | 2 +- .../test/manifests/falcon-7b/falcon-7b.yaml | 6 +- .../llama-2-13b-chat/llama-2-13b-chat.yaml | 6 +- .../manifests/llama-2-13b/llama-2-13b.yaml | 6 +- .../llama-2-7b-chat/llama-2-7b-chat.yaml | 6 +- .../test/manifests/llama-2-7b/llama-2-7b.yaml | 6 +- .../mistral-7b-instruct.yaml | 6 +- .../test/manifests/mistral-7b/mistral-7b.yaml | 6 +- presets/test/manifests/phi-2/phi-2.yaml | 6 +- .../phi-3-medium-128k-instruct.yaml | 6 +- .../phi-3-medium-4k-instruct.yaml | 6 +- .../phi-3-mini-128k-instruct.yaml | 6 +- .../phi-3-mini-4k-instruct.yaml | 6 +- .../phi-3-small-128k-instruct.yaml | 6 +- .../phi-3-small-8k-instruct.yaml | 6 +- 25 files changed, 165 insertions(+), 64 deletions(-) diff --git a/docker/presets/models/tfs/Dockerfile b/docker/presets/models/tfs/Dockerfile index da5a3b732..3ed435ba9 100644 --- a/docker/presets/models/tfs/Dockerfile +++ b/docker/presets/models/tfs/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-slim@sha256:684b1aaf96a7942b3c3af438d162e0baa3510aa7af25ad76d238e0c746bdec79 +FROM python:3.12-slim ARG WEIGHTS_PATH ARG MODEL_TYPE diff --git a/pkg/utils/test/testModel.go b/pkg/utils/test/testModel.go index d4ef5faf6..2a820191f 100644 --- a/pkg/utils/test/testModel.go +++ b/pkg/utils/test/testModel.go @@ -16,6 +16,7 @@ func (*testModel) GetInferenceParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: "python3", } } func (*testModel) GetTuningParameters() *model.PresetParam { @@ -37,6 +38,7 @@ func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: "python3", } } func (*testDistributedModel) GetTuningParameters() *model.PresetParam { diff --git a/pkg/workspace/inference/preset-inferences_test.go b/pkg/workspace/inference/preset-inferences_test.go index 0a532ad20..bebbab4e7 100644 --- a/pkg/workspace/inference/preset-inferences_test.go +++ b/pkg/workspace/inference/preset-inferences_test.go @@ -46,7 +46,7 @@ func TestCreatePresetInference(t *testing.T) { workload: "Deployment", // No BaseCommand, TorchRunParams, TorchRunRdzvParams, or ModelRunParams // So expected cmd consists of shell command and inference file - expectedCmd: "/bin/sh -c inference_api.py", + expectedCmd: "/bin/sh -c python3 inference_api.py", hasAdapters: false, }, @@ -58,7 +58,7 @@ func TestCreatePresetInference(t *testing.T) { c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil) }, workload: "StatefulSet", - expectedCmd: "/bin/sh -c inference_api.py", + expectedCmd: "/bin/sh -c python3 inference_api.py", hasAdapters: false, }, @@ -69,7 +69,7 @@ func TestCreatePresetInference(t *testing.T) { c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil) }, workload: "Deployment", - expectedCmd: "/bin/sh -c inference_api.py", + expectedCmd: "/bin/sh -c python3 inference_api.py", hasAdapters: true, expectedVolume: "adapter-volume", }, diff --git a/presets/inference/vllm/inference_api.py b/presets/inference/vllm/inference_api.py index 5b9a2d881..10fc3e312 100644 --- a/presets/inference/vllm/inference_api.py +++ b/presets/inference/vllm/inference_api.py @@ -1,11 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging +import gc import os import uvloop +import torch from vllm.utils import FlexibleArgumentParser import vllm.entrypoints.openai.api_server as api_server +from vllm.engine.llm_engine import (LLMEngine, EngineArgs, EngineConfig) # Initialize logger logger = logging.getLogger(__name__) @@ -26,15 +29,44 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # See https://docs.vllm.ai/en/latest/models/engine_args.html for more args engine_default_args = { "model": "/workspace/vllm/weights", - "cpu-offload-gb": 0, - "gpu-memory-utilization": 0.9, - "swap-space": 4, - "disable-log-stats": False, + "cpu_offload_gb": 0, + "gpu_memory_utilization": 0.95, + "swap_space": 4, + "disable_log_stats": False, + "uvicorn_log_level": "error" } parser.set_defaults(**engine_default_args) return parser +def find_max_available_seq_len(engine_config: EngineConfig) -> int: + """ + Load model and run profiler to find max available seq len. + """ + # see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L335 + executor_class = LLMEngine._get_executor_cls(engine_config) + executor = executor_class( + model_config=engine_config.model_config, + cache_config=engine_config.cache_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + lora_config=engine_config.lora_config, + speculative_config=engine_config.speculative_config, + load_config=engine_config.load_config, + prompt_adapter_config=engine_config.prompt_adapter_config, + observability_config=engine_config.observability_config, + ) + + # see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L477 + num_gpu_blocks, _ = executor.determine_num_available_blocks() + + # release memory + del executor + gc.collect() + torch.cuda.empty_cache() + + return engine_config.cache_config.block_size * num_gpu_blocks if __name__ == "__main__": parser = FlexibleArgumentParser(description='vLLM serving server') @@ -42,6 +74,28 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser = make_arg_parser(parser) args = parser.parse_args() + if args.max_model_len is None: + engine_args = EngineArgs.from_cli_args(args) + # read the model config from hf weights path. + # vllm will perform different parser for different model architectures + # and read it into a unified EngineConfig. + engine_config = engine_args.create_engine_config() + + logger.info("Try run profiler to find max available seq len") + available_seq_len = find_max_available_seq_len(engine_config) + # see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/worker.py#L262 + if available_seq_len <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_model_len = engine_config.model_config.max_model_len + if available_seq_len > max_model_len: + available_seq_len = max_model_len + + if available_seq_len != max_model_len: + logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}") + args.max_model_len = available_seq_len + # Run the serving server logger.info(f"Starting server on port {args.port}") # See https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html for more diff --git a/presets/models/phi3/model.go b/presets/models/phi3/model.go index 0cde3578c..abf63003b 100644 --- a/presets/models/phi3/model.go +++ b/presets/models/phi3/model.go @@ -28,6 +28,10 @@ func init() { Name: PresetPhi3Medium128kModel, Instance: &phi3MediumB, }) + plugin.KaitoModelRegister.Register(&plugin.Registration{ + Name: PresetPhi3_5MiniInstruct, + Instance: &phi3_5MiniC, + }) } var ( @@ -35,12 +39,14 @@ var ( PresetPhi3Mini128kModel = "phi-3-mini-128k-instruct" PresetPhi3Medium4kModel = "phi-3-medium-4k-instruct" PresetPhi3Medium128kModel = "phi-3-medium-128k-instruct" + PresetPhi3_5MiniInstruct = "phi-3.5-mini-instruct" PresetPhiTagMap = map[string]string{ "Phi3Mini4kInstruct": "0.0.2", "Phi3Mini128kInstruct": "0.0.2", "Phi3Medium4kInstruct": "0.0.2", "Phi3Medium128kInstruct": "0.0.2", + "Phi3_5MiniInstruct": "0.0.1", } baseCommandPresetPhiInference = "accelerate launch" @@ -130,6 +136,45 @@ func (*phi3Mini128KInst) SupportTuning() bool { return true } +var phi3_5MiniC phi3_5MiniInst + +type phi3_5MiniInst struct{} + +func (*phi3_5MiniInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Phi3_5", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "8Gi", + PerGPUMemoryRequirement: "0Gi", // We run Phi using native vertical model parallel, no per GPU memory requirement. + TorchRunParams: inference.DefaultAccelerateParams, + ModelRunParams: phiRunParams, + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetPhiInference, + Tag: PresetPhiTagMap["Phi3_5MiniInstruct"], + } +} +func (*phi3_5MiniInst) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Phi3_5", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "72Gi", + PerGPUMemoryRequirement: "72Gi", + // TorchRunParams: inference.DefaultAccelerateParams, + // ModelRunParams: phiRunParams, + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetPhiTuning, + Tag: PresetPhiTagMap["Phi3_5MiniInstruct"], + } +} +func (*phi3_5MiniInst) SupportDistributedInference() bool { return false } +func (*phi3_5MiniInst) SupportTuning() bool { + return true +} + var phi3MediumA Phi3Medium4kInstruct type Phi3Medium4kInstruct struct{} diff --git a/presets/models/supported_models.yaml b/presets/models/supported_models.yaml index be25c83f5..db49641b8 100644 --- a/presets/models/supported_models.yaml +++ b/presets/models/supported_models.yaml @@ -134,4 +134,4 @@ models: tag: 0.0.2 # Tag history: # 0.0.2 - Add Logging & Metrics Server - # 0.0.1 - Initial Release + # 0.0.1 - Initial Release \ No newline at end of file diff --git a/presets/test/manifests/falcon-40b-instruct/falcon-40b-instruct.yaml b/presets/test/manifests/falcon-40b-instruct/falcon-40b-instruct.yaml index 13060349f..37f3c6a6b 100644 --- a/presets/test/manifests/falcon-40b-instruct/falcon-40b-instruct.yaml +++ b/presets/test/manifests/falcon-40b-instruct/falcon-40b-instruct.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 4 # Requesting 4 GPUs @@ -27,13 +27,13 @@ spec: nvidia.com/gpu: 4 livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/falcon-40b/falcon-40b.yaml b/presets/test/manifests/falcon-40b/falcon-40b.yaml index a4cb2d524..a1c11af0e 100644 --- a/presets/test/manifests/falcon-40b/falcon-40b.yaml +++ b/presets/test/manifests/falcon-40b/falcon-40b.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 4 # Requesting 4 GPUs @@ -27,13 +27,13 @@ spec: nvidia.com/gpu: 4 livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml b/presets/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml index e160dfc7a..c48a1c2cf 100644 --- a/presets/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml +++ b/presets/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml @@ -30,7 +30,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 @@ -38,13 +38,13 @@ spec: nvidia.com/gpu: 2 # Requesting 2 GPUs livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/falcon-7b-instruct/falcon-7b-instruct.yaml b/presets/test/manifests/falcon-7b-instruct/falcon-7b-instruct.yaml index 93b37444e..cbf7f6f7f 100644 --- a/presets/test/manifests/falcon-7b-instruct/falcon-7b-instruct.yaml +++ b/presets/test/manifests/falcon-7b-instruct/falcon-7b-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 2 # Requesting 2 GPUs livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/falcon-7b-with-adapter/falcon-7b.yaml b/presets/test/manifests/falcon-7b-with-adapter/falcon-7b.yaml index 3f2212ab6..349a377a0 100644 --- a/presets/test/manifests/falcon-7b-with-adapter/falcon-7b.yaml +++ b/presets/test/manifests/falcon-7b-with-adapter/falcon-7b.yaml @@ -29,7 +29,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 diff --git a/presets/test/manifests/falcon-7b/falcon-7b.yaml b/presets/test/manifests/falcon-7b/falcon-7b.yaml index ed86043e7..f985124ea 100644 --- a/presets/test/manifests/falcon-7b/falcon-7b.yaml +++ b/presets/test/manifests/falcon-7b/falcon-7b.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 2 # Requesting 2 GPUs livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/llama-2-13b-chat/llama-2-13b-chat.yaml b/presets/test/manifests/llama-2-13b-chat/llama-2-13b-chat.yaml index 973f6d238..61a309821 100644 --- a/presets/test/manifests/llama-2-13b-chat/llama-2-13b-chat.yaml +++ b/presets/test/manifests/llama-2-13b-chat/llama-2-13b-chat.yaml @@ -35,7 +35,7 @@ spec: - | echo "MASTER_ADDR: $MASTER_ADDR" NODE_RANK=$(echo $HOSTNAME | grep -o '[^-]*$') - cd /workspace/llama/llama-2 && torchrun --nnodes 2 --nproc_per_node 1 --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port 29500 inference_api.py + cd /workspace/llama/llama-2 && torchrun --nnodes 2 --nproc_per_node 1 --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port 29500 /workspace/tfs/inference_api.py resources: limits: nvidia.com/gpu: "1" @@ -43,13 +43,13 @@ spec: nvidia.com/gpu: "1" livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/llama-2-13b/llama-2-13b.yaml b/presets/test/manifests/llama-2-13b/llama-2-13b.yaml index 46c609bbb..daff4cd0a 100644 --- a/presets/test/manifests/llama-2-13b/llama-2-13b.yaml +++ b/presets/test/manifests/llama-2-13b/llama-2-13b.yaml @@ -35,7 +35,7 @@ spec: - | echo "MASTER_ADDR: $MASTER_ADDR" NODE_RANK=$(echo $HOSTNAME | grep -o '[^-]*$') - cd /workspace/llama/llama-2 && torchrun --nnodes 2 --nproc_per_node 1 --node_rank $NODE_RANK --master-addr $MASTER_ADDR --master-port 29500 inference_api.py + cd /workspace/llama/llama-2 && torchrun --nnodes 2 --nproc_per_node 1 --node_rank $NODE_RANK --master-addr $MASTER_ADDR --master-port 29500 /workspace/tfs/inference_api.py resources: limits: nvidia.com/gpu: "1" @@ -43,13 +43,13 @@ spec: nvidia.com/gpu: "1" livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/llama-2-7b-chat/llama-2-7b-chat.yaml b/presets/test/manifests/llama-2-7b-chat/llama-2-7b-chat.yaml index f26b003a8..61ec695dc 100644 --- a/presets/test/manifests/llama-2-7b-chat/llama-2-7b-chat.yaml +++ b/presets/test/manifests/llama-2-7b-chat/llama-2-7b-chat.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - cd /workspace/llama/llama-2 && torchrun inference_api.py + - cd /workspace/llama/llama-2 && torchrun /workspace/tfs/inference_api.py resources: limits: nvidia.com/gpu: "1" @@ -27,13 +27,13 @@ spec: nvidia.com/gpu: "1" livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/llama-2-7b/llama-2-7b.yaml b/presets/test/manifests/llama-2-7b/llama-2-7b.yaml index f68d43c64..af295b8db 100644 --- a/presets/test/manifests/llama-2-7b/llama-2-7b.yaml +++ b/presets/test/manifests/llama-2-7b/llama-2-7b.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - cd /workspace/llama/llama-2 && torchrun inference_api.py + - cd /workspace/llama/llama-2 && torchrun /workspace/tfs/inference_api.py resources: limits: nvidia.com/gpu: "1" @@ -27,13 +27,13 @@ spec: nvidia.com/gpu: "1" livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/mistral-7b-instruct/mistral-7b-instruct.yaml b/presets/test/manifests/mistral-7b-instruct/mistral-7b-instruct.yaml index 35fad823a..a64780db9 100644 --- a/presets/test/manifests/mistral-7b-instruct/mistral-7b-instruct.yaml +++ b/presets/test/manifests/mistral-7b-instruct/mistral-7b-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 2 # Requesting 2 GPUs livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/mistral-7b/mistral-7b.yaml b/presets/test/manifests/mistral-7b/mistral-7b.yaml index 5521ef2f8..219f42ff5 100644 --- a/presets/test/manifests/mistral-7b/mistral-7b.yaml +++ b/presets/test/manifests/mistral-7b/mistral-7b.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 2 # Requesting 2 GPUs livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/phi-2/phi-2.yaml b/presets/test/manifests/phi-2/phi-2.yaml index edd7de0a1..cbc6f94e7 100644 --- a/presets/test/manifests/phi-2/phi-2.yaml +++ b/presets/test/manifests/phi-2/phi-2.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 1 # Requesting 1 GPU livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct.yaml b/presets/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct.yaml index a978521ea..0adb122e4 100644 --- a/presets/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct.yaml +++ b/presets/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 1 # Requesting 1 GPU livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct.yaml b/presets/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct.yaml index f3dc2e158..1d0d64e47 100644 --- a/presets/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct.yaml +++ b/presets/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 1 # Requesting 1 GPU livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct.yaml b/presets/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct.yaml index 14f6e39b2..cf8898015 100644 --- a/presets/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct.yaml +++ b/presets/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 1 # Requesting 1 GPU livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct.yaml b/presets/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct.yaml index 9e7f74e2b..1d7069a38 100644 --- a/presets/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct.yaml +++ b/presets/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 1 # Requesting 1 GPU livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml b/presets/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml index ee74638ca..1827155f4 100644 --- a/presets/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml +++ b/presets/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 1 # Requesting 1 GPU livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10 diff --git a/presets/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml b/presets/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml index d05c51337..1f515cc6a 100644 --- a/presets/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml +++ b/presets/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 @@ -26,13 +26,13 @@ spec: nvidia.com/gpu: 1 # Requesting 1 GPU livenessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 600 # 10 Min periodSeconds: 10 readinessProbe: httpGet: - path: /healthz + path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 10