Skip to content

Commit

Permalink
Updates and refactoring -- ready for production
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 30, 2024
1 parent 63936c7 commit e39be17
Showing 1 changed file with 60 additions and 20 deletions.
80 changes: 60 additions & 20 deletions pfms/controllers/pfmsController_inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from fastapi import APIRouter, Query, Request
from fastapi import File, UploadFile
from fastapi import APIRouter, Query, Request, UploadFile

from fastapi.encoders import jsonable_encoder
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Callable, Any

import asyncio

from torch import parse_ir
from models import iresponse
import os
from datetime import datetime
Expand All @@ -26,6 +28,7 @@

from spleenseg import spleenseg as spleen
from spleenseg import splparser as psr
from starlette.responses import FileResponse

LOG = logger.debug

Expand All @@ -50,16 +53,39 @@ def noop():
return {"status": True}


def showAll() -> iresponse.modelsAvailable:
resp: iresponse.modelsAvailable = iresponse.modelsAvailable()
resp.models = ["spleensegmentation.pth"]
def modelLocation_get(modelName) -> Path:
modelDir = settings.modelMeta.location / Path(modelName)
fileLocation: Path = modelDir / "model.pth"
return fileLocation


async def model_save(
remote: UploadFile, modelName: str
) -> iresponse.ModelUploadResponse:
modelUpload: iresponse.ModelUploadResponse = iresponse.ModelUploadResponse()
modelLocation: Path = modelLocation_get(modelName)
modelLocation.parent.mkdir(parents=True, exist_ok=True)
with modelLocation.open("wb") as buffer:
contents: bytes = await remote.read()
buffer.write(contents)
modelUpload.status = True
modelUpload.message = f"model {modelName} saved successfully"
modelUpload.location = modelLocation
return modelUpload


def models_list() -> iresponse.ModelsAvailable:
resp: iresponse.ModelsAvailable = iresponse.ModelsAvailable()
resp.models = os.listdir(str(settings.modelMeta.location))
return resp


def IOpaths_create() -> tuple[Path, Path]:
sessionUUID: UUID = uuid.uuid4()
baseParentPath: Path = Path(
Path(tempfile.mkdtemp()) / "imagesTs" / str(sessionUUID)
now: datetime = datetime.now()
nowstr: str = now.strftime("%Y-%m-%d-%H-%M-%S")
baseParentPath: Path = (
settings.analysisMeta.location / f"{nowstr}-{str(sessionUUID)}"
)
inputDir: Path = baseParentPath / "input"
outputDir: Path = baseParentPath / "output"
Expand All @@ -68,25 +94,39 @@ def IOpaths_create() -> tuple[Path, Path]:
return inputDir, outputDir


def NIfTIvol_save(remote: File, inputDir: Path) -> Path:
savePath: Path = inputDir / "input.nii.gz"
def NIfTIvol_saveInput(remote: UploadFile, inputDir: Path) -> Path:
saveDir: Path = inputDir / "imagesTs"
saveDir.mkdir(parents=True, exist_ok=True)
savePath: Path = saveDir / "input.nii.gz"
with savePath.open("wb") as buffer:
shutil.copyfileobj(remote.file, buffer)
return savePath


def NIfTIvol_infer(remote: File) -> Path:
inputDir: Path
outputDir: Path
inputDir, outputDir = IOpaths_create()
NIfTIvol_save(remote, inputDir)
options: Namespace = psr.parser_interpret(psr.parser_setup(""))
def NIfTIvol_infer(remote: UploadFile, inputDir: Path, outputDir: Path) -> Path:
NIfTIvol_saveInput(remote, inputDir)
options: Namespace = psr.parser_interpret(psr.parser_setup(""), asModule=True)
options.mode = "inference"
options.device = "cuda:0"
options.logTransformVols = True
options.inputdir = str(inputDir)
options.outputdir = str(outputDir)
spleen.main(options, inputDir, outputDir)
return outputDir
outputFile: Path = outputDir / "inference" / "output.nii.gz"
return outputFile


def model_prep(modelID: str, targetDir: Path):
modelLocation: Path = modelLocation_get(modelID)
shutil.copy(modelLocation, targetDir)

async def inferenceOnNIfTI(uploadFile: File) -> iresponse.inferenceResponseNIFTI:
iresp: iresponse.inferenceResponseNIFTI = iresponse.inferenceResponseNIFTI()
result: Path = NIfTIvol_infer(uploadFile)
return iresp

async def inferenceOnNIfTI(uploadFile: UploadFile, modelID: str = "") -> FileResponse:
inputDir: Path
outputDir: Path
inputDir, outputDir = IOpaths_create()
model_prep(modelID, inputDir)
result: Path = NIfTIvol_infer(uploadFile, inputDir, outputDir)
return FileResponse(
path=result, media_type="application/octet-stream", filename="output.nii.gz"
)

0 comments on commit e39be17

Please sign in to comment.