Skip to content

Commit

Permalink
mean unit
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 28, 2023
1 parent 8c2d5a5 commit ded2c8e
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions extra/src/exp_unit_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ pub fn register(registry: &mut Registry) {
);
registry.register_dumper(ser_eun);

registry.register_primitive(
"tract_extra_exp_mean_norm",
&[
TypeName::Scalar.tensor().named("input"),
TypeName::Scalar.tensor().named("state"),
TypeName::Integer.named("axis"),
TypeName::Scalar.named("alpha"),
TypeName::Integer.named("skip").default(0),
TypeName::Logical.named("stateless").default(false),
TypeName::Scalar.named("scaling_factor"),
],
&[("output", TypeName::Scalar.tensor())],
de_eun,
);

OpPulsifier::register::<ExpUnitNorm>(pulsify).unwrap();
}

Expand All @@ -31,11 +46,13 @@ fn de_eun(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractR
let state = invocation.named_arg_as(builder, "state")?;
let axis = invocation.named_arg_as::<i64>(builder, "axis")? as usize;
let alpha = invocation.named_arg_as(builder, "alpha")?;
let epsilon = invocation.named_arg_as(builder, "epsilon")?;
let epsilon = invocation.get_named_arg_as(builder, "epsilon")?.unwrap_or(1e-14);
let stateless = invocation.named_arg_as::<bool>(builder, "stateless")?;
let complex = invocation.named_arg_as::<bool>(builder, "complex")?;
let complex = invocation.get_named_arg_as::<bool>(builder, "complex")?.unwrap_or(false);
let skip = invocation.named_arg_as::<i64>(builder, "skip")? as usize;
let op = ExpUnitNorm { alpha, axis, epsilon, stateless, skip, complex };
let scaling_factor = invocation.get_named_arg_as(builder, "scaling_factor")?.unwrap_or(1.0);
let mean = invocation.invocation.id == Identifier::from("tract_extra_exp_mean_norm");
let op = ExpUnitNorm { alpha, axis, epsilon, stateless, skip, complex, scaling_factor, mean };
builder.wire(op, &[wire, state])
}

Expand Down Expand Up @@ -68,6 +85,8 @@ pub struct ExpUnitNorm {
pub skip: usize,
pub stateless: bool,
pub complex: bool,
pub mean: bool,
pub scaling_factor: f32,
}

#[derive(Clone, Debug, PartialEq, Default)]
Expand Down Expand Up @@ -118,16 +137,28 @@ impl ExpUnitNormState {
let mut state = self.hidden.as_mut().unwrap().to_array_view_mut::<f32>()?;
for mut time_slice in x_view.axis_iter_mut(Axis(op.axis)) {
if self.index >= op.skip {
let normed = if op.complex {
time_slice.mapv(|x| x * x).sum_axis(Axis(time_slice.ndim() - 1)).mapv(|x| x.sqrt())
if op.mean {
state.zip_mut_with(&time_slice, |s: &mut f32, x: &f32| {
*s = x * (1f32 - op.alpha) + *s * op.alpha;
});
} else {
time_slice.mapv(|x| x.abs())
};
state.zip_mut_with(&normed, |s: &mut f32, x: &f32| {
*s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha;
});
// unit norms
let normed = if op.complex {
time_slice
.mapv(|x| x * x)
.sum_axis(Axis(time_slice.ndim() - 1))
.mapv(|x| x.sqrt())
} else {
time_slice.mapv(|x| x.abs())
};
state.zip_mut_with(&normed, |s: &mut f32, x: &f32| {
*s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha;
});
}
}
if op.complex {
if op.mean {
time_slice.zip_mut_with(&state, |x, s| *x = (*x - s) / op.scaling_factor);
} else if op.complex {
let state_view = state.view().insert_axis(Axis(state.ndim()));
time_slice.zip_mut_with(&state_view, |x, s| *x /= s.sqrt());
} else {
Expand Down

0 comments on commit ded2c8e

Please sign in to comment.