diff --git a/src/pygenstability/data_clustering.py b/src/pygenstability/data_clustering.py index 764d73a..e00c50b 100644 --- a/src/pygenstability/data_clustering.py +++ b/src/pygenstability/data_clustering.py @@ -65,6 +65,14 @@ def __init__( def fit(self, X): """Construct graph from samples-by-features matrix.""" + # if precomputed take X as adjacency matrix + if self.method == "precomputed": + assert ( + X.shape[0] == X.shape[1] + ), "Precomputed matrix should be a square matrix." + self.adjacency_ = X + return self.adjacency_ + # compute normalised distance matrix D = squareform(pdist(X, metric=self.metric)) D_norm = D / np.amax(D)