diff --git a/haystack/preview/components/retrievers/__init__.py b/haystack/preview/components/retrievers/__init__.py index 2783d7f36a..ea387ae94a 100644 --- a/haystack/preview/components/retrievers/__init__.py +++ b/haystack/preview/components/retrievers/__init__.py @@ -1,3 +1,3 @@ -from haystack.preview.components.retrievers.memory import MemoryRetriever +from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever -__all__ = ["MemoryRetriever"] +__all__ = ["MemoryBM25Retriever", "MemoryEmbeddingRetriever"] diff --git a/haystack/preview/components/retrievers/memory.py b/haystack/preview/components/retrievers/memory.py index e1792057c1..ca85d60414 100644 --- a/haystack/preview/components/retrievers/memory.py +++ b/haystack/preview/components/retrievers/memory.py @@ -5,7 +5,7 @@ @component -class MemoryRetriever: +class MemoryBM25Retriever: """ A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm. @@ -20,12 +20,12 @@ def __init__( scale_score: bool = True, ): """ - Create a MemoryRetriever component. + Create a MemoryBM25Retriever component. :param document_store: An instance of MemoryDocumentStore. - :param filters: A dictionary with filters to narrow down the search space (default is None). - :param top_k: The maximum number of documents to retrieve (default is 10). - :param scale_score: Whether to scale the BM25 score or not (default is True). + :param filters: A dictionary with filters to narrow down the search space. Default is None. + :param top_k: The maximum number of documents to retrieve. Default is 10. + :param scale_score: Whether to scale the BM25 score or not. Default is True. :raises ValueError: If the specified top_k is not > 0. """ @@ -51,7 +51,7 @@ def to_dict(self) -> Dict[str, Any]: ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MemoryRetriever": + def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever": """ Deserialize this component from a dictionary. """ @@ -77,13 +77,12 @@ def run( scale_score: Optional[bool] = None, ): """ - Run the MemoryRetriever on the given input data. + Run the MemoryBM25Retriever on the given input data. :param query: The query string for the retriever. :param filters: A dictionary with filters to narrow down the search space. :param top_k: The maximum number of documents to return. :param scale_score: Whether to scale the BM25 scores or not. - :param document_stores: A dictionary mapping DocumentStore names to instances. :return: The retrieved documents. :raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance. @@ -101,3 +100,119 @@ def run( self.document_store.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score) ) return {"documents": docs} + + +@component +class MemoryEmbeddingRetriever: + """ + A component for retrieving documents from a MemoryDocumentStore using a vector similarity metric. + + Needs to be connected to a MemoryDocumentStore to run. + """ + + def __init__( + self, + document_store: MemoryDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, + return_embedding: bool = False, + ): + """ + Create a MemoryEmbeddingRetriever component. + + :param document_store: An instance of MemoryDocumentStore. + :param filters: A dictionary with filters to narrow down the search space. Default is None. + :param top_k: The maximum number of documents to retrieve. Default is 10. + :param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True. + :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False. + + :raises ValueError: If the specified top_k is not > 0. + """ + if not isinstance(document_store, MemoryDocumentStore): + raise ValueError("document_store must be an instance of MemoryDocumentStore") + + self.document_store = document_store + + if top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + + self.filters = filters + self.top_k = top_k + self.scale_score = scale_score + self.return_embedding = return_embedding + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + docstore = self.document_store.to_dict() + return default_to_dict( + self, + document_store=docstore, + filters=self.filters, + top_k=self.top_k, + scale_score=self.scale_score, + return_embedding=self.return_embedding, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemoryBM25Retriever": + """ + Deserialize this component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + if "document_store" not in init_params: + raise DeserializationError("Missing 'document_store' in serialization data") + if "type" not in init_params["document_store"]: + raise DeserializationError("Missing 'type' in document store's serialization data") + if init_params["document_store"]["type"] not in document_store.registry: + raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found") + + docstore_class = document_store.registry[init_params["document_store"]["type"]] + docstore = docstore_class.from_dict(init_params["document_store"]) + data["init_parameters"]["document_store"] = docstore + return default_from_dict(cls, data) + + @component.output_types(documents=List[List[Document]]) + def run( + self, + queries_embeddings: List[List[float]], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + return_embedding: Optional[bool] = None, + ): + """ + Run the MemoryEmbeddingRetriever on the given input data. + + :param queries_embeddings: Embeddings of the queries. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The maximum number of documents to return. + :param scale_score: Whether to scale the scores of the retrieved documents or not. + :param return_embedding: Whether to return the embedding of the retrieved Documents. + :return: The retrieved documents. + + :raises ValueError: If the specified DocumentStore is not found or is not a MemoryDocumentStore instance. + """ + if filters is None: + filters = self.filters + if top_k is None: + top_k = self.top_k + if scale_score is None: + scale_score = self.scale_score + if return_embedding is None: + return_embedding = self.return_embedding + + docs = [] + for query_embedding in queries_embeddings: + docs.append( + self.document_store.embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + scale_score=scale_score, + return_embedding=return_embedding, + ) + ) + return {"documents": docs} diff --git a/releasenotes/notes/memory-embedding-retriever-dde22dedc83d1603.yaml b/releasenotes/notes/memory-embedding-retriever-dde22dedc83d1603.yaml new file mode 100644 index 0000000000..946dc4fda7 --- /dev/null +++ b/releasenotes/notes/memory-embedding-retriever-dde22dedc83d1603.yaml @@ -0,0 +1,6 @@ +--- +preview: + - | + Rename `MemoryRetriever` to `MemoryBM25Retriever` + Add `MemoryEmbeddingRetriever`, which takes as input a query embedding and + retrieves the most relevant Documents from a `MemoryDocumentStore`. diff --git a/test/preview/components/retrievers/test_memory_retriever.py b/test/preview/components/retrievers/test_memory_retriever.py index 11752711a9..ad09352ca3 100644 --- a/test/preview/components/retrievers/test_memory_retriever.py +++ b/test/preview/components/retrievers/test_memory_retriever.py @@ -1,11 +1,10 @@ from typing import Dict, Any -from unittest.mock import MagicMock, patch import pytest from haystack.preview import Pipeline, DeserializationError from haystack.preview.testing.factory import document_store_class -from haystack.preview.components.retrievers.memory import MemoryRetriever +from haystack.preview.components.retrievers.memory import MemoryBM25Retriever, MemoryEmbeddingRetriever from haystack.preview.dataclasses import Document from haystack.preview.document_stores import MemoryDocumentStore @@ -21,36 +20,39 @@ def mock_docs(): ] -class TestMemoryRetriever: +class TestMemoryRetrievers: + @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.unit - def test_init_default(self): - retriever = MemoryRetriever(MemoryDocumentStore()) + def test_init_default(self, retriever_cls): + retriever = retriever_cls(MemoryDocumentStore()) assert retriever.filters is None assert retriever.top_k == 10 assert retriever.scale_score + @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.unit - def test_init_with_parameters(self): - retriever = MemoryRetriever(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False) + def test_init_with_parameters(self, retriever_cls): + retriever = retriever_cls(MemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=False) assert retriever.filters == {"name": "test.txt"} assert retriever.top_k == 5 assert not retriever.scale_score + @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.unit - def test_init_with_invalid_top_k_parameter(self): + def test_init_with_invalid_top_k_parameter(self, retriever_cls): with pytest.raises(ValueError, match="top_k must be > 0, but got -2"): - MemoryRetriever(MemoryDocumentStore(), top_k=-2, scale_score=False) + retriever_cls(MemoryDocumentStore(), top_k=-2, scale_score=False) @pytest.mark.unit - def test_to_dict(self): + def test_bm25_retriever_to_dict(self): MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) document_store = MyFakeStore() document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} - component = MemoryRetriever(document_store=document_store) + component = MemoryBM25Retriever(document_store=document_store) data = component.to_dict() assert data == { - "type": "MemoryRetriever", + "type": "MemoryBM25Retriever", "init_parameters": { "document_store": {"type": "MyFakeStore", "init_parameters": {}}, "filters": None, @@ -60,69 +62,116 @@ def test_to_dict(self): } @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): + def test_embedding_retriever_to_dict(self): MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) document_store = MyFakeStore() document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} - component = MemoryRetriever( + component = MemoryEmbeddingRetriever(document_store=document_store) + + data = component.to_dict() + assert data == { + "type": "MemoryEmbeddingRetriever", + "init_parameters": { + "document_store": {"type": "MyFakeStore", "init_parameters": {}}, + "filters": None, + "top_k": 10, + "scale_score": True, + "return_embedding": False, + }, + } + + @pytest.mark.unit + def test_bm25_retriever_to_dict_with_custom_init_parameters(self): + MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) + document_store = MyFakeStore() + document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} + component = MemoryBM25Retriever( document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=False ) data = component.to_dict() assert data == { - "type": "MemoryRetriever", + "type": "MemoryBM25Retriever", + "init_parameters": { + "document_store": {"type": "MyFakeStore", "init_parameters": {}}, + "filters": {"name": "test.txt"}, + "top_k": 5, + "scale_score": False, + }, + } + + @pytest.mark.unit + def test_embedding_retriever_to_dict_with_custom_init_parameters(self): + MyFakeStore = document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) + document_store = MyFakeStore() + document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} + component = MemoryEmbeddingRetriever( + document_store=document_store, + filters={"name": "test.txt"}, + top_k=5, + scale_score=False, + return_embedding=True, + ) + data = component.to_dict() + assert data == { + "type": "MemoryEmbeddingRetriever", "init_parameters": { "document_store": {"type": "MyFakeStore", "init_parameters": {}}, "filters": {"name": "test.txt"}, "top_k": 5, "scale_score": False, + "return_embedding": True, }, } + @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.unit - def test_from_dict(self): + def test_from_dict(self, retriever_cls): document_store_class("MyFakeStore", bases=(MemoryDocumentStore,)) data = { - "type": "MemoryRetriever", + "type": retriever_cls.__name__, "init_parameters": { "document_store": {"type": "MyFakeStore", "init_parameters": {}}, "filters": {"name": "test.txt"}, "top_k": 5, }, } - component = MemoryRetriever.from_dict(data) + component = retriever_cls.from_dict(data) assert isinstance(component.document_store, MemoryDocumentStore) assert component.filters == {"name": "test.txt"} assert component.top_k == 5 assert component.scale_score + @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.unit - def test_from_dict_without_docstore(self): - data = {"type": "MemoryRetriever", "init_parameters": {}} + def test_from_dict_without_docstore(self, retriever_cls): + data = {"type": retriever_cls.__name__, "init_parameters": {}} with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): - MemoryRetriever.from_dict(data) + retriever_cls.from_dict(data) + @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.unit - def test_from_dict_without_docstore_type(self): - data = {"type": "MemoryRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}} + def test_from_dict_without_docstore_type(self, retriever_cls): + data = {"type": retriever_cls.__name__, "init_parameters": {"document_store": {"init_parameters": {}}}} with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): - MemoryRetriever.from_dict(data) + retriever_cls.from_dict(data) + @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) @pytest.mark.unit - def test_from_dict_nonexisting_docstore(self): + def test_from_dict_nonexisting_docstore(self, retriever_cls): data = { - "type": "MemoryRetriever", + "type": retriever_cls.__name__, "init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}}, } with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"): - MemoryRetriever.from_dict(data) + retriever_cls.from_dict(data) @pytest.mark.unit - def test_valid_run(self, mock_docs): + def test_bm25_retriever_valid_run(self, mock_docs): top_k = 5 ds = MemoryDocumentStore() ds.write_documents(mock_docs) - retriever = MemoryRetriever(ds, top_k=top_k) + retriever = MemoryBM25Retriever(ds, top_k=top_k) result = retriever.run(queries=["PHP", "Java"]) assert "documents" in result @@ -133,10 +182,32 @@ def test_valid_run(self, mock_docs): assert result["documents"][1][0].content == "Java is a popular programming language" @pytest.mark.unit - def test_invalid_run_wrong_store_type(self): + def test_embedding_retriever_valid_run(self): + top_k = 3 + ds = MemoryDocumentStore(embedding_similarity_function="cosine") + docs = [ + Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]), + Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]), + ] + ds.write_documents(docs) + + retriever = MemoryEmbeddingRetriever(ds, top_k=top_k) + result = retriever.run(queries_embeddings=[[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], return_embedding=True) + + assert "documents" in result + assert len(result["documents"]) == 2 + assert len(result["documents"][0]) == top_k + assert len(result["documents"][1]) == top_k + assert result["documents"][0][0].embedding == [0.1, 0.2, 0.3, 0.4] + assert result["documents"][1][0].embedding == [1.0, 1.0, 1.0, 1.0] + + @pytest.mark.parametrize("retriever_cls", [MemoryBM25Retriever, MemoryEmbeddingRetriever]) + @pytest.mark.unit + def test_invalid_run_wrong_store_type(self, retriever_cls): SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore") with pytest.raises(ValueError, match="document_store must be an instance of MemoryDocumentStore"): - MemoryRetriever(SomeOtherDocumentStore()) + retriever_cls(SomeOtherDocumentStore()) @pytest.mark.integration @pytest.mark.parametrize( @@ -146,10 +217,10 @@ def test_invalid_run_wrong_store_type(self): ("Java", "Java is a popular programming language"), ], ) - def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): + def test_run_bm25_retriever_with_pipeline(self, mock_docs, query: str, query_result: str): ds = MemoryDocumentStore() ds.write_documents(mock_docs) - retriever = MemoryRetriever(ds) + retriever = MemoryBM25Retriever(ds) pipeline = Pipeline() pipeline.add_component("retriever", retriever) @@ -161,6 +232,38 @@ def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): assert results_docs assert results_docs[0][0].content == query_result + @pytest.mark.integration + def test_run_embedding_retriever_with_pipeline(self): + ds = MemoryDocumentStore(embedding_similarity_function="cosine") + top_k = 2 + docs = [ + Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]), + Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]), + ] + ds.write_documents(docs) + retriever = MemoryEmbeddingRetriever(ds, top_k=top_k) + + pipeline = Pipeline() + pipeline.add_component("retriever", retriever) + result: Dict[str, Any] = pipeline.run( + data={ + "retriever": { + "queries_embeddings": [[0.2, 0.4, 0.6, 0.8], [0.1, 0.1, 0.1, 0.1]], + "return_embedding": True, + } + } + ) + + assert result + assert "retriever" in result + results_docs = result["retriever"]["documents"] + assert results_docs + assert len(results_docs[0]) == top_k + assert len(results_docs[1]) == top_k + assert results_docs[0][0].embedding == [0.1, 0.2, 0.3, 0.4] + assert results_docs[1][0].embedding == [1.0, 1.0, 1.0, 1.0] + @pytest.mark.integration @pytest.mark.parametrize( "query, query_result, top_k", @@ -170,10 +273,10 @@ def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): ("Ruby", "Ruby is a popular programming language", 3), ], ) - def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int): + def test_run_bm25_retriever_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int): ds = MemoryDocumentStore() ds.write_documents(mock_docs) - retriever = MemoryRetriever(ds) + retriever = MemoryBM25Retriever(ds) pipeline = Pipeline() pipeline.add_component("retriever", retriever)