diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index 538da8dd..dc969fa6 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -106,7 +106,7 @@ class DatabaseConnection(BaseModel): @classmethod def get_dialect(cls, input_string): - pattern = r"([^:/]+):/+([^/]+)/?([^/]+)" + pattern = r"([^:/]+)://" match = re.match(pattern, input_string) if not match: raise InvalidURIFormatError(f"Invalid URI format: {input_string}") diff --git a/dataherald/sql_database/services/database_connection.py b/dataherald/sql_database/services/database_connection.py index 916c3839..e011ce3b 100644 --- a/dataherald/sql_database/services/database_connection.py +++ b/dataherald/sql_database/services/database_connection.py @@ -17,7 +17,9 @@ def __init__(self, scanner: Scanner, storage: DB): self.scanner = scanner self.storage = storage - def get_current_schema(self, database_connection: DatabaseConnection) -> list[str]: + def get_current_schema( + self, database_connection: DatabaseConnection + ) -> list[str] | None: sql_database = SQLDatabase.get_sql_engine(database_connection, True) inspector = inspect(sql_database.engine) if inspector.default_schema_name and database_connection.dialect not in [ @@ -37,7 +39,7 @@ def get_current_schema(self, database_connection: DatabaseConnection) -> list[st match = re.search(pattern, str(sql_database.engine.url)) if match: return [match.group(1)] - return ["default"] + return None def remove_schema_in_uri(self, connection_uri: str, dialect: str) -> str: if dialect in ["snowflake"]: @@ -99,6 +101,9 @@ def create( ) sql_database = SQLDatabase.get_sql_engine(database_connection, True) schemas_and_tables[schema] = sql_database.get_tables_and_views() + else: + sql_database = SQLDatabase.get_sql_engine(database_connection, True) + schemas_and_tables[None] = sql_database.get_tables_and_views() # Connect db db_connection_repository = DatabaseConnectionRepository(self.storage)