Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor quant as cast #1343

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 162 additions & 1 deletion core/src/ops/cast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use tract_data::itertools::Itertools;

use crate::internal::*;
use crate::plan::eval;

pub fn cast(to: DatumType) -> Cast {
Cast { to }
Expand All @@ -24,7 +27,8 @@ impl Cast {
Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue()))
}
} else {
Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue()))
let out = input.cast_to_dt(self.to)?;
Ok(tvec!(out.into_owned().into_tvalue()))
}
}
}
Expand Down Expand Up @@ -102,5 +106,162 @@ impl TypedOp for Cast {
Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
}

fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let src_dt = model.node_input_facts(node.id)?[0].datum_type;
if src_dt.is_quantized() && src_dt.size_of() == 1 && self.to.is_float() {
codegen_quant_ew_chain_to_lut(self, model, node)
} else {
Ok(None)
}
}

as_op!();
}

fn codegen_quant_ew_chain_to_lut(
original_dequant: &Cast,
model: &TypedModel,
origin: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut current = origin;
let incoming_dt = model.node_input_facts(origin.id)?[0].datum_type;
while let Some(next) = model.single_succ(current.id)? {
/*
let q_params = if let Some(op) = op.op_as::<ElementWiseOp>() {
if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
} else {
op.0.downcast_ref::<QuantizeLinearI8>()
.map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
}
} else {
None
};
*/
let q_dt_dst: Option<DatumType> =
next.op_as::<Cast>().map(|c| c.to).filter(|dt| dt.is_quantized());
if let Some(dt) = q_dt_dst {
let (zp, scale) = dt.zp_scale();
/*
// first, try Op::quantize() on all ops in the chain
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, origin.inputs[0])?;
let mut next = model.single_succ(origin.id)?.unwrap();
loop {
if let Some(op) = next
.op
.quantize(model, dequant, dt, scale, zero_point)
.with_context(|| format!("Quantizing {next}"))?
{
wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
} else {
break;
}
if next.id == current.id {
patch.shunt_outside(model, OutletId::new(op.id, 0), wire)?;
return Ok(Some(patch));
} else {
next = model.single_succ(next.id)?.unwrap();
}
}
*/
// or else make a lookup table
if incoming_dt.is_quantized() && incoming_dt.size_of() == 1 {
return Ok(Some(
transform_quant_seq_to_lut(model, origin.inputs[0], next.id.into())
.context("Transforming sequence to LUT")?,
));
}
}
let (input_facts, output_facts) = model.node_facts(next.id)?;
let invariants = next
.op
.axes_mapping(&input_facts, &output_facts)
.with_context(|| format!("Querying invariants for {next}"))?;
if invariants.is_element_wise_unary() {
current = next;
} else {
break;
}
}
Ok(None)
}

fn transform_quant_seq_to_lut(
model: &TypedModel,
src: OutletId, // wire before the dequant cast
dst: OutletId, // wire after the requant cast
) -> TractResult<TypedModelPatch> {
let incoming_dt = model.outlet_fact(src)?.datum_type;
let outgoing_dt = model.outlet_fact(dst)?.datum_type;
ensure!(incoming_dt.is_quantized() && incoming_dt.size_of() == 1);

let mut adhoc_model = TypedModel::default();
let wire = adhoc_model.add_source("ad-hoc", incoming_dt.fact([256]))?;
let mut next = model.single_succ(src.node)?.unwrap();
// plug in dequant
let dequant = model.node(src.node);
let name = &dequant.name;
let mut wire: TVec<OutletId> = tvec!(wire);
while next.id != dst.node {
wire = adhoc_model.wire_node(&*next.name, next.op.clone(), &wire)?;
next = model.single_succ(next.id)?.unwrap();
}
// plug in quant
wire = adhoc_model.wire_node(&*next.name, next.op.clone(), &wire)?;
adhoc_model.set_output_outlets(&wire)?;

let input = tensor1(&(0u8..=255).collect_vec());
let input = input.cast_to_dt(incoming_dt.unquantized())?.cast_to_dt(incoming_dt)?.into_owned();
let output = SimpleState::new(SimplePlan::new(adhoc_model)?)?
.run_plan_with_eval(tvec!(input.into_tvalue()), |s, op, node, inputs| {
eprintln!("{node} {inputs:?}");
eval(s, op, node, inputs)
})?
.remove(0);

let table: &[u8] = match incoming_dt.unquantized() {
DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
DatumType::U8 => output.as_slice::<u8>()?,
_ => unreachable!(),
};
let op = crate::ops::quant::lookup_table((tract_linalg::ops().lut_u8)(table));
let mut patch = TypedModelPatch::default();
let mut wire = patch.taps(model, &[src])?;
wire = patch.wire_node(format!("{name}.lut"), op, &wire)?;
wire = patch.wire_node(format!("{name}.cast"), cast(outgoing_dt), &wire)?;
patch.shunt_outside(model, dst, wire[0])?;
Ok(patch)
}

