Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 16, 2023
1 parent 8367ccb commit 1e454ae
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
4 changes: 4 additions & 0 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ fn ignore_onnx(t: &[String]) -> bool {
test_clip
test_batchnorm
test_hardswish
test_leakyrelu
test_prelu
test_relu
test_selu
test_thresholdrelu
",

);
Expand Down
21 changes: 18 additions & 3 deletions tflite/src/ops/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use crate::registry::{DeserOp, Registry};
use crate::ser::{BuiltinOp, SubgraphBuilder};
use crate::tflite::{
AbsOptions, AbsOptionsArgs, BuiltinOperator, BuiltinOptions, CosOptions, CosOptionsArgs,
ExpOptions, ExpOptionsArgs, HardSwishOptions, HardSwishOptionsArgs, SquareOptions,
SquareOptionsArgs,
ExpOptions, ExpOptionsArgs, HardSwishOptions, HardSwishOptionsArgs, LeakyReluOptions,
LeakyReluOptionsArgs, SquareOptions, SquareOptionsArgs,
};
use tract_core::internal::*;
use tract_core::ops::element_wise::ElementWiseOp;
use tract_core::ops::math::*;
use tract_core::ops::nn::{hard_swish, HardSwish};
use tract_core::ops::nn::{hard_swish, leaky_relu, HardSwish, LeakyRelu};

pub fn register_all(reg: &mut Registry) {
reg.reg_to_tflite(ser);
Expand All @@ -17,6 +17,7 @@ pub fn register_all(reg: &mut Registry) {
reg.reg_to_tract(BuiltinOperator::COS, |op| deser(op, cos()));
reg.reg_to_tract(BuiltinOperator::EXP, |op| deser(op, exp()));
reg.reg_to_tract(BuiltinOperator::HARD_SWISH, |op| deser(op, hard_swish()));
reg.reg_to_tract(BuiltinOperator::LEAKY_RELU, de_leaky_relu);
reg.reg_to_tract(BuiltinOperator::LOG, |op| deser(op, ln()));
reg.reg_to_tract(BuiltinOperator::SIN, |op| deser(op, sin()));
reg.reg_to_tract(BuiltinOperator::SQRT, |op| deser(op, sqrt()));
Expand All @@ -28,6 +29,11 @@ fn deser(op: &mut DeserOp, ew: ElementWiseOp) -> TractResult<TVec<OutletId>> {
op.ctx.target.wire_node(op.prefix, ew, op.inputs)
}

fn de_leaky_relu(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let options = builtin!(op, builtin_options_as_leaky_relu_options);
op.ctx.target.wire_node(op.prefix, leaky_relu(options.alpha()), op.inputs)
}

fn ser(
builder: &mut SubgraphBuilder,
model: &TypedModel,
Expand Down Expand Up @@ -68,6 +74,15 @@ fn ser(
BuiltinOp::new(117, 1, BuiltinOperator::HARD_SWISH, BuiltinOptions::HardSwishOptions),
options.as_union_value(),
)
} else if let Some(leaky) = (*op.0).downcast_ref::<LeakyRelu>() {
let options =
LeakyReluOptions::create(builder.fb(), &LeakyReluOptionsArgs { alpha: leaky.alpha });
builder.write_op_with_options(
&[input],
&[output],
BuiltinOp::new(98, 1, BuiltinOperator::LEAKY_RELU, BuiltinOptions::LeakyReluOptions),
options.as_union_value(),
)
} else if (*op.0).is::<Square>() {
let options = SquareOptions::create(builder.fb(), &SquareOptionsArgs {});
builder.write_op_with_options(
Expand Down

0 comments on commit 1e454ae

Please sign in to comment.