Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] prompt preprocessing #541

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,270 changes: 1,270 additions & 0 deletions preprocessing_demo/preprocessing_demo.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions preprocessing_demo/temporal_docs/doc_0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Project report 2022

This year we we decided to go to the moon
3 changes: 3 additions & 0 deletions preprocessing_demo/temporal_docs/doc_1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Project report 2023

This year we built a rocket
3 changes: 3 additions & 0 deletions preprocessing_demo/temporal_docs/doc_2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Project report 2024

This year we went to the moon
19 changes: 19 additions & 0 deletions preprocessing_demo/temporal_docs/mkdocs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path

docs = [
"""# Project report 2022

This year we we decided to go to the moon""",
"""# Project report 2023

This year we built a rocket""",
"""# Project report 2024

This year we went to the moon""",
]

base_path = Path(__file__).parent
for i, doc in enumerate(docs):
doc_path = base_path / f"doc_{i}.md"
with open(doc_path, "w") as f:
f.write(doc)
6 changes: 6 additions & 0 deletions ragna/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
"Source",
"SourceStorage",
"PlainTextDocumentHandler",
"QueryProcessingStep",
"ProcessedQuery",
"QueryPreprocessor",
]

from ._utils import (
Expand Down Expand Up @@ -51,6 +54,9 @@
Component,
Message,
MessageRole,
ProcessedQuery,
QueryPreprocessor,
QueryProcessingStep,
Source,
SourceStorage,
)
Expand Down
31 changes: 31 additions & 0 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import functools
import inspect
import uuid
from dataclasses import field
from datetime import datetime, timezone
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Iterator,
List,
Optional,
Type,
Union,
Expand Down Expand Up @@ -302,3 +304,32 @@ def answer(self, messages: list[Message]) -> Iterator[str]:
Answer.
"""
...


class QueryProcessingStep(pydantic.BaseModel):
original_query: str
processed_query: str
metadata_filter: Optional[MetadataFilter] = None
processor_name: str = ""


class ProcessedQuery(pydantic.BaseModel):
"""original query is the query as it was passed to the preprocessor.
processed query is the query after each step of the processing pipeline.
metadata_filter is the metadata filter that was applied to the query."""

original_query: str
processed_query: str
metadata_filter: Optional[MetadataFilter] = None
processing_history: List[QueryProcessingStep] = field(default_factory=list)


class QueryPreprocessor(Component, abc.ABC):
"""Abstract base class for query preprocessors."""

@abc.abstractmethod
def process(
self, query: str, metadata_filter: Optional[MetadataFilter] = None
) -> ProcessedQuery:
"""Preprocess a query."""
...
25 changes: 22 additions & 3 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@

from ragna._utils import as_async_iterator, as_awaitable, default_user

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._components import (
Assistant,
Component,
Message,
MessageRole,
QueryPreprocessor,
SourceStorage,
)
from ._document import Document, LocalDocument
from ._metadata_filter import MetadataFilter
from ._utils import RagnaException, merge_models
Expand Down Expand Up @@ -96,7 +103,6 @@ def _load_component(
) -> Optional[C]:
cls: type[C]
instance: Optional[C]

if isinstance(component, Component):
instance = cast(C, component)
cls = type(instance)
Expand Down Expand Up @@ -148,6 +154,7 @@ def chat(
*,
source_storage: Union[SourceStorage, type[SourceStorage], str],
assistant: Union[Assistant, type[Assistant], str],
preprocessor: Optional[QueryPreprocessor] = None,
corpus_name: str = "default",
**params: Any,
) -> Chat:
Expand All @@ -167,11 +174,14 @@ def chat(
corpus_name: Corpus of documents to use.
**params: Additional parameters passed to the source storage and assistant.
"""
if preprocessor is not None:
preprocessor = (cast(preprocessor, self._load_component(preprocessor)),) # type: ignore[arg-type]
return Chat(
self,
input=input,
source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type]
assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type]
preprocessor=preprocessor,
corpus_name=corpus_name,
**params,
)
Expand Down Expand Up @@ -239,6 +249,7 @@ def __init__(
*,
source_storage: SourceStorage,
assistant: Assistant,
preprocessor: QueryPreprocessor = None,
corpus_name: str = "default",
**params: Any,
) -> None:
Expand All @@ -248,6 +259,7 @@ def __init__(
self.source_storage = source_storage
self.assistant = assistant
self.corpus_name = corpus_name
self.preprocessor = preprocessor

special_params = SpecialChatParams().model_dump()
special_params.update(params)
Expand Down Expand Up @@ -299,9 +311,16 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:
http_status_code=status.HTTP_400_BAD_REQUEST,
http_detail=RagnaException.EVENT,
)
if self.preprocessor is not None:
processed = self.preprocessor.process(prompt, self.metadata_filter)
prompt = processed.processed_query
self.metadata_filter = processed.metadata_filter

sources = await self._as_awaitable(
self.source_storage.retrieve, self.corpus_name, self.metadata_filter, prompt
self.source_storage.retrieve,
self.corpus_name,
self.metadata_filter,
prompt,
)
if not sources:
event = "Unable to retrieve any sources."
Expand Down
10 changes: 10 additions & 0 deletions ragna/preprocessors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__all__ = [
"RagnaDemoPreprocessor",
]

from ragna._utils import fix_module

from ._demo import RagnaDemoPreprocessor

fix_module(globals())
del fix_module
30 changes: 30 additions & 0 deletions ragna/preprocessors/_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional

from ragna.core import (
MetadataFilter,
ProcessedQuery,
QueryPreprocessor,
QueryProcessingStep,
)


class RagnaDemoPreprocessor(QueryPreprocessor):
def process(
self, query: str, metadata_filter: Optional[MetadataFilter] = None
) -> ProcessedQuery:
"""Retrieval query is the original query, answer query is the processed query."""
processed_query = """This is a demo preprocessor. It doesn't do anything to the query. original query: """
processed_query += query
return ProcessedQuery(
original_query=query,
processed_query=processed_query,
metadata_filter=metadata_filter or None,
processing_history=[
QueryProcessingStep(
original_query=query,
processed_query=query,
metadata_filter=metadata_filter,
processor_name=self.display_name(),
)
],
)
Empty file added tests/preprocessors/__init__.py
Empty file.
Loading