-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
838bf83
commit 98b5634
Showing
2 changed files
with
229 additions
and
111 deletions.
There are no files selected for viewing
130 changes: 130 additions & 0 deletions
130
graph_rag/evaluation/ragas_evaluation/QA_graphrag_testdataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
""" | ||
This script contains functions to generate question-answer pairs from input documents using a language model, | ||
and critique them based on various criteria like groundedness, relevance, and standalone quality. | ||
Functions: | ||
- get_response: Sends a request to a language model API to generate responses based on a provided prompt. | ||
- qa_generator: Generates a specified number of question-answer pairs from input documents. | ||
- critique_qa: Critiques the generated QA pairs based on groundedness, relevance, and standalone quality. | ||
""" | ||
|
||
from prompts import * | ||
import pandas as pd | ||
import random | ||
from tqdm.auto import tqdm | ||
import requests | ||
|
||
|
||
def get_response( | ||
prompt: str, url: str = "http://localhost:11434/api/generate", model: str = "llama3" | ||
): | ||
""" | ||
Sends a prompt ollama API and retrieves the generated response. | ||
Args: | ||
prompt:The text input that the model will use to generate a response. | ||
url: The API endpoint for the model (default: "http://localhost:11434/api/generate"). | ||
model: The model to be used for generation (default: "llama3"). | ||
Returns: | ||
The generated response from the language model as a string. | ||
""" | ||
|
||
payload = {"model": model, "prompt": prompt, "stream": False} | ||
response = requests.post(url, json=payload) | ||
resp = response.json() | ||
return resp["response"] | ||
|
||
|
||
def qa_generator( | ||
documents: object, | ||
N_GENERATIONS: int = 20, | ||
): | ||
""" | ||
Generates a specified number of question-answer pairs from the provided documents. | ||
Args: | ||
documents: A collection of document objects to generate QA pairs from. | ||
N_GENERATIONS: The number of question-answer pairs to generate (default: 20). | ||
Returns: | ||
A list of dictionaries, each containing the generated context, question, answer, and source document metadata. | ||
""" | ||
print(f"Generating {N_GENERATIONS} QA couples...") | ||
|
||
outputs = [] | ||
for sampled_context in tqdm(random.sample(documents, N_GENERATIONS)): | ||
# Generate QA couple | ||
output_QA_couple = get_response( | ||
QA_generation_prompt.format(context=sampled_context.text) | ||
) | ||
try: | ||
question = output_QA_couple.split("Factoid question: ")[-1].split( | ||
"Answer: " | ||
)[0] | ||
answer = output_QA_couple.split("Answer: ")[-1] | ||
assert len(answer) < 300, "Answer is too long" | ||
outputs.append( | ||
{ | ||
"context": sampled_context.text, | ||
"question": question, | ||
"answer": answer, | ||
"source_doc": sampled_context.metadata, | ||
} | ||
) | ||
except: | ||
continue | ||
df = pd.DataFrame(outputs) | ||
df.to_csv("QA.csv") | ||
return outputs | ||
|
||
|
||
def critique_qa( | ||
outputs: list, | ||
): | ||
""" | ||
Critiques the generated question-answer pairs based on groundedness, relevance, and standalone quality. | ||
Args: | ||
outputs: A list of dictionaries containing generated QA pairs to be critiqued. | ||
Returns: | ||
The critiqued QA pairs with additional fields for groundedness, relevance, and standalone quality scores and evaluations. | ||
""" | ||
print("Generating critique for each QA couple...") | ||
for output in tqdm(outputs): | ||
evaluations = { | ||
"groundedness": get_response( | ||
question_groundedness_critique_prompt.format( | ||
context=output["context"], question=output["question"] | ||
), | ||
), | ||
"relevance": get_response( | ||
question_relevance_critique_prompt.format(question=output["question"]), | ||
), | ||
"standalone": get_response( | ||
question_standalone_critique_prompt.format(question=output["question"]), | ||
), | ||
} | ||
try: | ||
for criterion, evaluation in evaluations.items(): | ||
score, eval = ( | ||
int(evaluation.split("Total rating: ")[-1].strip()), | ||
evaluation.split("Total rating: ")[-2].split("Evaluation: ")[1], | ||
) | ||
output.update( | ||
{ | ||
f"{criterion}_score": score, | ||
f"{criterion}_eval": eval, | ||
} | ||
) | ||
except Exception as e: | ||
continue | ||
generated_questions = pd.DataFrame.from_dict(outputs) | ||
generated_questions = generated_questions.loc[ | ||
(generated_questions["groundedness_score"] >= 4) | ||
& (generated_questions["relevance_score"] >= 4) | ||
& (generated_questions["standalone_score"] >= 4) | ||
] | ||
generated_questions.to_csv("generated_questions.csv") | ||
return outputs |
210 changes: 99 additions & 111 deletions
210
graph_rag/evaluation/ragas_evaluation/evaluation_ragas.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,130 +1,118 @@ | ||
""" | ||
This script contains functions to generate question-answer pairs from input documents using a language model, | ||
and critique them based on various criteria like groundedness, relevance, and standalone quality. | ||
This script loads a pre-processed dataset, slices it for batch evaluation, and runs a series of metrics to evaluate the | ||
performance of a query engine using a language model and embeddings. | ||
Functions: | ||
- get_response: Sends a request to a language model API to generate responses based on a provided prompt. | ||
- qa_generator: Generates a specified number of question-answer pairs from input documents. | ||
- critique_qa: Critiques the generated QA pairs based on groundedness, relevance, and standalone quality. | ||
- load_test_dataset: Loads a test dataset from a pickle file. | ||
- slice_data: Slices the dataset into batches for evaluation. | ||
- evaluate: Runs evaluation on the sliced dataset using specified metrics, LLMs, and embeddings. | ||
""" | ||
|
||
from prompts import * | ||
import pickle | ||
import pandas as pd | ||
import random | ||
from tqdm.auto import tqdm | ||
import requests | ||
|
||
|
||
def get_response( | ||
prompt: str, url: str = "http://localhost:11434/api/generate", model: str = "llama3" | ||
from datasets import Dataset | ||
from ragas.integrations.llama_index import evaluate | ||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | ||
from ragas.metrics.critique import harmfulness | ||
from llama_index.llms.ollama import Ollama | ||
from ragas.metrics import ( | ||
faithfulness, | ||
answer_relevancy, | ||
context_precision, | ||
context_recall, | ||
) | ||
|
||
|
||
def load_test_dataset( | ||
data: str, | ||
): | ||
""" | ||
Sends a prompt ollama API and retrieves the generated response. | ||
Args: | ||
prompt:The text input that the model will use to generate a response. | ||
url: The API endpoint for the model (default: "http://localhost:11434/api/generate"). | ||
model: The model to be used for generation (default: "llama3"). | ||
Loads a test dataset from a pickle file. | ||
Returns: | ||
The generated response from the language model as a string. | ||
""" | ||
Args: | ||
data: The path to the dataset file in pickle format. | ||
payload = {"model": model, "prompt": prompt, "stream": False} | ||
response = requests.post(url, json=payload) | ||
resp = response.json() | ||
return resp["response"] | ||
Returns: | ||
A dictionary representing the loaded dataset or an empty dictionary if loading fails due to EOFError. | ||
""" | ||
try: | ||
with open(data, "rb") as f: | ||
dataset = pickle.load(f) | ||
except EOFError: | ||
print("EOFError: The file may be corrupted or incomplete loading empty dictionary.") | ||
dataset = {} | ||
return dataset | ||
|
||
|
||
def qa_generator( | ||
documents: object, | ||
N_GENERATIONS: int = 20, | ||
def slice_data(i: int, k: int, dataset: dict): | ||
""" | ||
Slices the dataset into smaller chunks for batch processing. | ||
Args: | ||
i: The starting index for the slice. | ||
k: The size of the slice (number of records to include in each batch). | ||
dataset: The dictionary representing the dataset to be sliced. | ||
Returns: | ||
A dictionary containing the sliced dataset with renamed columns for consistency with the evaluation process. | ||
""" | ||
|
||
hf_dataset = Dataset.from_list(dataset[i : i + k]) | ||
hf_dataset = hf_dataset.rename_column("context", "contexts") | ||
hf_dataset = hf_dataset.rename_column("answer", "ground_truth") | ||
ds_dict = hf_dataset.to_dict() | ||
return ds_dict | ||
|
||
|
||
def evaluate( | ||
query_engine: object, | ||
dataset: object, | ||
batch: int = 4, | ||
metrics: list = [ | ||
faithfulness, | ||
answer_relevancy, | ||
context_precision, | ||
context_recall, | ||
], | ||
llm: object = Ollama(base_url="http://localhost:11434", model="codellama"), | ||
embeddings=HuggingFaceEmbedding(model_name="microsoft/codebert-base"), | ||
): | ||
""" | ||
Generates a specified number of question-answer pairs from the provided documents. | ||
Evaluates the performance of a query engine on a dataset using various metrics and a language model. | ||
Args: | ||
documents: A collection of document objects to generate QA pairs from. | ||
N_GENERATIONS: The number of question-answer pairs to generate (default: 20). | ||
Args: | ||
query_engine: The query engine to be evaluated. | ||
dataset: The dataset to be evaluated against. | ||
batch: The number of records to process in each batch (default: 4). | ||
metrics: A list of metrics to be used for evaluation (default: faithfulness, answer relevancy, context precision, and context recall). | ||
llm: The language model to be used for evaluation (default: Ollama with model 'codellama'). | ||
embeddings: The embedding model to be used (default: HuggingFaceEmbedding with 'microsoft/codebert-base'). | ||
Returns: | ||
A list of dictionaries, each containing the generated context, question, answer, and source document metadata. | ||
""" | ||
print(f"Generating {N_GENERATIONS} QA couples...") | ||
Returns: | ||
A pandas DataFrame containing the evaluation results for each batch. | ||
""" | ||
|
||
outputs = [] | ||
for sampled_context in tqdm(random.sample(documents, N_GENERATIONS)): | ||
# Generate QA couple | ||
output_QA_couple = get_response( | ||
QA_generation_prompt.format(context=sampled_context.text) | ||
) | ||
try: | ||
question = output_QA_couple.split("Factoid question: ")[-1].split( | ||
"Answer: " | ||
)[0] | ||
answer = output_QA_couple.split("Answer: ")[-1] | ||
assert len(answer) < 300, "Answer is too long" | ||
outputs.append( | ||
{ | ||
"context": sampled_context.text, | ||
"question": question, | ||
"answer": answer, | ||
"source_doc": sampled_context.metadata, | ||
} | ||
) | ||
except: | ||
continue | ||
df = pd.DataFrame(outputs) | ||
df.to_csv("QA.csv") | ||
return outputs | ||
|
||
|
||
def critique_qa( | ||
outputs: list, | ||
): | ||
""" | ||
Critiques the generated question-answer pairs based on groundedness, relevance, and standalone quality. | ||
rows_count = len(next(iter(dataset.values()))) | ||
|
||
Args: | ||
outputs: A list of dictionaries containing generated QA pairs to be critiqued. | ||
results_df = pd.DataFrame() | ||
|
||
Returns: | ||
The critiqued QA pairs with additional fields for groundedness, relevance, and standalone quality scores and evaluations. | ||
""" | ||
print("Generating critique for each QA couple...") | ||
for output in tqdm(outputs): | ||
evaluations = { | ||
"groundedness": get_response( | ||
question_groundedness_critique_prompt.format( | ||
context=output["context"], question=output["question"] | ||
), | ||
), | ||
"relevance": get_response( | ||
question_relevance_critique_prompt.format(question=output["question"]), | ||
), | ||
"standalone": get_response( | ||
question_standalone_critique_prompt.format(question=output["question"]), | ||
), | ||
} | ||
try: | ||
for criterion, evaluation in evaluations.items(): | ||
score, eval = ( | ||
int(evaluation.split("Total rating: ")[-1].strip()), | ||
evaluation.split("Total rating: ")[-2].split("Evaluation: ")[1], | ||
) | ||
output.update( | ||
{ | ||
f"{criterion}_score": score, | ||
f"{criterion}_eval": eval, | ||
} | ||
) | ||
except Exception as e: | ||
continue | ||
generated_questions = pd.DataFrame.from_dict(outputs) | ||
generated_questions = generated_questions.loc[ | ||
(generated_questions["groundedness_score"] >= 4) | ||
& (generated_questions["relevance_score"] >= 4) | ||
& (generated_questions["standalone_score"] >= 4) | ||
] | ||
generated_questions.to_csv("generated_questions.csv") | ||
return outputs | ||
for i in range(0, rows_count, batch): | ||
|
||
batch_data = slice_data(i, batch, dataset=dataset) | ||
|
||
result = evaluate( | ||
query_engine=query_engine, | ||
metrics=metrics, | ||
dataset=batch_data, | ||
llm=llm, | ||
embeddings=embeddings, | ||
) | ||
|
||
rdf = result.to_pandas() | ||
results_df = pd.concat([results_df, rdf], ignore_index=True) | ||
print(f"Processed batch {i // batch + 1}:") | ||
print(rdf) | ||
print(results_df) | ||
results_df.to_csv("results.csv", index=False) | ||
return results_df |