Skip to content

Commit

Permalink
fix: Update default params and add associated UTs (#294)
Browse files Browse the repository at this point in the history
Default parameters for text generation are here:
https://huggingface.co/docs/transformers/en/main_classes/text_generation

With growing list of params, I shorten the list here and add UT.

---------

Signed-off-by: Ishaan Sehgal <[email protected]>
  • Loading branch information
ishaansehgal99 authored Mar 13, 2024
1 parent 62340b3 commit 60f464e
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 16 deletions.
20 changes: 4 additions & 16 deletions presets/inference/text-generation/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,30 +123,18 @@ def health_check():
return {"status": "Healthy"}

class GenerateKwargs(BaseModel):
max_length: int = 200
max_length: int = 200 # Length of input prompt+max_new_tokens
min_length: int = 0
do_sample: bool = True
do_sample: bool = False
early_stopping: bool = False
num_beams: int = 1
num_beam_groups: int = 1
diversity_penalty: float = 0.0
temperature: float = 1.0
top_k: int = 10
top_k: int = 50
top_p: float = 1
typical_p: float = 1
repetition_penalty: float = 1
length_penalty: float = 1
no_repeat_ngram_size: int = 0
encoder_no_repeat_ngram_size: int = 0
bad_words_ids: Optional[List[int]] = None
num_return_sequences: int = 1
output_scores: bool = False
return_dict_in_generate: bool = False
pad_token_id: Optional[int] = tokenizer.pad_token_id
eos_token_id: Optional[int] = tokenizer.eos_token_id
forced_bos_token_id: Optional[int] = None
forced_eos_token_id: Optional[int] = None
remove_invalid_values: Optional[bool] = None
class Config:
extra = Extra.allow # Allows for additional fields not explicitly defined

Expand All @@ -157,7 +145,7 @@ class UnifiedRequestModel(BaseModel):
clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output")
prefix: Optional[str] = Field(None, description="Prefix added to prompt")
handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation")
generate_kwargs: Optional[GenerateKwargs] = Field(None, description="Additional kwargs for generate method")
generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method")

# Field for conversational model
messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model")
Expand Down
107 changes: 107 additions & 0 deletions presets/inference/text-generation/tests/test_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch
from fastapi.testclient import TestClient
from transformers import AutoTokenizer

# Get the parent directory of the current file
parent_dir = str(Path(__file__).resolve().parent.parent)
Expand Down Expand Up @@ -127,3 +128,109 @@ def test_get_metrics_no_gpus(configured_app):
response = client.get("/metrics")
assert response.status_code == 200
assert response.json()["gpu_info"] == []

def test_default_generation_params(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")

client = TestClient(configured_app)

request_data = {
"prompt": "Test default params",
"return_full_text": True,
"clean_up_tokenization_spaces": False
# Note: generate_kwargs is not provided, so defaults should be used
}

with patch('inference_api.pipeline') as mock_pipeline:
mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function

response = client.post("/chat", json=request_data)

assert response.status_code == 200
data = response.json()
assert "Result" in data
assert data["Result"] == "Mocked response", "The response content doesn't match the expected mock response"

# Check the default args
_, kwargs = mock_pipeline.call_args
assert kwargs['max_length'] == 200
assert kwargs['min_length'] == 0
assert kwargs['do_sample'] is False
assert kwargs['temperature'] == 1.0
assert kwargs['top_k'] == 50
assert kwargs['top_p'] == 1
assert kwargs['typical_p'] == 1
assert kwargs['repetition_penalty'] == 1
assert kwargs['num_beams'] == 1
assert kwargs['early_stopping'] is False

def test_generation_with_max_length(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")

client = TestClient(configured_app)
prompt = "This prompt requests a response of a certain minimum length to test the functionality."
avg_res_len = 15
max_length = 40 # Set to lower than default (200) to prevent test hanging

request_data = {
"prompt": prompt,
"return_full_text": True,
"clean_up_tokenization_spaces": False,
"generate_kwargs": {"max_length": max_length}
}

response = client.post("/chat", json=request_data)

assert response.status_code == 200
data = response.json()
print("Response: ", data["Result"])
assert "Result" in data, "The response should contain a 'Result' key"

tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path'])
prompt_tokens = tokenizer.tokenize(prompt)
total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt

prompt_tokens_len = len(prompt_tokens)
max_new_tokens = max_length - prompt_tokens_len
new_tokens = len(total_tokens) - prompt_tokens_len

assert avg_res_len <= new_tokens, f"Ideally response should generate at least 15 tokens"
assert new_tokens <= max_new_tokens, "Response must not generate more than max new tokens"
assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length"

def test_generation_with_min_length(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")

client = TestClient(configured_app)
prompt = "This prompt requests a response of a certain minimum length to test the functionality."
min_length = 30
max_length = 40

request_data = {
"prompt": prompt,
"return_full_text": True,
"clean_up_tokenization_spaces": False,
"generate_kwargs": {"min_length": min_length, "max_length": max_length}
}

response = client.post("/chat", json=request_data)

assert response.status_code == 200
data = response.json()
assert "Result" in data, "The response should contain a 'Result' key"

tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path'])
prompt_tokens = tokenizer.tokenize(prompt)
total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt

prompt_tokens_len = len(prompt_tokens)

min_new_tokens = min_length - prompt_tokens_len
max_new_tokens = max_length - prompt_tokens_len
new_tokens = len(total_tokens) - prompt_tokens_len

assert min_new_tokens <= new_tokens <= max_new_tokens, "Response should generate at least min_new_tokens and at most max_new_tokens new tokens"
assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length"

0 comments on commit 60f464e

Please sign in to comment.