Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multithreading #496

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions bench/large_array_vs_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#################################################################################
# To mimic the scenario that computation is i/o bound and constrained by memory
#
# It's a much simplified version that the chunk is computed in a loop,
# and expression is evaluated in a sequence, which is not true in reality.
# Neverthless, numexpr outperforms numpy.
#################################################################################
"""
Benchmarking Expression 1:
NumPy time (threaded over 32 chunks with 2 threads): 4.612313 seconds
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 0.951172 seconds
numexpr speedup: 4.85x
----------------------------------------
Benchmarking Expression 2:
NumPy time (threaded over 32 chunks with 2 threads): 23.862752 seconds
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 2.182058 seconds
numexpr speedup: 10.94x
----------------------------------------
Benchmarking Expression 3:
NumPy time (threaded over 32 chunks with 2 threads): 20.594895 seconds
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 2.927881 seconds
numexpr speedup: 7.03x
----------------------------------------
Benchmarking Expression 4:
NumPy time (threaded over 32 chunks with 2 threads): 12.834101 seconds
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 5.392480 seconds
numexpr speedup: 2.38x
----------------------------------------
"""

import os

os.environ["NUMEXPR_NUM_THREADS"] = "16"
import numpy as np
import numexpr as ne
import timeit
import threading

array_size = 10**8
num_runs = 10
num_chunks = 32 # Number of chunks
num_threads = 2 # Number of threads constrained by how many chunks memory can hold

a = np.random.rand(array_size).reshape(10**4, -1)
b = np.random.rand(array_size).reshape(10**4, -1)
c = np.random.rand(array_size).reshape(10**4, -1)

chunk_size = array_size // num_chunks

expressions_numpy = [
lambda a, b, c: a + b * c,
lambda a, b, c: a**2 + b**2 - 2 * a * b * np.cos(c),
lambda a, b, c: np.sin(a) + np.log(b) * np.sqrt(c),
lambda a, b, c: np.exp(a) + np.tan(b) - np.sinh(c),
]

expressions_numexpr = [
"a + b * c",
"a**2 + b**2 - 2 * a * b * cos(c)",
"sin(a) + log(b) * sqrt(c)",
"exp(a) + tan(b) - sinh(c)",
]


def benchmark_numpy_chunk(func, a, b, c, results, indices):
for index in indices:
start = index * chunk_size
end = (index + 1) * chunk_size
time_taken = timeit.timeit(
lambda: func(a[start:end], b[start:end], c[start:end]), number=num_runs
)
results.append(time_taken)


def benchmark_numexpr_re_evaluate(expr, a, b, c, results, indices):
for index in indices:
start = index * chunk_size
end = (index + 1) * chunk_size
if index == 0:
# Evaluate the first chunk with evaluate
time_taken = timeit.timeit(
lambda: ne.evaluate(
expr,
local_dict={
"a": a[start:end],
"b": b[start:end],
"c": c[start:end],
},
),
number=num_runs,
)
else:
# Re-evaluate subsequent chunks with re_evaluate
time_taken = timeit.timeit(
lambda: ne.re_evaluate(
local_dict={"a": a[start:end], "b": b[start:end], "c": c[start:end]}
),
number=num_runs,
)
results.append(time_taken)


def run_benchmark_threaded():
chunk_indices = list(range(num_chunks))

for i in range(len(expressions_numpy)):
print(f"Benchmarking Expression {i+1}:")

results_numpy = []
results_numexpr = []

threads_numpy = []
for j in range(num_threads):
indices = chunk_indices[j::num_threads] # Distribute chunks across threads
thread = threading.Thread(
target=benchmark_numpy_chunk,
args=(expressions_numpy[i], a, b, c, results_numpy, indices),
)
threads_numpy.append(thread)
thread.start()

for thread in threads_numpy:
thread.join()

numpy_time = sum(results_numpy)
print(
f"NumPy time (threaded over {num_chunks} chunks with {num_threads} threads): {numpy_time:.6f} seconds"
)

threads_numexpr = []
for j in range(num_threads):
indices = chunk_indices[j::num_threads] # Distribute chunks across threads
thread = threading.Thread(
target=benchmark_numexpr_re_evaluate,
args=(expressions_numexpr[i], a, b, c, results_numexpr, indices),
)
threads_numexpr.append(thread)
thread.start()

for thread in threads_numexpr:
thread.join()

numexpr_time = sum(results_numexpr)
print(
f"numexpr time (threaded with re_evaluate over {num_chunks} chunks with {num_threads} threads): {numexpr_time:.6f} seconds"
)
print(f"numexpr speedup: {numpy_time / numexpr_time:.2f}x")
print("-" * 40)


if __name__ == "__main__":
run_benchmark_threaded()
8 changes: 3 additions & 5 deletions numexpr/necompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE
from numexpr import interpreter, expressions, use_vml
from numexpr.utils import CacheDict
from numexpr.utils import CacheDict, ContextDict

