From c3d9cfea136bc7cf89bacb892c0b4d1e3e638fe9 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 19 Apr 2024 21:16:37 +0100 Subject: [PATCH] Add check for OpenMP compile state to options (#82) --- CMakeLists.txt | 36 ++++++++++++++++++++++++------------ lib/jax_finufft_cpu.cc | 8 ++++++++ src/jax_finufft/options.py | 3 ++- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 04bd3c6..9691099 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,34 +9,42 @@ message(STATUS "Using CMake version: " ${CMAKE_VERSION}) # https://github.com/pybind/pybind11/issues/4825 set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF) +# Enable OpenMP if requested and available +option(JAX_FINUFFT_USE_OPENMP "Enable OpenMP" ON) +if(JAX_FINUFFT_USE_OPENMP) + find_package(OpenMP) + if(OpenMP_CXX_FOUND) + message(STATUS "jax_finufft: OpenMP found") + set(FINUFFT_USE_OPENMP ON) + else() + message(STATUS "jax_finufft: OpenMP not found") + set(FINUFFT_USE_OPENMP OFF) + endif() +else() + set(FINUFFT_USE_OPENMP OFF) +endif() + # Enable CUDA if requested and available option(JAX_FINUFFT_USE_CUDA "Enable CUDA build" OFF) - if(JAX_FINUFFT_USE_CUDA) include(CheckLanguage) check_language(CUDA) if(CMAKE_CUDA_COMPILER) - message(STATUS "CUDA compiler found; compiling with GPU support") + message(STATUS "jax_finufft: CUDA compiler found; compiling with GPU support") enable_language(CUDA) set(FINUFFT_USE_CUDA ON) else() - message(FATAL_ERROR "No CUDA compiler found! Please ensure the CUDA Toolkit " - "is installed, or set JAX_FINUFFT_USE_CUDA=OFF to disable GPU support.") + message(FATAL_ERROR "jax_finufft: No CUDA compiler found! Please ensure the " + "CUDA Toolkit is installed, or set JAX_FINUFFT_USE_CUDA=OFF to disable " + "GPU support.") set(FINUFFT_USE_CUDA OFF) endif() else() - message(STATUS "GPU support was not requested") + message(STATUS "jax_finufft: GPU support was not requested") set(FINUFFT_USE_CUDA OFF) endif() -if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - # TODO(dfm): OpenMP segfaults on my system - can we enable this somehow? - set(FINUFFT_USE_OPENMP OFF) -else() - set(FINUFFT_USE_OPENMP ON) -endif() - # Add the FINUFFT project using the vendored version add_subdirectory("${CMAKE_CURRENT_LIST_DIR}/vendor/finufft") @@ -49,6 +57,10 @@ pybind11_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cp target_link_libraries(jax_finufft_cpu PRIVATE finufft_static) install(TARGETS jax_finufft_cpu LIBRARY DESTINATION .) +if (FINUFFT_USE_OPENMP) + target_compile_definitions(jax_finufft_cpu PRIVATE FINUFFT_USE_OPENMP) +endif() + # Include the CUDA extensions if possible - see above for where this is set if(FINUFFT_USE_CUDA) enable_language(CUDA) diff --git a/lib/jax_finufft_cpu.cc b/lib/jax_finufft_cpu.cc index f6e1a49..6156afc 100644 --- a/lib/jax_finufft_cpu.cc +++ b/lib/jax_finufft_cpu.cc @@ -126,6 +126,14 @@ PYBIND11_MODULE(jax_finufft_cpu, m) { m.def("build_descriptorf", &build_descriptor); m.def("build_descriptor", &build_descriptor); + m.def("_omp_compile_check", []() { +#ifdef FINUFFT_USE_OPENMP + return true; +#else + return false; +#endif + }); + m.attr("FFTW_ESTIMATE") = py::int_(FFTW_ESTIMATE); m.attr("FFTW_MEASURE") = py::int_(FFTW_MEASURE); m.attr("FFTW_PATIENT") = py::int_(FFTW_PATIENT); diff --git a/src/jax_finufft/options.py b/src/jax_finufft/options.py index 90be3ed..c2f38e5 100644 --- a/src/jax_finufft/options.py +++ b/src/jax_finufft/options.py @@ -70,12 +70,13 @@ class Opts: gpu_maxbatchsize: int = 0 def to_finufft_opts(self): + compiled_with_omp = jax_finufft_cpu._omp_compile_check() return jax_finufft_cpu.FinufftOpts( modeord=self.modeord, chkbnds=self.chkbnds, debug=self.debug, spread_debug=self.spread_debug, - nthreads=self.nthreads, + nthreads=self.nthreads if compiled_with_omp else 1, fftw=self.fftw, spread_sort=self.spread_sort, spread_kerevalmeth=self.spread_kerevalmeth,