Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eun 2 #1273

Merged
merged 3 commits into from
Nov 28, 2023
Merged

Eun 2 #1273

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 75 additions & 20 deletions extra/src/exp_unit_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,29 @@ pub fn register(registry: &mut Registry) {
TypeName::Scalar.named("alpha"),
TypeName::Integer.named("skip").default(0),
TypeName::Logical.named("stateless").default(false),
TypeName::Logical.named("complex").default(false),
TypeName::Scalar.named("epsilon").default(1e-14f32),
],
&[("output", TypeName::Scalar.tensor())],
de_eun,
);
registry.register_dumper(ser_eun);

registry.register_primitive(
"tract_extra_exp_mean_norm",
&[
TypeName::Scalar.tensor().named("input"),
TypeName::Scalar.tensor().named("state"),
TypeName::Integer.named("axis"),
TypeName::Scalar.named("alpha"),
TypeName::Integer.named("skip").default(0),
TypeName::Logical.named("stateless").default(false),
TypeName::Scalar.named("scaling_factor"),
],
&[("output", TypeName::Scalar.tensor())],
de_eun,
);

OpPulsifier::register::<ExpUnitNorm>(pulsify).unwrap();
}

Expand All @@ -30,10 +46,13 @@ fn de_eun(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractR
let state = invocation.named_arg_as(builder, "state")?;
let axis = invocation.named_arg_as::<i64>(builder, "axis")? as usize;
let alpha = invocation.named_arg_as(builder, "alpha")?;
let epsilon = invocation.named_arg_as(builder, "epsilon")?;
let epsilon = invocation.get_named_arg_as(builder, "epsilon")?.unwrap_or(1e-14);
let stateless = invocation.named_arg_as::<bool>(builder, "stateless")?;
let complex = invocation.get_named_arg_as::<bool>(builder, "complex")?.unwrap_or(false);
let skip = invocation.named_arg_as::<i64>(builder, "skip")? as usize;
let op = ExpUnitNorm { alpha, axis, epsilon, stateless, skip };
let scaling_factor = invocation.get_named_arg_as(builder, "scaling_factor")?.unwrap_or(1.0);
let mean = invocation.invocation.id == Identifier::from("tract_extra_exp_mean_norm");
let op = ExpUnitNorm { alpha, axis, epsilon, stateless, skip, complex, scaling_factor, mean };
builder.wire(op, &[wire, state])
}

Expand All @@ -44,17 +63,20 @@ fn ser_eun(
) -> TractResult<Option<Arc<RValue>>> {
let input = ast.mapping[&node.inputs[0]].clone();
let state = ast.mapping[&node.inputs[1]].clone();
Ok(Some(invocation(
"tract_extra_exp_unit_norm",
&[input, state],
&[
("axis", numeric(op.axis)),
("alpha", numeric(op.alpha)),
("epsilon", numeric(op.epsilon)),
("stateless", logical(op.stateless)),
("skip", numeric(op.skip)),
],
)))
let mut attributes = vec![
("axis", numeric(op.axis)),
("alpha", numeric(op.alpha)),
("stateless", logical(op.stateless)),
("skip", numeric(op.skip)),
];
if op.mean {
attributes.push(("scaling_factor", numeric(op.scaling_factor)));
Ok(Some(invocation("tract_extra_exp_mean_norm", &[input, state], &attributes)))
} else {
attributes.push(("epsilon", numeric(op.epsilon)));
attributes.push(("complex", numeric(op.complex)));
Ok(Some(invocation("tract_extra_exp_unit_norm", &[input, state], &attributes)))
}
}

#[derive(Clone, Debug, PartialEq)]
Expand All @@ -64,6 +86,9 @@ pub struct ExpUnitNorm {
pub axis: usize,
pub skip: usize,
pub stateless: bool,
pub complex: bool,
pub mean: bool,
pub scaling_factor: f32,
}

#[derive(Clone, Debug, PartialEq, Default)]
Expand Down Expand Up @@ -101,20 +126,46 @@ impl EvalOp for ExpUnitNorm {

impl ExpUnitNormState {
fn eval(&mut self, op: &ExpUnitNorm, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
use tract_ndarray::Axis;
let (input, state0) = args_2!(inputs);
let mut input = input.into_tensor();
let mut view = input.to_array_view_mut::<f32>()?;
let mut x_view = input.to_array_view_mut::<f32>()?;
if self.hidden.is_none() || op.stateless {
self.hidden = Some(state0.into_tensor());
}
if op.complex {
ensure!(x_view.shape()[x_view.ndim() - 1] == 2);
}
let mut state = self.hidden.as_mut().unwrap().to_array_view_mut::<f32>()?;
for mut time_slice in view.axis_iter_mut(tract_ndarray::Axis(op.axis)) {
for mut time_slice in x_view.axis_iter_mut(Axis(op.axis)) {
if self.index >= op.skip {
state.zip_mut_with(&time_slice, |s: &mut f32, x: &f32| {
*s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha;
});
if op.mean {
state.zip_mut_with(&time_slice, |s: &mut f32, x: &f32| {
*s = x * (1f32 - op.alpha) + *s * op.alpha;
});
} else {
// unit norms
let normed = if op.complex {
time_slice
.mapv(|x| x * x)
.sum_axis(Axis(time_slice.ndim() - 1))
.mapv(|x| x.sqrt())
} else {
time_slice.mapv(|x| x.abs())
};
state.zip_mut_with(&normed, |s: &mut f32, x: &f32| {
*s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha;
});
}
}
if op.mean {
time_slice.zip_mut_with(&state, |x, s| *x = (*x - s) / op.scaling_factor);
} else if op.complex {
let state_view = state.view().insert_axis(Axis(state.ndim()));
time_slice.zip_mut_with(&state_view, |x, s| *x /= s.sqrt());
} else {
time_slice.zip_mut_with(&state, |x, s| *x /= s.sqrt());
}
time_slice.zip_mut_with(&state, |x, s| *x /= s.sqrt());
self.index += 1;
}
Ok(tvec!(input.into_tvalue()))
Expand All @@ -136,7 +187,11 @@ impl OpState for ExpUnitNormState {
impl TypedOp for ExpUnitNorm {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut state_shape = inputs[0].shape.clone();
let _ = state_shape.remove_axis(self.axis);
state_shape.remove_axis(self.axis)?;
if self.complex {
ensure!(inputs[0].shape[inputs[0].rank() - 1] == 2.to_dim());
state_shape.remove_axis(state_shape.rank() - 1)?;
}
ensure!(inputs[1].without_value() == inputs[0].datum_type.fact(state_shape));
Ok(tvec!(inputs[0].without_value()))
}
Expand Down