From 58769e8d4c0929817a4d6f1e1a4d8d03359fe57a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sat, 28 Sep 2024 13:15:07 +0800 Subject: [PATCH] add paddle backend to warp --- README.md | 2 +- docs/conf.py | 1 + docs/index.rst | 2 +- docs/installation.rst | 1 + docs/modules/interoperability.rst | 181 ++++++++++++++++++++ exts/omni.warp.core/config/extension.toml | 1 + warp/__init__.py | 5 + warp/dlpack.py | 2 + warp/examples/benchmarks/benchmark.bat | 2 + warp/examples/benchmarks/benchmark.sh | 2 + warp/examples/benchmarks/benchmark_cloth.py | 10 ++ warp/stubs.py | 5 + warp/tests/test_dlpack.py | 118 +++++++++++++ warp/thirdparty/dlpack.py | 4 +- 14 files changed, 333 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 44c5d326..04b97538 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ regular Python functions and JIT compiles them to efficient kernel code that can Warp is designed for [spatial computing](https://en.wikipedia.org/wiki/Spatial_computing) and comes with a rich set of primitives that make it easy to write programs for physics simulation, perception, robotics, and geometry processing. In addition, Warp kernels -are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch and JAX. +are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch, JAX and Paddle. Please refer to the project [Documentation](https://nvidia.github.io/warp/) for API and language reference and [CHANGELOG.md](./CHANGELOG.md) for release history. diff --git a/docs/conf.py b/docs/conf.py index 400d0c77..2e6ab2e1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,6 +68,7 @@ "numpy": ("https://numpy.org/doc/stable", None), "jax": ("https://jax.readthedocs.io/en/latest", None), "pytorch": ("https://pytorch.org/docs/stable", None), + "paddle": ("https://www.paddlepaddle.org.cn/", None), } extlinks = { diff --git a/docs/index.rst b/docs/index.rst index e3f45fa0..dc33aa38 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ regular Python functions and JIT compiles them to efficient kernel code that can Warp is designed for `spatial computing `_ and comes with a rich set of primitives that make it easy to write programs for physics simulation, perception, robotics, and geometry processing. In addition, Warp kernels -are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch and JAX. +are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch, JAX and Paddle. Below are some examples of simulations implemented using Warp: diff --git a/docs/installation.rst b/docs/installation.rst index 016109f0..6e72a664 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -76,6 +76,7 @@ The following optional dependencies are required to support certain features: * `usd-core `_: Required for some Warp examples, ``warp.sim.parse_usd()``, and ``warp.render.UsdRenderer``. * `JAX `_: Required for JAX interoperability (see :ref:`jax-interop`). * `PyTorch `_: Required for PyTorch interoperability (see :ref:`pytorch-interop`). +* `Paddle `_: Required for Paddle interoperability (see :ref:`paddle-interop`). * `NVTX for Python `_: Required to use :class:`wp.ScopedTimer(use_nvtx=True) `. Building the Warp documentation requires: diff --git a/docs/modules/interoperability.rst b/docs/modules/interoperability.rst index 800d4e79..21daa09b 100644 --- a/docs/modules/interoperability.rst +++ b/docs/modules/interoperability.rst @@ -709,6 +709,7 @@ The canonical way to export a Warp array to an external framework is to use the jax_array = jax.dlpack.from_dlpack(warp_array) torch_tensor = torch.utils.dlpack.from_dlpack(warp_array) + paddle_tensor = paddle.utils.dlpack.from_dlpack(warp_array) For CUDA arrays, this will synchronize the current stream of the consumer framework with the current Warp stream on the array's device. Thus it should be safe to use the wrapped array in the consumer framework, even if the array was previously used in a Warp kernel @@ -719,9 +720,11 @@ This approach may be used for older versions of frameworks that do not support t warp_array1 = wp.from_dlpack(jax.dlpack.to_dlpack(jax_array)) warp_array2 = wp.from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor)) + warp_array3 = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(paddle_tensor)) jax_array = jax.dlpack.from_dlpack(wp.to_dlpack(warp_array)) torch_tensor = torch.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array)) + paddle_tensor = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array)) This approach is generally faster because it skips any stream synchronization, but another solution must be used to ensure correct ordering of operations. In situations where no synchronization is required, using this approach can yield better performance. @@ -733,3 +736,181 @@ This may be a good choice in situations like these: .. autofunction:: warp.from_dlpack .. autofunction:: warp.to_dlpack + +.. _paddle-interop: + +Paddle +------ + +Warp provides helper functions to convert arrays to/from Paddle:: + + w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu") + + # convert to Paddle tensor + t = wp.to_paddle(w) + + # convert from Paddle tensor + w = wp.from_paddle(t) + +These helper functions allow the conversion of Warp arrays to/from Paddle tensors without copying the underlying data. +At the same time, if available, gradient arrays and tensors are converted to/from Paddle autograd tensors, allowing the use of Warp arrays +in Paddle autograd computations. + +.. autofunction:: warp.from_paddle +.. autofunction:: warp.to_paddle +.. autofunction:: warp.device_from_paddle +.. autofunction:: warp.device_to_paddle +.. autofunction:: warp.dtype_from_paddle +.. autofunction:: warp.dtype_to_paddle + +To convert a Paddle CUDA stream to a Warp CUDA stream and vice versa, Warp provides the following functions: + +.. autofunction:: warp.stream_from_paddle + +Example: Optimization using ``warp.from_paddle()`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An example usage of minimizing a loss function over an array of 2D points written in Warp via Paddle's Adam optimizer +using :func:`warp.from_paddle` is as follows:: + + import warp as wp + import paddle + + # init warp context at beginning + wp.context.init() + + @wp.kernel() + def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)): + tid = wp.tid() + wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0) + + # indicate requires_grad so that Warp can accumulate gradients in the grad buffers + xs = paddle.randn([100, 2]) + xs.stop_gradient = False + l = paddle.zeros([1]) + l.stop_gradient = False + opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=[xs]) + + wp_xs = wp.from_paddle(xs) + wp_l = wp.from_paddle(l) + + tape = wp.Tape() + with tape: + # record the loss function kernel launch on the tape + wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device) + + for i in range(500): + tape.zero() + tape.backward(loss=wp_l) # compute gradients + # now xs.grad will be populated with the gradients computed by Warp + opt.step() # update xs (and thereby wp_xs) + + # these lines are only needed for evaluating the loss + # (the optimization just needs the gradient, not the loss value) + wp_l.zero_() + wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device) + print(f"{i}\tloss: {l.item()}") + +Example: Optimization using ``warp.to_paddle`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Less code is needed when we declare the optimization variables directly in Warp and use :func:`warp.to_paddle` to convert them to Paddle tensors. +Here, we revisit the same example from above where now only a single conversion to a paddle tensor is needed to supply Adam with the optimization variables:: + + import warp as wp + import numpy as np + import paddle + + # init warp context at beginning + wp.context.init() + + @wp.kernel() + def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)): + tid = wp.tid() + wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0) + + # initialize the optimization variables in Warp + xs = wp.array(np.random.randn(100, 2), dtype=wp.float32, requires_grad=True) + l = wp.zeros(1, dtype=wp.float32, requires_grad=True) + # just a single wp.to_paddle call is needed, Adam optimizes using the Warp array gradients + opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=[wp.to_paddle(xs)]) + + tape = wp.Tape() + with tape: + wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device) + + for i in range(500): + tape.zero() + tape.backward(loss=l) + opt.step() + + l.zero_() + wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device) + print(f"{i}\tloss: {l.numpy()[0]}") + +Performance Notes +^^^^^^^^^^^^^^^^^ + +The ``wp.from_paddle()`` function creates a Warp array object that shares data with a Paddle tensor. Although this function does not copy the data, there is always some CPU overhead during the conversion. If these conversions happen frequently, the overall program performance may suffer. As a general rule, it's good to avoid repeated conversions of the same tensor. Instead of: + +.. code:: python + + x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + for i in range(10): + x_w = wp.from_paddle(x_t) + y_w = wp.from_paddle(y_t) + wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device) + +Try converting the arrays only once and reuse them: + +.. code:: python + + x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + x_w = wp.from_paddle(x_t) + y_w = wp.from_paddle(y_t) + + for i in range(10): + wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device) + +If reusing arrays is not possible (e.g., a new Paddle tensor is constructed on every iteration), passing ``return_ctype=True`` to ``wp.from_paddle()`` should yield faster performance. Setting this argument to True avoids constructing a ``wp.array`` object and instead returns a low-level array descriptor. This descriptor is a simple C structure that can be passed to Warp kernels instead of a ``wp.array``, but cannot be used in other places that require a ``wp.array``. + +.. code:: python + + for n in range(1, 10): + # get Paddle tensors for this iteration + x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + # get Warp array descriptors + x_ctype = wp.from_paddle(x_t, return_ctype=True) + y_ctype = wp.from_paddle(y_t, return_ctype=True) + + wp.launch(saxpy, dim=n, inputs=[x_ctype, y_ctype, 1.0], device=device) + +An alternative approach is to pass the Paddle tensors to Warp kernels directly. This avoids constructing temporary Warp arrays by leveraging standard array interfaces (like ``__cuda_array_interface__``) supported by both Paddle and Warp. The main advantage of this approach is convenience, since there is no need to call any conversion functions. The main limitation is that it does not handle gradients, because gradient information is not included in the standard array interfaces. This technique is therefore most suitable for algorithms that do not involve differentiation. + +.. code:: python + + x = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + y = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + for i in range(10): + wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device=device) + +.. code:: shell + + python -m warp.examples.benchmarks.benchmark_interop_paddle + +Sample output: + +.. code:: + + 13990 ms from_paddle(...) + 5990 ms from_paddle(..., return_ctype=True) + 35167 ms direct from paddle + +The default ``wp.from_paddle()`` conversion is the slowest. Passing ``return_ctype=True`` is the fastest, because it skips creating temporary Warp array objects. Passing Paddle tensors to Warp kernels directly falls somewhere in between. It skips creating temporary Warp arrays, but accessing the ``__cuda_array_interface__`` attributes of Paddle tensors adds overhead because they are initialized on-demand. diff --git a/exts/omni.warp.core/config/extension.toml b/exts/omni.warp.core/config/extension.toml index 8b80e91f..e39e4dd2 100644 --- a/exts/omni.warp.core/config/extension.toml +++ b/exts/omni.warp.core/config/extension.toml @@ -38,6 +38,7 @@ pyCoverageOmit = [ "warp/stubs.py", "warp/jax.py", "warp/torch.py", + "warp/paddle.py", "warp/build.py", "warp/build_dll.py", "warp/sim/**", diff --git a/warp/__init__.py b/warp/__init__.py index 76672327..8ba489a0 100644 --- a/warp/__init__.py +++ b/warp/__init__.py @@ -99,6 +99,11 @@ from warp.dlpack import from_dlpack, to_dlpack +from warp.paddle import from_paddle, to_paddle +from warp.paddle import dtype_from_paddle, dtype_to_paddle +from warp.paddle import device_from_paddle, device_to_paddle +from warp.paddle import stream_from_paddle + from warp.build import clear_kernel_cache from warp.constants import * diff --git a/warp/dlpack.py b/warp/dlpack.py index 34de4264..20860c6e 100644 --- a/warp/dlpack.py +++ b/warp/dlpack.py @@ -124,6 +124,8 @@ def device_to_dlpack(wp_device) -> DLDevice: def dtype_to_dlpack(wp_dtype) -> DLDataType: + if wp_dtype == warp.bool: + return (DLDataTypeCode.kDLBool, 8, 1) if wp_dtype == warp.int8: return (DLDataTypeCode.kDLInt, 8, 1) elif wp_dtype == warp.uint8: diff --git a/warp/examples/benchmarks/benchmark.bat b/warp/examples/benchmarks/benchmark.bat index 9edec17d..66a5dab3 100644 --- a/warp/examples/benchmarks/benchmark.bat +++ b/warp/examples/benchmarks/benchmark.bat @@ -11,3 +11,5 @@ python benchmark_cloth.py numpy @REM python benchmark_cloth.py numba @REM python benchmark_cloth.py jax_cpu @REM python benchmark_cloth.py jax_gpu +@REM python benchmark_cloth.py paddle_cpu +@REM python benchmark_cloth.py paddle_gpu diff --git a/warp/examples/benchmarks/benchmark.sh b/warp/examples/benchmarks/benchmark.sh index f82289a6..a4d54386 100755 --- a/warp/examples/benchmarks/benchmark.sh +++ b/warp/examples/benchmarks/benchmark.sh @@ -11,3 +11,5 @@ python3 benchmark_cloth.py numpy # python3 benchmark_cloth.py jax_cpu # python3 benchmark_cloth.py jax_gpu # python3 benchmark_cloth.py numba +# python3 benchmark_cloth.py paddle_cpu +# python3 benchmark_cloth.py paddle_gpu diff --git a/warp/examples/benchmarks/benchmark_cloth.py b/warp/examples/benchmarks/benchmark_cloth.py index d28213da..3fc6a740 100644 --- a/warp/examples/benchmarks/benchmark_cloth.py +++ b/warp/examples/benchmarks/benchmark_cloth.py @@ -219,6 +219,16 @@ def run_benchmark(mode, dim, timers, render=False): integrator = benchmark_cloth_jax.JxIntegrator(cloth) + elif mode == "paddle_cpu": + import benchmark_cloth_paddle + + integrator = benchmark_cloth_paddle.TrIntegrator(cloth, "cpu") + + elif mode == "paddle_gpu": + import benchmark_cloth_paddle + + integrator = benchmark_cloth_paddle.TrIntegrator(cloth, "gpu") + else: raise RuntimeError("Unknown simulation backend") diff --git a/warp/stubs.py b/warp/stubs.py index f9d7be6b..f45f2b72 100644 --- a/warp/stubs.py +++ b/warp/stubs.py @@ -108,6 +108,11 @@ from warp.dlpack import from_dlpack, to_dlpack +from warp.paddle import from_paddle, to_paddle +from warp.paddle import dtype_from_paddle, dtype_to_paddle +from warp.paddle import device_from_paddle, device_to_paddle +from warp.paddle import stream_from_paddle + from warp.build import clear_kernel_cache from warp.constants import * diff --git a/warp/tests/test_dlpack.py b/warp/tests/test_dlpack.py index 45fbef13..30ef693a 100644 --- a/warp/tests/test_dlpack.py +++ b/warp/tests/test_dlpack.py @@ -350,6 +350,34 @@ def test_dlpack_torch_to_warp_v2(test, device): assert_np_equal(a.numpy(), t.cpu().numpy()) +def test_dlpack_paddle_to_warp(test, device): + import paddle + import paddle.utils.dlpack + + t = paddle.arange(N, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + # paddle do not implement __dlpack__ yet, so only test to_dlpack here + a = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(t)) + + item_size = wp.types.type_size_in_bytes(a.dtype) + + test.assertEqual(a.ptr, t.data_ptr()) + test.assertEqual(a.device, wp.device_from_paddle(t.place)) + test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype)) + test.assertEqual(a.shape, tuple(t.shape)) + test.assertEqual(a.strides, tuple(s * item_size for s in t.strides)) + + assert_np_equal(a.numpy(), t.numpy()) + + wp.launch(inc, dim=a.size, inputs=[a], device=device) + + assert_np_equal(a.numpy(), t.numpy()) + + paddle.assign(t + 1, t) + + assert_np_equal(a.numpy(), t.numpy()) + + def test_dlpack_warp_to_jax(test, device): import jax import jax.dlpack @@ -421,6 +449,61 @@ def test_dlpack_warp_to_jax_v2(test, device): assert_np_equal(a.numpy(), np.asarray(j2)) +def test_dlpack_warp_to_paddle(test, device): + import paddle.utils.dlpack + + a = wp.array(data=np.arange(N, dtype=np.float32), device=device) + + t = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(a)) + + item_size = wp.types.type_size_in_bytes(a.dtype) + + test.assertEqual(a.ptr, t.data_ptr()) + test.assertEqual(a.device, wp.device_from_paddle(t.place)) + test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype)) + test.assertEqual(a.shape, tuple(t.shape)) + test.assertEqual(a.strides, tuple(s * item_size for s in t.strides)) + + assert_np_equal(a.numpy(), t.cpu().numpy()) + + wp.launch(inc, dim=a.size, inputs=[a], device=device) + + assert_np_equal(a.numpy(), t.cpu().numpy()) + + paddle.assign(t + 1, t) + + assert_np_equal(a.numpy(), t.cpu().numpy()) + + +def test_dlpack_warp_to_paddle_v2(test, device): + # same as original test, but uses newer __dlpack__() method + + import paddle.utils.dlpack + + a = wp.array(data=np.arange(N, dtype=np.float32), device=device) + + # pass the array directly + t = paddle.utils.dlpack.from_dlpack(a) + + item_size = wp.types.type_size_in_bytes(a.dtype) + + test.assertEqual(a.ptr, t.data_ptr()) + test.assertEqual(a.device, wp.device_from_paddle(t.place)) + test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype)) + test.assertEqual(a.shape, tuple(t.shape)) + test.assertEqual(a.strides, tuple(s * item_size for s in t.strides)) + + assert_np_equal(a.numpy(), t.numpy()) + + wp.launch(inc, dim=a.size, inputs=[a], device=device) + + assert_np_equal(a.numpy(), t.numpy()) + + paddle.assign(t + 1, t) + + assert_np_equal(a.numpy(), t.numpy()) + + def test_dlpack_jax_to_warp(test, device): import jax import jax.dlpack @@ -575,6 +658,41 @@ class TestDLPack(unittest.TestCase): print(f"Skipping Jax DLPack tests due to exception: {e}") +# paddle interop via dlpack +try: + import paddle + import paddle.utils.dlpack + + # check which Warp devices work with paddle + # CUDA devices may fail if paddle was not compiled with CUDA support + test_devices = get_test_devices() + paddle_compatible_devices = [] + for d in test_devices: + try: + t = paddle.arange(10).to(device=wp.device_to_paddle(d)) + paddle.assign(t + 1, t) + paddle_compatible_devices.append(d) + except Exception as e: + print(f"Skipping paddle DLPack tests on device '{d}' due to exception: {e}") + + if paddle_compatible_devices: + add_function_test( + TestDLPack, "test_dlpack_warp_to_paddle", test_dlpack_warp_to_paddle, devices=paddle_compatible_devices + ) + add_function_test( + TestDLPack, + "test_dlpack_warp_to_paddle_v2", + test_dlpack_warp_to_paddle_v2, + devices=paddle_compatible_devices, + ) + add_function_test( + TestDLPack, "test_dlpack_paddle_to_warp", test_dlpack_paddle_to_warp, devices=paddle_compatible_devices + ) + +except Exception as e: + print(f"Skipping Paddle DLPack tests due to exception: {e}") + + if __name__ == "__main__": wp.clear_kernel_cache() unittest.main(verbosity=2) diff --git a/warp/thirdparty/dlpack.py b/warp/thirdparty/dlpack.py index 0634474b..399e0002 100644 --- a/warp/thirdparty/dlpack.py +++ b/warp/thirdparty/dlpack.py @@ -58,6 +58,7 @@ class DLDataTypeCode(ctypes.c_uint8): kDLOpaquePointer = 3 kDLBfloat = 4 kDLComplex = 5 + kDLBool = 6 def __str__(self): return { @@ -66,6 +67,7 @@ def __str__(self): self.kDLFloat: "float", self.kDLBfloat: "bfloat", self.kDLComplex: "complex", + self.kDLBool: "bool", self.kDLOpaquePointer: "void_p", }[self.value] @@ -85,7 +87,7 @@ class DLDataType(ctypes.Structure): ("lanes", ctypes.c_uint16), ] TYPE_MAP = { - "bool": (DLDataTypeCode.kDLUInt, 1, 1), + "bool": (DLDataTypeCode.kDLBool, 8, 1), "int8": (DLDataTypeCode.kDLInt, 8, 1), "int16": (DLDataTypeCode.kDLInt, 16, 1), "int32": (DLDataTypeCode.kDLInt, 32, 1),