Skip to content

Commit

Permalink
Agents - missing type hint (#2896)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Nov 4, 2024
1 parent 3a26ee5 commit 0662bc6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap:
raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}")

if resource.outputs and not isinstance(resource.outputs, LiteralMap):
return TypeEngine.dict_to_literal_map(ctx, resource.outputs)
return TypeEngine.dict_to_literal_map(ctx, resource.outputs, type_hints=self.python_interface.outputs)
return resource.outputs

async def _do(
Expand Down
46 changes: 46 additions & 0 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,49 @@ def test_resource_type():
# round-tripping creates a literal map out of outputs
assert o2.outputs.literals["o0"].scalar.primitive.integer == 1
assert o2.custom_info == o.custom_info


def test_agent_complex_type():
@dataclass
class Foo:
val: str

class FooAgent(SyncAgentBase):
def __init__(self) -> None:
super().__init__(task_type_name="foo")

def do(
self,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
**kwargs: typing.Any,
) -> Resource:
return Resource(
phase=TaskExecution.SUCCEEDED, outputs={"foos": [Foo(val="a"), Foo(val="b")], "has_foos": True}
)

AgentRegistry.register(FooAgent(), override=True)

class FooTask(SyncAgentExecutorMixin, PythonTask): # type: ignore
_TASK_TYPE = "foo"

def __init__(self, name: str, **kwargs: typing.Any) -> None:
task_config: dict[str, typing.Any] = {}

outputs = {"has_foos": bool, "foos": typing.Optional[typing.List[Foo]]}

super().__init__(
task_type=self._TASK_TYPE,
name=name,
task_config=task_config,
interface=Interface(outputs=outputs),
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> typing.Dict[str, typing.Any]:
return {}

foo_task = FooTask(name="foo_task")
res = foo_task()
assert res.has_foos
assert res.foos[1].val == "b"

0 comments on commit 0662bc6

Please sign in to comment.