#[cfg(test)]
mod test {
use super::*;
use crate::ops::nn::sigmoid;

#[test]
fn test_lut() -> TractResult<()> {
let mut model = TypedModel::default();
let dt = i8::datum_type().with_zp_scale(0, 0.03);
let src = model.add_source("src", dt.fact(&[10]))?;
let mut wire = model.wire_node("dq", cast(f32::datum_type()), &[src])?;
wire = model.wire_node("sigmoid", sigmoid(), &wire)?;
wire = model.wire_node("q", cast(dt), &wire)?;
model.set_output_outlets(&wire)?;

let input =
tensor1(&(-5i32..5i32).collect_vec()).cast_to::<f32>()?.cast_to_dt(dt)?.into_owned();
let ref_output = model.clone().into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
dbg!(&input);
dbg!(&ref_output);

let codegen = model.into_optimized()?;
assert!(codegen.nodes.len() == 2); // Source then LookupTable
let output = codegen.into_runnable()?.run(tvec!(input.into_tvalue()))?;
output[0].close_enough(&ref_output[0], Approximation::Exact)
}
}
2 changes: 2 additions & 0 deletions core/src/ops/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ macro_rules! element_wise {
$(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)?
$(; cost: $cost:expr )?
$(; declutter: $declutter:expr )?
$(; eval_override: $eval_override: expr)?
$(; operating_datum_type: $operating_datum_type:expr )?
$(; prefix: $prefix:expr )?
$(; quantize: $quantize:expr )?
Expand All @@ -177,6 +178,7 @@ macro_rules! element_wise {
format!("{}{}", self.prefix(), stringify!($Op))
}
fn eval_in_place(&self, t: &mut Tensor) -> TractResult<()> {
$( return $eval_override(self, t); )?
$(
$(if t.datum_type() == $typ::datum_type() {
let t: &mut[$typ] = t.as_slice_mut::<$typ>()?;
Expand Down
29 changes: 21 additions & 8 deletions core/src/ops/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tract_linalg::Scaler;
use super::binary::TypedBinOp;
use super::math::round_ties_to_even;

/*
pub fn quantize_linear_f32_u8(x: f32, scale: f32, zero_point: i32) -> u8 {
(((x * scale).round() as i32) + zero_point)
.clamp(u8::min_value() as i32, u8::max_value() as i32) as u8
Expand All @@ -20,7 +21,9 @@ pub fn quantize_linear_f32_i8(x: f32, scale: f32, zero_point: i32) -> i8 {
(((x * scale).round() as i32) + zero_point)
.clamp(i8::min_value() as i32, i8::max_value() as i32) as i8
}
*/

/*
element_wise_oop!(quantize_linear_u8,
QuantizeLinearU8 {
scale: f32,
Expand Down Expand Up @@ -250,24 +253,34 @@ impl TypedOp for DequantizeLinearF32 {

as_op!();
}
*/

element_wise_oop!(lookup_table,
element_wise!(lookup_table,
LookupTable {
table: Box<dyn Lut>
},
[i8] => i8 |op, xs, ys| {
ys.copy_from_slice(xs);
}, ;
eval_override: |op: &LookupTable, xs: &mut Tensor| {
// dbg!(&op.table.table());
// dbg!(&xs);
let bytes = unsafe { xs.as_bytes_mut() };
// dbg!(&bytes);
op.table.run(bytes);
// dbg!(&bytes);
Ok(())
}
/*
[i8] => |op, xs| {
unsafe {
let casted = std::slice::from_raw_parts_mut(ys.as_mut_ptr() as *mut u8, ys.len());
let casted = std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len());
op.table.run(casted);
}
Ok(())
},
[u8] => u8 |op, xs, ys| {
ys.copy_from_slice(xs);
op.table.run(ys);
[u8] => |op, xs| {
op.table.run(xs);
Ok(())
}
*/
);

#[derive(Debug, Clone, Hash)]
Expand Down
Loading