Skip to content

Commit

Permalink
feat: task instances and logs from airflow (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
helen-m-lin authored Aug 23, 2024
1 parent 8d6d37e commit 6666fe3
Show file tree
Hide file tree
Showing 7 changed files with 1,481 additions and 13 deletions.
123 changes: 120 additions & 3 deletions src/aind_data_transfer_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import ast
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from typing import List, Optional, Union

from pydantic import AwareDatetime, BaseModel, Field, field_validator
from starlette.datastructures import QueryParams


class AirflowDagRun(BaseModel):
Expand Down Expand Up @@ -58,12 +59,93 @@ def validate_min_execution_date(cls, execution_date_gte: str):
return execution_date_gte

@classmethod
def from_query_params(cls, query_params: dict):
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
if "state" in params:
params["state"] = ast.literal_eval(params["state"])
return cls(**params)
return cls.model_validate(params)


class AirflowDagRunRequestParameters(BaseModel):
"""Model for parameters when requesting info from dag_run endpoint"""

dag_run_id: str = Field(..., min_length=1)

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class AirflowTaskInstancesRequestParameters(BaseModel):
"""Model for parameters when requesting info from task_instances
endpoint"""

dag_run_id: str = Field(..., min_length=1)

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class AirflowTaskInstance(BaseModel):
"""Data model for task_instance entry when requesting info from airflow"""

dag_id: Optional[str]
dag_run_id: Optional[str]
duration: Optional[Union[int, float]]
end_date: Optional[AwareDatetime]
execution_date: Optional[AwareDatetime]
executor_config: Optional[str]
hostname: Optional[str]
map_index: Optional[int]
max_tries: Optional[int]
note: Optional[str]
operator: Optional[str]
pid: Optional[int]
pool: Optional[str]
pool_slots: Optional[int]
priority_weight: Optional[int]
queue: Optional[str]
queued_when: Optional[AwareDatetime]
rendered_fields: Optional[dict]
sla_miss: Optional[dict]
start_date: Optional[AwareDatetime]
state: Optional[str]
task_id: Optional[str]
trigger: Optional[dict]
triggerer_job: Optional[dict]
try_number: Optional[int]
unixname: Optional[str]


class AirflowTaskInstancesResponse(BaseModel):
"""Data model for response when requesting info from task_instances
endpoint"""

task_instances: List[AirflowTaskInstance]
total_entries: int


class AirflowTaskInstanceLogsRequestParameters(BaseModel):
"""Model for parameters when requesting info from task_instance_logs
endpoint"""

# excluded fields are used to build the url
dag_run_id: str = Field(..., min_length=1, exclude=True)
task_id: str = Field(..., min_length=1, exclude=True)
try_number: int = Field(..., ge=0, exclude=True)
full_content: bool = True

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class JobStatus(BaseModel):
Expand Down Expand Up @@ -95,3 +177,38 @@ def from_airflow_dag_run(cls, airflow_dag_run: AirflowDagRun):
def jinja_dict(self):
"""Map model to a dictionary that jinja can render"""
return self.model_dump(exclude_none=True)


class JobTasks(BaseModel):
"""Model for what is rendered to the user for each task."""

job_id: Optional[str] = Field(None)
task_id: Optional[str] = Field(None)
try_number: Optional[int] = Field(None)
task_state: Optional[str] = Field(None)
priority_weight: Optional[int] = Field(None)
map_index: Optional[int] = Field(None)
submit_time: Optional[datetime] = Field(None)
start_time: Optional[datetime] = Field(None)
end_time: Optional[datetime] = Field(None)
duration: Optional[Union[int, float]] = Field(None)
comment: Optional[str] = Field(None)

@classmethod
def from_airflow_task_instance(
cls, airflow_task_instance: AirflowTaskInstance
):
"""Maps the fields from the HpcJobStatusResponse to this model"""
return cls(
job_id=airflow_task_instance.dag_run_id,
task_id=airflow_task_instance.task_id,
try_number=airflow_task_instance.try_number,
task_state=airflow_task_instance.state,
priority_weight=airflow_task_instance.priority_weight,
map_index=airflow_task_instance.map_index,
submit_time=airflow_task_instance.execution_date,
start_time=airflow_task_instance.start_date,
end_time=airflow_task_instance.end_date,
duration=airflow_task_instance.duration,
comment=airflow_task_instance.note,
)
174 changes: 165 additions & 9 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
from aind_data_transfer_service.models import (
AirflowDagRun,
AirflowDagRunRequestParameters,
AirflowDagRunsRequestParameters,
AirflowDagRunsResponse,
AirflowTaskInstanceLogsRequestParameters,
AirflowTaskInstancesRequestParameters,
AirflowTaskInstancesResponse,
JobStatus,
JobTasks,
)

template_directory = os.path.abspath(
Expand Down Expand Up @@ -394,15 +399,18 @@ async def get_job_status_list(request: Request):
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
get_one_job = request.query_params.get("dag_run_id") is not None
if get_one_job:
dag_run_id = request.query_params["dag_run_id"]
params = AirflowDagRunRequestParameters.from_query_params(
request.query_params
)
url = f"{url}/{params.dag_run_id}"
else:
params = AirflowDagRunsRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json())
params_dict = json.loads(params.model_dump_json())
# Send request to Airflow to ListDagRuns or GetDagRun
response_jobs = requests.get(
url=f"{url}/{dag_run_id}" if get_one_job else url,
url=url,
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
Expand All @@ -427,9 +435,7 @@ async def get_job_status_list(request: Request):
]
message = "Retrieved job status list from airflow"
data = {
"params": (
{"dag_run_id": dag_run_id} if get_one_job else params_dict
),
"params": params_dict,
"total_entries": dag_runs.total_entries,
"job_status_list": [
json.loads(j.model_dump_json()) for j in job_status_list
Expand All @@ -438,9 +444,7 @@ async def get_job_status_list(request: Request):
else:
message = "Error retrieving job status list from airflow"
data = {
"params": (
{"dag_run_id": dag_run_id} if get_one_job else params_dict
),
"params": params_dict,
"errors": [response_jobs.json()],
}
except ValidationError as e:
Expand All @@ -462,6 +466,115 @@ async def get_job_status_list(request: Request):
)


async def get_tasks_list(request: Request):
"""Get list of tasks instances given a job id."""
try:
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
params = AirflowTaskInstancesRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json())
response_tasks = requests.get(
url=f"{url}/{params.dag_run_id}/taskInstances",
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
),
)
status_code = response_tasks.status_code
if response_tasks.status_code == 200:
task_instances = AirflowTaskInstancesResponse.model_validate_json(
json.dumps(response_tasks.json())
)
job_tasks_list = sorted(
[
JobTasks.from_airflow_task_instance(t)
for t in task_instances.task_instances
],
key=lambda t: (-t.priority_weight, t.map_index),
)
message = "Retrieved job tasks list from airflow"
data = {
"params": params_dict,
"total_entries": task_instances.total_entries,
"job_tasks_list": [
json.loads(t.model_dump_json()) for t in job_tasks_list
],
}
else:
message = "Error retrieving job tasks list from airflow"
data = {
"params": params_dict,
"errors": [response_tasks.json()],
}
except ValidationError as e:
logging.error(e)
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
status_code = 500
message = "Unable to retrieve job tasks list from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
return JSONResponse(
status_code=status_code,
content={
"message": message,
"data": data,
},
)


