Skip to content

Commit

Permalink
Added CI workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
Jegp committed Jul 5, 2023
1 parent 599eb46 commit ccba439
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 46 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Build

on: [pull_request]

jobs:
build_python:
name: Build on ${{ matrix.os }}

strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]

runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .
pip install ruff pytest sinabs norse
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
ruff --format=github --select=E9,F63,F7,F82 --target-version=py37 .
# default set of ruff rules with GitHub Annotations
ruff --format=github --target-version=py37 .
- name: Test with pytest
run: |
pytest
19 changes: 19 additions & 0 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
on:
release:

jobs:
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/nirtorch
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Build Python package
run: pip install build && python -m build
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Byte-compiled / optimized / DLL files
bin/
__pycache__/
*.py[cod]
*$py.class
Expand Down Expand Up @@ -160,4 +161,4 @@ cython_debug/
#.idea/

# VSCode
.vscode
.vscode
4 changes: 2 additions & 2 deletions nirtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .to_nir import extract_nir_graph
from .graph import extract_torch_graph
from .to_nir import extract_nir_graph # noqa F401
from .to_nir import extract_nir_graph # noqa F401
25 changes: 17 additions & 8 deletions nirtorch/from_nir.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.nn as nn
import nir
from typing import Callable
Expand All @@ -16,26 +15,36 @@ def instantiate_modules(self, graph: Graph):
def forward(self, x):
raise NotImplementedError()

def _convert_number_to_legal_variable_name(num: int)->str:

def _convert_number_to_legal_variable_name(num: int) -> str:
return f"mod_{num}"

def _mod_nir_to_graph(nir_graph: nir.NIR)->Graph:
module_names = {module: _convert_number_to_legal_variable_name(idx) for idx, module in enumerate(nir_graph.nodes)}

def _mod_nir_to_graph(nir_graph: nir.NIR) -> Graph:
module_names = {
module: _convert_number_to_legal_variable_name(idx)
for idx, module in enumerate(nir_graph.nodes)
}
graph = Graph(module_names=module_names)
for src, dst in nir_graph.edges:
graph.add_edge(src, dst)
return graph

def _switch_models_with_map(nir_graph: nir.NIR, model_map: Callable[[nn.Module], nn.Module])->nir.NIR:

def _switch_models_with_map(
nir_graph: nir.NIR, model_map: Callable[[nn.Module], nn.Module]
) -> nir.NIR:
nodes = [model_map(node) for node in nir_graph.nodes]
return nir.NIR(nodes, nir_graph.edges)

def load(nir_graph: nir.NIR, model_map: Callable[[nn.Module], nn.Module])->nn.Module:

