From 284e31a11da06fe25cf5dff1ded8433b4fb4288b Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Sat, 9 Mar 2024 20:59:42 +0000 Subject: [PATCH] Search sorted (#91) --- Cargo.lock | 290 +----------------- vortex-array/Cargo.toml | 3 - .../src/array/primitive/compute/mod.rs | 6 + .../array/primitive/compute/search_sorted.rs | 57 ++++ vortex-array/src/array/sparse/compute.rs | 32 +- vortex-array/src/array/sparse/mod.rs | 6 +- vortex-array/src/compute/mod.rs | 5 + vortex-array/src/compute/search_sorted.rs | 74 +---- vortex-array/src/error.rs | 18 -- vortex-array/src/lib.rs | 1 - vortex-array/src/polars.rs | 101 ------ vortex-ree/src/ree.rs | 2 +- 12 files changed, 117 insertions(+), 478 deletions(-) create mode 100644 vortex-array/src/array/primitive/compute/search_sorted.rs delete mode 100644 vortex-array/src/polars.rs diff --git a/Cargo.lock b/Cargo.lock index 989887765d..597b337547 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -94,15 +94,6 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b3d0060af21e8d11a926981cc00c6c1541aa91dd64b9f881985c3da1094425f" -[[package]] -name = "argminmax" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52424b59d69d69d5056d508b260553afd91c57e21849579cd1f50ee8b8b88eaa" -dependencies = [ - "num-traits", -] - [[package]] name = "arrayref" version = "0.3.7" @@ -325,12 +316,6 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "atoi_simd" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" - [[package]] name = "autocfg" version = "1.1.0" @@ -397,7 +382,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.52", + "syn", "which", ] @@ -445,20 +430,6 @@ name = "bytemuck" version = "1.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" -dependencies = [ - "bytemuck_derive", -] - -[[package]] -name = "bytemuck_derive" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.52", -] [[package]] name = "byteorder" @@ -768,7 +739,7 @@ checksum = "27540baf49be0d484d8f0130d7d8da3011c32a44d4fc873368154f1510e574a2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn", ] [[package]] @@ -808,18 +779,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "ethnum" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" - -[[package]] -name = "fast-float" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" - [[package]] name = "fastlanez-sys" version = "0.1.0" @@ -879,12 +838,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" -[[package]] -name = "foreign_vec" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" - [[package]] name = "form_urlencoded" version = "1.2.1" @@ -949,10 +902,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", - "js-sys", "libc", "wasi", - "wasm-bindgen", ] [[package]] @@ -1005,7 +956,6 @@ checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash", "allocator-api2", - "rayon", ] [[package]] @@ -1344,7 +1294,7 @@ checksum = "adf157a4dc5a29b7b464aa8fe7edeff30076e07e13646a1c3874f58477dc99f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn", ] [[package]] @@ -1425,28 +1375,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "multiversion" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2c7b9d7fe61760ce5ea19532ead98541f6b4c495d87247aff9826445cf6872a" -dependencies = [ - "multiversion-macros", - "target-features", -] - -[[package]] -name = "multiversion-macros" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a83d8500ed06d68877e9de1dde76c1dbb83885dcdbda4ef44ccbc3fbda2ac8" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", - "target-features", -] - [[package]] name = "native-tls" version = "0.2.11" @@ -1585,7 +1513,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.52", + "syn", ] [[package]] @@ -1641,7 +1569,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn", ] [[package]] @@ -1790,143 +1718,6 @@ dependencies = [ "plotters-backend", ] -[[package]] -name = "polars-arrow" -version = "0.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faacd21a2548fa6d50c72d6b8d4649a8e029a0f3c6c5545b7f436f0610e49b0f" -dependencies = [ - "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "atoi_simd", - "bytemuck", - "chrono", - "dyn-clone", - "either", - "ethnum", - "fast-float", - "foreign_vec", - "getrandom", - "hashbrown", - "itoa", - "multiversion", - "num-traits", - "polars-error", - "polars-utils", - "ryu", - "simdutf8", - "streaming-iterator", - "strength_reduce", - "version_check", -] - -[[package]] -name = "polars-compute" -version = "0.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32d9dc87f8003ae0edeef5ad9ac92b2a345480bbe17adad64496113ae84706dd" -dependencies = [ - "bytemuck", - "num-traits", - "polars-arrow", - "polars-error", - "polars-utils", - "version_check", -] - -[[package]] -name = "polars-core" -version = "0.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "befd4d280a82219a01035c4f901319ceba65998c594d0c64f9a439cdee1d7777" -dependencies = [ - "ahash", - "bitflags 2.4.2", - "bytemuck", - "either", - "hashbrown", - "indexmap", - "num-traits", - "once_cell", - "polars-arrow", - "polars-compute", - "polars-error", - "polars-row", - "polars-utils", - "rayon", - "smartstring", - "thiserror", - "version_check", - "xxhash-rust", -] - -[[package]] -name = "polars-error" -version = "0.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50f2435b02d1ba36d8c1f6a722cad04e4c0b2705a3112c5706e6960d405d7798" -dependencies = [ - "simdutf8", - "thiserror", -] - -[[package]] -name = "polars-ops" -version = "0.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6395f5fd5e1adf016fd6403c0a493181c1a349a7a145b2687cdf50a0d630310a" -dependencies = [ - "ahash", - "argminmax", - "bytemuck", - "either", - "hashbrown", - "indexmap", - "memchr", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-utils", - "rayon", - "regex", - "smartstring", - "version_check", -] - -[[package]] -name = "polars-row" -version = "0.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4984d97aad3d0db92afe76ebcab10b5e37a1216618b5703ae0d2917ccd6168c" -dependencies = [ - "polars-arrow", - "polars-error", - "polars-utils", -] - -[[package]] -name = "polars-utils" -version = "0.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f9c955bb1e9b55d835aeb7fe4e4e8826e01abe5f0ada979ceb7d2b9af7b569" -dependencies = [ - "ahash", - "bytemuck", - "hashbrown", - "indexmap", - "num-traits", - "once_cell", - "polars-error", - "rayon", - "smartstring", - "version_check", -] - [[package]] name = "portable-atomic" version = "1.6.0" @@ -1952,7 +1743,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2", - "syn 2.0.52", + "syn", ] [[package]] @@ -2031,7 +1822,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.52", + "syn", ] [[package]] @@ -2044,7 +1835,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.52", + "syn", ] [[package]] @@ -2344,7 +2135,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn", ] [[package]] @@ -2376,12 +2167,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "simdutf8" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" - [[package]] name = "simplelog" version = "0.12.2" @@ -2409,17 +2194,6 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" -[[package]] -name = "smartstring" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" -dependencies = [ - "autocfg", - "static_assertions", - "version_check", -] - [[package]] name = "snap" version = "1.1.1" @@ -2442,29 +2216,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "streaming-iterator" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" - -[[package]] -name = "strength_reduce" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.52" @@ -2503,12 +2254,6 @@ dependencies = [ "libc", ] -[[package]] -name = "target-features" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" - [[package]] name = "target-lexicon" version = "0.12.14" @@ -2563,7 +2308,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn", ] [[package]] @@ -2833,9 +2578,6 @@ dependencies = [ "num-traits", "num_enum", "once_cell", - "polars-arrow", - "polars-core", - "polars-ops", "rand", "rayon", "roaring", @@ -2952,7 +2694,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.52", + "syn", "wasm-bindgen-shared", ] @@ -2986,7 +2728,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3210,12 +2952,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "xxhash-rust" -version = "0.8.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03" - [[package]] name = "zerocopy" version = "0.7.32" @@ -3233,7 +2969,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn", ] [[package]] diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 65b9537a5a..803f22ee3b 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -31,9 +31,6 @@ log = "0.4.20" num-traits = "0.2.18" num_enum = "0.7.2" once_cell = "1.19.0" -polars-arrow = { version = "0.37.0", features = ["arrow_rs"] } -polars-core = "0.37.0" -polars-ops = { version = "0.37.0", features = ["search_sorted"] } rand = { version = "0.8.5", features = [] } rayon = "1.8.1" roaring = "0.10.3" diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index e364de4bd3..95498da4bb 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -4,6 +4,7 @@ use crate::compute::cast::CastPrimitiveFn; use crate::compute::fill::FillForwardFn; use crate::compute::patch::PatchFn; use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::search_sorted::SearchSortedFn; use crate::compute::ArrayCompute; mod as_contiguous; @@ -11,6 +12,7 @@ mod cast; mod fill; mod patch; mod scalar_at; +mod search_sorted; impl ArrayCompute for PrimitiveArray { fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> { @@ -32,4 +34,8 @@ impl ArrayCompute for PrimitiveArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } + + fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { + Some(self) + } } diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs new file mode 100644 index 0000000000..bcc8396369 --- /dev/null +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -0,0 +1,57 @@ +use crate::array::primitive::PrimitiveArray; +use crate::compute::search_sorted::{SearchSortedFn, SearchSortedSide}; +use crate::error::VortexResult; +use crate::match_each_native_ptype; +use crate::ptype::NativePType; +use crate::scalar::Scalar; + +impl SearchSortedFn for PrimitiveArray { + fn search_sorted(&self, value: &dyn Scalar, side: SearchSortedSide) -> VortexResult { + match_each_native_ptype!(self.ptype(), |$T| { + let pvalue: $T = value.try_into()?; + Ok(search_sorted(self.typed_data::<$T>(), pvalue, side)) + }) + } +} + +fn search_sorted(arr: &[T], target: T, side: SearchSortedSide) -> usize { + match side { + SearchSortedSide::Left => search_sorted_cmp(arr, target, |a, b| a < b), + SearchSortedSide::Right => search_sorted_cmp(arr, target, |a, b| a <= b), + } +} + +fn search_sorted_cmp(arr: &[T], target: T, cmp: Cmp) -> usize +where + Cmp: Fn(T, T) -> bool + 'static, +{ + let mut low = 0; + let mut high = arr.len(); + + while low < high { + let mid = low + (high - low) / 2; + + if cmp(arr[mid], target) { + low = mid + 1; + } else { + high = mid; + } + } + + low +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_searchsorted_primitive() { + let values = vec![1u16, 2, 3]; + + assert_eq!(search_sorted(&values, 0, SearchSortedSide::Left), 0); + assert_eq!(search_sorted(&values, 1, SearchSortedSide::Left), 0); + assert_eq!(search_sorted(&values, 1, SearchSortedSide::Right), 1); + assert_eq!(search_sorted(&values, 4, SearchSortedSide::Left), 3); + } +} diff --git a/vortex-array/src/array/sparse/compute.rs b/vortex-array/src/array/sparse/compute.rs index 3d3e4216cb..91de0d9a48 100644 --- a/vortex-array/src/array/sparse/compute.rs +++ b/vortex-array/src/array/sparse/compute.rs @@ -3,7 +3,7 @@ use crate::array::sparse::SparseArray; use crate::array::{Array, ArrayRef}; use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; -use crate::compute::search_sorted::{search_sorted_usize, SearchSortedSide}; +use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::compute::ArrayCompute; use crate::error::VortexResult; use crate::scalar::{NullableScalar, Scalar, ScalarRef}; @@ -48,21 +48,19 @@ impl ScalarAtFn for SparseArray { // First, get the index of the patch index array that is the first index // greater than or equal to the true index let true_patch_index = index + self.indices_offset; - search_sorted_usize(self.indices(), true_patch_index, SearchSortedSide::Left).and_then( - |idx| { - // If the value at this index is equal to the true index, then it exists in the patch index array - // and we should return the value at the corresponding index in the patch values array - scalar_at(self.indices(), idx) - .or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed())) - .and_then(usize::try_from) - .and_then(|patch_index| { - if patch_index == true_patch_index { - scalar_at(self.values(), idx) - } else { - Ok(NullableScalar::none(self.values().dtype().clone()).boxed()) - } - }) - }, - ) + search_sorted(self.indices(), true_patch_index, SearchSortedSide::Left).and_then(|idx| { + // If the value at this index is equal to the true index, then it exists in the patch index array + // and we should return the value at the corresponding index in the patch values array + scalar_at(self.indices(), idx) + .or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed())) + .and_then(usize::try_from) + .and_then(|patch_index| { + if patch_index == true_patch_index { + scalar_at(self.values(), idx) + } else { + Ok(NullableScalar::none(self.values().dtype().clone()).boxed()) + } + }) + }) } } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index 6a0e7ad943..bab10419f7 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -15,7 +15,7 @@ use crate::array::{ check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, }; use crate::compress::EncodingCompression; -use crate::compute::search_sorted::{search_sorted_usize, SearchSortedSide}; +use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::dtype::DType; use crate::error::{VortexError, VortexResult}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; @@ -161,8 +161,8 @@ impl Array for SparseArray { check_slice_bounds(self, start, stop)?; // Find the index of the first patch index that is greater than or equal to the offset of this array - let index_start_index = search_sorted_usize(self.indices(), start, SearchSortedSide::Left)?; - let index_end_index = search_sorted_usize(self.indices(), stop, SearchSortedSide::Left)?; + let index_start_index = search_sorted(self.indices(), start, SearchSortedSide::Left)?; + let index_end_index = search_sorted(self.indices(), stop, SearchSortedSide::Left)?; Ok(SparseArray { indices_offset: self.indices_offset + start, diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 70cbdefa0f..19c4dd8655 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -1,4 +1,5 @@ use crate::compute::as_contiguous::AsContiguousFn; +use crate::compute::search_sorted::SearchSortedFn; use cast::{CastBoolFn, CastPrimitiveFn}; use fill::FillForwardFn; use patch::PatchFn; @@ -40,6 +41,10 @@ pub trait ArrayCompute { None } + fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { + None + } + fn take(&self) -> Option<&dyn TakeFn> { None } diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 694548eb45..d6e836be94 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -1,69 +1,29 @@ use crate::array::Array; -use crate::error::VortexResult; -use crate::polars::IntoPolarsSeries; -use crate::polars::IntoPolarsValue; -use crate::scalar::ScalarRef; -use polars_core::prelude::*; -use polars_ops::prelude::*; +use crate::error::{VortexError, VortexResult}; +use crate::scalar::{Scalar, ScalarRef}; pub enum SearchSortedSide { Left, Right, } -impl From for polars_ops::prelude::SearchSortedSide { - fn from(side: SearchSortedSide) -> Self { - match side { - SearchSortedSide::Left => polars_ops::prelude::SearchSortedSide::Left, - SearchSortedSide::Right => polars_ops::prelude::SearchSortedSide::Right, - } - } +pub trait SearchSortedFn { + fn search_sorted(&self, value: &dyn Scalar, side: SearchSortedSide) -> VortexResult; } -pub fn search_sorted_usize( - indices: &dyn Array, - index: usize, +pub fn search_sorted>( + array: &dyn Array, + target: T, side: SearchSortedSide, ) -> VortexResult { - let enc_scalar: ScalarRef = index.into(); - // Convert index into correctly typed Arrow scalar. - let enc_scalar = enc_scalar.cast(indices.dtype())?; - - let series: Series = indices.iter_arrow().into_polars(); - Ok(search_sorted( - &series, - &Series::from_any_values("needle", &[enc_scalar.into_polars()], true)?, - side.into(), - false, - )? - .get(0) - .unwrap() as usize) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::array::ArrayRef; - - #[test] - fn test_searchsorted_scalar() { - let haystack: ArrayRef = vec![1, 2, 3].into(); - - assert_eq!( - search_sorted_usize(haystack.as_ref(), 0, SearchSortedSide::Left).unwrap(), - 0 - ); - assert_eq!( - search_sorted_usize(haystack.as_ref(), 1, SearchSortedSide::Left).unwrap(), - 0 - ); - assert_eq!( - search_sorted_usize(haystack.as_ref(), 1, SearchSortedSide::Right).unwrap(), - 1 - ); - assert_eq!( - search_sorted_usize(haystack.as_ref(), 4, SearchSortedSide::Left).unwrap(), - 3 - ); - } + let scalar = target.into().cast(array.dtype())?; + array + .search_sorted() + .map(|f| f.search_sorted(scalar.as_ref(), side)) + .unwrap_or_else(|| { + Err(VortexError::NotImplemented( + "search_sorted", + array.encoding().id(), + )) + }) } diff --git a/vortex-array/src/error.rs b/vortex-array/src/error.rs index c82575f124..bfc9203feb 100644 --- a/vortex-array/src/error.rs +++ b/vortex-array/src/error.rs @@ -71,8 +71,6 @@ pub enum VortexError { MismatchedTypes(DType, DType), #[error("unexpected arrow data type: {0:?}")] InvalidArrowDataType(arrow::datatypes::DataType), - #[error("polars error: {0:?}")] - PolarsError(PolarsError), #[error("arrow error: {0:?}")] ArrowError(ArrowError), #[error("patch values may not be null for base dtype {0}")] @@ -102,19 +100,3 @@ impl From for VortexError { VortexError::ArrowError(ArrowError(err)) } } - -#[derive(Debug)] -#[allow(dead_code)] -pub struct PolarsError(polars_core::error::PolarsError); - -impl PartialEq for PolarsError { - fn eq(&self, _other: &Self) -> bool { - false - } -} - -impl From for VortexError { - fn from(err: polars_core::error::PolarsError) -> Self { - VortexError::PolarsError(PolarsError(err)) - } -} diff --git a/vortex-array/src/lib.rs b/vortex-array/src/lib.rs index fad5ab605f..f10a424d25 100644 --- a/vortex-array/src/lib.rs +++ b/vortex-array/src/lib.rs @@ -8,7 +8,6 @@ pub mod dtype; pub mod encode; pub mod error; pub mod formatter; -mod polars; pub mod ptype; mod sampling; pub mod serde; diff --git a/vortex-array/src/polars.rs b/vortex-array/src/polars.rs deleted file mode 100644 index 70b0845bcb..0000000000 --- a/vortex-array/src/polars.rs +++ /dev/null @@ -1,101 +0,0 @@ -use arrow::array::{Array as ArrowArray, ArrayRef as ArrowArrayRef}; -use polars_arrow::array::from_data; -use polars_core::prelude::{AnyValue, Series}; - -use crate::array::ArrowIterator; -use crate::dtype::DType; -use crate::scalar::{ - BinaryScalar, BoolScalar, NullableScalar, PScalar, Scalar, ScalarRef, Utf8Scalar, -}; - -pub trait IntoPolarsSeries { - fn into_polars(self) -> Series; -} - -impl IntoPolarsSeries for ArrowArrayRef { - fn into_polars(self) -> Series { - let polars_array = from_data(&self.to_data()); - ("array", polars_array).try_into().unwrap() - } -} - -impl IntoPolarsSeries for Vec { - fn into_polars(self) -> Series { - let chunks: Vec> = - self.iter().map(|a| from_data(&a.to_data())).collect(); - ("array", chunks).try_into().unwrap() - } -} - -impl IntoPolarsSeries for Box { - fn into_polars(self) -> Series { - let chunks: Vec> = - self.map(|a| from_data(&a.to_data())).collect(); - ("array", chunks).try_into().unwrap() - } -} - -pub trait IntoPolarsValue { - fn into_polars<'a>(self) -> AnyValue<'a>; -} - -impl IntoPolarsValue for ScalarRef { - fn into_polars<'a>(self) -> AnyValue<'a> { - self.as_ref().into_polars() - } -} - -impl IntoPolarsValue for &dyn Scalar { - fn into_polars<'a>(self) -> AnyValue<'a> { - if let Some(ns) = self.as_any().downcast_ref::() { - return match ns { - NullableScalar::Some(s, _) => s.as_ref().into_polars(), - NullableScalar::None(_) => AnyValue::Null, - }; - } - - match self.dtype() { - DType::Null => AnyValue::Null, - DType::Bool(_) => { - AnyValue::Boolean(self.as_any().downcast_ref::().unwrap().value()) - } - DType::Int(_, _, _) | DType::Float(_, _) => { - match self.as_any().downcast_ref::().unwrap() { - PScalar::U8(v) => AnyValue::UInt8(*v), - PScalar::U16(v) => AnyValue::UInt16(*v), - PScalar::U32(v) => AnyValue::UInt32(*v), - PScalar::U64(v) => AnyValue::UInt64(*v), - PScalar::I8(v) => AnyValue::Int8(*v), - PScalar::I16(v) => AnyValue::Int16(*v), - PScalar::I32(v) => AnyValue::Int32(*v), - PScalar::I64(v) => AnyValue::Int64(*v), - PScalar::F16(v) => AnyValue::Float32(v.to_f32()), - PScalar::F32(v) => AnyValue::Float32(*v), - PScalar::F64(v) => AnyValue::Float64(*v), - } - } - DType::Decimal(_, _, _) => todo!(), - DType::Utf8(_) => AnyValue::StringOwned( - self.as_any() - .downcast_ref::() - .unwrap() - .value() - .into(), - ), - DType::Binary(_) => AnyValue::BinaryOwned( - self.as_any() - .downcast_ref::() - .unwrap() - .value() - .clone(), - ), - DType::LocalTime(_, _) => todo!(), - DType::LocalDate(_) => todo!(), - DType::Instant(_, _) => todo!(), - DType::ZonedDateTime(_, _) => todo!(), - DType::Struct(_, _) => todo!(), - DType::List(_, _) => todo!(), - DType::Map(_, _, _) => todo!(), - } - } -} diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index a0256d2ed2..dfa246ea55 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -82,7 +82,7 @@ impl REEArray { } pub fn find_physical_index(&self, index: usize) -> VortexResult { - compute::search_sorted::search_sorted_usize( + compute::search_sorted::search_sorted( self.ends(), index + self.offset, SearchSortedSide::Right,