From 3b634d04218cb03865c66fea91fe88dae1db13d4 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 23 Jul 2024 14:01:21 +0200 Subject: [PATCH 1/4] Fix check_labels function --- python/cuml/cuml/prims/label/classlabels.py | 70 ++++++++++++++------- python/cuml/cuml/tests/test_prims.py | 10 +++ 2 files changed, 59 insertions(+), 21 deletions(-) diff --git a/python/cuml/cuml/prims/label/classlabels.py b/python/cuml/cuml/prims/label/classlabels.py index 978c436626..a59f5e0431 100644 --- a/python/cuml/cuml/prims/label/classlabels.py +++ b/python/cuml/cuml/prims/label/classlabels.py @@ -50,27 +50,41 @@ validate_kernel_str = r""" -({0} *x, int x_n, {0} *labels, int n_labels, int *out) { +({0} *labels, int n_labels, + {0} *classes, int n_classes, + int n_passes, int pass_size, + int *out) { int tid = blockDim.x * blockIdx.x + threadIdx.x; + bool found = false; - extern __shared__ {0} label_cache[]; - for(int i = threadIdx.x; i < n_labels; i+=blockDim.x) - label_cache[i] = labels[i]; - - if(tid >= x_n) return; - - __syncthreads(); + extern __shared__ {0} class_cache[]; + + int unmapped_class; + if (tid < n_labels) + unmapped_class = labels[tid]; + for (int pass = 0; pass < n_passes; pass++) { + int offset = pass * pass_size; + int to_analyze = min(pass_size, n_classes - offset); + for (int i = threadIdx.x; i < to_analyze; i+=blockDim.x) + class_cache[i] = classes[offset + i]; + + __syncthreads(); + + if (!found && tid < n_labels) { + for(int i = 0; i < to_analyze; i++) { + if(class_cache[i] == unmapped_class) { + found = true; + break; + } + } + } - int unmapped_label = x[tid]; - bool found = false; - for(int i = 0; i < n_labels; i++) { - if(label_cache[i] == unmapped_label) { - found = true; - break; + if (pass < n_passes - 1) { + __syncthreads(); + } else { + if (!found && tid < n_labels) out[0] = 0; } } - - if(!found) out[0] = 0; } """ @@ -191,14 +205,28 @@ def check_labels(labels, classes) -> bool: if labels.ndim != 1: raise ValueError("Labels array must be 1D") - valid = cp.array([1]) + n_labels = int(labels.shape[0]) + n_classes = int(classes.shape[0]) - smem = labels.dtype.itemsize * int(classes.shape[0]) + device = cp.cuda.Device() + device_properties = device.attributes + shared_mem_per_block = device_properties['MaxSharedMemoryPerBlock'] + pass_size = min(n_classes, math.floor(shared_mem_per_block / labels.dtype.itemsize)) + n_passes = math.ceil(n_classes / pass_size) + + threads_per_block = 512 + n_blocks = math.ceil(n_labels / threads_per_block) + smem = labels.dtype.itemsize * pass_size + + valid = cp.array([1]) validate = _validate_kernel(labels.dtype) validate( - (math.ceil(labels.shape[0] / 32),), - (32,), - (labels, labels.shape[0], classes, classes.shape[0], valid), + (n_blocks,), + (threads_per_block,), + (labels, n_labels, + classes, n_classes, + n_passes, pass_size, + valid), shared_mem=smem, ) diff --git a/python/cuml/cuml/tests/test_prims.py b/python/cuml/cuml/tests/test_prims.py index c15aa4cf24..7584840dd0 100644 --- a/python/cuml/cuml/tests/test_prims.py +++ b/python/cuml/cuml/tests/test_prims.py @@ -86,3 +86,13 @@ def test_monotonic_validate_invert_labels(arr_type, dtype, copy): assert array_equal(monotonic, arr_orig) assert array_equal(inverted, original) + + +def test_check_labels(): + n_labels, n_classes = 1_000_000, 8000 + labels = cp.random.choice(n_classes, size=n_labels) + classes = cp.arange(n_classes) + + assert check_labels(labels, classes) == True + labels[534_122] = 9123 + assert check_labels(labels, classes) == False From 8964859aa73012e66b3c2c0aac8d8808569e69bf Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 26 Jul 2024 15:31:42 +0200 Subject: [PATCH 2/4] Fix style --- python/cuml/cuml/prims/label/classlabels.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/prims/label/classlabels.py b/python/cuml/cuml/prims/label/classlabels.py index a59f5e0431..c287e80176 100644 --- a/python/cuml/cuml/prims/label/classlabels.py +++ b/python/cuml/cuml/prims/label/classlabels.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -211,7 +211,8 @@ def check_labels(labels, classes) -> bool: device = cp.cuda.Device() device_properties = device.attributes shared_mem_per_block = device_properties['MaxSharedMemoryPerBlock'] - pass_size = min(n_classes, math.floor(shared_mem_per_block / labels.dtype.itemsize)) + pass_size = min(n_classes, + math.floor(shared_mem_per_block / labels.dtype.itemsize)) n_passes = math.ceil(n_classes / pass_size) threads_per_block = 512 From 2ac5c9086042bb6fdc151778551b1159c4ea1210 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 29 Jul 2024 12:38:27 +0200 Subject: [PATCH 3/4] Fix style --- python/cuml/cuml/prims/label/classlabels.py | 12 +++++------- python/cuml/cuml/tests/test_prims.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/python/cuml/cuml/prims/label/classlabels.py b/python/cuml/cuml/prims/label/classlabels.py index c287e80176..e23c985b7e 100644 --- a/python/cuml/cuml/prims/label/classlabels.py +++ b/python/cuml/cuml/prims/label/classlabels.py @@ -210,9 +210,10 @@ def check_labels(labels, classes) -> bool: device = cp.cuda.Device() device_properties = device.attributes - shared_mem_per_block = device_properties['MaxSharedMemoryPerBlock'] - pass_size = min(n_classes, - math.floor(shared_mem_per_block / labels.dtype.itemsize)) + shared_mem_per_block = device_properties["MaxSharedMemoryPerBlock"] + pass_size = min( + n_classes, math.floor(shared_mem_per_block / labels.dtype.itemsize) + ) n_passes = math.ceil(n_classes / pass_size) threads_per_block = 512 @@ -224,10 +225,7 @@ def check_labels(labels, classes) -> bool: validate( (n_blocks,), (threads_per_block,), - (labels, n_labels, - classes, n_classes, - n_passes, pass_size, - valid), + (labels, n_labels, classes, n_classes, n_passes, pass_size, valid), shared_mem=smem, ) diff --git a/python/cuml/cuml/tests/test_prims.py b/python/cuml/cuml/tests/test_prims.py index 7584840dd0..df77777804 100644 --- a/python/cuml/cuml/tests/test_prims.py +++ b/python/cuml/cuml/tests/test_prims.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 05cea8a1be7bc52b407737465479593dcbbe85cc Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 30 Jul 2024 12:42:32 +0200 Subject: [PATCH 4/4] Fix style --- python/cuml/cuml/tests/test_prims.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/tests/test_prims.py b/python/cuml/cuml/tests/test_prims.py index df77777804..87a43f22a2 100644 --- a/python/cuml/cuml/tests/test_prims.py +++ b/python/cuml/cuml/tests/test_prims.py @@ -93,6 +93,11 @@ def test_check_labels(): labels = cp.random.choice(n_classes, size=n_labels) classes = cp.arange(n_classes) - assert check_labels(labels, classes) == True + assert check_labels(labels, classes) labels[534_122] = 9123 - assert check_labels(labels, classes) == False + assert not check_labels(labels, classes) + labels[534_122] = 0 + labels[11_728] = 9123 + assert not check_labels(labels, classes) + labels[11_728] = 0 + assert check_labels(labels, classes)