Skip to content

Commit

Permalink
replace type aliases that are incompatible with python 3.7, make valu…
Browse files Browse the repository at this point in the history
…e of 'n_init' KMeans parameter dependent on sklearn version
  • Loading branch information
mamei16 committed Oct 29, 2023
1 parent 4fc7b42 commit 683038b
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 8 deletions.
7 changes: 5 additions & 2 deletions pypsf/clustering.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import sklearn
from sklearn.cluster import KMeans
from typing import List


def run_clustering(cycles: list[np.array], n_clusters: int) -> KMeans:
def run_clustering(cycles: List[np.array] or np.array, n_clusters: int) -> KMeans:
"""
Apply K-means clustering to the provided list of cycles.
Args:
Expand All @@ -15,4 +17,5 @@ def run_clustering(cycles: list[np.array], n_clusters: int) -> KMeans:
kmeans (KMeans):
The fitted K-means clustering object
"""
return KMeans(n_clusters=n_clusters, init='random', n_init="auto", random_state=3683475120).fit(np.array(cycles))
return KMeans(n_clusters=n_clusters, init='random',
n_init="auto" if sklearn.__version__ >= "1.2" else 10, random_state=3683475120).fit(np.array(cycles))
6 changes: 4 additions & 2 deletions pypsf/hyperparameter_search.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Tuple

import numpy as np
from sklearn.metrics import silhouette_score, mean_absolute_error

from pypsf.clustering import run_clustering
from pypsf.predict import psf_predict


def optimum_k(data: np.array, k_values: tuple[int]) -> int:
def optimum_k(data: np.array, k_values: Tuple[int]) -> int:
"""
Perform a hyperparameter search using the provided values of 'k' to
determine the number of clusters 'k' that achieve the highest silhouette
Expand Down Expand Up @@ -40,7 +42,7 @@ def optimum_k(data: np.array, k_values: tuple[int]) -> int:
return best_k


def optimum_w(data: np.array, k: int, cycle_length: int, w_values: tuple[int]) -> int:
def optimum_w(data: np.array, k: int, cycle_length: int, w_values: Tuple[int]) -> int:
"""
Perform a hyperparameter search using the provided values of 'w' to
determine the window size that results in the lowest mean absolute error.
Expand Down
3 changes: 2 additions & 1 deletion pypsf/neighbors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
from typing import List

import numpy as np


def neighbor_indices(cluster_labels: np.array, w: int) -> list[int]:
def neighbor_indices(cluster_labels: np.array, w: int) -> List[int]:
"""
Take the last 'w' cluster labels and return all matching previous occurrences of this pattern in the rest of the
list of cluster labels (so-called neighbors). The is done by first converting the list of ints
Expand Down
4 changes: 3 additions & 1 deletion pypsf/psf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import numpy as np
from numpy.typing import ArrayLike
from sklearn.linear_model import LinearRegression
Expand Down Expand Up @@ -152,7 +154,7 @@ def predict(self, n_ahead: int) -> np.array:
self.preds = self.postprocessing(preds, orig_n_ahead)
return self.preds

def postprocessing(self, preds: list[np.array], orig_n_ahead: int) -> np.array:
def postprocessing(self, preds: List[np.array], orig_n_ahead: int) -> np.array:
"""
Performs the inverse of 'preprocessing', i.e.:
1. (Optional) Re-add a linear trend from the data if self.detrend is
Expand Down
4 changes: 2 additions & 2 deletions pypsf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ def psf_warn(message: str):
info = inspect.getframeinfo(caller_frame)
orig_formatwarning = warnings.formatwarning
warnings.formatwarning = (lambda message, category,
filename, lineno, _: f"{filename}:{lineno}:{category.__name__}:{message}\n")
filename, lineno, _: f"{filename}:{lineno}:{category.__name__}:{message}\n")
warnings.warn_explicit(message, UserWarning, info.filename, info.lineno)
warnings.formatwarning = orig_formatwarning
warnings.formatwarning = orig_formatwarning

0 comments on commit 683038b

Please sign in to comment.