Skip to content

Commit

Permalink
fix: python warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
helen-m-lin committed Aug 23, 2024
1 parent 27ae10f commit f84ee0c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
27 changes: 21 additions & 6 deletions src/aind_data_transfer_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Optional, Union

from pydantic import AwareDatetime, BaseModel, Field, field_validator
from starlette.datastructures import QueryParams


class AirflowDagRun(BaseModel):
Expand Down Expand Up @@ -58,12 +59,24 @@ def validate_min_execution_date(cls, execution_date_gte: str):
return execution_date_gte

@classmethod
def from_query_params(cls, query_params: dict):
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
if "state" in params:
params["state"] = ast.literal_eval(params["state"])
return cls(**params)
return cls.model_validate(params)


class AirflowDagRunRequestParameters(BaseModel):
"""Model for parameters when requesting info from dag_run endpoint"""

dag_run_id: str = Field(..., min_length=1)

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class AirflowTaskInstancesRequestParameters(BaseModel):
Expand All @@ -73,9 +86,10 @@ class AirflowTaskInstancesRequestParameters(BaseModel):
dag_run_id: str = Field(..., min_length=1)

@classmethod
def from_query_params(cls, query_params: dict):
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
return cls(**query_params)
params = dict(query_params)
return cls.model_validate(params)


class AirflowTaskInstance(BaseModel):
Expand Down Expand Up @@ -128,9 +142,10 @@ class AirflowTaskInstanceLogsRequestParameters(BaseModel):
full_content: bool = True

@classmethod
def from_query_params(cls, query_params: dict):
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
return cls(**query_params)
params = dict(query_params)
return cls.model_validate(params)


class JobStatus(BaseModel):
Expand Down
18 changes: 9 additions & 9 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
from aind_data_transfer_service.models import (
AirflowDagRun,
AirflowDagRunRequestParameters,
AirflowDagRunsRequestParameters,
AirflowDagRunsResponse,
AirflowTaskInstanceLogsRequestParameters,
Expand Down Expand Up @@ -398,15 +399,18 @@ async def get_job_status_list(request: Request):
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
get_one_job = request.query_params.get("dag_run_id") is not None
if get_one_job:
dag_run_id = request.query_params["dag_run_id"]
params = AirflowDagRunRequestParameters.from_query_params(
request.query_params
)
url = f"{url}/{params.dag_run_id}"
else:
params = AirflowDagRunsRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json())
params_dict = json.loads(params.model_dump_json())
# Send request to Airflow to ListDagRuns or GetDagRun
response_jobs = requests.get(
url=f"{url}/{dag_run_id}" if get_one_job else url,
url=url,
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
Expand All @@ -431,9 +435,7 @@ async def get_job_status_list(request: Request):
]
message = "Retrieved job status list from airflow"
data = {
"params": (
{"dag_run_id": dag_run_id} if get_one_job else params_dict
),
"params": params_dict,
"total_entries": dag_runs.total_entries,
"job_status_list": [
json.loads(j.model_dump_json()) for j in job_status_list
Expand All @@ -442,9 +444,7 @@ async def get_job_status_list(request: Request):
else:
message = "Error retrieving job status list from airflow"
data = {
"params": (
{"dag_run_id": dag_run_id} if get_one_job else params_dict
),
"params": params_dict,
"errors": [response_jobs.json()],
}
except ValidationError as e:
Expand Down

0 comments on commit f84ee0c

Please sign in to comment.