diff --git a/Cargo.toml b/Cargo.toml index bb64cc7..f95733d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"] license = "MIT OR Apache-2.0" name = "web-rwkv" repository = "https://github.com/cryscan/web-rwkv" -version = "0.8.4" +version = "0.8.5" [dependencies] ahash = "0.8" diff --git a/src/context.rs b/src/context.rs index 6f8dece..3141826 100644 --- a/src/context.rs +++ b/src/context.rs @@ -344,7 +344,15 @@ impl ContextInternal { let data = { let map = slice.get_mapped_range(); - map.to_vec().into_boxed_slice() + let len = map.len(); + let size = std::mem::size_of::(); + let data = vec![0u32; (len + size - 1) / size].into_boxed_slice(); + unsafe { + let data = Box::leak(data); + let data: &mut [u8] = bytemuck::cast_slice_mut(data); + data.copy_from_slice(&map); + Box::from_raw(data) + } }; buffer.unmap(); data diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index d0a6c5f..281d93d 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -428,7 +428,7 @@ impl TensorInto> for TensorGpu fn transfer_into(self, context: &Context) -> Self { match context { context if context == &self.context => self, - _ => self.back_local().transfer_into(context), + _ => self.back_in_place().transfer_into(context), } } } @@ -474,7 +474,7 @@ impl TensorReshape for TensorGpu { impl TensorGpu { #[cfg(not(target_arch = "wasm32"))] - pub fn back_local(&self) -> TensorCpu { + pub fn back_in_place(&self) -> TensorCpu { use crate::context::ContextEvent; let context = &self.context; @@ -495,10 +495,8 @@ impl TensorGpu { }); let data = receiver.blocking_recv().unwrap(); let data = unsafe { - let len = data.len() / std::mem::size_of::(); let data = Box::leak(data); - let data = data.as_mut_ptr() as *mut T; - let slice = core::slice::from_raw_parts_mut(data, len); + let slice = bytemuck::cast_slice_mut::<_, T>(data); Box::from_raw(slice) }; let data = data.into_vec().into(); @@ -534,10 +532,8 @@ impl TensorGpu { }); let data = receiver.await.unwrap(); let data = unsafe { - let len = data.len() / std::mem::size_of::(); let data = Box::leak(data); - let data = data.as_mut_ptr() as *mut T; - let slice = core::slice::from_raw_parts_mut(data, len); + let slice = bytemuck::cast_slice_mut::<_, T>(data); Box::from_raw(slice) }; let data = data.into_vec().into(); diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index bca9567..89a953f 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -2457,7 +2457,7 @@ mod tests { drop(pass); context.queue.submit(Some(encoder.finish())); - let x_host = x_dev.back_local().to_vec(); + let x_host = x_dev.back_in_place().to_vec(); let mut ans = vec![]; for x in &x.into_iter().chunks(C) { @@ -2524,8 +2524,8 @@ mod tests { drop(pass); context.queue.submit(Some(encoder.finish())); - let x_host = x_dev.back_local().to_vec(); - // let s_host = s_dev.back_local().to_vec(); + let x_host = x_dev.back_in_place().to_vec(); + // let s_host = s_dev.back_in_place().to_vec(); // test recenter and rms norm let shape = Shape::new(C, T, B, 1); @@ -2541,7 +2541,7 @@ mod tests { drop(pass); context.queue.submit(Some(encoder.finish())); - let x_rms_host = x_dev.back_local().to_vec(); + let x_rms_host = x_dev.back_in_place().to_vec(); let mut ans = vec![]; // let mut ans_stats = vec![]; @@ -2669,7 +2669,7 @@ mod tests { // profiler.resolve_queries(&mut encoder); context.queue.submit(Some(encoder.finish())); - let output_host = output_dev.back_local(); + let output_host = output_dev.back_in_place(); let output_host = Vec::from(output_host); // profiler.end_frame().unwrap(); @@ -2797,8 +2797,8 @@ mod tests { drop(pass); context.queue.submit(Some(encoder.finish())); - let matrix_u8_host = matrix_u8_dev.back_local().to_vec(); - let output_host = output_dev.back_local().to_vec(); + let matrix_u8_host = matrix_u8_dev.back_in_place().to_vec(); + let output_host = output_dev.back_in_place().to_vec(); // let mut truth = vec![0.0; output_host.len()]; // for token in 0..T { @@ -2971,9 +2971,9 @@ mod tests { drop(pass); context.queue.submit(Some(encoder.finish())); - let matrix_u4_host = matrix_u4_dev.back_local().to_vec(); - let absmax_host = absmax_dev.back_local().to_vec(); - let output_host = output_dev.back_local().to_vec(); + let matrix_u4_host = matrix_u4_dev.back_in_place().to_vec(); + let absmax_host = absmax_dev.back_in_place().to_vec(); + let output_host = output_dev.back_in_place().to_vec(); let mut truth = vec![0.0; output_host.len()]; for token in 0..T { @@ -3070,7 +3070,7 @@ mod tests { drop(pass); context.queue.submit(Some(encoder.finish())); - let output_host = output.back_local(); + let output_host = output.back_in_place(); let output_host = Vec::from(output_host); assert_eq!( @@ -3107,7 +3107,7 @@ mod tests { drop(pass); context.queue.submit(Some(encoder.finish())); - let output_host = output.back_local(); + let output_host = output.back_in_place(); let output_host: Vec = Vec::from(output_host); assert_eq!( diff --git a/src/tensor/serialization.rs b/src/tensor/serialization.rs index 5149e22..46f9ad0 100644 --- a/src/tensor/serialization.rs +++ b/src/tensor/serialization.rs @@ -60,7 +60,7 @@ impl Serialize for TensorGpu { where S: serde::Serializer, { - TensorBlob::from(self.back_local()).serialize(serializer) + TensorBlob::from(self.back_in_place()).serialize(serializer) } }