diff --git a/src/graph/model.rs b/src/graph/model.rs index bff2ac21e..951dd7892 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -744,7 +744,8 @@ impl Model { reader: &mut dyn std::io::Read, run_args: &RunArgs, ) -> Result<(Graph>, SymbolValues), Box> { - use tract_onnx::tract_hir::internal::GenericFactoid; + use maybe_rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + use tract_onnx::{tract_core::internal::TDim, tract_hir::internal::GenericFactoid}; let mut model = tract_onnx::onnx().model_for_read(reader).map_err(|e| { error!("Error loading model: {}", e); @@ -781,15 +782,54 @@ impl Model { let symbol = model.symbol_table.sym(symbol); symbol_values = symbol_values.with(&symbol, *value as i64); info!("set {} to {}", symbol, value); + println!("set {} to {}", symbol, value); } // Note: do not optimize the model, as the layout will depend on underlying hardware - let model = model + let mut typed_model = model .into_typed()? - .into_decluttered()? - .concretize_dims(&symbol_values)?; + .concretize_dims(&symbol_values)? + .into_decluttered()?; + + // concretize constants + for node in typed_model.eval_order()? { + let node = typed_model.node_mut(node); + if node.op_is::() { + // map option to err + let op = node + .op_as_mut::() + .unwrap(); + // get inner value to Arc + let constant = op.0.as_ref(); + + match constant.datum_type() { + DatumType::TDim => { + // Generally a shape or hyperparam + let vec = constant.as_slice::()?.to_vec(); + let data: Vec = + vec.par_iter().map(|x| x.eval(&symbol_values)).collect(); + + // allow unsafe + #[allow(unsafe_code)] + unsafe { + let bytes = std::slice::from_raw_parts( + data.as_ptr() as *const u8, + data.len() * DatumType::TDim.size_of(), + ); + + op.0 = std::sync::Arc::new(tract_onnx::prelude::Tensor::from_raw_dt( + DatumType::TDim, + constant.shape(), + bytes, + )?); + } + } + _ => {} + } + } + } - Ok((model, symbol_values)) + Ok((typed_model, symbol_values)) } /// Loads an Onnx model from a specified path. @@ -1082,7 +1122,7 @@ impl Model { ) -> Result>>, Box> { use tract_onnx::tract_core::internal::IntoArcTensor; - let (model, symbols) = Model::load_onnx_using_tract( + let (model, _) = Model::load_onnx_using_tract( &mut std::fs::File::open(model_path) .map_err(|_| format!("failed to load model at {}", model_path.display()))?, run_args, @@ -1102,8 +1142,7 @@ impl Model { result .into_iter() .map(|t| { - crate::graph::utilities::extract_tensor_value(t.into_arc_tensor(), &symbols) - .unwrap() + crate::graph::utilities::extract_tensor_value(t.into_arc_tensor()).unwrap() }) .collect(), ); diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index c92568335..7ca931dd2 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -106,7 +106,6 @@ use tract_onnx::prelude::SymbolValues; /// Extracts the raw values from a tensor. pub fn extract_tensor_value( input: Arc, - symbol_values: &SymbolValues, ) -> Result, Box> { use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; @@ -196,7 +195,7 @@ pub fn extract_tensor_value( .par_iter() .map(|x| match x.to_i64() { Ok(v) => Ok(v as f32), - Err(_) => match x.eval(symbol_values).to_i64() { + Err(_) => match x.to_i64() { Ok(v) => Ok(v as f32), Err(_) => Err("could not evaluate tdim"), }, @@ -478,7 +477,7 @@ pub fn new_op_from_onnx( let op: Const = load_op::(node.op(), idx, node.op().name().to_string())?; let dt = op.0.datum_type(); // Raw values are always f32 - let raw_value = extract_tensor_value(op.0, symbol_values)?; + let raw_value = extract_tensor_value(op.0)?; // If bool or a tensor dimension then don't scale let constant_scale = match dt { DatumType::Bool @@ -1075,12 +1074,12 @@ pub fn new_op_from_onnx( } }; - let kernel = extract_tensor_value(conv_node.kernel.clone(), symbol_values)?; + let kernel = extract_tensor_value(conv_node.kernel.clone())?; let kernel = quantize_tensor(kernel, scales.params, param_visibility)?; let bias = match conv_node.bias.clone() { Some(b) => { - let const_value = extract_tensor_value(b, symbol_values)?; + let const_value = extract_tensor_value(b)?; let val = quantize_tensor( const_value, @@ -1153,12 +1152,12 @@ pub fn new_op_from_onnx( } }; - let kernel = extract_tensor_value(deconv_node.kernel.clone(), symbol_values)?; + let kernel = extract_tensor_value(deconv_node.kernel.clone())?; let kernel = quantize_tensor(kernel, scales.params, param_visibility)?; let bias = match deconv_node.bias.clone() { Some(b) => { - let const_value = extract_tensor_value(b, symbol_values)?; + let const_value = extract_tensor_value(b)?; let val = quantize_tensor( const_value,