Skip to content

Commit

Permalink
no more dep on hir
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 5, 2023
1 parent d30b91c commit 4d32c3b
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 73 deletions.
4 changes: 1 addition & 3 deletions harness/core-proptest-pulse/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ license = "MIT OR Apache-2.0"
edition = "2021"

[dependencies]
tract-hir = { path = "../../hir", version = "=0.20.19-pre" }
tract-core = { path = "../../core", version = "=0.20.19-pre" }
tract-pulse = { path = "../../pulse", version = "=0.20.19-pre" }
tract-onnx = { path = "../../onnx", version = "=0.20.19-pre" }
tract-onnx-opl = { path = "../../onnx-opl", version = "=0.20.19-pre" }

[dev-dependencies]
log.workspace = true
Expand Down
6 changes: 2 additions & 4 deletions harness/core-proptest-pulse/src/conv_plus_conv.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use proptest::proptest;
use proptest::test_runner::TestCaseResult;
use tract_hir::internal::*;
use tract_hir::ops::cnn::*;
use tract_hir::prelude::tract_itertools::Itertools;
use tract_core::tract_data::itertools::Itertools;

use super::*;

Expand All @@ -21,7 +19,7 @@ impl ConvOp {
name,
ConvUnary {
pool_spec: PoolSpec {
data_format: tract_hir::ops::nn::DataFormat::NCHW,
data_format: DataFormat::NCHW,
kernel_shape: self.ker.shape()[2..].into(),
padding: self.padding.clone(),
dilations: Some(tvec!(self.dilation)),
Expand Down
36 changes: 17 additions & 19 deletions harness/core-proptest-pulse/src/deconv.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use proptest::proptest;
use proptest::test_runner::TestCaseResult;
use tract_hir::internal::*;
use tract_hir::ops::cnn::{self, PoolSpec};

use super::*;

Expand All @@ -11,14 +9,14 @@ struct DeconvOp {
dilation: usize,
adj: usize,
ker: Array3<f32>,
padding: cnn::PaddingSpec,
padding: PaddingSpec,
}

impl DeconvOp {
fn chain(&self, name: &str, model: &mut TypedModel, after: OutletId) -> OutletId {
let deconv = tract_core::ops::cnn::DeconvUnary {
pool_spec: PoolSpec {
data_format: tract_hir::ops::nn::DataFormat::NCHW,
data_format: DataFormat::NCHW,
kernel_shape: tvec!(self.ker.shape()[2]),
padding: self.padding.clone(),
strides: Some(self.stride).filter(|d| *d > 1).map(|d| tvec!(d)),
Expand Down Expand Up @@ -46,15 +44,15 @@ impl Arbitrary for DeconvOp {
0usize..4,
vec(1usize..4),
prop_oneof![
Just(cnn::PaddingSpec::Valid),
Just(cnn::PaddingSpec::SameUpper),
Just(cnn::PaddingSpec::SameLower)
Just(PaddingSpec::Valid),
Just(PaddingSpec::SameUpper),
Just(PaddingSpec::SameLower)
],
)
.prop_filter(
"Same padding geometry constraint",
|(stride, dilation, _adj, ker, padding)| {
padding == &cnn::PaddingSpec::Valid || ((ker.len() - 1) * dilation > stride - 1)
padding == &PaddingSpec::Valid || ((ker.len() - 1) * dilation > stride - 1)
},
)
.prop_map(|(stride, dilation, adj, ker, padding)| DeconvOp {
Expand Down Expand Up @@ -122,7 +120,7 @@ fn example_0() {
dilation: 1,
adj: 0,
ker: arr3(&[[[1.0f32]]]),
padding: cnn::PaddingSpec::Valid,
padding: PaddingSpec::Valid,
},
};
pb.run().unwrap()
Expand All @@ -138,7 +136,7 @@ fn example_1() {
dilation: 1,
adj: 0,
ker: arr3(&[[[0.0f32, 0.0]]]),
padding: cnn::PaddingSpec::Valid,
padding: PaddingSpec::Valid,
},
};
pb.run().unwrap()
Expand All @@ -154,7 +152,7 @@ fn example_2() {
dilation: 1,
adj: 0,
ker: arr3(&[[[0.0f32, 1.0]]]),
padding: cnn::PaddingSpec::Valid,
padding: PaddingSpec::Valid,
},
};
pb.run().unwrap()
Expand All @@ -170,7 +168,7 @@ fn example_3() {
dilation: 1,
adj: 0,
ker: arr3(&[[[0.0f32, 1.0]]]),
padding: cnn::PaddingSpec::Valid,
padding: PaddingSpec::Valid,
},
};
pb.run().unwrap()
Expand All @@ -186,7 +184,7 @@ fn dilation_0() {
dilation: 2,
adj: 0,
ker: arr3(&[[[0.0f32, 0.0]]]),
padding: cnn::PaddingSpec::Valid,
padding: PaddingSpec::Valid,
},
};
pb.run().unwrap()
Expand All @@ -202,7 +200,7 @@ fn dilation_1() {
dilation: 2,
adj: 0,
ker: arr3(&[[[0.0f32, 1.0]]]),
padding: cnn::PaddingSpec::SameUpper,
padding: PaddingSpec::SameUpper,
},
};
pb.run().unwrap()
Expand All @@ -218,7 +216,7 @@ fn stride_0() {
dilation: 1,
adj: 0,
ker: arr3(&[[[1.0f32]]]),
padding: cnn::PaddingSpec::Valid,
padding: PaddingSpec::Valid,
},
};
pb.run().unwrap()
Expand All @@ -234,7 +232,7 @@ fn same_upper_0() {
dilation: 1,
adj: 0,
ker: arr3(&[[[0.0f32, 1.0]]]),
padding: cnn::PaddingSpec::SameUpper,
padding: PaddingSpec::SameUpper,
},
};
pb.run().unwrap()
Expand All @@ -250,7 +248,7 @@ fn adj_0() {
dilation: 1,
adj: 1,
ker: arr3(&[[[0.0f32]]]),
padding: cnn::PaddingSpec::Valid,
padding: PaddingSpec::Valid,
},
};
pb.run().unwrap()
Expand All @@ -265,9 +263,9 @@ fn deconv2d() {
kernel.as_slice_mut::<f32>().unwrap().iter_mut().enumerate().for_each(|(ix, x)| *x = ix as f32);
let deconv = tract_core::ops::cnn::DeconvUnary {
pool_spec: PoolSpec {
data_format: tract_hir::ops::nn::DataFormat::NCHW,
data_format: DataFormat::NCHW,
kernel_shape: tvec!(1, 3),
padding: cnn::PaddingSpec::Explicit(tvec!(0, 1), tvec!(0, 1)),
padding: PaddingSpec::Explicit(tvec!(0, 1), tvec!(0, 1)),
strides: Some(tvec!(1, 2)),
dilations: Some(tvec![1, 1]),
output_channel_override: Some(2),
Expand Down
22 changes: 9 additions & 13 deletions harness/core-proptest-pulse/src/delay_plus_downsample.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use proptest::proptest;
use proptest::test_runner::TestCaseResult;
use tract_hir::internal::*;
use tract_hir::ops::array;
use tract_hir::prelude::tract_itertools::Itertools;
use tract_hir::tract_core::ops::Downsample;
use tract_core::ops::Downsample;
use tract_core::tract_data::itertools::Itertools;

use super::*;

Expand Down Expand Up @@ -48,13 +46,12 @@ impl Arbitrary for DelayPlusDownsampleProblem {

impl DelayPlusDownsampleProblem {
pub fn run(&self) -> TestCaseResult {
let mut model = InferenceModel::default();
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model
.add_source("a", f32::fact(dims!(1, s, 1)).into())
.unwrap();
let a = model.add_source("a", f32::fact(dims!(1, s, 1)).into()).unwrap();
let crop =
model.wire_node("delay", expand(array::Crop::new(1, self.delay, 0)), &[a]).unwrap();
// model.wire_node("delay", expand(array::Crop::new(1, self.delay, 0)), &[a]).unwrap();
model.wire_node("delay", Slice::new(1, self.delay, s), &[a]).unwrap();
let ds = model
.wire_node(
"ds",
Expand All @@ -63,8 +60,6 @@ impl DelayPlusDownsampleProblem {
)
.unwrap();
model.set_output_outlets(&ds).unwrap();
let model = model.into_typed().unwrap();
dbg!(&model);
proptest_regular_against_pulse(model, self.pulse as _, t(self.input), 1)
}
}
Expand All @@ -86,7 +81,9 @@ fn test_delay() {

#[test]
fn test_from_convs() {
DelayPlusDownsampleProblem { input: 5, pulse: 2, delay: 1, stride: 2, modulo: 0 }.run().unwrap();
DelayPlusDownsampleProblem { input: 5, pulse: 2, delay: 1, stride: 2, modulo: 0 }
.run()
.unwrap();
}

#[test]
Expand All @@ -103,4 +100,3 @@ fn test_big_delay() {
fn test_huge_delay() {
DelayPlusDownsampleProblem { input: 4, pulse: 2, delay: 1, stride: 2, modulo: 0 }.run().unwrap()
}

20 changes: 7 additions & 13 deletions harness/core-proptest-pulse/src/delay_plus_pool.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use proptest::proptest;
use proptest::test_runner::TestCaseResult;
use tract_hir::internal::*;
use tract_hir::ops::cnn::PaddingSpec;
use tract_hir::ops::{array, cnn, nn};
use tract_core::ops::cnn::MaxPool;

use super::*;

Expand Down Expand Up @@ -47,25 +45,21 @@ impl Arbitrary for DelayPlusPoolProblem {

impl DelayPlusPoolProblem {
pub fn run(&self) -> TestCaseResult {
let mut model = InferenceModel::default();
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model
.add_source("a", f32::fact(dims!(1, s, 1)).into())
.unwrap();
let crop =
model.wire_node("delay", expand(array::Crop::new(1, self.delay, 0)), &[a]).unwrap();
let pool_spec = cnn::PoolSpec::new(
nn::DataFormat::NHWC,
let a = model.add_source("a", f32::fact(dims!(1, s, 1)).into()).unwrap();
let crop = model.wire_node("delay", Slice::new(1, self.delay, s), &[a]).unwrap();
let pool_spec = PoolSpec::new(
DataFormat::NHWC,
tvec!(self.pool_window),
self.padding.clone(),
None,
Some(tvec!(self.stride)),
None,
);
let pool = model.wire_node("pool", cnn::MaxPool::new(pool_spec, None), &crop).unwrap();
let pool = model.wire_node("pool", MaxPool::new(pool_spec, None), &crop).unwrap();
model.set_output_outlets(&pool).unwrap();
let input = arr1(&self.input).into_shape((1, self.input.len(), 1)).unwrap().into_dyn();
let model = model.into_typed().unwrap();
proptest_regular_against_pulse(model, self.pulse as _, input, 1)
}
}
Expand Down
1 change: 0 additions & 1 deletion harness/core-proptest-pulse/src/einsum.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use tract_hir::internal::*;
use tract_core::ops::einsum::*;

use super::*;
Expand Down
32 changes: 17 additions & 15 deletions harness/core-proptest-pulse/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@ use proptest::prelude::*;
use proptest::proptest;
use proptest::test_runner::TestCaseResult;
use proptest::*;
use tract_hir::internal::*;
use tract_hir::tract_num_traits::Zero;
use tract_ndarray::*;
use tract_core::ndarray::arr3;
use tract_core::num_traits::Zero;
use tract_core::ops::array::Pad;
use tract_core::ops::array::PadMode;
use tract_core::ops::array::Slice;
use tract_core::ops::cnn::ConvUnary;
use tract_core::ops::cnn::KernelFormat;
use tract_core::ops::cnn::PaddingSpec;
use tract_core::ops::cnn::PoolSpec;
use tract_core::ops::nn::DataFormat;
use tract_ndarray::prelude::*;
use tract_pulse::internal::*;

mod conv_plus_conv;
Expand All @@ -27,7 +35,7 @@ fn setup_test_logger() {
fn proptest_regular_against_pulse(
model: TypedModel,
pulse: usize,
input_array: ArrayD<f32>,
input_array: tract_ndarray::ArrayD<f32>,
axis: usize,
) -> TestCaseResult {
setup_test_logger();
Expand Down Expand Up @@ -80,7 +88,7 @@ fn proptest_regular_against_pulse(
if to_write_in_chunk < pulse {
let mut filler_shape = input_array.shape().to_vec();
filler_shape[axis] = pulse - to_write_in_chunk;
chunk = concatenate(
chunk = tract_ndarray::concatenate(
Axis(axis),
&[chunk.view(), ArrayD::from_elem(filler_shape, std::f32::NAN).view()],
)
Expand All @@ -94,7 +102,7 @@ fn proptest_regular_against_pulse(
.map(|n| n.max(0) as usize);
}
let mut outputs = state.run(tvec!(chunk.into_tensor().into_tvalue())).unwrap();
got = concatenate(
got = tract_ndarray::concatenate(
Axis(output_stream_axis),
&[got.view(), outputs.remove(0).to_array_view::<f32>().unwrap()],
)
Expand Down Expand Up @@ -128,7 +136,6 @@ fn proptest_regular_against_pulse(
proptest! {
#[test]
fn proptest_crop(pulse in 1i32..3, input_len in 0i32..10, begin in 0i32..3, end in 0i32..3) {
use tract_hir::ops::array::Slice;
let full_len = input_len + begin + end;
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
Expand All @@ -142,11 +149,10 @@ proptest! {

#[test]
fn proptest_pad(pulse in 1i32..3, input_len in 0i32..10, begin in 0i32..3, end in 0i32..3) {
use tract_hir::ops::array::{ Pad, PadMode };
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(&[s]).into()).unwrap();
let pad = model.wire_node("pad",Pad::new(vec![(begin as _, end as _)],
let pad = model.wire_node("pad", Pad::new(vec![(begin as _, end as _)],
PadMode::Constant(Arc::new(Tensor::from(-1f32)))), &[a]).unwrap();
model.set_output_outlets(&pad).unwrap();

Expand All @@ -162,8 +168,6 @@ fn vec(len: impl Strategy<Value = usize>) -> impl Strategy<Value = Vec<f32>> {

#[test]
fn test_simple_conv() {
use tract_hir::ops::cnn::*;

let mut model = TypedModel::default();
let kernel = rctensor3(&[[[0.5f32, 1.0, -0.1]]]);
let s = model.symbol_table.sym("S");
Expand All @@ -174,14 +178,14 @@ fn test_simple_conv() {
"conv",
ConvUnary {
pool_spec: PoolSpec {
data_format: tract_hir::ops::nn::DataFormat::NCHW,
data_format: DataFormat::NCHW,
kernel_shape: tvec!(3),
padding: PaddingSpec::Valid,
dilations: None,
strides: None,
output_channel_override: Some(1),
},
kernel_fmt: tract_core::ops::cnn::KernelFormat::OIHW,
kernel_fmt: KernelFormat::OIHW,
kernel,
group: 1,
bias: None,
Expand All @@ -198,7 +202,6 @@ fn test_simple_conv() {

#[test]
fn test_pad_before_1() {
use tract_hir::ops::array::{Pad, PadMode};
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(&[s]).into()).unwrap();
Expand All @@ -217,7 +220,6 @@ fn test_pad_before_1() {

#[test]
fn test_pad_before_2() {
use tract_hir::ops::array::{Pad, PadMode};
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(&[s]).into()).unwrap();
Expand Down
Loading

0 comments on commit 4d32c3b

Please sign in to comment.