Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More linalg protocols #1308

Merged
merged 26 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ fn main() -> tract_core::anyhow::Result<()> {

.arg(Arg::new("f32-to-f16").long("f32-to-f16").alias("half-floats").long_help("Convert the decluttered network from f32 to f16"))
.arg(arg!(--"f16-to-f32" "Convert the decluttered network from f16 to f32"))
.arg(Arg::new("transform").long("transform").multiple_occurrences(true).takes_value(true).help("Apply a built-in transformation to the model"))
.arg(Arg::new("set").long("set").multiple_occurrences(true).takes_value(true)
.long_help("Set a symbol to a concrete value after decluttering"))

Expand Down
14 changes: 11 additions & 3 deletions cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,13 +617,13 @@ impl Parameters {
macro_rules! stage {
($name:expr, $from:ident -> $to:ident, $block:expr) => {
if let Some(from) = $from.take() {
info!(concat!("Running '", $name, "'"));
info!("Running {:?}", $name);
let mut last_model: Option<Box<dyn Model>> =
if keep_last { Some(Box::new(from.as_ref().clone())) } else { None };
let block: &dyn Fn(_) -> TractResult<_> = &$block;
let owned_model =
Arc::try_unwrap(from).unwrap_or_else(|from| from.as_ref().clone());
match block(owned_model).context(concat!("Error at stage ", $name)) {
match block(owned_model).with_context(|| format!("Error at stage {:?}", $name)) {
Ok(it) => {
$to = Some(Arc::new(it));
}
Expand All @@ -637,7 +637,7 @@ impl Parameters {
}
}
}
info_usage(concat!("after ", $name), probe);
info_usage(&format!("after {:?}", $name), probe);
if reference_stage.as_deref() == Some($name) {
reference_model = Some($to.as_ref().unwrap().clone());
}
Expand Down Expand Up @@ -724,6 +724,14 @@ impl Parameters {
tract_core::floats::FloatPrecisionTranslator::<f16, f32>::default().translate_model(&m)
});
}
if let Some(transform) = matches.values_of("transform") {
for transform in transform {
stage!(transform, typed_model -> typed_model, |m:TypedModel| {
let transformer = tract_core::transform::get_transformer(transform).with_context(|| format!("Could not find transformer named {}", transform))?;
transformer.transform_into(&m)
});
}
}
if let Some(set) = matches.values_of("set") {
let mut values = SymbolValues::default();
for set in set {
Expand Down
13 changes: 13 additions & 0 deletions core/src/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,23 @@ use crate::ops::einsum::EinSum;
use crate::ops::konst::Const;
use crate::ops::scan::Scan;
use crate::ops::source::TypedSource;
use crate::transform::ModelTransformer;

#[derive(Debug, Default)]
pub struct FloatPrecisionTranslator<T1: Datum + Float, T2: Datum + Float>(PhantomData<(T1, T2)>);

impl<T1: Datum + Float, T2: Datum + Float> ModelTransformer for FloatPrecisionTranslator<T1, T2> {
fn name(&self) -> Cow<str> {
format!("{:?}-to-{:?}", T1::datum_type(), T2::datum_type()).into()
}

fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
let new = self.translate_model(model)?;
*model = new;
Ok(())
}
}

impl<T1: Datum + Float, T2: Datum + Float>
Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>>
for FloatPrecisionTranslator<T1, T2>
Expand Down
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub mod model;
pub mod optim;
pub mod plan;
pub mod runtime;
pub mod transform;
pub mod value;

pub use dyn_clone;
Expand Down
6 changes: 6 additions & 0 deletions core/src/model/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::model::*;
use crate::ops;
use crate::optim::OptimizerSession;
use crate::plan::{FrozenSimpleState, SimplePlan, SimpleState};
use crate::transform::ModelTransformer;

/// A model with completely determined types and shapes.
pub type TypedModel = Graph<TypedFact, Box<dyn TypedOp>>;
Expand Down Expand Up @@ -149,6 +150,11 @@ impl TypedModel {
Ok(self)
}

/// Perform declutter passes on the network.
pub fn transform(&mut self, transformer: &dyn ModelTransformer) -> TractResult<()> {
transformer.transform(self)
}

/// Perform declutter passes on the network.
pub fn declutter(&mut self) -> TractResult<()> {
crate::optim::Optimizer::declutter().session().optimize(self)
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod softmax;

pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape};
pub use self::reduce::{Reduce, Reducer};
pub use self::softmax::Softmax;
pub use self::softmax::{Softmax, SoftmaxExp};

pub use crate::internal::*;

Expand Down
37 changes: 24 additions & 13 deletions core/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::internal::Axis;
use crate::internal::*;
use std::convert::TryFrom;
use std::mem::transmute;
use tract_data::internal::ClampCast;
use tract_data::itertools::Itertools;
use tract_ndarray::prelude::*;
Expand Down Expand Up @@ -210,6 +211,12 @@ fn max_t<T>(v: ArrayViewD<T>, _: ()) -> T
where
T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
{
if T::datum_type() == f32::datum_type() {
if let Some(slice) = v.as_slice() {
let slice = unsafe { transmute(slice) };
(tract_linalg::ops().max_f32)().run(slice).unwrap();
}
}
v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v })
}

Expand Down Expand Up @@ -297,19 +304,23 @@ impl TypedOp for Reduce {
outputs: &[&TypedFact],
) -> TractResult<AxesMapping> {
let mut letters = 'a'..;
let axes = (0..inputs[0].rank()).flat_map(|ix| {
if self.axes.contains(&ix) {
tvec!(
Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()).input(0, ix),
Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()).output(0, ix),
)
} else {
tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
.input(0, ix)
.output(0, ix))
}
.into_iter()
}).collect_vec();
let axes = (0..inputs[0].rank())
.flat_map(|ix| {
if self.axes.contains(&ix) {
tvec!(
Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
.input(0, ix),
Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
.output(0, ix),
)
} else {
tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
.input(0, ix)
.output(0, ix))
}
.into_iter()
})
.collect_vec();
AxesMapping::new(1, 1, axes)
}

