diff --git a/ansible_base/authentication/utils/authentication.py b/ansible_base/authentication/utils/authentication.py index 4992b7580..ccf2a6acd 100644 --- a/ansible_base/authentication/utils/authentication.py +++ b/ansible_base/authentication/utils/authentication.py @@ -61,7 +61,7 @@ def determine_username_from_uid(uid: str = None, authenticator: Authenticator = new_username = get_local_username({'username': uid}) logger.info( f'Authenticator {authenticator.name} wants to authenticate {uid} but that' - f'username is already in use by another authenticator,' + f' username is already in use by another authenticator,' f' the user from this authenticator will be {new_username}' ) return new_username diff --git a/ansible_base/lib/dynamic_config/dynamic_settings.py b/ansible_base/lib/dynamic_config/dynamic_settings.py index fa7d893a2..b5c5ea896 100644 --- a/ansible_base/lib/dynamic_config/dynamic_settings.py +++ b/ansible_base/lib/dynamic_config/dynamic_settings.py @@ -168,3 +168,37 @@ ORG_ADMINS_CAN_SEE_ALL_USERS except NameError: ORG_ADMINS_CAN_SEE_ALL_USERS = True + + +if 'ansible_base.oauth2_provider' in INSTALLED_APPS: # noqa: F821 + if 'oauth2_provider' not in INSTALLED_APPS: # noqa: F821 + INSTALLED_APPS.append('oauth2_provider') # noqa: F821 + + try: + OAUTH2_PROVIDER # noqa: F821 + except NameError: + OAUTH2_PROVIDER = {} + + if 'ACCESS_TOKEN_EXPIRE_SECONDS' not in OAUTH2_PROVIDER: + OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] = 31536000000 + if 'AUTHORIZATION_CODE_EXPIRE_SECONDS' not in OAUTH2_PROVIDER: + OAUTH2_PROVIDER['AUTHORIZATION_CODE_EXPIRE_SECONDS'] = 600 + if 'REFRESH_TOKEN_EXPIRE_SECONDS' not in OAUTH2_PROVIDER: + OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 2628000 + + OAUTH2_PROVIDER['APPLICATION_MODEL'] = 'dab_oauth2_provider.OAuth2Application' + OAUTH2_PROVIDER['ACCESS_TOKEN_MODEL'] = 'dab_oauth2_provider.OAuth2AccessToken' + + oauth2_authentication_class = 'ansible_base.oauth2_provider.authentication.LoggedOAuth2Authentication' + if 'DEFAULT_AUTHENTICATION_CLASSES' not in REST_FRAMEWORK: # noqa: F821 + REST_FRAMEWORK['DEFAULT_AUTHENTICATION_CLASSES'] = [] # noqa: F821 + if oauth2_authentication_class not in REST_FRAMEWORK['DEFAULT_AUTHENTICATION_CLASSES']: # noqa: F821 + REST_FRAMEWORK['DEFAULT_AUTHENTICATION_CLASSES'].insert(0, oauth2_authentication_class) # noqa: F821 + + # These have to be defined for the migration to function + OAUTH2_PROVIDER_APPLICATION_MODEL = 'dab_oauth2_provider.OAuth2Application' + OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = 'dab_oauth2_provider.OAuth2AccessToken' + OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = "dab_oauth2_provider.OAuth2RefreshToken" + OAUTH2_PROVIDER_ID_TOKEN_MODEL = "dab_oauth2_provider.OAuth2IDToken" + + ALLOW_OAUTH2_FOR_EXTERNAL_USERS = False diff --git a/ansible_base/lib/serializers/common.py b/ansible_base/lib/serializers/common.py index 2409b5860..8d9ef0972 100644 --- a/ansible_base/lib/serializers/common.py +++ b/ansible_base/lib/serializers/common.py @@ -47,10 +47,16 @@ def is_list_view(self) -> bool: def _get_related(self, obj) -> dict[str, str]: if obj is None: return {} + related_fields = {} + view = self.context.get('view') + if view is not None and hasattr(view, 'extra_related_fields'): + related_fields.update(view.extra_related_fields(obj)) if not hasattr(obj, 'related_fields'): logger.warning(f"Object {obj.__class__} has no related_fields method") - return {} - return obj.related_fields(self.context.get('request')) + else: + related_fields.update(obj.related_fields(self.context.get('request'))) + + return related_fields def _get_summary_fields(self, obj) -> dict[str, dict]: if obj is None: diff --git a/ansible_base/lib/utils/views/ansible_base.py b/ansible_base/lib/utils/views/ansible_base.py index a7b27a3ca..683a4381b 100644 --- a/ansible_base/lib/utils/views/ansible_base.py +++ b/ansible_base/lib/utils/views/ansible_base.py @@ -51,3 +51,13 @@ def finalize_response(self, request, response, *args, **kwargs): response['Warning'] = _('This resource has been deprecated and will be removed in a future release.') return response + + def extra_related_fields(self, obj): + """ + A hook for adding extra related fields to serializers which + make use of this view/viewset. + + This is particularly useful for mixins which want to extend a viewset + with additional actions and provide those actions as related fields. + """ + return {} diff --git a/ansible_base/oauth2_provider/__init__.py b/ansible_base/oauth2_provider/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ansible_base/oauth2_provider/admin.py b/ansible_base/oauth2_provider/admin.py new file mode 100644 index 000000000..4fd549025 --- /dev/null +++ b/ansible_base/oauth2_provider/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin # noqa: F401 + +# Register your models here. diff --git a/ansible_base/oauth2_provider/apps.py b/ansible_base/oauth2_provider/apps.py new file mode 100644 index 000000000..9b549e4c3 --- /dev/null +++ b/ansible_base/oauth2_provider/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class Oauth2ProviderConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'ansible_base.oauth2_provider' + label = 'dab_oauth2_provider' diff --git a/ansible_base/oauth2_provider/authentication.py b/ansible_base/oauth2_provider/authentication.py new file mode 100644 index 000000000..a1879177e --- /dev/null +++ b/ansible_base/oauth2_provider/authentication.py @@ -0,0 +1,20 @@ +import logging + +from django.utils.encoding import smart_str +from oauth2_provider.contrib.rest_framework import OAuth2Authentication + +logger = logging.getLogger('ansible_base.oauth2_provider.authentication') + + +class LoggedOAuth2Authentication(OAuth2Authentication): + def authenticate(self, request): + ret = super().authenticate(request) + if ret: + user, token = ret + username = user.username if user else '' + logger.info( + smart_str(u"User {} performed a {} to {} through the API using OAuth 2 token {}.".format(username, request.method, request.path, token.pk)) + ) + # TODO: check oauth_scopes when we have RBAC in Gateway + setattr(user, 'oauth_scopes', [x for x in token.scope.split() if x]) + return ret diff --git a/ansible_base/oauth2_provider/migrations/0001_initial.py b/ansible_base/oauth2_provider/migrations/0001_initial.py new file mode 100644 index 000000000..5732c858d --- /dev/null +++ b/ansible_base/oauth2_provider/migrations/0001_initial.py @@ -0,0 +1,130 @@ +# Generated by Django 4.2.8 on 2024-02-11 20:16 + +import re +import uuid + +import django.core.validators +import django.db.models.deletion +import oauth2_provider.generators +from django.conf import settings +from django.db import migrations, models + +import ansible_base.oauth2_provider.models.application + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + migrations.swappable_dependency(settings.ANSIBLE_BASE_ORGANIZATION_MODEL), + ] + + run_before = [ + ('oauth2_provider', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='OAuth2Application', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created', models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now, editable=False, help_text='The date/time this resource was created')), + ('modified', models.DateTimeField(default=None, editable=False, help_text='The date/time this resource was created')), + ('name', models.CharField(blank=True, help_text='The name of this resource', max_length=255)), + ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), + ('description', models.TextField(blank=True, default='')), + ('logo_data', models.TextField(default='', editable=False, validators=[django.core.validators.RegexValidator(re.compile('.*'))])), + ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, help_text='Used for more stringent verification of access to an application when creating a token.', max_length=1024)), + ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], help_text='Set to Public or Confidential depending on how secure the client device is.', max_length=32)), + ('skip_authorization', models.BooleanField(default=False, help_text='Set True to skip authorization step for completely trusted applications.')), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('password', 'Resource owner password-based')], help_text='The Grant type the user must use for acquire tokens for this application.', max_length=32)), + ('created_by', models.ForeignKey(default=None, editable=False, help_text='The user who created this resource', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='%(app_label)s_%(class)s_created+', to=settings.AUTH_USER_MODEL)), + ('modified_by', models.ForeignKey(default=None, editable=False, help_text='The user who last modified this resource', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='%(app_label)s_%(class)s_modified+', to=settings.AUTH_USER_MODEL)), + ('organization', models.ForeignKey(help_text='Organization containing this application.', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='applications', to=settings.ANSIBLE_BASE_ORGANIZATION_MODEL)), + ('algorithm', models.CharField(blank=True, choices=[('', 'No OIDC support'), ('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='', max_length=5)), + ('post_logout_redirect_uris', models.TextField(blank=True, help_text='Allowed Post Logout URIs list, space separated')), + ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), + ('updated', models.DateTimeField(auto_now=True)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(app_label)s_%(class)s', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': 'application', + 'ordering': ('organization', 'name'), + 'swappable': 'OAUTH2_PROVIDER_APPLICATION_MODEL', + 'unique_together': {('name', 'organization')}, + }, + ), + migrations.CreateModel( + name='OAuth2IDToken', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created', models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now, editable=False, help_text='The date/time this resource was created')), + ('modified', models.DateTimeField(default=None, editable=False, help_text='The date/time this resource was created')), + ('created_by', models.ForeignKey(default=None, editable=False, help_text='The user who created this resource', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='%(app_label)s_%(class)s_created+', to=settings.AUTH_USER_MODEL)), + ('modified_by', models.ForeignKey(default=None, editable=False, help_text='The user who last modified this resource', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='%(app_label)s_%(class)s_modified+', to=settings.AUTH_USER_MODEL)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), + ('expires', models.DateTimeField(default=None)), + ('jti', models.UUIDField(default=uuid.uuid4, editable=False, unique=True, verbose_name='JWT Token ID')), + ('scope', models.TextField(blank=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(app_label)s_%(class)s', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': 'id token', + 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', + }, + ), + migrations.CreateModel( + name='OAuth2RefreshToken', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created', models.DateTimeField(default=None, editable=False, help_text='The date/time this resource was created')), + ('modified', models.DateTimeField(default=None, editable=False, help_text='The date/time this resource was created')), + ('created_by', models.ForeignKey(default=None, editable=False, help_text='The user who created this resource', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='%(app_label)s_%(class)s_created+', to=settings.AUTH_USER_MODEL)), + ('modified_by', models.ForeignKey(default=None, editable=False, help_text='The user who last modified this resource', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='%(app_label)s_%(class)s_modified+', to=settings.AUTH_USER_MODEL)), + ('application', models.ForeignKey(default='', on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), + ('revoked', models.DateTimeField(null=True)), + ('token', models.CharField(default='', max_length=255)), + ('updated', models.DateTimeField(auto_now=True)), + ('user', models.ForeignKey(default='', on_delete=django.db.models.deletion.CASCADE, related_name='%(app_label)s_%(class)s', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': 'access token', + 'ordering': ('id',), + 'swappable': 'OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL', + 'unique_together': {('token', 'revoked')}, + }, + ), + migrations.CreateModel( + name='OAuth2AccessToken', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created', models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now, editable=False, help_text='The date/time this resource was created')), + ('modified', models.DateTimeField(default=None, editable=False, help_text='The date/time this resource was created')), + ('description', models.TextField(blank=True, default='')), + ('last_used', models.DateTimeField(default=None, editable=False, null=True)), + ('scope', models.CharField(blank=True, choices=[('read', 'read'), ('write', 'write')], default='write', help_text="Allowed scopes, further restricts user's permissions. Must be a simple space-separated string with allowed scopes ['read', 'write'].", max_length=32)), + ('created_by', models.ForeignKey(default=None, editable=False, help_text='The user who created this resource', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='%(app_label)s_%(class)s_created+', to=settings.AUTH_USER_MODEL)), + ('modified_by', models.ForeignKey(default=None, editable=False, help_text='The user who last modified this resource', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='%(app_label)s_%(class)s_modified+', to=settings.AUTH_USER_MODEL)), + ('user', models.ForeignKey(blank=True, help_text='The user representing the token owner', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(app_label)s_%(class)s', to=settings.AUTH_USER_MODEL)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), + ('expires', models.DateTimeField(default=None)), + ('token', models.CharField(default='', max_length=255, unique=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('id_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL)), + ('source_refresh_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='refreshed_access_token', to=settings.OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL)), + ], + options={ + 'verbose_name': 'access token', + 'ordering': ('id',), + 'swappable': 'OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL', + }, + ), + migrations.AddField( + model_name='oauth2refreshtoken', + name='access_token', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='refresh_token', to=settings.OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL), + ), + ] diff --git a/ansible_base/oauth2_provider/migrations/0002_alter_oauth2refreshtoken_options_and_more.py b/ansible_base/oauth2_provider/migrations/0002_alter_oauth2refreshtoken_options_and_more.py new file mode 100644 index 000000000..36547007b --- /dev/null +++ b/ansible_base/oauth2_provider/migrations/0002_alter_oauth2refreshtoken_options_and_more.py @@ -0,0 +1,198 @@ +# Generated by Django 4.2.11 on 2024-05-08 14:27 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import oauth2_provider.generators +import oauth2_provider.models + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('dab_oauth2_provider', '0001_initial'), + ] + + operations = [ + migrations.AlterModelOptions( + name='oauth2refreshtoken', + options={'ordering': ('id',), 'verbose_name': 'refresh token'}, + ), + migrations.RemoveField( + model_name='oauth2accesstoken', + name='updated', + ), + migrations.RemoveField( + model_name='oauth2application', + name='updated', + ), + migrations.RemoveField( + model_name='oauth2idtoken', + name='updated', + ), + migrations.RemoveField( + model_name='oauth2refreshtoken', + name='updated', + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='application', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_tokens', to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='created', + field=models.DateTimeField(auto_now_add=True, help_text='The date/time this resource was created'), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='created_by', + field=models.ForeignKey(default=None, editable=False, help_text='The user who created this resource', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='%(app_label)s_%(class)s_created+', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='expires', + field=models.DateTimeField(), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='id', + field=models.BigAutoField(primary_key=True, serialize=False), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='modified', + field=models.DateTimeField(auto_now=True, help_text='The date/time this resource was created'), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='modified_by', + field=models.ForeignKey(default=None, editable=False, help_text='The user who last modified this resource', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='%(app_label)s_%(class)s_modified+', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='scope', + field=models.CharField(blank=True, choices=[('read', 'Read'), ('write', 'Write')], default='write', help_text="Allowed scopes, further restricts user's permissions. Must be a simple space-separated string with allowed scopes ['read', 'write'].", max_length=32), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='token', + field=models.CharField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='oauth2accesstoken', + name='user', + field=models.ForeignKey(blank=True, help_text='The user representing the token owner', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_tokens', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2application', + name='client_secret', + field=oauth2_provider.models.ClientSecretField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, help_text='Hashed on Save. Copy it now if this is a new secret.', max_length=255), + ), + migrations.AlterField( + model_name='oauth2application', + name='created', + field=models.DateTimeField(auto_now_add=True, help_text='The date/time this resource was created'), + ), + migrations.AlterField( + model_name='oauth2application', + name='created_by', + field=models.ForeignKey(default=None, editable=False, help_text='The user who created this resource', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='%(app_label)s_%(class)s_created+', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2application', + name='id', + field=models.BigAutoField(primary_key=True, serialize=False), + ), + migrations.AlterField( + model_name='oauth2application', + name='modified', + field=models.DateTimeField(auto_now=True, help_text='The date/time this resource was created'), + ), + migrations.AlterField( + model_name='oauth2application', + name='modified_by', + field=models.ForeignKey(default=None, editable=False, help_text='The user who last modified this resource', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='%(app_label)s_%(class)s_modified+', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2application', + name='name', + field=models.CharField(help_text='The name of this resource', max_length=512), + ), + migrations.AlterField( + model_name='oauth2application', + name='user', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='applications', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2idtoken', + name='created', + field=models.DateTimeField(auto_now_add=True, help_text='The date/time this resource was created'), + ), + migrations.AlterField( + model_name='oauth2idtoken', + name='created_by', + field=models.ForeignKey(default=None, editable=False, help_text='The user who created this resource', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='%(app_label)s_%(class)s_created+', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2idtoken', + name='expires', + field=models.DateTimeField(), + ), + migrations.AlterField( + model_name='oauth2idtoken', + name='id', + field=models.BigAutoField(primary_key=True, serialize=False), + ), + migrations.AlterField( + model_name='oauth2idtoken', + name='modified', + field=models.DateTimeField(auto_now=True, help_text='The date/time this resource was created'), + ), + migrations.AlterField( + model_name='oauth2idtoken', + name='modified_by', + field=models.ForeignKey(default=None, editable=False, help_text='The user who last modified this resource', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='%(app_label)s_%(class)s_modified+', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2refreshtoken', + name='application', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL), + ), + migrations.AlterField( + model_name='oauth2refreshtoken', + name='created', + field=models.DateTimeField(auto_now_add=True, help_text='The date/time this resource was created'), + ), + migrations.AlterField( + model_name='oauth2refreshtoken', + name='created_by', + field=models.ForeignKey(default=None, editable=False, help_text='The user who created this resource', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='%(app_label)s_%(class)s_created+', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2refreshtoken', + name='id', + field=models.BigAutoField(primary_key=True, serialize=False), + ), + migrations.AlterField( + model_name='oauth2refreshtoken', + name='modified', + field=models.DateTimeField(auto_now=True, help_text='The date/time this resource was created'), + ), + migrations.AlterField( + model_name='oauth2refreshtoken', + name='modified_by', + field=models.ForeignKey(default=None, editable=False, help_text='The user who last modified this resource', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='%(app_label)s_%(class)s_modified+', to=settings.AUTH_USER_MODEL), + ), + migrations.AlterField( + model_name='oauth2refreshtoken', + name='token', + field=models.CharField(max_length=255), + ), + migrations.AlterField( + model_name='oauth2refreshtoken', + name='user', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='%(app_label)s_%(class)s', to=settings.AUTH_USER_MODEL), + ), + ] diff --git a/ansible_base/oauth2_provider/migrations/__init__.py b/ansible_base/oauth2_provider/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ansible_base/oauth2_provider/models/__init__.py b/ansible_base/oauth2_provider/models/__init__.py new file mode 100644 index 000000000..9864347a0 --- /dev/null +++ b/ansible_base/oauth2_provider/models/__init__.py @@ -0,0 +1,60 @@ +from .access_token import OAuth2AccessToken +from .application import OAuth2Application +from .id_token import OAuth2IDToken +from .refresh_token import OAuth2RefreshToken + +__all__ = ( + 'OAuth2AccessToken', + 'OAuth2Application', + 'OAuth2IDToken', + 'OAuth2RefreshToken', +) + +# +# There were a lot of problems making the initial migrations for this class +# See https://github.com/jazzband/django-oauth-toolkit/issues/634 which helped +# +# Here were my steps: +# 1. Make sure 'ansible_base.oauth2_provider' is in test_app INSTALLED_APPS (this already should be) +# 2. Comment out all OAUTH2_ settings in ansible_base/lib/dynamic_config/dynamic_settings.py which reference dab_oauth2_provider.* +# 3. Change all model classes to: +# remove oauth2_models.Abstract* as superclasses (including the meta ones) +# comment out the "import oauth2_provider.models as oauth2_models" imports +# 4. ./manage.py makemigrations dab_oauth2_provider +# 5. Edit the created 0001 migration and delete the dpendency on ('test_app', '000X') +# 6. ./manage.py migrate dab_oauth2_provider +# 7. Look at the generated migration, if this has a direct reference to your applications organization model in OAuth2Application model we need to update it +# for example, if it looks like: +# ('organization', ... to='.organization')), +# We want to change this to reference the setting: +# ('organization', ... to=settings.ANSIBLE_BASE_ORGANIZATION_MODEL)), +# We should also add this in the migration dependencies: +# migrations.swappable_dependency(settings.ANSIBLE_BASE_ORGANIZATION_MODEL), +# 8. Uncomment all OAUTH2_PROVIDER_* settings +# 9. Revert step 3 +# 10. gateway-manage makemigrations && gateway-manage migrate ansible_base +# When you do this django does not realize that you are creating an initial migration and tell you its impossible to migrate so fields +# It will ask you to either: 1. Enter a default 2. Quit +# Tell it to use the default if it has one populated at the prompt. Other wise use django.utils.timezone.now for timestamps and '' for other items +# This wont matter for us because there will be no data in the tables between these two migrations +# 11. You can now combine the migration into one. +# Add the `import uuid` to the top of the initial migration file +# Copy all of the operations from the second file to the first +# Find the AddFields commands for oauth2refreshtoken.access_token and oauth2accesstoken.source_refresh_token and move them to the end of the operations +# If desired, convert the remaining AddFilds into actual fields on the table creation. For example: +# migrations.AddField( +# model_name='oauth2accesstoken', +# name='created', +# field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), +# preserve_default=False, +# ), +# Would become the following field on the oauth2accesstoken table: +# ('created', models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now)), +# Next put the table creation in the following order: OAuth2Application, OAuth2IDToken, OAuth2RefreshToken, OAuth2AccessToken +# Finally, be sure to add this to the migration file: +# run_before = [ +# ('oauth2_provider', '0001_initial'), +# ] +# 12. Delete the new migration +# 13. zero out the initial migration with: ./manage.py migrate dab_oauth2_provider zero +# 14. Make the actual migration with: ./manage.py migrate dab_oauth2_provider diff --git a/ansible_base/oauth2_provider/models/access_token.py b/ansible_base/oauth2_provider/models/access_token.py new file mode 100644 index 000000000..175cc2baa --- /dev/null +++ b/ansible_base/oauth2_provider/models/access_token.py @@ -0,0 +1,98 @@ +import oauth2_provider.models as oauth2_models +from django.conf import settings +from django.db import connection, models +from django.utils.timezone import now +from django.utils.translation import gettext_lazy as _ +from oauthlib import oauth2 + +from ansible_base.lib.abstract_models.common import CommonModel +from ansible_base.lib.utils.models import prevent_search +from ansible_base.lib.utils.settings import get_setting +from ansible_base.oauth2_provider.utils import is_external_account + +activitystream = object +if 'ansible_base.activitystream' in settings.INSTALLED_APPS: + from ansible_base.activitystream.models import AuditableModel + + activitystream = AuditableModel + + +class OAuth2AccessToken(CommonModel, oauth2_models.AbstractAccessToken, activitystream): + router_basename = 'token' + ignore_relations = ['refresh_token'] + + class Meta(oauth2_models.AbstractAccessToken.Meta): + verbose_name = _('access token') + ordering = ('id',) + swappable = "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL" + + SCOPE_CHOICES = [ + ('read', _('Read')), + ('write', _('Write')), + ] + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="access_tokens", + help_text=_('The user representing the token owner'), + ) + # Overriding to set related_name + application = models.ForeignKey( + settings.OAUTH2_PROVIDER_APPLICATION_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name='access_tokens', + ) + description = models.TextField( + default='', + blank=True, + ) + last_used = models.DateTimeField( + null=True, + default=None, + editable=False, + ) + scope = models.CharField( + blank=True, + default='write', + max_length=32, + choices=SCOPE_CHOICES, + help_text=_("Allowed scopes, further restricts user's permissions. Must be a simple space-separated string with allowed scopes ['read', 'write']."), + ) + token = prevent_search( + models.CharField( + max_length=255, + unique=True, + ) + ) + updated = None # Tracked in CommonModel with 'modified', no need for this + + def is_valid(self, scopes=None): + valid = super(OAuth2AccessToken, self).is_valid(scopes) + if valid: + self.last_used = now() + + def _update_last_used(): + if OAuth2AccessToken.objects.filter(pk=self.pk).exists(): + self.save(update_fields=['last_used']) + + connection.on_commit(_update_last_used) + return valid + + def validate_external_users(self): + if self.user and get_setting('ALLOW_OAUTH2_FOR_EXTERNAL_USERS') is False: + external_account = is_external_account(self.user) + if external_account: + raise oauth2.AccessDeniedError( + _('OAuth2 Tokens cannot be created by users associated with an external authentication provider (%(authenticator)s)') + % {'authenticator': external_account.name} + ) + + def save(self, *args, **kwargs): + if not self.pk: + self.validate_external_users() + super().save(*args, **kwargs) diff --git a/ansible_base/oauth2_provider/models/application.py b/ansible_base/oauth2_provider/models/application.py new file mode 100644 index 000000000..0a0c981cb --- /dev/null +++ b/ansible_base/oauth2_provider/models/application.py @@ -0,0 +1,88 @@ +import re + +import oauth2_provider.models as oauth2_models +from django.conf import settings +from django.core.validators import RegexValidator +from django.db import models +from django.urls import reverse +from django.utils.translation import gettext_lazy as _ + +from ansible_base.lib.abstract_models.common import NamedCommonModel + +activitystream = object +if 'ansible_base.activitystream' in settings.INSTALLED_APPS: + from ansible_base.activitystream.models import AuditableModel + + activitystream = AuditableModel + + +DATA_URI_RE = re.compile(r'.*') # FIXME + + +class OAuth2Application(NamedCommonModel, oauth2_models.AbstractApplication, activitystream): + router_basename = 'application' + ignore_relations = ['oauth2idtoken', 'grant', 'oauth2refreshtoken'] + # We do NOT add client_secret to encrypted_fields because it is hashed by Django OAuth Toolkit + # and it would end up hashing the encrypted value. + + class Meta(oauth2_models.AbstractAccessToken.Meta): + verbose_name = _('application') + unique_together = (("name", "organization"),) + ordering = ('organization', 'name') + swappable = "OAUTH2_PROVIDER_APPLICATION_MODEL" + + CLIENT_TYPES = ( + ("confidential", _("Confidential")), + ("public", _("Public")), + ) + + GRANT_TYPES = ( + ("authorization-code", _("Authorization code")), + ("password", _("Resource owner password-based")), + ) + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + related_name="applications", + null=True, + blank=True, + on_delete=models.CASCADE, + ) + + description = models.TextField( + default='', + blank=True, + ) + logo_data = models.TextField( + default='', + editable=False, + validators=[RegexValidator(DATA_URI_RE)], + ) + organization = models.ForeignKey( + getattr(settings, 'ANSIBLE_BASE_ORGANIZATION_MODEL'), + related_name='applications', + help_text=_('Organization containing this application.'), + on_delete=models.CASCADE, + null=True, + ) + # Not overriding client_secret... Details: + # It would be nice to just use our usual encrypted_fields flow here + # until DOT makes a release with https://github.com/jazzband/django-oauth-toolkit/pull/1311 + # there is no way to disable its expectation of using its own hashing + # (which is Django's make_password/check_password). + # So we use their field here. + # Previous versions of DOT didn't hash the field at all and AWX pins + # to <2.0.0 so AWX used the AWX encryption with no issue. + client_type = models.CharField( + max_length=32, choices=CLIENT_TYPES, help_text=_('Set to Public or Confidential depending on how secure the client device is.') + ) + skip_authorization = models.BooleanField(default=False, help_text=_('Set True to skip authorization step for completely trusted applications.')) + authorization_grant_type = models.CharField( + max_length=32, choices=GRANT_TYPES, help_text=_('The Grant type the user must use for acquire tokens for this application.') + ) + updated = None # Tracked in CommonModel with 'modified', no need for this + + def get_absolute_url(self): + # This is kind of annoying. This method lives on the superclass and we check for it in CommonModel. + # But better would be to not have this method and let the CommonModel logic fall back to the "right" way of finding this. + return reverse(f'{self.router_basename}-detail', kwargs={'pk': self.pk}) diff --git a/ansible_base/oauth2_provider/models/id_token.py b/ansible_base/oauth2_provider/models/id_token.py new file mode 100644 index 000000000..67f641231 --- /dev/null +++ b/ansible_base/oauth2_provider/models/id_token.py @@ -0,0 +1,19 @@ +import oauth2_provider.models as oauth2_models +from django.conf import settings +from django.utils.translation import gettext_lazy as _ + +from ansible_base.lib.abstract_models.common import CommonModel + +activitystream = object +if 'ansible_base.activitystream' in settings.INSTALLED_APPS: + from ansible_base.activitystream.models import AuditableModel + + activitystream = AuditableModel + + +class OAuth2IDToken(CommonModel, oauth2_models.AbstractIDToken, activitystream): + class Meta(oauth2_models.AbstractIDToken.Meta): + verbose_name = _('id token') + swappable = "OAUTH2_PROVIDER_ID_TOKEN_MODEL" + + updated = None # Tracked in CommonModel with 'modified', no need for this diff --git a/ansible_base/oauth2_provider/models/refresh_token.py b/ansible_base/oauth2_provider/models/refresh_token.py new file mode 100644 index 000000000..078a87cf9 --- /dev/null +++ b/ansible_base/oauth2_provider/models/refresh_token.py @@ -0,0 +1,23 @@ +import oauth2_provider.models as oauth2_models +from django.conf import settings +from django.db import models +from django.utils.translation import gettext_lazy as _ + +from ansible_base.lib.abstract_models.common import CommonModel +from ansible_base.lib.utils.models import prevent_search + +activitystream = object +if 'ansible_base.activitystream' in settings.INSTALLED_APPS: + from ansible_base.activitystream.models import AuditableModel + + activitystream = AuditableModel + + +class OAuth2RefreshToken(CommonModel, oauth2_models.AbstractRefreshToken, activitystream): + class Meta(oauth2_models.AbstractRefreshToken.Meta): + verbose_name = _('refresh token') + ordering = ('id',) + swappable = "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL" + + token = prevent_search(models.CharField(max_length=255)) + updated = None # Tracked in CommonModel with 'modified', no need for this diff --git a/ansible_base/oauth2_provider/serializers/__init__.py b/ansible_base/oauth2_provider/serializers/__init__.py new file mode 100644 index 000000000..cac44fafb --- /dev/null +++ b/ansible_base/oauth2_provider/serializers/__init__.py @@ -0,0 +1,2 @@ +from .application import OAuth2ApplicationSerializer # noqa: F401 +from .token import OAuth2TokenSerializer # noqa: F401 diff --git a/ansible_base/oauth2_provider/serializers/application.py b/ansible_base/oauth2_provider/serializers/application.py new file mode 100644 index 000000000..ee1f0058b --- /dev/null +++ b/ansible_base/oauth2_provider/serializers/application.py @@ -0,0 +1,89 @@ +from django.core.exceptions import ObjectDoesNotExist +from django.utils.translation import gettext_lazy as _ +from oauth2_provider.generators import generate_client_secret + +from ansible_base.lib.serializers.common import NamedCommonModelSerializer +from ansible_base.lib.utils.encryption import ENCRYPTED_STRING +from ansible_base.oauth2_provider.models import OAuth2Application + + +class OAuth2ApplicationSerializer(NamedCommonModelSerializer): + oauth2_client_secret = None + + class Meta: + model = OAuth2Application + fields = NamedCommonModelSerializer.Meta.fields + [x.name for x in OAuth2Application._meta.concrete_fields] + read_only_fields = ('client_id', 'client_secret') + read_only_on_update_fields = ('user', 'authorization_grant_type') + extra_kwargs = { + 'user': {'allow_null': True, 'required': False}, + 'organization': {'allow_null': False}, + 'authorization_grant_type': {'allow_null': False, 'label': _('Authorization Grant Type')}, + 'client_secret': {'label': _('Client Secret')}, + 'client_type': {'label': _('Client Type')}, + 'redirect_uris': {'label': _('Redirect URIs')}, + 'skip_authorization': {'label': _('Skip Authorization')}, + } + + def _get_client_secret(self, obj): + request = self.context.get('request', None) + try: + if obj.client_type == 'public': + return None + elif request.method == 'POST': + # Show the secret, one time, on POST + return self.oauth2_client_secret + else: + return ENCRYPTED_STRING + except ObjectDoesNotExist: + return '' + + def to_representation(self, instance): + # We have to override this because in AbstractCommonModelSerializer, we'll + # auto-force all encrypted fields to ENCRYPTED_STRING. Usually that's fine, + # but we want to show the client_secret on POST. Ideally we'd just use + # get_client_secret() and a SerializerMethodField. + ret = super().to_representation(instance) + secret = self._get_client_secret(instance) + if secret is None: + del ret['client_secret'] + else: + ret['client_secret'] = secret + return ret + + def _summary_field_tokens(self, obj): + token_list = [{'id': x.pk, 'token': ENCRYPTED_STRING, 'scope': x.scope} for x in obj.access_tokens.all()[:10]] + if len(token_list) < 10: + token_count = len(token_list) + else: + token_count = obj.access_tokens.count() + return {'count': token_count, 'results': token_list} + + def _get_summary_fields(self, obj): + ret = super()._get_summary_fields(obj) + ret['tokens'] = self._summary_field_tokens(obj) + return ret + + def create(self, validated_data): + # This is hacky: + # There is a cascading set of issues here. + # 1. The first thing to know is that DOT automatically hashes the client_secret + # in a pre_save method on the client_secret field. + # 2. In current released versions, there is no way to disable (1). It uses + # the built-in Django password hashing stuff to do this. There's a merged + # PR to allow disabling this (DOT #1311), but it's not released yet. + # 3. If we use our own encrypted_field stuff, it conflicts with (1) and (2). + # They end up giving our encrypted field to Django's password check + # and *we* end up showing *their* hashed value to the user on POST, which + # doesn't work, the user needs to see the real (decrypted) value. So + # until upstream #1311 is released, we do NOT treat the field as an + # encrypted_field, we just defer to the upstream hashing. + # 4. But we have no way to see the client_secret on POST, if we let the + # model generate it, because it's hashed by the time we get to the + # serializer... + # + # So to that end, on POST, we'll make the client secret here, and then + # we can access it to show the user the value (once) on POST. + validated_data['client_secret'] = generate_client_secret() + self.oauth2_client_secret = validated_data['client_secret'] + return super().create(validated_data) diff --git a/ansible_base/oauth2_provider/serializers/token.py b/ansible_base/oauth2_provider/serializers/token.py new file mode 100644 index 000000000..319a15230 --- /dev/null +++ b/ansible_base/oauth2_provider/serializers/token.py @@ -0,0 +1,102 @@ +import logging +from datetime import timedelta + +from crum import get_current_user +from django.core.exceptions import ObjectDoesNotExist +from django.utils.timezone import now +from django.utils.translation import gettext_lazy as _ +from oauthlib.common import generate_token +from oauthlib.oauth2 import AccessDeniedError +from rest_framework.exceptions import PermissionDenied, ValidationError +from rest_framework.serializers import SerializerMethodField + +from ansible_base.lib.serializers.common import CommonModelSerializer +from ansible_base.lib.utils.encryption import ENCRYPTED_STRING +from ansible_base.lib.utils.settings import get_setting +from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken + +logger = logging.getLogger("ansible_base.oauth2_provider.serializers.token") + + +class BaseOAuth2TokenSerializer(CommonModelSerializer): + refresh_token = SerializerMethodField() + token = SerializerMethodField() + ALLOWED_SCOPES = [x[0] for x in OAuth2AccessToken.SCOPE_CHOICES] + + class Meta: + model = OAuth2AccessToken + fields = CommonModelSerializer.Meta.fields + [x.name for x in OAuth2AccessToken._meta.concrete_fields] + ['refresh_token'] + # The source_refresh_token and id_token are the concrete field but we change them to just token and refresh_token + # We wrap these in a try for when we need to make the initial models + try: + fields.remove('source_refresh_token') + except ValueError: + pass + try: + fields.remove('id_token') + except ValueError: + pass + read_only_fields = ('user', 'token', 'expires', 'refresh_token') + extra_kwargs = {'scope': {'allow_null': False, 'required': False}, 'user': {'allow_null': False, 'required': True}} + + def get_token(self, obj): + request = self.context.get('request') + try: + if request and request.method == 'POST': + return obj.token + else: + return ENCRYPTED_STRING + except ObjectDoesNotExist: + return '' + + def get_refresh_token(self, obj): + request = self.context.get('request') + try: + if not obj.refresh_token: + return None + elif request and request.method == 'POST': + return getattr(obj.refresh_token, 'token', '') + else: + return ENCRYPTED_STRING + except ObjectDoesNotExist: + return None + + def _is_valid_scope(self, value): + if not value or (not isinstance(value, str)): + return False + words = value.split() + for word in words: + if words.count(word) > 1: + return False # do not allow duplicates + if word not in self.ALLOWED_SCOPES: + return False + return True + + def validate_scope(self, value): + if not self._is_valid_scope(value): + raise ValidationError(_('Must be a simple space-separated string with allowed scopes {}.').format(self.ALLOWED_SCOPES)) + return value + + def create(self, validated_data): + validated_data['user'] = self.context['request'].user + try: + return super().create(validated_data) + except AccessDeniedError as e: + raise PermissionDenied(str(e)) + + +class OAuth2TokenSerializer(BaseOAuth2TokenSerializer): + def create(self, validated_data): + current_user = get_current_user() + validated_data['token'] = generate_token() + expires_delta = get_setting('OAUTH2_PROVIDER', {}).get('ACCESS_TOKEN_EXPIRE_SECONDS', 0) + if expires_delta == 0: + logger.warning("OAUTH2_PROVIDER.ACCESS_TOKEN_EXPIRE_SECONDS was set to 0, creating token that has already expired") + validated_data['expires'] = now() + timedelta(seconds=expires_delta) + obj = super().create(validated_data) + if obj.application and obj.application.user: + obj.user = obj.application.user + obj.save() + if obj.application: + OAuth2RefreshToken.objects.create(user=current_user, token=generate_token(), application=obj.application, access_token=obj) + return obj diff --git a/ansible_base/oauth2_provider/urls.py b/ansible_base/oauth2_provider/urls.py new file mode 100644 index 000000000..24e673610 --- /dev/null +++ b/ansible_base/oauth2_provider/urls.py @@ -0,0 +1,36 @@ +from django.urls import include, path, re_path +from oauth2_provider import views as oauth_views + +from ansible_base.lib.routers import AssociationResourceRouter +from ansible_base.oauth2_provider import views as oauth2_provider_views +from ansible_base.oauth2_provider.apps import Oauth2ProviderConfig + +app_name = Oauth2ProviderConfig.label + +router = AssociationResourceRouter() + +router.register( + r'applications', + oauth2_provider_views.OAuth2ApplicationViewSet, + basename='application', + related_views={ + 'tokens': (oauth2_provider_views.OAuth2TokenViewSet, 'access_tokens'), + }, +) + +router.register( + r'tokens', + oauth2_provider_views.OAuth2TokenViewSet, + basename='token', +) + +api_version_urls = [ + path('', include(router.urls)), +] + +root_urls = [ + re_path(r'^o/$', oauth2_provider_views.ApiOAuthAuthorizationRootView.as_view(), name='oauth_authorization_root_view'), + re_path(r"^o/authorize/$", oauth_views.AuthorizationView.as_view(), name="authorize"), + re_path(r"^o/token/$", oauth2_provider_views.TokenView.as_view(), name="token"), + re_path(r"^o/revoke_token/$", oauth_views.RevokeTokenView.as_view(), name="revoke-token"), +] diff --git a/ansible_base/oauth2_provider/utils.py b/ansible_base/oauth2_provider/utils.py new file mode 100644 index 000000000..51d3d341d --- /dev/null +++ b/ansible_base/oauth2_provider/utils.py @@ -0,0 +1,25 @@ +from typing import Optional + +from django.contrib.auth import get_user_model + +from ansible_base.authentication.models import Authenticator + +User = get_user_model() + + +def is_external_account(user: User) -> Optional[Authenticator]: + """ + Determines whether the user is associated with any external + login source. If they are, return the source. Otherwise, None. + + :param user: The user to test + :return: If the user is associated with any external login source, return it (the first, if multiple) + Otherwise, return None + """ + authenticator_users = user.authenticator_users.all() + local = 'ansible_base.authentication.authenticator_plugins.local' + for auth_user in authenticator_users: + if auth_user.provider.type != local: + return auth_user.provider + + return None diff --git a/ansible_base/oauth2_provider/views/__init__.py b/ansible_base/oauth2_provider/views/__init__.py new file mode 100644 index 000000000..5fce180e5 --- /dev/null +++ b/ansible_base/oauth2_provider/views/__init__.py @@ -0,0 +1,4 @@ +from .application import OAuth2ApplicationViewSet # noqa: F401 +from .authorization_root import ApiOAuthAuthorizationRootView # noqa: F401 +from .token import OAuth2TokenViewSet, TokenView # noqa: F401 +from .user_mixin import DABOAuth2UserViewsetMixin # noqa: F401 diff --git a/ansible_base/oauth2_provider/views/application.py b/ansible_base/oauth2_provider/views/application.py new file mode 100644 index 000000000..25c5919ef --- /dev/null +++ b/ansible_base/oauth2_provider/views/application.py @@ -0,0 +1,12 @@ +from rest_framework import permissions +from rest_framework.viewsets import ModelViewSet + +from ansible_base.lib.utils.views.django_app_api import AnsibleBaseDjangoAppApiView +from ansible_base.oauth2_provider.models import OAuth2Application +from ansible_base.oauth2_provider.serializers import OAuth2ApplicationSerializer + + +class OAuth2ApplicationViewSet(AnsibleBaseDjangoAppApiView, ModelViewSet): + queryset = OAuth2Application.objects.all() + serializer_class = OAuth2ApplicationSerializer + permission_classes = [permissions.IsAuthenticated] diff --git a/ansible_base/oauth2_provider/views/authorization_root.py b/ansible_base/oauth2_provider/views/authorization_root.py new file mode 100644 index 000000000..125dc3ba1 --- /dev/null +++ b/ansible_base/oauth2_provider/views/authorization_root.py @@ -0,0 +1,22 @@ +from collections import OrderedDict + +from django.utils.translation import gettext_lazy as _ +from rest_framework import permissions +from rest_framework.response import Response +from rest_framework.reverse import _reverse + +from ansible_base.lib.utils.views.django_app_api import AnsibleBaseDjangoAppApiView + + +class ApiOAuthAuthorizationRootView(AnsibleBaseDjangoAppApiView): + permission_classes = (permissions.AllowAny,) + name = _("API OAuth 2 Authorization Root") + versioning_class = None + swagger_topic = 'Authentication' + + def get(self, request, format=None): + data = OrderedDict() + data['authorize'] = _reverse('authorize') + data['revoke_token'] = _reverse('revoke-token') + data['token'] = _reverse('token') + return Response(data) diff --git a/ansible_base/oauth2_provider/views/token.py b/ansible_base/oauth2_provider/views/token.py new file mode 100644 index 000000000..021823797 --- /dev/null +++ b/ansible_base/oauth2_provider/views/token.py @@ -0,0 +1,57 @@ +from datetime import timedelta + +from django.utils.timezone import now +from oauth2_provider import views as oauth_views +from oauthlib import oauth2 +from rest_framework import permissions +from rest_framework.viewsets import ModelViewSet + +from ansible_base.lib.utils.settings import get_setting +from ansible_base.lib.utils.views.django_app_api import AnsibleBaseDjangoAppApiView +from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken +from ansible_base.oauth2_provider.serializers import OAuth2TokenSerializer + + +class TokenView(oauth_views.TokenView, AnsibleBaseDjangoAppApiView): + # There is a big flow of logic that happens around this (create_token_response) behind the scenes. + # + # oauth2_provider.views.TokenView inherits from oauth2_provider.views.mixins.OAuthLibMixin + # That's where this method comes from originally. + # Then *that* method ends up calling oauth2_provider.oauth2_backends.OAuthLibCore.create_token_response + # Then *that* method ends up (ultimately) calling oauthlib.oauth2.rfc6749.... + def create_token_response(self, request): + # Django OAuth2 Toolkit has a bug whereby refresh tokens are *never* + # properly expired (ugh): + # + # https://github.com/jazzband/django-oauth-toolkit/issues/746 + # + # This code detects and auto-expires them on refresh grant + # requests. + if request.POST.get('grant_type') == 'refresh_token' and 'refresh_token' in request.POST: + refresh_token = OAuth2RefreshToken.objects.filter(token=request.POST['refresh_token']).first() + if refresh_token: + expire_seconds = get_setting('OAUTH2_PROVIDER', {}).get('REFRESH_TOKEN_EXPIRE_SECONDS', 0) + if refresh_token.created + timedelta(seconds=expire_seconds) < now(): + return request.build_absolute_uri(), {}, 'The refresh token has expired.', '403' + + core = self.get_oauthlib_core() # oauth2_provider.views.mixins.OAuthLibMixin.create_token_response + + # oauth2_provider.oauth2_backends.OAuthLibCore.create_token_response + # (we override this so we can implement our own error handling to be compatible with AWX) + uri, http_method, body, headers = core._extract_params(request) + extra_credentials = core._get_extra_credentials(request) + try: + headers, body, status = core.server.create_token_response(uri, http_method, body, headers, extra_credentials) + uri = headers.get("Location", None) + status = 201 if request.method == 'POST' and status == 200 else status + return uri, headers, body, status + except oauth2.AccessDeniedError as e: + return request.build_absolute_uri(), {}, str(e), 403 # Compat with AWX + except oauth2.OAuth2Error as e: + return request.build_absolute_uri(), {}, str(e), e.status_code + + +class OAuth2TokenViewSet(ModelViewSet, AnsibleBaseDjangoAppApiView): + queryset = OAuth2AccessToken.objects.all() + serializer_class = OAuth2TokenSerializer + permission_classes = [permissions.IsAuthenticated] diff --git a/ansible_base/oauth2_provider/views/user_mixin.py b/ansible_base/oauth2_provider/views/user_mixin.py new file mode 100644 index 000000000..7233d6adc --- /dev/null +++ b/ansible_base/oauth2_provider/views/user_mixin.py @@ -0,0 +1,39 @@ +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework.decorators import action +from rest_framework.response import Response + +from ansible_base.lib.abstract_models.common import get_cls_view_basename +from ansible_base.oauth2_provider.models import OAuth2AccessToken +from ansible_base.oauth2_provider.serializers import OAuth2TokenSerializer + + +class DABOAuth2UserViewsetMixin: + """ + This mixin provides several actions to expose as sub-urls for a given user PK. + """ + + def extra_related_fields(self, obj) -> dict[str, str]: + fields = super().extra_related_fields(obj) + user_basename = get_cls_view_basename(get_user_model()) + fields['personal_tokens'] = reverse(f'{user_basename}-personal-tokens-list', kwargs={"pk": obj.pk}) + fields['authorized_tokens'] = reverse(f'{user_basename}-authorized-tokens-list', kwargs={"pk": obj.pk}) + return fields + + def _user_token_response(self, request, application_isnull, pk): + tokens = OAuth2AccessToken.objects.filter(application__isnull=application_isnull, user=pk) + page = self.paginate_queryset(tokens) + if page is not None: + serializer = OAuth2TokenSerializer(page, many=True, context={"request": request}) + return self.get_paginated_response(serializer.data) + + serializer = OAuth2TokenSerializer(tokens, many=True) + return Response(serializer.data) + + @action(detail=True, methods=["get"], url_name="personal-tokens-list") + def personal_tokens(self, request, pk=None): + return self._user_token_response(request, True, pk) + + @action(detail=True, methods=["get"], url_name="authorized-tokens-list") + def authorized_tokens(self, request, pk=None): + return self._user_token_response(request, False, pk) diff --git a/docs/apps/oauth2_provider.md b/docs/apps/oauth2_provider.md new file mode 100644 index 000000000..3875c7ec2 --- /dev/null +++ b/docs/apps/oauth2_provider.md @@ -0,0 +1,6 @@ +# Differences from AWX + +* Because of how DAB's router works, we don't allow for POSTing to (for example) + `/applications/PK/tokens/` to create a token which belongs to an application. + The workaround is to just use /tokens/ and in the body specify + `{"application": PK}`. diff --git a/pyproject.toml b/pyproject.toml index 0a5d0f8bf..ebec43e3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ optional-dependencies.all = { file = [ "requirements/requirements_jwt_consumer.in", "requirements/requirements_testing.in", "requirements/requirements_redis_client.in", + "requirements/requirements_oauth2_provider.in", ] } optional-dependencies.activitystream = { file = [ "requirements/requirements_activitystream.in" ] } optional-dependencies.authentication = { file = [ "requirements/requirements_authentication.in" ] } @@ -53,6 +54,7 @@ optional-dependencies.channel_auth = { file = [ "requirements/requirements_chann optional-dependencies.jwt_consumer = { file = [ "requirements/requirements_jwt_consumer.in" ] } optional-dependencies.testing = { file = [ "requirements/requirements_testing.in" ] } optional-dependencies.redis_client = { file = [ "requirements/requirements_redis_client.in" ] } +optional-dependencies.oauth2_provider = { file = [ "requirements/requirements_oauth2_provider.in" ] } [build-system] requires = ["setuptools>=64", "setuptools_scm>=8"] @@ -72,7 +74,8 @@ force-exclude = ''' [tool.isort] profile = "black" line_length = 160 -extend_skip = [ "ansible_base/authentication/migrations", "ansible_base/activitystream/migrations", "test_app/migrations" ] +extend_skip = [ "test_app/migrations" ] +skip_glob = [ "ansible_base/*/migrations" ] [tool.flake8] diff --git a/requirements/requirements_all.txt b/requirements/requirements_all.txt index 819d4e689..27bcbacf5 100644 --- a/requirements/requirements_all.txt +++ b/requirements/requirements_all.txt @@ -18,6 +18,7 @@ cryptography==42.0.5 # via # -r requirements/requirements.in # -r requirements/requirements_testing.in + # jwcrypto # social-auth-core defusedxml==0.8.0rc2 # via @@ -29,6 +30,7 @@ django==4.2.11 # channels # django-auth-ldap # django-crum + # django-oauth-toolkit # django-redis # djangorestframework # drf-spectacular @@ -37,6 +39,8 @@ django-auth-ldap==4.7.0 # via -r requirements/requirements_authentication.in django-crum==0.7.9 # via -r requirements/requirements.in +django-oauth-toolkit==2.3.0 + # via -r requirements/requirements_oauth2_provider.in django-redis==5.4.0 # via -r requirements/requirements_redis_client.in django-split-settings==1.3.0 @@ -61,6 +65,8 @@ jsonschema==4.21.1 # via drf-spectacular jsonschema-specifications==2023.12.1 # via jsonschema +jwcrypto==1.5.6 + # via django-oauth-toolkit lxml==5.1.0 # via # python3-saml @@ -69,6 +75,7 @@ netaddr==1.2.1 # via pyrad oauthlib==3.2.2 # via + # django-oauth-toolkit # requests-oauthlib # social-auth-core packaging==24.0 @@ -116,6 +123,7 @@ referencing==0.34.0 requests==2.31.0 # via # -r requirements/requirements_jwt_consumer.in + # django-oauth-toolkit # requests-oauthlib # social-auth-core requests-oauthlib==2.0.0 @@ -139,6 +147,8 @@ tabulate==0.9.0 # via -r requirements/requirements_authentication.in tacacs-plus==2.6 # via -r requirements/requirements_authentication.in +typing-extensions==4.11.0 + # via jwcrypto uritemplate==4.1.1 # via drf-spectacular urllib3==2.2.1 diff --git a/requirements/requirements_oauth2_provider.in b/requirements/requirements_oauth2_provider.in new file mode 100644 index 000000000..c6686727d --- /dev/null +++ b/requirements/requirements_oauth2_provider.in @@ -0,0 +1 @@ +django-oauth-toolkit \ No newline at end of file diff --git a/test_app/management/commands/create_demo_data.py b/test_app/management/commands/create_demo_data.py index cc771b5c2..f9289bbdf 100644 --- a/test_app/management/commands/create_demo_data.py +++ b/test_app/management/commands/create_demo_data.py @@ -7,6 +7,7 @@ from django.core.management.base import BaseCommand from ansible_base.authentication.models import Authenticator, AuthenticatorUser +from ansible_base.oauth2_provider.models import OAuth2Application from ansible_base.rbac.models import RoleDefinition from ansible_base.rbac.validators import combine_values, permissions_allowed_for_role from test_app.models import EncryptionModel, InstanceGroup, Inventory, Organization, Team, User @@ -122,6 +123,14 @@ def handle(self, *args, **kwargs): team_member.give_permission(spud, awx_devs) + OAuth2Application.objects.get_or_create( + name="Demo OAuth2 Application", + description="Demo OAuth2 Application", + redirect_uris="http://example.com/callback", + authorization_grant_type="authorization-code", + client_type="confidential", + ) + self.stdout.write('Finished creating demo data!') self.stdout.write(f'Admin user password: {admin_password}') diff --git a/test_app/router.py b/test_app/router.py index b3d54cced..160a4451a 100644 --- a/test_app/router.py +++ b/test_app/router.py @@ -1,4 +1,5 @@ from ansible_base.lib.routers import AssociationResourceRouter +from ansible_base.oauth2_provider import views as oauth2_provider_views from ansible_base.rbac.api import views as rbac_views from test_app import views @@ -61,6 +62,8 @@ def filter_queryset(self, qs): related_views={ 'organizations': (views.OrganizationViewSet, 'organizations'), 'teams': (views.TeamViewSet, 'teams'), + 'tokens': (oauth2_provider_views.OAuth2TokenViewSet, 'access_tokens'), + 'applications': (oauth2_provider_views.OAuth2ApplicationViewSet, 'applications'), }, basename='user', ) diff --git a/test_app/settings.py b/test_app/settings.py index a38b58baf..64a35884f 100644 --- a/test_app/settings.py +++ b/test_app/settings.py @@ -35,6 +35,11 @@ 'handlers': ['console'], 'level': 'DEBUG', }, + '': { + 'handlers': ['console'], + 'level': 'DEBUG', + 'propagate': True, + }, }, } for logger in LOGGING["loggers"]: # noqa: F405 @@ -57,6 +62,7 @@ 'ansible_base.resource_registry', 'ansible_base.rest_pagination', 'ansible_base.rbac', + 'ansible_base.oauth2_provider', 'test_app', 'django_extensions', 'debug_toolbar', @@ -150,3 +156,5 @@ ANSIBLE_BASE_ALLOW_SINGLETON_TEAM_ROLES = True ANSIBLE_BASE_USER_VIEWSET = 'test_app.views.UserViewSet' + +LOGIN_URL = "/login/login" diff --git a/test_app/tests/lib/routers/test_association_resoure_router.py b/test_app/tests/lib/routers/test_association_resoure_router.py index ea5c24060..9646b696a 100644 --- a/test_app/tests/lib/routers/test_association_resoure_router.py +++ b/test_app/tests/lib/routers/test_association_resoure_router.py @@ -3,7 +3,7 @@ from ansible_base.lib.routers import AssociationResourceRouter from test_app import views -from test_app.models import Inventory, User +from test_app.models import Inventory, Organization, User def validate_expected_url_pattern_names(router, expected_url_pattern_names): @@ -19,20 +19,20 @@ def validate_expected_url_pattern_names(router, expected_url_pattern_names): def test_association_router_basic_viewset(): router = AssociationResourceRouter() router.register( - r'user', - views.UserViewSet, - basename='user', + r'organizations', + views.OrganizationViewSet, + basename='organization', ) - validate_expected_url_pattern_names(router, ['user-list', 'user-detail']) + validate_expected_url_pattern_names(router, ['organization-list', 'organization-detail']) def test_association_router_basic_viewset_no_basename(): - class UserViewSetWithQueryset(views.UserViewSet): - queryset = User.objects.all() + class OrganizationViewSetWithQueryset(views.OrganizationViewSet): + queryset = Organization.objects.all() router = AssociationResourceRouter() - router.register(r'user', UserViewSetWithQueryset) - validate_expected_url_pattern_names(router, ['user-list', 'user-detail']) + router.register(r'organizations', OrganizationViewSetWithQueryset) + validate_expected_url_pattern_names(router, ['organization-list', 'organization-detail']) def test_association_router_associate_viewset_all_mapings(): @@ -106,9 +106,6 @@ def test_association_router_associate_existing_item(db, admin_api_client, random related_model = RelatedFieldsTestModel.objects.create() related_model.users.add(random_user) assert related_model.users.count() == 1 - - from test_app.models import User - assert User.objects.get(pk=random_user.pk) is not None url = reverse('related_fields_test_model-users-associate', kwargs={'pk': related_model.pk}) diff --git a/test_app/tests/oauth2_provider/__init__.py b/test_app/tests/oauth2_provider/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test_app/tests/oauth2_provider/conftest.py b/test_app/tests/oauth2_provider/conftest.py new file mode 100644 index 000000000..076248cb3 --- /dev/null +++ b/test_app/tests/oauth2_provider/conftest.py @@ -0,0 +1,66 @@ +from datetime import datetime, timezone + +import pytest +from django.urls import reverse +from oauthlib.common import generate_token + +from ansible_base.lib.testing.fixtures import copy_fixture +from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2Application + + +@pytest.fixture +def oauth2_application(randname): + """ + Creates an OAuth2 application with a random name and returns + both the application and its client secret. + """ + app = OAuth2Application( + name=randname("OAuth2 Application"), + description="Test OAuth2 Application", + redirect_uris="http://example.com/callback", + authorization_grant_type="authorization-code", + client_type="confidential", + ) + # Store this before it gets hashed + secret = app.client_secret + app.save() + return (app, secret) + + +@pytest.fixture +def oauth2_application_password(randname): + """ + Creates an OAuth2 application with a random name and returns + both the application and its client secret. + """ + app = OAuth2Application( + name=randname("OAuth2 Application"), + description="Test OAuth2 Application", + redirect_uris="http://example.com/callback", + authorization_grant_type="password", + client_type="confidential", + ) + # Store this before it gets hashed + secret = app.client_secret + app.save() + return (app, secret) + + +@pytest.fixture +def oauth2_admin_access_token(oauth2_application, admin_api_client, admin_user): + url = reverse('token-list') + response = admin_api_client.post(url, {'application': oauth2_application[0].pk}) + assert response.status_code == 201 + return OAuth2AccessToken.objects.get(token=response.data['token']) + + +@copy_fixture(copies=3) +@pytest.fixture +def oauth2_user_pat(user, randname): + return OAuth2AccessToken.objects.get_or_create( + user=user, + description=randname("Personal Access Token for 'user'"), + # This has to be timezone aware + expires=datetime(2088, 1, 1, tzinfo=timezone.utc), + token=generate_token(), + )[0] diff --git a/test_app/tests/oauth2_provider/test_authentication.py b/test_app/tests/oauth2_provider/test_authentication.py new file mode 100644 index 000000000..a99abee90 --- /dev/null +++ b/test_app/tests/oauth2_provider/test_authentication.py @@ -0,0 +1,118 @@ +import pytest +from django.urls import reverse +from oauthlib.common import generate_token + + +def test_oauth2_bearer_get_user_correct(unauthenticated_api_client, oauth2_admin_access_token): + """ + Perform a GET with a bearer token and ensure the authed user is correct. + """ + url = reverse("user-me") + response = unauthenticated_api_client.get( + url, + headers={'Authorization': f'Bearer {oauth2_admin_access_token.token}'}, + ) + assert response.status_code == 200 + assert response.data['username'] == oauth2_admin_access_token.user.username + + +@pytest.mark.parametrize( + 'token, expected', + [ + ('fixture', 200), + ('bad', 401), + ], +) +def test_oauth2_bearer_get(unauthenticated_api_client, oauth2_admin_access_token, animal, token, expected): + """ + GET an animal with a bearer token. + """ + url = reverse("animal-detail", kwargs={"pk": animal.pk}) + token = oauth2_admin_access_token.token if token == 'fixture' else generate_token() + response = unauthenticated_api_client.get( + url, + headers={'Authorization': f'Bearer {token}'}, + ) + assert response.status_code == expected + if expected != 401: + assert response.data['name'] == animal.name + + +@pytest.mark.parametrize( + 'token, expected', + [ + ('fixture', 201), + ('bad', 401), + ], +) +def test_oauth2_bearer_post(unauthenticated_api_client, oauth2_admin_access_token, admin_user, token, expected): + """ + POST an animal with a bearer token. + """ + url = reverse("animal-list") + token = oauth2_admin_access_token.token if token == 'fixture' else generate_token() + data = { + "name": "Fido", + "owner": admin_user.pk, + } + response = unauthenticated_api_client.post( + url, + data=data, + headers={'Authorization': f'Bearer {token}'}, + ) + assert response.status_code == expected + if expected != 401: + assert response.data['name'] == 'Fido' + + +@pytest.mark.parametrize( + 'token, expected', + [ + ('fixture', 200), + ('bad', 401), + ], +) +def test_oauth2_bearer_patch(unauthenticated_api_client, oauth2_admin_access_token, animal, admin_user, token, expected): + """ + PATCH an animal with a bearer token. + """ + url = reverse("animal-detail", kwargs={"pk": animal.pk}) + token = oauth2_admin_access_token.token if token == 'fixture' else generate_token() + data = { + "name": "Fido", + } + response = unauthenticated_api_client.patch( + url, + data=data, + headers={'Authorization': f'Bearer {token}'}, + ) + assert response.status_code == expected + if expected != 401: + assert response.data['name'] == 'Fido' + + +@pytest.mark.parametrize( + 'token, expected', + [ + ('fixture', 200), + ('bad', 401), + ], +) +def test_oauth2_bearer_put(unauthenticated_api_client, oauth2_admin_access_token, animal, admin_user, token, expected): + """ + PUT an animal with a bearer token. + """ + url = reverse("animal-detail", kwargs={"pk": animal.pk}) + token = oauth2_admin_access_token.token if token == 'fixture' else generate_token() + data = { + "name": "Fido", + "owner": admin_user.pk, + } + response = unauthenticated_api_client.put( + url, + data=data, + headers={'Authorization': f'Bearer {token}'}, + ) + assert response.status_code == expected + if expected != 401: + assert response.data['name'] == 'Fido' diff --git a/test_app/tests/oauth2_provider/test_models.py b/test_app/tests/oauth2_provider/test_models.py new file mode 100644 index 000000000..aaccfc2e0 --- /dev/null +++ b/test_app/tests/oauth2_provider/test_models.py @@ -0,0 +1,34 @@ +import pytest + +from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken + + +@pytest.mark.django_db +def test_oauth2_revoke_access_then_refresh_token(oauth2_admin_access_token): + token = oauth2_admin_access_token + refresh_token = oauth2_admin_access_token.refresh_token + assert OAuth2AccessToken.objects.count() == 1 + assert OAuth2RefreshToken.objects.count() == 1 + + token.revoke() + assert OAuth2AccessToken.objects.count() == 0 + assert OAuth2RefreshToken.objects.count() == 1 + assert not refresh_token.revoked + + refresh_token.revoke() + assert OAuth2AccessToken.objects.count() == 0 + assert OAuth2RefreshToken.objects.count() == 1 + + +@pytest.mark.django_db +def test_oauth2_revoke_refresh_token(oauth2_admin_access_token): + refresh_token = oauth2_admin_access_token.refresh_token + assert OAuth2AccessToken.objects.count() == 1 + assert OAuth2RefreshToken.objects.count() == 1 + + refresh_token.revoke() + assert OAuth2AccessToken.objects.count() == 0 + # the same OAuth2RefreshToken is recycled + new_refresh_token = OAuth2RefreshToken.objects.all().first() + assert refresh_token == new_refresh_token + assert new_refresh_token.revoked diff --git a/test_app/tests/oauth2_provider/test_utils.py b/test_app/tests/oauth2_provider/test_utils.py new file mode 100644 index 000000000..fbe40a5a7 --- /dev/null +++ b/test_app/tests/oauth2_provider/test_utils.py @@ -0,0 +1,29 @@ +import pytest + +from ansible_base.authentication.models import Authenticator, AuthenticatorUser +from ansible_base.oauth2_provider.utils import is_external_account + + +@pytest.mark.parametrize("link_local, link_ldap, expected", [(False, False, None), (True, False, None), (False, True, "ldap"), (True, True, "ldap")]) +def test_oauth2_provider_is_external_account_with_user(user, local_authenticator, ldap_authenticator, link_local, link_ldap, expected): + if link_local: + # Link the user to the local authenticator + local_au = AuthenticatorUser(provider=local_authenticator, user=user) + local_au.save() + if link_ldap: + # Link the user to the ldap authenticator + ldap_au = AuthenticatorUser(provider=ldap_authenticator, user=user) + ldap_au.save() + + if expected == "ldap": + expected = ldap_authenticator + assert is_external_account(user) == expected + + +def test_oauth2_provider_is_external_account_import_error(user, local_authenticator): + au = AuthenticatorUser(provider=local_authenticator, user=user) + au.save() + local_authenticator.type = "test_app.tests.fixtures.authenticator_plugins.broken" + # Avoid save() which would raise an ImportError + Authenticator.objects.bulk_update([local_authenticator], ['type']) + assert is_external_account(user) diff --git a/test_app/tests/oauth2_provider/views/__init__.py b/test_app/tests/oauth2_provider/views/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test_app/tests/oauth2_provider/views/test_application.py b/test_app/tests/oauth2_provider/views/test_application.py new file mode 100644 index 000000000..d132139f8 --- /dev/null +++ b/test_app/tests/oauth2_provider/views/test_application.py @@ -0,0 +1,265 @@ +import pytest +from django.contrib.auth.hashers import check_password +from django.urls import reverse + +from ansible_base.lib.utils.encryption import ENCRYPTED_STRING +from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2Application, OAuth2RefreshToken + + +@pytest.mark.parametrize( + "client_fixture,expected_status", + [ + ("admin_api_client", 200), + ("user_api_client", 200), + ("unauthenticated_api_client", 401), + ], +) +@pytest.mark.django_db +def test_oauth2_provider_application_list(request, client_fixture, expected_status, oauth2_application): + """ + Test that we can view the list of OAuth2 applications iff we are authenticated. + """ + client = request.getfixturevalue(client_fixture) + url = reverse("application-list") + response = client.get(url) + assert response.status_code == expected_status + if expected_status == 200: + assert len(response.data['results']) == OAuth2Application.objects.count() + assert response.data['results'][0]['name'] == oauth2_application[0].name + + +@pytest.mark.parametrize( + "view, path", + [ + ("application-list", lambda data: data['results'][0]), + ("application-detail", lambda data: data), + ], +) +def test_oauth2_provider_application_related(admin_api_client, oauth2_application, organization, view, path): + """ + Test that the related fields are correct. + + Organization should only be shown if the application is associated with an organization. + Associating an application with an organization should not affect other related fields. + """ + oauth2_application = oauth2_application[0] + if view == "application-list": + url = reverse(view) + else: + url = reverse(view, args=[oauth2_application.pk]) + + oauth2_application.organization = None + oauth2_application.save() + response = admin_api_client.get(url) + assert response.status_code == 200 + assert path(response.data)['related']['access_tokens'] == reverse("application-access_tokens-list", args=[oauth2_application.pk]) + assert 'organization' not in path(response.data)['related'] + + oauth2_application.organization = organization + oauth2_application.save() + response = admin_api_client.get(url) + assert response.status_code == 200 + assert path(response.data)['related']['access_tokens'] == reverse("application-access_tokens-list", args=[oauth2_application.pk]) + assert path(response.data)['related']['organization'] == reverse("organization-detail", args=[organization.pk]) + + +@pytest.mark.parametrize( + "client_fixture,expected_status", + [ + ("admin_api_client", 200), + ("user_api_client", 200), + ("unauthenticated_api_client", 401), + ], +) +@pytest.mark.django_db +def test_oauth2_provider_application_detail(request, client_fixture, expected_status, oauth2_application): + """ + Test that we can view the detail of an OAuth2 application iff we are authenticated. + """ + oauth2_application = oauth2_application[0] + client = request.getfixturevalue(client_fixture) + url = reverse("application-detail", args=[oauth2_application.pk]) + response = client.get(url) + assert response.status_code == expected_status + if expected_status == 200: + assert response.data['name'] == oauth2_application.name + + +@pytest.mark.parametrize( + "client_fixture,expected_status", + [ + ("admin_api_client", 201), + ("user_api_client", 201), + ("unauthenticated_api_client", 401), + ], +) +def test_oauth2_provider_application_create(request, client_fixture, expected_status, randname, organization): + """ + As an admin, I should be able to create an OAuth2 application. + """ + client = request.getfixturevalue(client_fixture) + url = reverse("application-list") + name = randname("Test Application") + response = client.post( + url, + data={ + 'name': name, + 'description': 'Test Description', + 'organization': organization.pk, + 'redirect_uris': 'http://example.com/callback', + 'authorization_grant_type': 'authorization-code', + 'client_type': 'confidential', + }, + ) + assert response.status_code == expected_status, response.data + if expected_status == 201: + assert response.data['name'] == name + assert OAuth2Application.objects.get(pk=response.data['id']).organization == organization + + created_app = OAuth2Application.objects.get(client_id=response.data['client_id']) + assert created_app.name == name + assert not created_app.skip_authorization + assert created_app.redirect_uris == 'http://example.com/callback' + assert created_app.client_type == 'confidential' + assert created_app.authorization_grant_type == 'authorization-code' + assert created_app.organization == organization + + +def test_oauth2_provider_application_validator(admin_api_client): + """ + If we don't get enough information in the request, we should 400 + """ + url = reverse("application-list") + response = admin_api_client.post( + url, + data={ + 'name': 'test app', + 'authorization_grant_type': 'authorization-code', + 'client_type': 'confidential', + }, + ) + assert response.status_code == 400 + + +@pytest.mark.parametrize( + "client_fixture,expected_status", + [ + ("admin_api_client", 200), + ("user_api_client", 200), + ("unauthenticated_api_client", 401), + ], +) +@pytest.mark.django_db +def test_oauth2_provider_application_update(request, client_fixture, expected_status, oauth2_application): + """ + Test that we can update oauth2 applications iff we are authenticated. + """ + oauth2_application = oauth2_application[0] + client = request.getfixturevalue(client_fixture) + url = reverse("application-detail", args=[oauth2_application.pk]) + response = client.patch( + url, + data={ + 'name': 'Updated Name', + 'description': 'Updated Description', + 'redirect_uris': 'http://example.com/updated', + 'client_type': 'public', + }, + ) + assert response.status_code == expected_status, response.data + if expected_status == 200: + assert response.data['name'] == 'Updated Name' + assert response.data['description'] == 'Updated Description' + assert response.data['redirect_uris'] == 'http://example.com/updated' + assert response.data['client_type'] == 'public' + oauth2_application.refresh_from_db() + assert oauth2_application.name == 'Updated Name' + assert oauth2_application.description == 'Updated Description' + assert oauth2_application.redirect_uris == 'http://example.com/updated' + assert oauth2_application.client_type == 'public' + + +def test_oauth2_provider_application_client_secret_encrypted(admin_api_client, organization): + """ + The client_secret should be encrypted in the database. + We only show it to the user once, on creation. All other requests should show the encrypted value. + """ + url = reverse("application-list") + + # POST + response = admin_api_client.post( + url, + data={ + 'name': 'Test Application', + 'description': 'Test Description', + 'organization': organization.pk, + 'redirect_uris': 'http://example.com/callback', + 'authorization_grant_type': 'authorization-code', + 'client_type': 'confidential', + }, + ) + assert response.status_code == 201, response.data + application = OAuth2Application.objects.get(pk=response.data['id']) + + # If we ever switch to using *our* encryption, this is a good test. + # But until a release with jazzband/django-oauth-toolkit#1311 hits pypi, + # we have no way to disable their built-in hashing (which conflicts with our + # own encryption). + # with connection.cursor() as cursor: + # cursor.execute("SELECT client_secret FROM dab_oauth2_provider_oauth2application WHERE id = %s", [application.pk]) + # encrypted = cursor.fetchone()[0] + # assert encrypted.startswith(ENCRYPTED_STRING), encrypted + # assert ansible_encryption.decrypt_string(encrypted) == response.data['client_secret'], response.data + # assert response.data['client_secret'] == application.client_secret + + # For now we just make sure it shows the real client secret on POST + # and never on any other method. + assert 'client_secret' in response.data + assert check_password(response.data['client_secret'], application.client_secret) + + # GET + response = admin_api_client.get(reverse("application-detail", args=[application.pk])) + assert response.status_code == 200 + assert response.data['client_secret'] == ENCRYPTED_STRING, response.data + + # PATCH + response = admin_api_client.patch( + reverse("application-detail", args=[application.pk]), + data={'name': 'Updated Name'}, + ) + assert response.status_code == 200 + assert response.data['client_secret'] == ENCRYPTED_STRING, response.data + + # PUT + response = admin_api_client.put( + reverse("application-detail", args=[application.pk]), + data={ + 'name': 'Updated Name', + 'description': 'Updated Description', + 'organization': organization.pk, + 'redirect_uris': 'http://example.com/updated', + 'client_type': 'public', + 'authorization_grant_type': 'password', + }, + ) + assert response.status_code == 200 + assert 'client_secret' not in response.data + + # DELETE + response = admin_api_client.delete(reverse("application-detail", args=[application.pk])) + assert response.status_code == 204 + assert response.data is None, response.data + + +@pytest.mark.django_db +def test_oauth2_application_delete(oauth2_application, admin_api_client): + """ + Test that we can delete an OAuth2 application. + """ + oauth2_application = oauth2_application[0] + url = reverse("application-detail", args=[oauth2_application.pk]) + response = admin_api_client.delete(url) + assert response.status_code == 204 + assert OAuth2Application.objects.filter(client_id=oauth2_application.client_id).count() == 0 + assert OAuth2RefreshToken.objects.filter(application=oauth2_application).count() == 0 + assert OAuth2AccessToken.objects.filter(application=oauth2_application).count() == 0 diff --git a/test_app/tests/oauth2_provider/views/test_authorization_root.py b/test_app/tests/oauth2_provider/views/test_authorization_root.py new file mode 100644 index 000000000..eeb76baaa --- /dev/null +++ b/test_app/tests/oauth2_provider/views/test_authorization_root.py @@ -0,0 +1,12 @@ +from django.urls import reverse + + +def test_oauth2_provider_authorization_root_view(admin_api_client, unauthenticated_api_client, user_api_client): + """ + As an admin, accessing /o/ gives an index of oauth endpoints. + """ + url = reverse("oauth_authorization_root_view") + for client in (admin_api_client, unauthenticated_api_client, user_api_client): + response = admin_api_client.get(url) + assert response.status_code == 200 + assert 'authorize' in response.data diff --git a/test_app/tests/oauth2_provider/views/test_authorize.py b/test_app/tests/oauth2_provider/views/test_authorize.py new file mode 100644 index 000000000..135f3065e --- /dev/null +++ b/test_app/tests/oauth2_provider/views/test_authorize.py @@ -0,0 +1,56 @@ +from django.urls import reverse +from django.utils.http import urlencode + + +def test_oauth2_provider_authorize_view_as_admin(admin_api_client): + """ + As an admin, accessing /o/authorize/ without client_id parameter should return a 400 error. + """ + url = reverse("authorize") + response = admin_api_client.get(url) + + assert response.status_code == 400 + assert 'Missing client_id parameter.' in str(response.content) + + +def test_oauth2_provider_authorize_view_anon(client, settings): + """ + As an anonymous user, accessing /o/authorize/ should redirect to the login page. + """ + url = reverse("authorize") + response = client.get(url) + + assert response.status_code == 302 + assert response.url.startswith(settings.LOGIN_URL) + + +def test_oauth2_provider_authorize_view_flow(user_api_client, oauth2_application): + """ + As a user, I should be able to complete the authorization flow and get an authorization code. + """ + oauth2_application = oauth2_application[0] + url = reverse("authorize") + query_params = { + 'client_id': oauth2_application.client_id, + 'response_type': 'code', + 'scope': 'read', + # PKCE + 'code_challenge': '4-as-randomly-generated-by-rolling-a-die', + 'code_challenge_method': 'S256', + } + + # Initial request - authorization request, should show a form to authorize the application + response = user_api_client.get(url + '?' + urlencode(query_params)) + assert response.status_code == 200, response.headers + assert f'Authorize {oauth2_application.name}' in str(response.content) + + # But the form mostly just repackages the GET params into a POST request + query_params['redirect_uri'] = oauth2_application.redirect_uris + query_params['allow'] = 'Authorize' + response = user_api_client.post(url, data=query_params) + assert response.status_code == 302 + assert response.url.startswith(query_params['redirect_uri']) + + # On success, it takes us to the redirect_uri with the code + assert 'code=' in response.url, response.url + assert 'error=' not in response.url, response.url diff --git a/test_app/tests/oauth2_provider/views/test_token.py b/test_app/tests/oauth2_provider/views/test_token.py new file mode 100644 index 000000000..de14292a6 --- /dev/null +++ b/test_app/tests/oauth2_provider/views/test_token.py @@ -0,0 +1,408 @@ +import base64 +import json +import time + +import pytest +from django.urls import reverse +from django.utils.http import urlencode + +from ansible_base.authentication.models import AuthenticatorUser +from ansible_base.lib.utils.encryption import ENCRYPTED_STRING +from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken + + +@pytest.mark.django_db +@pytest.mark.parametrize('allow_oauth, status', [(True, 201), (False, 403)]) +def test_oauth2_token_creation_disabled_for_external_accounts( + oauth2_application_password, + user, + ldap_authenticator, + local_authenticator, + settings, + unauthenticated_api_client, + allow_oauth, + status, +): + """ + If ALLOW_OAUTH2_FOR_EXTERNAL_USERS is enabled, users associated with an external authentication provider + can create OAuth2 tokens. Otherwise, they cannot. + """ + AuthenticatorUser.objects.get_or_create(uid=user.username, user=user, provider=ldap_authenticator) + AuthenticatorUser.objects.get_or_create(uid=user.username, user=user, provider=local_authenticator) + app = oauth2_application_password[0] + secret = oauth2_application_password[1] + url = reverse('token') + settings.ALLOW_OAUTH2_FOR_EXTERNAL_USERS = allow_oauth + data = { + 'grant_type': 'password', + 'username': 'user', + 'password': 'password', + 'scope': 'read', + } + resp = unauthenticated_api_client.post( + url, + data=urlencode(data), + content_type='application/x-www-form-urlencoded', + headers={'Authorization': 'Basic ' + base64.b64encode(f"{app.client_id}:{secret}".encode()).decode()}, + ) + + assert resp.status_code == status + if allow_oauth: + assert OAuth2AccessToken.objects.count() == 1 + else: + assert 'OAuth2 Tokens cannot be created by users associated with an external authentication provider' in resp.content.decode() + assert OAuth2AccessToken.objects.count() == 0 + + +@pytest.mark.django_db +def test_oauth2_existing_token_enabled_for_external_accounts( + oauth2_application_password, user, unauthenticated_api_client, settings, ldap_authenticator, local_authenticator +): + """ + If a token already exists but then ALLOW_OAUTH2_FOR_EXTERNAL_USERS becomes False + the token should still be usable. + """ + AuthenticatorUser.objects.get_or_create(uid=user.username, user=user, provider=ldap_authenticator) + AuthenticatorUser.objects.get_or_create(uid=user.username, user=user, provider=local_authenticator) + app = oauth2_application_password[0] + secret = oauth2_application_password[1] + url = reverse('token') + settings.ALLOW_OAUTH2_FOR_EXTERNAL_USERS = True + data = { + 'grant_type': 'password', + 'username': 'user', + 'password': 'password', + 'scope': 'read', + } + resp = unauthenticated_api_client.post( + url, + data=urlencode(data), + content_type='application/x-www-form-urlencoded', + headers={'Authorization': 'Basic ' + base64.b64encode(f"{app.client_id}:{secret}".encode()).decode()}, + ) + assert resp.status_code == 201 + token = resp.json()['access_token'] + assert OAuth2AccessToken.objects.count() == 1 + + for val in (True, False): + settings.ALLOW_OAUTH2_FOR_EXTERNAL_USERS = val + url = reverse('user-me') + resp = unauthenticated_api_client.get( + url, + headers={'Authorization': f'Bearer {token}'}, + ) + assert resp.json()['username'] == user.username + + +@pytest.mark.django_db +@pytest.mark.parametrize( + 'client_fixture, user_fixture', + [ + pytest.param('user_api_client', 'user', id='user'), + pytest.param('admin_api_client', 'admin_user', id='admin'), + ], +) +def test_oauth2_pat_create_and_list(request, client_fixture, user_fixture): + """ + A user can create and list personal access tokens. + """ + client = request.getfixturevalue(client_fixture) + user = request.getfixturevalue(user_fixture) + url = reverse('token-list') + response = client.post(url, data={'scope': 'read'}) + assert response.status_code == 201 + assert response.data['scope'] == 'read' + assert response.data['user'] == user.pk + + get_response = client.get(url) + assert get_response.status_code == 200 + assert len(get_response.data['results']) == 1 + + +@pytest.mark.django_db +def test_oauth2_pat_creation(oauth2_application_password, user, unauthenticated_api_client): + app = oauth2_application_password[0] + secret = oauth2_application_password[1] + url = reverse('token') + data = { + "grant_type": "password", + "username": "user", + "password": "password", + "scope": "read", + } + resp = unauthenticated_api_client.post( + url, + data=urlencode(data), + content_type='application/x-www-form-urlencoded', + headers={'Authorization': 'Basic ' + base64.b64encode(f"{app.client_id}:{secret}".encode()).decode()}, + ) + + assert resp.status_code == 201, resp.content + resp_json = resp.json() + assert 'access_token' in resp_json + assert len(resp_json['access_token']) > 0 + assert 'scope' in resp_json + assert resp_json['scope'] == 'read' + assert 'refresh_token' in resp_json + + +@pytest.mark.django_db +def test_oauth2_pat_creation_no_default_scope(oauth2_application, admin_api_client): + """ + Tests that the default scope is overriden + """ + url = reverse('token-list') + response = admin_api_client.post( + url, + { + 'description': 'test token', + 'scope': 'read', + 'application': oauth2_application[0].pk, + }, + ) + assert response.data['scope'] == 'read' + + +@pytest.mark.django_db +def test_oauth2_pat_creation_no_scope(oauth2_application, admin_api_client): + """ + Tests that the default scope is as expected + """ + url = reverse('token-list') + response = admin_api_client.post( + url, + { + 'description': 'test token', + 'application': oauth2_application[0].pk, + }, + ) + assert response.data['scope'] == 'write' + + +def test_oauth2_pat_list_for_user(oauth2_user_pat, oauth2_user_pat_1, user, admin_api_client): + """ + Tests that we can list a user's PATs via API. + """ + url = reverse('user-personal-tokens-list', kwargs={"pk": user.pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert len(response.data['results']) == 2 + + +def test_oauth2_pat_list_for_invalid_user(oauth2_user_pat, oauth2_user_pat_1, user, admin_api_client): + """ + Ensure we don't fatal if we give a bad user PK. + + We return an empty list. + """ + url = reverse('user-personal-tokens-list', kwargs={"pk": 1000}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert response.data['results'] == [] + + +def test_oauth2_pat_list_is_user_related_field(user, admin_api_client): + """ + Ensure 'personal_tokens' shows up in the user's related fields. + """ + url = reverse('user-detail', kwargs={"pk": user.pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert 'personal_tokens' in response.data['related'] + assert response.data['related']['personal_tokens'] == reverse('user-personal-tokens-list', kwargs={"pk": user.pk}) + + +def test_oauth2_application_token_summary_fields(admin_api_client, oauth2_admin_access_token, oauth2_application): + url = reverse('application-detail', kwargs={'pk': oauth2_application[0].pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert response.data['summary_fields']['tokens']['count'] == 1 + assert response.data['summary_fields']['tokens']['results'][0] == {'id': oauth2_admin_access_token.pk, 'scope': 'write', 'token': ENCRYPTED_STRING} + + +@pytest.mark.django_db +def test_oauth2_authorized_list_for_user(oauth2_application, oauth2_user_pat, oauth2_user_pat_1, user, admin_api_client): + """ + Tests that we can list a user's authorized tokens via API. + """ + # Turn the PATs into authorized tokens by attaching an application + oauth2_application = oauth2_application[0] + oauth2_user_pat.application = oauth2_application + oauth2_user_pat.save() + oauth2_user_pat_1.application = oauth2_application + oauth2_user_pat_1.save() + + url = reverse('user-authorized-tokens-list', kwargs={"pk": user.pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert len(response.data['results']) == 2 + + +def test_oauth2_authorized_list_for_invalid_user(oauth2_user_pat, oauth2_user_pat_1, user, admin_api_client): + """ + Ensure we don't fatal if we give a bad user PK. + + We return an empty list. + """ + url = reverse('user-authorized-tokens-list', kwargs={"pk": 1000}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert response.data['results'] == [] + + +def test_oauth2_authorized_list_is_user_related_field(user, admin_api_client): + """ + Ensure 'authorized_tokens' shows up in the user's related fields. + """ + url = reverse('user-detail', kwargs={"pk": user.pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert 'authorized_tokens' in response.data['related'] + assert response.data['related']['authorized_tokens'] == reverse('user-authorized-tokens-list', kwargs={"pk": user.pk}) + + +@pytest.mark.django_db +def test_oauth2_token_createn(oauth2_application, admin_api_client, admin_user): + oauth2_application = oauth2_application[0] + url = reverse('token-list') + response = admin_api_client.post(url, {'scope': 'read', 'application': oauth2_application.pk}) + assert response.status_code == 201 + assert 'modified' in response.data and response.data['modified'] is not None + assert 'updated' not in response.data + token = OAuth2AccessToken.objects.get(token=response.data['token']) + refresh_token = OAuth2RefreshToken.objects.get(token=response.data['refresh_token']) + assert token.application == oauth2_application + assert refresh_token.application == oauth2_application + assert token.user == admin_user + assert refresh_token.user == admin_user + assert refresh_token.access_token == token + assert token.scope == 'read' + + url = reverse('application-access_tokens-list', kwargs={'pk': oauth2_application.pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert response.data['count'] == 1 + assert response.data['results'][0]['id'] == token.pk + + url = reverse('application-detail', kwargs={'pk': oauth2_application.pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert response.data['summary_fields']['tokens']['count'] == 1 + assert response.data['summary_fields']['tokens']['results'][0] == {'id': token.pk, 'scope': token.scope, 'token': ENCRYPTED_STRING} + + url = reverse('token-list') + response = admin_api_client.post(url, {'scope': 'write', 'application': oauth2_application.pk}) + assert response.status_code == 201 + assert response.data['refresh_token'] + + url = reverse('token-list') + response = admin_api_client.post(url, {'scope': 'read', 'application': oauth2_application.pk, 'user': admin_user.pk}) + assert response.status_code == 201 + assert response.data['refresh_token'] + + url = reverse('token-list') + response = admin_api_client.post(url, {'scope': 'read', 'application': oauth2_application.pk}) + assert response.status_code == 201 + assert response.data['refresh_token'] + + +@pytest.mark.django_db +def test_oauth2_token_update(oauth2_admin_access_token, admin_api_client): + assert oauth2_admin_access_token.scope == 'write' + url = reverse('token-detail', kwargs={'pk': oauth2_admin_access_token.pk}) + response = admin_api_client.patch(url, {'scope': 'read'}) + assert response.status_code == 200 + oauth2_admin_access_token.refresh_from_db() + assert oauth2_admin_access_token.scope == 'read' + + +@pytest.mark.django_db +def test_oauth2_token_delete(oauth2_admin_access_token, admin_api_client): + url = reverse('token-detail', kwargs={'pk': oauth2_admin_access_token.pk}) + response = admin_api_client.delete(url) + assert response.status_code == 204 + assert OAuth2AccessToken.objects.count() == 0 + assert OAuth2RefreshToken.objects.count() == 1 + + url = reverse('application-access_tokens-list', kwargs={'pk': oauth2_admin_access_token.application.pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert response.data['count'] == 0 + + url = reverse('application-detail', kwargs={'pk': oauth2_admin_access_token.application.pk}) + response = admin_api_client.get(url) + assert response.status_code == 200 + assert response.data['summary_fields']['tokens']['count'] == 0 + + +@pytest.mark.django_db +def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_token, unauthenticated_api_client): + """ + Test that we can refresh an access token. + """ + app = oauth2_application[0] + secret = oauth2_application[1] + refresh_token = oauth2_admin_access_token.refresh_token + + url = reverse('token') + data = { + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token.token, + } + resp = unauthenticated_api_client.post( + url, + data=urlencode(data), + content_type='application/x-www-form-urlencoded', + headers={'Authorization': 'Basic ' + base64.b64encode(f"{app.client_id}:{secret}".encode()).decode()}, + ) + assert resp.status_code == 201 + assert OAuth2RefreshToken.objects.filter(token=refresh_token).exists() + original_refresh_token = OAuth2RefreshToken.objects.get(token=refresh_token) + assert oauth2_admin_access_token not in OAuth2AccessToken.objects.all() + assert OAuth2AccessToken.objects.count() == 1 + + # the same RefreshToken remains but is marked revoked + assert OAuth2RefreshToken.objects.count() == 2 + assert original_refresh_token.revoked + + json_resp = json.loads(resp.content) + new_token = json_resp['access_token'] + new_refresh_token = json_resp['refresh_token'] + + assert OAuth2AccessToken.objects.filter(token=new_token).count() == 1 + # checks that RefreshTokens are rotated (new RefreshToken issued) + assert OAuth2RefreshToken.objects.filter(token=new_refresh_token).count() == 1 + new_refresh_obj = OAuth2RefreshToken.objects.get(token=new_refresh_token) + assert not new_refresh_obj.revoked + + +@pytest.mark.django_db +def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2_admin_access_token, admin_api_client, settings): + """ + Test that a refresh token that has expired cannot be used to refresh an access token. + """ + app = oauth2_application[0] + secret = oauth2_application[1] + refresh_token = oauth2_admin_access_token.refresh_token + + settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 1 + settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] = 1 + settings.OAUTH2_PROVIDER['AUTHORIZATION_CODE_EXPIRE_SECONDS'] = 1 + time.sleep(1) + + url = reverse('token') + data = { + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token.token, + } + response = admin_api_client.post( + url, + data=urlencode(data), + content_type='application/x-www-form-urlencoded', + headers={'Authorization': 'Basic ' + base64.b64encode(f"{app.client_id}:{secret}".encode()).decode()}, + ) + assert response.status_code == 403 + assert b'The refresh token has expired.' in response.content + assert OAuth2RefreshToken.objects.filter(token=refresh_token).exists() + assert OAuth2AccessToken.objects.count() == 1 + assert OAuth2RefreshToken.objects.count() == 1 diff --git a/test_app/views.py b/test_app/views.py index 8fe529f26..50acecdee 100644 --- a/test_app/views.py +++ b/test_app/views.py @@ -1,3 +1,5 @@ +from itertools import chain + from django.shortcuts import render from rest_framework.decorators import action, api_view from rest_framework.response import Response @@ -5,6 +7,7 @@ from rest_framework.viewsets import ModelViewSet from ansible_base.lib.utils.views.ansible_base import AnsibleBaseView +from ansible_base.oauth2_provider.views import DABOAuth2UserViewsetMixin from ansible_base.rbac import permission_registry from ansible_base.rbac.api.permissions import AnsibleBaseObjectPermissions, AnsibleBaseUserPermissions from ansible_base.rbac.policies import visible_users @@ -47,7 +50,7 @@ class TeamViewSet(TestAppViewSet): select_related = ('resource__content_type',) -class UserViewSet(TestAppViewSet): +class UserViewSet(DABOAuth2UserViewsetMixin, TestAppViewSet): queryset = models.User.objects.all() permission_classes = [AnsibleBaseUserPermissions] serializer_class = serializers.UserSerializer @@ -58,6 +61,12 @@ def filter_queryset(self, qs): qs = self.apply_optimizations(qs) return qs + @action(detail=False, methods=['get']) + def me(self, request, pk=None): + user = request.user + serializer = self.get_serializer(user) + return Response(serializer.data) + class EncryptionModelViewSet(TestAppViewSet): serializer_class = serializers.EncryptionModelSerializer @@ -125,12 +134,21 @@ class UUIDModelViewSet(TestAppViewSet): def api_root(request, format=None): from ansible_base.activitystream.urls import router as activitystream_router from ansible_base.authentication.urls import router as auth_router + from ansible_base.oauth2_provider.urls import router as oauth2_provider_router from ansible_base.rbac.api.router import router as rbac_router from ansible_base.resource_registry.urls import service_router from test_app.router import router list_endpoints = {} - for url in router.urls + auth_router.urls + service_router.urls + activitystream_router.urls + rbac_router.urls: + urls = [ + activitystream_router.urls, + auth_router.urls, + oauth2_provider_router.urls, + rbac_router.urls, + router.urls, + service_router.urls, + ] + for url in chain(*urls): # only want "root" list views, for example: # want '^users/$' [name='user-list'] # do not want '^users/(?P[^/.]+)/organizations/$' [name='user-organizations-list'],