# Declare a double type that does not exist in Python space
double = numpy.double
Expand Down Expand Up @@ -776,11 +776,9 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2):
# Dictionaries for caching variable names and compiled expressions
_names_cache = CacheDict(256)
_numexpr_cache = CacheDict(256)
_numexpr_last = {}
_numexpr_last = ContextDict()
evaluate_lock = threading.Lock()

# MAYBE: decorate this function to add attributes instead of having the
# _numexpr_last dictionary?
def validate(ex: str,
local_dict: Optional[Dict] = None,
global_dict: Optional[Dict] = None,
Expand Down Expand Up @@ -887,7 +885,7 @@ def validate(ex: str,
compiled_ex = _numexpr_cache[numexpr_key] = NumExpr(ex, signature, sanitize=sanitize, **context)
kwargs = {'out': out, 'order': order, 'casting': casting,
'ex_uses_vml': ex_uses_vml}
_numexpr_last = dict(ex=compiled_ex, argnames=names, kwargs=kwargs)
_numexpr_last.set(ex=compiled_ex, argnames=names, kwargs=kwargs)
except Exception as e:
return e
return None
Expand Down
72 changes: 72 additions & 0 deletions numexpr/tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,7 @@ def run(self):
test.join()

def test_multithread(self):

import threading

# Running evaluate() from multiple threads shouldn't crash
Expand All @@ -1218,6 +1219,77 @@ def work(n):
for t in threads:
t.join()

def test_thread_safety(self):
"""
Expected output

When not safe (before the pr this test is commited)
AssertionError: Thread-0 failed: result does not match expected

When safe (after the pr this test is commited)
Should pass without failure
"""
import threading
import time

barrier = threading.Barrier(4)

# Function that each thread will run with different expressions
def thread_function(a_value, b_value, expression, expected_result, results, index):
validate(expression, local_dict={"a": a_value, "b": b_value})
# Wait for all threads to reach this point
# such that they all set _numexpr_last
barrier.wait()

# Simulate some work or a context switch delay
time.sleep(0.1)

result = re_evaluate(local_dict={"a": a_value, "b": b_value})
results[index] = np.array_equal(result, expected_result)

def test_thread_safety_with_numexpr():
num_threads = 4
array_size = 1000000

expressions = [
"a + b",
"a - b",
"a * b",
"a / b"
]

a_value = [np.full(array_size, i + 1) for i in range(num_threads)]
b_value = [np.full(array_size, (i + 1) * 2) for i in range(num_threads)]

expected_results = [
a_value[i] + b_value[i] if expr == "a + b" else
a_value[i] - b_value[i] if expr == "a - b" else
a_value[i] * b_value[i] if expr == "a * b" else
a_value[i] / b_value[i] if expr == "a / b" else None
for i, expr in enumerate(expressions)
]

results = [None] * num_threads
threads = []

# Create and start threads with different expressions
for i in range(num_threads):
thread = threading.Thread(
target=thread_function,
args=(a_value[i], b_value[i], expressions[i], expected_results[i], results, i)
)
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

for i in range(num_threads):
if not results[i]:
self.fail(f"Thread-{i} failed: result does not match expected")

test_thread_safety_with_numexpr()


# The worker function for the subprocess (needs to be here because Windows
# has problems pickling nested functions with the multiprocess module :-/)
Expand Down
81 changes: 81 additions & 0 deletions numexpr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import os
import subprocess
import contextvars

from numexpr.interpreter import _set_num_threads, _get_num_threads, MAX_THREADS
from numexpr import use_vml
Expand Down Expand Up @@ -226,3 +227,83 @@ def __setitem__(self, key, value):
super(CacheDict, self).__delitem__(k)
super(CacheDict, self).__setitem__(key, value)


class ContextDict:
"""
A context aware version dictionary
"""
def __init__(self):
self._context_data = contextvars.ContextVar('context_data', default={})

def set(self, key=None, value=None, **kwargs):
data = self._context_data.get().copy()

if key is not None:
data[key] = value

for k, v in kwargs.items():
data[k] = v

self._context_data.set(data)

def get(self, key, default=None):
data = self._context_data.get()
return data.get(key, default)

def delete(self, key):
data = self._context_data.get().copy()
if key in data:
del data[key]
self._context_data.set(data)

def clear(self):
self._context_data.set({})

def all(self):
return self._context_data.get()

def update(self, *args, **kwargs):
data = self._context_data.get().copy()

if args:
if len(args) > 1:
raise TypeError(f"update() takes at most 1 positional argument ({len(args)} given)")
other = args[0]
if isinstance(other, dict):
data.update(other)
else:
for k, v in other:
data[k] = v

data.update(kwargs)
self._context_data.set(data)

def keys(self):
return self._context_data.get().keys()

def values(self):
return self._context_data.get().values()

def items(self):
return self._context_data.get().items()

def __getitem__(self, key):
return self.get(key)

def __setitem__(self, key, value):
self.set(key, value)

def __delitem__(self, key):
self.delete(key)

def __contains__(self, key):
return key in self._context_data.get()

def __len__(self):
return len(self._context_data.get())

def __iter__(self):
return iter(self._context_data.get())

def __repr__(self):
return repr(self._context_data.get())
Loading