Skip to content

Commit

Permalink
Fix issues with inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 10, 2024
1 parent 9ceb256 commit 03d42a8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 6 additions & 4 deletions synapse_net/inference/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
return {"tile": tile, "halo": halo}

if torch.cuda.is_available():
# We always use the same default halo.
halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8
# The default halo size.
halo = {"x": 64, "y": 64, "z": 16}

# Determine the GPU RAM and derive a suitable tiling.
vram = torch.cuda.get_device_properties(0).total_memory / 1e9
Expand All @@ -426,9 +426,11 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
tile = {"x": 512, "y": 512, "z": 64}
elif vram >= 20:
tile = {"x": 352, "y": 352, "z": 48}
elif vram >= 10:
tile = {"x": 256, "y": 256, "z": 32}
halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z.
else:
# TODO determine tilings for smaller VRAM
raise NotImplementedError(f"Estimating the tile size for a GPU with {vram} GB is not yet supported.")
raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")

print(f"Determined tile size: {tile}")
tiling = {"tile": tile, "halo": halo}
Expand Down
4 changes: 2 additions & 2 deletions synapse_net/tools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def segmentation_cli():
)
parser.add_argument(
"--tile_shape", type=int, nargs=3,
help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient."
help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient."
)
parser.add_argument(
"--halo", type=int, nargs=3,
help="The halo for prediction. Increase the halo to minimize boundary artifacts."
help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts."
)
parser.add_argument(
"--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
Expand Down

0 comments on commit 03d42a8

Please sign in to comment.