Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: search all jobs by asset name or job id #193

Merged
merged 9 commits into from
Dec 13, 2024
12 changes: 0 additions & 12 deletions src/aind_data_transfer_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,6 @@ def from_query_params(cls, query_params: QueryParams):
return cls.model_validate(params)


class AirflowDagRunRequestParameters(BaseModel):
"""Model for parameters when requesting info from dag_run endpoint"""

dag_run_id: str = Field(..., min_length=1)

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class AirflowTaskInstancesRequestParameters(BaseModel):
"""Model for parameters when requesting info from task_instances
endpoint"""
Expand Down
109 changes: 61 additions & 48 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import io
import json
import os
from asyncio import sleep
from asyncio import gather, sleep
from pathlib import PurePosixPath
from typing import Optional

import requests
from aind_data_transfer_models.core import SubmitJobRequest
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from httpx import AsyncClient
from openpyxl import load_workbook
from pydantic import SecretStr, ValidationError
from starlette.applications import Starlette
Expand All @@ -30,8 +32,6 @@
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger
from aind_data_transfer_service.models import (
AirflowDagRun,
AirflowDagRunRequestParameters,
AirflowDagRunsRequestParameters,
AirflowDagRunsResponse,
AirflowTaskInstanceLogsRequestParameters,
Expand Down Expand Up @@ -422,59 +422,72 @@ async def submit_hpc_jobs(request: Request): # noqa: C901

async def get_job_status_list(request: Request):
"""Get status of jobs with default pagination of limit=25 and offset=0."""
# TODO: Use httpx async client

async def fetch_jobs(
client: AsyncClient, url: str, params: Optional[dict]
):
"""Helper method to fetch jobs using httpx async client"""
response = await client.get(url, params=params)
response.raise_for_status()
return response.json()

try:
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
get_one_job = request.query_params.get("dag_run_id") is not None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_one_job was originally created to allow users to retrieve a job based on exact job id. This is not needed anymore since UI enables users to search fully client-side.

if get_one_job:
params = AirflowDagRunRequestParameters.from_query_params(
request.query_params
)
url = f"{url}/{params.dag_run_id}"
else:
params = AirflowDagRunsRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json())
# Send request to Airflow to ListDagRuns or GetDagRun
response_jobs = requests.get(
url=url,
get_all_jobs = request.query_params.get("get_all_jobs") is not None
params = AirflowDagRunsRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json(exclude_none=True))
# Send request to Airflow to ListDagRuns
async with AsyncClient(
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
),
params=None if get_one_job else params_dict,
)
status_code = response_jobs.status_code
if response_jobs.status_code == 200:
if get_one_job:
dag_run = AirflowDagRun.model_validate_json(
json.dumps(response_jobs.json())
)
dag_runs = AirflowDagRunsResponse(
dag_runs=[dag_run], total_entries=1
)
else:
dag_runs = AirflowDagRunsResponse.model_validate_json(
json.dumps(response_jobs.json())
)
)
) as client:
# Fetch initial jobs
response_jobs = await fetch_jobs(
client=client,
url=url,
params=params_dict,
)
dag_runs = AirflowDagRunsResponse.model_validate_json(
json.dumps(response_jobs)
)
job_status_list = [
JobStatus.from_airflow_dag_run(d) for d in dag_runs.dag_runs
]
message = "Retrieved job status list from airflow"
data = {
"params": params_dict,
"total_entries": dag_runs.total_entries,
"job_status_list": [
json.loads(j.model_dump_json()) for j in job_status_list
],
}
else:
message = "Error retrieving job status list from airflow"
data = {
"params": params_dict,
"errors": [response_jobs.json()],
}
total_entries = dag_runs.total_entries
if get_all_jobs:
# Fetch remaining jobs concurrently
tasks = []
offset = params_dict["offset"] + params_dict["limit"]
while offset < total_entries:
params = {**params_dict, "limit": 100, "offset": offset}
tasks.append(
fetch_jobs(client=client, url=url, params=params)
)
offset += 100
batches = await gather(*tasks)
for batch in batches:
dag_runs = AirflowDagRunsResponse.model_validate_json(
json.dumps(batch)
)
job_status_list.extend(
[
JobStatus.from_airflow_dag_run(d)
for d in dag_runs.dag_runs
]
)
status_code = 200
message = "Retrieved job status list from airflow"
data = {
"params": params_dict,
"total_entries": total_entries,
"job_status_list": [
json.loads(j.model_dump_json()) for j in job_status_list
],
}
except ValidationError as e:
logger.warning(
f"There was a validation error process job_status list: {e}"
Expand Down
Loading
Loading