Skip to content

Commit

Permalink
Update docs, fix issue with params, add tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
Vasilije1990 committed Oct 30, 2023
1 parent 552a8e6 commit 3409d5b
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 68 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Send the request to the API:
```
curl -X POST -H "Content-Type: application/json" -d '{
"payload": {
"user_id": "681",
"user_id": "97980cfea0067",
"data": [".data/3ZCCCW.pdf"],
"test_set": "sample",
"params": ["chunk_size"],
Expand Down Expand Up @@ -217,7 +217,7 @@ After that, you can run the RAG test manager from your command line.
python rag_test_manager.py \
--file ".data" \
--test_set "example_data/test_set.json" \
--user_id "666" \
--user_id "97980cfea0067" \
--params "chunk_size" "search_type" \
--metadata "example_data/metadata.json" \
--retriever_type "single_document_context"
Expand All @@ -226,3 +226,6 @@ After that, you can run the RAG test manager from your command line.

Examples of metadata structure and test set are in the folder "example_data"




3 changes: 2 additions & 1 deletion level_3/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ POSTGRES_USER = bla
POSTGRES_PASSWORD = bla
POSTGRES_DB = bubu
POSTGRES_HOST = localhost
POSTGRES_HOST_DOCKER = postgres
POSTGRES_HOST_DOCKER = postgres
SEGMENT_KEY = Etl4WJwzOkeDPAjaOXOMgyU16hO7mV7B
39 changes: 35 additions & 4 deletions level_3/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions level_3/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dash = "^2.14.0"
unstructured = {extras = ["pdf"], version = "^0.10.23"}
sentence-transformers = "2.2.2"
torch = "2.0.*"
segment-analytics-python = "^2.2.3"



Expand Down
98 changes: 37 additions & 61 deletions level_3/rag_test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from models.operation import Operation
from models.docs import DocsModel

import segment.analytics as analytics

load_dotenv()
import ast
import tracemalloc
Expand All @@ -55,8 +57,16 @@
from database.database import AsyncSessionLocal

openai.api_key = os.getenv("OPENAI_API_KEY", "")
analytics.write_key = os.getenv("SEGMENT_KEY", "")


def on_error(error, items):
print("An error occurred:", error)


analytics.debug = True
analytics.on_error = on_error

async def retrieve_latest_test_case(session, user_id, memory_id):
try:
# Use await with session.execute() and row.fetchone() or row.all() for async query execution
Expand Down Expand Up @@ -476,6 +486,10 @@ async def start_test(
await add_entity(
session, TestSet(id=test_set_id, user_id=user_id, content=str(test_set))
)
analytics.track(user_id, 'TestSet', {
'id': test_set_id,
'content': str(test_set)
})

if params is None:
data_format = data_format_route(
Expand Down Expand Up @@ -521,6 +535,15 @@ async def start_test(
test_set_id=test_set_id,
),
)
analytics.track(user_id, 'Operation', {
'id': job_id,
'operation_params': str(test_params),
'number_of_files': count_files_in_data_folder(),
'operation_status': "RUNNING",
'operation_type': retriever_type,
'test_set_id': test_set_id,
})

doc_names = get_document_names(data)
for doc in doc_names:

Expand Down Expand Up @@ -697,56 +720,27 @@ async def run_generate_test_set(test_id):
test_params=str(chunk), # Add params to the database table
),
)
analytics.track(user_id, 'TestOutput', {
'test_set_id': test_set_id,
'operation_id': job_id,
'set_id' : str(uuid.uuid4()),
'test_results' : result["success"],
'test_score' : str(result["score"]),
'test_metric_name' : result["metric_name"],
'test_query' : result["query"],
'test_output' : result["output"],
'test_expected_output' : str(["expected_output"]),
'test_context' : result["context"][0],
'test_params' : str(chunk),
})
analytics.flush()

await update_entity(session, Operation, job_id, "COMPLETED")

return results


async def main():
# metadata = {
# "version": "1.0",
# "agreement_id": "AG123456",
# "privacy_policy": "https://example.com/privacy",
# "terms_of_service": "https://example.com/terms",
# "format": "json",
# "schema_version": "1.1",
# "checksum": "a1b2c3d4e5f6",
# "owner": "John Doe",
# "license": "MIT",
# "validity_start": "2023-08-01",
# "validity_end": "2024-07-31",
# }
#
# test_set = [
# {
# "question": "Who is the main character in 'The Call of the Wild'?",
# "answer": "Buck",
# },
# {"question": "Who wrote 'The Call of the Wild'?", "answer": "Jack London"},
# {
# "question": "Where does Buck live at the start of the book?",
# "answer": "In the Santa Clara Valley, at Judge Miller’s place.",
# },
# {
# "question": "Why is Buck kidnapped?",
# "answer": "He is kidnapped to be sold as a sled dog in the Yukon during the Klondike Gold Rush.",
# },
# {
# "question": "How does Buck become the leader of the sled dog team?",
# "answer": "Buck becomes the leader after defeating the original leader, Spitz, in a fight.",
# },
# ]
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
# # http://public-library.uk/ebooks/59/83.pdf
# result = await start_test(
# [".data/3ZCCCW.pdf"],
# test_set=test_set,
# user_id="677",
# params=["chunk_size", "search_type"],
# metadata=metadata,
# retriever_type="single_document_context",
# )

parser = argparse.ArgumentParser(description="Run tests against a document.")
parser.add_argument("--file", nargs="+", required=True, help="List of file paths to test.")
Expand Down Expand Up @@ -793,21 +787,3 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())

# delete_mems = await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
# namespace=test_id)
# test_load_pipeline = await asyncio.gather(
# *(run_load_test_element(test_item,loader_settings, metadata, test_id) for test_item in test_set)
# )
#
# test_eval_pipeline = await asyncio.gather(
# *(run_search_eval_element(test_item, test_id) for test_item in test_set)
# )
# logging.info("Results of the eval pipeline %s", str(test_eval_pipeline))
# await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(test_eval_pipeline)))
# return test_eval_pipeline

# # Gather and run all tests in parallel
# results = await asyncio.gather(
# *(run_testo(test, loader_settings, metadata) for test in test_params)
# )
# return results

0 comments on commit 3409d5b

Please sign in to comment.