Skip to content

Commit

Permalink
[DH-5018] Add dataset as prefix for table_name in table_descriptions …
Browse files Browse the repository at this point in the history
…collection
  • Loading branch information
jcjc712 committed Nov 29, 2023
1 parent acee6ab commit aa090cf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
5 changes: 5 additions & 0 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def scan_db(

scanner = self.system.instance(Scanner)
all_tables = scanner.get_all_tables_and_views(database)
if database.engine.driver in ["bigquery", "snowflake"]:
all_tables = [f"{database.engine.url.database}.{x}" for x in all_tables]

if scanner_request.table_names:
for table in scanner_request.table_names:
if table not in all_tables:
Expand Down Expand Up @@ -338,6 +341,8 @@ def list_table_descriptions(

scanner = self.system.instance(Scanner)
all_tables = scanner.get_all_tables_and_views(database)
if database.engine.driver in ["bigquery", "snowflake"]:
all_tables = [f"{database.engine.url.database}.{x}" for x in all_tables]

for table_description in table_descriptions:
if table_description.table_name not in all_tables:
Expand Down
21 changes: 14 additions & 7 deletions dataherald/db_scanner/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,17 @@ def get_table_schema(
self, meta: MetaData, db_engine: SQLDatabase, table: str
) -> str:
print(f"Create table schema: {table}")
create_table = str(
CreateTable([x for x in meta.sorted_tables if x.name == table][0]).compile(
db_engine.engine
obj = CreateTable(
[x for x in meta.sorted_tables if x.name == table][0]
).compile(db_engine.engine)

create_table = str(obj).rstrip()

if db_engine.engine.driver in ["bigquery", "snowflake"]:
create_table = create_table.replace(
table, f"{db_engine.engine.url.database}.{table}"
)
)
return f"{create_table.rstrip()}"
return create_table

def scan_single_table(
self,
Expand All @@ -189,7 +194,9 @@ def scan_single_table(

object = TableDescription(
db_connection_id=db_connection_id,
table_name=table,
table_name=f"{db_engine.engine.url.database}.{table}"
if db_engine.engine.driver in ["bigquery", "snowflake"]
else table,
columns=table_columns,
table_schema=self.get_table_schema(
meta=meta, db_engine=db_engine, table=table
Expand Down Expand Up @@ -217,7 +224,7 @@ def scan(
MetaData.reflect(meta, views=True)
tables = inspector.get_table_names() + inspector.get_view_names()
if table_names:
table_names = [table.lower() for table in table_names]
table_names = [table.lower().split(".")[-1] for table in table_names]
tables = [
table for table in tables if table and table.lower() in table_names
]
Expand Down

0 comments on commit aa090cf

Please sign in to comment.