Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: task instances and logs from airflow #135

Merged
merged 8 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 120 additions & 3 deletions src/aind_data_transfer_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import ast
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from typing import List, Optional, Union

from pydantic import AwareDatetime, BaseModel, Field, field_validator
from starlette.datastructures import QueryParams


class AirflowDagRun(BaseModel):
Expand Down Expand Up @@ -58,12 +59,93 @@ def validate_min_execution_date(cls, execution_date_gte: str):
return execution_date_gte

@classmethod
def from_query_params(cls, query_params: dict):
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
if "state" in params:
params["state"] = ast.literal_eval(params["state"])
return cls(**params)
return cls.model_validate(params)


class AirflowDagRunRequestParameters(BaseModel):
"""Model for parameters when requesting info from dag_run endpoint"""

dag_run_id: str = Field(..., min_length=1)

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class AirflowTaskInstancesRequestParameters(BaseModel):
"""Model for parameters when requesting info from task_instances
endpoint"""

dag_run_id: str = Field(..., min_length=1)

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class AirflowTaskInstance(BaseModel):
"""Data model for task_instance entry when requesting info from airflow"""

dag_id: Optional[str]
dag_run_id: Optional[str]
duration: Optional[Union[int, float]]
end_date: Optional[AwareDatetime]
execution_date: Optional[AwareDatetime]
executor_config: Optional[str]
hostname: Optional[str]
map_index: Optional[int]
max_tries: Optional[int]
note: Optional[str]
operator: Optional[str]
pid: Optional[int]
pool: Optional[str]
pool_slots: Optional[int]
priority_weight: Optional[int]
queue: Optional[str]
queued_when: Optional[AwareDatetime]
rendered_fields: Optional[dict]
sla_miss: Optional[dict]
start_date: Optional[AwareDatetime]
state: Optional[str]
task_id: Optional[str]
trigger: Optional[dict]
triggerer_job: Optional[dict]
try_number: Optional[int]
unixname: Optional[str]


class AirflowTaskInstancesResponse(BaseModel):
"""Data model for response when requesting info from task_instances
endpoint"""

task_instances: List[AirflowTaskInstance]
total_entries: int


class AirflowTaskInstanceLogsRequestParameters(BaseModel):
"""Model for parameters when requesting info from task_instance_logs
endpoint"""

# excluded fields are used to build the url
dag_run_id: str = Field(..., min_length=1, exclude=True)
task_id: str = Field(..., min_length=1, exclude=True)
try_number: int = Field(..., ge=0, exclude=True)
full_content: bool = True

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class JobStatus(BaseModel):
Expand Down Expand Up @@ -95,3 +177,38 @@ def from_airflow_dag_run(cls, airflow_dag_run: AirflowDagRun):
def jinja_dict(self):
"""Map model to a dictionary that jinja can render"""
return self.model_dump(exclude_none=True)


class JobTasks(BaseModel):
"""Model for what is rendered to the user for each task."""

job_id: Optional[str] = Field(None)
task_id: Optional[str] = Field(None)
try_number: Optional[int] = Field(None)
task_state: Optional[str] = Field(None)
priority_weight: Optional[int] = Field(None)
map_index: Optional[int] = Field(None)
submit_time: Optional[datetime] = Field(None)
start_time: Optional[datetime] = Field(None)
end_time: Optional[datetime] = Field(None)
duration: Optional[Union[int, float]] = Field(None)
comment: Optional[str] = Field(None)

@classmethod
def from_airflow_task_instance(
cls, airflow_task_instance: AirflowTaskInstance
):
"""Maps the fields from the HpcJobStatusResponse to this model"""
return cls(
job_id=airflow_task_instance.dag_run_id,
task_id=airflow_task_instance.task_id,
try_number=airflow_task_instance.try_number,
task_state=airflow_task_instance.state,
priority_weight=airflow_task_instance.priority_weight,
map_index=airflow_task_instance.map_index,
submit_time=airflow_task_instance.execution_date,
start_time=airflow_task_instance.start_date,
end_time=airflow_task_instance.end_date,
duration=airflow_task_instance.duration,
comment=airflow_task_instance.note,
)
174 changes: 165 additions & 9 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
from aind_data_transfer_service.models import (
AirflowDagRun,
AirflowDagRunRequestParameters,
AirflowDagRunsRequestParameters,
AirflowDagRunsResponse,
AirflowTaskInstanceLogsRequestParameters,
AirflowTaskInstancesRequestParameters,
AirflowTaskInstancesResponse,
JobStatus,
JobTasks,
)

