diff --git a/nirtorch/to_nir.py b/nirtorch/to_nir.py index a6ef230..ac4c635 100644 --- a/nirtorch/to_nir.py +++ b/nirtorch/to_nir.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Sequence import nir import numpy as np @@ -12,7 +12,8 @@ def extract_nir_graph( model_map: Callable[[nn.Module], nir.NIRNode], sample_data: Any, model_name: Optional[str] = "model", - ignore_submodules_of=None + ignore_submodules_of: Optional[Sequence[nn.Module]] = None, + ignore_dims: Optional[Sequence[int]] = None, ) -> nir.NIRNode: """Given a `model`, generate an NIR representation using the specified `model_map`. @@ -23,7 +24,11 @@ def extract_nir_graph( sample_data (Any): Sample input data to be used for model extraction model_name (Optional[str], optional): The name of the top level module. Defaults to "model". - + ignore_submodules_of (Optional[Sequence[nn.Module]]): If specified, + the corresponding module's children will not be traversed for graph. + ignore_dims (Optional[Sequence[int]]): Dimensions of data to be ignored for + type/shape inference. Typically the dimensions that you will want to ignore + are for batch and time. Returns: nir.NIR: Returns the generated NIR graph representation. """ @@ -49,7 +54,12 @@ def extract_nir_graph( # Convert the nodes and get indices nir_edges = [] - nir_nodes = {"input": nir.Input(np.array(sample_data.shape))} + input_shape = np.array(sample_data.shape) + if ignore_dims: + nir_nodes = {"input": nir.Input(np.delete(input_shape, ignore_dims))} + else: + nir_nodes = {"input": nir.Input(input_shape)} + nir_edges = [] # Get all the NIR nodes for indx, node in enumerate(torch_graph.node_list): @@ -85,7 +95,13 @@ def extract_nir_graph( if len(node.outgoing_nodes) == 0: out_name = "output" # Try to find shape of input to the Output node - output_node = nir.Output(torch_graph.module_output_types[node.elem]) + if ignore_dims: + out_shape = np.delete( + torch_graph.module_output_types[node.elem], ignore_dims + ) + else: + out_shape = torch_graph.module_output_types[node.elem] + output_node = nir.Output(out_shape) nir_nodes[out_name] = output_node nir_edges.append((node.name, out_name)) diff --git a/tests/test_conversion.py b/tests/test_conversion.py index 4e2f0db..bcbd527 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -7,7 +7,7 @@ def _torch_convert(module: nn.Module) -> nir.NIRNode: if isinstance(module, nn.Conv1d): - return nir.Conv1d(module.weight, 1, 1, 1, 1, module.bias) + return nir.Conv1d(None, module.weight, 1, 1, 1, 1, module.bias) elif isinstance(module, nn.Linear): return nir.Affine(module.weight, module.bias) else: diff --git a/tests/test_from_nir.py b/tests/test_from_nir.py index d6d09ae..55bfc01 100644 --- a/tests/test_from_nir.py +++ b/tests/test_from_nir.py @@ -17,7 +17,7 @@ def _torch_model_map(m: nir.NIRNode, device: str = "cpu") -> torch.nn.Module: lin.weight.data = torch.nn.Parameter(torch.tensor(m.weight).to(device)) return lin elif isinstance(m, nir.Input) or isinstance(m, nir.Output): - return None + return torch.nn.Identity() else: raise NotImplementedError(f"Unsupported module {m}") @@ -49,6 +49,7 @@ def test_extract_lin(): assert torch.allclose(m(x), y) +@pytest.mark.skip("Not yet supported") def test_extrac_recurrent(): w = np.random.randn(1, 1) g = nir.NIRGraph( diff --git a/tests/test_graph.py b/tests/test_graph.py index c99e1ed..60c9189 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -277,10 +277,15 @@ def test_output_type_when_single_node(): def test_sequential_flatten(): - d = torch.empty(2, 3, 4) + d = torch.empty(3, 4) g = extract_nir_graph(torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d) - g.nodes["input"].input_type["input"] == (2, 3, 4) - g.nodes["output"].output_type["output"] == (2, 3 * 4) + assert tuple(g.nodes["input"].input_type["input"]) == (3, 4) + assert tuple(g.nodes["output"].output_type["output"]) == (3 * 4,) + + d = torch.empty(2, 3, 4) + g = extract_nir_graph(torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d, ignore_dims=[0]) + assert tuple(g.nodes["input"].input_type["input"]) == (3, 4) + assert tuple(g.nodes["output"].output_type["output"]) == (3 * 4,) @pytest.mark.skip(reason="Not supported yet") diff --git a/tests/test_to_nir.py b/tests/test_to_nir.py index f0f0665..26fb0a1 100644 --- a/tests/test_to_nir.py +++ b/tests/test_to_nir.py @@ -105,6 +105,38 @@ def extractor(m): } +def test_ignore_batch_dim(): + model = nn.Linear(3, 1) + + def extractor(module: nn.Module): + return nir.Affine(module.weight, module.bias) + + raw_input_shape = (1, 3) + g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0]) + exp_input_shape = (3,) + assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape)) + assert g.nodes["model"].weight.shape == (1, 3) + assert np.alltrue(g.nodes["output"].output_type["output"] == np.array([1])) + + +def test_ignore_time_and_batch_dim(): + model = nn.Linear(3, 1) + + def extractor(module: nn.Module): + return nir.Affine(module.weight, module.bias) + + raw_input_shape = (1, 10, 3) + g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, -2]) + exp_input_shape = (3,) + assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape)) + assert g.nodes["model"].weight.shape == (1, 3) + + raw_input_shape = (1, 10, 3) + g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, 1]) + exp_input_shape = (3,) + assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape)) + + # def test_extract_stateful(): # model = norse.SequentialState(norse.LIFBoxCell(), nn.Linear(3, 1))