Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

setup a mask for deconv pulsing mode #1233

Merged
merged 5 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
686 changes: 371 additions & 315 deletions core/src/ops/cnn/deconv/deconv_sum.rs

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pulse-opl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use tract_nnef::internal::*;
mod concat;
mod deconv_delay;
mod delay;
mod mask;
mod pad;
mod slice;

Expand All @@ -17,6 +18,7 @@ pub mod prelude {
pub mod ops {
pub use super::deconv_delay::DeconvDelay;
pub use super::delay::{ Delay, DelayState };
pub use super::mask::PulseMask;
pub use super::pad::PulsePad;
pub use super::slice::PulsedAxisSlice;
}
Expand All @@ -41,6 +43,7 @@ pub fn tract_nnef_registry() -> Registry {
let mut reg = Registry::new("tract_pulse");
reg.aliases.push("pulse".into());
delay::register(&mut reg);
mask::register(&mut reg);
pad::register(&mut reg);
reg
}
149 changes: 149 additions & 0 deletions pulse-opl/src/mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use tract_nnef::internal::*;
use tract_nnef::ser::tdim;
use tract_nnef::tract_core::trivial_op_state_freeeze;

pub fn register(registry: &mut Registry) {
registry.register_primitive(
"tract_pulse_mask",
&[
TypeName::Scalar.tensor().named("input"),
TypeName::Integer.named("axis"),
TypeName::Integer.named("begin"),
TypeName::Integer.named("end"),
TypeName::Scalar.named("value"),
],
&[("output", TypeName::Scalar.tensor())],
deser,
);
registry.register_dumper(TypeId::of::<PulseMask>(), ser)
}

fn ser(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> {
let op = node.op_as::<PulseMask>().unwrap();
let wire = ast.mapping[&node.inputs[0]].clone();
let params = vec![
("axis", numeric(op.axis)),
("begin", numeric(op.begin)),
("end", tdim(&op.end)),
("value", numeric(op.value.cast_to_scalar::<f32>())),
];
Ok(Some(invocation("tract_pulse_mask", &[wire], &params)))
}

fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
let wire = invocation.named_arg_as(builder, "input")?;
let axis = invocation.named_arg_as(builder, "axis")?;
let begin = invocation.named_arg_as(builder, "begin")?;
let value: Tensor = tensor0(invocation.named_arg_as::<f32>(builder, "value")?);
let end = builder.allowing_new_symbols(|builder| invocation.named_arg_as(builder, "end"))?;
let op = PulseMask { axis, begin, end, value };
builder.wire(op, &[wire])
}

#[derive(Debug, Clone, Default, Hash)]
struct PulseMaskOpState {
current_pos: usize,
}

impl OpState for PulseMaskOpState {
fn eval(
&mut self,
session: &mut SessionState,
op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs).into_tensor();
let op = op.downcast_ref::<PulseMask>().ok_or_else(|| format_err!("Wrong Op type"))?;
let tensor = self.pad(session, op, input)?;
Ok(tvec!(tensor.into_tvalue()))
}
}

impl PulseMaskOpState {
fn pad(
&mut self,
session: &SessionState,
op: &PulseMask,
mut input: Tensor,
) -> TractResult<Tensor> {
let pulse = input.shape()[op.axis];
let pulse_begin = self.current_pos;
let pulse_end = self.current_pos + pulse;
self.current_pos += pulse;
let end = op.end.eval(&session.resolved_symbols).to_usize().unwrap_or(std::usize::MAX);

// pulse is entirely in valid input, just forward
if pulse_begin >= op.begin && pulse_end <= end {
return Ok(input);
}

if pulse_begin < op.begin {
let fill_up_to = (op.begin - pulse_begin).min(pulse);
unsafe {
dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
&mut input,
&op.value,
op.axis,
0..fill_up_to
))
};
}
if pulse_end > end {
let fill_from = pulse - (pulse_end - end).min(pulse);
unsafe {
dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
&mut input,
&op.value,
op.axis,
fill_from..pulse
))
}
}

Ok(input)
}
}

#[derive(Debug, Clone, Default, Hash)]
pub struct PulseMask {
pub axis: usize,
pub begin: usize,
pub end: TDim,
pub value: Tensor,
}

impl Op for PulseMask {
fn name(&self) -> Cow<str> {
"PulseMask".into()
}

fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![format!("axis: {} begin: {} end: {}", self.axis, self.begin, self.end,)])
}

op_as_typed_op!();
}

impl EvalOp for PulseMask {
fn is_stateless(&self) -> bool {
false
}

fn state(
&self,
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::<PulseMaskOpState>::default()))
}
}

impl TypedOp for PulseMask {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(inputs[0].clone()))
}

as_op!();
}

trivial_op_state_freeeze!(PulseMaskOpState);
54 changes: 27 additions & 27 deletions pulse-opl/src/pad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractRe
builder.wire(op, &[wire])
}

pub(crate) unsafe fn fill_slice_constant<T: Datum + Copy>(
data: &mut Tensor,
constant: &Tensor,
axis: usize,
range: std::ops::Range<usize>,
) {
let c = constant.to_scalar_unchecked::<T>();
data.to_array_view_mut_unchecked::<T>().slice_axis_mut(Axis(axis), range.into()).fill(*c);
}

unsafe fn fill_slice_with_frame<T: Datum + Copy>(
data: &mut Tensor,
axis: usize,
valid: &Tensor,
range: std::ops::Range<usize>,
) {
let mut data = data.to_array_view_mut_unchecked::<T>();
let valid = valid.to_array_view_unchecked::<T>();
for i in range {
data.slice_axis_mut(Axis(axis), (i..i + 1).into()).assign(&valid);
}
}

#[derive(Debug, Clone, Default, Hash)]
struct PulsePadOpState {
current_pos: usize,
Expand Down Expand Up @@ -91,29 +114,6 @@ impl PulsePadOpState {
Some(data.index_axis(Axis(op.axis), frame).to_owned().into_tensor());
}

unsafe fn fill_slice_constant<T: Datum + Copy>(
data: &mut Tensor,
constant: &Tensor,
axis: usize,
range: std::ops::Range<usize>,
) {
let c = constant.to_scalar_unchecked::<T>();
data.to_array_view_mut_unchecked::<T>().slice_axis_mut(Axis(axis), range.into()).fill(*c);
}

unsafe fn fill_slice_with_frame<T: Datum + Copy>(
data: &mut Tensor,
axis: usize,
valid: &Tensor,
range: std::ops::Range<usize>,
) {
let mut data = data.to_array_view_mut_unchecked::<T>();
let valid = valid.to_array_view_unchecked::<T>();
for i in range {
data.slice_axis_mut(Axis(axis), (i..i + 1).into()).assign(&valid);
}
}

fn pad(
&mut self,
session: &SessionState,
Expand Down Expand Up @@ -156,7 +156,7 @@ impl PulsePadOpState {
let fill_up_to = (op.begin_input - pulse_begin).min(pulse);
match &op.mode {
PadMode::Constant(c) => unsafe {
dispatch_copy_by_size!(Self::fill_slice_constant(input.datum_type())(
dispatch_copy_by_size!(fill_slice_constant(input.datum_type())(
&mut input,
c,
op.axis,
Expand All @@ -166,7 +166,7 @@ impl PulsePadOpState {
PadMode::Edge => {
let frame = input.slice(op.axis, fill_up_to, fill_up_to + 1)?;
unsafe {
dispatch_copy_by_size!(Self::fill_slice_with_frame(input.datum_type())(
dispatch_copy_by_size!(fill_slice_with_frame(input.datum_type())(
&mut input,
op.axis,
&frame,
Expand All @@ -181,7 +181,7 @@ impl PulsePadOpState {
let fill_from = pulse - (pulse_end - end_input).min(pulse);
match &op.mode {
PadMode::Constant(c) => unsafe {
dispatch_copy_by_size!(Self::fill_slice_constant(input.datum_type())(
dispatch_copy_by_size!(fill_slice_constant(input.datum_type())(
&mut input,
c,
op.axis,
Expand All @@ -191,7 +191,7 @@ impl PulsePadOpState {
PadMode::Edge => {
let last_frame = self.last_valid_frame.as_ref().unwrap();
unsafe {
dispatch_copy_by_size!(Self::fill_slice_with_frame(input.datum_type())(
dispatch_copy_by_size!(fill_slice_with_frame(input.datum_type())(
&mut input,
op.axis,
last_frame,
Expand Down
11 changes: 11 additions & 0 deletions pulse/src/ops/array/mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use crate::internal::*;
use tract_pulse_opl::ops::PulseMask;

impl PulsedOp for PulseMask {
fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
Ok(inputs.iter().cloned().cloned().collect())
}

as_op!();
pulsed_op_to_typed_op!();
}
1 change: 1 addition & 0 deletions pulse/src/ops/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::internal::*;

mod broadcast;
mod concat;
mod mask;
mod pad;
mod slice;

Expand Down
17 changes: 12 additions & 5 deletions pulse/src/ops/cnn/deconv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use tract_core::num_traits::Zero;
use tract_core::ops::cnn::DeconvUnary;
use tract_core::ops::cnn::PaddingSpec;
use tract_pulse_opl::ops::DeconvDelay;
use tract_pulse_opl::ops::PulseMask;

register_all!(DeconvUnary: pulsify);

Expand Down Expand Up @@ -37,9 +38,15 @@ fn pulsify(
let mut pulse_op = op.clone();
pulse_op.adjustments[geo_axis] = stride - 1;
pulse_op.pool_spec.padding = PaddingSpec::Valid;
let deconv =
target.wire_node(format!("{}.deconv", node.name), pulse_op, &[mapping[&node.inputs[0]]])?
[0];
let mut wire = tvec![mapping[&node.inputs[0]]];
let mask = PulseMask {
axis: stream.axis,
begin: stream.delay,
end: stream.dim.clone() + stream.delay,
value: Tensor::zero_scalar_dt(fact.datum_type)?,
};
wire = target.wire_node(format!("{}.mask", node.name), mask, &wire)?;
wire = target.wire_node(format!("{}.deconv", node.name), pulse_op, &wire)?;
let overlap = overlap(stream.axis, op);
let deconv_input_dim = (stream.dim.clone() - 1) * stride + 1;
let output_shape = tract_core::ops::cnn::deconv::output_shape(
Expand All @@ -56,7 +63,7 @@ fn pulsify(
&op.pool_spec.strides(),
&op.adjustments,
)?;
let mut wire = target.wire_node(
wire = target.wire_node(
&node.name,
DeconvDelay {
axis: stream.axis,
Expand All @@ -67,7 +74,7 @@ fn pulsify(
pulse: pulse.to_owned(),
deconv_output_dim: output_shape[stream.axis].clone(),
},
&[deconv],
&wire,
)?;

for (geo_axis, padding) in paddings.iter().enumerate() {
Expand Down
Empty file added pulse/src/ops/mask.rs
Empty file.
1 change: 1 addition & 0 deletions pulse/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod cnn;
pub mod delay;
pub mod downsample;
pub mod dummy;
pub mod mask;
pub mod scan;
pub mod slice;
pub mod source;
Expand Down