-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove insights, have single asyncio run, refactor
- Loading branch information
1 parent
b32025b
commit 33f09f2
Showing
3 changed files
with
104 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
{ | ||
"dataset": [ | ||
"hotpotqa" | ||
], | ||
"rag_option": [ | ||
"no_rag", | ||
"cognee", | ||
"simple_rag", | ||
"brute_force" | ||
], | ||
"num_samples": [ | ||
2 | ||
], | ||
"metric_names": [ | ||
"Correctness", | ||
"Comprehensiveness" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import itertools | ||
import matplotlib.pyplot as plt | ||
from jsonschema import ValidationError, validate | ||
import pandas as pd | ||
from pathlib import Path | ||
|
||
paramset_json_schema = { | ||
"type": "object", | ||
"properties": { | ||
"dataset": { | ||
"type": "array", | ||
"items": {"type": "string"}, | ||
}, | ||
"rag_option": { | ||
"type": "array", | ||
"items": {"type": "string"}, | ||
}, | ||
"num_samples": { | ||
"type": "array", | ||
"items": {"type": "integer", "minimum": 1}, | ||
}, | ||
"metric_names": { | ||
"type": "array", | ||
"items": {"type": "string"}, | ||
}, | ||
}, | ||
"required": ["dataset", "rag_option", "num_samples", "metric_names"], | ||
"additionalProperties": False, | ||
} | ||
|
||
|
||
def save_table_as_image(df, image_path): | ||
plt.figure(figsize=(10, 6)) | ||
plt.axis("tight") | ||
plt.axis("off") | ||
plt.table(cellText=df.values, colLabels=df.columns, rowLabels=df.index, loc="center") | ||
plt.title(f"{df.index.name}") | ||
plt.savefig(image_path, bbox_inches="tight") | ||
plt.close() | ||
|
||
|
||
def save_results_as_image(results, out_path): | ||
for dataset, num_samples_data in results.items(): | ||
for num_samples, table_data in num_samples_data.items(): | ||
df = pd.DataFrame.from_dict(table_data, orient="index") | ||
df.index.name = f"Dataset: {dataset}, Num Samples: {num_samples}" | ||
image_path = Path(out_path) / Path(f"table_{dataset}_{num_samples}.png") | ||
save_table_as_image(df, image_path) | ||
|
||
|
||
def get_combinations(parameters): | ||
try: | ||
validate(instance=parameters, schema=paramset_json_schema) | ||
except ValidationError as e: | ||
raise ValidationError(f"Invalid parameter set: {e.message}") | ||
|
||
params_for_combos = {k: v for k, v in parameters.items() if k != "metric_name"} | ||
keys, values = zip(*params_for_combos.items()) | ||
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)] | ||
return combinations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters