diff --git a/src/aind_data_transfer_service/models.py b/src/aind_data_transfer_service/models.py index 36d4a2d..ebf9d04 100644 --- a/src/aind_data_transfer_service/models.py +++ b/src/aind_data_transfer_service/models.py @@ -2,9 +2,10 @@ import ast from datetime import datetime, timedelta, timezone -from typing import List, Optional +from typing import List, Optional, Union from pydantic import AwareDatetime, BaseModel, Field, field_validator +from starlette.datastructures import QueryParams class AirflowDagRun(BaseModel): @@ -58,12 +59,93 @@ def validate_min_execution_date(cls, execution_date_gte: str): return execution_date_gte @classmethod - def from_query_params(cls, query_params: dict): + def from_query_params(cls, query_params: QueryParams): """Maps the query parameters to the model""" params = dict(query_params) if "state" in params: params["state"] = ast.literal_eval(params["state"]) - return cls(**params) + return cls.model_validate(params) + + +class AirflowDagRunRequestParameters(BaseModel): + """Model for parameters when requesting info from dag_run endpoint""" + + dag_run_id: str = Field(..., min_length=1) + + @classmethod + def from_query_params(cls, query_params: QueryParams): + """Maps the query parameters to the model""" + params = dict(query_params) + return cls.model_validate(params) + + +class AirflowTaskInstancesRequestParameters(BaseModel): + """Model for parameters when requesting info from task_instances + endpoint""" + + dag_run_id: str = Field(..., min_length=1) + + @classmethod + def from_query_params(cls, query_params: QueryParams): + """Maps the query parameters to the model""" + params = dict(query_params) + return cls.model_validate(params) + + +class AirflowTaskInstance(BaseModel): + """Data model for task_instance entry when requesting info from airflow""" + + dag_id: Optional[str] + dag_run_id: Optional[str] + duration: Optional[Union[int, float]] + end_date: Optional[AwareDatetime] + execution_date: Optional[AwareDatetime] + executor_config: Optional[str] + hostname: Optional[str] + map_index: Optional[int] + max_tries: Optional[int] + note: Optional[str] + operator: Optional[str] + pid: Optional[int] + pool: Optional[str] + pool_slots: Optional[int] + priority_weight: Optional[int] + queue: Optional[str] + queued_when: Optional[AwareDatetime] + rendered_fields: Optional[dict] + sla_miss: Optional[dict] + start_date: Optional[AwareDatetime] + state: Optional[str] + task_id: Optional[str] + trigger: Optional[dict] + triggerer_job: Optional[dict] + try_number: Optional[int] + unixname: Optional[str] + + +class AirflowTaskInstancesResponse(BaseModel): + """Data model for response when requesting info from task_instances + endpoint""" + + task_instances: List[AirflowTaskInstance] + total_entries: int + + +class AirflowTaskInstanceLogsRequestParameters(BaseModel): + """Model for parameters when requesting info from task_instance_logs + endpoint""" + + # excluded fields are used to build the url + dag_run_id: str = Field(..., min_length=1, exclude=True) + task_id: str = Field(..., min_length=1, exclude=True) + try_number: int = Field(..., ge=0, exclude=True) + full_content: bool = True + + @classmethod + def from_query_params(cls, query_params: QueryParams): + """Maps the query parameters to the model""" + params = dict(query_params) + return cls.model_validate(params) class JobStatus(BaseModel): @@ -95,3 +177,38 @@ def from_airflow_dag_run(cls, airflow_dag_run: AirflowDagRun): def jinja_dict(self): """Map model to a dictionary that jinja can render""" return self.model_dump(exclude_none=True) + + +class JobTasks(BaseModel): + """Model for what is rendered to the user for each task.""" + + job_id: Optional[str] = Field(None) + task_id: Optional[str] = Field(None) + try_number: Optional[int] = Field(None) + task_state: Optional[str] = Field(None) + priority_weight: Optional[int] = Field(None) + map_index: Optional[int] = Field(None) + submit_time: Optional[datetime] = Field(None) + start_time: Optional[datetime] = Field(None) + end_time: Optional[datetime] = Field(None) + duration: Optional[Union[int, float]] = Field(None) + comment: Optional[str] = Field(None) + + @classmethod + def from_airflow_task_instance( + cls, airflow_task_instance: AirflowTaskInstance + ): + """Maps the fields from the HpcJobStatusResponse to this model""" + return cls( + job_id=airflow_task_instance.dag_run_id, + task_id=airflow_task_instance.task_id, + try_number=airflow_task_instance.try_number, + task_state=airflow_task_instance.state, + priority_weight=airflow_task_instance.priority_weight, + map_index=airflow_task_instance.map_index, + submit_time=airflow_task_instance.execution_date, + start_time=airflow_task_instance.start_date, + end_time=airflow_task_instance.end_date, + duration=airflow_task_instance.duration, + comment=airflow_task_instance.note, + ) diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index 19398b2..3a9c890 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -31,9 +31,14 @@ from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings from aind_data_transfer_service.models import ( AirflowDagRun, + AirflowDagRunRequestParameters, AirflowDagRunsRequestParameters, AirflowDagRunsResponse, + AirflowTaskInstanceLogsRequestParameters, + AirflowTaskInstancesRequestParameters, + AirflowTaskInstancesResponse, JobStatus, + JobTasks, ) template_directory = os.path.abspath( @@ -394,15 +399,18 @@ async def get_job_status_list(request: Request): url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/") get_one_job = request.query_params.get("dag_run_id") is not None if get_one_job: - dag_run_id = request.query_params["dag_run_id"] + params = AirflowDagRunRequestParameters.from_query_params( + request.query_params + ) + url = f"{url}/{params.dag_run_id}" else: params = AirflowDagRunsRequestParameters.from_query_params( request.query_params ) - params_dict = json.loads(params.model_dump_json()) + params_dict = json.loads(params.model_dump_json()) # Send request to Airflow to ListDagRuns or GetDagRun response_jobs = requests.get( - url=f"{url}/{dag_run_id}" if get_one_job else url, + url=url, auth=( os.getenv("AIND_AIRFLOW_SERVICE_USER"), os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"), @@ -427,9 +435,7 @@ async def get_job_status_list(request: Request): ] message = "Retrieved job status list from airflow" data = { - "params": ( - {"dag_run_id": dag_run_id} if get_one_job else params_dict - ), + "params": params_dict, "total_entries": dag_runs.total_entries, "job_status_list": [ json.loads(j.model_dump_json()) for j in job_status_list @@ -438,9 +444,7 @@ async def get_job_status_list(request: Request): else: message = "Error retrieving job status list from airflow" data = { - "params": ( - {"dag_run_id": dag_run_id} if get_one_job else params_dict - ), + "params": params_dict, "errors": [response_jobs.json()], } except ValidationError as e: @@ -462,6 +466,115 @@ async def get_job_status_list(request: Request): ) +async def get_tasks_list(request: Request): + """Get list of tasks instances given a job id.""" + try: + url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/") + params = AirflowTaskInstancesRequestParameters.from_query_params( + request.query_params + ) + params_dict = json.loads(params.model_dump_json()) + response_tasks = requests.get( + url=f"{url}/{params.dag_run_id}/taskInstances", + auth=( + os.getenv("AIND_AIRFLOW_SERVICE_USER"), + os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"), + ), + ) + status_code = response_tasks.status_code + if response_tasks.status_code == 200: + task_instances = AirflowTaskInstancesResponse.model_validate_json( + json.dumps(response_tasks.json()) + ) + job_tasks_list = sorted( + [ + JobTasks.from_airflow_task_instance(t) + for t in task_instances.task_instances + ], + key=lambda t: (-t.priority_weight, t.map_index), + ) + message = "Retrieved job tasks list from airflow" + data = { + "params": params_dict, + "total_entries": task_instances.total_entries, + "job_tasks_list": [ + json.loads(t.model_dump_json()) for t in job_tasks_list + ], + } + else: + message = "Error retrieving job tasks list from airflow" + data = { + "params": params_dict, + "errors": [response_tasks.json()], + } + except ValidationError as e: + logging.error(e) + status_code = 406 + message = "Error validating request parameters" + data = {"errors": json.loads(e.json())} + except Exception as e: + logging.error(e) + status_code = 500 + message = "Unable to retrieve job tasks list from airflow" + data = {"errors": [f"{e.__class__.__name__}{e.args}"]} + return JSONResponse( + status_code=status_code, + content={ + "message": message, + "data": data, + }, + ) + + +async def get_task_logs(request: Request): + """Get task logs given a job id, task id, and task try number.""" + try: + url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/") + params = AirflowTaskInstanceLogsRequestParameters.from_query_params( + request.query_params + ) + params_dict = json.loads(params.model_dump_json()) + params_full = dict(params) + response_logs = requests.get( + url=( + f"{url}/{params.dag_run_id}/taskInstances/{params.task_id}" + f"/logs/{params.try_number}" + ), + auth=( + os.getenv("AIND_AIRFLOW_SERVICE_USER"), + os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"), + ), + params=params_dict, + ) + status_code = response_logs.status_code + if response_logs.status_code == 200: + message = "Retrieved task logs from airflow" + data = {"params": params_full, "logs": response_logs.text} + else: + message = "Error retrieving task logs from airflow" + data = { + "params": params_full, + "errors": [response_logs.json()], + } + except ValidationError as e: + logging.error(e) + status_code = 406 + message = "Error validating request parameters" + data = {"errors": json.loads(e.json())} + except Exception as e: + logging.error(e) + status_code = 500 + message = "Unable to retrieve task logs from airflow" + data = {"errors": [f"{e.__class__.__name__}{e.args}"]} + return JSONResponse( + status_code=status_code, + content={ + "message": message, + "data": data, + }, + ) + + async def index(request: Request): """GET|POST /: form handler""" return templates.TemplateResponse( @@ -500,6 +613,45 @@ async def job_status_table(request: Request): ) +async def job_tasks_table(request: Request): + """Get Job Tasks table given a job id""" + response_tasks = await get_tasks_list(request) + response_tasks_json = json.loads(response_tasks.body) + data = response_tasks_json.get("data") + return templates.TemplateResponse( + name="job_tasks_table.html", + context=( + { + "request": request, + "status_code": response_tasks.status_code, + "message": response_tasks_json.get("message"), + "errors": data.get("errors", []), + "total_entries": data.get("total_entries", 0), + "job_tasks_list": data.get("job_tasks_list", []), + } + ), + ) + + +async def task_logs(request: Request): + """Get task logs given a job id, task id, and task try number.""" + response_tasks = await get_task_logs(request) + response_tasks_json = json.loads(response_tasks.body) + data = response_tasks_json.get("data") + return templates.TemplateResponse( + name="task_logs.html", + context=( + { + "request": request, + "status_code": response_tasks.status_code, + "message": response_tasks_json.get("message"), + "errors": data.get("errors", []), + "logs": data.get("logs"), + } + ), + ) + + async def jobs(request: Request): """Get Job Status page with pagination""" default_limit = AirflowDagRunsRequestParameters.model_fields[ @@ -571,8 +723,12 @@ async def download_job_template(_: Request): endpoint=get_job_status_list, methods=["GET"], ), + Route("/api/v1/get_tasks_list", endpoint=get_tasks_list, methods=["GET"]), + Route("/api/v1/get_task_logs", endpoint=get_task_logs, methods=["GET"]), Route("/jobs", endpoint=jobs, methods=["GET"]), Route("/job_status_table", endpoint=job_status_table, methods=["GET"]), + Route("/job_tasks_table", endpoint=job_tasks_table, methods=["GET"]), + Route("/task_logs", endpoint=task_logs, methods=["GET"]), Route( "/api/job_upload_template", endpoint=download_job_template, diff --git a/src/aind_data_transfer_service/templates/job_status_table.html b/src/aind_data_transfer_service/templates/job_status_table.html index b283c7f..b32f905 100644 --- a/src/aind_data_transfer_service/templates/job_status_table.html +++ b/src/aind_data_transfer_service/templates/job_status_table.html @@ -4,8 +4,14 @@ + + @@ -17,8 +23,9 @@ Status Submit Time Start Time - End time + End Time Comment + Tasks {% for job_status in job_status_list %} @@ -33,9 +40,40 @@ {{job_status.start_time}} {{job_status.end_time}} {{job_status.comment}} + + + {% endfor %} + + {% if status_code != 200 %}