def load(nir_graph: nir.NIR, model_map: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""Load a NIR object and convert it to a torch module using the given model map
Args:
nir_graph (nir.NIR): NIR object
model_map (Callable[[nn.Module], nn.Module]): A method that returns the a torch module that corresponds to each NIR node.
model_map (Callable[[nn.Module], nn.Module]): A method that returns the a torch
module that corresponds to each NIR node.
Returns:
nn.Module: The generated torch module
Expand All @@ -46,4 +55,4 @@ def load(nir_graph: nir.NIR, model_map: Callable[[nn.Module], nn.Module])->nn.Mo
graph = _mod_nir_to_graph(nir_module_graph)
# Build and return ExtractedModule
model = ExtractedModel(graph)
return model
return model
21 changes: 14 additions & 7 deletions nirtorch/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

import torch
import torch.nn as nn
Expand All @@ -12,7 +12,8 @@ def named_modules_map(
Args:
model (nn.Module): The module to be hashed
model_name (str | None): Name of the top level module. If this doesn't need to be include, this option can be set to None
model_name (str | None): Name of the top level module. If this doesn't need
to be include, this option can be set to None
Returns:
Dict[str, nn.Module]: A dictionary with modules as keys, and names as values
Expand Down Expand Up @@ -130,6 +131,9 @@ def add_edge(
if self._is_mod_and_not_in_module_names(destination):
return

if source is None or destination is None:
return # Stateful models may have Nones

source_node = self.add_or_get_node_for_elem(source)
destination_node = self.add_or_get_node_for_elem(destination)
source_node.add_outgoing(destination_node)
Expand Down Expand Up @@ -268,9 +272,11 @@ def ignore_nodes(self, class_type: Type) -> "Graph":
# Directly add an edge from source to destination
for source_node in source_node_list:
graph.add_edge(source_node.elem, outgoing_node.elem)
# NOTE: Assuming that the destination is not of the same type here
# NOTE: Assuming that the destination is not of the same
# type here
else:
# This is to preserve the graph if executed on a graph that is already filtered
# This is to preserve the graph if executed on a graph that is
# already filtered
for outnode in node.outgoing_nodes:
if not isinstance(outnode.elem, class_type):
graph.add_edge(node.elem, outnode.elem)
Expand Down Expand Up @@ -301,7 +307,7 @@ def my_forward(mod: nn.Module, *args, **kwargs) -> Any:
model_graph.add_edge(input_data, mod)
out = _torch_module_call(mod, *args, **kwargs)
if isinstance(out, tuple):
out_tuple
out_tuple = (out[0],)
elif isinstance(out, torch.Tensor):
out_tuple = (out,)
else:
Expand Down Expand Up @@ -347,7 +353,8 @@ def extract_torch_graph(
model: nn.Module, sample_data: Any, model_name: Optional[str] = "model"
) -> Graph:
"""Extract computational graph between various modules in the model
NOTE: This method is not capable of any compute happening outside of module definitions.
NOTE: This method is not capable of any compute happening outside of module
definitions.
Args:
model (nn.Module): The module to be analysed
Expand All @@ -363,6 +370,6 @@ def extract_torch_graph(
with GraphTracer(
named_modules_map(model, model_name=model_name)
) as tracer, torch.no_grad():
out = model(sample_data)
_ = model(sample_data)

return tracer.graph
35 changes: 19 additions & 16 deletions nirtorch/to_nir.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import torch.nn as nn
from torch import Tensor
from typing import Any, Optional, Callable, List
from typing import Any, Optional, Callable, List, Union
import nir
from .graph import extract_torch_graph


def extract_nir_graph(
model: nn.Module,
model_map: Callable[[nn.Module], nir.NIRNode],
model_map: Callable[[nn.Module], Union[nir.NIRNode, List[nir.NIRNode]]],
sample_data: Any,
model_name: Optional[str] = "model",
) -> nir.NIR:
"""Given a `model`, generate an NIR representation using the specified `model_map`.
Args:
model (nn.Module): The model of interest
model_map (Callable[[nn.Module], nir.NIRNode]): A method that converts a given module type to an NIRNode type
model_map (Callable[[nn.Module], nir.NIRNode]): A method that converts a given
module type to an NIRNode type
sample_data (Any): Sample input data to be used for model extraction
model_name (Optional[str], optional): The name of the top level module. Defaults to "model".
model_name (Optional[str], optional): The name of the top level module.
Defaults to "model".
Returns:
nir.NIR: Returns the generated NIR graph representation.
Expand All @@ -27,7 +29,9 @@ def extract_nir_graph(
model_name = None

# Extract a torch graph given the model
torch_graph = extract_torch_graph(model, sample_data=sample_data, model_name=model_name).ignore_tensors()
torch_graph = extract_torch_graph(
model, sample_data=sample_data, model_name=model_name
).ignore_tensors()

# Convert the nodes and get indices
edges = []
Expand All @@ -37,21 +41,20 @@ def extract_nir_graph(
# Get all the NIR nodes
for indx, node in enumerate(torch_graph.node_list):
# Convert the node type
nir_nodes.append(model_map(node.elem))
mapped_node = model_map(node.elem)
if isinstance(mapped_node, list): # Node maps to multiple nodes
nir_nodes.extend(mapped_node)
# Add edges sequentially between the nodes
for n_idx in range(len(mapped_node[:-1])):
edges.append((indx + n_idx, indx + n_idx + 1))
elif isinstance(mapped_node, nir.NIRNode):
nir_nodes.append(mapped_node)

indices[node] = indx

# Get all the edges
for node in torch_graph.node_list:
for destination in node.outgoing_nodes:
edges.append((indices[node], indices[destination]))

print(indices)
print(nir_nodes)
print(edges)

return nir.NIR(nir_nodes, edges)





34 changes: 30 additions & 4 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import torch.nn as nn
import torch
from sinabs.layers import Merge
from norse.torch import LIFCell, SequentialState
import pytest


class TupleModule(torch.nn.Module):
def forward(self, data):
return (data, data)


def test_sequential_graph_extract():
from nirtorch.graph import extract_torch_graph

Expand Down Expand Up @@ -49,6 +55,18 @@ def forward(self, data):
return out4


class MyStatefulModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu1 = nn.ReLU()
self.lif = SequentialState(LIFCell())

def forward(self, data):
out1 = self.relu1(data)
out2, _ = self.lif(out1)
return out2


input_shape = (2, 28, 28)
batch_size = 1

Expand Down Expand Up @@ -96,7 +114,7 @@ def test_module_forward_wrapper():
nn.Module.__call__ = new_call

with torch.no_grad():
out = mymodel(data)
_ = mymodel(data)

# Restore normal behavior
nn.Module.__call__ = orig_call
Expand All @@ -111,7 +129,7 @@ def test_graph_tracer():
from nirtorch.graph import GraphTracer, named_modules_map

with GraphTracer(named_modules_map(my_branched_model)) as tracer, torch.no_grad():
out = my_branched_model(data)
_ = my_branched_model(data)

print(tracer.graph)
assert (
Expand All @@ -123,7 +141,7 @@ def test_leaf_only_graph():
from nirtorch.graph import GraphTracer, named_modules_map

with GraphTracer(named_modules_map(mydeepmodel)) as tracer, torch.no_grad():
out = mydeepmodel(data)
_ = mydeepmodel(data)

print(tracer.graph)

Expand All @@ -139,7 +157,7 @@ def test_ignore_submodules_of():
from nirtorch.graph import GraphTracer, named_modules_map

with GraphTracer(named_modules_map(mydeepmodel)) as tracer, torch.no_grad():
out = mydeepmodel(data)
_ = mydeepmodel(data)

top_overview_graph = tracer.graph.ignore_submodules_of(
[MyBranchedModel]
Expand Down Expand Up @@ -197,6 +215,14 @@ def forward(self, spikes):
assert len(graph.node_list) == 27 # 2*13 + 1


def test_snn_stateful():
from nirtorch.graph import extract_torch_graph

model = MyStatefulModel()
graph = extract_torch_graph(model, sample_data=torch.rand((1, 2, 3, 4)))
assert len(graph.node_list) == 7 # 2 + 1 nested + 4 tensors


def test_ignore_tensors():
from nirtorch.graph import extract_torch_graph

Expand Down
Loading

0 comments on commit ccba439

Please sign in to comment.