Skip to content

Commit

Permalink
make tests more idiomatic of pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
IndrajeetPatil committed Oct 18, 2024
1 parent 08ed475 commit 90048d2
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 126 deletions.
183 changes: 95 additions & 88 deletions server/chatgptserver/api/tests/test_azure_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,98 +42,105 @@ def mock_get_instance():
return mock_client


def create_mock_completion(content: str) -> object:
"""Helper function to create mock completion objects."""
MockMessage = type("MockMessage", (), {"content": content})
MockChoice = type("MockChoice", (), {"message": MockMessage()})
return type("MockCompletion", (), {"choices": [MockChoice()]})()


@pytest.mark.django_db
@override_settings(
AZURE_OPENAI_ENDPOINT="https://test-endpoint.openai.azure.com/",
AZURE_OPENAI_API_VERSION="2023-05-15",
AZURE_OPENAI_API_KEY="test-api-key",
)
def test_singleton_instance() -> None:
"""Test that AzureOpenAIClient maintains singleton pattern and proper configuration."""
instance1 = AzureOpenAIClient.get_instance()
instance2 = AzureOpenAIClient.get_instance()

assert instance1 is instance2
assert isinstance(instance1, AzureOpenAI)
assert instance1.api_key == "test-api-key"
assert instance1.max_retries == 5


