diff --git a/src/svf_simper.rs b/src/svf_simper.rs index 2e743c3..0aa10d8 100644 --- a/src/svf_simper.rs +++ b/src/svf_simper.rs @@ -319,7 +319,7 @@ where pub ic2eq: Simd, pi_over_sr: Simd, _behavior: PhantomData, - simd_impl: SimdImpl, + process_fn: fn(&mut Self, Simd) -> (Simd, Simd), } impl SVFSimper @@ -330,6 +330,19 @@ where let pi_over_sr = consts::PI / sample_rate; let (k, a1, a2, a3) = Self::compute_parameters(cutoff, resonance, pi_over_sr); + // Choose the processing function at initialization + let process_fn = match SimdImpl::detect() { + #[cfg(target_arch = "x86_64")] + SimdImpl::Avx512 if LANES == 16 => Self::process_avx512_safe, + #[cfg(target_arch = "x86_64")] + SimdImpl::Avx2 if LANES == 8 => Self::process_avx2_safe, + #[cfg(target_arch = "x86_64")] + SimdImpl::Sse2 if LANES == 4 => Self::process_sse2_safe, + #[cfg(target_arch = "aarch64")] + SimdImpl::Neon if LANES == 4 => Self::process_neon_safe, + _ => Self::process_generic, + }; + Self { k: Simd::splat(k), a1: Simd::splat(a1), @@ -338,8 +351,8 @@ where ic1eq: Simd::splat(0.0), ic2eq: Simd::splat(0.0), pi_over_sr: Simd::splat(pi_over_sr), + process_fn, _behavior: PhantomData, - simd_impl: SimdImpl::detect(), } } @@ -543,70 +556,112 @@ where (v1, v2) } - #[inline] + #[inline(always)] fn process(&mut self, v0: Simd) -> (Simd, Simd) { - unsafe { - match self.simd_impl { - #[cfg(target_arch = "x86_64")] - SimdImpl::Avx512 => { - if LANES == 16 { - // Convert from Simd to array - let input_array = v0.to_array(); - // Convert array to __m512 - let input_avx = _mm512_loadu_ps(input_array.as_ptr()); - // Process - let result = self.filter_avx512(input_avx); - // Convert back to Simd - let mut output_array = [0.0; LANES]; - _mm512_storeu_ps(output_array.as_mut_ptr(), result); - let output = Simd::from_array(output_array); - return (output, output); - } - } - #[cfg(target_arch = "x86_64")] - SimdImpl::Avx2 => { - if LANES == 8 { - let input_array = v0.to_array(); - let input_avx = _mm256_loadu_ps(input_array.as_ptr()); - let result = self.filter_avx2(input_avx); - let mut output_array = [0.0; LANES]; - _mm256_storeu_ps(output_array.as_mut_ptr(), result); - let output = Simd::from_array(output_array); - return (output, output); - } - } - #[cfg(target_arch = "x86_64")] - SimdImpl::Sse2 => { - if LANES == 4 { - let input_array = v0.to_array(); - let input_sse = _mm_loadu_ps(input_array.as_ptr()); - let result = self.filter_sse2(input_sse); - let mut output_array = [0.0; LANES]; - _mm_storeu_ps(output_array.as_mut_ptr(), result); - let output = Simd::from_array(output_array); - return (output, output); - } - } - #[cfg(target_arch = "aarch64")] - SimdImpl::Neon => { - if LANES == 4 { - let input_array = v0.to_array(); - let input_neon = vld1q_f32(input_array.as_ptr()); - let result = self.filter_neon(input_neon); - let mut output_array = [0.0; LANES]; - vst1q_f32(output_array.as_mut_ptr(), result); - let output = Simd::from_array(output_array); - return (output, output); - } - } - _ => {} - } - } + // Just call the function pointer - no runtime checks needed + (self.process_fn)(self, v0) + } - // Fall back to the generic implementation if no SIMD match - self.process_generic(v0) + // Implement separate process functions for each SIMD type + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx512f")] + unsafe fn process_avx512( + this: &mut Self, + v0: Simd, + ) -> (Simd, Simd) { + let input_array = v0.to_array(); + let input_avx = _mm512_loadu_ps(input_array.as_ptr()); + let result = this.filter_avx512(input_avx); + let mut output_array = [0.0; LANES]; + _mm512_storeu_ps(output_array.as_mut_ptr(), result); + let output = Simd::from_array(output_array); + (output, output) + } + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn process_avx2( + this: &mut Self, + v0: Simd, + ) -> (Simd, Simd) { + let input_array = v0.to_array(); + let input_avx = _mm256_loadu_ps(input_array.as_ptr()); + let result = this.filter_avx2(input_avx); + let mut output_array = [0.0; LANES]; + _mm256_storeu_ps(output_array.as_mut_ptr(), result); + let output = Simd::from_array(output_array); + (output, output) } + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "sse2")] + unsafe fn process_sse2( + this: &mut Self, + v0: Simd, + ) -> (Simd, Simd) { + let input_array = v0.to_array(); + let input_sse = _mm_loadu_ps(input_array.as_ptr()); + let result = this.filter_sse2(input_sse); + let mut output_array = [0.0; LANES]; + _mm_storeu_ps(output_array.as_mut_ptr(), result); + let output = Simd::from_array(output_array); + (output, output) + } + + #[cfg(target_arch = "aarch64")] + #[target_feature(enable = "neon")] + unsafe fn process_neon( + this: &mut Self, + v0: Simd, + ) -> (Simd, Simd) { + let input_array = v0.to_array(); + let input_neon = vld1q_f32(input_array.as_ptr()); + let result = this.filter_neon(input_neon); + let mut output_array = [0.0; LANES]; + vst1q_f32(output_array.as_mut_ptr(), result); + let output = Simd::from_array(output_array); + (output, output) + } + // Add safe wrapper functions for each SIMD variant + #[cfg(target_arch = "x86_64")] + fn process_avx512_safe( + this: &mut Self, + v0: Simd, + ) -> (Simd, Simd) { + unsafe { Self::process_avx512(this, v0) } + } + + #[cfg(target_arch = "x86_64")] + fn process_avx2_safe( + this: &mut Self, + v0: Simd, + ) -> (Simd, Simd) { + unsafe { Self::process_avx2(this, v0) } + } + + #[cfg(target_arch = "x86_64")] + fn process_sse2_safe( + this: &mut Self, + v0: Simd, + ) -> (Simd, Simd) { + unsafe { Self::process_sse2(this, v0) } + } + + #[cfg(target_arch = "aarch64")] + fn process_neon_safe( + this: &mut Self, + v0: Simd, + ) -> (Simd, Simd) { + unsafe { Self::process_neon(this, v0) } + } + // The generic fallback + // #[inline(always)] + // fn process_generic( + // this: &mut Self, + // v0: Simd, + // ) -> (Simd, Simd) { + // this.process_generic(v0) + // } + #[inline] pub fn lowpass(&mut self, v0: Simd) -> Simd { let (_, v2) = self.process(v0);