diff --git a/mlflow_oidc_auth/config.py b/mlflow_oidc_auth/config.py index ab26c21..f43faf1 100644 --- a/mlflow_oidc_auth/config.py +++ b/mlflow_oidc_auth/config.py @@ -10,6 +10,9 @@ load_dotenv() # take environment variables from .env. app.logger.setLevel(os.environ.get("LOG_LEVEL", "INFO")) +def get_bool_env_variable(variable, default_value): + value = os.environ.get(variable, str(default_value)) + return value.lower() in ["true", "1", "t"] class AppConfig: def __init__(self): @@ -26,10 +29,11 @@ def __init__(self): self.OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None) self.OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None) self.OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None) + self.ENABLE_AUTOMATIC_LOGIN_REDIRECT = get_bool_env_variable("ENABLE_AUTOMATIC_LOGIN_REDIRECT", False) # session self.SESSION_TYPE = os.environ.get("SESSION_TYPE", "cachelib") - self.SESSION_PERMANENT = os.environ.get("SESSION_PERMANENT", str(False)).lower() in ("true", "1", "t") + self.SESSION_PERMANENT = get_bool_env_variable("SESSION_PERMANENT", False) self.SESSION_KEY_PREFIX = os.environ.get("SESSION_KEY_PREFIX", "mlflow_oidc:") self.PERMANENT_SESSION_LIFETIME = os.environ.get("PERMANENT_SESSION_LIFETIME", 86400) if self.SESSION_TYPE: diff --git a/mlflow_oidc_auth/hooks/before_request.py b/mlflow_oidc_auth/hooks/before_request.py index 5eabbb7..1d76d5b 100644 --- a/mlflow_oidc_auth/hooks/before_request.py +++ b/mlflow_oidc_auth/hooks/before_request.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, Optional -from flask import render_template, request, session +from flask import render_template, request, session, redirect, url_for from mlflow.protos.model_registry_pb2 import ( CreateModelVersion, CreateRegisteredModel, @@ -195,6 +195,9 @@ def before_request_hook(): else: if session.get("username") is None: session.clear() + + if config.ENABLE_AUTOMATIC_LOGIN_REDIRECT: + return redirect(url_for("login", _external=True)) return render_template( "auth.html", username=None, diff --git a/mlflow_oidc_auth/views/authentication.py b/mlflow_oidc_auth/views/authentication.py index 70563fa..0e217a3 100644 --- a/mlflow_oidc_auth/views/authentication.py +++ b/mlflow_oidc_auth/views/authentication.py @@ -1,6 +1,6 @@ import secrets -from flask import redirect, session, url_for +from flask import redirect, session, url_for, render_template import mlflow_oidc_auth.utils as utils from mlflow_oidc_auth.auth import get_oauth_instance @@ -17,6 +17,12 @@ def login(): def logout(): session.clear() + if config.ENABLE_AUTOMATIC_LOGIN_REDIRECT: + return render_template( + "auth.html", + username=None, + provide_display_name=config.OIDC_PROVIDER_DISPLAY_NAME, + ) return redirect("/")