From 1e454aec3a4310736df0aea11b3a2daf89e3726a Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 15 Sep 2023 17:40:55 +0200 Subject: [PATCH] more tests --- test-rt/test-tflite/suite.rs | 4 ++++ tflite/src/ops/element_wise.rs | 21 ++++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index 2a001ee86d..ed5c021539 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -85,7 +85,11 @@ fn ignore_onnx(t: &[String]) -> bool { test_clip test_batchnorm test_hardswish + test_leakyrelu + test_prelu + test_relu test_selu + test_thresholdrelu ", ); diff --git a/tflite/src/ops/element_wise.rs b/tflite/src/ops/element_wise.rs index 64624347df..ac2cd0169b 100644 --- a/tflite/src/ops/element_wise.rs +++ b/tflite/src/ops/element_wise.rs @@ -2,13 +2,13 @@ use crate::registry::{DeserOp, Registry}; use crate::ser::{BuiltinOp, SubgraphBuilder}; use crate::tflite::{ AbsOptions, AbsOptionsArgs, BuiltinOperator, BuiltinOptions, CosOptions, CosOptionsArgs, - ExpOptions, ExpOptionsArgs, HardSwishOptions, HardSwishOptionsArgs, SquareOptions, - SquareOptionsArgs, + ExpOptions, ExpOptionsArgs, HardSwishOptions, HardSwishOptionsArgs, LeakyReluOptions, + LeakyReluOptionsArgs, SquareOptions, SquareOptionsArgs, }; use tract_core::internal::*; use tract_core::ops::element_wise::ElementWiseOp; use tract_core::ops::math::*; -use tract_core::ops::nn::{hard_swish, HardSwish}; +use tract_core::ops::nn::{hard_swish, leaky_relu, HardSwish, LeakyRelu}; pub fn register_all(reg: &mut Registry) { reg.reg_to_tflite(ser); @@ -17,6 +17,7 @@ pub fn register_all(reg: &mut Registry) { reg.reg_to_tract(BuiltinOperator::COS, |op| deser(op, cos())); reg.reg_to_tract(BuiltinOperator::EXP, |op| deser(op, exp())); reg.reg_to_tract(BuiltinOperator::HARD_SWISH, |op| deser(op, hard_swish())); + reg.reg_to_tract(BuiltinOperator::LEAKY_RELU, de_leaky_relu); reg.reg_to_tract(BuiltinOperator::LOG, |op| deser(op, ln())); reg.reg_to_tract(BuiltinOperator::SIN, |op| deser(op, sin())); reg.reg_to_tract(BuiltinOperator::SQRT, |op| deser(op, sqrt())); @@ -28,6 +29,11 @@ fn deser(op: &mut DeserOp, ew: ElementWiseOp) -> TractResult> { op.ctx.target.wire_node(op.prefix, ew, op.inputs) } +fn de_leaky_relu(op: &mut DeserOp) -> TractResult> { + let options = builtin!(op, builtin_options_as_leaky_relu_options); + op.ctx.target.wire_node(op.prefix, leaky_relu(options.alpha()), op.inputs) +} + fn ser( builder: &mut SubgraphBuilder, model: &TypedModel, @@ -68,6 +74,15 @@ fn ser( BuiltinOp::new(117, 1, BuiltinOperator::HARD_SWISH, BuiltinOptions::HardSwishOptions), options.as_union_value(), ) + } else if let Some(leaky) = (*op.0).downcast_ref::() { + let options = + LeakyReluOptions::create(builder.fb(), &LeakyReluOptionsArgs { alpha: leaky.alpha }); + builder.write_op_with_options( + &[input], + &[output], + BuiltinOp::new(98, 1, BuiltinOperator::LEAKY_RELU, BuiltinOptions::LeakyReluOptions), + options.as_union_value(), + ) } else if (*op.0).is::() { let options = SquareOptions::create(builder.fb(), &SquareOptionsArgs {}); builder.write_op_with_options(