Skip to content

Commit

Permalink
fix: move oauth2 capture to get_sqla_engine (#32137)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Feb 4, 2025
1 parent c64018d commit c7c3b1b
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 24 deletions.
31 changes: 15 additions & 16 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@
from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import get_username
from superset.utils.oauth2 import get_oauth2_access_token, OAuth2ClientConfigSchema
from superset.utils.oauth2 import (
check_for_oauth2,
get_oauth2_access_token,
OAuth2ClientConfigSchema,
)

config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
Expand Down Expand Up @@ -451,13 +455,14 @@ def get_sqla_engine( # pylint: disable=too-many-arguments

engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
with check_for_oauth2(self):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)

def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
self,
Expand Down Expand Up @@ -583,10 +588,9 @@ def get_raw_connection(
nullpool=nullpool,
source=source,
) as engine:
try:
with check_for_oauth2(self):
with closing(engine.raw_connection()) as conn:
# pre-session queries are used to set the selected schema and, in the # noqa: E501
# future, the selected catalog
# pre-session queries are used to set the selected catalog/schema
for prequery in self.db_engine_spec.get_prequeries(
database=self,
catalog=catalog,
Expand All @@ -597,11 +601,6 @@ def get_raw_connection(

yield conn

except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.db_engine_spec.start_oauth2_dance(self)
raise

def get_default_catalog(self) -> str | None:
"""
Return the default configured catalog for the database.
Expand Down
18 changes: 16 additions & 2 deletions superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from __future__ import annotations

from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import Any, TYPE_CHECKING
from typing import Any, Iterator, TYPE_CHECKING

import backoff
import jwt
Expand All @@ -32,7 +33,7 @@

if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import DatabaseUserOAuth2Tokens
from superset.models.core import Database, DatabaseUserOAuth2Tokens

JWT_EXPIRATION = timedelta(minutes=5)

Expand Down Expand Up @@ -197,3 +198,16 @@ class OAuth2ClientConfigSchema(Schema):
load_default=lambda: "json",
validate=validate.OneOf(["json", "data"]),
)


@contextmanager
def check_for_oauth2(database: Database) -> Iterator[None]:
"""
Run code and check if OAuth2 is needed.
"""
try:
yield
except Exception as ex:
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(ex):
database.db_engine_spec.start_oauth2_dance(database)
raise
98 changes: 92 additions & 6 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,16 +558,47 @@ def test_get_oauth2_config(app_context: None) -> None:
}


def test_raw_connection_oauth(mocker: MockerFixture) -> None:
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
Some databases that use OAuth2 need to trigger the flow when the connection is
created, rather than when the query runs. This happens when the SQLAlchemy engine
URI cannot be built without the user personal token.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This test verifies that the exception is captured and raised correctly so that the
frontend can trigger the OAuth2 dance.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the engine is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42

database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")

with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."


def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the connection is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
Expand All @@ -591,6 +622,40 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> None:
assert str(excinfo.value) == "You don't have permission to access the data."


def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the connection is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42

database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = (
OAuth2Error("OAuth2 required")
)

with pytest.raises(OAuth2RedirectError) as excinfo: # noqa: PT012
with database.get_raw_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
assert str(excinfo.value) == "You don't have permission to access the data."


def test_get_schema_access_for_file_upload() -> None:
"""
Test the `get_schema_access_for_file_upload` method.
Expand Down Expand Up @@ -638,6 +703,27 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
)


def test_engine_oauth2(mocker: MockerFixture) -> None:
"""
Test that we handle OAuth2 when `create_engine` fails.
"""
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
start_oauth2_dance = mocker.patch.object(
database.db_engine_spec,
"start_oauth2_dance",
side_effect=OAuth2Error("OAuth2 required"),
)

with pytest.raises(OAuth2Error):
with database.get_sqla_engine("catalog", "schema"):
pass

start_oauth2_dance.assert_called_with(database)


def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.
Expand Down

0 comments on commit c7c3b1b

Please sign in to comment.