From 031e3945e44030d4f085753ffdc43dc104708f91 Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Mon, 24 Jul 2023 13:46:44 -0400 Subject: [PATCH] Allow auth managers to override the security manager (#32525) Allow auth managers to override the security manager --- airflow/api/auth/backend/session.py | 4 +- airflow/auth/managers/base_auth_manager.py | 12 + airflow/auth/managers/fab/fab_auth_manager.py | 5 + .../managers/fab/security_manager_override.py | 220 ++++++++++++++++++ airflow/configuration.py | 34 ++- airflow/www/auth.py | 5 +- airflow/www/decorators.py | 4 +- airflow/www/extensions/init_appbuilder.py | 7 - airflow/www/extensions/init_auth_manager.py | 40 ++++ airflow/www/extensions/init_jinja_globals.py | 5 +- airflow/www/extensions/init_security.py | 5 +- airflow/www/fab_security/manager.py | 137 +---------- airflow/www/fab_security/sqla/manager.py | 2 +- airflow/www/security.py | 40 +++- airflow/www/views.py | 12 +- .../auh/managers/fab/test_fab_auth_manager.py | 4 + tests/auh/managers/test_base_auth_manager.py | 35 +++ 17 files changed, 409 insertions(+), 162 deletions(-) create mode 100644 airflow/auth/managers/fab/security_manager_override.py create mode 100644 airflow/www/extensions/init_auth_manager.py create mode 100644 tests/auh/managers/test_base_auth_manager.py diff --git a/airflow/api/auth/backend/session.py b/airflow/api/auth/backend/session.py index c55f7484605c..ef914b57e442 100644 --- a/airflow/api/auth/backend/session.py +++ b/airflow/api/auth/backend/session.py @@ -22,7 +22,7 @@ from flask import Response -from airflow.configuration import auth_manager +from airflow.www.extensions.init_auth_manager import get_auth_manager CLIENT_AUTH: tuple[str, str] | Any | None = None @@ -39,7 +39,7 @@ def requires_authentication(function: T): @wraps(function) def decorated(*args, **kwargs): - if not auth_manager.is_logged_in(): + if not get_auth_manager().is_logged_in(): return Response("Unauthorized", 401, {}) return function(*args, **kwargs) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 462fe34d6304..ab5356bf8ca3 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -38,3 +38,15 @@ def get_user_name(self) -> str: def is_logged_in(self) -> bool: """Return whether the user is logged in.""" ... + + def get_security_manager_override_class(self) -> type: + """ + Return the security manager override class. + + The security manager override class is responsible for overriding the default security manager + class airflow.www.security.AirflowSecurityManager with a custom implementation. This class is + essentially inherited from airflow.www.security.AirflowSecurityManager. + + By default, return an empty class. + """ + return object diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index b9f0c1e1df5a..f90a9ac06353 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -20,6 +20,7 @@ from flask_login import current_user from airflow.auth.managers.base_auth_manager import BaseAuthManager +from airflow.auth.managers.fab.security_manager_override import FabAirflowSecurityManagerOverride class FabAuthManager(BaseAuthManager): @@ -43,3 +44,7 @@ def get_user_name(self) -> str: def is_logged_in(self) -> bool: """Return whether the user is logged in.""" return current_user and not current_user.is_anonymous + + def get_security_manager_override_class(self) -> type: + """Return the security manager override.""" + return FabAirflowSecurityManagerOverride diff --git a/airflow/auth/managers/fab/security_manager_override.py b/airflow/auth/managers/fab/security_manager_override.py new file mode 100644 index 000000000000..5be9ee1f3672 --- /dev/null +++ b/airflow/auth/managers/fab/security_manager_override.py @@ -0,0 +1,220 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property + +from flask_appbuilder.const import AUTH_DB, AUTH_LDAP, AUTH_OAUTH, AUTH_OID, AUTH_REMOTE_USER +from flask_babel import lazy_gettext + + +class FabAirflowSecurityManagerOverride: + """ + This security manager overrides the default AirflowSecurityManager security manager. + + This security manager is used only if the auth manager FabAuthManager is used. It defines everything in + the security manager that is needed for the FabAuthManager to work. Any operation specific to + the AirflowSecurityManager should be defined here instead of AirflowSecurityManager. + + :param appbuilder: The appbuilder. + :param actionmodelview: The obj instance for action model view. + :param authdbview: The class for auth db view. + :param authldapview: The class for auth ldap view. + :param authoauthview: The class for auth oauth view. + :param authoidview: The class for auth oid view. + :param authremoteuserview: The class for auth remote user view. + :param permissionmodelview: The class for permission model view. + :param registeruser_view: The class for register user view. + :param registeruserdbview: The class for register user db view. + :param registeruseroauthview: The class for register user oauth view. + :param registerusermodelview: The class for register user model view. + :param registeruseroidview: The class for register user oid view. + :param resetmypasswordview: The class for reset my password view. + :param resetpasswordview: The class for reset password view. + :param rolemodelview: The class for role model view. + :param userinfoeditview: The class for user info edit view. + :param userdbmodelview: The class for user db model view. + :param userldapmodelview: The class for user ldap model view. + :param useroauthmodelview: The class for user oauth model view. + :param useroidmodelview: The class for user oid model view. + :param userremoteusermodelview: The class for user remote user model view. + :param userstatschartview: The class for user stats chart view. + """ + + """ The obj instance for authentication view """ + auth_view = None + """ The obj instance for user view """ + user_view = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.appbuilder = kwargs["appbuilder"] + self.actionmodelview = kwargs["actionmodelview"] + self.authdbview = kwargs["authdbview"] + self.authldapview = kwargs["authldapview"] + self.authoauthview = kwargs["authoauthview"] + self.authoidview = kwargs["authoidview"] + self.authremoteuserview = kwargs["authremoteuserview"] + self.permissionmodelview = kwargs["permissionmodelview"] + self.registeruser_view = kwargs["registeruser_view"] + self.registeruserdbview = kwargs["registeruserdbview"] + self.registeruseroauthview = kwargs["registeruseroauthview"] + self.registerusermodelview = kwargs["registerusermodelview"] + self.registeruseroidview = kwargs["registeruseroidview"] + self.resetmypasswordview = kwargs["resetmypasswordview"] + self.resetpasswordview = kwargs["resetpasswordview"] + self.rolemodelview = kwargs["rolemodelview"] + self.userinfoeditview = kwargs["userinfoeditview"] + self.userdbmodelview = kwargs["userdbmodelview"] + self.userldapmodelview = kwargs["userldapmodelview"] + self.useroauthmodelview = kwargs["useroauthmodelview"] + self.useroidmodelview = kwargs["useroidmodelview"] + self.userremoteusermodelview = kwargs["userremoteusermodelview"] + self.userstatschartview = kwargs["userstatschartview"] + + def register_views(self): + """Register FAB auth manager related views.""" + if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True): + return + + if self.auth_user_registration: + if self.auth_type == AUTH_DB: + self.registeruser_view = self.registeruserdbview() + elif self.auth_type == AUTH_OID: + self.registeruser_view = self.registeruseroidview() + elif self.auth_type == AUTH_OAUTH: + self.registeruser_view = self.registeruseroauthview() + if self.registeruser_view: + self.appbuilder.add_view_no_menu(self.registeruser_view) + + self.appbuilder.add_view_no_menu(self.resetpasswordview()) + self.appbuilder.add_view_no_menu(self.resetmypasswordview()) + self.appbuilder.add_view_no_menu(self.userinfoeditview()) + + if self.auth_type == AUTH_DB: + self.user_view = self.userdbmodelview + self.auth_view = self.authdbview() + elif self.auth_type == AUTH_LDAP: + self.user_view = self.userldapmodelview + self.auth_view = self.authldapview() + elif self.auth_type == AUTH_OAUTH: + self.user_view = self.useroauthmodelview + self.auth_view = self.authoauthview() + elif self.auth_type == AUTH_REMOTE_USER: + self.user_view = self.userremoteusermodelview + self.auth_view = self.authremoteuserview() + else: + self.user_view = self.useroidmodelview + self.auth_view = self.authoidview() + + self.appbuilder.add_view_no_menu(self.auth_view) + + # this needs to be done after the view is added, otherwise the blueprint + # is not initialized + if self.is_auth_limited: + self.limiter.limit(self.auth_rate_limit, methods=["POST"])(self.auth_view.blueprint) + + self.user_view = self.appbuilder.add_view( + self.user_view, + "List Users", + icon="fa-user", + label=lazy_gettext("List Users"), + category="Security", + category_icon="fa-cogs", + category_label=lazy_gettext("Security"), + ) + + role_view = self.appbuilder.add_view( + self.rolemodelview, + "List Roles", + icon="fa-group", + label=lazy_gettext("List Roles"), + category="Security", + category_icon="fa-cogs", + ) + role_view.related_views = [self.user_view.__class__] + + if self.userstatschartview: + self.appbuilder.add_view( + self.userstatschartview, + "User's Statistics", + icon="fa-bar-chart-o", + label=lazy_gettext("User's Statistics"), + category="Security", + ) + if self.auth_user_registration: + self.appbuilder.add_view( + self.registerusermodelview, + "User's Statistics", + icon="fa-user-plus", + label=lazy_gettext("User Registrations"), + category="Security", + ) + self.appbuilder.menu.add_separator("Security") + if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEW", True): + self.appbuilder.add_view( + self.actionmodelview, + "Actions", + icon="fa-lock", + label=lazy_gettext("Actions"), + category="Security", + ) + if self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEW_MENU_VIEW", True): + self.appbuilder.add_view( + self.resourcemodelview, + "Resources", + icon="fa-list-alt", + label=lazy_gettext("Resources"), + category="Security", + ) + if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True): + self.appbuilder.add_view( + self.permissionmodelview, + "Permission Pairs", + icon="fa-link", + label=lazy_gettext("Permissions"), + category="Security", + ) + + @property + def auth_user_registration(self): + """Will user self registration be allowed.""" + return self.appbuilder.get_app.config["AUTH_USER_REGISTRATION"] + + @property + def auth_type(self): + """Get the auth type.""" + return self.appbuilder.get_app.config["AUTH_TYPE"] + + @property + def is_auth_limited(self) -> bool: + """Is the auth rate limited.""" + return self.appbuilder.get_app.config["AUTH_RATE_LIMITED"] + + @property + def auth_rate_limit(self) -> str: + """Get the auth rate limit.""" + return self.appbuilder.get_app.config["AUTH_RATE_LIMIT"] + + @cached_property + def resourcemodelview(self): + """Return the resource model view.""" + from airflow.www.views import ResourceModelView + + return ResourceModelView diff --git a/airflow/configuration.py b/airflow/configuration.py index cca5588e543b..d6492891d3dd 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -2209,6 +2209,39 @@ def initialize_secrets_backends() -> list[BaseSecretsBackend]: return backend_list +@functools.lru_cache(maxsize=None) +def _DEFAULT_CONFIG() -> str: + path = _default_config_file_path("default_airflow.cfg") + with open(path) as fh: + return fh.read() + + +@functools.lru_cache(maxsize=None) +def _TEST_CONFIG() -> str: + path = _default_config_file_path("default_test.cfg") + with open(path) as fh: + return fh.read() + + +_deprecated = { + "DEFAULT_CONFIG": _DEFAULT_CONFIG, + "TEST_CONFIG": _TEST_CONFIG, + "TEST_CONFIG_FILE_PATH": functools.partial(_default_config_file_path, "default_test.cfg"), + "DEFAULT_CONFIG_FILE_PATH": functools.partial(_default_config_file_path, "default_airflow.cfg"), +} + + +def __getattr__(name): + if name in _deprecated: + warnings.warn( + f"{__name__}.{name} is deprecated and will be removed in future", + DeprecationWarning, + stacklevel=2, + ) + return _deprecated[name]() + raise AttributeError(f"module {__name__} has no attribute {name}") + + def initialize_auth_manager() -> BaseAuthManager: """ Initialize auth manager. @@ -2257,5 +2290,4 @@ def initialize_auth_manager() -> BaseAuthManager: conf = initialize_config() secrets_backend_list = initialize_secrets_backends() -auth_manager = initialize_auth_manager() conf.validate() diff --git a/airflow/www/auth.py b/airflow/www/auth.py index 82fb5d34c5c0..54114da1c8a7 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -21,8 +21,9 @@ from flask import current_app, flash, g, redirect, render_template, request, url_for -from airflow.configuration import auth_manager, conf +from airflow.configuration import conf from airflow.utils.net import get_hostname +from airflow.www.extensions.init_auth_manager import get_auth_manager T = TypeVar("T", bound=Callable) @@ -46,7 +47,7 @@ def decorated(*args, **kwargs): ) if appbuilder.sm.check_authorization(permissions, dag_id): return func(*args, **kwargs) - elif auth_manager.is_logged_in() and not g.user.perms: + elif get_auth_manager().is_logged_in() and not g.user.perms: return ( render_template( "airflow/no_roles_permissions.html", diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py index fc386d220f6b..af316e3ed0fd 100644 --- a/airflow/www/decorators.py +++ b/airflow/www/decorators.py @@ -29,10 +29,10 @@ from flask import after_this_request, g, request from pendulum.parsing.exceptions import ParserError -from airflow.configuration import auth_manager from airflow.models import Log from airflow.utils.log import secrets_masker from airflow.utils.session import create_session +from airflow.www.extensions.init_auth_manager import get_auth_manager T = TypeVar("T", bound=Callable) @@ -85,7 +85,7 @@ def wrapper(*args, **kwargs): __tracebackhide__ = True # Hide from pytest traceback. with create_session() as session: - if not auth_manager.is_logged_in(): + if not get_auth_manager().is_logged_in(): user = "anonymous" else: user = f"{g.user.username} ({g.user.get_full_name()})" diff --git a/airflow/www/extensions/init_appbuilder.py b/airflow/www/extensions/init_appbuilder.py index ac9d2c9107df..ae793ca95646 100644 --- a/airflow/www/extensions/init_appbuilder.py +++ b/airflow/www/extensions/init_appbuilder.py @@ -208,13 +208,6 @@ def init_app(self, app, session): if self.update_perms: # default is True, if False takes precedence from config self.update_perms = app.config.get("FAB_UPDATE_PERMS", True) - _security_manager_class_name = app.config.get("FAB_SECURITY_MANAGER_CLASS", None) - if _security_manager_class_name is not None: - self.security_manager_class = dynamic_class_import(_security_manager_class_name) - if self.security_manager_class is None: - from flask_appbuilder.security.sqla.manager import SecurityManager - - self.security_manager_class = SecurityManager self._addon_managers = app.config["ADDON_MANAGERS"] self.session = session diff --git a/airflow/www/extensions/init_auth_manager.py b/airflow/www/extensions/init_auth_manager.py new file mode 100644 index 000000000000..d21139f67069 --- /dev/null +++ b/airflow/www/extensions/init_auth_manager.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.auth.managers.base_auth_manager import BaseAuthManager +from airflow.compat.functools import cache +from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException + + +@cache +def get_auth_manager() -> BaseAuthManager: + """ + Initialize auth manager. + + Import the user manager class, instantiate it and return it. + """ + auth_manager_cls = conf.getimport(section="core", key="auth_manager") + + if not auth_manager_cls: + raise AirflowConfigException( + "No auth manager defined in the config. " + "Please specify one using section/key [core/auth_manager]." + ) + + return auth_manager_cls() diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py index 9ef948084cc9..13baeea7bc67 100644 --- a/airflow/www/extensions/init_jinja_globals.py +++ b/airflow/www/extensions/init_jinja_globals.py @@ -21,10 +21,11 @@ import pendulum import airflow -from airflow.configuration import auth_manager, conf +from airflow.configuration import conf from airflow.settings import IS_K8S_OR_K8SCELERY_EXECUTOR, STATE_COLORS from airflow.utils.net import get_hostname from airflow.utils.platform import get_airflow_git_version +from airflow.www.extensions.init_auth_manager import get_auth_manager def init_jinja_globals(app): @@ -68,7 +69,7 @@ def prepare_jinja_globals(): "git_version": git_version, "k8s_or_k8scelery_executor": IS_K8S_OR_K8SCELERY_EXECUTOR, "rest_api_enabled": False, - "auth_manager": auth_manager, + "auth_manager": get_auth_manager(), "config_test_connection": conf.get("core", "test_connection", fallback="Disabled"), } diff --git a/airflow/www/extensions/init_security.py b/airflow/www/extensions/init_security.py index ba57f99b1448..17f93fc1c5ee 100644 --- a/airflow/www/extensions/init_security.py +++ b/airflow/www/extensions/init_security.py @@ -22,8 +22,9 @@ from flask import g, redirect, url_for from flask_login import logout_user -from airflow.configuration import auth_manager, conf +from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.www.extensions.init_auth_manager import get_auth_manager log = logging.getLogger(__name__) @@ -68,6 +69,6 @@ def init_api_experimental_auth(app): def init_check_user_active(app): @app.before_request def check_user_active(): - if auth_manager.is_logged_in() and not g.user.is_active: + if get_auth_manager().is_logged_in() and not g.user.is_active: logout_user() return redirect(url_for(app.appbuilder.sm.auth_view.endpoint + ".login")) diff --git a/airflow/www/fab_security/manager.py b/airflow/www/fab_security/manager.py index 00d38064363e..3c59522a6268 100644 --- a/airflow/www/fab_security/manager.py +++ b/airflow/www/fab_security/manager.py @@ -22,7 +22,6 @@ import datetime import json import logging -from functools import cached_property from typing import Any from uuid import uuid4 @@ -34,7 +33,6 @@ AUTH_LDAP, AUTH_OAUTH, AUTH_OID, - AUTH_REMOTE_USER, LOGMSG_ERR_SEC_ADD_REGISTER_USER, LOGMSG_ERR_SEC_AUTH_LDAP, LOGMSG_ERR_SEC_AUTH_LDAP_TLS, @@ -66,14 +64,14 @@ UserRemoteUserModelView, UserStatsChartView, ) -from flask_babel import lazy_gettext as _ from flask_jwt_extended import JWTManager, current_user as current_user_jwt from flask_limiter import Limiter from flask_limiter.util import get_remote_address from flask_login import AnonymousUserMixin, LoginManager, current_user from werkzeug.security import check_password_hash, generate_password_hash -from airflow.configuration import auth_manager, conf +from airflow.configuration import conf +from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.fab_security.sqla.models import Action, Permission, RegisterUser, Resource, Role, User # This product contains a modified portion of 'Flask App Builder' developed by Daniel Vaz Gaspar. @@ -208,12 +206,6 @@ def oauth_tokengetter(token=None): userstatschartview = UserStatsChartView permissionmodelview = PermissionModelView - @cached_property - def resourcemodelview(self): - from airflow.www.views import ResourceModelView - - return ResourceModelView - def __init__(self, appbuilder): self.appbuilder = appbuilder app = self.appbuilder.get_app @@ -374,11 +366,6 @@ def builtin_roles(self): def api_login_allow_multiple_providers(self): return self.appbuilder.get_app.config["AUTH_API_LOGIN_ALLOW_MULTIPLE_PROVIDERS"] - @property - def auth_type(self): - """Get the auth type.""" - return self.appbuilder.get_app.config["AUTH_TYPE"] - @property def auth_username_ci(self): """Gets the auth username for CI.""" @@ -529,18 +516,10 @@ def oauth_providers(self): """Oauth providers.""" return self.appbuilder.get_app.config["OAUTH_PROVIDERS"] - @property - def is_auth_limited(self) -> bool: - return self.appbuilder.get_app.config["AUTH_RATE_LIMITED"] - - @property - def auth_rate_limit(self) -> str: - return self.appbuilder.get_app.config["AUTH_RATE_LIMIT"] - @property def current_user(self): """Current user object.""" - if auth_manager.is_logged_in(): + if get_auth_manager().is_logged_in(): return g.user elif current_user_jwt: return current_user_jwt @@ -732,114 +711,6 @@ def _azure_jwt_token_parse(self, id_token): return jwt_decoded_payload - def register_views(self): - if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True): - return - - if self.auth_user_registration: - if self.auth_type == AUTH_DB: - self.registeruser_view = self.registeruserdbview() - elif self.auth_type == AUTH_OID: - self.registeruser_view = self.registeruseroidview() - elif self.auth_type == AUTH_OAUTH: - self.registeruser_view = self.registeruseroauthview() - if self.registeruser_view: - self.appbuilder.add_view_no_menu(self.registeruser_view) - - self.appbuilder.add_view_no_menu(self.resetpasswordview()) - self.appbuilder.add_view_no_menu(self.resetmypasswordview()) - self.appbuilder.add_view_no_menu(self.userinfoeditview()) - - if self.auth_type == AUTH_DB: - self.user_view = self.userdbmodelview - self.auth_view = self.authdbview() - - elif self.auth_type == AUTH_LDAP: - self.user_view = self.userldapmodelview - self.auth_view = self.authldapview() - elif self.auth_type == AUTH_OAUTH: - self.user_view = self.useroauthmodelview - self.auth_view = self.authoauthview() - elif self.auth_type == AUTH_REMOTE_USER: - self.user_view = self.userremoteusermodelview - self.auth_view = self.authremoteuserview() - else: - self.user_view = self.useroidmodelview - self.auth_view = self.authoidview() - if self.auth_user_registration: - pass - # self.registeruser_view = self.registeruseroidview() - # self.appbuilder.add_view_no_menu(self.registeruser_view) - - self.appbuilder.add_view_no_menu(self.auth_view) - - # this needs to be done after the view is added, otherwise the blueprint - # is not initialized - if self.is_auth_limited: - self.limiter.limit(self.auth_rate_limit, methods=["POST"])(self.auth_view.blueprint) - - self.user_view = self.appbuilder.add_view( - self.user_view, - "List Users", - icon="fa-user", - label=_("List Users"), - category="Security", - category_icon="fa-cogs", - category_label=_("Security"), - ) - - role_view = self.appbuilder.add_view( - self.rolemodelview, - "List Roles", - icon="fa-group", - label=_("List Roles"), - category="Security", - category_icon="fa-cogs", - ) - role_view.related_views = [self.user_view.__class__] - - if self.userstatschartview: - self.appbuilder.add_view( - self.userstatschartview, - "User's Statistics", - icon="fa-bar-chart-o", - label=_("User's Statistics"), - category="Security", - ) - if self.auth_user_registration: - self.appbuilder.add_view( - self.registerusermodelview, - "User's Statistics", - icon="fa-user-plus", - label=_("User Registrations"), - category="Security", - ) - self.appbuilder.menu.add_separator("Security") - if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEW", True): - self.appbuilder.add_view( - self.actionmodelview, - "Actions", - icon="fa-lock", - label=_("Actions"), - category="Security", - ) - if self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEW_MENU_VIEW", True): - self.appbuilder.add_view( - self.resourcemodelview, - "Resources", - icon="fa-list-alt", - label=_("Resources"), - category="Security", - ) - if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True): - self.appbuilder.add_view( - self.permissionmodelview, - "Permission Pairs", - icon="fa-link", - label=_("Permissions"), - category="Security", - ) - def create_db(self): """Setups the DB, creates admin and public roles if they don't exist.""" roles_mapping = self.appbuilder.get_app.config.get("FAB_ROLES_MAPPING", {}) @@ -1415,7 +1286,7 @@ def _get_user_permission_resources( return result def get_user_menu_access(self, menu_names: list[str] | None = None) -> set[str]: - if auth_manager.is_logged_in(): + if get_auth_manager().is_logged_in(): return self._get_user_permission_resources(g.user, "menu_access", resource_names=menu_names) elif current_user_jwt: return self._get_user_permission_resources( diff --git a/airflow/www/fab_security/sqla/manager.py b/airflow/www/fab_security/sqla/manager.py index 62decfb184b0..c0daf5553ca2 100644 --- a/airflow/www/fab_security/sqla/manager.py +++ b/airflow/www/fab_security/sqla/manager.py @@ -58,7 +58,7 @@ class SecurityManager(BaseSecurityManager): permission_model = Permission registeruser_model = RegisterUser - def __init__(self, appbuilder): + def __init__(self, appbuilder, **kwargs): """ Class constructor. diff --git a/airflow/www/security.py b/airflow/www/security.py index 33188328fc09..d9229c9b1eff 100644 --- a/airflow/www/security.py +++ b/airflow/www/security.py @@ -18,18 +18,18 @@ from __future__ import annotations import warnings -from typing import Any, Collection, Container, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, Sequence from flask import g from sqlalchemy import or_ from sqlalchemy.orm import Session, joinedload -from airflow.configuration import auth_manager from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.models import DagBag, DagModel from airflow.security import permissions from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.fab_security.sqla.manager import SecurityManager from airflow.www.fab_security.sqla.models import Permission, Resource, Role, User from airflow.www.fab_security.views import ( @@ -57,8 +57,14 @@ "Public", } +if TYPE_CHECKING: + SecurityManagerOverride: type = object +else: + # Fetch the security manager override from the auth manager + SecurityManagerOverride = get_auth_manager().get_security_manager_override_class() -class AirflowSecurityManager(SecurityManager, LoggingMixin): + +class AirflowSecurityManager(SecurityManagerOverride, SecurityManager, LoggingMixin): """Custom security manager, which introduces a permission model adapted to Airflow.""" ########################################################################### @@ -193,7 +199,31 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin): userstatschartview = CustomUserStatsChartView def __init__(self, appbuilder) -> None: - super().__init__(appbuilder) + super().__init__( + appbuilder=appbuilder, + actionmodelview=self.actionmodelview, + authdbview=self.authdbview, + authldapview=self.authldapview, + authoauthview=self.authoauthview, + authoidview=self.authoidview, + authremoteuserview=self.authremoteuserview, + permissionmodelview=self.permissionmodelview, + registeruser_view=self.registeruser_view, + registeruserdbview=self.registeruserdbview, + registeruseroauthview=self.registeruseroauthview, + registerusermodelview=self.registerusermodelview, + registeruseroidview=self.registeruseroidview, + resetmypasswordview=self.resetmypasswordview, + resetpasswordview=self.resetpasswordview, + rolemodelview=self.rolemodelview, + userinfoeditview=self.userinfoeditview, + userdbmodelview=self.userdbmodelview, + userldapmodelview=self.userldapmodelview, + useroauthmodelview=self.useroauthmodelview, + useroidmodelview=self.useroidmodelview, + userremoteusermodelview=self.userremoteusermodelview, + userstatschartview=self.userstatschartview, + ) # Go and fix up the SQLAInterface used from the stock one to our subclass. # This is needed to support the "hack" where we had to edit @@ -339,7 +369,7 @@ def get_accessible_dag_ids( if not user_actions: user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] - if not auth_manager.is_logged_in(): + if not get_auth_manager().is_logged_in(): roles = user.roles else: if (permissions.ACTION_CAN_EDIT in user_actions and self.can_edit_all_dags(user)) or ( diff --git a/airflow/www/views.py b/airflow/www/views.py index 359f1ca52b77..54f46575be35 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -83,7 +83,7 @@ set_dag_run_state_to_success, set_state, ) -from airflow.configuration import AIRFLOW_CONFIG, auth_manager, conf +from airflow.configuration import AIRFLOW_CONFIG, conf from airflow.datasets import Dataset from airflow.exceptions import ( AirflowConfigException, @@ -131,6 +131,7 @@ from airflow.version import version from airflow.www import auth, utils as wwwutils from airflow.www.decorators import action_logging, gzipped +from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.forms import ( DagRunEditForm, DateTimeForm, @@ -622,16 +623,17 @@ def method_not_allowed(error): def show_traceback(error): """Show Traceback for a given error.""" + is_logged_in = get_auth_manager().is_logged_in() return ( render_template( "airflow/traceback.html", - python_version=sys.version.split(" ")[0] if auth_manager.is_logged_in() else "redact", - airflow_version=version if auth_manager.is_logged_in() else "redact", + python_version=sys.version.split(" ")[0] if is_logged_in else "redact", + airflow_version=version if is_logged_in else "redact", hostname=get_hostname() - if conf.getboolean("webserver", "EXPOSE_HOSTNAME") and auth_manager.is_logged_in() + if conf.getboolean("webserver", "EXPOSE_HOSTNAME") and is_logged_in else "redact", info=traceback.format_exc() - if conf.getboolean("webserver", "EXPOSE_STACKTRACE") and auth_manager.is_logged_in() + if conf.getboolean("webserver", "EXPOSE_STACKTRACE") and is_logged_in else "Error! Please contact server admin.", ), 500, diff --git a/tests/auh/managers/fab/test_fab_auth_manager.py b/tests/auh/managers/fab/test_fab_auth_manager.py index 4f24b1297ef7..baaec623f47c 100644 --- a/tests/auh/managers/fab/test_fab_auth_manager.py +++ b/tests/auh/managers/fab/test_fab_auth_manager.py @@ -22,6 +22,7 @@ import pytest from airflow.auth.managers.fab.fab_auth_manager import FabAuthManager +from airflow.auth.managers.fab.security_manager_override import FabAirflowSecurityManagerOverride from airflow.www.fab_security.sqla.models import User @@ -55,3 +56,6 @@ def test_is_logged_in(self, mock_current_user, auth_manager): mock_current_user.return_value = user assert auth_manager.is_logged_in() is False + + def test_get_security_manager_override_class_return_fab_security_manager_override(self, auth_manager): + assert auth_manager.get_security_manager_override_class() is FabAirflowSecurityManagerOverride diff --git a/tests/auh/managers/test_base_auth_manager.py b/tests/auh/managers/test_base_auth_manager.py new file mode 100644 index 000000000000..61bee75371a0 --- /dev/null +++ b/tests/auh/managers/test_base_auth_manager.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.auth.managers.base_auth_manager import BaseAuthManager + + +@pytest.fixture +def auth_manager(): + class EmptyAuthManager(BaseAuthManager): + def get_user_name(self) -> str: + raise NotImplementedError() + + return EmptyAuthManager() + + +class TestBaseAuthManager: + def test_get_security_manager_override_class_return_empty_class(self, auth_manager): + assert auth_manager.get_security_manager_override_class() is object