Skip to content

Commit

Permalink
sample code for SIMD on ARM NEON
Browse files Browse the repository at this point in the history
  • Loading branch information
clorton committed Jun 19, 2024
1 parent 1455ae2 commit 8832833
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
21 changes: 21 additions & 0 deletions jb/src/test_ua.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import ctypes

import numpy as np

# Requires openmp installed via homebrew
# Use .so for suffix rather then .dylib even though we are on macOS
# g++ -shared -fPIC -O3 -flto -fpermissive -I/opt/homebrew/Cellar/libomp/18.1.6/include -std=c++11 -Xpreprocessor -fopenmp -L/opt/homebrew/Cellar/libomp/18.1.6/lib -lomp -o update_ages.so update_ages.cpp

update_ages_lib = ctypes.CDLL("./update_ages.so")
update_ages_lib.update_ages.argtypes = [
ctypes.c_size_t, # start_idx
ctypes.c_size_t, # stop_idx
np.ctypeslib.ndpointer(dtype=np.float32, flags="C_CONTIGUOUS"),
]

ages = np.random.randint(-4, 28, 1024).astype(np.float32)
bges = np.array(ages)

update_ages_lib.update_ages(0, len(ages), bges)
print(ages[0:16])
print(bges[0:16])
34 changes: 31 additions & 3 deletions jb/src/update_ages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
#include <math.h>
#include <pthread.h>
#include <omp.h>

#ifdef __AVX2__
#include <immintrin.h>
#endif

#ifdef __ARM_NEON
#include <arm_neon.h>
#endif


#define SIMD_WIDTH 8 // AVX2 processes 8 integers at a time
unsigned recovered_counter = 0;

extern "C" {
Expand Down Expand Up @@ -51,6 +58,7 @@ void update_ages_vanilla(unsigned long int start_idx, unsigned long int stop_idx
}
}

#ifdef __AVX2__
// avxv2
void update_ages(unsigned long int start_idx, unsigned long int stop_idx, float *ages) {
#pragma omp parallel for
Expand All @@ -71,6 +79,26 @@ void update_ages(unsigned long int start_idx, unsigned long int stop_idx, float
_mm256_storeu_ps(&ages[i], ages_vec);
}
}
#endif

#ifdef __ARM_NEON
// ARM NEON version of update_ages() function
void update_ages(size_t start_idx, size_t stop_idx, float *ages) {

const float32x4_t one_day_vec = vdupq_n_f32(one_day); // Create a vector with the constant value of one_day
const float32x4_t zero_vec = vdupq_n_f32(0.0f); // Create a vector with the constant value of zero

// Iterate over the indices in the specified range, processing 4 elements at a time
for (size_t i = start_idx; i < stop_idx; i += 4) {

float32x4_t current = vld1q_f32(&ages[i]); // Load the age values into a vector// Load 4 age values into a vector
uint32x4_t mask = vcgtq_f32(current, zero_vec); // Create a mask vector to identify the ages that are greater than zero
float32x4_t updated = vaddq_f32(current, one_day_vec); // Add one_day_vec to ages
float32x4_t write = vbslq_f32(mask, updated, current); // Use the mask to select the updated age values or the original age values
vst1q_f32(&ages[i], write); // Store the updated age values back into the array
}
}
#endif

/*
* Progress all infections. Collect the indexes of those who recover.
Expand Down Expand Up @@ -227,12 +255,12 @@ void handle_new_infections_mp(
int * num_eligible_agents_array
) {
//printf( "handle_new_infections_mp: start_idx=%ld, end_idx=%ld.\n", start_idx, end_idx );
std::unordered_map<int, std::vector<int>> node2sus;
std::unordered_map<int, std::vector<int> > node2sus;

#pragma omp parallel
{
// Thread-local buffers to collect susceptible indices by node
std::unordered_map<int, std::vector<int>> local_node2sus;
std::unordered_map<int, std::vector<int> > local_node2sus;

#pragma omp for nowait
for (unsigned long int i = start_idx; i <= end_idx; ++i) {
Expand Down

0 comments on commit 8832833

Please sign in to comment.