From aa090cf750b3c04188d51df6d63683b5bb1a20ba Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Tue, 28 Nov 2023 19:21:36 -0600 Subject: [PATCH] [DH-5018] Add dataset as prefix for table_name in table_descriptions collection --- dataherald/api/fastapi.py | 5 +++++ dataherald/db_scanner/sqlalchemy.py | 21 ++++++++++++++------- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 26409845..68089b56 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -106,6 +106,9 @@ def scan_db( scanner = self.system.instance(Scanner) all_tables = scanner.get_all_tables_and_views(database) + if database.engine.driver in ["bigquery", "snowflake"]: + all_tables = [f"{database.engine.url.database}.{x}" for x in all_tables] + if scanner_request.table_names: for table in scanner_request.table_names: if table not in all_tables: @@ -338,6 +341,8 @@ def list_table_descriptions( scanner = self.system.instance(Scanner) all_tables = scanner.get_all_tables_and_views(database) + if database.engine.driver in ["bigquery", "snowflake"]: + all_tables = [f"{database.engine.url.database}.{x}" for x in all_tables] for table_description in table_descriptions: if table_description.table_name not in all_tables: diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index dd29195d..50687d95 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -158,12 +158,17 @@ def get_table_schema( self, meta: MetaData, db_engine: SQLDatabase, table: str ) -> str: print(f"Create table schema: {table}") - create_table = str( - CreateTable([x for x in meta.sorted_tables if x.name == table][0]).compile( - db_engine.engine + obj = CreateTable( + [x for x in meta.sorted_tables if x.name == table][0] + ).compile(db_engine.engine) + + create_table = str(obj).rstrip() + + if db_engine.engine.driver in ["bigquery", "snowflake"]: + create_table = create_table.replace( + table, f"{db_engine.engine.url.database}.{table}" ) - ) - return f"{create_table.rstrip()}" + return create_table def scan_single_table( self, @@ -189,7 +194,9 @@ def scan_single_table( object = TableDescription( db_connection_id=db_connection_id, - table_name=table, + table_name=f"{db_engine.engine.url.database}.{table}" + if db_engine.engine.driver in ["bigquery", "snowflake"] + else table, columns=table_columns, table_schema=self.get_table_schema( meta=meta, db_engine=db_engine, table=table @@ -217,7 +224,7 @@ def scan( MetaData.reflect(meta, views=True) tables = inspector.get_table_names() + inspector.get_view_names() if table_names: - table_names = [table.lower() for table in table_names] + table_names = [table.lower().split(".")[-1] for table in table_names] tables = [ table for table in tables if table and table.lower() in table_names ]