Skip to content

Commit

Permalink
wip, split seq_to_lut
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Feb 27, 2024
1 parent e6f4757 commit 348a496
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 11 deletions.
204 changes: 201 additions & 3 deletions core/src/ops/cast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use tract_data::itertools::Itertools;

use crate::internal::*;
use crate::plan::eval;

pub fn cast(to: DatumType) -> Cast {
Cast { to }
Expand All @@ -11,8 +14,6 @@ pub struct Cast {

impl Cast {
fn do_eval(&self, input: &Tensor, symbols: &SymbolValues) -> TractResult<TVec<TValue>> {
dbg!(self);
dbg!(input);
if input.datum_type() == self.to {
Ok(tvec!(input.clone().into_tvalue()))
} else if input.datum_type() == TDim::datum_type() {
Expand All @@ -27,7 +28,6 @@ impl Cast {
}
} else {
let out = input.cast_to_dt(self.to)?;
dbg!(&out);
Ok(tvec!(out.into_owned().into_tvalue()))
}
}
Expand Down Expand Up @@ -106,5 +106,203 @@ impl TypedOp for Cast {
Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
}

fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let src_dt = model.node_input_facts(node.id)?[0].datum_type;
if src_dt.is_quantized() && src_dt.size_of() == 1 && self.to.is_float() {
codegen_quant_ew_chain_to_lut(self, model, node)
} else {
Ok(None)
}
}

as_op!();
}

fn codegen_quant_ew_chain_to_lut(
original_dequant: &Cast,
model: &TypedModel,
origin: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut current = origin;
let incoming_dt = model.node_input_facts(origin.id)?[0].datum_type;
while let Some(next) = model.single_succ(current.id)? {
/*
let q_params = if let Some(op) = op.op_as::<ElementWiseOp>() {
if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
} else {
op.0.downcast_ref::<QuantizeLinearI8>()
.map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
}
} else {
None
};
*/
let q_dt_dst: Option<DatumType> =
next.op_as::<Cast>().map(|c| c.to).filter(|dt| dt.is_quantized());
if let Some(dt) = q_dt_dst {
let (zp, scale) = dt.zp_scale();
/*
// first, try Op::quantize() on all ops in the chain
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, origin.inputs[0])?;
let mut next = model.single_succ(origin.id)?.unwrap();
loop {
if let Some(op) = next
.op
.quantize(model, dequant, dt, scale, zero_point)
.with_context(|| format!("Quantizing {next}"))?
{
wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
} else {
break;
}
if next.id == current.id {
patch.shunt_outside(model, OutletId::new(op.id, 0), wire)?;
return Ok(Some(patch));
} else {
next = model.single_succ(next.id)?.unwrap();
}
}
*/
// or else make a lookup table
if incoming_dt.is_quantized() && incoming_dt.size_of() == 1 {
return Ok(Some(
transform_quant_seq_to_lut(model, origin.inputs[0], next.id.into())
.context("Transforming sequence to LUT")?,
));
/*
let mut adhoc_model = TypedModel::default();
let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
let mut next = model.single_succ(dequant.id)?.unwrap();
let mut name = None;
// plug in dequant
wire =
adhoc_model.wire_node(&*dequant.name, dequant.op.clone(), [wire].as_ref())?[0];
while next.id != op.id {
name.get_or_insert(&*next.name);
wire = adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?[0];
next = model.single_succ(next.id)?.unwrap();
}
// plug in quant
wire = adhoc_model.wire_node(&*op.name, op.op.clone(), [wire].as_ref())?[0];
adhoc_model.set_output_outlets(&[wire])?;
let input = (0u8..=255).collect::<Vec<u8>>();
let input = match dt {
DatumType::I8 => unsafe {
tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
},
DatumType::U8 => tensor1(&input),
_ => unreachable!(),
};
let output =
SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
let table: &[u8] = match dt {
DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
DatumType::U8 => output.as_slice::<u8>()?,
_ => unreachable!(),
};
let op = lookup_table((tract_linalg::ops().lut_u8)(table));
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
patch.shunt_outside(model, OutletId::new(op.id, 0), wire)?;
return Ok(Some(patch));
*/
}
}
let (input_facts, output_facts) = model.node_facts(next.id)?;
let invariants = next
.op
.axes_mapping(&input_facts, &output_facts)
.with_context(|| format!("Querying invariants for {next}"))?;
if invariants.is_element_wise_unary() {
current = next;
} else {
break;
}
}
Ok(None)
}

fn transform_quant_seq_to_lut(
model: &TypedModel,
src: OutletId, // wire before the dequant cast
dst: OutletId, // wire after the requant cast
) -> TractResult<TypedModelPatch> {
let incoming_dt = model.outlet_fact(src)?.datum_type;
let outgoing_dt = model.outlet_fact(dst)?.datum_type;
ensure!(incoming_dt.is_quantized() && incoming_dt.size_of() == 1);

let mut adhoc_model = TypedModel::default();
let wire = adhoc_model.add_source("ad-hoc", incoming_dt.fact([256]))?;
let mut next = model.single_succ(src.node)?.unwrap();
let mut name = None;
// plug in dequant
let dequant = model.node(src.node);
let mut wire = tvec!(wire);
while next.id != dst.node {
name.get_or_insert(&*next.name);
wire = adhoc_model.wire_node(&*next.name, next.op.clone(), &wire)?;
next = model.single_succ(next.id)?.unwrap();
}
// plug in quant
wire = adhoc_model.wire_node(&*next.name, next.op.clone(), &wire)?;
adhoc_model.set_output_outlets(&wire)?;

let input = tensor1(&(0u8..=255).collect_vec());
let input = input.cast_to_dt(incoming_dt.unquantized())?.cast_to_dt(incoming_dt)?.into_owned();
let output = SimpleState::new(SimplePlan::new(adhoc_model)?)?
.run_plan_with_eval(tvec!(input.into_tvalue()), |s, op, node, inputs| {
eprintln!("{node} {inputs:?}");
eval(s, op, node, inputs)
})?
.remove(0);
dbg!(&output);

let table: &[u8] = match incoming_dt.unquantized() {
DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
DatumType::U8 => output.as_slice::<u8>()?,
_ => unreachable!(),
};
let op = crate::ops::quant::lookup_table((tract_linalg::ops().lut_u8)(table));
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, src)?;

wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
patch.shunt_outside(model, dst, wire)?;
Ok(patch)
}