Expand Down
53 changes: 46 additions & 7 deletions core/src/ops/nn/softmax/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@ use std::fmt::Debug;
use crate::internal::*;
use ndarray::prelude::*;

#[derive(Debug, Clone, new, Hash)]
#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)]
pub enum SoftmaxExp {
#[default]
Libc,
// https://nic.schraudolph.org/pubs/Schraudolph99.pdf
FastCompact,
}

#[derive(Debug, Clone, new, Hash, Default)]
pub struct Softmax {
pub axes: TVec<usize>,
pub quant_output_dt: Option<DatumType>,
pub exp: SoftmaxExp,
}

impl Op for Softmax {
Expand All @@ -24,7 +33,7 @@ impl Op for Softmax {
}

fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![format!("Axis: {:?}", self.axes)])
Ok(vec![format!("Axis: {:?}", self.axes), format!("Exp impl: {:?}", self.exp)])
}

op_as_typed_op!();
Expand Down Expand Up @@ -122,16 +131,24 @@ impl Softmax {
}
}

let mut output = input.into_tensor().into_array::<T>()?;
let mut output = input.into_tensor();
let mut view = output.to_array_view_mut::<T>()?;

for it_coords in tract_ndarray::indices(&*iterating_shape) {
let mut view = output.view_mut();
let mut view = view.view_mut();
for ix in 0..iterating_shape.len() {
if !self.axes.contains(&ix) {
view.collapse_axis(Axis(ix), it_coords[ix]);
}
}
softmax_inner(view);
if let Some(slice) =
view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type())
{
let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
self.softmax_inner_slice_f32(slice)?;
} else {
softmax_inner(view);
}
}

Ok(tvec!(output.into_tvalue()))
Expand Down Expand Up @@ -169,6 +186,27 @@ impl Softmax {
unsafe { output_tensor.set_datum_type(output_dt) };
Ok(tvec!(output_tensor.into_tvalue()))
}

fn softmax_inner_slice_f32(&self, slice: &mut [f32]) -> TractResult<()> {
let max = (tract_linalg::ops().max_f32)().run(slice)?;
let sum = match self.exp {
SoftmaxExp::Libc => {
let mut s = 0f32;
for x in slice.iter_mut() {
let y = (*x - max).exp();
s += y;
*x = y;
}
s
}
SoftmaxExp::FastCompact => {
(tract_linalg::ops().softmax2_fastcompact_f32)().run_with_params(slice, max)?
}
};
let rsum = sum.recip();
(tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?;
Ok(())
}
}

