diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index d311662fe..cbb68a8e1 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -7,7 +7,8 @@ from haystack import Document, component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace -from cohere import AsyncClient, Client +from cohere import AsyncClientV2, ClientV2 +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response @@ -47,6 +48,7 @@ def __init__( progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + embedding_type: Optional[EmbeddingTypes] = None, ): """ :param api_key: the Cohere API key. @@ -72,6 +74,8 @@ def __init__( to keep the logs clean. :param meta_fields_to_embed: list of meta fields that should be embedded along with the Document text. :param embedding_separator: separator used to concatenate the meta fields to the Document text. + :param embedding_type: the type of embeddings to return. Defaults to float embeddings. + Note that int8, uint8, binary, and ubinary are only valid for v3 models. """ self.api_key = api_key @@ -85,6 +89,7 @@ def __init__( self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.embedding_type = embedding_type or EmbeddingTypes.FLOAT def to_dict(self) -> Dict[str, Any]: """ @@ -106,6 +111,7 @@ def to_dict(self) -> Dict[str, Any]: progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + embedding_type=self.embedding_type.value, ) @classmethod @@ -120,6 +126,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereDocumentEmbedder": """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) + + # Convert embedding_type string to EmbeddingTypes enum value + init_params["embedding_type"] = EmbeddingTypes.from_str(init_params["embedding_type"]) + return default_from_dict(cls, data) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: @@ -163,17 +173,19 @@ def run(self, documents: List[Document]): assert api_key is not None if self.use_async_client: - cohere_client = AsyncClient( + cohere_client = AsyncClientV2( api_key, base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) all_embeddings, metadata = asyncio.run( - get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate) + get_async_response( + cohere_client, texts_to_embed, self.model, self.input_type, self.truncate, self.embedding_type + ) ) else: - cohere_client = Client( + cohere_client = ClientV2( api_key, base_url=self.api_base_url, timeout=self.timeout, @@ -187,6 +199,7 @@ def run(self, documents: List[Document]): self.truncate, self.batch_size, self.progress_bar, + self.embedding_type, ) for doc, embeddings in zip(documents, all_embeddings): diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/embedding_types.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/embedding_types.py new file mode 100644 index 000000000..2f11c02cb --- /dev/null +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/embedding_types.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from enum import Enum + + +class EmbeddingTypes(Enum): + """ + Supported types for Cohere embeddings. + + FLOAT: Default float embeddings. Valid for all models. + INT8: Signed int8 embeddings. Valid for only v3 models. + UINT8: Unsigned int8 embeddings. Valid for only v3 models. + BINARY: Signed binary embeddings. Valid for only v3 models. + UBINARY: Unsigned binary embeddings. Valid for only v3 models. + """ + + FLOAT = "float" + INT8 = "int8" + UINT8 = "uint8" + BINARY = "binary" + UBINARY = "ubinary" + + def __str__(self): + return self.value + + @staticmethod + def from_str(string: str) -> "EmbeddingTypes": + """ + Convert a string to an EmbeddingTypes enum. + """ + enum_map = {e.value: e for e in EmbeddingTypes} + embedding_type = enum_map.get(string.lower()) + if embedding_type is None: + msg = f"Unknown embedding type '{string}'. Supported types are: {list(enum_map.keys())}" + raise ValueError(msg) + return embedding_type diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index c1e9bd613..fc7ff8cd2 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -2,12 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 import asyncio -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace -from cohere import AsyncClient, Client +from cohere import AsyncClientV2, ClientV2 +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response @@ -40,6 +41,7 @@ def __init__( truncate: str = "END", use_async_client: bool = False, timeout: int = 120, + embedding_type: Optional[EmbeddingTypes] = None, ): """ :param api_key: the Cohere API key. @@ -60,6 +62,8 @@ def __init__( :param use_async_client: flag to select the AsyncClient. It is recommended to use AsyncClient for applications with many concurrent calls. :param timeout: request timeout in seconds. + :param embedding_type: the type of embeddings to return. Defaults to float embeddings. + Note that int8, uint8, binary, and ubinary are only valid for v3 models. """ self.api_key = api_key @@ -69,6 +73,7 @@ def __init__( self.truncate = truncate self.use_async_client = use_async_client self.timeout = timeout + self.embedding_type = embedding_type or EmbeddingTypes.FLOAT def to_dict(self) -> Dict[str, Any]: """ @@ -86,6 +91,7 @@ def to_dict(self) -> Dict[str, Any]: truncate=self.truncate, use_async_client=self.use_async_client, timeout=self.timeout, + embedding_type=self.embedding_type.value, ) @classmethod @@ -100,6 +106,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder": """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) + + # Convert embedding_type string to EmbeddingTypes enum value + init_params["embedding_type"] = EmbeddingTypes.from_str(init_params["embedding_type"]) + return default_from_dict(cls, data) @component.output_types(embedding=List[float], meta=Dict[str, Any]) @@ -125,22 +135,26 @@ def run(self, text: str): assert api_key is not None if self.use_async_client: - cohere_client = AsyncClient( + cohere_client = AsyncClientV2( api_key, base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) embedding, metadata = asyncio.run( - get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate) + get_async_response( + cohere_client, [text], self.model, self.input_type, self.truncate, self.embedding_type + ) ) else: - cohere_client = Client( + cohere_client = ClientV2( api_key, base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) - embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate) + embedding, metadata = get_response( + cohere_client, [text], self.model, self.input_type, self.truncate, embedding_type=self.embedding_type + ) return {"embedding": embedding[0], "meta": metadata} diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py index a5c20cb35..951938143 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py @@ -1,14 +1,22 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from tqdm import tqdm -from cohere import AsyncClient, Client +from cohere import AsyncClientV2, ClientV2 +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes -async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): +async def get_async_response( + cohere_async_client: AsyncClientV2, + texts: List[str], + model_name, + input_type, + truncate, + embedding_type: Optional[EmbeddingTypes] = None, +): """Embeds a list of texts asynchronously using the Cohere API. :param cohere_async_client: the Cohere `AsyncClient` @@ -17,6 +25,7 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], :param input_type: one of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed. :param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length. + :param embedding_type: the type of embeddings to return. Defaults to float embeddings. :returns: A tuple of the embeddings and metadata. @@ -25,17 +34,36 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} - response = await cohere_async_client.embed(texts=texts, model=model_name, input_type=input_type, truncate=truncate) + embedding_type = embedding_type or EmbeddingTypes.FLOAT + response = await cohere_async_client.embed( + texts=texts, + model=model_name, + input_type=input_type, + truncate=truncate, + embedding_types=[embedding_type.value], + ) if response.meta is not None: metadata = response.meta - for emb in response.embeddings: - all_embeddings.append(emb) + for emb_tuple in response.embeddings: + # emb_tuple[0] is a str denoting the embedding type (e.g. "float", "int8", etc.) + if emb_tuple[1] is not None: + # ok we have embeddings for this type, let's take all + # the embeddings (a list of embeddings) and break the loop + all_embeddings.extend(emb_tuple[1]) + break return all_embeddings, metadata def get_response( - cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False + cohere_client: ClientV2, + texts: List[str], + model_name, + input_type, + truncate, + batch_size=32, + progress_bar=False, + embedding_type: Optional[EmbeddingTypes] = None, ) -> Tuple[List[List[float]], Dict[str, Any]]: """Embeds a list of texts using the Cohere API. @@ -47,6 +75,7 @@ def get_response( :param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length. :param batch_size: the batch size to use :param progress_bar: if `True`, show a progress bar + :param embedding_type: the type of embeddings to return. Defaults to float embeddings. :returns: A tuple of the embeddings and metadata. @@ -55,6 +84,7 @@ def get_response( all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} + embedding_type = embedding_type or EmbeddingTypes.FLOAT for i in tqdm( range(0, len(texts), batch_size), @@ -62,9 +92,20 @@ def get_response( desc="Calculating embeddings", ): batch = texts[i : i + batch_size] - response = cohere_client.embed(texts=batch, model=model_name, input_type=input_type, truncate=truncate) - for emb in response.embeddings: - all_embeddings.append(emb) + response = cohere_client.embed( + texts=batch, + model=model_name, + input_type=input_type, + truncate=truncate, + embedding_types=[embedding_type.value], + ) + ## response.embeddings always returns 5 tuples, one tuple per embedding type + ## let's take first non None tuple as that's the one we want + for emb_tuple in response.embeddings: + # emb_tuple[0] is a str denoting the embedding type (e.g. "float", "int8", etc.) + if emb_tuple[1] is not None: + # ok we have embeddings for this type, let's take all the embeddings (a list of embeddings) + all_embeddings.extend(emb_tuple[1]) if response.meta is not None: metadata = response.meta diff --git a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py index 7da823bbc..2c3060cb9 100644 --- a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py +++ b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py @@ -40,6 +40,7 @@ def __init__( max_chunks_per_doc: Optional[int] = None, meta_fields_to_embed: Optional[List[str]] = None, meta_data_separator: str = "\n", + max_tokens_per_doc: int = 4096, ): """ Creates an instance of the 'CohereRanker'. @@ -57,6 +58,7 @@ def __init__( with the document content for reranking. :param meta_data_separator: Separator used to concatenate the meta fields to the Document content. + :param max_tokens_per_doc: The maximum number of tokens to embed for each document defaults to 4096. """ self.model_name = model self.api_key = api_key @@ -65,7 +67,18 @@ def __init__( self.max_chunks_per_doc = max_chunks_per_doc self.meta_fields_to_embed = meta_fields_to_embed or [] self.meta_data_separator = meta_data_separator - self._cohere_client = cohere.Client( + self.max_tokens_per_doc = max_tokens_per_doc + if max_chunks_per_doc is not None: + # Note: max_chunks_per_doc is currently not supported by the Cohere V2 API + # See: https://docs.cohere.com/reference/rerank + import warnings + + warnings.warn( + "The max_chunks_per_doc parameter currently has no effect as it is not supported by the Cohere V2 API.", + UserWarning, + stacklevel=2, + ) + self._cohere_client = cohere.ClientV2( api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" ) @@ -85,6 +98,7 @@ def to_dict(self) -> Dict[str, Any]: max_chunks_per_doc=self.max_chunks_per_doc, meta_fields_to_embed=self.meta_fields_to_embed, meta_data_separator=self.meta_data_separator, + max_tokens_per_doc=self.max_tokens_per_doc, ) @classmethod @@ -152,7 +166,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None model=self.model_name, query=query, documents=cohere_input_docs, - max_chunks_per_doc=self.max_chunks_per_doc, + max_tokens_per_doc=self.max_tokens_per_doc, top_n=top_k, ) indices = [output.index for output in response.results] diff --git a/integrations/cohere/tests/test_cohere_ranker.py b/integrations/cohere/tests/test_cohere_ranker.py index ff861b39d..34a9d1456 100644 --- a/integrations/cohere/tests/test_cohere_ranker.py +++ b/integrations/cohere/tests/test_cohere_ranker.py @@ -20,7 +20,7 @@ def mock_ranker_response(): RerankResult, RerankResult] """ - with patch("cohere.Client.rerank", autospec=True) as mock_ranker_response: + with patch("cohere.ClientV2.rerank", autospec=True) as mock_ranker_response: mock_response = Mock() @@ -48,6 +48,7 @@ def test_init_default(self, monkeypatch): assert component.max_chunks_per_doc is None assert component.meta_fields_to_embed == [] assert component.meta_data_separator == "\n" + assert component.max_tokens_per_doc == 4096 def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("CO_API_KEY", raising=False) @@ -65,6 +66,7 @@ def test_init_with_parameters(self, monkeypatch): max_chunks_per_doc=40, meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=",", + max_tokens_per_doc=100, ) assert component.model_name == "rerank-multilingual-v2.0" assert component.top_k == 5 @@ -73,6 +75,7 @@ def test_init_with_parameters(self, monkeypatch): assert component.max_chunks_per_doc == 40 assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] assert component.meta_data_separator == "," + assert component.max_tokens_per_doc == 100 def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") @@ -88,6 +91,7 @@ def test_to_dict_default(self, monkeypatch): "max_chunks_per_doc": None, "meta_fields_to_embed": [], "meta_data_separator": "\n", + "max_tokens_per_doc": 4096, }, } @@ -101,6 +105,7 @@ def test_to_dict_with_parameters(self, monkeypatch): max_chunks_per_doc=50, meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=",", + max_tokens_per_doc=100, ) data = component.to_dict() assert data == { @@ -113,6 +118,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "max_chunks_per_doc": 50, "meta_fields_to_embed": ["meta_field_1", "meta_field_2"], "meta_data_separator": ",", + "max_tokens_per_doc": 100, }, } @@ -128,6 +134,7 @@ def test_from_dict(self, monkeypatch): "max_chunks_per_doc": 50, "meta_fields_to_embed": ["meta_field_1", "meta_field_2"], "meta_data_separator": ",", + "max_tokens_per_doc": 100, }, } component = CohereRanker.from_dict(data) @@ -138,6 +145,7 @@ def test_from_dict(self, monkeypatch): assert component.max_chunks_per_doc == 50 assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] assert component.meta_data_separator == "," + assert component.max_tokens_per_doc == 100 def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("CO_API_KEY", raising=False) @@ -149,6 +157,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "top_k": 2, "max_chunks_per_doc": 50, + "max_tokens_per_doc": 100, }, } with pytest.raises(ValueError, match="None of the following authentication environment variables are set: *"): diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index d69e1a5a2..895e27c7d 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -8,6 +8,7 @@ from haystack.utils import Secret from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes pytestmark = pytest.mark.embedders COHERE_API_URL = "https://api.cohere.com" @@ -27,6 +28,7 @@ def test_init_default(self): assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" + assert embedder.embedding_type == EmbeddingTypes.FLOAT def test_init_with_parameters(self): embedder = CohereDocumentEmbedder( @@ -53,6 +55,7 @@ def test_init_with_parameters(self): assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == "-" + assert embedder.embedding_type == EmbeddingTypes.FLOAT def test_to_dict(self): embedder_component = CohereDocumentEmbedder() @@ -71,6 +74,7 @@ def test_to_dict(self): "progress_bar": True, "meta_fields_to_embed": [], "embedding_separator": "\n", + "embedding_type": "float", }, } @@ -87,6 +91,7 @@ def test_to_dict_with_custom_init_parameters(self): progress_bar=False, meta_fields_to_embed=["text_field"], embedding_separator="-", + embedding_type=EmbeddingTypes.INT8, ) component_dict = embedder_component.to_dict() assert component_dict == { @@ -103,6 +108,7 @@ def test_to_dict_with_custom_init_parameters(self): "progress_bar": False, "meta_fields_to_embed": ["text_field"], "embedding_separator": "-", + "embedding_type": "int8", }, } @@ -112,7 +118,7 @@ def test_to_dict_with_custom_init_parameters(self): ) @pytest.mark.integration def test_run(self): - embedder = CohereDocumentEmbedder() + embedder = CohereDocumentEmbedder(model="embed-english-v2.0", embedding_type=EmbeddingTypes.FLOAT) docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index 80f7c1a3e..58fff3900 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -7,6 +7,7 @@ from haystack.utils import Secret from haystack_integrations.components.embedders.cohere import CohereTextEmbedder +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes pytestmark = pytest.mark.embedders COHERE_API_URL = "https://api.cohere.com" @@ -47,6 +48,7 @@ def test_init_with_parameters(self): assert embedder.truncate == "START" assert embedder.use_async_client is True assert embedder.timeout == 60 + assert embedder.embedding_type == EmbeddingTypes.FLOAT def test_to_dict(self): """ @@ -64,6 +66,7 @@ def test_to_dict(self): "truncate": "END", "use_async_client": False, "timeout": 120, + "embedding_type": "float", }, } @@ -79,6 +82,7 @@ def test_to_dict_with_custom_init_parameters(self): truncate="START", use_async_client=True, timeout=60, + embedding_type=EmbeddingTypes.INT8, ) component_dict = embedder_component.to_dict() assert component_dict == { @@ -91,6 +95,7 @@ def test_to_dict_with_custom_init_parameters(self): "truncate": "START", "use_async_client": True, "timeout": 60, + "embedding_type": "int8", }, }