diff --git a/nirtorch/from_nir.py b/nirtorch/from_nir.py index 773ba11..e867129 100644 --- a/nirtorch/from_nir.py +++ b/nirtorch/from_nir.py @@ -116,7 +116,10 @@ def forward(self, data: torch.Tensor, state: Optional[Dict[str, Any]] = {}): 0 ) # Multiple inputs are summed outs[node.name] = self._apply_module(node, input_data, state) - return outs[node.name] + if len(state) > 0: + return outs[node.name], state + else: + return outs[node.name] def _mod_nir_to_graph(nir_graph: nir.NIRNode) -> Graph: diff --git a/tests/test_from_nir.py b/tests/test_from_nir.py index 717d7d6..aad54d0 100644 --- a/tests/test_from_nir.py +++ b/tests/test_from_nir.py @@ -79,5 +79,7 @@ def forward(self, x, state=None): edges=[("li", "li2")], ) # Mock node m = load(g, lambda m: StatefulModel()) - out = m(torch.ones(10)) + out, state = m(torch.ones(10)) assert torch.allclose(out, torch.ones(10) * 3) + assert state["li"] == (1, ) + assert state["li"] == (1, )