Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up OAuth2 helper functions #898

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog.d/+oauth2.removed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
As part of refactoring some authentication utility functions the function
`get_psa_authentication_names()` has been removed as it wasn't used anywhere in
Argus proper.
35 changes: 24 additions & 11 deletions src/argus/auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
from django.conf import settings
from django.contrib.auth.backends import ModelBackend
from django.contrib.auth.backends import ModelBackend, RemoteUserBackend
from django.utils.module_loading import import_string

from rest_framework.reverse import reverse
from social_core.backends.base import BaseAuth

from social_core.backends.oauth import BaseOAuth2


_all__ = [
"get_authentication_backend_classes",
"has_model_backend",
"has_remote_user_backend",
"get_psa_authentication_backends",
"get_authentication_backend_name_and_type",
]


def get_authentication_backend_classes():
backend_dotted_paths = getattr(settings, "AUTHENTICATION_BACKENDS")
backends = [import_string(path) for path in backend_dotted_paths]
return backends


def get_psa_authentication_names(backends=None):
def has_model_backend(backends):
return ModelBackend in backends


def has_remote_user_backend(backends):
return RemoteUserBackend in backends


def get_psa_authentication_backends(backends=None):
backends = backends if backends else get_authentication_backend_classes()
psa_backends = set()
for backend in backends:
if issubclass(backend, BaseAuth):
psa_backends.add(backend.name)
return sorted(psa_backends)
return [backend for backend in backends if issubclass(backend, BaseOAuth2)]


def get_authentication_backend_name_and_type(request):
# Needed for SPA /login-methods/ API endpoint
backends = get_authentication_backend_classes()
data = []
if ModelBackend in backends:
if has_model_backend(backends):
data.append(
{
"type": "username_password",
Expand All @@ -40,8 +54,7 @@ def get_authentication_backend_name_and_type(request):
"url": reverse("social:begin", kwargs={"backend": backend.name}, request=request),
"name": backend.name,
}
for backend in backends
if issubclass(backend, BaseOAuth2)
for backend in get_psa_authentication_backends(backends)
)

return data
7 changes: 1 addition & 6 deletions tests/auth/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.contrib.auth.backends import ModelBackend
from django.test import TestCase

from argus.auth.utils import get_authentication_backend_name_and_type, get_psa_authentication_names
from argus.auth.utils import get_authentication_backend_name_and_type
from argus.dataporten.social import DataportenFeideOAuth2


Expand Down Expand Up @@ -33,8 +33,3 @@ def test_get_authentication_backend_name_and_type_returns_feide_login(
"name": "dataporten_feide",
}
]

@patch("argus.auth.utils.get_authentication_backend_classes")
def test_get_psa_authentication_names_returns_feide_name(self, mock_get_authentication_backend_classes):
mock_get_authentication_backend_classes.return_value = [DataportenFeideOAuth2]
assert get_psa_authentication_names() == ["dataporten_feide"]
Loading