-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
198 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from __future__ import annotations | ||
|
||
from typing import List, Dict, Any | ||
|
||
from together.abstract import api_requestor | ||
from together.together_response import TogetherResponse | ||
from together.types import ( | ||
RerankRequest, | ||
RerankResponse, | ||
TogetherClient, | ||
TogetherRequest, | ||
) | ||
|
||
|
||
class Rerank: | ||
def __init__(self, client: TogetherClient) -> None: | ||
self._client = client | ||
|
||
def create( | ||
self, | ||
*, | ||
model: str, | ||
query: str, | ||
documents: List[str] | List[Dict[str, Any]], | ||
top_n: int | None = None, | ||
return_documents: bool = False, | ||
rank_fields: List[str] | None = None, | ||
) -> RerankResponse: | ||
""" | ||
Method to generate completions based on a given prompt using a specified model. | ||
Args: | ||
model (str): The name of the model to query. | ||
query (str): The input query or list of queries to rerank. | ||
documents (List[str] | List[Dict[str, Any]]): List of documents to be reranked. | ||
top_n (int | None): Number of top results to return. | ||
return_documents (bool): Flag to indicate whether to return documents. | ||
rank_fields (List[str] | None): Fields to be used for ranking the documents. | ||
Returns: | ||
RerankResponse: Object containing reranked scores and documents | ||
""" | ||
|
||
requestor = api_requestor.APIRequestor( | ||
client=self._client, | ||
) | ||
|
||
parameter_payload = RerankRequest( | ||
model=model, | ||
query=query, | ||
documents=documents, | ||
top_n=top_n, | ||
return_documents=return_documents, | ||
rank_fields=rank_fields, | ||
).model_dump(exclude_none=True) | ||
|
||
response, _, _ = requestor.request( | ||
options=TogetherRequest( | ||
method="POST", | ||
url="rerank", | ||
params=parameter_payload, | ||
), | ||
stream=False, | ||
) | ||
|
||
assert isinstance(response, TogetherResponse) | ||
|
||
return RerankResponse(**response.data) | ||
|
||
|
||
class AsyncRerank: | ||
def __init__(self, client: TogetherClient) -> None: | ||
self._client = client | ||
|
||
async def create( | ||
self, | ||
*, | ||
model: str, | ||
query: str, | ||
documents: List[str] | List[Dict[str, Any]], | ||
top_n: int | None = None, | ||
return_documents: bool = False, | ||
rank_fields: List[str] | None = None, | ||
) -> RerankResponse: | ||
""" | ||
Async method to generate completions based on a given prompt using a specified model. | ||
Args: | ||
model (str): The name of the model to query. | ||
query (str): The input query or list of queries to rerank. | ||
documents (List[str] | List[Dict[str, Any]]): List of documents to be reranked. | ||
top_n (int | None): Number of top results to return. | ||
return_documents (bool): Flag to indicate whether to return documents. | ||
rank_fields (List[str] | None): Fields to be used for ranking the documents. | ||
Returns: | ||
RerankResponse: Object containing reranked scores and documents | ||
""" | ||
|
||
requestor = api_requestor.APIRequestor( | ||
client=self._client, | ||
) | ||
|
||
parameter_payload = RerankRequest( | ||
model=model, | ||
query=query, | ||
documents=documents, | ||
top_n=top_n, | ||
return_documents=return_documents, | ||
rank_fields=rank_fields, | ||
).model_dump(exclude_none=True) | ||
|
||
response, _, _ = await requestor.arequest( | ||
options=TogetherRequest( | ||
method="POST", | ||
url="rerank", | ||
params=parameter_payload, | ||
), | ||
stream=False, | ||
) | ||
|
||
assert isinstance(response, TogetherResponse) | ||
|
||
return RerankResponse(**response.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import annotations | ||
|
||
from typing import List, Literal, Dict, Any | ||
|
||
from together.types.abstract import BaseModel | ||
from together.types.common import UsageData | ||
|
||
|
||
class RerankRequest(BaseModel): | ||
# model to query | ||
model: str | ||
# input or list of inputs | ||
query: str | ||
# list of documents | ||
documents: List[str] | List[Dict[str, Any]] | ||
# return top_n results | ||
top_n: int | None = None | ||
# boolean to return documents | ||
return_documents: bool = False | ||
# field selector for documents | ||
rank_fields: List[str] | None = None | ||
|
||
|
||
class RerankChoicesData(BaseModel): | ||
# response index | ||
index: int | ||
# object type | ||
relevance_score: float | ||
# rerank response | ||
document: Dict[str, Any] | None = None | ||
|
||
|
||
class RerankResponse(BaseModel): | ||
# job id | ||
id: str | None = None | ||
# object type | ||
object: Literal["rerank"] | None = None | ||
# query model | ||
model: str | None = None | ||
# list of reranked results | ||
results: List[RerankChoicesData] | None = None | ||
# usage stats | ||
usage: UsageData | None = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters