Skip to content

Commit

Permalink
feat: Support requesting connected_account in custom action (#1090)
Browse files Browse the repository at this point in the history
  • Loading branch information
tushar-composio authored Dec 25, 2024
1 parent 9fdb719 commit 0111d9c
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions python/composio/tools/base/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel, Field

from composio import Composio
from composio.client.collections import CustomAuthParameter
from composio.client.collections import ConnectedAccountModel, CustomAuthParameter
from composio.client.enums.base import ActionData, SentinalObject, add_runtime_action
from composio.client.exceptions import ComposioClientError
from composio.exceptions import ComposioSDKError
Expand Down Expand Up @@ -253,12 +253,26 @@ def _parse_docstring(
return header, params, returns


def _get_connected_account(
app: str, entity_id: str
) -> t.Optional[ConnectedAccountModel]:
try:
client = Composio.get_latest()
connected_account = client.connected_accounts.get(
connection_id=client.get_entity(entity_id).get_connection(app=app).id
)
return connected_account
except ComposioClientError:
return None


def _get_auth_params(app: str, entity_id: str) -> t.Optional[t.Dict]:
try:
client = Composio.get_latest()
connection_params = client.connected_accounts.get(
connected_account = client.connected_accounts.get(
connection_id=client.get_entity(entity_id).get_connection(app=app).id
).connectionParams
)
connection_params = connected_account.connectionParams
return {
"headers": connection_params.headers,
"base_url": connection_params.base_url,
Expand Down Expand Up @@ -303,7 +317,8 @@ def _build_executable_from_args( # pylint: disable=too-many-statements
}

shell_argument = None
auth_params = False
auth_param = False
connected_account_param = False
request_executor = False
if "return" not in argspec.annotations:
raise InvalidRuntimeAction(
Expand All @@ -316,7 +331,11 @@ def _build_executable_from_args( # pylint: disable=too-many-statements
continue

if arg == "auth":
auth_params = True
auth_param = True
continue

if arg == "connected_account":
connected_account_param = True
continue

if arg == "execute_request":
Expand Down Expand Up @@ -376,7 +395,11 @@ def execute(request: BaseModel, metadata: t.Dict) -> BaseModel:
if shell_argument is not None:
kwargs[shell_argument] = metadata["workspace"].shells.recent

if auth_params > 0:
if connected_account_param:
kwargs["connected_account"] = (
_get_connected_account(app=app, entity_id=metadata["entity_id"]) or {}
)
if auth_param:
kwargs["auth"] = (
_get_auth_params(app=app, entity_id=metadata["entity_id"]) or {}
)
Expand Down

0 comments on commit 0111d9c

Please sign in to comment.