fn softmax_inner<T: Float + Datum + std::iter::Sum, D: Dimension>(mut view: ArrayViewMut<T, D>) {
Expand Down Expand Up @@ -328,7 +366,8 @@ mod test {
fn check(&self) -> Result<()> {
let inputs = tvec!(self.data.clone().into_tvalue());
let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
let softmax = Softmax { axes: self.axes.clone(), quant_output_dt };
let softmax =
Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() };

// Compute quantized output
let result = softmax.eval(inputs)?;
Expand All @@ -338,7 +377,7 @@ mod test {
// Compute reference output
let input_float = self.data.cast_to::<f32>()?;
let inputs_float = tvec!(input_float.into_owned().into_tvalue());
let softmax_float = Softmax { axes: self.axes.clone(), quant_output_dt: None };
let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() };
let reference_float = softmax_float.eval(inputs_float)?;
let reference_array = args_1!(reference_float);
let reference = reference_array.to_array_view::<f32>()?;
Expand Down
45 changes: 45 additions & 0 deletions core/src/transform.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::internal::*;
use std::borrow::Cow;
use std::fmt::Debug;

use tract_data::TractResult;

use crate::floats::FloatPrecisionTranslator;
use crate::ops::nn::{Softmax, SoftmaxExp, TypedModel};

pub fn get_transformer(name: &str) -> Option<Box<dyn ModelTransformer>> {
match name {
"f32-to-f16" => Some(Box::<FloatPrecisionTranslator<f32, f16>>::default()),
"f16-to-f32" => Some(Box::<FloatPrecisionTranslator<f32, f16>>::default()),
"softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)),
_ => None,
}
}

pub trait ModelTransformer: Debug {
fn name(&self) -> Cow<str>;
fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
fn transform_into(&self, model: &TypedModel) -> TractResult<TypedModel> {
let mut model = model.clone();
self.transform(&mut model)?;
Ok(model)
}
}

#[derive(Debug)]
struct SoftmaxFastCompact;

impl ModelTransformer for SoftmaxFastCompact {
fn name(&self) -> Cow<str> {
"softmax-fast-compact".into()
}

fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
for node in &mut model.nodes {
if let Some(softmax) = node.op_as_mut::<Softmax>() {
softmax.exp = SoftmaxExp::FastCompact;
}
}
Ok(())
}
}
15 changes: 5 additions & 10 deletions hir/src/ops/nn/layer_max.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use tract_core::ops::nn::Softmax;

use crate::infer::*;
use crate::internal::*;

Expand All @@ -9,8 +11,6 @@ pub struct LayerHardmax {
coerce_to_2d: bool,
}



impl Expansion for LayerHardmax {
fn name(&self) -> Cow<str> {
"LayerHardmax".into()
Expand Down Expand Up @@ -83,14 +83,11 @@ pub struct LayerLogSoftmax {
pub coerce_to_2d: bool,
}



impl Expansion for LayerLogSoftmax {
fn name(&self) -> Cow<str> {
"LayerLogSoftmax".into()
}


fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
solver: &mut Solver<'r>,
Expand Down Expand Up @@ -118,8 +115,6 @@ pub struct LayerSoftmax {
coerce_to_2d: bool,
}



impl Expansion for LayerSoftmax {
fn name(&self) -> Cow<str> {
"LayerSoftmax".into()
Expand All @@ -144,10 +139,10 @@ impl Expansion for LayerSoftmax {
let rank = target.outlet_fact(input)?.rank();
let dt = target.outlet_fact(input)?.datum_type;
let axis = if self.axis < 0 { rank as isize + self.axis } else { self.axis } as usize;
let reducing_axes =
let axes =
if self.coerce_to_2d { (axis..rank).collect::<TVec<usize>>() } else { tvec!(axis) };
let dt = if dt.is_float() { None } else { Some(dt) };
target.wire_node(name, tract_core::ops::nn::Softmax::new(reducing_axes, dt), inputs)
let quant_output_dt = if dt.is_float() { None } else { Some(dt) };
target.wire_node(name, Softmax { axes, quant_output_dt, ..Softmax::default() }, inputs)
}
}

Expand Down
9 changes: 5 additions & 4 deletions hir/src/ops/nn/softmax.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
//use tract_core::ops::nn::Softmax;
use crate::internal::*;

#[derive(Debug, Clone, new, Hash)]
pub struct Softmax {
axis: isize,
}



impl Expansion for Softmax {
fn name(&self) -> Cow<str> {
"Softmax".into()
Expand Down Expand Up @@ -54,7 +51,11 @@ impl Expansion for Softmax {

target.wire_node(
name,
tract_core::ops::nn::Softmax { axes: tvec![axis], quant_output_dt },
tract_core::ops::nn::Softmax {
axes: tvec![axis],
quant_output_dt,
..tract_core::ops::nn::Softmax::default()
},
inputs,
)
}
Expand Down
Loading
Loading