From f84ee0c1dd9085a0f30223c827f464018908420c Mon Sep 17 00:00:00 2001 From: Helen Lin Date: Fri, 23 Aug 2024 11:44:36 -0700 Subject: [PATCH] fix: python warnings --- src/aind_data_transfer_service/models.py | 27 ++++++++++++++++++------ src/aind_data_transfer_service/server.py | 18 ++++++++-------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/aind_data_transfer_service/models.py b/src/aind_data_transfer_service/models.py index 9f77fb6..ebf9d04 100644 --- a/src/aind_data_transfer_service/models.py +++ b/src/aind_data_transfer_service/models.py @@ -5,6 +5,7 @@ from typing import List, Optional, Union from pydantic import AwareDatetime, BaseModel, Field, field_validator +from starlette.datastructures import QueryParams class AirflowDagRun(BaseModel): @@ -58,12 +59,24 @@ def validate_min_execution_date(cls, execution_date_gte: str): return execution_date_gte @classmethod - def from_query_params(cls, query_params: dict): + def from_query_params(cls, query_params: QueryParams): """Maps the query parameters to the model""" params = dict(query_params) if "state" in params: params["state"] = ast.literal_eval(params["state"]) - return cls(**params) + return cls.model_validate(params) + + +class AirflowDagRunRequestParameters(BaseModel): + """Model for parameters when requesting info from dag_run endpoint""" + + dag_run_id: str = Field(..., min_length=1) + + @classmethod + def from_query_params(cls, query_params: QueryParams): + """Maps the query parameters to the model""" + params = dict(query_params) + return cls.model_validate(params) class AirflowTaskInstancesRequestParameters(BaseModel): @@ -73,9 +86,10 @@ class AirflowTaskInstancesRequestParameters(BaseModel): dag_run_id: str = Field(..., min_length=1) @classmethod - def from_query_params(cls, query_params: dict): + def from_query_params(cls, query_params: QueryParams): """Maps the query parameters to the model""" - return cls(**query_params) + params = dict(query_params) + return cls.model_validate(params) class AirflowTaskInstance(BaseModel): @@ -128,9 +142,10 @@ class AirflowTaskInstanceLogsRequestParameters(BaseModel): full_content: bool = True @classmethod - def from_query_params(cls, query_params: dict): + def from_query_params(cls, query_params: QueryParams): """Maps the query parameters to the model""" - return cls(**query_params) + params = dict(query_params) + return cls.model_validate(params) class JobStatus(BaseModel): diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index 39184e4..3a9c890 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -31,6 +31,7 @@ from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings from aind_data_transfer_service.models import ( AirflowDagRun, + AirflowDagRunRequestParameters, AirflowDagRunsRequestParameters, AirflowDagRunsResponse, AirflowTaskInstanceLogsRequestParameters, @@ -398,15 +399,18 @@ async def get_job_status_list(request: Request): url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/") get_one_job = request.query_params.get("dag_run_id") is not None if get_one_job: - dag_run_id = request.query_params["dag_run_id"] + params = AirflowDagRunRequestParameters.from_query_params( + request.query_params + ) + url = f"{url}/{params.dag_run_id}" else: params = AirflowDagRunsRequestParameters.from_query_params( request.query_params ) - params_dict = json.loads(params.model_dump_json()) + params_dict = json.loads(params.model_dump_json()) # Send request to Airflow to ListDagRuns or GetDagRun response_jobs = requests.get( - url=f"{url}/{dag_run_id}" if get_one_job else url, + url=url, auth=( os.getenv("AIND_AIRFLOW_SERVICE_USER"), os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"), @@ -431,9 +435,7 @@ async def get_job_status_list(request: Request): ] message = "Retrieved job status list from airflow" data = { - "params": ( - {"dag_run_id": dag_run_id} if get_one_job else params_dict - ), + "params": params_dict, "total_entries": dag_runs.total_entries, "job_status_list": [ json.loads(j.model_dump_json()) for j in job_status_list @@ -442,9 +444,7 @@ async def get_job_status_list(request: Request): else: message = "Error retrieving job status list from airflow" data = { - "params": ( - {"dag_run_id": dag_run_id} if get_one_job else params_dict - ), + "params": params_dict, "errors": [response_jobs.json()], } except ValidationError as e: