Skip to content

Commit

Permalink
FEAT add two-stage refine
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewSZhang committed Nov 18, 2024
1 parent 1d007de commit bbb7678
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
run: |
FMT=xml pixi run test-coverage
- name: Upload coverage reports to Codecov
uses: codecov/[email protected].0
uses: codecov/[email protected].2
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Build SDist
Expand Down
1 change: 1 addition & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ API Reference
:toctree: generated/

FastCan
refine
ssc
ols

Expand Down
2 changes: 2 additions & 0 deletions fastcan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""

from ._fastcan import FastCan
from ._refine import refine
from ._utils import ols, ssc

__all__ = [
"FastCan",
"ssc",
"ols",
"refine",
]
61 changes: 46 additions & 15 deletions fastcan/_fastcan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Feature selection
"""

from copy import deepcopy
from numbers import Integral, Real

import numpy as np
Expand Down Expand Up @@ -66,15 +67,15 @@ class FastCan(SelectorMixin, BaseEstimator):
The h-correlation/eta-cosine of selected features. The order of
the scores is corresponding to the feature selection process.
X_transformed_ : ndarray of shape (n_samples_, n_features), dtype=float, order='F'
X_transformed_ : ndarray of shape (`n_samples_`, n_features), dtype=float, order='F'
Transformed feature matrix.
When h-correlation method is used, n_samples_ = n_samples.
When eta-cosine method is used, n_samples_ = n_features+n_outputs.
When h-correlation method is used, `n_samples_` = n_samples.
When eta-cosine method is used, `n_samples_` = n_features+n_outputs.
y_transformed_ : ndarray of shape (n_samples_, n_outputs), dtype=float, order='F'
y_transformed_ : ndarray of shape (`n_samples_`, n_outputs), dtype=float, order='F'
Transformed target matrix.
When h-correlation method is used, n_samples_ = n_samples.
When eta-cosine method is used, n_samples_ = n_features+n_outputs.
When h-correlation method is used, `n_samples_` = n_samples.
When eta-cosine method is used, `n_samples_` = n_features+n_outputs.
References
----------
Expand Down Expand Up @@ -181,8 +182,26 @@ def fit(self, X, y):
raise ValueError(
"`eta` cannot be True, when n_samples < n_features+n_outputs."
)
indices_include = self._check_indices_params(self.indices_include, n_features)
indices_exclude = self._check_indices_params(self.indices_exclude, n_features)
self.indices_include_ = self._check_indices_params(
self.indices_include, n_features
)
self.indices_exclude_ = self._check_indices_params(
self.indices_exclude, n_features
)
if np.intersect1d(self.indices_include_, self.indices_exclude_).size != 0:
raise ValueError(
"`indices_include` and `indices_exclude` should not have intersection."
)

n_candidates = (
n_features - self.indices_exclude_.size - self.n_features_to_select
)
if n_candidates < 0:
raise ValueError(
"n_features - n_features_to_select - n_exclusions should >= 0."
)
if self.n_features_to_select - self.indices_include_.size < 0:
raise ValueError("n_features_to_select - n_inclusions should >= 0.")

if self.eta:
xy_hstack = np.hstack((X, y))
Expand All @@ -198,16 +217,16 @@ def fit(self, X, y):
self.X_transformed_ = X - X.mean(0)
self.y_transformed_ = orth(y - y.mean(0))

# initiated with -1
indices = np.full(self.n_features_to_select, -1, dtype=np.intc, order="F")
indices[: indices_include.size] = indices_include
scores = np.zeros(self.n_features_to_select, dtype=float, order="F")
mask = np.zeros(n_features, dtype=np.ubyte, order="F")
mask[indices_exclude] = True
indices, scores, mask = _prepare_search(
n_features,
self.n_features_to_select,
self.indices_include_,
self.indices_exclude_,
)

n_threads = _openmp_effective_n_threads()
_forward_search(
X=self.X_transformed_,
X=deepcopy(self.X_transformed_),
V=self.y_transformed_,
t=self.n_features_to_select,
tol=self.tol,
Expand Down Expand Up @@ -259,3 +278,15 @@ def _check_indices_params(self, indices_params, n_features):
def _get_support_mask(self):
check_is_fitted(self)
return self.support_


def _prepare_search(n_features, n_features_to_select, indices_include, indices_exclude):
""" """
# initiated with -1
indices = np.full(n_features_to_select, -1, dtype=np.intc, order="F")
indices[: indices_include.size] = indices_include
scores = np.zeros(n_features_to_select, dtype=float, order="F")
mask = np.zeros(n_features, dtype=np.ubyte, order="F")
mask[indices_exclude] = True

return indices, scores, mask
168 changes: 168 additions & 0 deletions fastcan/_refine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""
Refine fastcan selection results
"""

from copy import deepcopy
from numbers import Integral

import numpy as np
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
from sklearn.utils._param_validation import Interval, StrOptions, validate_params
from sklearn.utils.validation import check_is_fitted

from ._cancorr_fast import _forward_search # type: ignore
from ._fastcan import FastCan, _prepare_search


