Skip to content

Commit

Permalink
Implement concat for ModelOutput.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 2, 2024
1 parent 6d905d8 commit f6bef59
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.6.0"
version = "0.6.1"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
13 changes: 13 additions & 0 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ pub enum ModelOutput {
Full(Vec<Vec<f32>>),
}

impl ModelOutput {
pub fn concat(self, other: Self) -> Self {
match (self, other) {
(Self::None, y) => y,
(x, Self::None) => x,
(Self::Last(x), Self::Last(y)) => Self::Full(vec![x, y]),
(Self::Last(x), Self::Full(y)) => Self::Full([vec![x], y].concat()),
(Self::Full(x), Self::Last(y)) => Self::Full([x, vec![y]].concat()),
(Self::Full(x), Self::Full(y)) => Self::Full([x, y].concat()),
}
}
}

#[derive(Debug, Default, Clone, Copy)]
pub enum OutputType {
/// Only the prediction of the last token.
Expand Down

0 comments on commit f6bef59

Please sign in to comment.