Skip to content

Commit

Permalink
Improved docs on types and renamed input_shape to input_type (#52)
Browse files Browse the repository at this point in the history
* Renamed input_shape to input_type and output_shape to output_type
* Improved developer docs
  • Loading branch information
Jegp authored Sep 8, 2023
1 parent ee0877d commit 000e17a
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 115 deletions.
4 changes: 3 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ sphinx-book-theme
myst-parser
numpy
h5py
sphinx_external_toc
sphinx_external_toc
sphinxcontrib-mermaid
jupyter-book
7 changes: 7 additions & 0 deletions docs/source/_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
title: NIR
author: NIR team
logo: ../logo_light.png

repository:
url: https://github.com/neuromorphs/nir
branch: main
3 changes: 1 addition & 2 deletions docs/source/toc.yml → docs/source/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ parts:
- file: examples/spinnaker
- caption: Developer guide
chapters:
- file: api_design
- file: dev
title: Contributing
- file: internals
title: Internals
- caption: API
chapters:
- file: modindex
Expand Down
55 changes: 55 additions & 0 deletions docs/source/api_design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# NIR API design

The reference implementation simply consists of a series of Python classes that *represent* [NIR structures](primitives).
In other words, they do not implement the functionality of the nodes, but simply represent the necessary parameters required to *eventually* evaluate the node.

We chose Python because the language is straight-forward, known by most, and has excellent [dataclasses](https://docs.python.org/3/library/dataclasses.html) exactly for our purpose.
This permits an incredibly simple structure, where we have encoded all the NIR primitives into a single [`ir.py` file](https://github.com/neuromorphs/NIR/blob/main/nir/ir.py), with a simple structure:

```python
@dataclass
class MyNIRNode(NIRNode):
some_parameter: np.ndarray
```

In this example, we create a class that inherits from the parent [`NIRNode`](https://github.com/neuromorphs/NIR/blob/main/nir/ir.py#L160) with a single parameter, `some_parameter`.
Instantiating the class is simply `MyNIRNode(np.array([...]))`.

## NIR Graphs and edges
A collection of nodes is a `NIRGraph`, which is, you guessed it, a `NIRNode`.
But the graph node is special in that it contains a number of named nodes (`.nodes`) and connections between them (`.edges`).
The nodes are named because we need to uniquely distinguish them from each other, so `.nodes` is actually a dictionary (`Dict[str, NIRNode]`).
With our node above, we can define `nodes = {"my_node": MyNIRNode(np.array([...]))}`.

Edges are simply two strings: a beginning and ending node in a tuple (`Tuple[str, str]`).
There are no restrictions on edges; you can connect nodes to themselves---multiple times if you wish.
That would look like this: `edges = [("my_node", "my_node")]`.

In sum, a rudimentary, self-cyclic graph can be described in NIR as follows:

```python
NIRGraph(
nodes = {"my_node": MyNIRNode(np.array([...]))},
edges = [("my_node", "my_node")]
)
```

## Input and output types
All nodes are expected to carry two internal variables (`input_type` and `output_type`), that describe the names *and* shape of the input and outputs of the node as a dictionary (`Dict[str, np.ndarray]`).
The variables can be equalled to the function declaration in programming languages, where the names and types of the function arguments are given.
Most nodes have a single input (`"input"`) and output (`"output"`), in which case their `input_type` and `output_type` have single (trivial) entries: `{"input": ...}` and `{"output": ...}`.
Other nodes are more complicated and use explicit names for their arguments (see below).
In most cases the variables are inferred in the [`__post_init__`](https://docs.python.org/3/library/dataclasses.html#post-init-processing) method, but new implementations will have to somehow assign them.

## Subgraphs and types
The decision to include the input and output types were made to disambiguate connectivity between nodes.
Immediately, they allow us to ensure that an edge connecting two nodes are valid; that the type in one end corresponds to the type in the other end.
But the types are necessary in cases where we wish to connect *to* or *from* subgraphs.

Consider a subgraph `G` with two nodes `B` and `C`.
How can we specifically describe connectivity to `B` and not `C`?
By using the input types, we can *subscript* the edge to specify exactly which input we're connecting to: `G.B`.
An edge from `A` to `B` would then look like this: `("A", "G.B")`.
The same process works out of the graph, thanks to the `output_type`: we simply create an outgoing edge from `G.A`.

See [the unit test file `test_architectures.py`](https://github.com/neuromorphs/NIR/blob/main/tests/test_architectures.py) for concrete examples on NIR graphs, input/output types, and subgraphs.
2 changes: 1 addition & 1 deletion nir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .read import read # noqa: F401
from .write import write # noqa: F401

version = __version__ = "0.1.1"
version = __version__ = "0.2.0"
111 changes: 55 additions & 56 deletions nir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
Nodes = typing.Dict[str, "NIRNode"]
# Edges map one node id to another via the identity
Edges = typing.List[typing.Tuple[str, str]]
# Shape is a dict mapping strings to shapes
Shape = typing.Dict[str, np.ndarray]
# Types is a dict mapping strings to tensor shapes
Types = typing.Dict[str, np.ndarray]

def _parse_shape_argument(x: Shape, key: str):

def _parse_shape_argument(x: Types, key: str):
if isinstance(x, np.ndarray):
return {key: x}
elif isinstance(x, Sequence):
Expand All @@ -22,6 +23,7 @@ def _parse_shape_argument(x: Shape, key: str):
else:
raise ValueError("Unknown shape argument", x)


@dataclass
class NIRNode:
"""Base superclass of Neural Intermediate Representation Unit (NIR).
Expand All @@ -30,9 +32,9 @@ class NIRNode:
instantiated.
"""

# Note: Adding input/output shapes as follows is ideal, but requires Python 3.10
# input_shape: Shape = field(init=False, kw_only=True)
# output_shape: Shape = field(init=False, kw_only=True)
# Note: Adding input/output types as follows is ideal, but requires Python 3.10
# input_type: Types = field(init=False, kw_only=True)
# output_type: Types = field(init=False, kw_only=True)


@dataclass
Expand Down Expand Up @@ -64,14 +66,14 @@ def unique_node_name(node, counts):
return name

counts = Counter()
node_dict = {"input": Input(input_shape=nodes[0].input_shape)}
node_dict = {"input": Input(input_type=nodes[0].input_type)}
edges = []

for node in nodes:
name = unique_node_name(node, counts)
node_dict[name] = node

node_dict["output"] = Output(output_shape=nodes[-1].output_shape)
node_dict["output"] = Output(output_type=nodes[-1].output_type)

names = list(node_dict)
for i in range(len(names) - 1):
Expand All @@ -86,16 +88,16 @@ def __post_init__(self):
input_node_keys = [
k for k, node in self.nodes.items() if isinstance(node, Input)
]
self.input_shape = (
{node_key: self.nodes[node_key].input_shape for node_key in input_node_keys}
self.input_type = (
{node_key: self.nodes[node_key].input_type for node_key in input_node_keys}
if len(input_node_keys) > 0
else None
)
output_node_keys = [
k for k, node in self.nodes.items() if isinstance(node, Output)
]
self.output_shape = {
node_key: self.nodes[node_key].output_shape for node_key in output_node_keys
self.output_type = {
node_key: self.nodes[node_key].output_type for node_key in output_node_keys
}


Expand All @@ -116,12 +118,12 @@ class Affine(NIRNode):

def __post_init__(self):
assert len(self.weight.shape) >= 2, "Weight must be at least 2D"
self.input_shape = {
self.input_type = {
"input": np.array(
self.weight.shape[:-2] + tuple(np.array(self.weight.shape[-1:]).T)
)
}
self.output_shape = {
self.output_type = {
"output": np.array(self.weight.shape[:-2] + (self.weight.shape[-2],))
}

Expand All @@ -138,8 +140,8 @@ class Conv1d(NIRNode):
bias: np.ndarray # Bias C_out

def __post_init__(self):
self.input_shape = {"input": np.array(self.weight.shape)[1:]}
self.output_shape = {"output": np.array(self.weight.shape)[[0, 2]]}
self.input_type = {"input": np.array(self.weight.shape)[1:]}
self.output_type = {"output": np.array(self.weight.shape)[[0, 2]]}


@dataclass
Expand All @@ -160,8 +162,8 @@ def __post_init__(self):
self.padding = (self.padding, self.padding)
if isinstance(self.dilation, int):
self.dilation = (self.dilation, self.dilation)
self.input_shape = {"input": np.array(self.weight.shape)[1:]}
self.output_shape = {"output": np.array(self.weight.shape)[[0, 2, 3]]}
self.input_type = {"input": np.array(self.weight.shape)[1:]}
self.output_type = {"output": np.array(self.weight.shape)[[0, 2, 3]]}


@dataclass
Expand Down Expand Up @@ -216,8 +218,8 @@ def __post_init__(self):
), "All parameters must have the same shape"
# If w_in is a scalar, make it an array of same shape as v_threshold
self.w_in = np.ones_like(self.v_threshold) * self.w_in
self.input_shape = {"input": np.array(self.v_threshold.shape)}
self.output_shape = {"output": np.array(self.v_threshold.shape)}
self.input_type = {"input": np.array(self.v_threshold.shape)}
self.output_type = {"output": np.array(self.v_threshold.shape)}


@dataclass
Expand All @@ -234,39 +236,38 @@ class Delay(NIRNode):

def __post_init__(self):
# set input and output shape, if not set by user
self.input_shape = {"input": np.array(self.delay.shape)}
self.output_shape = {"output": np.array(self.delay.shape)}
self.input_type = {"input": np.array(self.delay.shape)}
self.output_type = {"output": np.array(self.delay.shape)}


@dataclass
class Flatten(NIRNode):
"""Flatten node.
This node flattens its input tensor.
input_shape must be a dict with one key: "input".
input_type must be a dict with one key: "input".
"""

# Shape of input tensor (overrrides input_shape from
# Shape of input tensor (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
input_shape: Shape
input_type: Types
start_dim: int = 1 # First dimension to flatten
end_dim: int = -1 # Last dimension to flatten

def __post_init__(self):
self.input_shape = _parse_shape_argument(self.input_shape, "input")
print(self.input_shape)
concat = self.input_shape["input"][self.start_dim : self.end_dim].prod()
self.output_shape = {
self.input_type = _parse_shape_argument(self.input_type, "input")
concat = self.input_type["input"][self.start_dim : self.end_dim].prod()
self.output_type = {
"output": np.array(
[
*self.input_shape["input"][: self.start_dim],
*self.input_type["input"][: self.start_dim],
concat,
*self.input_shape["input"][self.end_dim :],
*self.input_type["input"][self.end_dim :],
]
)
}
# make sure input and output shape are valid
if np.prod(self.input_shape["input"]) != np.prod(self.output_shape["output"]):
if np.prod(self.input_type["input"]) != np.prod(self.output_type["output"]):
raise ValueError("input and output shape must have same number of elements")


Expand All @@ -283,8 +284,8 @@ class I(NIRNode): # noqa: E742
r: np.ndarray

def __post_init__(self):
self.input_shape = {"input": np.array(self.r.shape)}
self.output_shape = {"output": np.array(self.r.shape)}
self.input_type = {"input": np.array(self.r.shape)}
self.output_type = {"output": np.array(self.r.shape)}


@dataclass
Expand Down Expand Up @@ -316,8 +317,8 @@ def __post_init__(self):
assert (
self.r.shape == self.v_threshold.shape
), "All parameters must have the same shape"
self.input_shape = {"input": np.array(self.r.shape)}
self.output_shape = {"output": np.array(self.r.shape)}
self.input_type = {"input": np.array(self.r.shape)}
self.output_type = {"output": np.array(self.r.shape)}


@dataclass
Expand All @@ -327,13 +328,13 @@ class Input(NIRNode):
This is a virtual node, which allows feeding in data into the graph.
"""

# Shape of incoming data (overrrides input_shape from
# Shape of incoming data (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
input_shape: Shape
input_type: Types

def __post_init__(self):
self.input_shape = _parse_shape_argument(self.input_shape, "input")
self.output_shape = {"output": self.input_shape["input"]}
self.input_type = _parse_shape_argument(self.input_type, "input")
self.output_type = {"output": self.input_type["input"]}


@dataclass
Expand All @@ -358,8 +359,8 @@ def __post_init__(self):
assert (
self.tau.shape == self.r.shape == self.v_leak.shape
), "All parameters must have the same shape"
self.input_shape = {"input": np.array(self.r.shape)}
self.output_shape = {"output": np.array(self.r.shape)}
self.input_type = {"input": np.array(self.r.shape)}
self.output_type = {"output": np.array(self.r.shape)}


@dataclass
Expand All @@ -373,14 +374,12 @@ class Linear(NIRNode):

def __post_init__(self):
assert len(self.weight.shape) >= 2, "Weight must be at least 2D"
self.input_shape = {
self.input_type = {
"input": np.array(
self.weight.shape[:-2] + tuple(np.array(self.weight.shape[-1:]).T)
)
}
self.output_shape = {
"output": self.weight.shape[:-2] + (self.weight.shape[-2],)
}
self.output_type = {"output": self.weight.shape[:-2] + (self.weight.shape[-2],)}


@dataclass
Expand Down Expand Up @@ -421,8 +420,8 @@ def __post_init__(self):
== self.v_leak.shape
== self.v_threshold.shape
), "All parameters must have the same shape"
self.input_shape = {"input": np.array(self.r.shape)}
self.output_shape = {"output": np.array(self.r.shape)}
self.input_type = {"input": np.array(self.r.shape)}
self.output_type = {"output": np.array(self.r.shape)}


@dataclass
Expand All @@ -432,13 +431,13 @@ class Output(NIRNode):
Defines an output of the graph.
"""

# Shape of incoming data (overrrides input_shape from
# Type of incoming data (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
output_shape: Shape
output_type: Types

def __post_init__(self):
self.output_shape = _parse_shape_argument(self.output_shape, "output")
self.input_shape = {"input": self.output_shape["output"]}
self.output_type = _parse_shape_argument(self.output_type, "output")
self.input_type = {"input": self.output_type["output"]}


@dataclass
Expand All @@ -455,8 +454,8 @@ class Scale(NIRNode):
scale: np.ndarray # Scaling factor

def __post_init__(self):
self.input_shape = {"input": np.array(self.scale.shape)}
self.output_shape = {"output": np.array(self.scale.shape)}
self.input_type = {"input": np.array(self.scale.shape)}
self.output_type = {"output": np.array(self.scale.shape)}


@dataclass
Expand All @@ -475,5 +474,5 @@ class Threshold(NIRNode):
threshold: np.ndarray # Firing threshold

def __post_init__(self):
self.input_shape = {"input": np.array(self.threshold.shape)}
self.output_shape = {"output": np.array(self.threshold.shape)}
self.input_type = {"input": np.array(self.threshold.shape)}
self.output_type = {"output": np.array(self.threshold.shape)}
Loading

0 comments on commit 000e17a

Please sign in to comment.