Skip to content

Commit

Permalink
fix optim issue
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Apr 12, 2024
1 parent 8e87f1a commit 60ad0c8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ dyn-clone = "1.0.4"
env_logger = "0.10"
flatbuffers = "23.1.21"
flate2 = "1.0.20"
fs-err = "2"
fs2 = "0.4.3"
getrandom = "0.2"
half = { version="2.2.1", features = [ "std", "num-traits" ] }
Expand Down
16 changes: 13 additions & 3 deletions core/src/optim/push_split_down.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::internal::*;

use tract_itertools::Itertools;
use crate::optim::OptimizerSession;
use tract_itertools::Itertools;

#[derive(Clone, Debug)]
pub struct PushSplitDown;
Expand All @@ -10,16 +10,26 @@ impl super::TypedPass for PushSplitDown {
fn reset(&mut self) -> TractResult<()> {
Ok(())
}
fn next(&mut self, _session: &mut OptimizerSession, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> {
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
for node in model.eval_order()? {
for output in &model.node(node).outputs {
for (a, b) in output.successors.iter().tuple_combinations() {
if a.node == b.node {
// found where a square is implemented using a mul with duplicate input
continue;
}
if patch.obliterate.contains(&b.node) {
continue;
}
// dont merge outputs.
if model.outputs.contains(&a.node.into()) && model.outputs.contains(&b.node.into()) {
if model.outputs.contains(&a.node.into())
&& model.outputs.contains(&b.node.into())
{
continue;
}
let a = model.node(a.node);
Expand Down

0 comments on commit 60ad0c8

Please sign in to comment.