diff --git a/README.md b/README.md index 529c67a..4ec2edb 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0. - > Warning: only works with pytorch > v1.3 and CUDA >= v10.1 ## Description @@ -24,39 +23,60 @@ the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a ## Installation +After setting up an environment with pytorch >= 1.3, run either of these commands from the root folder of the repo: -Just `python setup.py install`, in the root folder of this repo. This will compile -and install the torchsearchsorted module. -be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is) +```bash +pip install -v . +``` -be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3 +```bash +python setup.py install -v +``` -## Usage +The verbose flag `-v` is not mandatory, but it will print whether the installer was able to find `nvcc` and install the CUDA version of `torchsearchsorted`. +If you're having problems with the installation, make sure `nvcc` and `gcc` are installed and available in your path, e.g.: +```bash +export PATH="/usr/local/cuda/bin:${PATH}" +export CPATH="/usr/local/cuda/include:${CPATH}" -Just import the torchsearchsorted package after installation. I typically do: +which gcc +which nvcc +pip install -v . ``` + +## Usage + +```python +import torch from torchsearchsorted import searchsorted + +a = torch.sort(torch.randn(5000, 300, device='cuda'), dim=1)[0] +v = torch.randn(5000, 100, device='cuda') +out = searchsorted(a, v) ``` -## Testing +## Testing and benchmarking -Try `python test.py` with `torch` available for an example. +Install test dependencies and run the unit tests: +```bash +pip install '.[test]' +pytest -v +``` +Run [benchmark.py](examples/benchmark.py) for a speed comparison: +```bash +python examples/benchmark.py ``` -Looking for 50000x1000 values in 50000x300 entries -NUMPY: searchsorted in 4851.592ms -CPU: searchsorted in 4805.432ms - difference between CPU and NUMPY: 0.000 -GPU: searchsorted in 1.055ms - difference between GPU and NUMPY: 0.000 - -Looking for 50000x1000 values in 50000x300 entries -NUMPY: searchsorted in 4333.964ms -CPU: searchsorted in 4753.958ms - difference between CPU and NUMPY: 0.000 -GPU: searchsorted in 0.391ms - difference between GPU and NUMPY: 0.000 +```text +Benchmark searchsorted: +- a [5000 x 300] +- v [5000 x 100] +- reporting fastest time of 20 runs +- each run executes searchsorted 100 times + +Numpy: 3.4524286500000017 +CPU: 10.617608329001087 +CUDA: 0.00124932999824523 ``` -The first run comprises the time of allocation, while the second one does not. diff --git a/examples/benchmark.py b/examples/benchmark.py new file mode 100644 index 0000000..1373995 --- /dev/null +++ b/examples/benchmark.py @@ -0,0 +1,66 @@ +import timeit + +import torch +import numpy as np +from torchsearchsorted import searchsorted, numpy_searchsorted + +B = 5_000 +A = 300 +V = 100 + +repeats = 20 +number = 100 + +print( + f'Benchmark searchsorted:', + f'- a [{B} x {A}]', + f'- v [{B} x {V}]', + f'- reporting fastest time of {repeats} runs', + f'- each run executes searchsorted {number} times', + sep='\n', + end='\n\n' +) + + +def get_arrays(): + a = np.sort(np.random.randn(B, A), axis=1) + v = np.random.randn(B, V) + out = np.empty_like(v, dtype=np.long) + return a, v, out + + +def get_tensors(device): + a = torch.sort(torch.randn(B, A, device=device), dim=1)[0] + v = torch.randn(B, V, device=device) + out = torch.empty(B, V, device=device, dtype=torch.long) + return a, v, out + + +numpy = timeit.repeat( + stmt="numpy_searchsorted(a, v, out, side='left')", + setup="a, v, out = get_arrays()", + globals=globals(), + repeat=repeats, + number=number +) +print('Numpy: ', min(numpy), sep='\t') + +cpu = timeit.repeat( + stmt="searchsorted(a, v, out, side='left')", + setup="a, v, out = get_tensors(device='cpu')", + globals=globals(), + repeat=repeats, + number=number +) +print('CPU: ', min(cpu), sep='\t') + +if torch.cuda.is_available(): + gpu = timeit.repeat( + stmt="searchsorted(a, v, out, side='left')", + setup="a, v, out = get_tensors(device='cuda')", + globals=globals(), + repeat=repeats, + number=number + ) + print('CUDA: ', min(gpu), sep='\t') + diff --git a/examples/test.py b/examples/test.py index 1caead5..760f811 100644 --- a/examples/test.py +++ b/examples/test.py @@ -33,7 +33,7 @@ # v = torch.tensor([[1.]]) t0 = time.time() - test_NP = torch.tensor(numpy_searchsorted(a, v, side)) + test_NP = torch.tensor(numpy_searchsorted(a, v, side=side)) print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0))) t0 = time.time() test_CPU = searchsorted(a, v, test_CPU, side) diff --git a/setup.py b/setup.py index 092bcd1..23a46d2 100644 --- a/setup.py +++ b/setup.py @@ -8,13 +8,17 @@ ['src/cpu/searchsorted_cpu_wrapper.cpp']), ] -# If nvcc is available, add the CUDA extension +# If nvcc is available, add the CUDA extension, messages are +# printed when using `pip install -v .` or `python setup.py -v install` if CUDA_HOME: + print('torchsearchsorted will be installed with CUDA support') modules.append( CUDAExtension('torchsearchsorted.cuda', ['src/cuda/searchsorted_cuda_wrapper.cpp', 'src/cuda/searchsorted_cuda_kernel.cu']) ) +else: + print('torchsearchsorted will be installed for CPU only') tests_require = [ 'pytest', diff --git a/src/cpu/searchsorted_cpu_wrapper.cpp b/src/cpu/searchsorted_cpu_wrapper.cpp index 617d1d1..3a7e77d 100644 --- a/src/cpu/searchsorted_cpu_wrapper.cpp +++ b/src/cpu/searchsorted_cpu_wrapper.cpp @@ -1,121 +1,91 @@ #include "searchsorted_cpu_wrapper.h" #include -int eval(float val, float *a, int64_t row, int64_t col, int64_t ncol, bool side_left) -{ - /* Evaluates whether a[row,col] < val <= a[row, col+1]*/ - if (col == ncol - 1) - { - // special case: we are on the right border - if (a[row * ncol + col] <= val){ - return 1;} - else { - return -1;} - } - bool is_lower; - bool is_next_higher; - - if (side_left) { - // a[row, col] < v <= a[row, col+1] - is_lower = (a[row * ncol + col] < val); - is_next_higher = (a[row*ncol + col + 1] >= val); - } else { - // a[row, col] <= v < a[row, col+1] - is_lower = (a[row * ncol + col] <= val); - is_next_higher = (a[row * ncol + col + 1] > val); - } - if (is_lower && is_next_higher) { - // we found the right spot - return 0; - } else if (is_lower) { - // answer is on the right side - return 1; +int64_t bisect_left(float *array, float value, int64_t left, int64_t right) { +/** + * Locate the insertion point of a value in a sorted array that would + * maintain the array sorted, i.e. the index i such that: + * array[i] <= value < array[i + 1] + * Only the index range [right, left) is considered. + * + * If the value is already present in the array, the returned index would + * insert the value to the left of any existing entry. + * If value is < than every element, the returned index is equal to left. + * If value is >= than every element, the returned index is equal to right. + */ + int64_t mid; + while (left < right) { + mid = (left + right) / 2; + if (value > array[mid]) { + left = mid + 1; } else { - // answer is on the left side - return -1; + right = mid; } + } + return left; } -int64_t binary_search(float *a, int64_t row, float val, int64_t ncol, bool side_left) -{ - /* Look for the value `val` within row `row` of matrix `a`, which - has `ncol` columns. - - the `a` matrix is assumed sorted in increasing order, row-wise - - returns: - * -1 if `val` is smaller than the smallest value found within that row of `a` - * `ncol` - 1 if `val` is larger than the largest element of that row of `a` - * Otherwise, return the column index `res` such that: - - a[row, col] < val <= a[row, col+1]. (if side_left), or - - a[row, col] < val <= a[row, col+1] (if not side_left). - */ - - //start with left at 0 and right at number of columns of a - int64_t right = ncol; - int64_t left = 0; - - while (right >= left) { - // take the midpoint of current left and right cursors - int64_t mid = left + (right-left)/2; - - // check the relative position of val: are we good here ? - int rel_pos = eval(val, a, row, mid, ncol, side_left); - // we found the point - if(rel_pos == 0) { - return mid; - } else if (rel_pos > 0) { - if (mid==ncol-1){return ncol-1;} - // the answer is on the right side - left = mid; - } else { - if (mid==0){return -1;} - right = mid; - } +int64_t bisect_right(float *array, float value, int64_t left, int64_t right) { +/** + * Locate the insertion point of a value in a sorted array that would + * maintain the array sorted, i.e. the index i such that: + * array[i] < value <= array[i + 1] + * Only the index range [right, left) is considered. + * + * If the value is already present in the array, the returned index would + * insert the value to the right of any existing entry. + * If value is <= than every element, the returned index is equal to left. + * If value is > than every element, the returned index is equal to right. + */ + int64_t mid; + while (left < right) { + mid = (left + right) / 2; + if (value >= array[mid]) { + left = mid + 1; + } else { + right = mid; + } } - return -1; + return left; } -void searchsorted_cpu_wrapper( - at::Tensor a, - at::Tensor v, - at::Tensor res, - bool side_left) -{ - - // Get the dimensions - auto nrow_a = a.size(/*dim=*/0); - auto ncol_a = a.size(/*dim=*/1); - auto nrow_v = v.size(/*dim=*/0); - auto ncol_v = v.size(/*dim=*/1); - - auto nrow_res = fmax(nrow_a, nrow_v); - - //auto acc_v = v.accessor(); - //auto acc_res = res.accessor(); - - float *a_data = a.data(); - float *v_data = v.data(); - - for (int64_t row = 0; row < nrow_res; row++) - { - for (int64_t col = 0; col < ncol_v; col++) - { - // get the value to look for - int64_t row_in_v = (nrow_v == 1) ? 0 : row; - int64_t row_in_a = (nrow_a == 1) ? 0 : row; - - int64_t idx_in_v = row_in_v * ncol_v + col; - int64_t idx_in_res = row * ncol_v + col; - - // apply binary search - res.data()[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1); - }} +void searchsorted_cpu_wrapper( + at::Tensor a, + at::Tensor v, + at::Tensor res, + bool side_left) { + float *a_data = a.data_ptr(); + float *v_data = v.data_ptr(); + int64_t *res_data = res.data_ptr(); + + int64_t (*bisect)(float*, float, int64_t, int64_t); + if (side_left) { + bisect = &bisect_left; + } else { + bisect = &bisect_right; } - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)"); + for (int64_t i = 0; i < v.size(0); i++) { + // Search values in the range [left, right), i.e. an entire row of a + int64_t left = i * a.stride(0); + int64_t right = i * a.stride(0) + a.size(1); + + for (int64_t j = 0; j < v.size(1); j++) { + // idx_v is the location of the value in the flattened tensor v + // idx_res is the where the result will go in the flattened tensor res + int64_t idx_v = i * v.stride(0) + j * v.stride(1); + int64_t idx_res = i * res.stride(0) + j * res.stride(1); + // idx is the insertion index in the flattened tensor a + int64_t idx = (*bisect)(a_data, v_data[idx_v], left, right); + res_data[idx_res] = idx - i * a.stride(0); + } } +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)"); +} diff --git a/src/cuda/searchsorted_cuda_kernel.cu b/src/cuda/searchsorted_cuda_kernel.cu index af6ed27..485a8bd 100644 --- a/src/cuda/searchsorted_cuda_kernel.cu +++ b/src/cuda/searchsorted_cuda_kernel.cu @@ -1,142 +1,135 @@ +#include +#include + #include "searchsorted_cuda_kernel.h" + template __device__ -int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left) -{ - /* Evaluates whether a[row,col] < val <= a[row, col+1]*/ - - if (col == ncol - 1) - { - // special case: we are on the right border - if (a[row * ncol + col] <= val){ - return 1;} - else { - return -1;} - } - bool is_lower; - bool is_next_higher; - - if (side_left) { - // a[row, col] < v <= a[row, col+1] - is_lower = (a[row * ncol + col] < val); - is_next_higher = (a[row*ncol + col + 1] >= val); - } else { - // a[row, col] <= v < a[row, col+1] - is_lower = (a[row * ncol + col] <= val); - is_next_higher = (a[row * ncol + col + 1] > val); - } - if (is_lower && is_next_higher) { - // we found the right spot - return 0; - } else if (is_lower) { - // answer is on the right side - return 1; +int64_t bisect_left(scalar_t *array, scalar_t value, int64_t left, int64_t right) { +/** + * Locate the insertion point of a value in a sorted array that would + * maintain the array sorted, i.e. the index i such that: + * array[i] <= value < array[i + 1] + * Only the index range [right, left) is considered. + * + * If the value is already present in the array, the returned index would + * insert the value to the left of any existing entry. + * If value is < than every element, the returned index is equal to left. + * If value is >= than every element, the returned index is equal to right. + */ + int64_t mid; + while (left < right) { + mid = (left + right) / 2; + if (value > array[mid]) { + left = mid + 1; } else { - // answer is on the left side - return -1; + right = mid; } + } + return left; } + template __device__ -int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left) -{ - /* Look for the value `val` within row `row` of matrix `a`, which - has `ncol` columns. - - the `a` matrix is assumed sorted in increasing order, row-wise - - Returns - * -1 if `val` is smaller than the smallest value found within that row of `a` - * `ncol` - 1 if `val` is larger than the largest element of that row of `a` - * Otherwise, return the column index `res` such that: - - a[row, col] < val <= a[row, col+1]. (if side_left), or - - a[row, col] < val <= a[row, col+1] (if not side_left). - */ - - //start with left at 0 and right at number of columns of a - int64_t right = ncol; - int64_t left = 0; - - while (right >= left) { - // take the midpoint of current left and right cursors - int64_t mid = left + (right-left)/2; - - // check the relative position of val: are we good here ? - int rel_pos = eval(val, a, row, mid, ncol, side_left); - // we found the point - if(rel_pos == 0) { - return mid; - } else if (rel_pos > 0) { - if (mid==ncol-1){return ncol-1;} - // the answer is on the right side - left = mid; - } else { - if (mid==0){return -1;} - right = mid; - } +int64_t bisect_right(scalar_t *array, scalar_t value, int64_t left, int64_t right) { +/** + * Locate the insertion point of a value in a sorted array that would + * maintain the array sorted, i.e. the index i such that: + * array[i] < value <= array[i + 1] + * Only the index range [right, left) is considered. + * + * If the value is already present in the array, the returned index would + * insert the value to the right of any existing entry. + * If value is <= than every element, the returned index is equal to left. + * If value is > than every element, the returned index is equal to right. + */ + int64_t mid; + while (left < right) { + mid = (left + right) / 2; + if (value >= array[mid]) { + left = mid + 1; + } else { + right = mid; + } } - return -1; + return left; } + template __global__ void searchsorted_kernel( - int64_t *res, - scalar_t *a, - scalar_t *v, - int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left) -{ - // get current row and column - int64_t row = blockIdx.y*blockDim.y+threadIdx.y; - int64_t col = blockIdx.x*blockDim.x+threadIdx.x; - - // check whether we are outside the bounds of what needs be computed. - if ((row >= nrow_res) || (col >= ncol_v)) { - return;} - - // get the value to look for - int64_t row_in_v = (nrow_v==1) ? 0: row; - int64_t row_in_a = (nrow_a==1) ? 0: row; - int64_t idx_in_v = row_in_v*ncol_v+col; - int64_t idx_in_res = row*ncol_v+col; + at::cuda::detail::TensorInfo a, + at::cuda::detail::TensorInfo v, + at::cuda::detail::TensorInfo res, + bool side_left) { + int64_t i = blockIdx.x * blockDim.x + threadIdx.x; + int64_t j = blockIdx.y * blockDim.y + threadIdx.y; + + if ((i >= res.sizes[0]) || (j >= res.sizes[1])) { + return; + } - // apply binary search - res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1; + // Search values in the range [left, right), i.e. an entire row of a + int64_t left = i * a.strides[0]; + int64_t right = i * a.strides[0] + a.sizes[1]; + + // idx_v is the location of the value in the flattened tensor v + // idx_res is the where the result will go in the flattened tensor res + int64_t idx_v = i * v.strides[0] + j * v.strides[1]; + int64_t idx_res = i * res.strides[0] + j * res.strides[1]; + + // idx is the insertion index in the flattened tensor a + int64_t idx; + /* TODO this "if" works, but would be nicer to use function pointers: + * check side_left in searchsorted_cuda (on CPU) and pass the right pointer + * to the kernels (on GPU), but the fact that the bisect functions are + * templated and are defined with __device__ makes it hard to get the pointers + * right (the address on the CPU and on the GPU are different), see + * https://stackoverflow.com/questions/15644261/cuda-function-pointers + */ + if (side_left) { + idx = bisect_left(a.data, v.data[idx_v], left, right); + } else { + idx = bisect_right(a.data, v.data[idx_v], left, right); + } + res.data[idx_res] = idx - i * a.strides[0]; } - +__host__ void searchsorted_cuda( - at::Tensor a, - at::Tensor v, - at::Tensor res, - bool side_left){ - - // Get the dimensions - auto nrow_a = a.size(/*dim=*/0); - auto nrow_v = v.size(/*dim=*/0); - auto ncol_a = a.size(/*dim=*/1); - auto ncol_v = v.size(/*dim=*/1); - - auto nrow_res = fmax(double(nrow_a), double(nrow_v)); - - // prepare the kernel configuration - dim3 threads(ncol_v, nrow_res); - dim3 blocks(1, 1); - if (nrow_res*ncol_v > 1024){ - threads.x = int(fmin(double(1024), double(ncol_v))); - threads.y = floor(1024/threads.x); - blocks.x = ceil(double(ncol_v)/double(threads.x)); - blocks.y = ceil(double(nrow_res)/double(threads.y)); - } - - AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] { - searchsorted_kernel<<>>( - res.data(), - a.data(), - v.data(), - nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left); - })); - - } + at::Tensor a, + at::Tensor v, + at::Tensor res, + bool side_left) { + // Kernel configuration: + // - 2D grid of size v.size(0) x v.size(1) + // - The grid is partitioned in blocks of 256 x 4 + // - Each thread [i, j] will search for the value v[i, j] in the i-th row of a + dim3 threads(256, 4); + dim3 blocks( + (v.size(0) + threads.x - 1) / threads.x, + (v.size(1) + threads.y - 1) / threads.y + ); + + AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] { + /* Related to the comment in searchsorted_kernel, getting the address of a + * __device__ function from a __host__ function isn't straightforward, + * but here's a start + */ + // int64_t (*bisect)(scalar_t*, scalar_t, int64_t, int64_t); + // if (side_left) { + // bisect = &bisect_left; + // } else { + // bisect = &bisect_right; + // } + + searchsorted_kernel<<>>( + at::cuda::detail::getTensorInfo(a), + at::cuda::detail::getTensorInfo(v), + at::cuda::detail::getTensorInfo(res), + side_left); + })); +} diff --git a/src/cuda/searchsorted_cuda_wrapper.cpp b/src/cuda/searchsorted_cuda_wrapper.cpp index c11372e..157f25a 100644 --- a/src/cuda/searchsorted_cuda_wrapper.cpp +++ b/src/cuda/searchsorted_cuda_wrapper.cpp @@ -3,14 +3,17 @@ // C++ interface #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_INPUT(x) CHECK_CUDA(x); -void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left) +void searchsorted_cuda_wrapper( + at::Tensor a, + at::Tensor v, + at::Tensor res, + bool side_left) { - CHECK_INPUT(a); - CHECK_INPUT(v); - CHECK_INPUT(res); + CHECK_INPUT(a); + CHECK_INPUT(v); + CHECK_INPUT(res); searchsorted_cuda(a, v, res, side_left); } diff --git a/src/torchsearchsorted/__init__.py b/src/torchsearchsorted/__init__.py index fc30292..8a99aa2 100644 --- a/src/torchsearchsorted/__init__.py +++ b/src/torchsearchsorted/__init__.py @@ -1,2 +1,6 @@ from .searchsorted import searchsorted from .utils import numpy_searchsorted + +__all__ = [ + 'searchsorted', +] diff --git a/src/torchsearchsorted/searchsorted.py b/src/torchsearchsorted/searchsorted.py index aaca900..e5eea4e 100644 --- a/src/torchsearchsorted/searchsorted.py +++ b/src/torchsearchsorted/searchsorted.py @@ -1,53 +1,60 @@ +import warnings from typing import Optional import torch -# trying to import the CPU searchsorted -SEARCHSORTED_CPU_AVAILABLE = True -try: - from torchsearchsorted.cpu import searchsorted_cpu_wrapper -except ImportError: - SEARCHSORTED_CPU_AVAILABLE = False +from torchsearchsorted.cpu import searchsorted_cpu_wrapper -# trying to import the CUDA searchsorted -SEARCHSORTED_GPU_AVAILABLE = True -try: - from torchsearchsorted.cuda import searchsorted_cuda_wrapper -except ImportError: - SEARCHSORTED_GPU_AVAILABLE = False +if torch.cuda.is_available(): + try: + from torchsearchsorted.cuda import searchsorted_cuda_wrapper + except ImportError as e: + warnings.warn("PyTorch is installed with CUDA support, but " + "torchsearchsorted for CUDA was not installed, " + "please repeat the installation or avoid passing " + "CUDA tensors to the `searchsorted`.") -def searchsorted(a: torch.Tensor, v: torch.Tensor, +def searchsorted(a: torch.Tensor, + v: torch.Tensor, out: Optional[torch.LongTensor] = None, side='left') -> torch.LongTensor: - assert len(a.shape) == 2, "input `a` must be 2-D." - assert len(v.shape) == 2, "input `v` mus(t be 2-D." - assert (a.shape[0] == v.shape[0] - or a.shape[0] == 1 - or v.shape[0] == 1), ("`a` and `v` must have the same number of " - "rows or one of them must have only one ") - assert a.device == v.device, '`a` and `v` must be on the same device' - - result_shape = (max(a.shape[0], v.shape[0]), v.shape[1]) + if a.ndimension() != 2: + raise ValueError(f"Input `a` must be 2D, got shape {a.shape}") + if v.ndimension() != 2: + raise ValueError(f"Input `v` must be 2D, got shape {v.shape}") + if a.device != v.device: + raise ValueError(f"Inputs `a` and `v` must on the same device, " + f"got {a.device} and {v.device}") + + a, v = broadcast_tensors(a, v, dim=0) + if out is not None: - assert out.device == a.device, "`out` must be on the same device as `a`" - assert out.dtype == torch.long, "out.dtype must be torch.long" - assert out.shape == result_shape, ("If the output tensor is provided, " - "its shape must be correct.") + if out.shape != v.shape: + raise ValueError(f"Output `out` must have the same shape as `v`, " + f"got {out.shape} and {a.shape}") + if out.device != v.device: + raise ValueError(f"Output `out` must be on the same device as `v`" + f"device, got {out.device} and {v.device}") + if out.dtype != torch.long: + raise ValueError(f"Output `out` must have dtype `torch.long`, " + f"got {out.dtype}") else: - out = torch.empty(result_shape, device=v.device, dtype=torch.long) - - if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE: - raise Exception('torchsearchsorted on CUDA device is asked, but it seems ' - 'that it is not available. Please install it') - if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE: - raise Exception('torchsearchsorted on CPU is not available. ' - 'Please install it.') + out = torch.empty(v.shape, device=v.device, dtype=torch.long) - left_side = 1 if side=='left' else 0 + left_side = side == 'left' if a.is_cuda: searchsorted_cuda_wrapper(a, v, out, left_side) else: searchsorted_cpu_wrapper(a, v, out, left_side) return out + + +def broadcast_tensors(*tensors, dim=0): + """Broadcast tensors along one dimension, leaving others dims unchanged""" + if dim < 0: + raise ValueError(f"Negative dimensions not supported, got {dim}") + dim_size = max(t.shape[dim] for t in tensors) + return [t.expand(*t.shape[:dim], dim_size, *t.shape[dim + 1:]) + for t in tensors] diff --git a/src/torchsearchsorted/utils.py b/src/torchsearchsorted/utils.py index 68b9939..c98a456 100644 --- a/src/torchsearchsorted/utils.py +++ b/src/torchsearchsorted/utils.py @@ -1,15 +1,25 @@ import numpy as np -def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'): - """Numpy version of searchsorted that works batch-wise on pytorch tensors - """ - nrows_a = a.shape[0] - (nrows_v, ncols_v) = v.shape - nrows_out = max(nrows_a, nrows_v) - out = np.empty((nrows_out, ncols_v), dtype=np.long) - def sel(data, row): - return data[0] if data.shape[0] == 1 else data[row] - for row in range(nrows_out): - out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side) +def numpy_searchsorted(a: np.ndarray, v: np.ndarray, + out: np.ndarray=None, side='left') -> np.ndarray: + """Batch-wise version of numpy's searchsorted""" + a = np.asarray(a) + v = np.asarray(v) + a, v = broadcast_arrays(a, v, axis=0) + if out is None: + out = np.empty(v.shape, dtype=np.long) + for i in range(v.shape[0]): + out[i] = np.searchsorted(a[i], v[i], side=side) return out + + +def broadcast_arrays(*arrays, axis=0): + """Broadcast arrays along one axis, leaving other axes unchanged""" + if axis < 0: + raise ValueError(f"Negative axis not supported, got {axis}") + axis_size = max(a.shape[axis] for a in arrays) + return [ + np.broadcast_to(a, (*a.shape[:axis], axis_size, *a.shape[axis + 1:])) + for a in arrays + ] diff --git a/test/test_searchsorted.py b/test/test_searchsorted.py index 27bfb49..7c0cd37 100644 --- a/test/test_searchsorted.py +++ b/test/test_searchsorted.py @@ -6,39 +6,124 @@ from itertools import product, repeat -def test_searchsorted_output_dtype(device): +def test_output_dtype(): B = 100 A = 50 V = 12 - a = torch.sort(torch.rand(B, V, device=device), dim=1)[0] - v = torch.rand(B, A, device=device) + a = torch.sort(torch.rand(B, A), dim=1)[0] + v = torch.rand(B, V) out = searchsorted(a, v) - out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy()) assert out.dtype == torch.long - np.testing.assert_array_equal(out.cpu().numpy(), out_np) - out = torch.empty(v.shape, dtype=torch.long, device=device) + out = torch.empty(v.shape, dtype=torch.long) searchsorted(a, v, out) assert out.dtype == torch.long - np.testing.assert_array_equal(out.cpu().numpy(), out_np) -Ba_val = [1, 100, 200] -Bv_val = [1, 100, 200] -A_val = [1, 50, 500] -V_val = [1, 12, 120] -side_val = ['left', 'right'] -nrepeat = 100 - -@pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val)) -def test_searchsorted_correct(Ba, Bv, A, V, side, device): - if Ba > 1 and Bv > 1 and Ba != Bv: - return - for test in range(nrepeat): - a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0] - v = torch.rand(Bv, V, device=device) - out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(), - side=side) - out = searchsorted(a, v, side=side).cpu().numpy() - np.testing.assert_array_equal(out, out_np) + with pytest.raises(ValueError): + out = torch.empty(v.shape, dtype=torch.float) + searchsorted(a, v, out) + + +def test_broadcast_batch_dim(): + # Batch dimension: + # (B, A), (B, V) -> (B, A), (B, V) + # (B, A), (1, V) -> (B, A), (B, V) + # (1, A), (B, V) -> (B, A), (B, V) + # (1, A), (1, V) -> (1, A), (1, V) + # (X, A), (Y, V) -> RuntimeError + + B = 6 + A = 3 + V = 4 + + a = torch.sort(torch.rand(B, A), dim=1)[0] + v = torch.rand(B, V) + out = searchsorted(a, v) + assert out.shape == (B, V) + + a = torch.sort(torch.rand(1, A), dim=1)[0] + v = torch.rand(B, V) + out = searchsorted(a, v) + assert out.shape == (B, V) + + a = torch.sort(torch.rand(B, A), dim=1)[0] + v = torch.rand(1, V) + out = searchsorted(a, v) + assert out.shape == (B, V) + + a = torch.sort(torch.rand(B, A), dim=1)[0] + v = torch.rand(B, V) + out = searchsorted(a, v) + assert out.shape == (B, V) + + a = torch.sort(torch.rand(7, A), dim=1)[0] + v = torch.rand(9, V) + with pytest.raises(RuntimeError): + searchsorted(a, v) + + +tests = { + 'left': { + 'a': [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + 'v': [[-99, 99, 2], [5, 9, 8]], + 'side': 'left', + 'expected': [[0, 5, 2], [0, 4, 3]], + }, + 'right': { + 'a': [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + 'v': [[-99, 99, 2], [5, 9, 8]], + 'side': 'right', + 'expected': [[0, 5, 3], [1, 5, 4]], + }, + 'left-broadcast v': { + 'a': [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + 'v': [[-99, 99, 2]], + 'side': 'left', + 'expected': [[0, 5, 2], [0, 5, 0]], + }, + 'right-broadcast v': { + 'a': [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + 'v': [[-99, 99, 2]], + 'side': 'right', + 'expected': [[0, 5, 3], [0, 5, 0]], + }, + 'left-broadcast a': { + 'a': [[0, 1, 2, 3, 4]], + 'v': [[-99, 99, 2], [99, -99, 3]], + 'side': 'left', + 'expected': [[0, 5, 2], [5, 0, 3]], + }, + 'right-broadcast a': { + 'a': [[0, 1, 2, 3, 4]], + 'v': [[-99, 99, 2], [99, -99, 3]], + 'side': 'right', + 'expected': [[0, 5, 3], [5, 0, 4]], + }, +} +@pytest.mark.parametrize('test', tests.values(), ids=list(tests.keys())) +def test_correct(test, device): + a = torch.tensor(test['a'], dtype=torch.float, device=device) + v = torch.tensor(test['v'], dtype=torch.float, device=device) + expected = torch.tensor(test['expected'], dtype=torch.long) + + out = searchsorted(a, v, side=test['side']) + np.testing.assert_array_equal(out.cpu().numpy(), expected.numpy()) + + +@pytest.mark.parametrize('Ba, Bv', [ + (Ba, Bv) for Ba, Bv in + product([1, 150, 300], [1, 150, 300]) + if Ba == Bv or ((Ba == 1) ^ (Bv == 1)) +]) +@pytest.mark.parametrize('A', [1, 40, 80]) +@pytest.mark.parametrize('V', [1, 40, 80]) +@pytest.mark.parametrize('side', ['left', 'right']) +def test_bigger_random(Ba, Bv, A, V, side, device): + a = torch.sort(torch.randn(Ba, A, device=device), dim=1)[0] + v = torch.randn(Bv, V, device=device) + out = searchsorted(a, v, side=side) + + out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(), side=side) + np.testing.assert_array_equal(out.cpu().numpy(), out_np)