-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added option to execute stateful submodules #13
Merged
Merged
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
a0a2f47
Added option to execute stateful submodules
Jegp 1465f29
Returned state if stateful module
Jegp bbe54a0
Ruff
Jegp a64c23d
Added recurrent execution
Jegp 0a1d97d
Added tests for recurrent execution
Jegp 5cc87de
test for NIR -> NIRTorch -> NIR
stevenabreu7 70f4447
refactoring + expose ignore_submodules_of
stevenabreu7 175d34f
fix and test for issue #16
stevenabreu7 37e4237
fix recurrent test
stevenabreu7 c76fb6f
remove batch froms shape spec
sheiksadique 26cadf6
Merge branch 'main' into 17-input-node-retains-batch-dimension
sheiksadique 48f9842
bug from hell
stevenabreu7 26242d3
from_nir hacks for snnTorch
stevenabreu7 668e023
+ optional model.forward args for stateful modules
stevenabreu7 c555b2a
change subgraphs handlign (flatten + remove I/O)
stevenabreu7 60c01f8
model fwd args + ignore_dims arg
stevenabreu7 d4b1afb
[hack] remove wrong RNN self-connection (NIRTorch)
stevenabreu7 c736c0e
Added proper graph tracing
Jegp fe7188a
+ arg to ignore dims in to_nir
stevenabreu7 a21819f
add tests
stevenabreu7 bef454b
output_shape also uses ignore_dims
sheiksadique b95ad5c
Added test for flatten
Jegp 6c1d81e
Merged changes from #18
Jegp 3bc8bd2
minor correction to default value
sheiksadique 84e3cc8
Added ability to ignore state in executor
Jegp 8278437
Added flag in nirtorch parsing
Jegp ec8cded
Added flag in nirtorch parsing
Jegp 5845167
Merged sinabs test changes
Jegp 0325c80
minor changes to the doc strings
sheiksadique 53109c3
formatting fixes
sheiksadique File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,126 +1,202 @@ | ||
from typing import Callable, Dict, List, Optional | ||
import dataclasses | ||
import inspect | ||
from typing import Callable, Dict, List, Optional, Any, Union | ||
|
||
import nir | ||
import torch | ||
import torch.nn as nn | ||
|
||
from .graph import Graph, Node | ||
from .graph_utils import trace_execution | ||
from .utils import sanitize_name | ||
|
||
|
||
def execution_order_up_to_node( | ||
node: Node, | ||
graph: Graph, | ||
execution_order: List[Node], | ||
visited: Optional[Dict[Node, bool]] = None, | ||
) -> List[Node]: | ||
"""Recursive function to evaluate execution order until a given node. | ||
@dataclasses.dataclass | ||
class GraphExecutorState: | ||
"""State for the GraphExecutor that keeps track of both the state of hidden | ||
units and caches the output of previous modules, for use in (future) recurrent | ||
computations.""" | ||
|
||
Args: | ||
node (Node): Execution order for the node of interest | ||
graph (Graph): Graph object describing the network | ||
execution_order (List[Node]): The current known execution order. | ||
|
||
Returns: | ||
List[Node]: Execution order | ||
""" | ||
if visited is None: | ||
visited = {n: False for n in graph.node_list} | ||
is_recursive = False | ||
if len(execution_order) == list(graph.node_list): | ||
# All nodes are executed | ||
return execution_order | ||
for parent in graph.find_source_nodes_of(node): | ||
if parent not in execution_order and not visited[parent]: | ||
visited[parent] = True | ||
execution_order = execution_order_up_to_node( | ||
parent, graph, execution_order, visited | ||
) | ||
if node in parent.outgoing_nodes: | ||
is_recursive = True | ||
# Ensure we're not re-adding a recursive node | ||
if is_recursive and node in execution_order: | ||
return execution_order | ||
else: # Finally since all parents are known and executed | ||
return execution_order + [node] | ||
state: Dict[str, Any] = dataclasses.field(default_factory=dict) | ||
cache: Dict[str, Any] = dataclasses.field(default_factory=dict) | ||
|
||
|
||
class GraphExecutor(nn.Module): | ||
def __init__(self, graph: Graph) -> None: | ||
super().__init__() | ||
self.graph = graph | ||
self.stateful_modules = {} | ||
self.instantiate_modules() | ||
self.execution_order = self.get_execution_order() | ||
if len(self.execution_order) == 0: | ||
raise ValueError("Graph is empty") | ||
|
||
def _is_module_stateful(self, module: torch.nn.Module) -> bool: | ||
signature = inspect.signature(module.forward) | ||
arguments = len(signature.parameters) | ||
# HACK for snntorch modules | ||
if "snntorch" in str(module.__class__): | ||
if module.__class__.__name__ in [ | ||
"Synaptic", | ||
"RSynaptic", | ||
"Leaky", | ||
"RLeaky", | ||
]: | ||
return not module.init_hidden | ||
return arguments > 1 | ||
|
||
def get_execution_order(self) -> List[Node]: | ||
"""Evaluate the execution order and instantiate that as a list.""" | ||
execution_order = [] | ||
# Then loop over all nodes and check that they are added to the execution order. | ||
for node in self.graph.node_list: | ||
if node not in execution_order: | ||
execution_order = execution_order_up_to_node( | ||
node, self.graph, execution_order | ||
) | ||
return execution_order | ||
# TODO: Adapt this for graphs with multiple inputs | ||
inputs = self.graph.inputs | ||
if len(inputs) != 1: | ||
raise ValueError( | ||
f"Currently, only one input is supported, but {len(inputs)} was given" | ||
) | ||
return trace_execution(inputs[0], lambda n: n.outgoing_nodes.keys()) | ||
|
||
def instantiate_modules(self): | ||
for mod, name in self.graph.module_names.items(): | ||
self.add_module(sanitize_name(name), mod) | ||
if mod is not None: | ||
self.add_module(sanitize_name(name), mod) | ||
self.stateful_modules[sanitize_name(name)] = self._is_module_stateful( | ||
mod | ||
) | ||
|
||
def get_input_nodes(self) -> List[Node]: | ||
# NOTE: This is a hack. Should use the input nodes from NIR graph | ||
return self.graph.get_root() | ||
|
||
def forward(self, data: torch.Tensor): | ||
outs = {} | ||
def _apply_module( | ||
self, | ||
node: Node, | ||
input_nodes: List[Node], | ||
new_state: GraphExecutorState, | ||
old_state: GraphExecutorState, | ||
data: Optional[torch.Tensor] = None, | ||
): | ||
"""Applies a module and keeps track of its state. | ||
TODO: Use pytree to recursively construct the state | ||
""" | ||
inputs = [] | ||
# Append state if needed | ||
if node.name in self.stateful_modules and node.name in old_state.state: | ||
inputs.extend(old_state.state[node.name]) | ||
|
||
# Sum recurrence if needed | ||
summed_inputs = [] if data is None else [data] | ||
for input_node in input_nodes: | ||
if ( | ||
input_node.name not in new_state.cache | ||
and input_node.name in old_state.cache | ||
): | ||
summed_inputs.append(old_state.cache[input_node.name]) | ||
elif input_node.name in new_state.cache: | ||
summed_inputs.append(new_state.cache[input_node.name]) | ||
|
||
if len(summed_inputs) == 0: | ||
raise ValueError("No inputs found for node {}".format(node.name)) | ||
elif len(summed_inputs) == 1: | ||
inputs.insert(0, summed_inputs[0]) | ||
elif len(summed_inputs) > 1: | ||
inputs.insert(0, torch.stack(summed_inputs).sum(0)) | ||
|
||
out = node.elem(*inputs) | ||
# If the module is stateful, we know the output is (at least) a tuple | ||
# HACK to make it work for snnTorch | ||
is_rsynaptic = "snntorch._neurons.rsynaptic.RSynaptic" in str( | ||
node.elem.__class__ | ||
) | ||
if is_rsynaptic and not node.elem.init_hidden: | ||
assert "lif" in node.name, "this shouldnt happen.." | ||
new_state.state[node.name] = out # snnTorch requires output inside state | ||
out = out[0] | ||
elif self.stateful_modules[node.name]: | ||
new_state.state[node.name] = out[1:] # Store the new state | ||
out = out[0] | ||
return out, new_state | ||
|
||
def forward( | ||
self, data: torch.Tensor, old_state: Optional[GraphExecutorState] = None | ||
): | ||
if old_state is None: | ||
old_state = GraphExecutorState() | ||
new_state = GraphExecutorState() | ||
first_node = True | ||
# NOTE: This logic is not yet consistent for models with multiple input nodes | ||
for node in self.execution_order: | ||
input_nodes = self.graph.find_source_nodes_of(node) | ||
if node.elem is None: | ||
continue | ||
if len(input_nodes) == 0 or len(outs) == 0: | ||
# This is the root node | ||
outs[node.name] = node.elem(data) | ||
else: | ||
# Intermediate nodes | ||
input_data = (outs[node.name] for node in input_nodes) | ||
outs[node.name] = node.elem(*input_data) | ||
return outs[node.name] | ||
|
||
|
||
def _mod_nir_to_graph(nir_graph: nir.NIRNode) -> Graph: | ||
module_names = {module: name for name, module in nir_graph.nodes.items()} | ||
graph = Graph(module_names=module_names) | ||
for src, dst in nir_graph.edges: | ||
graph.add_edge(nir_graph.nodes[src], nir_graph.nodes[dst]) | ||
out, new_state = self._apply_module( | ||
node, | ||
input_nodes, | ||
new_state=new_state, | ||
old_state=old_state, | ||
data=data if first_node else None, | ||
) | ||
new_state.cache[node.name] = out | ||
first_node = False | ||
|
||
# If the output node is a dummy nir.Output node, use the second-to-last node | ||
if node.name not in new_state.cache: | ||
node = self.execution_order[-2] | ||
return new_state.cache[node.name], new_state | ||
|
||
|
||
def _mod_nir_to_graph( | ||
torch_graph: nir.NIRGraph, nir_nodes: Dict[str, nir.NIRNode] | ||
) -> Graph: | ||
module_names = {module: name for name, module in torch_graph.nodes.items()} | ||
inputs = [name for name, node in nir_nodes.items() if isinstance(node, nir.Input)] | ||
graph = Graph(module_names=module_names, inputs=inputs) | ||
for src, dst in torch_graph.edges: | ||
# Allow edges to refer to subgraph inputs and outputs | ||
if not src in torch_graph.nodes and f"{src}.output" in torch_graph.nodes: | ||
src = f"{src}.output" | ||
if not dst in torch_graph.nodes and f"{dst}.input" in torch_graph.nodes: | ||
dst = f"{dst}.input" | ||
graph.add_edge(torch_graph.nodes[src], torch_graph.nodes[dst]) | ||
return graph | ||
|
||
|
||
def _switch_default_models(nir_graph: nir.NIRNode) -> Optional[torch.nn.Module]: | ||
if isinstance(nir_graph, nir.Input) or isinstance(nir_graph, nir.Output): | ||
return torch.nn.Identity() | ||
|
||
|
||
def _switch_models_with_map( | ||
nir_graph: nir.NIRNode, model_map: Callable[[nn.Module], nn.Module] | ||
) -> nir.NIRNode: | ||
nodes = {name: model_map(node) for name, node in nir_graph.nodes.items()} | ||
nodes = {} | ||
for name, node in nir_graph.nodes.items(): | ||
mapped_module = model_map(node) | ||
if mapped_module is None: | ||
mapped_module = _switch_default_models(node) | ||
nodes[name] = mapped_module | ||
# nodes = {name: model_map(node) for name, node in nir_graph.nodes.items()} | ||
return nir.NIRGraph(nodes, nir_graph.edges) | ||
|
||
|
||
def load( | ||
nir_graph: nir.NIRNode, model_map: Callable[[nir.NIRNode], nn.Module] | ||
nir_graph: Union[nir.NIRNode, str], model_map: Callable[[nir.NIRNode], nn.Module] | ||
) -> nn.Module: | ||
"""Load a NIR object and convert it to a torch module using the given model map. | ||
"""Load a NIR graph and convert it to a torch module using the given model map. | ||
|
||
Args: | ||
nir_graph (nir.NIRNode): NIR object | ||
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string representing | ||
the path to the NIR object. | ||
model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch | ||
module that corresponds to each NIR node. | ||
|
||
Returns: | ||
nn.Module: The generated torch module | ||
""" | ||
if isinstance(nir_graph, str): | ||
nir_graph = nir.read(nir_graph) | ||
# Map modules to the target modules using th emodel map | ||
nir_module_graph = _switch_models_with_map(nir_graph, model_map) | ||
# Build a nirtorch.Graph based on the nir_graph | ||
graph = _mod_nir_to_graph(nir_module_graph) | ||
graph = _mod_nir_to_graph(nir_module_graph, nir_nodes=nir_graph.nodes) | ||
# Build and return a graph executor module | ||
return GraphExecutor(graph) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic here implies that if any module has multiple inputs, it will be assumed to be stateful. This is a deal breaker!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, we need to find a better way to implement this.. It currently breaks in snnTorch because you may have multiple inputs but not be stateful (if the node keeps track of its own hidden state by itself)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to find other ways of doing this. But how?
Here's the challenge as far as I can tell
state
parameter (similar to PyTorch RNNs)spk
andmem
inputsWould an option be to look for
state
in the arguments to account for the norse case andspk
andmem
to account for the snnTorch case?