@validate_params(
{
"selector": [FastCan],
"drop": [
Interval(Integral, 1, None, closed="left"),
StrOptions({"all"}),
"array-like",
],
"max_iter": [
None,
Interval(Integral, 1, None, closed="left"),
],
"verbose": ["verbose"],
},
prefer_skip_nested_validation=True,
)
def refine(selector, drop=1, max_iter=None, verbose=1):
"""Two-Stage Refining.
In the refining process, the selected features will be dropped, and
the vacancy positions will be refilled from the candidate features.
The processing of a vacany position is refilled after searching all
candidate features is called an `iteration`.
The processing of a vacany position is refilled by a different features
from the dropped one, which increase the SSC of the selected features
is called a `valid iteration`.
Parameters
----------
selector : FastCan
FastCan selector.
drop : int or array-like of shape (n_drops,) or "all", default=1
The number of the selected features dropped for the consequencing
reselection.
max_iter : int, default=None
The maximum number of valid iterations in the refining process.
verbose : int, default=1
The verbosity level.
Returns
-------
indices : ndarray of shape (n_features_to_select,), dtype=int
The indices of the selected features.
scores : ndarray of shape (n_features_to_select,), dtype=float
The h-correlation/eta-cosine of selected features.
References
----------
* Zhang L., Li K., Bai E. W. and Irwin G. W. (2015).
Two-stage orthogonal least squares methods for neural network construction.
IEEE Transactions on Neural Networks and Learning Systems, 26(8), 1608-1621.
Examples
--------
>>> from fastcan import FastCan, refine
>>> X = [[1, 1, 0], [0.01, 0, 0], [-1, 0, 1], [0, 0, 0]]
>>> y = [1, 0, -1, 0]
>>> selector = FastCan(2, verbose=0).fit(X, y)
>>> print(f"Indices: {selector.indices_}", f", SSC: {selector.scores_.sum():.5f}")
Indices: [0 1] , SSC: 0.99998
>>> indices, scores = refine(selector, drop=1, verbose=0)
>>> print(f"Indices: {indices}", f", SSC: {scores.sum():.5f}")
Indices: [1 2] , SSC: 1.00000
"""
check_is_fitted(selector)
X_transformed_ = deepcopy(selector.X_transformed_)
n_features = selector.n_features_in_
n_features_to_select = selector.n_features_to_select
indices_include = selector.indices_include_
indices_exclude = selector.indices_exclude_

n_inclusions = indices_include.size
n_selections = n_features_to_select - n_inclusions

if drop == "all":
drop = np.arange(1, n_selections)
else:
drop = np.atleast_1d(drop).astype(int)

if (drop.max() >= n_selections) or (drop.min() < 1):
raise ValueError(
"`drop` should be between `1<=drop<n_features_to_select-n_inclusions`, "
f"but got drop={drop} and n_selections={n_selections}."
)

if max_iter is None:
max_iter = np.inf

n_iters = 0
n_valid_iters = 0
best_scores = selector.scores_
best_indices = selector.indices_
best_ssc = selector.scores_.sum()
indices_temp = best_indices
for drop_n in drop:
i = 0
while i < n_features:
rolled_indices = np.r_[
indices_include, np.roll(indices_temp[n_inclusions:], -1)
]
indices, scores, mask = _prepare_search(
n_features,
n_features_to_select,
rolled_indices[:-drop_n],
indices_exclude,
)
n_threads = _openmp_effective_n_threads()
_forward_search(
X=X_transformed_,
V=selector.y_transformed_,
t=selector.n_features_to_select,
tol=selector.tol,
num_threads=n_threads,
verbose=0,
mask=mask,
indices=indices,
scores=scores,
)

if (scores.sum() > best_ssc) and (set(indices) != set(best_indices)):
i = 0
n_valid_iters += 1
best_scores = scores
best_indices = indices
best_ssc = scores.sum()
else:
i += 1

indices_temp = indices
n_iters += 1
if verbose == 1:
print(
f"No. of iterations: {n_iters}, "
f"No. of valid iterations {n_valid_iters}, "
f"SSC: {best_scores.sum():.5f}",
end="\r",
)

if n_iters >= max_iter:
if verbose == 1:
print()
return best_indices, best_scores

if verbose == 1:
print()
return best_indices, best_scores
23 changes: 23 additions & 0 deletions tests/test_fastcan.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,20 @@ def test_raise_errors():
indices_include=[[0]]
)

selector_include_exclude_intersect = FastCan(
n_features_to_select=n_features,
indices_include=[0, 1],
indices_exclude=[1, 2],
)
selector_n_candidates = FastCan(
n_features_to_select=n_features,
indices_exclude=[1, 2],
)
selector_too_many_inclusions = FastCan(
n_features_to_select=2,
indices_include=[1, 2, 3],
)

with pytest.raises(ValueError, match=r"n_features_to_select .*"):
selector_n_select.fit(X, y)

Expand All @@ -214,6 +228,15 @@ def test_raise_errors():
with pytest.raises(ValueError, match=r"`eta` cannot be True, .*"):
selector_eta_for_small_size_samples.fit(X, y)

with pytest.raises(ValueError, match=r"`indices_include` and `indices_exclu.*"):
selector_include_exclude_intersect.fit(X, y)

with pytest.raises(ValueError, match=r"n_features - n_features_to_select - n_e.*"):
selector_n_candidates.fit(X, y)

with pytest.raises(ValueError, match=r"n_features_to_select - n_inclusions sho.*"):
selector_too_many_inclusions.fit(X, y)


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
def test_cython_errors():
Expand Down
Loading

0 comments on commit bbb7678

Please sign in to comment.