Skip to content

Commit

Permalink
Enable Sub optim at codegen using SubF
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos committed Nov 28, 2024
1 parent 90a583c commit 7001755
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
18 changes: 16 additions & 2 deletions core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use downcast_rs::Downcast;
use std::fmt::{self, Debug};
use tract_data::itertools::izip;
use tract_itertools::Itertools;
use tract_linalg::LinalgFn;
use tract_linalg::{LinalgFn, BinOp};
use crate::ndarray::Dimension;

use super::cast::cast;
use super::{cast::cast, math::SubF};

pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast {
fn name(&self) -> &'static str;
Expand Down Expand Up @@ -308,6 +308,20 @@ fn declutter_broadcasting_operand_1(
)?));
}

// Special case for sub
let is_sub = mini_op.as_linalg_binop().map_or(false, |it| it == BinOp::Sub);
if a_should_be_broadcast & is_sub {
let subf_mini_op = Box::new(SubF {});
let mut swap_input = node.inputs.clone();
swap_input.swap(0, 1);
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&swap_input,
TypedBinOp(subf_mini_op, None),
)?));
}

Ok(None)
}

Expand Down
17 changes: 17 additions & 0 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ bin_to_super_type!(sub, Sub,
q_op_on_f32: |a: f32, b: f32| -> f32 {a-b},
[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b);

bin_to_super_type!(subf, SubF,
linalg:SubF,
is_commutative: false,
neutral_element: 0,
q: [i8, u8, i32, i32] => subf_quant;
q_op_on_f32: |a: f32, b: f32| -> f32 {b - a},
[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = b.clone() - a);

fn sub_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
where
T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
Expand All @@ -51,6 +59,15 @@ where
*c = (a.as_() - b.as_() + zp as i16).clamp_cast()
}

fn subf_quant<T>(c: &mut T, a: &T, b: &T, zp: i32, _: f32)
where
T: PrimInt + Bounded + AsPrimitive<i16> + Datum,
i16: AsPrimitive<T>,
{
*c = (b.as_() - a.as_() + zp as i16).clamp_cast()
}


bin_to_super_type!(mul, Mul,
cost: |dt| tvec!((Cost::FMA(dt), 1)),
declutter: declutter_mul,
Expand Down

0 comments on commit 7001755

Please sign in to comment.