Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Migrate Cohere to V2 #1321

Merged
merged 8 commits into from
Jan 28, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

from cohere import AsyncClient, Client
from cohere import AsyncClientV2, ClientV2
from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes
from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response


Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
embedding_type: Optional[EmbeddingTypes] = None,
):
"""
:param api_key: the Cohere API key.
Expand All @@ -72,6 +74,8 @@ def __init__(
to keep the logs clean.
:param meta_fields_to_embed: list of meta fields that should be embedded along with the Document text.
:param embedding_separator: separator used to concatenate the meta fields to the Document text.
:param embedding_type: the type of embeddings to return. Defaults to float embeddings.
Note that int8, uint8, binary, and ubinary are only valid for v3 models.
"""

self.api_key = api_key
Expand All @@ -85,6 +89,7 @@ def __init__(
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.embedding_type = embedding_type or EmbeddingTypes.FLOAT

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -106,6 +111,7 @@ def to_dict(self) -> Dict[str, Any]:
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
embedding_type=self.embedding_type.value,
)

@classmethod
Expand All @@ -120,6 +126,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereDocumentEmbedder":
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])

# Convert embedding_type string to EmbeddingTypes enum value
init_params["embedding_type"] = EmbeddingTypes.from_str(init_params["embedding_type"])

return default_from_dict(cls, data)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down Expand Up @@ -163,17 +173,19 @@ def run(self, documents: List[Document]):
assert api_key is not None

if self.use_async_client:
cohere_client = AsyncClient(
cohere_client = AsyncClientV2(
api_key,
base_url=self.api_base_url,
timeout=self.timeout,
client_name="haystack",
)
all_embeddings, metadata = asyncio.run(
get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate)
get_async_response(
cohere_client, texts_to_embed, self.model, self.input_type, self.truncate, self.embedding_type
)
)
else:
cohere_client = Client(
cohere_client = ClientV2(
api_key,
base_url=self.api_base_url,
timeout=self.timeout,
Expand All @@ -187,6 +199,7 @@ def run(self, documents: List[Document]):
self.truncate,
self.batch_size,
self.progress_bar,
self.embedding_type,
)

for doc, embeddings in zip(documents, all_embeddings):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from enum import Enum


class EmbeddingTypes(Enum):
"""
Supported types for Cohere embeddings.

FLOAT: Default float embeddings. Valid for all models.
INT8: Signed int8 embeddings. Valid for only v3 models.
UINT8: Unsigned int8 embeddings. Valid for only v3 models.
BINARY: Signed binary embeddings. Valid for only v3 models.
UBINARY: Unsigned binary embeddings. Valid for only v3 models.
"""

FLOAT = "float"
INT8 = "int8"
UINT8 = "uint8"
BINARY = "binary"
UBINARY = "ubinary"

def __str__(self):
return self.value

@staticmethod
def from_str(string: str) -> "EmbeddingTypes":
"""
Convert a string to an EmbeddingTypes enum.
"""
enum_map = {e.value: e for e in EmbeddingTypes}
embedding_type = enum_map.get(string.lower())
if embedding_type is None:
msg = f"Unknown embedding type '{string}'. Supported types are: {list(enum_map.keys())}"
raise ValueError(msg)
return embedding_type
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

from cohere import AsyncClient, Client
from cohere import AsyncClientV2, ClientV2
from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes
from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response


Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(
truncate: str = "END",
use_async_client: bool = False,
timeout: int = 120,
embedding_type: Optional[EmbeddingTypes] = None,
):
"""
:param api_key: the Cohere API key.
Expand All @@ -60,6 +62,8 @@ def __init__(
:param use_async_client: flag to select the AsyncClient. It is recommended to use
AsyncClient for applications with many concurrent calls.
:param timeout: request timeout in seconds.
:param embedding_type: the type of embeddings to return. Defaults to float embeddings.
Note that int8, uint8, binary, and ubinary are only valid for v3 models.
"""

self.api_key = api_key
Expand All @@ -69,6 +73,7 @@ def __init__(
self.truncate = truncate
self.use_async_client = use_async_client
self.timeout = timeout
self.embedding_type = embedding_type or EmbeddingTypes.FLOAT

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -86,6 +91,7 @@ def to_dict(self) -> Dict[str, Any]:
truncate=self.truncate,
use_async_client=self.use_async_client,
timeout=self.timeout,
embedding_type=self.embedding_type.value,
)

@classmethod
Expand All @@ -100,6 +106,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder":
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])

# Convert embedding_type string to EmbeddingTypes enum value
init_params["embedding_type"] = EmbeddingTypes.from_str(init_params["embedding_type"])

return default_from_dict(cls, data)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
Expand All @@ -125,22 +135,26 @@ def run(self, text: str):
assert api_key is not None

if self.use_async_client:
cohere_client = AsyncClient(
cohere_client = AsyncClientV2(
api_key,
base_url=self.api_base_url,
timeout=self.timeout,
client_name="haystack",
)
embedding, metadata = asyncio.run(
get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate)
get_async_response(
cohere_client, [text], self.model, self.input_type, self.truncate, self.embedding_type
)
)
else:
cohere_client = Client(
cohere_client = ClientV2(
api_key,
base_url=self.api_base_url,
timeout=self.timeout,
client_name="haystack",
)
embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate)
embedding, metadata = get_response(
cohere_client, [text], self.model, self.input_type, self.truncate, embedding_type=self.embedding_type
)

return {"embedding": embedding[0], "meta": metadata}
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

from tqdm import tqdm

