-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x #3280
base: main
Are you sure you want to change the base?
Changes from all commits
3121e04
e4b82c8
82568da
061bf75
e05cf08
45aaeed
58805fb
ff92c96
746d9ec
9bcfa7c
56bd438
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,10 @@ def _(X: np.ndarray, *, axis: Literal[0, 1], dtype: DTypeLike) -> np.ndarray: | |
def _get_mean_var( | ||
X: _SupportedArray, *, axis: Literal[0, 1] = 0 | ||
) -> tuple[NDArray[np.float64], NDArray[np.float64]]: | ||
if isinstance(X, sparse.spmatrix): | ||
if isinstance(X, np.ndarray): | ||
n_threads = numba.get_num_threads() | ||
mean, var = _compute_mean_var_dense(X, axis=axis, n_threads=n_threads) | ||
elif isinstance(X, sparse.spmatrix): | ||
mean, var = sparse_mean_variance_axis(X, axis=axis) | ||
else: | ||
mean = axis_mean(X, axis=axis, dtype=np.float64) | ||
|
@@ -46,6 +49,42 @@ def _get_mean_var( | |
return mean, var | ||
|
||
|
||
@numba.njit(cache=True, parallel=True) | ||
def _compute_mean_var_dense( | ||
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1 | ||
) -> tuple[NDArray[np.float64], NDArray[np.float64]]: | ||
if axis == 0: | ||
axis_i = 1 | ||
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | ||
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | ||
mean = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
var = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
n = X.shape[axis] | ||
for i in numba.prange(n_threads): | ||
for r in range(i, n, n_threads): | ||
for c in range(X.shape[axis_i]): | ||
value = X[r, c] | ||
sums[i, c] += value | ||
sums_squared[i, c] += value * value | ||
for c in numba.prange(X.shape[axis_i]): | ||
sum_ = sums[:, c].sum() | ||
mean[c] = sum_ / n | ||
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1) | ||
else: | ||
axis_i = 0 | ||
mean = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
var = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
Comment on lines
+74
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pull this out of the if branch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean these assignments. When you have two branches, and both start with the same 3 lines, just do those before the if statement instead. |
||
for r in numba.prange(X.shape[0]): | ||
for c in range(X.shape[1]): | ||
value = X[r, c] | ||
mean[r] += value | ||
var[r] += value * value | ||
for c in numba.prange(X.shape[0]): | ||
mean[c] = mean[c] / X.shape[1] | ||
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no return statement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i have updated the code please check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, the tests aren’t passing, so it still doesn’t seem to be working |
||
|
||
return mean, var | ||
|
||
def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int): | ||
""" | ||
This code and internal functions are based on sklearns | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of adding a second code path that handles np.ndarray, you should replace the existing one above: