Skip to content

Commit

Permalink
FIX(coeff) (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty authored Aug 26, 2023
1 parent 9ad7880 commit aceaa04
Showing 1 changed file with 25 additions and 41 deletions.
66 changes: 25 additions & 41 deletions interpol/coeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def get_gain(poles: List[float]) -> float:
return lam


@torch.jit.script
def _dot(x, y):
return x.unsqueeze(-2).matmul(y.unsqueeze(-1)).squeeze(-1).squeeze(-1)


@torch.jit.script
def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):

Expand All @@ -90,15 +95,13 @@ def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
inp0 = inp[0]
inp = inp[1-max_iter:]
inp = movedim1(inp, 0, -1)
out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
out = out + inp0.unsqueeze(-1)
if keepdim:
out = movedim1(out, -1, dim)
else:
out = out.squeeze(-1)
out = _dot(inp, poles) + inp0

pole = pole ** max_iter
out = out / (1 - pole)

if keepdim:
out = out.unsqueeze(dim)
return out


Expand All @@ -112,29 +115,20 @@ def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):

poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
poles = poles.pow(
torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device)
torch.arange(0, max_iter, dtype=inp.dtype, device=inp.device)
)

inp = movedim1(inp, dim, 0)
inp0 = inp[0]
inp = inp[1:max_iter]
inp = inp[:max_iter]
inp = movedim1(inp, 0, -1)
out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
out = out + inp0.unsqueeze(-1)
if keepdim:
out = movedim1(out, -1, dim)
else:
out = out.squeeze(-1)
out = _dot(inp, poles)

else:
max_iter = n

polen = pole ** (n - 1)
inp0 = inp[0] + polen * inp[-1]
inp = inp[1:-1]
inp = movedim1(inp, 0, -1)

out = inp0.unsqueeze(-1)
out = inp[0] + polen * inp[-1]

if n > 2:
poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device)
Expand All @@ -143,26 +137,18 @@ def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
)
poles = poles + (polen * polen) / poles

out = out + torch.matmul(
inp.unsqueeze(-2), poles.unsqueeze(-1)
).squeeze(-1)

if keepdim:
out = movedim1(out, -1, dim)
else:
out = out.squeeze(-1)
inp = inp[1:-1]
inp = movedim1(inp, 0, -1)
out = out + _dot(inp, poles)

pole = pole ** (max_iter - 1)
out = out / (1 - pole * pole)

if keepdim:
out = out.unsqueeze(dim)
return out


@torch.jit.script
def _dot(x, y):
return x.unsqueeze(-2).matmul(y.unsqueeze(-1)).squeeze(-1).squeeze(-1)


@torch.jit.script
def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):
# Ported from scipy:
Expand Down Expand Up @@ -190,7 +176,6 @@ def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False):

if keepdim:
out = out.unsqueeze(dim)

return out


Expand All @@ -210,15 +195,14 @@ def dft_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
inp0 = inp[-1]
inp = inp[:max_iter-1]
inp = movedim1(inp, 0, -1)
out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1)
out = out.add(inp0.unsqueeze(-1), alpha=pole)
if keepdim:
out = movedim1(out, -1, dim)
else:
out = out.squeeze(-1)
out = _dot(inp, poles)
out = out.add(inp0, alpha=pole)

pole = pole ** max_iter
out = out / (pole - 1)

if keepdim:
out = out.unsqueeze(dim)
return out


Expand All @@ -228,7 +212,7 @@ def dct1_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
out = pole * inp[-2] + inp[-1]
out = out * (pole / (pole*pole - 1))
if keepdim:
out = movedim1(out.unsqueeze(0), 0, dim)
out = out.unsqueeze(dim)
return out


Expand All @@ -239,7 +223,7 @@ def dct2_final(inp, pole: float, dim: int = -1, keepdim: bool = False):
inp = movedim1(inp, dim, 0)
out = inp[-1] * (pole / (pole - 1))
if keepdim:
out = movedim1(out.unsqueeze(0), 0, dim)
out = out.unsqueeze(dim)
return out


Expand Down

0 comments on commit aceaa04

Please sign in to comment.