Skip to content

Commit

Permalink
fix: hold stacked outputs in a separate map
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Apr 2, 2024
1 parent ff563e9 commit 32c3a5e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
13 changes: 13 additions & 0 deletions examples/onnx/lstm_large/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"input_data": [
[
0.8894134163856506,
0.8894201517105103
]
],
"output_data": [
[
0.8436377
]
]
}
Binary file added examples/onnx/lstm_large/network.onnx
Binary file not shown.
44 changes: 34 additions & 10 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,13 +803,18 @@ impl Model {
let input_state_idx = input_state_idx(&input_mappings);

let mut output_mappings = vec![];
for mapping in b.output_mapping.iter() {
for (i, mapping) in b.output_mapping.iter().enumerate() {
let mut mappings = vec![];
if let Some(outlet) = mapping.last_value_slot {
mappings.push(OutputMapping::Single {
outlet,
is_state: mapping.state,
});
} else if mapping.state {
mappings.push(OutputMapping::Single {
outlet: i,
is_state: mapping.state,
});
}
if let Some(last) = mapping.scan {
mappings.push(OutputMapping::Stacked {
Expand All @@ -818,6 +823,7 @@ impl Model {
is_state: false,
});
}

output_mappings.push(mappings);
}

Expand Down Expand Up @@ -1264,8 +1270,8 @@ impl Model {
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());

debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, inputs, model.graph.inputs
"{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}",
num_iter, inputs, model.graph.inputs, model.graph.outputs
);

let mut full_results: Vec<ValTensor<Fp>> = vec![];
Expand Down Expand Up @@ -1297,6 +1303,7 @@ impl Model {
let res = model.layout_nodes(config, region, &mut subgraph_results)?;

let mut outlets = BTreeMap::new();
let mut stacked_outlets = BTreeMap::new();

for (mappings, outlet_res) in output_mappings.iter().zip(res) {
for mapping in mappings {
Expand All @@ -1309,25 +1316,42 @@ impl Model {
let stacked_res = full_results[*outlet]
.clone()
.concat_axis(outlet_res.clone(), axis)?;

outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
stacked_outlets.insert(outlet, stacked_res);
}
outlets.insert(outlet, outlet_res.clone());
}
}
}
}

full_results = outlets.into_values().collect_vec();
// now extend with stacked elements
let mut pre_stacked_outlets = outlets.clone();
pre_stacked_outlets.extend(stacked_outlets);

let outlets = outlets.into_values().collect_vec();

full_results = pre_stacked_outlets.into_values().collect_vec();

let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);

assert_eq!(input_states.len(), output_states.len());
assert_eq!(
input_states.len(),
output_states.len(),
"input and output states must be the same length, got {:?} and {:?}",
input_mappings,
output_mappings
);

for (input_idx, output_idx) in input_states.iter().zip(output_states) {
values[*input_idx] = full_results[output_idx].clone();
assert_eq!(
values[*input_idx].dims(),
outlets[output_idx].dims(),
"input and output dims must be the same, got {:?} and {:?}",
values[*input_idx].dims(),
outlets[output_idx].dims()
);
values[*input_idx] = outlets[output_idx].clone();
}
}

Expand Down
5 changes: 3 additions & 2 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ mod native_tests {
"1l_tiny_div",
];

const TESTS: [&str; 91] = [
const TESTS: [&str; 92] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
Expand Down Expand Up @@ -296,6 +296,7 @@ mod native_tests {
"reducel1",
"reducel2", // 89
"1l_lppool",
"lstm_large", // 91
];

const WASM_TESTS: [&str; 46] = [
Expand Down Expand Up @@ -534,7 +535,7 @@ mod native_tests {
}
});

seq!(N in 0..=90 {
seq!(N in 0..=91 {

#(#[test_case(TESTS[N])])*
#[ignore]
Expand Down

0 comments on commit 32c3a5e

Please sign in to comment.