async def get_task_logs(request: Request):
"""Get task logs given a job id, task id, and task try number."""
try:
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
params = AirflowTaskInstanceLogsRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json())
params_full = dict(params)
response_logs = requests.get(
url=(
f"{url}/{params.dag_run_id}/taskInstances/{params.task_id}"
f"/logs/{params.try_number}"
),
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
),
params=params_dict,
)
status_code = response_logs.status_code
if response_logs.status_code == 200:
message = "Retrieved task logs from airflow"
data = {"params": params_full, "logs": response_logs.text}
else:
message = "Error retrieving task logs from airflow"
data = {
"params": params_full,
"errors": [response_logs.json()],
}
except ValidationError as e:
logging.error(e)
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
status_code = 500
message = "Unable to retrieve task logs from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
return JSONResponse(
status_code=status_code,
content={
"message": message,
"data": data,
},
)


async def index(request: Request):
"""GET|POST /: form handler"""
return templates.TemplateResponse(
Expand Down Expand Up @@ -500,6 +613,45 @@ async def job_status_table(request: Request):
)


async def job_tasks_table(request: Request):
"""Get Job Tasks table given a job id"""
response_tasks = await get_tasks_list(request)
response_tasks_json = json.loads(response_tasks.body)
data = response_tasks_json.get("data")
return templates.TemplateResponse(
name="job_tasks_table.html",
context=(
{
"request": request,
"status_code": response_tasks.status_code,
"message": response_tasks_json.get("message"),
"errors": data.get("errors", []),
"total_entries": data.get("total_entries", 0),
"job_tasks_list": data.get("job_tasks_list", []),
}
),
)


async def task_logs(request: Request):
"""Get task logs given a job id, task id, and task try number."""
response_tasks = await get_task_logs(request)
response_tasks_json = json.loads(response_tasks.body)
data = response_tasks_json.get("data")
return templates.TemplateResponse(
name="task_logs.html",
context=(
{
"request": request,
"status_code": response_tasks.status_code,
"message": response_tasks_json.get("message"),
"errors": data.get("errors", []),
"logs": data.get("logs"),
}
),
)


async def jobs(request: Request):
"""Get Job Status page with pagination"""
default_limit = AirflowDagRunsRequestParameters.model_fields[
Expand Down Expand Up @@ -571,8 +723,12 @@ async def download_job_template(_: Request):
endpoint=get_job_status_list,
methods=["GET"],
),
Route("/api/v1/get_tasks_list", endpoint=get_tasks_list, methods=["GET"]),
Route("/api/v1/get_task_logs", endpoint=get_task_logs, methods=["GET"]),
Route("/jobs", endpoint=jobs, methods=["GET"]),
Route("/job_status_table", endpoint=job_status_table, methods=["GET"]),
Route("/job_tasks_table", endpoint=job_tasks_table, methods=["GET"]),
Route("/task_logs", endpoint=task_logs, methods=["GET"]),
Route(
"/api/job_upload_template",
endpoint=download_job_template,
Expand Down
Loading

0 comments on commit 6666fe3

Please sign in to comment.