Skip to content

Commit

Permalink
conditional blas selection
Browse files Browse the repository at this point in the history
  • Loading branch information
kali authored and mathieupoumeyrolsonos committed Feb 27, 2024
1 parent 9be466a commit b9c9e12
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
9 changes: 6 additions & 3 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ maintenance = { status = "actively-developed" }
[dependencies]
anyhow.workspace = true
bit-set.workspace = true
cblas = { version = "0.4" }
# blis-src = { version = "0.2", features = ["static", "pthreads"] }
openblas-src = { version = "0.10", features = ["static"] }
cblas = { version = "0.4", optional = true }
blis-src = { version = "0.2", features = ["static", "pthreads"], optional = true }
openblas-src = { version = "0.10", features = ["static"], optional = true }
derive-new.workspace = true
downcast-rs.workspace = true
dyn-clone.workspace = true
Expand All @@ -39,6 +39,9 @@ tract-data = { version = "=0.21.2-pre", path = "../data" }
[features]
default = [ ]
complex = [ "tract-data/complex", "tract-linalg/complex" ]
blas = [ "cblas" ]
blis = [ "blas", "blis-src" ]
openblas = [ "blas", "openblas-src" ]
paranoid_assertions = []

[dev-dependencies]
Expand Down
7 changes: 5 additions & 2 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@
//! tract-tensorflow or tract-onnx crates.
//!
#[cfg(feature="blas")]
extern crate cblas;
// extern crate blis_src;
extern crate openblas_src;
#[cfg(feature="blis")]
extern crate blis_src;
#[cfg(feature="openblas")]
extern crate openblas_src;

pub extern crate anyhow;
extern crate bit_set;
Expand Down
6 changes: 6 additions & 0 deletions core/src/ops/einsum/as_blas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,11 @@ impl TypedOp for SGemm {
Ok(tvec!(f32::fact(&self.output_shape(&*inputs[0].shape, &*inputs[1].shape)?)))
}

fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
let fma = self.output_shape(&*inputs[0].shape, &*inputs[1].shape)?.iter().product::<TDim>()
* inputs[0].shape.last().unwrap();
Ok(tvec!((Cost::FMA(f32::datum_type()), fma)))
}

as_op!();
}

0 comments on commit b9c9e12

Please sign in to comment.