diff --git a/bnpm/spectral.py b/bnpm/spectral.py index b99f759..b2c8085 100644 --- a/bnpm/spectral.py +++ b/bnpm/spectral.py @@ -875,7 +875,7 @@ def spectrogram_magnitude_normalization(S: torch.Tensor, k: float = 0.05): -def ppc(phases, axis=None): +def ppc(phases, axis=-1): """ Computes the pairwise phase consistency (PPC0) for a (set of) vector of phases. Based on Vinck et al. 2010, and the implementation in the FieldTrip @@ -886,6 +886,8 @@ def ppc(phases, axis=None): Args: phases (np.ndarray): Vector of phases in radians. Bound to the range [-pi, pi]. + axis (int): + Axis along which to compute the pairwise phase consistency. Returns: float: @@ -896,7 +898,7 @@ def ppc(phases, axis=None): elif isinstance(phases, np.ndarray): sin, cos, abs, sum = np.sin, np.cos, np.abs, np.sum - N = phases.shape[0] + N = phases.shape[axis] if N < 2: raise ValueError("The input vector must contain at least two phase values.") @@ -907,7 +909,7 @@ def ppc(phases, axis=None): @torch.jit.script -def torch_ppc(phases: torch.Tensor, axis: Optional[List[int]] = None): +def torch_ppc(phases: torch.Tensor, axis: int = -1): """ Exactly the same as ``ppc`` but works with torch.jit.script. Computes the pairwise phase consistency (PPC0) for a (set of) vector of @@ -919,12 +921,14 @@ def torch_ppc(phases: torch.Tensor, axis: Optional[List[int]] = None): Args: phases (np.ndarray): Vector of phases in radians. Bound to the range [-pi, pi]. + axis (int): + Axis along which to compute the pairwise phase consistency. Returns: float: Pairwise phase consistency of the phases. """ - N = phases.shape[0] + N = phases.shape[axis] if N < 2: raise ValueError("The input vector must contain at least two phase values.")