diff --git a/extra/src/exp_unit_norm.rs b/extra/src/exp_unit_norm.rs index 33d9d73baa..7398a607e1 100644 --- a/extra/src/exp_unit_norm.rs +++ b/extra/src/exp_unit_norm.rs @@ -63,18 +63,20 @@ fn ser_eun( ) -> TractResult>> { let input = ast.mapping[&node.inputs[0]].clone(); let state = ast.mapping[&node.inputs[1]].clone(); - Ok(Some(invocation( - "tract_extra_exp_unit_norm", - &[input, state], - &[ - ("axis", numeric(op.axis)), - ("alpha", numeric(op.alpha)), - ("epsilon", numeric(op.epsilon)), - ("stateless", logical(op.stateless)), - ("complex", logical(op.complex)), - ("skip", numeric(op.skip)), - ], - ))) + let mut attributes = vec![ + ("axis", numeric(op.axis)), + ("alpha", numeric(op.alpha)), + ("stateless", logical(op.stateless)), + ("skip", numeric(op.skip)), + ]; + if op.mean { + attributes.push(("scaling_factor", numeric(op.scaling_factor))); + Ok(Some(invocation("tract_extra_exp_mean_norm", &[input, state], &attributes))) + } else { + attributes.push(("epsilon", numeric(op.epsilon))); + attributes.push(("complex", numeric(op.complex))); + Ok(Some(invocation("tract_extra_exp_unit_norm", &[input, state], &attributes))) + } } #[derive(Clone, Debug, PartialEq)]