From 271be98eca1e4f07dbb9e2417fb81388a16a7f84 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 30 Nov 2023 16:12:45 +0100 Subject: [PATCH] towards dyn conv --- core/src/ops/cnn/conv/unary.rs | 30 +++++++++++++++--------------- core/src/ops/einsum/codegen.rs | 6 +++--- core/src/ops/matmul/mir_quant.rs | 32 +++++++++++++++----------------- 3 files changed, 33 insertions(+), 35 deletions(-) diff --git a/core/src/ops/cnn/conv/unary.rs b/core/src/ops/cnn/conv/unary.rs index 6cb2e58e70..f2c54dba4e 100644 --- a/core/src/ops/cnn/conv/unary.rs +++ b/core/src/ops/cnn/conv/unary.rs @@ -155,36 +155,36 @@ impl ConvUnary { use crate::ops::matmul::mir_quant as qmm; let c_dt = self.q_params.unwrap(); - let &[input, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = wires else { + let &[mut x, mut k0, mut k_scale, mut x0, x_scale, y0, y_scale] = wires else { bail!("Wrong number of inputs") }; - let kernel = model.add_const(format!("{name}.kernel"), self.kernel.clone())?; - let kernel = wire_offset_u8_as_i8(model, name, kernel, "a", &mut a0, "a0")?; - let b = wire_offset_u8_as_i8(model, name, input, "b", &mut b0, "b0")?; + let mut kernel = model.add_const(format!("{name}.kernel"), self.kernel.clone())?; + wire_offset_u8_as_i8(model, name, &mut kernel, "k", &mut k0)?; + wire_offset_u8_as_i8(model, name, &mut x, "x", &mut x0)?; let a_fact = model.outlet_fact(kernel)?.clone(); - let b_fact = model.outlet_fact(b)?.clone(); + let b_fact = model.outlet_fact(x)?.clone(); let (_, _, k, n, mmm) = self.compute_geo(&a_fact, &b_fact)?; let output_shape = self.pool_spec.output_shape(&b_fact.shape)?; - if !model.outlet_fact(a_scale)?.shape.volume().is_one() { + if !model.outlet_fact(k_scale)?.shape.volume().is_one() { // requant is performed before geo_reshape, so we need at most one geo axis to the // right if !output_shape.fmt.c_is_last() { - a_scale = model.wire_node( + k_scale = model.wire_node( format!("{name}.a_scale_axis_fix"), AxisOp::Add(1), - &[a_scale], + &[k_scale], )?[0]; } } - let abc_scale = qmm::combine_scales(model, name, a_scale, b_scale, c_scale)?; + let abc_scale = qmm::combine_scales(model, name, k_scale, x_scale, y_scale)?; let im2col = model.wire_node( format!("{name}.im2col"), Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?, - &[b, b0], + &[x, x0], )?[0]; let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?; @@ -218,9 +218,9 @@ impl ConvUnary { model.wire_node(format!("{name}.transpose_sum_b"), AxisOp::Move(3, 1), &sum_b)?; } - let b_dt = model.outlet_fact(b)?.datum_type; + let x_dt = model.outlet_fact(x)?.datum_type; let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?; - let b_storage = unsafe { mmm.b_packed(b_dt.size_of(), k) }; + let b_storage = unsafe { mmm.b_packed(x_dt.size_of(), k) }; let wire = self.wire_mm_weights_bias( model, name, @@ -240,15 +240,15 @@ impl ConvUnary { name, wire[0], k.to_dim(), - a0, - b0, + k0, + x0, sum_a[0], sum_b[0], )?; let wire = self.wire_remove_group(model, name, &[wire], &mmm_output_shape, c_axis)?; let wire = self.wire_rm_n_if_needed(model, name, &wire)?; - let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, c0)?; + let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, y0)?; Self::wire_geo_reshape(model, name, &[wire], &output_shape) } diff --git a/core/src/ops/einsum/codegen.rs b/core/src/ops/einsum/codegen.rs index d84b860c54..b44ec13b21 100644 --- a/core/src/ops/einsum/codegen.rs +++ b/core/src/ops/einsum/codegen.rs @@ -224,7 +224,7 @@ fn dequant( let name = &node.name; let mut patch = TypedModelPatch::new("Dequantizing einsum"); let taps = patch.taps(model, &node.inputs)?; - let [a, b, bias, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = *taps else { + let [mut a, mut b, bias, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = *taps else { bail!("Expect exactly 9 inputs") }; @@ -240,8 +240,8 @@ fn dequant( } } - let a = wire_offset_u8_as_i8(&mut patch, &node.name, a, "a", &mut a0, "a0")?; - let b = wire_offset_u8_as_i8(&mut patch, &node.name, b, "b", &mut b0, "b0")?; + wire_offset_u8_as_i8(&mut patch, &node.name, &mut a, "a", &mut a0)?; + wire_offset_u8_as_i8(&mut patch, &node.name, &mut b, "b", &mut b0)?; let mut output = patch.wire_node( &node.name, diff --git a/core/src/ops/matmul/mir_quant.rs b/core/src/ops/matmul/mir_quant.rs index 562e2b8559..2bb5e8999c 100644 --- a/core/src/ops/matmul/mir_quant.rs +++ b/core/src/ops/matmul/mir_quant.rs @@ -10,47 +10,45 @@ use crate::ops::cast::cast; /// Only wires nodes of u8 type and leaves nodes of different type untouched. pub(crate) fn wire_offset_u8_as_i8( model: &mut TypedModel, - model_name: &str, - matrix: OutletId, - matrix_name: &str, + prefix: &str, + input: &mut OutletId, + input_name: &str, zero_point: &mut OutletId, - zero_point_name: &str, -) -> TractResult { - if let DatumType::U8 = model.outlet_fact(matrix)?.datum_type.unquantized() { +) -> TractResult<()> { + if let DatumType::U8 = model.outlet_fact(*input)?.datum_type.unquantized() { match model.outlet_fact(*zero_point)?.datum_type.unquantized() { DatumType::U8 => { *zero_point = model.wire_node( - format!("{model_name}.offset_{zero_point_name}_as_i8"), + format!("{prefix}.offset_{input_name}_zp_as_i8"), ops::quant::offset_u8_as_i8(), &[*zero_point], )?[0]; } DatumType::I32 | DatumType::I8 => { *zero_point = model.wire_node( - "{model_name}.{zero_point_name}.cast", + format!("{prefix}.{input_name}_zp.cast"), cast(i32::datum_type()), &[*zero_point], )?[0]; let cst = model.add_const( - format!("{model_name}.offset_{zero_point_name}_as_i8.min"), - tensor0(-128i32).broadcast_into_rank(model.outlet_fact(*zero_point)?.rank())? + format!("{prefix}.offset_{input_name}_zp_as_i8.min"), + tensor0(-128i32).broadcast_into_rank(model.outlet_fact(*zero_point)?.rank())?, )?; *zero_point = model.wire_node( - format!("{model_name}.offset_{zero_point_name}_as_i8"), + format!("{prefix}.offset_{input_name}_zp_as_i8"), ops::math::add(), &[*zero_point, cst], )?[0]; } _ => (), } - Ok(model.wire_node( - format!("{model_name}.offset_{matrix_name}_as_i8"), + *input = model.wire_node( + format!("{prefix}.offset_{input_name}_as_i8"), ops::quant::offset_u8_as_i8(), - &[matrix], - )?[0]) - } else { - Ok(matrix) + &[*input], + )?[0]; } + Ok(()) } pub(crate) fn combine_scales(