Skip to content

Commit

Permalink
address review comments in agent.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BWMac committed Jan 22, 2025
1 parent 34be93d commit 1c90065
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 39 deletions.
60 changes: 41 additions & 19 deletions synapseclient/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,10 @@ class AgentSession(AgentSessionSynchronousProtocol):
from synapseclient import Synapse
from synapseclient.models.agent import AgentSession, AgentSessionAccessLevel
AGENT_REGISTRATION_ID = 0 # replace with your custom agent's registration id
syn = Synapse()
syn.login()
my_session = AgentSession(agent_registration_id=AGENT_REGISTRATION_ID).start()
my_session = AgentSession(agent_registration_id="foo").start()
my_session.prompt(
prompt="Hello",
enable_trace=True,
Expand All @@ -173,12 +170,10 @@ class AgentSession(AgentSessionSynchronousProtocol):
from synapseclient import Synapse
from synapseclient.models.agent import AgentSession, AgentSessionAccessLevel
SESSION_ID = "my_session_id" # replace with your existing session's ID
syn = Synapse()
syn.login()
my_session = AgentSession(id=SESSION_ID).get()
my_session = AgentSession(id="foo").get()
my_session.prompt(
prompt="Hello",
enable_trace=True,
Expand All @@ -195,7 +190,7 @@ class AgentSession(AgentSessionSynchronousProtocol):
syn = Synapse()
syn.login()
my_session = AgentSession(id="my_session_id").get()
my_session = AgentSession(id="foo").get()
my_session.access_level = AgentSessionAccessLevel.READ_YOUR_PRIVATE_DATA
my_session.update()
"""
Expand Down Expand Up @@ -353,10 +348,10 @@ async def prompt_async(
self.chat_history.append(agent_prompt)

if print_response:
print(f"PROMPT:\n{prompt}\n")
print(f"RESPONSE:\n{agent_prompt.response}\n")
synapse_client.logger.info(f"PROMPT:\n{prompt}\n")
synapse_client.logger.info(f"RESPONSE:\n{agent_prompt.response}\n")
if enable_trace:
print(f"TRACE:\n{agent_prompt.trace}")
synapse_client.logger.info(f"TRACE:\n{agent_prompt.trace}")


@dataclass
Expand All @@ -368,9 +363,11 @@ class Agent(AgentSynchronousProtocol):
cloud_agent_id: The unique ID of the agent in the cloud provider.
cloud_alias_id: The alias ID of the agent in the cloud provider.
Defaults to 'TSTALIASID' in the Synapse API.
synapse_registration_id: The ID number of the agent assigned by Synapse.
registration_id: The ID number of the agent assigned by Synapse.
registered_on: The date the agent was registered.
type: The type of agent.
sessions: A dictionary of AgentSession objects, keyed by session ID.
current_session: The current session. Prompts will be sent to this session by default.
Example: Chat with the baseline Synapse Agent
You can chat with the same agent which is available in the Synapse UI
Expand All @@ -379,6 +376,9 @@ class Agent(AgentSynchronousProtocol):
the Agent class will start a session and set that new session as the
current session if one is not already set.
from synapseclient import Synapse
from synapseclient.models.agent import Agent
syn = Synapse()
syn.login()
Expand All @@ -395,17 +395,20 @@ class Agent(AgentSynchronousProtocol):
Alternatively, you can register a custom agent and chat with it provided
you have already created it.
from synapseclient import Synapse
from synapseclient.models.agent import Agent
syn = Synapse()
syn.login(silent=True)
syn.login()
my_agent = Agent(cloud_agent_id=AWS_AGENT_ID)
my_agent = Agent(cloud_agent_id="foo")
my_agent.register()
my_agent.prompt(
prompt="Hello",
enable_trace=True,
print_response=True,
)
enable_trace=True,
print_response=True,
)
Advanced Example: Start and prompt multiple sessions
Here, we connect to a custom agent and start one session with the prompt "Hello".
Expand All @@ -419,7 +422,7 @@ class Agent(AgentSynchronousProtocol):
syn = Synapse()
syn.login()
my_agent = Agent(registration_id=my_registration_id).get()
my_agent = Agent(registration_id="foo").get()
my_agent.prompt(
prompt="Hello",
Expand Down Expand Up @@ -473,7 +476,11 @@ def fill_from_dict(self, agent_registration: Dict[str, str]) -> "Agent":
self.cloud_alias_id = agent_registration.get("awsAliasId", None)
self.registration_id = agent_registration.get("agentRegistrationId", None)
self.registered_on = agent_registration.get("registeredOn", None)
self.type = agent_registration.get("type", None)
self.type = (
AgentType(agent_registration.get("type"))
if agent_registration.get("type", None)
else None
)
return self

@otel_trace_method(
Expand Down Expand Up @@ -607,6 +614,21 @@ async def prompt_async(
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Example: Prompt the baseline Synapse Agent to add annotations to a file on Synapse
The baseline Synpase Agent can be used to add annotations to files.
from synapseclient import Synapse
syn = Synapse()
syn.login()
my_agent = Agent()
my_agent.prompt(
prompt="Add the annotation 'test' to the file 'syn123456789'",
enable_trace=True,
print_response=True,
)
"""
if session:
await self.get_session_async(
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/synapseclient/models/async/unit_test_agent_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ async def test_prompt_trace_enabled_print_response(self) -> None:
new_callable=AsyncMock,
return_value=self.test_prompt_trace_enabled,
) as mock_send_job_and_wait_async,
patch(
"synapseclient.models.agent.print",
side_effect=print,
) as mock_print,
patch.object(
self.syn.logger,
"info",
) as mock_logger_info,
):
# GIVEN an existing AgentSession
# WHEN I call prompt with trace enabled and print_response enabled
Expand All @@ -337,7 +337,7 @@ async def test_prompt_trace_enabled_print_response(self) -> None:
synapse_client=self.syn, post_exchange_args={"newer_than": 0}
)
# AND the trace should be printed
mock_print.assert_called_with(
mock_logger_info.assert_called_with(
f"TRACE:\n{self.test_prompt_trace_enabled.trace}"
)

Expand All @@ -348,10 +348,10 @@ async def test_prompt_trace_disabled_no_print(self) -> None:
new_callable=AsyncMock,
return_value=self.test_prompt_trace_disabled,
) as mock_send_job_and_wait_async,
patch(
"synapseclient.models.agent.print",
side_effect=print,
) as mock_print,
patch.object(
self.syn.logger,
"info",
) as mock_logger_info,
):
# WHEN I call prompt with trace disabled and print_response disabled
await self.test_session.prompt_async(
Expand All @@ -370,7 +370,7 @@ async def test_prompt_trace_disabled_no_print(self) -> None:
synapse_client=self.syn, post_exchange_args={"newer_than": 0}
)
# AND print should not have been called
mock_print.assert_not_called()
mock_logger_info.assert_not_called()


