-
Notifications
You must be signed in to change notification settings - Fork 0
/
bench_resample.py
51 lines (44 loc) · 1.5 KB
/
bench_resample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from kazane.resample import Resample
from julius import ResampleFrac
from torch.profiler import profiler
from torch.utils.benchmark import Timer, Compare
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
device = args.device
batch = 4
sr = 44100
pairs = [(20000, 25000), (25000, 20000), (2205, 44100), (44100, 2205)]
duration = 10
zeros = 32
num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')
results = []
for old_sr, new_sr in pairs:
label = 'Resample'
sub_label = f'ratio: {old_sr, new_sr}'
x = torch.randn(batch, int(old_sr * duration), device=device)
results.append(Timer(
stmt='m(x)',
setup='',
globals={'x': x, 'm': ResampleFrac(
old_sr, new_sr, zeros).to(device)},
num_threads=num_threads,
label=label,
sub_label=sub_label,
description='julius',
).blocked_autorange(min_run_time=1))
results.append(Timer(
stmt='m(x)',
setup='',
globals={'x': x, 'm': Resample(old_sr, new_sr, zeros).to(device)},
num_threads=num_threads,
label=label,
sub_label=sub_label,
description='kazane',
).blocked_autorange(min_run_time=1))
compare = Compare(results)
compare.print()