Skip to content

Commit

Permalink
more tests (selu)
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 16, 2023
1 parent 0ed127d commit 8367ccb
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
5 changes: 3 additions & 2 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ pub const STAGES: &[&str] = &[
"set-declutter",
"nnef-cycle",
"nnef-cycle-declutter",
"tlite-cycle",
"tfile-cycle-declutter",
"tflite-cycle-predump",
"tflite-cycle",
"tflite-cycle-declutter",
"before-optimize",
"optimize",
];
Expand Down
4 changes: 4 additions & 0 deletions cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,10 @@ impl Parameters {
}
#[cfg(feature = "tflite")]
if matches.is_present("tflite-cycle") {
stage!("tflite-cycle-predump", typed_model -> typed_model, |mut m:TypedModel| {
tract_tflite::rewriter::rewrite_for_tflite(&mut m)?;
Ok(m)
});
stage!("tflite-cycle", typed_model -> typed_model, |m:TypedModel| {
let tflite = tract_tflite::tflite();
let mut vec = vec!();
Expand Down
1 change: 0 additions & 1 deletion test-rt/test-tflite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ mod tflite_cycle {
info!("Store to Tflite");
let mut buffer = vec![];
self.0.write(&model, &mut buffer)?;
std::fs::write("foo.tfllite", &buffer)?;
info!("Reload from Tflite");
let mut reloaded = self.0.model_for_read(&mut &*buffer)?;
for i in 0..model.inputs.len() {
Expand Down
7 changes: 5 additions & 2 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,23 @@ fn ignore_onnx(t: &[String]) -> bool {
test_abs
test_exp
test_hardswish
test_log
test_reciprocal
test_square
test_sqrt
test_rsqrt
test_cos
test_sin
# lol, no tan :)
test_clip
test_batchnorm
test_hardswish
test_selu
",

test_reciprocal",
);
let excluded = patterns("
test_slice_start_out_of_bounds
Expand Down
10 changes: 9 additions & 1 deletion tflite/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl<'f, 'b> ModelBuilder<'f, 'b> {
pub struct SubgraphBuilder<'f, 'b, 'mb> {
pub model: &'mb mut ModelBuilder<'f, 'b>,
pub tensors: Vec<WIPOffset<Tensor<'f>>>,
pub const_cache: Vec<(Arc<tract_core::prelude::Tensor>, i32)>,
pub operators: Vec<WIPOffset<Operator<'f>>>,
pub outlets_to_tensors: HashMap<OutletId, i32>,
}
Expand All @@ -93,6 +94,7 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> {
tensors: vec![],
operators: vec![],
outlets_to_tensors: HashMap::new(),
const_cache: vec!(),
}
}

Expand All @@ -119,6 +121,7 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> {
) -> TractResult<TVec<i32>> {
outlets.into_iter().map(|o| self.map_outlet(model, *o.borrow())).collect()
}

pub fn write_fact(
&mut self,
name: impl AsRef<str>,
Expand Down Expand Up @@ -197,6 +200,11 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> {
) -> TractResult<i32> {
let fact = fact.into();
let buffer = if let Some(k) = &fact.konst {
if let Some(pair) = self.const_cache.iter().find(|(t, _id)| t == k) {
return Ok(pair.1);
}
self.const_cache.push((k.clone(), self.tensors.len() as i32));

let data = self.fb().create_vector(unsafe { k.as_bytes() });
let buffer = Buffer::create(self.fb(), &BufferArgs { data: Some(data) });
self.model.buffers.push(buffer);
Expand Down Expand Up @@ -326,7 +334,7 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> {
}

fn finish(self, model: &TypedModel) -> TractResult<WIPOffset<SubGraph<'f>>> {
let Self { model: ModelBuilder { builder, .. }, tensors, operators, outlets_to_tensors } =
let Self { model: ModelBuilder { builder, .. }, tensors, operators, outlets_to_tensors, .. } =
self;
let inputs = model.inputs.iter().map(|i| outlets_to_tensors[i]).collect_vec();
let outputs = model.outputs.iter().map(|i| outlets_to_tensors[i]).collect_vec();
Expand Down

0 comments on commit 8367ccb

Please sign in to comment.