diff --git a/backend/app/api/admin_routes/embedding_model/models.py b/backend/app/api/admin_routes/embedding_model/models.py index 19c5f038e..7968251d4 100644 --- a/backend/app/api/admin_routes/embedding_model/models.py +++ b/backend/app/api/admin_routes/embedding_model/models.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, field_validator from typing_extensions import Optional -from app.types import EmbeddingProvider +from app.rag.embeddings.provider import EmbeddingProvider class EmbeddingModelCreate(BaseModel): diff --git a/backend/app/api/admin_routes/embedding_model/routes.py b/backend/app/api/admin_routes/embedding_model/routes.py index 87a4bbbdf..9f8c383fd 100644 --- a/backend/app/api/admin_routes/embedding_model/routes.py +++ b/backend/app/api/admin_routes/embedding_model/routes.py @@ -13,19 +13,29 @@ ) from app.api.deps import CurrentSuperuserDep, SessionDep from app.exceptions import EmbeddingModelNotFound, InternalServerError -from app.rag.chat_config import get_embed_model -from app.rag.embed_model_option import EmbeddingModelOption, admin_embed_model_options from app.repositories.embedding_model import embed_model_repo +from app.rag.embeddings import ( + get_embed_model, + EmbeddingProviderOption, + embedding_provider_options, +) router = APIRouter() logger = logging.getLogger(__name__) -@router.get("/admin/embedding-models/options") +@router.get("/admin/embedding-models/provider/options") +def list_embedding_model_provider_options( + user: CurrentSuperuserDep, +) -> List[EmbeddingProviderOption]: + return embedding_provider_options + + +@router.get("/admin/embedding-models/options", deprecated=True) def get_embedding_model_options( user: CurrentSuperuserDep, -) -> List[EmbeddingModelOption]: - return admin_embed_model_options +) -> List[EmbeddingProviderOption]: + return embedding_provider_options @router.post("/admin/embedding-models") diff --git a/backend/app/api/admin_routes/llm/routes.py b/backend/app/api/admin_routes/llm/routes.py index bf5d36cf5..d0bad1dab 100644 --- a/backend/app/api/admin_routes/llm/routes.py +++ b/backend/app/api/admin_routes/llm/routes.py @@ -9,17 +9,25 @@ from app.api.deps import CurrentSuperuserDep, SessionDep from app.exceptions import InternalServerError, LLMNotFound from app.models import AdminLLM, LLM, ChatEngine, KnowledgeBase -from app.rag.chat_config import get_llm -from app.rag.llm_option import LLMOption, admin_llm_options from app.repositories.llm import llm_repo +from app.rag.llms import ( + get_llm, + LLMProviderOption, + llm_provider_options, +) router = APIRouter() logger = logging.getLogger(__name__) -@router.get("/admin/llms/options") -def get_llm_options(user: CurrentSuperuserDep) -> List[LLMOption]: - return admin_llm_options +@router.get("/admin/llms/provider/options") +def list_llm_provider_options(user: CurrentSuperuserDep) -> List[LLMProviderOption]: + return llm_provider_options + + +@router.get("/admin/llms/options", deprecated=True) +def get_llm_options(user: CurrentSuperuserDep) -> List[LLMProviderOption]: + return llm_provider_options @router.get("/admin/llms") diff --git a/backend/app/api/admin_routes/reranker_model/routes.py b/backend/app/api/admin_routes/reranker_model/routes.py index 08f611659..0ca39bc47 100644 --- a/backend/app/api/admin_routes/reranker_model/routes.py +++ b/backend/app/api/admin_routes/reranker_model/routes.py @@ -11,20 +11,29 @@ from app.api.deps import CurrentSuperuserDep, SessionDep from app.exceptions import RerankerModelNotFound, InternalServerError from app.models import RerankerModel, AdminRerankerModel, ChatEngine -from app.rag.chat_config import get_reranker_model -from app.rag.reranker_model_option import ( - RerankerModelOption, - admin_reranker_model_options, -) from app.repositories.reranker_model import reranker_model_repo +from app.rag.rerankers import ( + get_reranker_model, + reranker_provider_options, + RerankerProviderOption, +) router = APIRouter() logger = logging.getLogger(__name__) -@router.get("/admin/reranker-models/options") -def get_reranker_model_options(user: CurrentSuperuserDep) -> List[RerankerModelOption]: - return admin_reranker_model_options +@router.get("/admin/reranker-models/provider/options") +def list_reranker_model_provider_options( + user: CurrentSuperuserDep, +) -> List[RerankerProviderOption]: + return reranker_provider_options + + +@router.get("/admin/reranker-models/options", deprecated=True) +def get_reranker_model_options( + user: CurrentSuperuserDep, +) -> List[RerankerProviderOption]: + return reranker_provider_options @router.post("/admin/reranker-models/test") diff --git a/backend/app/models/llm.py b/backend/app/models/llm.py index 360640067..7b7d1188e 100644 --- a/backend/app/models/llm.py +++ b/backend/app/models/llm.py @@ -1,9 +1,8 @@ from typing import Optional, Any - from sqlmodel import Field, Column, JSON, String +from app.rag.llms.provider import LLMProvider from .base import UpdatableBaseModel, AESEncryptedColumn -from app.types import LLMProvider class BaseLLM(UpdatableBaseModel): diff --git a/backend/app/rag/chat_config.py b/backend/app/rag/chat_config.py index 6508347d8..38da1a08e 100644 --- a/backend/app/rag/chat_config.py +++ b/backend/app/rag/chat_config.py @@ -1,43 +1,26 @@ -import os import logging -from typing import Dict, Optional - import dspy -from llama_index.llms.bedrock.utils import BEDROCK_FOUNDATION_LLMS + +from typing import Dict, Optional from pydantic import BaseModel -from llama_index.llms.openai import OpenAI -from llama_index.llms.openai_like import OpenAILike -from llama_index.llms.gemini import Gemini -from llama_index.llms.bedrock import Bedrock -from llama_index.llms.ollama import Ollama -from llama_index.core.llms.llm import LLM -from llama_index.core.base.embeddings.base import BaseEmbedding -from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.embeddings.openai import OpenAIEmbedding -from llama_index.embeddings.jinaai import JinaEmbedding -from llama_index.embeddings.cohere import CohereEmbedding -from llama_index.embeddings.bedrock import BedrockEmbedding -from llama_index.embeddings.ollama import OllamaEmbedding -from llama_index.postprocessor.jinaai_rerank import JinaRerank -from llama_index.postprocessor.cohere_rerank import CohereRerank -from llama_index.postprocessor.xinference_rerank import XinferenceRerank -from llama_index.postprocessor.bedrock_rerank import AWSBedrockRerank from sqlmodel import Session -from google.oauth2 import service_account -from google.auth.transport.requests import Request - -from app.rag.embeddings.openai_like_embedding import OpenAILikeEmbedding -from app.rag.node_postprocessor import MetadataPostFilter -from app.rag.node_postprocessor.metadata_post_filter import MetadataFilters -from app.rag.node_postprocessor.baisheng_reranker import BaishengRerank -from app.rag.node_postprocessor.local_reranker import LocalRerank -from app.rag.node_postprocessor.vllm_reranker import VLLMRerank -from app.rag.embeddings.local_embedding import LocalEmbedding -from app.repositories import chat_engine_repo, knowledge_base_repo -from app.repositories.embedding_model import embed_model_repo -from app.repositories.llm import llm_repo -from app.repositories.reranker_model import reranker_model_repo -from app.types import LLMProvider, EmbeddingProvider, RerankerProvider + +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.llms.llm import LLM + +from app.utils.dspy import get_dspy_lm_by_llama_llm +from app.rag.llms import get_default_llm, get_llm +from app.rag.rerankers import get_default_reranker_model, get_reranker_model +from app.rag.postprocessors import get_metadata_post_filter, MetadataFilters + + +from app.models import ( + ChatEngine as DBChatEngine, + LLM as DBLLM, + RerankerModel as DBRerankerModel, + KnowledgeBase, +) +from app.repositories import chat_engine_repo from app.rag.default_prompt import ( DEFAULT_INTENT_GRAPH_KNOWLEDGE, DEFAULT_NORMAL_GRAPH_KNOWLEDGE, @@ -49,15 +32,7 @@ DEFAULT_GENERATE_GOAL_PROMPT, DEFAULT_CLARIFYING_QUESTION_PROMPT, ) -from app.models import ( - ChatEngine as DBChatEngine, - LLM as DBLLM, - RerankerModel as DBRerankerModel, - KnowledgeBase, -) -from app.rag.llms.anthropic_vertex import AnthropicVertex -from app.utils.dspy import get_dspy_lm_by_llama_llm logger = logging.getLogger(__name__) @@ -208,283 +183,3 @@ def screenshot(self) -> dict: "post_verification_token": True, } ) - - -# LLM - - -def get_llm( - provider: LLMProvider, - model: str, - config: dict, - credentials: str | list | dict | None, -) -> LLM: - match provider: - case LLMProvider.OPENAI: - return OpenAI( - model=model, - api_key=credentials, - **config, - ) - case LLMProvider.OPENAI_LIKE: - config.setdefault("context_window", 200 * 1000) - return OpenAILike(model=model, api_key=credentials, **config) - case LLMProvider.GEMINI: - os.environ["GOOGLE_API_KEY"] = credentials - return Gemini(model=model, api_key=credentials, **config) - case LLMProvider.BEDROCK: - access_key_id = credentials["aws_access_key_id"] - secret_access_key = credentials["aws_secret_access_key"] - region_name = credentials["aws_region_name"] - - context_size = None - if model not in BEDROCK_FOUNDATION_LLMS: - context_size = 200000 - - llm = Bedrock( - model=model, - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - region_name=region_name, - context_size=context_size, - **config, - ) - # Note: Because llama index Bedrock class doesn't set up these values to the corresponding - # attributes in its constructor function, we pass the values again via setter to pass them to - # `get_dspy_lm_by_llama_llm` function. - llm.aws_access_key_id = access_key_id - llm.aws_secret_access_key = secret_access_key - llm.region_name = region_name - return llm - case LLMProvider.ANTHROPIC_VERTEX: - google_creds: service_account.Credentials = ( - service_account.Credentials.from_service_account_info( - credentials, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - ) - google_creds.refresh(request=Request()) - if "max_tokens" not in config: - config.update(max_tokens=4096) - return AnthropicVertex(model=model, credentials=google_creds, **config) - case LLMProvider.OLLAMA: - config.setdefault("request_timeout", 60 * 10) - config.setdefault("context_window", 4096) - return Ollama(model=model, **config) - case LLMProvider.GITEEAI: - config.setdefault("context_window", 200 * 1000) - return OpenAILike( - model=model, - api_base="https://ai.gitee.com/v1", - api_key=credentials, - **config, - ) - case _: - raise ValueError(f"Got unknown LLM provider: {provider}") - - -def get_default_llm(session: Session) -> Optional[LLM]: - db_llm = llm_repo.get_default(session) - if not db_llm: - return None - return get_llm( - db_llm.provider, - db_llm.model, - db_llm.config, - db_llm.credentials, - ) - - -def must_get_default_llm(session: Session) -> LLM: - db_llm = llm_repo.must_get_default(session) - return get_llm( - db_llm.provider, - db_llm.model, - db_llm.config, - db_llm.credentials, - ) - - -# Embedding model - - -def get_embed_model( - provider: EmbeddingProvider, - model: str, - config: dict, - credentials: str | list | dict | None, -) -> BaseEmbedding: - match provider: - case EmbeddingProvider.OPENAI: - return OpenAIEmbedding( - model=model, - api_key=credentials, - **config, - ) - case EmbeddingProvider.JINA: - return JinaEmbedding( - model=model, - api_key=credentials, - **config, - ) - case EmbeddingProvider.COHERE: - return CohereEmbedding( - model_name=model, - cohere_api_key=credentials, - **config, - ) - case EmbeddingProvider.BEDROCK: - return BedrockEmbedding( - model_name=model, - aws_access_key_id=credentials["aws_access_key_id"], - aws_secret_access_key=credentials["aws_secret_access_key"], - region_name=credentials["aws_region_name"], - **config, - ) - case EmbeddingProvider.OLLAMA: - return OllamaEmbedding( - model_name=model, - **config, - ) - case EmbeddingProvider.LOCAL: - return LocalEmbedding( - model=model, - **config, - ) - case EmbeddingProvider.GITEEAI: - return OpenAILikeEmbedding( - model=model, - api_base="https://ai.gitee.com/v1", - api_key=credentials, - **config, - ) - case EmbeddingProvider.OPENAI_LIKE: - return OpenAILikeEmbedding( - model=model, - api_key=credentials, - **config, - ) - case _: - raise ValueError(f"Got unknown embedding provider: {provider}") - - -def get_default_embed_model(session: Session) -> Optional[BaseEmbedding]: - db_embed_model = embed_model_repo.get_default(session) - if not db_embed_model: - return None - return get_embed_model( - db_embed_model.provider, - db_embed_model.model, - db_embed_model.config, - db_embed_model.credentials, - ) - - -def must_get_default_embed_model(session: Session) -> BaseEmbedding: - db_embed_model = embed_model_repo.must_get_default(session) - return get_embed_model( - db_embed_model.provider, - db_embed_model.model, - db_embed_model.config, - db_embed_model.credentials, - ) - - -# Reranker model - - -def get_reranker_model( - provider: RerankerProvider, - model: str, - top_n: int, - config: dict, - credentials: str | list | dict | None, -) -> BaseNodePostprocessor: - match provider: - case RerankerProvider.JINA: - return JinaRerank( - model=model, - top_n=top_n, - api_key=credentials, - **config, - ) - case RerankerProvider.COHERE: - return CohereRerank( - model=model, - top_n=top_n, - api_key=credentials, - **config, - ) - case RerankerProvider.BAISHENG: - return BaishengRerank( - model=model, - top_n=top_n, - api_key=credentials, - **config, - ) - case RerankerProvider.LOCAL: - return LocalRerank( - model=model, - top_n=top_n, - **config, - ) - case RerankerProvider.VLLM: - return VLLMRerank( - model=model, - top_n=top_n, - **config, - ) - case RerankerProvider.XINFERENCE: - return XinferenceRerank( - model=model, - top_n=top_n, - **config, - ) - case RerankerProvider.BEDROCK: - return AWSBedrockRerank( - rerank_model_name=model, - top_n=top_n, - aws_access_key_id=credentials["aws_access_key_id"], - aws_secret_access_key=credentials["aws_secret_access_key"], - region_name=credentials["aws_region_name"], - **config, - ) - case _: - raise ValueError(f"Got unknown reranker provider: {provider}") - - -# FIXME: Reranker top_n should be config in the retrival config. -def get_default_reranker_model( - session: Session, top_n: int = None -) -> Optional[BaseNodePostprocessor]: - db_reranker = reranker_model_repo.get_default(session) - if not db_reranker: - return None - top_n = db_reranker.top_n if top_n is None else top_n - return get_reranker_model( - db_reranker.provider, - db_reranker.model, - top_n, - db_reranker.config, - db_reranker.credentials, - ) - - -def must_get_default_reranker_model(session: Session) -> BaseNodePostprocessor: - db_reranker = reranker_model_repo.must_get_default(session) - return get_reranker_model( - db_reranker.provider, - db_reranker.model, - db_reranker.top_n, - db_reranker.config, - db_reranker.credentials, - ) - - -# Metadata post filter - - -def get_metadata_post_filter( - filters: Optional[MetadataFilters] = None, -) -> BaseNodePostprocessor: - return MetadataPostFilter(filters) diff --git a/backend/app/rag/embeddings/__init__.py b/backend/app/rag/embeddings/__init__.py index e69de29bb..a9c7216fa 100644 --- a/backend/app/rag/embeddings/__init__.py +++ b/backend/app/rag/embeddings/__init__.py @@ -0,0 +1,14 @@ +from .provider import EmbeddingProviderOption, embedding_provider_options +from .resolver import ( + get_embed_model, + get_default_embed_model, + must_get_default_embed_model, +) + +__all__ = [ + "get_embed_model", + "get_default_embed_model", + "must_get_default_embed_model", + "EmbeddingProviderOption", + "embedding_provider_options", +] diff --git a/backend/app/rag/embeddings/local_embedding.py b/backend/app/rag/embeddings/local/local_embedding.py similarity index 100% rename from backend/app/rag/embeddings/local_embedding.py rename to backend/app/rag/embeddings/local/local_embedding.py diff --git a/backend/app/rag/embeddings/openai_like_embedding.py b/backend/app/rag/embeddings/open_like/openai_like_embedding.py similarity index 100% rename from backend/app/rag/embeddings/openai_like_embedding.py rename to backend/app/rag/embeddings/open_like/openai_like_embedding.py diff --git a/backend/app/rag/embed_model_option.py b/backend/app/rag/embeddings/provider.py similarity index 91% rename from backend/app/rag/embed_model_option.py rename to backend/app/rag/embeddings/provider.py index f308a58bb..c7e67ce4e 100644 --- a/backend/app/rag/embed_model_option.py +++ b/backend/app/rag/embeddings/provider.py @@ -1,10 +1,21 @@ +import enum + from typing import List from pydantic import BaseModel -from app.types import EmbeddingProvider + +class EmbeddingProvider(str, enum.Enum): + OPENAI = "openai" + JINA = "jina" + COHERE = "cohere" + BEDROCK = "bedrock" + OLLAMA = "ollama" + GITEEAI = "giteeai" + LOCAL = "local" + OPENAI_LIKE = "openai_like" -class EmbeddingModelOption(BaseModel): +class EmbeddingProviderOption(BaseModel): provider: EmbeddingProvider provider_display_name: str | None = None provider_description: str | None = None @@ -19,8 +30,8 @@ class EmbeddingModelOption(BaseModel): credentials_type: str = "str" -admin_embed_model_options: List[EmbeddingModelOption] = [ - EmbeddingModelOption( +embedding_provider_options: List[EmbeddingProviderOption] = [ + EmbeddingProviderOption( provider=EmbeddingProvider.OPENAI, provider_display_name="OpenAI", provider_description="The OpenAI API provides a simple interface for developers to create an intelligence layer in their applications, powered by OpenAI's state of the art models.", @@ -32,7 +43,7 @@ class EmbeddingModelOption(BaseModel): credentials_type="str", default_credentials="sk-****", ), - EmbeddingModelOption( + EmbeddingProviderOption( provider=EmbeddingProvider.JINA, provider_display_name="JinaAI", provider_description="Jina AI provides multimodal, bilingual long-context embeddings for search and RAG", @@ -44,7 +55,7 @@ class EmbeddingModelOption(BaseModel): credentials_type="str", default_credentials="jina_****", ), - EmbeddingModelOption( + EmbeddingProviderOption( provider=EmbeddingProvider.COHERE, provider_display_name="Cohere", provider_description="Cohere provides industry-leading large language models (LLMs) and RAG capabilities tailored to meet the needs of enterprise use cases that solve real-world problems.", @@ -56,7 +67,7 @@ class EmbeddingModelOption(BaseModel): credentials_type="str", default_credentials="*****", ), - EmbeddingModelOption( + EmbeddingProviderOption( provider=EmbeddingProvider.BEDROCK, provider_display_name="Bedrock", provider_description="Amazon Bedrock is a fully managed foundation models service.", @@ -72,7 +83,7 @@ class EmbeddingModelOption(BaseModel): "aws_region_name": "us-west-2", }, ), - EmbeddingModelOption( + EmbeddingProviderOption( provider=EmbeddingProvider.OLLAMA, provider_display_name="Ollama", provider_description="Ollama is a lightweight framework for building and running large language models and embed models.", @@ -88,7 +99,7 @@ class EmbeddingModelOption(BaseModel): credentials_type="str", default_credentials="dummy", ), - EmbeddingModelOption( + EmbeddingProviderOption( provider=EmbeddingProvider.OPENAI_LIKE, provider_display_name="OpenAI Like", provider_description="OpenAI-Like is a set of platforms that provide text embeddings similar to OpenAI. Such as ZhiPuAI.", @@ -100,7 +111,7 @@ class EmbeddingModelOption(BaseModel): credentials_type="str", default_credentials="dummy", ), - EmbeddingModelOption( + EmbeddingProviderOption( provider=EmbeddingProvider.GITEEAI, provider_display_name="Gitee AI", provider_description="Gitee AI is a third-party model provider that offers ready-to-use cutting-edge model APIs for AI developers.", @@ -112,7 +123,7 @@ class EmbeddingModelOption(BaseModel): credentials_type="str", default_credentials="****", ), - EmbeddingModelOption( + EmbeddingProviderOption( provider=EmbeddingProvider.LOCAL, provider_display_name="Local Embedding", provider_description="Autoflow's local embedding server, deployed on your own infrastructure and powered by sentence-transformers.", diff --git a/backend/app/rag/embeddings/resolver.py b/backend/app/rag/embeddings/resolver.py new file mode 100644 index 000000000..aad709e8d --- /dev/null +++ b/backend/app/rag/embeddings/resolver.py @@ -0,0 +1,97 @@ +from typing import Optional +from sqlmodel import Session + +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.embeddings.jinaai import JinaEmbedding +from llama_index.embeddings.cohere import CohereEmbedding +from llama_index.embeddings.bedrock import BedrockEmbedding +from llama_index.embeddings.ollama import OllamaEmbedding + +from app.rag.embeddings.open_like.openai_like_embedding import OpenAILikeEmbedding +from app.rag.embeddings.local.local_embedding import LocalEmbedding + +from app.repositories.embedding_model import embed_model_repo +from app.rag.embeddings.provider import EmbeddingProvider + + +def get_embed_model( + provider: EmbeddingProvider, + model: str, + config: dict, + credentials: str | list | dict | None, +) -> BaseEmbedding: + match provider: + case EmbeddingProvider.OPENAI: + return OpenAIEmbedding( + model=model, + api_key=credentials, + **config, + ) + case EmbeddingProvider.JINA: + return JinaEmbedding( + model=model, + api_key=credentials, + **config, + ) + case EmbeddingProvider.COHERE: + return CohereEmbedding( + model_name=model, + cohere_api_key=credentials, + **config, + ) + case EmbeddingProvider.BEDROCK: + return BedrockEmbedding( + model_name=model, + aws_access_key_id=credentials["aws_access_key_id"], + aws_secret_access_key=credentials["aws_secret_access_key"], + region_name=credentials["aws_region_name"], + **config, + ) + case EmbeddingProvider.OLLAMA: + return OllamaEmbedding( + model_name=model, + **config, + ) + case EmbeddingProvider.LOCAL: + return LocalEmbedding( + model=model, + **config, + ) + case EmbeddingProvider.GITEEAI: + return OpenAILikeEmbedding( + model=model, + api_base="https://ai.gitee.com/v1", + api_key=credentials, + **config, + ) + case EmbeddingProvider.OPENAI_LIKE: + return OpenAILikeEmbedding( + model=model, + api_key=credentials, + **config, + ) + case _: + raise ValueError(f"Got unknown embedding provider: {provider}") + + +def get_default_embed_model(session: Session) -> Optional[BaseEmbedding]: + db_embed_model = embed_model_repo.get_default(session) + if not db_embed_model: + return None + return get_embed_model( + db_embed_model.provider, + db_embed_model.model, + db_embed_model.config, + db_embed_model.credentials, + ) + + +def must_get_default_embed_model(session: Session) -> BaseEmbedding: + db_embed_model = embed_model_repo.must_get_default(session) + return get_embed_model( + db_embed_model.provider, + db_embed_model.model, + db_embed_model.config, + db_embed_model.credentials, + ) diff --git a/backend/app/rag/knowledge_graph/extractor.py b/backend/app/rag/knowledge_graph/extractor.py index dd207aa5c..3204663fd 100644 --- a/backend/app/rag/knowledge_graph/extractor.py +++ b/backend/app/rag/knowledge_graph/extractor.py @@ -112,9 +112,7 @@ def get_llm_output_config(self): elif "bedrock" in self.dspy_lm.provider.lower(): # Fix: add bedrock branch to fix 'Malformed input request' error # subject must not be valid against schema {"required":["messages"]}: extraneous key [response_mime_type] is not permitted - return { - "max_tokens": 8192 - } + return {"max_tokens": 8192} else: return { "response_mime_type": "application/json", diff --git a/backend/app/rag/llms/__init__.py b/backend/app/rag/llms/__init__.py new file mode 100644 index 000000000..a416384df --- /dev/null +++ b/backend/app/rag/llms/__init__.py @@ -0,0 +1,10 @@ +from .provider import LLMProviderOption, llm_provider_options +from .resolver import get_llm, get_default_llm, must_get_default_llm + +__all__ = [ + "LLMProviderOption", + "llm_provider_options", + "get_llm", + "get_default_llm", + "must_get_default_llm", +] diff --git a/backend/app/rag/llms/anthropic_vertex/__init__.py b/backend/app/rag/llms/anthropic_vertex/__init__.py deleted file mode 100644 index 471b5f6f6..000000000 --- a/backend/app/rag/llms/anthropic_vertex/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base import AnthropicVertex - -__all__ = ["AnthropicVertex"] diff --git a/backend/app/rag/llms/anthropic_vertex/base.py b/backend/app/rag/llms/anthropic_vertex/base.py deleted file mode 100644 index 440821f86..000000000 --- a/backend/app/rag/llms/anthropic_vertex/base.py +++ /dev/null @@ -1,490 +0,0 @@ -import anthropic -import json -from anthropic.types import ( - ContentBlockDeltaEvent, - TextBlock, - TextDelta, - ContentBlockStartEvent, - ContentBlockStopEvent, -) -from anthropic.types.tool_use_block import ToolUseBlock -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, - Union, - TYPE_CHECKING, -) - -from google.oauth2 import service_account -from llama_index.core.base.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) -from llama_index.core.bridge.pydantic import Field, PrivateAttr -from llama_index.core.callbacks import CallbackManager -from llama_index.core.constants import DEFAULT_TEMPERATURE -from llama_index.core.llms.callbacks import ( - llm_chat_callback, - llm_completion_callback, -) -from llama_index.core.base.llms.generic_utils import ( - achat_to_completion_decorator, - astream_chat_to_completion_decorator, - chat_to_completion_decorator, - stream_chat_to_completion_decorator, -) -from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection -from llama_index.core.types import BaseOutputParser, PydanticProgramMode -from llama_index.core.utils import Tokenizer -from llama_index.core.llms.utils import parse_partial_json - -from .utils import ( - anthropic_modelname_to_contextsize, - force_single_tool_call, - is_function_calling_model, - messages_to_anthropic_messages, -) - -if TYPE_CHECKING: - from llama_index.core.tools.types import BaseTool - - -DEFAULT_ANTHROPIC_MODEL = "claude-2.1" -DEFAULT_ANTHROPIC_MAX_TOKENS = 512 - - -class AnthropicVertex(FunctionCallingLLM): - """AnthropicVertex LLM. - - Examples: - `pip install llama-index-llms-anthropic` - - ```python - from llama_index.llms.anthropic import Anthropic - - llm = Anthropic(model="claude-instant-1") - resp = llm.stream_complete("Paul Graham is ") - for r in resp: - print(r.delta, end="") - ``` - """ - - model: str = Field( - default=DEFAULT_ANTHROPIC_MODEL, description="The anthropic model to use." - ) - temperature: float = Field( - default=DEFAULT_TEMPERATURE, - description="The temperature to use for sampling.", - gte=0.0, - lte=1.0, - ) - max_tokens: int = Field( - default=DEFAULT_ANTHROPIC_MAX_TOKENS, - description="The maximum number of tokens to generate.", - gt=0, - ) - - base_url: Optional[str] = Field(default=None, description="The base URL to use.") - timeout: Optional[float] = Field( - default=None, description="The timeout to use in seconds.", gte=0 - ) - max_retries: int = Field( - default=10, description="The maximum number of API retries.", gte=0 - ) - additional_kwargs: Dict[str, Any] = Field( - default_factory=dict, description="Additional kwargs for the anthropic API." - ) - - _client: anthropic.AnthropicVertex = PrivateAttr() - _aclient: anthropic.AsyncAnthropicVertex = PrivateAttr() - - def __init__( - self, - model: str = DEFAULT_ANTHROPIC_MODEL, - temperature: float = DEFAULT_TEMPERATURE, - max_tokens: int = DEFAULT_ANTHROPIC_MAX_TOKENS, - base_url: Optional[str] = None, - timeout: Optional[float] = None, - credentials: Optional[service_account.Credentials] = None, - max_retries: int = 10, - additional_kwargs: Optional[Dict[str, Any]] = None, - callback_manager: Optional[CallbackManager] = None, - default_headers: Optional[Dict[str, str]] = None, - system_prompt: Optional[str] = None, - messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, - completion_to_prompt: Optional[Callable[[str], str]] = None, - pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, - output_parser: Optional[BaseOutputParser] = None, - ) -> None: - additional_kwargs = additional_kwargs or {} - callback_manager = callback_manager or CallbackManager([]) - - super().__init__( - temperature=temperature, - max_tokens=max_tokens, - additional_kwargs=additional_kwargs, - base_url=base_url, - timeout=timeout, - max_retries=max_retries, - model=model, - callback_manager=callback_manager, - system_prompt=system_prompt, - messages_to_prompt=messages_to_prompt, - completion_to_prompt=completion_to_prompt, - pydantic_program_mode=pydantic_program_mode, - output_parser=output_parser, - ) - - self._client = anthropic.AnthropicVertex( - region="us-east5", - timeout=timeout, - max_retries=max_retries, - default_headers=default_headers, - # credentials=credentials, - project_id=credentials.project_id, - ) - self._client._credentials = credentials - self._aclient = anthropic.AsyncAnthropicVertex( - region="us-east5", - timeout=timeout, - max_retries=max_retries, - default_headers=default_headers, - # credentials=credentials, - project_id=credentials.project_id, - ) - self._aclient._credentials = credentials - - @classmethod - def class_name(cls) -> str: - return "AnthropicVertex_LLM" - - @property - def metadata(self) -> LLMMetadata: - return LLMMetadata( - context_window=anthropic_modelname_to_contextsize(self.model), - num_output=self.max_tokens, - is_chat_model=True, - model_name=self.model, - is_function_calling_model=is_function_calling_model(self.model), - ) - - @property - def tokenizer(self) -> Tokenizer: - return self._client.get_tokenizer() - - @property - def _model_kwargs(self) -> Dict[str, Any]: - base_kwargs = { - "model": self.model, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - } - return { - **base_kwargs, - **self.additional_kwargs, - } - - def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: - return { - **self._model_kwargs, - **kwargs, - } - - def _get_content_and_tool_calls( - self, response: Any - ) -> Tuple[str, List[ToolUseBlock]]: - tool_calls = [] - content = "" - for content_block in response.content: - if isinstance(content_block, TextBlock): - content += content_block.text - elif isinstance(content_block, ToolUseBlock): - tool_calls.append(content_block.dict()) - - return content, tool_calls - - @llm_chat_callback() - def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - anthropic_messages, system_prompt = messages_to_anthropic_messages(messages) - all_kwargs = self._get_all_kwargs(**kwargs) - - response = self._client.messages.create( - messages=anthropic_messages, - stream=False, - system=system_prompt, - **all_kwargs, - ) - - content, tool_calls = self._get_content_and_tool_calls(response) - - return ChatResponse( - message=ChatMessage( - role=MessageRole.ASSISTANT, - content=content, - additional_kwargs={"tool_calls": tool_calls}, - ), - raw=dict(response), - ) - - @llm_completion_callback() - def complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponse: - complete_fn = chat_to_completion_decorator(self.chat) - return complete_fn(prompt, **kwargs) - - @llm_chat_callback() - def stream_chat( - self, messages: Sequence[ChatMessage], **kwargs: Any - ) -> ChatResponseGen: - anthropic_messages, system_prompt = messages_to_anthropic_messages(messages) - all_kwargs = self._get_all_kwargs(**kwargs) - - response = self._client.messages.create( - messages=anthropic_messages, system=system_prompt, stream=True, **all_kwargs - ) - - def gen() -> ChatResponseGen: - content = "" - cur_tool_calls: List[ToolUseBlock] = [] - cur_tool_call: Optional[ToolUseBlock] = None - cur_tool_json: str = "" - role = MessageRole.ASSISTANT - for r in response: - if isinstance(r, ContentBlockDeltaEvent): - if isinstance(r.delta, TextDelta): - content_delta = r.delta.text - content += content_delta - else: - if not isinstance(cur_tool_call, ToolUseBlock): - raise ValueError("Tool call not started") - content_delta = r.delta.partial_json - cur_tool_json += content_delta - try: - argument_dict = parse_partial_json(cur_tool_json) - cur_tool_call.input = argument_dict - except ValueError: - pass - - if cur_tool_call is not None: - tool_calls_to_send = [*cur_tool_calls, cur_tool_call] - else: - tool_calls_to_send = cur_tool_calls - yield ChatResponse( - message=ChatMessage( - role=role, - content=content, - additional_kwargs={ - "tool_calls": [t.dict() for t in tool_calls_to_send] - }, - ), - delta=content_delta, - raw=r, - ) - elif isinstance(r, ContentBlockStartEvent): - if isinstance(r.content_block, ToolUseBlock): - cur_tool_call = r.content_block - cur_tool_json = "" - elif isinstance(r, ContentBlockStopEvent): - if isinstance(cur_tool_call, ToolUseBlock): - cur_tool_calls.append(cur_tool_call) - - return gen() - - @llm_completion_callback() - def stream_complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponseGen: - stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) - return stream_complete_fn(prompt, **kwargs) - - @llm_chat_callback() - async def achat( - self, messages: Sequence[ChatMessage], **kwargs: Any - ) -> ChatResponse: - anthropic_messages, system_prompt = messages_to_anthropic_messages(messages) - all_kwargs = self._get_all_kwargs(**kwargs) - - response = await self._aclient.messages.create( - messages=anthropic_messages, - system=system_prompt, - stream=False, - **all_kwargs, - ) - - content, tool_calls = self._get_content_and_tool_calls(response) - - return ChatResponse( - message=ChatMessage( - role=MessageRole.ASSISTANT, - content=content, - additional_kwargs={"tool_calls": tool_calls}, - ), - raw=dict(response), - ) - - @llm_completion_callback() - async def acomplete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponse: - acomplete_fn = achat_to_completion_decorator(self.achat) - return await acomplete_fn(prompt, **kwargs) - - @llm_chat_callback() - async def astream_chat( - self, messages: Sequence[ChatMessage], **kwargs: Any - ) -> ChatResponseAsyncGen: - anthropic_messages, system_prompt = messages_to_anthropic_messages(messages) - all_kwargs = self._get_all_kwargs(**kwargs) - - response = await self._aclient.messages.create( - messages=anthropic_messages, system=system_prompt, stream=True, **all_kwargs - ) - - async def gen() -> ChatResponseAsyncGen: - content = "" - cur_tool_calls: List[ToolUseBlock] = [] - cur_tool_call: Optional[ToolUseBlock] = None - cur_tool_json: str = "" - role = MessageRole.ASSISTANT - async for r in response: - if isinstance(r, ContentBlockDeltaEvent): - if isinstance(r.delta, TextDelta): - content_delta = r.delta.text - content += content_delta - else: - if not isinstance(cur_tool_call, ToolUseBlock): - raise ValueError("Tool call not started") - content_delta = r.delta.partial_json - cur_tool_json += content_delta - try: - argument_dict = parse_partial_json(cur_tool_json) - cur_tool_call.input = argument_dict - except ValueError: - pass - - if cur_tool_call is not None: - tool_calls_to_send = [*cur_tool_calls, cur_tool_call] - else: - tool_calls_to_send = cur_tool_calls - yield ChatResponse( - message=ChatMessage( - role=role, - content=content, - additional_kwargs={ - "tool_calls": [t.dict() for t in tool_calls_to_send] - }, - ), - delta=content_delta, - raw=r, - ) - elif isinstance(r, ContentBlockStartEvent): - if isinstance(r.content_block, ToolUseBlock): - cur_tool_call = r.content_block - cur_tool_json = "" - elif isinstance(r, ContentBlockStopEvent): - if isinstance(cur_tool_call, ToolUseBlock): - cur_tool_calls.append(cur_tool_call) - - return gen() - - @llm_completion_callback() - async def astream_complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponseAsyncGen: - astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat) - return await astream_complete_fn(prompt, **kwargs) - - def _prepare_chat_with_tools( - self, - tools: List["BaseTool"], - user_msg: Optional[Union[str, ChatMessage]] = None, - chat_history: Optional[List[ChatMessage]] = None, - verbose: bool = False, - allow_parallel_tool_calls: bool = False, - **kwargs: Any, - ) -> Dict[str, Any]: - """Prepare the chat with tools.""" - chat_history = chat_history or [] - - if isinstance(user_msg, str): - user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) - chat_history.append(user_msg) - - tool_dicts = [] - for tool in tools: - tool_dicts.append( - { - "name": tool.metadata.name, - "description": tool.metadata.description, - "input_schema": tool.metadata.get_parameters_dict(), - } - ) - return {"messages": chat_history, "tools": tool_dicts or None, **kwargs} - - def _validate_chat_with_tools_response( - self, - response: ChatResponse, - tools: List["BaseTool"], - allow_parallel_tool_calls: bool = False, - **kwargs: Any, - ) -> ChatResponse: - """Validate the response from chat_with_tools.""" - if not allow_parallel_tool_calls: - force_single_tool_call(response) - return response - - def get_tool_calls_from_response( - self, - response: "ChatResponse", - error_on_no_tool_call: bool = True, - **kwargs: Any, - ) -> List[ToolSelection]: - """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) - - if len(tool_calls) < 1: - if error_on_no_tool_call: - raise ValueError( - f"Expected at least one tool call, but got {len(tool_calls)} tool calls." - ) - else: - return [] - - tool_selections = [] - for tool_call in tool_calls: - if ( - "input" not in tool_call - or "id" not in tool_call - or "name" not in tool_call - ): - raise ValueError("Invalid tool call.") - if tool_call["type"] != "tool_use": - raise ValueError("Invalid tool type. Unsupported by Anthropic") - argument_dict = ( - json.loads(tool_call["input"]) - if isinstance(tool_call["input"], str) - else tool_call["input"] - ) - - tool_selections.append( - ToolSelection( - tool_id=tool_call["id"], - tool_name=tool_call["name"], - tool_kwargs=argument_dict, - ) - ) - - return tool_selections diff --git a/backend/app/rag/llms/anthropic_vertex/utils.py b/backend/app/rag/llms/anthropic_vertex/utils.py deleted file mode 100644 index e91e71889..000000000 --- a/backend/app/rag/llms/anthropic_vertex/utils.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Dict, Sequence, Tuple - -from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole - -from anthropic.types import MessageParam, TextBlockParam -from anthropic.types.tool_result_block_param import ToolResultBlockParam -from anthropic.types.tool_use_block_param import ToolUseBlockParam - -HUMAN_PREFIX = "\n\nHuman:" -ASSISTANT_PREFIX = "\n\nAssistant:" - -CLAUDE_MODELS: Dict[str, int] = { - "claude-instant-1": 100000, - "claude-instant-1.2": 100000, - "claude-2": 100000, - "claude-2.0": 100000, - "claude-2.1": 200000, - "claude-3-opus-20240229": 180000, - "claude-3-sonnet-20240229": 180000, - "claude-3-haiku-20241022": 180000, - "claude-3-5-sonnet-20241022": 180000, - "claude-3-5-sonnet@20241022": 180000, -} - - -def is_function_calling_model(modelname: str) -> bool: - return "claude-3" in modelname - - -def anthropic_modelname_to_contextsize(modelname: str) -> int: - if modelname not in CLAUDE_MODELS: - raise ValueError( - f"Unknown model: {modelname}. Please provide a valid Anthropic model name." - "Known models are: " + ", ".join(CLAUDE_MODELS.keys()) - ) - - return CLAUDE_MODELS[modelname] - - -def __merge_common_role_msgs( - messages: Sequence[MessageParam], -) -> Sequence[MessageParam]: - """Merge consecutive messages with the same role.""" - postprocessed_messages: Sequence[MessageParam] = [] - for message in messages: - if ( - postprocessed_messages - and postprocessed_messages[-1]["role"] == message["role"] - ): - postprocessed_messages[-1]["content"] += message["content"] - else: - postprocessed_messages.append(message) - return postprocessed_messages - - -def messages_to_anthropic_messages( - messages: Sequence[ChatMessage], -) -> Tuple[Sequence[MessageParam], str]: - """Converts a list of generic ChatMessages to anthropic messages. - - Args: - messages: List of ChatMessages - - Returns: - Tuple of: - - List of anthropic messages - - System prompt - """ - anthropic_messages = [] - system_prompt = "" - for message in messages: - if message.role == MessageRole.SYSTEM: - system_prompt += message.content + "\n" - elif message.role == MessageRole.FUNCTION or message.role == MessageRole.TOOL: - content = ToolResultBlockParam( - tool_use_id=message.additional_kwargs["tool_call_id"], - type="tool_result", - content=[TextBlockParam(text=message.content, type="text")], - ) - anth_message = MessageParam( - role=MessageRole.USER.value, - content=[content], - ) - anthropic_messages.append(anth_message) - else: - content = [] - if message.content: - content.append(TextBlockParam(text=message.content, type="text")) - - tool_calls = message.additional_kwargs.get("tool_calls", []) - for tool_call in tool_calls: - assert "id" in tool_call - assert "input" in tool_call - assert "name" in tool_call - - content.append( - ToolUseBlockParam( - id=tool_call["id"], - input=tool_call["input"], - name=tool_call["name"], - type="tool_use", - ) - ) - - anth_message = MessageParam( - role=message.role.value, - content=content, # TODO: type detect for multimodal - ) - anthropic_messages.append(anth_message) - - return __merge_common_role_msgs(anthropic_messages), system_prompt.strip() - - -# Function used in bedrock -def _message_to_anthropic_prompt(message: ChatMessage) -> str: - if message.role == MessageRole.USER: - prompt = f"{HUMAN_PREFIX} {message.content}" - elif message.role == MessageRole.ASSISTANT: - prompt = f"{ASSISTANT_PREFIX} {message.content}" - elif message.role == MessageRole.SYSTEM: - prompt = f"{message.content}" - elif message.role == MessageRole.FUNCTION: - raise ValueError(f"Message role {MessageRole.FUNCTION} is not supported.") - else: - raise ValueError(f"Unknown message role: {message.role}") - - return prompt - - -def messages_to_anthropic_prompt(messages: Sequence[ChatMessage]) -> str: - if len(messages) == 0: - raise ValueError("Got empty list of messages.") - - # NOTE: make sure the prompt ends with the assistant prefix - if messages[-1].role != MessageRole.ASSISTANT: - messages = [ - *list(messages), - ChatMessage(role=MessageRole.ASSISTANT, content=""), - ] - - str_list = [_message_to_anthropic_prompt(message) for message in messages] - return "".join(str_list) - - -def force_single_tool_call(response: ChatResponse) -> None: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) - if len(tool_calls) > 1: - response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] diff --git a/backend/app/rag/llm_option.py b/backend/app/rag/llms/provider.py similarity index 92% rename from backend/app/rag/llm_option.py rename to backend/app/rag/llms/provider.py index a96950a82..840aadfa2 100644 --- a/backend/app/rag/llm_option.py +++ b/backend/app/rag/llms/provider.py @@ -1,10 +1,20 @@ +import enum + from typing import List from pydantic import BaseModel -from app.types import LLMProvider + +class LLMProvider(str, enum.Enum): + OPENAI = "openai" + GEMINI = "gemini" + ANTHROPIC_VERTEX = "anthropic_vertex" + OPENAI_LIKE = "openai_like" + BEDROCK = "bedrock" + OLLAMA = "ollama" + GITEEAI = "giteeai" -class LLMOption(BaseModel): +class LLMProviderOption(BaseModel): provider: LLMProvider provider_display_name: str | None = None provider_description: str | None = None @@ -19,8 +29,8 @@ class LLMOption(BaseModel): credentials_type: str = "str" -admin_llm_options: List[LLMOption] = [ - LLMOption( +llm_provider_options: List[LLMProviderOption] = [ + LLMProviderOption( provider=LLMProvider.OPENAI, provider_display_name="OpenAI", provider_description="The OpenAI API provides a simple interface for developers to create an intelligence layer in their applications, powered by OpenAI's state of the art models.", @@ -32,7 +42,7 @@ class LLMOption(BaseModel): credentials_type="str", default_credentials="sk-****", ), - LLMOption( + LLMProviderOption( provider=LLMProvider.OPENAI_LIKE, provider_display_name="OpenAI Like", default_llm_model="", @@ -51,7 +61,7 @@ class LLMOption(BaseModel): credentials_type="str", default_credentials="sk-****", ), - LLMOption( + LLMProviderOption( provider=LLMProvider.GEMINI, provider_display_name="Gemini", provider_description="The Gemini API and Google AI Studio help you start working with Google's latest models. Access the whole Gemini model family and turn your ideas into real applications that scale.", @@ -63,7 +73,7 @@ class LLMOption(BaseModel): credentials_type="str", default_credentials="AIza****", ), - LLMOption( + LLMProviderOption( provider=LLMProvider.OLLAMA, provider_display_name="Ollama", provider_description="Ollama is a lightweight framework for building and running large language models.", @@ -85,7 +95,7 @@ class LLMOption(BaseModel): credentials_type="str", default_credentials="dummy", ), - LLMOption( + LLMProviderOption( provider=LLMProvider.GITEEAI, provider_display_name="Gitee AI", provider_description="Gitee AI is a third-party model provider that offers ready-to-use cutting-edge model APIs for AI developers.", @@ -105,7 +115,7 @@ class LLMOption(BaseModel): credentials_type="str", default_credentials="****", ), - LLMOption( + LLMProviderOption( provider=LLMProvider.ANTHROPIC_VERTEX, provider_display_name="Anthropic Vertex AI", provider_description="Anthropic's Claude models are now generally available through Vertex AI.", @@ -121,7 +131,7 @@ class LLMOption(BaseModel): "private_key_id": "****", }, ), - LLMOption( + LLMProviderOption( provider=LLMProvider.BEDROCK, provider_display_name="Bedrock", provider_description="Amazon Bedrock is a fully managed foundation models service.", diff --git a/backend/app/rag/llms/resolver.py b/backend/app/rag/llms/resolver.py new file mode 100644 index 000000000..fd3c97858 --- /dev/null +++ b/backend/app/rag/llms/resolver.py @@ -0,0 +1,112 @@ +import os +from typing import Optional +from llama_index.core.llms.llm import LLM +from llama_index.llms.openai import OpenAI +from llama_index.llms.openai_like import OpenAILike +from llama_index.llms.gemini import Gemini +from llama_index.llms.bedrock import Bedrock +from llama_index.llms.bedrock.utils import BEDROCK_FOUNDATION_LLMS +from llama_index.llms.ollama import Ollama +from llama_index.llms.vertex import Vertex +from google.oauth2 import service_account +from google.auth.transport.requests import Request +from sqlmodel import Session + +from app.repositories.llm import llm_repo +from app.rag.llms.provider import LLMProvider + + +def get_llm( + provider: LLMProvider, + model: str, + config: dict, + credentials: str | list | dict | None, +) -> LLM: + match provider: + case LLMProvider.OPENAI: + return OpenAI( + model=model, + api_key=credentials, + **config, + ) + case LLMProvider.OPENAI_LIKE: + config.setdefault("context_window", 200 * 1000) + return OpenAILike(model=model, api_key=credentials, **config) + case LLMProvider.GEMINI: + os.environ["GOOGLE_API_KEY"] = credentials + return Gemini(model=model, api_key=credentials, **config) + case LLMProvider.BEDROCK: + access_key_id = credentials["aws_access_key_id"] + secret_access_key = credentials["aws_secret_access_key"] + region_name = credentials["aws_region_name"] + + context_size = None + if model not in BEDROCK_FOUNDATION_LLMS: + context_size = 200000 + + llm = Bedrock( + model=model, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + region_name=region_name, + context_size=context_size, + **config, + ) + # Note: Because llama index Bedrock class doesn't set up these values to the corresponding + # attributes in its constructor function, we pass the values again via setter to pass them to + # `get_dspy_lm_by_llama_llm` function. + llm.aws_access_key_id = access_key_id + llm.aws_secret_access_key = secret_access_key + llm.region_name = region_name + return llm + case LLMProvider.ANTHROPIC_VERTEX: + google_creds: service_account.Credentials = ( + service_account.Credentials.from_service_account_info( + credentials, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + ) + google_creds.refresh(request=Request()) + if "max_tokens" not in config: + config.update(max_tokens=4096) + return Vertex( + model=model, + credentials=google_creds, + **config, + ) + case LLMProvider.OLLAMA: + config.setdefault("request_timeout", 60 * 10) + config.setdefault("context_window", 4096) + return Ollama(model=model, **config) + case LLMProvider.GITEEAI: + config.setdefault("context_window", 200 * 1000) + return OpenAILike( + model=model, + api_base="https://ai.gitee.com/v1", + api_key=credentials, + **config, + ) + case _: + raise ValueError(f"Got unknown LLM provider: {provider}") + + +def get_default_llm(session: Session) -> Optional[LLM]: + db_llm = llm_repo.get_default(session) + if not db_llm: + return None + return get_llm( + db_llm.provider, + db_llm.model, + db_llm.config, + db_llm.credentials, + ) + + +def must_get_default_llm(session: Session) -> LLM: + db_llm = llm_repo.must_get_default(session) + return get_llm( + db_llm.provider, + db_llm.model, + db_llm.config, + db_llm.credentials, + ) diff --git a/backend/app/rag/node_postprocessor/__init__.py b/backend/app/rag/node_postprocessor/__init__.py deleted file mode 100644 index 1d3fd9f3e..000000000 --- a/backend/app/rag/node_postprocessor/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .metadata_post_filter import MetadataPostFilter - -__all__ = ["MetadataPostFilter"] diff --git a/backend/app/rag/postprocessors/__init__.py b/backend/app/rag/postprocessors/__init__.py new file mode 100644 index 000000000..c23d2f333 --- /dev/null +++ b/backend/app/rag/postprocessors/__init__.py @@ -0,0 +1,8 @@ +from .metadata_post_filter import MetadataPostFilter, MetadataFilters +from .resolver import get_metadata_post_filter + +__all__ = [ + "MetadataPostFilter", + "MetadataFilters", + "get_metadata_post_filter", +] diff --git a/backend/app/rag/node_postprocessor/metadata_post_filter.py b/backend/app/rag/postprocessors/metadata_post_filter.py similarity index 100% rename from backend/app/rag/node_postprocessor/metadata_post_filter.py rename to backend/app/rag/postprocessors/metadata_post_filter.py diff --git a/backend/app/rag/postprocessors/resolver.py b/backend/app/rag/postprocessors/resolver.py new file mode 100644 index 000000000..ad0178c1c --- /dev/null +++ b/backend/app/rag/postprocessors/resolver.py @@ -0,0 +1,12 @@ +from typing import Optional +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from .metadata_post_filter import ( + MetadataFilters, + MetadataPostFilter, +) + + +def get_metadata_post_filter( + filters: Optional[MetadataFilters] = None, +) -> BaseNodePostprocessor: + return MetadataPostFilter(filters) diff --git a/backend/app/rag/rerankers/__init__.py b/backend/app/rag/rerankers/__init__.py new file mode 100644 index 000000000..956c49779 --- /dev/null +++ b/backend/app/rag/rerankers/__init__.py @@ -0,0 +1,26 @@ +from .baisheng.baisheng_reranker import BaishengRerank +from .local.local_reranker import LocalRerank +from .vllm.vllm_reranker import VLLMRerank + +from .provider import ( + RerankerProvider, + RerankerProviderOption, + reranker_provider_options, +) +from .resolver import ( + get_reranker_model, + get_default_reranker_model, + must_get_default_reranker_model, +) + +__all__ = [ + "RerankerProvider", + "RerankerProviderOption", + "BaishengRerank", + "LocalRerank", + "VLLMRerank", + "get_reranker_model", + "get_default_reranker_model", + "must_get_default_reranker_model", + "reranker_provider_options", +] diff --git a/backend/app/rag/node_postprocessor/baisheng_reranker.py b/backend/app/rag/rerankers/baisheng/baisheng_reranker.py similarity index 100% rename from backend/app/rag/node_postprocessor/baisheng_reranker.py rename to backend/app/rag/rerankers/baisheng/baisheng_reranker.py diff --git a/backend/app/rag/node_postprocessor/local_reranker.py b/backend/app/rag/rerankers/local/local_reranker.py similarity index 100% rename from backend/app/rag/node_postprocessor/local_reranker.py rename to backend/app/rag/rerankers/local/local_reranker.py diff --git a/backend/app/rag/reranker_model_option.py b/backend/app/rag/rerankers/provider.py similarity index 91% rename from backend/app/rag/reranker_model_option.py rename to backend/app/rag/rerankers/provider.py index f595093f0..b4a6af281 100644 --- a/backend/app/rag/reranker_model_option.py +++ b/backend/app/rag/rerankers/provider.py @@ -1,10 +1,19 @@ +import enum from typing import List from pydantic import BaseModel -from app.types import RerankerProvider +class RerankerProvider(str, enum.Enum): + JINA = "jina" + COHERE = "cohere" + BAISHENG = "baisheng" + LOCAL = "local" + VLLM = "vllm" + XINFERENCE = "xinference" + BEDROCK = "bedrock" -class RerankerModelOption(BaseModel): + +class RerankerProviderOption(BaseModel): provider: RerankerProvider provider_display_name: str | None = None provider_description: str | None = None @@ -20,8 +29,8 @@ class RerankerModelOption(BaseModel): credentials_type: str = "str" -admin_reranker_model_options: List[RerankerModelOption] = [ - RerankerModelOption( +reranker_provider_options: List[RerankerProviderOption] = [ + RerankerProviderOption( provider=RerankerProvider.JINA, provider_display_name="Jina AI", provider_description="We provide best-in-class embeddings, rerankers, LLM-reader and prompt optimizers, pioneering search AI for multimodal data.", @@ -34,7 +43,7 @@ class RerankerModelOption(BaseModel): credentials_type="str", default_credentials="jina_****", ), - RerankerModelOption( + RerankerProviderOption( provider=RerankerProvider.COHERE, provider_display_name="Cohere", provider_description="Cohere provides industry-leading large language models (LLMs) and RAG capabilities tailored to meet the needs of enterprise use cases that solve real-world problems.", @@ -47,7 +56,7 @@ class RerankerModelOption(BaseModel): credentials_type="str", default_credentials="*****", ), - RerankerModelOption( + RerankerProviderOption( provider=RerankerProvider.BAISHENG, provider_display_name="BaiSheng", default_reranker_model="bge-reranker-v2-m3", @@ -61,7 +70,7 @@ class RerankerModelOption(BaseModel): credentials_type="str", default_credentials="*****", ), - RerankerModelOption( + RerankerProviderOption( provider=RerankerProvider.LOCAL, provider_display_name="Local Reranker", provider_description="TIDB.AI's local reranker server, deployed on your own infrastructure and powered by sentence-transformers.", @@ -77,7 +86,7 @@ class RerankerModelOption(BaseModel): credentials_type="str", default_credentials="dummy", ), - RerankerModelOption( + RerankerProviderOption( provider=RerankerProvider.VLLM, provider_display_name="vLLM", provider_description="vLLM is a fast and easy-to-use library for LLM inference and serving.", @@ -93,7 +102,7 @@ class RerankerModelOption(BaseModel): credentials_type="str", default_credentials="dummy", ), - RerankerModelOption( + RerankerProviderOption( provider=RerankerProvider.XINFERENCE, provider_display_name="Xinference Reranker", provider_description="Xorbits Inference (Xinference) is an open-source platform to streamline the operation and integration of a wide array of AI models.", @@ -109,7 +118,7 @@ class RerankerModelOption(BaseModel): credentials_type="str", default_credentials="dummy", ), - RerankerModelOption( + RerankerProviderOption( provider=RerankerProvider.BEDROCK, provider_display_name="Bedrock Reranker", provider_description="Amazon Bedrock is a fully managed foundation models service.", @@ -125,5 +134,5 @@ class RerankerModelOption(BaseModel): "aws_secret_access_key": "****", "aws_region_name": "us-west-2", }, - ) + ), ] diff --git a/backend/app/rag/rerankers/resolver.py b/backend/app/rag/rerankers/resolver.py new file mode 100644 index 000000000..21114fe5e --- /dev/null +++ b/backend/app/rag/rerankers/resolver.py @@ -0,0 +1,103 @@ +from typing import Optional +from sqlmodel import Session + +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.postprocessor.jinaai_rerank import JinaRerank +from llama_index.postprocessor.cohere_rerank import CohereRerank +from llama_index.postprocessor.xinference_rerank import XinferenceRerank +from llama_index.postprocessor.bedrock_rerank import AWSBedrockRerank + +from app.rag.rerankers.baisheng.baisheng_reranker import BaishengRerank +from app.rag.rerankers.local.local_reranker import LocalRerank +from app.rag.rerankers.vllm.vllm_reranker import VLLMRerank +from app.rag.rerankers.provider import RerankerProvider + +from app.repositories.reranker_model import reranker_model_repo + + +def get_reranker_model( + provider: RerankerProvider, + model: str, + top_n: int, + config: dict, + credentials: str | list | dict | None, +) -> BaseNodePostprocessor: + match provider: + case RerankerProvider.JINA: + return JinaRerank( + model=model, + top_n=top_n, + api_key=credentials, + **config, + ) + case RerankerProvider.COHERE: + return CohereRerank( + model=model, + top_n=top_n, + api_key=credentials, + **config, + ) + case RerankerProvider.BAISHENG: + return BaishengRerank( + model=model, + top_n=top_n, + api_key=credentials, + **config, + ) + case RerankerProvider.LOCAL: + return LocalRerank( + model=model, + top_n=top_n, + **config, + ) + case RerankerProvider.VLLM: + return VLLMRerank( + model=model, + top_n=top_n, + **config, + ) + case RerankerProvider.XINFERENCE: + return XinferenceRerank( + model=model, + top_n=top_n, + **config, + ) + case RerankerProvider.BEDROCK: + return AWSBedrockRerank( + rerank_model_name=model, + top_n=top_n, + aws_access_key_id=credentials["aws_access_key_id"], + aws_secret_access_key=credentials["aws_secret_access_key"], + region_name=credentials["aws_region_name"], + **config, + ) + case _: + raise ValueError(f"Got unknown reranker provider: {provider}") + + +# FIXME: Reranker top_n should be config in the retrival config. +def get_default_reranker_model( + session: Session, top_n: int = None +) -> Optional[BaseNodePostprocessor]: + db_reranker = reranker_model_repo.get_default(session) + if not db_reranker: + return None + top_n = db_reranker.top_n if top_n is None else top_n + return get_reranker_model( + db_reranker.provider, + db_reranker.model, + top_n, + db_reranker.config, + db_reranker.credentials, + ) + + +def must_get_default_reranker_model(session: Session) -> BaseNodePostprocessor: + db_reranker = reranker_model_repo.must_get_default(session) + return get_reranker_model( + db_reranker.provider, + db_reranker.model, + db_reranker.top_n, + db_reranker.config, + db_reranker.credentials, + ) diff --git a/backend/app/rag/node_postprocessor/vllm_reranker.py b/backend/app/rag/rerankers/vllm/vllm_reranker.py similarity index 98% rename from backend/app/rag/node_postprocessor/vllm_reranker.py rename to backend/app/rag/rerankers/vllm/vllm_reranker.py index e66ddab4d..cc1ccc36e 100644 --- a/backend/app/rag/node_postprocessor/vllm_reranker.py +++ b/backend/app/rag/rerankers/vllm/vllm_reranker.py @@ -83,7 +83,9 @@ def _postprocess_nodes( raise RuntimeError(f"Got error from reranker: {resp_json}") results = zip(range(len(nodes)), resp_json["data"]) - results = sorted(results, key=lambda x: x[1]["score"], reverse=True)[: self.top_n] + results = sorted(results, key=lambda x: x[1]["score"], reverse=True)[ + : self.top_n + ] new_nodes = [] for result in results: diff --git a/backend/app/repositories/chat_engine.py b/backend/app/repositories/chat_engine.py index 8b6f128e1..8c6de8d0f 100644 --- a/backend/app/repositories/chat_engine.py +++ b/backend/app/repositories/chat_engine.py @@ -7,7 +7,7 @@ from fastapi_pagination.ext.sqlmodel import paginate from sqlalchemy.orm.attributes import flag_modified -from app.models import ChatEngine, ChatEngineUpdate +from app.models.chat_engine import ChatEngine, ChatEngineUpdate from app.repositories.base_repo import BaseRepo diff --git a/backend/app/repositories/staff_action_log.py b/backend/app/repositories/staff_action_log.py index 3b7e8adc3..14eb9c1f6 100644 --- a/backend/app/repositories/staff_action_log.py +++ b/backend/app/repositories/staff_action_log.py @@ -1,6 +1,6 @@ from sqlmodel import Session -from app.models import StaffActionLog +from app.models.staff_action_log import StaffActionLog from app.repositories.base_repo import BaseRepo diff --git a/backend/app/types.py b/backend/app/types.py index 73ef5e06e..791afef1a 100644 --- a/backend/app/types.py +++ b/backend/app/types.py @@ -1,37 +1,6 @@ import enum -class LLMProvider(str, enum.Enum): - OPENAI = "openai" - GEMINI = "gemini" - ANTHROPIC_VERTEX = "anthropic_vertex" - OPENAI_LIKE = "openai_like" - BEDROCK = "bedrock" - OLLAMA = "ollama" - GITEEAI = "giteeai" - - -class EmbeddingProvider(str, enum.Enum): - OPENAI = "openai" - JINA = "jina" - COHERE = "cohere" - BEDROCK = "bedrock" - OLLAMA = "ollama" - GITEEAI = "giteeai" - LOCAL = "local" - OPENAI_LIKE = "openai_like" - - -class RerankerProvider(str, enum.Enum): - JINA = "jina" - COHERE = "cohere" - BAISHENG = "baisheng" - LOCAL = "local" - VLLM = "vllm" - XINFERENCE = "xinference" - BEDROCK = "bedrock" - - class MimeTypes(str, enum.Enum): PLAIN_TXT = "text/plain" MARKDOWN = "text/markdown" diff --git a/backend/app/utils/dspy.py b/backend/app/utils/dspy.py index b76918f9a..8c117ee2d 100644 --- a/backend/app/utils/dspy.py +++ b/backend/app/utils/dspy.py @@ -52,7 +52,9 @@ def get_dspy_lm_by_llama_llm(llama_llm: BaseLLM) -> dspy.LM: bedrock = dspy.Bedrock(region_name=llama_llm.region_name) if llama_llm.model.startswith("anthropic"): return dspy.AWSAnthropic( - bedrock, model=llama_llm.model, max_new_tokens=llama_llm.max_tokens or 8192 + bedrock, + model=llama_llm.model, + max_new_tokens=llama_llm.max_tokens or 8192, ) elif llama_llm.model.startswith("meta"): return dspy.AWSMeta( diff --git a/backend/pyproject.toml b/backend/pyproject.toml index b6e02dd1a..e26d8128e 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -6,61 +6,61 @@ authors = [ { name = "wd0517", email = "me@wangdi.ink" } ] dependencies = [ - "fastapi>=0.115.4", - "sqlmodel==0.0.19", + "fastapi>=0.115.6", + "fastapi-cli>=0.0.5", + "fastapi-users>=13.0.0", + "fastapi-pagination>=0.12.25", + "fastapi-users-db-sqlmodel>=0.3.0", + "sqlmodel==0.0.22", "pymysql>=1.1.1", - "celery>=5.4.0", - "dspy-ai>=2.4.9", - "langfuse>=2.48.0", - "llama-index>=0.11.10", - "alembic>=1.13.1", - "pydantic>=2.8.2", + "asyncmy>=0.2.9", + "tidb-vector>=0.0.14", + "alembic>=1.14.0", + "pydantic>=2.10.4", # Update Check: https://github.com/pydantic/pydantic/issues/8061 "pydantic-settings>=2.3.3", + "redis>=5.0.5", + "celery>=5.4.0", + "flower>=2.0.1", + "httpx-oauth>=0.14.1", + "uvicorn>=0.30.3", + "gunicorn>=22.0.0", + "python-dotenv>=1.0.1", "sentry-sdk>=2.5.1", + "dspy-ai>=2.4.9", + "langfuse>=2.48.0", + "langchain-openai>=0.2.9", + "ragas>=0.2.6", + "deepeval>=0.21.73", "click>=8.1.7", - "uvicorn>=0.30.3", "tenacity~=8.4.0", - "redis>=5.0.5", - "flower>=2.0.1", - "llama-index-llms-gemini>=0.1.11", - "tidb-vector>=0.0.14", + "retry>=0.9.2", "deepdiff>=7.0.1", - "python-dotenv>=1.0.1", - "fastapi-users>=13.0.0", - "asyncmy>=0.2.9", - "fastapi-users-db-sqlmodel>=0.3.0", - "llama-index-postprocessor-jinaai-rerank>=0.1.6", - "httpx-oauth>=0.14.1", + "colorama>=0.4.6", "jinja2>=3.1.4", - "fastapi-pagination>=0.12.25", - "gunicorn>=22.0.0", "pyyaml>=6.0.1", - "anthropic[vertex]>=0.28.1", - "google-cloud-aiplatform>=1.59.0", - "deepeval>=0.21.73", - "llama-index-llms-openai>=0.1.27", - "llama-index-llms-openai-like>=0.1.3", "playwright>=1.45.1", "markdownify>=0.13.1", - "llama-index-postprocessor-cohere-rerank>=0.1.7", - "llama-index-llms-bedrock>=0.1.12", "pypdf>=4.3.1", - "llama-index-llms-ollama>=0.3.0", - "llama-index-embeddings-ollama>=0.3.0", - "llama-index-embeddings-jinaai>=0.3.0", - "llama-index-embeddings-cohere>=0.2.0", "python-docx>=1.1.2", "python-pptx>=1.0.2", - "colorama>=0.4.6", "openpyxl>=3.1.5", - "fastapi-cli>=0.0.5", - "retry>=0.9.2", - "langchain-openai>=0.2.9", - "ragas>=0.2.6", - "llama-index-embeddings-bedrock>=0.2.0", + "llama-index>=0.12.10", + "llama-index-llms-openai>=0.3.13", + "llama-index-llms-openai-like>=0.3.3", + "llama-index-llms-bedrock>=0.3.3", + "llama-index-llms-ollama>=0.5.0", + # 0.4.3: AttributeError: module 'google.generativeai.types' has no attribute 'RequestOptions' + "llama-index-llms-gemini==0.4.2", + "llama-index-embeddings-ollama>=0.5.0", + "llama-index-embeddings-jinaai>=0.4.0", + "llama-index-embeddings-cohere>=0.4.0", + "llama-index-embeddings-bedrock>=0.4.0", + "llama-index-postprocessor-jinaai-rerank>=0.3.0", + "llama-index-postprocessor-cohere-rerank>=0.3.0", "llama-index-postprocessor-xinference-rerank>=0.2.0", "llama-index-postprocessor-bedrock-rerank>=0.3.0", + "llama-index-llms-vertex>=0.4.2", ] readme = "README.md" requires-python = ">= 3.8" diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index 0813dc3e4..688c05262 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -17,7 +17,7 @@ aiohttp==3.9.5 # via llama-index-core aiosignal==1.3.1 # via aiohttp -alembic==1.13.1 +alembic==1.14.0 # via optuna amqp==5.2.0 # via kombu @@ -141,7 +141,7 @@ et-xmlfile==1.1.0 # via openpyxl execnet==2.1.1 # via pytest-xdist -fastapi==0.115.4 +fastapi==0.115.6 # via fastapi-users fastapi-cli==0.0.5 fastapi-pagination==0.12.25 @@ -191,22 +191,23 @@ google-auth==2.30.0 # via google-generativeai google-auth-httplib2==0.2.0 # via google-api-python-client -google-cloud-aiplatform==1.59.0 -google-cloud-bigquery==3.25.0 +google-cloud-aiplatform==1.76.0 + # via llama-index-llms-vertex +google-cloud-bigquery==3.27.0 # via google-cloud-aiplatform google-cloud-core==2.4.1 # via google-cloud-bigquery # via google-cloud-storage -google-cloud-resource-manager==1.12.4 +google-cloud-resource-manager==1.12.5 # via google-cloud-aiplatform -google-cloud-storage==2.17.0 +google-cloud-storage==2.19.0 # via google-cloud-aiplatform -google-crc32c==1.5.0 +google-crc32c==1.6.0 # via google-cloud-storage # via google-resumable-media google-generativeai==0.5.4 # via llama-index-llms-gemini -google-resumable-media==2.7.1 +google-resumable-media==2.7.2 # via google-cloud-bigquery # via google-cloud-storage googleapis-common-protos==1.63.2 @@ -218,7 +219,7 @@ greenlet==3.0.3 # via fastapi-users-db-sqlmodel # via playwright # via sqlalchemy -grpc-google-iam-v1==0.13.1 +grpc-google-iam-v1==0.14.0 # via google-cloud-resource-manager grpcio==1.63.0 # via deepeval @@ -337,6 +338,7 @@ llama-index-core==0.12.10.post1 # via llama-index-llms-ollama # via llama-index-llms-openai # via llama-index-llms-openai-like + # via llama-index-llms-vertex # via llama-index-multi-modal-llms-openai # via llama-index-postprocessor-bedrock-rerank # via llama-index-postprocessor-cohere-rerank @@ -360,8 +362,8 @@ llama-index-llms-anthropic==0.6.3 # via llama-index-llms-bedrock llama-index-llms-bedrock==0.3.3 llama-index-llms-gemini==0.4.2 -llama-index-llms-ollama==0.4.2 -llama-index-llms-openai==0.3.12 +llama-index-llms-ollama==0.5.0 +llama-index-llms-openai==0.3.13 # via llama-index # via llama-index-agent-openai # via llama-index-cli @@ -370,6 +372,7 @@ llama-index-llms-openai==0.3.12 # via llama-index-program-openai # via llama-index-question-gen-openai llama-index-llms-openai-like==0.3.3 +llama-index-llms-vertex==0.4.2 llama-index-multi-modal-llms-openai==0.4.2 # via llama-index llama-index-postprocessor-bedrock-rerank==0.3.0 @@ -433,7 +436,7 @@ numpy==1.26.4 # via shapely # via tidb-vector # via transformers -ollama==0.3.1 +ollama==0.4.5 # via llama-index-embeddings-ollama # via llama-index-llms-ollama openai==1.59.3 @@ -532,7 +535,7 @@ pyasn1-modules==0.4.0 # via google-auth pycparser==2.22 # via cffi -pydantic==2.8.2 +pydantic==2.10.5 # via anthropic # via cohere # via deepeval @@ -547,11 +550,12 @@ pydantic==2.8.2 # via langsmith # via llama-cloud # via llama-index-core + # via ollama # via openai # via pydantic-settings # via ragas # via sqlmodel -pydantic-core==2.20.1 +pydantic-core==2.27.2 # via pydantic pydantic-settings==2.6.1 # via langchain-community @@ -644,7 +648,7 @@ safetensors==0.4.3 # via transformers sentry-sdk==2.5.1 # via deepeval -shapely==2.0.4 +shapely==2.0.6 # via google-cloud-aiplatform shellingham==1.5.4 # via typer @@ -665,7 +669,7 @@ sqlalchemy==2.0.30 # via llama-index-core # via optuna # via sqlmodel -sqlmodel==0.0.19 +sqlmodel==0.0.22 # via fastapi-users-db-sqlmodel starlette==0.41.2 # via fastapi @@ -715,6 +719,7 @@ typing-extensions==4.12.2 # via cohere # via fastapi # via fastapi-pagination + # via google-cloud-aiplatform # via google-generativeai # via huggingface-hub # via langchain-core diff --git a/backend/requirements.lock b/backend/requirements.lock index d731e9f3d..5444678e1 100644 --- a/backend/requirements.lock +++ b/backend/requirements.lock @@ -17,7 +17,7 @@ aiohttp==3.9.5 # via llama-index-core aiosignal==1.3.1 # via aiohttp -alembic==1.13.1 +alembic==1.14.0 # via optuna amqp==5.2.0 # via kombu @@ -137,7 +137,7 @@ et-xmlfile==1.1.0 # via openpyxl execnet==2.1.1 # via pytest-xdist -fastapi==0.115.4 +fastapi==0.115.6 # via fastapi-users fastapi-cli==0.0.5 fastapi-pagination==0.12.25 @@ -186,22 +186,23 @@ google-auth==2.30.0 # via google-generativeai google-auth-httplib2==0.2.0 # via google-api-python-client -google-cloud-aiplatform==1.59.0 -google-cloud-bigquery==3.25.0 +google-cloud-aiplatform==1.76.0 + # via llama-index-llms-vertex +google-cloud-bigquery==3.27.0 # via google-cloud-aiplatform google-cloud-core==2.4.1 # via google-cloud-bigquery # via google-cloud-storage -google-cloud-resource-manager==1.12.4 +google-cloud-resource-manager==1.12.5 # via google-cloud-aiplatform -google-cloud-storage==2.17.0 +google-cloud-storage==2.19.0 # via google-cloud-aiplatform -google-crc32c==1.5.0 +google-crc32c==1.6.0 # via google-cloud-storage # via google-resumable-media google-generativeai==0.5.4 # via llama-index-llms-gemini -google-resumable-media==2.7.1 +google-resumable-media==2.7.2 # via google-cloud-bigquery # via google-cloud-storage googleapis-common-protos==1.63.2 @@ -213,7 +214,7 @@ greenlet==3.0.3 # via fastapi-users-db-sqlmodel # via playwright # via sqlalchemy -grpc-google-iam-v1==0.13.1 +grpc-google-iam-v1==0.14.0 # via google-cloud-resource-manager grpcio==1.63.0 # via deepeval @@ -330,6 +331,7 @@ llama-index-core==0.12.10.post1 # via llama-index-llms-ollama # via llama-index-llms-openai # via llama-index-llms-openai-like + # via llama-index-llms-vertex # via llama-index-multi-modal-llms-openai # via llama-index-postprocessor-bedrock-rerank # via llama-index-postprocessor-cohere-rerank @@ -353,8 +355,8 @@ llama-index-llms-anthropic==0.6.3 # via llama-index-llms-bedrock llama-index-llms-bedrock==0.3.3 llama-index-llms-gemini==0.4.2 -llama-index-llms-ollama==0.4.2 -llama-index-llms-openai==0.3.12 +llama-index-llms-ollama==0.5.0 +llama-index-llms-openai==0.3.13 # via llama-index # via llama-index-agent-openai # via llama-index-cli @@ -363,6 +365,7 @@ llama-index-llms-openai==0.3.12 # via llama-index-program-openai # via llama-index-question-gen-openai llama-index-llms-openai-like==0.3.3 +llama-index-llms-vertex==0.4.2 llama-index-multi-modal-llms-openai==0.4.2 # via llama-index llama-index-postprocessor-bedrock-rerank==0.3.0 @@ -424,7 +427,7 @@ numpy==1.26.4 # via shapely # via tidb-vector # via transformers -ollama==0.3.1 +ollama==0.4.5 # via llama-index-embeddings-ollama # via llama-index-llms-ollama openai==1.59.3 @@ -520,7 +523,7 @@ pyasn1-modules==0.4.0 # via google-auth pycparser==2.22 # via cffi -pydantic==2.8.2 +pydantic==2.10.5 # via anthropic # via cohere # via deepeval @@ -535,11 +538,12 @@ pydantic==2.8.2 # via langsmith # via llama-cloud # via llama-index-core + # via ollama # via openai # via pydantic-settings # via ragas # via sqlmodel -pydantic-core==2.20.1 +pydantic-core==2.27.2 # via pydantic pydantic-settings==2.6.1 # via langchain-community @@ -630,7 +634,7 @@ safetensors==0.4.3 # via transformers sentry-sdk==2.5.1 # via deepeval -shapely==2.0.4 +shapely==2.0.6 # via google-cloud-aiplatform shellingham==1.5.4 # via typer @@ -651,7 +655,7 @@ sqlalchemy==2.0.30 # via llama-index-core # via optuna # via sqlmodel -sqlmodel==0.0.19 +sqlmodel==0.0.22 # via fastapi-users-db-sqlmodel starlette==0.41.2 # via fastapi @@ -701,6 +705,7 @@ typing-extensions==4.12.2 # via cohere # via fastapi # via fastapi-pagination + # via google-cloud-aiplatform # via google-generativeai # via huggingface-hub # via langchain-core