Skip to content

Commit

Permalink
Change the batch execution load balance strategy.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 8, 2024
1 parent 70cb692 commit c254907
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.6.4"
version = "0.6.5"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
2 changes: 0 additions & 2 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ pub enum ModelError {
NoViableChunkSize,
BatchSize(usize, usize),
BatchOutOfRange { batch: usize, max: usize },
EmptyInput,
}

impl std::fmt::Display for ModelError {
Expand All @@ -44,7 +43,6 @@ impl std::fmt::Display for ModelError {
ModelError::BatchOutOfRange { batch, max } => {
write!(f, "batch {batch} out of range of max {max}")
}
ModelError::EmptyInput => write!(f, "input is empty"),
}
}
}
Expand Down
60 changes: 25 additions & 35 deletions src/model/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ where
return Err(ModelError::BatchSize(tokens.len(), max_batch).into());
}
if num_token == 0 {
return Err(ModelError::EmptyInput.into());
return Ok(vec![ModelOutput::None; tokens.len()]);
}

// we only infer at most `token_chunk_size` tokens at a time
Expand All @@ -164,40 +164,30 @@ where
let mut inputs = vec![vec![]; max_batch];
let mut outputs: Vec<Option<OutputType>> = vec![None; max_batch];

// take `num_token` tokens out of all the inputs and put into `input`
// first pass, make sure each slot computes at least one token
for (output, input, slot) in
itertools::multizip((outputs.iter_mut(), inputs.iter_mut(), tokens.iter_mut()))
{
let mid = 1.min(slot.tokens.len()).min(num_token);
num_token -= mid;

if mid > 0 {
let (head, tail) = slot.tokens.split_at(mid);
*input = [&input, head].concat();
*output = match slot.ty {
OutputType::Last => tail.is_empty().then_some(OutputType::Last),
OutputType::Full => Some(OutputType::Full),
};
slot.tokens = tail.to_vec();
}
}

// second pass, assign rest token budgets from left to right
for (output, input, slot) in
itertools::multizip((outputs.iter_mut(), inputs.iter_mut(), tokens.iter_mut()))
{
let mid = slot.tokens.len().min(num_token);
num_token -= mid;

if mid > 0 {
let (head, tail) = slot.tokens.split_at(mid);
*input = [&input, head].concat();
*output = match slot.ty {
OutputType::Last => tail.is_empty().then_some(OutputType::Last),
OutputType::Full => Some(OutputType::Full),
};
slot.tokens = tail.to_vec();
// consume all available token counts
// assign them to as many slots as possible
while num_token > 0 {
let mid = tokens
.iter()
.map(|input| input.tokens.len())
.filter(|x| x > &0)
.min()
.unwrap_or_default();
for (output, input, slot) in
itertools::multizip((outputs.iter_mut(), inputs.iter_mut(), tokens.iter_mut()))
{
let mid = mid.min(slot.tokens.len()).min(num_token);
num_token -= mid;

if mid > 0 {
let (head, tail) = slot.tokens.split_at(mid);
*output = match slot.ty {
OutputType::Last => tail.is_empty().then_some(OutputType::Last),
OutputType::Full => Some(OutputType::Full),
};
input.append(&mut head.to_vec());
slot.tokens = tail.to_vec();
}
}
}

Expand Down

0 comments on commit c254907

Please sign in to comment.