diff --git a/pyroomacoustics/bss/auxiva.py b/pyroomacoustics/bss/auxiva.py index 8825e6dc..73c6229c 100644 --- a/pyroomacoustics/bss/auxiva.py +++ b/pyroomacoustics/bss/auxiva.py @@ -227,7 +227,7 @@ def demix(Y, X, W): ) WV = np.matmul(W_hat, V) - W[:, s, :] = np.conj(np.linalg.solve(WV, eyes[:, :, s])) + W[:, s, :] = np.conj(np.linalg.solve(WV, eyes[:, :, [s]]))[..., 0] # normalize denom = np.matmul( diff --git a/pyroomacoustics/bss/fastmnmf.py b/pyroomacoustics/bss/fastmnmf.py index 6caf833e..05898c78 100644 --- a/pyroomacoustics/bss/fastmnmf.py +++ b/pyroomacoustics/bss/fastmnmf.py @@ -175,8 +175,8 @@ def separate(): try: tmp_FM = np.linalg.solve( - np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, m] - ) + np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, :, [m]] + )[..., 0] except np.linalg.LinAlgError: # If Gaussian elimination fails due to a singlular matrix, we # switch to the pseudo-inverse solution diff --git a/pyroomacoustics/bss/fastmnmf2.py b/pyroomacoustics/bss/fastmnmf2.py index 35edc102..5d310733 100644 --- a/pyroomacoustics/bss/fastmnmf2.py +++ b/pyroomacoustics/bss/fastmnmf2.py @@ -165,8 +165,8 @@ def separate(): np.einsum("ftij, ft -> fij", XX_FTMM, 1 / Y_FTM[..., m]) / n_frames ) tmp_FM = np.linalg.solve( - np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, m] - ) + np.matmul(Q_FMM, V_FMM), np.eye(n_chan)[None, :, [m]] + )[..., 0] Q_FMM[:, m] = ( tmp_FM / np.sqrt( diff --git a/pyroomacoustics/bss/ilrma.py b/pyroomacoustics/bss/ilrma.py index 1bd27840..f2d0ed15 100644 --- a/pyroomacoustics/bss/ilrma.py +++ b/pyroomacoustics/bss/ilrma.py @@ -159,7 +159,7 @@ def demix(Y, X, W): C = np.matmul((X * iR[s, :, None, :]), np.conj(X.swapaxes(1, 2))) / n_frames WV = np.matmul(W, C) - W[:, s, :] = np.conj(np.linalg.solve(WV, eyes[:, :, s])) + W[:, s, :] = np.conj(np.linalg.solve(WV, eyes[:, :, [s]]))[..., 0] # normalize denom = np.matmul( diff --git a/pyroomacoustics/bss/sparseauxiva.py b/pyroomacoustics/bss/sparseauxiva.py index 37d5fdf0..d9fb8017 100644 --- a/pyroomacoustics/bss/sparseauxiva.py +++ b/pyroomacoustics/bss/sparseauxiva.py @@ -148,7 +148,7 @@ def demixsparse(Y, X, S, W): W_H = np.conj(np.swapaxes(W, 1, 2)) WV = np.matmul(W_H, V[:, s, :, :]) rhs = I[None, :, s][[0] * WV.shape[0], :] - W[:, :, s] = np.linalg.solve(WV, rhs) + W[:, :, s] = np.linalg.solve(WV, rhs[..., None])[..., 0] # normalize P1 = np.conj(W[:, :, s])