Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: copy oauth2 capture to get_sqla_engine #32137

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@
from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import get_username
from superset.utils.oauth2 import get_oauth2_access_token, OAuth2ClientConfigSchema
from superset.utils.oauth2 import (
check_for_oauth2,
get_oauth2_access_token,
OAuth2ClientConfigSchema,
)

config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
Expand Down Expand Up @@ -451,13 +455,14 @@ def get_sqla_engine( # pylint: disable=too-many-arguments

engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
with check_for_oauth2(self):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)

def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
self,
Expand Down Expand Up @@ -583,10 +588,9 @@ def get_raw_connection(
nullpool=nullpool,
source=source,
) as engine:
try:
with check_for_oauth2(self):
with closing(engine.raw_connection()) as conn:
# pre-session queries are used to set the selected schema and, in the # noqa: E501
# future, the selected catalog
# pre-session queries are used to set the selected catalog/schema
for prequery in self.db_engine_spec.get_prequeries(
database=self,
catalog=catalog,
Expand All @@ -597,11 +601,6 @@ def get_raw_connection(

yield conn

except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.db_engine_spec.start_oauth2_dance(self)
raise

def get_default_catalog(self) -> str | None:
"""
Return the default configured catalog for the database.
Expand Down
18 changes: 16 additions & 2 deletions superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from __future__ import annotations

from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import Any, TYPE_CHECKING
from typing import Any, Iterator, TYPE_CHECKING

import backoff
import jwt
Expand All @@ -32,7 +33,7 @@

if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import DatabaseUserOAuth2Tokens
from superset.models.core import Database, DatabaseUserOAuth2Tokens

JWT_EXPIRATION = timedelta(minutes=5)

Expand Down Expand Up @@ -197,3 +198,16 @@ class OAuth2ClientConfigSchema(Schema):
load_default=lambda: "json",
validate=validate.OneOf(["json", "data"]),
)


@contextmanager
def check_for_oauth2(database: Database) -> Iterator[None]:
"""
Run code and check if OAuth2 is needed.
"""
try:
yield
except Exception as ex:
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(ex):
database.db_engine_spec.start_oauth2_dance(database)
raise
98 changes: 92 additions & 6 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,16 +558,47 @@ def test_get_oauth2_config(app_context: None) -> None:
}


def test_raw_connection_oauth(mocker: MockerFixture) -> None:
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.

Some databases that use OAuth2 need to trigger the flow when the connection is
created, rather than when the query runs. This happens when the SQLAlchemy engine
URI cannot be built without the user personal token.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.

This test verifies that the exception is captured and raised correctly so that the
frontend can trigger the OAuth2 dance.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the engine is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42

database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")

with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."


def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.

With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.

This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the connection is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
Expand All @@ -591,6 +622,40 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> None:
assert str(excinfo.value) == "You don't have permission to access the data."


def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.

With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.

This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the connection is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42

database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = (
OAuth2Error("OAuth2 required")
)

with pytest.raises(OAuth2RedirectError) as excinfo: # noqa: PT012
with database.get_raw_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
assert str(excinfo.value) == "You don't have permission to access the data."


def test_get_schema_access_for_file_upload() -> None:
"""
Test the `get_schema_access_for_file_upload` method.
Expand Down Expand Up @@ -638,6 +703,27 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
)


def test_engine_oauth2(mocker: MockerFixture) -> None:
"""
Test that we handle OAuth2 when `create_engine` fails.
"""
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
start_oauth2_dance = mocker.patch.object(
database.db_engine_spec,
"start_oauth2_dance",
side_effect=OAuth2Error("OAuth2 required"),
)

with pytest.raises(OAuth2Error):
with database.get_sqla_engine("catalog", "schema"):
pass

start_oauth2_dance.assert_called_with(database)


def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.
Expand Down
Loading