Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CShorten committed Jan 14, 2025
1 parent 0d10c5e commit d9e227c
Show file tree
Hide file tree
Showing 7 changed files with 13,672 additions and 26,590 deletions.
49 changes: 39 additions & 10 deletions app/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict, Any
from src.lm.query_executor import execute_weaviate_query
from src.models import WeaviateQuery

app = FastAPI()

Expand All @@ -18,9 +20,25 @@

import json
import os
import weaviate

QUERIES_FILE = "synthetic-weaviate-queries-with-results.json"

# Get API keys from environment variables
WEAVIATE_URL = os.getenv("WEAVIATE_URL")
WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

print("Connecting to Weaviate...")
weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=weaviate.auth.AuthApiKey(WEAVIATE_API_KEY),
headers = {
"X-OpenAI-Api-Key": OPENAI_API_KEY
}
)
print("Successfully connected to Weaviate...")

with open(QUERIES_FILE, 'r') as f:
synthetic_query_data = json.load(f)

Expand Down Expand Up @@ -60,20 +78,31 @@ async def get_data():
class QueryUpdate(BaseModel):
index: int
updated_query: Dict[Any, Any]
updated_result: str

@app.put("/update-query")
async def update_query(query_update: QueryUpdate):
print(query_update)
try:
if 0 <= query_update.index < len(synthetic_query_data):
synthetic_query_data[query_update.index]["query"] = query_update.updated_query
synthetic_query_data[query_update.index]["ground_truth_query_result"] = query_update.updated_result

with open(QUERIES_FILE, 'w') as f:
json.dump(synthetic_query_data, f, indent=2)

return {"message": "Query and result updated successfully"}
else:
# 1) Validate the requested index
if not (0 <= query_update.index < len(synthetic_query_data)):
raise HTTPException(status_code=404, detail="Query index not found")

# 2) Overwrite the existing "query" field with the updated query data
synthetic_query_data[query_update.index]["query"] = query_update.updated_query

# 3) Build a WeaviateQuery object from the updated query dictionary
updated_query_obj = WeaviateQuery(**query_update.updated_query)

# 4) Execute the query to get the new ground-truth result
final_response = execute_weaviate_query(weaviate_client, updated_query_obj)

# 5) Update ground_truth_query_result and save changes to disk
synthetic_query_data[query_update.index]["ground_truth_query_result"] = final_response

with open(QUERIES_FILE, 'w') as f:
json.dump(synthetic_query_data, f, indent=2)

return {"message": "Query and ground truth result updated successfully"}

except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
13,596 changes: 6,796 additions & 6,800 deletions app/backend/synthetic-weaviate-queries-with-results.json

Large diffs are not rendered by default.

6,473 changes: 0 additions & 6,473 deletions app/backend/synthetic-weaviate-queries-with-schemas.json

This file was deleted.

23 changes: 9 additions & 14 deletions data/print-query-execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import weaviate
import weaviate.classes as wvc
import json
import time
import pandas as pd
import os
from src.models import WeaviateQuery
Expand Down Expand Up @@ -68,12 +69,14 @@
)

# Load and insert data from corresponding CSV
csv_path = f'./{collection_name}.csv'
start_time = time.time()
csv_path = f'./data-for-use-cases/{collection_name}.csv'
if os.path.exists(csv_path):
df = pd.read_csv(csv_path)
collection_obj = weaviate_client.collections.get(collection_name)
for _, row in df.iterrows():
collection_obj.data.insert(properties=row.to_dict())
print(f"Loading data for {collection_name} took {time.time() - start_time:.2f} seconds")

created_collections.add(collection_name)

Expand All @@ -87,20 +90,12 @@
try:
result = execute_weaviate_query(weaviate_client, query)
query_data['ground_truth_query_result'] = result
print(f"\033[92mQuery executed successfully\033[0m") # Green text
print(f"Query result: {result}") # Print the query result
except Exception as e:
failed_queries += 1
print("\nQuery:", query_data['query']['corresponding_natural_language_query'])
print("\nQuery details:")
print(f"Target collection: {query.target_collection}")
print(f"Search query: {query.search_query}")
print(f"Integer filters: {query.integer_property_filter}")
print(f"Text filters: {query.text_property_filter}")
print(f"Boolean filters: {query.boolean_property_filter}")
print(f"Integer aggregations: {query.integer_property_aggregation}")
print(f"Text aggregations: {query.text_property_aggregation}")
print(f"Boolean aggregations: {query.boolean_property_aggregation}")
print(f"Group by: {query.groupby_property}")
print(f"\033[91mQuery execution failed: {str(e)}\033[0m") # Red text
print(f"\033[91mQuery execution failed\033[0m") # Red text
print(f"Error: {str(e)}") # Print the error message
query_data['ground_truth_query_result'] = "QUERY EXECUTION FAILED"
print("Connecting to Weaviate...")
weaviate_client = weaviate.connect_to_weaviate_cloud(
Expand All @@ -122,4 +117,4 @@
print(f"\nResults saved to {output_path}")

# Close the Weaviate client connection
weaviate_client.close()
weaviate_client.close()
Loading

0 comments on commit d9e227c

Please sign in to comment.