Skip to content

Commit

Permalink
patch
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Jan 6, 2024
1 parent 5345966 commit e5026bb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 15 deletions.
55 changes: 47 additions & 8 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ impl Model {
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
) -> Result<(Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues), Box<dyn Error>> {
use tract_onnx::tract_hir::internal::GenericFactoid;
use maybe_rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use tract_onnx::{tract_core::internal::TDim, tract_hir::internal::GenericFactoid};

let mut model = tract_onnx::onnx().model_for_read(reader).map_err(|e| {
error!("Error loading model: {}", e);
Expand Down Expand Up @@ -781,15 +782,54 @@ impl Model {
let symbol = model.symbol_table.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
info!("set {} to {}", symbol, value);
println!("set {} to {}", symbol, value);
}

// Note: do not optimize the model, as the layout will depend on underlying hardware
let model = model
let mut typed_model = model
.into_typed()?
.into_decluttered()?
.concretize_dims(&symbol_values)?;
.concretize_dims(&symbol_values)?
.into_decluttered()?;

// concretize constants
for node in typed_model.eval_order()? {
let node = typed_model.node_mut(node);
if node.op_is::<tract_onnx::tract_hir::ops::konst::Const>() {
// map option to err
let op = node
.op_as_mut::<tract_onnx::tract_hir::ops::konst::Const>()
.unwrap();
// get inner value to Arc<Tensor>
let constant = op.0.as_ref();

match constant.datum_type() {
DatumType::TDim => {
// Generally a shape or hyperparam
let vec = constant.as_slice::<tract_onnx::prelude::TDim>()?.to_vec();
let data: Vec<TDim> =
vec.par_iter().map(|x| x.eval(&symbol_values)).collect();

// allow unsafe
#[allow(unsafe_code)]
unsafe {
let bytes = std::slice::from_raw_parts(
data.as_ptr() as *const u8,
data.len() * DatumType::TDim.size_of(),
);

op.0 = std::sync::Arc::new(tract_onnx::prelude::Tensor::from_raw_dt(
DatumType::TDim,
constant.shape(),
bytes,
)?);
}
}
_ => {}
}
}
}

Ok((model, symbol_values))
Ok((typed_model, symbol_values))
}

/// Loads an Onnx model from a specified path.
Expand Down Expand Up @@ -1082,7 +1122,7 @@ impl Model {
) -> Result<Vec<Vec<Tensor<f32>>>, Box<dyn Error>> {
use tract_onnx::tract_core::internal::IntoArcTensor;

let (model, symbols) = Model::load_onnx_using_tract(
let (model, _) = Model::load_onnx_using_tract(
&mut std::fs::File::open(model_path)
.map_err(|_| format!("failed to load model at {}", model_path.display()))?,
run_args,
Expand All @@ -1102,8 +1142,7 @@ impl Model {
result
.into_iter()
.map(|t| {
crate::graph::utilities::extract_tensor_value(t.into_arc_tensor(), &symbols)
.unwrap()
crate::graph::utilities::extract_tensor_value(t.into_arc_tensor()).unwrap()
})
.collect(),
);
Expand Down
13 changes: 6 additions & 7 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ use tract_onnx::prelude::SymbolValues;
/// Extracts the raw values from a tensor.
pub fn extract_tensor_value(
input: Arc<tract_onnx::prelude::Tensor>,
symbol_values: &SymbolValues,
) -> Result<Tensor<f32>, Box<dyn std::error::Error>> {
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};

Expand Down Expand Up @@ -196,7 +195,7 @@ pub fn extract_tensor_value(
.par_iter()
.map(|x| match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => match x.eval(symbol_values).to_i64() {
Err(_) => match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => Err("could not evaluate tdim"),
},
Expand Down Expand Up @@ -478,7 +477,7 @@ pub fn new_op_from_onnx(
let op: Const = load_op::<Const>(node.op(), idx, node.op().name().to_string())?;
let dt = op.0.datum_type();
// Raw values are always f32
let raw_value = extract_tensor_value(op.0, symbol_values)?;
let raw_value = extract_tensor_value(op.0)?;
// If bool or a tensor dimension then don't scale
let constant_scale = match dt {
DatumType::Bool
Expand Down Expand Up @@ -1075,12 +1074,12 @@ pub fn new_op_from_onnx(
}
};

let kernel = extract_tensor_value(conv_node.kernel.clone(), symbol_values)?;
let kernel = extract_tensor_value(conv_node.kernel.clone())?;
let kernel = quantize_tensor(kernel, scales.params, param_visibility)?;

let bias = match conv_node.bias.clone() {
Some(b) => {
let const_value = extract_tensor_value(b, symbol_values)?;
let const_value = extract_tensor_value(b)?;

let val = quantize_tensor(
const_value,
Expand Down Expand Up @@ -1153,12 +1152,12 @@ pub fn new_op_from_onnx(
}
};

let kernel = extract_tensor_value(deconv_node.kernel.clone(), symbol_values)?;
let kernel = extract_tensor_value(deconv_node.kernel.clone())?;
let kernel = quantize_tensor(kernel, scales.params, param_visibility)?;

let bias = match deconv_node.bias.clone() {
Some(b) => {
let const_value = extract_tensor_value(b, symbol_values)?;
let const_value = extract_tensor_value(b)?;

let val = quantize_tensor(
const_value,
Expand Down

0 comments on commit e5026bb

Please sign in to comment.