Skip to content

Commit

Permalink
prevent race condition when JIT in multiprocess (#312)
Browse files Browse the repository at this point in the history
Co-authored-by: Ruilong Li <[email protected]>
  • Loading branch information
liruilong940607 and Ruilong Li authored Aug 3, 2024
1 parent 40fd2fe commit 8a0e500
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 6 deletions.
42 changes: 39 additions & 3 deletions gsplat/cuda/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from subprocess import DEVNULL, call

from rich.console import Console
from torch.utils.cpp_extension import _get_build_directory, load
from torch.utils.cpp_extension import (
_get_build_directory,
_import_module_from_library,
load,
)

PATH = os.path.dirname(os.path.abspath(__file__))
NO_FAST_MATH = os.getenv("NO_FAST_MATH", "0") == "1"
Expand All @@ -16,6 +20,35 @@
os.environ["MAX_JOBS"] = "10"


def load_extension(
name,
sources,
extra_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
extra_include_paths=None,
build_directory=None,
):
"""Load a JIT compiled extension."""
# If the JIT build happens concurrently in multiple processes,
# race conditions can occur when removing the lock file at:
# https://github.com/pytorch/pytorch/blob/e3513fb2af7951ddf725d8c5b6f6d962a053c9da/torch/utils/cpp_extension.py#L1736
# But it's ok so we catch this exception and ignore it.
try:
return load(
name,
sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_ldflags=extra_ldflags,
extra_include_paths=extra_include_paths,
build_directory=build_directory,
)
except OSError:
# The module should be already compiled
return _import_module_from_library(name, build_directory, True)


def cuda_toolkit_available():
"""Check if the nvcc is avaiable on the machine."""
try:
Expand Down Expand Up @@ -74,12 +107,13 @@ def cuda_toolkit_version():
):
# If the build exists, we assume the extension has been built
# and we can load it.
_C = load(
_C = load_extension(
name=name,
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
build_directory=build_dir,
)
else:
# Build from scratch. Remove the build directory just to be safe: pytorch jit might stuck
Expand All @@ -89,13 +123,15 @@ def cuda_toolkit_version():
f"[bold yellow]gsplat: Setting up CUDA with MAX_JOBS={os.environ['MAX_JOBS']} (This may take a few minutes the first time)",
spinner="bouncingBall",
):
_C = load(
_C = load_extension(
name=name,
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
build_directory=build_dir,
)

else:
Console().print(
"[yellow]gsplat: No CUDA toolkit found. gsplat will be disabled.[/yellow]"
Expand Down
41 changes: 38 additions & 3 deletions gsplat/cuda_legacy/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from subprocess import DEVNULL, call

from rich.console import Console
from torch.utils.cpp_extension import _get_build_directory, load
from torch.utils.cpp_extension import (
_get_build_directory,
_import_module_from_library,
load,
)

PATH = os.path.dirname(os.path.abspath(__file__))
MAX_JOBS = os.getenv("MAX_JOBS")
Expand All @@ -15,6 +19,35 @@
os.environ["MAX_JOBS"] = "10"


def load_extension(
name,
sources,
extra_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
extra_include_paths=None,
build_directory=None,
):
"""Load a JIT compiled extension."""
# If the JIT build happens concurrently in multiple processes,
# race conditions can occur when removing the lock file at:
# https://github.com/pytorch/pytorch/blob/e3513fb2af7951ddf725d8c5b6f6d962a053c9da/torch/utils/cpp_extension.py#L1736
# But it's ok so we catch this exception and ignore it.
try:
return load(
name,
sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_ldflags=extra_ldflags,
extra_include_paths=extra_include_paths,
build_directory=build_directory,
)
except OSError:
# The module should be already compiled
return _import_module_from_library(name, build_directory, True)


def cuda_toolkit_available():
"""Check if the nvcc is avaiable on the machine."""
try:
Expand Down Expand Up @@ -71,12 +104,13 @@ def cuda_toolkit_version():
# If the build exists, we assume the extension has been built
# and we can load it.

_C = load(
_C = load_extension(
name=name,
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
build_directory=build_dir,
)
else:
# Build from scratch. Remove the build directory just to be safe: pytorch jit might stuck
Expand All @@ -86,12 +120,13 @@ def cuda_toolkit_version():
"[bold yellow]gsplat (legacy): Setting up CUDA (This may take a few minutes the first time)",
spinner="bouncingBall",
):
_C = load(
_C = load_extension(
name=name,
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
build_directory=build_dir,
)
else:
Console().print(
Expand Down

0 comments on commit 8a0e500

Please sign in to comment.