diff --git a/synapseclient/models/agent.py b/synapseclient/models/agent.py index d1ae8f9b4..fa676b4ff 100644 --- a/synapseclient/models/agent.py +++ b/synapseclient/models/agent.py @@ -153,13 +153,10 @@ class AgentSession(AgentSessionSynchronousProtocol): from synapseclient import Synapse from synapseclient.models.agent import AgentSession, AgentSessionAccessLevel - - AGENT_REGISTRATION_ID = 0 # replace with your custom agent's registration id - syn = Synapse() syn.login() - my_session = AgentSession(agent_registration_id=AGENT_REGISTRATION_ID).start() + my_session = AgentSession(agent_registration_id="foo").start() my_session.prompt( prompt="Hello", enable_trace=True, @@ -173,12 +170,10 @@ class AgentSession(AgentSessionSynchronousProtocol): from synapseclient import Synapse from synapseclient.models.agent import AgentSession, AgentSessionAccessLevel - SESSION_ID = "my_session_id" # replace with your existing session's ID - syn = Synapse() syn.login() - my_session = AgentSession(id=SESSION_ID).get() + my_session = AgentSession(id="foo").get() my_session.prompt( prompt="Hello", enable_trace=True, @@ -195,7 +190,7 @@ class AgentSession(AgentSessionSynchronousProtocol): syn = Synapse() syn.login() - my_session = AgentSession(id="my_session_id").get() + my_session = AgentSession(id="foo").get() my_session.access_level = AgentSessionAccessLevel.READ_YOUR_PRIVATE_DATA my_session.update() """ @@ -353,10 +348,10 @@ async def prompt_async( self.chat_history.append(agent_prompt) if print_response: - print(f"PROMPT:\n{prompt}\n") - print(f"RESPONSE:\n{agent_prompt.response}\n") + synapse_client.logger.info(f"PROMPT:\n{prompt}\n") + synapse_client.logger.info(f"RESPONSE:\n{agent_prompt.response}\n") if enable_trace: - print(f"TRACE:\n{agent_prompt.trace}") + synapse_client.logger.info(f"TRACE:\n{agent_prompt.trace}") @dataclass @@ -368,9 +363,11 @@ class Agent(AgentSynchronousProtocol): cloud_agent_id: The unique ID of the agent in the cloud provider. cloud_alias_id: The alias ID of the agent in the cloud provider. Defaults to 'TSTALIASID' in the Synapse API. - synapse_registration_id: The ID number of the agent assigned by Synapse. + registration_id: The ID number of the agent assigned by Synapse. registered_on: The date the agent was registered. type: The type of agent. + sessions: A dictionary of AgentSession objects, keyed by session ID. + current_session: The current session. Prompts will be sent to this session by default. Example: Chat with the baseline Synapse Agent You can chat with the same agent which is available in the Synapse UI @@ -379,6 +376,9 @@ class Agent(AgentSynchronousProtocol): the Agent class will start a session and set that new session as the current session if one is not already set. + from synapseclient import Synapse + from synapseclient.models.agent import Agent + syn = Synapse() syn.login() @@ -395,17 +395,20 @@ class Agent(AgentSynchronousProtocol): Alternatively, you can register a custom agent and chat with it provided you have already created it. + from synapseclient import Synapse + from synapseclient.models.agent import Agent + syn = Synapse() - syn.login(silent=True) + syn.login() - my_agent = Agent(cloud_agent_id=AWS_AGENT_ID) + my_agent = Agent(cloud_agent_id="foo") my_agent.register() my_agent.prompt( prompt="Hello", - enable_trace=True, - print_response=True, - ) + enable_trace=True, + print_response=True, + ) Advanced Example: Start and prompt multiple sessions Here, we connect to a custom agent and start one session with the prompt "Hello". @@ -419,7 +422,7 @@ class Agent(AgentSynchronousProtocol): syn = Synapse() syn.login() - my_agent = Agent(registration_id=my_registration_id).get() + my_agent = Agent(registration_id="foo").get() my_agent.prompt( prompt="Hello", @@ -473,7 +476,11 @@ def fill_from_dict(self, agent_registration: Dict[str, str]) -> "Agent": self.cloud_alias_id = agent_registration.get("awsAliasId", None) self.registration_id = agent_registration.get("agentRegistrationId", None) self.registered_on = agent_registration.get("registeredOn", None) - self.type = agent_registration.get("type", None) + self.type = ( + AgentType(agent_registration.get("type")) + if agent_registration.get("type", None) + else None + ) return self @otel_trace_method( @@ -607,6 +614,21 @@ async def prompt_async( synapse_client: If not passed in and caching was not disabled by `Synapse.allow_client_caching(False)` this will use the last created instance from the Synapse class constructor. + + Example: Prompt the baseline Synapse Agent to add annotations to a file on Synapse + The baseline Synpase Agent can be used to add annotations to files. + + from synapseclient import Synapse + + syn = Synapse() + syn.login() + + my_agent = Agent() + my_agent.prompt( + prompt="Add the annotation 'test' to the file 'syn123456789'", + enable_trace=True, + print_response=True, + ) """ if session: await self.get_session_async( diff --git a/tests/unit/synapseclient/models/async/unit_test_agent_async.py b/tests/unit/synapseclient/models/async/unit_test_agent_async.py index cb0f405d0..290094301 100644 --- a/tests/unit/synapseclient/models/async/unit_test_agent_async.py +++ b/tests/unit/synapseclient/models/async/unit_test_agent_async.py @@ -314,10 +314,10 @@ async def test_prompt_trace_enabled_print_response(self) -> None: new_callable=AsyncMock, return_value=self.test_prompt_trace_enabled, ) as mock_send_job_and_wait_async, - patch( - "synapseclient.models.agent.print", - side_effect=print, - ) as mock_print, + patch.object( + self.syn.logger, + "info", + ) as mock_logger_info, ): # GIVEN an existing AgentSession # WHEN I call prompt with trace enabled and print_response enabled @@ -337,7 +337,7 @@ async def test_prompt_trace_enabled_print_response(self) -> None: synapse_client=self.syn, post_exchange_args={"newer_than": 0} ) # AND the trace should be printed - mock_print.assert_called_with( + mock_logger_info.assert_called_with( f"TRACE:\n{self.test_prompt_trace_enabled.trace}" ) @@ -348,10 +348,10 @@ async def test_prompt_trace_disabled_no_print(self) -> None: new_callable=AsyncMock, return_value=self.test_prompt_trace_disabled, ) as mock_send_job_and_wait_async, - patch( - "synapseclient.models.agent.print", - side_effect=print, - ) as mock_print, + patch.object( + self.syn.logger, + "info", + ) as mock_logger_info, ): # WHEN I call prompt with trace disabled and print_response disabled await self.test_session.prompt_async( @@ -370,7 +370,7 @@ async def test_prompt_trace_disabled_no_print(self) -> None: synapse_client=self.syn, post_exchange_args={"newer_than": 0} ) # AND print should not have been called - mock_print.assert_not_called() + mock_logger_info.assert_not_called() class TestAgent: diff --git a/tests/unit/synapseclient/models/synchronous/unit_test_agent.py b/tests/unit/synapseclient/models/synchronous/unit_test_agent.py index 951f30ed8..83f33cb7b 100644 --- a/tests/unit/synapseclient/models/synchronous/unit_test_agent.py +++ b/tests/unit/synapseclient/models/synchronous/unit_test_agent.py @@ -218,10 +218,10 @@ def test_prompt_trace_enabled_print_response(self) -> None: "synapseclient.models.agent.AgentPrompt.send_job_and_wait_async", return_value=self.test_prompt_trace_enabled, ) as mock_send_job_and_wait_async, - patch( - "synapseclient.models.agent.print", - side_effect=print, - ) as mock_print, + patch.object( + self.syn.logger, + "info", + ) as mock_logger_info, ): # GIVEN an existing AgentSession # WHEN I call prompt with trace enabled and print_response enabled @@ -239,7 +239,7 @@ def test_prompt_trace_enabled_print_response(self) -> None: synapse_client=self.syn, post_exchange_args={"newer_than": 0} ) # AND the trace should be printed - mock_print.assert_called_with( + mock_logger_info.assert_called_with( f"TRACE:\n{self.test_prompt_trace_enabled.trace}" ) @@ -249,10 +249,10 @@ def test_prompt_trace_disabled_no_print(self) -> None: "synapseclient.models.agent.AgentPrompt.send_job_and_wait_async", return_value=self.test_prompt_trace_disabled, ) as mock_send_job_and_wait_async, - patch( - "synapseclient.models.agent.print", - side_effect=print, - ) as mock_print, + patch.object( + self.syn.logger, + "info", + ) as mock_logger_info, ): # WHEN I call prompt with trace disabled and print_response disabled self.test_session.prompt( @@ -269,7 +269,7 @@ def test_prompt_trace_disabled_no_print(self) -> None: synapse_client=self.syn, post_exchange_args={"newer_than": 0} ) # AND print should not have been called - mock_print.assert_not_called() + mock_logger_info.assert_not_called() class TestAgent: