Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial code review for DLWMLS #2

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: pre-commit

on:
pull_request:
push:
branches: [main, spiros-dev]

jobs:
pre-commit:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: install pre commit
run: pip install pre-commit && pre-commit install
- name: pre-commit
uses: pre-commit/[email protected]
20 changes: 20 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
repos:
- repo: https://github.com/ambv/black
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.11.5
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
hooks:
- id: mypy
args: ["--ignore-missing-imports"]

41 changes: 15 additions & 26 deletions DLWMLS/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
import json
import os
from pathlib import Path
import shutil
import sys
import warnings
from pathlib import Path

import torch

Expand All @@ -15,15 +15,16 @@

VERSION = 1.0


def main() -> None:
prog="DLWMLS"
prog = "DLWMLS"
parser = argparse.ArgumentParser(
prog=prog,
description="DLWMLS - MUlti-atlas region Segmentation utilizing Ensembles of registration algorithms and parameters.",
usage="""
DLWMLS v{VERSION}
Segment White Matter Lesions from the ICV-segmented (see DLICV method), LPS oriented brain image (Nifti/.nii.gz format).

Required arguments:
[-i, --in_dir] The filepath of the input directory
[-o, --out_dir] The filepath of the output directory
Expand All @@ -36,8 +37,10 @@ def main() -> None:
-o /path/to/output \
-device cpu|cuda|mps

""".format(VERSION=VERSION),
add_help=False
""".format(
VERSION=VERSION
),
add_help=False,
)

# Required Arguments
Expand All @@ -53,7 +56,7 @@ def main() -> None:
required=True,
help="[REQUIRED] Output folder for the segmentation results in Nifti format (nii.gz).",
)

# Optional Arguments
parser.add_argument(
"-device",
Expand Down Expand Up @@ -99,7 +102,7 @@ def main() -> None:
action="store_true",
required=False,
default=False,
help="Set this flag to clear any cached models before running. This is recommended if a previous download failed."
help="Set this flag to clear any cached models before running. This is recommended if a previous download failed.",
)
parser.add_argument(
"--disable_tta",
Expand All @@ -109,13 +112,6 @@ def main() -> None:
help="[nnUnet Arg] Set this flag to disable test time data augmentation in the form of mirroring. "
"Faster, but less accurate inference. Not recommended.",
)
### DEPRECIATED ####
# parser.add_argument(
# "-m",
# type=str,
# required=True,
# help="Model folder path. The model folder should be named nnunet_results.",
# )
parser.add_argument(
"-d",
type=str,
Expand Down Expand Up @@ -208,25 +204,17 @@ def main() -> None:
required=False,
default=0,
help="[nnUnet Arg] If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 "
"can end with num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set "
"can end with num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set "
"-num_parts 5 and use -part_id 0, 1, 2, 3 and 4. Note: You are yourself responsible to make these run on separate GPUs! "
"Use CUDA_VISIBLE_DEVICES.",
)



args = parser.parse_args()
args.f = [args.f]

if args.clear_cache:
shutil.rmtree(os.path.join(
Path(__file__).parent,
"nnunet_results"
))
shutil.rmtree(os.path.join(
Path(__file__).parent,
".cache"
))
shutil.rmtree(os.path.join(Path(__file__).parent, "nnunet_results"))
shutil.rmtree(os.path.join(Path(__file__).parent, ".cache"))
if not args.i or not args.o:
print("Cache cleared and missing either -i / -o. Exiting.")
sys.exit(0)
Expand Down Expand Up @@ -263,14 +251,14 @@ def main() -> None:
% (args.d, args.d, args.c),
)


# Check if model exists. If not exist, download using HuggingFace
print(f"Using model folder: {model_folder}")
if not os.path.exists(model_folder):
# HF download model
print("DLWMLS model not found, downloading...")

from huggingface_hub import snapshot_download

local_src = Path(__file__).parent
snapshot_download(repo_id="nichart/DLWMLS", local_dir=local_src)

Expand All @@ -292,6 +280,7 @@ def main() -> None:

if args.device == "cpu":
import multiprocessing

# use half of the available threads in the system.
torch.set_num_threads(multiprocessing.cpu_count() // 2)
device = torch.device("cpu")
Expand Down

This file was deleted.

Loading