diff --git a/examples/onnx/lstm_large/input.json b/examples/onnx/lstm_large/input.json new file mode 100644 index 000000000..2e54c296c --- /dev/null +++ b/examples/onnx/lstm_large/input.json @@ -0,0 +1,13 @@ +{ + "input_data": [ + [ + 0.8894134163856506, + 0.8894201517105103 + ] + ], + "output_data": [ + [ + 0.8436377 + ] + ] +} \ No newline at end of file diff --git a/examples/onnx/lstm_large/network.onnx b/examples/onnx/lstm_large/network.onnx new file mode 100644 index 000000000..bff5a4fe3 Binary files /dev/null and b/examples/onnx/lstm_large/network.onnx differ diff --git a/src/graph/model.rs b/src/graph/model.rs index d79fd4431..ec39d71f2 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -803,13 +803,18 @@ impl Model { let input_state_idx = input_state_idx(&input_mappings); let mut output_mappings = vec![]; - for mapping in b.output_mapping.iter() { + for (i, mapping) in b.output_mapping.iter().enumerate() { let mut mappings = vec![]; if let Some(outlet) = mapping.last_value_slot { mappings.push(OutputMapping::Single { outlet, is_state: mapping.state, }); + } else if mapping.state { + mappings.push(OutputMapping::Single { + outlet: i, + is_state: mapping.state, + }); } if let Some(last) = mapping.scan { mappings.push(OutputMapping::Stacked { @@ -818,6 +823,7 @@ impl Model { is_state: false, }); } + output_mappings.push(mappings); } @@ -1264,8 +1270,8 @@ impl Model { let num_iter = number_of_iterations(&input_mappings, input_dims.collect()); debug!( - "{} iteration(s) in a subgraph with inputs {:?} and sources {:?}", - num_iter, inputs, model.graph.inputs + "{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}", + num_iter, inputs, model.graph.inputs, model.graph.outputs ); let mut full_results: Vec> = vec![]; @@ -1297,6 +1303,7 @@ impl Model { let res = model.layout_nodes(config, region, &mut subgraph_results)?; let mut outlets = BTreeMap::new(); + let mut stacked_outlets = BTreeMap::new(); for (mappings, outlet_res) in output_mappings.iter().zip(res) { for mapping in mappings { @@ -1309,25 +1316,42 @@ impl Model { let stacked_res = full_results[*outlet] .clone() .concat_axis(outlet_res.clone(), axis)?; - - outlets.insert(outlet, stacked_res); - } else { - outlets.insert(outlet, outlet_res.clone()); + stacked_outlets.insert(outlet, stacked_res); } + outlets.insert(outlet, outlet_res.clone()); } } } } - full_results = outlets.into_values().collect_vec(); + // now extend with stacked elements + let mut pre_stacked_outlets = outlets.clone(); + pre_stacked_outlets.extend(stacked_outlets); + + let outlets = outlets.into_values().collect_vec(); + + full_results = pre_stacked_outlets.into_values().collect_vec(); let output_states = output_state_idx(output_mappings); let input_states = input_state_idx(&input_mappings); - assert_eq!(input_states.len(), output_states.len()); + assert_eq!( + input_states.len(), + output_states.len(), + "input and output states must be the same length, got {:?} and {:?}", + input_mappings, + output_mappings + ); for (input_idx, output_idx) in input_states.iter().zip(output_states) { - values[*input_idx] = full_results[output_idx].clone(); + assert_eq!( + values[*input_idx].dims(), + outlets[output_idx].dims(), + "input and output dims must be the same, got {:?} and {:?}", + values[*input_idx].dims(), + outlets[output_idx].dims() + ); + values[*input_idx] = outlets[output_idx].clone(); } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index a1893eed9..7c1c1cd28 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -200,7 +200,7 @@ mod native_tests { "1l_tiny_div", ]; - const TESTS: [&str; 91] = [ + const TESTS: [&str; 92] = [ "1l_mlp", //0 "1l_slice", "1l_concat", @@ -296,6 +296,7 @@ mod native_tests { "reducel1", "reducel2", // 89 "1l_lppool", + "lstm_large", // 91 ]; const WASM_TESTS: [&str; 46] = [ @@ -534,7 +535,7 @@ mod native_tests { } }); - seq!(N in 0..=90 { + seq!(N in 0..=91 { #(#[test_case(TESTS[N])])* #[ignore]