Skip to content

Commit

Permalink
Add docs, add pdm deps
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Nov 29, 2023
1 parent 2f332c8 commit 4563aa3
Show file tree
Hide file tree
Showing 3 changed files with 1,363 additions and 7 deletions.
47 changes: 46 additions & 1 deletion adala/skills/collection/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,34 @@
class RAGSkill(TransformSkill):
"""
Skill for RAG (Retrieval-Augmented Generation) models.
Attributes:
input_template: Template for the input. It wraps the input with the template to create a query.
rag_input_template: Template for RAG input. It wraps each retrieved item with the template, and then concatenates them with two newlines.
Example: "Question: {question}\nContext: {context}" with num_results=2 will result in "Question: <question>\nContext: <context>\n\nQuestion: <question>\nContext: <context>"
instructions: Instructions for the generation part of the RAG model.
output_template: Template for the output. It wraps the output with the template.
num_results: Number of results to retrieve from the memory.
memory: Memory to use for retrieval. If None, a VectorDBMemory will be used.
Examples:
>>> from adala.skills import RAGSkill
>>> skill = RAGSkill(
... name="rag",
... input_template="Question: {question}",
... rag_input_template="Question: {question}\nContext: {context}",
... instructions="Answer the question.",
... output_template="{answer}",
... num_results=2,
... )
>>> skill.apply(
... input=InternalDataFrame(
... data=[
... {"question": "What is the meaning of life?", "context": "Life is a game."},
... {"question": "What is the meaning of life?", "context": "Life is a game."},
... ]
... ))
"""

name: str = "rag"
Expand All @@ -32,6 +60,15 @@ def apply(
) -> InternalDataFrame:
"""
Apply the skill.
Args:
input: Input data.
runtime: Runtime to use for generation.
Returns:
Output data. The output field is named after the output_template.
If no instructions are given, the output field contains concatenated strings from retrieved items.
If instructions are given, the output field contains the generated output.
"""
input_strings = input.apply(
lambda r: self.input_template.format(**r), axis=1
Expand All @@ -51,13 +88,15 @@ def apply(
output_field = output_fields[0]
rag_input = InternalDataFrame({output_field: rag_input_strings})
if self.instructions:
# if instructions are given, use the runtime to generate the output
output = runtime.batch_to_batch(
rag_input,
instructions_template=self.instructions,
input_template=f"{{{output_field}}}",
output_template=self.output_template,
)
else:
# if no instructions - simply return the rag input
output = rag_input
output.index = input.index

Expand All @@ -71,7 +110,13 @@ def improve(
runtime: Runtime,
):
"""
Improve the skill.
Improve the skill by storing the feedback match errors in the memory.
Args:
predictions: Predictions made by the skill.
train_skill_output: Output field of the skill used for training.
feedback: Feedback data. for feedback.match equals False (prediction errors), the input is stored in the memory.
runtime: Runtime to use for generation (not used).
"""

error_indices = feedback.match[
Expand Down
Loading

0 comments on commit 4563aa3

Please sign in to comment.