From b57df12ace583526ab46752a3bd914996f3abd1a Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Thu, 17 Oct 2024 14:21:39 +0200 Subject: [PATCH] core: extract object matching from flow manager (#11458) --- authentik/core/sources/flow_manager.py | 143 +++-------------------- authentik/core/sources/matcher.py | 152 +++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 128 deletions(-) create mode 100644 authentik/core/sources/matcher.py diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py index 86b78d47efcf..5ec95df0c21b 100644 --- a/authentik/core/sources/flow_manager.py +++ b/authentik/core/sources/flow_manager.py @@ -1,11 +1,9 @@ """Source decision helper""" -from enum import Enum from typing import Any from django.contrib import messages from django.db import IntegrityError, transaction -from django.db.models.query_utils import Q from django.http import HttpRequest, HttpResponse from django.shortcuts import redirect from django.urls import reverse @@ -16,12 +14,11 @@ Group, GroupSourceConnection, Source, - SourceGroupMatchingModes, - SourceUserMatchingModes, User, UserSourceConnection, ) from authentik.core.sources.mapper import SourceMapper +from authentik.core.sources.matcher import Action, SourceMatcher from authentik.core.sources.stage import ( PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage, @@ -54,16 +51,6 @@ PLAN_CONTEXT_SOURCE_GROUPS = "source_groups" -class Action(Enum): - """Actions that can be decided based on the request - and source settings""" - - LINK = "link" - AUTH = "auth" - ENROLL = "enroll" - DENY = "deny" - - class MessageStage(StageView): """Show a pre-configured message after the flow is done""" @@ -86,6 +73,7 @@ class SourceFlowManager: source: Source mapper: SourceMapper + matcher: SourceMatcher request: HttpRequest identifier: str @@ -108,6 +96,9 @@ def __init__( ) -> None: self.source = source self.mapper = SourceMapper(self.source) + self.matcher = SourceMatcher( + self.source, self.user_connection_type, self.group_connection_type + ) self.request = request self.identifier = identifier self.user_info = user_info @@ -131,66 +122,19 @@ def __init__( def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911 """decide which action should be taken""" - new_connection = self.user_connection_type(source=self.source, identifier=self.identifier) # When request is authenticated, always link if self.request.user.is_authenticated: + new_connection = self.user_connection_type( + source=self.source, identifier=self.identifier + ) new_connection.user = self.request.user new_connection = self.update_user_connection(new_connection, **kwargs) return Action.LINK, new_connection - existing_connections = self.user_connection_type.objects.filter( - source=self.source, identifier=self.identifier - ) - if existing_connections.exists(): - connection = existing_connections.first() - return Action.AUTH, self.update_user_connection(connection, **kwargs) - # No connection exists, but we match on identifier, so enroll - if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER: - # We don't save the connection here cause it doesn't have a user assigned yet - return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) - - # Check for existing users with matching attributes - query = Q() - # Either query existing user based on email or username - if self.source.user_matching_mode in [ - SourceUserMatchingModes.EMAIL_LINK, - SourceUserMatchingModes.EMAIL_DENY, - ]: - if not self.user_properties.get("email", None): - self._logger.warning("Refusing to use none email") - return Action.DENY, None - query = Q(email__exact=self.user_properties.get("email", None)) - if self.source.user_matching_mode in [ - SourceUserMatchingModes.USERNAME_LINK, - SourceUserMatchingModes.USERNAME_DENY, - ]: - if not self.user_properties.get("username", None): - self._logger.warning("Refusing to use none username") - return Action.DENY, None - query = Q(username__exact=self.user_properties.get("username", None)) - self._logger.debug("trying to link with existing user", query=query) - matching_users = User.objects.filter(query) - # No matching users, always enroll - if not matching_users.exists(): - self._logger.debug("no matching users found, enrolling") - return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) - - user = matching_users.first() - if self.source.user_matching_mode in [ - SourceUserMatchingModes.EMAIL_LINK, - SourceUserMatchingModes.USERNAME_LINK, - ]: - new_connection.user = user - new_connection = self.update_user_connection(new_connection, **kwargs) - return Action.LINK, new_connection - if self.source.user_matching_mode in [ - SourceUserMatchingModes.EMAIL_DENY, - SourceUserMatchingModes.USERNAME_DENY, - ]: - self._logger.info("denying source because user exists", user=user) - return Action.DENY, None - # Should never get here as default enroll case is returned above. - return Action.DENY, None # pragma: no cover + action, connection = self.matcher.get_user_action(self.identifier, self.user_properties) + if connection: + connection = self.update_user_connection(connection, **kwargs) + return action, connection def update_user_connection( self, connection: UserSourceConnection, **kwargs @@ -408,74 +352,16 @@ def handle_enroll( class GroupUpdateStage(StageView): """Dynamically injected stage which updates the user after enrollment/authentication.""" - def get_action( - self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] - ) -> tuple[Action, GroupSourceConnection | None]: - """decide which action should be taken""" - new_connection = self.group_connection_type(source=self.source, identifier=group_id) - - existing_connections = self.group_connection_type.objects.filter( - source=self.source, identifier=group_id - ) - if existing_connections.exists(): - return Action.LINK, existing_connections.first() - # No connection exists, but we match on identifier, so enroll - if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER: - # We don't save the connection here cause it doesn't have a user assigned yet - return Action.ENROLL, new_connection - - # Check for existing groups with matching attributes - query = Q() - if self.source.group_matching_mode in [ - SourceGroupMatchingModes.NAME_LINK, - SourceGroupMatchingModes.NAME_DENY, - ]: - if not group_properties.get("name", None): - LOGGER.warning( - "Refusing to use none group name", source=self.source, group_id=group_id - ) - return Action.DENY, None - query = Q(name__exact=group_properties.get("name")) - LOGGER.debug( - "trying to link with existing group", source=self.source, query=query, group_id=group_id - ) - matching_groups = Group.objects.filter(query) - # No matching groups, always enroll - if not matching_groups.exists(): - LOGGER.debug( - "no matching groups found, enrolling", source=self.source, group_id=group_id - ) - return Action.ENROLL, new_connection - - group = matching_groups.first() - if self.source.group_matching_mode in [ - SourceGroupMatchingModes.NAME_LINK, - ]: - new_connection.group = group - return Action.LINK, new_connection - if self.source.group_matching_mode in [ - SourceGroupMatchingModes.NAME_DENY, - ]: - LOGGER.info( - "denying source because group exists", - source=self.source, - group=group, - group_id=group_id, - ) - return Action.DENY, None - # Should never get here as default enroll case is returned above. - return Action.DENY, None # pragma: no cover - def handle_group( self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] ) -> Group | None: - action, connection = self.get_action(group_id, group_properties) + action, connection = self.matcher.get_group_action(group_id, group_properties) if action == Action.ENROLL: group = Group.objects.create(**group_properties) connection.group = group connection.save() return group - elif action == Action.LINK: + elif action in (Action.LINK, Action.AUTH): group = connection.group group.update_attributes(group_properties) connection.save() @@ -489,6 +375,7 @@ def handle_groups(self) -> bool: self.group_connection_type: GroupSourceConnection = ( self.executor.current_stage.group_connection_type ) + self.matcher = SourceMatcher(self.source, None, self.group_connection_type) raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[ PLAN_CONTEXT_SOURCE_GROUPS diff --git a/authentik/core/sources/matcher.py b/authentik/core/sources/matcher.py new file mode 100644 index 000000000000..f45792761e02 --- /dev/null +++ b/authentik/core/sources/matcher.py @@ -0,0 +1,152 @@ +"""Source user and group matching""" + +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from django.db.models import Q +from structlog import get_logger + +from authentik.core.models import ( + Group, + GroupSourceConnection, + Source, + SourceGroupMatchingModes, + SourceUserMatchingModes, + User, + UserSourceConnection, +) + + +class Action(Enum): + """Actions that can be decided based on the request and source settings""" + + LINK = "link" + AUTH = "auth" + ENROLL = "enroll" + DENY = "deny" + + +@dataclass +class MatchableProperty: + property: str + link_mode: SourceUserMatchingModes | SourceGroupMatchingModes + deny_mode: SourceUserMatchingModes | SourceGroupMatchingModes + + +class SourceMatcher: + def __init__( + self, + source: Source, + user_connection_type: type[UserSourceConnection], + group_connection_type: type[GroupSourceConnection], + ): + self.source = source + self.user_connection_type = user_connection_type + self.group_connection_type = group_connection_type + self._logger = get_logger().bind(source=self.source) + + def get_action( + self, + object_type: type[User | Group], + matchable_properties: list[MatchableProperty], + identifier: str, + properties: dict[str, Any | dict[str, Any]], + ) -> tuple[Action, UserSourceConnection | GroupSourceConnection | None]: + connection_type = None + matching_mode = None + identifier_matching_mode = None + if object_type == User: + connection_type = self.user_connection_type + matching_mode = self.source.user_matching_mode + identifier_matching_mode = SourceUserMatchingModes.IDENTIFIER + if object_type == Group: + connection_type = self.group_connection_type + matching_mode = self.source.group_matching_mode + identifier_matching_mode = SourceGroupMatchingModes.IDENTIFIER + if not connection_type or not matching_mode or not identifier_matching_mode: + return Action.DENY, None + + new_connection = connection_type(source=self.source, identifier=identifier) + + existing_connections = connection_type.objects.filter( + source=self.source, identifier=identifier + ) + if existing_connections.exists(): + return Action.AUTH, existing_connections.first() + # No connection exists, but we match on identifier, so enroll + if matching_mode == identifier_matching_mode: + # We don't save the connection here cause it doesn't have a user/group assigned yet + return Action.ENROLL, new_connection + + # Check for existing users with matching attributes + query = Q() + for matchable_property in matchable_properties: + property = matchable_property.property + if matching_mode in [matchable_property.link_mode, matchable_property.deny_mode]: + if not properties.get(property, None): + self._logger.warning( + "Refusing to use none property", identifier=identifier, property=property + ) + return Action.DENY, None + query_args = { + f"{property}__exact": properties[property], + } + query = Q(**query_args) + self._logger.debug( + "Trying to link with existing object", query=query, identifier=identifier + ) + matching_objects = object_type.objects.filter(query) + # Not matching objects, always enroll + if not matching_objects.exists(): + self._logger.debug("No matching objects found, enrolling") + return Action.ENROLL, new_connection + + obj = matching_objects.first() + if matching_mode in [mp.link_mode for mp in matchable_properties]: + attr = None + if object_type == User: + attr = "user" + if object_type == Group: + attr = "group" + setattr(new_connection, attr, obj) + return Action.LINK, new_connection + if matching_mode in [mp.deny_mode for mp in matchable_properties]: + self._logger.info("Denying source because object exists", obj=obj) + return Action.DENY, None + + # Should never get here as default enroll case is returned above. + return Action.DENY, None # pragma: no cover + + def get_user_action( + self, identifier: str, properties: dict[str, Any | dict[str, Any]] + ) -> tuple[Action, UserSourceConnection | None]: + return self.get_action( + User, + [ + MatchableProperty( + "username", + SourceUserMatchingModes.USERNAME_LINK, + SourceUserMatchingModes.USERNAME_DENY, + ), + MatchableProperty( + "email", SourceUserMatchingModes.EMAIL_LINK, SourceUserMatchingModes.EMAIL_DENY + ), + ], + identifier, + properties, + ) + + def get_group_action( + self, identifier: str, properties: dict[str, Any | dict[str, Any]] + ) -> tuple[Action, GroupSourceConnection | None]: + return self.get_action( + Group, + [ + MatchableProperty( + "name", SourceGroupMatchingModes.NAME_LINK, SourceGroupMatchingModes.NAME_DENY + ), + ], + identifier, + properties, + )