-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added ability to specify a schema when extracting columns for Postgre…
…s tables (#49) * enabled custom schemas in postgres * added ability to support schemas other than `public` to `generate_postgres_schema` * removed vestigal code * linting * fixed issue with extracting columns for tables with schema names pre-prended
- Loading branch information
Showing
2 changed files
with
36 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ def generate_postgres_schema( | |
return_format: str = "csv", | ||
scan: bool = True, | ||
return_tables_only: bool = False, | ||
schemas: list[str] = ["public"], | ||
) -> str: | ||
# when upload is True, we send the schema to the defog servers and generate a CSV | ||
# when its false, we return the schema as a dict | ||
|
@@ -24,60 +25,48 @@ def generate_postgres_schema( | |
|
||
conn = psycopg2.connect(**self.db_creds) | ||
cur = conn.cursor() | ||
schemas = {} | ||
schemas = tuple(schemas) | ||
|
||
if len(tables) == 0: | ||
# get all tables | ||
cur.execute( | ||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" | ||
) | ||
tables = [row[0] for row in cur.fetchall()] | ||
for schema in schemas: | ||
cur.execute( | ||
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s;", | ||
(schema,), | ||
) | ||
if schema == "public": | ||
tables += [row[0] for row in cur.fetchall()] | ||
else: | ||
tables += [schema + "." + row[0] for row in cur.fetchall()] | ||
|
||
if return_tables_only: | ||
return tables | ||
|
||
print("Getting schema for each table that you selected...") | ||
# get the schema for each table | ||
for table_name in tables: | ||
cur.execute( | ||
"SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s;", | ||
(table_name,), | ||
) | ||
rows = cur.fetchall() | ||
rows = [row for row in rows] | ||
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] | ||
if scan: | ||
rows = identify_categorical_columns(cur, table_name, rows) | ||
schemas[table_name] = rows | ||
|
||
# get foreign key relationships | ||
print("Getting foreign keys for each table that you selected...") | ||
tables_regclass_str = ", ".join( | ||
[f"'{table_name}'::regclass" for table_name in tables] | ||
) | ||
query = f"""SELECT | ||
conrelid::regclass AS table_from, | ||
pg_get_constraintdef(oid) AS foreign_key_definition | ||
FROM pg_constraint | ||
WHERE contype = 'f' | ||
AND conrelid::regclass IN ({tables_regclass_str}) | ||
AND confrelid::regclass IN ({tables_regclass_str}); | ||
""" | ||
cur.execute(query) | ||
foreign_keys = list(cur.fetchall()) | ||
foreign_keys = [fk[0] + " " + fk[1] for fk in foreign_keys] | ||
|
||
# get indexes for each table | ||
print("Getting indexes for each table that you selected...") | ||
tables_str = ", ".join([f"'{table_name}'" for table_name in tables]) | ||
query = f"""SELECT indexdef FROM pg_indexes WHERE tablename IN ({tables_str});""" | ||
cur.execute(query) | ||
indexes = list(cur.fetchall()) | ||
if len(indexes) > 0: | ||
indexes = [index[0] for index in indexes] | ||
else: | ||
indexes = [] | ||
# print("No indexes found.") | ||
table_columns = {} | ||
|
||
# get the columns for each table | ||
for schema in schemas: | ||
for table_name in tables: | ||
if "." in table_name: | ||
_, table_name = table_name.split(".", 1) | ||
cur.execute( | ||
"SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s AND table_schema = %s;", | ||
( | ||
table_name, | ||
schema, | ||
), | ||
) | ||
rows = cur.fetchall() | ||
rows = [row for row in rows] | ||
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] | ||
if scan: | ||
rows = identify_categorical_columns(cur, table_name, rows) | ||
if schema == "public": | ||
table_columns[table_name] = rows | ||
else: | ||
table_columns[schema + table_name] = rows | ||
conn.close() | ||
|
||
print( | ||
|
@@ -89,9 +78,7 @@ def generate_postgres_schema( | |
f"{self.base_url}/get_schema_csv", | ||
json={ | ||
"api_key": self.api_key, | ||
"schemas": schemas, | ||
"foreign_keys": foreign_keys, | ||
"indexes": indexes, | ||
"schemas": table_columns, | ||
}, | ||
) | ||
resp = r.json() | ||
|
@@ -110,7 +97,7 @@ def generate_postgres_schema( | |
f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email [email protected]." | ||
) | ||
else: | ||
return schemas | ||
return table_columns | ||
|
||
|
||
def generate_redshift_schema( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters