diff --git a/src/aind_data_transfer_service/models.py b/src/aind_data_transfer_service/models.py index ebf9d04..816edc1 100644 --- a/src/aind_data_transfer_service/models.py +++ b/src/aind_data_transfer_service/models.py @@ -4,7 +4,19 @@ from datetime import datetime, timedelta, timezone from typing import List, Optional, Union -from pydantic import AwareDatetime, BaseModel, Field, field_validator +from aind_data_transfer_models.core import ( + BasicUploadJobConfigs, + ModalityConfigs, + SubmitJobRequest, +) +from pydantic import ( + AwareDatetime, + BaseModel, + ConfigDict, + Field, + field_validator, +) +from pydantic.json_schema import SkipJsonSchema from starlette.datastructures import QueryParams @@ -212,3 +224,40 @@ def from_airflow_task_instance( duration=airflow_task_instance.duration, comment=airflow_task_instance.note, ) + + +class ModalityConfigsForm(ModalityConfigs): + """Configurations for a modality type""" + + model_config = ConfigDict(extra="forbid") + job_settings: Optional[str] = Field( + default=None, + description=( + "Configs to pass into modality compression job. " + "Must be serialized as json string." + ), + ) + # remove from json schema: + slurm_settings: SkipJsonSchema[None] = None + + +class BasicUploadJobConfigsForm(BasicUploadJobConfigs): + """Configuration for a basic upload job""" + + model_config = ConfigDict(extra="forbid") + modalities: List[ModalityConfigsForm] + # remove from json schema: + user_email: SkipJsonSchema[None] = None + email_notification_types: SkipJsonSchema[None] = None + input_data_mount: SkipJsonSchema[None] = None + process_capsule_id: SkipJsonSchema[None] = None + trigger_capsule_configs: SkipJsonSchema[None] = None + + +class SubmitJobRequestForm(SubmitJobRequest): + """Form to submit a list of jobs to aind-data-transfer-service""" + + model_config = ConfigDict(extra="forbid") + upload_jobs: List[BasicUploadJobConfigsForm] + # remove from json schema: + job_type: SkipJsonSchema[None] = None diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index b9270d7..3f3783b 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -9,13 +9,18 @@ from pathlib import PurePosixPath import requests -from aind_data_transfer_models.core import SubmitJobRequest, BasicUploadJobConfigs +from aind_data_transfer_models.core import ( + BasicUploadJobConfigs, + SubmitJobRequest, +) from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse from fastapi.templating import Jinja2Templates from openpyxl import load_workbook from pydantic import SecretStr, ValidationError from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware from starlette.routing import Route from aind_data_transfer_service import OPEN_DATA_BUCKET_NAME @@ -37,11 +42,11 @@ AirflowTaskInstanceLogsRequestParameters, AirflowTaskInstancesRequestParameters, AirflowTaskInstancesResponse, + BasicUploadJobConfigsForm, JobStatus, JobTasks, + SubmitJobRequestForm, ) -from starlette.middleware.cors import CORSMiddleware -from starlette.middleware import Middleware template_directory = os.path.abspath( os.path.join(os.path.dirname(__file__), "templates") @@ -72,6 +77,7 @@ Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) ] + async def validate_csv(request: Request): """Validate a csv or xlsx file. Return parsed contents as json.""" async with request.form() as form: @@ -177,44 +183,47 @@ async def get_json_schema_for_model(request: Request): """Get the JSON schema for models from aind-data-transfer-models.""" # GET /api/v1/models/{model_name}/schema model_name = request.path_params.get("model_name") - match model_name: - case "BasicUploadJobConfigs": - model = BasicUploadJobConfigs - case "SubmitJobRequest": - model = SubmitJobRequest - case _: - return JSONResponse( - status_code=404, - content={ - "message": f"Schema not found for {model_name}", - }, - ) + # simplified versions for form generation + if model_name == "BasicUploadJobConfigsForm": + model = BasicUploadJobConfigsForm + elif model_name == "SubmitJobRequestForm": + model = SubmitJobRequestForm + # full versions (from aind-data-transfer-models) + elif model_name == "BasicUploadJobConfigs": + model = BasicUploadJobConfigs + elif model_name == "SubmitJobRequest": + model = SubmitJobRequest + else: + return JSONResponse( + status_code=404, + content={ + "message": f"Schema not found for {model_name}", + }, + ) json_schema = model.model_json_schema() return JSONResponse( status_code=200, content=json_schema, ) - async def validate_json_for_model(request: Request): - """Validate a SubmitJobRequest raw json. Returns validated job request, - or errors if request is invalid.""" + """Validate raw json against aind-data-transfer-models. Returns validated + json or errors if request is invalid.""" # POST /api/v1/models/{model_name}/validate model_name = request.path_params.get("model_name") content = await request.json() - match model_name: - case "BasicUploadJobConfigs": - model = BasicUploadJobConfigs - case "SubmitJobRequest": - model = SubmitJobRequest - case _: - return JSONResponse( - status_code=404, - content={ - "message": f"Model not found for {model_name}", - }, - ) + if model_name == "BasicUploadJobConfigs": + model = BasicUploadJobConfigs + elif model_name == "SubmitJobRequest": + model = SubmitJobRequest + else: + return JSONResponse( + status_code=404, + content={ + "message": f"Model not found for {model_name}", + }, + ) try: validated_model = model.model_validate_json(json.dumps(content)) validated_content = json.loads( @@ -223,10 +232,10 @@ async def validate_json_for_model(request: Request): return JSONResponse( status_code=200, content={ - "message": "Valid job request", + "message": "Valid model", "data": { - "job_request": content, - "validated_job_request": validated_content, + "model_json": content, + "validated_model_json": validated_content, }, }, ) @@ -236,7 +245,7 @@ async def validate_json_for_model(request: Request): content={ "message": "There were validation errors", "data": { - "job_request": content, + "model_json": content, "errors": e.json(), }, }, @@ -247,12 +256,13 @@ async def validate_json_for_model(request: Request): content={ "message": "There was an internal server error", "data": { - "job_request": content, + "model_json": content, "errors": str(e.args), }, }, ) + async def submit_jobs(request: Request): """Post BasicJobConfigs raw json to hpc server to process.""" content = await request.json() @@ -810,8 +820,16 @@ async def download_job_template(_: Request): "/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"] ), Route("/api/submit_hpc_jobs", endpoint=submit_hpc_jobs, methods=["POST"]), - Route("/api/v1/models/{model_name:str}/schema", endpoint=get_json_schema_for_model, methods=["GET"]), - Route("/api/v1/models/{model_name:str}/validate", endpoint=validate_json_for_model, methods=["POST"]), + Route( + "/api/v1/models/{model_name:str}/schema", + endpoint=get_json_schema_for_model, + methods=["GET"], + ), + Route( + "/api/v1/models/{model_name:str}/validate", + endpoint=validate_json_for_model, + methods=["POST"], + ), Route("/api/v1/validate_csv", endpoint=validate_csv, methods=["POST"]), Route("/api/v1/submit_jobs", endpoint=submit_jobs, methods=["POST"]), Route(