Skip to content

Commit

Permalink
Update test_serde_in_pipeline test
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jan 23, 2025
1 parent 91f2330 commit 8889d95
Showing 1 changed file with 76 additions and 32 deletions.
108 changes: 76 additions & 32 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,42 +574,86 @@ def test_pipeline_with_amazon_bedrock_chat_generator(self, model_name, tools):
== results["tool_invoker"]["tool_messages"][0].tool_call_result.result
)

@pytest.mark.parametrize("model_name", [MODELS_TO_TEST_WITH_TOOLS[0]]) # just one model is enough
@pytest.mark.integration
def test_pipeline_with_amazon_bedrock_chat_generator_serde(self, model_name, tools):
def test_serde_in_pipeline(self, mock_boto3_session, monkeypatch):
"""
Test that the AmazonBedrockChatGenerator component can be serialized and deserialized in a pipeline
Test serialization/deserialization of AmazonBedrockChatGenerator in a Pipeline,
including YAML conversion and detailed dictionary validation
"""
# Create original pipeline
# Set mock AWS credentials
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-key")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret")
monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1")

# Create a test tool
tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
parameters={"city": {"type": "string"}},
function=weather
)

# Create generator with specific configuration
generator = AmazonBedrockChatGenerator(
model="anthropic.claude-3-5-sonnet-20240620-v1:0",
generation_kwargs={"temperature": 0.7},
stop_words=["eviscerate"],
streaming_callback=print_streaming_chunk,
tools=[tool]
)

# Create and configure pipeline
pipeline = Pipeline()
pipeline.add_component("generator", AmazonBedrockChatGenerator(model=model_name, tools=tools))
pipeline.add_component("tool_invoker", ToolInvoker(tools=tools))
pipeline.connect("generator", "tool_invoker")
pipeline.add_component("generator", generator)

# Serialize and deserialize
# Get pipeline dictionary and verify its structure
pipeline_dict = pipeline.to_dict()

# Verify tools in serialized dict
generator_tools = pipeline_dict["components"]["generator"]["init_parameters"]["tools"]
tool_invoker_tools = pipeline_dict["components"]["tool_invoker"]["init_parameters"]["tools"]

# Both components should have the same tool configuration
assert generator_tools == tool_invoker_tools
assert len(generator_tools) == 1

# Verify tool details
tool_dict = generator_tools[0]
assert tool_dict["type"] == "haystack.tools.tool.Tool"
assert tool_dict["data"]["name"] == "weather"
assert tool_dict["data"]["description"] == "useful to determine the weather in a given location"
assert tool_dict["data"]["parameters"] == {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
assert pipeline_dict == {
"metadata": {},
"max_runs_per_component": 100,
"components": {
"generator": {
"type": KLASS,
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"generation_kwargs": {"temperature": 0.7},
"stop_words": ["eviscerate"],
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"boto3_config": None,
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "useful to determine the weather in a given location",
"parameters": {
"city": {"type": "string"}
},
"function": "tests.test_chat_generator.weather"
}
}
]
}
}
},
"connections": []
}
assert tool_dict["data"]["function"] == "tests.test_chat_generator.weather"

# Load pipeline and verify it works
loaded_pipeline = Pipeline.from_dict(pipeline_dict)

assert pipeline == loaded_pipeline
# Test YAML serialization/deserialization
pipeline_yaml = pipeline.dumps()
new_pipeline = Pipeline.loads(pipeline_yaml)
assert new_pipeline == pipeline

# Verify the loaded pipeline's generator has the same configuration
loaded_generator = new_pipeline.get_component("generator")
assert loaded_generator.model == generator.model
assert loaded_generator.generation_kwargs == generator.generation_kwargs
assert loaded_generator.streaming_callback == generator.streaming_callback
assert len(loaded_generator.tools) == len(generator.tools)
assert loaded_generator.tools[0].name == generator.tools[0].name
assert loaded_generator.tools[0].description == generator.tools[0].description
assert loaded_generator.tools[0].parameters == generator.tools[0].parameters

0 comments on commit 8889d95

Please sign in to comment.