Skip to content

Commit

Permalink
feat: DIA-1715: VertexAI Gemini model support (#298)
Browse files Browse the repository at this point in the history
Co-authored-by: hakan458 <[email protected]>
Co-authored-by: matt-bernstein <[email protected]>
Co-authored-by: niklub <[email protected]>
Co-authored-by: Matt Bernstein <[email protected]>
  • Loading branch information
5 people authored Jan 17, 2025
1 parent 92c3169 commit 2d3ef74
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 29 deletions.
78 changes: 56 additions & 22 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,35 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
return data


class InstructorClientMixin:
def normalize_litellm_model_and_provider(model_name: str, provider: str):
"""
When using litellm.get_model_info() some models are accessed with their provider prefix
while others are not.
This helper function contains logic which normalizes this for supported providers
"""
if "/" in model_name:
model_name = model_name.split('/', 1)[1]
provider = provider.lower()
if provider == "vertexai":
provider = "vertex_ai"

return model_name, provider


class InstructorClientMixin(BaseModel):

# Note: most models work better with json mode; this is set only for backwards compatibility
# instructor_mode: str = "json_mode"
instructor_mode: str = "tool_call"

# Note: doesn't seem like this separate function should be necessary, but errors when combined with @cached_property
def _from_litellm(self, **kwargs):
return instructor.from_litellm(litellm.completion, **kwargs)

@cached_property
def client(self):
kwargs = {}
if self.is_custom_openai_endpoint:
kwargs["mode"] = instructor.Mode.JSON
return self._from_litellm(**kwargs)

@property
def is_custom_openai_endpoint(self) -> bool:
return self.model.startswith("openai/") and self.model_extra.get("base_url")
return self._from_litellm(mode=instructor.Mode(self.instructor_mode))


class InstructorAsyncClientMixin(InstructorClientMixin):
Expand Down Expand Up @@ -241,7 +256,6 @@ class LiteLLMChatRuntime(InstructorClientMixin, Runtime):
with the provider of your specified model.
base_url (Optional[str]): Base URL, optional. If provided, will be used to talk to an OpenAI-compatible API provider besides OpenAI.
api_version (Optional[str]): API version, optional except for Azure.
timeout: Timeout in seconds.
"""

model: str = "gpt-4o-mini"
Expand Down Expand Up @@ -382,7 +396,6 @@ class AsyncLiteLLMChatRuntime(InstructorAsyncClientMixin, AsyncRuntime):
with the provider of your specified model.
base_url (Optional[str]): Base URL, optional. If provided, will be used to talk to an OpenAI-compatible API provider besides OpenAI.
api_version (Optional[str]): API version, optional except for Azure.
timeout: Timeout in seconds.
"""

model: str = "gpt-4o-mini"
Expand Down Expand Up @@ -553,9 +566,12 @@ def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int
return user_tokens + system_tokens

@staticmethod
def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> int:
def _get_completion_tokens(
model: str, output_fields: Optional[List[str]], provider: str
) -> int:
model, provider = normalize_litellm_model_and_provider(model, provider)
max_tokens = litellm.get_model_info(
model=model, custom_llm_provider="openai"
model=model, custom_llm_provider=provider
).get("max_tokens", None)
if not max_tokens:
raise ValueError
Expand All @@ -565,10 +581,14 @@ def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> in

@classmethod
def _estimate_cost(
cls, user_prompt: str, model: str, output_fields: Optional[List[str]]
cls,
user_prompt: str,
model: str,
output_fields: Optional[List[str]],
provider: str,
):
prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields)
completion_tokens = cls._get_completion_tokens(model, output_fields)
completion_tokens = cls._get_completion_tokens(model, output_fields, provider)
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
Expand All @@ -579,7 +599,11 @@ def _estimate_cost(
return prompt_cost, completion_cost, total_cost

def get_cost_estimate(
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
self,
prompt: str,
substitutions: List[Dict],
output_fields: Optional[List[str]],
provider: str,
) -> CostEstimate:
try:
user_prompts = [
Expand All @@ -594,6 +618,7 @@ def get_cost_estimate(
user_prompt=user_prompt,
model=self.model,
output_fields=output_fields,
provider=provider,
)
cumulative_prompt_cost += prompt_cost
cumulative_completion_cost += completion_cost
Expand Down Expand Up @@ -729,8 +754,12 @@ class AsyncLiteLLMVisionRuntime(AsyncLiteLLMChatRuntime):

def init_runtime(self) -> "Runtime":
super().init_runtime()
if not litellm.supports_vision(self.model):
raise ValueError(f"Model {self.model} does not support vision")
# Only running this supports_vision check for non-vertex models, since its based on a static JSON file in
# litellm which was not up to date. Will be soon in next release - should update this
if not self.model.startswith("vertex_ai"):
model_name = self.model
if not litellm.supports_vision(model_name):
raise ValueError(f"Model {self.model} does not support vision")
return self

async def batch_to_batch(
Expand Down Expand Up @@ -816,7 +845,10 @@ async def batch_to_batch(

# TODO: cost estimate

def get_model_info(provider: str, model_name: str, auth_info: Optional[dict]=None) -> dict:

def get_model_info(
provider: str, model_name: str, auth_info: Optional[dict] = None
) -> dict:
if auth_info is None:
auth_info = {}
try:
Expand All @@ -826,11 +858,13 @@ def get_model_info(provider: str, model_name: str, auth_info: Optional[dict]=Non
model=f"azure/{model_name}",
messages=[{"role": "user", "content": ""}],
max_tokens=1,
**auth_info
**auth_info,
)
model_name = dummy_completion.model
full_name = f"{provider}/{model_name}"
return litellm.get_model_info(full_name)
model_name, provider = normalize_litellm_model_and_provider(
model_name, provider
)
return litellm.get_model_info(model=model_name, custom_llm_provider=provider)
except Exception as err:
logger.error("Hit error when trying to get model metadata: %s", err)
return {}
144 changes: 137 additions & 7 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from fastapi import HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import litellm
from litellm.exceptions import AuthenticationError
from litellm.utils import check_valid_key, get_valid_models
from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator
from redis import Redis
import time
Expand All @@ -37,7 +39,6 @@

logger = init_logger(__name__)


settings = Settings()

app = fastapi.FastAPI()
Expand Down Expand Up @@ -83,10 +84,36 @@ class BatchSubmitted(BaseModel):
job_id: str


class ModelsListRequest(BaseModel):
provider: str


class ModelsListResponse(BaseModel):
models_list: List[str]


class CostEstimateRequest(BaseModel):
agent: Agent
prompt: str
substitutions: List[Dict]
provider: str


class ValidateConnectionRequest(BaseModel):
provider: str
api_key: Optional[str] = None
vertex_credentials: Optional[str] = None
vertex_location: Optional[str] = None
vertex_project: Optional[str] = None
api_version: Optional[str] = None
deployment_name: Optional[str] = None
endpoint: Optional[str] = None
auth_token: Optional[str] = None


class ValidateConnectionResponse(BaseModel):
model: str
success: bool


class Status(Enum):
Expand Down Expand Up @@ -216,6 +243,99 @@ async def submit_batch(batch: BatchData):
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))


@app.post("/validate-connection", response_model=Response[ValidateConnectionResponse])
async def validate_connection(request: ValidateConnectionRequest):
multi_model_provider_test_models = {
"openai": "gpt-4o-mini",
"vertexai": "vertex_ai/gemini-1.5-flash",
}
provider = request.provider.lower()
messages = [{"role": "user", "content": "Hey, how's it going?"}]

# For multi-model providers use a model that every account should have access to
if provider in multi_model_provider_test_models.keys():
model = multi_model_provider_test_models[provider]
if provider == "openai":
model_extra = {"api_key": request.api_key}
elif provider == "vertexai":
model_extra = {"vertex_credentials": request.vertex_credentials}
if request.vertex_location:
model_extra["vertex_location"] = request.vertex_location
if request.vertex_project:
model_extra["vertex_project"] = request.vertex_project
try:
response = litellm.completion(
messages=messages,
model=model,
max_tokens=10,
temperature=0.0,
**model_extra,
)
except AuthenticationError:
raise HTTPException(
status_code=400,
detail=f"Requested model '{model}' is not available with your api_key / credentials",
)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Error validating credentials for provider {provider}: {e}",
)

# For single-model connections use the provided model
else:
if provider.lower() == "azureopenai":
model = "azure/" + request.deployment_name
model_extra = {"base_url": request.endpoint}
elif provider.lower() == "custom":
model = "openai/" + request.deployment_name
model_extra = (
{"extra_headers": {"Authorization": request.auth_token}}
if request.auth_token
else {}
)
model_extra["api_key"] = request.api_key
try:
response = litellm.completion(
messages=messages,
model=model,
max_tokens=10,
temperature=0.0,
**model_extra,
)
except AuthenticationError:
raise HTTPException(
status_code=400,
detail=f"Requested model '{model}' is not available with your api_key and settings.",
)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to check availability of requested model '{model}': {e}",
)

