Skip to content

Commit

Permalink
Accept token_ids in Shortfin LLM Server (#862)
Browse files Browse the repository at this point in the history
The mlperf harness that I'm building sends over input_ids instead of
text. We had the option to send input_ids in a request, but never
actually implemented it. This makes it so that you can either send a
string prompt or a pre-tokenized prompt.
  • Loading branch information
stbaione authored Jan 24, 2025
1 parent 782ec5b commit 06bf638
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 22 deletions.
16 changes: 13 additions & 3 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Test fixtures and configurations."""

import hashlib
import pytest
from pathlib import Path
import hashlib
from tokenizers import Tokenizer, Encoding

from ..model_management import (
ModelProcessor,
Expand All @@ -25,8 +27,10 @@
),
"llama3.1_8b": ModelConfig(
source=ModelSource.LOCAL,
local_path=Path("/data/llama3.1/8b/llama8b_f16.irpa"),
model_file="llama8b_f16.irpa",
local_path=Path(
"/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa"
),
model_file="llama3.1_8b_fp16_instruct.irpa",
tokenizer_id="NousResearch/Meta-Llama-3.1-8B",
batch_sizes=(1, 4),
device_settings=device_settings.CPU,
Expand Down Expand Up @@ -89,3 +93,9 @@ def server(model_artifacts, request):

process.terminate()
process.wait()


@pytest.fixture(scope="module")
def encoded_prompt(model_artifacts: ModelArtifacts, request) -> list[int]:
tokenizer = Tokenizer.from_file(str(model_artifacts.tokenizer_path))
return tokenizer.encode(request.param).ids
82 changes: 66 additions & 16 deletions app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Main test module for LLM server functionality."""

from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
import pytest
import requests
import uuid
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Any
import uuid

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,16 +40,10 @@ class TestLLMServer:
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "none"},
marks=pytest.mark.xfail(
reason="llama3.1_8b irpa file not available on CI machine"
),
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "trie"},
marks=pytest.mark.xfail(
reason="llama3.1_8b irpa file not available on CI machine"
),
),
],
ids=[
Expand Down Expand Up @@ -78,6 +72,58 @@ def test_basic_generation(self, server: tuple[Any, int]) -> None:
message=f"Generation did not match expected pattern.\nExpected to start with: {expected_prefix}\nActual response: {response}",
)

@pytest.mark.parametrize(
"model_artifacts,server,encoded_prompt",
[
(
"open_llama_3b",
{"model": "open_llama_3b", "prefix_sharing": "none"},
"0 1 2 3 4 5 ",
),
(
"open_llama_3b",
{"model": "open_llama_3b", "prefix_sharing": "trie"},
"0 1 2 3 4 5 ",
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "none"},
"0 1 2 3 4 5 ",
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "trie"},
"0 1 2 3 4 5 ",
),
],
ids=[
"open_llama_3b_none_input_ids",
"open_llama_3b_trie_input_ids",
"llama31_8b_none_input_ids",
"llama31_8b_trie_input_ids",
],
indirect=True,
)
def test_basic_generation_input_ids(
self, server: tuple[Any, int], encoded_prompt
) -> None:
"""Tests basic text generation capabilities.
Args:
server: Tuple of (process, port) from server fixture
"""
process, port = server
assert process.poll() is None, "Server process terminated unexpectedly"

response = self._generate(encoded_prompt, port, input_ids=True)
expected_prefix = "6 7 8"
if not response.startswith(expected_prefix):
raise AccuracyValidationException(
expected=f"{expected_prefix}...",
actual=response,
message=f"Generation did not match expected pattern.\nExpected to start with: {expected_prefix}\nActual response: {response}",
)

@pytest.mark.parametrize(
"model_artifacts,server",
[
Expand Down Expand Up @@ -121,7 +167,7 @@ def test_concurrent_generation(
message=f"Concurrent generation did not match expected pattern.\nExpected to start with: {expected_prefix}\nActual response: {response}",
)

def _generate(self, prompt: str, port: int) -> str:
def _generate(self, prompt: str | list[int], port: int, input_ids=False) -> str:
"""Helper method to make generation request to server.
Args:
Expand All @@ -135,15 +181,19 @@ def _generate(self, prompt: str, port: int) -> str:
requests.exceptions.RequestException: If request fails
AccuracyValidationException: If response format is invalid
"""
payload = {
"sampling_params": {"max_completion_tokens": 15, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"stream": False,
}
if input_ids:
payload["input_ids"] = prompt
else:
payload["text"] = prompt
response = requests.post(
f"http://localhost:{port}/generate",
headers={"Content-Type": "application/json"},
json={
"text": prompt,
"sampling_params": {"max_completion_tokens": 15, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"stream": False,
},
json=payload,
timeout=30, # Add reasonable timeout
)
response.raise_for_status()
Expand Down
11 changes: 8 additions & 3 deletions shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,19 @@ async def run(self):
try:
# Launch all individual generate processes and wait for them to finish.
gen_processes = []
input_ids = self.gen_req.input_ids
is_pretokenized = input_ids is not None
# TODO: We should send this to an executor and await the results.
input_batch = self.tokenize()
if is_pretokenized:
input_batch = [input_ids] if self.gen_req.is_single else input_ids
else:
input_batch = self.tokenize()
for index, input_tokens in enumerate(input_batch):
gen_process = GenerateItemProcess(
self,
self.gen_req,
index,
input_tokens.ids,
input_tokens if is_pretokenized else input_tokens.ids,
max_completion_tokens=self.gen_req.sampling_params[
"max_completion_tokens"
],
Expand Down Expand Up @@ -184,4 +189,4 @@ def tokenize(self) -> list[Encoding]:
logger.debug("Generated encodings: %r", encodings)
return encodings
else:
raise NotImplementedError("GenerateReqInput.input_ids handling NYI")
raise ValueError("Cannot tokenize 'None' value")

0 comments on commit 06bf638

Please sign in to comment.