Skip to content

Commit

Permalink
feat: support async endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
zac-li committed Jan 16, 2025
1 parent b01a4da commit b76c8e8
Showing 1 changed file with 187 additions and 9 deletions.
196 changes: 187 additions & 9 deletions jina_sagemaker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(
self._aas_client = boto3.client("application-autoscaling", **client_args)
self._cw_client = boto3.client("cloudwatch", **client_args)

self._endpoint_name = None
self._endpoint_config_name = None
self._model_name = None

def _does_endpoint_exist(self, endpoint_name: str) -> bool:
try:
self._sm_client.describe_endpoint(EndpointName=endpoint_name)
Expand All @@ -66,6 +70,114 @@ def connect_to_endpoint(self, endpoint_name: str, arn: str) -> None:
)
self._arn = arn

def create_async_endpoint(
self,
arn: str,
endpoint_name: str,
s3_output_path: str,
instance_type: str,
n_instances: int = 1,
recreate: bool = False,
role: Optional[str] = None,
success_topic: Optional[str] = None,
error_topic: Optional[str] = None,
) -> None:
"""
Creates an asynchronous SageMaker endpoint from a model package ARN.
Args:
arn (str): The model package ARN.
endpoint_name (str): The name of the endpoint.
s3_output_path (str): S3 path where the asynchronous inference results will be stored.
instance_type (str): The instance type for the endpoint (e.g., "ml.m5.xlarge").
n_instances (int): The number of instances to deploy (default: 1).
recreate (bool): Whether to recreate the endpoint if it already exists (default: False).
role (Optional[str]): The IAM role ARN to associate with the model.
success_topic (Optional[str]): SNS topic ARN for successful inference notifications (default: None).
error_topic (Optional[str]): SNS topic ARN for error notifications (default: None).
"""
from botocore.exceptions import ClientError

if role is None:
role = get_role()

# Check if there is already endpoint config, if so delete it or it will block deploy
model_name = endpoint_name
try:
self._sm_client.delete_model(ModelName=model_name)
except ClientError as e:
if e.response["Error"]["Code"] != "ValidationException":
raise

create_model_args = {
"ModelName": model_name,
"ExecutionRoleArn": role,
"Containers": [
{
"ModelPackageName": arn,
}
],
}
_ = self._sm_client.create_model(**create_model_args)
self._model_name = model_name

# Delete existing endpoint configuration if it exists
try:
self._sm_client.delete_endpoint_config(EndpointConfigName=endpoint_name)
except ClientError as e:
if e.response["Error"]["Code"] != "ValidationException":
raise

# Check if the endpoint already exists
if self._does_endpoint_exist(endpoint_name):
if recreate:
self.connect_to_endpoint(endpoint_name, arn)
self.delete_endpoint()
else:
raise Exception(
f"Endpoint {endpoint_name} already exists and recreate={recreate}."
)

# Create an endpoint configuration with AsyncInferenceConfig
async_inference_config = {
"OutputConfig": {
"S3OutputPath": s3_output_path,
}
}

if success_topic or error_topic:
async_inference_config["OutputConfig"]["NotificationConfig"] = {}
if success_topic:
async_inference_config["OutputConfig"]["NotificationConfig"][
"SuccessTopic"
] = success_topic
if error_topic:
async_inference_config["OutputConfig"]["NotificationConfig"][
"ErrorTopic"
] = error_topic

_ = self._sm_client.create_endpoint_config(
EndpointConfigName=endpoint_name,
ProductionVariants=[
{
"VariantName": "AllTraffic",
"ModelName": model_name,
"InstanceType": instance_type,
"InitialInstanceCount": n_instances,
}
],
AsyncInferenceConfig=async_inference_config,
)
self._endpoint_config_name = endpoint_name

_ = self._sm_client.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_name,
)

# Connect to the new endpoint
self.connect_to_endpoint(endpoint_name, model_name)

def create_endpoint(
self,
arn: str,
Expand Down Expand Up @@ -223,6 +335,52 @@ def create_transform_job(
job_name = transformer.latest_transform_job.name
return job_name

def read_async(self, prompt: str, input_s3_path: str):
"""
Asynchronous version of the read method that uses invoke_endpoint_async.
Args:
prompt (str): The input prompt for the model.
input_s3_path (str): S3 path where the input data will be uploaded.
output_s3_path (str): S3 path where the output data will be stored.
Returns:
dict: A response containing the output location and other metadata.
"""
if self._endpoint_name is None:
raise Exception("No endpoint connected. Run connect_to_endpoint() first.")

model = "reader-lm-0.5b"
if "1500m" in self._arn:
model = "reader-lm-1.5b"
elif "v2" in self._arn:
model = "ReaderLM-v2"

# Prepare the input payload
data = json.dumps(
{
"model": model,
"prompt": prompt,
}
)

s3 = boto3.client("s3")
bucket_name, input_key = input_s3_path.replace("s3://", "").split("/", 1)
s3.put_object(Bucket=bucket_name, Key=input_key, Body=data)

# Call the async endpoint
response = self._sm_runtime_client.invoke_endpoint_async(
EndpointName=self._endpoint_name,
InputLocation=input_s3_path,
ContentType="application/json",
)

# Return the response metadata, including the output location
return {
"OutputLocation": response["OutputLocation"],
"InputLocation": input_s3_path,
}

def read(self, prompt: str, stream: bool = False):
if self._endpoint_name is None:
raise Exception(
Expand Down Expand Up @@ -293,13 +451,13 @@ def embed(
else:
data = {"data": [{"text": text} for text in texts]}

if 'jina-embeddings-v3' in self._arn:
if "jina-embeddings-v3" in self._arn:
data["parameters"] = {
"task": task_type.value if task_type else None,
"dimensions": dimensions,
"late_chunking": late_chunking,
}
elif 'jina-clip-v2' in self._arn:
elif "jina-clip-v2" in self._arn:
data["parameters"] = {
"task": task_type.value if task_type else None,
"dimensions": dimensions,
Expand Down Expand Up @@ -375,19 +533,39 @@ def rerank(self, documents: List[str], query: str, top_n: Optional[int] = None):
return resp["data"]

def delete_endpoint(self) -> None:
"""
Deletes the endpoint, its configuration, and the associated model if their names are set.
"""

if self._endpoint_name is None:
raise Exception("No endpoint connected.")

# Delete the endpoint
try:
self._sm_client.delete_endpoint(EndpointName=self._endpoint_name)
print(f"Deleted endpoint: {self._endpoint_name}")
except ClientError:
print("Endpoint not found, skipping deletion.")
print(f"Endpoint '{self._endpoint_name}' not found, skipping deletion.")

try:
self._sm_client.delete_endpoint_config(
EndpointConfigName=self._endpoint_name
)
except ClientError:
print("Endpoint config not found, skipping deletion.")
# Delete the endpoint configuration
if self._endpoint_config_name is not None:
try:
self._sm_client.delete_endpoint_config(
EndpointConfigName=self._endpoint_config_name
)
print(f"Deleted endpoint configuration: {self._endpoint_config_name}")
except ClientError:
print(
f"Endpoint configuration '{self._endpoint_config_name}' not found, skipping deletion."
)

# Delete the model
if self._model_name is not None:
try:
self._sm_client.delete_model(ModelName=self._model_name)
print(f"Deleted model: {self._model_name}")
except ClientError:
print(f"Model '{self._model_name}' not found, skipping deletion.")

def close(self) -> None:
try:
Expand Down

0 comments on commit b76c8e8

Please sign in to comment.