diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index 8aaa0222..c682ed3a 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -815,3 +815,22 @@ async def batch_to_batch( return output_df.set_index(batch.index) # TODO: cost estimate + +def get_model_info(provider: str, model_name: str, auth_info: Optional[dict]=None) -> dict: + if auth_info is None: + auth_info = {} + try: + # for azure models, need to get the canonical name for the model + if provider == "azure": + dummy_completion = litellm.completion( + model=f"azure/{model_name}", + messages=[{"role": "user", "content": ""}], + max_tokens=1, + **auth_info + ) + model_name = dummy_completion.model + full_name = f"{provider}/{model_name}" + return litellm.get_model_info(full_name) + except Exception as err: + logger.error("Hit error when trying to get model metadata: %s", err) + return {} diff --git a/server/app.py b/server/app.py index 24259aa2..ad21e3d0 100644 --- a/server/app.py +++ b/server/app.py @@ -424,6 +424,27 @@ async def improved_prompt(request: ImprovedPromptRequest): ) +class ModelMetadataRequestItem(BaseModel): + provider: str + model_name: str + auth_info: Optional[Dict[str, str]] = None + +class ModelMetadataRequest(BaseModel): + models: List[ModelMetadataRequestItem] + +class ModelMetadataResponse(BaseModel): + model_metadata: Dict[str, Dict] + +@app.post("/model-metadata", response_model=Response[ModelMetadataResponse]) +async def model_metadata(request: ModelMetadataRequest): + from adala.runtimes._litellm import get_model_info + + resp = {'model_metadata': {item.model_name: get_model_info(**item.model_dump()) for item in request.models}} + return Response[ModelMetadataResponse]( + success=True, + data=resp + ) + if __name__ == "__main__": # for debugging uvicorn.run("app:app", host="0.0.0.0", port=30001)