Skip to content

Commit

Permalink
MSSQL Unqualified Schema Table Parsing (#530)
Browse files Browse the repository at this point in the history
Co-authored-by: Santos, Tyler (Boston) <[email protected]>
  • Loading branch information
T-Santos and Santos, Tyler (Boston) authored Oct 24, 2024
1 parent 8bdc496 commit da144a4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
19 changes: 11 additions & 8 deletions sql_metadata/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def columns(self) -> List[str]:
self._handle_column_save(token=token, columns=columns)

elif token.is_column_name_inside_insert_clause:
column = str(token.value).strip("`")
column = str(token.value)
self._add_to_columns_subsection(
keyword=token.last_keyword_normalized, column=column
)
Expand Down Expand Up @@ -369,10 +369,8 @@ def tables(self) -> List[str]:
and self.query_type == "INSERT"
):
continue

table_name = str(token.value.strip("`"))
token.token_type = TokenType.TABLE
tables.append(table_name)
tables.append(str(token.value))

self._tables = tables - with_names
return self._tables
Expand Down Expand Up @@ -1013,6 +1011,8 @@ def _is_token_part_of_complex_identifier(
Checks if token is a part of complex identifier like
<schema>.<table>.<column> or <table/sub_query>.<column>
"""
if token.is_keyword:
return False
return str(token) == "." or (
index + 1 < self.tokens_length
and str(self.non_empty_tokens[index + 1]) == "."
Expand All @@ -1026,16 +1026,19 @@ def _combine_qualified_names(self, index: int, token: SQLToken) -> None:
is_complex = True
while is_complex:
value, is_complex = self._combine_tokens(index=index, value=value)
index = index - 2
index = index - 1
token.value = value

def _combine_tokens(self, index: int, value: str) -> Tuple[str, bool]:
"""
Checks if complex identifier is longer and follows back until it's finished
"""
if index > 1 and str(self.non_empty_tokens[index - 1]) == ".":
prev_value = self.non_empty_tokens[index - 2].value.strip("`").strip('"')
value = f"{prev_value}.{value}"
if index > 1:
prev_value = self.non_empty_tokens[index - 1]
if not self._is_token_part_of_complex_identifier(prev_value, index - 1):
return value, False
prev_value = str(prev_value).strip("`")
value = f"{prev_value}{value}"
return value, True
return value, False

Expand Down
31 changes: 31 additions & 0 deletions test/test_mssql_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
import pytest

from sql_metadata.parser import Parser


@pytest.mark.parametrize(
"query, expected",
[
pytest.param(
"SELECT * FROM mydb..test_table",
["mydb..test_table"],
id="Default schema, db qualified",
),
pytest.param(
"SELECT * FROM ..test_table",
["..test_table"],
id="Default schema, db unqualified",
),
pytest.param(
"SELECT * FROM [mydb].[dbo].[test_table]",
["[mydb].[dbo].[test_table]"],
id="With object identifier delimiters",
),
pytest.param(
"SELECT * FROM [my_server].[mydb].[dbo].[test_table]",
["[my_server].[mydb].[dbo].[test_table]"],
id="With linked-server and object identifier delimiters",
),
],
)
def test_simple_queries_tables(query, expected):
assert Parser(query).tables == expected


def test_sql_server_cte():
"""
Tests support for SQL Server's common table expression (CTE).
Expand Down

0 comments on commit da144a4

Please sign in to comment.