Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove batch from shape spec #18

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions nirtorch/to_nir.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Sequence

import nir
import numpy as np
Expand All @@ -12,7 +12,8 @@ def extract_nir_graph(
model_map: Callable[[nn.Module], nir.NIRNode],
sample_data: Any,
model_name: Optional[str] = "model",
ignore_submodules_of=None
ignore_submodules_of: Optional[Sequence[nn.Module]] = None,
ignore_dims: Optional[Sequence[int]] = None,
) -> nir.NIRNode:
"""Given a `model`, generate an NIR representation using the specified `model_map`.

Expand All @@ -23,7 +24,11 @@ def extract_nir_graph(
sample_data (Any): Sample input data to be used for model extraction
model_name (Optional[str], optional): The name of the top level module.
Defaults to "model".

ignore_submodules_of (Optional[Sequence[nn.Module]]): If specified,
the corresponding module's children will not be traversed for graph.
ignore_dims (Optional[Sequence[int]]): Dimensions of data to be ignored for
type/shape inference. Typically the dimensions that you will want to ignore
are for batch and time.
Returns:
nir.NIR: Returns the generated NIR graph representation.
"""
Expand All @@ -49,7 +54,12 @@ def extract_nir_graph(

# Convert the nodes and get indices
nir_edges = []
nir_nodes = {"input": nir.Input(np.array(sample_data.shape))}
input_shape = np.array(sample_data.shape)
if ignore_dims:
nir_nodes = {"input": nir.Input(np.delete(input_shape, ignore_dims))}
else:
nir_nodes = {"input": nir.Input(input_shape)}
nir_edges = []

# Get all the NIR nodes
for indx, node in enumerate(torch_graph.node_list):
Expand Down Expand Up @@ -85,7 +95,13 @@ def extract_nir_graph(
if len(node.outgoing_nodes) == 0:
out_name = "output"
# Try to find shape of input to the Output node
output_node = nir.Output(torch_graph.module_output_types[node.elem])
if ignore_dims:
out_shape = np.delete(
torch_graph.module_output_types[node.elem], ignore_dims
)
else:
out_shape = torch_graph.module_output_types[node.elem]
output_node = nir.Output(out_shape)
nir_nodes[out_name] = output_node
nir_edges.append((node.name, out_name))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def _torch_convert(module: nn.Module) -> nir.NIRNode:
if isinstance(module, nn.Conv1d):
return nir.Conv1d(module.weight, 1, 1, 1, 1, module.bias)
return nir.Conv1d(None, module.weight, 1, 1, 1, 1, module.bias)
elif isinstance(module, nn.Linear):
return nir.Affine(module.weight, module.bias)
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_from_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _torch_model_map(m: nir.NIRNode, device: str = "cpu") -> torch.nn.Module:
lin.weight.data = torch.nn.Parameter(torch.tensor(m.weight).to(device))
return lin
elif isinstance(m, nir.Input) or isinstance(m, nir.Output):
return None
return torch.nn.Identity()
Jegp marked this conversation as resolved.
Show resolved Hide resolved
else:
raise NotImplementedError(f"Unsupported module {m}")

Expand Down Expand Up @@ -49,6 +49,7 @@ def test_extract_lin():
assert torch.allclose(m(x), y)


@pytest.mark.skip("Not yet supported")
def test_extrac_recurrent():
w = np.random.randn(1, 1)
g = nir.NIRGraph(
Expand Down
11 changes: 8 additions & 3 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,15 @@ def test_output_type_when_single_node():


def test_sequential_flatten():
d = torch.empty(2, 3, 4)
d = torch.empty(3, 4)
g = extract_nir_graph(torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d)
g.nodes["input"].input_type["input"] == (2, 3, 4)
g.nodes["output"].output_type["output"] == (2, 3 * 4)
assert tuple(g.nodes["input"].input_type["input"]) == (3, 4)
assert tuple(g.nodes["output"].output_type["output"]) == (3 * 4,)

d = torch.empty(2, 3, 4)
g = extract_nir_graph(torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d, ignore_dims=[0])
assert tuple(g.nodes["input"].input_type["input"]) == (3, 4)
assert tuple(g.nodes["output"].output_type["output"]) == (3 * 4,)


@pytest.mark.skip(reason="Not supported yet")
Expand Down
32 changes: 32 additions & 0 deletions tests/test_to_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,38 @@ def extractor(m):
}


def test_ignore_batch_dim():
model = nn.Linear(3, 1)

def extractor(module: nn.Module):
return nir.Affine(module.weight, module.bias)

raw_input_shape = (1, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0])
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))
assert g.nodes["model"].weight.shape == (1, 3)
assert np.alltrue(g.nodes["output"].output_type["output"] == np.array([1]))


def test_ignore_time_and_batch_dim():
model = nn.Linear(3, 1)

def extractor(module: nn.Module):
return nir.Affine(module.weight, module.bias)

raw_input_shape = (1, 10, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, -2])
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))
assert g.nodes["model"].weight.shape == (1, 3)

raw_input_shape = (1, 10, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, 1])
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))


# def test_extract_stateful():
# model = norse.SequentialState(norse.LIFBoxCell(), nn.Linear(3, 1))

Expand Down
Loading