Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GLM predictor representation with pytrees #38

Open
BalzaniEdoardo opened this issue Aug 7, 2023 · 5 comments
Open

GLM predictor representation with pytrees #38

BalzaniEdoardo opened this issue Aug 7, 2023 · 5 comments

Comments

@BalzaniEdoardo
Copy link
Collaborator

How do we represent neuron-specific model matrices?

One way is to use a tensor,
$$X = [X]_{tij}$$
where t are the samples, $i$ are the parameter indices and $j$ are the neuron indices.

This works but most of the predictors will be shared (every experimental input will probably have the same set of basis for all the neurons, the coupling filters between a neuron all the other neurons will likely share the same basis). Is there a way to represent this efficiently?

Using a mask could be a way, a bit messier to implement, fitting a neuron at the time is another way, less efficient computationally but doesn't require creating the massive tensor.

@BalzaniEdoardo
Copy link
Collaborator Author

Another idea is to use pytrees, so that each neuron could have its own tree of parameters; this is more flexible and more in line with jax; we could use tree maps for evaluating the dot products with the model parameters

@BalzaniEdoardo BalzaniEdoardo changed the title GLM predictor representation GLM predictor representation with pytrees Dec 11, 2023
@billbrod
Copy link
Member

billbrod commented Dec 11, 2023

As part of making it all compatible with pytrees, there are several discussions of input-related issues in #41:

  • Base class currently expects params to be a tuple of length 2, why not just separate that into two explicit arguments?
  • Base class _check_input_dimensionality allows for x and y to be None, which seems weird. It's also GLM-specific right now, so maybe move it?
  • Similarly, Base class _check_and_convert_params

@BalzaniEdoardo
Copy link
Collaborator Author

If we change to pytree we should modify the proximal operator for group lasso; the mask could be done using tree_map, which will simplify the code.

@BalzaniEdoardo
Copy link
Collaborator Author

I add a snippet of code that is helpful for computing operation on trees with different dept:

import jax

x = {
    "outer_1": {
        "inner_1": {"stim": jax.numpy.array([0]), "coupl": jax.numpy.array([1])},
        "inner_2": {"stim": jax.numpy.array([2]), "coupl": jax.numpy.array([3])}
        },
    "outer_2": {
        "inner_1": {"stim": jax.numpy.array([4]), "coupl": jax.numpy.array([5])},
        "inner_2": {"stim": jax.numpy.array([6]), "coupl": jax.numpy.array([7])}
    }
}

b = {"stim": jax.numpy.array([10]), "coupl": jax.numpy.array([20])}
res = jax.tree_util.tree_map(
    lambda subtree: jax.tree_util.tree_map(sum, subtree, b),
    x,
    is_leaf=lambda xsub: jax.tree_util.tree_structure(xsub) == jax.tree_util.tree_structure(b)
)
print(res)

@BalzaniEdoardo
Copy link
Collaborator Author

BalzaniEdoardo commented Dec 20, 2023

I thought of a way to reduce a deep tree to a sub-tree with given structure, I don't want to forget this, so I am posting here. The issue is similar to the one before, the solution is different and probably can be polished.

Let's say we have a nested tree, and another tree with a different depth, and we wish to perform an operation on the deeper tree, and after the operation is performed, the tree is reduced to the same structure of the shallower one.

A concrete example is: we have an output tree: y = {neuron1: data_y1, neuron2: data_y2...}, where data_yi are spike counts,
and we have an input tree specific for each neuron, X = {input1: {neuron1: data_x1, neuron2: data_x2}, input2: {neuron3: data_x3, neuron4: data_x4}}.

The operation could be transforming the input into rate for each neuron ( a transformation on the leaf of X) and a tree_map operation over the leafs ys

Here is a code snippet that solves the performs the operation

import nemos.utils as utils
import jax

tree1 = {'aa': {'a': 1, 'b': 2}, 'bb': {'a': 3, 'b': 4},  'cc': {'a': 5, 'b': 6}}
tree2 = {"a":100, "b":1000}

# operation to be performed on leaves
operation_on_leaves = lambda y: y**2

# condition to asses which layer constitute a leaf
asses_leaves = lambda x: jax.tree_util.tree_structure(x) == jax.tree_util.tree_structure(tree2)

# this apply the operation, then flatten to a list of trees such that any 
# tree in the list has the same structure as tree2
flat_array = lambda x: jax.tree_util.tree_flatten(jax.tree_map(operation_on_leaves , x), 
    is_leaf= asses_leaves)[0]

# finally we can apply that performs the reduction step
# returning {‘a': 35, ‘b': 56}
print( jax.tree_map(lambda *x: sum(x), *flat_array(tree1)))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants