Skip to content

Commit

Permalink
bias as wire
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 27, 2023
1 parent fd9aa10 commit 9a4f2db
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,32 +120,28 @@ impl ConvUnary {
// group,bias
fn bias_as_non_linear(
&self,
model: &mut TypedModel,
name: &str,
bias: OutletId,
c_group_axis: usize,
) -> TractResult<Option<(ProtoFusedSpec, Tensor)>> {
) -> TractResult<(ProtoFusedSpec, OutletId)> {
use tract_linalg::mmm::BinOp::Add;
if let Some(bias) = &self.bias {
if let Some(uni) = bias.as_uniform() {
if uni.is_zero()? {
Ok(None)
} else {
Ok(Some((ProtoFusedSpec::BinScalar(2, Add), uni)))
}
} else {
let bias = bias
.clone()
.into_tensor()
.into_shape(&[self.group, bias.len() / self.group])?;
Ok(Some((
ProtoFusedSpec::BinPerRow(
2,
Add,
MapOutputAxisToInput(tvec!((c_group_axis, 0))),
),
bias,
)))
}
let fact = model.outlet_fact(bias)?;
if fact.uniform.is_some() {
Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
} else {
Ok(None)
let bias = model.wire_node(
format!("{name}.reformat_bias"),
AxisOp::Reshape(
0,
tvec!(fact.shape.volume()),
tvec!(self.group.to_dim(), fact.shape.volume() / self.group),
),
&[bias],
)?[0];
let pfs =
ProtoFusedSpec::BinPerRow(2, Add, MapOutputAxisToInput(tvec!((c_group_axis, 0))));
Ok((pfs, bias))
}
}

Expand Down Expand Up @@ -502,8 +498,9 @@ impl ConvUnary {
let mut wires: TVec<OutletId> = wire.into();
wires.push(kernels);
let mut ops: Vec<ProtoFusedSpec> = vec![ProtoFusedSpec::AddMatMul(geo, 1, 0)];
if let Some((fused, tensor)) = self.bias_as_non_linear(c_m_axis - 1)? {
let bias = model.add_const(format!("{name}.bias"), tensor)?;
if let Some(bias) = &self.bias {
let bias = model.add_const(format!("{name}.bias"), bias.clone())?;
let (fused, bias) = self.bias_as_non_linear(model, name, bias, c_m_axis - 1)?;
wires.push(bias);
ops.push(fused);
}
Expand Down

0 comments on commit 9a4f2db

Please sign in to comment.