Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cm] Initial CPP forward implementation #10

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
*.nbc
*.nbi
build/
*.egg*
*.egg*
.idea/
.DS_Store
80 changes: 80 additions & 0 deletions tests/benchmark_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import itertools
import logging
import os

import numba
import torch as tr
import torch.utils.cpp_extension
from torch.utils import benchmark
from tqdm import tqdm

from test_forward import sample_wise_lpc_scriptable
from torchlpc import sample_wise_lpc

tr.utils.cpp_extension.load(
name="torchlpc",
sources=["../torchlpc/csrc/torchlpc.cpp"],
is_python_module=False,
verbose=True,
)

logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO"))

batch_sizes = [32]
n_samples = [2048]
orders = [3]
forward_funcs = [
sample_wise_lpc,
sample_wise_lpc_scriptable,
tr.ops.torchlpc.forward,
tr.ops.torchlpc.forward_batch_parallel,
]
dtype = tr.float
num_threads = 1


def main() -> None:
tr.manual_seed(42)
# Set intraop threads
tr.set_num_threads(num_threads)
# Set interop threads
tr.set_num_interop_threads(num_threads)
numba.set_num_threads(num_threads)
log.info(f"numba.get_num_threads(): {numba.get_num_threads()}")

results = []

for bs, n, order in tqdm(itertools.product(batch_sizes, n_samples, orders)):
x = tr.randn(bs, n, dtype=dtype)
a = tr.randn(bs, n, order, dtype=dtype)
zi = tr.randn(bs, order, dtype=dtype)

x.requires_grad_(False)
a.requires_grad_(False)
zi.requires_grad_(False)

for forward_func in tqdm(forward_funcs):
globals = {
"forward_func": forward_func,
"x": x,
"a": a,
"zi": zi,
}
results.append(
benchmark.Timer(
stmt="y = forward_func(x, a, zi)",
globals=globals,
sub_label=f"bs_{bs}__n_{n}__order_{order}__threads_{num_threads}",
description=forward_func.__name__,
num_threads=num_threads,
).blocked_autorange(min_run_time=0.5)
)

compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
main()
91 changes: 91 additions & 0 deletions tests/test_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
import os
from typing import Optional, Callable

import pytest
import torch as tr
import torch.utils.cpp_extension
from torch import Tensor as T

from torchlpc import sample_wise_lpc

tr.utils.cpp_extension.load(
name="torchlpc",
sources=["../torchlpc/csrc/torchlpc.cpp"],
is_python_module=False,
verbose=True
)

logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO"))


# TorchScript compatible pure torch implementation of torchlpc.forward()
def sample_wise_lpc_scriptable(x: T, a: T, zi: Optional[T] = None) -> T:
assert x.ndim == 2
assert a.ndim == 3
assert x.size(0) == a.size(0)
assert x.size(1) == a.size(1)

B, T, order = a.shape
if zi is None:
zi = a.new_zeros(B, order)
else:
assert zi.shape == (B, order)

zi = tr.flip(zi, dims=[1])
a_flip = tr.flip(a, dims=[2])
padded_y = tr.cat([zi, x], dim=1)

for t in range(T):
prod = a_flip[:, t: t + 1] @ padded_y[:, t: t + order, None]
prod = prod[:, 0, 0]
padded_y[:, t + order] -= prod

return padded_y[:, order:]


def compare_forward(forward_a: Callable[[T, T, Optional[T]], T],
forward_b: Callable[[T, T, Optional[T]], T],
bs: int,
n_samples: int,
order: int,
use_double: bool = True,
rtol: float = 1e-5,
atol: float = 1e-8) -> None:
if use_double:
dtype = tr.double
else:
dtype = tr.float
x = tr.randn(bs, n_samples, dtype=dtype)
a = tr.randn(bs, n_samples, order, dtype=dtype)
zi = tr.randn(bs, order, dtype=dtype)
result_a = forward_a(x, a, zi)
result_b = forward_b(x, a, zi)
assert tr.allclose(result_a, result_b, rtol=rtol, atol=atol)


