diff --git a/flytekit/extras/webhook/agent.py b/flytekit/extras/webhook/agent.py index fa532ffdc0..b2a1ffca6b 100644 --- a/flytekit/extras/webhook/agent.py +++ b/flytekit/extras/webhook/agent.py @@ -1,7 +1,8 @@ +import asyncio import http from typing import Optional -import httpx +import aiohttp from flyteidl.core.execution_pb2 import TaskExecution from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase @@ -9,7 +10,6 @@ from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.utils.dict_formatter import format_dict - from .constants import BODY_KEY, HEADERS_KEY, METHOD_KEY, SHOW_BODY_KEY, SHOW_URL_KEY, TASK_TYPE, URL_KEY @@ -18,9 +18,17 @@ class WebhookAgent(SyncAgentBase): def __init__(self): super().__init__(task_type_name=TASK_TYPE) + self._loop = asyncio.get_running_loop() + self._loop.create_task(self._initialize_session()) + + def __del__(self): + self._loop.create_task(self._session.close()) + + async def _initialize_session(self): + self._session = aiohttp.ClientSession() - def do( - self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs + async def do( + self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs ) -> Resource: try: custom_dict = task_template.custom @@ -37,22 +45,17 @@ def do( show_body = final_dict.get(SHOW_BODY_KEY, False) show_url = final_dict.get(SHOW_URL_KEY, False) - async with httpx.AsyncClient() as client: - if method == http.HTTPMethod.GET: - response = await client.get(url, headers=headers) - else: - response = await client.post(url, data=body, headers=headers) if method == http.HTTPMethod.GET: - response = httpx.get(url, headers=headers) + response = await self._session.get(url, headers=headers) else: - response = httpx.post(url, data=body, headers=headers) - if response.status_code != 200: + response = self._session.post(url, data=body, headers=headers) + if response.status != 200: return Resource( phase=TaskExecution.FAILED, - message=f"Webhook failed with status code {response.status_code}, response: {response.text}", + message=f"Webhook failed with status code {response.status}, response: {response.text}", ) final_response = { - "status_code": response.status_code, + "status_code": response.status, "body": response.text, } if show_body: diff --git a/tests/flytekit/unit/extras/webhook/test_agent.py b/tests/flytekit/unit/extras/webhook/test_agent.py index fe359b185b..ebee64bf48 100644 --- a/tests/flytekit/unit/extras/webhook/test_agent.py +++ b/tests/flytekit/unit/extras/webhook/test_agent.py @@ -1,10 +1,11 @@ -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock import pytest + +from flytekit.extras.webhook.agent import WebhookAgent from flytekit.models.core.execution import TaskExecutionPhase as TaskExecution from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.extras.webhook.agent import WebhookAgent @pytest.fixture @@ -20,48 +21,57 @@ def mock_task_template(): } return task_template -@patch('flytekit.extras.webhook.agent.httpx') -def test_do_post_success(mock_httpx, mock_task_template): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.text = "Success" - mock_httpx.post.return_value = mock_response + +@pytest.fixture +def mock_aiohttp_session(): + with patch('aiohttp.ClientSession') as mock_session: + yield mock_session + + +@pytest.mark.asyncio +async def test_do_post_success(mock_task_template, mock_aiohttp_session): + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.text = AsyncMock(return_value="Success") + mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response) agent = WebhookAgent() - result = agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) + result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) assert result.phase == TaskExecution.SUCCEEDED assert result.outputs["status_code"] == 200 assert result.outputs["body"] == "Success" assert result.outputs["url"] == "http://example.com" -@patch('flytekit.extras.webhook.agent.httpx') -def test_do_get_success(mock_httpx, mock_task_template): + +@pytest.mark.asyncio +async def test_do_get_success(mock_task_template, mock_aiohttp_session): mock_task_template.custom["method"] = "GET" mock_task_template.custom.pop("body") mock_task_template.custom["show_body"] = False - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.text = "Success" - mock_httpx.get.return_value = mock_response + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.text = AsyncMock(return_value="Success") + mock_aiohttp_session.return_value.get = AsyncMock(return_value=mock_response) agent = WebhookAgent() - result = agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) + result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) assert result.phase == TaskExecution.SUCCEEDED assert result.outputs["status_code"] == 200 assert result.outputs["url"] == "http://example.com" -@patch('flytekit.extras.webhook.agent.httpx') -def test_do_failure(mock_httpx, mock_task_template): - mock_response = MagicMock() - mock_response.status_code = 500 - mock_response.text = "Internal Server Error" - mock_httpx.post.return_value = mock_response + +@pytest.mark.asyncio +async def test_do_failure(mock_task_template, mock_aiohttp_session): + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.text = AsyncMock(return_value="Internal Server Error") + mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response) agent = WebhookAgent() - result = agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) + result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) assert result.phase == TaskExecution.FAILED assert "Webhook failed with status code 500" in result.message