Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RAG #42

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 32 additions & 19 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
print_text,
print_error,
highlight_differences,
is_running_in_jupyter,
)
from adala.utils.internal_data import InternalDataFrame, InternalDataFrameConcat

Expand Down Expand Up @@ -276,25 +277,37 @@ def learn(
).merge(predictions, left_index=True, right_index=True)
)
# -----------------------------
train_skill_name, train_skill_output, accuracy = self.select_skill_to_train(
feedback, accuracy_threshold
)
if not train_skill_name:
print_text(f"No skill to improve found. Continue learning...")
skill_mismatch = feedback.match.fillna(True) == False
has_errors = skill_mismatch.any(axis=1).any()
if not has_errors:
print_text("No errors found!")
continue
train_skill = self.skills[train_skill_name]
print_text(
f'Output to improve: "{train_skill_output}" (Skill="{train_skill_name}")\n'
f"Accuracy = {accuracy * 100:0.2f}%",
style="bold red",
)

old_instructions = train_skill.instructions
train_skill.improve(
predictions, train_skill_output, feedback, runtime=teacher_runtime
)

highlight_differences(old_instructions, train_skill.instructions)
# print_text(f'{train_skill.instructions}', style='bold green')
first_skill_with_errors = skill_mismatch.any(axis=0).idxmax()

accuracy = feedback.get_accuracy()
# TODO: iterating over skill can be more complex, and we should take order into account
for skill_output, skill_name in self.skills.get_skill_outputs().items():
skill = self.skills[skill_name]
if skill.frozen:
continue

print_text(
f'Skill output to improve: "{skill_output}" (Skill="{skill_name}")\n'
f"Accuracy = {accuracy[skill_output] * 100:0.2f}%",
style="bold red",
)

old_instructions = skill.instructions
skill.improve(
predictions, skill_output, feedback, runtime=teacher_runtime
)

if is_running_in_jupyter():
highlight_differences(old_instructions, skill.instructions)
else:
print_text(skill.instructions, style="bold green")

if skill_name == first_skill_with_errors:
break

