diff --git a/torchcomp/__init__.py b/torchcomp/__init__.py index e6b4db7..cbd6093 100644 --- a/torchcomp/__init__.py +++ b/torchcomp/__init__.py @@ -187,8 +187,8 @@ def limiter_gain( assert torch.all(at > 0) and torch.all(at < 1) assert torch.all(rt > 0) and torch.all(rt < 1) - factory_func = lambda x: torch.as_tensor( - x, device=x.device, dtype=x.dtype + factory_func = lambda h: torch.as_tensor( + h, device=x.device, dtype=x.dtype ).broadcast_to(x.shape[0]) threshold = factory_func(threshold) at = factory_func(at)