Skip to content

Commit

Permalink
slice and downsample tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 6, 2023
1 parent fef17b0 commit 4bd34f2
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 16 deletions.
3 changes: 2 additions & 1 deletion test-rt/suite-unit/src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ impl Arbitrary for SliceProblem {
type Parameters = ();
type Strategy = BoxedStrategy<SliceProblem>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
vec(0..10usize, 1..5usize)
vec(1..10usize, 1..5usize)
.prop_flat_map(|input_shape| {
let rank = input_shape.len();
(Just(input_shape), 0..rank)
Expand All @@ -24,6 +24,7 @@ impl Arbitrary for SliceProblem {
let b1 = 0..=shape[axis];
(Just(shape), Just(axis), b0, b1)
})
.prop_filter("non empty slice", |(_, _, b0, b1)| b0 != b1)
.prop_map(|(input_shape, axis, b0, b1)| {
let start = b0.min(b1).to_dim();
let end = b0.max(b1).to_dim();
Expand Down
22 changes: 19 additions & 3 deletions test-rt/test-tflite/src/tflite_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ impl State for TfliteState {
ensure!(inputs.len() == interpreter.input_tensor_count());
for (ix, input) in inputs.iter().enumerate() {
let input_tensor = interpreter.input(ix)?;
dbg!(&input_tensor);
assert_eq!(input_tensor.shape().dimensions(), input.shape());
dbg!(&input);
input_tensor.set_data(unsafe { input.as_bytes() })?;
}
interpreter.invoke()?;
Expand All @@ -57,7 +55,10 @@ impl State for TfliteState {
DataType::Int64 => i64::datum_type(),
DataType::Int8 => {
if let Some(qp) = output_tensor.quantization_parameters() {
i8::datum_type().quantize(QParams::ZpScale { zero_point: qp.zero_point, scale: qp.scale })
i8::datum_type().quantize(QParams::ZpScale {
zero_point: qp.zero_point,
scale: qp.scale,
})
} else {
i8::datum_type()
}
Expand All @@ -81,3 +82,18 @@ fn runtime() -> &'static TfliteRuntime {
}

include!(concat!(env!("OUT_DIR"), "/tests/tests.rs"));

#[cfg(test)]
mod tests{
use super::*;

#[test]
fn test_trivial() -> TractResult<()> {
let mut model = TypedModel::default();
let wire = model.add_source("x", f32::fact(&[1]))?;
model.set_output_outlets(&[wire])?;
let out = runtime().prepare(model)?.run(tvec!(tensor1(&[0f32]).into_tvalue()))?.remove(0);
assert_eq!(out, tensor1(&[0f32]).into_tvalue());
Ok(())
}
}
5 changes: 3 additions & 2 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ pub fn suite() -> infra::TestSuite {
"proptest",
QConvProblemParams { conv: cv, no_kernel_zero_point: true, .. QConvProblemParams::default() },
);
infra::TestSuite::default().with("onnx", onnx).with("conv", unit)
infra::TestSuite::default().with("onnx", onnx).with("unit", unit)
}

fn ignore_onnx(t: &[String]) -> bool {
let name = t.last().unwrap();
let included = "_conv_ Conv1d Conv2d squeeze _transpose_ test_reshape test_flatten where less greater equal slice";
let excluded = "
test_slice_start_out_of_bounds
test_Conv1d_groups
test_Conv2d_groups
test_Conv1d_depthwise_with_multiplier
Expand All @@ -33,7 +34,7 @@ fn ignore_onnx(t: &[String]) -> bool {

fn ignore_conv(t: &[String]) -> bool {
let [section, unit] = t else { return false };
["deconv", "slice", "downsample"].contains(&&**section)
["deconv"].contains(&&**section)
// grouping and depthwise
|| unit.starts_with("group")
// conv 3D
Expand Down
1 change: 0 additions & 1 deletion tflite/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ impl TfliteProtoModel {

pub fn root(&self) -> tflite::Model {
unsafe { tflite::root_as_model_unchecked(&self.0) }
// tflite::model::Model::from_buffer(&self.0).context("Failed to read flat buffer model")
}
}

Expand Down
17 changes: 8 additions & 9 deletions tflite/src/ops/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,8 @@ fn ser_downsample(
node: &TypedNode,
op: &Downsample,
) -> TractResult<()> {
dbg!(node.to_string());
dbg!(op);
let input_fact = model.outlet_fact(node.inputs[0])?;
dbg!(input_fact);
let output_fact = model.outlet_fact(node.id.into())?;
dbg!(output_fact);
let begins = tvec!(0i32; input_fact.rank());
let mut begins = tvec!(0i32; input_fact.rank());
let mut ends = input_fact
.shape
.as_concrete()
Expand All @@ -258,8 +253,12 @@ fn ser_downsample(
.collect::<TVec<_>>();
let mut strides = tvec!(1; input_fact.rank());
strides[op.axis] = op.stride as i32;
ends[op.axis] =
begins[op.axis] + op.stride as i32 * output_fact.shape[op.axis].as_i64().unwrap() as i32 + op.stride.signum() as i32;
if op.modulo > 0 {
begins[op.axis] = op.modulo as i32;
} else if op.stride < 0 {
begins[op.axis] = -1;
ends[op.axis] = 0;
}
let mut inputs = tvec!(builder.outlets_to_tensors[&node.inputs[0]]);
inputs.push(builder.write_fact(format!("{}.begins", node.name), tensor1(&begins))?);
inputs.push(builder.write_fact(format!("{}.ends", node.name), tensor1(&ends))?);
Expand All @@ -269,7 +268,7 @@ fn ser_downsample(
builder.fb(),
&StridedSliceOptionsArgs {
begin_mask: 0,
end_mask: 0,
end_mask: 1 << op.axis,
ellipsis_mask: 0,
new_axis_mask: 0,
shrink_axis_mask: 0,
Expand Down

0 comments on commit 4bd34f2

Please sign in to comment.