diff --git a/docs/requirements.txt b/docs/requirements.txt index 052f01c..828099d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,4 +3,6 @@ sphinx-book-theme myst-parser numpy h5py -sphinx_external_toc \ No newline at end of file +sphinx_external_toc +sphinxcontrib-mermaid +jupyter-book \ No newline at end of file diff --git a/docs/source/_config.yml b/docs/source/_config.yml new file mode 100644 index 0000000..0129a78 --- /dev/null +++ b/docs/source/_config.yml @@ -0,0 +1,7 @@ +title: NIR +author: NIR team +logo: ../logo_light.png + +repository: + url: https://github.com/neuromorphs/nir + branch: main \ No newline at end of file diff --git a/docs/source/toc.yml b/docs/source/_toc.yml similarity index 94% rename from docs/source/toc.yml rename to docs/source/_toc.yml index 05d28b5..a496e3b 100644 --- a/docs/source/toc.yml +++ b/docs/source/_toc.yml @@ -25,10 +25,9 @@ parts: - file: examples/spinnaker - caption: Developer guide chapters: + - file: api_design - file: dev title: Contributing - - file: internals - title: Internals - caption: API chapters: - file: modindex diff --git a/docs/source/api_design.md b/docs/source/api_design.md new file mode 100644 index 0000000..e2ae2b9 --- /dev/null +++ b/docs/source/api_design.md @@ -0,0 +1,55 @@ +# NIR API design + +The reference implementation simply consists of a series of Python classes that *represent* [NIR structures](primitives). +In other words, they do not implement the functionality of the nodes, but simply represent the necessary parameters required to *eventually* evaluate the node. + +We chose Python because the language is straight-forward, known by most, and has excellent [dataclasses](https://docs.python.org/3/library/dataclasses.html) exactly for our purpose. +This permits an incredibly simple structure, where we have encoded all the NIR primitives into a single [`ir.py` file](https://github.com/neuromorphs/NIR/blob/main/nir/ir.py), with a simple structure: + +```python +@dataclass +class MyNIRNode(NIRNode): + some_parameter: np.ndarray +``` + +In this example, we create a class that inherits from the parent [`NIRNode`](https://github.com/neuromorphs/NIR/blob/main/nir/ir.py#L160) with a single parameter, `some_parameter`. +Instantiating the class is simply `MyNIRNode(np.array([...]))`. + +## NIR Graphs and edges +A collection of nodes is a `NIRGraph`, which is, you guessed it, a `NIRNode`. +But the graph node is special in that it contains a number of named nodes (`.nodes`) and connections between them (`.edges`). +The nodes are named because we need to uniquely distinguish them from each other, so `.nodes` is actually a dictionary (`Dict[str, NIRNode]`). +With our node above, we can define `nodes = {"my_node": MyNIRNode(np.array([...]))}`. + +Edges are simply two strings: a beginning and ending node in a tuple (`Tuple[str, str]`). +There are no restrictions on edges; you can connect nodes to themselves---multiple times if you wish. +That would look like this: `edges = [("my_node", "my_node")]`. + +In sum, a rudimentary, self-cyclic graph can be described in NIR as follows: + +```python +NIRGraph( + nodes = {"my_node": MyNIRNode(np.array([...]))}, + edges = [("my_node", "my_node")] +) +``` + +## Input and output types +All nodes are expected to carry two internal variables (`input_type` and `output_type`), that describe the names *and* shape of the input and outputs of the node as a dictionary (`Dict[str, np.ndarray]`). +The variables can be equalled to the function declaration in programming languages, where the names and types of the function arguments are given. +Most nodes have a single input (`"input"`) and output (`"output"`), in which case their `input_type` and `output_type` have single (trivial) entries: `{"input": ...}` and `{"output": ...}`. +Other nodes are more complicated and use explicit names for their arguments (see below). +In most cases the variables are inferred in the [`__post_init__`](https://docs.python.org/3/library/dataclasses.html#post-init-processing) method, but new implementations will have to somehow assign them. + +## Subgraphs and types +The decision to include the input and output types were made to disambiguate connectivity between nodes. +Immediately, they allow us to ensure that an edge connecting two nodes are valid; that the type in one end corresponds to the type in the other end. +But the types are necessary in cases where we wish to connect *to* or *from* subgraphs. + +Consider a subgraph `G` with two nodes `B` and `C`. +How can we specifically describe connectivity to `B` and not `C`? +By using the input types, we can *subscript* the edge to specify exactly which input we're connecting to: `G.B`. +An edge from `A` to `B` would then look like this: `("A", "G.B")`. +The same process works out of the graph, thanks to the `output_type`: we simply create an outgoing edge from `G.A`. + +See [the unit test file `test_architectures.py`](https://github.com/neuromorphs/NIR/blob/main/tests/test_architectures.py) for concrete examples on NIR graphs, input/output types, and subgraphs. \ No newline at end of file diff --git a/nir/__init__.py b/nir/__init__.py index dd2f877..6b3a1b9 100644 --- a/nir/__init__.py +++ b/nir/__init__.py @@ -7,4 +7,4 @@ from .read import read # noqa: F401 from .write import write # noqa: F401 -version = __version__ = "0.1.1" +version = __version__ = "0.2.0" diff --git a/nir/ir.py b/nir/ir.py index d73222c..3e7670f 100644 --- a/nir/ir.py +++ b/nir/ir.py @@ -9,10 +9,11 @@ Nodes = typing.Dict[str, "NIRNode"] # Edges map one node id to another via the identity Edges = typing.List[typing.Tuple[str, str]] -# Shape is a dict mapping strings to shapes -Shape = typing.Dict[str, np.ndarray] +# Types is a dict mapping strings to tensor shapes +Types = typing.Dict[str, np.ndarray] -def _parse_shape_argument(x: Shape, key: str): + +def _parse_shape_argument(x: Types, key: str): if isinstance(x, np.ndarray): return {key: x} elif isinstance(x, Sequence): @@ -22,6 +23,7 @@ def _parse_shape_argument(x: Shape, key: str): else: raise ValueError("Unknown shape argument", x) + @dataclass class NIRNode: """Base superclass of Neural Intermediate Representation Unit (NIR). @@ -30,9 +32,9 @@ class NIRNode: instantiated. """ - # Note: Adding input/output shapes as follows is ideal, but requires Python 3.10 - # input_shape: Shape = field(init=False, kw_only=True) - # output_shape: Shape = field(init=False, kw_only=True) + # Note: Adding input/output types as follows is ideal, but requires Python 3.10 + # input_type: Types = field(init=False, kw_only=True) + # output_type: Types = field(init=False, kw_only=True) @dataclass @@ -64,14 +66,14 @@ def unique_node_name(node, counts): return name counts = Counter() - node_dict = {"input": Input(input_shape=nodes[0].input_shape)} + node_dict = {"input": Input(input_type=nodes[0].input_type)} edges = [] for node in nodes: name = unique_node_name(node, counts) node_dict[name] = node - node_dict["output"] = Output(output_shape=nodes[-1].output_shape) + node_dict["output"] = Output(output_type=nodes[-1].output_type) names = list(node_dict) for i in range(len(names) - 1): @@ -86,16 +88,16 @@ def __post_init__(self): input_node_keys = [ k for k, node in self.nodes.items() if isinstance(node, Input) ] - self.input_shape = ( - {node_key: self.nodes[node_key].input_shape for node_key in input_node_keys} + self.input_type = ( + {node_key: self.nodes[node_key].input_type for node_key in input_node_keys} if len(input_node_keys) > 0 else None ) output_node_keys = [ k for k, node in self.nodes.items() if isinstance(node, Output) ] - self.output_shape = { - node_key: self.nodes[node_key].output_shape for node_key in output_node_keys + self.output_type = { + node_key: self.nodes[node_key].output_type for node_key in output_node_keys } @@ -116,12 +118,12 @@ class Affine(NIRNode): def __post_init__(self): assert len(self.weight.shape) >= 2, "Weight must be at least 2D" - self.input_shape = { + self.input_type = { "input": np.array( self.weight.shape[:-2] + tuple(np.array(self.weight.shape[-1:]).T) ) } - self.output_shape = { + self.output_type = { "output": np.array(self.weight.shape[:-2] + (self.weight.shape[-2],)) } @@ -138,8 +140,8 @@ class Conv1d(NIRNode): bias: np.ndarray # Bias C_out def __post_init__(self): - self.input_shape = {"input": np.array(self.weight.shape)[1:]} - self.output_shape = {"output": np.array(self.weight.shape)[[0, 2]]} + self.input_type = {"input": np.array(self.weight.shape)[1:]} + self.output_type = {"output": np.array(self.weight.shape)[[0, 2]]} @dataclass @@ -160,8 +162,8 @@ def __post_init__(self): self.padding = (self.padding, self.padding) if isinstance(self.dilation, int): self.dilation = (self.dilation, self.dilation) - self.input_shape = {"input": np.array(self.weight.shape)[1:]} - self.output_shape = {"output": np.array(self.weight.shape)[[0, 2, 3]]} + self.input_type = {"input": np.array(self.weight.shape)[1:]} + self.output_type = {"output": np.array(self.weight.shape)[[0, 2, 3]]} @dataclass @@ -216,8 +218,8 @@ def __post_init__(self): ), "All parameters must have the same shape" # If w_in is a scalar, make it an array of same shape as v_threshold self.w_in = np.ones_like(self.v_threshold) * self.w_in - self.input_shape = {"input": np.array(self.v_threshold.shape)} - self.output_shape = {"output": np.array(self.v_threshold.shape)} + self.input_type = {"input": np.array(self.v_threshold.shape)} + self.output_type = {"output": np.array(self.v_threshold.shape)} @dataclass @@ -234,8 +236,8 @@ class Delay(NIRNode): def __post_init__(self): # set input and output shape, if not set by user - self.input_shape = {"input": np.array(self.delay.shape)} - self.output_shape = {"output": np.array(self.delay.shape)} + self.input_type = {"input": np.array(self.delay.shape)} + self.output_type = {"output": np.array(self.delay.shape)} @dataclass @@ -243,30 +245,29 @@ class Flatten(NIRNode): """Flatten node. This node flattens its input tensor. - input_shape must be a dict with one key: "input". + input_type must be a dict with one key: "input". """ - # Shape of input tensor (overrrides input_shape from + # Shape of input tensor (overrrides input_type from # NIRNode to allow for non-keyword (positional) initialization) - input_shape: Shape + input_type: Types start_dim: int = 1 # First dimension to flatten end_dim: int = -1 # Last dimension to flatten def __post_init__(self): - self.input_shape = _parse_shape_argument(self.input_shape, "input") - print(self.input_shape) - concat = self.input_shape["input"][self.start_dim : self.end_dim].prod() - self.output_shape = { + self.input_type = _parse_shape_argument(self.input_type, "input") + concat = self.input_type["input"][self.start_dim : self.end_dim].prod() + self.output_type = { "output": np.array( [ - *self.input_shape["input"][: self.start_dim], + *self.input_type["input"][: self.start_dim], concat, - *self.input_shape["input"][self.end_dim :], + *self.input_type["input"][self.end_dim :], ] ) } # make sure input and output shape are valid - if np.prod(self.input_shape["input"]) != np.prod(self.output_shape["output"]): + if np.prod(self.input_type["input"]) != np.prod(self.output_type["output"]): raise ValueError("input and output shape must have same number of elements") @@ -283,8 +284,8 @@ class I(NIRNode): # noqa: E742 r: np.ndarray def __post_init__(self): - self.input_shape = {"input": np.array(self.r.shape)} - self.output_shape = {"output": np.array(self.r.shape)} + self.input_type = {"input": np.array(self.r.shape)} + self.output_type = {"output": np.array(self.r.shape)} @dataclass @@ -316,8 +317,8 @@ def __post_init__(self): assert ( self.r.shape == self.v_threshold.shape ), "All parameters must have the same shape" - self.input_shape = {"input": np.array(self.r.shape)} - self.output_shape = {"output": np.array(self.r.shape)} + self.input_type = {"input": np.array(self.r.shape)} + self.output_type = {"output": np.array(self.r.shape)} @dataclass @@ -327,13 +328,13 @@ class Input(NIRNode): This is a virtual node, which allows feeding in data into the graph. """ - # Shape of incoming data (overrrides input_shape from + # Shape of incoming data (overrrides input_type from # NIRNode to allow for non-keyword (positional) initialization) - input_shape: Shape + input_type: Types def __post_init__(self): - self.input_shape = _parse_shape_argument(self.input_shape, "input") - self.output_shape = {"output": self.input_shape["input"]} + self.input_type = _parse_shape_argument(self.input_type, "input") + self.output_type = {"output": self.input_type["input"]} @dataclass @@ -358,8 +359,8 @@ def __post_init__(self): assert ( self.tau.shape == self.r.shape == self.v_leak.shape ), "All parameters must have the same shape" - self.input_shape = {"input": np.array(self.r.shape)} - self.output_shape = {"output": np.array(self.r.shape)} + self.input_type = {"input": np.array(self.r.shape)} + self.output_type = {"output": np.array(self.r.shape)} @dataclass @@ -373,14 +374,12 @@ class Linear(NIRNode): def __post_init__(self): assert len(self.weight.shape) >= 2, "Weight must be at least 2D" - self.input_shape = { + self.input_type = { "input": np.array( self.weight.shape[:-2] + tuple(np.array(self.weight.shape[-1:]).T) ) } - self.output_shape = { - "output": self.weight.shape[:-2] + (self.weight.shape[-2],) - } + self.output_type = {"output": self.weight.shape[:-2] + (self.weight.shape[-2],)} @dataclass @@ -421,8 +420,8 @@ def __post_init__(self): == self.v_leak.shape == self.v_threshold.shape ), "All parameters must have the same shape" - self.input_shape = {"input": np.array(self.r.shape)} - self.output_shape = {"output": np.array(self.r.shape)} + self.input_type = {"input": np.array(self.r.shape)} + self.output_type = {"output": np.array(self.r.shape)} @dataclass @@ -432,13 +431,13 @@ class Output(NIRNode): Defines an output of the graph. """ - # Shape of incoming data (overrrides input_shape from + # Type of incoming data (overrrides input_type from # NIRNode to allow for non-keyword (positional) initialization) - output_shape: Shape + output_type: Types def __post_init__(self): - self.output_shape = _parse_shape_argument(self.output_shape, "output") - self.input_shape = {"input": self.output_shape["output"]} + self.output_type = _parse_shape_argument(self.output_type, "output") + self.input_type = {"input": self.output_type["output"]} @dataclass @@ -455,8 +454,8 @@ class Scale(NIRNode): scale: np.ndarray # Scaling factor def __post_init__(self): - self.input_shape = {"input": np.array(self.scale.shape)} - self.output_shape = {"output": np.array(self.scale.shape)} + self.input_type = {"input": np.array(self.scale.shape)} + self.output_type = {"output": np.array(self.scale.shape)} @dataclass @@ -475,5 +474,5 @@ class Threshold(NIRNode): threshold: np.ndarray # Firing threshold def __post_init__(self): - self.input_shape = {"input": np.array(self.threshold.shape)} - self.output_shape = {"output": np.array(self.threshold.shape)} + self.input_type = {"input": np.array(self.threshold.shape)} + self.output_type = {"output": np.array(self.threshold.shape)} diff --git a/nir/read.py b/nir/read.py index 53be7b7..4fb0471 100644 --- a/nir/read.py +++ b/nir/read.py @@ -34,14 +34,14 @@ def read_node(node: typing.Any) -> nir.NIRNode: return nir.Flatten( start_dim=node["start_dim"][()], end_dim=node["end_dim"][()], - input_shape={"input": node["input_shape"][()]}, + input_type={"input": node["input_type"][()]}, ) elif node["type"][()] == b"I": return nir.I(r=node["r"][()]) elif node["type"][()] == b"IF": return nir.IF(r=node["r"][()], v_threshold=node["v_threshold"][()]) elif node["type"][()] == b"Input": - return nir.Input(input_shape={"input": node["shape"][()]}) + return nir.Input(input_type={"input": node["shape"][()]}) elif node["type"][()] == b"LI": return nir.LI( tau=node["tau"][()], @@ -71,7 +71,7 @@ def read_node(node: typing.Any) -> nir.NIRNode: edges=[(a.decode("utf8"), b.decode("utf8")) for a, b in node["edges"][()]], ) elif node["type"][()] == b"Output": - return nir.Output(output_shape={"output": node["shape"][()]}) + return nir.Output(output_type={"output": node["shape"][()]}) elif node["type"][()] == b"Scale": return nir.Scale(scale=node["scale"][()]) elif node["type"][()] == b"Threshold": diff --git a/nir/write.py b/nir/write.py index 8b2b5c2..c118b08 100644 --- a/nir/write.py +++ b/nir/write.py @@ -41,7 +41,7 @@ def _convert_node(node: nir.NIRNode) -> dict: "type": "Flatten", "start_dim": node.start_dim, "end_dim": node.end_dim, - "input_shape": node.input_shape["input"], + "input_type": node.input_type["input"], } elif isinstance(node, nir.I): return {"type": "I", "r": node.r} @@ -52,7 +52,7 @@ def _convert_node(node: nir.NIRNode) -> dict: "v_threshold": node.v_threshold, } elif isinstance(node, nir.Input): - return {"type": "Input", "shape": node.input_shape["input"]} + return {"type": "Input", "shape": node.input_type["input"]} elif isinstance(node, nir.LI): return { "type": "LI", @@ -86,7 +86,7 @@ def _convert_node(node: nir.NIRNode) -> dict: "edges": node.edges, } elif isinstance(node, nir.Output): - return {"type": "Output", "shape": node.output_shape["output"]} + return {"type": "Output", "shape": node.output_type["output"]} elif isinstance(node, nir.Scale): return {"type": "Scale", "scale": node.scale} elif isinstance(node, nir.Threshold): diff --git a/tests/__init__.py b/tests/__init__.py index 59592bf..b15cf84 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -12,7 +12,7 @@ def mock_affine(*shape): def mock_input(*shape): - return nir.Input(input_shape=np.array(shape)) + return nir.Input(input_type=np.array(shape)) def mock_integrator(*shape): @@ -20,7 +20,7 @@ def mock_integrator(*shape): def mock_output(*shape): - return nir.Output(output_shape=np.array(shape)) + return nir.Output(output_type=np.array(shape)) def mock_delay(*shape): diff --git a/tests/test_architectures.py b/tests/test_architectures.py index c5e5832..ba31e58 100644 --- a/tests/test_architectures.py +++ b/tests/test_architectures.py @@ -21,6 +21,16 @@ def test_sequential(): def test_two_independent_branches(): + """ + ```mermaid + graph TD; + A --> B; + B --> C; + C --> D; + E --> F; + F --> G + ``` + """ # Branch 1 a = mock_affine(2, 3) b = nir.Delay(np.array([0.5, 0.1, 0.2])) @@ -54,6 +64,18 @@ def test_two_independent_branches(): def test_two_independent_branches_merging(): + """ + ```mermaid + graph TD; + A --> B; + B --> C; + C --> D; + E --> F; + F --> G; + G --> H; + D --> H; + ``` + """ # Branch 1 a = mock_affine(2, 3) b = nir.Delay(np.array([0.5, 0.1, 0.2])) @@ -80,7 +102,7 @@ def test_two_independent_branches_merging(): branch_2 = nir.NIRGraph.from_list(e, f, g) # Junction - # TODO: This should be a node that accepts two inputs + # TODO: This should be a node that accepts two input_type h = nir.LIF( tau=np.array([5, 2]), r=np.array([1, 1]), @@ -96,6 +118,18 @@ def test_two_independent_branches_merging(): def test_merge_and_split_single_output(): + """ + ```mermaid + graph TD; + A --> B; + B --> C; + C --> D; + B --> F; + F --> G; + G --> H; + D --> H; + ``` + """ # Part before split a = mock_affine(2, 3) b = nir.LIF( @@ -127,7 +161,7 @@ def test_merge_and_split_single_output(): branch_2 = nir.NIRGraph.from_list([e, f]) # Junction - # TODO: This should be a node that accepts two inputs + # TODO: This should be a node that accepts two input_type g = nir.Affine(weight=np.array([[2, 0], [1, 3], [4, 1]]), bias=np.array([0, 1])) nodes = { @@ -146,10 +180,10 @@ def test_merge_and_split_single_output(): factory_test_graph(ir) -def merge_and_split_different_outputs(): +def test_merge_and_split_different_output_type(): # Part before split a = mock_affine(3, 2) - # TODO: This should be a node with two outputs + # TODO: This should be a node with two output_type b = nir.LIF( tau=np.array([10, 20, 30]), r=np.array([1, 1, 1]), @@ -159,7 +193,6 @@ def merge_and_split_different_outputs(): pre_split = nir.NIRGraph.from_list([a, b]) # Branch 1 - reduce_1 = nir.Project(output_indices=[0]) c = mock_affine(3, 2) d = nir.LIF( tau=np.array([10, 20]), @@ -168,10 +201,8 @@ def merge_and_split_different_outputs(): v_threshold=np.array([1, 2]), ) branch_1 = nir.NIRGraph.from_list([c, d]) - expand_1 = nir.Project(output_indices=[0, float("nan")]) # Branch 2 - reduce_2 = nir.Project(output_indices=[1]) e = mock_affine(3, 2) f = nir.LIF( tau=np.array([15, 5]), @@ -180,37 +211,40 @@ def merge_and_split_different_outputs(): v_threshold=np.array([1, 1]), ) branch_2 = nir.NIRGraph.from_list([e, f]) - expand_2 = nir.Project(output_indices=[float("nan"), 1]) # Junction - # TODO: This should be a node that accepts two inputs + # TODO: This should be a node that accepts two input_type g = mock_affine(3, 2) nodes = { "pre_split": pre_split, - "reduce_1": reduce_1, - "reduce_2": reduce_2, "branch_1": branch_1, "branch_2": branch_2, - "expand_1": expand_1, - "expand_2": expand_2, "junction": g, } edges = [ - ("pre_split", "reduce_1"), - ("pre_split", "reduce_2"), - ("reduce_1", "branch_1"), - ("reduce_2", "branch_2"), - ("branch_1", "expand_1"), - ("expand_1", "junction"), - ("branch_2", "expand_2"), - ("expand_2", "junction"), + ("pre_split", "branch_1"), + ("pre_split", "branch_2"), + ("branch_1", "junction"), + ("branch_2", "junction"), ] ir = nir.NIRGraph(nodes=nodes, edges=edges) factory_test_graph(ir) def test_residual(): + """ + ```mermaid + graph TD; + A --> B; + B --> C; + C --> D; + A --> E; + D --> E; + E --> F; + ``` + """ + # Part before split a = mock_affine(2, 3) @@ -230,7 +264,7 @@ def test_residual(): ) # Junction - # TODO: This should be a node that accepts two inputs + # TODO: This should be a node that accepts two input_type e = mock_affine(3, 2) f = nir.LIF( tau=np.array([15, 5]), @@ -260,6 +294,18 @@ def test_residual(): def test_complex(): + """ + ```mermaid + graph TD; + A --> B; + A --> C; + C --> D; + C --> E; + B --> D; + D --> F; + E --> F; + ``` + """ a = nir.Affine(weight=np.array([[1, 2, 3]]), bias=np.array([[0, 0, 0]])) b = nir.LIF( tau=np.array([10, 20, 30]), @@ -273,13 +319,13 @@ def test_complex(): v_leak=np.array([0, 0, 0]), v_threshold=np.array([1, 1, 1]), ) - # TODO: This should be a node that accepts two inputs + # TODO: This should be a node that accepts two input_type d = nir.Affine( weight=np.array([[[1, 3], [2, 3], [1, 4]], [[2, 3], [1, 2], [1, 4]]]), bias=np.array([0, 0]), ) e = nir.Affine(weight=np.array([[1, 3], [2, 3], [1, 4]]), bias=np.array([0, 0])) - # TODO: This should be a node that accepts two inputs + # TODO: This should be a node that accepts two input_type f = nir.Affine( weight=np.array([[[1, 3], [1, 4]], [[2, 3], [3, 4]]]), bias=np.array([0, 0]) ) @@ -302,3 +348,40 @@ def test_complex(): ] ir = nir.NIRGraph(nodes=nodes, edges=edges) factory_test_graph(ir) + + +def test_subgraph_multiple_input_output(): + """ + ```mermaid + graph TD; + subgraph G + B; C; + end + A --> B; + A --> C; + B --> D; + C --> D; + ``` + """ + a = mock_affine(1, 3) + b = mock_affine(3, 2) + c = mock_affine(3, 2) + d = mock_affine(2, 1) + + # Subgraph + bi = nir.Input(b.input_type) + ci = nir.Input(b.input_type) + bo = nir.Output(b.output_type) + co = nir.Output(c.output_type) + g = nir.NIRGraph( + nodes={"b": b, "c": c, "bi": bi, "ci": ci, "bo": bo, "co": co}, + edges=[("bi", "b"), ("b", "bo"), ("ci", "c"), ("c"), "co"], + ) + + # Supgraph + nir.NIRGraph( + nodes={"a": a, "g": g, "d": d}, + edges=[("a", "g.bi"), ("a", "g.ci"), ("g.bo", "d"), ("g.co", "d")], + ) + + # TODO: Add type checking... diff --git a/tests/test_ir.py b/tests/test_ir.py index c31c5e9..ba72a30 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -56,7 +56,7 @@ def test_simple_with_input_output(): }, edges=[("in", "w"), ("w", "out")], ) - assert ir.nodes["in"].input_shape == {"input": np.array([3])} + assert ir.nodes["in"].input_type == {"input": np.array([3])} assert np.allclose(ir.nodes["w"].weight, a.weight) assert np.allclose(ir.nodes["w"].bias, a.bias) assert ir.edges == [("in", "w"), ("w", "out")] @@ -72,7 +72,7 @@ def test_delay(): }, edges=[("in", "d"), ("d", "out")], ) - assert ir.nodes["in"].input_shape == {"input": np.array([3])} + assert ir.nodes["in"].input_type == {"input": np.array([3])} assert np.allclose(ir.nodes["d"].delay, d.delay) assert ir.edges == [("in", "d"), ("d", "out")] @@ -93,7 +93,7 @@ def test_threshold(): }, edges=[("in", "thr"), ("thr", "out")], ) - assert ir.nodes["in"].input_shape == {"input": np.array([3])} + assert ir.nodes["in"].input_type == {"input": np.array([3])} assert np.allclose(ir.nodes["thr"].threshold, threshold) assert ir.edges == [("in", "thr"), ("thr", "out")] @@ -108,16 +108,16 @@ def test_linear(): def test_flatten(): ir = nir.NIRGraph( nodes={ - "in": nir.Input(input_shape=np.array([4, 5, 2])), + "in": nir.Input(input_type=np.array([4, 5, 2])), "flat": nir.Flatten( - start_dim=0, end_dim=0, input_shape={"input": np.array([4, 5, 2])} + start_dim=0, end_dim=0, input_type={"input": np.array([4, 5, 2])} ), - "out": nir.Output(output_shape=np.array([20, 2])), + "out": nir.Output(output_type=np.array([20, 2])), }, edges=[("in", "flat"), ("flat", "out")], ) - assert np.allclose(ir.nodes["in"].input_shape["input"], np.array([4, 5, 2])) - assert np.allclose(ir.nodes["out"].input_shape["input"], np.array([20, 2])) + assert np.allclose(ir.nodes["in"].input_type["input"], np.array([4, 5, 2])) + assert np.allclose(ir.nodes["out"].input_type["input"], np.array([20, 2])) def test_from_list_naming(): @@ -149,7 +149,7 @@ def test_from_list_naming(): assert "affine_2" in ir.nodes.keys() assert "affine_3" in ir.nodes.keys() assert "output" in ir.nodes.keys() - assert np.allclose(ir.nodes["input"].input_shape["input"], [2]) + assert np.allclose(ir.nodes["input"].input_type["input"], [2]) assert np.allclose(ir.nodes["linear"].weight, np.array([[3, 1], [-1, 2], [1, 2]])) assert np.allclose( ir.nodes["linear_1"].weight, np.array([[3, 1], [-1, 4], [1, 2]]).T @@ -170,8 +170,8 @@ def test_from_list_naming(): ir.nodes["affine_3"].weight, np.array([[2, 1], [-1, 3], [1, 2]]).T ) assert np.allclose(ir.nodes["affine_3"].bias, np.array([-2, 3])) - print(ir.nodes["output"].input_shape["input"]) - assert np.allclose(ir.nodes["output"].input_shape["input"], [2]) + print(ir.nodes["output"].input_type["input"]) + assert np.allclose(ir.nodes["output"].input_type["input"], [2]) assert ir.edges == [ ("input", "linear"), ("linear", "linear_1"), @@ -212,11 +212,11 @@ def test_subgraph_merge(): nodes={"L": g1, "R": g2, "E": end}, edges=[("L.output", "E.input"), ("R.output", "E.input")], ) - assert np.allclose(g.nodes["L"].nodes["linear"].input_shape["input"], [2]) - assert np.allclose(g.nodes["L"].nodes["linear_1"].input_shape["input"], [3]) - assert np.allclose(g.nodes["R"].nodes["linear"].input_shape["input"], [1]) - assert np.allclose(g.nodes["R"].nodes["linear_1"].input_shape["input"], [3]) - assert np.allclose(g.nodes["E"].input_shape["input"], [2]) + assert np.allclose(g.nodes["L"].nodes["linear"].input_type["input"], [2]) + assert np.allclose(g.nodes["L"].nodes["linear_1"].input_type["input"], [3]) + assert np.allclose(g.nodes["R"].nodes["linear"].input_type["input"], [1]) + assert np.allclose(g.nodes["R"].nodes["linear_1"].input_type["input"], [3]) + assert np.allclose(g.nodes["E"].input_type["input"], [2]) assert g.edges == [("L.output", "E.input"), ("R.output", "E.input")] assert g.nodes["L"].edges == [ ("input", "linear"), diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 4e0d7d5..c9c2cc1 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -126,18 +126,18 @@ def test_current_based_leaky_integrator_and_fire(): def test_scale(): ir = nir.NIRGraph.from_list( - nir.Input(input_shape=np.array([3])), + nir.Input(input_type=np.array([3])), nir.Scale(scale=np.array([1, 2, 3])), - nir.Output(output_shape=np.array([3])), + nir.Output(output_type=np.array([3])), ) factory_test_graph(ir) def test_simple_with_read_write(): ir = nir.NIRGraph.from_list( - nir.Input(input_shape=np.array([3])), + nir.Input(input_type=np.array([3])), mock_affine(2, 2), - nir.Output(output_shape=np.array([3])), + nir.Output(output_type=np.array([3])), ) factory_test_graph(ir) @@ -164,12 +164,12 @@ def test_threshold(): def test_flatten(): ir = nir.NIRGraph.from_list( - nir.Input(input_shape=np.array([2, 3])), + nir.Input(input_type=np.array([2, 3])), nir.Flatten( start_dim=0, end_dim=0, - input_shape={"input": np.array([2, 3])}, + input_type={"input": np.array([2, 3])}, ), - nir.Output(output_shape=np.array([6])), + nir.Output(output_type=np.array([6])), ) factory_test_graph(ir)