-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added option to execute stateful submodules #13
Conversation
awesome work, thanks @Jegp! For now, I just tried importing a NIR graph into snntorch and everything works splendidly. However, when I try to export that back to NIR, I get an error. The graph is not able to figure out the root node - it returns multiple roots even though the graph I'm using is clearly sequential (I'm using |
self.instantiate_modules() | ||
self.execution_order = self.get_execution_order() | ||
if len(self.execution_order) == 0: | ||
raise ValueError("Graph is empty") | ||
|
||
def _is_module_stateful(self, module: torch.nn.Module) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic here implies that if any module has multiple inputs, it will be assumed to be stateful. This is a deal breaker!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, we need to find a better way to implement this.. It currently breaks in snnTorch because you may have multiple inputs but not be stateful (if the node keeps track of its own hidden state by itself)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to find other ways of doing this. But how?
Here's the challenge as far as I can tell
- Most frameworks can live without state (snnTorch, Sinabs, Rockpool)
- Norse requires a
state
parameter (similar to PyTorch RNNs) - snnTorch can take
spk
andmem
inputs
Would an option be to look for state
in the arguments to account for the norse case and spk
and mem
to account for the snnTorch case?
nirtorch/to_nir.py
Outdated
@@ -12,9 +13,16 @@ def extract_nir_graph( | |||
model_map: Callable[[nn.Module], nir.NIRNode], | |||
sample_data: Any, | |||
model_name: Optional[str] = "model", | |||
ignore_submodules_of=None, | |||
model_fwd_args=[], | |||
ignore_dims=[], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allow negative numbers to do reverse indexing
This PR allows stateful modules by
forward