From 816b4d58907653df0a9ad3ff62ffd704f5048793 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 22 Apr 2024 20:10:24 +0100 Subject: [PATCH] Replacing pybind11 with nanobind (#83) * Replacing pybind11 with nanobind * removing extra unused namespace * build: nanobind doesn't enable LTO by default, no need to globally disable * Including @lgarrison's suggestion --------- Co-authored-by: Lehman Garrison --- CMakeLists.txt | 26 ++++++----- lib/jax_finufft_cpu.cc | 85 +++++++++++++++-------------------- lib/jax_finufft_gpu.cc | 68 ++++++++++++---------------- lib/kernel_helpers.h | 5 --- lib/nanobind_kernel_helpers.h | 28 ++++++++++++ lib/pybind11_kernel_helpers.h | 28 ------------ pyproject.toml | 2 +- src/jax_finufft/options.py | 56 ++++++++++++----------- 8 files changed, 139 insertions(+), 159 deletions(-) create mode 100644 lib/nanobind_kernel_helpers.h delete mode 100644 lib/pybind11_kernel_helpers.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9691099..1b12ee5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,18 +1,16 @@ -cmake_minimum_required(VERSION 3.15...3.26) +cmake_minimum_required(VERSION 3.15...3.27) project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX) message(STATUS "Using CMake version: " ${CMAKE_VERSION}) # for cuda-gdb and verbose PTXAS output # set(CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} "-g -G -Xptxas -v") -# Workaround for LTO applied incorrectly to CUDA fatbin -# https://github.com/pybind/pybind11/issues/4825 -set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF) - # Enable OpenMP if requested and available option(JAX_FINUFFT_USE_OPENMP "Enable OpenMP" ON) + if(JAX_FINUFFT_USE_OPENMP) find_package(OpenMP) + if(OpenMP_CXX_FOUND) message(STATUS "jax_finufft: OpenMP found") set(FINUFFT_USE_OPENMP ON) @@ -26,6 +24,7 @@ endif() # Enable CUDA if requested and available option(JAX_FINUFFT_USE_CUDA "Enable CUDA build" OFF) + if(JAX_FINUFFT_USE_CUDA) include(CheckLanguage) check_language(CUDA) @@ -48,16 +47,21 @@ endif() # Add the FINUFFT project using the vendored version add_subdirectory("${CMAKE_CURRENT_LIST_DIR}/vendor/finufft") -# Find pybind11 -set(PYBIND11_NEWPYTHON ON) -find_package(pybind11 CONFIG REQUIRED) +# Find Python and nanobind +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) +find_package(nanobind CONFIG REQUIRED) + +if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() # Build the CPU XLA bindings -pybind11_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc) +nanobind_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc) target_link_libraries(jax_finufft_cpu PRIVATE finufft_static) install(TARGETS jax_finufft_cpu LIBRARY DESTINATION .) -if (FINUFFT_USE_OPENMP) +if(FINUFFT_USE_OPENMP) target_compile_definitions(jax_finufft_cpu PRIVATE FINUFFT_USE_OPENMP) endif() @@ -75,7 +79,7 @@ if(FINUFFT_USE_CUDA) ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include/cufinufft/contrib/cuda_samples ) - pybind11_add_module(jax_finufft_gpu + nanobind_add_module(jax_finufft_gpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc ${CMAKE_CURRENT_LIST_DIR}/lib/cufinufft_wrapper.cc ${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu) diff --git a/lib/jax_finufft_cpu.cc b/lib/jax_finufft_cpu.cc index 6156afc..a89d8ba 100644 --- a/lib/jax_finufft_cpu.cc +++ b/lib/jax_finufft_cpu.cc @@ -1,14 +1,14 @@ // This file defines the Python interface to the XLA custom call implemented on the CPU. -// It is exposed as a standard pybind11 module defining "capsule" objects containing our +// It is exposed as a standard nanobind module defining "capsule" objects containing our // method. For simplicity, we export a separate capsule for each supported dtype. #include "jax_finufft_cpu.h" -#include "pybind11_kernel_helpers.h" +#include "nanobind_kernel_helpers.h" using namespace jax_finufft; using namespace jax_finufft::cpu; -namespace py = pybind11; +namespace nb = nanobind; namespace { @@ -68,41 +68,14 @@ void nufft2(void *out, void **in) { } template -py::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j, +nb::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j, int64_t n_k_1, int64_t n_k_2, int64_t n_k_3, finufft_opts opts) { return pack_descriptor( descriptor{eps, iflag, n_tot, n_transf, n_j, {n_k_1, n_k_2, n_k_3}, opts}); } -template -finufft_opts *build_opts(bool modeord, bool chkbnds, int debug, int spread_debug, bool showwarn, - int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth, - bool spread_kerpad, double upsampfac, int spread_thread, int maxbatchsize, - int spread_nthr_atomic, int spread_max_sp_size) { - finufft_opts *opts = new finufft_opts; - default_opts(opts); - - opts->modeord = int(modeord); - opts->chkbnds = int(chkbnds); - opts->debug = debug; - opts->spread_debug = spread_debug; - opts->showwarn = int(showwarn); - opts->nthreads = nthreads; - opts->fftw = fftw; - opts->spread_sort = spread_sort; - opts->spread_kerevalmeth = int(spread_kerevalmeth); - opts->spread_kerpad = int(spread_kerpad); - opts->upsampfac = upsampfac; - opts->spread_thread = int(spread_thread); - opts->maxbatchsize = maxbatchsize; - opts->spread_nthr_atomic = spread_nthr_atomic; - opts->spread_max_sp_size = spread_max_sp_size; - - return opts; -} - -pybind11::dict Registrations() { - pybind11::dict dict; +nb::dict Registrations() { + nb::dict dict; dict["nufft1d1f"] = encapsulate_function(nufft1<1, float>); dict["nufft1d2f"] = encapsulate_function(nufft2<1, float>); @@ -121,7 +94,7 @@ pybind11::dict Registrations() { return dict; } -PYBIND11_MODULE(jax_finufft_cpu, m) { +NB_MODULE(jax_finufft_cpu, m) { m.def("registrations", &Registrations); m.def("build_descriptorf", &build_descriptor); m.def("build_descriptor", &build_descriptor); @@ -134,20 +107,36 @@ PYBIND11_MODULE(jax_finufft_cpu, m) { #endif }); - m.attr("FFTW_ESTIMATE") = py::int_(FFTW_ESTIMATE); - m.attr("FFTW_MEASURE") = py::int_(FFTW_MEASURE); - m.attr("FFTW_PATIENT") = py::int_(FFTW_PATIENT); - m.attr("FFTW_EXHAUSTIVE") = py::int_(FFTW_EXHAUSTIVE); - m.attr("FFTW_WISDOM_ONLY") = py::int_(FFTW_WISDOM_ONLY); - - py::class_ opts(m, "FinufftOpts"); - opts.def(py::init(&build_opts), py::arg("modeord") = false, py::arg("chkbnds") = true, - py::arg("debug") = 0, py::arg("spread_debug") = 0, py::arg("showwarn") = false, - py::arg("nthreads") = 0, py::arg("fftw") = int(FFTW_ESTIMATE), - py::arg("spread_sort") = 2, py::arg("spread_kerevalmeth") = true, - py::arg("spread_kerpad") = true, py::arg("upsampfac") = 0.0, - py::arg("spread_thread") = 0, py::arg("maxbatchsize") = 0, - py::arg("spread_nthr_atomic") = -1, py::arg("spread_max_sp_size") = 0); + m.attr("FFTW_ESTIMATE") = nb::int_(FFTW_ESTIMATE); + m.attr("FFTW_MEASURE") = nb::int_(FFTW_MEASURE); + m.attr("FFTW_PATIENT") = nb::int_(FFTW_PATIENT); + m.attr("FFTW_EXHAUSTIVE") = nb::int_(FFTW_EXHAUSTIVE); + m.attr("FFTW_WISDOM_ONLY") = nb::int_(FFTW_WISDOM_ONLY); + + nb::class_ opts(m, "FinufftOpts"); + opts.def("__init__", + [](finufft_opts *self, bool modeord, bool chkbnds, int debug, int spread_debug, + bool showwarn, int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth, + bool spread_kerpad, double upsampfac, int spread_thread, int maxbatchsize, + int spread_nthr_atomic, int spread_max_sp_size) { + new (self) finufft_opts; + default_opts(self); + self->modeord = int(modeord); + self->chkbnds = int(chkbnds); + self->debug = debug; + self->spread_debug = spread_debug; + self->showwarn = int(showwarn); + self->nthreads = nthreads; + self->fftw = fftw; + self->spread_sort = spread_sort; + self->spread_kerevalmeth = int(spread_kerevalmeth); + self->spread_kerpad = int(spread_kerpad); + self->upsampfac = upsampfac; + self->spread_thread = int(spread_thread); + self->maxbatchsize = maxbatchsize; + self->spread_nthr_atomic = spread_nthr_atomic; + self->spread_max_sp_size = spread_max_sp_size; + }); } } // namespace diff --git a/lib/jax_finufft_gpu.cc b/lib/jax_finufft_gpu.cc index 24a72be..56da6f8 100644 --- a/lib/jax_finufft_gpu.cc +++ b/lib/jax_finufft_gpu.cc @@ -1,51 +1,26 @@ // This file defines the Python interface to the XLA custom call implemented on the CPU. -// It is exposed as a standard pybind11 module defining "capsule" objects containing our +// It is exposed as a standard nanobind module defining "capsule" objects containing our // method. For simplicity, we export a separate capsule for each supported dtype. #include "cufinufft_wrapper.h" #include "kernels.h" -#include "pybind11_kernel_helpers.h" +#include "nanobind_kernel_helpers.h" using namespace jax_finufft; using namespace jax_finufft::gpu; -namespace py = pybind11; +namespace nb = nanobind; namespace { template -py::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j, +nb::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j, int64_t n_k_1, int64_t n_k_2, int64_t n_k_3, cufinufft_opts opts) { return pack_descriptor( descriptor{eps, iflag, n_tot, n_transf, n_j, {n_k_1, n_k_2, n_k_3}, opts}); } -template -cufinufft_opts *build_opts(double upsampfac, int gpu_method, bool gpu_sort, int gpu_binsizex, - int gpu_binsizey, int gpu_binsizez, int gpu_obinsizex, - int gpu_obinsizey, int gpu_obinsizez, int gpu_maxsubprobsize, - bool gpu_kerevalmeth, int gpu_spreadinterponly, int gpu_maxbatchsize) { - cufinufft_opts *opts = new cufinufft_opts; - default_opts(opts); - - opts->upsampfac = upsampfac; - opts->gpu_method = gpu_method; - opts->gpu_sort = int(gpu_sort); - opts->gpu_binsizex = gpu_binsizex; - opts->gpu_binsizey = gpu_binsizey; - opts->gpu_binsizez = gpu_binsizez; - opts->gpu_obinsizex = gpu_obinsizex; - opts->gpu_obinsizey = gpu_obinsizey; - opts->gpu_obinsizez = gpu_obinsizez; - opts->gpu_maxsubprobsize = gpu_maxsubprobsize; - opts->gpu_kerevalmeth = gpu_kerevalmeth; - opts->gpu_spreadinterponly = gpu_spreadinterponly; - opts->gpu_maxbatchsize = gpu_maxbatchsize; - - return opts; -} - -pybind11::dict Registrations() { - pybind11::dict dict; +nb::dict Registrations() { + nb::dict dict; // TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"? // dict["nufft1d1f"] = encapsulate_function(nufft1d1f); @@ -65,18 +40,33 @@ pybind11::dict Registrations() { return dict; } -PYBIND11_MODULE(jax_finufft_gpu, m) { +NB_MODULE(jax_finufft_gpu, m) { m.def("registrations", &Registrations); m.def("build_descriptorf", &build_descriptor); m.def("build_descriptor", &build_descriptor); - py::class_ opts(m, "CufinufftOpts"); - opts.def(py::init(&build_opts), py::arg("upsampfac") = 2.0, py::arg("gpu_method") = 0, - py::arg("gpu_sort") = true, py::arg("gpu_binsizex") = -1, py::arg("gpu_binsizey") = -1, - py::arg("gpu_binsizez") = -1, py::arg("gpu_obinsizex") = -1, - py::arg("gpu_obinsizey") = -1, py::arg("gpu_obinsizez") = -1, - py::arg("gpu_maxsubprobsize") = 1024, py::arg("gpu_kerevalmeth") = true, - py::arg("gpu_spreadinterponly") = 0, py::arg("gpu_maxbatchsize") = 0); + nb::class_ opts(m, "CufinufftOpts"); + opts.def("__init__", [](cufinufft_opts *self, double upsampfac, int gpu_method, bool gpu_sort, + int gpu_binsizex, int gpu_binsizey, int gpu_binsizez, int gpu_obinsizex, + int gpu_obinsizey, int gpu_obinsizez, int gpu_maxsubprobsize, + bool gpu_kerevalmeth, int gpu_spreadinterponly, int gpu_maxbatchsize) { + new (self) cufinufft_opts; + default_opts(self); + + self->upsampfac = upsampfac; + self->gpu_method = gpu_method; + self->gpu_sort = int(gpu_sort); + self->gpu_binsizex = gpu_binsizex; + self->gpu_binsizey = gpu_binsizey; + self->gpu_binsizez = gpu_binsizez; + self->gpu_obinsizex = gpu_obinsizex; + self->gpu_obinsizey = gpu_obinsizey; + self->gpu_obinsizez = gpu_obinsizez; + self->gpu_maxsubprobsize = gpu_maxsubprobsize; + self->gpu_kerevalmeth = gpu_kerevalmeth; + self->gpu_spreadinterponly = gpu_spreadinterponly; + self->gpu_maxbatchsize = gpu_maxbatchsize; + }); } } // namespace diff --git a/lib/kernel_helpers.h b/lib/kernel_helpers.h index 766a4bf..f725e01 100644 --- a/lib/kernel_helpers.h +++ b/lib/kernel_helpers.h @@ -28,11 +28,6 @@ bit_cast(const From& src) noexcept { return dst; } -template -std::string pack_descriptor_as_string(const T& descriptor) { - return std::string(bit_cast(&descriptor), sizeof(T)); -} - template const T* unpack_descriptor(const char* opaque, std::size_t opaque_len) { if (opaque_len != sizeof(T)) { diff --git a/lib/nanobind_kernel_helpers.h b/lib/nanobind_kernel_helpers.h new file mode 100644 index 0000000..de7042e --- /dev/null +++ b/lib/nanobind_kernel_helpers.h @@ -0,0 +1,28 @@ +// This header extends kernel_helpers.h with the nanobind specific interface to +// serializing descriptors. It also adds a nanobind function for wrapping our +// custom calls in a Python capsule. This is separate from kernel_helpers so that +// the CUDA code itself doesn't include nanobind. I don't think that this is +// strictly necessary, but they do it in jaxlib, so let's do it here too. + +#ifndef _JAX_FINUFFT_NANOBIND_KERNEL_HELPERS_H_ +#define _JAX_FINUFFT_NANOBIND_KERNEL_HELPERS_H_ + +#include + +#include "kernel_helpers.h" + +namespace jax_finufft { + +template +nanobind::bytes pack_descriptor(const T& descriptor) { + return nanobind::bytes(bit_cast(&descriptor), sizeof(T)); +} + +template +nanobind::capsule encapsulate_function(T* fn) { + return nanobind::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +} // namespace jax_finufft + +#endif diff --git a/lib/pybind11_kernel_helpers.h b/lib/pybind11_kernel_helpers.h deleted file mode 100644 index 58f6c25..0000000 --- a/lib/pybind11_kernel_helpers.h +++ /dev/null @@ -1,28 +0,0 @@ -// This header extends kernel_helpers.h with the pybind11 specific interface to -// serializing descriptors. It also adds a pybind11 function for wrapping our -// custom calls in a Python capsule. This is separate from kernel_helpers so that -// the CUDA code itself doesn't include pybind11. I don't think that this is -// strictly necessary, but they do it in jaxlib, so let's do it here too. - -#ifndef _JAX_FINUFFT_PYBIND11_KERNEL_HELPERS_H_ -#define _JAX_FINUFFT_PYBIND11_KERNEL_HELPERS_H_ - -#include - -#include "kernel_helpers.h" - -namespace jax_finufft { - -template -pybind11::bytes pack_descriptor(const T& descriptor) { - return pybind11::bytes(pack_descriptor_as_string(descriptor)); -} - -template -pybind11::capsule encapsulate_function(T* fn) { - return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); -} - -} // namespace jax_finufft - -#endif diff --git a/pyproject.toml b/pyproject.toml index 98dbfd4..8a6e119 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["pybind11>=2.6", "scikit-build-core>=0.5"] +requires = ["nanobind", "scikit-build-core>=0.5"] build-backend = "scikit_build_core.build" [project] diff --git a/src/jax_finufft/options.py b/src/jax_finufft/options.py index c2f38e5..ec08f77 100644 --- a/src/jax_finufft/options.py +++ b/src/jax_finufft/options.py @@ -44,6 +44,7 @@ class Opts: chkbnds: bool = True debug: DebugLevel = DebugLevel.Silent spread_debug: DebugLevel = DebugLevel.Silent + showwarn: bool = False nthreads: int = 0 fftw: int = FftwFlags.Estimate spread_sort: SpreadSort = SpreadSort.Heuristic @@ -72,39 +73,40 @@ class Opts: def to_finufft_opts(self): compiled_with_omp = jax_finufft_cpu._omp_compile_check() return jax_finufft_cpu.FinufftOpts( - modeord=self.modeord, - chkbnds=self.chkbnds, - debug=self.debug, - spread_debug=self.spread_debug, - nthreads=self.nthreads if compiled_with_omp else 1, - fftw=self.fftw, - spread_sort=self.spread_sort, - spread_kerevalmeth=self.spread_kerevalmeth, - spread_kerpad=self.spread_kerpad, - upsampfac=self.upsampfac, - spread_thread=self.spread_thread, - maxbatchsize=self.maxbatchsize, - spread_nthr_atomic=self.spread_nthr_atomic, - spread_max_sp_size=self.spread_max_sp_size, + self.modeord, + self.chkbnds, + int(self.debug), + int(self.spread_debug), + self.showwarn, + self.nthreads if compiled_with_omp else 1, + int(self.fftw), + int(self.spread_sort), + self.spread_kerevalmeth, + self.spread_kerpad, + self.upsampfac, + int(self.spread_thread), + self.maxbatchsize, + self.spread_nthr_atomic, + self.spread_max_sp_size, ) def to_cufinufft_opts(self): from jax_finufft import jax_finufft_gpu return jax_finufft_gpu.CufinufftOpts( - upsampfac=self.gpu_upsampfac, - gpu_method=self.gpu_method, - gpu_sort=self.gpu_sort, - gpu_binsizex=self.gpu_binsizex, - gpu_binsizey=self.gpu_binsizey, - gpu_binsizez=self.gpu_binsizez, - gpu_obinsizex=self.gpu_obinsizex, - gpu_obinsizey=self.gpu_obinsizey, - gpu_obinsizez=self.gpu_obinsizez, - gpu_maxsubprobsize=self.gpu_maxsubprobsize, - gpu_kerevalmeth=self.gpu_kerevalmeth, - gpu_spreadinterponly=self.gpu_spreadinterponly, - gpu_maxbatchsize=self.gpu_maxbatchsize, + self.gpu_upsampfac, + int(self.gpu_method), + self.gpu_sort, + self.gpu_binsizex, + self.gpu_binsizey, + self.gpu_binsizez, + self.gpu_obinsizex, + self.gpu_obinsizey, + self.gpu_obinsizez, + self.gpu_maxsubprobsize, + self.gpu_kerevalmeth, + self.gpu_spreadinterponly, + self.gpu_maxbatchsize, )