Skip to content

Commit

Permalink
more usage of trivial paths
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Feb 19, 2024
1 parent 8a72a25 commit 6e0adf2
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
6 changes: 3 additions & 3 deletions core/src/ops/einsum/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ fn lir_mat_mul_unary(
let c_fact = op.output_facts(&input_facts)?.remove(0);
let name = &node.name;
let geo = AddMatMulGeometry {
k: k.to_dim(),
a_storage: None,
b_storage: None,
k: k.clone(),
a_storage: k.as_i64().map(|k| unsafe { mmm.a_packed(a_dt.size_of(), k as usize) }),
b_storage: k.as_i64().map(|k| unsafe { mmm.b_packed(b_dt.size_of(), k as usize) }),
mmm: mmm.clone(),
c_to_a_axis_mapping: MapOutputAxisToInput(c_to_a_axis_mapping),
c_to_b_axis_mapping: MapOutputAxisToInput(c_to_b_axis_mapping),
Expand Down
24 changes: 8 additions & 16 deletions core/src/ops/matmul/lir_unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ impl ProtoFusedSpec {
fs
}

pub fn has_symbols(&self) -> bool {
pub fn is_trivial(&self) -> bool {
match self {
ProtoFusedSpec::AddMatMul(geo, _, _) => geo.k.as_i64().is_none(),
_ => false,
ProtoFusedSpec::AddMatMul(geo, _, _) => {
geo.k.as_i64().is_some() && geo.a_storage.is_some() && geo.b_storage.is_some()
}
_ => true,
}
}

Expand All @@ -127,18 +129,8 @@ impl ProtoFusedSpec {
let b = inputs[*b].view();
unsafe {
let k = geo.k.as_i64().unwrap_unchecked() as usize;
// careful here. this work because a_packed() return a packer from which
// nothing is borrowed
let a = if let Some(sto) = &geo.a_storage {
sto.wrap(&a)
} else {
geo.mmm.a_packed(a.datum_type().size_of(), k).wrap(&a)
};
let b = if let Some(sto) = &geo.b_storage {
sto.wrap(&b)
} else {
geo.mmm.b_packed(b.datum_type().size_of(), k).wrap(&b)
};
let a = geo.a_storage.as_ref().unwrap().wrap(&a);
let b = geo.b_storage.as_ref().unwrap().wrap(&b);
FusedSpec::AddMatMul { k, a, b }
}
}
Expand Down Expand Up @@ -608,7 +600,7 @@ impl LirMatMulUnary {
.iter()
.enumerate()
.all(|(ax, dim)| ax == self.c_m_axis || ax == self.c_n_axis || dim.is_one())
&& self.micro_ops.iter().all(|o| !o.has_symbols())
&& self.micro_ops.iter().all(|o| o.is_trivial())
}

fn fuse_op(
Expand Down
4 changes: 2 additions & 2 deletions data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ impl Tensor {

#[inline]
pub fn view(&self) -> view::TensorView {
unsafe { view::TensorView::at_prefix_unchecked(self, &[]) }
unsafe { view::TensorView::view(self) }
}

#[inline]
Expand All @@ -1412,7 +1412,7 @@ impl Tensor {

#[inline]
pub fn view_mut(&mut self) -> view::TensorView {
unsafe { view::TensorView::at_prefix_unchecked(self, &[]) }
unsafe { view::TensorView::view(self) }
}

#[inline]
Expand Down
27 changes: 20 additions & 7 deletions data/src/tensor/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ impl<'a> TensorView<'a> {
}
}

#[inline]
pub unsafe fn view(tensor: &'a Tensor) -> TensorView<'a> {
TensorView { tensor, offset_bytes: 0, indexing: Indexing::Prefix(0), phantom: PhantomData }
}

#[inline]
pub fn datum_type(&self) -> DatumType {
self.tensor.datum_type()
Expand All @@ -99,7 +104,16 @@ impl<'a> TensorView<'a> {
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.shape().iter().product::<usize>()
match &self.indexing {
Indexing::Prefix(i) => {
if *i == 0 {
self.tensor.len()
} else {
self.tensor.strides[*i - 1] as usize
}
}
Indexing::Custom { shape, .. } => shape.iter().product(),
}
}

#[inline]
Expand All @@ -110,7 +124,10 @@ impl<'a> TensorView<'a> {

#[inline]
pub fn rank(&self) -> usize {
self.shape().len()
match &self.indexing {
Indexing::Prefix(i) => &self.tensor.rank() - i,
Indexing::Custom { shape, .. } => shape.len(),
}
}

fn check_dt<D: Datum>(&self) -> anyhow::Result<()> {
Expand Down Expand Up @@ -205,11 +222,7 @@ impl<'a> TensorView<'a> {

#[inline]
fn offset_for_coords(&self, coords: &[usize]) -> isize {
self.strides()
.iter()
.zip(coords.as_ref())
.map(|(s, c)| *s * *c as isize)
.sum::<isize>()
self.strides().iter().zip(coords.as_ref()).map(|(s, c)| *s * *c as isize).sum::<isize>()
}

#[inline]
Expand Down

0 comments on commit 6e0adf2

Please sign in to comment.