From 60f464e9bbd9d266ab0c5fca587bc1537e7ec243 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 12 Mar 2024 20:26:26 -0700 Subject: [PATCH] fix: Update default params and add associated UTs (#294) Default parameters for text generation are here: https://huggingface.co/docs/transformers/en/main_classes/text_generation With growing list of params, I shorten the list here and add UT. --------- Signed-off-by: Ishaan Sehgal --- .../text-generation/inference_api.py | 20 +--- .../tests/test_inference_api.py | 107 ++++++++++++++++++ 2 files changed, 111 insertions(+), 16 deletions(-) diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index 73c7b5095..bf739844d 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -123,30 +123,18 @@ def health_check(): return {"status": "Healthy"} class GenerateKwargs(BaseModel): - max_length: int = 200 + max_length: int = 200 # Length of input prompt+max_new_tokens min_length: int = 0 - do_sample: bool = True + do_sample: bool = False early_stopping: bool = False num_beams: int = 1 - num_beam_groups: int = 1 - diversity_penalty: float = 0.0 temperature: float = 1.0 - top_k: int = 10 + top_k: int = 50 top_p: float = 1 typical_p: float = 1 repetition_penalty: float = 1 - length_penalty: float = 1 - no_repeat_ngram_size: int = 0 - encoder_no_repeat_ngram_size: int = 0 - bad_words_ids: Optional[List[int]] = None - num_return_sequences: int = 1 - output_scores: bool = False - return_dict_in_generate: bool = False pad_token_id: Optional[int] = tokenizer.pad_token_id eos_token_id: Optional[int] = tokenizer.eos_token_id - forced_bos_token_id: Optional[int] = None - forced_eos_token_id: Optional[int] = None - remove_invalid_values: Optional[bool] = None class Config: extra = Extra.allow # Allows for additional fields not explicitly defined @@ -157,7 +145,7 @@ class UnifiedRequestModel(BaseModel): clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") prefix: Optional[str] = Field(None, description="Prefix added to prompt") handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation") - generate_kwargs: Optional[GenerateKwargs] = Field(None, description="Additional kwargs for generate method") + generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method") # Field for conversational model messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model") diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index ff9866bbb..d6506b08b 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -6,6 +6,7 @@ import pytest import torch from fastapi.testclient import TestClient +from transformers import AutoTokenizer # Get the parent directory of the current file parent_dir = str(Path(__file__).resolve().parent.parent) @@ -127,3 +128,109 @@ def test_get_metrics_no_gpus(configured_app): response = client.get("/metrics") assert response.status_code == 200 assert response.json()["gpu_info"] == [] + +def test_default_generation_params(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + + client = TestClient(configured_app) + + request_data = { + "prompt": "Test default params", + "return_full_text": True, + "clean_up_tokenization_spaces": False + # Note: generate_kwargs is not provided, so defaults should be used + } + + with patch('inference_api.pipeline') as mock_pipeline: + mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function + + response = client.post("/chat", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "Result" in data + assert data["Result"] == "Mocked response", "The response content doesn't match the expected mock response" + + # Check the default args + _, kwargs = mock_pipeline.call_args + assert kwargs['max_length'] == 200 + assert kwargs['min_length'] == 0 + assert kwargs['do_sample'] is False + assert kwargs['temperature'] == 1.0 + assert kwargs['top_k'] == 50 + assert kwargs['top_p'] == 1 + assert kwargs['typical_p'] == 1 + assert kwargs['repetition_penalty'] == 1 + assert kwargs['num_beams'] == 1 + assert kwargs['early_stopping'] is False + +def test_generation_with_max_length(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + + client = TestClient(configured_app) + prompt = "This prompt requests a response of a certain minimum length to test the functionality." + avg_res_len = 15 + max_length = 40 # Set to lower than default (200) to prevent test hanging + + request_data = { + "prompt": prompt, + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "generate_kwargs": {"max_length": max_length} + } + + response = client.post("/chat", json=request_data) + + assert response.status_code == 200 + data = response.json() + print("Response: ", data["Result"]) + assert "Result" in data, "The response should contain a 'Result' key" + + tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path']) + prompt_tokens = tokenizer.tokenize(prompt) + total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt + + prompt_tokens_len = len(prompt_tokens) + max_new_tokens = max_length - prompt_tokens_len + new_tokens = len(total_tokens) - prompt_tokens_len + + assert avg_res_len <= new_tokens, f"Ideally response should generate at least 15 tokens" + assert new_tokens <= max_new_tokens, "Response must not generate more than max new tokens" + assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length" + +def test_generation_with_min_length(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + + client = TestClient(configured_app) + prompt = "This prompt requests a response of a certain minimum length to test the functionality." + min_length = 30 + max_length = 40 + + request_data = { + "prompt": prompt, + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "generate_kwargs": {"min_length": min_length, "max_length": max_length} + } + + response = client.post("/chat", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "Result" in data, "The response should contain a 'Result' key" + + tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path']) + prompt_tokens = tokenizer.tokenize(prompt) + total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt + + prompt_tokens_len = len(prompt_tokens) + + min_new_tokens = min_length - prompt_tokens_len + max_new_tokens = max_length - prompt_tokens_len + new_tokens = len(total_tokens) - prompt_tokens_len + + assert min_new_tokens <= new_tokens <= max_new_tokens, "Response should generate at least min_new_tokens and at most max_new_tokens new tokens" + assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length"