Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to execute stateful submodules #13

Merged
merged 30 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a0a2f47
Added option to execute stateful submodules
Jegp Oct 10, 2023
1465f29
Returned state if stateful module
Jegp Oct 10, 2023
bbe54a0
Ruff
Jegp Oct 10, 2023
a64c23d
Added recurrent execution
Jegp Oct 11, 2023
0a1d97d
Added tests for recurrent execution
Jegp Oct 12, 2023
5cc87de
test for NIR -> NIRTorch -> NIR
stevenabreu7 Oct 13, 2023
70f4447
refactoring + expose ignore_submodules_of
stevenabreu7 Oct 13, 2023
175d34f
fix and test for issue #16
stevenabreu7 Oct 13, 2023
37e4237
fix recurrent test
stevenabreu7 Oct 16, 2023
c76fb6f
remove batch froms shape spec
sheiksadique Oct 18, 2023
26cadf6
Merge branch 'main' into 17-input-node-retains-batch-dimension
sheiksadique Oct 18, 2023
48f9842
bug from hell
stevenabreu7 Oct 18, 2023
26242d3
from_nir hacks for snnTorch
stevenabreu7 Oct 19, 2023
668e023
+ optional model.forward args for stateful modules
stevenabreu7 Oct 19, 2023
c555b2a
change subgraphs handlign (flatten + remove I/O)
stevenabreu7 Oct 19, 2023
60c01f8
model fwd args + ignore_dims arg
stevenabreu7 Oct 19, 2023
d4b1afb
[hack] remove wrong RNN self-connection (NIRTorch)
stevenabreu7 Oct 19, 2023
c736c0e
Added proper graph tracing
Jegp Oct 19, 2023
fe7188a
+ arg to ignore dims in to_nir
stevenabreu7 Oct 20, 2023
a21819f
add tests
stevenabreu7 Oct 20, 2023
bef454b
output_shape also uses ignore_dims
sheiksadique Oct 20, 2023
b95ad5c
Added test for flatten
Jegp Oct 20, 2023
6c1d81e
Merged changes from #18
Jegp Oct 20, 2023
3bc8bd2
minor correction to default value
sheiksadique Oct 20, 2023
84e3cc8
Added ability to ignore state in executor
Jegp Oct 21, 2023
8278437
Added flag in nirtorch parsing
Jegp Oct 21, 2023
ec8cded
Added flag in nirtorch parsing
Jegp Oct 21, 2023
5845167
Merged sinabs test changes
Jegp Oct 21, 2023
0325c80
minor changes to the doc strings
sheiksadique Dec 5, 2023
53109c3
formatting fixes
sheiksadique Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 142 additions & 66 deletions nirtorch/from_nir.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,202 @@
from typing import Callable, Dict, List, Optional
import dataclasses
import inspect
from typing import Callable, Dict, List, Optional, Any, Union

import nir
import torch
import torch.nn as nn

from .graph import Graph, Node
from .graph_utils import trace_execution
from .utils import sanitize_name


def execution_order_up_to_node(
node: Node,
graph: Graph,
execution_order: List[Node],
visited: Optional[Dict[Node, bool]] = None,
) -> List[Node]:
"""Recursive function to evaluate execution order until a given node.
@dataclasses.dataclass
class GraphExecutorState:
"""State for the GraphExecutor that keeps track of both the state of hidden
units and caches the output of previous modules, for use in (future) recurrent
computations."""

Args:
node (Node): Execution order for the node of interest
graph (Graph): Graph object describing the network
execution_order (List[Node]): The current known execution order.

Returns:
List[Node]: Execution order
"""
if visited is None:
visited = {n: False for n in graph.node_list}
is_recursive = False
if len(execution_order) == list(graph.node_list):
# All nodes are executed
return execution_order
for parent in graph.find_source_nodes_of(node):
if parent not in execution_order and not visited[parent]:
visited[parent] = True
execution_order = execution_order_up_to_node(
parent, graph, execution_order, visited
)
if node in parent.outgoing_nodes:
is_recursive = True
# Ensure we're not re-adding a recursive node
if is_recursive and node in execution_order:
return execution_order
else: # Finally since all parents are known and executed
return execution_order + [node]
state: Dict[str, Any] = dataclasses.field(default_factory=dict)
cache: Dict[str, Any] = dataclasses.field(default_factory=dict)


