Skip to content

Commit

Permalink
[cm] Adding 3 forward implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Jun 19, 2024
1 parent ab900f4 commit 804a5aa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
2 changes: 1 addition & 1 deletion torchlpc/csrc/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def sample_wise_lpc_scriptable(x: T, a: T, zi: Optional[T] = None) -> T:
# torch.set_num_threads(1)

torch.utils.cpp_extension.load(
name="forward",
name="torchlpc",
sources=["torchlpc.cpp"],
is_python_module=False,
verbose=True
Expand Down
31 changes: 21 additions & 10 deletions torchlpc/csrc/torchlpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


// Use for small T (less than a second of audio)
// TODO(cm): look into using associative scan for this
torch::Tensor torchlpc_forward(torch::Tensor x, torch::Tensor a, torch::Tensor zi) {
// Ensure input dimensions are correct
TORCH_CHECK(x.dim() == 2, "x must be 2-dimensional");
Expand All @@ -16,7 +17,8 @@ torch::Tensor torchlpc_forward(torch::Tensor x, torch::Tensor a, torch::Tensor z
const auto order = a.size(2);

// Ensure the zi tensor is the correct size
TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), "zi must have shape (B, order)");
TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}),
"zi must have shape (B, order)");

// Flip zi and a to match scipy.signal.lfilter
zi = torch::flip(zi, {1});
Expand All @@ -38,7 +40,9 @@ torch::Tensor torchlpc_forward(torch::Tensor x, torch::Tensor a, torch::Tensor z


// Use for large T (seconds of audio) or abnormally large B
torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, torch::Tensor a, torch::Tensor zi) {
torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x,
torch::Tensor a,
torch::Tensor zi) {
// Ensure input dimensions are correct
TORCH_CHECK(x.dim() == 2, "x must be 2-dimensional");
TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional");
Expand All @@ -51,7 +55,8 @@ torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, torch::Tensor a,
const auto order = a.size(2);

// Ensure the zi tensor is the correct size
TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), "zi must have shape (B, order)");
TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}),
"zi must have shape (B, order)");

// Flip zi and a to match scipy.signal.lfilter
zi = torch::flip(zi, {1});
Expand All @@ -66,9 +71,12 @@ torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, torch::Tensor a,
// The temporal loop cannot be parallelized
for (int64_t t = 0; t < T; ++t) {
auto a_slice = a.slice(0, b, b + 1).slice(1, t, t + 1);
auto y_slice = padded_y.slice(0, b, b + 1).slice(1, t, t + order).unsqueeze(2);
auto y_slice = padded_y.slice(0, b, b + 1)
.slice(1, t, t + order)
.unsqueeze(2);
auto prod = torch::matmul(a_slice, y_slice).squeeze(2);
padded_y.slice(0, b, b + 1).slice(1, t + order, t + order + 1) -= prod;
padded_y.slice(0, b, b + 1)
.slice(1, t + order, t + order + 1) -= prod;
}
}
});
Expand All @@ -81,7 +89,9 @@ torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x, torch::Tensor a,

// Use for large T (seconds of audio) and abnormally large order
// TODO(cm): inner loop over order has a runtime error due to inplace tensor ops
torch::Tensor torchlpc_forward_batch_order_parallel(torch::Tensor x, torch::Tensor a, torch::Tensor zi) {
torch::Tensor torchlpc_forward_batch_order_parallel(torch::Tensor x,
torch::Tensor a,
torch::Tensor zi) {
// Ensure input dimensions are correct
TORCH_CHECK(x.dim() == 2, "x must be 2-dimensional");
TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional");
Expand All @@ -94,7 +104,8 @@ torch::Tensor torchlpc_forward_batch_order_parallel(torch::Tensor x, torch::Tens
const auto order = a.size(2);

// Ensure the zi tensor is the correct size
TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}), "zi must have shape (B, order)");
TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}),
"zi must have shape (B, order)");

// Flip zi to match scipy.signal.lfilter
zi = torch::flip(zi, {1});
Expand Down Expand Up @@ -127,7 +138,7 @@ torch::Tensor torchlpc_forward_batch_order_parallel(torch::Tensor x, torch::Tens
}

TORCH_LIBRARY(torchlpc, m) {
m.def("forward", torchlpc_forward);
m.def("forward_batch_parallel", torchlpc_forward_batch_parallel);
m.def("forward_batch_order_parallel", torchlpc_forward_batch_order_parallel);
m.def("forward", torchlpc_forward);
m.def("forward_batch_parallel", torchlpc_forward_batch_parallel);
m.def("forward_batch_order_parallel", torchlpc_forward_batch_order_parallel);
}

0 comments on commit 804a5aa

Please sign in to comment.