#[cfg(test)]
mod test {
use super::*;
use crate::ops::nn::sigmoid;

#[test]
fn test_lut() -> TractResult<()> {
let mut model = TypedModel::default();
let dt = i8::datum_type().with_zp_scale(0, 0.03);
let src = model.add_source("src", dt.fact(&[10]))?;
let mut wire = model.wire_node("dq", cast(f32::datum_type()), &[src])?;
wire = model.wire_node("sigmoid", sigmoid(), &wire)?;
wire = model.wire_node("q", cast(dt), &wire)?;
model.set_output_outlets(&wire)?;

let input =
tensor1(&(-5i32..5i32).collect_vec()).cast_to::<f32>()?.cast_to_dt(dt)?.into_owned();
let ref_output = model.clone().into_runnable()?.run(tvec!(input.clone().into_tvalue()))?;
dbg!(&input);
dbg!(&ref_output);

let codegen = model.into_optimized()?;
assert!(codegen.nodes.len() == 2); // Source then LookupTable
let output = codegen.into_runnable()?.run(tvec!(input.into_tvalue()))?;
output[0].close_enough(&ref_output[0], Approximation::Exact)
}
}
2 changes: 2 additions & 0 deletions core/src/ops/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ macro_rules! element_wise {
$(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)?
$(; cost: $cost:expr )?
$(; declutter: $declutter:expr )?
$(; eval_override: $eval_override: expr)?
$(; operating_datum_type: $operating_datum_type:expr )?
$(; prefix: $prefix:expr )?
$(; quantize: $quantize:expr )?
Expand All @@ -177,6 +178,7 @@ macro_rules! element_wise {
format!("{}{}", self.prefix(), stringify!($Op))
}
fn eval_in_place(&self, t: &mut Tensor) -> TractResult<()> {
$( return $eval_override(self, t); )?
$(
$(if t.datum_type() == $typ::datum_type() {
let t: &mut[$typ] = t.as_slice_mut::<$typ>()?;
Expand Down
25 changes: 17 additions & 8 deletions core/src/ops/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,23 +255,32 @@ impl TypedOp for DequantizeLinearF32 {
}
*/

element_wise_oop!(lookup_table,
element_wise!(lookup_table,
LookupTable {
table: Box<dyn Lut>
},
[i8] => i8 |op, xs, ys| {
ys.copy_from_slice(xs);
}, ;
eval_override: |op: &LookupTable, xs: &mut Tensor| {
// dbg!(&op.table.table());
// dbg!(&xs);
let bytes = unsafe { xs.as_bytes_mut() };
// dbg!(&bytes);
op.table.run(bytes);
// dbg!(&bytes);
Ok(())
}
/*
[i8] => |op, xs| {
unsafe {
let casted = std::slice::from_raw_parts_mut(ys.as_mut_ptr() as *mut u8, ys.len());
let casted = std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len());
op.table.run(casted);
}
Ok(())
},
[u8] => u8 |op, xs, ys| {
ys.copy_from_slice(xs);
op.table.run(ys);
[u8] => |op, xs| {
op.table.run(xs);
Ok(())
}
*/
);

#[derive(Debug, Clone, Hash)]
Expand Down

0 comments on commit 348a496

Please sign in to comment.