Skip to content

Commit

Permalink
FIx the alignment issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 7, 2024
1 parent e771b15 commit ed7f2f5
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"]
license = "MIT OR Apache-2.0"
name = "web-rwkv"
repository = "https://github.com/cryscan/web-rwkv"
version = "0.8.4"
version = "0.8.5"

[dependencies]
ahash = "0.8"
Expand Down
10 changes: 9 additions & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,15 @@ impl ContextInternal {

let data = {
let map = slice.get_mapped_range();
map.to_vec().into_boxed_slice()
let len = map.len();
let size = std::mem::size_of::<u32>();
let data = vec![0u32; (len + size - 1) / size].into_boxed_slice();
unsafe {
let data = Box::leak(data);
let data: &mut [u8] = bytemuck::cast_slice_mut(data);
data.copy_from_slice(&map);
Box::from_raw(data)
}
};
buffer.unmap();
data
Expand Down
12 changes: 4 additions & 8 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite>
fn transfer_into(self, context: &Context) -> Self {
match context {
context if context == &self.context => self,
_ => self.back_local().transfer_into(context),
_ => self.back_in_place().transfer_into(context),
}
}
}
Expand Down Expand Up @@ -474,7 +474,7 @@ impl<T: Scalar, K: Kind> TensorReshape for TensorGpu<T, K> {

impl<T: Scalar, K: Kind> TensorGpu<T, K> {
#[cfg(not(target_arch = "wasm32"))]
pub fn back_local(&self) -> TensorCpu<T> {
pub fn back_in_place(&self) -> TensorCpu<T> {
use crate::context::ContextEvent;

let context = &self.context;
Expand All @@ -495,10 +495,8 @@ impl<T: Scalar, K: Kind> TensorGpu<T, K> {
});
let data = receiver.blocking_recv().unwrap();
let data = unsafe {
let len = data.len() / std::mem::size_of::<T>();
let data = Box::leak(data);
let data = data.as_mut_ptr() as *mut T;
let slice = core::slice::from_raw_parts_mut(data, len);
let slice = bytemuck::cast_slice_mut::<_, T>(data);
Box::from_raw(slice)
};
let data = data.into_vec().into();
Expand Down Expand Up @@ -534,10 +532,8 @@ impl<T: Scalar, K: Kind> TensorGpu<T, K> {
});
let data = receiver.await.unwrap();
let data = unsafe {
let len = data.len() / std::mem::size_of::<T>();
let data = Box::leak(data);
let data = data.as_mut_ptr() as *mut T;
let slice = core::slice::from_raw_parts_mut(data, len);
let slice = bytemuck::cast_slice_mut::<_, T>(data);
Box::from_raw(slice)
};
let data = data.into_vec().into();
Expand Down
24 changes: 12 additions & 12 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2457,7 +2457,7 @@ mod tests {
drop(pass);
context.queue.submit(Some(encoder.finish()));

let x_host = x_dev.back_local().to_vec();
let x_host = x_dev.back_in_place().to_vec();

let mut ans = vec![];
for x in &x.into_iter().chunks(C) {
Expand Down Expand Up @@ -2524,8 +2524,8 @@ mod tests {
drop(pass);
context.queue.submit(Some(encoder.finish()));

let x_host = x_dev.back_local().to_vec();
// let s_host = s_dev.back_local().to_vec();
let x_host = x_dev.back_in_place().to_vec();
// let s_host = s_dev.back_in_place().to_vec();

// test recenter and rms norm
let shape = Shape::new(C, T, B, 1);
Expand All @@ -2541,7 +2541,7 @@ mod tests {
drop(pass);
context.queue.submit(Some(encoder.finish()));

let x_rms_host = x_dev.back_local().to_vec();
let x_rms_host = x_dev.back_in_place().to_vec();

let mut ans = vec![];
// let mut ans_stats = vec![];
Expand Down Expand Up @@ -2669,7 +2669,7 @@ mod tests {
// profiler.resolve_queries(&mut encoder);
context.queue.submit(Some(encoder.finish()));

let output_host = output_dev.back_local();
let output_host = output_dev.back_in_place();
let output_host = Vec::from(output_host);

// profiler.end_frame().unwrap();
Expand Down Expand Up @@ -2797,8 +2797,8 @@ mod tests {
drop(pass);
context.queue.submit(Some(encoder.finish()));

let matrix_u8_host = matrix_u8_dev.back_local().to_vec();
let output_host = output_dev.back_local().to_vec();
let matrix_u8_host = matrix_u8_dev.back_in_place().to_vec();
let output_host = output_dev.back_in_place().to_vec();

// let mut truth = vec![0.0; output_host.len()];
// for token in 0..T {
Expand Down Expand Up @@ -2971,9 +2971,9 @@ mod tests {
drop(pass);
context.queue.submit(Some(encoder.finish()));

let matrix_u4_host = matrix_u4_dev.back_local().to_vec();
let absmax_host = absmax_dev.back_local().to_vec();
let output_host = output_dev.back_local().to_vec();
let matrix_u4_host = matrix_u4_dev.back_in_place().to_vec();
let absmax_host = absmax_dev.back_in_place().to_vec();
let output_host = output_dev.back_in_place().to_vec();

let mut truth = vec![0.0; output_host.len()];
for token in 0..T {
Expand Down Expand Up @@ -3070,7 +3070,7 @@ mod tests {
drop(pass);
context.queue.submit(Some(encoder.finish()));

let output_host = output.back_local();
let output_host = output.back_in_place();
let output_host = Vec::from(output_host);

assert_eq!(
Expand Down Expand Up @@ -3107,7 +3107,7 @@ mod tests {
drop(pass);
context.queue.submit(Some(encoder.finish()));

let output_host = output.back_local();
let output_host = output.back_in_place();
let output_host: Vec<f32> = Vec::from(output_host);

assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<T: Scalar + Serialize, K: Kind> Serialize for TensorGpu<T, K> {
where
S: serde::Serializer,
{
TensorBlob::from(self.back_local()).serialize(serializer)
TensorBlob::from(self.back_in_place()).serialize(serializer)
}
}

Expand Down

0 comments on commit ed7f2f5

Please sign in to comment.