Skip to content

Commit

Permalink
Add check for OpenMP compile state to options (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm authored Apr 19, 2024
1 parent 05ed009 commit c3d9cfe
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
36 changes: 24 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,42 @@ message(STATUS "Using CMake version: " ${CMAKE_VERSION})
# https://github.com/pybind/pybind11/issues/4825
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF)

# Enable OpenMP if requested and available
option(JAX_FINUFFT_USE_OPENMP "Enable OpenMP" ON)
if(JAX_FINUFFT_USE_OPENMP)
find_package(OpenMP)
if(OpenMP_CXX_FOUND)
message(STATUS "jax_finufft: OpenMP found")
set(FINUFFT_USE_OPENMP ON)
else()
message(STATUS "jax_finufft: OpenMP not found")
set(FINUFFT_USE_OPENMP OFF)
endif()
else()
set(FINUFFT_USE_OPENMP OFF)
endif()

# Enable CUDA if requested and available
option(JAX_FINUFFT_USE_CUDA "Enable CUDA build" OFF)

if(JAX_FINUFFT_USE_CUDA)
include(CheckLanguage)
check_language(CUDA)

if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA compiler found; compiling with GPU support")
message(STATUS "jax_finufft: CUDA compiler found; compiling with GPU support")
enable_language(CUDA)
set(FINUFFT_USE_CUDA ON)
else()
message(FATAL_ERROR "No CUDA compiler found! Please ensure the CUDA Toolkit "
"is installed, or set JAX_FINUFFT_USE_CUDA=OFF to disable GPU support.")
message(FATAL_ERROR "jax_finufft: No CUDA compiler found! Please ensure the "
"CUDA Toolkit is installed, or set JAX_FINUFFT_USE_CUDA=OFF to disable "
"GPU support.")
set(FINUFFT_USE_CUDA OFF)
endif()
else()
message(STATUS "GPU support was not requested")
message(STATUS "jax_finufft: GPU support was not requested")
set(FINUFFT_USE_CUDA OFF)
endif()

if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
# TODO(dfm): OpenMP segfaults on my system - can we enable this somehow?
set(FINUFFT_USE_OPENMP OFF)
else()
set(FINUFFT_USE_OPENMP ON)
endif()

# Add the FINUFFT project using the vendored version
add_subdirectory("${CMAKE_CURRENT_LIST_DIR}/vendor/finufft")

Expand All @@ -49,6 +57,10 @@ pybind11_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cp
target_link_libraries(jax_finufft_cpu PRIVATE finufft_static)
install(TARGETS jax_finufft_cpu LIBRARY DESTINATION .)

if (FINUFFT_USE_OPENMP)
target_compile_definitions(jax_finufft_cpu PRIVATE FINUFFT_USE_OPENMP)
endif()

# Include the CUDA extensions if possible - see above for where this is set
if(FINUFFT_USE_CUDA)
enable_language(CUDA)
Expand Down
8 changes: 8 additions & 0 deletions lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ PYBIND11_MODULE(jax_finufft_cpu, m) {
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);

m.def("_omp_compile_check", []() {
#ifdef FINUFFT_USE_OPENMP
return true;
#else
return false;
#endif
});

m.attr("FFTW_ESTIMATE") = py::int_(FFTW_ESTIMATE);
m.attr("FFTW_MEASURE") = py::int_(FFTW_MEASURE);
m.attr("FFTW_PATIENT") = py::int_(FFTW_PATIENT);
Expand Down
3 changes: 2 additions & 1 deletion src/jax_finufft/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ class Opts:
gpu_maxbatchsize: int = 0

def to_finufft_opts(self):
compiled_with_omp = jax_finufft_cpu._omp_compile_check()
return jax_finufft_cpu.FinufftOpts(
modeord=self.modeord,
chkbnds=self.chkbnds,
debug=self.debug,
spread_debug=self.spread_debug,
nthreads=self.nthreads,
nthreads=self.nthreads if compiled_with_omp else 1,
fftw=self.fftw,
spread_sort=self.spread_sort,
spread_kerevalmeth=self.spread_kerevalmeth,
Expand Down

0 comments on commit c3d9cfe

Please sign in to comment.