class GraphExecutor(nn.Module):
def __init__(self, graph: Graph) -> None:
super().__init__()
self.graph = graph
self.stateful_modules = {}
self.instantiate_modules()
self.execution_order = self.get_execution_order()
if len(self.execution_order) == 0:
raise ValueError("Graph is empty")

def _is_module_stateful(self, module: torch.nn.Module) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here implies that if any module has multiple inputs, it will be assumed to be stateful. This is a deal breaker!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we need to find a better way to implement this.. It currently breaks in snnTorch because you may have multiple inputs but not be stateful (if the node keeps track of its own hidden state by itself)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to find other ways of doing this. But how?

Here's the challenge as far as I can tell

  • Most frameworks can live without state (snnTorch, Sinabs, Rockpool)
  • Norse requires a state parameter (similar to PyTorch RNNs)
  • snnTorch can take spk and mem inputs

Would an option be to look for state in the arguments to account for the norse case and spk and mem to account for the snnTorch case?

signature = inspect.signature(module.forward)
arguments = len(signature.parameters)
# HACK for snntorch modules
if "snntorch" in str(module.__class__):
if module.__class__.__name__ in [
"Synaptic",
"RSynaptic",
"Leaky",
"RLeaky",
]:
return not module.init_hidden
return arguments > 1

def get_execution_order(self) -> List[Node]:
"""Evaluate the execution order and instantiate that as a list."""
execution_order = []
# Then loop over all nodes and check that they are added to the execution order.
for node in self.graph.node_list:
if node not in execution_order:
execution_order = execution_order_up_to_node(
node, self.graph, execution_order
)
return execution_order
# TODO: Adapt this for graphs with multiple inputs
inputs = self.graph.inputs
if len(inputs) != 1:
raise ValueError(
f"Currently, only one input is supported, but {len(inputs)} was given"
)
return trace_execution(inputs[0], lambda n: n.outgoing_nodes.keys())

def instantiate_modules(self):
for mod, name in self.graph.module_names.items():
self.add_module(sanitize_name(name), mod)
if mod is not None:
self.add_module(sanitize_name(name), mod)
self.stateful_modules[sanitize_name(name)] = self._is_module_stateful(
mod
)

def get_input_nodes(self) -> List[Node]:
# NOTE: This is a hack. Should use the input nodes from NIR graph
return self.graph.get_root()

def forward(self, data: torch.Tensor):
outs = {}
def _apply_module(
self,
node: Node,
input_nodes: List[Node],
new_state: GraphExecutorState,
old_state: GraphExecutorState,
data: Optional[torch.Tensor] = None,
):
"""Applies a module and keeps track of its state.
TODO: Use pytree to recursively construct the state
"""
inputs = []
# Append state if needed
if node.name in self.stateful_modules and node.name in old_state.state:
inputs.extend(old_state.state[node.name])

# Sum recurrence if needed
summed_inputs = [] if data is None else [data]
for input_node in input_nodes:
if (
input_node.name not in new_state.cache
and input_node.name in old_state.cache
):
summed_inputs.append(old_state.cache[input_node.name])
elif input_node.name in new_state.cache:
summed_inputs.append(new_state.cache[input_node.name])

if len(summed_inputs) == 0:
raise ValueError("No inputs found for node {}".format(node.name))
elif len(summed_inputs) == 1:
inputs.insert(0, summed_inputs[0])
elif len(summed_inputs) > 1:
inputs.insert(0, torch.stack(summed_inputs).sum(0))

out = node.elem(*inputs)
# If the module is stateful, we know the output is (at least) a tuple
# HACK to make it work for snnTorch
is_rsynaptic = "snntorch._neurons.rsynaptic.RSynaptic" in str(
node.elem.__class__
)
if is_rsynaptic and not node.elem.init_hidden:
assert "lif" in node.name, "this shouldnt happen.."
new_state.state[node.name] = out # snnTorch requires output inside state
out = out[0]
elif self.stateful_modules[node.name]:
new_state.state[node.name] = out[1:] # Store the new state
out = out[0]
return out, new_state

