Skip to content

Commit

Permalink
update to main
Browse files Browse the repository at this point in the history
  • Loading branch information
pranmod01 committed Oct 18, 2024
2 parents e56e475 + 64ea184 commit 5d4617b
Show file tree
Hide file tree
Showing 20 changed files with 1,576 additions and 374 deletions.
29 changes: 26 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
- main

jobs:
tox:
tox_tests:
if: ${{ !github.event.pull_request.draft }}
strategy:
matrix:
Expand All @@ -38,7 +38,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .
pip install tox
- name: Run tox
Expand All @@ -49,6 +48,27 @@ jobs:
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

tox_check:
if: ${{ !github.event.pull_request.draft }}
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4 # Use v4 for compatibility with pyproject.toml
with:
python-version: 3.12
cache: pip

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox
- name: Run tox
run: tox -e check

prevent_docs_absolute_links:
runs-on: ubuntu-latest
steps:
Expand All @@ -63,7 +83,10 @@ jobs:

check:
if: ${{ !github.event.pull_request.draft }}
needs: tox
needs:
- tox_tests
- prevent_docs_absolute_links
- tox_check
runs-on: ubuntu-latest
steps:
- name: Decide whether all tests and notebooks succeeded
Expand Down
26 changes: 17 additions & 9 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ See below for details on how to [add tests](#adding-tests) and properly [documen

Lastly, you should make sure that the existing tests all run successfully and that the codebase is formatted properly:

> [!TIP]
> The [NeMoS GitHub action](.github/workflows/ci.yml) runs tests and some additional style checks in an isolated environment using [`tox`](https://tox.wiki/en/). `tox` is not included in our optional dependencies, so if you want to replicate the action workflow locally, you need to install `tox` via pip and then run it. From the package directory:
> ```sh
> pip install tox
> tox -e check,py
> ```
> This will execute `tox` with a Python version that matches your local environment. If the above passes, then the Github action will pass and your PR is mergeable
>
> You can also use `tox` to use `black` and `isort` to try and fix your code if either of those are failing. To do so, run `tox -e fix`
>
> `tox` configurations can be found in the [`tox.ini`](tox.ini) file.
```bash
# run tests and make sure they all pass
pytest tests/
Expand All @@ -105,7 +118,10 @@ pytest --doctest-modules src/nemos/
# format the code base
black src/
isort src
isort src --profile=black
isort docs/how_to_guide --profile=black
isort docs/background --profile=black
isort docs/tutorials --profile=black
flake8 --config=tox.ini src
```
Expand All @@ -129,14 +145,6 @@ changes.
Additionally, every PR to `main` or `development` will automatically run linters and tests through a [GitHub action](https://docs.github.com/en/actions). Merges can happen only when all check passes.
> [!NOTE]
> The [NeMoS GitHub action](.github/workflows/ci.yml) runs tests in an isolated environment using [`tox`](https://tox.wiki/en/). `tox` is not included in our optional dependencies, so if you want to replicate the action workflow locally, you need to install `tox` via pip and then run it. From the package directory:
> ```sh
> pip install tox
> tox -e py
> ```
> This will execute `tox` with a Python version that matches your local environment. `tox` configurations can be found in the [`tox.ini`](tox.ini) file.
Once your changes are integrated, you will be added as a GitHub contributor and as one of the authors of the package. Thank you for being part of `nemos`!
### Style Guide
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dev = [
"isort", # Import sorter
"pip-tools", # Dependency management
"pytest", # Testing framework
"pytest-xdist", # Parallelize pytest
"flake8", # Code linter
"coverage", # Test coverage measurement
"pytest-cov", # Test coverage plugin for pytest
Expand Down Expand Up @@ -112,6 +113,7 @@ profile = "black"
# Configure pytest
[tool.pytest.ini_options]
testpaths = ["tests"] # Specify the directory where test files are located
addopts = "-n auto"

[tool.coverage.run]
omit = [
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Base:
Additionally, it has methods for selecting target devices and sending arrays to them.
"""

def get_params(self, deep=True):
def get_params(self, deep=True) -> dict:
"""
From scikit-learn, get parameters by inspecting init.
Expand Down
37 changes: 30 additions & 7 deletions src/nemos/base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,23 @@ def solver_run(self) -> Union[None, SolverRun]:
"""
return self._solver_run

def set_params(self, **params: Any):
"""Manage warnings in case of multiple parameter settings."""
# if both regularizer and regularizer_strength are set, then only
# warn in case the strength is not expected for the regularizer type
if "regularizer" in params and "regularizer_strength" in params:
reg = params.pop("regularizer")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="Caution: regularizer strength.*"
"|Unused parameter `regularizer_strength`.*",
)
super().set_params(regularizer=reg)

return super().set_params(**params)

@property
def regularizer(self) -> Union[None, Regularizer]:
"""Getter for the regularizer attribute."""
Expand Down Expand Up @@ -171,19 +188,16 @@ def regularizer_strength(self) -> float:

@regularizer_strength.setter
def regularizer_strength(self, strength: Union[float, None]):
# if using unregularized, strength will be None no matter what
if isinstance(self._regularizer, UnRegularized):
self._regularizer_strength = None
# check regularizer strength
elif strength is None:
if strength is None and not isinstance(self._regularizer, UnRegularized):
warnings.warn(
UserWarning(
"Caution: regularizer strength has not been set. Defaulting to 1.0. Please see "
"the documentation for best practices in setting regularization strength."
)
)
self._regularizer_strength = 1.0
else:
strength = 1.0
elif strength is not None:
try:
# force conversion to float to prevent weird GPU issues
strength = float(strength)
Expand All @@ -192,7 +206,16 @@ def regularizer_strength(self, strength: Union[float, None]):
raise ValueError(
f"Could not convert the regularizer strength: {strength} to a float."
)
self._regularizer_strength = strength
if isinstance(self._regularizer, UnRegularized):
warnings.warn(
UserWarning(
"Unused parameter `regularizer_strength` for UnRegularized GLM. "
"The regularizer strength parameter is not required and won't be used when the regularizer "
"is set to UnRegularized."
)
)

self._regularizer_strength = strength

@property
def solver_name(self) -> str:
Expand Down
117 changes: 90 additions & 27 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def __sklearn_clone__(self) -> TransformerBasis:
return cloned_obj

def set_params(self, **parameters) -> TransformerBasis:
r"""
"""
Set TransformerBasis parameters.
When used with `sklearn.model_selection`, users can set either the `_basis` attribute directly
Expand Down Expand Up @@ -497,54 +497,67 @@ def __init__(
) -> None:
self.n_basis_funcs = n_basis_funcs
self._n_input_dimensionality = 0
self._check_n_basis_min()
self._conv_kwargs = kwargs
self.bounds = bounds

# check mode
if mode not in ["conv", "eval"]:
raise ValueError(
f"`mode` should be either 'conv' or 'eval'. '{mode}' provided instead!"
)
if mode == "conv":
if window_size is None:
raise ValueError(
"If the basis is in `conv` mode, you must provide a window_size!"
)
elif not (isinstance(window_size, int) and window_size > 0):
raise ValueError(
f"`window_size` must be a positive integer. {window_size} provided instead!"
)
if bounds is not None:
raise ValueError("`bounds` should only be set when `mode=='eval'`.")
else:
if kwargs:
raise ValueError(
f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!"
)

self._window_size = window_size
self._mode = mode
self.window_size = window_size
self.bounds = bounds

if mode == "eval" and kwargs:
raise ValueError(
f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!"
)

self.kernel_ = None
self._identifiability_constraints = False

@property
def n_basis_funcs(self):
return self._n_basis_funcs

@n_basis_funcs.setter
def n_basis_funcs(self, value):
orig_n_basis = copy.deepcopy(getattr(self, "_n_basis_funcs", None))
self._n_basis_funcs = value
try:
self._check_n_basis_min()
except ValueError as e:
self._n_basis_funcs = orig_n_basis
raise e

@property
def bounds(self):
return self._bounds

@bounds.setter
def bounds(self, values: Union[None, Tuple[float, float]]):
"""Setter for bounds."""

if values is not None and self.mode == "conv":
raise ValueError("`bounds` should only be set when `mode=='eval'`.")

if values is not None and len(values) != 2:
raise ValueError(
f"The provided `bounds` must be of length two. Length {len(values)} provided instead!"
)

# convert to float and store
try:
self._bounds = values if values is None else tuple(map(float, values))
except (ValueError, TypeError):
raise TypeError("Could not convert `bounds` to float.")

if values is not None and values[1] <= values[0]:
raise ValueError(
f"Invalid bound {values}. Lower bound is greater or equal than the upper bound."
)

@property
def mode(self):
return self._mode
Expand All @@ -553,6 +566,28 @@ def mode(self):
def window_size(self):
return self._window_size

@window_size.setter
def window_size(self, window_size):
"""Setter for the window size parameter."""
if self.mode == "eval":
if window_size:
raise ValueError(
"If basis is in `mode=='eval'`, `window_size` should be None."
)

else:
if window_size is None:
raise ValueError(
"If the basis is in `conv` mode, you must provide a window_size!"
)

elif not (isinstance(window_size, int) and window_size > 0):
raise ValueError(
f"`window_size` must be a positive integer. {window_size} provided instead!"
)

self._window_size = window_size

@property
def identifiability_constraints(self):
return self._identifiability_constraints
Expand Down Expand Up @@ -1016,7 +1051,7 @@ def to_transformer(self) -> TransformerBasis:
>>> # load some data
>>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30)
>>> basis = nmo.basis.RaisedCosineBasisLinear(10).to_transformer()
>>> glm = nmo.glm.GLM(regularizer="Ridge")
>>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.)
>>> pipeline = Pipeline([("basis", basis), ("glm", glm)])
>>> param_grid = dict(
... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
Expand Down Expand Up @@ -1296,10 +1331,34 @@ def __init__(
bounds=bounds,
**kwargs,
)

self._n_input_dimensionality = 1
if self.order < 1:

@property
def order(self):
return self._order

@order.setter
def order(self, value):
"""Setter for the order parameter."""

if value < 1:
raise ValueError("Spline order must be positive!")

# Set to None only the first time the setter is called.
orig_order = copy.deepcopy(getattr(self, "_order", None))

# Set the order
self._order = value

# If the order was already initialized, re-check basis
if orig_order is not None:
try:
self._check_n_basis_min()
except ValueError as e:
self._order = orig_order
raise e

def _generate_knots(
self,
sample_pts: NDArray,
Expand Down Expand Up @@ -2041,19 +2100,23 @@ def __init__(
# The samples are scaled appropriately in the self._transform_samples which scales
# and applies the log-stretch, no additional transform is needed.
self._rescale_samples = False
if time_scaling is None:
time_scaling = 50.0

self.time_scaling = time_scaling
self.enforce_decay_to_zero = enforce_decay_to_zero
if time_scaling is None:
self._time_scaling = 50.0
else:
self._check_time_scaling(time_scaling)
self._time_scaling = time_scaling

@property
def time_scaling(self):
"""Getter property for time_scaling."""
return self._time_scaling

@time_scaling.setter
def time_scaling(self, time_scaling):
"""Setter property for time_scaling."""
self._check_time_scaling(time_scaling)
self._time_scaling = time_scaling

@staticmethod
def _check_time_scaling(time_scaling: float) -> None:
if time_scaling <= 0:
Expand Down
Loading

0 comments on commit 5d4617b

Please sign in to comment.