Skip to content

Commit

Permalink
add input/output shape annotations for all nodes (#50)
Browse files Browse the repository at this point in the history
* add I/O shape annotations for all nodes

---------

Co-authored-by: Jens E. Pedersen <[email protected]>
Co-authored-by: Felix Bauer <[email protected]>
  • Loading branch information
3 people authored Sep 4, 2023
1 parent 84790cb commit b6cb9d0
Show file tree
Hide file tree
Showing 9 changed files with 636 additions and 98 deletions.
25 changes: 19 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,28 @@ The goal of NIR is to provide a common format that different neuromorphic framew
## Computational primitives
> Read more about in our [documentation about NIR primitives](https://nnir.readthedocs.io/en/latest/primitives.html)
On top of popular primitives such as convolutional or fully connected/linear computations, we define additional compuational primitives that are specific to neuromorphic computing and hardware implementations thereof. Computational units that are not specifically neuromorphic take inspiration from the Pytorch ecosystem in terms of naming and parameters (such as Conv2d that uses groups/strides).
On top of popular primitives such as convolutional or fully connected/linear computations, we define additional compuational primitives that are specific to neuromorphic computing and hardware implementations thereof.
Computational units that are not specifically neuromorphic take inspiration from the Pytorch ecosystem in terms of naming and parameters (such as Conv2d that uses groups/strides).

## Connectivity
Each computational unit is a node in a static graph. Given 3 nodes $A$ which is a LIF node, $B$ which is a Linear node and $C$ which is another LIF node, we can define edges in the graph such as:
Each computational unit is a node in a static graph.
Given 3 nodes $A$ which is a LIF node, $B$ which is a Linear node and $C$ which is another LIF node, we can define edges in the graph such as:

$$
A \rightarrow B \\
B \rightarrow C
$$
```mermaid
graph LR;
A --> B;
B --> C;
```

Or more complicated graphs, such as

```mermaid
graph LR;
A --> A;
A --> B;
B --> C;
A --> C;
```

## Format
The intermediate represenation can be stored as hdf5 file, which benefits from compression.
Expand Down
1 change: 0 additions & 1 deletion example/sinabs/export_from_sinabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch.nn as nn
from sinabs import from_nir, to_nir


batch_size = 4

# Create Sinabs model
Expand Down
167 changes: 159 additions & 8 deletions nir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

import numpy as np

Edges = typing.NewType("Edges", typing.List[typing.Tuple[str, str]])
# Nodes are uniquely named computational units
Nodes = typing.Dict[str, "NIRNode"]
# Edges map one node id to another via the identity
Edges = typing.List[typing.Tuple[str, str]]
# Shape is a dict mapping strings to shapes
Shape = typing.Dict[str, np.ndarray]


@dataclass
Expand All @@ -15,6 +20,10 @@ class NIRNode:
instantiated.
"""

# Note: Adding input/output shapes as follows is ideal, but requires Python 3.10
# input_shape: Shape = field(init=False, kw_only=True)
# output_shape: Shape = field(init=False, kw_only=True)


@dataclass
class NIRGraph(NIRNode):
Expand All @@ -24,14 +33,19 @@ class NIRGraph(NIRNode):
A graph of computational nodes and identity edges.
"""

nodes: typing.Dict[str, NIRNode] # List of computational nodes
edges: Edges
nodes: Nodes # List of computational nodes
edges: Edges # List of edges between nodes

@staticmethod
def from_list(*nodes: NIRNode) -> "NIRGraph":
"""Create a sequential graph from a list of nodes by labelling them after
indices."""

if len(nodes) > 0 and (
isinstance(nodes[0], list) or isinstance(nodes[0], tuple)
):
nodes = [*nodes[0]]

def unique_node_name(node, counts):
basename = node.__class__.__name__.lower()
id = counts[basename]
Expand All @@ -40,35 +54,67 @@ def unique_node_name(node, counts):
return name

counts = Counter()
node_dict = {}
node_dict = {"input": Input(input_shape=nodes[0].input_shape)}
edges = []

for node in nodes:
name = unique_node_name(node, counts)
node_dict[name] = node

node_dict["output"] = Output(output_shape=nodes[-1].output_shape)

names = list(node_dict)
for i in range(len(nodes) - 1):
for i in range(len(names) - 1):
edges.append((names[i], names[i + 1]))

return NIRGraph(
nodes=node_dict,
edges=edges,
)

def __post_init__(self):
input_node_keys = [
k for k, node in self.nodes.items() if isinstance(node, Input)
]
self.input_shape = (
{node_key: self.nodes[node_key].input_shape for node_key in input_node_keys}
if len(input_node_keys) > 0
else None
)
output_node_keys = [
k for k, node in self.nodes.items() if isinstance(node, Output)
]
self.output_shape = {
node_key: self.nodes[node_key].output_shape for node_key in output_node_keys
}


@dataclass
class Affine(NIRNode):
r"""Affine transform that linearly maps and translates the input signal.
This is equivalent to the `Affine transformation <https://en.wikipedia.org/wiki/Affine_transformation>`_
This is equivalent to the
`Affine transformation <https://en.wikipedia.org/wiki/Affine_transformation>`_
Assumes a one-dimensional input vector of shape (N,).
.. math::
y(t) = W*x(t) + b
"""
weight: np.ndarray # Weight term
bias: np.ndarray # Bias term

def __post_init__(self):
assert len(self.weight.shape) >= 2, "Weight must be at least 2D"
self.input_shape = {
"input": np.array(
self.weight.shape[:-2] + tuple(np.array(self.weight.shape[-1:]).T)
)
}
self.output_shape = {
"output": np.array(self.weight.shape[:-2] + (self.weight.shape[-2],))
}


@dataclass
class Conv1d(NIRNode):
Expand All @@ -81,6 +127,10 @@ class Conv1d(NIRNode):
groups: int # Groups
bias: np.ndarray # Bias C_out

def __post_init__(self):
self.input_shape = {"input": np.array(self.weight.shape)[1:]}
self.output_shape = {"output": np.array(self.weight.shape)[[0, 2]]}


@dataclass
class Conv2d(NIRNode):
Expand All @@ -100,6 +150,8 @@ def __post_init__(self):
self.padding = (self.padding, self.padding)
if isinstance(self.dilation, int):
self.dilation = (self.dilation, self.dilation)
self.input_shape = {"input": np.array(self.weight.shape)[1:]}
self.output_shape = {"output": np.array(self.weight.shape)[[0, 2, 3]]}


@dataclass
Expand Down Expand Up @@ -145,8 +197,17 @@ class CubaLIF(NIRNode):
w_in: np.ndarray = 1.0 # Input current weight

def __post_init__(self):
assert (
self.tau_syn.shape
== self.tau_mem.shape
== self.r.shape
== self.v_leak.shape
== self.v_threshold.shape
), "All parameters must have the same shape"
# If w_in is a scalar, make it an array of same shape as v_threshold
self.w_in = np.ones_like(self.v_threshold) * self.w_in
self.input_shape = {"input": np.array(self.v_threshold.shape)}
self.output_shape = {"output": np.array(self.v_threshold.shape)}


@dataclass
Expand All @@ -161,17 +222,46 @@ class Delay(NIRNode):

delay: np.ndarray # Delay

def __post_init__(self):
# set input and output shape, if not set by user
self.input_shape = {"input": np.array(self.delay.shape)}
self.output_shape = {"output": np.array(self.delay.shape)}


@dataclass
class Flatten(NIRNode):
"""Flatten node.
This node flattens its input tensor.
input_shape must be a dict with one key: "input".
"""

# Shape of input tensor (overrrides input_shape from
# NIRNode to allow for non-keyword (positional) initialization)
input_shape: Shape
start_dim: int = 1 # First dimension to flatten
end_dim: int = -1 # Last dimension to flatten

def __post_init__(self):
assert list(self.input_shape.keys()) == [
"input"
], "Flatten must have one input: `input`"
if isinstance(self.input_shape, np.ndarray):
self.input_shape = {"input": self.input_shape}
concat = self.input_shape["input"][self.start_dim : self.end_dim].prod()
self.output_shape = {
"output": np.array(
[
*self.input_shape["input"][: self.start_dim],
concat,
*self.input_shape["input"][self.end_dim :],
]
)
}
# make sure input and output shape are valid
if np.prod(self.input_shape["input"]) != np.prod(self.output_shape["output"]):
raise ValueError("input and output shape must have same number of elements")


@dataclass
class I(NIRNode): # noqa: E742
Expand All @@ -185,6 +275,10 @@ class I(NIRNode): # noqa: E742

r: np.ndarray

def __post_init__(self):
self.input_shape = {"input": np.array(self.r.shape)}
self.output_shape = {"output": np.array(self.r.shape)}


@dataclass
class IF(NIRNode):
Expand All @@ -211,6 +305,13 @@ class IF(NIRNode):
r: np.ndarray # Resistance
v_threshold: np.ndarray # Firing threshold

def __post_init__(self):
assert (
self.r.shape == self.v_threshold.shape
), "All parameters must have the same shape"
self.input_shape = {"input": np.array(self.r.shape)}
self.output_shape = {"output": np.array(self.r.shape)}


@dataclass
class Input(NIRNode):
Expand All @@ -219,7 +320,14 @@ class Input(NIRNode):
This is a virtual node, which allows feeding in data into the graph.
"""

shape: np.ndarray # Shape of input data
# Shape of incoming data (overrrides input_shape from
# NIRNode to allow for non-keyword (positional) initialization)
input_shape: Shape

def __post_init__(self):
if isinstance(self.input_shape, np.ndarray):
self.input_shape = {"input": self.input_shape}
self.output_shape = {"output": self.input_shape["input"]}


@dataclass
Expand All @@ -240,6 +348,13 @@ class LI(NIRNode):
r: np.ndarray # Resistance
v_leak: np.ndarray # Leak voltage

def __post_init__(self):
assert (
self.tau.shape == self.r.shape == self.v_leak.shape
), "All parameters must have the same shape"
self.input_shape = {"input": np.array(self.r.shape)}
self.output_shape = {"output": np.array(self.r.shape)}


@dataclass
class Linear(NIRNode):
Expand All @@ -250,6 +365,17 @@ class Linear(NIRNode):
"""
weight: np.ndarray # Weight term

def __post_init__(self):
assert len(self.weight.shape) >= 2, "Weight must be at least 2D"
self.input_shape = {
"input": np.array(
self.weight.shape[:-2] + tuple(np.array(self.weight.shape[-1:]).T)
)
}
self.output_shape = {
"output": self.weight.shape[:-2] + (self.weight.shape[-2],)
}


@dataclass
class LIF(NIRNode):
Expand Down Expand Up @@ -282,6 +408,16 @@ class LIF(NIRNode):
v_leak: np.ndarray # Leak voltage
v_threshold: np.ndarray # Firing threshold

def __post_init__(self):
assert (
self.tau.shape
== self.r.shape
== self.v_leak.shape
== self.v_threshold.shape
), "All parameters must have the same shape"
self.input_shape = {"input": np.array(self.r.shape)}
self.output_shape = {"output": np.array(self.r.shape)}


@dataclass
class Output(NIRNode):
Expand All @@ -290,7 +426,14 @@ class Output(NIRNode):
Defines an output of the graph.
"""

shape: int # Size of output
# Shape of incoming data (overrrides input_shape from
# NIRNode to allow for non-keyword (positional) initialization)
output_shape: Shape

def __post_init__(self):
if isinstance(self.output_shape, np.ndarray):
self.output_shape = {"output": self.output_shape}
self.input_shape = {"input": self.output_shape["output"]}


@dataclass
Expand All @@ -306,6 +449,10 @@ class Scale(NIRNode):

scale: np.ndarray # Scaling factor

def __post_init__(self):
self.input_shape = {"input": np.array(self.scale.shape)}
self.output_shape = {"output": np.array(self.scale.shape)}


@dataclass
class Threshold(NIRNode):
Expand All @@ -321,3 +468,7 @@ class Threshold(NIRNode):
"""

threshold: np.ndarray # Firing threshold

def __post_init__(self):
self.input_shape = {"input": np.array(self.threshold.shape)}
self.output_shape = {"output": np.array(self.threshold.shape)}
7 changes: 4 additions & 3 deletions nir/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ def read_node(node: typing.Any) -> nir.NIRNode:
return nir.Flatten(
start_dim=node["start_dim"][()],
end_dim=node["end_dim"][()],
input_shape={"input": node["input_shape"][()]},
)
elif node["type"][()] == b"I":
return nir.I(r=node["r"][()])
elif node["type"][()] == b"IF":
return nir.IF(r=node["r"][()], v_threshold=node["v_threshold"][()])
elif node["type"][()] == b"Input":
return nir.Input(shape=node["shape"][()])
return nir.Input(input_shape={"input": node["shape"][()]})
elif node["type"][()] == b"LI":
return nir.LI(
tau=node["tau"][()],
Expand All @@ -67,10 +68,10 @@ def read_node(node: typing.Any) -> nir.NIRNode:
elif node["type"][()] == b"NIRGraph":
return nir.NIRGraph(
nodes={k: read_node(n) for k, n in node["nodes"].items()},
edges=node["edges"].asstr()[()],
edges=[(a.decode("utf8"), b.decode("utf8")) for a, b in node["edges"][()]],
)
elif node["type"][()] == b"Output":
return nir.Output(shape=node["shape"][()])
return nir.Output(output_shape={"output": node["shape"][()]})
elif node["type"][()] == b"Scale":
return nir.Scale(scale=node["scale"][()])
elif node["type"][()] == b"Threshold":
Expand Down
Loading

0 comments on commit b6cb9d0

Please sign in to comment.