diff --git a/src/neurostatslib/solver.py b/src/neurostatslib/solver.py index 05b043f5..d3f8ac74 100644 --- a/src/neurostatslib/solver.py +++ b/src/neurostatslib/solver.py @@ -368,8 +368,6 @@ class ProxGradientSolver(Solver, abc.ABC): ---------- allowed_optimizers : List[...,str] A list of optimizer names that are allowed to be used with this solver. - mask : Optional[Union[NDArray, jnp.ndarray]] - An optional mask array for element-wise operations. Shape (n_groups, n_features) """ allowed_optimizers = ["ProximalGradient"] @@ -379,65 +377,11 @@ def __init__( solver_name: str, solver_kwargs: Optional[dict] = None, regularizer_strength: float = 1.0, - mask: Optional[Union[NDArray, jnp.ndarray]] = None, + **kwargs ): super().__init__(solver_name, solver_kwargs=solver_kwargs) - self.mask = mask self.regularizer_strength = regularizer_strength - @property - def mask(self): - return self._mask - - @mask.setter - def mask(self, mask: jnp.ndarray): - self._check_mask(mask) - self._mask = mask - - @staticmethod - def _check_mask(mask: jnp.ndarray): - """ - Validate the mask array. - - This method ensures the mask adheres to requirements: - - It should be 2-dimensional. - - Each element must be either 0 or 1. - - Each feature should belong to only one group. - - The mask should not be empty. - - The mask is an array of float type. - - Raises - ------ - ValueError - If any of the above conditions are not met. - """ - if mask.ndim != 2: - raise ValueError( - "`mask` must be 2-dimensional. " - f"{mask.ndim} dimensional mask provided instead!" - ) - - if mask.shape[0] == 0: - raise ValueError(f"Empty mask provided! Mask has shape {mask.shape}.") - - if jnp.any((mask != 1) & (mask != 0)): - raise ValueError("Mask elements be 0s and 1s!") - - if mask.sum() == 0: - raise ValueError("Empty mask provided!") - - if jnp.any(mask.sum(axis=0) > 1): - raise ValueError( - "Incorrect group assignment. Some of the features are assigned " - "to more then one group." - ) - - if not jnp.issubdtype(mask.dtype, jnp.floating): - raise ValueError( - "Mask should be a floating point jnp.ndarray. " - f"Data type {mask.dtype} provided instead!" - ) - @abc.abstractmethod def get_prox_operator( self, @@ -499,12 +443,10 @@ def __init__( solver_name: str = "ProximalGradient", solver_kwargs: Optional[dict] = None, regularizer_strength: float = 1.0, - mask: Optional[Union[NDArray, jnp.ndarray]] = None, ): super().__init__( solver_name, - solver_kwargs=solver_kwargs, - mask=mask, + solver_kwargs=solver_kwargs ) self.regularizer_strength = regularizer_strength @@ -564,11 +506,62 @@ def __init__( super().__init__( solver_name, solver_kwargs=solver_kwargs, - mask=mask, ) self.regularizer_strength = regularizer_strength - mask = jnp.asarray(mask) + self.mask = jnp.asarray(mask) + + @property + def mask(self): + return self._mask + + @mask.setter + def mask(self, mask: jnp.ndarray): self._check_mask(mask) + self._mask = mask + + @staticmethod + def _check_mask(mask: jnp.ndarray): + """ + Validate the mask array. + + This method ensures the mask adheres to requirements: + - It should be 2-dimensional. + - Each element must be either 0 or 1. + - Each feature should belong to only one group. + - The mask should not be empty. + - The mask is an array of float type. + + Raises + ------ + ValueError + If any of the above conditions are not met. + """ + if mask.ndim != 2: + raise ValueError( + "`mask` must be 2-dimensional. " + f"{mask.ndim} dimensional mask provided instead!" + ) + + if mask.shape[0] == 0: + raise ValueError(f"Empty mask provided! Mask has shape {mask.shape}.") + + if jnp.any((mask != 1) & (mask != 0)): + raise ValueError("Mask elements be 0s and 1s!") + + if mask.sum() == 0: + raise ValueError("Empty mask provided!") + + if jnp.any(mask.sum(axis=0) > 1): + raise ValueError( + "Incorrect group assignment. Some of the features are assigned " + "to more then one group." + ) + + if not jnp.issubdtype(mask.dtype, jnp.floating): + raise ValueError( + "Mask should be a floating point jnp.ndarray. " + f"Data type {mask.dtype} provided instead!" + ) def get_prox_operator( self,