Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to execute stateful submodules #13

Merged
merged 30 commits into from
Dec 6, 2023
Merged

Conversation

Jegp
Copy link
Collaborator

@Jegp Jegp commented Oct 10, 2023

This PR allows stateful modules by

  1. Adding and keeping track of the state in the call to forward
  2. Returning a dict that can be reused in subsequent calls to the GraphExecutor

@stevenabreu7
Copy link
Contributor

awesome work, thanks @Jegp! For now, I just tried importing a NIR graph into snntorch and everything works splendidly. However, when I try to export that back to NIR, I get an error. The graph is not able to figure out the root node - it returns multiple roots even though the graph I'm using is clearly sequential (I'm using norse_lif.nir from our paper experiments). Any hints what might be going wrong?

self.instantiate_modules()
self.execution_order = self.get_execution_order()
if len(self.execution_order) == 0:
raise ValueError("Graph is empty")

def _is_module_stateful(self, module: torch.nn.Module) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here implies that if any module has multiple inputs, it will be assumed to be stateful. This is a deal breaker!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we need to find a better way to implement this.. It currently breaks in snnTorch because you may have multiple inputs but not be stateful (if the node keeps track of its own hidden state by itself)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to find other ways of doing this. But how?

Here's the challenge as far as I can tell

  • Most frameworks can live without state (snnTorch, Sinabs, Rockpool)
  • Norse requires a state parameter (similar to PyTorch RNNs)
  • snnTorch can take spk and mem inputs

Would an option be to look for state in the arguments to account for the norse case and spk and mem to account for the snnTorch case?

@@ -12,9 +13,16 @@ def extract_nir_graph(
model_map: Callable[[nn.Module], nir.NIRNode],
sample_data: Any,
model_name: Optional[str] = "model",
ignore_submodules_of=None,
model_fwd_args=[],
ignore_dims=[],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow negative numbers to do reverse indexing

@Jegp Jegp mentioned this pull request Oct 20, 2023
@Jegp Jegp linked an issue Oct 21, 2023 that may be closed by this pull request
@sheiksadique sheiksadique merged commit 5fa4c07 into main Dec 6, 2023
4 checks passed
@sheiksadique sheiksadique deleted the feature-state branch December 6, 2023 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow stateful modules in the graph executor
3 participants