Skip to content

Commit

Permalink
[DH-4306] Split db methods for cardinality and logs (#269)
Browse files Browse the repository at this point in the history
* [DH-4306] Split db methods for cardinality and logs

* Improve query history method for Snowflake and BigQuery

* Document query_history endpoint
  • Loading branch information
jcjc712 authored Dec 15, 2023
1 parent 6e00ee3 commit 611724b
Show file tree
Hide file tree
Showing 17 changed files with 386 additions and 60 deletions.
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ Once you have connected to the data warehouse, you should add context to the eng
While only the Database scan part is required to start generating SQL, adding verified SQL and string descriptions are also important for the tool to generate accurate SQL.

#### Scanning the Database
The database scan is used to gather information about the database including table and column names and identifying low cardinality columns and their values to be stored in the context store and used in the prompts to the LLM. You can trigger a scan of a database from the `POST /api/v1/table-descriptions/sync-schemas` endpoint. Example below
The database scan is used to gather information about the database including table and column names and identifying low cardinality columns and their values to be stored in the context store and used in the prompts to the LLM.
In addition, it retrieves logs, which consist of historical queries associated with each database table. These records are then stored within the query_history collection. The historical queries retrieved encompass data from the past three months and are grouped based on query and user.
You can trigger a scan of a database from the `POST /api/v1/table-descriptions/sync-schemas` endpoint. Example below


```
Expand All @@ -279,6 +281,29 @@ curl -X 'POST' \

Since the endpoint identifies low cardinality columns (and their values) it can take time to complete. Therefore while it is possible to trigger a scan on the entire DB by not specifying the `table_names`, we recommend against it for large databases.

#### Get logs per db connection
Once a database was scanned you can use this endpoint to retrieve the tables logs

```
curl -X 'GET' \
'http://localhost/api/v1/query-history?db_connection_id=656e52cb4d1fda50cae7b939' \
-H 'accept: application/json'
```

Response example:
```
[
{
"id": "656e52cb4d1fda50cae7b939",
"db_connection_id": "656e52cb4d1fda50cae7b939",
"table_name": "table_name",
"query": "select QUERY_TEXT, USER_NAME, count(*) as occurrences from ....",
"user": "user_name",
"occurrences": 1
}
]
```

#### Get a scanned db
Once a database was scanned you can use this endpoint to retrieve the tables names and columns

Expand Down
6 changes: 5 additions & 1 deletion dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dataherald.api.types import Query
from dataherald.config import Component
from dataherald.db_scanner.models.types import TableDescription
from dataherald.db_scanner.models.types import QueryHistory, TableDescription
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.types import (
CancelFineTuningRequest,
Expand Down Expand Up @@ -125,6 +125,10 @@ def create_response(
) -> Response:
pass

@abstractmethod
def get_query_history(self, db_connection_id: str) -> list[QueryHistory]:
pass

@abstractmethod
def get_responses(self, question_id: str | None = None) -> list[Response]:
pass
Expand Down
15 changes: 14 additions & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
from dataherald.context_store import ContextStore
from dataherald.db import DB
from dataherald.db_scanner import Scanner
from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus
from dataherald.db_scanner.models.types import (
QueryHistory,
TableDescription,
TableDescriptionStatus,
)
from dataherald.db_scanner.repository.base import (
InvalidColumnNameError,
TableDescriptionRepository,
)
from dataherald.db_scanner.repository.query_history import QueryHistoryRepository
from dataherald.eval import Evaluator
from dataherald.finetuning.openai_finetuning import OpenAIFineTuning
from dataherald.repositories.base import ResponseRepository
Expand Down Expand Up @@ -71,6 +76,7 @@ def async_scanning(scanner, database, scanner_request, storage):
scanner_request.db_connection_id,
scanner_request.table_names,
TableDescriptionRepository(storage),
QueryHistoryRepository(storage),
)


Expand Down Expand Up @@ -381,6 +387,13 @@ def get_table_description(self, table_description_id: str) -> TableDescription:
raise HTTPException(status_code=404, detail="Table description not found")
return result

@override
def get_query_history(self, db_connection_id: str) -> list[QueryHistory]:
query_history_repository = QueryHistoryRepository(self.storage)
return query_history_repository.find_by(
{"db_connection_id": ObjectId(db_connection_id)}
)

@override
def get_responses(self, question_id: str | None = None) -> list[Response]:
response_repository = ResponseRepository(self.storage)
Expand Down
2 changes: 2 additions & 0 deletions dataherald/db_scanner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dataherald.config import Component
from dataherald.db_scanner.repository.base import TableDescriptionRepository
from dataherald.db_scanner.repository.query_history import QueryHistoryRepository
from dataherald.sql_database.base import SQLDatabase


Expand All @@ -14,6 +15,7 @@ def scan(
db_connection_id: str,
table_names: list[str] | None,
repository: TableDescriptionRepository,
query_history_repository: QueryHistoryRepository,
) -> None:
""" "Scan a db"""

Expand Down
9 changes: 9 additions & 0 deletions dataherald/db_scanner/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,12 @@ def parse_datetime_with_timezone(cls, value):
if not value:
return None
return value.replace(tzinfo=timezone.utc) # Set the timezone to UTC


class QueryHistory(BaseModel):
id: str | None
db_connection_id: str
table_name: str
query: str
user: str
occurrences: int = 0
31 changes: 31 additions & 0 deletions dataherald/db_scanner/repository/query_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from bson.objectid import ObjectId

from dataherald.db_scanner.models.types import QueryHistory

DB_COLLECTION = "query_history"


class QueryHistoryRepository:
def __init__(self, storage):
self.storage = storage

def insert(self, query_history: QueryHistory) -> QueryHistory:
query_history_dict = query_history.dict(exclude={"id"})
query_history_dict["db_connection_id"] = ObjectId(
query_history.db_connection_id
)
query_history.id = str(
self.storage.insert_one(DB_COLLECTION, query_history_dict)
)
return query_history

def find_by(
self, query: dict, page: int = 1, limit: int = 10
) -> list[QueryHistory]:
rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit)
result = []
for row in rows:
row["id"] = str(row["_id"])
row["db_connection_id"] = str(row["db_connection_id"])
result.append(QueryHistory(**row))
return result
Empty file.
20 changes: 20 additions & 0 deletions dataherald/db_scanner/services/abstract_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod

from sqlalchemy.sql.schema import Column

from dataherald.db_scanner.models.types import QueryHistory
from dataherald.sql_database.base import SQLDatabase


class AbstractScanner(ABC):
@abstractmethod
def cardinality_values(self, column: Column, db_engine: SQLDatabase) -> list | None:
"""Returns a list if it is a catalog otherwise return None"""
pass

@abstractmethod
def get_logs(
self, table: str, db_engine: SQLDatabase, db_connection_id: str
) -> list[QueryHistory]:
"""Returns a list of logs"""
pass
28 changes: 28 additions & 0 deletions dataherald/db_scanner/services/base_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import sqlalchemy
from overrides import override
from sqlalchemy.sql import func
from sqlalchemy.sql.schema import Column

from dataherald.db_scanner.models.types import QueryHistory
from dataherald.db_scanner.services.abstract_scanner import AbstractScanner
from dataherald.sql_database.base import SQLDatabase

MIN_CATEGORY_VALUE = 1
MAX_CATEGORY_VALUE = 100


class BaseScanner(AbstractScanner):
@override
def cardinality_values(self, column: Column, db_engine: SQLDatabase) -> list | None:
cardinality_query = sqlalchemy.select([func.distinct(column)]).limit(101)
cardinality = db_engine.engine.execute(cardinality_query).fetchall()

if MAX_CATEGORY_VALUE > len(cardinality) > MIN_CATEGORY_VALUE:
return [str(category[0]) for category in cardinality]
return None

@override
def get_logs(
self, table: str, db_engine: SQLDatabase, db_connection_id: str # noqa: ARG002
) -> list[QueryHistory]:
return []
52 changes: 52 additions & 0 deletions dataherald/db_scanner/services/big_query_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from datetime import datetime, timedelta

import sqlalchemy
from overrides import override
from sqlalchemy.sql import func
from sqlalchemy.sql.schema import Column

from dataherald.db_scanner.models.types import QueryHistory
from dataherald.db_scanner.services.abstract_scanner import AbstractScanner
from dataherald.sql_database.base import SQLDatabase

MIN_CATEGORY_VALUE = 1
MAX_CATEGORY_VALUE = 100
MAX_LOGS = 5_000


class BigQueryScanner(AbstractScanner):
@override
def cardinality_values(self, column: Column, db_engine: SQLDatabase) -> list | None:
rs = db_engine.engine.execute(
f"SELECT APPROX_COUNT_DISTINCT({column.name}) FROM {column.table.name}" # noqa: S608 E501
).fetchall()

if (
len(rs) > 0
and len(rs[0]) > 0
and MIN_CATEGORY_VALUE < rs[0][0] <= MAX_CATEGORY_VALUE
):
cardinality_query = sqlalchemy.select([func.distinct(column)]).limit(101)
cardinality = db_engine.engine.execute(cardinality_query).fetchall()
return [str(category[0]) for category in cardinality]

return None

@override
def get_logs(
self, table: str, db_engine: SQLDatabase, db_connection_id: str
) -> list[QueryHistory]:
filter_date = (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d")
rows = db_engine.engine.execute(
f"SELECT query, user_email, count(*) as occurrences FROM `region-us.INFORMATION_SCHEMA.JOBS`, UNNEST(referenced_tables) AS t where job_type = 'QUERY' and statement_type = 'SELECT' and t.table_id = '{table}' and state = 'DONE' and creation_time >='{filter_date}' group by query, user_email ORDER BY occurrences DESC limit {MAX_LOGS}" # noqa: S608 E501
).fetchall()
return [
QueryHistory(
db_connection_id=db_connection_id,
table_name=table,
query=row[0],
user=row[1],
occurrences=row[2],
)
for row in rows
]
30 changes: 30 additions & 0 deletions dataherald/db_scanner/services/postgre_sql_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from overrides import override
from sqlalchemy.sql.schema import Column

from dataherald.db_scanner.models.types import QueryHistory
from dataherald.db_scanner.services.abstract_scanner import AbstractScanner
from dataherald.sql_database.base import SQLDatabase

MIN_CATEGORY_VALUE = 1
MAX_CATEGORY_VALUE = 100


class PostgreSqlScanner(AbstractScanner):
@override
def cardinality_values(self, column: Column, db_engine: SQLDatabase) -> list | None:
rs = db_engine.engine.execute(
f"SELECT n_distinct, most_common_vals::TEXT::TEXT[] FROM pg_catalog.pg_stats WHERE tablename = '{column.table.name}' AND attname = '{column.name}'" # noqa: S608 E501
).fetchall()

if (
len(rs) > 0
and MIN_CATEGORY_VALUE < rs[0]["n_distinct"] <= MAX_CATEGORY_VALUE
):
return rs[0]["most_common_vals"]
return None

@override
def get_logs(
self, table: str, db_engine: SQLDatabase, db_connection_id: str # noqa: ARG002
) -> list[QueryHistory]:
return []
53 changes: 53 additions & 0 deletions dataherald/db_scanner/services/snowflake_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from datetime import datetime, timedelta

import sqlalchemy
from overrides import override
from sqlalchemy.sql import func
from sqlalchemy.sql.schema import Column

from dataherald.db_scanner.models.types import QueryHistory
from dataherald.db_scanner.services.abstract_scanner import AbstractScanner
from dataherald.sql_database.base import SQLDatabase

MIN_CATEGORY_VALUE = 1
MAX_CATEGORY_VALUE = 100
MAX_LOGS = 5_000


class SnowflakeScanner(AbstractScanner):
@override
def cardinality_values(self, column: Column, db_engine: SQLDatabase) -> list | None:
rs = db_engine.engine.execute(
f"select HLL({column.name}) from {column.table.name}" # noqa: S608 E501
).fetchall()

if (
len(rs) > 0
and len(rs[0]) > 0
and MIN_CATEGORY_VALUE < rs[0][0] <= MAX_CATEGORY_VALUE
):
cardinality_query = sqlalchemy.select([func.distinct(column)]).limit(101)
cardinality = db_engine.engine.execute(cardinality_query).fetchall()
return [str(category[0]) for category in cardinality]

return None

@override
def get_logs(
self, table: str, db_engine: SQLDatabase, db_connection_id: str
) -> list[QueryHistory]:
database_name = db_engine.engine.url.database.split("/")[0]
filter_date = (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d")
rows = db_engine.engine.execute(
f"select QUERY_TEXT, USER_NAME, count(*) as occurrences from TABLE(INFORMATION_SCHEMA.QUERY_HISTORY()) where DATABASE_NAME = '{database_name}' and QUERY_TYPE = 'SELECT' and EXECUTION_STATUS = 'SUCCESS' and START_TIME > '{filter_date}' and QUERY_TEXT like '%FROM {table}%' and QUERY_TEXT not like '%QUERY_HISTORY%' group by QUERY_TEXT, USER_NAME ORDER BY occurrences DESC limit {MAX_LOGS}" # noqa: S608 E501
).fetchall()
return [
QueryHistory(
db_connection_id=db_connection_id,
table_name=table,
query=row[0],
user=row[1],
occurrences=row[2],
)
for row in rows
]
Loading

0 comments on commit 611724b

Please sign in to comment.