From e39be1776126bd08fd7c177476a080672b21155a Mon Sep 17 00:00:00 2001 From: Rudolph Pienaar Date: Tue, 30 Apr 2024 13:56:54 -0400 Subject: [PATCH] Updates and refactoring -- ready for production --- pfms/controllers/pfmsController_inference.py | 80 +++++++++++++++----- 1 file changed, 60 insertions(+), 20 deletions(-) diff --git a/pfms/controllers/pfmsController_inference.py b/pfms/controllers/pfmsController_inference.py index 6f626ef..85e42e1 100644 --- a/pfms/controllers/pfmsController_inference.py +++ b/pfms/controllers/pfmsController_inference.py @@ -1,11 +1,13 @@ -from fastapi import APIRouter, Query, Request -from fastapi import File, UploadFile +from fastapi import APIRouter, Query, Request, UploadFile + from fastapi.encoders import jsonable_encoder from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel, Field from typing import Optional, List, Dict, Callable, Any import asyncio + +from torch import parse_ir from models import iresponse import os from datetime import datetime @@ -26,6 +28,7 @@ from spleenseg import spleenseg as spleen from spleenseg import splparser as psr +from starlette.responses import FileResponse LOG = logger.debug @@ -50,16 +53,39 @@ def noop(): return {"status": True} -def showAll() -> iresponse.modelsAvailable: - resp: iresponse.modelsAvailable = iresponse.modelsAvailable() - resp.models = ["spleensegmentation.pth"] +def modelLocation_get(modelName) -> Path: + modelDir = settings.modelMeta.location / Path(modelName) + fileLocation: Path = modelDir / "model.pth" + return fileLocation + + +async def model_save( + remote: UploadFile, modelName: str +) -> iresponse.ModelUploadResponse: + modelUpload: iresponse.ModelUploadResponse = iresponse.ModelUploadResponse() + modelLocation: Path = modelLocation_get(modelName) + modelLocation.parent.mkdir(parents=True, exist_ok=True) + with modelLocation.open("wb") as buffer: + contents: bytes = await remote.read() + buffer.write(contents) + modelUpload.status = True + modelUpload.message = f"model {modelName} saved successfully" + modelUpload.location = modelLocation + return modelUpload + + +def models_list() -> iresponse.ModelsAvailable: + resp: iresponse.ModelsAvailable = iresponse.ModelsAvailable() + resp.models = os.listdir(str(settings.modelMeta.location)) return resp def IOpaths_create() -> tuple[Path, Path]: sessionUUID: UUID = uuid.uuid4() - baseParentPath: Path = Path( - Path(tempfile.mkdtemp()) / "imagesTs" / str(sessionUUID) + now: datetime = datetime.now() + nowstr: str = now.strftime("%Y-%m-%d-%H-%M-%S") + baseParentPath: Path = ( + settings.analysisMeta.location / f"{nowstr}-{str(sessionUUID)}" ) inputDir: Path = baseParentPath / "input" outputDir: Path = baseParentPath / "output" @@ -68,25 +94,39 @@ def IOpaths_create() -> tuple[Path, Path]: return inputDir, outputDir -def NIfTIvol_save(remote: File, inputDir: Path) -> Path: - savePath: Path = inputDir / "input.nii.gz" +def NIfTIvol_saveInput(remote: UploadFile, inputDir: Path) -> Path: + saveDir: Path = inputDir / "imagesTs" + saveDir.mkdir(parents=True, exist_ok=True) + savePath: Path = saveDir / "input.nii.gz" with savePath.open("wb") as buffer: shutil.copyfileobj(remote.file, buffer) return savePath -def NIfTIvol_infer(remote: File) -> Path: - inputDir: Path - outputDir: Path - inputDir, outputDir = IOpaths_create() - NIfTIvol_save(remote, inputDir) - options: Namespace = psr.parser_interpret(psr.parser_setup("")) +def NIfTIvol_infer(remote: UploadFile, inputDir: Path, outputDir: Path) -> Path: + NIfTIvol_saveInput(remote, inputDir) + options: Namespace = psr.parser_interpret(psr.parser_setup(""), asModule=True) options.mode = "inference" + options.device = "cuda:0" + options.logTransformVols = True + options.inputdir = str(inputDir) + options.outputdir = str(outputDir) spleen.main(options, inputDir, outputDir) - return outputDir + outputFile: Path = outputDir / "inference" / "output.nii.gz" + return outputFile + +def model_prep(modelID: str, targetDir: Path): + modelLocation: Path = modelLocation_get(modelID) + shutil.copy(modelLocation, targetDir) -async def inferenceOnNIfTI(uploadFile: File) -> iresponse.inferenceResponseNIFTI: - iresp: iresponse.inferenceResponseNIFTI = iresponse.inferenceResponseNIFTI() - result: Path = NIfTIvol_infer(uploadFile) - return iresp + +async def inferenceOnNIfTI(uploadFile: UploadFile, modelID: str = "") -> FileResponse: + inputDir: Path + outputDir: Path + inputDir, outputDir = IOpaths_create() + model_prep(modelID, inputDir) + result: Path = NIfTIvol_infer(uploadFile, inputDir, outputDir) + return FileResponse( + path=result, media_type="application/octet-stream", filename="output.nii.gz" + )