diff --git a/hir/src/ops/activations.rs b/hir/src/ops/activations.rs index 816f22f0d2..44fc10bd38 100644 --- a/hir/src/ops/activations.rs +++ b/hir/src/ops/activations.rs @@ -82,6 +82,22 @@ activation!(Softsign, |_op, name: &str, model: &mut TypedModel, inputs| { Ok(wire) }); +#[derive(Debug, Clone, new)] +pub struct Celu(pub f32); + +activation!(Celu, |op, name: &str, model: &mut TypedModel, inputs| { + cst!(model, inputs, name, zero, 0.0); + cst!(model, inputs, name, one, 1.0); + cst!(model, inputs, name, alpha, op.0); + let x_over_alpha = model.wire_node(name.to_string() + ".x_over_alpha", div(), &[inputs[0], alpha])?; + let x_over_alpha_exp = model.wire_node(name.to_string() + ".exp", exp(), &[x_over_alpha[0]])?; + let minus_one = model.wire_node(name.to_string() + ".minus_one", sub(), &[x_over_alpha_exp[0], one])?; + let wire = model.wire_node(name.to_string() + ".sat-zero", min(), &[zero, minus_one[0]])?; + let relu = model.wire_node(name.to_string() + ".relu", max(), &[zero, inputs[0]])?; + let wire = model.wire_node(name.to_string(), add(), &[relu[0], wire[0]])?; + Ok(wire) +}); + #[derive(Debug, Clone, new)] pub struct Elu(pub f32); diff --git a/onnx/src/ops/nn/mod.rs b/onnx/src/ops/nn/mod.rs index 736621180f..9a24e3ccfd 100644 --- a/onnx/src/ops/nn/mod.rs +++ b/onnx/src/ops/nn/mod.rs @@ -33,6 +33,7 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) { reg.insert("ArgMin", arg_max_min); reg.insert("AveragePool", average_pool); reg.insert("BatchNormalization", batch_normalization); + reg.insert("Celu", celu); reg.insert("Conv", conv); reg.insert("ConvInteger", conv_integer); reg.insert("ConvTranspose", conv_transpose::conv_transpose); @@ -216,6 +217,14 @@ pub fn average_pool( )) } +pub fn celu( + _ctx: &ParsingContext, + node: &NodeProto, +) -> TractResult<(Box, Vec)> { + let alpha = node.get_attr_opt("alpha")?.unwrap_or(1.); + Ok((expand(ops::activations::Celu(alpha)), vec![])) +} + pub fn elu( _ctx: &ParsingContext, node: &NodeProto, diff --git a/test-rt/suite-onnx/node.txt b/test-rt/suite-onnx/node.txt index fc2776b76b..8941ff966c 100644 --- a/test-rt/suite-onnx/node.txt +++ b/test-rt/suite-onnx/node.txt @@ -126,6 +126,7 @@ test_castlike_STRING_to_FLOAT test_castlike_STRING_to_FLOAT_expanded not-nnef test_ceil test_ceil_example +test_celu test_celu_expanded test_clip test_clip_default_inbounds diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index 25233e2413..5c3d260e4e 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -54,10 +54,11 @@ fn ignore_onnx(t: &[String]) -> bool { test_slice test_split - where - less - greater - equal + test_where + test_less + test_greater + test_equal + test_not test_add test_mul @@ -70,7 +71,9 @@ fn ignore_onnx(t: &[String]) -> bool { test_softmax test_abs + test_ceil test_exp + test_floor test_log test_reciprocal test_square diff --git a/tflite/src/ops/element_wise.rs b/tflite/src/ops/element_wise.rs index ac2cd0169b..cb750dc80f 100644 --- a/tflite/src/ops/element_wise.rs +++ b/tflite/src/ops/element_wise.rs @@ -3,10 +3,11 @@ use crate::ser::{BuiltinOp, SubgraphBuilder}; use crate::tflite::{ AbsOptions, AbsOptionsArgs, BuiltinOperator, BuiltinOptions, CosOptions, CosOptionsArgs, ExpOptions, ExpOptionsArgs, HardSwishOptions, HardSwishOptionsArgs, LeakyReluOptions, - LeakyReluOptionsArgs, SquareOptions, SquareOptionsArgs, + LeakyReluOptionsArgs, LogicalNotOptionsArgs, SquareOptions, SquareOptionsArgs, LogicalNotOptions, }; use tract_core::internal::*; use tract_core::ops::element_wise::ElementWiseOp; +use tract_core::ops::logic::{ Not, not }; use tract_core::ops::math::*; use tract_core::ops::nn::{hard_swish, leaky_relu, HardSwish, LeakyRelu}; @@ -14,11 +15,14 @@ pub fn register_all(reg: &mut Registry) { reg.reg_to_tflite(ser); reg.reg_to_tract(BuiltinOperator::ABS, |op| deser(op, abs())); + reg.reg_to_tract(BuiltinOperator::CEIL, |op| deser(op, ceil())); reg.reg_to_tract(BuiltinOperator::COS, |op| deser(op, cos())); reg.reg_to_tract(BuiltinOperator::EXP, |op| deser(op, exp())); + reg.reg_to_tract(BuiltinOperator::FLOOR, |op| deser(op, floor())); reg.reg_to_tract(BuiltinOperator::HARD_SWISH, |op| deser(op, hard_swish())); reg.reg_to_tract(BuiltinOperator::LEAKY_RELU, de_leaky_relu); reg.reg_to_tract(BuiltinOperator::LOG, |op| deser(op, ln())); + reg.reg_to_tract(BuiltinOperator::LOGICAL_NOT, |op| deser(op, not())); reg.reg_to_tract(BuiltinOperator::SIN, |op| deser(op, sin())); reg.reg_to_tract(BuiltinOperator::SQRT, |op| deser(op, sqrt())); reg.reg_to_tract(BuiltinOperator::SQUARE, |op| deser(op, square())); @@ -83,6 +87,14 @@ fn ser( BuiltinOp::new(98, 1, BuiltinOperator::LEAKY_RELU, BuiltinOptions::LeakyReluOptions), options.as_union_value(), ) + } else if (*op.0).is::() { + let options = LogicalNotOptions::create(builder.fb(), &LogicalNotOptionsArgs {}); + builder.write_op_with_options( + &[input], + &[output], + BuiltinOp::new(87, 1, BuiltinOperator::LOGICAL_NOT, BuiltinOptions::LogicalNotOptions), + options.as_union_value(), + ) } else if (*op.0).is::() { let options = SquareOptions::create(builder.fb(), &SquareOptionsArgs {}); builder.write_op_with_options( @@ -91,6 +103,10 @@ fn ser( BuiltinOp::new(92, 1, BuiltinOperator::SQUARE, BuiltinOptions::SquareOptions), options.as_union_value(), ) + } else if (*op.0).is::() { + builder.write_op(&[input], &[output], 104, 1, BuiltinOperator::CEIL) + } else if (*op.0).is::() { + builder.write_op(&[input], &[output], 8, 1, BuiltinOperator::FLOOR) } else if (*op.0).is::() { builder.write_op(&[input], &[output], 66, 1, BuiltinOperator::SIN) } else if (*op.0).is::() {