template_directory = os.path.abspath(
Expand Down Expand Up @@ -394,15 +399,18 @@ async def get_job_status_list(request: Request):
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
get_one_job = request.query_params.get("dag_run_id") is not None
if get_one_job:
dag_run_id = request.query_params["dag_run_id"]
params = AirflowDagRunRequestParameters.from_query_params(
request.query_params
)
url = f"{url}/{params.dag_run_id}"
else:
params = AirflowDagRunsRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json())
params_dict = json.loads(params.model_dump_json())
# Send request to Airflow to ListDagRuns or GetDagRun
response_jobs = requests.get(
url=f"{url}/{dag_run_id}" if get_one_job else url,
url=url,
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
Expand All @@ -427,9 +435,7 @@ async def get_job_status_list(request: Request):
]
message = "Retrieved job status list from airflow"
data = {
"params": (
{"dag_run_id": dag_run_id} if get_one_job else params_dict
),
"params": params_dict,
"total_entries": dag_runs.total_entries,
"job_status_list": [
json.loads(j.model_dump_json()) for j in job_status_list
Expand All @@ -438,9 +444,7 @@ async def get_job_status_list(request: Request):
else:
message = "Error retrieving job status list from airflow"
data = {
"params": (
{"dag_run_id": dag_run_id} if get_one_job else params_dict
),
"params": params_dict,
"errors": [response_jobs.json()],
}
except ValidationError as e:
Expand All @@ -462,6 +466,115 @@ async def get_job_status_list(request: Request):
)


async def get_tasks_list(request: Request):
"""Get list of tasks instances given a job id."""
try:
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
params = AirflowTaskInstancesRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json())
response_tasks = requests.get(
url=f"{url}/{params.dag_run_id}/taskInstances",
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
),
)
status_code = response_tasks.status_code
if response_tasks.status_code == 200:
task_instances = AirflowTaskInstancesResponse.model_validate_json(
json.dumps(response_tasks.json())
)
job_tasks_list = sorted(
[
JobTasks.from_airflow_task_instance(t)
for t in task_instances.task_instances
],
key=lambda t: (-t.priority_weight, t.map_index),
jtyoung84 marked this conversation as resolved.
Show resolved Hide resolved
)
message = "Retrieved job tasks list from airflow"
data = {
"params": params_dict,
"total_entries": task_instances.total_entries,
"job_tasks_list": [
json.loads(t.model_dump_json()) for t in job_tasks_list
],
}
else:
message = "Error retrieving job tasks list from airflow"
data = {
"params": params_dict,
"errors": [response_tasks.json()],
}
except ValidationError as e:
logging.error(e)
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
status_code = 500
message = "Unable to retrieve job tasks list from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
return JSONResponse(
status_code=status_code,
content={
"message": message,
"data": data,
},
)


async def get_task_logs(request: Request):
"""Get task logs given a job id, task id, and task try number."""
try:
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
params = AirflowTaskInstanceLogsRequestParameters.from_query_params(
request.query_params
jtyoung84 marked this conversation as resolved.
Show resolved Hide resolved
)
params_dict = json.loads(params.model_dump_json())
params_full = dict(params)
response_logs = requests.get(
url=(
f"{url}/{params.dag_run_id}/taskInstances/{params.task_id}"
f"/logs/{params.try_number}"
),
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
),
params=params_dict,
)
status_code = response_logs.status_code
if response_logs.status_code == 200:
message = "Retrieved task logs from airflow"
data = {"params": params_full, "logs": response_logs.text}
else:
message = "Error retrieving task logs from airflow"
data = {
"params": params_full,
"errors": [response_logs.json()],
}
except ValidationError as e:
logging.error(e)
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
status_code = 500
message = "Unable to retrieve task logs from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
return JSONResponse(
status_code=status_code,
content={
"message": message,
"data": data,
},
)


async def index(request: Request):
"""GET|POST /: form handler"""
return templates.TemplateResponse(
Expand Down Expand Up @@ -500,6 +613,45 @@ async def job_status_table(request: Request):
)


async def job_tasks_table(request: Request):
"""Get Job Tasks table given a job id"""
response_tasks = await get_tasks_list(request)
response_tasks_json = json.loads(response_tasks.body)
data = response_tasks_json.get("data")
return templates.TemplateResponse(
name="job_tasks_table.html",
context=(
{
"request": request,
"status_code": response_tasks.status_code,
"message": response_tasks_json.get("message"),
"errors": data.get("errors", []),
"total_entries": data.get("total_entries", 0),
"job_tasks_list": data.get("job_tasks_list", []),
}
),
)


async def task_logs(request: Request):
"""Get task logs given a job id, task id, and task try number."""
response_tasks = await get_task_logs(request)
response_tasks_json = json.loads(response_tasks.body)
data = response_tasks_json.get("data")
return templates.TemplateResponse(
name="task_logs.html",
context=(
{
"request": request,
"status_code": response_tasks.status_code,
"message": response_tasks_json.get("message"),
"errors": data.get("errors", []),
"logs": data.get("logs"),
}
),
)


async def jobs(request: Request):
"""Get Job Status page with pagination"""
default_limit = AirflowDagRunsRequestParameters.model_fields[
Expand Down Expand Up @@ -571,8 +723,12 @@ async def download_job_template(_: Request):
endpoint=get_job_status_list,
methods=["GET"],
),
Route("/api/v1/get_tasks_list", endpoint=get_tasks_list, methods=["GET"]),
Route("/api/v1/get_task_logs", endpoint=get_task_logs, methods=["GET"]),
Route("/jobs", endpoint=jobs, methods=["GET"]),
Route("/job_status_table", endpoint=job_status_table, methods=["GET"]),
Route("/job_tasks_table", endpoint=job_tasks_table, methods=["GET"]),
Route("/task_logs", endpoint=task_logs, methods=["GET"]),
Route(
"/api/job_upload_template",
endpoint=download_job_template,
Expand Down
Loading
Loading