from cohere import AsyncClient, Client
from cohere import AsyncClientV2, ClientV2
from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes


async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate):
async def get_async_response(
cohere_async_client: AsyncClientV2,
texts: List[str],
model_name,
input_type,
truncate,
embedding_type: Optional[EmbeddingTypes] = None,
):
"""Embeds a list of texts asynchronously using the Cohere API.

:param cohere_async_client: the Cohere `AsyncClient`
Expand All @@ -17,6 +25,7 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str],
:param input_type: one of "classification", "clustering", "search_document", "search_query".
The type of input text provided to embed.
:param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length.
:param embedding_type: the type of embeddings to return. Defaults to float embeddings.

:returns: A tuple of the embeddings and metadata.

Expand All @@ -25,17 +34,36 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str],
all_embeddings: List[List[float]] = []
metadata: Dict[str, Any] = {}

response = await cohere_async_client.embed(texts=texts, model=model_name, input_type=input_type, truncate=truncate)
embedding_type = embedding_type or EmbeddingTypes.FLOAT
response = await cohere_async_client.embed(
texts=texts,
model=model_name,
input_type=input_type,
truncate=truncate,
embedding_types=[embedding_type.value],
)
if response.meta is not None:
metadata = response.meta
for emb in response.embeddings:
all_embeddings.append(emb)
for emb_tuple in response.embeddings:
# emb_tuple[0] is a str denoting the embedding type (e.g. "float", "int8", etc.)
if emb_tuple[1] is not None:
# ok we have embeddings for this type, let's take all
# the embeddings (a list of embeddings) and break the loop
all_embeddings.extend(emb_tuple[1])
break

return all_embeddings, metadata


def get_response(
cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False
cohere_client: ClientV2,
texts: List[str],
model_name,
input_type,
truncate,
batch_size=32,
progress_bar=False,
embedding_type: Optional[EmbeddingTypes] = None,
) -> Tuple[List[List[float]], Dict[str, Any]]:
"""Embeds a list of texts using the Cohere API.

Expand All @@ -47,6 +75,7 @@ def get_response(
:param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length.
:param batch_size: the batch size to use
:param progress_bar: if `True`, show a progress bar
:param embedding_type: the type of embeddings to return. Defaults to float embeddings.

:returns: A tuple of the embeddings and metadata.

Expand All @@ -55,16 +84,28 @@ def get_response(

all_embeddings: List[List[float]] = []
metadata: Dict[str, Any] = {}
embedding_type = embedding_type or EmbeddingTypes.FLOAT

for i in tqdm(
range(0, len(texts), batch_size),
disable=not progress_bar,
desc="Calculating embeddings",
):
batch = texts[i : i + batch_size]
response = cohere_client.embed(texts=batch, model=model_name, input_type=input_type, truncate=truncate)
for emb in response.embeddings:
all_embeddings.append(emb)
response = cohere_client.embed(
texts=batch,
model=model_name,
input_type=input_type,
truncate=truncate,
embedding_types=[embedding_type.value],
)
## response.embeddings always returns 5 tuples, one tuple per embedding type
## let's take first non None tuple as that's the one we want
for emb_tuple in response.embeddings:
# emb_tuple[0] is a str denoting the embedding type (e.g. "float", "int8", etc.)
if emb_tuple[1] is not None:
# ok we have embeddings for this type, let's take all the embeddings (a list of embeddings)
all_embeddings.extend(emb_tuple[1])
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
if response.meta is not None:
metadata = response.meta

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
max_chunks_per_doc: Optional[int] = None,
meta_fields_to_embed: Optional[List[str]] = None,
meta_data_separator: str = "\n",
max_tokens_per_doc: int = 4096,
):
"""
Creates an instance of the 'CohereRanker'.
Expand All @@ -57,6 +58,7 @@ def __init__(
with the document content for reranking.
:param meta_data_separator: Separator used to concatenate the meta fields
to the Document content.
:param max_tokens_per_doc: The maximum number of tokens to embed for each document defaults to 4096.
"""
self.model_name = model
self.api_key = api_key
Expand All @@ -65,7 +67,18 @@ def __init__(
self.max_chunks_per_doc = max_chunks_per_doc
self.meta_fields_to_embed = meta_fields_to_embed or []
self.meta_data_separator = meta_data_separator
self._cohere_client = cohere.Client(
self.max_tokens_per_doc = max_tokens_per_doc
if max_chunks_per_doc is not None:
# Note: max_chunks_per_doc is currently not supported by the Cohere V2 API
# See: https://docs.cohere.com/reference/rerank
import warnings

warnings.warn(
"The max_chunks_per_doc parameter currently has no effect as it is not supported by the Cohere V2 API.",
UserWarning,
stacklevel=2,
)
self._cohere_client = cohere.ClientV2(
api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack"
)

Expand All @@ -85,6 +98,7 @@ def to_dict(self) -> Dict[str, Any]:
max_chunks_per_doc=self.max_chunks_per_doc,
meta_fields_to_embed=self.meta_fields_to_embed,
meta_data_separator=self.meta_data_separator,
max_tokens_per_doc=self.max_tokens_per_doc,
)

@classmethod
Expand Down Expand Up @@ -152,7 +166,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None
model=self.model_name,
query=query,
documents=cohere_input_docs,
max_chunks_per_doc=self.max_chunks_per_doc,
max_tokens_per_doc=self.max_tokens_per_doc,
top_n=top_k,
)
indices = [output.index for output in response.results]
Expand Down
Loading
Loading