Skip to content

Commit

Permalink
Merge pull request #198 from AllenNeuralDynamics/release-v1.9.0
Browse files Browse the repository at this point in the history
Release v1.9.0
  • Loading branch information
jtyoung84 authored Dec 23, 2024
2 parents c4dd457 + 44aafa9 commit 856491f
Show file tree
Hide file tree
Showing 8 changed files with 493 additions and 185 deletions.
8 changes: 4 additions & 4 deletions docs/source/UserGuide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ portal can accessed at
- s3_bucket: As default, data will be uploaded to a default bucket
in S3 managed by AIND. Please reach out to the Scientific
Computing department if you wish to upload to a different bucket.
- metadata_dir_force: We will automatically pull subject and
procedures data for a mouse. By setting this ``True``, we will
overwrite any data in the ``metadata_dir`` folder with data
acquired automatically from our service
- force_cloud_sync: We run a check to verify whether there is
already a data asset with this name saved in our S3 bucket. If
this field is set to ``True``, we will sync the data to the
Expand Down Expand Up @@ -151,6 +147,10 @@ endpoint:
)
post_request_content = json.loads(submit_request.model_dump_json(exclude_none=True))
# Optionally validate the submit_request before submitting
validate_job_response = requests.post(url="http://aind-data-transfer-service/api/v1/validate_json", json=post_request_content)
print(validate_job_response.status_code)
print(validate_job_response.json())
# Uncomment the following to submit the request
# submit_job_response = requests.post(url="http://aind-data-transfer-service/api/v1/submit_jobs", json=post_request_content)
# print(submit_job_response.status_code)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
'pydantic>=2.7,<2.9',
'pydantic-settings>=2.0',
'aind-data-schema>=1.0.0,<2.0',
'aind-data-transfer-models==0.14.1'
'aind-data-transfer-models==0.15.0'
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/aind_data_transfer_service/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Init package"""
import os

__version__ = "1.8.0"
__version__ = "1.9.0"

# Global constants
OPEN_DATA_BUCKET_NAME = os.getenv("OPEN_DATA_BUCKET_NAME", "open")
Expand Down
12 changes: 0 additions & 12 deletions src/aind_data_transfer_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,6 @@ def from_query_params(cls, query_params: QueryParams):
return cls.model_validate(params)


class AirflowDagRunRequestParameters(BaseModel):
"""Model for parameters when requesting info from dag_run endpoint"""

dag_run_id: str = Field(..., min_length=1)

@classmethod
def from_query_params(cls, query_params: QueryParams):
"""Maps the query parameters to the model"""
params = dict(query_params)
return cls.model_validate(params)


class AirflowTaskInstancesRequestParameters(BaseModel):
"""Model for parameters when requesting info from task_instances
endpoint"""
Expand Down
165 changes: 117 additions & 48 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
import io
import json
import os
from asyncio import sleep
from asyncio import gather, sleep
from pathlib import PurePosixPath
from typing import Optional

import requests
from aind_data_transfer_models import (
__version__ as aind_data_transfer_models_version,
)
from aind_data_transfer_models.core import SubmitJobRequest
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from httpx import AsyncClient
from openpyxl import load_workbook
from pydantic import SecretStr, ValidationError
from starlette.applications import Starlette
Expand All @@ -30,8 +35,6 @@
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger
from aind_data_transfer_service.models import (
AirflowDagRun,
AirflowDagRunRequestParameters,
AirflowDagRunsRequestParameters,
AirflowDagRunsResponse,
AirflowTaskInstanceLogsRequestParameters,
Expand Down Expand Up @@ -173,6 +176,58 @@ async def validate_csv_legacy(request: Request):
)


async def validate_json(request: Request):
"""Validate raw json against aind-data-transfer-models. Returns validated
json or errors if request is invalid."""
logger.info("Received request to validate json")
content = await request.json()
try:
validated_model = SubmitJobRequest.model_validate_json(
json.dumps(content)
)
validated_content = json.loads(
validated_model.model_dump_json(warnings=False, exclude_none=True)
)
logger.info("Valid model detected")
return JSONResponse(
status_code=200,
content={
"message": "Valid model",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"validated_model_json": validated_content,
},
},
)
except ValidationError as e:
logger.warning(f"There were validation errors processing {content}")
return JSONResponse(
status_code=406,
content={
"message": "There were validation errors",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"errors": e.json(),
},
},
)
except Exception as e:
logger.exception("Internal Server Error.")
return JSONResponse(
status_code=500,
content={
"message": "There was an internal server error",
"data": {
"version": aind_data_transfer_models_version,
"model_json": content,
"errors": str(e.args),
},
},
)


async def submit_jobs(request: Request):
"""Post BasicJobConfigs raw json to hpc server to process."""
logger.info("Received request to submit jobs")
Expand Down Expand Up @@ -422,59 +477,72 @@ async def submit_hpc_jobs(request: Request): # noqa: C901

async def get_job_status_list(request: Request):
"""Get status of jobs with default pagination of limit=25 and offset=0."""
# TODO: Use httpx async client

async def fetch_jobs(
client: AsyncClient, url: str, params: Optional[dict]
):
"""Helper method to fetch jobs using httpx async client"""
response = await client.get(url, params=params)
response.raise_for_status()
return response.json()

try:
url = os.getenv("AIND_AIRFLOW_SERVICE_JOBS_URL", "").strip("/")
get_one_job = request.query_params.get("dag_run_id") is not None
if get_one_job:
params = AirflowDagRunRequestParameters.from_query_params(
request.query_params
)
url = f"{url}/{params.dag_run_id}"
else:
params = AirflowDagRunsRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json())
# Send request to Airflow to ListDagRuns or GetDagRun
response_jobs = requests.get(
url=url,
get_all_jobs = request.query_params.get("get_all_jobs") is not None
params = AirflowDagRunsRequestParameters.from_query_params(
request.query_params
)
params_dict = json.loads(params.model_dump_json(exclude_none=True))
# Send request to Airflow to ListDagRuns
async with AsyncClient(
auth=(
os.getenv("AIND_AIRFLOW_SERVICE_USER"),
os.getenv("AIND_AIRFLOW_SERVICE_PASSWORD"),
),
params=None if get_one_job else params_dict,
)
status_code = response_jobs.status_code
if response_jobs.status_code == 200:
if get_one_job:
dag_run = AirflowDagRun.model_validate_json(
json.dumps(response_jobs.json())
)
dag_runs = AirflowDagRunsResponse(
dag_runs=[dag_run], total_entries=1
)
else:
dag_runs = AirflowDagRunsResponse.model_validate_json(
json.dumps(response_jobs.json())
)
)
) as client:
# Fetch initial jobs
response_jobs = await fetch_jobs(
client=client,
url=url,
params=params_dict,
)
dag_runs = AirflowDagRunsResponse.model_validate_json(
json.dumps(response_jobs)
)
job_status_list = [
JobStatus.from_airflow_dag_run(d) for d in dag_runs.dag_runs
]
message = "Retrieved job status list from airflow"
data = {
"params": params_dict,
"total_entries": dag_runs.total_entries,
"job_status_list": [
json.loads(j.model_dump_json()) for j in job_status_list
],
}
else:
message = "Error retrieving job status list from airflow"
data = {
"params": params_dict,
"errors": [response_jobs.json()],
}
total_entries = dag_runs.total_entries
if get_all_jobs:
# Fetch remaining jobs concurrently
tasks = []
offset = params_dict["offset"] + params_dict["limit"]
while offset < total_entries:
params = {**params_dict, "limit": 100, "offset": offset}
tasks.append(
fetch_jobs(client=client, url=url, params=params)
)
offset += 100
batches = await gather(*tasks)
for batch in batches:
dag_runs = AirflowDagRunsResponse.model_validate_json(
json.dumps(batch)
)
job_status_list.extend(
[
JobStatus.from_airflow_dag_run(d)
for d in dag_runs.dag_runs
]
)
status_code = 200
message = "Retrieved job status list from airflow"
data = {
"params": params_dict,
"total_entries": total_entries,
"job_status_list": [
json.loads(j.model_dump_json()) for j in job_status_list
],
}
except ValidationError as e:
logger.warning(
f"There was a validation error process job_status list: {e}"
Expand Down Expand Up @@ -746,6 +814,7 @@ async def download_job_template(_: Request):
"/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"]
),
Route("/api/submit_hpc_jobs", endpoint=submit_hpc_jobs, methods=["POST"]),
Route("/api/v1/validate_json", endpoint=validate_json, methods=["POST"]),
Route("/api/v1/validate_csv", endpoint=validate_csv, methods=["POST"]),
Route("/api/v1/submit_jobs", endpoint=submit_jobs, methods=["POST"]),
Route(
Expand Down
Loading

0 comments on commit 856491f

Please sign in to comment.