From 804a5aa1217821bae6eb9755d4470da186a3d787 Mon Sep 17 00:00:00 2001 From: christhetree Date: Wed, 19 Jun 2024 13:44:01 +0100 Subject: [PATCH] [cm] Adding 3 forward implementations --- torchlpc/csrc/experiments.py | 2 +- torchlpc/csrc/torchlpc.cpp | 31 +++++++++++++++++++++---------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/torchlpc/csrc/experiments.py b/torchlpc/csrc/experiments.py index 847e38c..557186b 100644 --- a/torchlpc/csrc/experiments.py +++ b/torchlpc/csrc/experiments.py @@ -44,7 +44,7 @@ def sample_wise_lpc_scriptable(x: T, a: T, zi: Optional[T] = None) -> T: # torch.set_num_threads(1) torch.utils.cpp_extension.load( - name="forward", + name="torchlpc", sources=["torchlpc.cpp"], is_python_module=False, verbose=True diff --git a/torchlpc/csrc/torchlpc.cpp b/torchlpc/csrc/torchlpc.cpp index eac4268..a4f1884 100644 --- a/torchlpc/csrc/torchlpc.cpp +++ b/torchlpc/csrc/torchlpc.cpp @@ -3,6 +3,7 @@ // Use for small T (less than a second of audio) +// TODO(cm): look into using associative scan for this torch::Tensor torchlpc_forward(torch::Tensor x, torch::Tensor a, torch::Tensor zi) { // Ensure input dimensions are correct TORCH_CHECK(x.dim() == 2, "x must be 2-dimensional"); @@ -16,7 +17,8 @@ torch::Tensor torchlpc_forward(torch::Tensor x, torch::Tensor a, torch::Tensor z const auto order = a.size(2); // Ensure the zi tensor is the correct size - TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), "zi must have shape (B, order)"); + TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), + "zi must have shape (B, order)"); // Flip zi and a to match scipy.signal.lfilter zi = torch::flip(zi, {1}); @@ -38,7 +40,9 @@ torch::Tensor torchlpc_forward(torch::Tensor x, torch::Tensor a, torch::Tensor z // Use for large T (seconds of audio) or abnormally large B -torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, torch::Tensor a, torch::Tensor zi) { +torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, + torch::Tensor a, + torch::Tensor zi) { // Ensure input dimensions are correct TORCH_CHECK(x.dim() == 2, "x must be 2-dimensional"); TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional"); @@ -51,7 +55,8 @@ torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, torch::Tensor a, const auto order = a.size(2); // Ensure the zi tensor is the correct size - TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), "zi must have shape (B, order)"); + TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), + "zi must have shape (B, order)"); // Flip zi and a to match scipy.signal.lfilter zi = torch::flip(zi, {1}); @@ -66,9 +71,12 @@ torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, torch::Tensor a, // The temporal loop cannot be parallelized for (int64_t t = 0; t < T; ++t) { auto a_slice = a.slice(0, b, b + 1).slice(1, t, t + 1); - auto y_slice = padded_y.slice(0, b, b + 1).slice(1, t, t + order).unsqueeze(2); + auto y_slice = padded_y.slice(0, b, b + 1) + .slice(1, t, t + order) + .unsqueeze(2); auto prod = torch::matmul(a_slice, y_slice).squeeze(2); - padded_y.slice(0, b, b + 1).slice(1, t + order, t + order + 1) -= prod; + padded_y.slice(0, b, b + 1) + .slice(1, t + order, t + order + 1) -= prod; } } }); @@ -81,7 +89,9 @@ torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, torch::Tensor a, // Use for large T (seconds of audio) and abnormally large order // TODO(cm): inner loop over order has a runtime error due to inplace tensor ops -torch::Tensor torchlpc_forward_batch_order_parallel(torch::Tensor x, torch::Tensor a, torch::Tensor zi) { +torch::Tensor torchlpc_forward_batch_order_parallel(torch::Tensor x, + torch::Tensor a, + torch::Tensor zi) { // Ensure input dimensions are correct TORCH_CHECK(x.dim() == 2, "x must be 2-dimensional"); TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional"); @@ -94,7 +104,8 @@ torch::Tensor torchlpc_forward_batch_order_parallel(torch::Tensor x, torch::Tens const auto order = a.size(2); // Ensure the zi tensor is the correct size - TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), "zi must have shape (B, order)"); + TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), + "zi must have shape (B, order)"); // Flip zi to match scipy.signal.lfilter zi = torch::flip(zi, {1}); @@ -127,7 +138,7 @@ torch::Tensor torchlpc_forward_batch_order_parallel(torch::Tensor x, torch::Tens } TORCH_LIBRARY(torchlpc, m) { - m.def("forward", torchlpc_forward); - m.def("forward_batch_parallel", torchlpc_forward_batch_parallel); - m.def("forward_batch_order_parallel", torchlpc_forward_batch_order_parallel); + m.def("forward", torchlpc_forward); + m.def("forward_batch_parallel", torchlpc_forward_batch_parallel); + m.def("forward_batch_order_parallel", torchlpc_forward_batch_order_parallel); }