From ded2c8eaf42b2bcd8477f67a274a7a79da347d50 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 28 Nov 2023 16:39:54 +0100 Subject: [PATCH] mean unit --- extra/src/exp_unit_norm.rs | 53 ++++++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/extra/src/exp_unit_norm.rs b/extra/src/exp_unit_norm.rs index a5a6d986db..33d9d73baa 100644 --- a/extra/src/exp_unit_norm.rs +++ b/extra/src/exp_unit_norm.rs @@ -23,6 +23,21 @@ pub fn register(registry: &mut Registry) { ); registry.register_dumper(ser_eun); + registry.register_primitive( + "tract_extra_exp_mean_norm", + &[ + TypeName::Scalar.tensor().named("input"), + TypeName::Scalar.tensor().named("state"), + TypeName::Integer.named("axis"), + TypeName::Scalar.named("alpha"), + TypeName::Integer.named("skip").default(0), + TypeName::Logical.named("stateless").default(false), + TypeName::Scalar.named("scaling_factor"), + ], + &[("output", TypeName::Scalar.tensor())], + de_eun, + ); + OpPulsifier::register::(pulsify).unwrap(); } @@ -31,11 +46,13 @@ fn de_eun(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractR let state = invocation.named_arg_as(builder, "state")?; let axis = invocation.named_arg_as::(builder, "axis")? as usize; let alpha = invocation.named_arg_as(builder, "alpha")?; - let epsilon = invocation.named_arg_as(builder, "epsilon")?; + let epsilon = invocation.get_named_arg_as(builder, "epsilon")?.unwrap_or(1e-14); let stateless = invocation.named_arg_as::(builder, "stateless")?; - let complex = invocation.named_arg_as::(builder, "complex")?; + let complex = invocation.get_named_arg_as::(builder, "complex")?.unwrap_or(false); let skip = invocation.named_arg_as::(builder, "skip")? as usize; - let op = ExpUnitNorm { alpha, axis, epsilon, stateless, skip, complex }; + let scaling_factor = invocation.get_named_arg_as(builder, "scaling_factor")?.unwrap_or(1.0); + let mean = invocation.invocation.id == Identifier::from("tract_extra_exp_mean_norm"); + let op = ExpUnitNorm { alpha, axis, epsilon, stateless, skip, complex, scaling_factor, mean }; builder.wire(op, &[wire, state]) } @@ -68,6 +85,8 @@ pub struct ExpUnitNorm { pub skip: usize, pub stateless: bool, pub complex: bool, + pub mean: bool, + pub scaling_factor: f32, } #[derive(Clone, Debug, PartialEq, Default)] @@ -118,16 +137,28 @@ impl ExpUnitNormState { let mut state = self.hidden.as_mut().unwrap().to_array_view_mut::()?; for mut time_slice in x_view.axis_iter_mut(Axis(op.axis)) { if self.index >= op.skip { - let normed = if op.complex { - time_slice.mapv(|x| x * x).sum_axis(Axis(time_slice.ndim() - 1)).mapv(|x| x.sqrt()) + if op.mean { + state.zip_mut_with(&time_slice, |s: &mut f32, x: &f32| { + *s = x * (1f32 - op.alpha) + *s * op.alpha; + }); } else { - time_slice.mapv(|x| x.abs()) - }; - state.zip_mut_with(&normed, |s: &mut f32, x: &f32| { - *s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha; - }); + // unit norms + let normed = if op.complex { + time_slice + .mapv(|x| x * x) + .sum_axis(Axis(time_slice.ndim() - 1)) + .mapv(|x| x.sqrt()) + } else { + time_slice.mapv(|x| x.abs()) + }; + state.zip_mut_with(&normed, |s: &mut f32, x: &f32| { + *s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha; + }); + } } - if op.complex { + if op.mean { + time_slice.zip_mut_with(&state, |x, s| *x = (*x - s) / op.scaling_factor); + } else if op.complex { let state_view = state.view().insert_axis(Axis(state.ndim())); time_slice.zip_mut_with(&state_view, |x, s| *x /= s.sqrt()); } else {