diff --git a/README.md b/README.md index 6b60d97..c959331 100644 --- a/README.md +++ b/README.md @@ -15,15 +15,28 @@ The goal of NIR is to provide a common format that different neuromorphic framew ## Computational primitives > Read more about in our [documentation about NIR primitives](https://nnir.readthedocs.io/en/latest/primitives.html) -On top of popular primitives such as convolutional or fully connected/linear computations, we define additional compuational primitives that are specific to neuromorphic computing and hardware implementations thereof. Computational units that are not specifically neuromorphic take inspiration from the Pytorch ecosystem in terms of naming and parameters (such as Conv2d that uses groups/strides). +On top of popular primitives such as convolutional or fully connected/linear computations, we define additional compuational primitives that are specific to neuromorphic computing and hardware implementations thereof. +Computational units that are not specifically neuromorphic take inspiration from the Pytorch ecosystem in terms of naming and parameters (such as Conv2d that uses groups/strides). ## Connectivity -Each computational unit is a node in a static graph. Given 3 nodes $A$ which is a LIF node, $B$ which is a Linear node and $C$ which is another LIF node, we can define edges in the graph such as: +Each computational unit is a node in a static graph. +Given 3 nodes $A$ which is a LIF node, $B$ which is a Linear node and $C$ which is another LIF node, we can define edges in the graph such as: -$$ -A \rightarrow B \\ -B \rightarrow C -$$ +```mermaid +graph LR; +A --> B; +B --> C; +``` + +Or more complicated graphs, such as + +```mermaid +graph LR; +A --> A; +A --> B; +B --> C; +A --> C; +``` ## Format The intermediate represenation can be stored as hdf5 file, which benefits from compression. diff --git a/example/sinabs/export_from_sinabs.py b/example/sinabs/export_from_sinabs.py index 1c52906..171caa4 100644 --- a/example/sinabs/export_from_sinabs.py +++ b/example/sinabs/export_from_sinabs.py @@ -3,7 +3,6 @@ import torch.nn as nn from sinabs import from_nir, to_nir - batch_size = 4 # Create Sinabs model diff --git a/nir/ir.py b/nir/ir.py index a992119..09322fa 100644 --- a/nir/ir.py +++ b/nir/ir.py @@ -4,7 +4,12 @@ import numpy as np -Edges = typing.NewType("Edges", typing.List[typing.Tuple[str, str]]) +# Nodes are uniquely named computational units +Nodes = typing.Dict[str, "NIRNode"] +# Edges map one node id to another via the identity +Edges = typing.List[typing.Tuple[str, str]] +# Shape is a dict mapping strings to shapes +Shape = typing.Dict[str, np.ndarray] @dataclass @@ -15,6 +20,10 @@ class NIRNode: instantiated. """ + # Note: Adding input/output shapes as follows is ideal, but requires Python 3.10 + # input_shape: Shape = field(init=False, kw_only=True) + # output_shape: Shape = field(init=False, kw_only=True) + @dataclass class NIRGraph(NIRNode): @@ -24,14 +33,19 @@ class NIRGraph(NIRNode): A graph of computational nodes and identity edges. """ - nodes: typing.Dict[str, NIRNode] # List of computational nodes - edges: Edges + nodes: Nodes # List of computational nodes + edges: Edges # List of edges between nodes @staticmethod def from_list(*nodes: NIRNode) -> "NIRGraph": """Create a sequential graph from a list of nodes by labelling them after indices.""" + if len(nodes) > 0 and ( + isinstance(nodes[0], list) or isinstance(nodes[0], tuple) + ): + nodes = [*nodes[0]] + def unique_node_name(node, counts): basename = node.__class__.__name__.lower() id = counts[basename] @@ -40,15 +54,17 @@ def unique_node_name(node, counts): return name counts = Counter() - node_dict = {} + node_dict = {"input": Input(input_shape=nodes[0].input_shape)} edges = [] for node in nodes: name = unique_node_name(node, counts) node_dict[name] = node + node_dict["output"] = Output(output_shape=nodes[-1].output_shape) + names = list(node_dict) - for i in range(len(nodes) - 1): + for i in range(len(names) - 1): edges.append((names[i], names[i + 1])) return NIRGraph( @@ -56,12 +72,31 @@ def unique_node_name(node, counts): edges=edges, ) + def __post_init__(self): + input_node_keys = [ + k for k, node in self.nodes.items() if isinstance(node, Input) + ] + self.input_shape = ( + {node_key: self.nodes[node_key].input_shape for node_key in input_node_keys} + if len(input_node_keys) > 0 + else None + ) + output_node_keys = [ + k for k, node in self.nodes.items() if isinstance(node, Output) + ] + self.output_shape = { + node_key: self.nodes[node_key].output_shape for node_key in output_node_keys + } + @dataclass class Affine(NIRNode): r"""Affine transform that linearly maps and translates the input signal. - This is equivalent to the `Affine transformation `_ + This is equivalent to the + `Affine transformation `_ + + Assumes a one-dimensional input vector of shape (N,). .. math:: y(t) = W*x(t) + b @@ -69,6 +104,17 @@ class Affine(NIRNode): weight: np.ndarray # Weight term bias: np.ndarray # Bias term + def __post_init__(self): + assert len(self.weight.shape) >= 2, "Weight must be at least 2D" + self.input_shape = { + "input": np.array( + self.weight.shape[:-2] + tuple(np.array(self.weight.shape[-1:]).T) + ) + } + self.output_shape = { + "output": np.array(self.weight.shape[:-2] + (self.weight.shape[-2],)) + } + @dataclass class Conv1d(NIRNode): @@ -81,6 +127,10 @@ class Conv1d(NIRNode): groups: int # Groups bias: np.ndarray # Bias C_out + def __post_init__(self): + self.input_shape = {"input": np.array(self.weight.shape)[1:]} + self.output_shape = {"output": np.array(self.weight.shape)[[0, 2]]} + @dataclass class Conv2d(NIRNode): @@ -100,6 +150,8 @@ def __post_init__(self): self.padding = (self.padding, self.padding) if isinstance(self.dilation, int): self.dilation = (self.dilation, self.dilation) + self.input_shape = {"input": np.array(self.weight.shape)[1:]} + self.output_shape = {"output": np.array(self.weight.shape)[[0, 2, 3]]} @dataclass @@ -145,8 +197,17 @@ class CubaLIF(NIRNode): w_in: np.ndarray = 1.0 # Input current weight def __post_init__(self): + assert ( + self.tau_syn.shape + == self.tau_mem.shape + == self.r.shape + == self.v_leak.shape + == self.v_threshold.shape + ), "All parameters must have the same shape" # If w_in is a scalar, make it an array of same shape as v_threshold self.w_in = np.ones_like(self.v_threshold) * self.w_in + self.input_shape = {"input": np.array(self.v_threshold.shape)} + self.output_shape = {"output": np.array(self.v_threshold.shape)} @dataclass @@ -161,17 +222,46 @@ class Delay(NIRNode): delay: np.ndarray # Delay + def __post_init__(self): + # set input and output shape, if not set by user + self.input_shape = {"input": np.array(self.delay.shape)} + self.output_shape = {"output": np.array(self.delay.shape)} + @dataclass class Flatten(NIRNode): """Flatten node. This node flattens its input tensor. + input_shape must be a dict with one key: "input". """ + # Shape of input tensor (overrrides input_shape from + # NIRNode to allow for non-keyword (positional) initialization) + input_shape: Shape start_dim: int = 1 # First dimension to flatten end_dim: int = -1 # Last dimension to flatten + def __post_init__(self): + assert list(self.input_shape.keys()) == [ + "input" + ], "Flatten must have one input: `input`" + if isinstance(self.input_shape, np.ndarray): + self.input_shape = {"input": self.input_shape} + concat = self.input_shape["input"][self.start_dim : self.end_dim].prod() + self.output_shape = { + "output": np.array( + [ + *self.input_shape["input"][: self.start_dim], + concat, + *self.input_shape["input"][self.end_dim :], + ] + ) + } + # make sure input and output shape are valid + if np.prod(self.input_shape["input"]) != np.prod(self.output_shape["output"]): + raise ValueError("input and output shape must have same number of elements") + @dataclass class I(NIRNode): # noqa: E742 @@ -185,6 +275,10 @@ class I(NIRNode): # noqa: E742 r: np.ndarray + def __post_init__(self): + self.input_shape = {"input": np.array(self.r.shape)} + self.output_shape = {"output": np.array(self.r.shape)} + @dataclass class IF(NIRNode): @@ -211,6 +305,13 @@ class IF(NIRNode): r: np.ndarray # Resistance v_threshold: np.ndarray # Firing threshold + def __post_init__(self): + assert ( + self.r.shape == self.v_threshold.shape + ), "All parameters must have the same shape" + self.input_shape = {"input": np.array(self.r.shape)} + self.output_shape = {"output": np.array(self.r.shape)} + @dataclass class Input(NIRNode): @@ -219,7 +320,14 @@ class Input(NIRNode): This is a virtual node, which allows feeding in data into the graph. """ - shape: np.ndarray # Shape of input data + # Shape of incoming data (overrrides input_shape from + # NIRNode to allow for non-keyword (positional) initialization) + input_shape: Shape + + def __post_init__(self): + if isinstance(self.input_shape, np.ndarray): + self.input_shape = {"input": self.input_shape} + self.output_shape = {"output": self.input_shape["input"]} @dataclass @@ -240,6 +348,13 @@ class LI(NIRNode): r: np.ndarray # Resistance v_leak: np.ndarray # Leak voltage + def __post_init__(self): + assert ( + self.tau.shape == self.r.shape == self.v_leak.shape + ), "All parameters must have the same shape" + self.input_shape = {"input": np.array(self.r.shape)} + self.output_shape = {"output": np.array(self.r.shape)} + @dataclass class Linear(NIRNode): @@ -250,6 +365,17 @@ class Linear(NIRNode): """ weight: np.ndarray # Weight term + def __post_init__(self): + assert len(self.weight.shape) >= 2, "Weight must be at least 2D" + self.input_shape = { + "input": np.array( + self.weight.shape[:-2] + tuple(np.array(self.weight.shape[-1:]).T) + ) + } + self.output_shape = { + "output": self.weight.shape[:-2] + (self.weight.shape[-2],) + } + @dataclass class LIF(NIRNode): @@ -282,6 +408,16 @@ class LIF(NIRNode): v_leak: np.ndarray # Leak voltage v_threshold: np.ndarray # Firing threshold + def __post_init__(self): + assert ( + self.tau.shape + == self.r.shape + == self.v_leak.shape + == self.v_threshold.shape + ), "All parameters must have the same shape" + self.input_shape = {"input": np.array(self.r.shape)} + self.output_shape = {"output": np.array(self.r.shape)} + @dataclass class Output(NIRNode): @@ -290,7 +426,14 @@ class Output(NIRNode): Defines an output of the graph. """ - shape: int # Size of output + # Shape of incoming data (overrrides input_shape from + # NIRNode to allow for non-keyword (positional) initialization) + output_shape: Shape + + def __post_init__(self): + if isinstance(self.output_shape, np.ndarray): + self.output_shape = {"output": self.output_shape} + self.input_shape = {"input": self.output_shape["output"]} @dataclass @@ -306,6 +449,10 @@ class Scale(NIRNode): scale: np.ndarray # Scaling factor + def __post_init__(self): + self.input_shape = {"input": np.array(self.scale.shape)} + self.output_shape = {"output": np.array(self.scale.shape)} + @dataclass class Threshold(NIRNode): @@ -321,3 +468,7 @@ class Threshold(NIRNode): """ threshold: np.ndarray # Firing threshold + + def __post_init__(self): + self.input_shape = {"input": np.array(self.threshold.shape)} + self.output_shape = {"output": np.array(self.threshold.shape)} diff --git a/nir/read.py b/nir/read.py index 496ff44..53be7b7 100644 --- a/nir/read.py +++ b/nir/read.py @@ -34,13 +34,14 @@ def read_node(node: typing.Any) -> nir.NIRNode: return nir.Flatten( start_dim=node["start_dim"][()], end_dim=node["end_dim"][()], + input_shape={"input": node["input_shape"][()]}, ) elif node["type"][()] == b"I": return nir.I(r=node["r"][()]) elif node["type"][()] == b"IF": return nir.IF(r=node["r"][()], v_threshold=node["v_threshold"][()]) elif node["type"][()] == b"Input": - return nir.Input(shape=node["shape"][()]) + return nir.Input(input_shape={"input": node["shape"][()]}) elif node["type"][()] == b"LI": return nir.LI( tau=node["tau"][()], @@ -67,10 +68,10 @@ def read_node(node: typing.Any) -> nir.NIRNode: elif node["type"][()] == b"NIRGraph": return nir.NIRGraph( nodes={k: read_node(n) for k, n in node["nodes"].items()}, - edges=node["edges"].asstr()[()], + edges=[(a.decode("utf8"), b.decode("utf8")) for a, b in node["edges"][()]], ) elif node["type"][()] == b"Output": - return nir.Output(shape=node["shape"][()]) + return nir.Output(output_shape={"output": node["shape"][()]}) elif node["type"][()] == b"Scale": return nir.Scale(scale=node["scale"][()]) elif node["type"][()] == b"Threshold": diff --git a/nir/write.py b/nir/write.py index 637ea37..8b2b5c2 100644 --- a/nir/write.py +++ b/nir/write.py @@ -41,6 +41,7 @@ def _convert_node(node: nir.NIRNode) -> dict: "type": "Flatten", "start_dim": node.start_dim, "end_dim": node.end_dim, + "input_shape": node.input_shape["input"], } elif isinstance(node, nir.I): return {"type": "I", "r": node.r} @@ -51,7 +52,7 @@ def _convert_node(node: nir.NIRNode) -> dict: "v_threshold": node.v_threshold, } elif isinstance(node, nir.Input): - return {"type": "Input", "shape": node.shape} + return {"type": "Input", "shape": node.input_shape["input"]} elif isinstance(node, nir.LI): return { "type": "LI", @@ -85,7 +86,7 @@ def _convert_node(node: nir.NIRNode) -> dict: "edges": node.edges, } elif isinstance(node, nir.Output): - return {"type": "Output", "shape": node.shape} + return {"type": "Output", "shape": node.output_shape["output"]} elif isinstance(node, nir.Scale): return {"type": "Scale", "scale": node.scale} elif isinstance(node, nir.Threshold): diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..59592bf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,27 @@ +import numpy as np + +import nir + + +def mock_linear(*shape): + return nir.Linear(weight=np.random.randn(*shape).T) + + +def mock_affine(*shape): + return nir.Affine(weight=np.random.randn(*shape).T, bias=np.random.randn(shape[1])) + + +def mock_input(*shape): + return nir.Input(input_shape=np.array(shape)) + + +def mock_integrator(*shape): + return nir.I(r=np.random.randn(*shape)) + + +def mock_output(*shape): + return nir.Output(output_shape=np.array(shape)) + + +def mock_delay(*shape): + return nir.Delay(delay=np.random.randn(*shape)) diff --git a/tests/test_architectures.py b/tests/test_architectures.py new file mode 100644 index 0000000..c5e5832 --- /dev/null +++ b/tests/test_architectures.py @@ -0,0 +1,304 @@ +import numpy as np + +import nir +from .test_readwrite import factory_test_graph +from tests import mock_affine + + +def test_sequential(): + a = mock_affine(2, 2) + b = nir.Delay(np.array([0.5, 0.1, 0.2])) + c = nir.LIF( + tau=np.array([10, 20, 30]), + r=np.array([1, 1, 1]), + v_leak=np.array([0, 0, 0]), + v_threshold=np.array([1, 2, 3]), + ) + d = mock_affine(3, 2) + + ir = nir.NIRGraph.from_list(a, b, c, d) + factory_test_graph(ir) + + +def test_two_independent_branches(): + # Branch 1 + a = mock_affine(2, 3) + b = nir.Delay(np.array([0.5, 0.1, 0.2])) + c = nir.LIF( + tau=np.array([10, 20, 30]), + r=np.array([1, 1, 1]), + v_leak=np.array([0, 0, 0]), + v_threshold=np.array([1, 2, 3]), + ) + d = mock_affine(2, 3) + + branch_1 = nir.NIRGraph.from_list(a, b, c, d) + + # Branch 2 + e = mock_affine(2, 3) + f = nir.LIF( + tau=np.array([10, 20]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 2]), + ) + g = mock_affine(3, 2) + + branch_2 = nir.NIRGraph.from_list(e, f, g) + + ir = nir.NIRGraph( + nodes={"branch_1": branch_1, "branch_2": branch_2}, + edges=[], + ) + factory_test_graph(ir) + + +def test_two_independent_branches_merging(): + # Branch 1 + a = mock_affine(2, 3) + b = nir.Delay(np.array([0.5, 0.1, 0.2])) + c = nir.LIF( + tau=np.array([10, 20, 30]), + r=np.array([1, 1, 1]), + v_leak=np.array([0, 0, 0]), + v_threshold=np.array([1, 2, 3]), + ) + d = mock_affine(2, 3) + + branch_1 = nir.NIRGraph.from_list(a, b, c, d) + + # Branch 2 + e = mock_affine(2, 3) + f = nir.LIF( + tau=np.array([10, 20]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 2]), + ) + g = mock_affine(3, 2) + + branch_2 = nir.NIRGraph.from_list(e, f, g) + + # Junction + # TODO: This should be a node that accepts two inputs + h = nir.LIF( + tau=np.array([5, 2]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 1]), + ) + + ir = nir.NIRGraph( + nodes={"branch_1": branch_1, "branch_2": branch_2, "junction": h}, + edges=[("branch_1", "junction"), ("branch_2", "junction")], + ) + factory_test_graph(ir) + + +def test_merge_and_split_single_output(): + # Part before split + a = mock_affine(2, 3) + b = nir.LIF( + tau=np.array([10, 20, 30]), + r=np.array([1, 1, 1]), + v_leak=np.array([0, 0, 0]), + v_threshold=np.array([1, 2, 3]), + ) + pre_split = nir.NIRGraph.from_list(a, b) + + # Branch 1 + c = mock_affine(2, 3) + d = nir.LIF( + tau=np.array([10, 20]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 2]), + ) + branch_1 = nir.NIRGraph.from_list(c, d) + + # Branch 2 + e = mock_affine(2, 3) + f = nir.LIF( + tau=np.array([15, 5]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 1]), + ) + branch_2 = nir.NIRGraph.from_list([e, f]) + + # Junction + # TODO: This should be a node that accepts two inputs + g = nir.Affine(weight=np.array([[2, 0], [1, 3], [4, 1]]), bias=np.array([0, 1])) + + nodes = { + "pre_split": pre_split, + "branch_1": branch_1, + "branch_2": branch_2, + "junction": g, + } + edges = [ + ("pre_split", "branch_1"), + ("pre_split", "branch_2"), + ("branch_1", "junction"), + ("branch_2", "junction"), + ] + ir = nir.NIRGraph(nodes=nodes, edges=edges) + factory_test_graph(ir) + + +def merge_and_split_different_outputs(): + # Part before split + a = mock_affine(3, 2) + # TODO: This should be a node with two outputs + b = nir.LIF( + tau=np.array([10, 20, 30]), + r=np.array([1, 1, 1]), + v_leak=np.array([0, 0, 0]), + v_threshold=np.array([1, 2, 3]), + ) + pre_split = nir.NIRGraph.from_list([a, b]) + + # Branch 1 + reduce_1 = nir.Project(output_indices=[0]) + c = mock_affine(3, 2) + d = nir.LIF( + tau=np.array([10, 20]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 2]), + ) + branch_1 = nir.NIRGraph.from_list([c, d]) + expand_1 = nir.Project(output_indices=[0, float("nan")]) + + # Branch 2 + reduce_2 = nir.Project(output_indices=[1]) + e = mock_affine(3, 2) + f = nir.LIF( + tau=np.array([15, 5]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 1]), + ) + branch_2 = nir.NIRGraph.from_list([e, f]) + expand_2 = nir.Project(output_indices=[float("nan"), 1]) + + # Junction + # TODO: This should be a node that accepts two inputs + g = mock_affine(3, 2) + + nodes = { + "pre_split": pre_split, + "reduce_1": reduce_1, + "reduce_2": reduce_2, + "branch_1": branch_1, + "branch_2": branch_2, + "expand_1": expand_1, + "expand_2": expand_2, + "junction": g, + } + edges = [ + ("pre_split", "reduce_1"), + ("pre_split", "reduce_2"), + ("reduce_1", "branch_1"), + ("reduce_2", "branch_2"), + ("branch_1", "expand_1"), + ("expand_1", "junction"), + ("branch_2", "expand_2"), + ("expand_2", "junction"), + ] + ir = nir.NIRGraph(nodes=nodes, edges=edges) + factory_test_graph(ir) + + +def test_residual(): + # Part before split + a = mock_affine(2, 3) + + # Residual block + b = nir.LIF( + tau=np.array([10, 20, 30]), + r=np.array([1, 1, 1]), + v_leak=np.array([0, 0, 0]), + v_threshold=np.array([1, 2, 3]), + ) + c = mock_affine(3, 2) + d = nir.LIF( + tau=np.array([10, 20]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 2]), + ) + + # Junction + # TODO: This should be a node that accepts two inputs + e = mock_affine(3, 2) + f = nir.LIF( + tau=np.array([15, 5]), + r=np.array([1, 1]), + v_leak=np.array([0, 0]), + v_threshold=np.array([1, 1]), + ) + + nodes = { + "a": a, + "b": b, + "c": c, + "d": d, + "e": e, + "f": f, + } + edges = [ + ("a", "b"), + ("b", "c"), + ("c", "d"), + ("d", "e"), + ("a", "e"), + ("e", "f"), + ] + ir = nir.NIRGraph(nodes=nodes, edges=edges) + factory_test_graph(ir) + + +def test_complex(): + a = nir.Affine(weight=np.array([[1, 2, 3]]), bias=np.array([[0, 0, 0]])) + b = nir.LIF( + tau=np.array([10, 20, 30]), + r=np.array([1, 1, 1]), + v_leak=np.array([0, 0, 0]), + v_threshold=np.array([1, 2, 3]), + ) + c = nir.LIF( + tau=np.array([5, 20, 1]), + r=np.array([1, 1, 1]), + v_leak=np.array([0, 0, 0]), + v_threshold=np.array([1, 1, 1]), + ) + # TODO: This should be a node that accepts two inputs + d = nir.Affine( + weight=np.array([[[1, 3], [2, 3], [1, 4]], [[2, 3], [1, 2], [1, 4]]]), + bias=np.array([0, 0]), + ) + e = nir.Affine(weight=np.array([[1, 3], [2, 3], [1, 4]]), bias=np.array([0, 0])) + # TODO: This should be a node that accepts two inputs + f = nir.Affine( + weight=np.array([[[1, 3], [1, 4]], [[2, 3], [3, 4]]]), bias=np.array([0, 0]) + ) + nodes = { + "a": a, + "b": b, + "c": c, + "d": d, + "e": e, + "f": f, + } + edges = [ + ("a", "b"), + ("a", "c"), + ("b", "d"), + ("c", "d"), + ("c", "e"), + ("d", "f"), + ("e", "f"), + ] + ir = nir.NIRGraph(nodes=nodes, edges=edges) + factory_test_graph(ir) diff --git a/tests/test_ir.py b/tests/test_ir.py index e2efd57..c31c5e9 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -1,6 +1,13 @@ import numpy as np import nir +from tests import ( + mock_delay, + mock_affine, + mock_integrator, + mock_linear, + mock_output, +) def test_has_version(): @@ -9,70 +16,64 @@ def test_has_version(): def test_simple(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) - ir = nir.NIRGraph(nodes={"a": nir.Affine(weight=w, bias=b)}, edges=[("a", "a")]) - assert np.allclose(ir.nodes["a"].weight, w) - assert np.allclose(ir.nodes["a"].bias, b) + a = mock_affine(4, 3) + ir = nir.NIRGraph(nodes={"a": a}, edges=[("a", "a")]) + assert np.allclose(ir.nodes["a"].weight, a.weight) + assert np.allclose(ir.nodes["a"].bias, a.bias) assert ir.edges == [("a", "a")] def test_nested(): - r = np.array([1, 1]) - delay = np.array([2, 2]) - w = np.array([1, 2]) - b = np.array([4, 4]) + i = mock_integrator(3) + d = mock_delay(3) + a = mock_affine(3, 3) + nested = nir.NIRGraph( nodes={ - "integrator": nir.I(r=r), - "delay": nir.Delay(delay), + "integrator": i, + "delay": d, }, edges=[("integrator", "delay"), ("delay", "integrator")], ) ir = nir.NIRGraph( - nodes={"affine": nir.Affine(weight=w, bias=b), "inner": nested}, + nodes={"affine": a, "inner": nested}, edges=[("affine", "inner")], ) - assert np.allclose(ir.nodes["affine"].weight, w) - assert np.allclose(ir.nodes["affine"].bias, b) - assert np.allclose(ir.nodes["inner"].nodes["integrator"].r, r) - assert np.allclose(ir.nodes["inner"].nodes["delay"].delay, delay) + assert np.allclose(ir.nodes["affine"].weight, a.weight) + assert np.allclose(ir.nodes["affine"].bias, a.bias) + assert np.allclose(ir.nodes["inner"].nodes["integrator"].r, i.r) + assert np.allclose(ir.nodes["inner"].nodes["delay"].delay, d.delay) assert ir.nodes["inner"].edges == [("integrator", "delay"), ("delay", "integrator")] def test_simple_with_input_output(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) + a = mock_affine(3, 3) ir = nir.NIRGraph( nodes={ "in": nir.Input(np.array([3])), - "w": nir.Affine(weight=w, bias=b), + "w": a, "out": nir.Output(np.array([3])), }, edges=[("in", "w"), ("w", "out")], ) - assert ir.nodes["in"].shape == [ - 3, - ] - assert np.allclose(ir.nodes["w"].weight, w) - assert np.allclose(ir.nodes["w"].bias, b) + assert ir.nodes["in"].input_shape == {"input": np.array([3])} + assert np.allclose(ir.nodes["w"].weight, a.weight) + assert np.allclose(ir.nodes["w"].bias, a.bias) assert ir.edges == [("in", "w"), ("w", "out")] def test_delay(): - delay = np.array([1, 2, 3]) + d = mock_delay(3) ir = nir.NIRGraph( nodes={ "in": nir.Input(np.array([3])), - "d": nir.Delay(delay=delay), + "d": d, "out": nir.Output(np.array([3])), }, edges=[("in", "d"), ("d", "out")], ) - assert ir.nodes["in"].shape == [ - 3, - ] - assert np.allclose(ir.nodes["d"].delay, delay) + assert ir.nodes["in"].input_shape == {"input": np.array([3])} + assert np.allclose(ir.nodes["d"].delay, d.delay) assert ir.edges == [("in", "d"), ("d", "out")] @@ -92,36 +93,35 @@ def test_threshold(): }, edges=[("in", "thr"), ("thr", "out")], ) - assert ir.nodes["in"].shape == [ - 3, - ] + assert ir.nodes["in"].input_shape == {"input": np.array([3])} assert np.allclose(ir.nodes["thr"].threshold, threshold) assert ir.edges == [("in", "thr"), ("thr", "out")] def test_linear(): - w = np.array([1, 2, 3]) - ir = nir.NIRGraph(nodes={"a": nir.Linear(weight=w)}, edges=[("a", "a")]) - assert np.allclose(ir.nodes["a"].weight, w) + a = mock_linear(3, 3) + ir = nir.NIRGraph(nodes={"a": a}, edges=[("a", "a")]) + assert np.allclose(ir.nodes["a"].weight, a.weight) assert ir.edges == [("a", "a")] def test_flatten(): ir = nir.NIRGraph( nodes={ - "in": nir.Input(np.array([4, 5, 2])), - "flat": nir.Flatten(0), - "out": nir.Output(np.array([20, 2])), + "in": nir.Input(input_shape=np.array([4, 5, 2])), + "flat": nir.Flatten( + start_dim=0, end_dim=0, input_shape={"input": np.array([4, 5, 2])} + ), + "out": nir.Output(output_shape=np.array([20, 2])), }, edges=[("in", "flat"), ("flat", "out")], ) - assert np.allclose(ir.nodes["in"].shape, [4, 5, 2]) - assert np.allclose(ir.nodes["out"].shape, [20, 2]) + assert np.allclose(ir.nodes["in"].input_shape["input"], np.array([4, 5, 2])) + assert np.allclose(ir.nodes["out"].input_shape["input"], np.array([20, 2])) def test_from_list_naming(): ir = nir.NIRGraph.from_list( - nir.Input(shape=np.array([2])), nir.Linear(weight=np.array([[3, 1], [-1, 2], [1, 2]])), nir.Linear(weight=np.array([[3, 1], [-1, 4], [1, 2]]).T), nir.Affine( @@ -138,7 +138,6 @@ def test_from_list_naming(): nir.Affine( weight=np.array([[2, 1], [-1, 3], [1, 2]]).T, bias=np.array([-2, 3]) ), - nir.Output(shape=np.array([2])), ) assert "input" in ir.nodes.keys() assert "linear" in ir.nodes.keys() @@ -150,7 +149,7 @@ def test_from_list_naming(): assert "affine_2" in ir.nodes.keys() assert "affine_3" in ir.nodes.keys() assert "output" in ir.nodes.keys() - assert np.allclose(ir.nodes["input"].shape, [2]) + assert np.allclose(ir.nodes["input"].input_shape["input"], [2]) assert np.allclose(ir.nodes["linear"].weight, np.array([[3, 1], [-1, 2], [1, 2]])) assert np.allclose( ir.nodes["linear_1"].weight, np.array([[3, 1], [-1, 4], [1, 2]]).T @@ -171,7 +170,8 @@ def test_from_list_naming(): ir.nodes["affine_3"].weight, np.array([[2, 1], [-1, 3], [1, 2]]).T ) assert np.allclose(ir.nodes["affine_3"].bias, np.array([-2, 3])) - assert np.allclose(ir.nodes["output"].shape, [2]) + print(ir.nodes["output"].input_shape["input"]) + assert np.allclose(ir.nodes["output"].input_shape["input"], [2]) assert ir.edges == [ ("input", "linear"), ("linear", "linear_1"), @@ -183,3 +183,48 @@ def test_from_list_naming(): ("affine_2", "affine_3"), ("affine_3", "output"), ] + + +def test_from_list_tuple_or_list(): + nodes = [mock_affine(2, 3), mock_delay(1)] + assert len(nir.NIRGraph.from_list(*nodes).nodes) == 4 + assert len(nir.NIRGraph.from_list(*nodes).edges) == 3 + assert len(nir.NIRGraph.from_list(tuple(nodes)).nodes) == 4 + assert len(nir.NIRGraph.from_list(tuple(nodes)).nodes) == 4 + assert len(nir.NIRGraph.from_list(nodes[0], nodes[1]).edges) == 3 + assert len(nir.NIRGraph.from_list(nodes[0], nodes[1]).edges) == 3 + + +def test_subgraph_merge(): + """ + ```mermaid + graph TD; + A --> B; + C --> D; + D --> E; + B --> E; + ``` + """ + g1 = nir.NIRGraph.from_list(mock_linear(2, 3), mock_linear(3, 2)) + g2 = nir.NIRGraph.from_list(mock_linear(1, 3), mock_linear(3, 2)) + end = mock_output(2) + g = nir.NIRGraph( + nodes={"L": g1, "R": g2, "E": end}, + edges=[("L.output", "E.input"), ("R.output", "E.input")], + ) + assert np.allclose(g.nodes["L"].nodes["linear"].input_shape["input"], [2]) + assert np.allclose(g.nodes["L"].nodes["linear_1"].input_shape["input"], [3]) + assert np.allclose(g.nodes["R"].nodes["linear"].input_shape["input"], [1]) + assert np.allclose(g.nodes["R"].nodes["linear_1"].input_shape["input"], [3]) + assert np.allclose(g.nodes["E"].input_shape["input"], [2]) + assert g.edges == [("L.output", "E.input"), ("R.output", "E.input")] + assert g.nodes["L"].edges == [ + ("input", "linear"), + ("linear", "linear_1"), + ("linear_1", "output"), + ] + assert g.nodes["R"].edges == [ + ("input", "linear"), + ("linear", "linear_1"), + ("linear_1", "output"), + ] diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index ecc1f6c..4e0d7d5 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -3,6 +3,7 @@ import numpy as np import nir +from tests import mock_affine def assert_equivalence(ir: nir.NIRGraph, ir2: nir.NIRGraph): @@ -13,8 +14,16 @@ def assert_equivalence(ir: nir.NIRGraph, ir2: nir.NIRGraph): assert_equivalence(ir.nodes[ik], ir2.nodes[ik]) else: for k, v in ir.nodes[ik].__dict__.items(): - if isinstance(v, np.ndarray) or isinstance(v, list): + if ( + isinstance(v, np.ndarray) + or isinstance(v, list) + or isinstance(v, tuple) + ): assert np.array_equal(v, getattr(ir2.nodes[ik], k)) + elif isinstance(v, dict): + d = getattr(ir2.nodes[ik], k) + for a, b in d.items(): + assert np.array_equal(v[a], b) else: assert v == getattr(ir2.nodes[ik], k) for i, _ in enumerate(ir.edges): @@ -30,9 +39,7 @@ def factory_test_graph(ir: nir.NIRGraph): def test_simple(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) - ir = nir.NIRGraph(nodes={"a": nir.Affine(weight=w, bias=b)}, edges=[("a", "a")]) + ir = nir.NIRGraph(nodes={"a": mock_affine(2, 2)}, edges=[("a", "a")]) factory_test_graph(ir) @@ -40,7 +47,7 @@ def test_nested(): i = np.array([1, 1]) nested = nir.NIRGraph( nodes={ - "a": nir.I(r=[1, 1]), + "a": nir.I(r=np.array([1, 1])), "b": nir.NIRGraph( nodes={ "a": nir.Input(i), @@ -57,73 +64,61 @@ def test_nested(): def test_integrator(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) r = np.array([1, 1, 1]) ir = nir.NIRGraph( - nodes={"a": nir.Affine(weight=w, bias=b), "b": nir.I(r)}, + nodes={"a": mock_affine(2, 2), "b": nir.I(r)}, edges=[("a", "b")], ) factory_test_graph(ir) def test_integrate_and_fire(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) r = np.array([1, 1, 1]) v_threshold = np.array([1, 1, 1]) ir = nir.NIRGraph( - nodes={"a": nir.Affine(weight=w, bias=b), "b": nir.IF(r, v_threshold)}, + nodes={"a": mock_affine(2, 2), "b": nir.IF(r, v_threshold)}, edges=[("a", "b")], ) factory_test_graph(ir) def test_leaky_integrator(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) tau = np.array([1, 1, 1]) r = np.array([1, 1, 1]) v_leak = np.array([1, 1, 1]) - ir = nir.NIRGraph.from_list(nir.Affine(weight=w, bias=b), nir.LI(tau, r, v_leak)) + ir = nir.NIRGraph.from_list(mock_affine(2, 2), nir.LI(tau, r, v_leak)) factory_test_graph(ir) def test_linear(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) tau = np.array([1, 1, 1]) r = np.array([1, 1, 1]) v_leak = np.array([1, 1, 1]) - ir = nir.NIRGraph.from_list(nir.Affine(w, b), nir.LI(tau, r, v_leak)) + ir = nir.NIRGraph.from_list(mock_affine(2, 2), nir.LI(tau, r, v_leak)) factory_test_graph(ir) def test_leaky_integrator_and_fire(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) tau = np.array([1, 1, 1]) r = np.array([1, 1, 1]) v_leak = np.array([1, 1, 1]) v_threshold = np.array([3, 3, 3]) ir = nir.NIRGraph.from_list( - nir.Affine(w, b), + mock_affine(2, 2), nir.LIF(tau, r, v_leak, v_threshold), ) factory_test_graph(ir) def test_current_based_leaky_integrator_and_fire(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) tau_mem = np.array([1, 1, 1]) tau_syn = np.array([2, 2, 2]) r = np.array([1, 1, 1]) v_leak = np.array([1, 1, 1]) v_threshold = np.array([3, 3, 3]) ir = nir.NIRGraph.from_list( - nir.Affine(w, b), + mock_affine(2, 2), nir.CubaLIF(tau_mem, tau_syn, r, v_leak, v_threshold), ) factory_test_graph(ir) @@ -131,20 +126,18 @@ def test_current_based_leaky_integrator_and_fire(): def test_scale(): ir = nir.NIRGraph.from_list( - nir.Input(shape=np.array([3])), + nir.Input(input_shape=np.array([3])), nir.Scale(scale=np.array([1, 2, 3])), - nir.Output(shape=np.array([3])), + nir.Output(output_shape=np.array([3])), ) factory_test_graph(ir) def test_simple_with_read_write(): - w = np.array([1, 2, 3]) - b = np.array([4, 4, 4]) ir = nir.NIRGraph.from_list( - nir.Input(shape=np.array([3])), - nir.Affine(w, b), - nir.Output(shape=np.array([3])), + nir.Input(input_shape=np.array([3])), + mock_affine(2, 2), + nir.Output(output_shape=np.array([3])), ) factory_test_graph(ir) @@ -171,8 +164,12 @@ def test_threshold(): def test_flatten(): ir = nir.NIRGraph.from_list( - nir.Input(shape=np.array([2, 3])), - nir.Flatten(), - nir.Output(shape=np.array([6])), + nir.Input(input_shape=np.array([2, 3])), + nir.Flatten( + start_dim=0, + end_dim=0, + input_shape={"input": np.array([2, 3])}, + ), + nir.Output(output_shape=np.array([6])), ) factory_test_graph(ir)