From 88328337b0eeb116100a6ae95fb2ba22227b9dcc Mon Sep 17 00:00:00 2001 From: Christopher Lorton Date: Tue, 18 Jun 2024 17:08:27 -0700 Subject: [PATCH] sample code for SIMD on ARM NEON --- jb/src/test_ua.py | 21 +++++++++++++++++++++ jb/src/update_ages.cpp | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 jb/src/test_ua.py diff --git a/jb/src/test_ua.py b/jb/src/test_ua.py new file mode 100644 index 0000000..e439211 --- /dev/null +++ b/jb/src/test_ua.py @@ -0,0 +1,21 @@ +import ctypes + +import numpy as np + +# Requires openmp installed via homebrew +# Use .so for suffix rather then .dylib even though we are on macOS +# g++ -shared -fPIC -O3 -flto -fpermissive -I/opt/homebrew/Cellar/libomp/18.1.6/include -std=c++11 -Xpreprocessor -fopenmp -L/opt/homebrew/Cellar/libomp/18.1.6/lib -lomp -o update_ages.so update_ages.cpp + +update_ages_lib = ctypes.CDLL("./update_ages.so") +update_ages_lib.update_ages.argtypes = [ + ctypes.c_size_t, # start_idx + ctypes.c_size_t, # stop_idx + np.ctypeslib.ndpointer(dtype=np.float32, flags="C_CONTIGUOUS"), +] + +ages = np.random.randint(-4, 28, 1024).astype(np.float32) +bges = np.array(ages) + +update_ages_lib.update_ages(0, len(ages), bges) +print(ages[0:16]) +print(bges[0:16]) diff --git a/jb/src/update_ages.cpp b/jb/src/update_ages.cpp index a4afeed..608159a 100644 --- a/jb/src/update_ages.cpp +++ b/jb/src/update_ages.cpp @@ -15,9 +15,16 @@ #include #include #include + +#ifdef __AVX2__ #include +#endif + +#ifdef __ARM_NEON +#include +#endif + -#define SIMD_WIDTH 8 // AVX2 processes 8 integers at a time unsigned recovered_counter = 0; extern "C" { @@ -51,6 +58,7 @@ void update_ages_vanilla(unsigned long int start_idx, unsigned long int stop_idx } } +#ifdef __AVX2__ // avxv2 void update_ages(unsigned long int start_idx, unsigned long int stop_idx, float *ages) { #pragma omp parallel for @@ -71,6 +79,26 @@ void update_ages(unsigned long int start_idx, unsigned long int stop_idx, float _mm256_storeu_ps(&ages[i], ages_vec); } } +#endif + +#ifdef __ARM_NEON +// ARM NEON version of update_ages() function +void update_ages(size_t start_idx, size_t stop_idx, float *ages) { + + const float32x4_t one_day_vec = vdupq_n_f32(one_day); // Create a vector with the constant value of one_day + const float32x4_t zero_vec = vdupq_n_f32(0.0f); // Create a vector with the constant value of zero + + // Iterate over the indices in the specified range, processing 4 elements at a time + for (size_t i = start_idx; i < stop_idx; i += 4) { + + float32x4_t current = vld1q_f32(&ages[i]); // Load the age values into a vector// Load 4 age values into a vector + uint32x4_t mask = vcgtq_f32(current, zero_vec); // Create a mask vector to identify the ages that are greater than zero + float32x4_t updated = vaddq_f32(current, one_day_vec); // Add one_day_vec to ages + float32x4_t write = vbslq_f32(mask, updated, current); // Use the mask to select the updated age values or the original age values + vst1q_f32(&ages[i], write); // Store the updated age values back into the array + } +} +#endif /* * Progress all infections. Collect the indexes of those who recover. @@ -227,12 +255,12 @@ void handle_new_infections_mp( int * num_eligible_agents_array ) { //printf( "handle_new_infections_mp: start_idx=%ld, end_idx=%ld.\n", start_idx, end_idx ); - std::unordered_map> node2sus; + std::unordered_map > node2sus; #pragma omp parallel { // Thread-local buffers to collect susceptible indices by node - std::unordered_map> local_node2sus; + std::unordered_map > local_node2sus; #pragma omp for nowait for (unsigned long int i = start_idx; i <= end_idx; ++i) {