Skip to content

Commit

Permalink
simple unit tests for slice and downsample
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 4, 2023
1 parent 6762298 commit 05dfa8f
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 5 deletions.
15 changes: 10 additions & 5 deletions test-rt/infra/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ impl TestSuite {
writeln!(
rs,
" {test_suite}.get({full_id:?}).run_with_approx({full_id:?}, {runtime}, {approx})",
)
.unwrap();
)
.unwrap();
writeln!(rs, "}}").unwrap();
}
}
Expand All @@ -196,14 +196,19 @@ impl<A: Arbitrary + Test + Clone> Test for ProptestWrapper<A>
where
A::Parameters: Clone + Send + Sync,
{
fn run_with_approx(&self, id: &str, runtime: &dyn Runtime, approx: Approximation) -> TestResult {
fn run_with_approx(
&self,
id: &str,
runtime: &dyn Runtime,
approx: Approximation,
) -> TestResult {
let mut runner = TestRunner::new(Config {
failure_persistence: Some(Box::new(FileFailurePersistence::Off)),
..Config::default()
});
runner.run(&any_with::<A>(self.0.clone()), |v| {
v.run_with_approx(id, runtime, approx).unwrap();
Ok(())
v.run_with_approx(id, runtime, approx)
.map_err(|e| proptest::test_runner::TestCaseError::Fail(format!("{e:?}").into()))
})?;
Ok(())
}
Expand Down
73 changes: 73 additions & 0 deletions test-rt/suite-unit/src/downsample.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use infra::{Test, TestSuite};
use proptest::collection::vec;
use proptest::prelude::*;
use tract_core::internal::*;
use tract_core::ops::Downsample;

#[derive(Debug, Clone, Default)]
struct DownsampleProblem {
input_shape: Vec<usize>,
op: Downsample,
}

impl Arbitrary for DownsampleProblem {
type Parameters = ();
type Strategy = BoxedStrategy<DownsampleProblem>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
vec(1..10usize, 1..5usize)
.prop_flat_map(|input_shape| {
let rank = input_shape.len();
(Just(input_shape), 0..rank, 1..4isize, any::<bool>())
})
.prop_map(|(input_shape, axis, stride, dir)| DownsampleProblem {
input_shape,
op: Downsample { axis, stride: if dir { -stride } else { stride }, modulo: 0 },
})
.boxed()
}
}

impl Test for DownsampleProblem {
fn run_with_approx(
&self,
id: &str,
runtime: &dyn Runtime,
approx: Approximation,
) -> infra::TestResult {
let mut input = Tensor::zero::<f32>(&self.input_shape)?;
input.as_slice_mut::<f32>()?.iter_mut().enumerate().for_each(|(ix, x)| *x = ix as f32);

let mut slices = vec![];
let mut current =
if self.op.stride > 0 { 0isize } else { input.shape()[self.op.axis] as isize - 1 };
while current >= 0 && current < input.shape()[self.op.axis] as isize {
slices.push(input.slice(self.op.axis, current as usize, current as usize + 1)?);
current += self.op.stride;
}
let reference = Tensor::stack_tensors(self.op.axis, &slices)?;

let mut model = TypedModel::default();
model.properties.insert("tract-rt-test.id".to_string(), rctensor0(id.to_string()));
let wire = model.add_source("input", TypedFact::from(&input).without_value())?;
let output = model.wire_node("downsample", self.op.clone(), &[wire])?;
model.set_output_outlets(&output)?;
let mut output = runtime.prepare(model)?.run(tvec![input.clone().into_tvalue()])?;
let output = output.remove(0).into_tensor();
output.close_enough(&reference, approx)
}
}

pub fn suite() -> TractResult<TestSuite> {
let mut suite = TestSuite::default();
suite.add_arbitrary::<DownsampleProblem>("proptest", ());

suite.add_test(
"neg_0",
DownsampleProblem {
input_shape: vec![1],
op: Downsample { axis: 0, stride: -1, modulo: 0 },
},
false,
);
Ok(suite)
}
4 changes: 4 additions & 0 deletions test-rt/suite-unit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ use tract_ndarray::*;
pub mod conv_f32;
pub mod conv_q;
pub mod deconv;
pub mod downsample;
pub mod slice;

pub fn suite() -> TractResult<TestSuite> {
let mut suite: TestSuite = Default::default();
suite.add("conv_f32", conv_f32::suite()?);
suite.add("conv_q", conv_q::suite()?);
suite.add("deconv", deconv::suite()?);
suite.add("downsample", downsample::suite()?);
suite.add("slice", slice::suite()?);
Ok(suite)
}

Expand Down
65 changes: 65 additions & 0 deletions test-rt/suite-unit/src/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use infra::{Test, TestSuite};
use proptest::collection::vec;
use proptest::prelude::*;
use tract_core::internal::*;
use tract_core::ops::array::Slice;

#[derive(Debug, Clone, Default)]
struct SliceProblem {
input_shape: Vec<usize>,
op: Slice,
}

impl Arbitrary for SliceProblem {
type Parameters = ();
type Strategy = BoxedStrategy<SliceProblem>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
vec(0..10usize, 1..5usize)
.prop_flat_map(|input_shape| {
let rank = input_shape.len();
(Just(input_shape), 0..rank)
})
.prop_flat_map(|(shape, axis)| {
let b0 = 0..=shape[axis];
let b1 = 0..=shape[axis];
(Just(shape), Just(axis), b0, b1)
})
.prop_map(|(input_shape, axis, b0, b1)| {
let start = b0.min(b1).to_dim();
let end = b0.max(b1).to_dim();
SliceProblem { input_shape, op: Slice { axis, start, end } }
})
.boxed()
}
}

impl Test for SliceProblem {
fn run_with_approx(
&self,
id: &str,
runtime: &dyn Runtime,
approx: Approximation,
) -> infra::TestResult {
let mut input = Tensor::zero::<f32>(&self.input_shape)?;
input.as_slice_mut::<f32>()?.iter_mut().enumerate().for_each(|(ix, x)| *x = ix as f32);
let reference = input.slice(
self.op.axis,
self.op.start.to_usize().unwrap(),
self.op.end.to_usize().unwrap(),
)?;
let mut model = TypedModel::default();
model.properties.insert("tract-rt-test.id".to_string(), rctensor0(id.to_string()));
let wire = model.add_source("input", TypedFact::from(&input).without_value())?;
let output = model.wire_node("slice", self.op.clone(), &[wire])?;
model.set_output_outlets(&output)?;
let mut output = runtime.prepare(model)?.run(tvec![input.clone().into_tvalue()])?;
let output = output.remove(0).into_tensor();
output.close_enough(&reference, approx)
}
}

pub fn suite() -> TractResult<TestSuite> {
let mut suite = TestSuite::default();
suite.add_arbitrary::<SliceProblem>("proptest", ());
Ok(suite)
}

0 comments on commit 05dfa8f

Please sign in to comment.