def forward(
self, data: torch.Tensor, old_state: Optional[GraphExecutorState] = None
):
if old_state is None:
old_state = GraphExecutorState()
new_state = GraphExecutorState()
first_node = True
# NOTE: This logic is not yet consistent for models with multiple input nodes
for node in self.execution_order:
input_nodes = self.graph.find_source_nodes_of(node)
if node.elem is None:
continue
if len(input_nodes) == 0 or len(outs) == 0:
# This is the root node
outs[node.name] = node.elem(data)
else:
# Intermediate nodes
input_data = (outs[node.name] for node in input_nodes)
outs[node.name] = node.elem(*input_data)
return outs[node.name]


def _mod_nir_to_graph(nir_graph: nir.NIRNode) -> Graph:
module_names = {module: name for name, module in nir_graph.nodes.items()}
graph = Graph(module_names=module_names)
for src, dst in nir_graph.edges:
graph.add_edge(nir_graph.nodes[src], nir_graph.nodes[dst])
out, new_state = self._apply_module(
node,
input_nodes,
new_state=new_state,
old_state=old_state,
data=data if first_node else None,
)
new_state.cache[node.name] = out
first_node = False

# If the output node is a dummy nir.Output node, use the second-to-last node
if node.name not in new_state.cache:
node = self.execution_order[-2]
return new_state.cache[node.name], new_state


def _mod_nir_to_graph(
torch_graph: nir.NIRGraph, nir_nodes: Dict[str, nir.NIRNode]
) -> Graph:
module_names = {module: name for name, module in torch_graph.nodes.items()}
inputs = [name for name, node in nir_nodes.items() if isinstance(node, nir.Input)]
graph = Graph(module_names=module_names, inputs=inputs)
for src, dst in torch_graph.edges:
# Allow edges to refer to subgraph inputs and outputs
if not src in torch_graph.nodes and f"{src}.output" in torch_graph.nodes:
src = f"{src}.output"
if not dst in torch_graph.nodes and f"{dst}.input" in torch_graph.nodes:
dst = f"{dst}.input"
graph.add_edge(torch_graph.nodes[src], torch_graph.nodes[dst])
return graph


def _switch_default_models(nir_graph: nir.NIRNode) -> Optional[torch.nn.Module]:
if isinstance(nir_graph, nir.Input) or isinstance(nir_graph, nir.Output):
return torch.nn.Identity()


def _switch_models_with_map(
nir_graph: nir.NIRNode, model_map: Callable[[nn.Module], nn.Module]
) -> nir.NIRNode:
nodes = {name: model_map(node) for name, node in nir_graph.nodes.items()}
nodes = {}
for name, node in nir_graph.nodes.items():
mapped_module = model_map(node)
if mapped_module is None:
mapped_module = _switch_default_models(node)
nodes[name] = mapped_module
# nodes = {name: model_map(node) for name, node in nir_graph.nodes.items()}
return nir.NIRGraph(nodes, nir_graph.edges)


def load(
nir_graph: nir.NIRNode, model_map: Callable[[nir.NIRNode], nn.Module]
nir_graph: Union[nir.NIRNode, str], model_map: Callable[[nir.NIRNode], nn.Module]
) -> nn.Module:
"""Load a NIR object and convert it to a torch module using the given model map.
"""Load a NIR graph and convert it to a torch module using the given model map.

Args:
nir_graph (nir.NIRNode): NIR object
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string representing
the path to the NIR object.
model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch
module that corresponds to each NIR node.

Returns:
nn.Module: The generated torch module
"""
if isinstance(nir_graph, str):
nir_graph = nir.read(nir_graph)
# Map modules to the target modules using th emodel map
nir_module_graph = _switch_models_with_map(nir_graph, model_map)
# Build a nirtorch.Graph based on the nir_graph
graph = _mod_nir_to_graph(nir_module_graph)
graph = _mod_nir_to_graph(nir_module_graph, nir_nodes=nir_graph.nodes)
# Build and return a graph executor module
return GraphExecutor(graph)
Loading