Skip to content

Commit

Permalink
Added ability to specify a schema when extracting columns for Postgre…
Browse files Browse the repository at this point in the history
…s tables (#49)

* enabled custom schemas in postgres

* added ability to support schemas other than `public` to `generate_postgres_schema`

* removed vestigal code

* linting

* fixed issue with extracting columns for tables with schema names pre-prended
  • Loading branch information
rishsriv authored May 9, 2024
1 parent e6e83ac commit 958af22
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 53 deletions.
85 changes: 36 additions & 49 deletions defog/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def generate_postgres_schema(
return_format: str = "csv",
scan: bool = True,
return_tables_only: bool = False,
schemas: list[str] = ["public"],
) -> str:
# when upload is True, we send the schema to the defog servers and generate a CSV
# when its false, we return the schema as a dict
Expand All @@ -24,60 +25,48 @@ def generate_postgres_schema(

conn = psycopg2.connect(**self.db_creds)
cur = conn.cursor()
schemas = {}
schemas = tuple(schemas)

if len(tables) == 0:
# get all tables
cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
)
tables = [row[0] for row in cur.fetchall()]
for schema in schemas:
cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s;",
(schema,),
)
if schema == "public":
tables += [row[0] for row in cur.fetchall()]
else:
tables += [schema + "." + row[0] for row in cur.fetchall()]

if return_tables_only:
return tables

print("Getting schema for each table that you selected...")
# get the schema for each table
for table_name in tables:
cur.execute(
"SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s;",
(table_name,),
)
rows = cur.fetchall()
rows = [row for row in rows]
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows]
if scan:
rows = identify_categorical_columns(cur, table_name, rows)
schemas[table_name] = rows

# get foreign key relationships
print("Getting foreign keys for each table that you selected...")
tables_regclass_str = ", ".join(
[f"'{table_name}'::regclass" for table_name in tables]
)
query = f"""SELECT
conrelid::regclass AS table_from,
pg_get_constraintdef(oid) AS foreign_key_definition
FROM pg_constraint
WHERE contype = 'f'
AND conrelid::regclass IN ({tables_regclass_str})
AND confrelid::regclass IN ({tables_regclass_str});
"""
cur.execute(query)
foreign_keys = list(cur.fetchall())
foreign_keys = [fk[0] + " " + fk[1] for fk in foreign_keys]

# get indexes for each table
print("Getting indexes for each table that you selected...")
tables_str = ", ".join([f"'{table_name}'" for table_name in tables])
query = f"""SELECT indexdef FROM pg_indexes WHERE tablename IN ({tables_str});"""
cur.execute(query)
indexes = list(cur.fetchall())
if len(indexes) > 0:
indexes = [index[0] for index in indexes]
else:
indexes = []
# print("No indexes found.")
table_columns = {}

# get the columns for each table
for schema in schemas:
for table_name in tables:
if "." in table_name:
_, table_name = table_name.split(".", 1)
cur.execute(
"SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s AND table_schema = %s;",
(
table_name,
schema,
),
)
rows = cur.fetchall()
rows = [row for row in rows]
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows]
if scan:
rows = identify_categorical_columns(cur, table_name, rows)
if schema == "public":
table_columns[table_name] = rows
else:
table_columns[schema + table_name] = rows
conn.close()

print(
Expand All @@ -89,9 +78,7 @@ def generate_postgres_schema(
f"{self.base_url}/get_schema_csv",
json={
"api_key": self.api_key,
"schemas": schemas,
"foreign_keys": foreign_keys,
"indexes": indexes,
"schemas": table_columns,
},
)
resp = r.json()
Expand All @@ -110,7 +97,7 @@ def generate_postgres_schema(
f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email [email protected]."
)
else:
return schemas
return table_columns


def generate_redshift_schema(
Expand Down
4 changes: 0 additions & 4 deletions defog/query_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ def get_query(
"previous_context": previous_context,
"db_type": self.db_type if self.db_type != "databricks" else "postgres",
"glossary": glossary,
"language": language,
"hard_filters": hard_filters,
"dev": dev,
"ignore_cache": ignore_cache,
}
if schema != {}:
data["schema"] = schema
data["is_direct"] = True

t_start = datetime.now()
r = requests.post(
Expand Down

0 comments on commit 958af22

Please sign in to comment.