diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c391399..29c3074 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,9 +22,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install . python -m pip install -r requirements.txt python -m pip install -r tests/test_requirements.txt + python -m pip install . - name: Lint with ruff run: | # stop the build if there are Python syntax errors or undefined names diff --git a/nirtorch/__init__.py b/nirtorch/__init__.py index 9469cc1..4b5b70c 100644 --- a/nirtorch/__init__.py +++ b/nirtorch/__init__.py @@ -1,2 +1,4 @@ from .from_nir import load # noqa F401 from .to_nir import extract_nir_graph # noqa F401 + +__version__ = version = "0.2.1" diff --git a/nirtorch/graph.py b/nirtorch/graph.py index 787e0a4..ac5c7c1 100644 --- a/nirtorch/graph.py +++ b/nirtorch/graph.py @@ -21,9 +21,13 @@ def named_modules_map( """ modules_map = {} for name, mod in model.named_modules(): + # Ignore sequential modules + if isinstance(mod, nn.Sequential): + continue modules_map[mod] = name if model_name is None: - del modules_map[model] + if model in modules_map: + del modules_map[model] else: modules_map[model] = model_name return modules_map @@ -34,20 +38,20 @@ def __init__( self, elem: Any, name: str, - outgoing_nodes: Optional[List["Node"]] = None, + outgoing_nodes: Optional[Dict["Node", torch.Tensor]] = None, ) -> None: self.elem = elem self.name = name if not outgoing_nodes: - self.outgoing_nodes = [] + self.outgoing_nodes = {} else: self.outgoing_nodes = outgoing_nodes - def add_outgoing(self, node: "Node"): - self.outgoing_nodes.append(node) + def add_outgoing(self, node: "Node", shape=None) -> None: + self.outgoing_nodes[node] = shape def __str__(self) -> str: - return f"Node: {self.name}, Out: {len(self.outgoing_nodes)}" + return f"Node: {self.name} ({type(self.elem)}), Out: {len(self.outgoing_nodes)}" def __eq__(self, other: Any) -> bool: # Two nodes are meant to be the same if they refer to the same element @@ -62,9 +66,14 @@ def __hash__(self): class Graph: - def __init__(self, module_names: Dict[nn.Module, str]) -> None: + def __init__( + self, + module_names: Dict[nn.Module, str], + module_output_types: Dict[nn.Module, torch.Tensor] = {}, + ) -> None: self.module_names = module_names self.node_list: List[Node] = [] + self.module_output_types = module_output_types self._last_used_tensor_id = None # Add modules to node_list for mod, name in self.module_names.items(): @@ -129,6 +138,7 @@ def add_edge( self, source: Union[torch.Tensor, nn.Module], destination: Union[torch.Tensor, nn.Module], + shape: torch.Tensor = None, ): if self._is_mod_and_not_in_module_names(source): return @@ -140,7 +150,7 @@ def add_edge( source_node = self.add_or_get_node_for_elem(source) destination_node = self.add_or_get_node_for_elem(destination) - source_node.add_outgoing(destination_node) + source_node.add_outgoing(destination_node, shape) return source_node, destination_node def get_leaf_modules(self) -> Dict[nn.Module, str]: @@ -174,9 +184,10 @@ def _is_mod_and_not_in_module_names(self, elem: Any) -> bool: return False def populate_from(self, other_graph: "Graph"): + self.module_output_types.update(other_graph.module_output_types) for node in other_graph.node_list: - for outgoing_node in node.outgoing_nodes: - self.add_edge(node.elem, outgoing_node.elem) + for outgoing_node, shape in node.outgoing_nodes.items(): + self.add_edge(node.elem, outgoing_node.elem, shape) def __str__(self) -> str: return self.to_md() @@ -188,7 +199,7 @@ def to_md(self) -> str: """ for node in self.node_list: if node.outgoing_nodes: - for outgoing in node.outgoing_nodes: + for outgoing, _ in node.outgoing_nodes.items(): mermaid_md += f"{node.name} --> {outgoing.name};\n" else: mermaid_md += f"{node.name};\n" @@ -226,7 +237,7 @@ def ignore_submodules_of(self, classes: List[Type]) -> "Graph": if mod not in sub_modules_to_ignore: new_named_modules[mod] = name # Create a new graph with the allowed modules - new_graph = Graph(new_named_modules) + new_graph = Graph(new_named_modules, self.module_output_types) new_graph.populate_from(self) return new_graph @@ -241,7 +252,7 @@ def find_source_nodes_of(self, node: Node) -> List[Node]: """ source_node_list = [] for source_node in self.node_list: - for outnode in source_node.outgoing_nodes: + for outnode, shape in source_node.outgoing_nodes.items(): if node == outnode: source_node_list.append(source_node) return source_node_list @@ -261,7 +272,7 @@ def ignore_nodes(self, class_type: Type) -> "Graph": } # Generate the new graph with the filtered module names - graph = Graph(new_module_names) + graph = Graph(new_module_names, self.module_output_types) # Iterate over all the nodes for node in self.node_list: if isinstance(node.elem, class_type): @@ -273,18 +284,18 @@ def ignore_nodes(self, class_type: Type) -> "Graph": # Get all of its destinations if node.outgoing_nodes: # If no destinations, it is a leaf node, just drop it. - for outgoing_node in node.outgoing_nodes: + for outgoing_node, shape in node.outgoing_nodes.items(): # Directly add an edge from source to destination for source_node in source_node_list: - graph.add_edge(source_node.elem, outgoing_node.elem) + graph.add_edge(source_node.elem, outgoing_node.elem, shape) # NOTE: Assuming that the destination is not of the same # type here else: # This is to preserve the graph if executed on a graph that is # already filtered - for outnode in node.outgoing_nodes: + for outnode, shape in node.outgoing_nodes.items(): if not isinstance(outnode.elem, class_type): - graph.add_edge(node.elem, outnode.elem) + graph.add_edge(node.elem, outnode.elem, shape) return graph def get_root(self) -> List[Node]: @@ -296,7 +307,8 @@ def get_root(self) -> List[Node]: roots = [] for node in self.node_list: sources = self.find_source_nodes_of(node) - if len(sources) == 0: + # Append root node if it has no sources (and it isn't a sequential module) + if len(sources) == 0 and not isinstance(node.elem, torch.nn.Sequential): roots.append(node) return roots @@ -304,23 +316,38 @@ def get_root(self) -> List[Node]: _torch_module_call = torch.nn.Module.__call__ -def module_forward_wrapper(model_graph: Graph) -> Callable[..., Any]: +def module_forward_wrapper( + model_graph: Graph, output_types: Dict[nn.Module, torch.Tensor] +) -> Callable[..., Any]: def my_forward(mod: nn.Module, *args, **kwargs) -> Any: - # Iterate over all inputs - for i, input_data in enumerate(args): - # Create nodes and edges - model_graph.add_edge(input_data, mod) out = _torch_module_call(mod, *args, **kwargs) + if isinstance(out, tuple): out_tuple = (out[0],) + output_types[mod] = out[0].shape elif isinstance(out, torch.Tensor): out_tuple = (out,) + output_types[mod] = out.shape else: raise Exception("Unknown output format") + + # Iterate over all inputs + for i, input_data in enumerate(args): + # Create nodes and edges + model_graph.add_edge( + input_data, + mod, + input_data.shape if isinstance(input_data, torch.Tensor) else None, + ) + # Iterate over all outputs and create nodes and edges for output_data in out_tuple: # Create nodes and edges - model_graph.add_edge(mod, output_data) + model_graph.add_edge( + mod, + output_data, + output_data.shape if isinstance(output_data, torch.Tensor) else None, + ) return out return my_forward @@ -341,11 +368,12 @@ class GraphTracer: def __init__(self, mod: nn.Module) -> None: self.original_torch_call = nn.Module.__call__ - self.graph = Graph(mod) + self.output_types = {} + self.graph = Graph(mod, self.output_types) def __enter__(self) -> "GraphTracer": # Override the torch call method - nn.Module.__call__ = module_forward_wrapper(self.graph) + nn.Module.__call__ = module_forward_wrapper(self.graph, self.output_types) return self def __exit__(self, exc_type, exc_value, exc_tb): @@ -367,7 +395,10 @@ def extract_torch_graph( If specified, it will be included in the graph. If set to None, only its submodules will be listed in the graph. Defaults to "model". - + for n in torch_graph.node_list: + n_names = {x.name for x in n.outgoing_nodes} + if node.name in n_names: + shape = n.outgoing_nodes[node] Returns: Graph: A graph object representing the computational graph of the given model """ diff --git a/nirtorch/to_nir.py b/nirtorch/to_nir.py index ac338c5..c12511d 100644 --- a/nirtorch/to_nir.py +++ b/nirtorch/to_nir.py @@ -51,15 +51,18 @@ def extract_nir_graph( for indx, node in enumerate(torch_graph.node_list): # Convert the node type to NIR subgraph mapped_node = model_map(node.elem) - print(mapped_node, node) if isinstance(mapped_node, nir.NIRGraph): - for n in mapped_node.nodes: - nir_nodes[n.name] = n + for k, v in mapped_node.nodes.items(): + # For now, we add nodes in subgraphs to the top-level node list + # TODO: Parse graphs recursively + if isinstance(v, nir.NIRNode): + nir_nodes[f"{node.name}.{k}"] = v + else: + nir_nodes[v.name] = v # Add edges from graph for x, y in mapped_node.edges: - print(x, y) - nir_edges.append(x, y) + nir_edges.append((f"{node.name}.{x}", f"{node.name}.{y}")) else: nir_nodes[node.name] = mapped_node @@ -72,12 +75,17 @@ def extract_nir_graph( # Get all the edges for node in torch_graph.node_list: - for destination in node.outgoing_nodes: + for destination, shape in node.outgoing_nodes.items(): nir_edges.append((node.name, destination.name)) + if len(node.outgoing_nodes) == 0: out_name = "output" - output_node = nir.Output(None) + # Try to find shape of input to the Output node + output_node = nir.Output(torch_graph.module_output_types[node.elem]) nir_nodes[out_name] = output_node nir_edges.append((node.name, out_name)) + # Remove duplicate edges + nir_edges = list(set(nir_edges)) + return nir.NIRGraph(nir_nodes, nir_edges) diff --git a/pyproject.toml b/pyproject.toml index 3513a08..30ba287 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,6 @@ requires = ["setuptools"] [project] name = "nirtorch" -version = "0.2.0" description = "Neuromorphic Intermediate Representation" authors = [ { name = "Steven Abreu", email = "s.abreu@rug.nl" }, @@ -29,3 +28,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] dependencies = [ "torch", "nir" ] +dynamic = ["version"] # Version number read from __init__.py + +[tool.setuptools.dynamic] +version = {attr = "nirtorch.__version__"} \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_from_nir.py b/tests/test_from_nir.py index f8014cc..c9d6bd4 100644 --- a/tests/test_from_nir.py +++ b/tests/test_from_nir.py @@ -8,8 +8,8 @@ def _torch_model_map(m: nir.NIRNode, device: str = "cpu") -> torch.nn.Module: if isinstance(m, nir.Affine): lin = torch.nn.Linear(*m.weight.shape[-2:]) - lin.weight.data = torch.tensor(m.weight, device=device) - lin.bias.data = torch.tensor(m.bias, device=device) + lin.weight.data = torch.nn.Parameter(m.weight.to(device)) + lin.bias.data = torch.nn.Parameter(m.bias.to(device)) return lin elif isinstance(m, nir.Input) or isinstance(m, nir.Output): return None @@ -27,8 +27,8 @@ def test_extract_lin(): x = torch.randn(1, 1) lin = nir.Affine(x, torch.randn(1, 1)) torchlin = torch.nn.Linear(1, 1) - torchlin.weight.data = torch.tensor(lin.weight) - torchlin.bias.data = torch.tensor(lin.bias) + torchlin.weight.data = torch.nn.Parameter(lin.weight) + torchlin.bias.data = torch.nn.Parameter(lin.bias) y = torchlin(torchlin(x)) g = nir.NIRGraph({"a": lin, "b": lin}, [("a", "b")]) m = load(g, _torch_model_map) diff --git a/tests/test_graph.py b/tests/test_graph.py index 3be8211..b47a415 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -67,10 +67,10 @@ def forward(self, data): return out2 -input_shape = (2, 28, 28) +input_type = (2, 28, 28) batch_size = 1 -data = torch.ones((batch_size, *input_shape)) +data = torch.ones((batch_size, *input_type)) my_branched_model = SinabsBranchedModel() @@ -107,8 +107,9 @@ def test_module_forward_wrapper(): from nirtorch.graph import Graph, module_forward_wrapper, named_modules_map + output_types = {} model_graph = Graph(named_modules_map(mymodel)) - new_call = module_forward_wrapper(model_graph) + new_call = module_forward_wrapper(model_graph, output_types) # Override call to the new wrapped call nn.Module.__call__ = new_call @@ -123,6 +124,7 @@ def test_module_forward_wrapper(): assert ( len(model_graph.node_list) == 1 + 5 + 5 + 1 ) # 1 top module + 5 submodules + 5 tensors + 1 output tensor + assert len(output_types) == 6 # 1 top module + 5 submodules def test_graph_tracer(): @@ -221,7 +223,7 @@ def test_snn_stateful(): model = NorseStatefulModel() graph = extract_torch_graph(model, sample_data=torch.rand((1, 2, 3, 4))) - assert len(graph.node_list) == 7 # 2 + 1 nested + 4 tensors + assert len(graph.node_list) == 6 # 2 + 1 nested + 3 tensors def test_ignore_tensors(): @@ -264,3 +266,33 @@ def test_ignore_nodes_parent_model(): with pytest.raises(ValueError): new_graph.find_node(my_branched_model) + + +def test_input_output(): + from norse.torch import to_nir as norse_to_nir + + g = norse_to_nir(NorseStatefulModel(), data) + assert len(g.nodes) == 4 # in -> relu -> lif -> out + assert len(g.edges) == 3 + + +def test_output_type_when_single_node(): + import nir + from nirtorch import extract_nir_graph + + g = extract_nir_graph( + torch.nn.ReLU(), + lambda x: nir.Threshold(torch.tensor(0.1)), + sample_data=torch.rand((1,)), + ) + g.nodes["output"].output_type["output"] == torch.Size([1]) + + +def test_sequential_flatten(): + import nir + from nirtorch import extract_nir_graph + + d = torch.empty(2, 3, 4) + g = extract_nir_graph(torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d) + g.nodes["input"].input_type["input"] == (2, 3, 4) + g.nodes["output"].output_type["output"] == (2, 3 * 4) diff --git a/tests/test_requirements.txt b/tests/test_requirements.txt index 589017c..351614b 100644 --- a/tests/test_requirements.txt +++ b/tests/test_requirements.txt @@ -2,3 +2,5 @@ pytest torch nir ruff +git+https://github.com/norse/norse +sinabs \ No newline at end of file diff --git a/tests/test_to_nir.py b/tests/test_to_nir.py index 99e7cb3..f0f0665 100644 --- a/tests/test_to_nir.py +++ b/tests/test_to_nir.py @@ -13,9 +13,9 @@ def test_extract_single(): lambda x: nir.Affine(x.weight.detach().numpy(), x.bias.detach().numpy()), torch.rand(1, 1), ) - assert g.edges == [("input", "model"), ("model", "output")] + assert set(g.edges) == {("input", "model"), ("model", "output")} assert isinstance(g.nodes["input"], nir.Input) - assert np.allclose(g.nodes["input"].shape, np.array([1, 1])) + assert np.allclose(g.nodes["input"].input_type["input"], np.array([1, 1])) assert isinstance(g.nodes["model"], nir.Affine) assert np.allclose(g.nodes["model"].weight, m.weight.detach().numpy()) assert np.allclose(g.nodes["model"].bias, m.bias.detach().numpy()) @@ -40,7 +40,7 @@ def dummy_model_map(module: nn.Module) -> nir.NIRNode: assert len(nir_graph.nodes) == 8 print(nir_graph.edges) - assert nir_graph.edges == [ + assert set(nir_graph.edges) == { ("input", "0"), ("0", "1"), ("1", "2"), @@ -48,7 +48,7 @@ def dummy_model_map(module: nn.Module) -> nir.NIRNode: ("3", "4"), ("4", "5"), ("5", "output"), - ] + } class BranchedModel(nn.Module): @@ -61,34 +61,53 @@ def forward(self, x): return x @ self.a @ self.b -# def test_extract_multiple_explicit(): -# model = nn.Sequential(BranchedModel(1, 2, 3), nn.Linear(3, 4)) -# -# def extractor(module: nn.Module): -# if isinstance(module, BranchedModel): -# return nir.NIRGraph( -# nodes=[ -# nir.Input(np.array(module.a.shape[0])), -# nir.Linear(module.a), -# nir.Linear(module.b), -# nir.Output(), -# ], -# edges=[(0, 1), (0, 2), (1, 3), (2, 3)], -# ) -# else: -# return nir.Affine(module.weight, module.bias) -# -# g = extract_nir_graph(model, extractor, torch.rand(1)) -# print([type(n) for n in g.nodes]) -# assert len(g.nodes) == 7 -# assert len(g.edges) == 7 +def test_extract_multiple_explicit(): + model = nn.Sequential(BranchedModel(1, 2, 3), nn.Linear(3, 4)) + + def extractor(module: nn.Module): + if isinstance(module, BranchedModel): + return nir.NIRGraph( + nodes={ + "0": nir.Input(np.array(module.a.shape[0])), + "1": nir.Linear(module.a), + "2": nir.Linear(module.b), + "3": nir.Output(np.array(module.b.shape[0])), + }, + edges=[("0", "1"), ("0", "2"), ("1", "2"), ("1", "3"), ("2", "3")], + ) + else: + return nir.Affine(module.weight, module.bias) + + g = extract_nir_graph(model, extractor, torch.rand(1)) + print([type(n) for n in g.nodes]) + assert len(g.nodes) == 7 + assert len(g.edges) == 8 # in + 5 + 1 + out + + +def test_extract_recursive(): + class RecursiveModel(torch.nn.Module): + def forward(self, x, s=None): + if s is None: + s = torch.zeros_like(x) + return x + 1, s + x + + model = RecursiveModel() + + def extractor(m): + if isinstance(m, RecursiveModel): + return nir.Delay(np.array([1])) + + g = extract_nir_graph(model, extractor, torch.rand(1)) + assert set(g.edges) == { + ("input", "model"), + ("model", "output"), + # ("model", "model") TODO: Detect and add recursive connections + } -# -# # def test_extract_stateful(): # model = norse.SequentialState(norse.LIFBoxCell(), nn.Linear(3, 1)) -# + # def extract(module: torch.nn.Module): # if isinstance(module, norse.LIFBoxCell): # return nir.NIR( @@ -103,9 +122,8 @@ def forward(self, x): # ) # elif isinstance(module, torch.nn.Linear): # return nir.NIR(nodes=[nir.Linear(module.weight, module.bias)]) -# + # graph = extract_nir_graph(model, extract, torch.rand(1, 3)) # assert len(graph.nodes) == 4 # assert len(graph.edges) == 3 # assert graph.edges[0] == (0, 1) -#