From b92467ceae7eb5068a6580d3e7d2fc87aea03620 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 1 Sep 2023 16:26:51 +0200 Subject: [PATCH 1/3] switch to tflitec --- Cargo.toml | 2 +- test-rt/test-tflite/Cargo.toml | 2 +- test-rt/test-tflite/src/tflite_runtime.rs | 59 ++++++++++++----------- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4e527e346c..e4ef14f175 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,7 +112,7 @@ string-interner = "0.14" tar = "0.4.37" tempfile = "3.8" tensorflow = "0.17.0" -tflite = { git = "https://github.com/kali/tflite-rs.git", rev="61d2aa7" } +tflitec = { git = "https://github.com/kali/tflitec-rs.git", rev="b37d65a" } time = "=0.3.23" time-macros = "=0.2.10" tokenizers = "0.13" diff --git a/test-rt/test-tflite/Cargo.toml b/test-rt/test-tflite/Cargo.toml index c524331711..4da24b1bc7 100644 --- a/test-rt/test-tflite/Cargo.toml +++ b/test-rt/test-tflite/Cargo.toml @@ -13,7 +13,7 @@ suite-conv = { path = "../suite-conv" } [dev-dependencies] lazy_static.workspace = true log.workspace = true -tflite.workspace = true +tflitec.workspace = true tract-tflite = { path = "../../tflite", version = "=0.20.19-pre" } tract-onnx-opl = { path = "../../onnx-opl", version = "=0.20.19-pre" } infra = { path = "../infra" } diff --git a/test-rt/test-tflite/src/tflite_runtime.rs b/test-rt/test-tflite/src/tflite_runtime.rs index 40398116ee..c8f089f081 100644 --- a/test-rt/test-tflite/src/tflite_runtime.rs +++ b/test-rt/test-tflite/src/tflite_runtime.rs @@ -1,8 +1,8 @@ +use tflitec::interpreter::Interpreter; +use tflitec::model::Model; +use tflitec::tensor::DataType; + use super::*; -use tflite::ops::builtin::BuiltinOpResolver; -use tflite::FlatBufferModel; -use tflite::Interpreter; -use tflite::InterpreterBuilder; struct TfliteRuntime(Tflite); @@ -23,47 +23,48 @@ impl Runtime for TfliteRuntime { } } +#[derive(Clone)] struct TfliteRunnable(Vec, TVec); impl Runnable for TfliteRunnable { fn spawn(&self) -> TractResult> { - let fb = FlatBufferModel::build_from_buffer(self.0.clone())?; - let resolver = BuiltinOpResolver::default(); - let builder = InterpreterBuilder::new(fb, resolver)?; - let mut interpreter = builder.build()?; - interpreter.allocate_tensors()?; - Ok(Box::new(TfliteState(interpreter, self.1.clone()))) + Ok(Box::new(TfliteState(self.clone()))) } } -struct TfliteState(Interpreter<'static, BuiltinOpResolver>, TVec); +struct TfliteState(TfliteRunnable); impl State for TfliteState { fn run(&mut self, inputs: TVec) -> TractResult> { - ensure!(inputs.len() == self.0.inputs().len()); + let model = Model::from_bytes(&self.0 .0)?; + let interpreter = Interpreter::new(&model, None)?; + interpreter.allocate_tensors()?; + ensure!(inputs.len() == interpreter.input_tensor_count()); for (ix, input) in inputs.iter().enumerate() { - let input_ix = self.0.inputs()[ix]; - let input_tensor = self.0.tensor_info(input_ix).unwrap(); - assert_eq!(input_tensor.dims, input.shape()); - self.0.tensor_buffer_mut(input_ix).unwrap().copy_from_slice(unsafe { input.as_bytes() }) + let input_tensor = interpreter.input(ix)?; + dbg!(&input_tensor); + assert_eq!(input_tensor.shape().dimensions(), input.shape()); + dbg!(&input); + input_tensor.set_data(unsafe { input.as_bytes() })?; } - self.0.invoke()?; + interpreter.invoke()?; let mut outputs = tvec![]; - for ix in 0..self.0.outputs().len() { - let output_ix = self.0.outputs()[ix]; - let output_tensor = self.0.tensor_info(output_ix).unwrap(); - let dt = match output_tensor.element_kind as u32 { - 1 => f32::datum_type(), - 9 => self.1[ix].clone(), // impossible to retrieve QP from this TFL binding + for ix in 0..interpreter.output_tensor_count() { + let output_tensor = interpreter.output(ix)?; + let dt = match output_tensor.data_type() { + DataType::Float32 => f32::datum_type(), + DataType::Int64 => i64::datum_type(), + DataType::Int8 => { + if let Some(qp) = output_tensor.quantization_parameters() { + i8::datum_type().quantize(QParams::ZpScale { zero_point: qp.zero_point, scale: qp.scale }) + } else { + i8::datum_type() + } + } _ => bail!("unknown type"), }; - dbg!(self.0.tensor_buffer(output_ix)); let tensor = unsafe { - Tensor::from_raw_dt( - dt, - &output_tensor.dims, - self.0.tensor_buffer(output_ix).unwrap(), - )? + Tensor::from_raw_dt(dt, &output_tensor.shape().dimensions(), output_tensor.data())? }; outputs.push(tensor.into_tvalue()); } From 8fedcdd88ce5e6c0b68f0e6b316eab33013eb5a7 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Sun, 3 Sep 2023 10:04:01 +0200 Subject: [PATCH 2/3] numpy needed for tflitec --- .travis/onnx-tests.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.travis/onnx-tests.sh b/.travis/onnx-tests.sh index 5e12d6a67f..5d7e9be5db 100755 --- a/.travis/onnx-tests.sh +++ b/.travis/onnx-tests.sh @@ -17,8 +17,7 @@ then brew install coreutils elif [ -n "$GITHUB_ACTIONS" ] then - # this seems to help with tflite / bindgen obscure bug - sudo apt-get install -y libclang-dev + pip install numpy fi # if [ `uname` = "Linux" -a -z "$TRAVIS" ] From d2bf631f138d0c5e8e7194e6b5deee42c556f489 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Sun, 3 Sep 2023 12:02:41 +0200 Subject: [PATCH 3/3] more space management ? --- .travis/regular-tests.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis/regular-tests.sh b/.travis/regular-tests.sh index e01b85be14..3c87bf0ac1 100755 --- a/.travis/regular-tests.sh +++ b/.travis/regular-tests.sh @@ -52,8 +52,10 @@ fi for c in data linalg core nnef hir onnx pulse onnx-opl pulse-opl rs proxy do + df -h cargo -q test $CARGO_EXTRA -q -p tract-$c done + # doc test are not finding libtensorflow.so if ! cargo -q test $CARGO_EXTRA -q -p tract-tensorflow --lib $ALL_FEATURES then