From da40fb8f59a6aca784f7822504ac04e118a1bad8 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 29 Sep 2023 09:34:39 +0200 Subject: [PATCH] leakyrelu and scan --- core/src/half.rs | 4 ++++ core/src/ops/nn/mod.rs | 5 ++++- core/src/ops/scan/mir.rs | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/core/src/half.rs b/core/src/half.rs index a9a0b68cc5..568c7582f0 100644 --- a/core/src/half.rs +++ b/core/src/half.rs @@ -4,6 +4,7 @@ use crate::ops::array::{Pad, PadMode}; use crate::ops::cnn::{ConvUnary, DeconvUnary}; use crate::ops::einsum::EinSum; use crate::ops::konst::Const; +use crate::ops::scan::Scan; use crate::ops::source::TypedSource; #[derive(Debug)] @@ -27,6 +28,9 @@ impl Translate, TypedFact, Box> for Hal bias: op.bias.as_ref().map(tensor_f32_to_f16), ..op.clone() }) + } else if let Some(op) = node.op_as::() { + let body = HalfTranslator.translate_model(&op.body)?; + Box::new(Scan { body, .. op.clone() }) } else if let Some(op) = node.op_as::() { Box::new(EinSum { operating_dt: dt_f32_to_f16(op.operating_dt), ..op.clone() }) } else if let Some(op) = node.op_as::() { diff --git a/core/src/ops/nn/mod.rs b/core/src/ops/nn/mod.rs index 1a3d52e814..777a8a370e 100644 --- a/core/src/ops/nn/mod.rs +++ b/core/src/ops/nn/mod.rs @@ -2,6 +2,8 @@ mod data_formats; mod reduce; mod softmax; +use tract_num_traits::{AsPrimitive, Zero}; + pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape}; pub use self::reduce::{Reduce, Reducer}; pub use self::softmax::Softmax; @@ -19,5 +21,6 @@ element_wise!(hard_swish, HardSwish, ); element_wise!(leaky_relu, LeakyRelu { alpha: f32 }, - [f32] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < 0. { op.alpha } else { 1.0 }); Ok(()) } + [f32] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < 0. { op.alpha } else { 1.0 }); Ok(()) }, + [f16] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < f16::zero() { AsPrimitive::::as_(op.alpha) } else { (1.0).as_() }); Ok(()) } ); diff --git a/core/src/ops/scan/mir.rs b/core/src/ops/scan/mir.rs index cd90ad0c9b..8ad0854c97 100644 --- a/core/src/ops/scan/mir.rs +++ b/core/src/ops/scan/mir.rs @@ -12,7 +12,7 @@ pub struct Scan { pub skip: usize, pub reset_every_turn: bool, pub body: TypedModel, - decluttered: bool, + pub decluttered: bool, pub input_mapping: Vec, pub output_mapping: Vec>, }