class TestAgent:
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/synapseclient/models/synchronous/unit_test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def test_prompt_trace_enabled_print_response(self) -> None:
"synapseclient.models.agent.AgentPrompt.send_job_and_wait_async",
return_value=self.test_prompt_trace_enabled,
) as mock_send_job_and_wait_async,
patch(
"synapseclient.models.agent.print",
side_effect=print,
) as mock_print,
patch.object(
self.syn.logger,
"info",
) as mock_logger_info,
):
# GIVEN an existing AgentSession
# WHEN I call prompt with trace enabled and print_response enabled
Expand All @@ -239,7 +239,7 @@ def test_prompt_trace_enabled_print_response(self) -> None:
synapse_client=self.syn, post_exchange_args={"newer_than": 0}
)
# AND the trace should be printed
mock_print.assert_called_with(
mock_logger_info.assert_called_with(
f"TRACE:\n{self.test_prompt_trace_enabled.trace}"
)

Expand All @@ -249,10 +249,10 @@ def test_prompt_trace_disabled_no_print(self) -> None:
"synapseclient.models.agent.AgentPrompt.send_job_and_wait_async",
return_value=self.test_prompt_trace_disabled,
) as mock_send_job_and_wait_async,
patch(
"synapseclient.models.agent.print",
side_effect=print,
) as mock_print,
patch.object(
self.syn.logger,
"info",
) as mock_logger_info,
):
# WHEN I call prompt with trace disabled and print_response disabled
self.test_session.prompt(
Expand All @@ -269,7 +269,7 @@ def test_prompt_trace_disabled_no_print(self) -> None:
synapse_client=self.syn, post_exchange_args={"newer_than": 0}
)
# AND print should not have been called
mock_print.assert_not_called()
mock_logger_info.assert_not_called()


class TestAgent:
Expand Down

0 comments on commit 1c90065

Please sign in to comment.