Skip to content

Commit

Permalink
mean serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 28, 2023
1 parent ded2c8e commit eff4493
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions extra/src/exp_unit_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,20 @@ fn ser_eun(
) -> TractResult<Option<Arc<RValue>>> {
let input = ast.mapping[&node.inputs[0]].clone();
let state = ast.mapping[&node.inputs[1]].clone();
Ok(Some(invocation(
"tract_extra_exp_unit_norm",
&[input, state],
&[
("axis", numeric(op.axis)),
("alpha", numeric(op.alpha)),
("epsilon", numeric(op.epsilon)),
("stateless", logical(op.stateless)),
("complex", logical(op.complex)),
("skip", numeric(op.skip)),
],
)))
let mut attributes = vec![
("axis", numeric(op.axis)),
("alpha", numeric(op.alpha)),
("stateless", logical(op.stateless)),
("skip", numeric(op.skip)),
];
if op.mean {
attributes.push(("scaling_factor", numeric(op.scaling_factor)));
Ok(Some(invocation("tract_extra_exp_mean_norm", &[input, state], &attributes)))
} else {
attributes.push(("epsilon", numeric(op.epsilon)));
attributes.push(("complex", numeric(op.complex)));
Ok(Some(invocation("tract_extra_exp_unit_norm", &[input, state], &attributes)))
}
}

#[derive(Clone, Debug, PartialEq)]
Expand Down

0 comments on commit eff4493

Please sign in to comment.