diff --git a/extra/src/exp_unit_norm.rs b/extra/src/exp_unit_norm.rs index a580247b63..7398a607e1 100644 --- a/extra/src/exp_unit_norm.rs +++ b/extra/src/exp_unit_norm.rs @@ -15,6 +15,7 @@ pub fn register(registry: &mut Registry) { TypeName::Scalar.named("alpha"), TypeName::Integer.named("skip").default(0), TypeName::Logical.named("stateless").default(false), + TypeName::Logical.named("complex").default(false), TypeName::Scalar.named("epsilon").default(1e-14f32), ], &[("output", TypeName::Scalar.tensor())], @@ -22,6 +23,21 @@ pub fn register(registry: &mut Registry) { ); registry.register_dumper(ser_eun); + registry.register_primitive( + "tract_extra_exp_mean_norm", + &[ + TypeName::Scalar.tensor().named("input"), + TypeName::Scalar.tensor().named("state"), + TypeName::Integer.named("axis"), + TypeName::Scalar.named("alpha"), + TypeName::Integer.named("skip").default(0), + TypeName::Logical.named("stateless").default(false), + TypeName::Scalar.named("scaling_factor"), + ], + &[("output", TypeName::Scalar.tensor())], + de_eun, + ); + OpPulsifier::register::(pulsify).unwrap(); } @@ -30,10 +46,13 @@ fn de_eun(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractR let state = invocation.named_arg_as(builder, "state")?; let axis = invocation.named_arg_as::(builder, "axis")? as usize; let alpha = invocation.named_arg_as(builder, "alpha")?; - let epsilon = invocation.named_arg_as(builder, "epsilon")?; + let epsilon = invocation.get_named_arg_as(builder, "epsilon")?.unwrap_or(1e-14); let stateless = invocation.named_arg_as::(builder, "stateless")?; + let complex = invocation.get_named_arg_as::(builder, "complex")?.unwrap_or(false); let skip = invocation.named_arg_as::(builder, "skip")? as usize; - let op = ExpUnitNorm { alpha, axis, epsilon, stateless, skip }; + let scaling_factor = invocation.get_named_arg_as(builder, "scaling_factor")?.unwrap_or(1.0); + let mean = invocation.invocation.id == Identifier::from("tract_extra_exp_mean_norm"); + let op = ExpUnitNorm { alpha, axis, epsilon, stateless, skip, complex, scaling_factor, mean }; builder.wire(op, &[wire, state]) } @@ -44,17 +63,20 @@ fn ser_eun( ) -> TractResult>> { let input = ast.mapping[&node.inputs[0]].clone(); let state = ast.mapping[&node.inputs[1]].clone(); - Ok(Some(invocation( - "tract_extra_exp_unit_norm", - &[input, state], - &[ - ("axis", numeric(op.axis)), - ("alpha", numeric(op.alpha)), - ("epsilon", numeric(op.epsilon)), - ("stateless", logical(op.stateless)), - ("skip", numeric(op.skip)), - ], - ))) + let mut attributes = vec![ + ("axis", numeric(op.axis)), + ("alpha", numeric(op.alpha)), + ("stateless", logical(op.stateless)), + ("skip", numeric(op.skip)), + ]; + if op.mean { + attributes.push(("scaling_factor", numeric(op.scaling_factor))); + Ok(Some(invocation("tract_extra_exp_mean_norm", &[input, state], &attributes))) + } else { + attributes.push(("epsilon", numeric(op.epsilon))); + attributes.push(("complex", numeric(op.complex))); + Ok(Some(invocation("tract_extra_exp_unit_norm", &[input, state], &attributes))) + } } #[derive(Clone, Debug, PartialEq)] @@ -64,6 +86,9 @@ pub struct ExpUnitNorm { pub axis: usize, pub skip: usize, pub stateless: bool, + pub complex: bool, + pub mean: bool, + pub scaling_factor: f32, } #[derive(Clone, Debug, PartialEq, Default)] @@ -101,20 +126,46 @@ impl EvalOp for ExpUnitNorm { impl ExpUnitNormState { fn eval(&mut self, op: &ExpUnitNorm, inputs: TVec) -> TractResult> { + use tract_ndarray::Axis; let (input, state0) = args_2!(inputs); let mut input = input.into_tensor(); - let mut view = input.to_array_view_mut::()?; + let mut x_view = input.to_array_view_mut::()?; if self.hidden.is_none() || op.stateless { self.hidden = Some(state0.into_tensor()); } + if op.complex { + ensure!(x_view.shape()[x_view.ndim() - 1] == 2); + } let mut state = self.hidden.as_mut().unwrap().to_array_view_mut::()?; - for mut time_slice in view.axis_iter_mut(tract_ndarray::Axis(op.axis)) { + for mut time_slice in x_view.axis_iter_mut(Axis(op.axis)) { if self.index >= op.skip { - state.zip_mut_with(&time_slice, |s: &mut f32, x: &f32| { - *s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha; - }); + if op.mean { + state.zip_mut_with(&time_slice, |s: &mut f32, x: &f32| { + *s = x * (1f32 - op.alpha) + *s * op.alpha; + }); + } else { + // unit norms + let normed = if op.complex { + time_slice + .mapv(|x| x * x) + .sum_axis(Axis(time_slice.ndim() - 1)) + .mapv(|x| x.sqrt()) + } else { + time_slice.mapv(|x| x.abs()) + }; + state.zip_mut_with(&normed, |s: &mut f32, x: &f32| { + *s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha; + }); + } + } + if op.mean { + time_slice.zip_mut_with(&state, |x, s| *x = (*x - s) / op.scaling_factor); + } else if op.complex { + let state_view = state.view().insert_axis(Axis(state.ndim())); + time_slice.zip_mut_with(&state_view, |x, s| *x /= s.sqrt()); + } else { + time_slice.zip_mut_with(&state, |x, s| *x /= s.sqrt()); } - time_slice.zip_mut_with(&state, |x, s| *x /= s.sqrt()); self.index += 1; } Ok(tvec!(input.into_tvalue())) @@ -136,7 +187,11 @@ impl OpState for ExpUnitNormState { impl TypedOp for ExpUnitNorm { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { let mut state_shape = inputs[0].shape.clone(); - let _ = state_shape.remove_axis(self.axis); + state_shape.remove_axis(self.axis)?; + if self.complex { + ensure!(inputs[0].shape[inputs[0].rank() - 1] == 2.to_dim()); + state_shape.remove_axis(state_shape.rank() - 1)?; + } ensure!(inputs[1].without_value() == inputs[0].datum_type.fact(state_shape)); Ok(tvec!(inputs[0].without_value())) }