Skip to content

Commit

Permalink
celu, floor, ceil
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 21, 2023
1 parent b488dfa commit 61591d0
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 0 deletions.
16 changes: 16 additions & 0 deletions hir/src/ops/activations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ activation!(Softsign, |_op, name: &str, model: &mut TypedModel, inputs| {
Ok(wire)
});

#[derive(Debug, Clone, new)]
pub struct Celu(pub f32);

activation!(Celu, |op, name: &str, model: &mut TypedModel, inputs| {
cst!(model, inputs, name, zero, 0.0);
cst!(model, inputs, name, one, 1.0);
cst!(model, inputs, name, alpha, op.0);
let x_over_alpha = model.wire_node(name.to_string() + ".x_over_alpha", div(), &[inputs[0], alpha])?;
let x_over_alpha_exp = model.wire_node(name.to_string() + ".exp", exp(), &[x_over_alpha[0]])?;
let minus_one = model.wire_node(name.to_string() + ".minus_one", sub(), &[x_over_alpha_exp[0], one])?;
let wire = model.wire_node(name.to_string() + ".sat-zero", min(), &[zero, minus_one[0]])?;
let relu = model.wire_node(name.to_string() + ".relu", max(), &[zero, inputs[0]])?;
let wire = model.wire_node(name.to_string(), add(), &[relu[0], wire[0]])?;
Ok(wire)
});

#[derive(Debug, Clone, new)]
pub struct Elu(pub f32);

Expand Down
9 changes: 9 additions & 0 deletions onnx/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) {
reg.insert("ArgMin", arg_max_min);
reg.insert("AveragePool", average_pool);
reg.insert("BatchNormalization", batch_normalization);
reg.insert("Celu", celu);
reg.insert("Conv", conv);
reg.insert("ConvInteger", conv_integer);
reg.insert("ConvTranspose", conv_transpose::conv_transpose);
Expand Down Expand Up @@ -216,6 +217,14 @@ pub fn average_pool(
))
}

pub fn celu(
_ctx: &ParsingContext,
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let alpha = node.get_attr_opt("alpha")?.unwrap_or(1.);
Ok((expand(ops::activations::Celu(alpha)), vec![]))
}

pub fn elu(
_ctx: &ParsingContext,
node: &NodeProto,
Expand Down
1 change: 1 addition & 0 deletions test-rt/suite-onnx/node.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ test_castlike_STRING_to_FLOAT
test_castlike_STRING_to_FLOAT_expanded not-nnef
test_ceil
test_ceil_example
test_celu
test_celu_expanded
test_clip
test_clip_default_inbounds
Expand Down
2 changes: 2 additions & 0 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ fn ignore_onnx(t: &[String]) -> bool {
test_softmax
test_abs
test_ceil
test_exp
test_floor
test_log
test_reciprocal
test_square
Expand Down
6 changes: 6 additions & 0 deletions tflite/src/ops/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ pub fn register_all(reg: &mut Registry) {
reg.reg_to_tflite(ser);

reg.reg_to_tract(BuiltinOperator::ABS, |op| deser(op, abs()));
reg.reg_to_tract(BuiltinOperator::CEIL, |op| deser(op, ceil()));
reg.reg_to_tract(BuiltinOperator::COS, |op| deser(op, cos()));
reg.reg_to_tract(BuiltinOperator::EXP, |op| deser(op, exp()));
reg.reg_to_tract(BuiltinOperator::FLOOR, |op| deser(op, floor()));
reg.reg_to_tract(BuiltinOperator::HARD_SWISH, |op| deser(op, hard_swish()));
reg.reg_to_tract(BuiltinOperator::LEAKY_RELU, de_leaky_relu);
reg.reg_to_tract(BuiltinOperator::LOG, |op| deser(op, ln()));
Expand Down Expand Up @@ -101,6 +103,10 @@ fn ser(
BuiltinOp::new(92, 1, BuiltinOperator::SQUARE, BuiltinOptions::SquareOptions),
options.as_union_value(),
)
} else if (*op.0).is::<Ceil>() {
builder.write_op(&[input], &[output], 104, 1, BuiltinOperator::CEIL)
} else if (*op.0).is::<Floor>() {
builder.write_op(&[input], &[output], 8, 1, BuiltinOperator::FLOOR)
} else if (*op.0).is::<Sin>() {
builder.write_op(&[input], &[output], 66, 1, BuiltinOperator::SIN)
} else if (*op.0).is::<Sqrt>() {
Expand Down

0 comments on commit 61591d0

Please sign in to comment.