Skip to content

Commit

Permalink
Add return_real parameter to FFTConvolve class
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 10, 2024
1 parent 6f083c1 commit cde800a
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion bnpm/timeSeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,7 @@ def __init__(
n: Optional[int]=None,
next_fast_length: bool=False,
use_x_fft: bool=True,
return_real: bool=True,
):
super(FFTConvolve, self).__init__()
if x is not None:
Expand All @@ -1250,6 +1251,7 @@ def __init__(
self.x_fft = None

self.use_x_fft = use_x_fft
self.return_real = return_real

def set_x_fft(self, x: torch.Tensor, n: Optional[int]=None, next_fast_length: bool=False):
if next_fast_length:
Expand All @@ -1272,7 +1274,17 @@ def forward(
n: Optional[int]=None,
fast_length: Union[int, bool]=False,
x_fft: Optional[torch.Tensor]=None,
return_real: bool=None,
) -> torch.Tensor:
x_fft = self.x_fft if x_fft is None else x_fft
return_real = self.return_real if return_real is None else return_real
n = self.n if n is None else n
return fftconvolve(x=x, y=y, mode=mode, n=n, fast_length=fast_length, x_fft=x_fft if self.use_x_fft else None)
return fftconvolve(
x=x,
y=y,
mode=mode,
n=n,
fast_length=fast_length,
x_fft=x_fft if self.use_x_fft else None,
return_real=return_real,
)

0 comments on commit cde800a

Please sign in to comment.