@pytest.mark.parametrize(
"bs",
[1, 2, 10],
)
@pytest.mark.parametrize(
"n_samples",
[1, 2, 2048],
)
@pytest.mark.parametrize(
"order",
[1, 2, 3, 6],
)
def test_forward(bs: int, n_samples: int, order: int) -> None:
forward_a = sample_wise_lpc
# sample_wise_lpc_scriptable
forward_b = sample_wise_lpc_scriptable
compare_forward(forward_a, forward_b, bs, n_samples, order)
# CPP forward
forward_b = tr.ops.torchlpc.forward
compare_forward(forward_a, forward_b, bs, n_samples, order)
# CPP forward_batch_parallel
forward_b = tr.ops.torchlpc.forward_batch_parallel
compare_forward(forward_a, forward_b, bs, n_samples, order)
11 changes: 11 additions & 0 deletions torchlpc/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cmake_minimum_required(VERSION 3.15 FATAL_ERROR)
project(torchlpc)

find_package(Torch REQUIRED)

# Define our library target
add_library(torchlpc SHARED torchlpc.cpp)
# Enable C++17
target_compile_features(torchlpc PRIVATE cxx_std_17)
# Link against LibTorch
target_link_libraries(torchlpc "${TORCH_LIBRARIES}")
93 changes: 93 additions & 0 deletions torchlpc/csrc/torchlpc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include <torch/script.h>
#include <ATen/Parallel.h>


// Use for small T (less than a second of audio)
// TODO(cm): look into using associative scan for this
torch::Tensor torchlpc_forward(torch::Tensor x, torch::Tensor a, torch::Tensor zi) {
// Ensure input dimensions are correct
TORCH_CHECK(x.dim() == 2, "x must be 2-dimensional");
TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional");
TORCH_CHECK(x.size(0) == a.size(0), "Batch size of x and a must match");
TORCH_CHECK(x.size(1) == a.size(1), "Time dimension of x and a must match");

// Get the dimensions
const auto B = a.size(0);
const auto T = a.size(1);
const auto order = a.size(2);

// Ensure the zi tensor is the correct size
TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}),
"zi must have shape (B, order)");

// Flip zi and a to match scipy.signal.lfilter
zi = torch::flip(zi, {1});
a = torch::flip(a, {2});

// Concatenate zi and x along the time dimension
auto padded_y = torch::cat({zi, x}, 1);
for (int64_t t = 0; t < T; ++t) {
auto a_slice = a.slice(1, t, t + 1);
auto y_slice = padded_y.slice(1, t, t + order).unsqueeze(2);
auto prod = torch::matmul(a_slice, y_slice).squeeze(2);
padded_y.slice(1, t + order, t + order + 1) -= prod;
}

// Remove the padding and return the result
auto y = padded_y.slice(1, order, T + order);
return y;
}


// Use for large T (seconds of audio) or abnormally large B
torch::Tensor torchlpc_forward_batch_parallel(torch::Tensor x,
torch::Tensor a,
torch::Tensor zi) {
// Ensure input dimensions are correct
TORCH_CHECK(x.dim() == 2, "x must be 2-dimensional");
TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional");
TORCH_CHECK(x.size(0) == a.size(0), "Batch size of x and a must match");
TORCH_CHECK(x.size(1) == a.size(1), "Time dimension of x and a must match");

// Get the dimensions
const auto B = a.size(0);
const auto T = a.size(1);
const auto order = a.size(2);

// Ensure the zi tensor is the correct size
TORCH_CHECK(zi.sizes() == torch::IntArrayRef({B, order}),
"zi must have shape (B, order)");

// Flip zi and a to match scipy.signal.lfilter
zi = torch::flip(zi, {1});
a = torch::flip(a, {2});

// Concatenate zi and x along the time dimension
auto padded_y = torch::cat({zi, x}, 1);

// Perform the computation for each time step
at::parallel_for(0, B, 1, [&](int64_t begin_b, int64_t end_b) {
for (auto b = begin_b; b < end_b; ++b) {
// The temporal loop cannot be parallelized
for (int64_t t = 0; t < T; ++t) {
auto a_slice = a.slice(0, b, b + 1).slice(1, t, t + 1);
auto y_slice = padded_y.slice(0, b, b + 1)
.slice(1, t, t + order)
.unsqueeze(2);
auto prod = torch::matmul(a_slice, y_slice).squeeze(2);
padded_y.slice(0, b, b + 1)
.slice(1, t + order, t + order + 1) -= prod;
}
}
});

// Remove the padding and return the result
auto y = padded_y.slice(1, order, T + order);
return y;
}


TORCH_LIBRARY(torchlpc, m) {
m.def("forward", torchlpc_forward);
m.def("forward_batch_parallel", torchlpc_forward_batch_parallel);
}
Loading