print_text("Train is done!")
28 changes: 15 additions & 13 deletions adala/environments/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd
import numpy as np
from pydantic import BaseModel
from pydantic import BaseModel, Field
from abc import ABC, abstractmethod
from typing import Optional, Dict, Union, Callable
from adala.utils.internal_data import (
Expand Down Expand Up @@ -124,21 +124,21 @@ class StaticEnvironment(Environment):

Attributes
df (InternalDataFrame): The dataframe containing the ground truth.
ground_truth_columns (Optional[Dict[str, str]], optional):
ground_truth_columns ([Dict[str, str]]):
A dictionary mapping skill outputs to ground truth columns.
If None, the skill outputs names are assumed to be the ground truth columns names.
Defaults to None.
If not specified, the skill outputs are assumed to be the ground truth columns.
If a skill output is not in the dictionary, it is assumed to have no ground truth signal - NaNs are returned in the feedback.
matching_function (str, optional): The matching function to match ground truth strings with prediction strings.
Defaults to 'fuzzy'.
matching_threshold (float, optional): The matching threshold for the matching function.

Examples:
>>> df = pd.DataFrame({'skill_1': ['a', 'b', 'c'], 'skill_2': ['d', 'e', 'f'], 'skill_3': ['g', 'h', 'i']})
>>> env = StaticEnvironment(df)
>>> env = StaticEnvironment(df, ground_truth_columns={'skill_1': 'ground_truth_1', 'skill_2': 'ground_truth_2'})
"""

df: InternalDataFrame = None
ground_truth_columns: Optional[Dict[str, str]] = None
ground_truth_columns: Dict[str, str] = Field(default_factory=dict)
matching_function: Union[str, Callable] = "fuzzy"
matching_threshold: float = 0.9

Expand Down Expand Up @@ -171,16 +171,16 @@ def get_feedback(
predictions = predictions.sample(n=num_feedbacks)

for pred_column in pred_columns:
if not self.ground_truth_columns:
gt_column = pred_column
else:
gt_column = self.ground_truth_columns[pred_column]

pred = predictions[pred_column]
gt_column = self.ground_truth_columns.get(pred_column, pred_column)
if gt_column not in self.df.columns:
# if ground truth column is not in the dataframe, assume no ground truth signal - return NaNs
pred_match[pred_column] = InternalSeries(np.nan, index=pred.index)
pred_feedback[pred_column] = InternalSeries(np.nan, index=pred.index)
continue

gt = self.df[gt_column]
pred = predictions[pred_column]

gt, pred = gt.align(pred)
nonnull_index = gt.notnull() & pred.notnull()
gt = gt[nonnull_index]
Expand Down Expand Up @@ -214,10 +214,12 @@ def get_feedback(
else np.nan,
axis=1,
)
return EnvironmentFeedback(

fb = EnvironmentFeedback(
match=InternalDataFrame(pred_match).reindex(predictions.index),
feedback=InternalDataFrame(pred_feedback).reindex(predictions.index),
)
return fb

def get_data_batch(self, batch_size: int = None) -> InternalDataFrame:
"""
Expand Down
38 changes: 32 additions & 6 deletions adala/memories/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,51 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, TYPE_CHECKING, Dict

from typing import Any, Optional, TYPE_CHECKING, Dict, List
from pydantic import BaseModel, Field
from adala.utils.internal_data import InternalDataFrame


class Memory(BaseModel, ABC):

"""
Base class for long-term memories.
Long-term memories are used to store acquired knowledge and can be shared between agents.
Base class for memories.
"""

@abstractmethod
def remember(self, observation: str, experience: Any):
def remember(self, observation: str, data: Dict):
"""
Base method for remembering experiences in long term memory.
"""

def remember_many(self, observations: List[str], data: List[Dict]):
"""
Base method for remembering experiences in long term memory.
"""
for observation, d in zip(observations, data):
self.remember(observation, d)

@abstractmethod
def retrieve(self, observation: str) -> Any:
def retrieve(self, observation: str, num_results: int = 1) -> Any:
"""
Base method for retrieving past experiences from long term memory, based on current observations

Args:
observation: the current observation
num_results: the number of results to return
"""

def retrieve_many(self, observations: List[str], num_results: int = 1) -> List[Any]:
"""
Base method for retrieving past experiences from long term memory, based on current observations

Args:
observation: the current observation
num_results: the number of results to return
"""
return [self.retrieve(observation) for observation in observations]

@abstractmethod
def clear(self):
"""
Base method for clearing memory.
"""
54 changes: 54 additions & 0 deletions adala/memories/vectordb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import chromadb
import hashlib
from .base import Memory
from uuid import uuid4
from pydantic import BaseModel, Field, model_validator
from chromadb.utils import embedding_functions
from typing import Any, List, Dict


class VectorDBMemory(Memory):
"""
Memory backed by a vector database.
"""

db_name: str = ""
_client = None
_collection = None
_embedding_function = None

@model_validator(mode="after")
def init_database(self):
self._client = chromadb.Client()
self._embedding_function = embedding_functions.OpenAIEmbeddingFunction(
model_name="text-embedding-ada-002"
)
self._collection = self._client.get_or_create_collection(
name=self.db_name, embedding_function=self._embedding_function
)

def create_unique_id(self, string):
return hashlib.md5(string.encode()).hexdigest()

def remember(self, observation: str, data: Any):
self.remember_many([observation], [data])

def remember_many(self, observations: List[str], data: List[Dict]):
self._collection.add(
ids=[self.create_unique_id(o) for o in observations],
documents=observations,
metadatas=data,
)

def retrieve_many(self, observations: List[str], num_results: int = 1) -> List[Any]:
result = self._collection.query(query_texts=observations, n_results=num_results)
return result["metadatas"]

def retrieve(self, observation: str, num_results: int = 1) -> Any:
return self.retrieve_many([observation], num_results=num_results)[0]

def clear(self):
self._client.delete_collection(name=self.db_name)
self._collection = self._client.create_collection(
name=self.db_name, embedding_function=self._embedding_function
)
2 changes: 1 addition & 1 deletion adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def record_to_record(
# for example, output template "Output: {answer} is correct" results in output_prefix "Output: "
output_prefix = output_template[: output_field["start"]]
if instructions_first:
user_prompt += f"\n\n{output_prefix}"
user_prompt += f"\n{output_prefix}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
Expand Down
1 change: 1 addition & 0 deletions adala/skills/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .skillset import SkillSet, LinearSkillSet, ParallelSkillSet
from .collection.classification import ClassificationSkill
from .collection.rag import RAGSkill
from ._base import Skill, TransformSkill, AnalysisSkill, SynthesisSkill
23 changes: 18 additions & 5 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ class Skill(BaseModel, ABC):
examples=[True, False],
)

frozen: bool = Field(
default=False,
title="Frozen",
description="Flag indicating if the skill is frozen.",
examples=[True, False],
)

def _get_extra_fields(self):
"""
Retrieves fields that are not categorized as system fields.
Expand Down Expand Up @@ -177,6 +184,12 @@ def improve(
runtime (Runtime): The runtime instance to be used for processing (CURRENTLY SUPPORTS ONLY `OpenAIChatRuntime`).

"""
if (
feedback.match[train_skill_output].all()
and not feedback.match[train_skill_output].isna().all()
):
# nothing to improve
return

fb = feedback.feedback.rename(
columns=lambda x: x + "__fb" if x in predictions.columns else x
Expand Down Expand Up @@ -257,13 +270,13 @@ def improve(
{
"role": "user",
"content": f"""
## Current prompt
{self.instructions}
## Current prompt
{self.instructions}

## Examples
{examples}
## Examples
{examples}

Summarize your analysis about incorrect predictions and suggest changes to the prompt.""",
Summarize your analysis about incorrect predictions and suggest changes to the prompt.""",
}
]

Expand Down
Loading