Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset kernels automatically #278

Open
BalzaniEdoardo opened this issue Dec 9, 2024 · 1 comment
Open

Reset kernels automatically #278

BalzaniEdoardo opened this issue Dec 9, 2024 · 1 comment

Comments

@BalzaniEdoardo
Copy link
Collaborator

set_kernel and kernel computation

Changing window size or basis function requires a reset of the convolution kernels, which are of shape (window_size, n_basis_funcs). The set_kernel method is the method responsible to compute the kernels. This is meant to be part of the fit call of the TransformerBasis, which prepares the basis for the transform method (which maps to _compute_features).
Currently, if one changes window size or num basis after the kenels have been set, then the kernel shape would not match the basis parameters, and a call to transform (or _compute_features for regular bases) would not produce the expected results.

One way possible solution is to re-write the __setattr__ of basis and intercept the window size and basis setter:

# these are methods of nemos._basis.Basis
    def _recompute_kernels(self):
        """Recompute all kernels if needed.

        Traverse the tree upwards and reset all input-independent states.
        If the node is the root, directly update its states; otherwise, propagate
        the request to the parent node.
        """
        # Assumes that state updates in the basis tree can be handled independently for each node.
        # This is currently true but may change if dependencies are introduced.
        # The only such state is self.kernel_, which is set independently for each basis component.
        # If dependencies are introduced, use `self.set_kernel` at the root level instead.
        # (A basis is the tree root if self._parend is None).
        # Note: `self.set_kernel` is more expensive as it recomputes kernels for the entire tree.
        update_states = getattr(self, "_reset_all_input_independent_states", None)
        if update_states:
            update_states()
        if getattr(self, "_parent", None):
            self._parent._recompute_kernels()

    def _is_init_params_updated(self, name: str, value: Any):
        """Check if an attribute set at initialization have been updated."""
        return name in self._get_param_names()

    def __setattr__(self, name: str, value: Any):
        """
        Set to None all attributes ending with '_'.

        This __setattr__ resets all the attributes that are defined by a method
        like the `kernel_` or `_n_input_shape_` (states of the basis) when an initialization configuration
        is updated.
        A Basis class must respect the following naming convention: all names of parameters that are settable
        by with a method (like `kernel_` computed in `set_kernel`) must end in "_".

        Parameters
        ----------
        name :
            The name of the attribute to set.
        value :
            The value to set the attribute to.
        """
        # check if the attribute was defined in the __init__ signature
        # and if so, then resets all computable states.
        super().__setattr__(name, value)
        if self._is_init_params_updated(name, value):
            self._recompute_kernels()

And in the ConvBasisMixin

    def _reset_all_input_independent_states(self):
        """Set all states that are input independent for self only.

        This method sets all the input independent states. This reimplements an abstract method
        of basis, and it is different from ``set_kernel`` because it won't traverse the basis
        tree in any basis (including composite basis), while ``set_kernel`` applies to all the tree.
        Called by the setattr of basis.
        """
        current_kernel = getattr(self, "kernel_", None)
        try:
            self.kernel_ = (
                current_kernel
                if current_kernel is None
                else self._evaluate(np.linspace(0, 1, self.window_size))
            )
        except Exception as e:
            # if basis not fully initialized attribute is not there yet.
            kernel = getattr(self, "kernel_", None)
            if kernel is not None:
                warnings.warn(
                    message=f"Unable to re-initialize the kernel for basis {self.label}, "
                    f"with exception:\n{repr(e)}. \n"
                    f"Resetting the kernel `None`.",
                    category=UserWarning,
                )
                self.kernel_ = None
@BalzaniEdoardo
Copy link
Collaborator Author

    @pytest.mark.parametrize(
        "window_size_old, window_size_new",
        [
            (10, 11),
        ]

    )
    @pytest.mark.parametrize(
        "n_basis_old, n_basis_new",
        [
            (5, 6),
        ]

    )
    def test_reset_kernel(self, window_size_old, window_size_new, n_basis_old, n_basis_new, cls):
        bas = cls["conv"](n_basis_funcs=n_basis_old, window_size=window_size_old, **extra_decay_rates(cls["eval"], n_basis_old))
        assert bas.kernel_ is None
        bas.set_kernel()
        assert bas.kernel_.shape == (window_size_old, n_basis_old)
        bas.window_size = window_size_new
        assert bas.kernel_.shape == (window_size_new, n_basis_old)
        if not isinstance(bas, OrthExponentialBasis):
            bas.n_basis_funcs = n_basis_new
            assert bas.kernel_.shape == (window_size_new, n_basis_new)
        else:
            with pytest.raises(ValueError, match="The number of basis functions "):
                bas.n_basis_funcs = n_basis_new

Test for reset kernel to be added to shared tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant