Skip to content

Commit

Permalink
unified managmeent for u8->i8 strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 28, 2023
1 parent f6a0f71 commit 7a263b8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 38 deletions.
74 changes: 38 additions & 36 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,16 @@ impl ConvUnary {
use crate::ops::matmul::mir_quant as qmm;

let c_dt = self.q_params.unwrap();
let [a0, mut a_scale, mut b0, b_scale, c0, c_scale] = wires[1..] else {
let &[input, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = wires else {
bail!("Wrong number of inputs")
};
let b = wire_offset_u8_as_i8(model, name, wires[0], "b", &mut b0, "b0")?;
let kernel = model.add_const(format!("{name}.kernel"), self.kernel.clone())?;
let kernel = wire_offset_u8_as_i8(model, name, kernel, "a", &mut a0, "a0")?;
let b = wire_offset_u8_as_i8(model, name, input, "b", &mut b0, "b0")?;

let a_fact = model.outlet_fact(kernel)?.clone();
let b_fact = model.outlet_fact(b)?.clone();
let (_, _, k, n, mmm) = self.compute_geo(&b_fact)?;
let (_, _, k, n, mmm) = self.compute_geo(&a_fact, &b_fact)?;
let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;

if !model.outlet_fact(a_scale)?.shape.volume().is_one() {
Expand All @@ -183,7 +187,6 @@ impl ConvUnary {
&[b, b0],
)?[0];

let kernel = model.add_const(format!("{name}.kernel"), self.kernel.clone())?;
let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
let g_o_ihw_as_i32 =
model.wire_node(format!("{name}.kernel_as_i32"), cast(i32::datum_type()), &g_o_ihw)?;
Expand Down Expand Up @@ -218,10 +221,11 @@ impl ConvUnary {
let b_dt = model.outlet_fact(b)?.datum_type;
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
let b_storage = unsafe { mmm.b_packed(b_dt.size_of(), k) };
let wire = self.wire_lir_matmatmul(
let wire = self.wire_mm_weights_bias(
model,
name,
&[im2col],
im2col,
g_o_ihw[0],
mmm,
i32::datum_type(),
mmm_output_shape.clone().into(),
Expand Down Expand Up @@ -275,11 +279,12 @@ impl ConvUnary {
name: &str,
wire: &[OutletId],
) -> TractResult<TVec<OutletId>> {
let a_fact = TypedFact::shape_and_dt_of(&self.kernel);
let b_fact = model.outlet_fact(wire[0])?.clone();
let b_dt = b_fact.datum_type;
let c_dt = crate::ops::matmul::output_type(b_fact.datum_type);

let (_, _, k, _, mmm) = self.compute_geo(model.outlet_fact(wire[0])?)?;
let (_, _, k, _, mmm) = self.compute_geo(&a_fact, &b_fact)?;
let geo_output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo_output_shape)?;

Expand All @@ -291,11 +296,16 @@ impl ConvUnary {
)?;

let b_storage = unsafe { mmm.b_packed(b_dt.size_of(), k) };

let kernel = model.add_const(format!("{name}.kernels"), self.kernel.clone())?;
let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;

let wire = self
.wire_lir_matmatmul(
.wire_mm_weights_bias(
model,
name,
&wire,
wire[0],
g_o_ihw[0],
mmm,
c_dt,
mmm_output_shape.clone().into(),
Expand Down Expand Up @@ -373,8 +383,9 @@ impl ConvUnary {
name: &str,
mut wire: OutletId,
) -> TractResult<TVec<OutletId>> {
let a_fact = TypedFact::shape_and_dt_of(&self.kernel);
let mut b_fact = model.outlet_fact(wire)?.clone();
let (geo, _, k, _, mmm) = self.compute_geo(&b_fact)?;
let (geo, _, k, _, mmm) = self.compute_geo(&a_fact, &b_fact)?;
let input_shape = b_fact.shape.as_concrete().unwrap().to_vec();
let mut geo = geo.to_concrete(&input_shape)?.into_owned();
let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
Expand Down Expand Up @@ -417,10 +428,14 @@ impl ConvUnary {
let b_storage = mmm.b_virtual_input(Box::new(virtual_input), k);
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;

let wire = self.wire_lir_matmatmul(
let kernel = model.add_const(format!("{name}.kernels"), self.kernel.clone())?;
let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;

let wire = self.wire_mm_weights_bias(
model,
name,
&[wire],
wire,
g_o_ihw[0],
mmm,
c_dt,
mmm_output_shape.clone().into(),
Expand All @@ -438,9 +453,10 @@ impl ConvUnary {
#[allow(clippy::type_complexity)]
fn compute_geo(
&self,
kernel_fact: &TypedFact,
input_fact: &TypedFact,
) -> TractResult<(PoolGeometry, usize, usize, TDim, Box<dyn MatMatMul>)> {
let a_dt = self.kernel.datum_type();
let a_dt = kernel_fact.datum_type;
let b_dt = input_fact.datum_type;
let c_dt = crate::ops::matmul::output_type(b_dt);

Expand All @@ -461,11 +477,12 @@ impl ConvUnary {
}

#[allow(clippy::too_many_arguments)]
fn wire_lir_matmatmul(
fn wire_mm_weights_bias(
&self,
model: &mut TypedModel,
name: &str,
wire: &[OutletId],
input: OutletId,
g_o_ihw: OutletId,
mmm: Box<dyn MatMatMul>,
c_datum_type: DatumType,
mmm_output_shape: ShapeFact,
Expand All @@ -474,12 +491,10 @@ impl ConvUnary {
c_n_axis: usize,
b_storage: InputStoreSpec,
) -> TractResult<TVec<OutletId>> {
let kernel = model.add_const(format!("{name}.kernels"), self.kernel.clone())?;
let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
let kernels = self
.wire_pack_g_o_ihw(model, name, mmm.a_pack(), g_o_ihw[0])
let packed_ker = self
.wire_pack_g_o_ihw(model, name, mmm.a_pack(), g_o_ihw)
.context("in kernel_as_packed_as")?;
let a_dt = model.outlet_fact(kernels)?.datum_type;
let a_dt = model.outlet_fact(packed_ker)?.datum_type;
let a_storage = unsafe { mmm.a_packed(a_dt.size_of(), k) };
let (mut c_to_a_axis_mapping, mut c_to_b_axis_mapping) = (tvec!(), tvec!());

Expand All @@ -495,8 +510,7 @@ impl ConvUnary {
c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
};
let mut wires: TVec<OutletId> = wire.into();
wires.push(kernels);
let mut wires: TVec<OutletId> = tvec!(input, packed_ker);
let mut ops: Vec<ProtoFusedSpec> = vec![ProtoFusedSpec::AddMatMul(geo, 1, 0)];
if let Some(bias) = &self.bias {
let bias = model.add_const(format!("{name}.bias"), bias.clone())?;
Expand Down Expand Up @@ -789,16 +803,14 @@ impl EvalOp for ConvUnary {

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut model = TypedModel::default();
let mut wire: TVec<OutletId> = inputs
let wire: TVec<OutletId> = inputs
.iter()
.enumerate()
.map(|(ix, v)| model.add_source(format!("source.{ix}"), v.datum_type().fact(v.shape())))
.collect::<TractResult<_>>()?;
let wire = unsafe {
if self.q_params.is_some() {
let new_op = self.kernel_offset_u8_as_i8(&mut wire, &mut model)?;
let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
op_ref.wire_as_quant_im2col(&mut model, "im2col-adhoc", &wire)?
self.wire_as_quant_im2col(&mut model, "im2col-adhoc", &wire)?
} else {
self.wire_as_im2col_pair(&mut model, "im2col-adhoc", &wire)?
}
Expand Down Expand Up @@ -1050,16 +1062,6 @@ impl TypedOp for ConvUnary {
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let DatumType::U8 = self.kernel.datum_type().unquantized() {
let mut patch = TypedModelPatch::default();
let mut wire = patch.taps(model, &node.inputs)?;
let new_op = self.kernel_offset_u8_as_i8(&mut wire, &mut patch)?.unwrap();
let wire = patch.wire_node(&node.name, new_op, &wire)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
patch.obliterate(node.id)?;
return Ok(Some(patch.with_context("kernel-u8-to-i8")));
}

let input_fact = model.outlet_fact(node.inputs[0])?;
unsafe {
if self.q_params.is_some() {
Expand Down
4 changes: 2 additions & 2 deletions core/src/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,13 @@ where
F: Fact + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
// eprint!("{node} {input:?}");
// eprint!("{node} {input:?}");
let r = match state {
Some(ref mut state) => state.eval(session_state, node.op(), input),
None => node.op().eval(input),
}
.with_context(|| format!("Evaluating {node}"));
// eprintln!(" ==> {r:?}");
// eprintln!(" ==> {r:?}");
r
}

Expand Down

0 comments on commit 7a263b8

Please sign in to comment.