Skip to content

Commit

Permalink
Merge pull request #19 from MatthewSZhang/exclude
Browse files Browse the repository at this point in the history
FEAT add indices_exclude params
  • Loading branch information
MatthewSZhang authored Nov 15, 2024
2 parents 6eda8c7 + a74154a commit 0840101
Show file tree
Hide file tree
Showing 4 changed files with 807 additions and 797 deletions.
11 changes: 6 additions & 5 deletions fastcan/_cancorr_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ from cython.parallel import prange
from scipy.linalg.cython_blas cimport isamax, idamax
from sklearn.utils._cython_blas cimport ColMajor, NoTrans
from sklearn.utils._cython_blas cimport _dot, _scal, _nrm2, _gemm, _axpy
from sklearn.utils._typedefs cimport int32_t
from sklearn.utils._typedefs cimport int32_t, uint8_t


@final
cdef int _bsum(
const bint* x,
const uint8_t* x,
int n,
) noexcept nogil:
"""Computes the sum of the vector of bool elements.
Expand Down Expand Up @@ -129,6 +129,7 @@ cpdef int _forward_search(
floating tol, # IN
int num_threads, # IN
int verbose, # IN
uint8_t[::1] mask, # IN/TEMP
int32_t[::1] indices, # OUT
floating[::1] scores, # OUT
) except -1 nogil:
Expand All @@ -140,6 +141,7 @@ cpdef int _forward_search(
is orthonormal to selected features and M.
t : Non-negative integer. The number of features to be selected.
tol : Tolerance for linear dependence check.
mask (n_features, ) Mask for candidate features.
indices: (t,) The indices vector of selected features, initiated with -1.
scores: (t,) The h-correlation/eta-cosine of selected features.
"""
Expand All @@ -149,7 +151,6 @@ cpdef int _forward_search(
# OpenMP (in Windows) requires signed integral for prange
int n_features = X.shape[1]
floating* r2 = <floating*> malloc(sizeof(floating) * n_features)
bint* mask = <bint*> malloc(sizeof(bint) * n_features)
floating g, ssc = 0.0
int i, j
int index = -1
Expand All @@ -160,7 +161,8 @@ cpdef int _forward_search(
if i == 0:
# Preprocessing
for j in range(n_features):
mask[j] = _normv(&X[0, j], n_samples)
if not mask[j]:
mask[j] = _normv(&X[0, j], n_samples)
else:
mask[index] = True
r2[index] = 0
Expand Down Expand Up @@ -204,5 +206,4 @@ cpdef int _forward_search(
with gil:
print()
free(r2)
free(mask)
return 0
124 changes: 60 additions & 64 deletions fastcan/_fastcan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class FastCan(SelectorMixin, BaseEstimator):
indices_include : array-like of shape (n_inclusions,), default=None
The indices of the prerequisite features.
indices_exclude : array-like of shape (n_exclusions,), default=None
The indices of the excluded features.
eta : bool, default=False
Whether to use eta-cosine method.
Expand Down Expand Up @@ -63,6 +66,16 @@ class FastCan(SelectorMixin, BaseEstimator):
The h-correlation/eta-cosine of selected features. The order of
the scores is corresponding to the feature selection process.
X_transformed_ : ndarray of shape (n_samples_, n_features), dtype=float, order='F'
Transformed feature matrix.
When h-correlation method is used, n_samples_ = n_samples.
When eta-cosine method is used, n_samples_ = n_features+n_outputs.
y_transformed_ : ndarray of shape (n_samples_, n_outputs), dtype=float, order='F'
Transformed target matrix.
When h-correlation method is used, n_samples_ = n_samples.
When eta-cosine method is used, n_samples_ = n_features+n_outputs.
References
----------
* Zhang, S., & Lang, Z. Q. (2022).
Expand All @@ -88,6 +101,7 @@ class FastCan(SelectorMixin, BaseEstimator):
Interval(Integral, 1, None, closed="left"),
],
"indices_include": [None, "array-like"],
"indices_exclude": [None, "array-like"],
"eta": ["boolean"],
"tol": [Interval(Real, 0, None, closed="neither")],
"verbose": ["verbose"],
Expand All @@ -97,12 +111,14 @@ def __init__(
self,
n_features_to_select=1,
indices_include=None,
indices_exclude=None,
eta=False,
tol=0.01,
verbose=1,
):
self.n_features_to_select = n_features_to_select
self.indices_include = indices_include
self.indices_exclude = indices_exclude
self.eta = eta
self.tol = tol
self.verbose = verbose
Expand Down Expand Up @@ -152,17 +168,6 @@ def fit(self, X, y):
# [:, np.newaxis] that does not.
y = y.reshape(-1, 1)

# indices_include
if self.indices_include is None:
indices_include = np.zeros(0, dtype=int)
else:
indices_include = check_array(
self.indices_include,
ensure_2d=False,
dtype=int,
ensure_min_samples=0,
)

n_samples, n_features = X.shape
n_outputs = y.shape[1]

Expand All @@ -172,29 +177,12 @@ def fit(self, X, y):
f"must be <= n_features {n_features}."
)

if indices_include.ndim != 1:
raise ValueError(
f"Found indices_include with dim {indices_include.ndim}, "
"but expected == 1."
)

if indices_include.size >= n_features:
raise ValueError(
f"n_inclusions {indices_include.size} must "
f"be < n_features {n_features}."
)

if np.any((indices_include < 0) | (indices_include >= n_features)):
raise ValueError(
"Out of bounds. "
f"All items in indices_include should be in [0, {n_features}). "
f"But got indices_include = {indices_include}."
)

if (n_samples < n_features + n_outputs) and self.eta:
raise ValueError(
"`eta` cannot be True, when n_samples < n_features+n_outputs."
)
indices_include = self._check_indices_params(self.indices_include, n_features)
indices_exclude = self._check_indices_params(self.indices_exclude, n_features)

if self.eta:
xy_hstack = np.hstack((X, y))
Expand All @@ -204,23 +192,28 @@ def fit(self, X, y):
)[1:]
qxy_transformed = singular_values.reshape(-1, 1) * unitary_arrays
qxy_transformed = np.asfortranarray(qxy_transformed)
X_transformed = qxy_transformed[:, :n_features]
y_transformed = orth(qxy_transformed[:, n_features:])
self.X_transformed_ = qxy_transformed[:, :n_features]
self.y_transformed_ = orth(qxy_transformed[:, n_features:])
else:
X_transformed = X - X.mean(0)
y_transformed = orth(y - y.mean(0))
self.X_transformed_ = X - X.mean(0)
self.y_transformed_ = orth(y - y.mean(0))

