Skip to content

Commit

Permalink
wip de-unarizing conv
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 5, 2023
1 parent 271be98 commit 5436851
Show file tree
Hide file tree
Showing 20 changed files with 636 additions and 548 deletions.
59 changes: 47 additions & 12 deletions core/src/axes/mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,22 +352,54 @@ impl AxesMapping {
AxesMapping::new(self.input_count, self.output_count, axes)
}

pub fn remove_input_axis(&self, slot: usize, position: usize) -> TractResult<AxesMapping> {
pub fn remove_axis_occurency(&self, slot: InOut, position: usize) -> TractResult<AxesMapping> {
let axis = self.axis((slot, position))?;
if axis.inputs.iter().map(|i| i.len()).sum::<usize>()
+ axis.outputs.iter().map(|i| i.len()).sum::<usize>()
== 1
{
return self.remove_axis(axis.repr);
}
let mut axes = self.axes.clone();
for axis in &mut axes {
axis.inputs[slot].retain(|pos| *pos != position);
axis.inputs[slot].iter_mut().for_each(|pos| *pos -= (*pos > position) as usize);
match slot {
InOut::In(slot) => {
for axis in &mut axes {
axis.inputs[slot].retain(|pos| *pos != position);
axis.inputs[slot].iter_mut().for_each(|pos| *pos -= (*pos > position) as usize);
}
}
InOut::Out(slot) => {
for axis in &mut axes {
axis.outputs[slot].retain(|pos| *pos != position);
axis.outputs[slot]
.iter_mut()
.for_each(|pos| *pos -= (*pos > position) as usize);
}
}
}
AxesMapping::new(self.input_count, self.output_count, axes)
}

pub fn remove_output_axis(&self, slot: usize, position: usize) -> TractResult<AxesMapping> {
let mut axes = self.axes.clone();
for axis in &mut axes {
axis.outputs[slot].retain(|pos| *pos != position);
axis.outputs[slot].iter_mut().for_each(|pos| *pos -= (*pos > position) as usize);

pub fn remove_slot(&self, slot: InOut) -> TractResult<AxesMapping> {
let mut axes = self.clone();
while axes.rank(slot) > 0 {
axes = axes.remove_axis_occurency(slot, 0)?
}
AxesMapping::new(self.input_count, self.output_count, axes)
match slot {
InOut::In(slot) => {
for axis in &mut axes.axes {
axis.inputs.remove(slot);
}
axes.input_count -= 1;
}
InOut::Out(slot) => {
for axis in &mut axes.axes {
axis.outputs.remove(slot);
}
axes.output_count -= 1;
}
}
axes.sorted().check()
}

pub fn with_extra_input(self, slot: usize) -> TractResult<AxesMapping> {
Expand Down Expand Up @@ -822,7 +854,10 @@ mod test {

#[test]
fn test_translate_to_ops_add_0() {
assert_eq!(m("bacmn->bmn").translate_to_axis_ops().unwrap(), vec!(AxisOp::Rm(2), AxisOp::Rm(1)));
assert_eq!(
m("bacmn->bmn").translate_to_axis_ops().unwrap(),
vec!(AxisOp::Rm(2), AxisOp::Rm(1))
);
}

#[test]
Expand Down
2 changes: 2 additions & 0 deletions core/src/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ impl<T1: Datum + Float, T2: Datum + Float>
Box::new(TypedSource::new(fact_float_precision_conversion::<T1, T2>(&source.fact)))
} else if let Some(konst) = node.op_as::<Const>() {
Box::new(Const(tensor_float_precision_conversion::<T1, T2>(&konst.0)))
/*
} else if let Some(op) = node.op_as::<ConvUnary>() {
Box::new(ConvUnary {
kernel: tensor_float_precision_conversion::<T1, T2>(&op.kernel),
bias: op.bias.as_ref().map(tensor_float_precision_conversion::<T1, T2>),
..op.clone()
})
*/
} else if let Some(op) = node.op_as::<Scan>() {
let body = FloatPrecisionTranslator::<T1, T2>::default().translate_model(&op.body)?;
Box::new(Scan { body, ..op.clone() })
Expand Down
Loading

0 comments on commit 5436851

Please sign in to comment.