From 0b52f58d09b8fde83b6755550429ccd44b09cca3 Mon Sep 17 00:00:00 2001 From: Alexander J Sheehan Date: Wed, 20 Sep 2023 21:16:15 +0000 Subject: [PATCH] fix: enterprise sso orchestrator api cleanup --- CHANGELOG.rst | 10 +- enterprise/__init__.py | 2 +- enterprise/admin/__init__.py | 141 +++++++++-------- .../enterprise_customer_sso_configuration.py | 85 +++++++++- .../migrations/0186_auto_20230921_1828.py | 33 ++++ enterprise/models.py | 23 +-- tests/test_enterprise/api/test_views.py | 147 +++++++++++++++++- 7 files changed, 354 insertions(+), 87 deletions(-) create mode 100644 enterprise/migrations/0186_auto_20230921_1828.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 18e791ccbf..6d8cd94043 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,16 +15,20 @@ Change Log Unreleased ---------- +[4.3.2] +------- +fix: enterprise sso orchestrator api cleanup + [4.3.1] --------- +------- chore: use lms_update_or_create_enrollment without feature flag [4.3.0] --------- +------- feat: Added the ``enable_career_engagement_network_on_learner_portal`` field for EnterpriseCustomer [4.2.0] --------- +------- feat: create generic ``PaginationWithFeatureFlags`` to add a ``features`` property to DRF's default pagination response containing Waffle-based feature flags. feat: integrate ``PaginationWithFeatureFlags`` with ``EnterpriseCustomerViewSet``. diff --git a/enterprise/__init__.py b/enterprise/__init__.py index 27abfebdc4..fdf652a73f 100644 --- a/enterprise/__init__.py +++ b/enterprise/__init__.py @@ -2,4 +2,4 @@ Your project description goes here. """ -__version__ = "4.3.1" +__version__ = "4.3.2" diff --git a/enterprise/admin/__init__.py b/enterprise/admin/__init__.py index 2d3cc8baa7..3b9a60df61 100644 --- a/enterprise/admin/__init__.py +++ b/enterprise/admin/__init__.py @@ -20,7 +20,7 @@ from django.utils.safestring import mark_safe from django.utils.translation import gettext as _ -from enterprise import constants +from enterprise import constants, models from enterprise.admin.actions import export_as_csv_action, refresh_catalog from enterprise.admin.forms import ( AdminNotificationForm, @@ -41,29 +41,12 @@ ) from enterprise.api_client.lms import CourseApiClient, EnrollmentApiClient from enterprise.config.models import UpdateRoleAssignmentsWithCustomersConfig -from enterprise.models import ( - AdminNotification, - AdminNotificationFilter, - AdminNotificationRead, - ChatGPTResponse, - EnrollmentNotificationEmailTemplate, - EnterpriseCatalogQuery, - EnterpriseCourseEnrollment, - EnterpriseCustomer, - EnterpriseCustomerBrandingConfiguration, - EnterpriseCustomerCatalog, - EnterpriseCustomerIdentityProvider, - EnterpriseCustomerInviteKey, - EnterpriseCustomerReportingConfiguration, - EnterpriseCustomerType, - EnterpriseCustomerUser, - EnterpriseFeatureUserRoleAssignment, - PendingEnrollment, - PendingEnterpriseCustomerAdminUser, - PendingEnterpriseCustomerUser, - SystemWideEnterpriseUserRoleAssignment, +from enterprise.utils import ( + discovery_query_url, + get_all_field_names, + get_default_catalog_content_filter, + localized_utcnow, ) -from enterprise.utils import discovery_query_url, get_all_field_names, get_default_catalog_content_filter try: from enterprise.api_client.enterprise_catalog import EnterpriseCatalogApiClient @@ -88,7 +71,7 @@ class EnterpriseCustomerBrandingConfigurationInline(admin.StackedInline): https://docs.djangoproject.com/en/1.8/ref/contrib/admin/#django.contrib.admin.StackedInline """ - model = EnterpriseCustomerBrandingConfiguration + model = models.EnterpriseCustomerBrandingConfiguration can_delete = False @@ -100,7 +83,7 @@ class EnterpriseCustomerIdentityProviderInline(admin.StackedInline): https://docs.djangoproject.com/en/1.8/ref/contrib/admin/#django.contrib.admin.StackedInline """ - model = EnterpriseCustomerIdentityProvider + model = models.EnterpriseCustomerIdentityProvider form = EnterpriseCustomerIdentityProviderAdminForm extra = 0 @@ -112,7 +95,7 @@ class EnterpriseCustomerCatalogInline(admin.TabularInline): https://docs.djangoproject.com/en/1.8/ref/contrib/admin/#django.contrib.admin.StackedInline """ - model = EnterpriseCustomerCatalog + model = models.EnterpriseCustomerCatalog form = EnterpriseCustomerCatalogAdminForm extra = 0 can_delete = False @@ -128,7 +111,7 @@ class PendingEnterpriseCustomerAdminUserInline(admin.TabularInline): Django admin inline model for PendingEnterpriseCustomerAdminUser. """ - model = PendingEnterpriseCustomerAdminUser + model = models.PendingEnterpriseCustomerAdminUser extra = 0 fieldsets = ( (None, { @@ -149,14 +132,14 @@ def get_admin_registration_url(self, obj): return format_html('{0}'.format(obj.admin_registration_url)) -@admin.register(EnterpriseCustomerType) +@admin.register(models.EnterpriseCustomerType) class EnterpriseCustomerTypeAdmin(admin.ModelAdmin): """ Django admin model for EnterpriseCustomerType. """ class Meta: - model = EnterpriseCustomerType + model = models.EnterpriseCustomerType fields = ( 'name', @@ -166,7 +149,7 @@ class Meta: search_fields = ('name', ) -@admin.register(EnterpriseCustomer) +@admin.register(models.EnterpriseCustomer) class EnterpriseCustomerAdmin(DjangoObjectActions, SimpleHistoryAdmin): """ Django admin model for EnterpriseCustomer. @@ -254,7 +237,7 @@ class EnterpriseCustomerAdmin(DjangoObjectActions, SimpleHistoryAdmin): form = EnterpriseCustomerAdminForm class Meta: - model = EnterpriseCustomer + model = models.EnterpriseCustomer def get_search_results(self, request, queryset, search_term): original_queryset = queryset @@ -397,14 +380,14 @@ def get_urls(self): return customer_urls + super().get_urls() -@admin.register(EnterpriseCustomerUser) +@admin.register(models.EnterpriseCustomerUser) class EnterpriseCustomerUserAdmin(admin.ModelAdmin): """ Django admin model for EnterpriseCustomerUser. """ class Meta: - model = EnterpriseCustomerUser + model = models.EnterpriseCustomerUser fields = ( 'user_id', @@ -446,7 +429,7 @@ def get_search_results(self, request, queryset, search_term): use_distinct = False if search_term: - queryset = EnterpriseCustomerUser.objects.filter( + queryset = models.EnterpriseCustomerUser.objects.filter( user_id__in=User.objects.filter( Q(email__icontains=search_term) | Q(username__icontains=search_term) ) @@ -495,7 +478,9 @@ def _get_enterprise_course_enrollments(self, enterprise_customer_user): enterprise_customer_user: The instance of EnterpriseCustomerUser being rendered with this admin form. """ - enrollments = EnterpriseCourseEnrollment.objects.filter(enterprise_customer_user=enterprise_customer_user) + enrollments = models.EnterpriseCourseEnrollment.objects.filter( + enterprise_customer_user=enterprise_customer_user + ) return [enrollment.course_id for enrollment in enrollments] def _get_all_enrollments(self, enterprise_customer_user): @@ -540,14 +525,14 @@ def get_enrolled_course_string(self, course_ids): ) -@admin.register(PendingEnterpriseCustomerUser) +@admin.register(models.PendingEnterpriseCustomerUser) class PendingEnterpriseCustomerUserAdmin(admin.ModelAdmin): """ Django admin model for PendingEnterpriseCustomerUser """ class Meta: - model = PendingEnterpriseCustomerUser + model = models.PendingEnterpriseCustomerUser fields = ( 'user_email', @@ -562,14 +547,14 @@ class Meta: ) -@admin.register(PendingEnterpriseCustomerAdminUser) +@admin.register(models.PendingEnterpriseCustomerAdminUser) class PendingEnterpriseCustomerAdminUserAdmin(admin.ModelAdmin): """ Django admin model for PendingEnterpriseCustomerAdminUser """ class Meta: - model = PendingEnterpriseCustomerAdminUser + model = models.PendingEnterpriseCustomerAdminUser fields = ( 'user_email', @@ -611,7 +596,7 @@ def get_admin_registration_url(self, obj): return format_html('{0}'.format(obj.admin_registration_url)) -@admin.register(EnrollmentNotificationEmailTemplate) +@admin.register(models.EnrollmentNotificationEmailTemplate) class EnrollmentNotificationEmailTemplateAdmin(DjangoObjectActions, admin.ModelAdmin): """ Django admin for EnrollmentNotificationEmailTemplate model @@ -619,7 +604,7 @@ class EnrollmentNotificationEmailTemplateAdmin(DjangoObjectActions, admin.ModelA change_actions = ("preview_as_course", "preview_as_program") class Meta: - model = EnrollmentNotificationEmailTemplate + model = models.EnrollmentNotificationEmailTemplate def get_urls(self): """ @@ -669,14 +654,14 @@ def preview_as_program(self, request, obj): preview_as_program.label = _("Preview (program)") -@admin.register(EnterpriseCourseEnrollment) +@admin.register(models.EnterpriseCourseEnrollment) class EnterpriseCourseEnrollmentAdmin(admin.ModelAdmin): """ Django admin model for EnterpriseCourseEnrollment """ class Meta: - model = EnterpriseCourseEnrollment + model = models.EnterpriseCourseEnrollment readonly_fields = ( 'enterprise_customer_user', @@ -737,14 +722,14 @@ def get_urls(self): return custom_urls + super().get_urls() -@admin.register(PendingEnrollment) +@admin.register(models.PendingEnrollment) class PendingEnrollmentAdmin(admin.ModelAdmin): """ Django admin model for PendingEnrollment """ class Meta: - model = PendingEnrollment + model = models.PendingEnrollment readonly_fields = ( 'user', @@ -773,14 +758,14 @@ def has_delete_permission(self, request, obj=None): return False -@admin.register(EnterpriseCatalogQuery) +@admin.register(models.EnterpriseCatalogQuery) class EnterpriseCatalogQueryAdmin(admin.ModelAdmin): """ Django admin model for EnterpriseCatalogQuery. """ class Meta: - model = EnterpriseCatalogQuery + model = models.EnterpriseCatalogQuery def get_urls(self): """ @@ -820,7 +805,7 @@ def has_delete_permission(self, request, obj=None): readonly_fields = ('discovery_query_url', 'uuid') -@admin.register(EnterpriseCustomerCatalog) +@admin.register(models.EnterpriseCustomerCatalog) class EnterpriseCustomerCatalogAdmin(admin.ModelAdmin): """ Django admin model for EnterpriseCustomerCatalog. @@ -829,7 +814,7 @@ class EnterpriseCustomerCatalogAdmin(admin.ModelAdmin): actions = [refresh_catalog] class Meta: - model = EnterpriseCustomerCatalog + model = models.EnterpriseCustomerCatalog class Media: js = ('enterprise/admin/enterprise_customer_catalog.js',) @@ -897,7 +882,7 @@ def get_actions(self, request): return actions -@admin.register(EnterpriseCustomerReportingConfiguration) +@admin.register(models.EnterpriseCustomerReportingConfiguration) class EnterpriseCustomerReportingConfigurationAdmin(admin.ModelAdmin): """ Django admin model for EnterpriseCustomerReportingConfiguration. @@ -920,7 +905,7 @@ class EnterpriseCustomerReportingConfigurationAdmin(admin.ModelAdmin): form = EnterpriseCustomerReportingConfigAdminForm class Meta: - model = EnterpriseCustomerReportingConfiguration + model = models.EnterpriseCustomerReportingConfiguration def get_fields(self, request, obj=None): """ @@ -968,7 +953,7 @@ def count(self): # pylint: disable=invalid-overridden-method return self._whole_table_count -@admin.register(SystemWideEnterpriseUserRoleAssignment) +@admin.register(models.SystemWideEnterpriseUserRoleAssignment) class SystemWideEnterpriseUserRoleAssignmentAdmin(UserRoleAssignmentAdmin): """ Django admin model for SystemWideEnterpriseUserRoleAssignment. @@ -994,10 +979,10 @@ class SystemWideEnterpriseUserRoleAssignmentAdmin(UserRoleAssignmentAdmin): form = SystemWideEnterpriseUserRoleAssignmentForm class Meta: - model = SystemWideEnterpriseUserRoleAssignment + model = models.SystemWideEnterpriseUserRoleAssignment -@admin.register(EnterpriseFeatureUserRoleAssignment) +@admin.register(models.EnterpriseFeatureUserRoleAssignment) class EnterpriseFeatureUserRoleAssignmentAdmin(UserRoleAssignmentAdmin): """ Django admin model for EnterpriseFeatureUserRoleAssignment. @@ -1006,45 +991,45 @@ class EnterpriseFeatureUserRoleAssignmentAdmin(UserRoleAssignmentAdmin): form = EnterpriseFeatureUserRoleAssignmentForm class Meta: - model = EnterpriseFeatureUserRoleAssignment + model = models.EnterpriseFeatureUserRoleAssignment admin.site.register(UpdateRoleAssignmentsWithCustomersConfig, ConfigurationModelAdmin) -@admin.register(AdminNotificationRead) +@admin.register(models.AdminNotificationRead) class AdminNotificationReadAdmin(admin.ModelAdmin): """ Django admin for AdminNotificationRead model. """ - model = AdminNotificationRead + model = models.AdminNotificationRead list_display = ('id', 'enterprise_customer_user', 'admin_notification', 'is_read', 'created', 'modified') -@admin.register(AdminNotification) +@admin.register(models.AdminNotification) class AdminNotificationAdmin(admin.ModelAdmin): """ Django admin for AdminNotification model. """ - model = AdminNotification + model = models.AdminNotification form = AdminNotificationForm list_display = ('id', 'title', 'text', 'is_active', 'start_date', 'expiration_date', 'created', 'modified') filter_horizontal = ('admin_notification_filter',) -@admin.register(AdminNotificationFilter) +@admin.register(models.AdminNotificationFilter) class AdminNotificationFilterAdmin(admin.ModelAdmin): """ - Django admin for AdminNotificationFilter model. + Django admin for models.AdminNotificationFilter model. """ - model = AdminNotificationFilter + model = models.AdminNotificationFilter list_display = ('id', 'filter', 'created', 'modified') -@admin.register(EnterpriseCustomerInviteKey) +@admin.register(models.EnterpriseCustomerInviteKey) class EnterpriseCustomerInviteKeyAdmin(admin.ModelAdmin): """ Django admin model for EnterpriseCustomerInviteKey. @@ -1076,7 +1061,7 @@ class EnterpriseCustomerInviteKeyAdmin(admin.ModelAdmin): ) class Meta: - model = EnterpriseCustomerInviteKey + model = models.EnterpriseCustomerInviteKey def get_readonly_fields(self, request, obj=None): readonly_fields = super().get_readonly_fields(request, obj=obj) @@ -1087,12 +1072,36 @@ def get_readonly_fields(self, request, obj=None): return readonly_fields -@admin.register(ChatGPTResponse) +@admin.register(models.ChatGPTResponse) class ChatGPTResponseAdmin(admin.ModelAdmin): """ Django admin for ChatGPTResponse model. """ - model = ChatGPTResponse + model = models.ChatGPTResponse list_display = ('uuid', 'enterprise_customer', 'prompt_hash', ) readonly_fields = ('prompt', 'response', 'prompt_hash', ) + + +@admin.register(models.EnterpriseCustomerSsoConfiguration) +class EnterpriseCustomerSsoConfigurationAdmin(DjangoObjectActions, admin.ModelAdmin): + """ + Django admin for models.EnterpriseCustomerSsoConfigurationAdmin model. + """ + + model = models.EnterpriseCustomerSsoConfiguration + list_display = ('uuid', 'enterprise_customer', 'active', 'identity_provider', 'created', 'configured_at') + change_actions = ['mark_configured'] + + @admin.action( + description="Allows for marking a config as configured. This is useful for testing while the SSO" + "orchestrator is under constructions.", + ) + def mark_configured(self, request, obj): + """ + Object tool handler method - marks the config as configured. + """ + obj.configured_at = localized_utcnow() + obj.save() + + mark_configured.label = "Mark as Configured" diff --git a/enterprise/api/v1/views/enterprise_customer_sso_configuration.py b/enterprise/api/v1/views/enterprise_customer_sso_configuration.py index a093fc7fce..b6eaad1adf 100644 --- a/enterprise/api/v1/views/enterprise_customer_sso_configuration.py +++ b/enterprise/api/v1/views/enterprise_customer_sso_configuration.py @@ -2,6 +2,9 @@ Views for the ``enterprise-customer-sso-configuration`` API endpoint. """ +from xml.etree.ElementTree import fromstring + +import requests from edx_rbac.decorators import permission_required from rest_framework import permissions, viewsets from rest_framework.decorators import action @@ -41,6 +44,18 @@ class EnterpriseCustomerInactiveException(Exception): """ +class SsoConfigurationApiError(requests.exceptions.RequestException): + """ + Exception raised when the Sso configuration api encounters an error while fetching provider metadata. + """ + + +class EntityIdNotFoundError(Exception): + """ + Exception raised by the SSO configuration api when it fails to fetch a customer IDP's entity ID from the metadata. + """ + + def check_user_part_of_customer(user, enterprise_customer): """ Checks if a user is in an enterprise customer. @@ -67,6 +82,28 @@ def fetch_configuration_record(kwargs): return EnterpriseCustomerSsoConfiguration.all_objects.filter(pk=kwargs.get('configuration_uuid')) +def get_metadata_xml_from_url(url): + """ + Gets the metadata xml from the given url. + """ + response = requests.get(url) + if response.status_code >= 300: + raise SsoConfigurationApiError(f'Error fetching metadata xml from provided url: {url}') + return response.text + + +def fetch_entity_id_from_metadata_xml(metadata_xml): + """ + Fetches the entity id from the metadata xml. + """ + root = fromstring(metadata_xml) + if entity_id := root.get('entityID'): + return entity_id + if entity_descriptor_child := root.find('EntityDescriptor'): + return entity_descriptor_child.get('entityID') + raise EntityIdNotFoundError('Could not find entity ID in metadata xml') + + class EnterpriseCustomerSsoConfigurationViewSet(viewsets.ModelViewSet): """ API views for the ``EnterpriseCustomerSsoConfiguration`` model. @@ -177,6 +214,26 @@ def create(self, request, *args, **kwargs): request_data['enterprise_customer'] = enterprise_customer else: return Response({'error': BAD_CUSTOMER_ERROR}, status=HTTP_400_BAD_REQUEST) + + # Parse the request data to see if the metadata url or xml has changed and update the entity id if so + sso_config_metadata_xml = None + if request_metadata_url := request_data.get('metadata_url'): + # If the metadata url has changed, we need to update the metadata xml + try: + sso_config_metadata_xml = get_metadata_xml_from_url(request_metadata_url) + except SsoConfigurationApiError as e: + LOGGER.error(f'{CONFIG_UPDATE_ERROR}{e}') + return Response({'error': f'{CONFIG_UPDATE_ERROR} {e}'}, status=HTTP_400_BAD_REQUEST) + request_data['metadata_xml'] = sso_config_metadata_xml + if sso_config_metadata_xml or (sso_config_metadata_xml := request_data.get('metadata_xml')): + try: + entity_id = fetch_entity_id_from_metadata_xml(sso_config_metadata_xml) + except (EntityIdNotFoundError) as e: + LOGGER.error(f'{CONFIG_UPDATE_ERROR}{e}') + return Response({'error': f'{CONFIG_UPDATE_ERROR} {e}'}, status=HTTP_400_BAD_REQUEST) + + request_data['entity_id'] = entity_id + try: new_record = EnterpriseCustomerSsoConfiguration.objects.create(**request_data) except TypeError as e: @@ -206,8 +263,32 @@ def update(self, request, *args, **kwargs): except EnterpriseCustomerInactiveException: return Response(status=HTTP_403_FORBIDDEN) + # Parse the request data to see if the metadata url or xml has changed and update the entity id if so + request_data = request.data.dict() + sso_config_metadata_xml = None + if request_metadata_url := request_data.get('metadata_url'): + sso_config_metadata_url = sso_configuration_record.first().metadata_url + if request_metadata_url != sso_config_metadata_url: + # If the metadata url has changed, we need to update the metadata xml + try: + sso_config_metadata_xml = get_metadata_xml_from_url(request_metadata_url) + except SsoConfigurationApiError as e: + LOGGER.error(f'{CONFIG_UPDATE_ERROR} {e}') + return Response({'error': f'{CONFIG_UPDATE_ERROR} {e}'}, status=HTTP_400_BAD_REQUEST) + request_data['metadata_xml'] = sso_config_metadata_xml + if request_metadata_xml := request_data.get('metadata_xml'): + if request_metadata_xml != sso_configuration_record.first().metadata_xml: + sso_config_metadata_xml = request_metadata_xml + if sso_config_metadata_xml: + try: + entity_id = fetch_entity_id_from_metadata_xml(sso_config_metadata_xml) + request_data['entity_id'] = entity_id + except (EntityIdNotFoundError) as e: + LOGGER.error(f'{CONFIG_UPDATE_ERROR}{e}') + return Response({'error': f'{CONFIG_UPDATE_ERROR} {e}'}, status=HTTP_400_BAD_REQUEST) + # If the request includes a customer uuid, ensure the new customer is valid - if new_customer := request.data.dict().get('enterprise_customer'): + if new_customer := request_data.get('enterprise_customer'): try: enterprise_customer = EnterpriseCustomer.objects.get(uuid=new_customer) except EnterpriseCustomer.DoesNotExist: @@ -219,7 +300,7 @@ def update(self, request, *args, **kwargs): return Response(status=HTTP_403_FORBIDDEN) try: with transaction.atomic(): - sso_configuration_record.update(**request.data.dict()) + sso_configuration_record.update(**request_data) sso_configuration_record.first().submit_for_configuration(updating_existing_record=True) except (TypeError, FieldDoesNotExist, ValidationError) as e: LOGGER.error(f'{CONFIG_UPDATE_ERROR}{e}') diff --git a/enterprise/migrations/0186_auto_20230921_1828.py b/enterprise/migrations/0186_auto_20230921_1828.py new file mode 100644 index 0000000000..014979ff7e --- /dev/null +++ b/enterprise/migrations/0186_auto_20230921_1828.py @@ -0,0 +1,33 @@ +# Generated by Django 3.2.20 on 2023-09-21 18:28 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('enterprise', '0185_auto_20230921_1007'), + ] + + operations = [ + migrations.AlterField( + model_name='enterprisecustomerssoconfiguration', + name='entity_id', + field=models.CharField(blank=True, help_text='The entity id of the identity provider.', max_length=255, null=True), + ), + migrations.AlterField( + model_name='enterprisecustomerssoconfiguration', + name='metadata_url', + field=models.CharField(blank=True, help_text='The metadata url of the identity provider.', max_length=255, null=True), + ), + migrations.AlterField( + model_name='historicalenterprisecustomerssoconfiguration', + name='entity_id', + field=models.CharField(blank=True, help_text='The entity id of the identity provider.', max_length=255, null=True), + ), + migrations.AlterField( + model_name='historicalenterprisecustomerssoconfiguration', + name='metadata_url', + field=models.CharField(blank=True, help_text='The metadata url of the identity provider.', max_length=255, null=True), + ), + ] diff --git a/enterprise/models.py b/enterprise/models.py index 9803282c6b..6df6ebe7a1 100644 --- a/enterprise/models.py +++ b/enterprise/models.py @@ -3797,8 +3797,8 @@ class Meta: ) metadata_url = models.CharField( - blank=False, - null=False, + blank=True, + null=True, max_length=255, help_text=_( "The metadata url of the identity provider." @@ -3814,8 +3814,8 @@ class Meta: ) entity_id = models.CharField( - blank=False, - null=False, + blank=True, + null=True, max_length=255, help_text=_( "The entity id of the identity provider." @@ -4002,7 +4002,12 @@ def is_pending_configuration(self): """ Returns True if the configuration has been submitted but not completed configuration. """ - return self.submitted_at and not self.configured_at + if self.submitted_at: + if not self.configured_at: + return True + if self.submitted_at > self.configured_at: + return True + return False def submit_for_configuration(self, updating_existing_record=False): """ @@ -4018,14 +4023,14 @@ def submit_for_configuration(self, updating_existing_record=False): ) is_sap = False sap_data = {} + config_data = {} if self.identity_provider == self.SAP_SUCCESS_FACTORS: for field in self.sap_config_fields: sap_data[utils.camelCase(field)] = getattr(self, field) is_sap = True - - config_data = {} - for field in self.base_saml_config_fields: - config_data[utils.camelCase(field)] = getattr(self, field) + else: + for field in self.base_saml_config_fields: + config_data[utils.camelCase(field)] = getattr(self, field) EnterpriseSSOOrchestratorApiClient().configure_sso_orchestration_record( config_data=config_data, diff --git a/tests/test_enterprise/api/test_views.py b/tests/test_enterprise/api/test_views.py index 5de69468b6..28688b4a21 100644 --- a/tests/test_enterprise/api/test_views.py +++ b/tests/test_enterprise/api/test_views.py @@ -7467,10 +7467,19 @@ def test_sso_configuration_list_customer_filtering_while_staff(self): # -------------------------- create test suite -------------------------- @responses.activate - def test_sso_configuration_create_x(self): + def test_sso_configuration_create(self): """ Test expected response when successfully creating a new sso configuration. """ + xml_metadata = """ + + + """ + responses.add( + responses.GET, + "https://example.com/metadata.xml", + body=xml_metadata, + ) responses.add( responses.POST, urljoin(get_sso_orchestrator_api_base_url(), get_sso_orchestrator_configure_path()), @@ -7540,13 +7549,129 @@ def test_sso_configuration_create_bad_data_format(self): response = self.post_new_sso_configuration(data) assert "somewhackyvalue" in response.json()['error'] + def test_sso_configuration_create_bad_xml_url(self): + """ + Test expected response when creating a new sso configuration with a bad xml url. + """ + responses.add( + responses.GET, + "https://example.com/metadata.xml", + json={'error': 'some error'}, + status=400, + ) + data = { + "metadata_url": "https://example.com/metadata.xml", + "enterprise_customer": str(self.enterprise_customer.uuid), + "identity_provider": "cornerstone" + } + self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid) + config_pk = uuid.uuid4() + EnterpriseCustomerSsoConfigurationFactory( + uuid=config_pk, + enterprise_customer=self.enterprise_customer, + ) + response = self.update_sso_configuration(config_pk, data) + assert response.status_code == 400 + assert "Error fetching metadata xml" in response.json()['error'] + + @responses.activate + def test_sso_configuration_create_bad_xml_content(self): + """ + Test expected response when creating a new sso configuration with an xml string that doesn't contain an entity + id. + """ + xml_metadata = """ + + + """ + data = { + "metadata_url": "https://example.com/metadata.xml", + "enterprise_customer": str(self.enterprise_customer.uuid), + "identity_provider": "cornerstone" + } + self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid) + config_pk = uuid.uuid4() + EnterpriseCustomerSsoConfigurationFactory( + uuid=config_pk, + enterprise_customer=self.enterprise_customer, + ) + data = { + "metadata_xml": xml_metadata, + } + response = self.update_sso_configuration(config_pk, data) + assert response.status_code == 400 + assert "Could not find entity ID in metadata xml" in response.json()['error'] + # -------------------------- update test suite -------------------------- + @responses.activate + def test_sso_configurations_update_bad_xml_content(self): + """ + Test the expected response when updating an sso configuration with an xml string that doesn't contain an entity + id. + """ + xml_metadata = """ + + + """ + responses.add( + responses.GET, + "https://example.com/metadata.xml", + body=xml_metadata, + ) + + self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid) + config_pk = uuid.uuid4() + EnterpriseCustomerSsoConfigurationFactory( + uuid=config_pk, + enterprise_customer=self.enterprise_customer, + ) + data = { + "metadata_url": "https://example.com/metadata.xml", + } + response = self.update_sso_configuration(config_pk, data) + assert response.status_code == 400 + + @responses.activate + def test_sso_configurations_update_bad_xml_url(self): + """ + Test the expected response when updating an sso configuration with a bad xml url. + """ + responses.add( + responses.GET, + "https://example.com/metadata.xml", + json={'error': 'some error'}, + status=400, + ) + + self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid) + config_pk = uuid.uuid4() + EnterpriseCustomerSsoConfigurationFactory( + uuid=config_pk, + enterprise_customer=self.enterprise_customer, + ) + data = { + "metadata_url": "https://example.com/metadata.xml", + } + response = self.update_sso_configuration(config_pk, data) + assert response.status_code == 400 + assert "Error fetching metadata xml" in response.json()['error'] + @responses.activate def test_sso_configurations_update_submitted_config(self): """ Test the expected response when updating an sso configuration that's already been submitted for configuration. """ + xml_metadata = """ + + + """ + responses.add( + responses.GET, + "https://example.com/metadata.xml", + body=xml_metadata, + ) + self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid) config_pk = uuid.uuid4() enterprise_sso_orchestration_config = EnterpriseCustomerSsoConfigurationFactory( @@ -7570,11 +7695,11 @@ def test_sso_configurations_update_submitted_config(self): enterprise_sso_orchestration_config.save() response = self.update_sso_configuration(config_pk, data) assert response.status_code == 200 - sent_body_params = json.loads(responses.calls[0].request.body) + sent_body_params = json.loads(responses.calls[2].request.body) assert sent_body_params['requestIdentifier'] == str(config_pk) @responses.activate - def test_sso_configuration_update(self): + def test_sso_configuration_update_x(self): """ Test expected response when successfully updating an existing sso configuration. """ @@ -7583,6 +7708,16 @@ def test_sso_configuration_update(self): urljoin(get_sso_orchestrator_api_base_url(), get_sso_orchestrator_configure_path()), json={}, ) + xml_metadata = """ + + + """ + responses.add( + responses.GET, + "https://example.com/metadata_update.xml", + body=xml_metadata, + ) + self.set_jwt_cookie(ENTERPRISE_ADMIN_ROLE, self.enterprise_customer.uuid) config_pk = uuid.uuid4() enterprise_sso_orchestration_config = EnterpriseCustomerSsoConfigurationFactory( @@ -7591,15 +7726,15 @@ def test_sso_configuration_update(self): metadata_url="before_value" ) data = { - "metadata_url": "https://example.com/metadata.xml", + "metadata_url": "https://example.com/metadata_update.xml", } response = self.update_sso_configuration(config_pk, data) assert response.status_code == status.HTTP_200_OK assert response.json()['uuid'] == str(enterprise_sso_orchestration_config.uuid) - assert response.json()['metadata_url'] == "https://example.com/metadata.xml" + assert response.json()['metadata_url'] == "https://example.com/metadata_update.xml" enterprise_sso_orchestration_config.refresh_from_db() - assert enterprise_sso_orchestration_config.metadata_url == "https://example.com/metadata.xml" + assert enterprise_sso_orchestration_config.metadata_url == "https://example.com/metadata_update.xml" def test_sso_configuration_update_permissioning(self): """