# initiated with -1
indices = np.full(self.n_features_to_select, -1, dtype=np.intc, order="F")
indices[: indices_include.size] = indices_include
scores = np.zeros(self.n_features_to_select, dtype=float, order="F")
mask = np.zeros(n_features, dtype=np.ubyte, order="F")
mask[indices_exclude] = True

indices, scores = self._prepare_data(
indices_include,
)
n_threads = _openmp_effective_n_threads()
_forward_search(
X=X_transformed,
V=y_transformed,
X=self.X_transformed_,
V=self.y_transformed_,
t=self.n_features_to_select,
tol=self.tol,
num_threads=n_threads,
verbose=self.verbose,
mask=mask,
indices=indices,
scores=scores,
)
Expand All @@ -231,34 +224,37 @@ def fit(self, X, y):
self.scores_ = scores
return self

def _prepare_data(self, indices_include):
"""Prepare data for _forward_search()
When h-correlation method is used, n_samples_ = n_samples.
When eta-cosine method is used, n_samples_ = n_features+n_outputs.
Parameters
----------
indices_include : array-like of shape (n_inclusions,), dtype=int
The indices of the prerequisite features.
def _check_indices_params(self, indices_params, n_features):
"""Check indices_include or indices_exclude."""
if indices_params is None:
indices_params = np.zeros(0, dtype=int)
else:
indices_params = check_array(
indices_params,
ensure_2d=False,
dtype=int,
ensure_min_samples=0,
)

Returns
-------
mask : ndarray of shape (n_features,), dtype=np.ubyte, order='F'
Mask for invalid candidate features.
The data type is unsigned char.
if indices_params.ndim != 1:
raise ValueError(
f"Found indices_params with dim {indices_params.ndim}, "
"but expected == 1."
)

indices: ndarray of shape (n_features_to_select,), dtype=np.intc, order='F'
The indices vector of selected features, initiated with -1.
The data type is signed int.
if indices_params.size >= n_features:
raise ValueError(
f"The number of indices in indices_params {indices_params.size} must "
f"be < n_features {n_features}."
)

scores: ndarray of shape (n_features_to_select,), dtype=float, order='F'
The h-correlation/eta-cosine of selected features.
"""
# initiated with -1
indices = np.full(self.n_features_to_select, -1, dtype=np.intc, order="F")
indices[: indices_include.size] = indices_include
scores = np.zeros(self.n_features_to_select, dtype=float, order="F")
return indices, scores
if np.any((indices_params < 0) | (indices_params >= n_features)):
raise ValueError(
"Out of bounds. "
f"All items in indices_params should be in [0, {n_features}). "
f"But got indices_params = {indices_params}."
)
return indices_params

def _get_support_mask(self):
check_is_fitted(self)
Expand Down
Loading

0 comments on commit 0840101

Please sign in to comment.