@pytest.mark.django_db
class TestAzureOpenAIClient:
@override_settings(
AZURE_OPENAI_ENDPOINT="https://test-endpoint.openai.azure.com/",
AZURE_OPENAI_API_VERSION="2023-05-15",
AZURE_OPENAI_API_KEY="test-api-key",
@pytest.mark.parametrize(
("api_response", "expected_output"),
[
("Valid content", "Valid content"),
("", ""),
(None, ""),
],
)
def test_successful_response(
mock_azure_client: MockAzureClient,
api_response: str,
expected_output: str,
) -> None:
"""Test successful API responses with different content values."""
mock_azure_client.chat.completions.return_value = create_mock_completion(
api_response,
)
def test_singleton_instance(self) -> None:
instance1 = AzureOpenAIClient.get_instance()
instance2 = AzureOpenAIClient.get_instance()
assert instance1 is instance2
assert isinstance(instance1, AzureOpenAI)
assert instance1.api_key == "test-api-key"
assert instance1.max_retries == 5

response = get_azure_openai_response("Test prompt")

assert response == expected_output
assert len(mock_azure_client.chat.completions.create_calls) == 1
assert mock_azure_client.chat.completions.create_calls[0] == {
"model": AssistantModel.FULL.value,
"temperature": AssistantTemperature.BALANCED.value,
"messages": [{"role": "user", "content": "Test prompt"}],
}


@pytest.mark.django_db
def test_api_exception(mock_azure_client: MockAzureClient) -> None:
"""Test handling of API exceptions."""
mock_azure_client.chat.completions.side_effect = Exception("API Error")

with pytest.raises(Exception):
get_azure_openai_response("Test prompt")


@pytest.mark.django_db
class TestGetAzureOpenAIResponse:
@pytest.mark.parametrize(
("api_response", "expected_output"),
[
("Valid content", "Valid content"),
("", ""),
(None, ""),
],
def test_unexpected_response_format(mock_azure_client: MockAzureClient) -> None:
"""Test handling of unexpected response format."""
mock_completion = type("MockCompletion", (), {"choices": []})()
mock_azure_client.chat.completions.return_value = mock_completion

response = get_azure_openai_response("Test prompt")

assert response == ""


@pytest.mark.django_db
@pytest.mark.parametrize(
("model", "temperature"),
list(product(AssistantModel, AssistantTemperature)),
)
def test_different_models_and_temperatures(
mock_azure_client: MockAzureClient,
model: AssistantModel,
temperature: AssistantTemperature,
) -> None:
"""Test API calls with different combinations of models and temperatures."""
mock_azure_client.chat.completions.return_value = create_mock_completion(
"Test response",
)
def test_successful_response(
self,
mock_azure_client: MockAzureClient,
api_response: str,
expected_output: str,
) -> None:
mock_azure_client.chat.completions.return_value = self._create_mock_completion(
api_response,
)

response = get_azure_openai_response("Test prompt")

assert response == expected_output
assert len(mock_azure_client.chat.completions.create_calls) == 1
assert mock_azure_client.chat.completions.create_calls[0] == {
"model": AssistantModel.FULL.value,
"temperature": AssistantTemperature.BALANCED.value,
"messages": [{"role": "user", "content": "Test prompt"}],
}

def test_api_exception(self, mock_azure_client: MockAzureClient) -> None:
mock_azure_client.chat.completions.side_effect = Exception("API Error")

with pytest.raises(Exception):
get_azure_openai_response("Test prompt")

def test_unexpected_response_format(
self,
mock_azure_client: MockAzureClient,
) -> None:
mock_completion = type("MockCompletion", (), {"choices": []})()
mock_azure_client.chat.completions.return_value = mock_completion

response = get_azure_openai_response("Test prompt")

assert response == ""

@pytest.mark.parametrize(
("model", "temperature"),
list(product(AssistantModel, AssistantTemperature)),

response = get_azure_openai_response(
"Test prompt",
model=model,
temperature=temperature,
)
def test_different_models_and_temperatures(
self,
mock_azure_client: MockAzureClient,
model: AssistantModel,
temperature: AssistantTemperature,
) -> None:
mock_azure_client.chat.completions.return_value = self._create_mock_completion(
"Test response",
)

response = get_azure_openai_response(
"Test prompt",
model=model,
temperature=temperature,
)

assert response == "Test response"
assert len(mock_azure_client.chat.completions.create_calls) == 1
assert mock_azure_client.chat.completions.create_calls[0] == {
"model": model.value,
"temperature": temperature.value,
"messages": [{"role": "user", "content": "Test prompt"}],
}

def _create_mock_completion(self, content: str) -> object:
MockMessage = type("MockMessage", (), {"content": content})
MockChoice = type("MockChoice", (), {"message": MockMessage()})
return type("MockCompletion", (), {"choices": [MockChoice()]})()

assert response == "Test response"
assert len(mock_azure_client.chat.completions.create_calls) == 1
assert mock_azure_client.chat.completions.create_calls[0] == {
"model": model.value,
"temperature": temperature.value,
"messages": [{"role": "user", "content": "Test prompt"}],
}
78 changes: 40 additions & 38 deletions server/chatgptserver/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,45 @@ def mock_azure_response(monkeypatch: MonkeyPatch) -> MockAzureResponse:


@pytest.mark.django_db
class TestChatView:
def test_post_chat_view_success(
self,
api_client: APIClient,
chat_url: str,
valid_payload: dict[str, str],
mock_azure_response: MockAzureResponse,
) -> None:
response = api_client.post(
f"{chat_url}?temperature=BALANCED",
valid_payload,
format="json",
)

assert response.status_code == status.HTTP_200_OK
assert response.data["response"] == MOCK_RESPONSE
assert len(mock_azure_response.calls) == 1
assert mock_azure_response.calls[0] == {
"prompt": MOCK_PROMPT,
"model": AssistantModel.FULL,
"temperature": AssistantTemperature.BALANCED,
}

def test_post_chat_view_empty_payload(
self,
api_client: APIClient,
chat_url: str,
) -> None:
invalid_payload: dict[str, Any] = {}

response = api_client.post(
f"{chat_url}?temperature=BALANCED",
invalid_payload,
format="json",
)

assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "prompt" in response.data
def test_post_chat_view_success(
api_client: APIClient,
chat_url: str,
valid_payload: dict[str, str],
mock_azure_response: MockAzureResponse,
) -> None:
"""Test successful POST request to chat view with valid payload."""
response = api_client.post(
f"{chat_url}?temperature=BALANCED",
valid_payload,
format="json",
)

assert response.status_code == status.HTTP_200_OK
assert response.data["response"] == MOCK_RESPONSE
assert len(mock_azure_response.calls) == 1
assert mock_azure_response.calls[0] == {
"prompt": MOCK_PROMPT,
"model": AssistantModel.FULL,
"temperature": AssistantTemperature.BALANCED,
}


@pytest.mark.django_db
def test_post_chat_view_empty_payload(
api_client: APIClient,
chat_url: str,
) -> None:
"""Test POST request to chat view with empty payload."""
invalid_payload: dict[str, Any] = {}

response = api_client.post(
f"{chat_url}?temperature=BALANCED",
invalid_payload,
format="json",
)

assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "prompt" in response.data


@pytest.mark.django_db
Expand Down Expand Up @@ -140,6 +141,7 @@ def test_chat_view_parameters(
expected_status: int,
expected_errors: dict[str, str] | None,
) -> None:
"""Test chat view with different combinations of model and temperature parameters."""
url: str = reverse("chat", kwargs={"model": model})
response = api_client.post(
f"{url}?temperature={temperature}",
Expand Down

0 comments on commit 90048d2

Please sign in to comment.