diff --git a/.github/workflows/units-test-scripts-user-retirement.yml b/.github/workflows/units-test-scripts-user-retirement.yml new file mode 100644 index 000000000000..e11d70193f0e --- /dev/null +++ b/.github/workflows/units-test-scripts-user-retirement.yml @@ -0,0 +1,33 @@ +name: units-test-scripts-user-retirement + +on: + pull_request: + push: + branches: + - master + +jobs: + test: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: [ '3.8' ] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r scripts/user_retirement/requirements/testing.txt + + - name: Run pytest + run: | + pytest scripts/user_retirement diff --git a/Makefile b/Makefile index 2db357b090b1..055923d1afcd 100644 --- a/Makefile +++ b/Makefile @@ -137,7 +137,9 @@ REQ_FILES = \ requirements/edx/development \ requirements/edx/assets \ requirements/edx/semgrep \ - scripts/xblock/requirements + scripts/xblock/requirements \ + scripts/user_retirement/requirements/base \ + scripts/user_retirement/requirements/testing define COMMON_CONSTRAINTS_TEMP_COMMENT # This is a temporary solution to override the real common_constraints.txt\n# In edx-lint, until the pyjwt constraint in edx-lint has been removed.\n# See BOM-2721 for more details.\n# Below is the copied and edited version of common_constraints\n diff --git a/scripts/user_retirement/README.rst b/scripts/user_retirement/README.rst new file mode 100644 index 000000000000..20c99197ed6d --- /dev/null +++ b/scripts/user_retirement/README.rst @@ -0,0 +1,100 @@ +User Retirement Scripts +======================= + +`This `_ directory contains python scripts which are migrated from the `tubular `_ respository. +These scripts are intended to drive the user retirement workflow which involves handling the deactivation or removal of user accounts as part of the platform's management process. + +These scripts could be called from any automation/CD framework. + +How to run the scripts +====================== + +Download the Scripts +-------------------- + +To download the scripts, you can perform a partial clone of the edx-platform repository to obtain only the required scripts. The following steps demonstrate how to achieve this. Alternatively, you may choose other utilities or libraries for the partial clone. + +.. code-block:: bash + + repo_url=git@github.com:openedx/edx-platform.git + branch=master + directory=scripts/user_retirement + + git clone --branch $branch --single-branch --depth=1 --filter=tree:0 $repo_url + cd edx-platform + git sparse-checkout init --cone + git sparse-checkout set $directory + +Create Python Virtual Environment +--------------------------------- + +Create a Python virtual environment using Python 3.8: + +.. code-block:: bash + + python3.8 -m venv ../venv + source ../venv/bin/activate + +Install Pip Packages +-------------------- + +Install the required pip packages using the provided requirements file: + +.. code-block:: bash + + pip install -r scripts/user_retirement/requirements/base.txt + +In-depth Documentation and Configuration Steps +---------------------------------------------- + +For in-depth documentation and essential configurations follow these docs + +`Documentation `_ + +`Configuration Docs `_ + + +Execute Script +-------------- + +Execute the following shell command to establish entry points for the scripts + +.. code-block:: bash + + chmod +x scripts/user_retirement/entry_points.sh + source scripts/user_retirement/entry_points.sh + +To retire a specific learner, you can use the provided example script: + +.. code-block:: bash + + retire_one_learner.py \ + --config_file=src/config.yml \ + --username=user1 + +Make sure to replace ``src/config.yml`` with the actual path to your configuration file and ``user1`` with the actual username. + +You can also execute Python scripts directly using the file path: + +.. code-block:: bash + + python scripts/user_retirement/retire_one_learner.py \ + --config_file=src/config.yml \ + --username=user1 + +Feel free to customize these steps according to your specific environment and requirements. + +Run Test Cases +============== + +Before running test cases, install the testing requirements: + +.. code-block:: bash + + pip install -r scripts/user_retirement/requirements/testing.txt + +Run the test cases using pytest: + +.. code-block:: bash + + pytest scripts/user_retirement diff --git a/scripts/user_retirement/__init__.py b/scripts/user_retirement/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/scripts/user_retirement/entry_points.sh b/scripts/user_retirement/entry_points.sh new file mode 100755 index 000000000000..ec16776e0108 --- /dev/null +++ b/scripts/user_retirement/entry_points.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +alias get_learners_to_retire.py='python scripts/user_retirement/get_learners_to_retire.py' +alias replace_usernames.py='python scripts/user_retirement/replace_usernames.py' +alias retire_one_learner.py='python scripts/user_retirement/retire_one_learner.py' +alias retirement_archive_and_cleanup.py='python scripts/user_retirement/retirement_archive_and_cleanup.py' +alias retirement_bulk_status_update.py='python scripts/user_retirement/retirement_bulk_status_update.py' +alias retirement_partner_report.py='python scripts/user_retirement/retirement_partner_report.py' diff --git a/scripts/user_retirement/get_learners_to_retire.py b/scripts/user_retirement/get_learners_to_retire.py new file mode 100755 index 000000000000..6c02196e4412 --- /dev/null +++ b/scripts/user_retirement/get_learners_to_retire.py @@ -0,0 +1,105 @@ +#! /usr/bin/env python3 + +""" +Command-line script to retrieve list of learners that have requested to be retired. +The script calls the appropriate LMS endpoint to get this list of learners. +""" + +import io +import logging +import sys +from os import path + +import click +import yaml + +# Add top-level project path to sys.path before importing scripts code +sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..'))) + +from scripts.user_retirement.utils.edx_api import LmsApi +from scripts.user_retirement.utils.jenkins import export_learner_job_properties + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +LOG = logging.getLogger(__name__) + + +@click.command("get_learners_to_retire") +@click.option( + '--config_file', + help='File in which YAML config exists that overrides all other params.' +) +@click.option( + '--cool_off_days', + help='Number of days a learner should be in the retirement queue before being actually retired.', + default=7 +) +@click.option( + '--output_dir', + help="Directory in which to write the Jenkins properties files.", + default='./jenkins_props' +) +@click.option( + '--user_count_error_threshold', + help="If more users than this number are returned we will error out instead of retiring. This is a failsafe" + "against attacks that somehow manage to add users to the retirement queue.", + default=300 +) +@click.option( + '--max_user_batch_size', + help="This setting will only get at most X number of users. If this number is lower than the user_count_error_threshold" + "setting then it will not error.", + default=200 +) +def get_learners_to_retire(config_file, + cool_off_days, + output_dir, + user_count_error_threshold, + max_user_batch_size): + """ + Retrieves a JWT token as the retirement service user, then calls the LMS + endpoint to retrieve the list of learners awaiting retirement. + """ + if not config_file: + click.echo('A config file is required.') + sys.exit(-1) + + with io.open(config_file, 'r') as config: + config_yaml = yaml.safe_load(config) + + user_count_error_threshold = int(user_count_error_threshold) + cool_off_days = int(cool_off_days) + + client_id = config_yaml['client_id'] + client_secret = config_yaml['client_secret'] + lms_base_url = config_yaml['base_urls']['lms'] + retirement_pipeline = config_yaml['retirement_pipeline'] + end_states = [state[1] for state in retirement_pipeline] + states_to_request = ['PENDING'] + end_states + + api = LmsApi(lms_base_url, lms_base_url, client_id, client_secret) + + # Retrieve the learners to retire and export them to separate Jenkins property files. + learners_to_retire = api.learners_to_retire(states_to_request, cool_off_days, max_user_batch_size) + if max_user_batch_size: + learners_to_retire = learners_to_retire[:max_user_batch_size] + learners_to_retire_cnt = len(learners_to_retire) + + if learners_to_retire_cnt > user_count_error_threshold: + click.echo( + 'Too many learners to retire! Expected {} or fewer, got {}!'.format( + user_count_error_threshold, + learners_to_retire_cnt + ) + ) + sys.exit(-1) + + export_learner_job_properties( + learners_to_retire, + output_dir + ) + + +if __name__ == "__main__": + # pylint: disable=unexpected-keyword-arg, no-value-for-parameter + # If using env vars to provide params, prefix them with "RETIREMENT_", e.g. RETIREMENT_CLIENT_ID + get_learners_to_retire(auto_envvar_prefix='RETIREMENT') diff --git a/scripts/user_retirement/pytest.ini b/scripts/user_retirement/pytest.ini new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/scripts/user_retirement/replace_usernames.py b/scripts/user_retirement/replace_usernames.py new file mode 100644 index 000000000000..2034fce0e6a3 --- /dev/null +++ b/scripts/user_retirement/replace_usernames.py @@ -0,0 +1,153 @@ +#! /usr/bin/env python3 + +""" +Command-line script to replace the usernames for all passed in learners. +Accepts a list of current usernames and their preferred new username. This +script will call LMS first which generates a unique username if the passed in +new username is not unique. It then calls all other services to replace the +username in their DBs. + +""" + +import csv +import io +import logging +import sys +from os import path + +import click +import yaml + +# Add top-level project path to sys.path before importing scripts code +sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..'))) + +from scripts.user_retirement.utils.edx_api import ( # pylint: disable=wrong-import-position + CredentialsApi, + DiscoveryApi, + EcommerceApi, + LmsApi +) + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +LOG = logging.getLogger(__name__) + + +def write_responses(writer, replacements, status): + for replacement in replacements: + original_username = list(replacement.keys())[0] + new_username = list(replacement.values())[0] + writer.writerow([original_username, new_username, status]) + + +@click.command("replace_usernames") +@click.option( + '--config_file', + help='File in which YAML config exists that overrides all other params.' +) +@click.option( + '--username_replacement_csv', + help='File in which YAML config exists that overrides all other params.' +) +def replace_usernames(config_file, username_replacement_csv): + """ + Retrieves a JWT token as the retirement service user, then calls the LMS + endpoint to retrieve the list of learners awaiting retirement. + + Config file example: + ``` + client_id: xxx + client_secret: xxx + base_urls: + lms: http://localhost:18000 + ecommerce: http://localhost:18130 + discovery: http://localhost:18381 + credentials: http://localhost:18150 + ``` + + Username file example: + ``` + current_un_1,desired_un_1 + current_un_2,desired_un_2, + current_un_3,desired_un_3 + ``` + """ + if not config_file: + click.echo('A config file is required.') + sys.exit(-1) + + if not username_replacement_csv: + click.echo('A username replacement CSV file is required') + sys.exit(-1) + + with io.open(config_file, 'r') as config: + config_yaml = yaml.safe_load(config) + + with io.open(username_replacement_csv, 'r') as replacement_file: + csv_reader = csv.reader(replacement_file) + lms_username_mappings = [ + {current_username: desired_username} + for (current_username, desired_username) + in csv_reader + ] + + client_id = config_yaml['client_id'] + client_secret = config_yaml['client_secret'] + lms_base_url = config_yaml['base_urls']['lms'] + ecommerce_base_url = config_yaml['base_urls']['ecommerce'] + discovery_base_url = config_yaml['base_urls']['discovery'] + credentials_base_url = config_yaml['base_urls']['credentials'] + + # Note that though partially_failed sounds better than completely_failed, + # it's actually worse since the user is not consistant across DBs. + # Partially failed username replacements will need to be triaged so the + # user isn't in a broken state + successful_replacements = [] + partially_failed_replacements = [] + fully_failed_replacements = [] + + lms_api = LmsApi(lms_base_url, lms_base_url, client_id, client_secret) + ecommerce_api = EcommerceApi(lms_base_url, ecommerce_base_url, client_id, client_secret) + discovery_api = DiscoveryApi(lms_base_url, discovery_base_url, client_id, client_secret) + credentials_api = CredentialsApi(lms_base_url, credentials_base_url, client_id, client_secret) + + # Call LMS with current and desired usernames + response = lms_api.replace_lms_usernames(lms_username_mappings) + fully_failed_replacements += response['failed_replacements'] + in_progress_replacements = response['successful_replacements'] + + # Step through each services endpoints with the list returned from LMS. + # The LMS list has already verified usernames and made any duplicate + # usernames unique (e.g. 'matt' => 'mattf56a'). We pass successful + # replacements onto the next service and store all failed replacments. + replacement_methods = [ + ecommerce_api.replace_usernames, + discovery_api.replace_usernames, + credentials_api.replace_usernames, + lms_api.replace_forums_usernames, + ] + # Iterate through the endpoints above and if the APIs return any failures + # capture these in partially_failed_replacements. Only successfuly + # replacements will continue to be passed to the next service. + for replacement_method in replacement_methods: + response = replacement_method(in_progress_replacements) + partially_failed_replacements += response['failed_replacements'] + in_progress_replacements = response['successful_replacements'] + + successful_replacements = in_progress_replacements + + with open('username_replacement_results.csv', 'w', newline='') as output_file: + csv_writer = csv.writer(output_file) + # Write header + csv_writer.writerow(['Original Username', 'New Username', 'Status']) + write_responses(csv_writer, successful_replacements, "SUCCESS") + write_responses(csv_writer, partially_failed_replacements, "PARTIALLY FAILED") + write_responses(csv_writer, fully_failed_replacements, "FAILED") + + if partially_failed_replacements or fully_failed_replacements: + sys.exit(-1) + + +if __name__ == "__main__": + # pylint: disable=unexpected-keyword-arg, no-value-for-parameter + # If using env vars to provide params, prefix them with "RETIREMENT_", e.g. RETIREMENT_CLIENT_ID + replace_usernames(auto_envvar_prefix='USERNAME_REPLACEMENT') diff --git a/scripts/user_retirement/requirements/base.in b/scripts/user_retirement/requirements/base.in new file mode 100644 index 000000000000..7377ff996df4 --- /dev/null +++ b/scripts/user_retirement/requirements/base.in @@ -0,0 +1,11 @@ +boto3 +click +pyyaml +backoff +requests +edx-rest-api-client +jenkinsapi +unicodecsv +simplejson +simple-salesforce +google-api-python-client diff --git a/scripts/user_retirement/requirements/base.txt b/scripts/user_retirement/requirements/base.txt new file mode 100644 index 000000000000..94aefc2dcbc5 --- /dev/null +++ b/scripts/user_retirement/requirements/base.txt @@ -0,0 +1,178 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# make upgrade +# +asgiref==3.7.2 + # via django +attrs==23.2.0 + # via zeep +backoff==2.2.1 + # via -r scripts/user_retirement/requirements/base.in +backports-zoneinfo==0.2.1 + # via + # django + # pendulum +boto3==1.34.26 + # via -r scripts/user_retirement/requirements/base.in +botocore==1.34.26 + # via + # boto3 + # s3transfer +cachetools==5.3.2 + # via google-auth +certifi==2023.11.17 + # via requests +cffi==1.16.0 + # via + # cryptography + # pynacl +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via + # -r scripts/user_retirement/requirements/base.in + # edx-django-utils +cryptography==42.0.0 + # via simple-salesforce +django==4.2.9 + # via + # django-crum + # django-waffle + # edx-django-utils +django-crum==0.7.9 + # via edx-django-utils +django-waffle==4.1.0 + # via edx-django-utils +edx-django-utils==5.10.1 + # via edx-rest-api-client +edx-rest-api-client==5.6.1 + # via -r scripts/user_retirement/requirements/base.in +google-api-core==2.15.0 + # via google-api-python-client +google-api-python-client==2.115.0 + # via -r scripts/user_retirement/requirements/base.in +google-auth==2.26.2 + # via + # google-api-core + # google-api-python-client + # google-auth-httplib2 +google-auth-httplib2==0.2.0 + # via google-api-python-client +googleapis-common-protos==1.62.0 + # via google-api-core +httplib2==0.22.0 + # via + # google-api-python-client + # google-auth-httplib2 +idna==3.6 + # via requests +importlib-resources==6.1.1 + # via pendulum +isodate==0.6.1 + # via zeep +jenkinsapi==0.3.13 + # via -r scripts/user_retirement/requirements/base.in +jmespath==1.0.1 + # via + # boto3 + # botocore +lxml==4.9.3 + # via zeep +more-itertools==10.2.0 + # via simple-salesforce +newrelic==9.5.0 + # via edx-django-utils +pbr==6.0.0 + # via stevedore +pendulum==3.0.0 + # via simple-salesforce +platformdirs==4.1.0 + # via zeep +protobuf==4.25.2 + # via + # google-api-core + # googleapis-common-protos +psutil==5.9.8 + # via edx-django-utils +pyasn1==0.5.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pycparser==2.21 + # via cffi +pyjwt==2.8.0 + # via + # edx-rest-api-client + # simple-salesforce +pynacl==1.5.0 + # via edx-django-utils +pyparsing==3.1.1 + # via httplib2 +python-dateutil==2.8.2 + # via + # botocore + # pendulum + # time-machine +pytz==2023.3.post1 + # via + # jenkinsapi + # zeep +pyyaml==6.0.1 + # via -r scripts/user_retirement/requirements/base.in +requests==2.31.0 + # via + # -r scripts/user_retirement/requirements/base.in + # edx-rest-api-client + # google-api-core + # jenkinsapi + # requests-file + # requests-toolbelt + # simple-salesforce + # slumber + # zeep +requests-file==1.5.1 + # via zeep +requests-toolbelt==1.0.0 + # via zeep +rsa==4.9 + # via google-auth +s3transfer==0.10.0 + # via boto3 +simple-salesforce==1.12.5 + # via -r scripts/user_retirement/requirements/base.in +simplejson==3.19.2 + # via -r scripts/user_retirement/requirements/base.in +six==1.16.0 + # via + # isodate + # jenkinsapi + # python-dateutil + # requests-file +slumber==0.7.1 + # via edx-rest-api-client +sqlparse==0.4.4 + # via django +stevedore==5.1.0 + # via edx-django-utils +time-machine==2.13.0 + # via pendulum +typing-extensions==4.9.0 + # via asgiref +tzdata==2023.4 + # via pendulum +unicodecsv==0.14.1 + # via -r scripts/user_retirement/requirements/base.in +uritemplate==4.1.1 + # via google-api-python-client +urllib3==1.26.18 + # via + # botocore + # requests +zeep==4.2.1 + # via simple-salesforce +zipp==3.17.0 + # via importlib-resources diff --git a/scripts/user_retirement/requirements/testing.in b/scripts/user_retirement/requirements/testing.in new file mode 100644 index 000000000000..49a4297b2292 --- /dev/null +++ b/scripts/user_retirement/requirements/testing.in @@ -0,0 +1,8 @@ +-r base.txt + +moto +pytest +requests_mock +responses +mock +ddt diff --git a/scripts/user_retirement/requirements/testing.txt b/scripts/user_retirement/requirements/testing.txt new file mode 100644 index 000000000000..233f83a4911e --- /dev/null +++ b/scripts/user_retirement/requirements/testing.txt @@ -0,0 +1,316 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# make upgrade +# +asgiref==3.7.2 + # via + # -r scripts/user_retirement/requirements/base.txt + # django +attrs==23.2.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # zeep +backoff==2.2.1 + # via -r scripts/user_retirement/requirements/base.txt +backports-zoneinfo==0.2.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # django + # pendulum +boto3==1.34.26 + # via + # -r scripts/user_retirement/requirements/base.txt + # moto +botocore==1.34.26 + # via + # -r scripts/user_retirement/requirements/base.txt + # boto3 + # moto + # s3transfer +cachetools==5.3.2 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-auth +certifi==2023.11.17 + # via + # -r scripts/user_retirement/requirements/base.txt + # requests +cffi==1.16.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # cryptography + # pynacl +charset-normalizer==3.3.2 + # via + # -r scripts/user_retirement/requirements/base.txt + # requests +click==8.1.7 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-django-utils +cryptography==42.0.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # moto + # simple-salesforce +ddt==1.7.1 + # via -r scripts/user_retirement/requirements/testing.in +django==4.2.9 + # via + # -r scripts/user_retirement/requirements/base.txt + # django-crum + # django-waffle + # edx-django-utils +django-crum==0.7.9 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-django-utils +django-waffle==4.1.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-django-utils +edx-django-utils==5.10.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-rest-api-client +edx-rest-api-client==5.6.1 + # via -r scripts/user_retirement/requirements/base.txt +exceptiongroup==1.2.0 + # via pytest +google-api-core==2.15.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-api-python-client +google-api-python-client==2.115.0 + # via -r scripts/user_retirement/requirements/base.txt +google-auth==2.26.2 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-api-core + # google-api-python-client + # google-auth-httplib2 +google-auth-httplib2==0.2.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-api-python-client +googleapis-common-protos==1.62.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-api-core +httplib2==0.22.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-api-python-client + # google-auth-httplib2 +idna==3.6 + # via + # -r scripts/user_retirement/requirements/base.txt + # requests +importlib-resources==6.1.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # pendulum +iniconfig==2.0.0 + # via pytest +isodate==0.6.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # zeep +jenkinsapi==0.3.13 + # via -r scripts/user_retirement/requirements/base.txt +jinja2==3.1.3 + # via moto +jmespath==1.0.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # boto3 + # botocore +lxml==4.9.3 + # via + # -r scripts/user_retirement/requirements/base.txt + # zeep +markupsafe==2.1.4 + # via + # jinja2 + # werkzeug +mock==5.1.0 + # via -r scripts/user_retirement/requirements/testing.in +more-itertools==10.2.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # simple-salesforce +moto==4.2.13 + # via -r scripts/user_retirement/requirements/testing.in +newrelic==9.5.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-django-utils +packaging==23.2 + # via pytest +pbr==6.0.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # stevedore +pendulum==3.0.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # simple-salesforce +platformdirs==4.1.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # zeep +pluggy==1.3.0 + # via pytest +protobuf==4.25.2 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-api-core + # googleapis-common-protos +psutil==5.9.8 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-django-utils +pyasn1==0.5.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-auth +pycparser==2.21 + # via + # -r scripts/user_retirement/requirements/base.txt + # cffi +pyjwt==2.8.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-rest-api-client + # simple-salesforce +pynacl==1.5.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-django-utils +pyparsing==3.1.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # httplib2 +pytest==7.4.4 + # via -r scripts/user_retirement/requirements/testing.in +python-dateutil==2.8.2 + # via + # -r scripts/user_retirement/requirements/base.txt + # botocore + # moto + # pendulum + # time-machine +pytz==2023.3.post1 + # via + # -r scripts/user_retirement/requirements/base.txt + # jenkinsapi + # zeep +pyyaml==6.0.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # responses +requests==2.31.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-rest-api-client + # google-api-core + # jenkinsapi + # moto + # requests-file + # requests-mock + # requests-toolbelt + # responses + # simple-salesforce + # slumber + # zeep +requests-file==1.5.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # zeep +requests-mock==1.11.0 + # via -r scripts/user_retirement/requirements/testing.in +requests-toolbelt==1.0.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # zeep +responses==0.24.1 + # via + # -r scripts/user_retirement/requirements/testing.in + # moto +rsa==4.9 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-auth +s3transfer==0.10.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # boto3 +simple-salesforce==1.12.5 + # via -r scripts/user_retirement/requirements/base.txt +simplejson==3.19.2 + # via -r scripts/user_retirement/requirements/base.txt +six==1.16.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # isodate + # jenkinsapi + # python-dateutil + # requests-file + # requests-mock +slumber==0.7.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-rest-api-client +sqlparse==0.4.4 + # via + # -r scripts/user_retirement/requirements/base.txt + # django +stevedore==5.1.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # edx-django-utils +time-machine==2.13.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # pendulum +tomli==2.0.1 + # via pytest +typing-extensions==4.9.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # asgiref +tzdata==2023.4 + # via + # -r scripts/user_retirement/requirements/base.txt + # pendulum +unicodecsv==0.14.1 + # via -r scripts/user_retirement/requirements/base.txt +uritemplate==4.1.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # google-api-python-client +urllib3==1.26.18 + # via + # -r scripts/user_retirement/requirements/base.txt + # botocore + # requests + # responses +werkzeug==3.0.1 + # via moto +xmltodict==0.13.0 + # via moto +zeep==4.2.1 + # via + # -r scripts/user_retirement/requirements/base.txt + # simple-salesforce +zipp==3.17.0 + # via + # -r scripts/user_retirement/requirements/base.txt + # importlib-resources diff --git a/scripts/user_retirement/retire_one_learner.py b/scripts/user_retirement/retire_one_learner.py new file mode 100755 index 000000000000..2d298c0729b4 --- /dev/null +++ b/scripts/user_retirement/retire_one_learner.py @@ -0,0 +1,224 @@ +#! /usr/bin/env python3 +""" +Command-line script to drive the user retirement workflow for a single user + +To run this script you will need a username to run against and a YAML config file in the format: + +client_id: +client_secret: +base_urls: + lms: http://localhost:18000/ + ecommerce: http://localhost:18130/ + credentials: http://localhost:18150/ + demographics: http://localhost:18360/ +retirement_pipeline: + - ['RETIRING_CREDENTIALS', 'CREDENTIALS_COMPLETE', 'CREDENTIALS', 'retire_learner'] + - ['RETIRING_ECOM', 'ECOM_COMPLETE', 'ECOMMERCE', 'retire_learner'] + - ['RETIRING_DEMOGRAPHICS', 'DEMOGRAPHICS_COMPLETE', 'DEMOGRAPHICS', 'retire_learner'] + - ['RETIRING_LICENSE_MANAGER', 'LICENSE_MANAGER_COMPLETE', 'LICENSE_MANAGER', 'retire_learner'] + - ['RETIRING_FORUMS', 'FORUMS_COMPLETE', 'LMS', 'retirement_retire_forum'] + - ['RETIRING_EMAIL_LISTS', 'EMAIL_LISTS_COMPLETE', 'LMS', 'retirement_retire_mailings'] + - ['RETIRING_ENROLLMENTS', 'ENROLLMENTS_COMPLETE', 'LMS', 'retirement_unenroll'] + - ['RETIRING_LMS', 'LMS_COMPLETE', 'LMS', 'retirement_lms_retire'] +""" + +import logging +import sys +from functools import partial +from os import path +from time import time + +import click + +# Add top-level project path to sys.path before importing scripts code +sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..'))) + +from scripts.user_retirement.utils.exception import HttpDoesNotExistException +# pylint: disable=wrong-import-position +from scripts.user_retirement.utils.helpers import ( + _config_or_exit, + _fail, + _fail_exception, + _get_error_str_from_exception, + _log, + _setup_all_apis_or_exit +) + +# Return codes for various fail cases +ERR_SETUP_FAILED = -1 +ERR_USER_AT_END_STATE = -2 +ERR_USER_IN_WORKING_STATE = -3 +ERR_WHILE_RETIRING = -4 +ERR_BAD_LEARNER = -5 +ERR_UNKNOWN_STATE = -6 +ERR_BAD_CONFIG = -7 + +SCRIPT_SHORTNAME = 'Learner Retirement' +LOG = partial(_log, SCRIPT_SHORTNAME) +FAIL = partial(_fail, SCRIPT_SHORTNAME) +FAIL_EXCEPTION = partial(_fail_exception, SCRIPT_SHORTNAME) +CONFIG_OR_EXIT = partial(_config_or_exit, FAIL_EXCEPTION, ERR_BAD_CONFIG) +SETUP_ALL_APIS_OR_EXIT = partial(_setup_all_apis_or_exit, FAIL_EXCEPTION, ERR_SETUP_FAILED) + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# "Magic" states with special meaning, these are required to be in LMS +START_STATE = 'PENDING' +ERROR_STATE = 'ERRORED' +COMPLETE_STATE = 'COMPLETE' +ABORTED_STATE = 'ABORTED' +END_STATES = (ERROR_STATE, ABORTED_STATE, COMPLETE_STATE) + +# We'll store the access token here once retrieved +AUTH_HEADER = {} + + +def _get_learner_state_index_or_exit(learner, config): + """ + Returns the index in the ALL_STATES retirement state list, validating that it is in + an appropriate state to work on. + """ + try: + learner_state = learner['current_state']['state_name'] + learner_state_index = config['all_states'].index(learner_state) + + if learner_state in END_STATES: + FAIL(ERR_USER_AT_END_STATE, 'User already in end state: {}'.format(learner_state)) + + if learner_state in config['working_states']: + FAIL(ERR_USER_IN_WORKING_STATE, 'User is already in a working state! {}'.format(learner_state)) + + return learner_state_index + except KeyError: + FAIL(ERR_BAD_LEARNER, 'Bad learner response missing current_state or state_name: {}'.format(learner)) + except ValueError: + FAIL(ERR_UNKNOWN_STATE, 'Unknown learner retirement state for learner: {}'.format(learner)) + + +def _config_retirement_pipeline(config): + """ + Organizes the pipeline and populate the various state types + """ + # List of states where an API call is currently in progress + retirement_pipeline = config['retirement_pipeline'] + config['working_states'] = [state[0] for state in retirement_pipeline] + + # Create the full list of all of our states + config['all_states'] = [START_STATE] + for working in config['retirement_pipeline']: + config['all_states'].append(working[0]) + config['all_states'].append(working[1]) + for end in END_STATES: + config['all_states'].append(end) + + +def _get_learner_and_state_index_or_exit(config, username): + """ + Double-checks the current learner state, contacting LMS, and maps that state to its + index in the pipeline. Exits out if the learner is in an invalid state or not found + in LMS. + """ + try: + learner = config['LMS'].get_learner_retirement_state(username) + learner_state_index = _get_learner_state_index_or_exit(learner, config) + return learner, learner_state_index + except HttpDoesNotExistException: + FAIL(ERR_BAD_LEARNER, 'Learner {} not found. Please check that the learner is present in ' + 'UserRetirementStatus, is not already retired, ' + 'and is in an appropriate state to be acted upon.'.format(username)) + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_SETUP_FAILED, 'Unexpected error fetching user state!', str(exc)) + + +def _get_ecom_segment_id(config, learner): + """ + Calls Ecommerce to get the ecom-specific Segment tracking id that we need to retire. + This is only available from Ecommerce, unfortunately, and makes more sense to handle + here than to pass all of the config down to SegmentApi. + """ + try: + return config['ECOMMERCE'].get_tracking_key(learner) + except HttpDoesNotExistException: + LOG('Learner {} not found in Ecommerce. Setting Ecommerce Segment ID to None'.format(learner)) + return None + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_SETUP_FAILED, 'Unexpected error fetching Ecommerce tracking id!', str(exc)) + + +@click.command("retire_learner") +@click.option( + '--username', + help='The original username of the user to retire' +) +@click.option( + '--config_file', + help='File in which YAML config exists that overrides all other params.' +) +def retire_learner( + username, + config_file +): + """ + Retrieves a JWT token as the retirement service learner, then performs the retirement process as + defined in WORKING_STATE_ORDER + """ + LOG('Starting learner retirement for {} using config file {}'.format(username, config_file)) + + if not config_file: + FAIL(ERR_BAD_CONFIG, 'No config file passed in.') + + config = CONFIG_OR_EXIT(config_file) + _config_retirement_pipeline(config) + SETUP_ALL_APIS_OR_EXIT(config) + + learner, learner_state_index = _get_learner_and_state_index_or_exit(config, username) + + if config.get('fetch_ecommerce_segment_id', False): + learner['ecommerce_segment_id'] = _get_ecom_segment_id(config, learner) + + start_state = None + try: + for start_state, end_state, service, method in config['retirement_pipeline']: + # Skip anything that has already been done + if config['all_states'].index(start_state) < learner_state_index: + LOG('State {} completed in previous run, skipping'.format(start_state)) + continue + + LOG('Starting state {}'.format(start_state)) + + config['LMS'].update_learner_retirement_state(username, start_state, 'Starting: {}'.format(start_state)) + + # This does the actual API call + start_time = time() + response = getattr(config[service], method)(learner) + end_time = time() + + LOG('State {} completed in {} seconds'.format(start_state, end_time - start_time)) + + config['LMS'].update_learner_retirement_state( + username, + end_state, + 'Ending: {} with response:\n{}'.format(end_state, response) + ) + + learner_state_index += 1 + + LOG('Progressing to state {}'.format(end_state)) + + config['LMS'].update_learner_retirement_state(username, COMPLETE_STATE, 'Learner retirement complete.') + LOG('Retirement complete for learner {}'.format(username)) + except Exception as exc: # pylint: disable=broad-except + exc_msg = _get_error_str_from_exception(exc) + + try: + LOG('Error in retirement state {}: {}'.format(start_state, exc_msg)) + config['LMS'].update_learner_retirement_state(username, ERROR_STATE, exc_msg) + except Exception as update_exc: # pylint: disable=broad-except + LOG('Critical error attempting to change learner state to ERRORED: {}'.format(update_exc)) + + FAIL_EXCEPTION(ERR_WHILE_RETIRING, 'Error encountered in state "{}"'.format(start_state), exc) + + +if __name__ == '__main__': + # pylint: disable=unexpected-keyword-arg, no-value-for-parameter + retire_learner(auto_envvar_prefix='RETIREMENT') diff --git a/scripts/user_retirement/retirement_archive_and_cleanup.py b/scripts/user_retirement/retirement_archive_and_cleanup.py new file mode 100644 index 000000000000..d44a5e9fa8c9 --- /dev/null +++ b/scripts/user_retirement/retirement_archive_and_cleanup.py @@ -0,0 +1,329 @@ +#! /usr/bin/env python3 +""" +Command-line script to bulk archive and cleanup retired learners from LMS +""" + +import datetime +import gzip +import json +import logging +import sys +import time +from functools import partial +from os import path + +import backoff +import boto3 +import click +from botocore.exceptions import BotoCoreError, ClientError +from six import text_type + +# Add top-level project path to sys.path before importing scripts code +sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..'))) + +# pylint: disable=wrong-import-position +from scripts.user_retirement.utils.helpers import _config_or_exit, _fail, _fail_exception, _log, _setup_lms_api_or_exit + +SCRIPT_SHORTNAME = 'Archive and Cleanup' + +# Return codes for various fail cases +ERR_NO_CONFIG = -1 +ERR_BAD_CONFIG = -2 +ERR_FETCHING = -3 +ERR_ARCHIVING = -4 +ERR_DELETING = -5 +ERR_SETUP_FAILED = -5 +ERR_BAD_CLI_PARAM = -6 + +LOG = partial(_log, SCRIPT_SHORTNAME) +FAIL = partial(_fail, SCRIPT_SHORTNAME) +FAIL_EXCEPTION = partial(_fail_exception, SCRIPT_SHORTNAME) +CONFIG_OR_EXIT = partial(_config_or_exit, FAIL_EXCEPTION, ERR_BAD_CONFIG) +SETUP_LMS_OR_EXIT = partial(_setup_lms_api_or_exit, FAIL, ERR_SETUP_FAILED) + +DELAY = 10 + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logging.getLogger('boto').setLevel(logging.INFO) + + +def _fetch_learners_to_archive_or_exit(config, start_date, end_date, initial_state): + """ + Makes the call to fetch learners to be cleaned up, returns the list of learners or exits. + """ + LOG('Fetching users in state {} created from {} to {}'.format(initial_state, start_date, end_date)) + try: + learners = config['LMS'].get_learners_by_date_and_status(initial_state, start_date, end_date) + LOG('Successfully fetched {} learners'.format(str(len(learners)))) + return learners + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_FETCHING, 'Unexpected error occurred fetching users to update!', exc) + + +def _batch_learners(learners=None, batch_size=None): + """ + To avoid potentially overwheling the LMS with a large number of user retirements to + delete, create a list of smaller batches of users to iterate over. This has the + added benefit of reducing the amount of user retirement archive requests that can + get into a bad state should this script experience an error. + + Args: + learners (list): List of learners to portion into smaller batches (lists) + batch_size (int): The number of learners to portion into each batch. If this + parameter is not supplied, this function will return one batch containing + all of the learners supplied to it. + """ + if batch_size: + return [ + learners[i:i + batch_size] for i, _ in list(enumerate(learners))[::batch_size] + ] + else: + return [learners] + + +def _on_s3_backoff(details): + """ + Callback that is called when backoff... backs off + """ + LOG("Backing off {wait:0.1f} seconds after {tries} tries calling function {target}".format(**details)) + + +@backoff.on_exception( + backoff.expo, + ( + ClientError, + BotoCoreError + ), + on_backoff=lambda details: _on_s3_backoff(details), # pylint: disable=unnecessary-lambda, + max_time=120, # 2 minutes +) +def _upload_to_s3(config, filename, dry_run=False): + """ + Upload the archive file to S3 + """ + try: + datestr = datetime.datetime.now().strftime('%Y/%m/') + s3 = boto3.resource('s3') + bucket_name = config['s3_archive']['bucket_name'] + # Dry runs of this script should only generate the retirement archive file, not push it to s3. + bucket = s3.Bucket(bucket_name) + key = 'raw/' + datestr + filename + if dry_run: + LOG('Dry run. Skipping the step to upload data to {}'.format(key)) + return + else: + bucket.upload_file(filename, key) + LOG('Successfully uploaded retirement data to {}'.format(key)) + except Exception as exc: + LOG(text_type(exc)) + raise + + +def _format_datetime_for_athena(timestamp): + """ + Takes a JSON serialized timestamp string and returns a format of it that is queryable as a datetime in Athena + """ + return timestamp.replace('T', ' ').rstrip('Z') + + +def _archive_retirements_or_exit(config, learners, dry_run=False): + """ + Creates an archive file with all of the retirements and uploads it to S3 + + The format of learners from LMS should be a list of these: + { + 'id': 46, # This is the UserRetirementStatus ID! + 'user': + { + 'id': 5213599, # THIS is the LMS User ID + 'username': 'retired__user_88ad587896920805c26041a2e75c767c75471ee9', + 'email': 'retired__user_d08919da55a0e03c032425567e4a33e860488a96@retired.invalid', + 'profile': + { + 'id': 2842382, + 'name': '' + } + }, + 'current_state': + { + 'id': 41, + 'state_name': 'COMPLETE', + 'state_execution_order': 13 + }, + 'last_state': { + 'id': 1, + 'state_name': 'PENDING', + 'state_execution_order': 1 + }, + 'created': '2018-10-18T20:35:52.349757Z', # This is the UserRetirementStatus creation date + 'modified': '2018-10-18T20:35:52.350050Z', # This is the UserRetirementStatus last touched date + 'original_username': 'retirement_test', + 'original_email': 'orig@foo.invalid', + 'original_name': 'Retirement Test', + 'retired_username': 'retired__user_88ad587896920805c26041a2e75c767c75471ee9', + 'retired_email': 'retired__user_d08919da55a0e03c032425567e4a33e860488a96@retired.invalid' + } + """ + LOG('Archiving retirements for {} learners to {}'.format(len(learners), config['s3_archive']['bucket_name'])) + try: + now = _get_utc_now() + filename = 'retirement_archive_{}.json.gz'.format(now.strftime('%Y_%d_%m_%H_%M_%S')) + LOG('Creating retirement archive file {}'.format(filename)) + + # The file format is one JSON object per line with the newline as a separator. This allows for + # easy queries via AWS Athena if we need to confirm learner deletion. + with gzip.open(filename, 'wt') as out: + for learner in learners: + user = { + 'user_id': learner['user']['id'], + 'original_username': learner['original_username'], + 'original_email': learner['original_email'], + 'original_name': learner['original_name'], + 'retired_username': learner['retired_username'], + 'retired_email': learner['retired_email'], + 'retirement_request_date': _format_datetime_for_athena(learner['created']), + 'last_modified_date': _format_datetime_for_athena(learner['modified']), + } + json.dump(user, out) + out.write("\n") + if dry_run: + LOG('Dry run. Logging the contents of {} for debugging'.format(filename)) + with gzip.open(filename, 'r') as archive_file: + for line in archive_file.readlines(): + LOG(line) + _upload_to_s3(config, filename, dry_run) + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_ARCHIVING, 'Unexpected error occurred archiving retirements!', exc) + + +def _cleanup_retirements_or_exit(config, learners): + """ + Bulk deletes the retirements for this run + """ + LOG('Cleaning up retirements for {} learners'.format(len(learners))) + try: + usernames = [l['original_username'] for l in learners] + config['LMS'].bulk_cleanup_retirements(usernames) + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_DELETING, 'Unexpected error occurred deleting retirements!', exc) + + +def _get_utc_now(): + """ + Helper function only used to make unit test mocking/patching easier. + """ + return datetime.datetime.utcnow() + + +@click.command("archive_and_cleanup") +@click.option( + '--config_file', + help='YAML file that contains retirement-related configuration for this environment.' +) +@click.option( + '--cool_off_days', + help='Number of days a retirement should exist before being archived and deleted.', + type=int, + default=37 # 7 days before retirement, 30 after +) +@click.option( + '--dry_run', + help=''' + Should this script be run in a dry-run mode, in which generated retirement + archive files are not pushed to s3 and retirements are not cleaned up in the LMS + ''', + type=bool, + default=False +) +@click.option( + '--start_date', + help=''' + Start of window used to select user retirements for archival. Only user retirements + added to the retirement queue after this date will be processed. + ''', + type=click.DateTime(formats=['%Y-%m-%d']) +) +@click.option( + '--end_date', + help=''' + End of window used to select user retirments for archival. Only user retirments + added to the retirement queue before this date will be processed. In the case that + this date is more recent than the value specified in the `cool_off_days` parameter, + an error will be thrown. If this parameter is not used, the script will default to + using an end_date based upon the `cool_off_days` parameter. + ''', + type=click.DateTime(formats=['%Y-%m-%d']) +) +@click.option( + '--batch_size', + help='Number of user retirements to process', + type=int +) +def archive_and_cleanup(config_file, cool_off_days, dry_run, start_date, end_date, batch_size): + """ + Cleans up UserRetirementStatus rows in LMS by: + 1- Getting all rows currently in COMPLETE that were created --cool_off_days ago or more, + unless a specific timeframe is specified + 2- Archiving them to S3 in an Athena-queryable format + 3- Deleting them from LMS (by username) + """ + try: + LOG('Starting bulk update script: Config: {}'.format(config_file)) + + if not config_file: + FAIL(ERR_NO_CONFIG, 'No config file passed in.') + + config = CONFIG_OR_EXIT(config_file) + SETUP_LMS_OR_EXIT(config) + + if not start_date: + # This date is just a bogus "earliest possible value" since the call requires one + start_date = datetime.datetime.strptime('2018-01-01', '%Y-%m-%d') + if end_date: + if end_date > _get_utc_now() - datetime.timedelta(days=cool_off_days): + FAIL(ERR_BAD_CLI_PARAM, 'End date cannot occur within the cool_off_days period') + else: + # Set an end_date of `cool_off_days` days before the time that this script is run + end_date = _get_utc_now() - datetime.timedelta(days=cool_off_days) + + if start_date >= end_date: + FAIL(ERR_BAD_CLI_PARAM, 'Conflicting start and end dates passed on CLI') + + LOG( + 'Fetching retirements for learners that have a COMPLETE status and were created ' + 'between {} and {}.'.format( + start_date, end_date + ) + ) + learners = _fetch_learners_to_archive_or_exit( + config, start_date, end_date, 'COMPLETE' + ) + + learners_to_process = _batch_learners(learners, batch_size) + num_batches = len(learners_to_process) + + if learners_to_process: + for index, batch in enumerate(learners_to_process): + LOG( + 'Processing batch {} out of {} of user retirement requests'.format( + str(index + 1), str(num_batches) + ) + ) + _archive_retirements_or_exit(config, batch, dry_run) + + if dry_run: + LOG('This is a dry-run. Exiting before any retirements are cleaned up') + else: + _cleanup_retirements_or_exit(config, batch) + LOG('Archive and cleanup complete for batch #{}'.format(str(index + 1))) + time.sleep(DELAY) + else: + LOG('No learners found!') + except Exception as exc: + LOG(text_type(exc)) + raise + + +if __name__ == '__main__': + # pylint: disable=unexpected-keyword-arg, no-value-for-parameter + archive_and_cleanup(auto_envvar_prefix='RETIREMENT') diff --git a/scripts/user_retirement/retirement_bulk_status_update.py b/scripts/user_retirement/retirement_bulk_status_update.py new file mode 100755 index 000000000000..da8b05879373 --- /dev/null +++ b/scripts/user_retirement/retirement_bulk_status_update.py @@ -0,0 +1,146 @@ +#! /usr/bin/env python3 +""" +Command-line script to bulk update retirement states in LMS +""" + +import logging +import sys +from datetime import datetime +from functools import partial +from os import path + +import click +from six import text_type + +# Add top-level project path to sys.path before importing scripts code +sys.path.append(path.abspath(path.join(path.dirname(__file__), '../..'))) + +# pylint: disable=wrong-import-position +from scripts.user_retirement.utils.helpers import _config_or_exit, _fail, _fail_exception, _log, _setup_lms_api_or_exit + +SCRIPT_SHORTNAME = 'Bulk Status' + +# Return codes for various fail cases +ERR_NO_CONFIG = -1 +ERR_BAD_CONFIG = -2 +ERR_FETCHING = -3 +ERR_UPDATING = -4 +ERR_SETUP_FAILED = -5 + +LOG = partial(_log, SCRIPT_SHORTNAME) +FAIL = partial(_fail, SCRIPT_SHORTNAME) +FAIL_EXCEPTION = partial(_fail_exception, SCRIPT_SHORTNAME) +CONFIG_OR_EXIT = partial(_config_or_exit, FAIL_EXCEPTION, ERR_BAD_CONFIG) +SETUP_LMS_OR_EXIT = partial(_setup_lms_api_or_exit, FAIL, ERR_SETUP_FAILED) + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + + +def validate_dates(_, __, value): + """ + Click input validator for date options. + - Validates string format + - Transforms the string into a datetime.Date object + - Validates the date is less than or equal to today + - Returns the Date, or raises a click.BadParameter + """ + try: + date = datetime.strptime(value, '%Y-%m-%d').date() + if date > datetime.now().date(): + raise ValueError() + return date + except ValueError: + raise click.BadParameter('Dates need to be in the format of YYYY-MM-DD and today or earlier.') + + +def _fetch_learners_to_update_or_exit(config, start_date, end_date, initial_state): + """ + Makes the call to fetch learners to be bulk updated, returns the list of learners + or exits. + """ + LOG('Fetching users in state {} created from {} to {}'.format(initial_state, start_date, end_date)) + try: + return config['LMS'].get_learners_by_date_and_status(initial_state, start_date, end_date) + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_FETCHING, 'Unexpected error occurred fetching users to update!', exc) + + +def _update_learners_or_exit(config, learners, new_state=None, rewind_state=False): + """ + Iterates the list of learners, setting each to the new state. On any error + it will exit the script. If rewind_state is set to True then the learner + will be reset to their previous state. + """ + if (not new_state and not rewind_state) or (rewind_state and new_state): + FAIL(ERR_BAD_CONFIG, "You must specify either the boolean rewind_state or a new state to set learners to.") + LOG('Updating {} learners to {}'.format(len(learners), new_state)) + try: + for learner in learners: + if rewind_state: + new_state = learner['last_state']['state_name'] + config['LMS'].update_learner_retirement_state( + learner['original_username'], + new_state, + 'Force updated via retirement_bulk_status_update script', + force=True + ) + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_UPDATING, 'Unexpected error occurred updating users!', exc) + + +@click.command("update_statuses") +@click.option( + '--config_file', + help='YAML file that contains retirement-related configuration for this environment.' +) +@click.option( + '--initial_state', + help='Find learners in this retirement state. Use the state name ex: PENDING, COMPLETE' +) +@click.option( + '--new_state', + help='Set any found learners to this new state. Use the state name ex: PENDING, COMPLETE', + default=None +) +@click.option( + '--start_date', + callback=validate_dates, + help='(YYYY-MM-DD) Earliest creation date for retirements to act on.' +) +@click.option( + '--end_date', + callback=validate_dates, + help='(YYYY-MM-DD) Latest creation date for retirements to act on.' +) +@click.option( + '--rewind-state', + help='Rewinds to the last_state for learners. Useful for resetting ERRORED users', + default=False, + is_flag=True +) +def update_statuses(config_file, initial_state, new_state, start_date, end_date, rewind_state): + """ + Bulk-updates user retirement statuses which are in the specified state -and- retirement was + requested between a start date and end date. + """ + try: + LOG('Starting bulk update script: Config: {}'.format(config_file)) + + if not config_file: + FAIL(ERR_NO_CONFIG, 'No config file passed in.') + + config = CONFIG_OR_EXIT(config_file) + SETUP_LMS_OR_EXIT(config) + + learners = _fetch_learners_to_update_or_exit(config, start_date, end_date, initial_state) + _update_learners_or_exit(config, learners, new_state, rewind_state) + + LOG('Bulk update complete') + except Exception as exc: + print(text_type(exc)) + raise + + +if __name__ == '__main__': + # pylint: disable=unexpected-keyword-arg, no-value-for-parameter + update_statuses(auto_envvar_prefix='RETIREMENT') diff --git a/scripts/user_retirement/retirement_partner_report.py b/scripts/user_retirement/retirement_partner_report.py new file mode 100755 index 000000000000..53bcfa685660 --- /dev/null +++ b/scripts/user_retirement/retirement_partner_report.py @@ -0,0 +1,404 @@ +#! /usr/bin/env python3 +# coding=utf-8 + +""" +Command-line script to drive the partner reporting part of the retirement process +""" + +import logging +import os +import sys +import unicodedata +from collections import OrderedDict, defaultdict +from datetime import date +from functools import partial + +import click +import unicodecsv as csv +from six import text_type + +# Add top-level project path to sys.path before importing scripts code +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +from scripts.user_retirement.utils.thirdparty_apis.google_api import DriveApi # pylint: disable=wrong-import-position +# pylint: disable=wrong-import-position +from scripts.user_retirement.utils.helpers import ( + _config_with_drive_or_exit, + _fail, + _fail_exception, + _log, + _setup_lms_api_or_exit +) + +# Return codes for various fail cases +ERR_SETUP_FAILED = -1 +ERR_FETCHING_LEARNERS = -2 +ERR_NO_CONFIG = -3 +ERR_NO_SECRETS = -4 +ERR_NO_OUTPUT_DIR = -5 +ERR_BAD_CONFIG = -6 +ERR_BAD_SECRETS = -7 +ERR_UNKNOWN_ORG = -8 +ERR_REPORTING = -9 +ERR_DRIVE_UPLOAD = -10 +ERR_CLEANUP = -11 +ERR_DRIVE_LISTING = -12 + +SCRIPT_SHORTNAME = 'Partner report' +LOG = partial(_log, SCRIPT_SHORTNAME) +FAIL = partial(_fail, SCRIPT_SHORTNAME) +FAIL_EXCEPTION = partial(_fail_exception, SCRIPT_SHORTNAME) +CONFIG_WITH_DRIVE_OR_EXIT = partial(_config_with_drive_or_exit, FAIL_EXCEPTION, ERR_BAD_CONFIG, ERR_BAD_SECRETS) +SETUP_LMS_OR_EXIT = partial(_setup_lms_api_or_exit, FAIL, ERR_SETUP_FAILED) + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# Prefix which starts all generated report filenames. +REPORTING_FILENAME_PREFIX = 'user_retirement' + +# We'll store the access token here once retrieved +AUTH_HEADER = {} + +# This text template will be the comment body for all new CSV uploads. The +# following format variables need to be provided: +# tags: space delimited list of google user tags, e.g. "+user1@gmail.com +user2@gmail.com" +NOTIFICATION_MESSAGE_TEMPLATE = """ +Hello from edX. Dear {tags}, a new report listing the learners enrolled in your institution’s courses on edx.org that have requested deletion of their edX account and associated personal data within the last week has been published to Google Drive. Please access your folder to see the latest report. +""".strip() + +LEARNER_CREATED_KEY = 'created' # This key is currently required to exist in the learner +LEARNER_ORIGINAL_USERNAME_KEY = 'original_username' # This key is currently required to exist in the learner +ORGS_KEY = 'orgs' +ORGS_CONFIG_KEY = 'orgs_config' +ORGS_CONFIG_ORG_KEY = 'org' +ORGS_CONFIG_FIELD_HEADINGS_KEY = 'field_headings' +ORGS_CONFIG_LEARNERS_KEY = 'learners' + +# Default field headings for the CSV file +DEFAULT_FIELD_HEADINGS = ['user_id', 'original_username', 'original_email', 'original_name', 'deletion_completed'] + + +def _check_all_learner_orgs_or_exit(config, learners): + """ + Checks all learners and their orgs, ensuring that each org has a mapping to a partner Drive folder. + If any orgs are missing a mapping, fails after printing the mismatched orgs. + """ + # Loop through all learner orgs, checking for their mappings. + mismatched_orgs = set() + for learner in learners: + # Check the orgs with standard fields + if ORGS_KEY in learner: + for org in learner[ORGS_KEY]: + if org not in config['org_partner_mapping']: + mismatched_orgs.add(org) + + # Check the orgs with custom configurations (orgs with custom fields) + if ORGS_CONFIG_KEY in learner: + for org_config in learner[ORGS_CONFIG_KEY]: + org_name = org_config[ORGS_CONFIG_ORG_KEY] + if org_name not in config['org_partner_mapping']: + mismatched_orgs.add(org_name) + if mismatched_orgs: + FAIL( + ERR_UNKNOWN_ORG, + 'Partners for organizations {} do not exist in configuration.'.format(text_type(mismatched_orgs)) + ) + + +def _get_orgs_and_learners_or_exit(config): + """ + Contacts LMS to get the list of learners to report on and the orgs they belong to. + Reformats them into dicts with keys of the orgs and lists of learners as the value + and returns a tuple of that dict plus a list of all of the learner usernames. + """ + try: + LOG('Retrieving all learners on which to report from the LMS.') + learners = config['LMS'].retirement_partner_report() + LOG('Retrieved {} learners from the LMS.'.format(len(learners))) + + _check_all_learner_orgs_or_exit(config, learners) + + orgs = defaultdict() + usernames = [] + + # Organize the learners, create separate dicts per partner, making sure each partner is in the mapping. + # Learners can appear in more than one dict. It is assumed that each org has 1 and only 1 set of field headings. + for learner in learners: + usernames.append({'original_username': learner[LEARNER_ORIGINAL_USERNAME_KEY]}) + + # Use the datetime upon which the record was 'created' in the partner reporting queue + # as the approximate time upon which user retirement was completed ('deletion_completed') + # for the record's user. + learner['deletion_completed'] = learner[LEARNER_CREATED_KEY] + + # Create a list of orgs who should be notified about this user + if ORGS_KEY in learner: + for org_name in learner[ORGS_KEY]: + reporting_org_names = config['org_partner_mapping'][org_name] + _add_reporting_org(orgs, reporting_org_names, DEFAULT_FIELD_HEADINGS, learner) + + # Check for orgs with custom fields + if ORGS_CONFIG_KEY in learner: + for org_config in learner[ORGS_CONFIG_KEY]: + org_name = org_config[ORGS_CONFIG_ORG_KEY] + org_headings = org_config[ORGS_CONFIG_FIELD_HEADINGS_KEY] + reporting_org_names = config['org_partner_mapping'][org_name] + _add_reporting_org(orgs, reporting_org_names, org_headings, learner) + + return orgs, usernames + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_FETCHING_LEARNERS, 'Unexpected exception occurred!', exc) + + +def _add_reporting_org(orgs, org_names, org_headings, learner): + """ + Add the learner to the org + """ + for org_name in org_names: + # Create the org, if necessary + orgs[org_name] = orgs.get( + org_name, + { + ORGS_CONFIG_FIELD_HEADINGS_KEY: org_headings, + ORGS_CONFIG_LEARNERS_KEY: [] + } + ) + + # Add the learner to the list of learners in the org + orgs[org_name][ORGS_CONFIG_LEARNERS_KEY].append(learner) + + +def _generate_report_files_or_exit(config, report_data, output_dir): + """ + Spins through the partners, creating a single CSV file for each + """ + # We'll store all of the partner to file links here so we can be sure all files generated successfully + # before trying to push to Google, minimizing the cases where we might have to overwrite files + # already up there. + partner_filenames = {} + + for partner_name in report_data: + try: + partner = report_data[partner_name] + partner_headings = partner[ORGS_CONFIG_FIELD_HEADINGS_KEY] + partner_learners = partner[ORGS_CONFIG_LEARNERS_KEY] + outfile = _generate_report_file_or_exit(config, output_dir, partner_name, partner_headings, + partner_learners) + partner_filenames[partner_name] = outfile + LOG('Report complete for partner {}'.format(partner_name)) + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_REPORTING, 'Error reporting retirement for partner {}'.format(partner_name), exc) + + return partner_filenames + + +def _generate_report_file_or_exit(config, output_dir, partner, field_headings, field_values): + """ + Create a CSV file for the partner + """ + LOG('Starting report for partner {}: {} learners to add. Field headings are {}'.format( + partner, + len(field_values), + field_headings + )) + + outfile = os.path.join(output_dir, '{}_{}_{}_{}.csv'.format( + REPORTING_FILENAME_PREFIX, config['partner_report_platform_name'], partner, date.today().isoformat() + )) + + # If there is already a file for this date, assume it is bad and replace it + try: + os.remove(outfile) + except OSError: + pass + + with open(outfile, 'wb') as f: + writer = csv.DictWriter(f, field_headings, dialect=csv.excel, extrasaction='ignore') + writer.writeheader() + writer.writerows(field_values) + + return outfile + + +def _config_drive_folder_map_or_exit(config): + """ + Lists folders under our top level parent for this environment and returns + a dict of {partner name: folder id}. Partner names should match the values + in config['org_partner_mapping'] + """ + drive = DriveApi(config['google_secrets_file']) + + try: + LOG('Attempting to find all partner sub-directories on Drive.') + folders = drive.walk_files( + config['drive_partners_folder'], + mimetype='application/vnd.google-apps.folder', + recurse=False + ) + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_DRIVE_LISTING, 'Finding partner directories on Drive failed.', exc) + + if not folders: + FAIL(ERR_DRIVE_LISTING, 'Finding partner directories on Drive failed. Check your permissions.') + + # As in _config_or_exit we force normalize the unicode here to make sure the keys + # match. Otherwise the name we get back from Google won't match what's in the YAML config. + config['partner_folder_mapping'] = OrderedDict() + for folder in folders: + folder['name'] = unicodedata.normalize('NFKC', text_type(folder['name'])) + config['partner_folder_mapping'][folder['name']] = folder['id'] + + +def _push_files_to_google(config, partner_filenames): + """ + Copy the file to Google drive for this partner + + Returns: + List of file IDs for the uploaded csv files. + """ + # First make sure we have Drive folders for all partners + failed_partners = [] + for partner in partner_filenames: + if partner not in config['partner_folder_mapping']: + failed_partners.append(partner) + + if failed_partners: + FAIL(ERR_BAD_CONFIG, 'These partners have retiring learners, but no Drive folder: {}'.format(failed_partners)) + + file_ids = {} + drive = DriveApi(config['google_secrets_file']) + for partner in partner_filenames: + # This is populated on the fly in _config_drive_folder_map_or_exit + folder_id = config['partner_folder_mapping'][partner] + file_id = None + with open(partner_filenames[partner], 'rb') as f: + try: + drive_filename = os.path.basename(partner_filenames[partner]) + LOG('Attempting to upload {} to {} Drive folder.'.format(drive_filename, partner)) + file_id = drive.create_file_in_folder(folder_id, drive_filename, f, "text/csv") + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_DRIVE_UPLOAD, 'Drive upload failed for: {}'.format(drive_filename), exc) + file_ids[partner] = file_id + return file_ids + + +def _add_comments_to_files(config, file_ids): + """ + Add comments to the uploaded csv files, triggering email notification. + + Args: + file_ids (dict): Mapping of partner names to Drive file IDs corresponding to the newly uploaded csv files. + """ + drive = DriveApi(config['google_secrets_file']) + + partner_folders_to_permissions = drive.list_permissions_for_files( + config['partner_folder_mapping'].values(), + fields='emailAddress', + ) + + # create a mapping of partners to a list of permissions dicts: + permissions = { + partner: partner_folders_to_permissions[config['partner_folder_mapping'][partner]] + for partner in file_ids + } + + # throw out all denied addresses, and flatten the permissions dicts to just the email: + external_emails = { + partner: [ + perm['emailAddress'] + for perm in permissions[partner] + if not any( + perm['emailAddress'].lower().endswith(denied_domain.lower()) + for denied_domain in config['denied_notification_domains'] + ) + ] + for partner in permissions + } + + file_ids_and_comments = [] + for partner in file_ids: + if not external_emails[partner]: + LOG( + 'WARNING: could not find a POC for the following partner: "{}". ' + 'Double check the partner folder permissions in Google Drive.' + .format(partner) + ) + else: + tag_string = ' '.join('+' + email for email in external_emails[partner]) + comment_content = NOTIFICATION_MESSAGE_TEMPLATE.format(tags=tag_string) + file_ids_and_comments.append((file_ids[partner], comment_content)) + + try: + LOG('Adding notification comments to uploaded csv files.') + drive.create_comments_for_files(file_ids_and_comments) + except Exception as exc: # pylint: disable=broad-except + # do not fail the script here, since comment errors are non-critical + LOG('WARNING: there was an error adding Google Drive comments to the csv files: {}'.format(exc)) + + +@click.command("generate_report") +@click.option( + '--config_file', + help='YAML file that contains retirement related configuration for this environment.' +) +@click.option( + '--google_secrets_file', + help='JSON file with Google service account credentials for uploading.' +) +@click.option( + '--output_dir', + help='The local directory that the script will write the reports to.' +) +@click.option( + '--comments/--no_comments', + default=True, + help='Do or skip adding notification comments to the reports.' +) +def generate_report(config_file, google_secrets_file, output_dir, comments): + """ + Retrieves a JWT token as the retirement service learner, then performs the reporting process as that user. + + - Accepts the configuration file with all necessary credentials and URLs for a single environment + - Gets the users in the LMS reporting queue and the partners they need to be reported to + - Generates a single report per partner + - Pushes the reports to Google Drive + - On success tells LMS to remove the users who succeeded from the reporting queue + """ + LOG('Starting partner report using config file {} and Google config {}'.format(config_file, google_secrets_file)) + + try: + if not config_file: + FAIL(ERR_NO_CONFIG, 'No config file passed in.') + + if not google_secrets_file: + FAIL(ERR_NO_SECRETS, 'No secrets file passed in.') + + # The Jenkins DSL is supposed to create this path for us + if not output_dir or not os.path.exists(output_dir): + FAIL(ERR_NO_OUTPUT_DIR, 'No output_dir passed in or path does not exist.') + + config = CONFIG_WITH_DRIVE_OR_EXIT(config_file, google_secrets_file) + SETUP_LMS_OR_EXIT(config) + _config_drive_folder_map_or_exit(config) + report_data, all_usernames = _get_orgs_and_learners_or_exit(config) + # If no usernames were returned, then no reports need to be generated. + if all_usernames: + partner_filenames = _generate_report_files_or_exit(config, report_data, output_dir) + + # All files generated successfully, now push them to Google + report_file_ids = _push_files_to_google(config, partner_filenames) + + if comments: + # All files uploaded successfully, now add comments to them to trigger notifications + _add_comments_to_files(config, report_file_ids) + + # Success, tell LMS to remove these users from the queue + config['LMS'].retirement_partner_cleanup(all_usernames) + LOG('All reports completed and uploaded to Google.') + except Exception as exc: # pylint: disable=broad-except + FAIL_EXCEPTION(ERR_CLEANUP, 'Unexpected error occurred! Users may be stuck in the processing state!', exc) + + +if __name__ == '__main__': + # pylint: disable=unexpected-keyword-arg, no-value-for-parameter + generate_report(auto_envvar_prefix='RETIREMENT') diff --git a/scripts/user_retirement/tests/__init__.py b/scripts/user_retirement/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/scripts/user_retirement/tests/mixins.py b/scripts/user_retirement/tests/mixins.py new file mode 100644 index 000000000000..3d4664d4268f --- /dev/null +++ b/scripts/user_retirement/tests/mixins.py @@ -0,0 +1,23 @@ +from urllib.parse import urljoin + +import responses + +from scripts.user_retirement.utils import edx_api + +FAKE_ACCESS_TOKEN = 'THIS_IS_A_JWT' +CONTENT_TYPE = 'application/json' + + +class OAuth2Mixin: + @staticmethod + def mock_access_token_response(status=200): + """ + Mock POST requests to retrieve an access token for this site's service user. + """ + responses.add( + responses.POST, + urljoin('http://localhost:18000/', edx_api.OAUTH_ACCESS_TOKEN_URL), + status=status, + json={'access_token': FAKE_ACCESS_TOKEN, 'expires_in': 60}, + content_type=CONTENT_TYPE + ) diff --git a/scripts/user_retirement/tests/retirement_helpers.py b/scripts/user_retirement/tests/retirement_helpers.py new file mode 100644 index 000000000000..1347743e0c4c --- /dev/null +++ b/scripts/user_retirement/tests/retirement_helpers.py @@ -0,0 +1,166 @@ +# coding=utf-8 + +""" +Common functionality for retirement related tests +""" +import json +import unicodedata +from datetime import datetime + +import yaml + +TEST_RETIREMENT_PIPELINE = [ + ['RETIRING_FORUMS', 'FORUMS_COMPLETE', 'LMS', 'retirement_retire_forum'], + ['RETIRING_EMAIL_LISTS', 'EMAIL_LISTS_COMPLETE', 'LMS', 'retirement_retire_mailings'], + ['RETIRING_ENROLLMENTS', 'ENROLLMENTS_COMPLETE', 'LMS', 'retirement_unenroll'], + ['RETIRING_LMS', 'LMS_COMPLETE', 'LMS', 'retirement_lms_retire'] +] + +TEST_RETIREMENT_END_STATES = [state[1] for state in TEST_RETIREMENT_PIPELINE] +TEST_RETIREMENT_QUEUE_STATES = ['PENDING'] + TEST_RETIREMENT_END_STATES +TEST_RETIREMENT_STATE = 'PENDING' + +FAKE_DATETIME_OBJECT = datetime(2022, 1, 1) +FAKE_DATETIME_STR = '2022-01-01' +FAKE_ORIGINAL_USERNAME = 'foo_username' +FAKE_USERNAMES = [FAKE_ORIGINAL_USERNAME, FAKE_ORIGINAL_USERNAME] +FAKE_RESPONSE_MESSAGE = 'fake response message' +FAKE_USERNAME_MAPPING = [ + {"fake_current_username_1": "fake_desired_username_1"}, + {"fake_current_username_2": "fake_desired_username_2"} +] + +FAKE_ORGS = { + # Make sure unicode names, as they should come in from the yaml config, work + 'org1': [unicodedata.normalize('NFKC', u'TéstX')], + 'org2': ['Org2X'], + 'org3': ['Org3X', 'Org4X'], +} + +TEST_PLATFORM_NAME = 'fakename' + +TEST_DENIED_NOTIFICATION_DOMAINS = { + '@edx.org', + '@partner-reporting-automation.iam.gserviceaccount.com', +} + + +def flatten_partner_list(partner_list): + """ + Flattens a list of lists into a list. + [["Org1X"], ["Org2X"], ["Org3X", "Org4X"]] => ["Org1X", "Org2X", "Org3X", "Org4X"] + """ + return [partner for sublist in partner_list for partner in sublist] + + +def fake_config_file(f, orgs=None, fetch_ecom_segment_id=False): + """ + Create a config file for a single test. Combined with CliRunner.isolated_filesystem() to + ensure the file lifetime is limited to the test. See _call_script for usage. + """ + if orgs is None: + orgs = FAKE_ORGS + + config = { + 'client_id': 'bogus id', + 'client_secret': 'supersecret', + 'base_urls': { + 'credentials': 'https://credentials.stage.edx.invalid/', + 'lms': 'https://stage-edx-edxapp.edx.invalid/', + 'ecommerce': 'https://ecommerce.stage.edx.invalid/', + 'segment': 'https://segment.invalid/graphql', + }, + 'retirement_pipeline': TEST_RETIREMENT_PIPELINE, + 'partner_report_platform_name': TEST_PLATFORM_NAME, + 'org_partner_mapping': orgs, + 'drive_partners_folder': 'FakeDriveID', + 'denied_notification_domains': TEST_DENIED_NOTIFICATION_DOMAINS, + 'sailthru_key': 'fake_sailthru_key', + 'sailthru_secret': 'fake_sailthru_secret', + 's3_archive': { + 'bucket_name': 'fake_test_bucket', + 'region': 'fake_region', + }, + 'segment_workspace_slug': 'test_slug', + 'segment_auth_token': 'fakeauthtoken', + } + + if fetch_ecom_segment_id: + config['fetch_ecommerce_segment_id'] = True + + yaml.safe_dump(config, f) + + +def get_fake_user_retirement( + retirement_id=1, + original_username="foo_username", + original_email="foo@edx.invalid", + original_name="Foo User", + retired_username="retired_user__asdf123", + retired_email="retired_user__asdf123", + ecommerce_segment_id="ecommerce-90", + user_id=9009, + current_username="foo_username", + current_email="foo@edx.invalid", + current_name="Foo User", + current_state_name="PENDING", + last_state_name="PENDING", +): + """ + Return a "learner" used in retirment in the serialized format we get from LMS. + """ + return { + "id": retirement_id, + "current_state": { + "id": 1, + "state_name": current_state_name, + "state_execution_order": 10, + }, + "last_state": { + "id": 1, + "state_name": last_state_name, + "state_execution_order": 10, + }, + "original_username": original_username, + "original_email": original_email, + "original_name": original_name, + "retired_username": retired_username, + "retired_email": retired_email, + "ecommerce_segment_id": ecommerce_segment_id, + "user": { + "id": user_id, + "username": current_username, + "email": current_email, + "profile": { + "id": 10009, + "name": current_name + } + }, + "created": "2018-10-18T20:08:03.724805", + "modified": "2018-10-18T20:08:03.724805", + } + + +def fake_google_secrets_file(f): + """ + Create a fake google secrets file for a single test. + """ + fake_private_key = """ +-----BEGIN PRIVATE KEY----- +-----END PRIVATE KEY----- + r""" + + secrets = { + "type": "service_account", + "project_id": "partner-reporting-automation", + "private_key_id": "foo", + "private_key": fake_private_key, + "client_email": "bogus@serviceacct.invalid", + "client_id": "411", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://accounts.google.com/o/oauth2/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/foo" + } + + json.dump(secrets, f) diff --git a/scripts/user_retirement/tests/test_data/uploading.txt b/scripts/user_retirement/tests/test_data/uploading.txt new file mode 100644 index 000000000000..eb2b87e692a3 --- /dev/null +++ b/scripts/user_retirement/tests/test_data/uploading.txt @@ -0,0 +1 @@ +Upload this file on s3 in tests. \ No newline at end of file diff --git a/scripts/user_retirement/tests/test_get_learners_to_retire.py b/scripts/user_retirement/tests/test_get_learners_to_retire.py new file mode 100644 index 000000000000..480de92185db --- /dev/null +++ b/scripts/user_retirement/tests/test_get_learners_to_retire.py @@ -0,0 +1,159 @@ +""" +Test the get_learners_to_retire.py script +""" + +import os + +from click.testing import CliRunner +from mock import DEFAULT, patch +from requests.exceptions import HTTPError + +from scripts.user_retirement.get_learners_to_retire import get_learners_to_retire +from scripts.user_retirement.tests.retirement_helpers import fake_config_file, get_fake_user_retirement + + +def _call_script(expected_user_files, cool_off_days=1, output_dir='test', user_count_error_threshold=200, + max_user_batch_size=201): + """ + Call the retired learner script with the given username and a generic, temporary config file. + Returns the CliRunner.invoke results + """ + runner = CliRunner() + with runner.isolated_filesystem(): + with open('test_config.yml', 'w') as f: + fake_config_file(f) + result = runner.invoke( + get_learners_to_retire, + args=[ + '--config_file', 'test_config.yml', + '--cool_off_days', cool_off_days, + '--output_dir', output_dir, + '--user_count_error_threshold', user_count_error_threshold, + '--max_user_batch_size', max_user_batch_size + ] + ) + print(result) + print(result.output) + + # This is the number of users in the mocked call, each should have a file if the number is + # greater than 0, otherwise a failure is expected and the output dir should not exist + if expected_user_files: + assert len(os.listdir(output_dir)) == expected_user_files + else: + assert not os.path.exists(output_dir) + return result + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + learners_to_retire=DEFAULT +) +def test_success(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners_to_retire = kwargs['learners_to_retire'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_learners_to_retire.return_value = [ + get_fake_user_retirement(original_username='test_user1'), + get_fake_user_retirement(original_username='test_user2'), + ] + + result = _call_script(2) + + # Called once per API we instantiate (LMS, ECommerce, Credentials) + assert mock_get_access_token.call_count == 1 + mock_get_learners_to_retire.assert_called_once() + + assert result.exit_code == 0 + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + learners_to_retire=DEFAULT +) +def test_lms_down(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners_to_retire = kwargs['learners_to_retire'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_learners_to_retire.side_effect = HTTPError + + result = _call_script(0) + + # Called once per API we instantiate (LMS, ECommerce, Credentials) + assert mock_get_access_token.call_count == 1 + mock_get_learners_to_retire.assert_called_once() + + assert result.exit_code == 1 + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + learners_to_retire=DEFAULT +) +def test_misconfigured(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners_to_retire = kwargs['learners_to_retire'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_learners_to_retire.side_effect = HTTPError + + result = _call_script(0) + + # Called once per API we instantiate (LMS, ECommerce, Credentials) + assert mock_get_access_token.call_count == 1 + mock_get_learners_to_retire.assert_called_once() + + assert result.exit_code == 1 + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + learners_to_retire=DEFAULT +) +def test_too_many_users(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners_to_retire = kwargs['learners_to_retire'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_learners_to_retire.return_value = [ + get_fake_user_retirement(original_username='test_user1'), + get_fake_user_retirement(original_username='test_user2'), + ] + + result = _call_script(0, user_count_error_threshold=1) + + # Called once per API we instantiate (LMS, ECommerce, Credentials) + assert mock_get_access_token.call_count == 1 + mock_get_learners_to_retire.assert_called_once() + + assert result.exit_code == -1 + assert 'Too many learners' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + learners_to_retire=DEFAULT +) +def test_users_limit(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners_to_retire = kwargs['learners_to_retire'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_learners_to_retire.return_value = [ + get_fake_user_retirement(original_username='test_user1'), + get_fake_user_retirement(original_username='test_user2'), + ] + + result = _call_script(1, user_count_error_threshold=200, max_user_batch_size=1) + + # Called once per API we instantiate (LMS, ECommerce, Credentials) + assert mock_get_access_token.call_count == 1 + mock_get_learners_to_retire.assert_called_once() + + assert result.exit_code == 0 diff --git a/scripts/user_retirement/tests/test_retire_one_learner.py b/scripts/user_retirement/tests/test_retire_one_learner.py new file mode 100644 index 000000000000..a78b6b787acb --- /dev/null +++ b/scripts/user_retirement/tests/test_retire_one_learner.py @@ -0,0 +1,412 @@ +""" +Test the retire_one_learner.py script +""" + +from click.testing import CliRunner +from mock import DEFAULT, patch + +from scripts.user_retirement.retire_one_learner import ( + END_STATES, + ERR_BAD_CONFIG, + ERR_BAD_LEARNER, + ERR_SETUP_FAILED, + ERR_UNKNOWN_STATE, + ERR_USER_AT_END_STATE, + ERR_USER_IN_WORKING_STATE, + retire_learner +) +from scripts.user_retirement.tests.retirement_helpers import fake_config_file, get_fake_user_retirement +from scripts.user_retirement.utils.exception import HttpDoesNotExistException + + +def _call_script(username, fetch_ecom_segment_id=False): + """ + Call the retired learner script with the given username and a generic, temporary config file. + Returns the CliRunner.invoke results + """ + runner = CliRunner() + with runner.isolated_filesystem(): + with open('test_config.yml', 'w') as f: + fake_config_file(f, fetch_ecom_segment_id=fetch_ecom_segment_id) + result = runner.invoke(retire_learner, args=['--username', username, '--config_file', 'test_config.yml']) + print(result) + print(result.output) + return result + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch('scripts.user_retirement.utils.edx_api.EcommerceApi.get_tracking_key') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT, + retirement_retire_forum=DEFAULT, + retirement_retire_mailings=DEFAULT, + retirement_unenroll=DEFAULT, + retirement_lms_retire=DEFAULT +) +def test_successful_retirement(*args, **kwargs): + username = 'test_username' + + mock_get_access_token = args[1] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + mock_retire_forum = kwargs['retirement_retire_forum'] + mock_retire_mailings = kwargs['retirement_retire_mailings'] + mock_unenroll = kwargs['retirement_unenroll'] + mock_lms_retire = kwargs['retirement_lms_retire'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_retirement_state.return_value = get_fake_user_retirement(original_username=username) + + result = _call_script(username, fetch_ecom_segment_id=True) + + # Called once per API we instantiate (LMS, ECommerce, Credentials) + assert mock_get_access_token.call_count == 3 + mock_get_retirement_state.assert_called_once_with(username) + assert mock_update_learner_state.call_count == 9 + + # Called once per retirement + for mock_call in ( + mock_retire_forum, + mock_retire_mailings, + mock_unenroll, + mock_lms_retire + ): + mock_call.assert_called_once_with(mock_get_retirement_state.return_value) + + assert result.exit_code == 0 + assert 'Retirement complete' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT +) +def test_user_does_not_exist(*args, **kwargs): + username = 'test_username' + + mock_get_access_token = args[0] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_retirement_state.side_effect = Exception + + result = _call_script(username) + + assert mock_get_access_token.call_count == 3 + mock_get_retirement_state.assert_called_once_with(username) + mock_update_learner_state.assert_not_called() + + assert result.exit_code == ERR_SETUP_FAILED + assert 'Exception' in result.output + + +def test_bad_config(): + username = 'test_username' + runner = CliRunner() + result = runner.invoke(retire_learner, args=['--username', username, '--config_file', 'does_not_exist.yml']) + assert result.exit_code == ERR_BAD_CONFIG + assert 'does_not_exist.yml' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT +) +def test_bad_learner(*args, **kwargs): + username = 'test_username' + + mock_get_access_token = args[0] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + + # Broken API call, no state returned + mock_get_retirement_state.side_effect = HttpDoesNotExistException + result = _call_script(username) + + assert mock_get_access_token.call_count == 3 + mock_get_retirement_state.assert_called_once_with(username) + mock_update_learner_state.assert_not_called() + + assert result.exit_code == ERR_BAD_LEARNER + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT +) +def test_user_in_working_state(*args, **kwargs): + username = 'test_username' + + mock_get_access_token = args[0] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_retirement_state.return_value = get_fake_user_retirement( + original_username=username, + current_state_name='RETIRING_FORUMS' + ) + + result = _call_script(username) + + assert mock_get_access_token.call_count == 3 + mock_get_retirement_state.assert_called_once_with(username) + mock_update_learner_state.assert_not_called() + + assert result.exit_code == ERR_USER_IN_WORKING_STATE + assert 'in a working state' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT +) +def test_user_in_bad_state(*args, **kwargs): + username = 'test_username' + bad_state = 'BOGUS_STATE' + mock_get_access_token = args[0] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_retirement_state.return_value = get_fake_user_retirement( + original_username=username, + current_state_name=bad_state + ) + result = _call_script(username) + + assert mock_get_access_token.call_count == 3 + mock_get_retirement_state.assert_called_once_with(username) + mock_update_learner_state.assert_not_called() + + assert result.exit_code == ERR_UNKNOWN_STATE + assert bad_state in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT +) +def test_user_in_end_state(*args, **kwargs): + username = 'test_username' + + mock_get_access_token = args[0] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + + # pytest.parameterize doesn't play nicely with patch.multiple, this seemed more + # readable than the alternatives. + for end_state in END_STATES: + mock_get_retirement_state.return_value = { + 'original_username': username, + 'current_state': { + 'state_name': end_state + } + } + + result = _call_script(username) + + assert mock_get_access_token.call_count == 3 + mock_get_retirement_state.assert_called_once_with(username) + mock_update_learner_state.assert_not_called() + + assert result.exit_code == ERR_USER_AT_END_STATE + assert end_state in result.output + + # Reset our call counts for the next test + mock_get_access_token.reset_mock() + mock_get_retirement_state.reset_mock() + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT, + retirement_retire_forum=DEFAULT, + retirement_retire_mailings=DEFAULT, + retirement_unenroll=DEFAULT, + retirement_lms_retire=DEFAULT +) +def test_skipping_states(*args, **kwargs): + username = 'test_username' + + mock_get_access_token = args[0] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + mock_retire_forum = kwargs['retirement_retire_forum'] + mock_retire_mailings = kwargs['retirement_retire_mailings'] + mock_unenroll = kwargs['retirement_unenroll'] + mock_lms_retire = kwargs['retirement_lms_retire'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_retirement_state.return_value = get_fake_user_retirement( + original_username=username, + current_state_name='EMAIL_LISTS_COMPLETE' + ) + + result = _call_script(username) + + # Called once per API we instantiate (LMS, ECommerce, Credentials) + assert mock_get_access_token.call_count == 3 + mock_get_retirement_state.assert_called_once_with(username) + assert mock_update_learner_state.call_count == 5 + + # Skipped + for mock_call in ( + mock_retire_forum, + mock_retire_mailings + ): + mock_call.assert_not_called() + + # Called once per retirement + for mock_call in ( + mock_unenroll, + mock_lms_retire + ): + mock_call.assert_called_once_with(mock_get_retirement_state.return_value) + + assert result.exit_code == 0 + + for required_output in ( + 'RETIRING_FORUMS completed in previous run', + 'RETIRING_EMAIL_LISTS completed in previous run', + 'Starting state RETIRING_ENROLLMENTS', + 'State RETIRING_ENROLLMENTS completed', + 'Starting state RETIRING_LMS', + 'State RETIRING_LMS completed', + 'Retirement complete' + ): + assert required_output in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch('scripts.user_retirement.utils.edx_api.EcommerceApi.get_tracking_key') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT, + retirement_retire_forum=DEFAULT, + retirement_retire_mailings=DEFAULT, + retirement_unenroll=DEFAULT, + retirement_lms_retire=DEFAULT +) +def test_get_segment_id_success(*args, **kwargs): + username = 'test_username' + + mock_get_tracking_key = args[0] + mock_get_access_token = args[1] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_retirement_retire_forum = kwargs['retirement_retire_forum'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_tracking_key.return_value = {'id': 1, 'ecommerce_tracking_id': 'ecommerce-1'} + + # The learner starts off with these values, 'ecommerce_segment_id' is added during script + # startup + mock_get_retirement_state.return_value = get_fake_user_retirement( + original_username=username, + ) + + _call_script(username, fetch_ecom_segment_id=True) + mock_get_tracking_key.assert_called_once_with(mock_get_retirement_state.return_value) + + config_after_get_segment_id = mock_get_retirement_state.return_value + config_after_get_segment_id['ecommerce_segment_id'] = 'ecommerce-1' + + mock_retirement_retire_forum.assert_called_once_with(config_after_get_segment_id) + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch('scripts.user_retirement.utils.edx_api.EcommerceApi.get_tracking_key') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT, + retirement_retire_forum=DEFAULT, + retirement_retire_mailings=DEFAULT, + retirement_unenroll=DEFAULT, + retirement_lms_retire=DEFAULT +) +def test_get_segment_id_not_found(*args, **kwargs): + username = 'test_username' + + mock_get_tracking_key = args[0] + mock_get_access_token = args[1] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_get_tracking_key.side_effect = HttpDoesNotExistException('{} not found'.format(username)) + + mock_get_retirement_state.return_value = get_fake_user_retirement( + original_username=username, + ) + + result = _call_script(username, fetch_ecom_segment_id=True) + mock_get_tracking_key.assert_called_once_with(mock_get_retirement_state.return_value) + assert 'Setting Ecommerce Segment ID to None' in result.output + + # Reset our call counts for the next test + mock_get_access_token.reset_mock() + mock_get_retirement_state.reset_mock() + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch('scripts.user_retirement.utils.edx_api.EcommerceApi.get_tracking_key') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learner_retirement_state=DEFAULT, + update_learner_retirement_state=DEFAULT, + retirement_retire_forum=DEFAULT, + retirement_retire_mailings=DEFAULT, + retirement_unenroll=DEFAULT, + retirement_lms_retire=DEFAULT +) +def test_get_segment_id_error(*args, **kwargs): + username = 'test_username' + + mock_get_tracking_key = args[0] + mock_get_access_token = args[1] + mock_get_retirement_state = kwargs['get_learner_retirement_state'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + + test_exception_message = 'Test Exception!' + mock_get_tracking_key.side_effect = Exception(test_exception_message) + + mock_get_retirement_state.return_value = get_fake_user_retirement( + original_username=username, + ) + + mock_get_retirement_state.return_value = { + 'original_username': username, + 'current_state': { + 'state_name': 'PENDING' + } + } + + result = _call_script(username, fetch_ecom_segment_id=True) + mock_get_tracking_key.assert_called_once_with(mock_get_retirement_state.return_value) + mock_update_learner_state.assert_not_called() + + assert result.exit_code == ERR_SETUP_FAILED + assert 'Unexpected error fetching Ecommerce tracking id!' in result.output + assert test_exception_message in result.output diff --git a/scripts/user_retirement/tests/test_retirement_archive_and_cleanup.py b/scripts/user_retirement/tests/test_retirement_archive_and_cleanup.py new file mode 100644 index 000000000000..3a6a847e1d6e --- /dev/null +++ b/scripts/user_retirement/tests/test_retirement_archive_and_cleanup.py @@ -0,0 +1,277 @@ +""" +Test the retirement_archive_and_cleanup.py script +""" + +import datetime +import os + +import boto3 +import pytest +from botocore.exceptions import ClientError +from click.testing import CliRunner +from mock import DEFAULT, call, patch +from moto import mock_ec2, mock_s3 + +from scripts.user_retirement.retirement_archive_and_cleanup import ( + ERR_ARCHIVING, + ERR_BAD_CLI_PARAM, + ERR_BAD_CONFIG, + ERR_DELETING, + ERR_FETCHING, + ERR_NO_CONFIG, + ERR_SETUP_FAILED, + _upload_to_s3, + archive_and_cleanup +) +from scripts.user_retirement.tests.retirement_helpers import fake_config_file, get_fake_user_retirement + +FAKE_BUCKET_NAME = "fake_test_bucket" + + +def _call_script(cool_off_days=37, batch_size=None, dry_run=None, start_date=None, end_date=None): + """ + Call the archive script with the given params and a generic config file. + Returns the CliRunner.invoke results + """ + runner = CliRunner() + with runner.isolated_filesystem(): + with open('test_config.yml', 'w') as f: + fake_config_file(f) + + base_args = [ + '--config_file', 'test_config.yml', + '--cool_off_days', cool_off_days, + ] + if batch_size: + base_args += ['--batch_size', batch_size] + if dry_run: + base_args += ['--dry_run', dry_run] + if start_date: + base_args += ['--start_date', start_date] + if end_date: + base_args += ['--end_date', end_date] + + result = runner.invoke(archive_and_cleanup, args=base_args) + print(result) + print(result.output) + return result + + +def _fake_learner(ordinal): + """ + Creates a simple fake learner + """ + return get_fake_user_retirement( + user_id=ordinal, + original_username='test{}'.format(ordinal), + original_email='test{}@edx.invalid'.format(ordinal), + original_name='test {}'.format(ordinal), + retired_username='retired_{}'.format(ordinal), + retired_email='retired_test{}@edx.invalid'.format(ordinal), + last_state_name='COMPLETE' + ) + + +def fake_learners_to_retire(): + """ + A simple hard-coded list of fake learners + """ + return [ + _fake_learner(1), + _fake_learner(2), + _fake_learner(3) + ] + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learners_by_date_and_status=DEFAULT, + bulk_cleanup_retirements=DEFAULT +) +@mock_s3 +def test_successful(*args, **kwargs): + conn = boto3.resource('s3') + conn.create_bucket(Bucket=FAKE_BUCKET_NAME) + + mock_get_access_token = args[0] + mock_get_learners = kwargs['get_learners_by_date_and_status'] + mock_bulk_cleanup_retirements = kwargs['bulk_cleanup_retirements'] + + mock_get_learners.return_value = fake_learners_to_retire() + + result = _call_script() + + # Called once to get the LMS token + assert mock_get_access_token.call_count == 1 + mock_get_learners.assert_called_once() + mock_bulk_cleanup_retirements.assert_called_once_with( + ['test1', 'test2', 'test3']) + + assert result.exit_code == 0 + assert 'Archive and cleanup complete' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learners_by_date_and_status=DEFAULT, + bulk_cleanup_retirements=DEFAULT +) +@mock_ec2 +@mock_s3 +def test_successful_with_batching(*args, **kwargs): + conn = boto3.resource('s3') + conn.create_bucket(Bucket=FAKE_BUCKET_NAME) + + mock_get_access_token = args[0] + mock_get_learners = kwargs['get_learners_by_date_and_status'] + mock_bulk_cleanup_retirements = kwargs['bulk_cleanup_retirements'] + + mock_get_learners.return_value = fake_learners_to_retire() + + result = _call_script(batch_size=2) + + # Called once to get the LMS token + assert mock_get_access_token.call_count == 1 + mock_get_learners.assert_called_once() + get_learner_calls = [call(['test1', 'test2']), call(['test3'])] + mock_bulk_cleanup_retirements.assert_has_calls(get_learner_calls) + + assert result.exit_code == 0 + assert 'Archive and cleanup complete for batch #1' in result.output + assert 'Archive and cleanup complete for batch #2' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learners_by_date_and_status=DEFAULT, + bulk_cleanup_retirements=DEFAULT +) +@mock_s3 +def test_successful_dry_run(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners = kwargs['get_learners_by_date_and_status'] + mock_bulk_cleanup_retirements = kwargs['bulk_cleanup_retirements'] + + mock_get_learners.return_value = fake_learners_to_retire() + + result = _call_script(dry_run=True) + + # Called once to get the LMS token + assert mock_get_access_token.call_count == 1 + mock_get_learners.assert_called_once() + mock_bulk_cleanup_retirements.assert_not_called() + + assert result.exit_code == 0 + assert 'Dry run. Skipping the step to upload data to' in result.output + assert 'This is a dry-run. Exiting before any retirements are cleaned up' in result.output + + +def test_no_config(): + runner = CliRunner() + result = runner.invoke( + archive_and_cleanup, + args=[ + '--cool_off_days', 37 + ] + ) + assert result.exit_code == ERR_NO_CONFIG + assert 'No config file passed in.' in result.output + + +def test_bad_config(): + runner = CliRunner() + result = runner.invoke( + archive_and_cleanup, + args=[ + '--config_file', 'does_not_exist.yml', + '--cool_off_days', 37 + ] + ) + assert result.exit_code == ERR_BAD_CONFIG + assert 'does_not_exist.yml' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.__init__', side_effect=Exception) +def test_setup_failed(*_): + result = _call_script() + assert result.exit_code == ERR_SETUP_FAILED + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status', side_effect=Exception) +def test_bad_fetch(*_): + result = _call_script() + assert result.exit_code == ERR_FETCHING + assert 'Unexpected error occurred fetching users to update!' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status', + return_value=fake_learners_to_retire()) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.bulk_cleanup_retirements', side_effect=Exception) +@patch('scripts.user_retirement.retirement_archive_and_cleanup._upload_to_s3') +def test_bad_lms_deletion(*_): + result = _call_script() + assert result.exit_code == ERR_DELETING + assert 'Unexpected error occurred deleting retirements!' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status', + return_value=fake_learners_to_retire()) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.bulk_cleanup_retirements') +@patch('scripts.user_retirement.retirement_archive_and_cleanup._upload_to_s3', side_effect=Exception) +def test_bad_s3_upload(*_): + result = _call_script() + assert result.exit_code == ERR_ARCHIVING + assert 'Unexpected error occurred archiving retirements!' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +def test_conflicting_dates(*_): + result = _call_script(start_date=datetime.datetime( + 2021, 10, 10), end_date=datetime.datetime(2018, 10, 10)) + assert result.exit_code == ERR_BAD_CLI_PARAM + assert 'Conflicting start and end dates passed on CLI' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch( + 'scripts.user_retirement.retirement_archive_and_cleanup._get_utc_now', + return_value=datetime.datetime(2021, 2, 2, 0, 0) +) +def test_conflicting_cool_off_date(*_): + result = _call_script( + cool_off_days=10, + start_date=datetime.datetime(2021, 1, 1), end_date=datetime.datetime(2021, 2, 1) + ) + assert result.exit_code == ERR_BAD_CLI_PARAM + assert 'End date cannot occur within the cool_off_days period' in result.output + + +@mock_s3 +def test_s3_upload_data(): + """ + Test case to verify s3 upload and download. + """ + s3 = boto3.client("s3") + s3.create_bucket(Bucket=FAKE_BUCKET_NAME) + config = {'s3_archive': {'bucket_name': FAKE_BUCKET_NAME}} + filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data', 'uploading.txt') + key = 'raw/' + datetime.datetime.now().strftime('%Y/%m/') + filename + + # first try dry run without uploading. Try to get object should raise error + with pytest.raises(ClientError) as exc_info: + _upload_to_s3(config, filename, True) + s3.get_object(Bucket=FAKE_BUCKET_NAME, Key=key) + assert exc_info.value.response['Error']['Code'] == 'NoSuchKey' + + # upload a file, download and compare its content. + _upload_to_s3(config, filename, False) + resp = s3.get_object(Bucket=FAKE_BUCKET_NAME, Key=key) + data = resp["Body"].read() + assert data.decode() == "Upload this file on s3 in tests." diff --git a/scripts/user_retirement/tests/test_retirement_bulk_status_update.py b/scripts/user_retirement/tests/test_retirement_bulk_status_update.py new file mode 100644 index 000000000000..d2a8bab60ba3 --- /dev/null +++ b/scripts/user_retirement/tests/test_retirement_bulk_status_update.py @@ -0,0 +1,182 @@ +""" +Test the retirement_bulk_status_update.py script +""" + +from click.testing import CliRunner +from mock import DEFAULT, patch + +from scripts.user_retirement.retirement_bulk_status_update import ( + ERR_BAD_CONFIG, + ERR_FETCHING, + ERR_NO_CONFIG, + ERR_SETUP_FAILED, + ERR_UPDATING, + update_statuses +) +from scripts.user_retirement.tests.retirement_helpers import fake_config_file, get_fake_user_retirement + + +def _call_script(initial_state='COMPLETE', new_state='PENDING', start_date='2018-01-01', end_date='2018-01-15', + rewind_state=False): + """ + Call the bulk update statuses script with the given params and a generic config file. + Returns the CliRunner.invoke results + """ + runner = CliRunner() + with runner.isolated_filesystem(): + with open('test_config.yml', 'w') as f: + fake_config_file(f) + args = [ + '--config_file', 'test_config.yml', + '--initial_state', initial_state, + '--start_date', start_date, + '--end_date', end_date + ] + args.extend(['--new_state', new_state]) if new_state else None + args.append('--rewind-state') if rewind_state else None + result = runner.invoke( + update_statuses, + args=args + ) + print(result) + print(result.output) + return result + + +def fake_learners_to_retire(**overrides): + """ + A simple hard-coded list of fake learners with the only piece of + information this script cares about. + """ + + return [ + get_fake_user_retirement(**{"original_username": "user1", **overrides}), + get_fake_user_retirement(**{"original_username": "user2", **overrides}), + get_fake_user_retirement(**{"original_username": "user3", **overrides}), + ] + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learners_by_date_and_status=DEFAULT, + update_learner_retirement_state=DEFAULT +) +def test_successful_update(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners = kwargs['get_learners_by_date_and_status'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + + mock_get_learners.return_value = fake_learners_to_retire() + + result = _call_script() + + # Called once to get the LMS token + assert mock_get_access_token.call_count == 1 + mock_get_learners.assert_called_once() + assert mock_update_learner_state.call_count == 3 + + assert result.exit_code == 0 + assert 'Bulk update complete' in result.output + + +def test_no_config(): + runner = CliRunner() + result = runner.invoke( + update_statuses, + args=[ + '--initial_state', 'COMPLETE', + '--new_state', 'PENDING', + '--start_date', '2018-01-01', + '--end_date', '2018-01-15' + ] + ) + assert result.exit_code == ERR_NO_CONFIG + assert 'No config file passed in.' in result.output + + +def test_bad_config(): + runner = CliRunner() + result = runner.invoke( + update_statuses, + args=[ + '--config_file', 'does_not_exist.yml', + '--initial_state', 'COMPLETE', + '--new_state', 'PENDING', + '--start_date', '2018-01-01', + '--end_date', '2018-01-15' + ] + ) + assert result.exit_code == ERR_BAD_CONFIG + assert 'does_not_exist.yml' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learners_by_date_and_status=DEFAULT, + update_learner_retirement_state=DEFAULT +) +def test_successful_rewind(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners = kwargs['get_learners_by_date_and_status'] + mock_update_learner_state = kwargs['update_learner_retirement_state'] + + mock_get_learners.return_value = fake_learners_to_retire(current_state_name='ERRORED') + + result = _call_script(new_state=None, rewind_state=True) + + # Called once to get the LMS token + assert mock_get_access_token.call_count == 1 + mock_get_learners.assert_called_once() + assert mock_update_learner_state.call_count == 3 + + assert result.exit_code == 0 + assert 'Bulk update complete' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + get_learners_by_date_and_status=DEFAULT, + update_learner_retirement_state=DEFAULT +) +def test_rewind_bad_args(*args, **kwargs): + mock_get_access_token = args[0] + mock_get_learners = kwargs['get_learners_by_date_and_status'] + + mock_get_learners.return_value = fake_learners_to_retire(current_state_name='ERRORED') + + result = _call_script(rewind_state=True) + + # Called once to get the LMS token + assert mock_get_access_token.call_count == 1 + mock_get_learners.assert_called_once() + + assert result.exit_code == ERR_BAD_CONFIG + assert 'boolean rewind_state or a new state to set learners to' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.__init__', side_effect=Exception) +def test_setup_failed(*_): + result = _call_script() + assert result.exit_code == ERR_SETUP_FAILED + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status', side_effect=Exception) +def test_bad_fetch(*_): + result = _call_script() + assert result.exit_code == ERR_FETCHING + assert 'Unexpected error occurred fetching users to update!' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token', return_value=('THIS_IS_A_JWT', None)) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.get_learners_by_date_and_status', + return_value=fake_learners_to_retire()) +@patch('scripts.user_retirement.utils.edx_api.LmsApi.update_learner_retirement_state', side_effect=Exception) +def test_bad_update(*_): + result = _call_script() + assert result.exit_code == ERR_UPDATING + assert 'Unexpected error occurred updating users!' in result.output diff --git a/scripts/user_retirement/tests/test_retirement_partner_report.py b/scripts/user_retirement/tests/test_retirement_partner_report.py new file mode 100644 index 000000000000..c9d05a974396 --- /dev/null +++ b/scripts/user_retirement/tests/test_retirement_partner_report.py @@ -0,0 +1,818 @@ +# coding=utf-8 +""" +Test the retire_one_learner.py script +""" + +import csv +import os +import time +import unicodedata +from datetime import date + +from click.testing import CliRunner +from mock import DEFAULT, patch +from six import PY2, itervalues + +from scripts.user_retirement.retirement_partner_report import \ + _generate_report_files_or_exit # pylint: disable=protected-access +from scripts.user_retirement.retirement_partner_report import \ + _get_orgs_and_learners_or_exit # pylint: disable=protected-access +from scripts.user_retirement.retirement_partner_report import ( + DEFAULT_FIELD_HEADINGS, + ERR_BAD_CONFIG, + ERR_BAD_SECRETS, + ERR_CLEANUP, + ERR_DRIVE_LISTING, + ERR_FETCHING_LEARNERS, + ERR_NO_CONFIG, + ERR_NO_OUTPUT_DIR, + ERR_NO_SECRETS, + ERR_REPORTING, + ERR_SETUP_FAILED, + ERR_UNKNOWN_ORG, + LEARNER_CREATED_KEY, + LEARNER_ORIGINAL_USERNAME_KEY, + ORGS_CONFIG_FIELD_HEADINGS_KEY, + ORGS_CONFIG_KEY, + ORGS_CONFIG_LEARNERS_KEY, + ORGS_CONFIG_ORG_KEY, + ORGS_KEY, + REPORTING_FILENAME_PREFIX, + SETUP_LMS_OR_EXIT, + generate_report +) +from scripts.user_retirement.tests.retirement_helpers import ( + FAKE_ORGS, + TEST_PLATFORM_NAME, + fake_config_file, + fake_google_secrets_file, + flatten_partner_list +) + +TEST_CONFIG_YML_NAME = 'test_config.yml' +TEST_GOOGLE_SECRETS_FILENAME = 'test_google_secrets.json' +DELETION_TIME = time.strftime("%Y-%m-%dT%H:%M:%S") +UNICODE_NAME_CONSTANT = '阿碧' +USER_ID = '12345' +TEST_ORGS_CONFIG = [ + { + ORGS_CONFIG_ORG_KEY: 'orgCustom', + ORGS_CONFIG_FIELD_HEADINGS_KEY: ['heading_1', 'heading_2', 'heading_3'] + }, + { + ORGS_CONFIG_ORG_KEY: 'otherCustomOrg', + ORGS_CONFIG_FIELD_HEADINGS_KEY: ['unique_id'] + } +] +DEFAULT_FIELD_VALUES = { + 'user_id': USER_ID, + LEARNER_ORIGINAL_USERNAME_KEY: 'username', + 'original_email': 'invalid', + 'original_name': UNICODE_NAME_CONSTANT, + 'deletion_completed': DELETION_TIME +} + + +def _call_script(expect_success=True, expected_num_rows=10, config_orgs=None, expected_fields=None): + """ + Call the retired learner script with the given username and a generic, temporary config file. + Returns the CliRunner.invoke results + """ + if expected_fields is None: + expected_fields = DEFAULT_FIELD_VALUES + if config_orgs is None: + config_orgs = FAKE_ORGS + + runner = CliRunner() + with runner.isolated_filesystem(): + with open(TEST_CONFIG_YML_NAME, 'w') as config_f: + fake_config_file(config_f, config_orgs) + with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as secrets_f: + fake_google_secrets_file(secrets_f) + + tmp_output_dir = 'test_output_dir' + os.mkdir(tmp_output_dir) + + result = runner.invoke( + generate_report, + args=[ + '--config_file', + TEST_CONFIG_YML_NAME, + '--google_secrets_file', + TEST_GOOGLE_SECRETS_FILENAME, + '--output_dir', + tmp_output_dir + ] + ) + + print(result) + print(result.output) + + if expect_success: + assert result.exit_code == 0 + + if config_orgs is None: + # These are the orgs + config_org_vals = flatten_partner_list(FAKE_ORGS.values()) + else: + config_org_vals = flatten_partner_list(config_orgs.values()) + + # Normalize the unicode as the script does + if PY2: + config_org_vals = [org.decode('utf-8') for org in config_org_vals] + + config_org_vals = [unicodedata.normalize('NFKC', org) for org in config_org_vals] + + for org in config_org_vals: + outfile = os.path.join(tmp_output_dir, '{}_{}_{}_{}.csv'.format( + REPORTING_FILENAME_PREFIX, TEST_PLATFORM_NAME, org, date.today().isoformat() + )) + + with open(outfile, 'r') as csvfile: + reader = csv.DictReader(csvfile) + rows = [] + for row in reader: + for field_key in expected_fields: + field_value = expected_fields[field_key] + assert field_value in row[field_key] + rows.append(row) + + # Confirm the number of rows + assert len(rows) == expected_num_rows + return result + + +def _fake_retirement_report_user(seed_val, user_orgs=None, user_orgs_config=None): + """ + Creates unique user to populate a fake report with. + - seed_val is a number or other unique value for this user, will be formatted into + user values to make sure they're distinct. + - user_orgs, if given, should be a list of orgs that will be associated with the user. + - user_orgs_config, if given, should be a list of dicts mapping orgs to their customized + field headings. These orgs will also be associated with the user. + """ + user_info = { + 'user_id': USER_ID, + LEARNER_ORIGINAL_USERNAME_KEY: 'username_{}'.format(seed_val), + 'original_email': 'user_{}@foo.invalid'.format(seed_val), + 'original_name': '{} {}'.format(UNICODE_NAME_CONSTANT, seed_val), + LEARNER_CREATED_KEY: DELETION_TIME, + } + + if user_orgs is not None: + user_info[ORGS_KEY] = user_orgs + + if user_orgs_config is not None: + user_info[ORGS_CONFIG_KEY] = user_orgs_config + + return user_info + + +def _fake_retirement_report(num_users=10, user_orgs=None, user_orgs_config=None): + """ + Fake the output of a retirement report with unique users + """ + return [_fake_retirement_report_user(i, user_orgs, user_orgs_config) for i in range(num_users)] + + +@patch('scripts.user_retirement.utils.edx_api.LmsApi.retirement_partner_report') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +def test_report_generation_multiple_partners(*args, **kwargs): + mock_get_access_token = args[0] + mock_retirement_report = args[1] + + org_1_users = [_fake_retirement_report_user(i, user_orgs=['org1']) for i in range(1, 3)] + org_2_users = [_fake_retirement_report_user(i, user_orgs=['org2']) for i in range(3, 5)] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_retirement_report.return_value = org_1_users + org_2_users + + config = { + 'client_id': 'bogus id', + 'client_secret': 'supersecret', + 'base_urls': { + 'lms': 'https://stage-edx-edxapp.edx.invalid/', + }, + 'org_partner_mapping': { + 'org1': ['Org1X'], + 'org2': ['Org2X', 'Org2Xb'], + } + } + SETUP_LMS_OR_EXIT(config) + orgs, usernames = _get_orgs_and_learners_or_exit(config) + + assert usernames == [{'original_username': 'username_{}'.format(username)} for username in range(1, 5)] + + def _get_learner_usernames(org_data): + return [learner['original_username'] for learner in org_data['learners']] + + assert _get_learner_usernames(orgs['Org1X']) == ['username_1', 'username_2'] + + # Org2X and Org2Xb should have the same learners in their report data + assert _get_learner_usernames(orgs['Org2X']) == _get_learner_usernames(orgs['Org2Xb']) == ['username_3', + 'username_4'] + + # Org2X and Org2Xb report data should match + assert orgs['Org2X'] == orgs['Org2Xb'] + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_file_in_folder') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.list_permissions_for_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_comments_for_files') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + retirement_partner_report=DEFAULT, + retirement_partner_cleanup=DEFAULT +) +def test_successful_report(*args, **kwargs): + mock_get_access_token = args[0] + mock_create_comments = args[1] + mock_list_permissions = args[2] + mock_walk_files = args[3] + mock_create_files = args[4] + mock_driveapi = args[5] + mock_retirement_report = kwargs['retirement_partner_report'] + mock_retirement_cleanup = kwargs['retirement_partner_cleanup'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_create_comments.return_value = None + fake_partners = list(itervalues(FAKE_ORGS)) + # Generate the list_permissions return value. + # The first few have POCs. + mock_list_permissions.return_value = { + 'folder' + partner: [ + {'emailAddress': 'some.contact@example.com'}, # The POC. + {'emailAddress': 'another.contact@edx.org'}, + ] + for partner in flatten_partner_list(fake_partners[:2]) + } + # The last one does not have any POCs. + mock_list_permissions.return_value.update({ + 'folder' + partner: [ + {'emailAddress': 'another.contact@edx.org'}, + ] + for partner in fake_partners[2] + }) + mock_walk_files.return_value = [{'name': partner, 'id': 'folder' + partner} for partner in + flatten_partner_list(FAKE_ORGS.values())] + mock_create_files.side_effect = ['foo', 'bar', 'baz', 'qux'] + mock_driveapi.return_value = None + mock_retirement_report.return_value = _fake_retirement_report(user_orgs=list(FAKE_ORGS.keys())) + + result = _call_script() + + # Make sure we're getting the LMS token + mock_get_access_token.assert_called_once() + + # Make sure that we get the report + mock_retirement_report.assert_called_once() + + # Make sure we tried to upload the files + assert mock_create_files.call_count == 4 + + # Make sure we tried to add comments to the files + assert mock_create_comments.call_count == 1 + # First [0] returns all positional args, second [0] gets the first positional arg. + create_comments_file_ids, create_comments_messages = zip(*mock_create_comments.call_args[0][0]) + assert set(create_comments_file_ids).issubset(set(['foo', 'bar', 'baz', 'qux'])) + assert len(create_comments_file_ids) == 2 # only two comments created, the third didn't have a POC. + assert all('+some.contact@example.com' in msg for msg in create_comments_messages) + assert all('+another.contact@edx.org' not in msg for msg in create_comments_messages) + assert 'WARNING: could not find a POC' in result.output + + # Make sure we tried to remove the users from the queue + mock_retirement_cleanup.assert_called_with( + [{'original_username': user[LEARNER_ORIGINAL_USERNAME_KEY]} for user in mock_retirement_report.return_value] + ) + + assert 'All reports completed and uploaded to Google.' in result.output + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_file_in_folder') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.list_permissions_for_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_comments_for_files') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + retirement_partner_report=DEFAULT, + retirement_partner_cleanup=DEFAULT +) +def test_successful_report_org_config(*args, **kwargs): + mock_get_access_token = args[0] + mock_create_comments = args[1] + mock_list_permissions = args[2] + mock_walk_files = args[3] + mock_create_files = args[4] + mock_driveapi = args[5] + mock_retirement_report = kwargs['retirement_partner_report'] + mock_retirement_cleanup = kwargs['retirement_partner_cleanup'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_create_comments.return_value = None + fake_custom_orgs = { + 'orgCustom': ['firstBlah'] + } + fake_partners = list(itervalues(fake_custom_orgs)) + mock_list_permissions.return_value = { + 'folder' + partner: [ + {'emailAddress': 'some.contact@example.com'}, # The POC. + {'emailAddress': 'another.contact@edx.org'}, + ] + for partner in flatten_partner_list(fake_partners[:2]) + } + mock_walk_files.return_value = [{'name': partner, 'id': 'folder' + partner} for partner in + flatten_partner_list(fake_custom_orgs.values())] + mock_create_files.side_effect = ['foo', 'bar', 'baz'] + mock_driveapi.return_value = None + expected_num_users = 1 + + orgs_config = [ + { + ORGS_CONFIG_ORG_KEY: 'orgCustom', + ORGS_CONFIG_FIELD_HEADINGS_KEY: ['heading_1', 'heading_2', 'heading_3'] + } + ] + + # Input from the LMS + report_data = [ + { + 'heading_1': 'h1val', + 'heading_2': 'h2val', + 'heading_3': 'h3val', + LEARNER_ORIGINAL_USERNAME_KEY: 'blah', + LEARNER_CREATED_KEY: DELETION_TIME, + ORGS_CONFIG_KEY: orgs_config + } + ] + + # Resulting csv file content + expected_fields = { + 'heading_1': 'h1val', + 'heading_2': 'h2val', + 'heading_3': 'h3val', + } + + mock_retirement_report.return_value = report_data + + result = _call_script(expected_num_rows=expected_num_users, config_orgs=fake_custom_orgs, + expected_fields=expected_fields) + + # Make sure we're getting the LMS token + mock_get_access_token.assert_called_once() + + # Make sure that we get the report + mock_retirement_report.assert_called_once() + + # Make sure we tried to remove the users from the queue + mock_retirement_cleanup.assert_called_with( + [{'original_username': user[LEARNER_ORIGINAL_USERNAME_KEY]} for user in mock_retirement_report.return_value] + ) + + assert 'All reports completed and uploaded to Google.' in result.output + + +def test_no_config(): + runner = CliRunner() + result = runner.invoke(generate_report) + print(result.output) + assert result.exit_code == ERR_NO_CONFIG + assert 'No config file' in result.output + + +def test_no_secrets(): + runner = CliRunner() + result = runner.invoke(generate_report, args=['--config_file', 'does_not_exist.yml']) + print(result.output) + assert result.exit_code == ERR_NO_SECRETS + assert 'No secrets file' in result.output + + +def test_no_output_dir(): + runner = CliRunner() + with runner.isolated_filesystem(): + with open(TEST_CONFIG_YML_NAME, 'w') as config_f: + config_f.write('irrelevant') + + with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as config_f: + config_f.write('irrelevant') + + result = runner.invoke( + generate_report, + args=[ + '--config_file', + TEST_CONFIG_YML_NAME, + '--google_secrets_file', + TEST_GOOGLE_SECRETS_FILENAME + ] + ) + print(result.output) + assert result.exit_code == ERR_NO_OUTPUT_DIR + assert 'No output_dir' in result.output + + +def test_bad_config(): + runner = CliRunner() + with runner.isolated_filesystem(): + with open(TEST_CONFIG_YML_NAME, 'w') as config_f: + config_f.write(']this is bad yaml') + + with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as config_f: + config_f.write('{this is bad json but we should not get to parsing it') + + tmp_output_dir = 'test_output_dir' + os.mkdir(tmp_output_dir) + + result = runner.invoke( + generate_report, + args=[ + '--config_file', + TEST_CONFIG_YML_NAME, + '--google_secrets_file', + TEST_GOOGLE_SECRETS_FILENAME, + '--output_dir', + tmp_output_dir + ] + ) + print(result.output) + assert result.exit_code == ERR_BAD_CONFIG + assert 'Failed to read' in result.output + + +def test_bad_secrets(): + runner = CliRunner() + with runner.isolated_filesystem(): + with open(TEST_CONFIG_YML_NAME, 'w') as config_f: + fake_config_file(config_f) + + with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as config_f: + config_f.write('{this is bad json') + + tmp_output_dir = 'test_output_dir' + os.mkdir(tmp_output_dir) + + result = runner.invoke( + generate_report, + args=[ + '--config_file', + TEST_CONFIG_YML_NAME, + '--google_secrets_file', + TEST_GOOGLE_SECRETS_FILENAME, + '--output_dir', + tmp_output_dir + ] + ) + print(result.output) + assert result.exit_code == ERR_BAD_SECRETS + assert 'Failed to read' in result.output + + +def test_bad_output_dir(): + runner = CliRunner() + with runner.isolated_filesystem(): + with open(TEST_CONFIG_YML_NAME, 'w') as config_f: + fake_config_file(config_f) + + with open(TEST_GOOGLE_SECRETS_FILENAME, 'w') as config_f: + fake_google_secrets_file(config_f) + + result = runner.invoke( + generate_report, + args=[ + '--config_file', + TEST_CONFIG_YML_NAME, + '--google_secrets_file', + TEST_GOOGLE_SECRETS_FILENAME, + '--output_dir', + 'does_not_exist/at_all' + ] + ) + print(result.output) + assert result.exit_code == ERR_NO_OUTPUT_DIR + assert 'or path does not exist' in result.output + + +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +def test_setup_failed(*args): + mock_get_access_token = args[0] + mock_get_access_token.side_effect = Exception('boom') + + result = _call_script(expect_success=False) + mock_get_access_token.assert_called_once() + assert result.exit_code == ERR_SETUP_FAILED + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + retirement_partner_report=DEFAULT) +def test_fetching_learners_failed(*args, **kwargs): + mock_get_access_token = args[0] + mock_walk_files = args[1] + mock_drive_init = args[2] + mock_retirement_report = kwargs['retirement_partner_report'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_walk_files.return_value = [{'name': 'dummy_file_name', 'id': 'dummy_file_id'}] + mock_drive_init.return_value = None + mock_retirement_report.side_effect = Exception('failed to get learners') + + result = _call_script(expect_success=False) + + assert result.exit_code == ERR_FETCHING_LEARNERS + assert 'failed to get learners' in result.output + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +def test_listing_folders_failed(*args): + mock_get_access_token = args[0] + mock_walk_files = args[1] + mock_drive_init = args[2] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_walk_files.side_effect = [[], Exception()] + mock_drive_init.return_value = None + + # call it once; this time walk_files will return an empty list. + result = _call_script(expect_success=False) + + assert result.exit_code == ERR_DRIVE_LISTING + assert 'Finding partner directories on Drive failed' in result.output + + # call it a second time; this time walk_files will throw an exception. + result = _call_script(expect_success=False) + + assert result.exit_code == ERR_DRIVE_LISTING + assert 'Finding partner directories on Drive failed' in result.output + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + retirement_partner_report=DEFAULT) +def test_unknown_org(*args, **kwargs): + mock_get_access_token = args[0] + mock_drive_init = args[2] + mock_retirement_report = kwargs['retirement_partner_report'] + + mock_drive_init.return_value = None + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + + orgs = ['orgA', 'orgB'] + + mock_retirement_report.return_value = [_fake_retirement_report_user(i, orgs, TEST_ORGS_CONFIG) for i in range(10)] + + result = _call_script(expect_success=False) + + assert result.exit_code == ERR_UNKNOWN_ORG + assert 'orgA' in result.output + assert 'orgB' in result.output + assert 'orgCustom' in result.output + assert 'otherCustomOrg' in result.output + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + retirement_partner_report=DEFAULT) +def test_unknown_org_custom(*args, **kwargs): + mock_get_access_token = args[0] + mock_drive_init = args[2] + mock_retirement_report = kwargs['retirement_partner_report'] + + mock_drive_init.return_value = None + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + + custom_orgs_config = [ + { + ORGS_CONFIG_ORG_KEY: 'singleCustomOrg', + ORGS_CONFIG_FIELD_HEADINGS_KEY: ['first_heading', 'second_heading'] + } + ] + + mock_retirement_report.return_value = [_fake_retirement_report_user(i, None, custom_orgs_config) for i in range(2)] + + result = _call_script(expect_success=False) + + assert result.exit_code == ERR_UNKNOWN_ORG + assert 'organizations {\'singleCustomOrg\'} do not exist' in result.output + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch('unicodecsv.DictWriter') +@patch('scripts.user_retirement.utils.edx_api.LmsApi.retirement_partner_report') +def test_reporting_error(*args): + mock_retirement_report = args[0] + mock_dictwriter = args[1] + mock_get_access_token = args[2] + mock_drive_init = args[4] + + error_msg = 'Fake unable to write csv' + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_dictwriter.side_effect = Exception(error_msg) + mock_drive_init.return_value = None + mock_retirement_report.return_value = _fake_retirement_report(user_orgs=list(FAKE_ORGS.keys())) + + result = _call_script(expect_success=False) + + assert result.exit_code == ERR_REPORTING + assert error_msg in result.output + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.list_permissions_for_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_comments_for_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_file_in_folder') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + retirement_partner_report=DEFAULT, + retirement_partner_cleanup=DEFAULT +) +def test_cleanup_error(*args, **kwargs): + mock_get_access_token = args[0] + mock_create_files = args[1] + mock_driveapi = args[2] + mock_walk_files = args[3] + mock_create_comments = args[4] + mock_list_permissions = args[5] + mock_retirement_report = kwargs['retirement_partner_report'] + mock_retirement_cleanup = kwargs['retirement_partner_cleanup'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_create_files.return_value = True + mock_driveapi.return_value = None + mock_walk_files.return_value = [{'name': partner, 'id': 'folder' + partner} for partner in + flatten_partner_list(FAKE_ORGS.values())] + fake_partners = list(itervalues(FAKE_ORGS)) + # Generate the list_permissions return value. + mock_list_permissions.return_value = { + 'folder' + partner: [ + {'emailAddress': 'some.contact@example.com'}, # The POC. + {'emailAddress': 'another.contact@edx.org'}, + {'emailAddress': 'third@edx.org'} + ] + for partner in flatten_partner_list(fake_partners) + } + mock_create_comments.return_value = None + + mock_retirement_report.return_value = _fake_retirement_report(user_orgs=list(FAKE_ORGS.keys())) + mock_retirement_cleanup.side_effect = Exception('Mock cleanup exception') + + result = _call_script(expect_success=False) + + mock_retirement_cleanup.assert_called_with( + [{'original_username': user[LEARNER_ORIGINAL_USERNAME_KEY]} for user in mock_retirement_report.return_value] + ) + + assert result.exit_code == ERR_CLEANUP + assert 'Users may be stuck in the processing state!' in result.output + + +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.__init__') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_file_in_folder') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.walk_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.list_permissions_for_files') +@patch('scripts.user_retirement.utils.thirdparty_apis.google_api.DriveApi.create_comments_for_files') +@patch('scripts.user_retirement.utils.edx_api.BaseApiClient.get_access_token') +@patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + retirement_partner_report=DEFAULT, + retirement_partner_cleanup=DEFAULT +) +def test_google_unicode_folder_names(*args, **kwargs): + mock_get_access_token = args[0] + mock_create_comments = args[1] + mock_list_permissions = args[2] + mock_walk_files = args[3] + mock_create_files = args[4] + mock_driveapi = args[5] + mock_retirement_report = kwargs['retirement_partner_report'] + mock_retirement_cleanup = kwargs['retirement_partner_cleanup'] + + mock_get_access_token.return_value = ('THIS_IS_A_JWT', None) + mock_list_permissions.return_value = { + 'folder' + partner: [ + {'emailAddress': 'some.contact@example.com'}, + {'emailAddress': 'another.contact@edx.org'}, + ] + for partner in [ + unicodedata.normalize('NFKC', u'TéstX'), + unicodedata.normalize('NFKC', u'TéstX2'), + unicodedata.normalize('NFKC', u'TéstX3'), + ] + } + mock_walk_files.return_value = [ + {'name': partner, 'id': 'folder' + partner} + for partner in [ + unicodedata.normalize('NFKC', u'TéstX'), + unicodedata.normalize('NFKC', u'TéstX2'), + unicodedata.normalize('NFKC', u'TéstX3'), + ] + ] + mock_create_files.side_effect = ['foo', 'bar', 'baz'] + mock_driveapi.return_value = None + mock_retirement_report.return_value = _fake_retirement_report(user_orgs=list(FAKE_ORGS.keys())) + + config_orgs = { + 'org1': [unicodedata.normalize('NFKC', u'TéstX')], + 'org2': [unicodedata.normalize('NFD', u'TéstX2')], + 'org3': [unicodedata.normalize('NFKD', u'TéstX3')], + } + + result = _call_script(config_orgs=config_orgs) + + # Make sure we're getting the LMS token + mock_get_access_token.assert_called_once() + + # Make sure that we get the report + mock_retirement_report.assert_called_once() + + # Make sure we tried to upload the files + assert mock_create_files.call_count == 3 + + # Make sure we tried to add comments to the files + assert mock_create_comments.call_count == 1 + # First [0] returns all positional args, second [0] gets the first positional arg. + create_comments_file_ids, create_comments_messages = zip(*mock_create_comments.call_args[0][0]) + assert set(create_comments_file_ids) == set(['foo', 'bar', 'baz']) + assert all('+some.contact@example.com' in msg for msg in create_comments_messages) + assert all('+another.contact@edx.org' not in msg for msg in create_comments_messages) + + # Make sure we tried to remove the users from the queue + mock_retirement_cleanup.assert_called_with( + [{'original_username': user[LEARNER_ORIGINAL_USERNAME_KEY]} for user in mock_retirement_report.return_value] + ) + + assert 'All reports completed and uploaded to Google.' in result.output + + +def test_file_content_custom_headings(): + runner = CliRunner() + with runner.isolated_filesystem(): + config = {'partner_report_platform_name': 'fake_platform_name'} + tmp_output_dir = 'test_output_dir' + os.mkdir(tmp_output_dir) + + # Custom headings and values + ch1 = 'special_id' + ch1v = '134456765432' + ch2 = 'alternate_heading_for_email' + ch2v = 'zxcvbvcxz@blah.com' + custom_field_headings = [ch1, ch2] + + org_name = 'my_delightful_org' + username = 'unique_user' + learner_data = [ + { + ch1: ch1v, + ch2: ch2v, + LEARNER_ORIGINAL_USERNAME_KEY: username, + LEARNER_CREATED_KEY: DELETION_TIME, + } + ] + report_data = { + org_name: { + ORGS_CONFIG_FIELD_HEADINGS_KEY: custom_field_headings, + ORGS_CONFIG_LEARNERS_KEY: learner_data + } + } + + partner_filenames = _generate_report_files_or_exit(config, report_data, tmp_output_dir) + + assert len(partner_filenames) == 1 + filename = partner_filenames[org_name] + with open(filename) as f: + file_content = f.read() + + # Custom field headings + for ch in custom_field_headings: + # Verify custom field headings are present + assert ch in file_content + # Verify custom field values are present + assert ch1v in file_content + assert ch2v in file_content + + # Default field headings + for h in DEFAULT_FIELD_HEADINGS: + # Verify default field headings are not present + assert h not in file_content + # Verify default field values are not present + assert username not in file_content + assert DELETION_TIME not in file_content diff --git a/scripts/user_retirement/tests/utils/__init__.py b/scripts/user_retirement/tests/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/scripts/user_retirement/tests/utils/test_edx_api.py b/scripts/user_retirement/tests/utils/test_edx_api.py new file mode 100644 index 000000000000..1c1aa78c39bd --- /dev/null +++ b/scripts/user_retirement/tests/utils/test_edx_api.py @@ -0,0 +1,584 @@ +""" +Tests for edX API calls. +""" +import unittest +from urllib.parse import urljoin + +import requests +import responses +from ddt import data, ddt, unpack +from mock import DEFAULT, patch +from requests.exceptions import ConnectionError, HTTPError +from responses import GET, PATCH, POST, matchers +from responses.registries import OrderedRegistry + +from scripts.user_retirement.tests.mixins import OAuth2Mixin +from scripts.user_retirement.tests.retirement_helpers import ( + FAKE_DATETIME_OBJECT, + FAKE_ORIGINAL_USERNAME, + FAKE_RESPONSE_MESSAGE, + FAKE_USERNAME_MAPPING, + FAKE_USERNAMES, + TEST_RETIREMENT_QUEUE_STATES, + TEST_RETIREMENT_STATE, + get_fake_user_retirement +) +from scripts.user_retirement.utils import edx_api + + +class BackoffTriedException(Exception): + """ + Raise this from a backoff handler to indicate that backoff was tried. + """ + + +@ddt +class TestLmsApi(OAuth2Mixin, unittest.TestCase): + """ + Test the edX LMS API client. + """ + + @responses.activate(registry=OrderedRegistry) + def setUp(self): + super().setUp() + self.mock_access_token_response() + self.lms_base_url = 'http://localhost:18000/' + self.lms_api = edx_api.LmsApi( + self.lms_base_url, + self.lms_base_url, + 'the_client_id', + 'the_client_secret' + ) + + def tearDown(self): + super().tearDown() + responses.reset() + + @patch.object(edx_api.LmsApi, 'learners_to_retire') + def test_learners_to_retire(self, mock_method): + params = { + 'states': TEST_RETIREMENT_QUEUE_STATES, + 'cool_off_days': 365, + } + responses.add( + GET, + urljoin(self.lms_base_url, 'api/user/v1/accounts/retirement_queue/'), + match=[matchers.query_param_matcher(params)], + ) + self.lms_api.learners_to_retire( + TEST_RETIREMENT_QUEUE_STATES, cool_off_days=365) + mock_method.assert_called_once_with( + TEST_RETIREMENT_QUEUE_STATES, cool_off_days=365) + + @patch.object(edx_api.LmsApi, 'get_learners_by_date_and_status') + def test_get_learners_by_date_and_status(self, mock_method): + query_params = { + 'start_date': FAKE_DATETIME_OBJECT.strftime('%Y-%m-%d'), + 'end_date': FAKE_DATETIME_OBJECT.strftime('%Y-%m-%d'), + 'state': TEST_RETIREMENT_STATE, + } + responses.add( + GET, + urljoin(self.lms_base_url, 'api/user/v1/accounts/retirements_by_status_and_date/'), + match=[matchers.query_param_matcher(query_params)] + ) + self.lms_api.get_learners_by_date_and_status( + state_to_request=TEST_RETIREMENT_STATE, + start_date=FAKE_DATETIME_OBJECT, + end_date=FAKE_DATETIME_OBJECT + ) + mock_method.assert_called_once_with( + state_to_request=TEST_RETIREMENT_STATE, + start_date=FAKE_DATETIME_OBJECT, + end_date=FAKE_DATETIME_OBJECT + ) + + @patch.object(edx_api.LmsApi, 'get_learner_retirement_state') + def test_get_learner_retirement_state(self, mock_method): + responses.add( + GET, + urljoin(self.lms_base_url, f'api/user/v1/accounts/{FAKE_ORIGINAL_USERNAME}/retirement_status/'), + ) + self.lms_api.get_learner_retirement_state( + username=FAKE_ORIGINAL_USERNAME + ) + mock_method.assert_called_once_with( + username=FAKE_ORIGINAL_USERNAME + ) + + @patch.object(edx_api.LmsApi, 'update_learner_retirement_state') + def test_update_leaner_retirement_state(self, mock_method): + json_data = { + 'username': FAKE_ORIGINAL_USERNAME, + 'new_state': TEST_RETIREMENT_STATE, + 'response': FAKE_RESPONSE_MESSAGE, + } + responses.add( + PATCH, + urljoin(self.lms_base_url, 'api/user/v1/accounts/update_retirement_status/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.lms_api.update_learner_retirement_state( + username=FAKE_ORIGINAL_USERNAME, + new_state_name=TEST_RETIREMENT_STATE, + message=FAKE_RESPONSE_MESSAGE + ) + mock_method.assert_called_once_with( + username=FAKE_ORIGINAL_USERNAME, + new_state_name=TEST_RETIREMENT_STATE, + message=FAKE_RESPONSE_MESSAGE + ) + + @data( + { + 'api_url': 'api/user/v1/accounts/deactivate_logout/', + 'mock_method': 'retirement_deactivate_logout', + 'method': 'POST', + }, + { + 'api_url': 'api/discussion/v1/accounts/retire_forum/', + 'mock_method': 'retirement_retire_forum', + 'method': 'POST', + }, + { + 'api_url': 'api/user/v1/accounts/retire_mailings/', + 'mock_method': 'retirement_retire_mailings', + 'method': 'POST', + }, + { + 'api_url': 'api/enrollment/v1/unenroll/', + 'mock_method': 'retirement_unenroll', + 'method': 'POST', + }, + { + 'api_url': 'api/edxnotes/v1/retire_user/', + 'mock_method': 'retirement_retire_notes', + 'method': 'POST', + }, + { + 'api_url': 'api/user/v1/accounts/retire_misc/', + 'mock_method': 'retirement_lms_retire_misc', + 'method': 'POST', + }, + { + 'api_url': 'api/user/v1/accounts/retire/', + 'mock_method': 'retirement_lms_retire', + 'method': 'POST', + }, + { + 'api_url': 'api/user/v1/accounts/retirement_partner_report/', + 'mock_method': 'retirement_partner_queue', + 'method': 'PUT', + }, + ) + @unpack + @patch.multiple( + 'scripts.user_retirement.utils.edx_api.LmsApi', + retirement_deactivate_logout=DEFAULT, + retirement_retire_forum=DEFAULT, + retirement_retire_mailings=DEFAULT, + retirement_unenroll=DEFAULT, + retirement_retire_notes=DEFAULT, + retirement_lms_retire_misc=DEFAULT, + retirement_lms_retire=DEFAULT, + retirement_partner_queue=DEFAULT, + ) + def test_learner_retirement(self, api_url, mock_method, method, **kwargs): + json_data = { + 'username': FAKE_ORIGINAL_USERNAME, + } + responses.add( + method, + urljoin(self.lms_base_url, api_url), + match=[matchers.json_params_matcher(json_data)] + ) + getattr(self.lms_api, mock_method)(get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)) + kwargs[mock_method].assert_called_once_with(get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME)) + + @patch.object(edx_api.LmsApi, 'retirement_partner_report') + def test_retirement_partner_report(self, mock_method): + responses.add( + POST, + urljoin(self.lms_base_url, 'api/user/v1/accounts/retirement_partner_report/') + ) + self.lms_api.retirement_partner_report( + learner=get_fake_user_retirement( + original_username=FAKE_ORIGINAL_USERNAME + ) + ) + mock_method.assert_called_once_with( + learner=get_fake_user_retirement( + original_username=FAKE_ORIGINAL_USERNAME + ) + ) + + @patch.object(edx_api.LmsApi, 'retirement_partner_cleanup') + def test_retirement_partner_cleanup(self, mock_method): + json_data = FAKE_USERNAMES + responses.add( + POST, + urljoin(self.lms_base_url, 'api/user/v1/accounts/retirement_partner_report_cleanup/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.lms_api.retirement_partner_cleanup( + usernames=FAKE_USERNAMES + ) + mock_method.assert_called_once_with( + usernames=FAKE_USERNAMES + ) + + @patch.object(edx_api.LmsApi, 'retirement_retire_proctoring_data') + def test_retirement_retire_proctoring_data(self, mock_method): + learner = get_fake_user_retirement() + responses.add( + POST, + urljoin(self.lms_base_url, f"api/edx_proctoring/v1/retire_user/{learner['user']['id']}/"), + ) + self.lms_api.retirement_retire_proctoring_data() + mock_method.assert_called_once() + + @patch.object(edx_api.LmsApi, 'retirement_retire_proctoring_backend_data') + def test_retirement_retire_proctoring_backend_data(self, mock_method): + learner = get_fake_user_retirement() + responses.add( + POST, + urljoin(self.lms_base_url, f"api/edx_proctoring/v1/retire_backend_user/{learner['user']['id']}/"), + ) + self.lms_api.retirement_retire_proctoring_backend_data() + mock_method.assert_called_once() + + @patch.object(edx_api.LmsApi, 'replace_lms_usernames') + def test_replace_lms_usernames(self, mock_method): + json_data = { + 'username_mappings': FAKE_USERNAME_MAPPING + } + responses.add( + POST, + urljoin(self.lms_base_url, 'api/user/v1/accounts/replace_usernames/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.lms_api.replace_lms_usernames( + username_mappings=FAKE_USERNAME_MAPPING + ) + mock_method.assert_called_once_with( + username_mappings=FAKE_USERNAME_MAPPING + ) + + @patch.object(edx_api.LmsApi, 'replace_forums_usernames') + def test_replace_forums_usernames(self, mock_method): + json_data = { + 'username_mappings': FAKE_USERNAME_MAPPING + } + responses.add( + POST, + urljoin(self.lms_base_url, 'api/discussion/v1/accounts/replace_usernames/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.lms_api.replace_forums_usernames( + username_mappings=FAKE_USERNAME_MAPPING + ) + mock_method.assert_called_once_with( + username_mappings=FAKE_USERNAME_MAPPING + ) + + @data(504, 500) + @patch('scripts.user_retirement.utils.edx_api._backoff_handler') + @patch.object(edx_api.LmsApi, 'learners_to_retire') + def test_retrieve_learner_queue_backoff( + self, + svr_status_code, + mock_backoff_handler, + mock_learners_to_retire + ): + mock_backoff_handler.side_effect = BackoffTriedException + params = { + 'states': TEST_RETIREMENT_QUEUE_STATES, + 'cool_off_days': 365, + } + response = requests.Response() + response.status_code = svr_status_code + responses.add( + GET, + urljoin(self.lms_base_url, 'api/user/v1/accounts/retirement_queue/'), + status=200, + match=[matchers.query_param_matcher(params)], + ) + + mock_learners_to_retire.side_effect = HTTPError(response=response) + with self.assertRaises(BackoffTriedException): + self.lms_api.learners_to_retire( + TEST_RETIREMENT_QUEUE_STATES, cool_off_days=365) + + @data(104) + @responses.activate + @patch('scripts.user_retirement.utils.edx_api._backoff_handler') + @patch.object(edx_api.LmsApi, 'retirement_partner_cleanup') + def test_retirement_partner_cleanup_backoff_on_connection_error( + self, + svr_status_code, + mock_backoff_handler, + mock_retirement_partner_cleanup + ): + mock_backoff_handler.side_effect = BackoffTriedException + response = requests.Response() + response.status_code = svr_status_code + mock_retirement_partner_cleanup.retirement_partner_cleanup.side_effect = ConnectionError( + response=response + ) + with self.assertRaises(BackoffTriedException): + self.lms_api.retirement_partner_cleanup([{'original_username': 'test'}]) + + +class TestEcommerceApi(OAuth2Mixin, unittest.TestCase): + """ + Test the edX Ecommerce API client. + """ + + @responses.activate(registry=OrderedRegistry) + def setUp(self): + super().setUp() + self.mock_access_token_response() + self.lms_base_url = 'http://localhost:18000/' + self.ecommerce_base_url = 'http://localhost:18130/' + self.ecommerce_api = edx_api.EcommerceApi( + self.lms_base_url, + self.ecommerce_base_url, + 'the_client_id', + 'the_client_secret' + ) + + def tearDown(self): + super().tearDown() + responses.reset() + + @patch.object(edx_api.EcommerceApi, 'retire_learner') + def test_retirement_partner_report(self, mock_method): + json_data = { + 'username': FAKE_ORIGINAL_USERNAME, + } + responses.add( + POST, + urljoin(self.lms_base_url, 'api/v2/user/retire/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.ecommerce_api.retire_learner( + learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME) + ) + mock_method.assert_called_once_with( + learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME) + ) + + @patch.object(edx_api.EcommerceApi, 'retire_learner') + def get_tracking_key(self, mock_method): + original_username = { + 'original_username': get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME) + } + responses.add( + GET, + urljoin(self.lms_base_url, f"api/v2/retirement/tracking_id/{original_username}/"), + ) + self.ecommerce_api.get_tracking_key( + learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME) + ) + mock_method.assert_called_once_with( + learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME) + ) + + @patch.object(edx_api.EcommerceApi, 'replace_usernames') + def test_replace_usernames(self, mock_method): + json_data = { + "username_mappings": FAKE_USERNAME_MAPPING + } + responses.add( + POST, + urljoin(self.lms_base_url, 'api/v2/user_management/replace_usernames/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.ecommerce_api.replace_usernames( + username_mappings=FAKE_USERNAME_MAPPING + ) + mock_method.assert_called_once_with( + username_mappings=FAKE_USERNAME_MAPPING + ) + + +class TestCredentialApi(OAuth2Mixin, unittest.TestCase): + """ + Test the edX Credential API client. + """ + + @responses.activate(registry=OrderedRegistry) + def setUp(self): + super().setUp() + self.mock_access_token_response() + self.lms_base_url = 'http://localhost:18000/' + self.credentials_base_url = 'http://localhost:18150/' + self.credentials_api = edx_api.CredentialsApi( + self.lms_base_url, + self.credentials_base_url, + 'the_client_id', + 'the_client_secret' + ) + + def tearDown(self): + super().tearDown() + responses.reset() + + @patch.object(edx_api.CredentialsApi, 'retire_learner') + def test_retire_learner(self, mock_method): + json_data = { + 'username': FAKE_ORIGINAL_USERNAME + } + responses.add( + POST, + urljoin(self.credentials_base_url, 'user/retire/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.credentials_api.retire_learner( + learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME) + ) + mock_method.assert_called_once_with( + learner=get_fake_user_retirement(original_username=FAKE_ORIGINAL_USERNAME) + ) + + @patch.object(edx_api.CredentialsApi, 'replace_usernames') + def test_replace_usernames(self, mock_method): + json_data = { + "username_mappings": FAKE_USERNAME_MAPPING + } + responses.add( + POST, + urljoin(self.credentials_base_url, 'api/v2/replace_usernames/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.credentials_api.replace_usernames( + username_mappings=FAKE_USERNAME_MAPPING + ) + mock_method.assert_called_once_with( + username_mappings=FAKE_USERNAME_MAPPING + ) + + +class TestDiscoveryApi(OAuth2Mixin, unittest.TestCase): + """ + Test the edX Discovery API client. + """ + + @responses.activate(registry=OrderedRegistry) + def setUp(self): + super().setUp() + self.mock_access_token_response() + self.lms_base_url = 'http://localhost:18000/' + self.discovery_base_url = 'http://localhost:18150/' + self.discovery_api = edx_api.DiscoveryApi( + self.lms_base_url, + self.discovery_base_url, + 'the_client_id', + 'the_client_secret' + ) + + def tearDown(self): + super().tearDown() + responses.reset() + + @patch.object(edx_api.DiscoveryApi, 'replace_usernames') + def test_replace_usernames(self, mock_method): + json_data = { + "username_mappings": FAKE_USERNAME_MAPPING + } + responses.add( + POST, + urljoin(self.discovery_base_url, 'api/v1/replace_usernames/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.discovery_api.replace_usernames( + username_mappings=FAKE_USERNAME_MAPPING + ) + mock_method.assert_called_once_with( + username_mappings=FAKE_USERNAME_MAPPING + ) + + +class TestDemographicsApi(OAuth2Mixin, unittest.TestCase): + """ + Test the edX Demographics API client. + """ + + @responses.activate(registry=OrderedRegistry) + def setUp(self): + super().setUp() + self.mock_access_token_response() + self.lms_base_url = 'http://localhost:18000/' + self.demographics_base_url = 'http://localhost:18360/' + self.demographics_api = edx_api.DemographicsApi( + self.lms_base_url, + self.demographics_base_url, + 'the_client_id', + 'the_client_secret' + ) + + def tearDown(self): + super().tearDown() + responses.reset() + + @patch.object(edx_api.DemographicsApi, 'retire_learner') + def test_retire_learner(self, mock_method): + json_data = { + 'lms_user_id': get_fake_user_retirement()['user']['id'] + } + responses.add( + POST, + urljoin(self.demographics_base_url, 'demographics/api/v1/retire_demographics/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.demographics_api.retire_learner( + learner=get_fake_user_retirement() + ) + mock_method.assert_called_once_with( + learner=get_fake_user_retirement() + ) + + +class TestLicenseManagerApi(OAuth2Mixin, unittest.TestCase): + """ + Test the edX License Manager API client. + """ + + @responses.activate(registry=OrderedRegistry) + def setUp(self): + super().setUp() + self.mock_access_token_response() + self.lms_base_url = 'http://localhost:18000/' + self.license_manager_base_url = 'http://localhost:18170/' + self.license_manager_api = edx_api.LicenseManagerApi( + self.lms_base_url, + self.license_manager_base_url, + 'the_client_id', + 'the_client_secret' + ) + + def tearDown(self): + super().tearDown() + responses.reset() + + @patch.object(edx_api.LicenseManagerApi, 'retire_learner') + def test_retire_learner(self, mock_method): + json_data = { + 'lms_user_id': get_fake_user_retirement()['user']['id'], + 'original_username': FAKE_ORIGINAL_USERNAME, + } + responses.add( + POST, + urljoin(self.license_manager_base_url, 'api/v1/retire_user/'), + match=[matchers.json_params_matcher(json_data)] + ) + self.license_manager_api.retire_learner( + learner=get_fake_user_retirement( + original_username=FAKE_ORIGINAL_USERNAME + ) + ) + mock_method.assert_called_once_with( + learner=get_fake_user_retirement( + original_username=FAKE_ORIGINAL_USERNAME + ) + ) diff --git a/scripts/user_retirement/tests/utils/test_jenkins.py b/scripts/user_retirement/tests/utils/test_jenkins.py new file mode 100644 index 000000000000..8d32ee24c2df --- /dev/null +++ b/scripts/user_retirement/tests/utils/test_jenkins.py @@ -0,0 +1,193 @@ +""" +Tests for triggering a Jenkins job. +""" + +import json +import re +import unittest +from itertools import islice + +import backoff +import ddt +import requests_mock +from mock import Mock, call, mock_open, patch + +import scripts.user_retirement.utils.jenkins as jenkins +from scripts.user_retirement.utils.exception import BackendError + +BASE_URL = u'https://test-jenkins' +USER_ID = u'foo' +USER_TOKEN = u'12345678901234567890123456789012' +JOB = u'test-job' +TOKEN = u'asdf' +BUILD_NUM = 456 +JOBS_URL = u'{}/job/{}/'.format(BASE_URL, JOB) +JOB_URL = u'{}{}'.format(JOBS_URL, BUILD_NUM) +MOCK_BUILD = {u'number': BUILD_NUM, u'url': JOB_URL} +MOCK_JENKINS_DATA = {'jobs': [{'name': JOB, 'url': JOBS_URL, 'color': 'blue'}]} +MOCK_BUILDS_DATA = { + 'actions': [ + {'parameterDefinitions': [ + {'defaultParameterValue': {'value': '0'}, 'name': 'EXIT_CODE', 'type': 'StringParameterDefinition'} + ]} + ], + 'builds': [MOCK_BUILD], + 'lastBuild': MOCK_BUILD +} +MOCK_QUEUE_DATA = { + 'id': 123, + 'task': {'name': JOB, 'url': JOBS_URL}, + 'executable': {'number': BUILD_NUM, 'url': JOB_URL} +} +MOCK_BUILD_DATA = { + 'actions': [{}], + 'fullDisplayName': 'foo', + 'number': BUILD_NUM, + 'result': 'SUCCESS', + 'url': JOB_URL, +} +MOCK_CRUMB_DATA = { + 'crumbRequestField': 'Jenkins-Crumb', + 'crumb': '1234567890' +} + + +class TestProperties(unittest.TestCase): + """ + Test the Jenkins property-creating methods. + """ + + def test_properties_files(self): + learners = [ + { + 'original_username': 'learnerA' + }, + { + 'original_username': 'learnerB' + }, + ] + open_mocker = mock_open() + with patch('scripts.user_retirement.utils.jenkins.open', open_mocker, create=True): + jenkins._recreate_directory = Mock() # pylint: disable=protected-access + jenkins.export_learner_job_properties(learners, "tmpdir") + jenkins._recreate_directory.assert_called_once() # pylint: disable=protected-access + self.assertIn(call('tmpdir/learner_retire_learnera', 'w'), open_mocker.call_args_list) + self.assertIn(call('tmpdir/learner_retire_learnerb', 'w'), open_mocker.call_args_list) + handle = open_mocker() + self.assertIn(call('RETIREMENT_USERNAME=learnerA\n'), handle.write.call_args_list) + self.assertIn(call('RETIREMENT_USERNAME=learnerB\n'), handle.write.call_args_list) + + +@ddt.ddt +class TestBackoff(unittest.TestCase): + u""" + Test of custom backoff code (wait time generator and max_tries) + """ + + @ddt.data( + (2, 1, 1, 2, [1]), + (2, 1, 2, 3, [1, 1]), + (2, 1, 3, 3, [1, 2]), + (2, 100, 90, 2, [90]), + (2, 1, 90, 8, [1, 2, 4, 8, 16, 32, 27]), + (3, 5, 1000, 7, [5, 15, 45, 135, 405, 395]), + (2, 1, 3600, 13, [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1553]), + ) + @ddt.unpack + def test_max_timeout(self, base, factor, timeout, expected_max_tries, expected_waits): + # pylint: disable=protected-access + wait_gen, max_tries = jenkins._backoff_timeout(timeout, base, factor) + self.assertEqual(expected_max_tries, max_tries) + + # Use max_tries-1, because we only wait that many times + waits = list(islice(wait_gen(), max_tries - 1)) + self.assertEqual(expected_waits, waits) + + self.assertEqual(timeout, sum(waits)) + + def test_backoff_call(self): + # pylint: disable=protected-access + wait_gen, max_tries = jenkins._backoff_timeout(timeout=.36, base=2, factor=.0001) + always_false = Mock(return_value=False) + + count_retries = backoff.on_predicate( + wait_gen, + max_tries=max_tries, + on_backoff=print, + jitter=None, + )(always_false.__call__) + + count_retries() + + self.assertEqual(always_false.call_count, 13) + + +@ddt.ddt +class TestJenkinsAPI(unittest.TestCase): + """ + Tests for interacting with the Jenkins API + """ + + @requests_mock.Mocker() + def test_failure(self, mock): + """ + Test the failure condition when triggering a jenkins job + """ + # Mock all network interactions + mock.get( + re.compile(".*"), + status_code=404, + ) + with self.assertRaises(BackendError): + jenkins.trigger_build(BASE_URL, USER_ID, USER_TOKEN, JOB, TOKEN, None, ()) + + @ddt.data( + (None, ()), + ('my cause', ()), + (None, ((u'FOO', u'bar'),)), + (None, ((u'FOO', u'bar'), (u'BAZ', u'biz'))), + ('my cause', ((u'FOO', u'bar'),)), + ) + @ddt.unpack + @requests_mock.Mocker() + def test_success(self, cause, param, mock): + u""" + Test triggering a jenkins job + """ + + def text_callback(request, context): + u""" What to return from the mock. """ + # This is the initial call that jenkinsapi uses to + # establish connectivity to Jenkins + # https://test-jenkins/api/python?tree=jobs[name,color,url] + context.status_code = 200 + if request.url.startswith(u'https://test-jenkins/api/python'): + return json.dumps(MOCK_JENKINS_DATA) + elif request.url.startswith(u'https://test-jenkins/job/test-job/456'): + return json.dumps(MOCK_BUILD_DATA) + elif request.url.startswith(u'https://test-jenkins/job/test-job'): + return json.dumps(MOCK_BUILDS_DATA) + elif request.url.startswith(u'https://test-jenkins/queue/item/123/api/python'): + return json.dumps(MOCK_QUEUE_DATA) + elif request.url.startswith(u'https://test-jenkins/crumbIssuer/api/python'): + return json.dumps(MOCK_CRUMB_DATA) + else: + # We should never get here, unless the jenkinsapi implementation changes. + # This response will catch that condition. + context.status_code = 500 + return None + + # Mock all network interactions + mock.get( + re.compile('.*'), + text=text_callback + ) + mock.post( + '{}/job/test-job/buildWithParameters'.format(BASE_URL), + status_code=201, # Jenkins responds with a 201 Created on success + headers={'location': '{}/queue/item/123'.format(BASE_URL)} + ) + + # Make the call to the Jenkins API + result = jenkins.trigger_build(BASE_URL, USER_ID, USER_TOKEN, JOB, TOKEN, cause, param) + self.assertEqual(result, 'SUCCESS') diff --git a/scripts/user_retirement/tests/utils/thirdparty_apis/__init__.py b/scripts/user_retirement/tests/utils/thirdparty_apis/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/scripts/user_retirement/tests/utils/thirdparty_apis/test_amplitude.py b/scripts/user_retirement/tests/utils/thirdparty_apis/test_amplitude.py new file mode 100644 index 000000000000..253834e19cb8 --- /dev/null +++ b/scripts/user_retirement/tests/utils/thirdparty_apis/test_amplitude.py @@ -0,0 +1,86 @@ +""" +Tests for the Amplitude API functionality +""" +import logging +import os +import unittest +from unittest import mock + +import ddt +import requests_mock + +MAX_ATTEMPTS = int(os.environ.get("RETRY_MAX_ATTEMPTS", 5)) +from scripts.user_retirement.utils.thirdparty_apis.amplitude_api import ( + AmplitudeApi, + AmplitudeException, + AmplitudeRecoverableException +) + + +@ddt.ddt +@requests_mock.Mocker() +class TestAmplitude(unittest.TestCase): + """ + Class containing tests of all code interacting with Amplitude. + """ + + def setUp(self): + super().setUp() + self.user = {"user": {"id": "1234"}} + self.amplitude = AmplitudeApi("test-api-key", "test-secret-key") + + def _mock_delete(self, req_mock, status_code, message=None): + """ + Send a mock request with dummy headers and status code. + + """ + req_mock.post( + "https://amplitude.com/api/2/deletions/users", + headers={"Content-Type": "application/json"}, + json={}, + status_code=status_code + ) + + def test_delete_happy_path(self, req_mock): + """ + This test pass status_code 200 to mock_delete see how AmplitudeApi respond in happy path. + + """ + self._mock_delete(req_mock, 200) + logger = logging.getLogger("scripts.user_retirement.utils.thirdparty_apis.amplitude_api") + with mock.patch.object(logger, "info") as mock_info: + self.amplitude.delete_user(self.user) + + self.assertEqual(mock_info.call_args, [("Amplitude user deletion succeeded",)]) + + self.assertEqual(len(req_mock.request_history), 1) + request = req_mock.request_history[0] + self.assertEqual(request.json(), + {"user_ids": ["1234"], 'ignore_invalid_id': 'true', "requester": "user-retirement-pipeline"}) + + def test_delete_fatal_error(self, req_mock): + """ + This test pass status_code 404 to see how AmplitudeApi respond in fatal error case. + + """ + self._mock_delete(req_mock, 404) + message = None + logger = logging.getLogger("scripts.user_retirement.utils.thirdparty_apis.amplitude_api") + with mock.patch.object(logger, "error") as mock_error: + with self.assertRaises(AmplitudeException) as exc: + self.amplitude.delete_user(self.user) + error = "Amplitude user deletion failed due to {message}".format(message=message) + self.assertEqual(mock_error.call_args, [(error,)]) + self.assertEqual(str(exc.exception), error) + + @ddt.data(429, 500) + def test_delete_recoverable_error(self, status_code, req_mock): + """ + This test pass status_code 429 and 500 to see how AmplitudeApi respond to recoverable cases. + + """ + self._mock_delete(req_mock, status_code) + + with self.assertRaises(AmplitudeRecoverableException): + self.amplitude.delete_user(self.user) + self.assertEqual(len(req_mock.request_history), MAX_ATTEMPTS) diff --git a/scripts/user_retirement/tests/utils/thirdparty_apis/test_braze.py b/scripts/user_retirement/tests/utils/thirdparty_apis/test_braze.py new file mode 100644 index 000000000000..57ca77929fa7 --- /dev/null +++ b/scripts/user_retirement/tests/utils/thirdparty_apis/test_braze.py @@ -0,0 +1,66 @@ +""" +Tests for the Braze API functionality +""" +import logging +import unittest +from unittest import mock + +import ddt +import requests_mock + +from scripts.user_retirement.utils.thirdparty_apis.braze_api import BrazeApi, BrazeException, BrazeRecoverableException + + +@ddt.ddt +@requests_mock.Mocker() +class TestBraze(unittest.TestCase): + """ + Class containing tests of all code interacting with Braze. + """ + + def setUp(self): + super().setUp() + self.learner = {'user': {'id': 1234}} + self.braze = BrazeApi('test-key', 'test-instance') + + def _mock_delete(self, req_mock, status_code, message=None): + req_mock.post( + 'https://rest.test-instance.braze.com/users/delete', + request_headers={'Authorization': 'Bearer test-key'}, + json={'message': message} if message else {}, + status_code=status_code + ) + + def test_delete_happy_path(self, req_mock): + self._mock_delete(req_mock, 200) + + logger = logging.getLogger('scripts.user_retirement.utils.thirdparty_apis.braze_api') + with mock.patch.object(logger, 'info') as mock_info: + self.braze.delete_user(self.learner) + + self.assertEqual(mock_info.call_args, [('Braze user deletion succeeded',)]) + + self.assertEqual(len(req_mock.request_history), 1) + request = req_mock.request_history[0] + self.assertEqual(request.json(), {'external_ids': [1234]}) + + def test_delete_fatal_error(self, req_mock): + self._mock_delete(req_mock, 404, message='Test Error Message') + + logger = logging.getLogger('scripts.user_retirement.utils.thirdparty_apis.braze_api') + with mock.patch.object(logger, 'error') as mock_error: + with self.assertRaises(BrazeException) as exc: + self.braze.delete_user(self.learner) + + error = 'Braze user deletion failed due to Test Error Message' + self.assertEqual(mock_error.call_args, [(error,)]) + self.assertEqual(str(exc.exception), error) + + @ddt.data(429, 500) + def test_delete_recoverable_error(self, status_code, req_mock): + self._mock_delete(req_mock, status_code) + + with self.assertRaises(BrazeRecoverableException): + self.braze.delete_user(self.learner) + + self.assertEqual(len(req_mock.request_history), 5) diff --git a/scripts/user_retirement/tests/utils/thirdparty_apis/test_hubspot.py b/scripts/user_retirement/tests/utils/thirdparty_apis/test_hubspot.py new file mode 100644 index 000000000000..85e58b3d953e --- /dev/null +++ b/scripts/user_retirement/tests/utils/thirdparty_apis/test_hubspot.py @@ -0,0 +1,159 @@ +""" +Tests for the Sailthru API functionality +""" +import logging +import os +import unittest +from unittest import mock + +import requests_mock +from six.moves import reload_module + +# This module is imported separately solely so it can be re-loaded below. +from scripts.user_retirement.utils.thirdparty_apis import hubspot_api +# This HubspotAPI class will be used without being re-loaded. +from scripts.user_retirement.utils.thirdparty_apis.hubspot_api import HubspotAPI + +# Change the number of retries for Hubspot API's delete_user call to 1. +# Then reload hubspot_api so only a single retry is performed. +os.environ['RETRY_HUBSPOT_MAX_ATTEMPTS'] = "1" +reload_module(hubspot_api) # pylint: disable=too-many-function-args + + +@requests_mock.Mocker() +@mock.patch.object(HubspotAPI, 'send_marketing_alert') +class TestHubspot(unittest.TestCase): + """ + Class containing tests of all code interacting with Hubspot. + """ + + def setUp(self): + super(TestHubspot, self).setUp() + self.test_learner = {'original_email': 'foo@bar.com'} + self.api_key = 'example_key' + self.test_vid = 12345 + self.test_region = 'test-east-1' + self.from_address = 'no-reply@example.com' + self.alert_email = 'marketing@example.com' + + def _mock_get_vid(self, req_mock, status_code): + req_mock.get( + hubspot_api.GET_VID_FROM_EMAIL_URL_TEMPLATE.format( + email=self.test_learner['original_email'] + ), + json={'vid': self.test_vid}, + status_code=status_code + ) + + def _mock_delete(self, req_mock, status_code): + req_mock.delete( + hubspot_api.DELETE_USER_FROM_VID_TEMPLATE.format( + vid=self.test_vid + ), + json={}, + status_code=status_code + ) + + def test_delete_no_email(self, req_mock, mock_alert): # pylint: disable=unused-argument + with self.assertRaises(TypeError) as exc: + HubspotAPI( + self.api_key, + self.test_region, + self.from_address, + self.alert_email + ).delete_user({}) + self.assertIn('Expected an email address for user to delete, but received None.', str(exc)) + mock_alert.assert_not_called() + + def test_delete_success(self, req_mock, mock_alert): + self._mock_get_vid(req_mock, 200) + self._mock_delete(req_mock, 200) + logger = logging.getLogger('scripts.user_retirement.utils.thirdparty_apis.hubspot_api') + + with mock.patch.object(logger, 'info') as mock_info: + HubspotAPI( + self.api_key, + self.test_region, + self.from_address, + self.alert_email + ).delete_user(self.test_learner) + mock_info.assert_called_once_with("User successfully deleted from Hubspot") + mock_alert.assert_called_once_with(12345) + + def test_delete_email_does_not_exist(self, req_mock, mock_alert): + self._mock_get_vid(req_mock, 404) + logger = logging.getLogger('scripts.user_retirement.utils.thirdparty_apis.hubspot_api') + with mock.patch.object(logger, 'info') as mock_info: + HubspotAPI( + self.api_key, + self.test_region, + self.from_address, + self.alert_email + ).delete_user(self.test_learner) + mock_info.assert_called_once_with("No action taken because no user was found in Hubspot.") + mock_alert.assert_not_called() + + def test_delete_server_failure_on_user_retrieval(self, req_mock, mock_alert): + self._mock_get_vid(req_mock, 500) + with self.assertRaises(hubspot_api.HubspotException) as exc: + HubspotAPI( + self.api_key, + self.test_region, + self.from_address, + self.alert_email + ).delete_user(self.test_learner) + self.assertIn("Error attempted to get user_vid from Hubspot", str(exc)) + mock_alert.assert_not_called() + + def test_delete_unauthorized_deletion(self, req_mock, mock_alert): + self._mock_get_vid(req_mock, 200) + self._mock_delete(req_mock, 401) + with self.assertRaises(hubspot_api.HubspotException) as exc: + HubspotAPI( + self.api_key, + self.test_region, + self.from_address, + self.alert_email + ).delete_user(self.test_learner) + self.assertIn("Hubspot user deletion failed due to authorized API call", str(exc)) + mock_alert.assert_not_called() + + def test_delete_vid_not_found(self, req_mock, mock_alert): + self._mock_get_vid(req_mock, 200) + self._mock_delete(req_mock, 404) + with self.assertRaises(hubspot_api.HubspotException) as exc: + HubspotAPI( + self.api_key, + self.test_region, + self.from_address, + self.alert_email + ).delete_user(self.test_learner) + self.assertIn("Hubspot user deletion failed because vid doesn't match user", str(exc)) + mock_alert.assert_not_called() + + def test_delete_server_failure_on_deletion(self, req_mock, mock_alert): + self._mock_get_vid(req_mock, 200) + self._mock_delete(req_mock, 500) + with self.assertRaises(hubspot_api.HubspotException) as exc: + HubspotAPI( + self.api_key, + self.test_region, + self.from_address, + self.alert_email + ).delete_user(self.test_learner) + self.assertIn("Hubspot user deletion failed due to server-side (Hubspot) issues", str(exc)) + mock_alert.assert_not_called() + + def test_delete_catch_all_on_deletion(self, req_mock, mock_alert): + self._mock_get_vid(req_mock, 200) + # Testing 403 as it's not a response type per the Hubspot API docs, so it doesn't have it's own error. + self._mock_delete(req_mock, 403) + with self.assertRaises(hubspot_api.HubspotException) as exc: + HubspotAPI( + self.api_key, + self.test_region, + self.from_address, + self.alert_email + ).delete_user(self.test_learner) + self.assertIn("Hubspot user deletion failed due to unknown reasons", str(exc)) + mock_alert.assert_not_called() diff --git a/scripts/user_retirement/tests/utils/thirdparty_apis/test_salesforce.py b/scripts/user_retirement/tests/utils/thirdparty_apis/test_salesforce.py new file mode 100644 index 000000000000..47770439533e --- /dev/null +++ b/scripts/user_retirement/tests/utils/thirdparty_apis/test_salesforce.py @@ -0,0 +1,155 @@ +""" +Tests for the Salesforce API functionality +""" +import logging +from contextlib import contextmanager + +import mock +import pytest +from simple_salesforce import SalesforceError + +from scripts.user_retirement.utils.thirdparty_apis import salesforce_api + + +@pytest.fixture +def test_learner(): + return {'original_email': 'foo@bar.com'} + + +def make_api(): + """ + Helper function to create salesforce api object + """ + return salesforce_api.SalesforceApi("user", "pass", "key", "domain", "user") + + +@contextmanager +def mock_get_user(): + """ + Context manager method to mock getting the assignee user id when the api object is created + """ + with mock.patch( + 'scripts.user_retirement.utils.thirdparty_apis.salesforce_api.SalesforceApi.get_user_id' + ) as getuser: + getuser.return_value = "userid" + yield + + +def test_no_assignee_email(): + with mock.patch( + 'scripts.user_retirement.utils.thirdparty_apis.salesforce_api.SalesforceApi.get_user_id' + ) as getuser: + getuser.return_value = None + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + with pytest.raises(Exception) as exc: + make_api() + print(str(exc)) + assert 'Could not find Salesforce user with username user' in str(exc) + + +def test_retire_no_email(): + with mock_get_user(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + with pytest.raises(TypeError) as exc: + make_api().retire_learner({}) + assert 'Expected an email address for user to delete, but received None.' in str(exc) + + +def test_retire_get_id_error(test_learner): # pylint: disable=redefined-outer-name + with mock_get_user(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + api = make_api() + api._sf.query.side_effect = SalesforceError("", "", "", "") # pylint: disable=protected-access + with pytest.raises(SalesforceError): + api.retire_learner(test_learner) + + +# pylint: disable=protected-access +def test_escape_email(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + api = make_api() + mock_response = {'totalSize': 0, 'records': []} + api._sf.query.return_value = mock_response + api.get_lead_ids_by_email("Robert'); DROP TABLE students;--") + api._sf.query.assert_called_with( + "SELECT Id FROM Lead WHERE Email = 'Robert\\'); DROP TABLE students;--'" + ) + + +# pylint: disable=protected-access +def test_escape_username(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + api = make_api() + mock_response = {'totalSize': 0, 'records': []} + api._sf.query.return_value = mock_response + api.get_user_id("Robert'); DROP TABLE students;--") + api._sf.query.assert_called_with( + "SELECT Id FROM User WHERE Username = 'Robert\\'); DROP TABLE students;--'" + ) + + +def test_retire_learner_not_found(test_learner, caplog): # pylint: disable=redefined-outer-name + caplog.set_level(logging.INFO) + with mock_get_user(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + api = make_api() + mock_response = {'totalSize': 0, 'records': []} + api._sf.query.return_value = mock_response # pylint: disable=protected-access + api.retire_learner(test_learner) + assert not api._sf.Task.create.called # pylint: disable=protected-access + assert 'No action taken because no lead was found in Salesforce.' in caplog.text + + +def test_retire_task_error(test_learner, caplog): # pylint: disable=redefined-outer-name + with mock_get_user(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + api = make_api() + mock_query_response = {'totalSize': 1, 'records': [{'Id': 1}]} + api._sf.query.return_value = mock_query_response # pylint: disable=protected-access + mock_task_response = {'success': False, 'errors': ["This is an error!"]} + api._sf.Task.create.return_value = mock_task_response # pylint: disable=protected-access + with pytest.raises(Exception) as exc: + api.retire_learner(test_learner) + assert "Errors while creating task:" in caplog.text + assert "This is an error!" in caplog.text + assert "Unable to create retirement task for email foo@bar.com" in str(exc) + + +def test_retire_task_exception(test_learner): # pylint: disable=redefined-outer-name + with mock_get_user(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + api = make_api() + mock_query_response = {'totalSize': 1, 'records': [{'Id': 1}]} + api._sf.query.return_value = mock_query_response # pylint: disable=protected-access + api._sf.Task.create.side_effect = SalesforceError("", "", "", "") # pylint: disable=protected-access + with pytest.raises(SalesforceError): + api.retire_learner(test_learner) + + +def test_retire_success(test_learner, caplog): # pylint: disable=redefined-outer-name + caplog.set_level(logging.INFO) + with mock_get_user(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + api = make_api() + mock_query_response = {'totalSize': 1, 'records': [{'Id': 1}]} + api._sf.query.return_value = mock_query_response # pylint: disable=protected-access + mock_task_response = {'success': True, 'id': 'task-id'} + api._sf.Task.create.return_value = mock_task_response # pylint: disable=protected-access + api.retire_learner(test_learner) + assert "Successfully salesforce task created task task-id" in caplog.text + + +def test_retire_multiple_learners(test_learner, caplog): # pylint: disable=redefined-outer-name + caplog.set_level(logging.INFO) + with mock_get_user(): + with mock.patch('scripts.user_retirement.utils.thirdparty_apis.salesforce_api.Salesforce'): + api = make_api() + mock_response = {'totalSize': 2, 'records': [{'Id': 1}, {'Id': 2}]} + api._sf.query.return_value = mock_response # pylint: disable=protected-access + mock_task_response = {'success': True, 'id': 'task-id'} + api._sf.Task.create.return_value = mock_task_response # pylint: disable=protected-access + api.retire_learner(test_learner) + assert "Multiple Ids returned for Lead with email foo@bar.com" in caplog.text + assert "Successfully salesforce task created task task-id" in caplog.text + note = "Notice: Multiple leads were identified with the same email. Please retire all following leads:" + assert note in api._sf.Task.create.call_args[0][0]['Description'] # pylint: disable=protected-access diff --git a/scripts/user_retirement/tests/utils/thirdparty_apis/test_segment_api.py b/scripts/user_retirement/tests/utils/thirdparty_apis/test_segment_api.py new file mode 100644 index 000000000000..8d413c7d3282 --- /dev/null +++ b/scripts/user_retirement/tests/utils/thirdparty_apis/test_segment_api.py @@ -0,0 +1,170 @@ +""" +Tests for the Segment API functionality +""" +import json + +import mock +import pytest +import requests +from six import text_type + +from scripts.user_retirement.tests.retirement_helpers import get_fake_user_retirement +from scripts.user_retirement.utils.thirdparty_apis.segment_api import BULK_REGULATE_URL, SegmentApi + +FAKE_AUTH_TOKEN = 'FakeToken' +TEST_SEGMENT_CONFIG = { + 'projects_to_retire': ['project_1', 'project_2'], + 'learner': [get_fake_user_retirement(), ], + 'fake_base_url': 'https://segment.invalid/', + 'fake_auth_token': FAKE_AUTH_TOKEN, + 'fake_workspace': 'FakeEdx', + 'headers': {"Authorization": "Bearer {}".format(FAKE_AUTH_TOKEN), "Content-Type": "application/json"} +} + + +class FakeResponse: + """ + Fakes out requests.post response + """ + + def json(self): + """ + Returns fake Segment retirement response data in the correct format + """ + return {'regulate_id': 1} + + def raise_for_status(self): + pass + + +class FakeErrorResponse: + """ + Fakes an error response + """ + status_code = 500 + text = "{'error': 'Test error message'}" + + def json(self): + """ + Returns fake Segment retirement response error in the correct format + """ + return json.loads(self.text) + + def raise_for_status(self): + raise requests.exceptions.HTTPError("", response=self) + + +@pytest.fixture +def setup_regulation_api(): + """ + Fixture to setup common bulk delete items. + """ + with mock.patch('requests.post') as mock_post: + segment = SegmentApi( + *[TEST_SEGMENT_CONFIG[key] for key in [ + 'fake_base_url', 'fake_auth_token', 'fake_workspace' + ]] + ) + + yield mock_post, segment + + +def test_bulk_delete_success(setup_regulation_api): # pylint: disable=redefined-outer-name + """ + Test simple success case + """ + mock_post, segment = setup_regulation_api + mock_post.return_value = FakeResponse() + + learner = TEST_SEGMENT_CONFIG['learner'] + segment.delete_and_suppress_learners(learner, 1000) + + assert mock_post.call_count == 1 + + expected_learner = get_fake_user_retirement() + learners_vals = [ + text_type(expected_learner['user']['id']), + expected_learner['original_username'], + expected_learner['ecommerce_segment_id'], + ] + + fake_json = { + "regulation_type": "Suppress_With_Delete", + "attributes": { + "name": "userId", + "values": learners_vals + } + } + + url = TEST_SEGMENT_CONFIG['fake_base_url'] + BULK_REGULATE_URL.format(TEST_SEGMENT_CONFIG['fake_workspace']) + mock_post.assert_any_call( + url, json=fake_json, headers=TEST_SEGMENT_CONFIG['headers'] + ) + + +def test_bulk_delete_error(setup_regulation_api, caplog): # pylint: disable=redefined-outer-name + """ + Test simple error case + """ + mock_post, segment = setup_regulation_api + mock_post.return_value = FakeErrorResponse() + + learner = TEST_SEGMENT_CONFIG['learner'] + with pytest.raises(Exception): + segment.delete_and_suppress_learners(learner, 1000) + + assert mock_post.call_count == 4 + assert "Error was encountered for params:" in caplog.text + assert "9009" in caplog.text + assert "foo_username" in caplog.text + assert "ecommerce-90" in caplog.text + assert "Suppress_With_Delete" in caplog.text + assert "Test error message" in caplog.text + + +def test_bulk_unsuppress_success(setup_regulation_api): # pylint: disable=redefined-outer-name + """ + Test simple success case + """ + mock_post, segment = setup_regulation_api + mock_post.return_value = FakeResponse() + + learner = TEST_SEGMENT_CONFIG['learner'] + segment.unsuppress_learners_by_key('original_username', learner, 100) + + assert mock_post.call_count == 1 + + expected_learner = get_fake_user_retirement() + + fake_json = { + "regulation_type": "Unsuppress", + "attributes": { + "name": "userId", + "values": [expected_learner['original_username'], ] + } + } + + url = TEST_SEGMENT_CONFIG['fake_base_url'] + BULK_REGULATE_URL.format(TEST_SEGMENT_CONFIG['fake_workspace']) + mock_post.assert_any_call( + url, json=fake_json, headers=TEST_SEGMENT_CONFIG['headers'] + ) + + +def test_bulk_unsuppress_error(setup_regulation_api, caplog): # pylint: disable=redefined-outer-name + """ + Test simple error case + """ + mock_post, segment = setup_regulation_api + mock_post.return_value = FakeErrorResponse() + + learner = TEST_SEGMENT_CONFIG['learner'] + with pytest.raises(Exception): + segment.unsuppress_learners_by_key('original_username', learner, 100) + + assert mock_post.call_count == 4 + assert "Error was encountered for params:" in caplog.text + assert "9009" not in caplog.text + assert "foo_username" in caplog.text + assert "ecommerce-90" not in caplog.text + assert "Unsuppress" in caplog.text + assert "Test error message" in caplog.text diff --git a/scripts/user_retirement/utils/__init__.py b/scripts/user_retirement/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/scripts/user_retirement/utils/edx_api.py b/scripts/user_retirement/utils/edx_api.py new file mode 100644 index 000000000000..10903640456a --- /dev/null +++ b/scripts/user_retirement/utils/edx_api.py @@ -0,0 +1,522 @@ +""" +edX API classes which call edX service REST API endpoints using the edx-rest-api-client module. +""" +import logging +from urllib.parse import urljoin + +import backoff +import requests +from edx_rest_api_client.auth import SuppliedJwtAuth +from edx_rest_api_client.client import REQUEST_CONNECT_TIMEOUT, REQUEST_READ_TIMEOUT +from requests.exceptions import ConnectionError, HTTPError, Timeout + +from scripts.user_retirement.utils.exception import HttpDoesNotExistException + +LOG = logging.getLogger(__name__) + +OAUTH_ACCESS_TOKEN_URL = "/oauth2/access_token" + + +class EdxGatewayTimeoutError(Exception): + """ + Exception used to indicate a 504 server error was returned. + Differentiates from other 5xx errors. + """ + + +class BaseApiClient: + """ + API client base class used to submit API requests to a particular web service. + """ + append_slash = True + _access_token = None + + def __init__(self, lms_base_url, api_base_url, client_id, client_secret): + """ + Retrieves OAuth access token from the LMS and creates REST API client instance. + """ + self.api_base_url = api_base_url + self._access_token = self.get_access_token(lms_base_url, client_id, client_secret) + + def get_api_url(self, path): + """ + Construct the full API URL using the api_base_url and path. + + Args: + path (str): API endpoint path. + """ + path = path.strip('/') + if self.append_slash: + path += '/' + + return urljoin(f'{self.api_base_url}/', path) + + def _request(self, method, url, log_404_as_error=True, **kwargs): + if 'headers' not in kwargs: + kwargs['headers'] = {'Content-type': 'application/json'} + + try: + response = requests.request(method, url, auth=SuppliedJwtAuth(self._access_token), **kwargs) + response.raise_for_status() + + if response.status_code != 204: + return response.json() + except HTTPError as exc: + status_code = exc.response.status_code + + if status_code == 404 and not log_404_as_error: + # Immediately raise the error so that a 404 isn't logged as an API error in this case. + raise HttpDoesNotExistException(str(exc)) + + LOG.error(f'API Error: {str(exc)} with status code: {status_code}') + + if status_code == 504: + # Differentiate gateway errors so different backoff can be used. + raise EdxGatewayTimeoutError(str(exc)) + + if status_code == 404: + raise HttpDoesNotExistException(str(exc)) + raise + + except Timeout: + LOG.error("The request is timed out.") + raise + + return response + + @staticmethod + def get_access_token(oauth_base_url, client_id, client_secret): + """ + Returns an access token for this site's service user. + + Returns: + str: JWT access token + """ + oauth_access_token_url = urljoin(f'{oauth_base_url}/', OAUTH_ACCESS_TOKEN_URL) + data = { + 'grant_type': 'client_credentials', + 'client_id': client_id, + 'client_secret': client_secret, + 'token_type': 'jwt', + } + try: + response = requests.post( + oauth_access_token_url, + data=data, + headers={ + 'User-Agent': 'scripts.user_retirement', + }, + timeout=(REQUEST_CONNECT_TIMEOUT, REQUEST_READ_TIMEOUT) + ) + response.raise_for_status() + return response.json()['access_token'] + except KeyError as exc: + LOG.error(f'Failed to get token. {str(exc)} does not exist.') + raise + + except HTTPError as exc: + LOG.error( + f'API Error: {str(exc)} with status code: {exc.response.status_code} fetching access token: {client_id}' + ) + raise + + +def _backoff_handler(details): + """ + Simple logging handler for when timeout backoff occurs. + """ + LOG.info('Trying again in {wait:0.1f} seconds after {tries} tries calling {target}'.format(**details)) + + +def _wait_one_minute(): + """ + Backoff generator that waits for 60 seconds. + """ + return backoff.constant(interval=60) + + +def _giveup_on_unexpected_exception(exc): + """ + Giveup method that gives up backoff upon any unexpected exception. + """ + keep_retrying = ( + # Treat a ConnectionError as retryable. + isinstance(exc, ConnectionError) + # All 5xx status codes are retryable except for 504 Gateway Timeout. + or ( + 500 <= exc.response.status_code < 600 + and exc.response.status_code != 504 # Gateway Timeout + ) + # Status code 104 is unreserved, but we must have added this because we observed retryable 104 responses. + or exc.response.status_code == 104 + ) + return not keep_retrying + + +def _retry_lms_api(): + """ + Decorator which enables retries with sane backoff defaults for LMS APIs. + """ + + def inner(func): # pylint: disable=missing-docstring + func_with_backoff = backoff.on_exception( + backoff.expo, + (HTTPError, ConnectionError), + max_time=600, # 10 minutes + giveup=_giveup_on_unexpected_exception, + # Wrap the actual _backoff_handler so that we can patch the real one in unit tests. Otherwise, the func + # will get decorated on import, embedding this handler as a python object reference, precluding our ability + # to patch it in tests. + on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda + ) + func_with_timeout_backoff = backoff.on_exception( + _wait_one_minute, + EdxGatewayTimeoutError, + max_tries=2, + # Wrap the actual _backoff_handler so that we can patch the real one in unit tests. Otherwise, the func + # will get decorated on import, embedding this handler as a python object reference, precluding our ability + # to patch it in tests. + on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda + ) + return func_with_backoff(func_with_timeout_backoff(func)) + + return inner + + +class LmsApi(BaseApiClient): + """ + LMS API client with convenience methods for making API calls. + """ + + @_retry_lms_api() + def learners_to_retire(self, states_to_request, cool_off_days=7, limit=None): + """ + Retrieves a list of learners awaiting retirement actions. + """ + params = { + 'cool_off_days': cool_off_days, + 'states': states_to_request + } + if limit: + params['limit'] = limit + api_url = self.get_api_url('api/user/v1/accounts/retirement_queue') + return self._request('GET', api_url, params=params) + + @_retry_lms_api() + def get_learners_by_date_and_status(self, state_to_request, start_date, end_date): + """ + Retrieves a list of learners in the given retirement state that were + created in the retirement queue between the dates given. Date range + is inclusive, so to get one day you would set both dates to that day. + + :param state_to_request: String LMS UserRetirementState state name (ex. COMPLETE) + :param start_date: Date or Datetime object + :param end_date: Date or Datetime + """ + params = { + 'start_date': start_date.strftime('%Y-%m-%d'), + 'end_date': end_date.strftime('%Y-%m-%d'), + 'state': state_to_request + } + api_url = self.get_api_url('api/user/v1/accounts/retirements_by_status_and_date') + return self._request('GET', api_url, params=params) + + @_retry_lms_api() + def get_learner_retirement_state(self, username): + """ + Retrieves the given learner's retirement state. + """ + api_url = self.get_api_url(f'api/user/v1/accounts/{username}/retirement_status') + return self._request('GET', api_url) + + @_retry_lms_api() + def update_learner_retirement_state(self, username, new_state_name, message, force=False): + """ + Updates the given learner's retirement state to the retirement state name new_string + with the additional string information in message (for logging purposes). + """ + data = { + 'username': username, + 'new_state': new_state_name, + 'response': message + } + + if force: + data['force'] = True + + api_url = self.get_api_url('api/user/v1/accounts/update_retirement_status') + return self._request('PATCH', api_url, json=data) + + @_retry_lms_api() + def retirement_deactivate_logout(self, learner): + """ + Performs the user deactivation and forced logout step of learner retirement + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('api/user/v1/accounts/deactivate_logout') + return self._request('POST', api_url, json=data) + + @_retry_lms_api() + def retirement_retire_forum(self, learner): + """ + Performs the forum retirement step of learner retirement + """ + # api/discussion/ + data = {'username': learner['original_username']} + try: + api_url = self.get_api_url('api/discussion/v1/accounts/retire_forum') + return self._request('POST', api_url, json=data) + except HttpDoesNotExistException: + LOG.info("No information about learner retirement") + return True + + @_retry_lms_api() + def retirement_retire_mailings(self, learner): + """ + Performs the email list retirement step of learner retirement + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('api/user/v1/accounts/retire_mailings') + return self._request('POST', api_url, json=data) + + @_retry_lms_api() + def retirement_unenroll(self, learner): + """ + Unenrolls the user from all courses + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('api/enrollment/v1/unenroll') + return self._request('POST', api_url, json=data) + + # This endpoint additionally returns 500 when the EdxNotes backend service is unavailable. + @_retry_lms_api() + def retirement_retire_notes(self, learner): + """ + Deletes all the user's notes (aka. annotations) + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('api/edxnotes/v1/retire_user') + return self._request('POST', api_url, json=data) + + @_retry_lms_api() + def retirement_lms_retire_misc(self, learner): + """ + Deletes, blanks, or one-way hashes personal information in LMS as + defined in EDUCATOR-2802 and sub-tasks. + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('api/user/v1/accounts/retire_misc') + return self._request('POST', api_url, json=data) + + @_retry_lms_api() + def retirement_lms_retire(self, learner): + """ + Deletes, blanks, or one-way hashes all remaining personal information in LMS + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('api/user/v1/accounts/retire') + return self._request('POST', api_url, json=data) + + @_retry_lms_api() + def retirement_partner_queue(self, learner): + """ + Calls LMS to add the given user to the retirement reporting queue + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('api/user/v1/accounts/retirement_partner_report') + return self._request('PUT', api_url, json=data) + + @_retry_lms_api() + def retirement_partner_report(self): + """ + Retrieves the list of users to create partner reports for and set their status to + processing + """ + api_url = self.get_api_url('api/user/v1/accounts/retirement_partner_report') + return self._request('POST', api_url) + + @_retry_lms_api() + def retirement_partner_cleanup(self, usernames): + """ + Removes the given users from the partner reporting queue + """ + api_url = self.get_api_url('api/user/v1/accounts/retirement_partner_report_cleanup') + return self._request('POST', api_url, json=usernames) + + @_retry_lms_api() + def retirement_retire_proctoring_data(self, learner): + """ + Deletes or hashes learner data from edx-proctoring + """ + api_url = self.get_api_url(f"api/edx_proctoring/v1/retire_user/{learner['user']['id']}") + return self._request('POST', api_url) + + @_retry_lms_api() + def retirement_retire_proctoring_backend_data(self, learner): + """ + Removes the given learner from 3rd party proctoring backends + """ + api_url = self.get_api_url(f"api/edx_proctoring/v1/retire_backend_user/{learner['user']['id']}") + return self._request('POST', api_url) + + @_retry_lms_api() + def bulk_cleanup_retirements(self, usernames): + """ + Deletes the retirements for all given usernames + """ + data = {'usernames': usernames} + api_url = self.get_api_url('api/user/v1/accounts/retirement_cleanup') + return self._request('POST', api_url, json=data) + + def replace_lms_usernames(self, username_mappings): + """ + Calls LMS API to replace usernames. + + Param: + username_mappings: list of dicts where key is current username and value is new desired username + [{current_un_1: desired_un_1}, {current_un_2: desired_un_2}] + """ + data = {"username_mappings": username_mappings} + api_url = self.get_api_url('api/user/v1/accounts/replace_usernames') + return self._request('POST', api_url, json=data) + + def replace_forums_usernames(self, username_mappings): + """ + Calls the discussion forums API inside of LMS to replace usernames. + + Param: + username_mappings: list of dicts where key is current username and value is new unique username + [{current_un_1: new_un_1}, {current_un_2: new_un_2}] + """ + data = {"username_mappings": username_mappings} + api_url = self.get_api_url('api/discussion/v1/accounts/replace_usernames') + return self._request('POST', api_url, json=data) + + +class EcommerceApi(BaseApiClient): + """ + Ecommerce API client with convenience methods for making API calls. + """ + + @_retry_lms_api() + def retire_learner(self, learner): + """ + Performs the learner retirement step for Ecommerce + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('api/v2/user/retire') + return self._request('POST', api_url, json=data) + + @_retry_lms_api() + def get_tracking_key(self, learner): + """ + Fetches the ecommerce tracking id used for Segment tracking when + ecommerce doesn't have access to the LMS user id. + """ + api_url = self.get_api_url(f"api/v2/retirement/tracking_id/{learner['original_username']}") + return self._request('GET', api_url)['ecommerce_tracking_id'] + + def replace_usernames(self, username_mappings): + """ + Calls the ecommerce API to replace usernames. + + Param: + username_mappings: list of dicts where key is current username and value is new unique username + [{current_un_1: new_un_1}, {current_un_2: new_un_2}] + """ + data = {"username_mappings": username_mappings} + api_url = self.get_api_url('api/v2/user_management/replace_usernames') + return self._request('POST', api_url, json=data) + + +class CredentialsApi(BaseApiClient): + """ + Credentials API client with convenience methods for making API calls. + """ + + @_retry_lms_api() + def retire_learner(self, learner): + """ + Performs the learner retirement step for Credentials + """ + data = {'username': learner['original_username']} + api_url = self.get_api_url('user/retire') + return self._request('POST', api_url, json=data) + + def replace_usernames(self, username_mappings): + """ + Calls the credentials API to replace usernames. + + Param: + username_mappings: list of dicts where key is current username and value is new unique username + [{current_un_1: new_un_1}, {current_un_2: new_un_2}] + """ + data = {"username_mappings": username_mappings} + api_url = self.get_api_url('api/v2/replace_usernames') + return self._request('POST', api_url, json=data) + + +class DiscoveryApi(BaseApiClient): + """ + Discovery API client with convenience methods for making API calls. + """ + + def replace_usernames(self, username_mappings): + """ + Calls the discovery API to replace usernames. + + Param: + username_mappings: list of dicts where key is current username and value is new unique username + [{current_un_1: new_un_1}, {current_un_2: new_un_2}] + """ + data = {"username_mappings": username_mappings} + api_url = self.get_api_url('api/v1/replace_usernames') + return self._request('POST', api_url, json=data) + + +class DemographicsApi(BaseApiClient): + """ + Demographics API client. + """ + + @_retry_lms_api() + def retire_learner(self, learner): + """ + Performs the learner retirement step for Demographics. Passes the learner's LMS User Id instead of username. + """ + data = {'lms_user_id': learner['user']['id']} + # If the user we are retiring has no data in the Demographics DB the request will return a 404. We + # catch the HTTPError and return True in order to prevent this error getting raised and + # incorrectly causing the learner to enter an ERROR state during retirement. + try: + api_url = self.get_api_url('demographics/api/v1/retire_demographics') + return self._request('POST', api_url, log_404_as_error=False, json=data) + except HttpDoesNotExistException: + LOG.info("No demographics data found for user") + return True + + +class LicenseManagerApi(BaseApiClient): + """ + License Manager API client. + """ + + @_retry_lms_api() + def retire_learner(self, learner): + """ + Performs the learner retirement step for License manager. Passes the learner's LMS User Id in addition to + username. + """ + data = { + 'lms_user_id': learner['user']['id'], + 'original_username': learner['original_username'], + } + # If the user we are retiring has no data in the License Manager DB the request will return a 404. We + # catch the HTTPError and return True in order to prevent this error getting raised and + # incorrectly causing the learner to enter an ERROR state during retirement. + try: + api_url = self.get_api_url('api/v1/retire_user') + return self._request('POST', api_url, log_404_as_error=False, json=data) + except HttpDoesNotExistException: + LOG.info("No license manager data found for user") + return True diff --git a/scripts/user_retirement/utils/email_utils.py b/scripts/user_retirement/utils/email_utils.py new file mode 100644 index 000000000000..1f21a3d24914 --- /dev/null +++ b/scripts/user_retirement/utils/email_utils.py @@ -0,0 +1,85 @@ +""" +Convenience functions using boto and AWS SES to send email. +""" + +import logging + +import backoff +import boto3 + +from scripts.user_retirement.utils.exception import BackendError +from scripts.user_retirement.utils.utils import envvar_get_int + +LOG = logging.getLogger(__name__) + +# Default maximum number of attempts to send email. +MAX_EMAIL_TRIES_DEFAULT = 10 + + +def _poll_giveup(results): + """ + Raise an error when the polling tries are exceeded. + """ + orig_args = results['args'] + msg = 'Timed out after {tries} attempts to send email with subject "{subject}".'.format( + tries=results['tries'], + subject=orig_args[3] + ) + raise BackendError(msg) + + +@backoff.on_exception(backoff.expo, + Exception, + max_tries=envvar_get_int("MAX_EMAIL_TRIES", MAX_EMAIL_TRIES_DEFAULT), + on_giveup=_poll_giveup) +def _send_email_with_retry(ses_conn, + from_address, + to_addresses, + subject, + body): + """ + Send email, retrying upon exception. + """ + ses_conn.send_email( + Source=from_address, + Message={ + "Body": { + "Text": { + "Charset": "UTF-8", + "Data": body, + }, + }, + "Subject": { + "Charset": "UTF-8", + "Data": subject, + }, + }, + Destination={ + "ToAddresses": to_addresses, + }, + ) + + +def send_email(aws_region, + from_address, + to_addresses, + subject, + body): + """ + Send an email via AWS SES using boto with the specified subject/body/recipients. + + Args: + aws_region (str): AWS region whose SES service will be used, e.g. "us-east-1". + from_address (str): Email address to use as the From: address. Must be an SES verified address. + to_addresses (list(str)): List of email addresses to which to send the email. + subject (str): Subject to use in the email. + body (str): Body to use in the email - text format. + """ + ses_conn = boto3.client("ses", region_name=aws_region) + _send_email_with_retry( + ses_conn, + from_address, + to_addresses, + subject, + body + ) diff --git a/scripts/user_retirement/utils/exception.py b/scripts/user_retirement/utils/exception.py new file mode 100644 index 000000000000..977272f0a267 --- /dev/null +++ b/scripts/user_retirement/utils/exception.py @@ -0,0 +1,14 @@ +""" +Exceptions used by various utilities. +""" + + +class BackendError(Exception): + pass + + +class HttpDoesNotExistException(Exception): + """ + Called when the server sends a 404 error. + """ + pass diff --git a/scripts/user_retirement/utils/helpers.py b/scripts/user_retirement/utils/helpers.py new file mode 100644 index 000000000000..8203e363593c --- /dev/null +++ b/scripts/user_retirement/utils/helpers.py @@ -0,0 +1,244 @@ +""" +Common helper methods to use in user retirement scripts. +""" +# NOTE: Make sure that all non-ascii text written to standard output (including +# print statements and logging) is manually encoded to bytes using a utf-8 or +# other encoding. We currently make use of this library within a context that +# does NOT tolerate unicode text on sys.stdout, namely python 2 on Build +# Jenkins. PLAT-2287 tracks this Tech Debt. + + +import io +import json +import sys +import traceback +import unicodedata + +import yaml +from six import text_type + +from scripts.user_retirement.utils.edx_api import LmsApi # pylint: disable=wrong-import-position +from scripts.user_retirement.utils.edx_api import CredentialsApi, DemographicsApi, EcommerceApi, LicenseManagerApi +from scripts.user_retirement.utils.thirdparty_apis.amplitude_api import \ + AmplitudeApi # pylint: disable=wrong-import-position +from scripts.user_retirement.utils.thirdparty_apis.braze_api import BrazeApi # pylint: disable=wrong-import-position +from scripts.user_retirement.utils.thirdparty_apis.hubspot_api import \ + HubspotAPI # pylint: disable=wrong-import-position +from scripts.user_retirement.utils.thirdparty_apis.salesforce_api import \ + SalesforceApi # pylint: disable=wrong-import-position +from scripts.user_retirement.utils.thirdparty_apis.segment_api import \ + SegmentApi # pylint: disable=wrong-import-position + + +def _log(kind, message): + """ + Convenience method to log text. Prepended "kind" text makes finding log entries easier. + """ + print(u'{}: {}'.format(kind, message).encode('utf-8')) # See note at the top of this file. + + +def _fail(kind, code, message): + """ + Convenience method to fail out of the command with a message and traceback. + """ + _log(kind, message) + + # Try to get a traceback, if there is one. On Python 3.4 this raises an AttributeError + # if there is no current exception, so we eat that here. + try: + _log(kind, traceback.format_exc()) + except AttributeError: + pass + + sys.exit(code) + + +def _fail_exception(kind, code, message, exc): + """ + A version of fail that takes an exception to be utf-8 decoded + """ + exc_msg = _get_error_str_from_exception(exc) + message += '\n' + exc_msg + _fail(kind, code, message) + + +def _get_error_str_from_exception(exc): + """ + Return a string from an exception that may or may not have a .content (Slumber) + """ + exc_msg = text_type(exc) + + if hasattr(exc, 'content'): + # Slumber inconveniently discards the decoded .text attribute from the Response object, + # and instead gives us the raw encoded .content attribute, so we need to decode it first. + # Python 2 needs the decode, Py3 does not have it. + try: + exc_msg += '\n' + str(exc.content).decode('utf-8') + except AttributeError: + exc_msg += '\n' + str(exc.content) + + return exc_msg + + +def _config_or_exit(fail_func, fail_code, config_file): + """ + Returns the config values from the given file, allows overriding of passed in values. + """ + try: + with io.open(config_file, 'r') as config: + config = yaml.safe_load(config) + + return config + except Exception as exc: # pylint: disable=broad-except + fail_func(fail_code, 'Failed to read config file {}'.format(config_file), exc) + + +def _config_with_drive_or_exit(fail_func, config_fail_code, google_fail_code, config_file, google_secrets_file): + """ + Returns the config values from the given file, allows overriding of passed in values. + """ + try: + with io.open(config_file, 'r') as config: + config = yaml.safe_load(config) + + # Check required values + for var in ('org_partner_mapping', 'drive_partners_folder'): + if var not in config or not config[var]: + fail_func(config_fail_code, 'No {} in config, or it is empty!'.format(var), ValueError()) + + # Force the partner names into NFKC here and when we get the folders to ensure + # they are using the same characters. Otherwise accented characters will not match. + for org in config['org_partner_mapping']: + partner = config['org_partner_mapping'][org] + config['org_partner_mapping'][org] = [unicodedata.normalize('NFKC', text_type(partner)) for partner in + config['org_partner_mapping'][org]] + except Exception as exc: # pylint: disable=broad-except + fail_func(config_fail_code, 'Failed to read config file {}'.format(config_file), exc) + + try: + # Just load and parse the file to make sure it's legit JSON before doing + # all of the work to get the users. + with open(google_secrets_file, 'r') as secrets_f: + json.load(secrets_f) + + config['google_secrets_file'] = google_secrets_file + return config + except Exception as exc: # pylint: disable=broad-except + fail_func(google_fail_code, 'Failed to read secrets file {}'.format(google_secrets_file), exc) + + +def _setup_lms_api_or_exit(fail_func, fail_code, config): + """ + Performs setup of EdxRestClientApi for LMS and returns the validated, sorted list of users to report on. + """ + try: + lms_base_url = config['base_urls']['lms'] + client_id = config['client_id'] + client_secret = config['client_secret'] + + config['LMS'] = LmsApi(lms_base_url, lms_base_url, client_id, client_secret) + except Exception as exc: # pylint: disable=broad-except + fail_func(fail_code, text_type(exc)) + + +def _setup_all_apis_or_exit(fail_func, fail_code, config): + """ + Performs setup of EdxRestClientApi instances for LMS, E-Commerce, Credentials, and + Demographics, as well as fetching the learner's record from LMS and validating that + it is in a state to work on. Returns the learner dict and their current stage in the + retirement flow. + """ + try: + lms_base_url = config['base_urls']['lms'] + ecommerce_base_url = config['base_urls'].get('ecommerce', None) + credentials_base_url = config['base_urls'].get('credentials', None) + segment_base_url = config['base_urls'].get('segment', None) + demographics_base_url = config['base_urls'].get('demographics', None) + license_manager_base_url = config['base_urls'].get('license_manager', None) + client_id = config['client_id'] + client_secret = config['client_secret'] + braze_api_key = config.get('braze_api_key', None) + braze_instance = config.get('braze_instance', None) + amplitude_api_key = config.get('amplitude_api_key', None) + amplitude_secret_key = config.get('amplitude_secret_key', None) + salesforce_user = config.get('salesforce_user', None) + salesforce_password = config.get('salesforce_password', None) + salesforce_token = config.get('salesforce_token', None) + salesforce_domain = config.get('salesforce_domain', None) + salesforce_assignee = config.get('salesforce_assignee', None) + segment_auth_token = config.get('segment_auth_token', None) + segment_workspace_slug = config.get('segment_workspace_slug', None) + hubspot_api_key = config.get('hubspot_api_key', None) + hubspot_aws_region = config.get('hubspot_aws_region', None) + hubspot_from_address = config.get('hubspot_from_address', None) + hubspot_alert_email = config.get('hubspot_alert_email', None) + + for state in config['retirement_pipeline']: + for service, service_url in ( + ('BRAZE', braze_api_key), + ('AMPLITUDE', amplitude_api_key), + ('ECOMMERCE', ecommerce_base_url), + ('CREDENTIALS', credentials_base_url), + ('SEGMENT', segment_base_url), + ('HUBSPOT', hubspot_api_key), + ('DEMOGRAPHICS', demographics_base_url) + ): + if state[2] == service and service_url is None: + fail_func(fail_code, 'Service URL is not configured, but required for state {}'.format(state)) + + config['LMS'] = LmsApi(lms_base_url, lms_base_url, client_id, client_secret) + + if braze_api_key: + config['BRAZE'] = BrazeApi( + braze_api_key, + braze_instance, + ) + + if amplitude_api_key and amplitude_secret_key: + config['AMPLITUDE'] = AmplitudeApi( + amplitude_api_key, + amplitude_secret_key, + ) + + if salesforce_user and salesforce_password and salesforce_token: + config['SALESFORCE'] = SalesforceApi( + salesforce_user, + salesforce_password, + salesforce_token, + salesforce_domain, + salesforce_assignee + ) + + if hubspot_api_key: + config['HUBSPOT'] = HubspotAPI( + hubspot_api_key, + hubspot_aws_region, + hubspot_from_address, + hubspot_alert_email + ) + + if ecommerce_base_url: + config['ECOMMERCE'] = EcommerceApi(lms_base_url, ecommerce_base_url, client_id, client_secret) + + if credentials_base_url: + config['CREDENTIALS'] = CredentialsApi(lms_base_url, credentials_base_url, client_id, client_secret) + + if demographics_base_url: + config['DEMOGRAPHICS'] = DemographicsApi(lms_base_url, demographics_base_url, client_id, client_secret) + + if license_manager_base_url: + config['LICENSE_MANAGER'] = LicenseManagerApi( + lms_base_url, + license_manager_base_url, + client_id, + client_secret, + ) + + if segment_base_url: + config['SEGMENT'] = SegmentApi( + segment_base_url, + segment_auth_token, + segment_workspace_slug + ) + except Exception as exc: # pylint: disable=broad-except + fail_func(fail_code, 'Unexpected error occurred!', exc) diff --git a/scripts/user_retirement/utils/jenkins.py b/scripts/user_retirement/utils/jenkins.py new file mode 100644 index 000000000000..0870af9f5d5b --- /dev/null +++ b/scripts/user_retirement/utils/jenkins.py @@ -0,0 +1,201 @@ +""" +Methods to interact with the Jenkins API to perform various tasks. +""" + +import logging +import math +import os.path +import shutil +import sys + +import backoff +from jenkinsapi.custom_exceptions import JenkinsAPIException +from jenkinsapi.jenkins import Jenkins +from jenkinsapi.utils.crumb_requester import CrumbRequester +from requests.exceptions import HTTPError + +from scripts.user_retirement.utils.exception import BackendError + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +LOG = logging.getLogger(__name__) + + +def _recreate_directory(directory): + """ + Deletes an existing directory recursively (if exists) and (re-)creates it. + """ + if os.path.exists(directory): + shutil.rmtree(directory) + os.mkdir(directory) + + +def export_learner_job_properties(learners, directory): + """ + Creates a Jenkins properties file for each learner in order to make + a retirement slave job for each learner. + + Args: + learners (list of dicts): List of learners for which to create properties files. + directory (str): Directory in which to create the properties files. + """ + _recreate_directory(directory) + + for learner in learners: + learner_name = learner['original_username'].lower() + filename = os.path.join(directory, 'learner_retire_{}'.format(learner_name)) + with open(filename, 'w') as learner_prop_file: + learner_prop_file.write('RETIREMENT_USERNAME={}\n'.format(learner['original_username'])) + + +def _poll_giveup(data): + u""" Raise an error when the polling tries are exceeded.""" + orig_args = data.get(u'args') + # The Build object was the only parameter to the original method call, + # and so it's the first and only item in the args. + build = orig_args[0] + msg = u'Timed out waiting for build {} to finish.'.format(build.name) + raise BackendError(msg) + + +def _backoff_timeout(timeout, base=2, factor=1): + u""" + Return a tuple of (wait_gen, max_tries) so that backoff will only try up to `timeout` seconds. + + |timeout (s)|max attempts|wait durations | + |----------:|-----------:|---------------------:| + |1 |2 |1 | + |5 |4 |1, 2, 2 | + |10 |5 |1, 2, 4, 3 | + |30 |6 |1, 2, 4, 8, 13 | + |60 |8 |1, 2, 4, 8, 16, 32, 37| + |300 |10 |1, 2, 4, 8, 16, 32, 64| + | | |128, 44 | + |600 |11 |1, 2, 4, 8, 16, 32, 64| + | | |128, 256, 89 | + |3600 |13 |1, 2, 4, 8, 16, 32, 64| + | | |128, 256, 512, 1024, | + | | |1553 | + + """ + # Total duration of sum(factor * base ** n for n in range(K)) = factor*(base**K - 1)/(base - 1), + # where K is the number of retries, or max_tries - 1 (since the first try doesn't require a wait) + # + # Solving for K, K = log(timeout * (base - 1) / factor + 1, base) + # + # Using the next smallest integer K will give us a number of elements from + # the exponential sequence to take and still be less than the timeout. + tries = int(math.log(timeout * (base - 1) / factor + 1, base)) + + remainder = timeout - (factor * (base ** tries - 1)) / (base - 1) + + def expo(): + u"""Compute an exponential backoff wait period, but capped to an expected max timeout""" + # pylint: disable=invalid-name + n = 0 + while True: + a = factor * base ** n + if n >= tries: + yield remainder + else: + yield a + n += 1 + + # tries tells us the largest standard wait using the standard progression (before being capped) + # tries + 1 because backoff waits one fewer times than max_tries (the first attempt has no wait time). + # If a remainder, then we need to make one last attempt to get the target timeout (so tries + 2) + if remainder == 0: + return expo, tries + 1 + else: + return expo, tries + 2 + + +def trigger_build(base_url, user_name, user_token, job_name, job_token, + job_cause=None, job_params=None, timeout=60 * 30): + u""" + Trigger a jenkins job/project (note that jenkins uses these terms interchangeably) + + Args: + base_url (str): The base URL for the jenkins server, e.g. https://test-jenkins.testeng.edx.org + user_name (str): The jenkins username + user_token (str): API token for the user. Available at {base_url}/user/{user_name)/configure + job_name (str): The Jenkins job name, e.g. test-project + job_token (str): Jobs must be configured with the option "Trigger builds remotely" selected. + Under this option, you must provide an authorization token (configured in the job) + in the form of a string so that only those who know it would be able to remotely + trigger this project's builds. + job_cause (str): Text that will be included in the recorded build cause + job_params (set of tuples): Parameter names and their values to pass to the job + timeout (int): The maximum number of seconds to wait for the jenkins build to complete (measured + from when the job is triggered.) + + Returns: + A the status of the build that was triggered + + Raises: + BackendError: if the Jenkins job could not be triggered successfully + """ + + @backoff.on_predicate( + backoff.constant, + interval=60, + max_tries=timeout / 60 + 1, + on_giveup=_poll_giveup, + # We aren't worried about concurrent access, so turn off jitter + jitter=None, + ) + def poll_build_for_result(build): + u""" + Poll for the build running, with exponential backoff, capped to ``timeout`` seconds. + The on_predicate decorator is used to retry when the return value + of the target function is True. + """ + return not build.is_running() + + # Create a dict with key/value pairs from the job_params + # that were passed in like this: --param FOO bar --param BAZ biz + # These will get passed to the job as string parameters like this: + # {u'FOO': u'bar', u'BAX': u'biz'} + request_params = {} + for param in job_params: + request_params[param[0]] = param[1] + + # Contact jenkins, log in, and get the base data on the system. + try: + crumb_requester = CrumbRequester( + baseurl=base_url, username=user_name, password=user_token, + ssl_verify=True + ) + jenkins = Jenkins( + base_url, username=user_name, password=user_token, + requester=crumb_requester + ) + except (JenkinsAPIException, HTTPError) as err: + raise BackendError(str(err)) + + if not jenkins.has_job(job_name): + msg = u'Job not found: {}.'.format(job_name) + msg += u' Verify that you have permissions for the job and double check the spelling of its name.' + raise BackendError(msg) + + # This will start the job and will return a QueueItem object which can be used to get build results + job = jenkins[job_name] + queue_item = job.invoke(securitytoken=job_token, build_params=request_params, cause=job_cause) + LOG.info(u'Added item to jenkins. Server: {} Job: {} '.format( + jenkins.base_server_url(), queue_item + )) + + # Block this script until we are through the queue and the job has begun to build. + queue_item.block_until_building() + build = queue_item.get_build() + LOG.info(u'Created build {}'.format(build)) + LOG.info(u'See {}'.format(build.baseurl)) + + # Now block until you get a result back from the build. + poll_build_for_result(build) + + # Update the build's internal state, so that the final status is available + build.poll() + + status = build.get_status() + LOG.info(u'Build status: {status}'.format(status=status)) + return status diff --git a/scripts/user_retirement/utils/thirdparty_apis/__init__.py b/scripts/user_retirement/utils/thirdparty_apis/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/scripts/user_retirement/utils/thirdparty_apis/amplitude_api.py b/scripts/user_retirement/utils/thirdparty_apis/amplitude_api.py new file mode 100644 index 000000000000..ef930f602f0b --- /dev/null +++ b/scripts/user_retirement/utils/thirdparty_apis/amplitude_api.py @@ -0,0 +1,91 @@ +""" +Amplitude API class that is used to delete user from Amplitude. +""" +import logging +import os + +import backoff +import requests + +logger = logging.getLogger(__name__) +MAX_ATTEMPTS = int(os.environ.get("RETRY_MAX_ATTEMPTS", 5)) + + +class AmplitudeException(Exception): + """ + AmplitudeException will be raised there is fatal error and is not recoverable. + """ + pass + + +class AmplitudeRecoverableException(AmplitudeException): + """ + AmplitudeRecoverableException will be raised when request can be retryable. + """ + pass + + +class AmplitudeApi: + """ + Amplitude API is used to handle communication with Amplitude Api's. + """ + + def __init__(self, amplitude_api_key, amplitude_secret_key): + self.amplitude_api_key = amplitude_api_key + self.amplitude_secret_key = amplitude_secret_key + self.base_url = "https://amplitude.com/" + self.delete_user_path = "api/2/deletions/users" + + def auth(self): + """ + Returns auth credentials for Amplitude authorization. + + Returns: + Tuple: Returns authorization tuple. + """ + return (self.amplitude_api_key, self.amplitude_secret_key) + + @backoff.on_exception( + backoff.expo, + AmplitudeRecoverableException, + max_tries=MAX_ATTEMPTS, + ) + def delete_user(self, user): + """ + This function send an API request to delete user from Amplitude. It then parse the response and + try again if it is recoverable. + + Returns: + None + + Args: + user (dict): raw data of user to delete. + + Raises: + AmplitudeException: if the error from amplitude is unrecoverable/unretryable. + AmplitudeRecoverableException: if the error from amplitude is recoverable/retryable. + """ + response = requests.post( + self.base_url + self.delete_user_path, + headers={"Content-Type": "application/json"}, + json={ + "user_ids": [user["user"]["id"]], + 'ignore_invalid_id': 'true', # When true, the job ignores users that don't exist in the project. + "requester": "user-retirement-pipeline", + }, + auth=self.auth() + ) + + if response.status_code == 200: + logger.info("Amplitude user deletion succeeded") + return + + # We have some sort of error. Parse it, log it, and retry as needed. + error_msg = "Amplitude user deletion failed due to {reason}".format(reason=response.reason) + logger.error(error_msg) + # Status 429 is returned when there are too many requests and can be resolved in retrying sending + # request. + if response.status_code == 429 or 500 <= response.status_code < 600: + raise AmplitudeRecoverableException(error_msg) + else: + raise AmplitudeException(error_msg) diff --git a/scripts/user_retirement/utils/thirdparty_apis/braze_api.py b/scripts/user_retirement/utils/thirdparty_apis/braze_api.py new file mode 100644 index 000000000000..247ceccc6f17 --- /dev/null +++ b/scripts/user_retirement/utils/thirdparty_apis/braze_api.py @@ -0,0 +1,85 @@ +""" +Helper API classes for calling Braze APIs. +""" +import logging +import os + +import backoff +import requests + +LOG = logging.getLogger(__name__) +MAX_ATTEMPTS = int(os.environ.get('RETRY_BRAZE_MAX_ATTEMPTS', 5)) + + +class BrazeException(Exception): + pass + + +class BrazeRecoverableException(BrazeException): + pass + + +class BrazeApi: + """ + Braze API client used to make calls to Braze + """ + + def __init__(self, braze_api_key, braze_instance): + self.api_key = braze_api_key + + # https://www.braze.com/docs/api/basics/#endpoints + self.base_url = 'https://rest.{instance}.braze.com'.format(instance=braze_instance) + + def auth_headers(self): + """Returns authorization headers suitable for passing to the requests library""" + return { + 'Authorization': 'Bearer ' + self.api_key, + } + + @staticmethod + def get_error_message(response): + """Returns a string suitable for logging""" + try: + json = response.json() + except ValueError: + json = {} + + # https://www.braze.com/docs/api/errors + message = json.get('message') + + return message or response.reason + + def process_response(self, response, action): + """Log response status and raise an error as needed""" + if response.ok: + LOG.info('Braze {action} succeeded'.format(action=action)) + return + + # We have some sort of error. Parse it, log it, and retry as needed. + error_msg = 'Braze {action} failed due to {msg}'.format(action=action, msg=self.get_error_message(response)) + LOG.error(error_msg) + + if response.status_code == 429 or 500 <= response.status_code < 600: + raise BrazeRecoverableException(error_msg) + else: + raise BrazeException(error_msg) + + @backoff.on_exception( + backoff.expo, + BrazeRecoverableException, + max_tries=MAX_ATTEMPTS, + ) + def delete_user(self, learner): + """ + Delete a learner from Braze. + """ + # https://www.braze.com/docs/help/gdpr_compliance/#the-right-to-erasure + # https://www.braze.com/docs/api/endpoints/user_data/post_user_delete + response = requests.post( + self.base_url + '/users/delete', + headers=self.auth_headers(), + json={ + 'external_ids': [learner['user']['id']], # Braze external ids are LMS user ids + }, + ) + self.process_response(response, 'user deletion') diff --git a/scripts/user_retirement/utils/thirdparty_apis/google_api.py b/scripts/user_retirement/utils/thirdparty_apis/google_api.py new file mode 100644 index 000000000000..e43670c01caf --- /dev/null +++ b/scripts/user_retirement/utils/thirdparty_apis/google_api.py @@ -0,0 +1,530 @@ +""" +Helper API classes for calling google APIs. + +DriveApi is for managing files in google drive. +""" +# NOTE: Make sure that all non-ascii text written to standard output (including +# print statements and logging) is manually encoded to bytes using a utf-8 or +# other encoding. We currently make use of this library within a context that +# does NOT tolerate unicode text on sys.stdout, namely python 2 on Build +# Jenkins PLAT-2287 tracks this Tech Debt.. + +import json +import logging +from itertools import count + +import backoff +from dateutil.parser import parse +from google.oauth2 import service_account +from google.oauth2.credentials import Credentials +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError +# I'm not super happy about this since the function is protected with a leading +# underscore, but the next best thing is literally copying this ~40 line +# function verbatim. +from googleapiclient.http import MediaIoBaseUpload, _should_retry_response +from six import iteritems, text_type + +from scripts.user_retirement.utils.utils import batch + +LOG = logging.getLogger(__name__) + +# The maximum number of requests per batch is 100, according to the google API docs. +# However, cap our number lower than that maximum to avoid throttling errors and backoff. +GOOGLE_API_MAX_BATCH_SIZE = 10 + +# Mimetype used for Google Drive folders. +FOLDER_MIMETYPE = 'application/vnd.google-apps.folder' + +# Fields to be extracted from OAuth2 JSON token files +OAUTH2_TOKEN_FIELDS = [ + 'client_id', 'client_secret', 'refresh_token', + 'token_uri', 'id_token', 'scopes', 'access_token' +] + + +class BatchRequestError(Exception): + """ + Exception which indicates one or more failed requests inside of a batch request. + """ + + +class TriggerRetryException(Exception): + """ + Exception which indicates one or more throttled requests inside of a batch request. + """ + + +class BaseApiClient: + """ + Base API client for google services. + + To add a new service, extend this class and override these class variables: + + _api_name (e.g. "drive") + _api_version (e.g. "v3") + _api_scopes + """ + _api_name = None + _api_version = None + _api_scopes = None + + def __init__(self, client_secrets_file_path, **kwargs): + self._build_client(client_secrets_file_path, **kwargs) + + def _build_client(self, client_secrets_file_path, **kwargs): + """ + Build the google API client, specific to a single google service. + """ + # as_user_account is an indicator that the authentication + # is using a user account. + # If not true, assume a service account. Otherwise, read in the JSON + # file, set the scope, and use the info to instantiate Credentials. + # For more information about user account authentication, go to + # https://google-auth.readthedocs.io/en/master/user-guide.html#user-credentials + as_user_account = kwargs.pop('as_user_account', False) + if not as_user_account: + credentials = service_account.Credentials.from_service_account_file( + client_secrets_file_path, scopes=self._api_scopes + ) + else: + with open(client_secrets_file_path) as fh: + token_info = json.load(fh) + token_info = {k: token_info.get(k) for k in OAUTH2_TOKEN_FIELDS} + # Take the access_token field and change it to token + token = token_info.pop('access_token', None) + token_info['token'] = token + # Set the scopes + token_info['scopes'] = self._api_scopes + credentials = Credentials(**token_info) + self._client = build(self._api_name, self._api_version, credentials=credentials, **kwargs) + LOG.info("Client built.") + + def _batch_with_retry(self, requests): + """ + Send the given Google API requests in a single batch requests, and retry only requests that are throttled. + + Args: + requests (list of googleapiclient.http.HttpRequest): The requests to send. + + Returns: + dict mapping of request object to response + """ + + # Mapping of request object to the corresponding response. + responses = {} + + # This is our working "request queue". Initially, populate the request queue with all the given requests. + try_requests = [] + try_requests.extend(requests) + + # This is the queue of requests that are to be retried, populated by the batch callback function. + retry_requests = [] + + # Generate arbitrary (but unique in this batch request) IDs for each request, so that we can recall the + # corresponding response within a batch response. + request_object_to_request_id = dict(zip( + requests, + (text_type(n) for n in count()), + )) + # Create a flipped mapping for convenience. + request_id_to_request_object = {v: k for k, v in iteritems(request_object_to_request_id)} + + def batch_callback(request_id, response, exception): # pylint: disable=unused-argument,missing-docstring + """ + Handle individual responses in the batch request. + """ + request_object = request_id_to_request_object[request_id] + if exception: + if _should_retry_google_api(exception): + LOG.error(u'Request throttled, adding to the retry queue: {}'.format(exception).encode('utf-8')) + retry_requests.append(request_object) + else: + # In this case, probably nothing can be done, so we just give up on this particular request and + # do not include it in the responses dict. + LOG.error(u'Error processing request {}'.format(request_object).encode('utf-8')) + LOG.error(text_type(exception).encode('utf-8')) + else: + responses[request_object] = response + LOG.info(u'Successfully processed request {}.'.format(request_object).encode('utf-8')) + + # Retry on API throttling at the HTTP request level. + @backoff.on_exception( + backoff.expo, + HttpError, + max_time=600, # 10 minutes + giveup=lambda e: not _should_retry_google_api(e), + on_backoff=lambda details: _backoff_handler(details), # pylint: disable=unnecessary-lambda + ) + # Retry on API throttling at the BATCH ITEM request level. + @backoff.on_exception( + backoff.expo, + TriggerRetryException, + max_time=600, # 10 minutes + on_backoff=lambda details: _backoff_handler(details), # pylint: disable=unnecessary-lambda + ) + def func(): + """ + Core function which constitutes the retry loop. It has no inputs or outputs, only side-effects which + populates the `responses` variable within the scope of _batch_with_retry(). + """ + # Construct a new batch request object containing the current iteration of requests to "try". + batch_request = self._client.new_batch_http_request(callback=batch_callback) # pylint: disable=no-member + for request_object in try_requests: + batch_request.add( + request_object, + request_id=request_object_to_request_id[request_object] + ) + + # Empty the retry queue in preparation of filling it back up with requests that need to be retried. + del retry_requests[:] + + # Send the batch request. If the API responds with HTTP 403 or some other retryable error, we should + # immediately retry this function func() with the same requests in the try_requests queue. If the response + # is HTTP 200, we *still* may raise TriggerRetryException and retry a subset of requests if some, but not + # all requests need to be retried. + batch_request.execute() + + # If the API throttled some requests, batch_callback would have populated the retry queue. Reset the + # try_requests queue and indicate to backoff that there are requests to retry. + if retry_requests: + del try_requests[:] + try_requests.extend(retry_requests) + raise TriggerRetryException() + + # func()'s side-effect is that it indirectly calls batch_callback which populates the responses dict. + func() + return responses + + +def _backoff_handler(details): + """ + Simple logging handler for when timeout backoff occurs. + """ + LOG.info('Trying again in {wait:0.1f} seconds after {tries} tries calling {target}'.format(**details)) + + +def _should_retry_google_api(exc): + """ + General logic for determining if a google API response is retryable. + + Args: + exc (googleapiclient.errors.HttpError): The exception thrown by googleapiclient. + + Returns: + bool: True if the caller should retry the API call. + """ + retry = False + if hasattr(exc, 'resp') and exc.resp: # bizarre and disappointing that sometimes `resp` doesn't exist. + retry = _should_retry_response(exc.resp.status, exc.content) + return retry + + +class DriveApi(BaseApiClient): + """ + Google Drive API client. + """ + _api_name = 'drive' + _api_version = 'v3' + _api_scopes = [ + # basic file read-write functionality. + # 'https://www.googleapis.com/auth/drive.file', + # Full read write functionality + 'https://www.googleapis.com/auth/drive', + # additional scope for being able to see folders not owned by this account. + 'https://www.googleapis.com/auth/drive.metadata', + ] + + @backoff.on_exception( + backoff.expo, + HttpError, + max_time=600, # 10 minutes + giveup=lambda e: not _should_retry_google_api(e), + on_backoff=lambda details: _backoff_handler(details), # pylint: disable=unnecessary-lambda + ) + def create_file_in_folder(self, folder_id, filename, file_stream, mimetype): + """ + Creates a new file in the specified folder. + + Args: + folder_id (str): google resource ID for the drive folder to put the file into. + filename (str): name of the uploaded file. + file_stream (file-like/stream): contents of the file to upload. + mimetype (str): mimetype of the given file. + + Returns: file ID (str). + + Throws: + googleapiclient.errors.HttpError: + For some non-retryable 4xx or 5xx error. See the full list here: + https://developers.google.com/drive/api/v3/handle-errors + """ + file_metadata = { + 'name': filename, + 'parents': [folder_id], + } + media = MediaIoBaseUpload(file_stream, mimetype=mimetype) + uploaded_file = self._client.files().create( # pylint: disable=no-member + body=file_metadata, + media_body=media, + fields='id' + ).execute() + LOG.info(u'File uploaded: ID="{}", name="{}"'.format(uploaded_file.get('id'), filename).encode('utf-8')) + return uploaded_file.get('id') + + # NOTE: Do not decorate this function with backoff since it already calls retryable methods. + def delete_files(self, file_ids): + """ + Delete multiple files forever, bypassing the "trash". + + This function takes advantage of request batching to reduce request volume. + + Args: + file_ids (list of str): list of IDs for files to delete. + + Returns: nothing + + Throws: + BatchRequestError: + One or more files could not be deleted (could even mean the file does not exist). + """ + if len(set(file_ids)) != len(file_ids): + raise ValueError('duplicates detected in the file_ids list.') + + # mapping of request object to the new comment resource returned in the response. + responses = {} + + # process the list of file ids in batches of size GOOGLE_API_MAX_BATCH_SIZE. + for file_ids_batch in batch(file_ids, batch_size=GOOGLE_API_MAX_BATCH_SIZE): + request_objects = [] + for file_id in file_ids_batch: + request_objects.append(self._client.files().delete(fileId=file_id)) # pylint: disable=no-member + + # this generic helper function will handle the retry logic + responses_batch = self._batch_with_retry(request_objects) + + responses.update(responses_batch) + + if len(responses) != len(file_ids): + raise BatchRequestError('Error deleting one or more files/folders.') + + def delete_files_older_than(self, top_level, delete_before_dt, mimetype=None, prefix=None): + """ + Delete all files beneath a given top level folder that are older than a certain datetime. + Optionally, specify a file mimetype and a filename prefix. + + Args: + top_level (str): ID of top level folder. + delete_before_dt (datetime.datetime): Datetime to use for file age. All files created before this datetime + will be permanently deleted. Should be timezone offset-aware. + mimetype (str): Mimetype of files to delete. If not specified, all non-folders will be found. + prefix (str): Filename prefix - only files started with this prefix will be deleted. + """ + LOG.info("Walking files...") + all_files = self.walk_files( + top_level, 'id, name, createdTime', mimetype + ) + LOG.info("Files walked. {} files found before filtering.".format(len(all_files))) + file_ids_to_delete = [] + for file in all_files: + if (not prefix or file['name'].startswith(prefix)) and parse(file['createdTime']) < delete_before_dt: + file_ids_to_delete.append(file['id']) + if file_ids_to_delete: + LOG.info("{} files remaining after filtering.".format(len(file_ids_to_delete))) + self.delete_files(file_ids_to_delete) + + @backoff.on_exception( + backoff.expo, + HttpError, + max_time=600, # 10 minutes + giveup=lambda e: not _should_retry_google_api(e), + on_backoff=lambda details: _backoff_handler(details), # pylint: disable=unnecessary-lambda + ) + def walk_files(self, top_folder_id, file_fields='id, name', mimetype=None, recurse=True): + """ + List all files of a particular mimetype within a given top level folder, traversing all folders recursively. + + This function may make multiple HTTP requests depending on how many pages the response contains. The default + page size for the python google API client is 100 items. + + Args: + top_folder_id (str): ID of top level folder. + file_fields (str): Comma-separated list of metadata fields to return for each folder/file. + For a full list of file metadata fields, see https://developers.google.com/drive/api/v3/reference/files + mimetype (str): Mimetype of files to find. If not specified, all items will be returned, including folders. + recurse (bool): True to recurse into all found folders for items, False to only return top-level items. + + Returns: List of dicts, where each dict contains file metadata and each dict key corresponds to fields + specified in the `file_fields` arg. + + Throws: + googleapiclient.errors.HttpError: + For some non-retryable 4xx or 5xx error. See the full list here: + https://developers.google.com/drive/api/v3/handle-errors + """ + # Sent to list() call and used only for sending the pageToken. + extra_kwargs = {} + # Cumulative list of file metadata dicts for found files. + results = [] + # List of IDs of all visited folders. + visited_folders = [] + # List of IDs of all found files. + found_ids = [] + # List of folder IDs remaining to be listed. + folders_to_visit = [top_folder_id] + # Mimetype part of file-listing query. + mimetype_clause = "" + if mimetype: + # Return both folders and the specified mimetype. + mimetype_clause = "( mimeType = '{}' or mimeType = '{}') and ".format(FOLDER_MIMETYPE, mimetype) + + while folders_to_visit: + current_folder = folders_to_visit.pop() + LOG.info("Current folder: {}".format(current_folder)) + visited_folders.append(current_folder) + extra_kwargs = {} + + while True: + resp = self._client.files().list( # pylint: disable=no-member + q="{}'{}' in parents".format(mimetype_clause, current_folder), + fields='nextPageToken, files({})'.format( + file_fields + ', mimeType, parents' + ), + **extra_kwargs + ).execute() + page_results = resp.get('files', []) + + LOG.info("walk_files: Returned %s results.", len(page_results)) + + # Examine returned results to separate folders from non-folders. + for result in page_results: + LOG.info(u"walk_files: Result: {}".format(result).encode('utf-8')) + # Folders contain files - and get special treatment. + if result['mimeType'] == FOLDER_MIMETYPE: + if recurse and result['id'] not in visited_folders: + # Add any undiscovered folders to the list of folders to check. + folders_to_visit.append(result['id']) + # Determine if this result is a file to return. + if result['id'] not in found_ids and (not mimetype or result['mimeType'] == mimetype): + found_ids.append(result['id']) + # Return only the fields specified in file_fields. + results.append({k.strip(): result.get(k.strip(), None) for k in file_fields.split(',')}) + + LOG.info("walk_files: %s files found and %s folders to check.", len(results), len(folders_to_visit)) + + if page_results and 'nextPageToken' in resp and resp['nextPageToken']: + # Only call for more result pages if results were actually returned -and + # a nextPageToken is returned. + extra_kwargs['pageToken'] = resp['nextPageToken'] + else: + break + return results + + # NOTE: Do not decorate this function with backoff since it already calls retryable methods. + def create_comments_for_files(self, file_ids_and_content, fields='id'): + """ + Create comments for files. + + This function is NOT idempotent. It will blindly create the comments it was asked to create, regardless of the + existence of other identical comments. + + Args: + file_ids_and_content (list of tuple(str, str)): list of (file_id, content) tuples. + fields (str): comma separated list of fields to describe each comment resource in the response. + + Returns: dict mapping of file_id to comment resource (dict). The contents of the comment resources are dictated + by the `fields` arg. + + Throws: + googleapiclient.errors.HttpError: + For some non-retryable 4xx or 5xx error. See the full list here: + https://developers.google.com/drive/api/v3/handle-errors + BatchRequestError: + One or more files resulted in an error when adding comments. + """ + file_ids, _ = zip(*file_ids_and_content) + if len(set(file_ids)) != len(file_ids): + raise ValueError('Duplicates detected in the file_ids_and_content list.') + + # Mapping of file_id to the new comment resource returned in the response. + responses = {} + + # Process the list of file IDs in batches of size GOOGLE_API_MAX_BATCH_SIZE. + for file_ids_and_content_batch in batch(file_ids_and_content, batch_size=GOOGLE_API_MAX_BATCH_SIZE): + request_objects_to_file_id = {} + for file_id, content in file_ids_and_content_batch: + request_object = self._client.comments().create( # pylint: disable=no-member + fileId=file_id, + body={u'content': content}, + fields=fields + ) + request_objects_to_file_id[request_object] = file_id + + # This generic helper function will handle the retry logic + responses_batch = self._batch_with_retry(request_objects_to_file_id.keys()) + + # Transform the mapping FROM request objects -> comment resource TO file IDs -> comment resources. + responses_batch = { + request_objects_to_file_id[request_object]: resp + for request_object, resp in responses_batch.items() + } + responses.update(responses_batch) + + if len(responses) != len(file_ids_and_content): + raise BatchRequestError('Error creating comments for one or more files/folders.') + + return responses + + # NOTE: Do not decorate this function with backoff since it already calls retryable methods. + def list_permissions_for_files(self, file_ids, fields='emailAddress, role'): + """ + List permissions for files. + + Args: + file_ids (list of str): list of Drive file IDs for which to list permissions. + fields (str): comma separated list of fields to describe each permissions resource in the response. + + Returns: dict mapping of file_id to permission resource list (list of dict). The contents of the permission + resources are dictated by the `fields` arg. + + Throws: + googleapiclient.errors.HttpError: + For some non-retryable 4xx or 5xx error. See the full list here: + https://developers.google.com/drive/api/v3/handle-errors + BatchRequestError: + One or more files resulted in an error when having their permissions listed. + """ + + if len(set(file_ids)) != len(file_ids): + raise ValueError('duplicates detected in the file_ids list.') + + # mapping of file_id to the new comment resource returned in the response. + responses = {} + + # process the list of file ids in batches of size GOOGLE_API_MAX_BATCH_SIZE. + for file_ids_batch in batch(file_ids, batch_size=GOOGLE_API_MAX_BATCH_SIZE): + request_objects_to_file_id = {} + for file_id in file_ids_batch: + request_object = self._client.permissions().list( # pylint: disable=no-member + fileId=file_id, + fields='permissions({})'.format(fields) + ) + request_objects_to_file_id[request_object] = file_id + + # this generic helper function will handle the retry logic + responses_batch = self._batch_with_retry(request_objects_to_file_id.keys()) + + # transform the mapping from request objects -> response dicts to file ids -> permissions resource lists. + responses_batch_transformed = {} + for request_object, resp in responses_batch.items(): + permissions = None + if resp and 'permissions' in resp: + permissions = resp['permissions'] + responses_batch_transformed[request_objects_to_file_id[request_object]] = permissions + + responses.update(responses_batch_transformed) + + if len(responses) != len(file_ids): + raise BatchRequestError('Error listing permissions for one or more files/folders.') + + return responses diff --git a/scripts/user_retirement/utils/thirdparty_apis/hubspot_api.py b/scripts/user_retirement/utils/thirdparty_apis/hubspot_api.py new file mode 100644 index 000000000000..e357fd660264 --- /dev/null +++ b/scripts/user_retirement/utils/thirdparty_apis/hubspot_api.py @@ -0,0 +1,123 @@ +""" +Helper API classes for calling Hubspot APIs. +""" +import logging +import os + +import backoff +import requests + +from scripts.user_retirement.utils.email_utils import send_email + +LOG = logging.getLogger(__name__) +MAX_ATTEMPTS = int(os.environ.get('RETRY_HUBSPOT_MAX_ATTEMPTS', 5)) + +GET_VID_FROM_EMAIL_URL_TEMPLATE = "https://api.hubapi.com/contacts/v1/contact/email/{email}/profile" +DELETE_USER_FROM_VID_TEMPLATE = "https://api.hubapi.com/contacts/v1/contact/vid/{vid}" + + +class HubspotException(Exception): + pass + + +class HubspotAPI: + """ + Hubspot API client used to make calls to Hubspot + """ + + def __init__( + self, + hubspot_api_key, + aws_region, + from_address, + alert_email + ): + self.api_key = hubspot_api_key + self.aws_region = aws_region + self.from_address = from_address + self.alert_email = alert_email + + @backoff.on_exception( + backoff.expo, + HubspotException, + max_tries=MAX_ATTEMPTS + ) + def delete_user(self, learner): + """ + Delete a learner from hubspot using their email address. + """ + email = learner.get('original_email', None) + if not email: + raise TypeError('Expected an email address for user to delete, but received None.') + + user_vid = self.get_user_vid(email) + if user_vid: + self.delete_user_by_vid(user_vid) + + def delete_user_by_vid(self, vid): + """ + Delete a learner from hubspot using their Hubspot `vid` (unique identifier) + """ + headers = { + 'content-type': 'application/json', + 'authorization': f'Bearer {self.api_key}' + } + + req = requests.delete(DELETE_USER_FROM_VID_TEMPLATE.format( + vid=vid + ), headers=headers) + error_msg = "" + if req.status_code == 200: + LOG.info("User successfully deleted from Hubspot") + self.send_marketing_alert(vid) + elif req.status_code == 401: + error_msg = "Hubspot user deletion failed due to authorized API call" + elif req.status_code == 404: + error_msg = "Hubspot user deletion failed because vid doesn't match user" + elif req.status_code == 500: + error_msg = "Hubspot user deletion failed due to server-side (Hubspot) issues" + else: + error_msg = "Hubspot user deletion failed due to unknown reasons" + + if error_msg: + LOG.error(error_msg) + raise HubspotException(error_msg) + + def get_user_vid(self, email): + """ + Get a user's `vid` from Hubspot. `vid` is the terminology that hubspot uses for a user ids + """ + headers = { + 'content-type': 'application/json', + 'authorization': f'Bearer {self.api_key}' + } + + req = requests.get(GET_VID_FROM_EMAIL_URL_TEMPLATE.format( + email=email + ), headers=headers) + if req.status_code == 200: + req_data = req.json() + return req_data.get('vid') + elif req.status_code == 404: + LOG.info("No action taken because no user was found in Hubspot.") + return + else: + error_msg = "Error attempted to get user_vid from Hubspot. Error: {}".format( + req.text + ) + LOG.error(error_msg) + raise HubspotException(error_msg) + + def send_marketing_alert(self, vid): + """ + Notify marketing with user's Hubspot `vid` upon successful deletion. + """ + subject = "Alert: Hubspot Deletion" + body = "Learner with the VID \"{}\" has been deleted from Hubspot.".format(vid) + send_email( + self.aws_region, + self.from_address, + [self.alert_email], + subject, + body + ) diff --git a/scripts/user_retirement/utils/thirdparty_apis/salesforce_api.py b/scripts/user_retirement/utils/thirdparty_apis/salesforce_api.py new file mode 100644 index 000000000000..354e723739c8 --- /dev/null +++ b/scripts/user_retirement/utils/thirdparty_apis/salesforce_api.py @@ -0,0 +1,137 @@ +""" +Salesforce API class that will call the Salesforce REST API using simple-salesforce. +""" +import logging +import os + +import backoff +from requests.exceptions import ConnectionError as RequestsConnectionError +from simple_salesforce import Salesforce, format_soql + +LOG = logging.getLogger(__name__) + +MAX_ATTEMPTS = int(os.environ.get('RETRY_SALESFORCE_MAX_ATTEMPTS', 5)) +RETIREMENT_TASK_DESCRIPTION = ( + "A user data retirement request has been made for " + "{email} who has been identified as a lead in Salesforce. " + "Please manually retire the user data for this lead." +) + + +class SalesforceApi: + """ + Class for making Salesforce API calls + """ + + def __init__(self, username, password, security_token, domain, assignee_username): + """ + Create API with credentials + """ + self._sf = self._get_salesforce_client( + username=username, + password=password, + security_token=security_token, + domain=domain + ) + self.assignee_id = self.get_user_id(assignee_username) + if not self.assignee_id: + raise Exception("Could not find Salesforce user with username " + assignee_username) + + @backoff.on_exception( + backoff.expo, + RequestsConnectionError, + max_tries=MAX_ATTEMPTS + ) + def _get_salesforce_client(self, username, password, security_token, domain): + """ + Returns a constructed Salesforce client and retries upon failure. + """ + return Salesforce( + username=username, + password=password, + security_token=security_token, + domain=domain + ) + + @backoff.on_exception( + backoff.expo, + RequestsConnectionError, + max_tries=MAX_ATTEMPTS + ) + def get_lead_ids_by_email(self, email): + """ + Given an id, query for a Lead with that email + Returns a list of ids tht have that email or None if none are found + """ + id_query = self._sf.query(format_soql("SELECT Id FROM Lead WHERE Email = {email}", email=email)) + total_size = int(id_query['totalSize']) + if total_size == 0: + return None + else: + ids = [record['Id'] for record in id_query['records']] + if len(ids) > 1: + LOG.warning("Multiple Ids returned for Lead with email {}".format(email)) + return ids + + @backoff.on_exception( + backoff.expo, + RequestsConnectionError, + max_tries=MAX_ATTEMPTS + ) + def get_user_id(self, username): + """ + Given a username, returns the user id for the User with that username + or None if no user is found + Used to get a the user id of the user we will assign the retirement task to + """ + id_query = self._sf.query(format_soql("SELECT Id FROM User WHERE Username = {username}", username=username)) + total_size = int(id_query['totalSize']) + if total_size == 0: + return None + else: + return id_query['records'][0]['Id'] + + @backoff.on_exception( + backoff.expo, + RequestsConnectionError, + max_tries=MAX_ATTEMPTS + ) + def _create_retirement_task(self, email, lead_ids): + """ + Creates a Salesforce Task instructing a user to manually retire the + given lead + """ + task_params = { + 'Description': RETIREMENT_TASK_DESCRIPTION.format(email=email), + 'Subject': "GDPR Request: " + email, + 'WhoId': lead_ids[0], + 'OwnerId': self.assignee_id + } + if len(lead_ids) > 1: + note = "\nNotice: Multiple leads were identified with the same email. Please retire all following leads:" + for lead_id in lead_ids: + note += "\n{}".format(lead_id) + task_params['Description'] += note + created_task = self._sf.Task.create(task_params) + if created_task['success']: + LOG.info("Successfully salesforce task created task %s", created_task['id']) + else: + LOG.error("Errors while creating task:") + for error in created_task['errors']: + LOG.error(error) + raise Exception("Unable to create retirement task for email " + email) + + def retire_learner(self, learner): + """ + Given a learner email, check if that learner exists as a lead in Salesforce + If they do, create a Salesforce Task instructing someone to manually retire + their personal information + """ + email = learner.get('original_email', None) + if not email: + raise TypeError('Expected an email address for user to delete, but received None.') + lead_ids = self.get_lead_ids_by_email(email) + if lead_ids is None: + LOG.info("No action taken because no lead was found in Salesforce.") + return + self._create_retirement_task(email, lead_ids) diff --git a/scripts/user_retirement/utils/thirdparty_apis/segment_api.py b/scripts/user_retirement/utils/thirdparty_apis/segment_api.py new file mode 100644 index 000000000000..09df5d069392 --- /dev/null +++ b/scripts/user_retirement/utils/thirdparty_apis/segment_api.py @@ -0,0 +1,283 @@ +""" +Segment API call wrappers +""" +import logging +import sys +import traceback + +import backoff +import requests +from simplejson.errors import JSONDecodeError +from six import text_type + +# Maximum number of tries on Segment API calls +MAX_TRIES = 4 + +# These are the required/optional keys in the learner dict that contain IDs we need to retire from Segment. +REQUIRED_IDENTIFYING_KEYS = [('user', 'id'), 'original_username'] +OPTIONAL_IDENTIFYING_KEYS = ['ecommerce_segment_id'] + +# The Segment Config API for bulk deleting users for a particular workspace +BULK_REGULATE_URL = 'v1beta/workspaces/{}/regulations' + +# The Segment Config API for querying the status of a bulk user deletion request for a particular workspace +BULK_REGULATE_STATUS_URL = 'v1beta/workspaces/{}/regulations/{}' + +# According to Segment this represents the maximum limits of the bulk regulation call. +# https://reference.segmentapis.com/?version=latest#57a69434-76cc-43cc-a547-98c319182247 +MAXIMUM_USERS_IN_REGULATION_REQUEST = 5000 + +LOG = logging.getLogger(__name__) + + +def _backoff_handler(details): + """ + Simple logging handler for when timeout backoff occurs. + """ + LOG.error('Trying again in {wait:0.1f} seconds after {tries} tries calling {target}'.format(**details)) + + # Log the text response from any HTTPErrors, if possible + try: + LOG.error(traceback.format_exc()) + exc = sys.exc_info()[1] + LOG.error("HTTPError code {}: {}".format(exc.response.status_code, exc.response.text)) + except Exception: # pylint: disable=broad-except + pass + + +def _wait_30_seconds(): + """ + Backoff generator that waits for 30 seconds. + """ + return backoff.constant(interval=30) + + +def _http_status_giveup(exc): + """ + Giveup method that gives up backoff upon any non-5xx and 504 server errors. + """ + return not 429 == exc.response.status_code and not 500 <= exc.response.status_code < 600 + + +def _retry_segment_api(): + """ + Decorator which enables retries with sane backoff defaults + """ + + def inner(func): # pylint: disable=missing-docstring + func_with_decode_backoff = backoff.on_exception( + backoff.expo, + JSONDecodeError, + max_tries=MAX_TRIES, + on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda + ) + func_with_backoff = backoff.on_exception( + backoff.expo, + requests.exceptions.HTTPError, + max_tries=MAX_TRIES, + giveup=_http_status_giveup, + on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda + ) + func_with_timeout_backoff = backoff.on_exception( + _wait_30_seconds, + requests.exceptions.Timeout, + max_tries=MAX_TRIES, + on_backoff=lambda details: _backoff_handler(details) # pylint: disable=unnecessary-lambda + ) + return func_with_decode_backoff(func_with_backoff(func_with_timeout_backoff(func))) + + return inner + + +class SegmentApi: + """ + Segment API client with convenience methods + """ + + def __init__(self, base_url, auth_token, workspace_slug): + self.base_url = base_url + self.auth_token = auth_token + self.workspace_slug = workspace_slug + + @_retry_segment_api() + def _call_segment_post(self, url, params): + """ + Actually makes the Segment REST POST call. + + 5xx errors and timeouts will be retried via _retry_segment_api, + all others will bubble up. + """ + headers = { + "Authorization": "Bearer {}".format(self.auth_token), + "Content-Type": "application/json" + } + resp = requests.post(self.base_url + url, json=params, headers=headers) + resp.raise_for_status() + return resp + + @_retry_segment_api() + def _call_segment_get(self, url): + """ + Actually makes the Segment REST GET call. + + 5xx errors and timeouts will be retried via _retry_segment_api, + all others will bubble up. + """ + headers = { + "Authorization": "Bearer {}".format(self.auth_token) + } + resp = requests.get(self.base_url + url, headers=headers) + resp.raise_for_status() + return resp + + def _get_value_from_learner(self, learner, key): + """ + Return the value from a learner dict for the given key or 2-tuple of keys. + + Allows us to map things like learner['user']['id'] in a single entry in REQUIRED_IDENTIFYING_KEYS. + """ + if isinstance(key, tuple): + val = learner[key[0]][key[1]] + else: + val = learner[key] + + return text_type(val) + + def _send_regulation_request(self, params): + """ + Make the call to the Segment Regulate API, cleanly report any errors + """ + resp_json = "" + + try: + resp = self._call_segment_post(BULK_REGULATE_URL.format(self.workspace_slug), params) + try: + resp_json = resp.json() + bulk_user_delete_id = resp_json['regulate_id'] + LOG.info('Bulk user regulation queued. Id: {}'.format(bulk_user_delete_id)) + except JSONDecodeError: + resp_json = resp.text + raise + + # If we get here we got some kind of JSON response from Segment, we'll try to get + # the data we need. If it doesn't exist we'll bubble up the error from Segment and + # eat the TypeError / KeyError since they won't be relevant. + except (TypeError, KeyError, requests.exceptions.HTTPError, JSONDecodeError) as exc: + LOG.exception(exc) + err = u'Error was encountered for params: {} \n\n Response: {}'.format( + params, + text_type(resp_json) + ).encode('utf-8') + LOG.error(err) + + raise Exception(err) + + def delete_and_suppress_learner(self, learner): + """ + Delete AND suppress a single Segment user using the bulk user deletion REST API. + + :param learner: Single user retirement status row with its fields. + """ + # Send a list of one learner to be deleted by the multiple learner deletion call. + return self.delete_and_suppress_learners([learner], 1) + + def unsuppress_learners_by_key(self, key, learners, chunk_size, beginning_idx=0): + """ + Sets up the Segment REST API calls to UNSUPPRESS users in chunks. + + :param key: Key in the learner dict to pull the ID we care about from. + :param learners: List of learner dicts to be worked on. We only use the key passed in. + :param chunk_size: How many learners should be retired in this batch. + :param beginning_idx: Index into learners where this batch should start. + """ + curr_idx = beginning_idx + while curr_idx < len(learners): + start_idx = curr_idx + end_idx = min(start_idx + chunk_size - 1, len(learners) - 1) + + LOG.info( + "Attempting unsuppress for key '%s', start index %s, end index %s for learners '%s' through '%s'", + key, + start_idx, end_idx, + learners[start_idx]['original_username'], + learners[end_idx]['original_username'] + ) + + learner_vals = [] + for idx in range(start_idx, end_idx + 1): + learner_vals.append(self._get_value_from_learner(learners[idx], key)) + + if len(learner_vals) >= MAXIMUM_USERS_IN_REGULATION_REQUEST: + LOG.error( + 'Attempting to UNSUPPRESS too many user values (%s) at once in bulk request - decrease chunk_size.', + len(learner_vals) + ) + return + + params = { + "regulation_type": "Unsuppress", + "attributes": { + "name": "userId", + "values": learner_vals + } + } + + self._send_regulation_request(params) + + curr_idx += chunk_size + + def delete_and_suppress_learners(self, learners, chunk_size, beginning_idx=0): + """ + Sets up the Segment REST API calls to GDPR-delete users in chunks. + + :param learners: List of learner dicts returned from LMS, should contain all we need to retire this learner. + :param chunk_size: How many learners should be retired in this batch. + :param beginning_idx: Index into learners where this batch should start. + """ + curr_idx = beginning_idx + while curr_idx < len(learners): + start_idx = curr_idx + end_idx = min(start_idx + chunk_size - 1, len(learners) - 1) + LOG.info( + "Attempting Segment deletion with start index %s, end index %s for learners (%s, %s) through (%s, %s)", + start_idx, end_idx, + learners[start_idx]['user']['id'], learners[start_idx]['original_username'], + learners[end_idx]['user']['id'], learners[end_idx]['original_username'] + ) + + learner_vals = [] + for idx in range(start_idx, end_idx + 1): + for id_key in REQUIRED_IDENTIFYING_KEYS: + learner_vals.append(self._get_value_from_learner(learners[idx], id_key)) + for id_key in OPTIONAL_IDENTIFYING_KEYS: + if id_key in learners[idx]: + learner_vals.append(self._get_value_from_learner(learners[idx], id_key)) + + if len(learner_vals) >= MAXIMUM_USERS_IN_REGULATION_REQUEST: + LOG.error( + 'Attempting to delete too many user values (%s) at once in bulk request - decrease chunk_size.', + len(learner_vals) + ) + return + + params = { + "regulation_type": "Suppress_With_Delete", + "attributes": { + "name": "userId", + "values": learner_vals + } + } + + self._send_regulation_request(params) + + curr_idx += chunk_size + + def get_bulk_delete_status(self, bulk_delete_id): + """ + Queries the status of a previously submitted bulk delete request. + + :param bulk_delete_id: ID returned from a previously-submitted bulk delete request. + """ + resp = self._call_segment_get(BULK_REGULATE_STATUS_URL.format(self.workspace_slug, bulk_delete_id)) + resp_json = resp.json() + LOG.info(text_type(resp_json)) diff --git a/scripts/user_retirement/utils/utils.py b/scripts/user_retirement/utils/utils.py new file mode 100644 index 000000000000..b0c8c000cdd7 --- /dev/null +++ b/scripts/user_retirement/utils/utils.py @@ -0,0 +1,25 @@ +import os + + +def envvar_get_int(var_name, default): + """ + Grab an environment variable and return it as an integer. + If the environment variable does not exist, return the default. + """ + return int(os.environ.get(var_name, default)) + + +def batch(batchable, batch_size=1): + """ + Utility to facilitate batched iteration over a list. + + Arguments: + batchable (list): The list to break into batches. + + Yields: + list + """ + batchable_list = list(batchable) + length = len(batchable_list) + for index in range(0, length, batch_size): + yield batchable_list[index:index + batch_size]