Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support LoRA adapters for vllm runtime #774

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions presets/workspace/inference/vllm/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import logging
import gc
import os
from typing import Callable
from typing import Callable, Optional, List

import uvloop
import torch
from vllm.utils import FlexibleArgumentParser
import vllm.entrypoints.openai.api_server as api_server
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.engine.llm_engine import (LLMEngine, EngineArgs, EngineConfig)
from vllm.executor.executor_base import ExecutorBase

Expand All @@ -26,12 +27,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
port = 5000 + local_rank # Adjust port based on local rank

server_default_args = {
"disable-frontend-multiprocessing": False,
"port": port
"disable_frontend_multiprocessing": False,
"port": port,
}
parser.set_defaults(**server_default_args)

# See https://docs.vllm.ai/en/latest/models/engine_args.html for more args
# See https://docs.vllm.ai/en/stable/models/engine_args.html for more args
engine_default_args = {
"model": "/workspace/vllm/weights",
"cpu_offload_gb": 0,
Expand All @@ -42,9 +43,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
}
parser.set_defaults(**engine_default_args)

# KAITO only args
# They should start with "kaito-" prefix to avoid conflict with vllm args
parser.add_argument("--kaito-adapters-dir", type=str, default="/mnt/adapter", help="Directory where adapters are stored in KAITO preset.")

return parser

def find_max_available_seq_len(engine_config: EngineConfig) -> int:
def load_lora_adapters(adapters_dir: str) -> Optional[LoRAModulePath]:
lora_list: List[LoRAModulePath] = []

logger.info(f"Loading LoRA adapters from {adapters_dir}")
if not os.path.exists(adapters_dir):
return lora_list

for adapter in os.listdir(adapters_dir):
adapter_path = os.path.join(adapters_dir, adapter)
if os.path.isdir(adapter_path):
lora_list.append(LoRAModulePath(adapter, adapter_path))

return lora_list

def find_max_available_seq_len(engine_config: EngineConfig, max_probe_steps: int) -> int:
"""
Load model and run profiler to find max available seq len.
"""
Expand All @@ -63,13 +82,6 @@ def find_max_available_seq_len(engine_config: EngineConfig) -> int:
observability_config=engine_config.observability_config,
)

max_probe_steps = 6
if os.getenv("MAX_PROBE_STEPS") is not None:
try:
max_probe_steps = int(os.getenv("MAX_PROBE_STEPS"))
except ValueError:
raise ValueError("MAX_PROBE_STEPS must be an integer.")

model_max_blocks = int(engine_config.model_config.max_model_len / engine_config.cache_config.block_size)
res = binary_search_with_limited_steps(model_max_blocks, max_probe_steps, lambda x: is_context_length_safe(executor, x))

Expand Down Expand Up @@ -131,23 +143,34 @@ def is_context_length_safe(executor: ExecutorBase, num_gpu_blocks: int) -> bool:
parser = make_arg_parser(parser)
args = parser.parse_args()

# set LoRA adapters
if args.lora_modules is None:
args.lora_modules = load_lora_adapters(args.kaito_adapters_dir)

if args.max_model_len is None:
max_probe_steps = 6
if os.getenv("MAX_PROBE_STEPS") is not None:
try:
max_probe_steps = int(os.getenv("MAX_PROBE_STEPS"))
except ValueError:
raise ValueError("MAX_PROBE_STEPS must be an integer.")

engine_args = EngineArgs.from_cli_args(args)
# read the model config from hf weights path.
# vllm will perform different parser for different model architectures
# and read it into a unified EngineConfig.
engine_config = engine_args.create_engine_config()

logger.info("Try run profiler to find max available seq len")
available_seq_len = find_max_available_seq_len(engine_config)
max_model_len = engine_config.model_config.max_model_len
available_seq_len = max_model_len
if max_probe_steps > 0:
logger.info("Try run profiler to find max available seq len")
available_seq_len = find_max_available_seq_len(engine_config, max_probe_steps)
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/worker.py#L262
if available_seq_len <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_model_len = engine_config.model_config.max_model_len
if available_seq_len > max_model_len:
available_seq_len = max_model_len

if available_seq_len != max_model_len:
logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,42 @@
sys.path.append(parent_dir)

from inference_api import binary_search_with_limited_steps
from huggingface_hub import snapshot_download
import shutil

TEST_MODEL = "facebook/opt-125m"
TEST_ADAPTER_NAME1 = "mylora1"
TEST_ADAPTER_NAME2 = "mylora2"
CHAT_TEMPLATE = ("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}"
"{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"
"{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}")

@pytest.fixture
def setup_server(request):
@pytest.fixture(scope="session", autouse=True)
def setup_server(request, tmp_path_factory, autouse=True):
if os.getenv("DEVICE") == "cpu":
pytest.skip("Skipping test on cpu device")
print("\n>>> Doing setup")
port = find_available_port()
global TEST_PORT
TEST_PORT = port

tmp_file_dir = tmp_path_factory.mktemp("adapter")
print(f"Downloading adapter image to {tmp_file_dir}")
snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME1))
snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME2))

args = [
"python3",
os.path.join(parent_dir, "inference_api.py"),
"--model", TEST_MODEL,
"--chat-template", CHAT_TEMPLATE,
"--port", str(TEST_PORT)
"--port", str(TEST_PORT),
"--kaito-adapters-dir", tmp_file_dir,
]
print(f">>> Starting server on port {TEST_PORT}")
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
env = os.environ.copy()
env["MAX_PROBE_STEPS"] = "0"
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)

def fin():
process.terminate()
Expand All @@ -47,6 +59,8 @@ def fin():
stdout = process.stdout.read().decode()
print(f">>> Server stdout: {stdout}")
print ("\n>>> Doing teardown")
print(f"Removing adapter image from {tmp_file_dir}")
shutil.rmtree(tmp_file_dir)

if not is_port_open("localhost", TEST_PORT):
fin()
Expand Down Expand Up @@ -115,6 +129,17 @@ def test_chat_completions_api(setup_server):
assert "content" in choice["message"], "Each message should contain a 'content' key"
assert len(choice["message"]["content"]) > 0, "The completion text should not be empty"

def test_model_list(setup_server):
response = requests.get(f"http://127.0.0.1:{TEST_PORT}/v1/models")
data = response.json()

assert "data" in data, f"The response should contain a 'data' key, but got {data}"
assert len(data["data"]) == 3, f"The response should contain three models, but got {data['data']}"
assert data["data"][0]["id"] == TEST_MODEL, f"The first model should be the test model, but got {data['data'][0]['id']}"
assert data["data"][1]["id"] == TEST_ADAPTER_NAME2, f"The second model should be the test adapter, but got {data['data'][1]['id']}"
assert data["data"][1]["parent"] == TEST_MODEL, f"The second model should have the test model as parent, but got {data['data'][1]['parent']}"
assert data["data"][2]["id"] == TEST_ADAPTER_NAME1, f"The third model should be the test adapter, but got {data['data'][2]['id']}"
assert data["data"][2]["parent"] == TEST_MODEL, f"The third model should have the test model as parent, but got {data['data'][2]['parent']}"

def test_binary_search_with_limited_steps():

Expand Down
Loading