Skip to content

Commit

Permalink
Fix occasional batch truncation error.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Sep 2, 2023
1 parent 7e735e9 commit 3de9e9e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -784,15 +784,15 @@ impl<'a, 'b> Model<'a, 'b> {
// we only infer at most `token_chunk_size` tokens at a time
let mut num_token = num_token.min(self.token_chunk_size);
let mut inputs = vec![vec![]; max_batch];
let mut output = false;
let mut last_batch = None;

// take `num_token` tokens out of all the inputs and put into `input`
for (batch, input) in tokens.iter_mut().zip(inputs.iter_mut()) {
for (index, (batch, input)) in tokens.iter_mut().zip(inputs.iter_mut()).enumerate() {
let mid = batch.len().min(num_token);
num_token -= mid;

let (head, tail) = batch.split_at(mid);
output = tail.is_empty();
last_batch = (!tail.is_empty()).then_some(index);
*input = head.to_vec();
*batch = tail.to_vec();

Expand All @@ -801,7 +801,7 @@ impl<'a, 'b> Model<'a, 'b> {
}
}

let (buffer, redirect) = self.run_internal(inputs, state, output)?;
let (buffer, redirect) = self.run_internal(inputs, state, last_batch)?;
let output = async { TensorCpu::from(buffer.map.clone()) }.await;

Ok(redirect
Expand All @@ -820,7 +820,7 @@ impl<'a, 'b> Model<'a, 'b> {
&self,
tokens: Vec<Vec<u16>>,
state: &ModelState,
output: bool,
last_batch: Option<usize>,
) -> Result<(Arc<ModelBuffer>, Vec<Option<usize>>)> {
let context = self.context;
let tensor = &self.tensor;
Expand Down Expand Up @@ -857,7 +857,7 @@ impl<'a, 'b> Model<'a, 'b> {
.cursors
.iter()
.filter(|cursor| cursor.len > 0)
.filter(|cursor| output || cursor.batch + 1 < max_batch)
.filter(|cursor| !last_batch.is_some_and(|index| cursor.batch == index))
.enumerate()
.map(|(index, cursor)| -> Result<TensorOp<'_>, TensorError> {
redirect[cursor.batch] = Some(index);
Expand Down

0 comments on commit 3de9e9e

Please sign in to comment.