Skip to content

Commit

Permalink
feat: add wrapper models for form generation
Browse files Browse the repository at this point in the history
  • Loading branch information
helen-m-lin committed Nov 26, 2024
1 parent e524892 commit 9b1da6a
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 38 deletions.
51 changes: 50 additions & 1 deletion src/aind_data_transfer_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
from datetime import datetime, timedelta, timezone
from typing import List, Optional, Union

from pydantic import AwareDatetime, BaseModel, Field, field_validator
from aind_data_transfer_models.core import (
BasicUploadJobConfigs,
ModalityConfigs,
SubmitJobRequest,
)
from pydantic import (
AwareDatetime,
BaseModel,
ConfigDict,
Field,
field_validator,
)
from pydantic.json_schema import SkipJsonSchema
from starlette.datastructures import QueryParams


Expand Down Expand Up @@ -212,3 +224,40 @@ def from_airflow_task_instance(
duration=airflow_task_instance.duration,
comment=airflow_task_instance.note,
)


class ModalityConfigsForm(ModalityConfigs):
"""Configurations for a modality type"""

model_config = ConfigDict(extra="forbid")
job_settings: Optional[str] = Field(
default=None,
description=(
"Configs to pass into modality compression job. "
"Must be serialized as json string."
),
)
# remove from json schema:
slurm_settings: SkipJsonSchema[None] = None


class BasicUploadJobConfigsForm(BasicUploadJobConfigs):
"""Configuration for a basic upload job"""

model_config = ConfigDict(extra="forbid")
modalities: List[ModalityConfigsForm]
# remove from json schema:
user_email: SkipJsonSchema[None] = None
email_notification_types: SkipJsonSchema[None] = None
input_data_mount: SkipJsonSchema[None] = None
process_capsule_id: SkipJsonSchema[None] = None
trigger_capsule_configs: SkipJsonSchema[None] = None


class SubmitJobRequestForm(SubmitJobRequest):
"""Form to submit a list of jobs to aind-data-transfer-service"""

model_config = ConfigDict(extra="forbid")
upload_jobs: List[BasicUploadJobConfigsForm]
# remove from json schema:
job_type: SkipJsonSchema[None] = None
92 changes: 55 additions & 37 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@
from pathlib import PurePosixPath

import requests
from aind_data_transfer_models.core import SubmitJobRequest, BasicUploadJobConfigs
from aind_data_transfer_models.core import (
BasicUploadJobConfigs,
SubmitJobRequest,
)
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from openpyxl import load_workbook
from pydantic import SecretStr, ValidationError
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Route

from aind_data_transfer_service import OPEN_DATA_BUCKET_NAME
Expand All @@ -37,11 +42,11 @@
AirflowTaskInstanceLogsRequestParameters,
AirflowTaskInstancesRequestParameters,
AirflowTaskInstancesResponse,
BasicUploadJobConfigsForm,
JobStatus,
JobTasks,
SubmitJobRequestForm,
)
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware import Middleware

template_directory = os.path.abspath(
os.path.join(os.path.dirname(__file__), "templates")
Expand Down Expand Up @@ -72,6 +77,7 @@
Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
]


async def validate_csv(request: Request):
"""Validate a csv or xlsx file. Return parsed contents as json."""
async with request.form() as form:
Expand Down Expand Up @@ -177,44 +183,47 @@ async def get_json_schema_for_model(request: Request):
"""Get the JSON schema for models from aind-data-transfer-models."""
# GET /api/v1/models/{model_name}/schema
model_name = request.path_params.get("model_name")
match model_name:
case "BasicUploadJobConfigs":
model = BasicUploadJobConfigs
case "SubmitJobRequest":
model = SubmitJobRequest
case _:
return JSONResponse(
status_code=404,
content={
"message": f"Schema not found for {model_name}",
},
)
# simplified versions for form generation
if model_name == "BasicUploadJobConfigsForm":
model = BasicUploadJobConfigsForm
elif model_name == "SubmitJobRequestForm":
model = SubmitJobRequestForm
# full versions (from aind-data-transfer-models)
elif model_name == "BasicUploadJobConfigs":
model = BasicUploadJobConfigs
elif model_name == "SubmitJobRequest":
model = SubmitJobRequest
else:
return JSONResponse(
status_code=404,
content={
"message": f"Schema not found for {model_name}",
},
)
json_schema = model.model_json_schema()
return JSONResponse(
status_code=200,
content=json_schema,
)



async def validate_json_for_model(request: Request):
"""Validate a SubmitJobRequest raw json. Returns validated job request,
or errors if request is invalid."""
"""Validate raw json against aind-data-transfer-models. Returns validated
json or errors if request is invalid."""
# POST /api/v1/models/{model_name}/validate
model_name = request.path_params.get("model_name")
content = await request.json()
match model_name:
case "BasicUploadJobConfigs":
model = BasicUploadJobConfigs
case "SubmitJobRequest":
model = SubmitJobRequest
case _:
return JSONResponse(
status_code=404,
content={
"message": f"Model not found for {model_name}",
},
)
if model_name == "BasicUploadJobConfigs":
model = BasicUploadJobConfigs
elif model_name == "SubmitJobRequest":
model = SubmitJobRequest
else:
return JSONResponse(
status_code=404,
content={
"message": f"Model not found for {model_name}",
},
)
try:
validated_model = model.model_validate_json(json.dumps(content))
validated_content = json.loads(
Expand All @@ -223,10 +232,10 @@ async def validate_json_for_model(request: Request):
return JSONResponse(
status_code=200,
content={
"message": "Valid job request",
"message": "Valid model",
"data": {
"job_request": content,
"validated_job_request": validated_content,
"model_json": content,
"validated_model_json": validated_content,
},
},
)
Expand All @@ -236,7 +245,7 @@ async def validate_json_for_model(request: Request):
content={
"message": "There were validation errors",
"data": {
"job_request": content,
"model_json": content,
"errors": e.json(),
},
},
Expand All @@ -247,12 +256,13 @@ async def validate_json_for_model(request: Request):
content={
"message": "There was an internal server error",
"data": {
"job_request": content,
"model_json": content,
"errors": str(e.args),
},
},
)


async def submit_jobs(request: Request):
"""Post BasicJobConfigs raw json to hpc server to process."""
content = await request.json()
Expand Down Expand Up @@ -810,8 +820,16 @@ async def download_job_template(_: Request):
"/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"]
),
Route("/api/submit_hpc_jobs", endpoint=submit_hpc_jobs, methods=["POST"]),
Route("/api/v1/models/{model_name:str}/schema", endpoint=get_json_schema_for_model, methods=["GET"]),
Route("/api/v1/models/{model_name:str}/validate", endpoint=validate_json_for_model, methods=["POST"]),
Route(
"/api/v1/models/{model_name:str}/schema",
endpoint=get_json_schema_for_model,
methods=["GET"],
),
Route(
"/api/v1/models/{model_name:str}/validate",
endpoint=validate_json_for_model,
methods=["POST"],
),
Route("/api/v1/validate_csv", endpoint=validate_csv, methods=["POST"]),
Route("/api/v1/submit_jobs", endpoint=submit_jobs, methods=["POST"]),
Route(
Expand Down

0 comments on commit 9b1da6a

Please sign in to comment.