Skip to content

Commit

Permalink
allow for empty inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Dec 5, 2023
1 parent 5850bc7 commit 8ce620b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ pub(crate) fn calibrate(
std::mem::drop(_r);

let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());

info!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(
&settings.run_args,
&model_path,
Expand All @@ -720,8 +722,6 @@ pub(crate) fn calibrate(
}
};

info!("num of calibration batches: {}", chunks.len());

let mut found_params: Vec<GraphSettings> = vec![];

let scale_rebase_multiplier = [1, 2, 10];
Expand Down
12 changes: 7 additions & 5 deletions src/graph/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,13 @@ impl GraphData {
match &self.input_data {
DataSource::File(data) => {
for (i, input) in data.iter().enumerate() {
let dt = datum_types[i];
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
let tt = TractTensor::from_shape(&shapes[i], &input)?;
let tt = tt.cast_to_dt(dt)?;
inputs.push(tt.into_owned().into());
if !input.is_empty() {
let dt = datum_types[i];
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
let tt = TractTensor::from_shape(&shapes[i], &input)?;
let tt = tt.cast_to_dt(dt)?;
inputs.push(tt.into_owned().into());
}
}
}
_ => {
Expand Down

0 comments on commit 8ce620b

Please sign in to comment.