Skip to content

Commit

Permalink
Support both f16 and f32 running precision.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 16, 2024
1 parent eb0025d commit f6df239
Show file tree
Hide file tree
Showing 18 changed files with 318 additions and 159 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.6.9"
version = "0.6.10"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
7 changes: 4 additions & 3 deletions examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crossterm::terminal::{
};
#[cfg(not(debug_assertions))]
use dialoguer::{theme::ColorfulTheme, Select};
use half::f16;
use itertools::Itertools;
use memmap2::Mmap;
#[cfg(not(debug_assertions))]
Expand Down Expand Up @@ -148,7 +149,7 @@ async fn run(cli: Cli) -> Result<()> {

match info.version {
ModelVersion::V4 => {
let model: v4::Model = load_model(
let model: v4::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand All @@ -167,7 +168,7 @@ async fn run(cli: Cli) -> Result<()> {
run_internal(model, state, tokenizer, cli.batch).await
}
ModelVersion::V5 => {
let model: v5::Model = load_model(
let model: v5::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand All @@ -186,7 +187,7 @@ async fn run(cli: Cli) -> Result<()> {
run_internal(model, state, tokenizer, cli.batch).await
}
ModelVersion::V6 => {
let model: v6::Model = load_model(
let model: v6::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand Down
7 changes: 4 additions & 3 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use anyhow::Result;
use clap::{Args, Parser, ValueEnum};
#[cfg(not(debug_assertions))]
use dialoguer::{theme::ColorfulTheme, Select};
use half::f16;
use itertools::Itertools;
use memmap2::Mmap;
use serde::Deserialize;
Expand Down Expand Up @@ -188,7 +189,7 @@ async fn run(cli: Cli) -> Result<()> {

match info.version {
ModelVersion::V4 => {
let model: v4::Model = load_model(
let model: v4::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand All @@ -202,7 +203,7 @@ async fn run(cli: Cli) -> Result<()> {
run_internal(model, state, tokenizer, prompt, sampler).await
}
ModelVersion::V5 => {
let model: v5::Model = load_model(
let model: v5::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand All @@ -216,7 +217,7 @@ async fn run(cli: Cli) -> Result<()> {
run_internal(model, state, tokenizer, prompt, sampler).await
}
ModelVersion::V6 => {
let model: v6::Model = load_model(
let model: v6::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand Down
7 changes: 4 additions & 3 deletions examples/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use anyhow::Result;
use clap::{Parser, ValueEnum};
#[cfg(not(debug_assertions))]
use dialoguer::{theme::ColorfulTheme, Select};
use half::f16;
#[cfg(not(debug_assertions))]
use itertools::Itertools;
use memmap2::Mmap;
Expand Down Expand Up @@ -122,7 +123,7 @@ async fn run(cli: Cli) -> Result<()> {

match info.version {
ModelVersion::V4 => {
let model: v4::Model = load_model(
let model: v4::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand All @@ -136,7 +137,7 @@ async fn run(cli: Cli) -> Result<()> {
run_internal(model, state, tokenizer).await
}
ModelVersion::V5 => {
let model: v5::Model = load_model(
let model: v5::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand All @@ -150,7 +151,7 @@ async fn run(cli: Cli) -> Result<()> {
run_internal(model, state, tokenizer).await
}
ModelVersion::V6 => {
let model: v6::Model = load_model(
let model: v6::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand Down
10 changes: 7 additions & 3 deletions examples/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async fn run(cli: Cli) -> Result<()> {
println!("{:#?}", info);

let context = create_context(&info).await?;
let model: v5::Model = load_model(
let model: v5::Model<f16> = load_model(
&context,
&map,
cli.lora,
Expand All @@ -168,7 +168,11 @@ async fn run(cli: Cli) -> Result<()> {
hooks.insert(
v5::Hook::PostFfn(layer),
Box::new(
move |_model, _state, runtime: &v5::Runtime| -> Result<TensorOp, TensorError> {
move |_model,
_state,
runtime: &v5::Runtime<_>,
_header|
-> Result<TensorOp, TensorError> {
// figure out how many tokens this run has
let shape = runtime.ffn_x.shape();
let num_token = shape[1];
Expand Down Expand Up @@ -222,7 +226,7 @@ async fn run(cli: Cli) -> Result<()> {
&tensor.head.layer_norm.b,
&buffer.ffn_x,
None,
v5::Model::LN_EPS,
v5::Model::<f16>::LN_EPS,
)?,
tensor.head.w.matmul_mat_op(
buffer.ffn_x.view(.., .., .., ..)?,
Expand Down
4 changes: 0 additions & 4 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,8 @@ pub trait ModelState: for<'a> FromBuilder<Builder<'a> = StateBuilder, Error = In
}

pub trait ModelBase {
type ModelTensor;

fn context(&self) -> &Context;
fn info(&self) -> &ModelInfo;

fn tensor(&self) -> &Self::ModelTensor;
}

pub trait Model:
Expand Down
54 changes: 39 additions & 15 deletions src/model/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{
};
use crate::{
context::Context,
num::{Float, Hom},
tensor::{
kind::{ReadBack, ReadWrite},
ops::TensorOp,
Expand All @@ -18,13 +19,13 @@ use crate::{
};

#[derive(Debug)]
pub struct Header {
pub head_x: TensorGpu<f16, ReadWrite>,
pub struct Header<F: Float> {
pub head_x: TensorGpu<F, ReadWrite>,
pub head_o: TensorGpu<f32, ReadWrite>,
pub map: TensorGpu<f32, ReadBack>,
}

impl Header {
impl<F: Float> Header<F> {
pub fn new(context: &Context, info: &ModelInfo, num_batch: usize) -> Self {
let head_shape = Shape::new(info.num_emb, num_batch, 1, 1);
let output_shape = Shape::new(info.num_vocab, num_batch, 1, 1);
Expand All @@ -37,16 +38,20 @@ impl Header {
}
}

pub type HookMap<Hook, Model, State, Runtime> =
HashMap<Hook, Box<dyn Fn(&Model, &State, &Runtime) -> Result<TensorOp, TensorError>>>;
pub type HookMap<Hook, Tensor, State, Runtime, Header> =
HashMap<Hook, Box<dyn Fn(&Tensor, &State, &Runtime, &Header) -> Result<TensorOp, TensorError>>>;

pub(crate) trait ModelRunInternal: ModelBase {
type Hook: Hash;
type State: ModelState;
type Tensor;
type Runtime;
type Header;

fn tensor(&self) -> &Self::Tensor;

fn checkout_runtime(&self, num_batch: usize) -> Arc<Self::Runtime>;
fn checkout_header(&self, num_batch: usize) -> Arc<Header>;
fn checkout_header(&self, num_batch: usize) -> Arc<Self::Header>;

/// To prevent the GPU device from lost, this limits the maximum batch-token it processes one time.
fn token_chunk_size(&self) -> usize;
Expand All @@ -60,14 +65,14 @@ pub(crate) trait ModelRunInternal: ModelBase {
tokens: Vec<Vec<u16>>,
state: &Self::State,
outputs: Vec<Option<OutputType>>,
hooks: &HookMap<Self::Hook, Self, Self::State, Self::Runtime>,
hooks: &HookMap<Self::Hook, Self::Tensor, Self::State, Self::Runtime, Self::Header>,
) -> Result<(TensorGpu<f32, ReadBack>, Vec<std::ops::Range<usize>>), TensorError>;

fn create_input<'a>(
fn create_input<'a, F: Float>(
&self,
embed: &TensorCpu<'a, f16>,
tokens: &[Vec<u16>],
) -> Result<TensorStack<'a, f16>, TensorError> {
) -> Result<TensorStack<'a, F>, TensorError> {
let info = self.info();
let context = self.context();

Expand All @@ -80,7 +85,8 @@ pub(crate) trait ModelRunInternal: ModelBase {
.map(|&token| embed.slice(.., token as usize, .., ..))
.try_collect()?,
)
.unwrap_or_else(|_| context.zeros(Shape::new(info.num_emb, 1, 0, 1)));
.unwrap_or_else(|_| context.zeros(Shape::new(info.num_emb, 1, 0, 1)))
.map(|x| Hom::co_hom(*x));
stack.reshape(
TensorDimension::Full,
TensorDimension::Auto,
Expand All @@ -96,7 +102,11 @@ pub(crate) trait ModelRunInternal: ModelBase {
pub trait ModelRun {
type Hook: Hash;
type State: ModelState;
type Tensor;
type Runtime;
type Header;

fn tensor(&self) -> &Self::Tensor;

/// Run the model for a batch of tokens as input.
/// The length of `tokens` must match the number of batches in `state`.
Expand All @@ -110,23 +120,37 @@ pub trait ModelRun {
/// Run the model for a batch of tokens as input, but with custom hooks.
/// The length of `tokens` must match the number of batches in `state`.
/// `tokens` may have slots with no tokens, for which `run` won't compute that batch and will return an empty vector in that corresponding slot.
#[allow(clippy::type_complexity)]
fn run_with_hooks(
&self,
tokens: &mut Vec<ModelInput>,
state: &Self::State,
hooks: &HookMap<Self::Hook, Self, Self::State, Self::Runtime>,
hooks: &HookMap<Self::Hook, Self::Tensor, Self::State, Self::Runtime, Self::Header>,
) -> impl Future<Output = Result<Vec<ModelOutput>, TensorError>>;
}

impl<Hook, Runtime, Model, State> ModelRun for Model
impl<Hook, Model, Tensor, State, Runtime, Header> ModelRun for Model
where
Hook: Hash,
Model: ModelRunInternal<Hook = Hook, Runtime = Runtime, State = State>,
State: super::ModelState,
Model: ModelRunInternal<
Hook = Hook,
Tensor = Tensor,
State = State,
Runtime = Runtime,
Header = Header,
>,
{
type Hook = Hook;
type Runtime = Runtime;
type State = State;
type Tensor = Tensor;
type Runtime = Runtime;
type Header = Header;

#[inline]
fn tensor(&self) -> &Self::Tensor {
<Self as ModelRunInternal>::tensor(self)
}

async fn run(
&self,
Expand All @@ -141,7 +165,7 @@ where
&self,
tokens: &mut Vec<ModelInput>,
state: &Self::State,
hooks: &HookMap<Self::Hook, Self, Self::State, Self::Runtime>,
hooks: &HookMap<Self::Hook, Self::Tensor, Self::State, Self::Runtime, Self::Header>,
) -> Result<Vec<ModelOutput>, TensorError> {
let num_token: usize = tokens.iter().map(|input| input.tokens.len()).sum();
let num_batch = state.num_batch();
Expand Down
Loading

0 comments on commit f6df239

Please sign in to comment.