Skip to content

Commit

Permalink
leakyrelu and scan
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 29, 2023
1 parent 2b71121 commit da40fb8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
4 changes: 4 additions & 0 deletions core/src/half.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::ops::array::{Pad, PadMode};
use crate::ops::cnn::{ConvUnary, DeconvUnary};
use crate::ops::einsum::EinSum;
use crate::ops::konst::Const;
use crate::ops::scan::Scan;
use crate::ops::source::TypedSource;

#[derive(Debug)]
Expand All @@ -27,6 +28,9 @@ impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for Hal
bias: op.bias.as_ref().map(tensor_f32_to_f16),
..op.clone()
})
} else if let Some(op) = node.op_as::<Scan>() {
let body = HalfTranslator.translate_model(&op.body)?;
Box::new(Scan { body, .. op.clone() })
} else if let Some(op) = node.op_as::<EinSum>() {
Box::new(EinSum { operating_dt: dt_f32_to_f16(op.operating_dt), ..op.clone() })
} else if let Some(op) = node.op_as::<DeconvUnary>() {
Expand Down
5 changes: 4 additions & 1 deletion core/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod data_formats;
mod reduce;
mod softmax;

use tract_num_traits::{AsPrimitive, Zero};

pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape};
pub use self::reduce::{Reduce, Reducer};
pub use self::softmax::Softmax;
Expand All @@ -19,5 +21,6 @@ element_wise!(hard_swish, HardSwish,
);

element_wise!(leaky_relu, LeakyRelu { alpha: f32 },
[f32] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < 0. { op.alpha } else { 1.0 }); Ok(()) }
[f32] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < 0. { op.alpha } else { 1.0 }); Ok(()) },
[f16] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < f16::zero() { AsPrimitive::<f16>::as_(op.alpha) } else { (1.0).as_() }); Ok(()) }
);
2 changes: 1 addition & 1 deletion core/src/ops/scan/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct Scan {
pub skip: usize,
pub reset_every_turn: bool,
pub body: TypedModel,
decluttered: bool,
pub decluttered: bool,
pub input_mapping: Vec<InputMapping>,
pub output_mapping: Vec<OutputMapping<TDim>>,
}
Expand Down

0 comments on commit da40fb8

Please sign in to comment.