From ebbf04d0acf74474f06541444144aa5a4f52387e Mon Sep 17 00:00:00 2001 From: d-schindler <60650591+d-schindler@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:12:30 +0200 Subject: [PATCH] add scale selection method --- src/pygenstability/data_clustering.py | 38 ++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/pygenstability/data_clustering.py b/src/pygenstability/data_clustering.py index f0bece2..764d73a 100644 --- a/src/pygenstability/data_clustering.py +++ b/src/pygenstability/data_clustering.py @@ -8,6 +8,7 @@ from pygenstability.pygenstability import run as pgs_run from pygenstability.plotting import plot_scan as pgs_plot_scan +from pygenstability.optimal_scales import identify_optimal_scales def compute_kNN(D, k=5): @@ -193,6 +194,16 @@ def fit(self, X): self.constructor_kwargs, ) + # store labels of robust partitions + self._postprocess_selected_partitions() + + return self.results_ + + def _postprocess_selected_partitions(self): + """Postprocess selected partitions.""" + + self.labels_ = [] + # store labels of robust partitions for i in self.results_["selected_partitions"]: @@ -201,7 +212,32 @@ def fit(self, X): if not np.allclose(robust_partition, np.zeros(self.adjacency_.shape[0])): self.labels_.append(robust_partition) - return self.results_ + def scale_selection( + self, kernel_size=0.1, window_size=0.1, max_nvi=1, basin_radius=0.01 + ): + """Identify optimal scales.""" + + # transform relative values to absolute values + if kernel_size < 1: + kernel_size = int(kernel_size * self.results_["run_params"]["n_scale"]) + if window_size < 1: + window_size = int(window_size * self.results_["run_params"]["n_scale"]) + if basin_radius < 1: + basin_radius = int(basin_radius * self.results_["run_params"]["n_scale"]) + + # apply scale selection algorithm + self.results_ = identify_optimal_scales( + self.results_, + kernel_size=kernel_size, + window_size=window_size, + max_nvi=max_nvi, + basin_radius=basin_radius, + ) + + # store labels of robust partitions + self._postprocess_selected_partitions() + + return self.labels_ def plot_scan(self): """Plot PyGenStability scan."""