-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GLM predictor representation with pytrees #38
Comments
Another idea is to use pytrees, so that each neuron could have its own tree of parameters; this is more flexible and more in line with jax; we could use tree maps for evaluating the dot products with the model parameters |
As part of making it all compatible with pytrees, there are several discussions of input-related issues in #41:
|
If we change to pytree we should modify the proximal operator for group lasso; the mask could be done using tree_map, which will simplify the code. |
I add a snippet of code that is helpful for computing operation on trees with different dept: import jax
x = {
"outer_1": {
"inner_1": {"stim": jax.numpy.array([0]), "coupl": jax.numpy.array([1])},
"inner_2": {"stim": jax.numpy.array([2]), "coupl": jax.numpy.array([3])}
},
"outer_2": {
"inner_1": {"stim": jax.numpy.array([4]), "coupl": jax.numpy.array([5])},
"inner_2": {"stim": jax.numpy.array([6]), "coupl": jax.numpy.array([7])}
}
}
b = {"stim": jax.numpy.array([10]), "coupl": jax.numpy.array([20])}
res = jax.tree_util.tree_map(
lambda subtree: jax.tree_util.tree_map(sum, subtree, b),
x,
is_leaf=lambda xsub: jax.tree_util.tree_structure(xsub) == jax.tree_util.tree_structure(b)
)
print(res) |
I thought of a way to reduce a deep tree to a sub-tree with given structure, I don't want to forget this, so I am posting here. The issue is similar to the one before, the solution is different and probably can be polished. Let's say we have a nested tree, and another tree with a different depth, and we wish to perform an operation on the deeper tree, and after the operation is performed, the tree is reduced to the same structure of the shallower one. A concrete example is: we have an output tree: y = {neuron1: data_y1, neuron2: data_y2...}, where data_yi are spike counts, The operation could be transforming the input into rate for each neuron ( a transformation on the leaf of X) and a tree_map operation over the leafs ys Here is a code snippet that solves the performs the operation import nemos.utils as utils
import jax
tree1 = {'aa': {'a': 1, 'b': 2}, 'bb': {'a': 3, 'b': 4}, 'cc': {'a': 5, 'b': 6}}
tree2 = {"a":100, "b":1000}
# operation to be performed on leaves
operation_on_leaves = lambda y: y**2
# condition to asses which layer constitute a leaf
asses_leaves = lambda x: jax.tree_util.tree_structure(x) == jax.tree_util.tree_structure(tree2)
# this apply the operation, then flatten to a list of trees such that any
# tree in the list has the same structure as tree2
flat_array = lambda x: jax.tree_util.tree_flatten(jax.tree_map(operation_on_leaves , x),
is_leaf= asses_leaves)[0]
# finally we can apply that performs the reduction step
# returning {‘a': 35, ‘b': 56}
print( jax.tree_map(lambda *x: sum(x), *flat_array(tree1))) |
How do we represent neuron-specific model matrices?
One way is to use a tensor,
$$X = [X]_{tij}$$ $i$ are the parameter indices and $j$ are the neuron indices.
where t are the samples,
This works but most of the predictors will be shared (every experimental input will probably have the same set of basis for all the neurons, the coupling filters between a neuron all the other neurons will likely share the same basis). Is there a way to represent this efficiently?
Using a mask could be a way, a bit messier to implement, fitting a neuron at the time is another way, less efficient computationally but doesn't require creating the massive tensor.
The text was updated successfully, but these errors were encountered: