From eff4493cbf51bc0516d3078bc3fa0f50dcd85805 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 28 Nov 2023 16:45:17 +0100 Subject: [PATCH] mean serializer --- extra/src/exp_unit_norm.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/extra/src/exp_unit_norm.rs b/extra/src/exp_unit_norm.rs index 33d9d73baa..7398a607e1 100644 --- a/extra/src/exp_unit_norm.rs +++ b/extra/src/exp_unit_norm.rs @@ -63,18 +63,20 @@ fn ser_eun( ) -> TractResult>> { let input = ast.mapping[&node.inputs[0]].clone(); let state = ast.mapping[&node.inputs[1]].clone(); - Ok(Some(invocation( - "tract_extra_exp_unit_norm", - &[input, state], - &[ - ("axis", numeric(op.axis)), - ("alpha", numeric(op.alpha)), - ("epsilon", numeric(op.epsilon)), - ("stateless", logical(op.stateless)), - ("complex", logical(op.complex)), - ("skip", numeric(op.skip)), - ], - ))) + let mut attributes = vec![ + ("axis", numeric(op.axis)), + ("alpha", numeric(op.alpha)), + ("stateless", logical(op.stateless)), + ("skip", numeric(op.skip)), + ]; + if op.mean { + attributes.push(("scaling_factor", numeric(op.scaling_factor))); + Ok(Some(invocation("tract_extra_exp_mean_norm", &[input, state], &attributes))) + } else { + attributes.push(("epsilon", numeric(op.epsilon))); + attributes.push(("complex", numeric(op.complex))); + Ok(Some(invocation("tract_extra_exp_unit_norm", &[input, state], &attributes))) + } } #[derive(Clone, Debug, PartialEq)]