Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom data types #68

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions defog/async_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from io import StringIO
import pandas as pd
import json
from typing import List
from typing import List, Dict, Union


async def generate_postgres_schema(
Expand All @@ -14,9 +14,16 @@ async def generate_postgres_schema(
scan: bool = True,
return_tables_only: bool = False,
schemas: List[str] = ["public"],
) -> str:
# when upload is True, we send the schema to the defog servers and generate a CSV
# when its false, we return the schema as a dict
) -> Union[Dict, List, str]:
"""
Returns the schema of the tables in the database. Keys: column_name, data_type, column_description, custom_type_labels
If tables is non-empty, we only generate the schema for the mentioned tables in the list.
If schemas is non-empty, we only generate the schema for the mentioned schemas in the list.
If return_tables_only is True, we return only the table names as a list.
If upload is True, we send the schema to the defog servers and generate a CSV.
If upload is False, we return the schema as a dict.
If scan is True, we also scan the tables for categorical columns to enhance the column description.
"""
try:
import asyncpg
except ImportError:
Expand Down Expand Up @@ -55,13 +62,45 @@ async def generate_postgres_schema(
if "." in table_name:
_, table_name = table_name.split(".", 1)
query = """
SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT)
SELECT
CAST(column_name AS TEXT),
CAST(
CASE
WHEN data_type = 'USER-DEFINED' THEN udt_name
ELSE data_type
END AS TEXT
) AS data_type,
col_description(
(quote_ident($2) || '.' || quote_ident($1))::regclass::oid,
ordinal_position
) AS column_description,
CASE
WHEN data_type = 'USER-DEFINED' THEN (
SELECT string_agg(enumlabel, ', ')
FROM pg_enum
WHERE enumtypid = (
SELECT oid
FROM pg_type
WHERE typname = udt_name
)
)
ELSE NULL
END AS custom_type_labels
FROM information_schema.columns
WHERE table_name = $1 AND table_schema = $2;
"""
print(f"Schema for {schema}.{table_name}")
rows = await conn.fetch(query, table_name, schema)
rows = [row for row in rows]
rows = [{"column_name": row[0], "data_type": row[1]} for row in rows]
rows = [
{
"column_name": row[0],
"data_type": row[1],
"column_description": row[2] or "",
"custom_type_labels": row[3].split(", ") if row[3] else [],
}
for row in rows
]
if len(rows) > 0:
if scan:
rows = await async_identify_categorical_columns(
Expand Down
128 changes: 121 additions & 7 deletions defog/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from io import StringIO
import pandas as pd
import json
from typing import List
from typing import List, Dict, Union


def generate_postgres_schema(
Expand All @@ -14,9 +14,16 @@ def generate_postgres_schema(
scan: bool = True,
return_tables_only: bool = False,
schemas: List[str] = [],
) -> str:
# when upload is True, we send the schema to the defog servers and generate a CSV
# when its false, we return the schema as a dict
) -> Union[Dict, List, str]:
"""
Returns the schema of the tables in the database. Keys: column_name, data_type, column_description, custom_type_labels
If tables is non-empty, we only generate the schema for the mentioned tables in the list.
If schemas is non-empty, we only generate the schema for the mentioned schemas in the list.
If return_tables_only is True, we return only the table names as a list.
If upload is True, we send the schema to the defog servers and generate a CSV.
If upload is False, we return the schema as a dict.
If scan is True, we also scan the tables for categorical columns to enhance the column description.
"""
try:
import psycopg2
except ImportError:
Expand Down Expand Up @@ -75,11 +82,28 @@ def generate_postgres_schema(
"""
SELECT
CAST(column_name AS TEXT),
CAST(data_type AS TEXT),
CAST(
CASE
WHEN data_type = 'USER-DEFINED' THEN udt_name
ELSE data_type
END AS TEXT
) AS type,
col_description(
FORMAT('%%s.%%s', table_schema, table_name)::regclass::oid,
ordinal_position
) AS column_description
) AS column_description,
CASE
WHEN data_type = 'USER-DEFINED' THEN (
SELECT string_agg(enumlabel, ', ')
FROM pg_enum
WHERE enumtypid = (
SELECT oid
FROM pg_type
WHERE typname = udt_name
)
)
ELSE NULL
END AS custom_type_labels
FROM information_schema.columns
WHERE table_name::text = %s
AND table_schema = %s;
Expand All @@ -92,7 +116,12 @@ def generate_postgres_schema(
rows = cur.fetchall()
rows = [row for row in rows]
rows = [
{"column_name": i[0], "data_type": i[1], "column_description": i[2] or ""}
{
"column_name": i[0],
"data_type": i[1],
"column_description": i[2] or "",
"custom_type_labels": i[3].split(", ") if i[3] else [],
}
for i in rows
]
if len(rows) > 0:
Expand Down Expand Up @@ -136,6 +165,91 @@ def generate_postgres_schema(
return table_columns


def get_postgres_functions(
self, schemas: List[str] = []
) -> Dict[str, List[Dict[str, str]]]:
"""
Returns the custom functions and their definitions of the mentioned schemas in the database.
"""
try:
import psycopg2
except ImportError:
raise ImportError(
"psycopg2 not installed. Please install it with `pip install psycopg2-binary`."
)

conn = psycopg2.connect(**self.db_creds)
cur = conn.cursor()
functions = {}

if len(schemas) == 0:
schemas = ["public"]

for schema in schemas:
cur.execute(
"""
SELECT
CAST(p.proname AS TEXT) AS function_name,
pg_get_functiondef(p.oid) AS function_definition
FROM pg_proc p
JOIN pg_namespace n ON n.oid = p.pronamespace
WHERE n.nspname = %s;
""",
(schema,),
)
rows = [
{"function_name": row[0], "function_definition": row[1]}
for row in cur.fetchall()
if row[1] is not None
]
if rows:
functions[schema] = rows
conn.close()
return functions


def get_postgres_triggers(
self, schemas: List[str] = []
) -> Dict[str, List[Dict[str, str]]]:
try:
import psycopg2
except ImportError:
raise ImportError(
"psycopg2 not installed. Please install it with `pip install psycopg2-binary`."
)

conn = psycopg2.connect(**self.db_creds)
cur = conn.cursor()
triggers = {}

if len(schemas) == 0:
schemas = ["public"]

for schema in schemas:
cur.execute(
"""
SELECT
CAST(t.tgname AS TEXT) AS trigger_name,
pg_get_triggerdef(t.oid) AS trigger_definition
FROM pg_trigger t
JOIN pg_class c ON t.tgrelid = c.oid
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname = %s
""",
(schema,),
)
rows = [
{"trigger_name": row[0], "trigger_definition": row[1]}
for row in cur.fetchall()
if row[1] is not None
]
if rows:
triggers[schema] = rows

conn.close()
return triggers


def generate_redshift_schema(
self,
tables: list,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def package_files(directory):
name="defog",
packages=find_packages(),
package_data={"defog": ["gcp/*", "aws/*"] + next_static_files},
version="0.65.24",
version="0.66.0",
description="Defog is a Python library that helps you generate data queries from natural language questions.",
author="Full Stack Data Pte. Ltd.",
license="MIT",
Expand Down