diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index 25233e2413..8be4a8eea3 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -54,10 +54,11 @@ fn ignore_onnx(t: &[String]) -> bool { test_slice test_split - where - less - greater - equal + test_where + test_less + test_greater + test_equal + test_not test_add test_mul diff --git a/tflite/src/ops/element_wise.rs b/tflite/src/ops/element_wise.rs index ac2cd0169b..80480c1b3f 100644 --- a/tflite/src/ops/element_wise.rs +++ b/tflite/src/ops/element_wise.rs @@ -3,10 +3,11 @@ use crate::ser::{BuiltinOp, SubgraphBuilder}; use crate::tflite::{ AbsOptions, AbsOptionsArgs, BuiltinOperator, BuiltinOptions, CosOptions, CosOptionsArgs, ExpOptions, ExpOptionsArgs, HardSwishOptions, HardSwishOptionsArgs, LeakyReluOptions, - LeakyReluOptionsArgs, SquareOptions, SquareOptionsArgs, + LeakyReluOptionsArgs, LogicalNotOptionsArgs, SquareOptions, SquareOptionsArgs, LogicalNotOptions, }; use tract_core::internal::*; use tract_core::ops::element_wise::ElementWiseOp; +use tract_core::ops::logic::{ Not, not }; use tract_core::ops::math::*; use tract_core::ops::nn::{hard_swish, leaky_relu, HardSwish, LeakyRelu}; @@ -19,6 +20,7 @@ pub fn register_all(reg: &mut Registry) { reg.reg_to_tract(BuiltinOperator::HARD_SWISH, |op| deser(op, hard_swish())); reg.reg_to_tract(BuiltinOperator::LEAKY_RELU, de_leaky_relu); reg.reg_to_tract(BuiltinOperator::LOG, |op| deser(op, ln())); + reg.reg_to_tract(BuiltinOperator::LOGICAL_NOT, |op| deser(op, not())); reg.reg_to_tract(BuiltinOperator::SIN, |op| deser(op, sin())); reg.reg_to_tract(BuiltinOperator::SQRT, |op| deser(op, sqrt())); reg.reg_to_tract(BuiltinOperator::SQUARE, |op| deser(op, square())); @@ -83,6 +85,14 @@ fn ser( BuiltinOp::new(98, 1, BuiltinOperator::LEAKY_RELU, BuiltinOptions::LeakyReluOptions), options.as_union_value(), ) + } else if (*op.0).is::() { + let options = LogicalNotOptions::create(builder.fb(), &LogicalNotOptionsArgs {}); + builder.write_op_with_options( + &[input], + &[output], + BuiltinOp::new(87, 1, BuiltinOperator::LOGICAL_NOT, BuiltinOptions::LogicalNotOptions), + options.as_union_value(), + ) } else if (*op.0).is::() { let options = SquareOptions::create(builder.fb(), &SquareOptionsArgs {}); builder.write_op_with_options(