diff --git a/dispatcher/backend/src/common/constants.py b/dispatcher/backend/src/common/constants.py index bc362b40a..5f4ee187a 100644 --- a/dispatcher/backend/src/common/constants.py +++ b/dispatcher/backend/src/common/constants.py @@ -66,7 +66,7 @@ # using the following, it is possible to automate # the update of a whitelist of workers IPs on Wasabi (S3 provider) # enable this feature (default is off) -USES_WORKERS_IPS_WHITELIST = bool(os.getenv("USES_WORKERS_IPS_WHITELIST", "")) +USES_WORKERS_IPS_WHITELIST = bool(os.getenv("USES_WORKERS_IPS_WHITELIST")) MAX_WORKER_IP_CHANGES_PER_DAY = 4 # wasabi URL with credentials to update policy WASABI_URL = os.getenv("WASABI_URL", "") diff --git a/dispatcher/backend/src/common/external.py b/dispatcher/backend/src/common/external.py index 8a4d76fd3..3706b0f07 100644 --- a/dispatcher/backend/src/common/external.py +++ b/dispatcher/backend/src/common/external.py @@ -27,12 +27,11 @@ logger = logging.getLogger(__name__) -def update_workers_whitelist(): +def update_workers_whitelist(session: so.Session): """update whitelist of workers on external services""" - update_wasabi_whitelist(build_workers_whitelist()) + ExternalIpUpdater.update(build_workers_whitelist(session=session)) -@dbsession def build_workers_whitelist(session: so.Session) -> typing.List[str]: """list of worker IP adresses and networks (text) to use as whitelist""" wl_networks = [] @@ -150,6 +149,16 @@ def get_statement(): ) +class ExternalIpUpdater: + """Class responsible to push IP updates to external system(s) + + `update` is called with the new list of all workers IPs everytime + a change is detected. + By default, this class update our IPs whitelist in Wasabi""" + + update = update_wasabi_whitelist + + @dbsession def advertise_books_to_cms(task_id: UUID, session: so.Session): """inform openZIM CMS of all created ZIMs in the farm for this task diff --git a/dispatcher/backend/src/db/__init__.py b/dispatcher/backend/src/db/__init__.py index 69bad2e96..5f03486a2 100644 --- a/dispatcher/backend/src/db/__init__.py +++ b/dispatcher/backend/src/db/__init__.py @@ -52,6 +52,20 @@ def inner(*args, **kwargs): return inner +def dbsession_manual(func): + """Decorator to create an SQLAlchemy ORM session object and wrap the function + inside the session. A `session` argument is automatically set. Transaction must + be managed by the developer (e.g. perform a commit / rollback). + """ + + def inner(*args, **kwargs): + with Session() as session: + kwargs["session"] = session + return func(*args, **kwargs) + + return inner + + def count_from_stmt(session: OrmSession, stmt: SelectBase) -> int: """Count all records returned by any statement `stmt` passed as parameter""" return session.execute( diff --git a/dispatcher/backend/src/routes/requested_tasks/requested_task.py b/dispatcher/backend/src/routes/requested_tasks/requested_task.py index 347192bc3..688051932 100644 --- a/dispatcher/backend/src/routes/requested_tasks/requested_task.py +++ b/dispatcher/backend/src/routes/requested_tasks/requested_task.py @@ -9,12 +9,8 @@ from marshmallow import ValidationError import db.models as dbm -from common import WorkersIpChangesCounts, getnow -from common.constants import ( - ENABLED_SCHEDULER, - MAX_WORKER_IP_CHANGES_PER_DAY, - USES_WORKERS_IPS_WHITELIST, -) +from common import WorkersIpChangesCounts, constants, getnow +from common.constants import ENABLED_SCHEDULER, MAX_WORKER_IP_CHANGES_PER_DAY from common.external import update_workers_whitelist from common.schemas.orms import RequestedTaskFullSchema, RequestedTaskLightSchema from common.schemas.parameters import ( @@ -24,8 +20,8 @@ WorkerRequestedTaskSchema, ) from common.utils import task_event_handler -from db import count_from_stmt, dbsession -from errors.http import InvalidRequestJSON, TaskNotFound, WorkerNotFound +from db import count_from_stmt, dbsession, dbsession_manual +from errors.http import HTTPBase, InvalidRequestJSON, TaskNotFound, WorkerNotFound from routes import auth_info_if_supplied, authenticate, require_perm, url_uuid from routes.base import BaseRoute from routes.errors import NotFound @@ -35,14 +31,14 @@ logger = logging.getLogger(__name__) -def record_ip_change(worker_name): +def record_ip_change(session: so.Session, worker_name: str): """record that this worker changed its IP and trigger whitelist changes""" today = datetime.date.today() # counts and limits are per-day so reset it if date changed if today != WorkersIpChangesCounts.today: WorkersIpChangesCounts.reset() if WorkersIpChangesCounts.add(worker_name) <= MAX_WORKER_IP_CHANGES_PER_DAY: - update_workers_whitelist() + update_workers_whitelist(session) else: logger.error( f"Worker {worker_name} IP changes for {today} " @@ -208,7 +204,7 @@ class RequestedTasksForWorkers(BaseRoute): methods = ["GET"] @authenticate - @dbsession + @dbsession_manual def get(self, session: so.Session, token: AccessToken.Payload): """list of requested tasks to be retrieved by workers, auth-only""" @@ -229,15 +225,26 @@ def get(self, session: so.Session, token: AccessToken.Payload): worker = dbm.Worker.get(session, worker_name, WorkerNotFound) if worker.user.username == token.username: worker.last_seen = getnow() - previous_ip = str(worker.last_ip) - worker.last_ip = worker_ip - - # flush to DB so that record_ip_change has access to updated IP - session.flush() # IP changed since last encounter - if USES_WORKERS_IPS_WHITELIST and previous_ip != worker_ip: - record_ip_change(worker_name) + if str(worker.last_ip) != worker_ip: + logger.info( + f"Worker IP changed detected for {worker_name}: " + f"IP changed from {worker.last_ip} to {worker_ip}" + ) + worker.last_ip = worker_ip + # commit explicitely since we are not using an explicit transaction, + # and do it before calling Wasabi so that changes are propagated + # quickly and transaction is not blocking + session.commit() + if constants.USES_WORKERS_IPS_WHITELIST: + try: + record_ip_change(session=session, worker_name=worker_name) + except Exception: + raise HTTPBase( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + error="Recording IP changes failed", + ) request_args = WorkerRequestedTaskSchema().load(request_args) diff --git a/dispatcher/backend/src/tests/conftest.py b/dispatcher/backend/src/tests/conftest.py new file mode 100644 index 000000000..230be5dae --- /dev/null +++ b/dispatcher/backend/src/tests/conftest.py @@ -0,0 +1,12 @@ +from typing import Generator + +import pytest +from sqlalchemy.orm import Session as OrmSession + +from db import Session + + +@pytest.fixture +def dbsession() -> Generator[OrmSession, None, None]: + with Session.begin() as session: + yield session diff --git a/dispatcher/backend/src/tests/integration/routes/workers/test_worker.py b/dispatcher/backend/src/tests/integration/routes/workers/test_worker.py index 48b1971c9..4767bdc6a 100644 --- a/dispatcher/backend/src/tests/integration/routes/workers/test_worker.py +++ b/dispatcher/backend/src/tests/integration/routes/workers/test_worker.py @@ -1,11 +1,14 @@ +from typing import List + import pytest -from common.external import build_workers_whitelist +from common import constants +from common.external import ExternalIpUpdater, build_workers_whitelist class TestWorkersCommon: - def test_build_workers_whitelist(self, workers): - whitelist = build_workers_whitelist() + def test_build_workers_whitelist(self, workers, dbsession): + whitelist = build_workers_whitelist(session=dbsession) # - 4 because: # 2 workers have a duplicate IP # 1 worker has an IP missing @@ -206,3 +209,130 @@ def test_checkin_another_user( # response.get_json()["error"] # == "worker with same name already exists for another user" # ) + + +class TestWorkerRequestedTasks: + def test_requested_task_worker_as_admin(self, client, access_token, worker): + response = client.get( + "/requested-tasks/worker", + query_string={ + "worker": worker["name"], + "avail_cpu": 4, + "avail_memory": 2048, + "avail_disk": 4096, + }, + headers={"Authorization": access_token}, + ) + assert response.status_code == 200 + + def test_requested_task_worker_as_worker(self, client, make_access_token, worker): + response = client.get( + "/requested-tasks/worker", + query_string={ + "worker": worker["name"], + "avail_cpu": 4, + "avail_memory": 2048, + "avail_disk": 4096, + }, + headers={"Authorization": make_access_token(worker["username"], "worker")}, + ) + assert response.status_code == 200 + + @pytest.mark.parametrize( + "prev_ip, new_ip, external_update_enabled, external_update_fails," + " external_update_called", + [ + ("77.77.77.77", "88.88.88.88", False, False, False), # ip update disabled + ("77.77.77.77", "77.77.77.77", True, False, False), # ip did not changed + ("77.77.77.77", "88.88.88.88", True, False, True), # ip should be updated + ("77.77.77.77", "88.88.88.88", True, True, False), # ip update fails + ], + ) + def test_requested_task_worker_update_ip_whitelist( + self, + client, + make_access_token, + worker, + prev_ip, + new_ip, + external_update_enabled, + external_update_fails, + external_update_called, + ): + # call it once to set prev_ip + response = client.get( + "/requested-tasks/worker", + query_string={ + "worker": worker["name"], + "avail_cpu": 4, + "avail_memory": 2048, + "avail_disk": 4096, + }, + headers={ + "Authorization": make_access_token(worker["username"], "worker"), + "X-Forwarded-For": prev_ip, + }, + ) + assert response.status_code == 200 + + # check prev_ip has been set + response = client.get("/workers/") + assert response.status_code == 200 + response_data = response.get_json() + for item in response_data["items"]: + if item["name"] != worker["name"]: + continue + assert item["last_ip"] == prev_ip + + # setup custom ip updater to intercept Wasabi operations + updater = IpUpdaterAndChecker(should_fail=external_update_fails) + assert new_ip not in updater.ip_addresses + ExternalIpUpdater.update = updater.ip_update + constants.USES_WORKERS_IPS_WHITELIST = external_update_enabled + + # call it once to set next_ip + response = client.get( + "/requested-tasks/worker", + query_string={ + "worker": worker["name"], + "avail_cpu": 4, + "avail_memory": 2048, + "avail_disk": 4096, + }, + headers={ + "Authorization": make_access_token(worker["username"], "worker"), + "X-Forwarded-For": new_ip, + }, + ) + if external_update_fails: + assert response.status_code == 503 + else: + assert response.status_code == 200 + assert updater.ips_updated == external_update_called + if external_update_called: + assert new_ip in updater.ip_addresses + + # check new_ip has been set (even if ip update is disabled or has failed) + response = client.get("/workers/") + assert response.status_code == 200 + response_data = response.get_json() + for item in response_data["items"]: + if item["name"] != worker["name"]: + continue + assert item["last_ip"] == new_ip + + +class IpUpdaterAndChecker: + """Helper class to intercept Wasabi operations and perform assertions""" + + def __init__(self, should_fail: bool) -> None: + self.ips_updated = False + self.should_fail = should_fail + self.ip_addresses = [] + + def ip_update(self, ip_addresses: List): + if self.should_fail: + raise Exception() + else: + self.ips_updated = True + self.ip_addresses = ip_addresses