return Response[ValidateConnectionResponse](
data=ValidateConnectionResponse(success=True, model=response.model)
)


@app.post("/models-list", response_model=Response[ModelsListResponse])
async def models_list(request: ModelsListRequest):
# get_valid_models uses api key set in env, however the list is not dynamically retrieved
# https://docs.litellm.ai/docs/set_keys#get_valid_models
# https://github.com/BerriAI/litellm/blob/b9280528d368aced49cb4d287c57cd0b46168cb6/litellm/utils.py#L5705
# Ultimately just uses litellm.models_by_provider - setting API key is not needed
lse_provider_to_litellm_provider = {"openai": "openai", "vertexai": "vertex_ai"}
provider = request.provider.lower()
valid_models = litellm.models_by_provider[
lse_provider_to_litellm_provider[provider]
]

return Response[ModelsListResponse](
data=ModelsListResponse(models_list=valid_models)
)


@app.post("/estimate-cost", response_model=Response[CostEstimate])
async def estimate_cost(
request: CostEstimateRequest,
Expand All @@ -238,6 +358,7 @@ async def estimate_cost(
prompt = request.prompt
substitutions = request.substitutions
agent = request.agent
provider = request.provider
runtime = agent.get_runtime()

try:
Expand All @@ -247,7 +368,10 @@ async def estimate_cost(
list(skill.field_schema.keys()) if skill.field_schema else None
)
cost_estimate = runtime.get_cost_estimate(
prompt=prompt, substitutions=substitutions, output_fields=output_fields
prompt=prompt,
substitutions=substitutions,
output_fields=output_fields,
provider=provider,
)
cost_estimates.append(cost_estimate)
total_cost_estimate = sum(
Expand Down Expand Up @@ -429,21 +553,27 @@ class ModelMetadataRequestItem(BaseModel):
model_name: str
auth_info: Optional[Dict[str, str]] = None


class ModelMetadataRequest(BaseModel):
models: List[ModelMetadataRequestItem]


class ModelMetadataResponse(BaseModel):
model_metadata: Dict[str, Dict]


@app.post("/model-metadata", response_model=Response[ModelMetadataResponse])
async def model_metadata(request: ModelMetadataRequest):
from adala.runtimes._litellm import get_model_info

resp = {'model_metadata': {item.model_name: get_model_info(**item.model_dump()) for item in request.models}}
return Response[ModelMetadataResponse](
success=True,
data=resp
)
resp = {
"model_metadata": {
item.model_name: get_model_info(**item.model_dump())
for item in request.models
}
}
return Response[ModelMetadataResponse](success=True, data=resp)


if __name__ == "__main__":
# for debugging
Expand Down
1 change: 1 addition & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_agent_is_serializable():
"verbose": False,
"batch_size": 100,
"concurrency": 1,
"instructor_mode": "tool_call",
"model": "gpt-4o-mini",
"max_tokens": 200,
"temperature": 0.0,
Expand Down

0 comments on commit 2d3ef74

Please sign in to comment.