Skip to content

Commit

Permalink
towards dyn conv
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 30, 2023
1 parent 119b083 commit 271be98
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 35 deletions.
30 changes: 15 additions & 15 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,36 +155,36 @@ impl ConvUnary {
use crate::ops::matmul::mir_quant as qmm;

let c_dt = self.q_params.unwrap();
let &[input, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = wires else {
let &[mut x, mut k0, mut k_scale, mut x0, x_scale, y0, y_scale] = wires else {
bail!("Wrong number of inputs")
};
let kernel = model.add_const(format!("{name}.kernel"), self.kernel.clone())?;
let kernel = wire_offset_u8_as_i8(model, name, kernel, "a", &mut a0, "a0")?;
let b = wire_offset_u8_as_i8(model, name, input, "b", &mut b0, "b0")?;
let mut kernel = model.add_const(format!("{name}.kernel"), self.kernel.clone())?;
wire_offset_u8_as_i8(model, name, &mut kernel, "k", &mut k0)?;
wire_offset_u8_as_i8(model, name, &mut x, "x", &mut x0)?;

let a_fact = model.outlet_fact(kernel)?.clone();
let b_fact = model.outlet_fact(b)?.clone();
let b_fact = model.outlet_fact(x)?.clone();
let (_, _, k, n, mmm) = self.compute_geo(&a_fact, &b_fact)?;
let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;

if !model.outlet_fact(a_scale)?.shape.volume().is_one() {
if !model.outlet_fact(k_scale)?.shape.volume().is_one() {
// requant is performed before geo_reshape, so we need at most one geo axis to the
// right
if !output_shape.fmt.c_is_last() {
a_scale = model.wire_node(
k_scale = model.wire_node(
format!("{name}.a_scale_axis_fix"),
AxisOp::Add(1),
&[a_scale],
&[k_scale],
)?[0];
}
}

let abc_scale = qmm::combine_scales(model, name, a_scale, b_scale, c_scale)?;
let abc_scale = qmm::combine_scales(model, name, k_scale, x_scale, y_scale)?;

let im2col = model.wire_node(
format!("{name}.im2col"),
Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
&[b, b0],
&[x, x0],
)?[0];

let g_o_ihw = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
Expand Down Expand Up @@ -218,9 +218,9 @@ impl ConvUnary {
model.wire_node(format!("{name}.transpose_sum_b"), AxisOp::Move(3, 1), &sum_b)?;
}

let b_dt = model.outlet_fact(b)?.datum_type;
let x_dt = model.outlet_fact(x)?.datum_type;
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
let b_storage = unsafe { mmm.b_packed(b_dt.size_of(), k) };
let b_storage = unsafe { mmm.b_packed(x_dt.size_of(), k) };
let wire = self.wire_mm_weights_bias(
model,
name,
Expand All @@ -240,15 +240,15 @@ impl ConvUnary {
name,
wire[0],
k.to_dim(),
a0,
b0,
k0,
x0,
sum_a[0],
sum_b[0],
)?;

let wire = self.wire_remove_group(model, name, &[wire], &mmm_output_shape, c_axis)?;
let wire = self.wire_rm_n_if_needed(model, name, &wire)?;
let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, c0)?;
let wire = qmm::requant(model, name, wire[0], c_dt, abc_scale, y0)?;
Self::wire_geo_reshape(model, name, &[wire], &output_shape)
}

Expand Down
6 changes: 3 additions & 3 deletions core/src/ops/einsum/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ fn dequant(
let name = &node.name;
let mut patch = TypedModelPatch::new("Dequantizing einsum");
let taps = patch.taps(model, &node.inputs)?;
let [a, b, bias, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
let [mut a, mut b, bias, mut a0, mut a_scale, mut b0, b_scale, c0, c_scale] = *taps else {
bail!("Expect exactly 9 inputs")
};

Expand All @@ -240,8 +240,8 @@ fn dequant(
}
}

let a = wire_offset_u8_as_i8(&mut patch, &node.name, a, "a", &mut a0, "a0")?;
let b = wire_offset_u8_as_i8(&mut patch, &node.name, b, "b", &mut b0, "b0")?;
wire_offset_u8_as_i8(&mut patch, &node.name, &mut a, "a", &mut a0)?;
wire_offset_u8_as_i8(&mut patch, &node.name, &mut b, "b", &mut b0)?;

let mut output = patch.wire_node(
&node.name,
Expand Down
32 changes: 15 additions & 17 deletions core/src/ops/matmul/mir_quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,45 @@ use crate::ops::cast::cast;
/// Only wires nodes of u8 type and leaves nodes of different type untouched.
pub(crate) fn wire_offset_u8_as_i8(
model: &mut TypedModel,
model_name: &str,
matrix: OutletId,
matrix_name: &str,
prefix: &str,
input: &mut OutletId,
input_name: &str,
zero_point: &mut OutletId,
zero_point_name: &str,
) -> TractResult<OutletId> {
if let DatumType::U8 = model.outlet_fact(matrix)?.datum_type.unquantized() {
) -> TractResult<()> {
if let DatumType::U8 = model.outlet_fact(*input)?.datum_type.unquantized() {
match model.outlet_fact(*zero_point)?.datum_type.unquantized() {
DatumType::U8 => {
*zero_point = model.wire_node(
format!("{model_name}.offset_{zero_point_name}_as_i8"),
format!("{prefix}.offset_{input_name}_zp_as_i8"),
ops::quant::offset_u8_as_i8(),
&[*zero_point],
)?[0];
}
DatumType::I32 | DatumType::I8 => {
*zero_point = model.wire_node(
"{model_name}.{zero_point_name}.cast",
format!("{prefix}.{input_name}_zp.cast"),
cast(i32::datum_type()),
&[*zero_point],
)?[0];
let cst = model.add_const(
format!("{model_name}.offset_{zero_point_name}_as_i8.min"),
tensor0(-128i32).broadcast_into_rank(model.outlet_fact(*zero_point)?.rank())?
format!("{prefix}.offset_{input_name}_zp_as_i8.min"),
tensor0(-128i32).broadcast_into_rank(model.outlet_fact(*zero_point)?.rank())?,
)?;
*zero_point = model.wire_node(
format!("{model_name}.offset_{zero_point_name}_as_i8"),
format!("{prefix}.offset_{input_name}_zp_as_i8"),
ops::math::add(),
&[*zero_point, cst],
)?[0];
}
_ => (),
}
Ok(model.wire_node(
format!("{model_name}.offset_{matrix_name}_as_i8"),
*input = model.wire_node(
format!("{prefix}.offset_{input_name}_as_i8"),
ops::quant::offset_u8_as_i8(),
&[matrix],
)?[0])
} else {
Ok(matrix)
&[*input],
)?[0];
}
Ok(())
}

pub(crate) fn combine_scales(
Expand Down

0 comments on commit 271be98

Please sign in to comment.