Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kali authored and mathieupoumeyrolsonos committed Feb 27, 2024
1 parent 4f08538 commit 9be466a
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 22 deletions.
77 changes: 57 additions & 20 deletions core/src/ops/einsum/as_blas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ fn matmul_to_sgemm(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
node_name: &str,
_node_name: &str,
op: &BasicMatMul,
) -> TractResult<Option<TypedModelPatch>> {
if !op.transpose_a && !op.transpose_b && !op.transpose_c && op.quantize_output.is_none() {
if !op.transpose_a
&& !op.transpose_b
&& !op.transpose_c
&& op.quantize_output.is_none()
&& model.node_input_facts(node.id)?.iter().all(|f| f.datum_type == f32::datum_type())
{
TypedModelPatch::replace_single_op(model, node, &node.inputs, SGemm::default()).map(Some)
} else {
Ok(None)
Expand Down Expand Up @@ -88,30 +93,62 @@ impl EvalOp for SGemm {
let mut b_ptr = b_ptr;
let mut c_ptr = c_ptr;
for (axis, x) in prefix.as_array_view().iter().enumerate() {
if axis > silent_a_axis && a.shape()[axis - silent_a_axis] != 1 {
if axis >= silent_a_axis && a.shape()[axis - silent_a_axis] != 1 {
a_ptr = a_ptr.offset(*x as isize * a.strides()[axis - silent_a_axis]);
}
if axis > silent_b_axis && b.shape()[axis - silent_b_axis] != 1 {
if axis >= silent_b_axis && b.shape()[axis - silent_b_axis] != 1 {
b_ptr = b_ptr.offset(*x as isize * b.strides()[axis - silent_b_axis]);
}
c_ptr = c_ptr.offset(*x as isize * c.strides()[axis]);
}
cblas::sgemm(
cblas::Layout::RowMajor,
cblas::Transpose::None,
cblas::Transpose::None,
m as _,
n as _,
k as _,
1.0,
std::slice::from_raw_parts(a_ptr, m * k),
k as _,
std::slice::from_raw_parts(b_ptr, k * n),
n as _,
0.0,
std::slice::from_raw_parts_mut(c_ptr, m * n),
n as _,
)
if m == 1 {
cblas::sgemv(
cblas::Layout::RowMajor,
cblas::Transpose::Ordinary,
k as _,
n as _,
1.0,
std::slice::from_raw_parts(b_ptr, n * k),
n as _,
std::slice::from_raw_parts(a_ptr, k),
1,
0.0,
std::slice::from_raw_parts_mut(c_ptr, n),
1,
)
} else if n == 1 {
cblas::sgemv(
cblas::Layout::RowMajor,
cblas::Transpose::None,
m as _,
k as _,
1.0,
std::slice::from_raw_parts(a_ptr, m * k),
k as _,
std::slice::from_raw_parts(b_ptr, k),
1,
0.0,
std::slice::from_raw_parts_mut(c_ptr, m),
1,
)
} else {
cblas::sgemm(
cblas::Layout::RowMajor,
cblas::Transpose::None,
cblas::Transpose::None,
m as _,
n as _,
k as _,
1.0,
std::slice::from_raw_parts(a_ptr, m * k),
k as _,
std::slice::from_raw_parts(b_ptr, k * n),
n as _,
0.0,
std::slice::from_raw_parts_mut(c_ptr, m * n),
n as _,
)
}
}

Ok(tvec!(c.into_tvalue()))
Expand Down
6 changes: 6 additions & 0 deletions test-rt/test-onnx-core/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,10 @@ fn main() {
"unoptimized()",
"Approximation::Approximate",
);
suite.test_runtime(
"as_blas",
"suite_onnx::suite()",
"as_blas()",
"Approximation::Approximate",
);
}
28 changes: 26 additions & 2 deletions test-rt/test-onnx-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#![cfg(test)]
use tract_core::internal::*;

mod default {
use tract_core::internal::*;
use super::*;
pub fn default() -> &'static DefaultRuntime {
&DefaultRuntime
}
include!(concat!(env!("OUT_DIR"), "/tests/default.rs"));
}

mod unoptimized {
use tract_core::internal::*;
use super::*;

pub fn unoptimized() -> &'static UnoptimizedRuntime {
&UnoptimizedRuntime
Expand All @@ -29,3 +31,25 @@ mod unoptimized {
include!(concat!(env!("OUT_DIR"), "/tests/unoptimized.rs"));
}

mod as_blas {
use super::*;

pub fn as_blas() -> &'static AsBlasRuntime {
&AsBlasRuntime
}

#[derive(Debug)]
pub struct AsBlasRuntime;

impl Runtime for AsBlasRuntime {
fn name(&self) -> Cow<str> {
Cow::Borrowed("as_blas")
}
fn prepare(&self, mut model: TypedModel) -> TractResult<Box<dyn Runnable>> {
tract_core::transform::get_transformer("as-blas").unwrap().transform(&mut model)?;
Ok(Box::new(Arc::new(model.into_runnable()?)))
}
}

include!(concat!(env!("OUT_DIR"), "/tests/as_blas.rs"));
}
6 changes: 6 additions & 0 deletions test-rt/test-unit-core/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,10 @@ fn main() {
"unoptimized()",
"Approximation::Approximate",
);
suite.test_runtime(
"as_blas",
"suite_unit::suite().unwrap()",
"as_blas()",
"Approximation::Approximate",
);
}
23 changes: 23 additions & 0 deletions test-rt/test-unit-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,26 @@ mod unoptimized {

include!(concat!(env!("OUT_DIR"), "/tests/unoptimized.rs"));
}

mod as_blas {
use super::*;

pub fn as_blas() -> &'static AsBlasRuntime {
&AsBlasRuntime
}

#[derive(Debug)]
pub struct AsBlasRuntime;

impl Runtime for AsBlasRuntime {
fn name(&self) -> Cow<str> {
Cow::Borrowed("as_blas")
}
fn prepare(&self, mut model: TypedModel) -> TractResult<Box<dyn Runnable>> {
tract_core::transform::get_transformer("as-blas").unwrap().transform(&mut model)?;
Ok(Box::new(Arc::new(model.into_runnable()?)))
}
}

include!(concat!(env!("OUT_DIR"), "/tests/as_blas.rs"));
}

0 comments on commit 9be466a

Please sign in to comment.