Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix check_labels function #5971

Open
wants to merge 7 commits into
base: branch-24.08
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 49 additions & 22 deletions python/cuml/cuml/prims/label/classlabels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,27 +50,41 @@


validate_kernel_str = r"""
({0} *x, int x_n, {0} *labels, int n_labels, int *out) {
({0} *labels, int n_labels,
{0} *classes, int n_classes,
int n_passes, int pass_size,
int *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
bool found = false;

extern __shared__ {0} label_cache[];
for(int i = threadIdx.x; i < n_labels; i+=blockDim.x)
label_cache[i] = labels[i];

if(tid >= x_n) return;

__syncthreads();
extern __shared__ {0} class_cache[];

int unmapped_class;
if (tid < n_labels)
unmapped_class = labels[tid];
for (int pass = 0; pass < n_passes; pass++) {
int offset = pass * pass_size;
int to_analyze = min(pass_size, n_classes - offset);
for (int i = threadIdx.x; i < to_analyze; i+=blockDim.x)
class_cache[i] = classes[offset + i];

__syncthreads();

if (!found && tid < n_labels) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found will always be false at this point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A thread with tid < n_labels does not check the presence of a class for a label, but it is still necessary to correctly fetch the data for every pass. The found value will indeed always be false, but a thread with tid < n_labels will never write the final output, it should thus be safe.

for(int i = 0; i < to_analyze; i++) {
if(class_cache[i] == unmapped_class) {
found = true;
break;
}
}
}

int unmapped_label = x[tid];
bool found = false;
for(int i = 0; i < n_labels; i++) {
if(label_cache[i] == unmapped_label) {
found = true;
break;
if (pass < n_passes - 1) {
__syncthreads();
} else {
if (!found && tid < n_labels) out[0] = 0;
}
}

if(!found) out[0] = 0;
}
"""

Expand Down Expand Up @@ -191,14 +205,27 @@ def check_labels(labels, classes) -> bool:
if labels.ndim != 1:
raise ValueError("Labels array must be 1D")

valid = cp.array([1])
n_labels = int(labels.shape[0])
n_classes = int(classes.shape[0])

smem = labels.dtype.itemsize * int(classes.shape[0])
device = cp.cuda.Device()
device_properties = device.attributes
shared_mem_per_block = device_properties["MaxSharedMemoryPerBlock"]
pass_size = min(
n_classes, math.floor(shared_mem_per_block / labels.dtype.itemsize)
)
n_passes = math.ceil(n_classes / pass_size)

threads_per_block = 512
n_blocks = math.ceil(n_labels / threads_per_block)
smem = labels.dtype.itemsize * pass_size

valid = cp.array([1])
validate = _validate_kernel(labels.dtype)
validate(
(math.ceil(labels.shape[0] / 32),),
(32,),
(labels, labels.shape[0], classes, classes.shape[0], valid),
(n_blocks,),
(threads_per_block,),
(labels, n_labels, classes, n_classes, n_passes, pass_size, valid),
shared_mem=smem,
)

Expand Down
17 changes: 16 additions & 1 deletion python/cuml/cuml/tests/test_prims.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -86,3 +86,18 @@ def test_monotonic_validate_invert_labels(arr_type, dtype, copy):
assert array_equal(monotonic, arr_orig)

assert array_equal(inverted, original)


def test_check_labels():
n_labels, n_classes = 1_000_000, 8000
labels = cp.random.choice(n_classes, size=n_labels)
classes = cp.arange(n_classes)

assert check_labels(labels, classes)
labels[534_122] = 9123
assert not check_labels(labels, classes)
labels[534_122] = 0
labels[11_728] = 9123
assert not check_labels(labels, classes)
labels[11_728] = 0
assert check_labels(labels, classes)
Loading