Skip to content

Commit

Permalink
DeconvUnary->Deconv
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 14, 2023
1 parent 2c6c232 commit 4ca30a5
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use crate::ops::cnn::wire_reshape_bias_for_bin;
use crate::ops::einsum::EinSum;

#[derive(Clone, Debug, new, Hash)]
pub struct DeconvUnary {
pub struct Deconv {
pub pool_spec: PoolSpec,
pub kernel_format: KernelFormat,
pub adjustments: TVec<usize>,
pub group: usize,
}

impl DeconvUnary {
impl Deconv {
fn wire_with_deconv_sum(
&self,
name: &str,
Expand Down Expand Up @@ -123,9 +123,9 @@ impl DeconvUnary {
}
}

impl Op for DeconvUnary {
impl Op for Deconv {
fn name(&self) -> Cow<str> {
"DeconvUnary".into()
"Deconv".into()
}

fn info(&self) -> TractResult<Vec<String>> {
Expand All @@ -135,7 +135,7 @@ impl Op for DeconvUnary {
op_as_typed_op!();
}

impl EvalOp for DeconvUnary {
impl EvalOp for Deconv {
fn is_stateless(&self) -> bool {
true
}
Expand All @@ -154,7 +154,7 @@ impl EvalOp for DeconvUnary {
}
}

impl TypedOp for DeconvUnary {
impl TypedOp for Deconv {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs.len() == 3);
let x_fact = inputs[0];
Expand Down
5 changes: 3 additions & 2 deletions core/src/ops/cnn/deconv/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::internal::*;
use crate::ops::cnn::{PaddingSpec, PoolSpec};

#[allow(clippy::module_inception)]
mod deconv;
mod deconv_sum;
mod unary;

pub use unary::DeconvUnary;
pub use deconv::Deconv;

pub fn output_shape<D: DimLike>(
pool_spec: &PoolSpec,
Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/cnn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod pools;
mod sumpool;

pub use self::conv::{Conv, KernelFormat};
pub use self::deconv::DeconvUnary;
pub use self::deconv::Deconv;
pub use self::maxpool::MaxPool;
pub use self::padding::PaddingSpec;
pub use self::patch_axis::PatchAxis;
Expand Down Expand Up @@ -95,7 +95,7 @@ pub fn rewrite_deconv_with_n_axis(
model: &TypedModel,
node: &TypedNode,
name: &str,
deconv: &DeconvUnary,
deconv: &Deconv,
) -> TractResult<Option<TypedModelPatch>> {
if !deconv.pool_spec.data_format.has_n() {
let mut new = deconv.clone();
Expand Down
4 changes: 2 additions & 2 deletions harness/core-proptest-pulse/src/deconv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct DeconvOp {

impl DeconvOp {
fn chain(&self, name: &str, model: &mut TypedModel, after: OutletId) -> OutletId {
let deconv = tract_core::ops::cnn::DeconvUnary {
let deconv = tract_core::ops::cnn::Deconv {
pool_spec: PoolSpec {
data_format: DataFormat::NCHW,
kernel_shape: tvec!(self.ker.shape()[2]),
Expand Down Expand Up @@ -262,7 +262,7 @@ fn deconv2d() {
let a = model.add_source("a", f32::fact(dims!(1, 2, s, 8))).unwrap();
let mut kernel = Tensor::zero::<f32>(&[2, 2, 1, 3]).unwrap();
kernel.as_slice_mut::<f32>().unwrap().iter_mut().enumerate().for_each(|(ix, x)| *x = ix as f32);
let deconv = tract_core::ops::cnn::DeconvUnary {
let deconv = tract_core::ops::cnn::Deconv {
pool_spec: PoolSpec {
data_format: DataFormat::NCHW,
kernel_shape: tvec!(1, 3),
Expand Down
4 changes: 2 additions & 2 deletions nnef/src/ops/nnef/deser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::ast::*;
use crate::deser::Value;
use ops::cnn::deconv::DeconvUnary;
use ops::cnn::deconv::Deconv;
use ops::cnn::{Conv, KernelFormat};
use tract_core::internal::*;
use tract_core::ops::array::PadMode;
Expand Down Expand Up @@ -367,7 +367,7 @@ pub fn conv_or_deconv(
} else {
tvec!(0; pool_spec.rank())
};
Box::new(DeconvUnary::new(pool_spec, KernelFormat::OIHW, adjustments, group))
Box::new(Deconv::new(pool_spec, KernelFormat::OIHW, adjustments, group))
} else {
if let Some(odt) = &output_dt {
for dt in &[&input_fact.datum_type, &kernel_fact.datum_type, odt] {
Expand Down
8 changes: 4 additions & 4 deletions nnef/src/ops/nnef/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tract_core::num_traits::Zero;
use tract_core::ops;
use tract_core::ops::cast::cast;
use tract_core::ops::cnn::Conv;
use tract_core::ops::cnn::DeconvUnary;
use tract_core::ops::cnn::Deconv;
use tract_core::ops::cnn::KernelFormat;
use tract_core::ops::cnn::PoolSpec;
use tract_core::ops::einsum::BasicMatMul;
Expand Down Expand Up @@ -255,7 +255,7 @@ pub fn conv(
pub fn deconv(
ast: &mut IntoAst,
node: &TypedNode,
op: &ops::cnn::deconv::DeconvUnary,
op: &ops::cnn::deconv::Deconv,
) -> TractResult<Option<Arc<RValue>>> {
conv_or_deconv(ast, node, &op.pool_spec, op.group, true, Some(&op.adjustments))
}
Expand Down Expand Up @@ -516,15 +516,15 @@ pub fn rewrite_kernel_deconv_in_oihw(
model: &TypedModel,
node: &TypedNode,
name: &str,
conv: &DeconvUnary,
conv: &Deconv,
) -> TractResult<Option<TypedModelPatch>> {
rewrite_kernel_in_oihw(
model,
node,
name,
conv.kernel_format,
conv.group,
Box::new(DeconvUnary { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
)
}

Expand Down
4 changes: 2 additions & 2 deletions onnx/src/ops/nn/conv_transpose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ impl Expansion for ConvTranspose {
&x_shape.as_concrete().context("expects concrete dim for deconv")?[2..],
output_shape,
)?;
tract_core::ops::cnn::DeconvUnary::new(
tract_core::ops::cnn::Deconv::new(
pool_spec,
KernelFormat::OIHW,
adjustments,
self.group,
)
} else {
tract_core::ops::cnn::DeconvUnary::new(
tract_core::ops::cnn::Deconv::new(
pool_spec,
KernelFormat::OIHW,
self.adjustments.clone().unwrap_or_else(|| tvec!(0; kernel_shape.len() - 2)),
Expand Down
10 changes: 5 additions & 5 deletions pulse/src/ops/cnn/deconv.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::internal::*;
use tract_core::num_traits::Zero;
use tract_core::ops::cnn::DeconvUnary;
use tract_core::ops::cnn::Deconv;
use tract_core::ops::cnn::PaddingSpec;
use tract_pulse_opl::ops::DeconvDelay;
use tract_pulse_opl::ops::PulseMask;

register_all!(DeconvUnary: pulsify);
register_all!(Deconv: pulsify);

fn pulsify(
op: &DeconvUnary,
op: &Deconv,
source: &TypedModel,
node: &TypedNode,
target: &mut PulsedModel,
Expand Down Expand Up @@ -97,12 +97,12 @@ fn pulsify(
Ok(Some(wire))
}

fn overlap(pulse_axis: usize, op: &DeconvUnary) -> usize {
fn overlap(pulse_axis: usize, op: &Deconv) -> usize {
let geo_axis = pulse_axis - op.pool_spec.data_format.h_axis();
(op.pool_spec.kernel_shape[geo_axis] - 1) * op.pool_spec.dilation(geo_axis)
}

impl PulsedOp for DeconvUnary {
impl PulsedOp for Deconv {
fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
let mut fact = inputs[0].clone();
let stream = fact.stream.as_mut().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions test-rt/suite-unit/src/deconv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl Arbitrary for DeconvProblem {
}

impl DeconvProblem {
fn as_op(&self) -> TractResult<DeconvUnary> {
fn as_op(&self) -> TractResult<Deconv> {
let pool_spec = PoolSpec::new(
self.data_format,
self.kernel_format.spatial_shape(self.kernel.shape()).into(),
Expand All @@ -146,7 +146,7 @@ impl DeconvProblem {
self.kernel_format.output_channels(self.kernel.shape(), self.group).into_owned(),
);
let op =
DeconvUnary::new(pool_spec, self.kernel_format, self.adjustments.clone(), self.group);
Deconv::new(pool_spec, self.kernel_format, self.adjustments.clone(), self.group);
Ok(op)
}

Expand Down

0 comments on commit 4ca30a5

Please sign in to comment.