Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 6, 2023
1 parent 5436851 commit deee548
Show file tree
Hide file tree
Showing 17 changed files with 356 additions and 347 deletions.
15 changes: 0 additions & 15 deletions core/src/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use tract_num_traits::Float;
use crate::internal::translator::Translate;
use crate::internal::*;
use crate::ops::array::{Pad, PadMode};
use crate::ops::cnn::{ConvUnary, DeconvUnary};
use crate::ops::einsum::EinSum;
use crate::ops::konst::Const;
use crate::ops::scan::Scan;
Expand All @@ -27,14 +26,6 @@ impl<T1: Datum + Float, T2: Datum + Float>
Box::new(TypedSource::new(fact_float_precision_conversion::<T1, T2>(&source.fact)))
} else if let Some(konst) = node.op_as::<Const>() {
Box::new(Const(tensor_float_precision_conversion::<T1, T2>(&konst.0)))
/*
} else if let Some(op) = node.op_as::<ConvUnary>() {
Box::new(ConvUnary {
kernel: tensor_float_precision_conversion::<T1, T2>(&op.kernel),
bias: op.bias.as_ref().map(tensor_float_precision_conversion::<T1, T2>),
..op.clone()
})
*/
} else if let Some(op) = node.op_as::<Scan>() {
let body = FloatPrecisionTranslator::<T1, T2>::default().translate_model(&op.body)?;
Box::new(Scan { body, ..op.clone() })
Expand All @@ -43,12 +34,6 @@ impl<T1: Datum + Float, T2: Datum + Float>
operating_dt: dt_float_precision_conversion::<T1, T2>(op.operating_dt),
..op.clone()
})
} else if let Some(op) = node.op_as::<DeconvUnary>() {
Box::new(DeconvUnary {
kernel: tensor_float_precision_conversion::<T1, T2>(&op.kernel),
bias: op.bias.as_ref().map(tensor_float_precision_conversion::<T1, T2>),
..op.clone()
})
} else if let Some(op) = node.op_as::<Pad>() {
if let PadMode::Constant(t) = &op.mode {
Box::new(Pad {
Expand Down
9 changes: 5 additions & 4 deletions core/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ impl AxisOp {
shape.insert(*to, axis);
}
Reshape(at, from, to) => {
ensure!(from.iter().product::<TDim>() == to.iter().product::<TDim>());
if shape.len() >= from.len() + *at
&& tract_itertools::izip!(shape.iter().skip(*at), from)
.all(|(shape, spec)| shape.to_dim() == *spec)
Expand Down Expand Up @@ -435,7 +436,7 @@ impl AxisOp {

pub fn wire_split_axis(
model: &mut TypedModel,
name: &str,
name: impl ToString,
outlet: OutletId,
axis: usize,
outer_dim: usize,
Expand All @@ -444,20 +445,20 @@ impl AxisOp {
let dim: TDim = fact.shape[axis].clone();
let inner_dim = dim.clone() / outer_dim;
let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim));
model.wire_node(name, op, &[outlet])
model.wire_node(name.to_string(), op, &[outlet])
}

pub fn wire_collapse_axis(
model: &mut TypedModel,
name: &str,
name: impl ToString,
outlet: OutletId,
axis: usize,
) -> TractResult<TVec<OutletId>> {
let fact = model.outlet_fact(outlet)?;
let dim: TDim = fact.shape[axis].clone();
let next_dim: TDim = fact.shape[axis + 1].clone();
let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim));
model.wire_node(name, op, &[outlet])
model.wire_node(name.to_string(), op, &[outlet])
}
}

Expand Down
188 changes: 97 additions & 91 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@ use tract_num_traits::Zero;
use crate::internal::*;
use crate::model::*;
use crate::ops;
use crate::ops::array::MultiBroadcastTo;
use crate::ops::array::Pad;
use crate::ops::array::PadMode;
use crate::ops::binary::TypedBinOp;
use crate::ops::cast::cast;
use crate::ops::cnn::wire_reshape_bias;
use crate::ops::cnn::PaddingSpec::*;
use crate::ops::einsum::EinSum;
use crate::ops::math::Add;
use crate::ops::math::Div;
use crate::ops::math::Mul;
use crate::ops::math::Sub;
use crate::ops::math::{add, div, mul, sub};
use crate::ops::math::{Add, Div, Mul, Sub};
use crate::ops::matmul::lir_unary::AddMatMulGeometry;
use crate::ops::matmul::lir_unary::MapOutputAxisToInput;
use crate::ops::matmul::mir_quant::wire_offset_u8_as_i8;
Expand Down Expand Up @@ -159,7 +157,8 @@ impl ConvUnary {
Reduce::new(tvec!(2), ops::nn::Reducer::Sum),
&g_o_ihw_as_i32,
)?;
let sum_ker_a_g_c = model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?;
let sum_ker_a_g_c =
model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?;
// align sum_A from G,C to "C" shape: N,HW,G,C (or N,G,C,HW)
let sum_ker_n_g_c =
model.wire_node(format!("{name}.sum_ker_n_g_c"), AxisOp::Add(0), &sum_ker_a_g_c)?;
Expand All @@ -185,7 +184,8 @@ impl ConvUnary {
let x_dt = model.outlet_fact(x)?.datum_type;
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
let b_storage = unsafe { mmm.b_packed(x_dt.size_of(), k) };
let bias = model.wire_node(format!("{name}.cast_bias"), cast(mmm.internal_type()), &[bias])?[0];
let bias =
model.wire_node(format!("{name}.cast_bias"), cast(mmm.internal_type()), &[bias])?[0];
let wire = self.wire_mm_weights_bias(
model,
name,
Expand Down Expand Up @@ -503,20 +503,15 @@ impl ConvUnary {
name: &str,
wire: &[OutletId],
) -> TractResult<OutletId> {
let &[x, kernel, bias] = wire else { bail!("Wrong number of inputs") };
let &[x, kernel, mut bias] = wire else { bail!("Wrong number of inputs") };
let x_fact = model.outlet_fact(x)?.clone();
let x_shape = x_fact.shape.as_concrete().unwrap();
let ConcretePoolGeometry { input_shape, patch, output_shape } =
self.pool_spec.compute_geo(&x_fact.shape)?.to_concrete(x_shape)?.into_owned();
let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?;
let bias_fact = model.outlet_fact(bias)?;
let need_bias_shape = tvec!(self.output_channels().to_dim());
let bias_reshape: Box<dyn TypedOp> = if bias_fact.shape.volume().is_one() {
Box::new(MultiBroadcastTo { shape: need_bias_shape.into() })
} else {
Box::new(AxisOp::Reshape(0, bias_fact.shape.to_tvec(), need_bias_shape))
};
let bias = model.wire_node(format!("{name}.bias"), bias_reshape, &[bias])?[0];
let c_axis = self.pool_spec.data_format.shape(x_shape)?.c_axis();
bias =
wire_reshape_bias(model, name, bias, x_fact.rank(), c_axis, self.output_channels())?[0];
let op = DepthWise::new(patch, input_shape, output_shape);
Ok(model.wire_node(name, op, &[x, kernel[0], bias])?[0])
}
Expand Down Expand Up @@ -666,80 +661,91 @@ impl ConvUnary {
Ok(Some(patch))
}

/*
fn declutter_channel_arithmetic_succ(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.q_params.is_some() || self.group != 1 {
return Ok(None);
}
let &[succ] = &*node.outputs[0].successors else { return Ok(None) };
let Some(bin) = model.node(succ.node).op_as::<TypedBinOp>() else { return Ok(None) };
let other_input = model.node(succ.node).inputs[1 - succ.slot];
let other_fact = &model.outlet_fact(other_input)?;
let Some(konst) = &other_fact.konst else { return Ok(None) };
let axes_mapping = model.node_axes_mapping(succ.node)?;
let input_shape =
self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
let conv_c_axis = input_shape.c_axis();
let &[konst_c_axis] =
&*axes_mapping.axis((InOut::In(succ.slot), conv_c_axis))?.inputs[1 - succ.slot]
else {
return Ok(None);
};
let Ok(co) = node.outputs[0].fact.shape[conv_c_axis].to_usize() else { return Ok(None) };
let operand_for_bias = if konst.shape()[konst_c_axis] == co && konst.len() == co {
konst.clone().into_tensor().into_shape(&[co])?
} else if konst.len() == 1 {
konst.clone().to_scalar_tensor()?.broadcast_scalar_to_shape(&[co])?
} else {
return Ok(None);
};
let mut bias = if let Some(b) = &self.bias {
b.clone()
} else {
Tensor::zero_dt(other_fact.datum_type, &[co])?.into_arc_tensor()
};
let mut kernel = self.kernel.clone();
let mut operand_shape_for_kernel = tvec!(1; 2 + input_shape.hw_rank());
let o_axis = if self.kernel_fmt == KernelFormat::OIHW { 0 } else { self.kernel.rank() - 1 };
operand_shape_for_kernel[o_axis] = co;
let operand_for_kernel = operand_for_bias.clone().into_shape(&operand_shape_for_kernel)?;
if bin.0.is::<Sub>() && succ.slot == 0 {
bias = (bias.into_tensor().into_array::<f32>()?
- operand_for_bias.to_array_view::<f32>()?)
.into_arc_tensor()
} else if bin.0.is::<Div>() && succ.slot == 0 {
bias = (bias.into_tensor().into_array::<f32>()?
/ operand_for_bias.to_array_view::<f32>()?)
.into_arc_tensor();
kernel = (kernel.into_tensor().into_array::<f32>()?
/ operand_for_kernel.to_array_view::<f32>()?)
.into_arc_tensor();
} else if bin.0.is::<Add>() {
bias = (bias.into_tensor().into_array::<f32>()?
+ operand_for_bias.to_array_view::<f32>()?)
.into_arc_tensor();
} else if bin.0.is::<Mul>() {
bias = (bias.into_tensor().into_array::<f32>()?
* operand_for_bias.to_array_view::<f32>()?)
.into_arc_tensor();
kernel = (kernel.into_tensor().into_array::<f32>()?
* operand_for_kernel.to_array_view::<f32>()?)
.into_arc_tensor();
} else {
return Ok(None);
};
let new_op = ConvUnary { bias: Some(bias), kernel, ..self.clone() };
let mut patch = TypedModelPatch::default();
let wire = patch.tap_model(model, node.inputs[0])?;
let wire = patch.wire_node(&node.name, new_op, &[wire])?[0];
patch.shunt_outside(model, succ.node.into(), wire)?;
fn declutter_channel_arithmetic_succ(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.q_params.is_some() || self.group != 1 {
return Ok(None);
}
let &[succ_outlet] = &*node.outputs[0].successors else { return Ok(None) };
let succ = model.node(succ_outlet.node);
let Some(bin) = succ.op_as::<TypedBinOp>() else { return Ok(None) };
let other_input = succ.inputs[1 - succ_outlet.slot];
let axes_mapping = model.node_axes_mapping(succ.id)?;
let input_shape =
self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
let conv_c_axis = input_shape.c_axis();
if axes_mapping.axis((InOut::In(succ_outlet.slot), conv_c_axis))?.inputs
[1 - succ_outlet.slot]
.len()
!= 1
{
return Ok(None);
};

let mut patch = TypedModelPatch::default();
let [input, mut kernel, mut bias] = &*patch.taps(model, &node.inputs)? else {
panic!("Expect three inputs");
};
let name = &node.name;
let succ_name = &succ.name;

let operand = patch.tap_model(model, other_input)?;

let renamed = format!("{name}.{succ_name}");
bias = wire_reshape_bias(
&mut patch,
format!("{renamed}.reshape_bias"),
bias,
1,
0,
self.output_channels(),
)?[0];

let operand = wire_reshape_bias(
&mut patch,
format!("{renamed}.reshape_operand"),
operand,
1,
0,
self.output_channels(),
)?[0];

let operand_fact = patch.outlet_fact(operand)?.shape.to_tvec();
let kernel_fact = patch.outlet_fact(kernel)?;
let mut operand_shape_for_kernel = tvec!(1.to_dim(); 2 + input_shape.hw_rank());
operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] =
self.output_channels().to_dim();
let operand_for_kernel = patch.wire_node(
format!("{renamed}.reshape_operand_for_kernel"),
AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel),
&[operand],
)?[0];

if bin.0.is::<Sub>() && succ_outlet.slot == 0 {
bias = patch.wire_node(&renamed, sub(), &[bias, operand])?[0];
} else if bin.0.is::<Sub>() {
bias = patch.wire_node(&renamed, sub(), &[operand, bias])?[0];
} else if bin.0.is::<Div>() && succ_outlet.slot == 0 {
bias = patch.wire_node(&renamed, div(), &[bias, operand])?[0];
kernel = patch.wire_node(&renamed, div(), &[kernel, operand_for_kernel])?[0];
} else if bin.0.is::<Div>() {
bias = patch.wire_node(&renamed, div(), &[operand, bias])?[0];
kernel = patch.wire_node(&renamed, div(), &[operand_for_kernel, kernel])?[0];
} else if bin.0.is::<Add>() {
bias = patch.wire_node(&renamed, add(), &[bias, operand])?[0];
} else if bin.0.is::<Mul>() {
bias = patch.wire_node(&renamed, mul(), &[bias, operand])?[0];
kernel = patch.wire_node(&renamed, mul(), &[kernel, operand_for_kernel])?[0];
} else {
return Ok(None);
};
let wire = patch.wire_node(&node.name, self.clone(), &[*input, kernel, bias])?[0];
patch.shunt_outside(model, succ_outlet.node.into(), wire)?;
Ok(Some(patch))
}
*/
}

impl Op for ConvUnary {
Expand Down Expand Up @@ -894,7 +900,7 @@ impl TypedOp for ConvUnary {
}
pass!(declutter_stride_slice_to_downsample);
pass!(declutter_as_einsum);
// pass!(declutter_channel_arithmetic_succ);
pass!(declutter_channel_arithmetic_succ);
pass!(declutter_precursor_padding);
Ok(None)
}
Expand Down Expand Up @@ -1150,8 +1156,8 @@ mod test {
)?;
model.set_output_outlets(&wire)?;
model.declutter()?;
assert_eq!(model.nodes().len(), 2); // source + conv
let cv = model.nodes()[1].op_as::<ConvUnary>().unwrap();
assert_eq!(model.nodes().len(), 4); // source + conv + kernel + bias
let cv = model.nodes()[3].op_as::<ConvUnary>().unwrap();
assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); // source + conv
Ok(())
}
Expand Down
19 changes: 8 additions & 11 deletions core/src/ops/cnn/deconv/deconv_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub struct DeconvSum {
/// shape of the deconvolution input
pub input_shape: ShapeFact,
pub adjustments: TVec<usize>,
pub bias: Option<Arc<Tensor>>,
pub group: usize,
}

Expand Down Expand Up @@ -74,8 +73,7 @@ impl DeconvSum {
inputs: TVec<TValue>,
values: &SymbolValues,
) -> TractResult<TVec<TValue>> {
let gemm = args_1!(inputs).into_tensor();
let dt = gemm.datum_type();
let (gemm, bias) = args_2!(inputs);
let input_shape = self.input_shape.eval_to_usize(values)?.into_owned();
let input_shape = self.pool_spec.data_format.shape(input_shape)?;
let output_shape =
Expand All @@ -88,14 +86,10 @@ impl DeconvSum {
&self.pool_spec.strides(),
&self.adjustments,
)?;
let mut tensor = if let Some(b) = &self.bias {
b.broadcast_vector_to_shape(&output_shape.shape, output_shape.c_axis())?
} else {
Tensor::zero_dt(dt, &output_shape.shape)?
};
let mut tensor = bias.broadcast_to_shape(&output_shape.shape)?;
let hw = *gemm.shape().last().unwrap();
let n = *output_shape.n().unwrap_or(&1);
let n_o_hkwk_hw = gemm.into_shape(&[
let n_o_hkwk_hw = gemm.into_tensor().into_shape(&[
n,
*output_shape.c(),
self.pool_spec.kernel_shape.iter().product(),
Expand All @@ -121,8 +115,11 @@ impl DeconvSum {

impl TypedOp for DeconvSum {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs.len() == 2);
let shape = super::output_shape(&self.pool_spec, &self.input_shape, &self.adjustments)?;
Ok(tvec!(inputs[0].datum_type.fact(&*shape)))
ensure!(inputs[1].rank() == shape.len());
ensure!(inputs[1].shape.iter().zip(shape.iter()).all(|(b, o)| b.is_one() || b == o.to_dim()));
Ok(tvec!(inputs[0].datum_type.fact(shape)))
}

fn concretize_dims(
Expand All @@ -136,7 +133,7 @@ impl TypedOp for DeconvSum {
target.wire_node(
&node.name,
Self { input_shape: self.input_shape.eval(values)?.into_owned(), ..self.clone() },
&[mapping[&node.inputs[0]]],
&[mapping[&node.inputs[0]], mapping[&node.inputs[1]]],
)
}

Expand Down
Loading

0 comments on commit deee548

Please sign in to comment.