diff --git a/src/state.rs b/src/state.rs index da97e84..c469081 100644 --- a/src/state.rs +++ b/src/state.rs @@ -70,13 +70,15 @@ impl Value for u32 { #[derive(Debug, Clone, Copy)] pub struct WaveValue { pub value: u32, + pub warp_size: usize, pub default_lane: Option, pub mutations: Option<[bool; 32]>, } impl WaveValue { - pub fn new(value: u32) -> Self { + pub fn new(value: u32, warp_size: usize) -> Self { Self { value, + warp_size, default_lane: None, mutations: None, } @@ -92,7 +94,7 @@ impl WaveValue { } pub fn apply_muts(&mut self) { self.value = 0; - for lane in 0..32 { + for lane in 0..self.warp_size { if self.mutations.unwrap()[lane] { self.value |= 1 << lane; } @@ -106,16 +108,40 @@ mod test_state { #[test] fn test_wave_value() { - let mut val = WaveValue::new(0b11000000000000011111111111101110); + let mut val = WaveValue::new(0b11000000000000011111111111101110, 32); val.default_lane = Some(0); assert!(!val.read()); val.default_lane = Some(31); assert!(val.read()); } + #[test] + fn test_wave_value_small() { + let mut val = WaveValue::new(0, 1); + val.default_lane = Some(0); + assert!(!val.read()); + assert_eq!(val.value, 0); + val.set_lane(true); + val.apply_muts(); + assert!(val.read()); + assert_eq!(val.value, 1); + } + + #[test] + fn test_wave_value_small_alt() { + let mut val = WaveValue::new(0, 2); + val.default_lane = Some(0); + assert!(!val.read()); + assert_eq!(val.value, 0); + val.set_lane(true); + val.apply_muts(); + assert!(val.read()); + assert_eq!(val.value, 1); + } + #[test] fn test_wave_value_mutations() { - let mut val = WaveValue::new(0b10001); + let mut val = WaveValue::new(0b10001, 32); val.default_lane = Some(0); val.set_lane(false); assert!(val.mutations.unwrap().iter().all(|x| !x)); diff --git a/src/thread.rs b/src/thread.rs index 68b09d3..2097eb0 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -28,6 +28,7 @@ pub struct Thread<'a> { pub stream: Vec, pub simm: Option, pub sgpr_co: &'a mut Option<(usize, WaveValue)>, + pub warp_size: usize, pub scalar: bool, } @@ -992,7 +993,7 @@ impl<'a> Thread<'a> { let sdst = ((instr >> 8) & 0x7f) as usize; let f = |i: u32| -> usize { ((instr >> i) & 0x1ff) as usize }; let (s0, s1, s2) = (f(32), f(41), f(50)); - let mut carry_in = WaveValue::new(self.val(s2)); + let mut carry_in = WaveValue::new(self.val(s2), self.warp_size); carry_in.default_lane = self.vcc.default_lane; let omod = (instr >> 59) & 0x3; let _neg = (instr >> 61) & 0x7; @@ -1352,7 +1353,8 @@ impl<'a> Thread<'a> { 796 => s0 * 2f32.powi(s1.to_bits() as i32), // cnd_mask isn't a float only ALU but supports neg 257 => { - let mut cond = WaveValue::new(s2.to_bits()); + let mut cond = + WaveValue::new(s2.to_bits(), self.warp_size); cond.default_lane = self.vcc.default_lane; match cond.read() { true => s1, @@ -1795,7 +1797,7 @@ impl<'a> Thread<'a> { let mut wv = self .sgpr_co .map(|(_, wv)| wv) - .unwrap_or_else(|| WaveValue::new(0)); + .unwrap_or_else(|| WaveValue::new(0, self.warp_size)); wv.default_lane = self.vcc.default_lane; wv.set_lane(val); *self.sgpr_co = Some((idx, wv)); @@ -3777,8 +3779,8 @@ fn _helper_test_thread() -> Thread<'static> { let static_sgpr: &'static mut Vec = Box::leak(Box::new(vec![0; 256])); let static_vgpr: &'static mut VGPR = Box::leak(Box::new(VGPR::new())); let static_scc: &'static mut u32 = Box::leak(Box::new(0)); - let static_exec: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(u32::MAX))); - let static_vcc: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(0))); + let static_exec: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(u32::MAX, 32))); + let static_vcc: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(0, 32))); let static_sds: &'static mut VecDataStore = Box::leak(Box::new(VecDataStore::new())); let static_co: &'static mut Option<(usize, WaveValue)> = Box::leak(Box::new(None)); @@ -3794,6 +3796,7 @@ fn _helper_test_thread() -> Thread<'static> { pc_offset: 0, stream: vec![], sgpr_co: static_co, + warp_size: 32, scalar: false, }; thread.vec_reg.default_lane = Some(0); diff --git a/src/work_group.rs b/src/work_group.rs index ec2d075..3db488a 100644 --- a/src/work_group.rs +++ b/src/work_group.rs @@ -106,7 +106,11 @@ impl<'a> WorkGroup<'a> { }; let (mut vec_reg, mut vcc, mut exec) = match wave_state { Some(val) => (val.2.clone(), val.3.clone(), val.4.clone()), - _ => (VGPR::new(), WaveValue::new(0), WaveValue::new(u32::MAX)), + _ => ( + VGPR::new(), + WaveValue::new(0, threads.len()), + WaveValue::new((1 << threads.len()) - 1, threads.len()), + ), }; let mut seeded_lanes = vec![]; @@ -167,6 +171,7 @@ impl<'a> WorkGroup<'a> { stream: self.kernel[pc..self.kernel.len()].to_vec(), scalar: false, simm: None, + warp_size: threads.len(), sgpr_co: &mut sgpr_co, }; thread.interpret()?;