diff --git a/src/aind_data_transfer_service/__init__.py b/src/aind_data_transfer_service/__init__.py index dd99e9c..bacf75a 100644 --- a/src/aind_data_transfer_service/__init__.py +++ b/src/aind_data_transfer_service/__init__.py @@ -1,7 +1,7 @@ """Init package""" import os -__version__ = "1.1.0" +__version__ = "1.2.0" # Global constants OPEN_DATA_BUCKET_NAME = os.getenv("OPEN_DATA_BUCKET_NAME", "open") diff --git a/src/aind_data_transfer_service/models.py b/src/aind_data_transfer_service/models.py index 5e3a656..ebf9d04 100644 --- a/src/aind_data_transfer_service/models.py +++ b/src/aind_data_transfer_service/models.py @@ -1,9 +1,11 @@ """Module for data models used in application""" -from datetime import datetime -from typing import List, Optional +import ast +from datetime import datetime, timedelta, timezone +from typing import List, Optional, Union -from pydantic import AwareDatetime, BaseModel, Field +from pydantic import AwareDatetime, BaseModel, Field, field_validator +from starlette.datastructures import QueryParams class AirflowDagRun(BaseModel): @@ -37,12 +39,113 @@ class AirflowDagRunsRequestParameters(BaseModel): limit: int = 25 offset: int = 0 - order_by: str = "-start_date" + state: Optional[list[str]] = [] + execution_date_gte: Optional[str] = ( + datetime.now(timezone.utc) - timedelta(weeks=2) + ).isoformat() + execution_date_lte: Optional[str] = None + order_by: str = "-execution_date" + + @field_validator("execution_date_gte", mode="after") + def validate_min_execution_date(cls, execution_date_gte: str): + """Validate the earliest submit date filter is within 2 weeks""" + min_execution_date = datetime.now(timezone.utc) - timedelta(weeks=2) + # datetime.fromisoformat does not support Z in python < 3.11 + date_to_check = execution_date_gte.replace("Z", "+00:00") + if datetime.fromisoformat(date_to_check) < min_execution_date: + raise ValueError( + "execution_date_gte must be within the last 2 weeks" + ) + return execution_date_gte @classmethod - def from_query_params(cls, query_params: dict): + def from_query_params(cls, query_params: QueryParams): """Maps the query parameters to the model""" - return cls(**query_params) + params = dict(query_params) + if "state" in params: + params["state"] = ast.literal_eval(params["state"]) + return cls.model_validate(params) + + +class AirflowDagRunRequestParameters(BaseModel): + """Model for parameters when requesting info from dag_run endpoint""" + + dag_run_id: str = Field(..., min_length=1) + + @classmethod + def from_query_params(cls, query_params: QueryParams): + """Maps the query parameters to the model""" + params = dict(query_params) + return cls.model_validate(params) + + +class AirflowTaskInstancesRequestParameters(BaseModel): + """Model for parameters when requesting info from task_instances + endpoint""" + + dag_run_id: str = Field(..., min_length=1) + + @classmethod + def from_query_params(cls, query_params: QueryParams): + """Maps the query parameters to the model""" + params = dict(query_params) + return cls.model_validate(params) + + +class AirflowTaskInstance(BaseModel): + """Data model for task_instance entry when requesting info from airflow""" + + dag_id: Optional[str] + dag_run_id: Optional[str] + duration: Optional[Union[int, float]] + end_date: Optional[AwareDatetime] + execution_date: Optional[AwareDatetime] + executor_config: Optional[str] + hostname: Optional[str] + map_index: Optional[int] + max_tries: Optional[int] + note: Optional[str] + operator: Optional[str] + pid: Optional[int] + pool: Optional[str] + pool_slots: Optional[int] + priority_weight: Optional[int] + queue: Optional[str] + queued_when: Optional[AwareDatetime] + rendered_fields: Optional[dict] + sla_miss: Optional[dict] + start_date: Optional[AwareDatetime] + state: Optional[str] + task_id: Optional[str] + trigger: Optional[dict] + triggerer_job: Optional[dict] + try_number: Optional[int] + unixname: Optional[str] + + +class AirflowTaskInstancesResponse(BaseModel): + """Data model for response when requesting info from task_instances + endpoint""" + + task_instances: List[AirflowTaskInstance] + total_entries: int + + +class AirflowTaskInstanceLogsRequestParameters(BaseModel): + """Model for parameters when requesting info from task_instance_logs + endpoint""" + + # excluded fields are used to build the url + dag_run_id: str = Field(..., min_length=1, exclude=True) + task_id: str = Field(..., min_length=1, exclude=True) + try_number: int = Field(..., ge=0, exclude=True) + full_content: bool = True + + @classmethod + def from_query_params(cls, query_params: QueryParams): + """Maps the query parameters to the model""" + params = dict(query_params) + return cls.model_validate(params) class JobStatus(BaseModel): @@ -74,3 +177,38 @@ def from_airflow_dag_run(cls, airflow_dag_run: AirflowDagRun): def jinja_dict(self): """Map model to a dictionary that jinja can render""" return self.model_dump(exclude_none=True) + + +class JobTasks(BaseModel): + """Model for what is rendered to the user for each task.""" + + job_id: Optional[str] = Field(None) + task_id: Optional[str] = Field(None) + try_number: Optional[int] = Field(None) + task_state: Optional[str] = Field(None) + priority_weight: Optional[int] = Field(None) + map_index: Optional[int] = Field(None) + submit_time: Optional[datetime] = Field(None) + start_time: Optional[datetime] = Field(None) + end_time: Optional[datetime] = Field(None) + duration: Optional[Union[int, float]] = Field(None) + comment: Optional[str] = Field(None) + + @classmethod + def from_airflow_task_instance( + cls, airflow_task_instance: AirflowTaskInstance + ): + """Maps the fields from the HpcJobStatusResponse to this model""" + return cls( + job_id=airflow_task_instance.dag_run_id, + task_id=airflow_task_instance.task_id, + try_number=airflow_task_instance.try_number, + task_state=airflow_task_instance.state, + priority_weight=airflow_task_instance.priority_weight, + map_index=airflow_task_instance.map_index, + submit_time=airflow_task_instance.execution_date, + start_time=airflow_task_instance.start_date, + end_time=airflow_task_instance.end_date, + duration=airflow_task_instance.duration, + comment=airflow_task_instance.note, + ) diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index ff8f815..3a9c890 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -30,9 +30,15 @@ from aind_data_transfer_service.hpc.client import HpcClient, HpcClientConfigs from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings from aind_data_transfer_service.models import ( + AirflowDagRun, + AirflowDagRunRequestParameters, AirflowDagRunsRequestParameters, AirflowDagRunsResponse, + AirflowTaskInstanceLogsRequestParameters, + AirflowTaskInstancesRequestParameters, + AirflowTaskInstancesResponse, JobStatus, + JobTasks, ) template_directory = os.path.abspath( @@ -390,23 +396,40 @@ async def get_job_status_list(request: Request): """Get status of jobs with default pagination of limit=25 and offset=0.""" # TODO: Use httpx async client try: - params = AirflowDagRunsRequestParameters.from_query_params( - request.query_params - ) + url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/") + get_one_job = request.query_params.get("dag_run_id") is not None + if get_one_job: + params = AirflowDagRunRequestParameters.from_query_params( + request.query_params + ) + url = f"{url}/{params.dag_run_id}" + else: + params = AirflowDagRunsRequestParameters.from_query_params( + request.query_params + ) params_dict = json.loads(params.model_dump_json()) + # Send request to Airflow to ListDagRuns or GetDagRun response_jobs = requests.get( - url=os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL"), + url=url, auth=( os.getenv("AIND_AIRFLOW_SERVICE_USER"), os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"), ), - params=params_dict, + params=None if get_one_job else params_dict, ) status_code = response_jobs.status_code if response_jobs.status_code == 200: - dag_runs = AirflowDagRunsResponse.model_validate_json( - json.dumps(response_jobs.json()) - ) + if get_one_job: + dag_run = AirflowDagRun.model_validate_json( + json.dumps(response_jobs.json()) + ) + dag_runs = AirflowDagRunsResponse( + dag_runs=[dag_run], total_entries=1 + ) + else: + dag_runs = AirflowDagRunsResponse.model_validate_json( + json.dumps(response_jobs.json()) + ) job_status_list = [ JobStatus.from_airflow_dag_run(d) for d in dag_runs.dag_runs ] @@ -420,7 +443,10 @@ async def get_job_status_list(request: Request): } else: message = "Error retrieving job status list from airflow" - data = {"params": params_dict, "errors": [response_jobs.json()]} + data = { + "params": params_dict, + "errors": [response_jobs.json()], + } except ValidationError as e: logging.error(e) status_code = 406 @@ -440,6 +466,115 @@ async def get_job_status_list(request: Request): ) +async def get_tasks_list(request: Request): + """Get list of tasks instances given a job id.""" + try: + url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/") + params = AirflowTaskInstancesRequestParameters.from_query_params( + request.query_params + ) + params_dict = json.loads(params.model_dump_json()) + response_tasks = requests.get( + url=f"{url}/{params.dag_run_id}/taskInstances", + auth=( + os.getenv("AIND_AIRFLOW_SERVICE_USER"), + os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"), + ), + ) + status_code = response_tasks.status_code + if response_tasks.status_code == 200: + task_instances = AirflowTaskInstancesResponse.model_validate_json( + json.dumps(response_tasks.json()) + ) + job_tasks_list = sorted( + [ + JobTasks.from_airflow_task_instance(t) + for t in task_instances.task_instances + ], + key=lambda t: (-t.priority_weight, t.map_index), + ) + message = "Retrieved job tasks list from airflow" + data = { + "params": params_dict, + "total_entries": task_instances.total_entries, + "job_tasks_list": [ + json.loads(t.model_dump_json()) for t in job_tasks_list + ], + } + else: + message = "Error retrieving job tasks list from airflow" + data = { + "params": params_dict, + "errors": [response_tasks.json()], + } + except ValidationError as e: + logging.error(e) + status_code = 406 + message = "Error validating request parameters" + data = {"errors": json.loads(e.json())} + except Exception as e: + logging.error(e) + status_code = 500 + message = "Unable to retrieve job tasks list from airflow" + data = {"errors": [f"{e.__class__.__name__}{e.args}"]} + return JSONResponse( + status_code=status_code, + content={ + "message": message, + "data": data, + }, + ) + + +async def get_task_logs(request: Request): + """Get task logs given a job id, task id, and task try number.""" + try: + url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/") + params = AirflowTaskInstanceLogsRequestParameters.from_query_params( + request.query_params + ) + params_dict = json.loads(params.model_dump_json()) + params_full = dict(params) + response_logs = requests.get( + url=( + f"{url}/{params.dag_run_id}/taskInstances/{params.task_id}" + f"/logs/{params.try_number}" + ), + auth=( + os.getenv("AIND_AIRFLOW_SERVICE_USER"), + os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"), + ), + params=params_dict, + ) + status_code = response_logs.status_code + if response_logs.status_code == 200: + message = "Retrieved task logs from airflow" + data = {"params": params_full, "logs": response_logs.text} + else: + message = "Error retrieving task logs from airflow" + data = { + "params": params_full, + "errors": [response_logs.json()], + } + except ValidationError as e: + logging.error(e) + status_code = 406 + message = "Error validating request parameters" + data = {"errors": json.loads(e.json())} + except Exception as e: + logging.error(e) + status_code = 500 + message = "Unable to retrieve task logs from airflow" + data = {"errors": [f"{e.__class__.__name__}{e.args}"]} + return JSONResponse( + status_code=status_code, + content={ + "message": message, + "data": data, + }, + ) + + async def index(request: Request): """GET|POST /: form handler""" return templates.TemplateResponse( @@ -478,6 +613,45 @@ async def job_status_table(request: Request): ) +async def job_tasks_table(request: Request): + """Get Job Tasks table given a job id""" + response_tasks = await get_tasks_list(request) + response_tasks_json = json.loads(response_tasks.body) + data = response_tasks_json.get("data") + return templates.TemplateResponse( + name="job_tasks_table.html", + context=( + { + "request": request, + "status_code": response_tasks.status_code, + "message": response_tasks_json.get("message"), + "errors": data.get("errors", []), + "total_entries": data.get("total_entries", 0), + "job_tasks_list": data.get("job_tasks_list", []), + } + ), + ) + + +async def task_logs(request: Request): + """Get task logs given a job id, task id, and task try number.""" + response_tasks = await get_task_logs(request) + response_tasks_json = json.loads(response_tasks.body) + data = response_tasks_json.get("data") + return templates.TemplateResponse( + name="task_logs.html", + context=( + { + "request": request, + "status_code": response_tasks.status_code, + "message": response_tasks_json.get("message"), + "errors": data.get("errors", []), + "logs": data.get("logs"), + } + ), + ) + + async def jobs(request: Request): """Get Job Status page with pagination""" default_limit = AirflowDagRunsRequestParameters.model_fields[ @@ -486,6 +660,9 @@ async def jobs(request: Request): default_offset = AirflowDagRunsRequestParameters.model_fields[ "offset" ].default + default_state = AirflowDagRunsRequestParameters.model_fields[ + "state" + ].default return templates.TemplateResponse( name="job_status.html", context=( @@ -493,6 +670,7 @@ async def jobs(request: Request): "request": request, "default_limit": default_limit, "default_offset": default_offset, + "default_state": default_state, "project_names_url": os.getenv( "AIND_METADATA_SERVICE_PROJECT_NAMES_URL" ), @@ -545,8 +723,12 @@ async def download_job_template(_: Request): endpoint=get_job_status_list, methods=["GET"], ), + Route("/api/v1/get_tasks_list", endpoint=get_tasks_list, methods=["GET"]), + Route("/api/v1/get_task_logs", endpoint=get_task_logs, methods=["GET"]), Route("/jobs", endpoint=jobs, methods=["GET"]), Route("/job_status_table", endpoint=job_status_table, methods=["GET"]), + Route("/job_tasks_table", endpoint=job_tasks_table, methods=["GET"]), + Route("/task_logs", endpoint=task_logs, methods=["GET"]), Route( "/api/job_upload_template", endpoint=download_job_template, diff --git a/src/aind_data_transfer_service/templates/job_status.html b/src/aind_data_transfer_service/templates/job_status.html index 75890f7..ebdac0f 100644 --- a/src/aind_data_transfer_service/templates/job_status.html +++ b/src/aind_data_transfer_service/templates/job_status.html @@ -3,8 +3,12 @@ + + + + {% block title %} {% endblock %} AIND Data Transfer Service Jobs - +
- + + {% for job_status in job_status_list %} @@ -27,14 +34,38 @@ + {% endfor %}
Asset Name Job ID Status Submit Time Start TimeEnd timeEnd Time CommentTasks
{{job_status.job_state}} {{job_status.submit_time}} {{job_status.start_time}} {{job_status.end_time}} {{job_status.comment}} + +
+ + {% if status_code != 200 %}