From 03d42a8c8f0e4a68949be14292479065ddb21ba7 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 10 Dec 2024 21:54:41 +0100 Subject: [PATCH] Fix issues with inference code --- synapse_net/inference/util.py | 10 ++++++---- synapse_net/tools/cli.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index 1ad3a73..ea92f29 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -414,8 +414,8 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: return {"tile": tile, "halo": halo} if torch.cuda.is_available(): - # We always use the same default halo. - halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8 + # The default halo size. + halo = {"x": 64, "y": 64, "z": 16} # Determine the GPU RAM and derive a suitable tiling. vram = torch.cuda.get_device_properties(0).total_memory / 1e9 @@ -426,9 +426,11 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: tile = {"x": 512, "y": 512, "z": 64} elif vram >= 20: tile = {"x": 352, "y": 352, "z": 48} + elif vram >= 10: + tile = {"x": 256, "y": 256, "z": 32} + halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. else: - # TODO determine tilings for smaller VRAM - raise NotImplementedError(f"Estimating the tile size for a GPU with {vram} GB is not yet supported.") + raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") print(f"Determined tile size: {tile}") tiling = {"tile": tile, "halo": halo} diff --git a/synapse_net/tools/cli.py b/synapse_net/tools/cli.py index 11caeb7..609bb0e 100644 --- a/synapse_net/tools/cli.py +++ b/synapse_net/tools/cli.py @@ -124,11 +124,11 @@ def segmentation_cli(): ) parser.add_argument( "--tile_shape", type=int, nargs=3, - help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient." + help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." ) parser.add_argument( "--halo", type=int, nargs=3, - help="The halo for prediction. Increase the halo to minimize boundary artifacts." + help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." ) parser.add_argument( "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."