Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed Jan 14, 2025
1 parent d5ae0a9 commit 50da8ec
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 37 deletions.
31 changes: 17 additions & 14 deletions flytekit/extras/webhook/agent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
import http
from typing import Optional

import httpx
import aiohttp
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase
from flytekit.interaction.string_literals import literal_map_string_repr
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.utils.dict_formatter import format_dict

from .constants import BODY_KEY, HEADERS_KEY, METHOD_KEY, SHOW_BODY_KEY, SHOW_URL_KEY, TASK_TYPE, URL_KEY


Expand All @@ -18,9 +18,17 @@ class WebhookAgent(SyncAgentBase):

def __init__(self):
super().__init__(task_type_name=TASK_TYPE)
self._loop = asyncio.get_running_loop()
self._loop.create_task(self._initialize_session())

def __del__(self):
self._loop.create_task(self._session.close())

async def _initialize_session(self):
self._session = aiohttp.ClientSession()

def do(
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
async def do(
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
) -> Resource:
try:
custom_dict = task_template.custom
Expand All @@ -37,22 +45,17 @@ def do(
show_body = final_dict.get(SHOW_BODY_KEY, False)
show_url = final_dict.get(SHOW_URL_KEY, False)

async with httpx.AsyncClient() as client:
if method == http.HTTPMethod.GET:
response = await client.get(url, headers=headers)
else:
response = await client.post(url, data=body, headers=headers)
if method == http.HTTPMethod.GET:
response = httpx.get(url, headers=headers)
response = await self._session.get(url, headers=headers)
else:
response = httpx.post(url, data=body, headers=headers)
if response.status_code != 200:
response = self._session.post(url, data=body, headers=headers)
if response.status != 200:
return Resource(
phase=TaskExecution.FAILED,
message=f"Webhook failed with status code {response.status_code}, response: {response.text}",
message=f"Webhook failed with status code {response.status}, response: {response.text}",
)
final_response = {
"status_code": response.status_code,
"status_code": response.status,
"body": response.text,
}
if show_body:
Expand Down
56 changes: 33 additions & 23 deletions tests/flytekit/unit/extras/webhook/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, AsyncMock

import pytest

from flytekit.extras.webhook.agent import WebhookAgent
from flytekit.models.core.execution import TaskExecutionPhase as TaskExecution
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.extras.webhook.agent import WebhookAgent


@pytest.fixture
Expand All @@ -20,48 +21,57 @@ def mock_task_template():
}
return task_template

@patch('flytekit.extras.webhook.agent.httpx')
def test_do_post_success(mock_httpx, mock_task_template):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "Success"
mock_httpx.post.return_value = mock_response

@pytest.fixture
def mock_aiohttp_session():
with patch('aiohttp.ClientSession') as mock_session:
yield mock_session


@pytest.mark.asyncio
async def test_do_post_success(mock_task_template, mock_aiohttp_session):
mock_response = AsyncMock()
mock_response.status = 200
mock_response.text = AsyncMock(return_value="Success")
mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response)

agent = WebhookAgent()
result = agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))

assert result.phase == TaskExecution.SUCCEEDED
assert result.outputs["status_code"] == 200
assert result.outputs["body"] == "Success"
assert result.outputs["url"] == "http://example.com"

@patch('flytekit.extras.webhook.agent.httpx')
def test_do_get_success(mock_httpx, mock_task_template):

@pytest.mark.asyncio
async def test_do_get_success(mock_task_template, mock_aiohttp_session):
mock_task_template.custom["method"] = "GET"
mock_task_template.custom.pop("body")
mock_task_template.custom["show_body"] = False

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "Success"
mock_httpx.get.return_value = mock_response
mock_response = AsyncMock()
mock_response.status = 200
mock_response.text = AsyncMock(return_value="Success")
mock_aiohttp_session.return_value.get = AsyncMock(return_value=mock_response)

agent = WebhookAgent()
result = agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))

assert result.phase == TaskExecution.SUCCEEDED
assert result.outputs["status_code"] == 200
assert result.outputs["url"] == "http://example.com"

@patch('flytekit.extras.webhook.agent.httpx')
def test_do_failure(mock_httpx, mock_task_template):
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_httpx.post.return_value = mock_response

@pytest.mark.asyncio
async def test_do_failure(mock_task_template, mock_aiohttp_session):
mock_response = AsyncMock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value="Internal Server Error")
mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response)

agent = WebhookAgent()
result = agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))

assert result.phase == TaskExecution.FAILED
assert "Webhook failed with status code 500" in result.message

0 comments on commit 50da8ec

Please sign in to comment.