diff --git a/.env.example b/.env.example
index 9156d2ea..cdfe7143 100644
--- a/.env.example
+++ b/.env.example
@@ -90,4 +90,5 @@ SPARK_APPID=changethis
SPARK_APISecret=changethis
SPARK_APIKey=changethis
-ZHIPUAI_API_KEY=changethis
\ No newline at end of file
+ZHIPUAI_API_KEY=changethis
+SILICONFLOW_API_KEY=changethis
\ No newline at end of file
diff --git a/backend/app/core/celery_app.py b/backend/app/core/celery_app.py
index 5f929cae..d81be29c 100644
--- a/backend/app/core/celery_app.py
+++ b/backend/app/core/celery_app.py
@@ -1,8 +1,10 @@
import os
-
from celery import Celery
-
from app.core.config import settings
+import logging
+
+# 配置基本日志
+logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(os.getcwd(), "fastembed_cache")
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
@@ -22,3 +24,18 @@
celery_app.conf.update(
result_expires=3600,
)
+
+celery_app.conf.task_routes = {"app.worker.celery_worker.*": "main-queue"}
+celery_app.conf.update(task_track_started=True)
+
+# 配置 Celery 日志
+celery_app.conf.update(
+ worker_hijack_root_logger=False,
+ worker_log_format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ worker_task_log_format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+
+@celery_app.task(acks_late=True)
+def test_celery(word: str) -> str:
+ logging.info(f"Test task received: {word}")
+ return f"test task return {word}"
diff --git a/backend/app/core/config.py b/backend/app/core/config.py
index 09c1f942..776dcff2 100644
--- a/backend/app/core/config.py
+++ b/backend/app/core/config.py
@@ -142,9 +142,11 @@ def _enforce_non_default_secrets(self) -> Self:
return self
# Qdrant
- QDRANT_SERVICE_API_KEY: str | None = None
- QDRANT_URL: str | None = None
- QDRANT_COLLECTION: str | None = None
+ QDRANT_SERVICE_API_KEY: str | None = "XMj3HXm5GlBKQLwZuStOlkwZiOWTdd_IwZNDJINFh-w"
+ # QDRANT_URL: str = "http://localhost:6333"
+ QDRANT_URL: str = "http://127.0.0.1:6333"
+
+ QDRANT_COLLECTION: str | None = "kb_uploads"
# LangSmith
# USE_LANGSMITH: bool = True
@@ -154,9 +156,18 @@ def _enforce_non_default_secrets(self) -> Self:
# LANGCHAIN_PROJECT: str | None = None
# Embeddings
- DENSE_EMBEDDING_MODEL: str | None = None
- SPARSE_EMBEDDING_MODEL: str | None = None
- FASTEMBED_CACHE_PATH: str | None = None
+ # EMBEDDING_MODEL: str = "local" # 或者你想使用的其他模型
+ EMBEDDING_MODEL: str = "zhipuai" # 或者你想使用的其他模型
+
+ DENSE_EMBEDDING_MODEL: str = (
+ "sentence-transformers/all-MiniLM-L6-v2" # 默认的密集嵌入模型
+ )
+ SPARSE_EMBEDDING_MODEL: str = (
+ "sentence-transformers/all-MiniLM-L6-v2" # 默认的稀疏嵌入模型
+ )
+ ZHIPUAI_API_KEY: str | None = None
+ SILICONFLOW_API_KEY: str | None = None
+ OLLAMA_BASE_URL: str | None = None
# Celery
CELERY_BROKER_URL: str | None = None
@@ -166,5 +177,7 @@ def _enforce_non_default_secrets(self) -> Self:
RECURSION_LIMIT: int = 25
TAVILY_API_KEY: str | None = None
+ OPENAI_API_KEY: str
+
settings = Settings() # type: ignore
diff --git a/backend/app/core/graph/members.py b/backend/app/core/graph/members.py
index 10227bf3..17734238 100644
--- a/backend/app/core/graph/members.py
+++ b/backend/app/core/graph/members.py
@@ -1,8 +1,9 @@
from collections.abc import Mapping, Sequence
from typing import Annotated, Any
-
+from app.core.rag.qdrant import QdrantStore
+from langchain_core.tools import BaseTool
+from pydantic import BaseModel, Field
from langchain.chat_models import init_chat_model
-from langchain.tools.retriever import create_retriever_tool
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
@@ -18,7 +19,7 @@
from pydantic import BaseModel, Field
from typing_extensions import NotRequired, TypedDict
-from app.core.graph.rag.qdrant import QdrantStore
+from app.core.rag.qdrant import QdrantStore
from app.core.tools import managed_tools
from app.core.tools.api_tool import dynamic_api_tool
from app.core.tools.retriever_tool import create_retriever_tool
@@ -41,6 +42,18 @@ def tool(self) -> BaseTool:
raise ValueError("Skill is not managed and no definition provided.")
+# class GraphUpload(BaseModel):
+# name: str = Field(description="Name of the upload")
+# description: str = Field(description="Description of the upload")
+# owner_id: int = Field(description="Id of the user that owns this upload")
+# upload_id: int = Field(description="Id of the upload")
+
+# @property
+# def tool(self) -> BaseTool:
+# retriever = QdrantStore().retriever(self.owner_id, self.upload_id)
+# return create_retriever_tool(retriever)
+
+
class GraphUpload(BaseModel):
name: str = Field(description="Name of the upload")
description: str = Field(description="Description of the upload")
@@ -49,7 +62,8 @@ class GraphUpload(BaseModel):
@property
def tool(self) -> BaseTool:
- retriever = QdrantStore().retriever(self.owner_id, self.upload_id)
+ qdrant_store = QdrantStore()
+ retriever = qdrant_store.retriever(self.owner_id, self.upload_id)
return create_retriever_tool(retriever)
diff --git a/backend/app/core/graph/messages.py b/backend/app/core/graph/messages.py
index 63e2fdf9..4fecbf12 100644
--- a/backend/app/core/graph/messages.py
+++ b/backend/app/core/graph/messages.py
@@ -83,7 +83,7 @@ def event_to_response(event: StreamEvent) -> ChatResponse | None:
for doc in docs:
documents.append(
{
- "score": doc.metadata["score"],
+ # "score": doc.metadata["score"],
"content": doc.page_content,
}
)
diff --git a/backend/app/core/graph/rag/qdrant.py b/backend/app/core/graph/rag/qdrant.py
deleted file mode 100644
index 64c32fd8..00000000
--- a/backend/app/core/graph/rag/qdrant.py
+++ /dev/null
@@ -1,200 +0,0 @@
-import os
-from collections.abc import Callable
-from typing import Any
-
-import pymupdf # type: ignore[import-untyped]
-from langchain_community.document_loaders import WebBaseLoader
-from langchain_core.documents import Document
-from langchain_text_splitters import RecursiveCharacterTextSplitter
-from qdrant_client import QdrantClient
-from qdrant_client.http import models as rest
-
-from app.core.config import settings
-from app.core.graph.rag.qdrant_retriever import QdrantRetriever
-
-
-class QdrantStore:
- """
- A class to handle uploading and searching documents in a Qdrant vector store.
- """
-
- collection_name = settings.QDRANT_COLLECTION
- url = settings.QDRANT_URL
-
- def __init__(self) -> None:
- self.client = self._create_collection()
-
- def add(
- self,
- file_path: str,
- upload_id: int,
- user_id: int,
- chunk_size: int = 500,
- chunk_overlap: int = 50,
- callback: Callable[[], None] | None = None,
- ) -> None:
- """
- Uploads a PDF document to the Qdrant vector store after converting it to markdown and splitting into chunks.
-
- Args:
- upload_name (str): The name of the upload (PDF file path).
- user_id (int): The ID of the user uploading the document.
- chunk_size (int, optional): The size of each text chunk. Defaults to 500.
- chunk_overlap (int, optional): The overlap size between chunks. Defaults to 50.
- """
- if os.path.basename(file_path).endswith(".pdf"):
- doc = pymupdf.open(file_path)
- elif os.path.basename(file_path).endswith(".html"):
- loader = WebBaseLoader(file_path)
- doc = loader.load()
- else:
- raise ValueError("Unsupported file type")
- documents = [
- Document(
- page_content=page.get_text().encode("utf8"),
- metadata={"user_id": user_id, "upload_id": upload_id},
- )
- for page in doc
- ]
- text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=chunk_size,
- chunk_overlap=chunk_overlap,
- )
- docs = text_splitter.split_documents(documents)
-
- doc_texts: list[str] = []
- metadata: list[dict[Any, Any]] = []
- for doc in docs:
- doc_texts.append(doc.page_content)
- metadata.append(doc.metadata)
-
- self.client.add(
- collection_name=self.collection_name,
- documents=doc_texts,
- metadata=metadata,
- )
-
- callback() if callback else None
-
- def _create_collection(self) -> QdrantClient:
- """
- Creates a collection in Qdrant if it does not already exist, configured for hybrid search.
-
- The collection uses both dense and sparse vector models. Returns an instance of the Qdrant client.
-
- Returns:
- QdrantClient: An instance of the Qdrant client.
- """
- client = QdrantClient(
- url=self.url, api_key=settings.QDRANT_SERVICE_API_KEY, prefer_grpc=True
- )
- client.set_model(settings.DENSE_EMBEDDING_MODEL)
- client.set_sparse_model(settings.SPARSE_EMBEDDING_MODEL)
-
- if not client.collection_exists(self.collection_name):
- client.create_collection(
- collection_name=self.collection_name,
- vectors_config=client.get_fastembed_vector_params(),
- sparse_vectors_config=client.get_fastembed_sparse_vector_params(),
- )
- return client
-
- def delete(self, upload_id: int, user_id: int) -> None:
- """Delete points from collection where upload_id and user_id in metadata matches."""
- self.client.delete(
- collection_name=self.collection_name,
- points_selector=rest.FilterSelector(
- filter=rest.Filter(
- must=[
- rest.FieldCondition(
- key="user_id",
- match=rest.MatchValue(value=user_id),
- ),
- rest.FieldCondition(
- key="upload_id",
- match=rest.MatchValue(value=upload_id),
- ),
- ]
- )
- ),
- )
-
- def update(
- self,
- file_path: str,
- upload_id: int,
- user_id: int,
- chunk_size: int = 500,
- chunk_overlap: int = 50,
- callback: Callable[[], None] | None = None,
- ) -> None:
- """Delete and re-upload the new PDF document to the Qdrant vector store"""
- self.delete(user_id, upload_id)
- self.add(file_path, upload_id, user_id, chunk_size, chunk_overlap)
- callback() if callback else None
-
- def retriever(self, user_id: int, upload_id: int) -> QdrantRetriever:
- """
- Creates a VectorStoreRetriever that retrieves results containing the specified user_id and upload_id in the metadata.
-
- Args:
- user_id (int): Filters the retriever results to only include those belonging to this user.
- upload_id (int): Filters the retriever results to only include those from this upload ID.
-
- Returns:
- VectorStoreRetriever: A VectorStoreRetriever instance.
- """
- retriever = QdrantRetriever(
- client=self.client,
- collection_name=self.collection_name,
- search_kwargs=rest.Filter(
- must=[
- rest.FieldCondition(
- key="user_id",
- match=rest.MatchValue(value=user_id),
- ),
- rest.FieldCondition(
- key="upload_id",
- match=rest.MatchValue(value=upload_id),
- ),
- ],
- ),
- )
- return retriever
-
- def search(self, user_id: int, upload_ids: list[int], query: str) -> list[Document]:
- """
- Performs a similarity search in the Qdrant vector store for a given query, filtered by user ID and upload names.
-
- Args:
- user_id (str): The ID of the user performing the search.
- upload_names (list[str]): A list of upload names to filter the search.
- query (str): The search query.
-
- Returns:
- List[Document]: A list of documents matching the search criteria.
- """
- search_results = self.client.query(
- collection_name=self.collection_name,
- query_text=query,
- query_filter=rest.Filter(
- must=[
- rest.FieldCondition(
- key="user_id",
- match=rest.MatchValue(value=user_id),
- ),
- rest.FieldCondition(
- key="upload_id",
- match=rest.MatchAny(any=upload_ids),
- ),
- ],
- ),
- )
- documents: list[Document] = []
- for result in search_results:
- document = Document(
- page_content=result.document,
- metadata={"score": result.score},
- )
- documents.append(document)
- return documents
diff --git a/backend/app/core/graph/rag/__init__.py b/backend/app/core/rag/__init__.py
similarity index 100%
rename from backend/app/core/graph/rag/__init__.py
rename to backend/app/core/rag/__init__.py
diff --git a/backend/app/core/rag/embeddings.py b/backend/app/core/rag/embeddings.py
new file mode 100644
index 00000000..07db8d64
--- /dev/null
+++ b/backend/app/core/rag/embeddings.py
@@ -0,0 +1,112 @@
+from typing import List
+from langchain_core.embeddings import Embeddings
+from langchain_openai import OpenAIEmbeddings
+from langchain_community.embeddings import HuggingFaceEmbeddings
+from langchain_core.pydantic_v1 import BaseModel, Extra
+import requests
+from app.core.config import settings
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class ZhipuAIEmbeddings(BaseModel, Embeddings):
+ api_key: str = settings.ZHIPUAI_API_KEY
+ model: str = "embedding-3"
+ dimension: int = 2048 # 添加这一行,设置默认维度
+
+ class Config:
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ from zhipuai import ZhipuAI
+
+ client = ZhipuAI(api_key=self.api_key)
+ response = client.embeddings.create(model=self.model, input=texts)
+ embeddings = [item.embedding for item in response.data]
+ if embeddings:
+ self.dimension = len(embeddings[0]) # 更新实际的维度
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ return self.embed_documents([text])[0]
+
+
+class SiliconFlowEmbeddings(BaseModel, Embeddings):
+ api_key: str = settings.SILICONFLOW_API_KEY
+ model: str = "BAAI/bge-large-zh-v1.5"
+
+ class Config:
+ extra = Extra.forbid
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ url = "https://api.siliconflow.cn/v1/embeddings"
+ headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ "Authorization": f"Bearer {self.api_key}",
+ }
+ payload = {"model": self.model, "input": texts, "encoding_format": "float"}
+ response = requests.post(url, json=payload, headers=headers)
+ response_json = response.json()
+ logger.debug(
+ f"SiliconFlow API response: {response_json}"
+ ) # 添加这行来记录完整的响应
+
+ if "data" not in response_json or not isinstance(response_json["data"], list):
+ raise ValueError(
+ f"Unexpected response format from SiliconFlow API: {response_json}"
+ )
+
+ embeddings = []
+ for item in response_json["data"]:
+ if "embedding" not in item or not isinstance(item["embedding"], list):
+ raise ValueError(f"Unexpected embedding format in response: {item}")
+ embeddings.append(item["embedding"])
+
+ return embeddings
+
+ def embed_query(self, text: str) -> List[float]:
+ return self.embed_documents([text])[0]
+
+
+def get_embedding_dimension(embedding_model: Embeddings) -> int:
+ if hasattr(embedding_model, "dimension"):
+ return embedding_model.dimension
+ elif hasattr(embedding_model, "embedding_dim"):
+ return embedding_model.embedding_dim
+ else:
+ # 如果无法获取维度,我们可以尝试嵌入一个样本文本并获取其长度
+ sample_embedding = embedding_model.embed_query("Sample text for dimension")
+ return len(sample_embedding)
+
+
+def get_embedding_model(model_name: str) -> Embeddings:
+ logger.info(f"Initializing embedding model: {model_name}")
+ try:
+ if model_name == "openai":
+ embedding_model = OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY)
+ elif model_name == "zhipuai":
+ embedding_model = ZhipuAIEmbeddings()
+ elif model_name == "siliconflow":
+ embedding_model = SiliconFlowEmbeddings()
+ elif model_name == "local":
+ embedding_model = HuggingFaceEmbeddings(
+ model_name=settings.DENSE_EMBEDDING_MODEL,
+ model_kwargs={"device": "cpu"},
+ )
+ else:
+ raise ValueError(f"Unsupported embedding model: {model_name}")
+
+ logger.info(f"Embedding model created: {type(embedding_model)}")
+
+ # 对于 ZhipuAIEmbeddings,我们不需要手动设置 dimension
+ if not isinstance(embedding_model, ZhipuAIEmbeddings):
+ embedding_model.dimension = get_embedding_dimension(embedding_model)
+
+ logger.info(f"Embedding model dimension: {embedding_model.dimension}")
+
+ return embedding_model
+ except Exception as e:
+ logger.error(f"Error initializing embedding model: {e}", exc_info=True)
+ raise
diff --git a/backend/app/core/rag/qdrant.py b/backend/app/core/rag/qdrant.py
new file mode 100644
index 00000000..2007e9df
--- /dev/null
+++ b/backend/app/core/rag/qdrant.py
@@ -0,0 +1,208 @@
+from typing import List, Callable
+from langchain_community.document_loaders import PyMuPDFLoader, WebBaseLoader
+from langchain_core.documents import Document
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_qdrant import QdrantVectorStore
+from qdrant_client import QdrantClient
+from qdrant_client.http import models as rest
+from qdrant_client.models import VectorParams, Distance
+import pymupdf
+from app.core.config import settings
+from app.core.rag.embeddings import get_embedding_model
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class QdrantStore:
+ def __init__(self) -> None:
+ self.collection_name = settings.QDRANT_COLLECTION
+ self.url = settings.QDRANT_URL
+ self.embedding_model = get_embedding_model(settings.EMBEDDING_MODEL)
+
+ logger.info(f"Initializing QdrantStore with URL: {self.url}")
+
+ self.client = QdrantClient(
+ url=self.url, api_key=settings.QDRANT_SERVICE_API_KEY, prefer_grpc=False
+ )
+ logger.info("QdrantClient initialized successfully")
+
+ self._initialize_vector_store()
+
+ def _initialize_vector_store(self):
+ try:
+ collections = self.client.get_collections().collections
+ if self.collection_name not in [collection.name for collection in collections]:
+ logger.info(f"Creating new collection: {self.collection_name}")
+ self.client.create_collection(
+ collection_name=self.collection_name,
+ vectors_config=VectorParams(
+ size=self.embedding_model.dimension, distance=Distance.COSINE
+ ),
+ )
+ self.client.create_payload_index(
+ collection_name=self.collection_name,
+ field_name="user_id",
+ field_schema="integer",
+ )
+ self.client.create_payload_index(
+ collection_name=self.collection_name,
+ field_name="upload_id",
+ field_schema="integer",
+ )
+ else:
+ logger.info(f"Using existing collection: {self.collection_name}")
+
+ collection_info = self.client.get_collection(self.collection_name)
+ logger.info(f"Collection info: {collection_info}")
+
+ self.vector_store = QdrantVectorStore(
+ client=self.client,
+ collection_name=self.collection_name,
+ embedding=self.embedding_model,
+ )
+ except Exception as e:
+ logger.error(f"Error initializing vector store: {str(e)}", exc_info=True)
+ raise
+
+ def add(
+ self,
+ file_path: str,
+ upload_id: int,
+ user_id: int,
+ chunk_size: int = 500,
+ chunk_overlap: int = 50,
+ callback: Callable[[], None] | None = None,
+ ) -> None:
+ if file_path.endswith(".pdf"):
+ loader = PyMuPDFLoader(file_path)
+ elif file_path.endswith(".html"):
+ loader = WebBaseLoader(file_path)
+ else:
+ raise ValueError("Unsupported file type")
+
+ documents = loader.load()
+ for doc in documents:
+ doc.metadata.update({"user_id": user_id, "upload_id": upload_id})
+ logger.info(f"Document metadata: {doc.metadata}")
+
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ )
+ docs = text_splitter.split_documents(documents)
+
+ self.vector_store.add_documents(docs)
+
+ if callback:
+ callback()
+
+ def delete(self, upload_id: int, user_id: int) -> bool:
+ try:
+ result = self.client.delete(
+ collection_name=self.collection_name,
+ points_selector=rest.Filter(
+ must=[
+ rest.FieldCondition(
+ key="user_id",
+ match=rest.MatchValue(value=user_id),
+ ),
+ rest.FieldCondition(
+ key="upload_id",
+ match=rest.MatchValue(value=upload_id),
+ ),
+ ],
+ )
+ )
+ logger.info(f"Delete operation result: {result}")
+ return True
+ except Exception as e:
+ logger.error(f"Error deleting documents: {str(e)}", exc_info=True)
+ return False
+
+ def update(
+ self,
+ file_path: str,
+ upload_id: int,
+ user_id: int,
+ chunk_size: int = 500,
+ chunk_overlap: int = 50,
+ callback: Callable[[], None] | None = None,
+ ) -> None:
+ self.delete(upload_id, user_id)
+ self.add(file_path, upload_id, user_id, chunk_size, chunk_overlap)
+ if callback:
+ callback()
+
+ def search(self, user_id: int, upload_ids: List[int], query: str) -> List[Document]:
+ logger.info(f"Searching with query: '{query}' for user_id: {user_id}, upload_ids: {upload_ids}")
+
+ query_vector = self.embedding_model.embed_query(query)
+
+ filter_condition = {
+ "must": [
+ {"key": "metadata.user_id", "match": {"value": user_id}},
+ {"key": "metadata.upload_id", "match": {"any": upload_ids}}
+ ]
+ }
+ logger.info(f"Search filter condition: {filter_condition}")
+
+ search_results = self.client.search(
+ collection_name=self.collection_name,
+ query_vector=query_vector,
+ query_filter=filter_condition,
+ limit=4
+ )
+
+ documents = [Document(page_content=result.payload.get('page_content', ''), metadata=result.payload.get('metadata', {})) for result in search_results]
+
+ logger.info(f"Search results: {len(documents)} documents found")
+ for doc in documents:
+ logger.info(f"Document metadata: {doc.metadata}")
+
+ return documents
+
+ def retriever(self, user_id: int, upload_id: int):
+ logger.info(f"Creating retriever for user_id: {user_id}, upload_id: {upload_id}")
+ filter_condition = {
+ "must": [
+ {"key": "metadata.user_id", "match": {"value": user_id}},
+ {"key": "metadata.upload_id", "match": {"value": upload_id}}
+ ]
+ }
+ retriever = self.vector_store.as_retriever(
+ search_kwargs={
+ "filter": filter_condition,
+ "k": 5
+ },
+ search_type="similarity",
+ )
+ logger.info(f"Retriever created: {retriever}")
+ return retriever
+
+ def debug_retriever(self, user_id: int, upload_id: int, query: str):
+ logger.info(
+ f"Debug retriever for user_id: {user_id}, upload_id: {upload_id}, query: '{query}'"
+ )
+
+ # 使用过滤器的搜索
+ filtered_docs = self.search(user_id, [upload_id], query)
+ logger.info(f"Filtered search found {len(filtered_docs)} documents")
+ for doc in filtered_docs:
+ logger.info(f"Filtered doc metadata: {doc.metadata}")
+
+ # 不使用过滤器的搜索
+ unfiltered_docs = self.vector_store.similarity_search(query, k=5)
+ logger.info(f"Unfiltered search found {len(unfiltered_docs)} documents")
+
+ # 打印所有文档的元数据
+ for i, doc in enumerate(unfiltered_docs):
+ logger.info(f"Unfiltered doc {i} metadata: {doc.metadata}")
+
+ return filtered_docs
+
+ def get_collection_info(self):
+ collection_info = self.client.get_collection(self.collection_name)
+ logger.info(f"Collection info: {collection_info}")
+ return collection_info
\ No newline at end of file
diff --git a/backend/app/core/graph/rag/qdrant_retriever.py b/backend/app/core/rag/qdrant_retriever.py
similarity index 98%
rename from backend/app/core/graph/rag/qdrant_retriever.py
rename to backend/app/core/rag/qdrant_retriever.py
index 7340e59b..4bdcb2d7 100644
--- a/backend/app/core/graph/rag/qdrant_retriever.py
+++ b/backend/app/core/rag/qdrant_retriever.py
@@ -56,4 +56,4 @@ def _get_relevant_documents(
limit=self.k,
)
documents.append(document)
- return documents
+ return documents
\ No newline at end of file
diff --git a/backend/app/core/rag/rag_test_temp.py b/backend/app/core/rag/rag_test_temp.py
new file mode 100644
index 00000000..bd1ba4e9
--- /dev/null
+++ b/backend/app/core/rag/rag_test_temp.py
@@ -0,0 +1,27 @@
+import sys
+
+sys.path.append("./")
+
+from app.core.rag.qdrant import QdrantStore
+from app.core.rag.embeddings import get_embedding_model
+
+# 初始化 QdrantStore 和嵌入模型
+qdrant_store = QdrantStore()
+embedding_model = get_embedding_model("zhipuai")
+
+# 生成查询文本的嵌入向量
+query_text = "东莞市博视自控科技有限公司"
+query_vector = embedding_model.embed_query(query_text)
+
+# 执行搜索
+results = qdrant_store.vector_store.similarity_search(
+ query_text,
+ k=5, # 限制结果数量
+ filter={"must": [{"key": "page_content", "match": {"text": query_text}}]},
+)
+
+# 打印结果
+for doc in results:
+ print(f"Content: {doc.page_content}")
+ print(f"Metadata: {doc.metadata}")
+ print("---")
diff --git a/backend/app/core/rag/ragtest.py b/backend/app/core/rag/ragtest.py
new file mode 100644
index 00000000..b72b9ccc
--- /dev/null
+++ b/backend/app/core/rag/ragtest.py
@@ -0,0 +1,130 @@
+from app.core.rag.qdrant import QdrantStore
+from app.core.tools.retriever_tool import create_retriever_tool
+import logging
+import json
+from qdrant_client.http import models as rest
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+user_id = 1
+upload_id = 26
+query = "东莞市博视自控科技有限公司"
+
+# 创建 QdrantStore 实例
+qdrant_store = QdrantStore()
+
+# 检查集合信息
+collection_info = qdrant_store.get_collection_info()
+logger.info(f"Collection info: {collection_info}")
+
+logger.info("Checking document content:")
+all_points = qdrant_store.client.scroll(
+ collection_name=qdrant_store.collection_name,
+ limit=10,
+ with_payload=True,
+ with_vectors=False,
+)
+for point in all_points[0]:
+ logger.info(f"Point ID: {point.id}")
+ logger.info(f"Payload: {point.payload}")
+ logger.info(f"user_id type: {type(point.payload['metadata']['user_id'])}")
+ logger.info(f"upload_id type: {type(point.payload['metadata']['upload_id'])}")
+
+
+# 执行搜索
+search_results = qdrant_store.search(user_id, [upload_id], query)
+logger.info(f"Search found {len(search_results)} documents")
+for doc in search_results:
+ logger.info(f"Content: {doc.page_content[:100]}...")
+ logger.info(f"Metadata: {json.dumps(doc.metadata, ensure_ascii=False, indent=2)}")
+ logger.info("---")
+
+# 创建和使用检索工具
+retriever = qdrant_store.retriever(user_id, upload_id)
+logger.info(f"Created retriever: {retriever}")
+retriever_tool = create_retriever_tool(retriever)
+logger.info(f"Created retriever tool: {retriever_tool}")
+
+# 使用检索工具
+# 使用检索工具
+result, docs = retriever_tool._run(query)
+
+logger.info(f"Retriever tool result: {result[:100]}...")
+logger.info(f"Retriever tool found {len(docs)} documents")
+logger.info(f"Retriever search kwargs: {retriever.search_kwargs}")
+logger.info(f"Retriever vectorstore: {retriever.vectorstore}")
+
+for doc in docs:
+ logger.info(f"Retrieved document content: {doc.page_content[:100]}...")
+ logger.info(
+ f"Retrieved document metadata: {json.dumps(doc.metadata, ensure_ascii=False, indent=2)}"
+ )
+ logger.info("---")
+
+# 执行未过滤的搜索
+logger.info("Performing unfiltered search:")
+unfiltered_results = qdrant_store.vector_store.similarity_search(query, k=5)
+for doc in unfiltered_results:
+ logger.info(f"Unfiltered Content: {doc.page_content[:100]}...")
+ logger.info(
+ f"Unfiltered Metadata: {json.dumps(doc.metadata, ensure_ascii=False, indent=2)}"
+ )
+ logger.info("---")
+
+# 执行不带过滤器的直接搜索
+logger.info("Performing search without filter:")
+query_vector = qdrant_store.embedding_model.embed_query(query)
+unfiltered_results = qdrant_store.client.search(
+ collection_name=qdrant_store.collection_name, query_vector=query_vector, limit=5
+)
+for result in unfiltered_results:
+ logger.info(f"Unfiltered search result: {result.payload}")
+
+
+# 在文件末尾添加这个新函数
+def perform_native_qdrant_search(qdrant_store, user_id, upload_id, query):
+ logger.info("Performing native Qdrant API search with filter:")
+ query_vector = qdrant_store.embedding_model.embed_query(query)
+
+ # filter_condition = rest.Filter(
+ # must=[
+ # rest.FieldCondition(
+ # key="user_id",
+ # match=rest.MatchValue(value=user_id)
+ # ),
+ # rest.FieldCondition(
+ # key="upload_id",
+ # match=rest.MatchValue(value=upload_id)
+ # )
+ # ]
+ # )
+ filter_condition = {
+ "must": [
+ {"key": "metadata.user_id", "match": {"value": user_id}},
+ {"key": "metadata.upload_id", "match": {"value": upload_id}},
+ ]
+ }
+
+ native_results = qdrant_store.client.search(
+ collection_name=qdrant_store.collection_name,
+ query_vector=query_vector,
+ query_filter=filter_condition,
+ limit=5,
+ )
+
+ logger.info(f"Native Qdrant API search found {len(native_results)} results")
+ for result in native_results:
+ logger.info(f"Native search result: {result.payload}")
+
+ return native_results
+
+
+# 在主代码部分调用这个新函数
+if __name__ == "__main__":
+ # ... (保留之前的代码)
+
+ # 在文件末尾添加这个调用
+ native_results = perform_native_qdrant_search(
+ qdrant_store, user_id, upload_id, query
+ )
diff --git a/backend/app/core/tools/retriever_tool.py b/backend/app/core/tools/retriever_tool.py
index 3dccd0be..a68a9047 100644
--- a/backend/app/core/tools/retriever_tool.py
+++ b/backend/app/core/tools/retriever_tool.py
@@ -1,10 +1,12 @@
from typing import Annotated, Literal
+import logging
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools import BaseTool
+logger = logging.getLogger(__name__)
class RetrieverTool(BaseTool):
name: str = "KnowledgeBase"
@@ -12,29 +14,36 @@ class RetrieverTool(BaseTool):
response_format: Literal["content", "content_and_artifact"] = "content_and_artifact"
retriever: BaseRetriever
- document_prompt: BasePromptTemplate | PromptTemplate # type: ignore [type-arg]
+ document_prompt: BasePromptTemplate | PromptTemplate
document_separator: str
def _run(
self, query: Annotated[str, "query to look up in retriever"]
) -> tuple[str, list[Document]]:
"""Retrieve documents from knowledge base."""
+ logger.info(f"Retrieving documents for query: {query}")
docs = self.retriever.invoke(query, config={"callbacks": self.callbacks})
+ logger.info(f"Retrieved {len(docs)} documents")
+
+ if not docs:
+ logger.warning("No documents retrieved")
+ return "", []
+
result_string = self.document_separator.join(
[format_document(doc, self.document_prompt) for doc in docs]
)
+ logger.info(f"Formatted result string (first 100 chars): {result_string[:100]}...")
return result_string, docs
-
def create_retriever_tool(
retriever: BaseRetriever,
- document_prompt: BasePromptTemplate | None = None, # type: ignore [type-arg]
+ document_prompt: BasePromptTemplate | None = None,
document_separator: str = "\n\n",
) -> BaseTool:
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
-
+ logger.info(f"Creating retriever tool with retriever type: {type(retriever)}")
return RetrieverTool(
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
- )
+ )
\ No newline at end of file
diff --git a/backend/app/core/workflow/node.py b/backend/app/core/workflow/node.py
index b074fc6e..6c186574 100644
--- a/backend/app/core/workflow/node.py
+++ b/backend/app/core/workflow/node.py
@@ -18,7 +18,7 @@
from pydantic import BaseModel, Field
from typing_extensions import NotRequired, TypedDict
-from app.core.graph.rag.qdrant import QdrantStore
+from app.core.rag.qdrant import QdrantStore
from app.core.tools import managed_tools
from app.core.tools.api_tool import dynamic_api_tool
from app.core.tools.retriever_tool import create_retriever_tool
@@ -253,7 +253,9 @@ async def work(self, state: TeamState, config: RunnableConfig) -> ReturnTeamStat
"If you are unable to perform the task, that's OK, you can ask human for help, or just say that you are unable to perform the task."
"Execute what you can to make progress. "
"And your role is:" + self.system_prompt + "\n"
- "And your name is:" + self.agent_name + "please remember your name\n"
+ "And your name is:"
+ + self.agent_name
+ + "please remember your name\n"
"Stay true to your role and use your tools if necessary.\n\n",
),
(
diff --git a/backend/app/tasks/tasks.py b/backend/app/tasks/tasks.py
index 464f2497..464a2547 100644
--- a/backend/app/tasks/tasks.py
+++ b/backend/app/tasks/tasks.py
@@ -1,12 +1,15 @@
import os
+import logging
from sqlmodel import Session
from app.core.celery_app import celery_app
from app.core.db import engine
-from app.core.graph.rag.qdrant import QdrantStore
+from app.core.rag.qdrant import QdrantStore
from app.models import Upload, UploadStatus
+logger = logging.getLogger(__name__)
+
@celery_app.task
def add_upload(
@@ -21,8 +24,9 @@ def add_upload(
upload.status = UploadStatus.COMPLETED
session.add(upload)
session.commit()
+ logger.info(f"Upload {upload_id} completed successfully")
except Exception as e:
- print(f"add_upload failed: {e}")
+ logger.error(f"add_upload failed: {e}", exc_info=True)
upload.status = UploadStatus.FAILED
session.add(upload)
session.commit()
@@ -40,14 +44,17 @@ def edit_upload(
if not upload:
raise ValueError("Upload not found")
try:
- QdrantStore().update(
+ qdrant_store = QdrantStore()
+ logger.info("QdrantStore initialized successfully")
+ qdrant_store.update(
file_path, upload_id, user_id, chunk_size, chunk_overlap
)
upload.status = UploadStatus.COMPLETED
session.add(upload)
session.commit()
+ logger.info(f"Upload {upload_id} updated successfully")
except Exception as e:
- print(f"edit_upload failed: {e}")
+ logger.error(f"Error in edit_upload task: {e}", exc_info=True)
upload.status = UploadStatus.FAILED
session.add(upload)
session.commit()
@@ -61,10 +68,32 @@ def remove_upload(upload_id: int, user_id: int) -> None:
with Session(engine) as session:
upload = session.get(Upload, upload_id)
if not upload:
- raise ValueError("Upload not found")
+ logger.warning(
+ f"Upload not found in database for upload_id: {upload_id}, user_id: {user_id}"
+ )
+ return
+
try:
- QdrantStore().delete(upload_id, user_id)
- session.delete(upload)
- session.commit()
+ qdrant_store = QdrantStore()
+ deletion_successful = qdrant_store.delete(upload_id, user_id)
+
+ if deletion_successful or not deletion_successful:
+ # 无论删除是否成功,我们都从数据库中删除上传记录
+ session.delete(upload)
+ session.commit()
+ logger.info(f"Upload {upload_id} removed from database successfully")
+ else:
+ logger.warning(
+ f"Failed to delete documents from Qdrant for upload_id: {upload_id}, user_id: {user_id}"
+ )
+ upload.status = UploadStatus.FAILED
+ session.add(upload)
+ session.commit()
except Exception as e:
- print(f"remove_upload failed: {e}")
+ logger.error(
+ f"remove_upload failed for upload_id: {upload_id}, user_id: {user_id}. Error: {str(e)}",
+ exc_info=True,
+ )
+ upload.status = UploadStatus.FAILED
+ session.add(upload)
+ session.commit()
diff --git a/web/src/app/(applayout)/knowledge/page.tsx b/web/src/app/(applayout)/knowledge/page.tsx
index fdccc5fd..60e94808 100644
--- a/web/src/app/(applayout)/knowledge/page.tsx
+++ b/web/src/app/(applayout)/knowledge/page.tsx
@@ -159,7 +159,7 @@ function Uploads() {
alignItems={"center"}
>
{upload.status}
-
+
))}