Skip to content

Commit

Permalink
DBs without schema should store None
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Apr 17, 2024
1 parent e164a6f commit f392cbd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dataherald/sql_database/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class DatabaseConnection(BaseModel):

@classmethod
def get_dialect(cls, input_string):
pattern = r"([^:/]+):/+([^/]+)/?([^/]+)"
pattern = r"([^:/]+)://"
match = re.match(pattern, input_string)
if not match:
raise InvalidURIFormatError(f"Invalid URI format: {input_string}")
Expand Down
9 changes: 7 additions & 2 deletions dataherald/sql_database/services/database_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self, scanner: Scanner, storage: DB):
self.scanner = scanner
self.storage = storage

def get_current_schema(self, database_connection: DatabaseConnection) -> list[str]:
def get_current_schema(
self, database_connection: DatabaseConnection
) -> list[str] | None:
sql_database = SQLDatabase.get_sql_engine(database_connection, True)
inspector = inspect(sql_database.engine)
if inspector.default_schema_name and database_connection.dialect not in [
Expand All @@ -37,7 +39,7 @@ def get_current_schema(self, database_connection: DatabaseConnection) -> list[st
match = re.search(pattern, str(sql_database.engine.url))
if match:
return [match.group(1)]
return ["default"]
return None

def remove_schema_in_uri(self, connection_uri: str, dialect: str) -> str:
if dialect in ["snowflake"]:
Expand Down Expand Up @@ -99,6 +101,9 @@ def create(
)
sql_database = SQLDatabase.get_sql_engine(database_connection, True)
schemas_and_tables[schema] = sql_database.get_tables_and_views()
else:
sql_database = SQLDatabase.get_sql_engine(database_connection, True)
schemas_and_tables[None] = sql_database.get_tables_and_views()

# Connect db
db_connection_repository = DatabaseConnectionRepository(self.storage)
Expand Down

0 comments on commit f392cbd

Please sign in to comment.