Skip to content

Commit

Permalink
Added shapes to graph edges (#7)
Browse files Browse the repository at this point in the history
* Added shapes to graph edges
* Updated to latest changes in NIR
* Renamed shape -> type
* Added norse and sinabs to test requirements
* Bumped version number to 0.2.1

---------

Co-authored-by: Sadique Sheik <[email protected]>
  • Loading branch information
Jegp and sheiksadique authored Sep 12, 2023
1 parent 6b3ba5e commit 7b1976f
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .
python -m pip install -r requirements.txt
python -m pip install -r tests/test_requirements.txt
python -m pip install .
- name: Lint with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
2 changes: 2 additions & 0 deletions nirtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .from_nir import load # noqa F401
from .to_nir import extract_nir_graph # noqa F401

__version__ = version = "0.2.1"
87 changes: 59 additions & 28 deletions nirtorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ def named_modules_map(
"""
modules_map = {}
for name, mod in model.named_modules():
# Ignore sequential modules
if isinstance(mod, nn.Sequential):
continue
modules_map[mod] = name
if model_name is None:
del modules_map[model]
if model in modules_map:
del modules_map[model]
else:
modules_map[model] = model_name
return modules_map
Expand All @@ -34,20 +38,20 @@ def __init__(
self,
elem: Any,
name: str,
outgoing_nodes: Optional[List["Node"]] = None,
outgoing_nodes: Optional[Dict["Node", torch.Tensor]] = None,
) -> None:
self.elem = elem
self.name = name
if not outgoing_nodes:
self.outgoing_nodes = []
self.outgoing_nodes = {}
else:
self.outgoing_nodes = outgoing_nodes

def add_outgoing(self, node: "Node"):
self.outgoing_nodes.append(node)
def add_outgoing(self, node: "Node", shape=None) -> None:
self.outgoing_nodes[node] = shape

def __str__(self) -> str:
return f"Node: {self.name}, Out: {len(self.outgoing_nodes)}"
return f"Node: {self.name} ({type(self.elem)}), Out: {len(self.outgoing_nodes)}"

def __eq__(self, other: Any) -> bool:
# Two nodes are meant to be the same if they refer to the same element
Expand All @@ -62,9 +66,14 @@ def __hash__(self):


class Graph:
def __init__(self, module_names: Dict[nn.Module, str]) -> None:
def __init__(
self,
module_names: Dict[nn.Module, str],
module_output_types: Dict[nn.Module, torch.Tensor] = {},
) -> None:
self.module_names = module_names
self.node_list: List[Node] = []
self.module_output_types = module_output_types
self._last_used_tensor_id = None
# Add modules to node_list
for mod, name in self.module_names.items():
Expand Down Expand Up @@ -129,6 +138,7 @@ def add_edge(
self,
source: Union[torch.Tensor, nn.Module],
destination: Union[torch.Tensor, nn.Module],
shape: torch.Tensor = None,
):
if self._is_mod_and_not_in_module_names(source):
return
Expand All @@ -140,7 +150,7 @@ def add_edge(

source_node = self.add_or_get_node_for_elem(source)
destination_node = self.add_or_get_node_for_elem(destination)
source_node.add_outgoing(destination_node)
source_node.add_outgoing(destination_node, shape)
return source_node, destination_node

def get_leaf_modules(self) -> Dict[nn.Module, str]:
Expand Down Expand Up @@ -174,9 +184,10 @@ def _is_mod_and_not_in_module_names(self, elem: Any) -> bool:
return False

def populate_from(self, other_graph: "Graph"):
self.module_output_types.update(other_graph.module_output_types)
for node in other_graph.node_list:
for outgoing_node in node.outgoing_nodes:
self.add_edge(node.elem, outgoing_node.elem)
for outgoing_node, shape in node.outgoing_nodes.items():
self.add_edge(node.elem, outgoing_node.elem, shape)

def __str__(self) -> str:
return self.to_md()
Expand All @@ -188,7 +199,7 @@ def to_md(self) -> str:
"""
for node in self.node_list:
if node.outgoing_nodes:
for outgoing in node.outgoing_nodes:
for outgoing, _ in node.outgoing_nodes.items():
mermaid_md += f"{node.name} --> {outgoing.name};\n"
else:
mermaid_md += f"{node.name};\n"
Expand Down Expand Up @@ -226,7 +237,7 @@ def ignore_submodules_of(self, classes: List[Type]) -> "Graph":
if mod not in sub_modules_to_ignore:
new_named_modules[mod] = name
# Create a new graph with the allowed modules
new_graph = Graph(new_named_modules)
new_graph = Graph(new_named_modules, self.module_output_types)
new_graph.populate_from(self)
return new_graph

Expand All @@ -241,7 +252,7 @@ def find_source_nodes_of(self, node: Node) -> List[Node]:
"""
source_node_list = []
for source_node in self.node_list:
for outnode in source_node.outgoing_nodes:
for outnode, shape in source_node.outgoing_nodes.items():
if node == outnode:
source_node_list.append(source_node)
return source_node_list
Expand All @@ -261,7 +272,7 @@ def ignore_nodes(self, class_type: Type) -> "Graph":
}

# Generate the new graph with the filtered module names
graph = Graph(new_module_names)
graph = Graph(new_module_names, self.module_output_types)
# Iterate over all the nodes
for node in self.node_list:
if isinstance(node.elem, class_type):
Expand All @@ -273,18 +284,18 @@ def ignore_nodes(self, class_type: Type) -> "Graph":
# Get all of its destinations
if node.outgoing_nodes:
# If no destinations, it is a leaf node, just drop it.
for outgoing_node in node.outgoing_nodes:
for outgoing_node, shape in node.outgoing_nodes.items():
# Directly add an edge from source to destination
for source_node in source_node_list:
graph.add_edge(source_node.elem, outgoing_node.elem)
graph.add_edge(source_node.elem, outgoing_node.elem, shape)
# NOTE: Assuming that the destination is not of the same
# type here
else:
# This is to preserve the graph if executed on a graph that is
# already filtered
for outnode in node.outgoing_nodes:
for outnode, shape in node.outgoing_nodes.items():
if not isinstance(outnode.elem, class_type):
graph.add_edge(node.elem, outnode.elem)
graph.add_edge(node.elem, outnode.elem, shape)
return graph

def get_root(self) -> List[Node]:
Expand All @@ -296,31 +307,47 @@ def get_root(self) -> List[Node]:
roots = []
for node in self.node_list:
sources = self.find_source_nodes_of(node)
if len(sources) == 0:
# Append root node if it has no sources (and it isn't a sequential module)
if len(sources) == 0 and not isinstance(node.elem, torch.nn.Sequential):
roots.append(node)
return roots


_torch_module_call = torch.nn.Module.__call__


def module_forward_wrapper(model_graph: Graph) -> Callable[..., Any]:
def module_forward_wrapper(
model_graph: Graph, output_types: Dict[nn.Module, torch.Tensor]
) -> Callable[..., Any]:
def my_forward(mod: nn.Module, *args, **kwargs) -> Any:
# Iterate over all inputs
for i, input_data in enumerate(args):
# Create nodes and edges
model_graph.add_edge(input_data, mod)
out = _torch_module_call(mod, *args, **kwargs)

if isinstance(out, tuple):
out_tuple = (out[0],)
output_types[mod] = out[0].shape
elif isinstance(out, torch.Tensor):
out_tuple = (out,)
output_types[mod] = out.shape
else:
raise Exception("Unknown output format")

# Iterate over all inputs
for i, input_data in enumerate(args):
# Create nodes and edges
model_graph.add_edge(
input_data,
mod,
input_data.shape if isinstance(input_data, torch.Tensor) else None,
)

# Iterate over all outputs and create nodes and edges
for output_data in out_tuple:
# Create nodes and edges
model_graph.add_edge(mod, output_data)
model_graph.add_edge(
mod,
output_data,
output_data.shape if isinstance(output_data, torch.Tensor) else None,
)
return out

return my_forward
Expand All @@ -341,11 +368,12 @@ class GraphTracer:

def __init__(self, mod: nn.Module) -> None:
self.original_torch_call = nn.Module.__call__
self.graph = Graph(mod)
self.output_types = {}
self.graph = Graph(mod, self.output_types)

def __enter__(self) -> "GraphTracer":
# Override the torch call method
nn.Module.__call__ = module_forward_wrapper(self.graph)
nn.Module.__call__ = module_forward_wrapper(self.graph, self.output_types)
return self

def __exit__(self, exc_type, exc_value, exc_tb):
Expand All @@ -367,7 +395,10 @@ def extract_torch_graph(
If specified, it will be included in the graph.
If set to None, only its submodules will be listed in the graph.
Defaults to "model".
for n in torch_graph.node_list:
n_names = {x.name for x in n.outgoing_nodes}
if node.name in n_names:
shape = n.outgoing_nodes[node]
Returns:
Graph: A graph object representing the computational graph of the given model
"""
Expand Down
22 changes: 15 additions & 7 deletions nirtorch/to_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,18 @@ def extract_nir_graph(
for indx, node in enumerate(torch_graph.node_list):
# Convert the node type to NIR subgraph
mapped_node = model_map(node.elem)
print(mapped_node, node)

if isinstance(mapped_node, nir.NIRGraph):
for n in mapped_node.nodes:
nir_nodes[n.name] = n
for k, v in mapped_node.nodes.items():
# For now, we add nodes in subgraphs to the top-level node list
# TODO: Parse graphs recursively
if isinstance(v, nir.NIRNode):
nir_nodes[f"{node.name}.{k}"] = v
else:
nir_nodes[v.name] = v
# Add edges from graph
for x, y in mapped_node.edges:
print(x, y)
nir_edges.append(x, y)
nir_edges.append((f"{node.name}.{x}", f"{node.name}.{y}"))
else:
nir_nodes[node.name] = mapped_node

Expand All @@ -72,12 +75,17 @@ def extract_nir_graph(

# Get all the edges
for node in torch_graph.node_list:
for destination in node.outgoing_nodes:
for destination, shape in node.outgoing_nodes.items():
nir_edges.append((node.name, destination.name))

if len(node.outgoing_nodes) == 0:
out_name = "output"
output_node = nir.Output(None)
# Try to find shape of input to the Output node
output_node = nir.Output(torch_graph.module_output_types[node.elem])
nir_nodes[out_name] = output_node
nir_edges.append((node.name, out_name))

# Remove duplicate edges
nir_edges = list(set(nir_edges))

return nir.NIRGraph(nir_nodes, nir_edges)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ requires = ["setuptools"]

[project]
name = "nirtorch"
version = "0.2.0"
description = "Neuromorphic Intermediate Representation"
authors = [
{ name = "Steven Abreu", email = "[email protected]" },
Expand All @@ -29,3 +28,7 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = [ "torch", "nir" ]
dynamic = ["version"] # Version number read from __init__.py

[tool.setuptools.dynamic]
version = {attr = "nirtorch.__version__"}
Empty file added tests/__init__.py
Empty file.
8 changes: 4 additions & 4 deletions tests/test_from_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
def _torch_model_map(m: nir.NIRNode, device: str = "cpu") -> torch.nn.Module:
if isinstance(m, nir.Affine):
lin = torch.nn.Linear(*m.weight.shape[-2:])
lin.weight.data = torch.tensor(m.weight, device=device)
lin.bias.data = torch.tensor(m.bias, device=device)
lin.weight.data = torch.nn.Parameter(m.weight.to(device))
lin.bias.data = torch.nn.Parameter(m.bias.to(device))
return lin
elif isinstance(m, nir.Input) or isinstance(m, nir.Output):
return None
Expand All @@ -27,8 +27,8 @@ def test_extract_lin():
x = torch.randn(1, 1)
lin = nir.Affine(x, torch.randn(1, 1))
torchlin = torch.nn.Linear(1, 1)
torchlin.weight.data = torch.tensor(lin.weight)
torchlin.bias.data = torch.tensor(lin.bias)
torchlin.weight.data = torch.nn.Parameter(lin.weight)
torchlin.bias.data = torch.nn.Parameter(lin.bias)
y = torchlin(torchlin(x))
g = nir.NIRGraph({"a": lin, "b": lin}, [("a", "b")])
m = load(g, _torch_model_map)
Expand Down
Loading

0 comments on commit 7b1976f

Please sign in to comment.