Skip to content

Commit

Permalink
removed mask from proximal operator
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Nov 1, 2023
1 parent c340dd2 commit 448d370
Showing 1 changed file with 55 additions and 62 deletions.
117 changes: 55 additions & 62 deletions src/neurostatslib/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,6 @@ class ProxGradientSolver(Solver, abc.ABC):
----------
allowed_optimizers : List[...,str]
A list of optimizer names that are allowed to be used with this solver.
mask : Optional[Union[NDArray, jnp.ndarray]]
An optional mask array for element-wise operations. Shape (n_groups, n_features)
"""

allowed_optimizers = ["ProximalGradient"]
Expand All @@ -379,65 +377,11 @@ def __init__(
solver_name: str,
solver_kwargs: Optional[dict] = None,
regularizer_strength: float = 1.0,
mask: Optional[Union[NDArray, jnp.ndarray]] = None,
**kwargs
):
super().__init__(solver_name, solver_kwargs=solver_kwargs)
self.mask = mask
self.regularizer_strength = regularizer_strength

@property
def mask(self):
return self._mask

@mask.setter
def mask(self, mask: jnp.ndarray):
self._check_mask(mask)
self._mask = mask

@staticmethod
def _check_mask(mask: jnp.ndarray):
"""
Validate the mask array.
This method ensures the mask adheres to requirements:
- It should be 2-dimensional.
- Each element must be either 0 or 1.
- Each feature should belong to only one group.
- The mask should not be empty.
- The mask is an array of float type.
Raises
------
ValueError
If any of the above conditions are not met.
"""
if mask.ndim != 2:
raise ValueError(
"`mask` must be 2-dimensional. "
f"{mask.ndim} dimensional mask provided instead!"
)

if mask.shape[0] == 0:
raise ValueError(f"Empty mask provided! Mask has shape {mask.shape}.")

if jnp.any((mask != 1) & (mask != 0)):
raise ValueError("Mask elements be 0s and 1s!")

if mask.sum() == 0:
raise ValueError("Empty mask provided!")

if jnp.any(mask.sum(axis=0) > 1):
raise ValueError(
"Incorrect group assignment. Some of the features are assigned "
"to more then one group."
)

if not jnp.issubdtype(mask.dtype, jnp.floating):
raise ValueError(
"Mask should be a floating point jnp.ndarray. "
f"Data type {mask.dtype} provided instead!"
)

@abc.abstractmethod
def get_prox_operator(
self,
Expand Down Expand Up @@ -499,12 +443,10 @@ def __init__(
solver_name: str = "ProximalGradient",
solver_kwargs: Optional[dict] = None,
regularizer_strength: float = 1.0,
mask: Optional[Union[NDArray, jnp.ndarray]] = None,
):
super().__init__(
solver_name,
solver_kwargs=solver_kwargs,
mask=mask,
solver_kwargs=solver_kwargs
)
self.regularizer_strength = regularizer_strength

Expand Down Expand Up @@ -564,11 +506,62 @@ def __init__(
super().__init__(
solver_name,
solver_kwargs=solver_kwargs,
mask=mask,
)
self.regularizer_strength = regularizer_strength
mask = jnp.asarray(mask)
self.mask = jnp.asarray(mask)

@property
def mask(self):
return self._mask

@mask.setter
def mask(self, mask: jnp.ndarray):
self._check_mask(mask)
self._mask = mask

@staticmethod
def _check_mask(mask: jnp.ndarray):
"""
Validate the mask array.
This method ensures the mask adheres to requirements:
- It should be 2-dimensional.
- Each element must be either 0 or 1.
- Each feature should belong to only one group.
- The mask should not be empty.
- The mask is an array of float type.
Raises
------
ValueError
If any of the above conditions are not met.
"""
if mask.ndim != 2:
raise ValueError(
"`mask` must be 2-dimensional. "
f"{mask.ndim} dimensional mask provided instead!"
)

if mask.shape[0] == 0:
raise ValueError(f"Empty mask provided! Mask has shape {mask.shape}.")

if jnp.any((mask != 1) & (mask != 0)):
raise ValueError("Mask elements be 0s and 1s!")

if mask.sum() == 0:
raise ValueError("Empty mask provided!")

if jnp.any(mask.sum(axis=0) > 1):
raise ValueError(
"Incorrect group assignment. Some of the features are assigned "
"to more then one group."
)

if not jnp.issubdtype(mask.dtype, jnp.floating):
raise ValueError(
"Mask should be a floating point jnp.ndarray. "
f"Data type {mask.dtype} provided instead!"
)

def get_prox_operator(
self,
Expand Down

0 comments on commit 448d370

Please sign in to comment.