diff --git a/.vscode/settings.json b/.vscode/settings.json index ac2bb1cab..bc43e427c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -78,6 +78,7 @@ "copybutton", "cstride", "csys", + "cumsum", "datapoints", "datetime", "dcsys", @@ -149,6 +150,7 @@ "IGRA", "imageio", "imread", + "imshow", "intc", "interp", "Interquartile", @@ -258,6 +260,7 @@ "SRTM", "SRTMGL", "Stano", + "STFT", "subintervals", "suptitle", "ticklabel", diff --git a/CHANGELOG.md b/CHANGELOG.md index dced5f94c..0f1828872 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,8 @@ Attention: The newest changes should be on top --> ### Added -- ENH: Rocket Axis Definition [#635](https://github.com/RocketPy-Team/RocketPy/pull/635) +- ENH: Add STFT function to Function class [#620](https://github.com/RocketPy-Team/RocketPy/pull/620) +- ENH: Rocket Axis Definition [#635](https://github.com/RocketPy-Team/RocketPy/pull/635) ### Changed diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 7b4c2f23a..095b82268 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -1007,6 +1007,142 @@ def to_frequency_domain(self, lower, upper, sampling_frequency, remove_dc=True): extrapolation="zero", ) + def short_time_fft( + self, + lower, + upper, + sampling_frequency, + window_size, + step_size, + remove_dc=True, + only_positive=True, + ): + r""" + Performs the Short-Time Fourier Transform (STFT) of the Function and + returns the result. The STFT is computed by applying the Fourier + transform to overlapping windows of the Function. + + Parameters + ---------- + lower : float + Lower bound of the time range. + upper : float + Upper bound of the time range. + sampling_frequency : float + Sampling frequency at which to perform the Fourier transform. + window_size : float + Size of the window for the STFT, in seconds. + step_size : float + Step size for the window, in seconds. + remove_dc : bool, optional + If True, the DC component is removed from each window before + computing the Fourier transform. + only_positive: bool, optional + If True, only the positive frequencies are returned. + + Returns + ------- + list[Function] + A list of Functions, each representing the STFT of a window. + + Examples + -------- + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from rocketpy import Function + + Generate a signal with varying frequency: + + >>> T_x, N = 1 / 20 , 1000 # 20 Hz sampling rate for 50 s signal + >>> t_x = np.arange(N) * T_x # time indexes for signal + >>> f_i = 1 * np.arctan((t_x - t_x[N // 2]) / 2) + 5 # varying frequency + >>> signal = np.sin(2 * np.pi * np.cumsum(f_i) * T_x) # the signal + + Create the Function object and perform the STFT: + + >>> time_domain = Function(np.array([t_x, signal]).T) + >>> stft_result = time_domain.short_time_fft( + ... lower=0, + ... upper=50, + ... sampling_frequency=95, + ... window_size=2, + ... step_size=0.5, + ... ) + + Plot the spectrogram: + + >>> Sx = np.abs([window[:, 1] for window in stft_result]) + >>> t_lo, t_hi = t_x[0], t_x[-1] + >>> fig1, ax1 = plt.subplots(figsize=(10, 6)) + >>> im1 = ax1.imshow( + ... Sx.T, + ... origin='lower', + ... aspect='auto', + ... extent=[t_lo, t_hi, 0, 50], + ... cmap='viridis' + ... ) + >>> _ = ax1.set_title(rf"STFT (2$\,s$ Gaussian window, $\sigma_t=0.4\,$s)") + >>> _ = ax1.set( + ... xlabel=f"Time $t$ in seconds", + ... ylabel=f"Freq. $f$ in Hz)", + ... xlim=(t_lo, t_hi) + ... ) + >>> _ = ax1.plot(t_x, f_i, 'r--', alpha=.5, label='$f_i(t)$') + >>> _ = fig1.colorbar(im1, label="Magnitude $|S_x(t, f)|$") + >>> # Shade areas where window slices stick out to the side + >>> for t0_, t1_ in [(t_lo, 1), (49, t_hi)]: + ... _ = ax1.axvspan(t0_, t1_, color='w', linewidth=0, alpha=.2) + >>> # Mark signal borders with vertical line + >>> for t_ in [t_lo, t_hi]: + ... _ = ax1.axvline(t_, color='y', linestyle='--', alpha=0.5) + >>> # Add legend and finalize plot + >>> _ = ax1.legend() + >>> fig1.tight_layout() + >>> # plt.show() # uncomment to show the plot + + References + ---------- + Example adapted from the SciPy documentation: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.ShortTimeFFT.html + """ + # Get the time domain data + sampling_time_step = 1.0 / sampling_frequency + sampling_range = np.arange(lower, upper, sampling_time_step) + sampled_points = self(sampling_range) + samples_per_window = int(window_size * sampling_frequency) + samples_skipped_per_step = int(step_size * sampling_frequency) + stft_results = [] + + max_start = len(sampled_points) - samples_per_window + 1 + + for start in range(0, max_start, samples_skipped_per_step): + windowed_samples = sampled_points[start : start + samples_per_window] + if remove_dc: + windowed_samples -= np.mean(windowed_samples) + fourier_amplitude = np.abs( + np.fft.fft(windowed_samples) / (samples_per_window / 2) + ) + fourier_frequencies = np.fft.fftfreq(samples_per_window, sampling_time_step) + + # Filter to keep only positive frequencies if specified + if only_positive: + positive_indices = fourier_frequencies > 0 + fourier_frequencies = fourier_frequencies[positive_indices] + fourier_amplitude = fourier_amplitude[positive_indices] + + stft_results.append( + Function( + source=np.array([fourier_frequencies, fourier_amplitude]).T, + inputs="Frequency (Hz)", + outputs="Amplitude", + interpolation="linear", + extrapolation="zero", + ) + ) + + return stft_results + def low_pass_filter(self, alpha, file_path=None): """Implements a low pass filter with a moving average filter. This does not mutate the original Function object, but returns a new one with the