From 1ffba85e0ee68636914e4fe38cb84db3a601f1e1 Mon Sep 17 00:00:00 2001 From: Helen Lin <46795546+helen-m-lin@users.noreply.github.com> Date: Thu, 12 Dec 2024 23:29:42 -0800 Subject: [PATCH 1/5] feat: search all jobs by asset name or job id (#193) * feat: get_all_jobs using httpx async client * feat: render full jobs table * feat: search full jobs table by asset name or job id * feat: toggle advanced/full search table * feat: custom rendering for full search table * refactor: do not need to get_one_job * test: update unit tests * test: update env patch * fix: missing docstring --- src/aind_data_transfer_service/models.py | 12 - src/aind_data_transfer_service/server.py | 109 ++++---- .../templates/job_status.html | 242 +++++++++++++----- tests/test_server.py | 88 +++---- 4 files changed, 274 insertions(+), 177 deletions(-) diff --git a/src/aind_data_transfer_service/models.py b/src/aind_data_transfer_service/models.py index ebf9d04..69fa428 100644 --- a/src/aind_data_transfer_service/models.py +++ b/src/aind_data_transfer_service/models.py @@ -67,18 +67,6 @@ def from_query_params(cls, query_params: QueryParams): return cls.model_validate(params) -class AirflowDagRunRequestParameters(BaseModel): - """Model for parameters when requesting info from dag_run endpoint""" - - dag_run_id: str = Field(..., min_length=1) - - @classmethod - def from_query_params(cls, query_params: QueryParams): - """Maps the query parameters to the model""" - params = dict(query_params) - return cls.model_validate(params) - - class AirflowTaskInstancesRequestParameters(BaseModel): """Model for parameters when requesting info from task_instances endpoint""" diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index e70f901..017e9ff 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -4,14 +4,16 @@ import io import json import os -from asyncio import sleep +from asyncio import gather, sleep from pathlib import PurePosixPath +from typing import Optional import requests from aind_data_transfer_models.core import SubmitJobRequest from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse from fastapi.templating import Jinja2Templates +from httpx import AsyncClient from openpyxl import load_workbook from pydantic import SecretStr, ValidationError from starlette.applications import Starlette @@ -30,8 +32,6 @@ from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger from aind_data_transfer_service.models import ( - AirflowDagRun, - AirflowDagRunRequestParameters, AirflowDagRunsRequestParameters, AirflowDagRunsResponse, AirflowTaskInstanceLogsRequestParameters, @@ -422,59 +422,72 @@ async def submit_hpc_jobs(request: Request): # noqa: C901 async def get_job_status_list(request: Request): """Get status of jobs with default pagination of limit=25 and offset=0.""" - # TODO: Use httpx async client + + async def fetch_jobs( + client: AsyncClient, url: str, params: Optional[dict] + ): + """Helper method to fetch jobs using httpx async client""" + response = await client.get(url, params=params) + response.raise_for_status() + return response.json() + try: url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/") - get_one_job = request.query_params.get("dag_run_id") is not None - if get_one_job: - params = AirflowDagRunRequestParameters.from_query_params( - request.query_params - ) - url = f"{url}/{params.dag_run_id}" - else: - params = AirflowDagRunsRequestParameters.from_query_params( - request.query_params - ) - params_dict = json.loads(params.model_dump_json()) - # Send request to Airflow to ListDagRuns or GetDagRun - response_jobs = requests.get( - url=url, + get_all_jobs = request.query_params.get("get_all_jobs") is not None + params = AirflowDagRunsRequestParameters.from_query_params( + request.query_params + ) + params_dict = json.loads(params.model_dump_json(exclude_none=True)) + # Send request to Airflow to ListDagRuns + async with AsyncClient( auth=( os.getenv("AIND_AIRFLOW_SERVICE_USER"), os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"), - ), - params=None if get_one_job else params_dict, - ) - status_code = response_jobs.status_code - if response_jobs.status_code == 200: - if get_one_job: - dag_run = AirflowDagRun.model_validate_json( - json.dumps(response_jobs.json()) - ) - dag_runs = AirflowDagRunsResponse( - dag_runs=[dag_run], total_entries=1 - ) - else: - dag_runs = AirflowDagRunsResponse.model_validate_json( - json.dumps(response_jobs.json()) - ) + ) + ) as client: + # Fetch initial jobs + response_jobs = await fetch_jobs( + client=client, + url=url, + params=params_dict, + ) + dag_runs = AirflowDagRunsResponse.model_validate_json( + json.dumps(response_jobs) + ) job_status_list = [ JobStatus.from_airflow_dag_run(d) for d in dag_runs.dag_runs ] - message = "Retrieved job status list from airflow" - data = { - "params": params_dict, - "total_entries": dag_runs.total_entries, - "job_status_list": [ - json.loads(j.model_dump_json()) for j in job_status_list - ], - } - else: - message = "Error retrieving job status list from airflow" - data = { - "params": params_dict, - "errors": [response_jobs.json()], - } + total_entries = dag_runs.total_entries + if get_all_jobs: + # Fetch remaining jobs concurrently + tasks = [] + offset = params_dict["offset"] + params_dict["limit"] + while offset < total_entries: + params = {**params_dict, "limit": 100, "offset": offset} + tasks.append( + fetch_jobs(client=client, url=url, params=params) + ) + offset += 100 + batches = await gather(*tasks) + for batch in batches: + dag_runs = AirflowDagRunsResponse.model_validate_json( + json.dumps(batch) + ) + job_status_list.extend( + [ + JobStatus.from_airflow_dag_run(d) + for d in dag_runs.dag_runs + ] + ) + status_code = 200 + message = "Retrieved job status list from airflow" + data = { + "params": params_dict, + "total_entries": total_entries, + "job_status_list": [ + json.loads(j.model_dump_json()) for j in job_status_list + ], + } except ValidationError as e: logger.warning( f"There was a validation error process job_status list: {e}" diff --git a/src/aind_data_transfer_service/templates/job_status.html b/src/aind_data_transfer_service/templates/job_status.html index 942623c..b65b647 100644 --- a/src/aind_data_transfer_service/templates/job_status.html +++ b/src/aind_data_transfer_service/templates/job_status.html @@ -4,10 +4,12 @@ + +