Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

udpate ages with ARM NEON #15

Open
wants to merge 2 commits into
base: jbloedow/end_of_may_wip
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jb/src/sir_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def add_expansion_slots( columns, num_slots=settings.expansion_slots ):
num_slots = int(num_slots)
print( f"Adding {num_slots} expansion slots for future babies." )
new_ids = [ x for x in range( num_slots ) ]
new_nodes = np.ones( num_slots, dtype=np.uint32 )*-1
new_nodes = np.ones( num_slots, dtype=np.int32 )*-1
new_ages = np.ones( num_slots, dtype=np.float32 )*-1
new_infected = np.zeros( num_slots, dtype=bool )

Expand Down
21 changes: 21 additions & 0 deletions jb/src/test_ua.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import ctypes

import numpy as np

# Requires openmp installed via homebrew
# Use .so for suffix rather then .dylib even though we are on macOS
# g++ -shared -fPIC -O3 -flto -fpermissive -I/opt/homebrew/Cellar/libomp/18.1.6/include -std=c++11 -Xpreprocessor -fopenmp -L/opt/homebrew/Cellar/libomp/18.1.6/lib -lomp -o update_ages.so update_ages.cpp

update_ages_lib = ctypes.CDLL("./update_ages.so")
update_ages_lib.update_ages.argtypes = [
ctypes.c_size_t, # start_idx
ctypes.c_size_t, # stop_idx
np.ctypeslib.ndpointer(dtype=np.float32, flags="C_CONTIGUOUS"),
]

ages = np.random.randint(-4, 28, 1024).astype(np.float32)
bges = np.array(ages)

update_ages_lib.update_ages(0, len(ages), bges)
print(ages[0:16])
print(bges[0:16])
34 changes: 31 additions & 3 deletions jb/src/update_ages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
#include <math.h>
#include <pthread.h>
#include <omp.h>

#ifdef __AVX2__
#include <immintrin.h>
#endif

#ifdef __ARM_NEON
#include <arm_neon.h>
#endif


#define SIMD_WIDTH 8 // AVX2 processes 8 integers at a time
unsigned recovered_counter = 0;

extern "C" {
Expand Down Expand Up @@ -51,6 +58,7 @@ void update_ages_vanilla(unsigned long int start_idx, unsigned long int stop_idx
}
}

#ifdef __AVX2__
// avxv2
void update_ages(unsigned long int start_idx, unsigned long int stop_idx, float *ages) {
#pragma omp parallel for
Expand All @@ -71,6 +79,26 @@ void update_ages(unsigned long int start_idx, unsigned long int stop_idx, float
_mm256_storeu_ps(&ages[i], ages_vec);
}
}
#endif

#ifdef __ARM_NEON
// ARM NEON version of update_ages() function
void update_ages(size_t start_idx, size_t stop_idx, float *ages) {

const float32x4_t one_day_vec = vdupq_n_f32(one_day); // Create a vector with the constant value of one_day
const float32x4_t zero_vec = vdupq_n_f32(0.0f); // Create a vector with the constant value of zero

// Iterate over the indices in the specified range, processing 4 elements at a time
for (size_t i = start_idx; i < stop_idx; i += 4) {

float32x4_t current = vld1q_f32(&ages[i]); // Load the age values into a vector// Load 4 age values into a vector
uint32x4_t mask = vcgtq_f32(current, zero_vec); // Create a mask vector to identify the ages that are greater than zero
float32x4_t updated = vaddq_f32(current, one_day_vec); // Add one_day_vec to ages
float32x4_t write = vbslq_f32(mask, updated, current); // Use the mask to select the updated age values or the original age values
vst1q_f32(&ages[i], write); // Store the updated age values back into the array
}
}
#endif

/*
* Progress all infections. Collect the indexes of those who recover.
Expand Down Expand Up @@ -227,12 +255,12 @@ void handle_new_infections_mp(
int * num_eligible_agents_array
) {
//printf( "handle_new_infections_mp: start_idx=%ld, end_idx=%ld.\n", start_idx, end_idx );
std::unordered_map<int, std::vector<int>> node2sus;
std::unordered_map<int, std::vector<int> > node2sus;

#pragma omp parallel
{
// Thread-local buffers to collect susceptible indices by node
std::unordered_map<int, std::vector<int>> local_node2sus;
std::unordered_map<int, std::vector<int> > local_node2sus;

#pragma omp for nowait
for (unsigned long int i = start_idx; i <= end_idx; ++i) {
Expand Down
Loading