Skip to content

Commit

Permalink
Use thiserror.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Mar 6, 2024
1 parent 45a9222 commit a0b2740
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 77 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ keywords = ["deep-learning", "language", "model", "rwkv"]
license = "MIT OR Apache-2.0"
name = "web-rwkv"
repository = "https://github.com/cryscan/web-rwkv"
version = "0.6.23"
version = "0.6.24"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ahash = "0.8"
anyhow = "1"
anyhow = "1.0"
bitflags = "2.3"
bytemuck = { version = "1.13", features = ["extern_crate_alloc"] }
derive-getters = "0.3"
Expand All @@ -28,8 +28,9 @@ lazy_static = "1.4"
log = "0.4"
regex = "1.10"
safetensors = "0.4"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
trait-variant = "0.1"
uid = "0.1"
wasm-bindgen = "0.2"
Expand Down
16 changes: 4 additions & 12 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{borrow::Cow, sync::Arc};

use thiserror::Error;
use wasm_bindgen::prelude::wasm_bindgen;
use web_rwkv_derive::{Deref, DerefMut};
use wgpu::{
Expand Down Expand Up @@ -91,23 +92,14 @@ pub struct ContextBuilder {
}

#[wasm_bindgen]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
pub enum CreateEnvironmentError {
#[error("failed to request adaptor")]
RequestAdapterFailed,
#[error("failed to request device")]
RequestDeviceFailed,
}

impl std::fmt::Display for CreateEnvironmentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CreateEnvironmentError::RequestAdapterFailed => write!(f, "failed to request adaptor"),
CreateEnvironmentError::RequestDeviceFailed => write!(f, "failed to request device"),
}
}
}

impl std::error::Error for CreateEnvironmentError {}

impl<'a> ContextBuilder {
pub fn new(adapter: Adapter) -> Self {
Self {
Expand Down
16 changes: 4 additions & 12 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{collections::HashMap, future::Future};
use anyhow::Result;
use half::f16;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use wasm_bindgen::prelude::wasm_bindgen;

use self::{
Expand Down Expand Up @@ -31,23 +32,14 @@ pub enum ModelVersion {
}

#[wasm_bindgen]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Error)]
pub enum ModelError {
#[error("invalid model version")]
InvalidVersion,
#[error("no viable chunk size found")]
NoViableChunkSize,
}

impl std::fmt::Display for ModelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelError::InvalidVersion => write!(f, "invalid model version"),
ModelError::NoViableChunkSize => write!(f, "no viable chunk size found"),
}
}
}

impl std::error::Error for ModelError {}

#[wasm_bindgen]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelInfo {
Expand Down
40 changes: 12 additions & 28 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{borrow::Cow, marker::PhantomData, sync::Arc};

use itertools::Itertools;
use thiserror::Error;
use web_rwkv_derive::JsError;
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
Expand Down Expand Up @@ -94,49 +95,32 @@ impl<K: Kind> Device for Gpu<K> {
type Data = TensorBuffer;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, JsError)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Error, JsError)]
pub enum TensorError {
#[error("list must not be empty")]
Empty,
#[error("data type mismatch")]
Type,
#[error("data size not match: {0} vs. {1}")]
Size(usize, usize),
#[error("batch size not match: {0} vs. {1}")]
Batch(usize, usize),
#[error("tensor shape not match: {0} vs. {1}")]
Shape(Shape, Shape),
#[error("cannot deduce dimension")]
Deduce,
BatchOutOfRange {
batch: usize,
max: usize,
},
#[error("batch {batch} out of range of max {max}")]
BatchOutOfRange { batch: usize, max: usize },
#[error("slice {start}..{end} out of range for dimension size {dim}")]
SliceOutOfRange {
dim: usize,
start: usize,
end: usize,
},
#[error("slice not contiguous")]
Contiguous,
}

impl std::fmt::Display for TensorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TensorError::Empty => write!(f, "list must not be empty"),
TensorError::Type => write!(f, "data type mismatch"),
TensorError::Size(a, b) => write!(f, "data size not match: {a} vs. {b}"),
TensorError::Batch(a, b) => write!(f, "batch size not match: {a} vs. {b}"),
TensorError::Shape(a, b) => write!(f, "tensor shape not match: {a} vs. {b}"),
TensorError::Deduce => write!(f, "cannot deduce dimension"),
TensorError::BatchOutOfRange { batch, max } => {
write!(f, "batch {batch} out of range of max {max}")
}
TensorError::SliceOutOfRange { dim, start, end } => write!(
f,
"slice {start}..{end} out of range for dimension size {dim}",
),
TensorError::Contiguous => write!(f, "slice not contiguous"),
}
}
}

impl std::error::Error for TensorError {}

/// Data defining a tensor view in shader.
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
pub struct View {
Expand Down
26 changes: 5 additions & 21 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,20 @@
use ahash::{AHashMap as HashMap, AHashSet as HashSet};
use derive_getters::Getters;
use std::collections::BTreeMap;
use thiserror::Error;
use wasm_bindgen::prelude::wasm_bindgen;
use web_rwkv_derive::JsError;

#[derive(Debug, JsError)]
#[derive(Debug, Error, JsError)]
pub enum TokenizerError {
#[error("failed to parse vocabulary: {0}")]
FailedToParseVocabulary(serde_json::Error),
#[error("no matching token found")]
NoMatchingTokenFound,
#[error("out of range token: {0}")]
OutOfRangeToken(u16),
}

impl std::fmt::Display for TokenizerError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
TokenizerError::FailedToParseVocabulary(error) => {
write!(fmt, "failed to parse vocabulary: {error}")?;
}
TokenizerError::NoMatchingTokenFound => {
write!(fmt, "no matching token found")?;
}
TokenizerError::OutOfRangeToken(token) => {
write!(fmt, "out of range token: {token}")?;
}
}

Ok(())
}
}

impl std::error::Error for TokenizerError {}

#[wasm_bindgen]
#[derive(Debug, Clone, Getters)]
pub struct Tokenizer {
Expand Down

0 comments on commit a0b2740

Please sign in to comment.