diff --git a/tests/test_random_prompt.py b/tests/test_random_prompt.py new file mode 100644 index 0000000..5080db5 --- /dev/null +++ b/tests/test_random_prompt.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/1/27 上午10:41 +# @Author : sudoskys +# @File : test_random_prompt.py +# @Software: PyCharm + +from src.novelai_python.utils.random_prompt import RandomPromptGenerator + + +def test_generate_returns_string(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result = generator.generate() + assert isinstance(result, str) + + +def test_generate_returns_non_empty_string(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result = generator.generate() + assert len(result) > 0 + + +def test_generate_returns_different_results(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result1 = generator.generate() + result2 = generator.generate() + assert result1 != result2 + + +def test_generate_with_nsfw_disabled(): + generator = RandomPromptGenerator(nsfw_enabled=False) + result = generator.generate() + assert 'nsfw' not in result + + +def test_generate_with_nsfw_enabled(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result = generator.generate() + assert 'nsfw' in result + + +def test_get_weighted_choice_returns_string(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result = generator.get_weighted_choice([['tag1', 1], ['tag2', 2]], []) + assert isinstance(result, str) + + +def test_get_weighted_choice_returns_valid_tag(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result = generator.get_weighted_choice([['tag1', 1], ['tag2', 2]], []) + assert result in ['tag1', 'tag2'] + + +def test_character_features_returns_list(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result = generator.character_features('m', 'front', True, 1) + assert isinstance(result, list) + + +def test_character_features_returns_non_empty_list(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result = generator.character_features('m', 'front', True, 1) + assert len(result) > 0 + + +def test_character_features_with_different_genders(): + generator = RandomPromptGenerator(nsfw_enabled=True) + result_m = generator.character_features('m', 'front', True, 1) + result_f = generator.character_features('f', 'front', True, 1) + result_o = generator.character_features('o', 'front', True, 1) + assert result_m != result_f != result_o diff --git a/tests/test_server_run.py b/tests/test_server_run.py new file mode 100644 index 0000000..dc6d885 --- /dev/null +++ b/tests/test_server_run.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/1/30 下午11:52 +# @Author : sudoskys +# @File : test_server_run.py +# @Software: PyCharm +from fastapi.testclient import TestClient + +from src.novelai_python.sdk.ai.generate_image import GenerateImageInfer +from src.novelai_python.server import app, get_session + +client = TestClient(app) + + +def test_health_check(): + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_generate_image_with_valid_token(): + valid_token = "valid_token" + get_session(valid_token) # to simulate a valid session + response = client.post( + "/ai/generate_image", + headers={"Authorization": valid_token}, + json=GenerateImageInfer(input="1girl").model_dump() + ) + assert response.status_code == 500