From f9dd4f4e1109f764b08ac97357716beeff695086 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 7 May 2024 19:56:53 +0100 Subject: [PATCH 1/4] DType Serialization (#298) * Clean up DType serde * Create build-vortex package to unify flatbuffer/protobuffer code generation. * Add protobuf DType definitions. FLUP questions: * Should we remove nullable from DType and define a separate `FieldType { dtype: DType, nullable: bool }`? Recursive types such as Struct and List would then use FieldType. It does introduce an interesting quirk of a nullable null... But maybe that's less annoying that duplicating nullability everywhere? * If we don't want a separate FieldType, we could still pull `nullable` up in both flatbuffer and protobuf definitions. Start of #277 --- Cargo.lock | 221 ++++++++++-------- Cargo.toml | 5 + build-vortex/Cargo.toml | 19 ++ build-vortex/README.md | 12 + build-vortex/src/lib.rs | 96 ++++++++ vortex-array/Cargo.toml | 4 +- vortex-dtype/Cargo.toml | 12 +- vortex-dtype/build.rs | 4 +- vortex-dtype/flatbuffers/dtype.fbs | 26 +-- vortex-dtype/proto/dtype.proto | 73 ++++++ vortex-dtype/src/deserialize.rs | 83 ------- vortex-dtype/src/dtype.rs | 39 +--- vortex-dtype/src/lib.rs | 10 +- vortex-dtype/src/nullability.rs | 36 +++ vortex-dtype/src/ptype.rs | 3 +- vortex-dtype/src/serde.rs | 1 - .../{serialize.rs => serde/flatbuffers.rs} | 111 ++++++--- vortex-dtype/src/serde/mod.rs | 31 +++ vortex-dtype/src/serde/proto.rs | 129 ++++++++++ vortex-dtype/src/serde/serde.rs | 25 ++ vortex-ipc/benches/ipc_array_reader_take.rs | 3 +- vortex-ipc/src/reader.rs | 5 +- vortex-roaring/src/boolean/mod.rs | 3 +- vortex-scalar/Cargo.toml | 6 +- vortex-scalar/build.rs | 4 +- vortex-scalar/src/lib.rs | 1 + vortex-scalar/src/serde.rs | 2 +- 27 files changed, 670 insertions(+), 294 deletions(-) create mode 100644 build-vortex/Cargo.toml create mode 100644 build-vortex/README.md create mode 100644 build-vortex/src/lib.rs mode change 120000 => 100644 vortex-dtype/build.rs create mode 100644 vortex-dtype/proto/dtype.proto delete mode 100644 vortex-dtype/src/deserialize.rs create mode 100644 vortex-dtype/src/nullability.rs delete mode 100644 vortex-dtype/src/serde.rs rename vortex-dtype/src/{serialize.rs => serde/flatbuffers.rs} (64%) create mode 100644 vortex-dtype/src/serde/mod.rs create mode 100644 vortex-dtype/src/serde/proto.rs create mode 100644 vortex-dtype/src/serde/serde.rs mode change 120000 => 100644 vortex-scalar/build.rs diff --git a/Cargo.lock b/Cargo.lock index af8518321c..11f09cdd3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -84,15 +84,15 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" +checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" [[package]] name = "anyhow" -version = "1.0.82" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" +checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3" [[package]] name = "arc-swap" @@ -343,7 +343,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -354,7 +354,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -374,9 +374,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" [[package]] name = "aws-config" @@ -785,7 +785,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.60", + "syn 2.0.61", "which", ] @@ -853,6 +853,15 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "build-vortex" +version = "0.1.0" +dependencies = [ + "flatc", + "prost-build", + "walkdir", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -953,9 +962,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.96" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "065a29261d53ba54260972629f9ca6bffa69bac13cd1fed61420f7fa68b9f8bd" +checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4" dependencies = [ "jobserver", "libc", @@ -1620,7 +1629,7 @@ checksum = "27540baf49be0d484d8f0130d7d8da3011c32a44d4fc873368154f1510e574a2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -1661,7 +1670,7 @@ checksum = "a1ab991c1362ac86c61ab6f556cff143daa22e5a15e4e189df818b2fd19fe65b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -1861,7 +1870,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -1906,9 +1915,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "js-sys", @@ -2337,9 +2346,9 @@ dependencies = [ [[package]] name = "lance" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65c831c2367f290c0c704db0f8a27d006b448001de266fe391bb136fedf4398" +checksum = "362b480df322cd9a1d0ae9336197001d29140843d3716a31a8f55ced9c000a54" dependencies = [ "arrow", "arrow-arith", @@ -2396,9 +2405,9 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a0261647477271b4c5a4e17c184c48a60df5fd3e2f7378270ade4ab13fe6353" +checksum = "a227e1c408cce3db3270a41e80b353f4abb90d0891de7bc53aa9ee7683f77d8f" dependencies = [ "arrow-array", "arrow-buffer", @@ -2415,9 +2424,9 @@ dependencies = [ [[package]] name = "lance-core" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8de18d3460c09860df3c02e97ff3142e1f81532720d50955d75d2545fb7b56c" +checksum = "a88115f28618a2ca888fe41aeecca49bb90d4c61ba4c25f457c5a86f29fce5f9" dependencies = [ "arrow-array", "arrow-buffer", @@ -2451,9 +2460,9 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a99c234ffb02922a073e35d2d814528b371dd9376b2d205108ce5cdfcfd233" +checksum = "c1d43fb160792ce395d76ed864b7e1c6ef3a9fe46e5fe54f5339409aaf89c037" dependencies = [ "arrow", "arrow-array", @@ -2475,9 +2484,9 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f72e133550f1a90dd9364d0f6d387371a97b3280a8db13adb2891535178d75bd" +checksum = "7abd3dc7a5420ba82abb8dc3f826ff0d2b08c8e61ec20e5c9b087695ff583aa8" dependencies = [ "arrow", "arrow-array", @@ -2492,9 +2501,9 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6d0421c41d9da1ac5ac4f2e66ab8a63efeab2291ba10d7a5441893b5227c3c0" +checksum = "694a0d2ce6617d83ccbccbe03b092343e03d0d3de389773143a6994575bb6c40" dependencies = [ "arrow-arith", "arrow-array", @@ -2519,9 +2528,9 @@ dependencies = [ [[package]] name = "lance-file" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9439893d05ca544efce2d3555f2ee9ece2763fa4f966cd3af7b0bd46713dd0" +checksum = "fa0d5fa512333bf420475d34d631262ba4de6d604c93eaa7fb3d7fa2760c1758" dependencies = [ "arrow-arith", "arrow-array", @@ -2555,9 +2564,9 @@ dependencies = [ [[package]] name = "lance-index" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f552f3fffc64b97be3667fdf4b6161bc96fdb920e4317d84c4070baf10902804" +checksum = "b2f1892475738d35b46adb4f474934999545f46d23cdbb38181b45d8cdcbf471" dependencies = [ "arrow", "arrow-array", @@ -2601,9 +2610,9 @@ dependencies = [ [[package]] name = "lance-io" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8387b17cb0c1f3c8b5c918a04e8ea67ec6c76a96e68f976caa1a851e2ae36fea" +checksum = "e7143cf3e0175ed720be9515a188d7c37f748f2ffe32b08af65d45aed5a0af00" dependencies = [ "arrow", "arrow-arith", @@ -2641,9 +2650,9 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd628dba9812f210c427faabf16b8de194ba1c3e68be2182e6759985b1f96c14" +checksum = "d6d737aabcaa7c34c1bd70282d667b45df7549b5cf80e44ed2a3481fbdebd260" dependencies = [ "arrow-array", "arrow-ord", @@ -2664,9 +2673,9 @@ dependencies = [ [[package]] name = "lance-table" -version = "0.10.16" +version = "0.10.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce31ce988552ac4ef5da92138b8032227faea6f2b0c3445646654721b46428e6" +checksum = "b9017c6d5aa45b233a835ac156a48621f61e552bebb8fdbb11952e96581c6aca" dependencies = [ "arrow", "arrow-array", @@ -3029,11 +3038,10 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" dependencies = [ - "autocfg", "num-integer", "num-traits", ] @@ -3064,9 +3072,9 @@ dependencies = [ [[package]] name = "num-iter" -version = "0.1.44" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" dependencies = [ "autocfg", "num-integer", @@ -3087,9 +3095,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", "libm", @@ -3144,7 +3152,7 @@ dependencies = [ "proc-macro-crate 3.1.0", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -3231,7 +3239,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -3345,18 +3353,18 @@ dependencies = [ [[package]] name = "parse-zoneinfo" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c705f256449c60da65e11ff6626e0c16a0a0b96aaa348de61376b249bc340f41" +checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" dependencies = [ "regex", ] [[package]] name = "paste" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "path_abs" @@ -3441,7 +3449,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -3510,12 +3518,12 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prettyplease" -version = "0.2.19" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550" +checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -3539,9 +3547,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b" dependencies = [ "unicode-ident", ] @@ -3573,7 +3581,7 @@ dependencies = [ "prost", "prost-types", "regex", - "syn 2.0.60", + "syn 2.0.61", "tempfile", ] @@ -3587,7 +3595,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -3668,7 +3676,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -3681,7 +3689,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -3990,9 +3998,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustc-hash" @@ -4067,9 +4075,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" +checksum = "51f344d206c5e1b010eec27349b815a4805f70a778895959d70b74b9b529b30a" [[package]] name = "rustls-webpki" @@ -4083,15 +4091,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47" +checksum = "092474d1a01ea8278f69e6a358998405fae5b8b963ddaeb2b0b04a128bf1dfb0" [[package]] name = "ryu" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "same-file" @@ -4138,11 +4146,11 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.10.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "770452e37cad93e0a50d5abc3990d2bc351c36d0328f86cefec2f2fb206eaef6" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.5.0", "core-foundation", "core-foundation-sys", "libc", @@ -4151,9 +4159,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.10.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41f3cc463c0ef97e11c3461a9d3787412d30e8e7eb907c79180c4a57bf7c04ef" +checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" dependencies = [ "core-foundation-sys", "libc", @@ -4161,9 +4169,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" dependencies = [ "serde", ] @@ -4191,7 +4199,7 @@ checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -4205,6 +4213,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_test" +version = "1.0.176" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a2f49ace1498612d14f7e0b8245519584db8299541dfe31a06374a828d620ab" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -4359,7 +4376,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -4399,7 +4416,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -4421,9 +4438,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.60" +version = "2.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +checksum = "c993ed8ccba56ae856363b1845da7266a7cb78e1d146c8a32d54b45a8b831fc9" dependencies = [ "proc-macro2", "quote", @@ -4502,22 +4519,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.59" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.59" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -4623,7 +4640,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -4659,16 +4676,15 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" dependencies = [ "bytes", "futures-core", "futures-sink", "pin-project-lite", "tokio", - "tracing", ] [[package]] @@ -4747,7 +4763,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] @@ -4979,15 +4995,17 @@ dependencies = [ name = "vortex-dtype" version = "0.1.0" dependencies = [ + "build-vortex", "flatbuffers", - "flatc", "half", "itertools 0.12.1", "num-traits", + "prost", "serde", + "serde_json", + "serde_test", "vortex-error", "vortex-flatbuffers", - "walkdir", ] [[package]] @@ -5089,8 +5107,8 @@ dependencies = [ name = "vortex-scalar" version = "0.1.0" dependencies = [ + "build-vortex", "flatbuffers", - "flatc", "flexbuffers", "itertools 0.12.1", "num-traits", @@ -5099,7 +5117,6 @@ dependencies = [ "vortex-dtype", "vortex-error", "vortex-flatbuffers", - "walkdir", ] [[package]] @@ -5167,7 +5184,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", "wasm-bindgen-shared", ] @@ -5201,7 +5218,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -5463,22 +5480,22 @@ checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] name = "zerocopy" -version = "0.7.32" +version = "0.7.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.32" +version = "0.7.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.61", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 22e8c20135..adb1e60e59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "bench-vortex", + "build-vortex", "fastlanez", "fastlanez-sys", "pyvortex", @@ -74,12 +75,16 @@ num-traits = "0.2.18" num_enum = "0.7.2" parquet = "51.0.0" paste = "1.0.14" +prost = "0.12.4" +prost-build = "0.12.4" pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py311"] } pyo3-log = "0.9.0" rand = "0.8.5" reqwest = { version = "0.12.0", features = ["blocking"] } seq-macro = "0.3.5" serde = "1.0.197" +serde_json = "1.0.116" +serde_test = "1.0.176" simplelog = { version = "0.12.2", features = ["paris"] } thiserror = "1.0.58" tokio = "1.37.0" diff --git a/build-vortex/Cargo.toml b/build-vortex/Cargo.toml new file mode 100644 index 0000000000..1ed3cd4905 --- /dev/null +++ b/build-vortex/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "build-vortex" +version.workspace = true +homepage.workspace = true +repository.workspace = true +authors.workspace = true +license.workspace = true +keywords.workspace = true +include.workspace = true +edition.workspace = true +rust-version.workspace = true + +[dependencies] +flatc = { workspace = true } +prost-build = { workspace = true } +walkdir = { workspace = true } + +[lints] +workspace = true diff --git a/build-vortex/README.md b/build-vortex/README.md new file mode 100644 index 0000000000..ef392c09e4 --- /dev/null +++ b/build-vortex/README.md @@ -0,0 +1,12 @@ +# Build Vortex + +A crate containing configuration logic for Vortex build.rs files. + +## Usage + +## Features + +Depending on the enabled features, this script supports: + +* FlatBuffers +* Protocol Buffers \ No newline at end of file diff --git a/build-vortex/src/lib.rs b/build-vortex/src/lib.rs new file mode 100644 index 0000000000..650292159e --- /dev/null +++ b/build-vortex/src/lib.rs @@ -0,0 +1,96 @@ +use std::env; +use std::ffi::OsStr; +use std::fs::create_dir_all; +use std::path::{Path, PathBuf}; +use std::process::Command; + +use flatc::flatc; +use walkdir::WalkDir; + +fn manifest_dir() -> PathBuf { + PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()) + .canonicalize() + .expect("Failed to canonicalize CARGO_MANIFEST_DIR") +} + +fn out_dir() -> PathBuf { + PathBuf::from(env::var("OUT_DIR").unwrap()) + .canonicalize() + .expect("Failed to canonicalize OUT_DIR") +} + +pub fn build() { + // FlatBuffers + if env::var("CARGO_FEATURE_FLATBUFFERS").ok().is_some() { + build_flatbuffers(); + } + + // Proto (prost) + if env::var("CARGO_FEATURE_PROST").ok().is_some() { + build_proto(); + } +} + +pub fn build_proto() { + let proto_dir = manifest_dir().join("proto"); + let proto_files = walk_files(&proto_dir, "proto"); + let proto_out = out_dir().join("proto"); + + create_dir_all(&proto_out).expect("Failed to create proto output directory"); + + prost_build::Config::new() + .out_dir(&proto_out) + .compile_protos(&proto_files, &[&proto_dir, &proto_dir.join("../../")]) + .expect("Failed to compile protos"); +} + +pub fn build_flatbuffers() { + let flatbuffers_dir = manifest_dir().join("flatbuffers"); + let fbs_files = walk_files(&flatbuffers_dir, "fbs"); + check_call( + Command::new(flatc()) + .arg("--rust") + .arg("--filename-suffix") + .arg("") + .arg("-I") + .arg(flatbuffers_dir.join("../../")) + .arg("--include-prefix") + .arg("flatbuffers::deps") + .arg("-o") + .arg(out_dir().join("flatbuffers")) + .args(fbs_files), + ) +} + +/// Recursively walk for files with the given extension, adding them to rerun-if-changed. +fn walk_files(dir: &Path, ext: &str) -> Vec { + WalkDir::new(dir) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| e.path().extension() == Some(OsStr::new(ext))) + .map(|e| { + rerun_if_changed(e.path()); + e.path().to_path_buf() + }) + .collect() +} + +fn rerun_if_changed(path: &Path) { + println!( + "cargo:rerun-if-changed={}", + path.canonicalize() + .unwrap_or_else(|_| panic!("failed to canonicalize {}", path.to_str().unwrap())) + .to_str() + .unwrap() + ); +} + +fn check_call(command: &mut Command) { + let name = command.get_program().to_str().unwrap().to_string(); + let Ok(status) = command.status() else { + panic!("Failed to launch {}", &name) + }; + if !status.success() { + panic!("{} failed with status {}", &name, status.code().unwrap()); + } +} diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 26f6cd800f..e18e1680ee 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -34,10 +34,10 @@ num_enum = { workspace = true } paste = { workspace = true } rand = { workspace = true } vortex-buffer = { path = "../vortex-buffer" } -vortex-dtype = { path = "../vortex-dtype", features = ["serde"] } +vortex-dtype = { path = "../vortex-dtype", features = ["flatbuffers", "serde"] } vortex-error = { path = "../vortex-error", features = ["flexbuffers"] } vortex-flatbuffers = { path = "../vortex-flatbuffers" } -vortex-scalar = { path = "../vortex-scalar", features = ["serde"] } +vortex-scalar = { path = "../vortex-scalar", features = ["flatbuffers", "serde"] } serde = { workspace = true, features = ["derive"] } [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/vortex-dtype/Cargo.toml b/vortex-dtype/Cargo.toml index 2ba80e1064..3dcdc20299 100644 --- a/vortex-dtype/Cargo.toml +++ b/vortex-dtype/Cargo.toml @@ -16,17 +16,21 @@ name = "vortex_dtype" path = "src/lib.rs" [dependencies] -flatbuffers = { workspace = true } +flatbuffers = { workspace = true, optional = true } half = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } -serde = { workspace = true, optional = true, features = ["rc"] } +prost = { workspace = true, optional = true } +serde = { workspace = true, optional = true, features = ["rc", "derive"] } vortex-error = { path = "../vortex-error" } vortex-flatbuffers = { path = "../vortex-flatbuffers" } +[dev-dependencies] +serde_json = { workspace = true } +serde_test = { workspace = true } + [build-dependencies] -flatc = { workspace = true } -walkdir = { workspace = true } +build-vortex = { path = "../build-vortex" } [lints] workspace = true diff --git a/vortex-dtype/build.rs b/vortex-dtype/build.rs deleted file mode 120000 index 7cb528993c..0000000000 --- a/vortex-dtype/build.rs +++ /dev/null @@ -1 +0,0 @@ -../flatbuffers.build.rs \ No newline at end of file diff --git a/vortex-dtype/build.rs b/vortex-dtype/build.rs new file mode 100644 index 0000000000..3ce2fd1cb5 --- /dev/null +++ b/vortex-dtype/build.rs @@ -0,0 +1,3 @@ +pub fn main() { + build_vortex::build(); +} diff --git a/vortex-dtype/flatbuffers/dtype.fbs b/vortex-dtype/flatbuffers/dtype.fbs index 6ceaa4a459..e5d47a54a0 100644 --- a/vortex-dtype/flatbuffers/dtype.fbs +++ b/vortex-dtype/flatbuffers/dtype.fbs @@ -1,10 +1,5 @@ namespace vortex.dtype; -enum Nullability: uint8 { - NonNullable, - Nullable, -} - enum PType: uint8 { U8, U16, @@ -22,46 +17,45 @@ enum PType: uint8 { table Null {} table Bool { - nullability: Nullability; + nullable: bool; } table Primitive { ptype: PType; - nullability: Nullability; + nullable: bool; } table Decimal { /// Total number of decimal digits precision: uint8; - /// Number of digits after the decimal point "." - scale: int8; - nullability: Nullability; + scale: uint8; + nullable: bool; } table Utf8 { - nullability: Nullability; + nullable: bool; } table Binary { - nullability: Nullability; + nullable: bool; } table Struct_ { names: [string]; - fields: [DType]; - nullability: Nullability; + dtypes: [DType]; + nullable: bool; } table List { element_type: DType; - nullability: Nullability; + nullable: bool; } table Extension { id: string; metadata: [ubyte]; - nullability: Nullability; + nullable: bool; } union Type { diff --git a/vortex-dtype/proto/dtype.proto b/vortex-dtype/proto/dtype.proto new file mode 100644 index 0000000000..d3f82c9bd9 --- /dev/null +++ b/vortex-dtype/proto/dtype.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; + +package vortex.dtype; + +enum PType { + U8 = 0; + U16 = 1; + U32 = 2; + U64 = 3; + I8 = 4; + I16 = 5; + I32 = 6; + I64 = 7; + F16 = 8; + F32 = 9; + F64 = 10; +} + +message Null {} + +message Bool { + bool nullable = 1; +} + +message Primitive { + PType type = 1; + bool nullable = 2; +} + +message Decimal { + uint32 precision = 1; + uint32 scale = 2; + bool nullable = 3; +} + +message Utf8 { + bool nullable = 1; +} + +message Binary { + bool nullable = 1; +} + +message Struct { + repeated string names = 1; + repeated DType dtypes = 2; + bool nullable = 3; +} + +message List { + DType element_type = 1; + bool nullable = 2; +} + +message Extension { + string id = 1; + optional bytes metadata = 2; + bool nullable = 3; +} + +message DType { + oneof type { + Null null = 1; + Bool bool = 2; + Primitive primitive = 3; + Decimal decimal = 4; + Utf8 utf8 = 5; + Binary binary = 6; + Struct struct = 7; + List list = 8; + Extension extension = 9; + } +} \ No newline at end of file diff --git a/vortex-dtype/src/deserialize.rs b/vortex-dtype/src/deserialize.rs deleted file mode 100644 index 2645ab1fcc..0000000000 --- a/vortex-dtype/src/deserialize.rs +++ /dev/null @@ -1,83 +0,0 @@ -use itertools::Itertools; -use vortex_error::{vortex_err, VortexError, VortexResult}; -use vortex_flatbuffers::ReadFlatBuffer; - -use crate::{flatbuffers as fb, ExtDType, ExtID, ExtMetadata, Nullability}; -use crate::{DType, StructDType}; - -impl ReadFlatBuffer for DType { - type Source<'a> = fb::DType<'a>; - type Error = VortexError; - - fn read_flatbuffer(fb: &Self::Source<'_>) -> Result { - match fb.type_type() { - fb::Type::Null => Ok(DType::Null), - fb::Type::Bool => Ok(DType::Bool( - fb.type__as_bool().unwrap().nullability().try_into()?, - )), - fb::Type::Primitive => { - let fb_primitive = fb.type__as_primitive().unwrap(); - Ok(DType::Primitive( - fb_primitive.ptype().try_into()?, - fb_primitive.nullability().try_into()?, - )) - } - fb::Type::Binary => Ok(DType::Binary( - fb.type__as_binary().unwrap().nullability().try_into()?, - )), - fb::Type::Utf8 => Ok(DType::Utf8( - fb.type__as_utf_8().unwrap().nullability().try_into()?, - )), - fb::Type::List => { - let fb_list = fb.type__as_list().unwrap(); - let element_dtype = DType::read_flatbuffer(&fb_list.element_type().unwrap())?; - Ok(DType::List( - Box::new(element_dtype), - fb_list.nullability().try_into()?, - )) - } - fb::Type::Struct_ => { - let fb_struct = fb.type__as_struct_().unwrap(); - let names = fb_struct - .names() - .unwrap() - .iter() - .map(|n| n.into()) - .collect_vec() - .into(); - let fields: Vec = fb_struct - .fields() - .unwrap() - .iter() - .map(|f| DType::read_flatbuffer(&f)) - .collect::>>()?; - Ok(DType::Struct( - StructDType::new(names, fields), - fb_struct.nullability().try_into()?, - )) - } - fb::Type::Extension => { - let fb_ext = fb.type__as_extension().unwrap(); - let id = ExtID::from(fb_ext.id().unwrap()); - let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes())); - Ok(DType::Extension( - ExtDType::new(id, metadata), - fb_ext.nullability().try_into()?, - )) - } - _ => Err(vortex_err!("Unknown DType variant")), - } - } -} - -impl TryFrom for Nullability { - type Error = VortexError; - - fn try_from(value: fb::Nullability) -> VortexResult { - match value { - fb::Nullability::NonNullable => Ok(Nullability::NonNullable), - fb::Nullability::Nullable => Ok(Nullability::Nullable), - _ => Err(vortex_err!("Unknown nullability value")), - } - } -} diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index abdf2cce1a..9da2cbf363 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -5,44 +5,9 @@ use std::sync::Arc; use itertools::Itertools; use DType::*; +use crate::nullability::Nullability; use crate::{ExtDType, PType}; -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] -#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -pub enum Nullability { - #[default] - NonNullable, - Nullable, -} - -impl From for Nullability { - fn from(value: bool) -> Self { - if value { - Nullability::Nullable - } else { - Nullability::NonNullable - } - } -} - -impl From for bool { - fn from(value: Nullability) -> Self { - match value { - Nullability::NonNullable => false, - Nullability::Nullable => true, - } - } -} - -impl Display for Nullability { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Nullability::NonNullable => write!(f, ""), - Nullability::Nullable => write!(f, "?"), - } - } -} - pub type FieldNames = Arc<[Arc]>; pub type Metadata = Vec; @@ -71,7 +36,7 @@ impl DType { } pub fn is_nullable(&self) -> bool { - use Nullability::*; + use crate::nullability::Nullability::*; match self { Null => true, diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index 869eff844a..44f70828cb 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -1,14 +1,20 @@ pub use dtype::*; pub use extension::*; pub use half; +pub use nullability::*; pub use ptype::*; -mod deserialize; mod dtype; mod extension; +mod nullability; mod ptype; mod serde; -mod serialize; +#[cfg(feature = "prost")] +pub mod proto { + include!(concat!(env!("OUT_DIR"), "/proto/vortex.dtype.rs")); +} + +#[cfg(feature = "flatbuffers")] pub mod flatbuffers { #[allow(unused_imports)] #[allow(dead_code)] diff --git a/vortex-dtype/src/nullability.rs b/vortex-dtype/src/nullability.rs new file mode 100644 index 0000000000..e8416215bd --- /dev/null +++ b/vortex-dtype/src/nullability.rs @@ -0,0 +1,36 @@ +use std::fmt::{Display, Formatter}; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub enum Nullability { + #[default] + NonNullable, + Nullable, +} + +impl From for Nullability { + fn from(value: bool) -> Self { + if value { + Nullability::Nullable + } else { + Nullability::NonNullable + } + } +} + +impl From for bool { + fn from(value: Nullability) -> Self { + match value { + Nullability::NonNullable => false, + Nullability::Nullable => true, + } + } +} + +impl Display for Nullability { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Nullability::NonNullable => write!(f, ""), + Nullability::Nullable => write!(f, "?"), + } + } +} diff --git a/vortex-dtype/src/ptype.rs b/vortex-dtype/src/ptype.rs index d9bd7d5aca..c2318001c7 100644 --- a/vortex-dtype/src/ptype.rs +++ b/vortex-dtype/src/ptype.rs @@ -5,12 +5,13 @@ use num_traits::{FromPrimitive, Num, NumCast}; use vortex_error::{vortex_err, VortexError, VortexResult}; use crate::half::f16; +use crate::nullability::Nullability::NonNullable; use crate::DType; use crate::DType::*; -use crate::Nullability::NonNullable; #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))] pub enum PType { U8, U16, diff --git a/vortex-dtype/src/serde.rs b/vortex-dtype/src/serde.rs deleted file mode 100644 index 29df755186..0000000000 --- a/vortex-dtype/src/serde.rs +++ /dev/null @@ -1 +0,0 @@ -#![cfg(feature = "serde")] diff --git a/vortex-dtype/src/serialize.rs b/vortex-dtype/src/serde/flatbuffers.rs similarity index 64% rename from vortex-dtype/src/serialize.rs rename to vortex-dtype/src/serde/flatbuffers.rs index f08064de90..add24ad831 100644 --- a/vortex-dtype/src/serialize.rs +++ b/vortex-dtype/src/serde/flatbuffers.rs @@ -1,10 +1,72 @@ +#![cfg(feature = "flatbuffers")] + use flatbuffers::{FlatBufferBuilder, WIPOffset}; use itertools::Itertools; -use vortex_error::{vortex_bail, VortexError}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_flatbuffers::{FlatBufferRoot, WriteFlatBuffer}; -use crate::{flatbuffers as fb, PType}; -use crate::{DType, Nullability}; +use crate::{flatbuffers as fb, ExtDType, ExtID, ExtMetadata, PType}; +use crate::{DType, StructDType}; + +impl TryFrom> for DType { + type Error = VortexError; + + fn try_from(fb: fb::DType<'_>) -> Result { + match fb.type_type() { + fb::Type::Null => Ok(DType::Null), + fb::Type::Bool => Ok(DType::Bool(fb.type__as_bool().unwrap().nullable().into())), + fb::Type::Primitive => { + let fb_primitive = fb.type__as_primitive().unwrap(); + Ok(DType::Primitive( + fb_primitive.ptype().try_into()?, + fb_primitive.nullable().into(), + )) + } + fb::Type::Binary => Ok(DType::Binary( + fb.type__as_binary().unwrap().nullable().into(), + )), + fb::Type::Utf8 => Ok(DType::Utf8(fb.type__as_utf_8().unwrap().nullable().into())), + fb::Type::List => { + let fb_list = fb.type__as_list().unwrap(); + let element_dtype = DType::try_from(fb_list.element_type().unwrap())?; + Ok(DType::List( + Box::new(element_dtype), + fb_list.nullable().into(), + )) + } + fb::Type::Struct_ => { + let fb_struct = fb.type__as_struct_().unwrap(); + let names = fb_struct + .names() + .unwrap() + .iter() + .map(|n| (*n).into()) + .collect_vec() + .into(); + let dtypes: Vec = fb_struct + .dtypes() + .unwrap() + .iter() + .map(DType::try_from) + .collect::>>()?; + Ok(DType::Struct( + StructDType::new(names, dtypes), + fb_struct.nullable().into(), + )) + } + fb::Type::Extension => { + let fb_ext = fb.type__as_extension().unwrap(); + let id = ExtID::from(fb_ext.id().unwrap()); + let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes())); + Ok(DType::Extension( + ExtDType::new(id, metadata), + fb_ext.nullable().into(), + )) + } + _ => Err(vortex_err!("Unknown DType variant")), + } + } +} impl FlatBufferRoot for DType {} impl WriteFlatBuffer for DType { @@ -19,7 +81,7 @@ impl WriteFlatBuffer for DType { DType::Bool(n) => fb::Bool::create( fbb, &fb::BoolArgs { - nullability: n.into(), + nullable: (*n).into(), }, ) .as_union_value(), @@ -27,21 +89,21 @@ impl WriteFlatBuffer for DType { fbb, &fb::PrimitiveArgs { ptype: (*ptype).into(), - nullability: n.into(), + nullable: (*n).into(), }, ) .as_union_value(), DType::Utf8(n) => fb::Utf8::create( fbb, &fb::Utf8Args { - nullability: n.into(), + nullable: (*n).into(), }, ) .as_union_value(), DType::Binary(n) => fb::Binary::create( fbb, &fb::BinaryArgs { - nullability: n.into(), + nullable: (*n).into(), }, ) .as_union_value(), @@ -58,14 +120,14 @@ impl WriteFlatBuffer for DType { .iter() .map(|dtype| dtype.write_flatbuffer(fbb)) .collect_vec(); - let fields = Some(fbb.create_vector(&dtypes)); + let dtypes = Some(fbb.create_vector(&dtypes)); fb::Struct_::create( fbb, &fb::Struct_Args { names, - fields, - nullability: n.into(), + dtypes, + nullable: (*n).into(), }, ) .as_union_value() @@ -76,7 +138,7 @@ impl WriteFlatBuffer for DType { fbb, &fb::ListArgs { element_type, - nullability: n.into(), + nullable: (*n).into(), }, ) .as_union_value() @@ -89,7 +151,7 @@ impl WriteFlatBuffer for DType { &fb::ExtensionArgs { id, metadata, - nullability: n.into(), + nullable: (*n).into(), }, ) .as_union_value() @@ -117,24 +179,6 @@ impl WriteFlatBuffer for DType { } } -impl From for fb::Nullability { - fn from(value: Nullability) -> Self { - match value { - Nullability::NonNullable => fb::Nullability::NonNullable, - Nullability::Nullable => fb::Nullability::Nullable, - } - } -} - -impl From<&Nullability> for fb::Nullability { - fn from(value: &Nullability) -> Self { - match value { - Nullability::NonNullable => fb::Nullability::NonNullable, - Nullability::Nullable => fb::Nullability::Nullable, - } - } -} - impl From for fb::PType { fn from(value: PType) -> Self { match value { @@ -178,14 +222,15 @@ impl TryFrom for PType { mod test { use flatbuffers::root; - use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; + use vortex_flatbuffers::FlatBufferToBytes; + use crate::nullability::Nullability; + use crate::DType; use crate::{flatbuffers as fb, PType, StructDType}; - use crate::{DType, Nullability}; fn roundtrip_dtype(dtype: DType) { let bytes = dtype.with_flatbuffer_bytes(|bytes| bytes.to_vec()); - let deserialized = DType::read_flatbuffer(&root::(&bytes).unwrap()).unwrap(); + let deserialized = DType::try_from(root::(&bytes).unwrap()).unwrap(); assert_eq!(dtype, deserialized); } diff --git a/vortex-dtype/src/serde/mod.rs b/vortex-dtype/src/serde/mod.rs new file mode 100644 index 0000000000..859dd4a584 --- /dev/null +++ b/vortex-dtype/src/serde/mod.rs @@ -0,0 +1,31 @@ +mod flatbuffers; +mod proto; +#[allow(clippy::module_inception)] +mod serde; + +#[cfg(test)] +#[cfg(feature = "serde")] +mod test { + use serde_test::{assert_tokens, Token}; + + use crate::PType; + + #[test] + fn test_serde_ptype_json() { + // Ensure we serialize PTypes to lowercase. + let serialized = serde_json::to_string(&PType::U8).unwrap(); + assert_eq!(serialized, "\"u8\""); + assert_eq!(serde_json::from_str::("\"u8\"").unwrap(), PType::U8); + } + + #[test] + fn test_serde_ptype() { + assert_tokens( + &PType::U8, + &[Token::UnitVariant { + name: "PType", + variant: "u8", + }], + ); + } +} diff --git a/vortex-dtype/src/serde/proto.rs b/vortex-dtype/src/serde/proto.rs new file mode 100644 index 0000000000..0e49ded1cf --- /dev/null +++ b/vortex-dtype/src/serde/proto.rs @@ -0,0 +1,129 @@ +#![cfg(feature = "prost")] + +use vortex_error::{vortex_err, VortexError, VortexResult}; + +use crate::proto::d_type::Type; +use crate::{proto as pb, DType, ExtDType, ExtID, ExtMetadata, PType, StructDType}; + +impl TryFrom<&pb::DType> for DType { + type Error = VortexError; + + fn try_from(value: &pb::DType) -> Result { + match value + .r#type + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Unrecognized DType"))? + { + Type::Null(_) => Ok(DType::Null), + Type::Bool(b) => Ok(DType::Bool(b.nullable.into())), + Type::Primitive(p) => Ok(DType::Primitive(p.r#type().into(), p.nullable.into())), + Type::Decimal(_) => todo!("Not Implemented"), + Type::Utf8(u) => Ok(DType::Utf8(u.nullable.into())), + Type::Binary(b) => Ok(DType::Binary(b.nullable.into())), + Type::Struct(s) => Ok(DType::Struct( + StructDType::new( + s.names.iter().map(|s| s.as_str().into()).collect(), + s.dtypes + .iter() + .map(TryInto::::try_into) + .collect::>>()?, + ), + s.nullable.into(), + )), + Type::List(l) => { + let nullable = l.nullable.into(); + Ok(DType::List( + l.element_type + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Invalid list element type"))? + .as_ref() + .try_into() + .map(Box::new)?, + nullable, + )) + } + Type::Extension(e) => Ok(DType::Extension( + ExtDType::new( + ExtID::from(e.id.as_str()), + e.metadata.as_ref().map(|m| ExtMetadata::from(m.as_ref())), + ), + e.nullable.into(), + )), + } + } +} + +impl From<&DType> for pb::DType { + fn from(value: &DType) -> Self { + pb::DType { + r#type: Some(match value { + DType::Null => Type::Null(pb::Null {}), + DType::Bool(n) => Type::Bool(pb::Bool { + nullable: (*n).into(), + }), + DType::Primitive(ptype, n) => Type::Primitive(pb::Primitive { + r#type: pb::PType::from(*ptype).into(), + nullable: (*n).into(), + }), + DType::Utf8(n) => Type::Utf8(pb::Utf8 { + nullable: (*n).into(), + }), + DType::Binary(n) => Type::Binary(pb::Binary { + nullable: (*n).into(), + }), + DType::Struct(s, n) => Type::Struct(pb::Struct { + names: s.names().iter().map(|s| s.as_ref().to_string()).collect(), + dtypes: s.dtypes().iter().map(Into::into).collect(), + nullable: (*n).into(), + }), + DType::List(l, n) => Type::List(Box::new(pb::List { + element_type: Some(Box::new(l.as_ref().into())), + nullable: (*n).into(), + })), + DType::Extension(e, n) => Type::Extension(pb::Extension { + id: e.id().as_ref().into(), + metadata: e.metadata().map(|m| m.as_ref().into()), + nullable: (*n).into(), + }), + }), + } + } +} + +impl From for PType { + fn from(value: pb::PType) -> Self { + use pb::PType::*; + match value { + U8 => PType::U8, + U16 => PType::U16, + U32 => PType::U32, + U64 => PType::U64, + I8 => PType::I8, + I16 => PType::I16, + I32 => PType::I32, + I64 => PType::I64, + F16 => PType::F16, + F32 => PType::F32, + F64 => PType::F64, + } + } +} + +impl From for pb::PType { + fn from(value: PType) -> Self { + use pb::PType::*; + match value { + PType::U8 => U8, + PType::U16 => U16, + PType::U32 => U32, + PType::U64 => U64, + PType::I8 => I8, + PType::I16 => I16, + PType::I32 => I32, + PType::I64 => I64, + PType::F16 => F16, + PType::F32 => F32, + PType::F64 => F64, + } + } +} diff --git a/vortex-dtype/src/serde/serde.rs b/vortex-dtype/src/serde/serde.rs new file mode 100644 index 0000000000..2f7db0dbfa --- /dev/null +++ b/vortex-dtype/src/serde/serde.rs @@ -0,0 +1,25 @@ +#![cfg(feature = "serde")] + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::Nullability; + +/// Serialize Nullability as a boolean +impl Serialize for Nullability { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + bool::from(*self).serialize(serializer) + } +} + +/// Deserialize Nullability from a boolean +impl<'de> Deserialize<'de> for Nullability { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + bool::deserialize(deserializer).map(Nullability::from) + } +} diff --git a/vortex-ipc/benches/ipc_array_reader_take.rs b/vortex-ipc/benches/ipc_array_reader_take.rs index 962c16596c..3b42cede12 100644 --- a/vortex-ipc/benches/ipc_array_reader_take.rs +++ b/vortex-ipc/benches/ipc_array_reader_take.rs @@ -5,7 +5,8 @@ use fallible_iterator::FallibleIterator; use itertools::Itertools; use vortex::array::primitive::PrimitiveArray; use vortex::{Context, IntoArray}; -use vortex_dtype::{DType, Nullability, PType}; +use vortex_dtype::Nullability; +use vortex_dtype::{DType, PType}; use vortex_ipc::iter::FallibleLendingIterator; use vortex_ipc::reader::StreamReader; use vortex_ipc::writer::StreamWriter; diff --git a/vortex-ipc/src/reader.rs b/vortex-ipc/src/reader.rs index 7a5f2d6ab0..a7502d4b22 100644 --- a/vortex-ipc/src/reader.rs +++ b/vortex-ipc/src/reader.rs @@ -19,7 +19,6 @@ use vortex::{ use vortex_buffer::Buffer; use vortex_dtype::{match_each_integer_ptype, DType}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; -use vortex_flatbuffers::ReadFlatBuffer; use vortex_scalar::Scalar; use crate::flatbuffers::ipc::Message; @@ -104,8 +103,8 @@ impl FallibleLendingIterator for StreamReader { .header_as_schema() .unwrap(); - let dtype = DType::read_flatbuffer( - &schema_msg + let dtype = DType::try_from( + schema_msg .dtype() .ok_or_else(|| vortex_err!(InvalidSerde: "Schema missing DType"))?, ) diff --git a/vortex-roaring/src/boolean/mod.rs b/vortex-roaring/src/boolean/mod.rs index 656688df7d..2e2a085580 100644 --- a/vortex-roaring/src/boolean/mod.rs +++ b/vortex-roaring/src/boolean/mod.rs @@ -9,10 +9,9 @@ use vortex::validity::{ArrayValidity, LogicalValidity, Validity}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, ArrayFlatten, OwnedArray}; use vortex_buffer::Buffer; -use vortex_dtype::Nullability; use vortex_dtype::Nullability::NonNullable; +use vortex_dtype::Nullability::Nullable; use vortex_error::{vortex_bail, vortex_err}; -use Nullability::Nullable; mod compress; mod compute; diff --git a/vortex-scalar/Cargo.toml b/vortex-scalar/Cargo.toml index abc30ffc6a..0f18be7574 100644 --- a/vortex-scalar/Cargo.toml +++ b/vortex-scalar/Cargo.toml @@ -12,7 +12,7 @@ edition = { workspace = true } rust-version = { workspace = true } [dependencies] -flatbuffers = { workspace = true } +flatbuffers = { workspace = true, optional = true } flexbuffers = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } @@ -22,10 +22,8 @@ vortex-dtype = { path = "../vortex-dtype" } vortex-error = { path = "../vortex-error" } vortex-flatbuffers = { path = "../vortex-flatbuffers" } - [build-dependencies] -flatc = { workspace = true } -walkdir = { workspace = true } +build-vortex = { path = "../build-vortex" } [lints] workspace = true diff --git a/vortex-scalar/build.rs b/vortex-scalar/build.rs deleted file mode 120000 index 7cb528993c..0000000000 --- a/vortex-scalar/build.rs +++ /dev/null @@ -1 +0,0 @@ -../flatbuffers.build.rs \ No newline at end of file diff --git a/vortex-scalar/build.rs b/vortex-scalar/build.rs new file mode 100644 index 0000000000..46df240a6b --- /dev/null +++ b/vortex-scalar/build.rs @@ -0,0 +1,3 @@ +fn main() { + build_vortex::build(); +} diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index e6a4279169..429640e477 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -23,6 +23,7 @@ mod struct_; mod utf8; mod value; +#[cfg(feature = "flatbuffers")] pub mod flatbuffers { pub use gen_scalar::vortex::*; diff --git a/vortex-scalar/src/serde.rs b/vortex-scalar/src/serde.rs index a03ab231dc..203865913d 100644 --- a/vortex-scalar/src/serde.rs +++ b/vortex-scalar/src/serde.rs @@ -1,5 +1,5 @@ #![cfg(feature = "serde")] - +#![cfg(feature = "flatbuffers")] use flatbuffers::{root, FlatBufferBuilder, WIPOffset}; use serde::de::Visitor; use serde::{Deserialize, Deserializer, Serialize, Serializer}; From 6fb0e8a09de17db621e7f8819a4acec36db453eb Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 8 May 2024 11:11:09 +0100 Subject: [PATCH 2/4] Add ScalarView (#301) Fixes #290 Makes scalar generic over heap allocated data or serialised scalar data. Serialized scalars will no longer include DType information to avoid duplication meaning they require a DType to deserialise. --- Cargo.lock | 4 + Cargo.toml | 1 + build-vortex/src/lib.rs | 2 +- vortex-alp/src/compress.rs | 2 +- vortex-alp/src/compute.rs | 2 +- .../src/array/bool/compute/scalar_at.rs | 14 +- vortex-array/src/array/bool/mod.rs | 2 +- vortex-array/src/array/chunked/mod.rs | 6 +- vortex-array/src/array/constant/flatten.rs | 36 +- vortex-array/src/array/constant/stats.rs | 15 +- vortex-array/src/array/extension/compute.rs | 36 +- .../src/array/primitive/compute/scalar_at.rs | 12 +- .../primitive/compute/subtract_scalar.rs | 25 +- vortex-array/src/array/primitive/stats.rs | 117 +++-- vortex-array/src/array/sparse/compute/mod.rs | 2 +- vortex-array/src/array/sparse/mod.rs | 12 +- vortex-array/src/array/struct/compute.rs | 9 +- vortex-array/src/array/varbin/builder.rs | 4 +- vortex-array/src/array/varbin/compute/mod.rs | 2 +- vortex-array/src/array/varbin/mod.rs | 21 +- vortex-array/src/array/varbin/stats.rs | 31 +- vortex-array/src/array/varbinview/compute.rs | 2 +- vortex-array/src/arrow/array.rs | 4 +- vortex-array/src/arrow/dtype.rs | 4 +- vortex-array/src/stats/mod.rs | 12 +- vortex-array/src/stats/statsset.rs | 21 +- vortex-array/src/validity.rs | 2 +- vortex-buffer/Cargo.toml | 1 + vortex-buffer/src/flexbuffers.rs | 23 + vortex-buffer/src/lib.rs | 29 +- vortex-buffer/src/string.rs | 52 ++ vortex-dict/src/compress.rs | 23 +- vortex-dict/src/compute.rs | 2 +- vortex-dict/src/dict.rs | 6 +- vortex-dtype/Cargo.toml | 9 +- vortex-dtype/src/dtype.rs | 6 +- vortex-dtype/src/lib.rs | 9 +- vortex-dtype/src/ptype.rs | 44 ++ vortex-dtype/src/serde/flatbuffers.rs | 7 +- vortex-dtype/src/serde/proto.rs | 10 +- vortex-error/src/lib.rs | 6 + vortex-fastlanes/src/bitpacking/compress.rs | 16 +- .../src/bitpacking/compute/mod.rs | 9 +- vortex-fastlanes/src/for/compress.rs | 12 +- vortex-fastlanes/src/for/compute.rs | 26 +- vortex-fastlanes/src/for/mod.rs | 2 +- vortex-flatbuffers/src/lib.rs | 18 +- vortex-ree/src/ree.rs | 2 +- vortex-scalar/Cargo.toml | 24 +- vortex-scalar/flatbuffers/scalar.fbs | 52 +- vortex-scalar/proto/scalar.proto | 21 + vortex-scalar/src/binary.rs | 79 ++- vortex-scalar/src/bool.rs | 81 +-- vortex-scalar/src/display.rs | 41 ++ vortex-scalar/src/extension.rs | 107 ++-- vortex-scalar/src/lib.rs | 178 +++---- vortex-scalar/src/list.rs | 146 +++--- vortex-scalar/src/null.rs | 41 -- vortex-scalar/src/primitive.rs | 470 ++++-------------- vortex-scalar/src/pvalue.rs | 132 +++++ vortex-scalar/src/serde.rs | 250 ---------- vortex-scalar/src/serde/flatbuffers.rs | 29 ++ vortex-scalar/src/serde/mod.rs | 4 + vortex-scalar/src/serde/proto.rs | 41 ++ vortex-scalar/src/serde/serde.rs | 175 +++++++ vortex-scalar/src/struct_.rs | 114 ++--- vortex-scalar/src/utf8.rs | 86 ++-- vortex-scalar/src/value.rs | 81 +-- vortex-zigzag/src/compute.rs | 26 +- 69 files changed, 1389 insertions(+), 1501 deletions(-) create mode 100644 vortex-buffer/src/flexbuffers.rs create mode 100644 vortex-buffer/src/string.rs create mode 100644 vortex-scalar/proto/scalar.proto create mode 100644 vortex-scalar/src/display.rs delete mode 100644 vortex-scalar/src/null.rs create mode 100644 vortex-scalar/src/pvalue.rs delete mode 100644 vortex-scalar/src/serde.rs create mode 100644 vortex-scalar/src/serde/flatbuffers.rs create mode 100644 vortex-scalar/src/serde/mod.rs create mode 100644 vortex-scalar/src/serde/proto.rs create mode 100644 vortex-scalar/src/serde/serde.rs diff --git a/Cargo.lock b/Cargo.lock index 11f09cdd3d..8d8f187969 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4958,6 +4958,7 @@ version = "0.1.0" dependencies = [ "arrow-buffer", "bytes", + "flexbuffers", "vortex-dtype", ] @@ -5112,6 +5113,9 @@ dependencies = [ "flexbuffers", "itertools 0.12.1", "num-traits", + "paste", + "prost", + "prost-types", "serde", "vortex-buffer", "vortex-dtype", diff --git a/Cargo.toml b/Cargo.toml index adb1e60e59..bf92f3ed1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ parquet = "51.0.0" paste = "1.0.14" prost = "0.12.4" prost-build = "0.12.4" +prost-types = "0.12.4" pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py311"] } pyo3-log = "0.9.0" rand = "0.8.5" diff --git a/build-vortex/src/lib.rs b/build-vortex/src/lib.rs index 650292159e..5e70edc4c0 100644 --- a/build-vortex/src/lib.rs +++ b/build-vortex/src/lib.rs @@ -26,7 +26,7 @@ pub fn build() { } // Proto (prost) - if env::var("CARGO_FEATURE_PROST").ok().is_some() { + if env::var("CARGO_FEATURE_PROTO").ok().is_some() { build_proto(); } } diff --git a/vortex-alp/src/compress.rs b/vortex-alp/src/compress.rs index 1889992df0..d648f23571 100644 --- a/vortex-alp/src/compress.rs +++ b/vortex-alp/src/compress.rs @@ -98,7 +98,7 @@ where PrimitiveArray::from(exc_pos).into_array(), PrimitiveArray::from_vec(exc, Validity::AllValid).into_array(), len, - Scalar::null(&values.dtype().as_nullable()), + Scalar::null(values.dtype().as_nullable()), ) .into_array() }), diff --git a/vortex-alp/src/compute.rs b/vortex-alp/src/compute.rs index 169bb56a8d..562bf77861 100644 --- a/vortex-alp/src/compute.rs +++ b/vortex-alp/src/compute.rs @@ -30,7 +30,7 @@ impl ScalarAtFn for ALPArray<'_> { use crate::ALPFloat; let encoded_val = scalar_at(&self.encoded(), index)?; match_each_alp_float_ptype!(self.dtype().try_into().unwrap(), |$T| { - let encoded_val: <$T as ALPFloat>::ALPInt = encoded_val.try_into().unwrap(); + let encoded_val: <$T as ALPFloat>::ALPInt = encoded_val.as_ref().try_into().unwrap(); Scalar::from(<$T as ALPFloat>::decode_single( encoded_val, self.exponents(), diff --git a/vortex-array/src/array/bool/compute/scalar_at.rs b/vortex-array/src/array/bool/compute/scalar_at.rs index 06bcd23e82..06a03e96c1 100644 --- a/vortex-array/src/array/bool/compute/scalar_at.rs +++ b/vortex-array/src/array/bool/compute/scalar_at.rs @@ -1,5 +1,5 @@ use vortex_error::VortexResult; -use vortex_scalar::{BoolScalar, Scalar}; +use vortex_scalar::Scalar; use crate::array::bool::BoolArray; use crate::compute::scalar_at::ScalarAtFn; @@ -8,12 +8,10 @@ use crate::ArrayDType; impl ScalarAtFn for BoolArray<'_> { fn scalar_at(&self, index: usize) -> VortexResult { - Ok(BoolScalar::try_new( - self.is_valid(index) - .then(|| self.boolean_buffer().value(index)), - self.dtype().nullability(), - ) - .unwrap() - .into()) + if self.is_valid(index) { + Ok(self.boolean_buffer().value(index).into()) + } else { + return Ok(Scalar::null(self.dtype().clone())); + } } } diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index 55d8e4ebbc..9e75aee753 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -129,7 +129,7 @@ mod tests { #[test] fn bool_array() { let arr = BoolArray::from(vec![true, false, true]).into_array(); - let scalar: bool = scalar_at(&arr, 0).unwrap().try_into().unwrap(); + let scalar = bool::try_from(&scalar_at(&arr, 0).unwrap()).unwrap(); assert!(scalar); } } diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index b06d7f5133..a7a5389891 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -74,12 +74,12 @@ impl ChunkedArray<'_> { .unwrap() .to_index(); let mut chunk_start = - usize::try_from(scalar_at(&self.chunk_ends(), index_chunk).unwrap()).unwrap(); + usize::try_from(&scalar_at(&self.chunk_ends(), index_chunk).unwrap()).unwrap(); if chunk_start != index { index_chunk -= 1; chunk_start = - usize::try_from(scalar_at(&self.chunk_ends(), index_chunk).unwrap()).unwrap(); + usize::try_from(&scalar_at(&self.chunk_ends(), index_chunk).unwrap()).unwrap(); } let index_in_chunk = index - chunk_start; @@ -125,7 +125,7 @@ impl AcceptArrayVisitor for ChunkedArray<'_> { impl ArrayTrait for ChunkedArray<'_> { fn len(&self) -> usize { - usize::try_from(scalar_at(&self.chunk_ends(), self.nchunks()).unwrap()).unwrap() + usize::try_from(&scalar_at(&self.chunk_ends(), self.nchunks()).unwrap()).unwrap() } } diff --git a/vortex-array/src/array/constant/flatten.rs b/vortex-array/src/array/constant/flatten.rs index 86d2658819..c83c897b9c 100644 --- a/vortex-array/src/array/constant/flatten.rs +++ b/vortex-array/src/array/constant/flatten.rs @@ -1,6 +1,6 @@ -use vortex_dtype::{match_each_native_ptype, Nullability}; -use vortex_error::VortexResult; -use vortex_scalar::Scalar; +use vortex_dtype::{match_each_native_ptype, Nullability, PType}; +use vortex_error::{vortex_bail, VortexResult}; +use vortex_scalar::BoolScalar; use crate::array::bool::BoolArray; use crate::array::constant::ConstantArray; @@ -22,20 +22,22 @@ impl ArrayFlatten for ConstantArray<'_> { }, }; - Ok(match self.scalar() { - Scalar::Bool(b) => Flattened::Bool(BoolArray::from_vec( - vec![b.value().copied().unwrap_or_default(); self.len()], + if let Ok(b) = BoolScalar::try_from(self.scalar()) { + return Ok(Flattened::Bool(BoolArray::from_vec( + vec![b.value().unwrap_or_default(); self.len()], validity, - )), - Scalar::Primitive(p) => { - match_each_native_ptype!(p.ptype(), |$P| { - Flattened::Primitive(PrimitiveArray::from_vec::<$P>( - vec![$P::try_from(self.scalar())?; self.len()], - validity, - )) - }) - } - _ => panic!("Unsupported scalar type {}", self.dtype()), - }) + ))); + } + + if let Ok(ptype) = PType::try_from(self.scalar().dtype()) { + return match_each_native_ptype!(ptype, |$P| { + Ok(Flattened::Primitive(PrimitiveArray::from_vec::<$P>( + vec![$P::try_from(self.scalar())?; self.len()], + validity, + ))) + }); + } + + vortex_bail!("Unsupported scalar type {}", self.dtype()) } } diff --git a/vortex-array/src/array/constant/stats.rs b/vortex-array/src/array/constant/stats.rs index 7b948f0e7f..dc957a9fae 100644 --- a/vortex-array/src/array/constant/stats.rs +++ b/vortex-array/src/array/constant/stats.rs @@ -1,22 +1,23 @@ use std::collections::HashMap; -use vortex_dtype::DType; use vortex_error::VortexResult; -use vortex_scalar::Scalar; +use vortex_scalar::BoolScalar; use crate::array::constant::ConstantArray; use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet}; -use crate::{ArrayDType, ArrayTrait}; +use crate::ArrayTrait; impl ArrayStatisticsCompute for ConstantArray<'_> { fn compute_statistics(&self, _stat: Stat) -> VortexResult { - if matches!(self.dtype(), &DType::Bool(_)) { - let Scalar::Bool(b) = self.scalar() else { - unreachable!("Got bool dtype without bool scalar") + if let Ok(b) = BoolScalar::try_from(self.scalar()) { + let true_count = if b.value().unwrap_or(false) { + self.len() as u64 + } else { + 0 }; return Ok(StatsSet::from(HashMap::from([( Stat::TrueCount, - (self.len() as u64 * b.value().cloned().map(|v| v as u64).unwrap_or(0)).into(), + true_count.into(), )]))); } diff --git a/vortex-array/src/array/extension/compute.rs b/vortex-array/src/array/extension/compute.rs index 9c111e671a..b933d8c730 100644 --- a/vortex-array/src/array/extension/compute.rs +++ b/vortex-array/src/array/extension/compute.rs @@ -1,6 +1,6 @@ use arrow_array::ArrayRef as ArrowArrayRef; use vortex_error::{vortex_bail, VortexResult}; -use vortex_scalar::{ExtScalar, Scalar}; +use vortex_scalar::Scalar; use crate::array::datetime::LocalDateTimeArray; use crate::array::extension::ExtensionArray; @@ -8,13 +8,10 @@ use crate::compute::as_arrow::AsArrowArray; use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::cast::CastFn; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; -use crate::compute::search_sorted::{ - search_sorted, SearchResult, SearchSortedFn, SearchSortedSide, -}; use crate::compute::slice::{slice, SliceFn}; use crate::compute::take::{take, TakeFn}; use crate::compute::ArrayCompute; -use crate::{Array, ArrayDType, IntoArray, OwnedArray, ToStatic}; +use crate::{Array, IntoArray, OwnedArray, ToStatic}; impl ArrayCompute for ExtensionArray<'_> { fn as_arrow(&self) -> Option<&dyn AsArrowArray> { @@ -36,10 +33,6 @@ impl ArrayCompute for ExtensionArray<'_> { Some(self) } - fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { - Some(self) - } - fn slice(&self) -> Option<&dyn SliceFn> { Some(self) } @@ -82,29 +75,10 @@ impl AsContiguousFn for ExtensionArray<'_> { impl ScalarAtFn for ExtensionArray<'_> { fn scalar_at(&self, index: usize) -> VortexResult { - Ok(Scalar::Extension(ExtScalar::try_new( + Ok(Scalar::extension( self.ext_dtype().clone(), - self.dtype().nullability(), - Some(scalar_at(&self.storage(), index)?), - )?)) - } -} - -impl SearchSortedFn for ExtensionArray<'_> { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { - if value.dtype() != self.dtype() { - vortex_bail!("Value dtype does not match array dtype"); - } - let Scalar::Extension(ext) = value else { - unreachable!(); - }; - - let storage_scalar = ext - .value() - .map(|v| (**v).clone()) - .unwrap_or_else(|| Scalar::null(self.dtype())); - - search_sorted(&self.storage(), storage_scalar, side) + scalar_at(&self.storage(), index)?, + )) } } diff --git a/vortex-array/src/array/primitive/compute/scalar_at.rs b/vortex-array/src/array/primitive/compute/scalar_at.rs index 0cfd8a95ab..33c0940b4d 100644 --- a/vortex-array/src/array/primitive/compute/scalar_at.rs +++ b/vortex-array/src/array/primitive/compute/scalar_at.rs @@ -1,6 +1,5 @@ use vortex_dtype::match_each_native_ptype; use vortex_error::VortexResult; -use vortex_scalar::PrimitiveScalar; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; @@ -11,12 +10,11 @@ use crate::ArrayDType; impl ScalarAtFn for PrimitiveArray<'_> { fn scalar_at(&self, index: usize) -> VortexResult { match_each_native_ptype!(self.ptype(), |$T| { - Ok(PrimitiveScalar::try_new( - self.is_valid(index) - .then(|| self.typed_data::<$T>()[index]), - self.dtype().nullability(), - )? - .into()) + if self.is_valid(index) { + Ok(Scalar::primitive(self.typed_data::<$T>()[index], self.dtype().nullability())) + } else { + Ok(Scalar::null(self.dtype().clone())) + } }) } } diff --git a/vortex-array/src/array/primitive/compute/subtract_scalar.rs b/vortex-array/src/array/primitive/compute/subtract_scalar.rs index b4dbbc3d96..3d6ecab8d5 100644 --- a/vortex-array/src/array/primitive/compute/subtract_scalar.rs +++ b/vortex-array/src/array/primitive/compute/subtract_scalar.rs @@ -3,7 +3,8 @@ use num_traits::ops::overflowing::OverflowingSub; use num_traits::SaturatingSub; use vortex_dtype::{match_each_float_ptype, match_each_integer_ptype, NativePType}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; -use vortex_scalar::{PScalarType, Scalar}; +use vortex_scalar::PrimitiveScalar; +use vortex_scalar::Scalar; use crate::array::constant::ConstantArray; use crate::array::primitive::PrimitiveArray; @@ -20,24 +21,22 @@ impl SubtractScalarFn for PrimitiveArray<'_> { let validity = self.validity().to_logical(self.len()); if validity.all_invalid() { - return Ok(ConstantArray::new(Scalar::null(self.dtype()), self.len()).into_array()); + return Ok( + ConstantArray::new(Scalar::null(self.dtype().clone()), self.len()).into_array(), + ); } - let to_subtract = match to_subtract { - Scalar::Primitive(prim_scalar) => prim_scalar, - _ => vortex_bail!("Expected primitive scalar"), - }; - let result = if to_subtract.dtype().is_int() { match_each_integer_ptype!(self.ptype(), |$T| { - let to_subtract: $T = to_subtract - .typed_value() + let to_subtract: $T = PrimitiveScalar::try_from(to_subtract)? + .typed_value::<$T>() .ok_or_else(|| vortex_err!("expected primitive"))?; subtract_scalar_integer::<$T>(self, to_subtract)? }) } else { match_each_float_ptype!(self.ptype(), |$T| { - let to_subtract: $T = to_subtract.typed_value() + let to_subtract: $T = PrimitiveScalar::try_from(to_subtract)? + .typed_value::<$T>() .ok_or_else(|| vortex_err!("expected primitive"))?; let sub_vec : Vec<$T> = self.typed_data::<$T>() .iter() @@ -51,11 +50,7 @@ impl SubtractScalarFn for PrimitiveArray<'_> { fn subtract_scalar_integer< 'a, - T: NativePType - + OverflowingSub - + SaturatingSub - + PScalarType - + TryFrom, + T: NativePType + OverflowingSub + SaturatingSub + for<'b> TryFrom<&'b Scalar, Error = VortexError>, >( subtract_from: &PrimitiveArray<'a>, to_subtract: T, diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index fd1d4f1166..9e29ffd974 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -2,10 +2,12 @@ use std::collections::HashMap; use std::mem::size_of; use arrow_buffer::buffer::BooleanBuffer; -use vortex_dtype::match_each_native_ptype; +use num_traits::PrimInt; +use vortex_dtype::half::f16; +use vortex_dtype::Nullability::Nullable; +use vortex_dtype::{match_each_native_ptype, DType, NativePType}; use vortex_error::VortexResult; -use vortex_scalar::{ListScalarVec, PScalar}; -use vortex_scalar::{PScalarType, Scalar}; +use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet}; @@ -13,8 +15,8 @@ use crate::validity::ArrayValidity; use crate::validity::LogicalValidity; use crate::IntoArray; -trait PStatsType: PScalarType + Into {} -impl> PStatsType for T {} +trait PStatsType: NativePType + Into + BitWidth {} +impl + BitWidth> PStatsType for T {} impl ArrayStatisticsCompute for PrimitiveArray<'_> { fn compute_statistics(&self, stat: Stat) -> VortexResult { @@ -45,20 +47,23 @@ impl ArrayStatisticsCompute for &[T] { fn all_null_stats(len: usize) -> VortexResult { Ok(StatsSet::from(HashMap::from([ - (Stat::Min, Option::::None.into()), - (Stat::Max, Option::::None.into()), + ( + Stat::Min, + Scalar::null(DType::Primitive(T::PTYPE, Nullable)), + ), + ( + Stat::Max, + Scalar::null(DType::Primitive(T::PTYPE, Nullable)), + ), (Stat::IsConstant, true.into()), (Stat::IsSorted, true.into()), (Stat::IsStrictSorted, (len < 2).into()), (Stat::RunCount, 1.into()), (Stat::NullCount, len.into()), - ( - Stat::BitWidthFreq, - ListScalarVec(vec![0; size_of::() * 8 + 1]).into(), - ), + (Stat::BitWidthFreq, vec![0; size_of::() * 8 + 1].into()), ( Stat::TrailingZeroFreq, - ListScalarVec(vec![size_of::() * 8; size_of::() * 8 + 1]).into(), + vec![size_of::() * 8; size_of::() * 8 + 1].into(), ), ]))) } @@ -96,47 +101,52 @@ impl<'a, T: PStatsType> ArrayStatisticsCompute for NullableValues<'a, T> { } trait BitWidth { - fn bit_width(self) -> usize; - fn trailing_zeros(self) -> usize; + fn bit_width(self) -> u32; + fn trailing_zeros(self) -> u32; } -impl BitWidth for T { - fn bit_width(self) -> usize { - let bit_width = size_of::() * 8; - let scalar: PScalar = self.into(); - match scalar { - PScalar::U8(i) => bit_width - i.leading_zeros() as usize, - PScalar::U16(i) => bit_width - i.leading_zeros() as usize, - PScalar::U32(i) => bit_width - i.leading_zeros() as usize, - PScalar::U64(i) => bit_width - i.leading_zeros() as usize, - PScalar::I8(i) => bit_width - i.leading_zeros() as usize, - PScalar::I16(i) => bit_width - i.leading_zeros() as usize, - PScalar::I32(i) => bit_width - i.leading_zeros() as usize, - PScalar::I64(i) => bit_width - i.leading_zeros() as usize, - PScalar::F16(_) => bit_width, - PScalar::F32(_) => bit_width, - PScalar::F64(_) => bit_width, +macro_rules! int_bit_width { + ($T:ty) => { + impl BitWidth for $T { + fn bit_width(self) -> u32 { + Self::BITS - PrimInt::leading_zeros(self) + } + + fn trailing_zeros(self) -> u32 { + PrimInt::trailing_zeros(self) + } } - } + }; +} + +int_bit_width!(u8); +int_bit_width!(u16); +int_bit_width!(u32); +int_bit_width!(u64); +int_bit_width!(i8); +int_bit_width!(i16); +int_bit_width!(i32); +int_bit_width!(i64); - fn trailing_zeros(self) -> usize { - let scalar: PScalar = self.into(); - match scalar { - PScalar::U8(i) => i.trailing_zeros() as usize, - PScalar::U16(i) => i.trailing_zeros() as usize, - PScalar::U32(i) => i.trailing_zeros() as usize, - PScalar::U64(i) => i.trailing_zeros() as usize, - PScalar::I8(i) => i.trailing_zeros() as usize, - PScalar::I16(i) => i.trailing_zeros() as usize, - PScalar::I32(i) => i.trailing_zeros() as usize, - PScalar::I64(i) => i.trailing_zeros() as usize, - PScalar::F16(_) => 0, - PScalar::F32(_) => 0, - PScalar::F64(_) => 0, +// TODO(ngates): just skip counting this in the implementation. +macro_rules! float_bit_width { + ($T:ty) => { + impl BitWidth for $T { + fn bit_width(self) -> u32 { + (size_of::() * 8) as u32 + } + + fn trailing_zeros(self) -> u32 { + 0 + } } - } + }; } +float_bit_width!(f16); +float_bit_width!(f32); +float_bit_width!(f64); + struct StatsAccumulator { prev: T, min: T, @@ -162,8 +172,8 @@ impl StatsAccumulator { bit_widths: vec![0; size_of::() * 8 + 1], trailing_zeros: vec![0; size_of::() * 8 + 1], }; - stats.bit_widths[first_value.bit_width()] += 1; - stats.trailing_zeros[first_value.trailing_zeros()] += 1; + stats.bit_widths[first_value.bit_width() as usize] += 1; + stats.trailing_zeros[first_value.trailing_zeros() as usize] += 1; stats } @@ -187,8 +197,8 @@ impl StatsAccumulator { } pub fn next(&mut self, next: T) { - self.bit_widths[next.bit_width()] += 1; - self.trailing_zeros[next.trailing_zeros()] += 1; + self.bit_widths[next.bit_width() as usize] += 1; + self.trailing_zeros[next.trailing_zeros() as usize] += 1; if self.prev == next { self.is_strict_sorted = false; @@ -212,11 +222,8 @@ impl StatsAccumulator { (Stat::Max, self.max.into()), (Stat::NullCount, self.null_count.into()), (Stat::IsConstant, (self.min == self.max).into()), - (Stat::BitWidthFreq, ListScalarVec(self.bit_widths).into()), - ( - Stat::TrailingZeroFreq, - ListScalarVec(self.trailing_zeros).into(), - ), + (Stat::BitWidthFreq, self.bit_widths.into()), + (Stat::TrailingZeroFreq, self.trailing_zeros.into()), (Stat::IsSorted, self.is_sorted.into()), ( Stat::IsStrictSorted, diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index ff1cba531c..15ecdd9cea 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -156,7 +156,7 @@ mod test { PrimitiveArray::from_vec(vec![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid) .into_array(), 100, - Scalar::null(&DType::Primitive(PType::F64, Nullability::Nullable)), + Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable)), ) .into_array() } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index ae2d06940a..611ea82ef4 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -187,7 +187,7 @@ mod test { use crate::{Array, IntoArray, OwnedArray}; fn nullable_fill() -> Scalar { - Scalar::null(&DType::Primitive(PType::I32, Nullable)) + Scalar::null(DType::Primitive(PType::I32, Nullable)) } #[allow(dead_code)] @@ -233,7 +233,7 @@ mod test { #[test] pub fn iter_sliced() { - let p_fill_val = Some(non_nullable_fill().try_into().unwrap()); + let p_fill_val = Some(non_nullable_fill().as_ref().try_into().unwrap()); assert_sparse_array( &slice(&sparse_array(non_nullable_fill()), 2, 7).unwrap(), &[Some(100), p_fill_val, p_fill_val, Some(200), p_fill_val], @@ -272,7 +272,7 @@ mod test { #[test] pub fn test_scalar_at() { assert_eq!( - usize::try_from(scalar_at(&sparse_array(nullable_fill()), 2).unwrap()).unwrap(), + usize::try_from(&scalar_at(&sparse_array(nullable_fill()), 2).unwrap()).unwrap(), 100 ); let error = scalar_at(&sparse_array(nullable_fill()), 10).err().unwrap(); @@ -288,7 +288,7 @@ mod test { pub fn scalar_at_sliced() { let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap(); assert_eq!( - usize::try_from(scalar_at(&sliced, 0).unwrap()).unwrap(), + usize::try_from(&scalar_at(&sliced, 0).unwrap()).unwrap(), 100 ); let error = scalar_at(&sliced, 5).err().unwrap(); @@ -304,7 +304,7 @@ mod test { pub fn scalar_at_sliced_twice() { let sliced_once = slice(&sparse_array(nullable_fill()), 1, 8).unwrap(); assert_eq!( - usize::try_from(scalar_at(&sliced_once, 1).unwrap()).unwrap(), + usize::try_from(&scalar_at(&sliced_once, 1).unwrap()).unwrap(), 100 ); let error = scalar_at(&sliced_once, 7).err().unwrap(); @@ -317,7 +317,7 @@ mod test { let sliced_twice = slice(&sliced_once, 1, 6).unwrap(); assert_eq!( - usize::try_from(scalar_at(&sliced_twice, 3).unwrap()).unwrap(), + usize::try_from(&scalar_at(&sliced_twice, 3).unwrap()).unwrap(), 200 ); let error2 = scalar_at(&sliced_twice, 5).err().unwrap(); diff --git a/vortex-array/src/array/struct/compute.rs b/vortex-array/src/array/struct/compute.rs index 3579f23649..ecb842bb23 100644 --- a/vortex-array/src/array/struct/compute.rs +++ b/vortex-array/src/array/struct/compute.rs @@ -6,7 +6,7 @@ use arrow_array::{ use arrow_schema::{Field, Fields}; use itertools::Itertools; use vortex_error::VortexResult; -use vortex_scalar::{Scalar, StructScalar}; +use vortex_scalar::Scalar; use crate::array::r#struct::StructArray; use crate::compute::as_arrow::{as_arrow, AsArrowArray}; @@ -103,13 +103,12 @@ impl AsContiguousFn for StructArray<'_> { impl ScalarAtFn for StructArray<'_> { fn scalar_at(&self, index: usize) -> VortexResult { - Ok(StructScalar::new( + Ok(Scalar::r#struct( self.dtype().clone(), self.children() - .map(|field| scalar_at(&field, index)) + .map(|field| scalar_at(&field, index).map(|s| s.into_value())) .try_collect()?, - ) - .into()) + )) } } diff --git a/vortex-array/src/array/varbin/builder.rs b/vortex-array/src/array/varbin/builder.rs index 7a73472959..da9104420a 100644 --- a/vortex-array/src/array/varbin/builder.rs +++ b/vortex-array/src/array/varbin/builder.rs @@ -69,7 +69,7 @@ impl VarBinBuilder { mod test { use vortex_dtype::DType; use vortex_dtype::Nullability::Nullable; - use vortex_scalar::Utf8Scalar; + use vortex_scalar::Scalar; use crate::array::varbin::builder::VarBinBuilder; use crate::compute::scalar_at::scalar_at; @@ -87,7 +87,7 @@ mod test { assert_eq!(array.dtype().nullability(), Nullable); assert_eq!( scalar_at(&array, 0).unwrap(), - Utf8Scalar::nullable("hello".to_owned()).into() + Scalar::utf8("hello".to_string(), Nullable) ); assert!(scalar_at(&array, 1).unwrap().is_null()); } diff --git a/vortex-array/src/array/varbin/compute/mod.rs b/vortex-array/src/array/varbin/compute/mod.rs index 8a56aa9535..9823652238 100644 --- a/vortex-array/src/array/varbin/compute/mod.rs +++ b/vortex-array/src/array/varbin/compute/mod.rs @@ -144,7 +144,7 @@ impl ScalarAtFn for VarBinArray<'_> { self.dtype(), )) } else { - Ok(Scalar::null(self.dtype())) + Ok(Scalar::null(self.dtype().clone())) } } } diff --git a/vortex-array/src/array/varbin/mod.rs b/vortex-array/src/array/varbin/mod.rs index 20058f16b8..cbd6ac3254 100644 --- a/vortex-array/src/array/varbin/mod.rs +++ b/vortex-array/src/array/varbin/mod.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use vortex_dtype::Nullability; use vortex_dtype::{match_each_native_ptype, NativePType}; use vortex_error::vortex_bail; -use vortex_scalar::{BinaryScalar, Scalar, Utf8Scalar}; +use vortex_scalar::Scalar; use crate::array::varbin::builder::VarBinBuilder; use crate::compute::scalar_at::scalar_at; @@ -73,11 +73,12 @@ impl VarBinArray<'_> { .expect("missing offsets") } - pub fn first_offset>( + pub fn first_offset TryFrom<&'a Scalar, Error = VortexError>>( &self, ) -> VortexResult { scalar_at(&self.offsets(), 0)? .cast(&DType::from(T::PTYPE))? + .as_ref() .try_into() } @@ -93,9 +94,10 @@ impl VarBinArray<'_> { } pub fn sliced_bytes(&self) -> VortexResult { - let first_offset: usize = scalar_at(&self.offsets(), 0)?.try_into()?; - let last_offset: usize = - scalar_at(&self.offsets(), self.offsets().len() - 1)?.try_into()?; + let first_offset: usize = scalar_at(&self.offsets(), 0)?.as_ref().try_into()?; + let last_offset: usize = scalar_at(&self.offsets(), self.offsets().len() - 1)? + .as_ref() + .try_into()?; slice(&self.bytes(), first_offset, last_offset) } @@ -143,6 +145,7 @@ impl VarBinArray<'_> { .unwrap_or_else(|| { scalar_at(&self.offsets(), index) .unwrap() + .as_ref() .try_into() .unwrap() }) @@ -207,13 +210,9 @@ impl<'a> FromIterator> for VarBinArray<'_> { pub fn varbin_scalar(value: Vec, dtype: &DType) -> Scalar { if matches!(dtype, DType::Utf8(_)) { let str = unsafe { String::from_utf8_unchecked(value) }; - Utf8Scalar::try_new(Some(str), dtype.nullability()) - .unwrap() - .into() + Scalar::utf8(str, dtype.nullability()) } else { - BinaryScalar::try_new(Some(value), dtype.nullability()) - .unwrap() - .into() + Scalar::binary(value.into(), dtype.nullability()) } } diff --git a/vortex-array/src/array/varbin/stats.rs b/vortex-array/src/array/varbin/stats.rs index 5d161b6fdf..7060ed1351 100644 --- a/vortex-array/src/array/varbin/stats.rs +++ b/vortex-array/src/array/varbin/stats.rs @@ -43,8 +43,8 @@ pub fn compute_stats(iter: &mut dyn Iterator>, dtype: &DTyp fn all_null_stats(len: usize, dtype: &DType) -> StatsSet { StatsSet::from(HashMap::from([ - (Stat::Min, Scalar::null(dtype)), - (Stat::Max, Scalar::null(dtype)), + (Stat::Min, Scalar::null(dtype.clone())), + (Stat::Max, Scalar::null(dtype.clone())), (Stat::IsConstant, true.into()), (Stat::IsSorted, true.into()), (Stat::IsStrictSorted, (len < 2).into()), @@ -124,6 +124,9 @@ impl<'a> VarBinAccumulator<'a> { #[cfg(test)] mod test { + use std::ops::Deref; + + use vortex_buffer::{Buffer, BufferString}; use vortex_dtype::{DType, Nullability}; use crate::array::varbin::{OwnedVarBinArray, VarBinArray}; @@ -140,12 +143,12 @@ mod test { fn utf8_stats() { let arr = array(DType::Utf8(Nullability::NonNullable)); assert_eq!( - arr.statistics().compute_min::().unwrap(), - "hello world".to_owned() + arr.statistics().compute_min::().unwrap(), + BufferString::from("hello world".to_string()) ); assert_eq!( - arr.statistics().compute_max::().unwrap(), - "hello world this is a long string".to_owned() + arr.statistics().compute_max::().unwrap(), + BufferString::from("hello world this is a long string".to_string()) ); assert_eq!(arr.statistics().compute_run_count().unwrap(), 2); assert!(!arr.statistics().compute_is_constant().unwrap()); @@ -156,12 +159,12 @@ mod test { fn binary_stats() { let arr = array(DType::Binary(Nullability::NonNullable)); assert_eq!( - arr.statistics().compute_min::>().unwrap(), - "hello world".as_bytes().to_vec() + arr.statistics().compute_min::().unwrap().deref(), + "hello world".as_bytes() ); assert_eq!( - arr.statistics().compute_max::>().unwrap(), - "hello world this is a long string".as_bytes().to_vec() + arr.statistics().compute_max::().unwrap().deref(), + "hello world this is a long string".as_bytes() ); assert_eq!(arr.statistics().compute_run_count().unwrap(), 2); assert!(!arr.statistics().compute_is_constant().unwrap()); @@ -180,12 +183,12 @@ mod test { DType::Utf8(Nullability::Nullable), ); assert_eq!( - array.statistics().compute_min::().unwrap(), - "hello world".to_owned() + array.statistics().compute_min::().unwrap(), + BufferString::from("hello world".to_string()) ); assert_eq!( - array.statistics().compute_max::().unwrap(), - "hello world this is a long string".to_owned() + array.statistics().compute_max::().unwrap(), + BufferString::from("hello world this is a long string".to_string()) ); } diff --git a/vortex-array/src/array/varbinview/compute.rs b/vortex-array/src/array/varbinview/compute.rs index 30668059db..8151637221 100644 --- a/vortex-array/src/array/varbinview/compute.rs +++ b/vortex-array/src/array/varbinview/compute.rs @@ -38,7 +38,7 @@ impl ScalarAtFn for VarBinViewArray<'_> { self.bytes_at(index) .map(|bytes| varbin_scalar(bytes, self.dtype())) } else { - Ok(Scalar::null(self.dtype())) + Ok(Scalar::null(self.dtype().clone())) } } } diff --git a/vortex-array/src/arrow/array.rs b/vortex-array/src/arrow/array.rs index c93cd9dc34..92e466e6c7 100644 --- a/vortex-array/src/arrow/array.rs +++ b/vortex-array/src/arrow/array.rs @@ -22,7 +22,7 @@ use arrow_schema::{DataType, TimeUnit}; use itertools::Itertools; use vortex_dtype::DType; use vortex_dtype::NativePType; -use vortex_scalar::NullScalar; +use vortex_scalar::Scalar; use crate::array::bool::BoolArray; use crate::array::constant::ConstantArray; @@ -197,7 +197,7 @@ impl FromArrowArray<&ArrowStructArray> for ArrayData { impl FromArrowArray<&ArrowNullArray> for ArrayData { fn from_arrow(value: &ArrowNullArray, nullable: bool) -> Self { assert!(nullable); - ConstantArray::new(NullScalar::new(), value.len()).into_array_data() + ConstantArray::new(Scalar::null(DType::Null), value.len()).into_array_data() } } diff --git a/vortex-array/src/arrow/dtype.rs b/vortex-array/src/arrow/dtype.rs index f8c60bc1b2..4dd328fd04 100644 --- a/vortex-array/src/arrow/dtype.rs +++ b/vortex-array/src/arrow/dtype.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use arrow_schema::TimeUnit as ArrowTimeUnit; use arrow_schema::{DataType, Field, SchemaRef}; use itertools::Itertools; @@ -81,7 +83,7 @@ impl FromArrowType<&Field> for DType { // DataType::Time32(u) => localtime(u.into(), IntWidth::_32, nullability), // DataType::Time64(u) => localtime(u.into(), IntWidth::_64, nullability), DataType::List(e) | DataType::LargeList(e) => { - List(Box::new(DType::from_arrow(e.as_ref())), nullability) + List(Arc::new(DType::from_arrow(e.as_ref())), nullability) } DataType::Struct(f) => Struct( StructDType::new( diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index d3c3b9b824..46e9fad434 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -5,7 +5,7 @@ use enum_iterator::Sequence; pub use statsset::*; use vortex_dtype::{DType, NativePType}; use vortex_error::{VortexError, VortexResult}; -use vortex_scalar::{ListScalarVec, Scalar}; +use vortex_scalar::Scalar; mod statsset; @@ -137,13 +137,13 @@ impl dyn Statistics + '_ { Ok(res.expect("Result should have been populated by previous call")) } - pub fn compute_as_cast>( + pub fn compute_as_cast TryFrom<&'a Scalar, Error = VortexError>>( &self, stat: Stat, ) -> VortexResult { let mut res: Option = None; self.with_computed_stat_value(stat, &mut |s| { - res = Some(U::try_from(s.cast(&DType::from(U::PTYPE))?)?); + res = Some(U::try_from(s.cast(&DType::from(U::PTYPE))?.as_ref())?); Ok(()) })?; Ok(res.expect("Result should have been populated by previous call")) @@ -186,12 +186,10 @@ impl dyn Statistics + '_ { } pub fn compute_bit_width_freq(&self) -> VortexResult> { - self.compute_as::>(Stat::BitWidthFreq) - .map(|s| s.0) + self.compute_as::>(Stat::BitWidthFreq) } pub fn compute_trailing_zero_freq(&self) -> VortexResult> { - self.compute_as::>(Stat::TrailingZeroFreq) - .map(|s| s.0) + self.compute_as::>(Stat::TrailingZeroFreq) } } diff --git a/vortex-array/src/stats/statsset.rs b/vortex-array/src/stats/statsset.rs index b853ec53ba..efd5988dff 100644 --- a/vortex-array/src/stats/statsset.rs +++ b/vortex-array/src/stats/statsset.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use enum_iterator::all; use itertools::Itertools; use vortex_error::VortexError; -use vortex_scalar::{ListScalarVec, Scalar}; +use vortex_scalar::Scalar; use crate::stats::Stat; @@ -156,19 +156,16 @@ impl StatsSet { fn merge_freq_stat(&mut self, other: &Self, stat: Stat) { match self.values.entry(stat) { Entry::Occupied(mut e) => { - if let Some(other_value) = other.get_as::>(stat) { + if let Some(other_value) = other.get_as::>(stat) { // TODO(robert): Avoid the copy here. We could e.get_mut() but need to figure out casting - let self_value: ListScalarVec = e.get().try_into().unwrap(); + let self_value: Vec = e.get().try_into().unwrap(); e.insert( - ListScalarVec( - self_value - .0 - .iter() - .zip_eq(other_value.0.iter()) - .map(|(s, o)| *s + *o) - .collect::>(), - ) - .into(), + self_value + .iter() + .zip_eq(other_value.iter()) + .map(|(s, o)| *s + *o) + .collect::>() + .into(), ); } } diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 998dd93a09..200d458e94 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -96,7 +96,7 @@ impl<'v> Validity<'v> { match self { Validity::NonNullable | Validity::AllValid => true, Validity::AllInvalid => false, - Validity::Array(a) => scalar_at(a, index).unwrap().try_into().unwrap(), + Validity::Array(a) => bool::try_from(&scalar_at(a, index).unwrap()).unwrap(), } } diff --git a/vortex-buffer/Cargo.toml b/vortex-buffer/Cargo.toml index b5f2f98033..db1d71e79f 100644 --- a/vortex-buffer/Cargo.toml +++ b/vortex-buffer/Cargo.toml @@ -13,6 +13,7 @@ rust-version.workspace = true [dependencies] arrow-buffer = { workspace = true } bytes = { workspace = true } +flexbuffers = { workspace = true, optional = true } vortex-dtype = { path = "../vortex-dtype" } [lints] diff --git a/vortex-buffer/src/flexbuffers.rs b/vortex-buffer/src/flexbuffers.rs new file mode 100644 index 0000000000..885c407ea5 --- /dev/null +++ b/vortex-buffer/src/flexbuffers.rs @@ -0,0 +1,23 @@ +#![cfg(feature = "flexbuffers")] +use std::ops::Range; +use std::str::Utf8Error; + +use crate::string::BufferString; +use crate::Buffer; + +impl flexbuffers::Buffer for Buffer { + type BufferString = BufferString; + + fn slice(&self, range: Range) -> Option { + // TODO(ngates): bounds-check and return None? + Some(Buffer::slice(self, range)) + } + + fn empty() -> Self { + Buffer::from(vec![]) + } + + fn buffer_str(&self) -> Result { + BufferString::try_from(self.clone()) + } +} diff --git a/vortex-buffer/src/lib.rs b/vortex-buffer/src/lib.rs index c0415e73d5..4c9f112d78 100644 --- a/vortex-buffer/src/lib.rs +++ b/vortex-buffer/src/lib.rs @@ -1,6 +1,11 @@ -use std::ops::Deref; +mod flexbuffers; +mod string; + +use std::cmp::Ordering; +use std::ops::{Deref, Range}; use arrow_buffer::Buffer as ArrowBuffer; +pub use string::*; use vortex_dtype::{match_each_native_ptype, NativePType}; #[derive(Debug, Clone)] @@ -28,6 +33,15 @@ impl Buffer { } } + pub fn slice(&self, range: Range) -> Self { + match self { + Buffer::Arrow(b) => { + Buffer::Arrow(b.slice_with_length(range.start, range.end - range.start)) + } + Buffer::Bytes(b) => Buffer::Bytes(b.slice(range)), + } + } + pub fn typed_data(&self) -> &[T] { match self { Buffer::Arrow(buffer) => unsafe { @@ -78,6 +92,13 @@ impl AsRef<[u8]> for Buffer { } } +impl From<&[u8]> for Buffer { + fn from(value: &[u8]) -> Self { + // We prefer Arrow since it retains mutability + Buffer::Arrow(ArrowBuffer::from(value)) + } +} + impl From> for Buffer { fn from(value: Vec) -> Self { // We prefer Arrow since it retains mutability @@ -98,3 +119,9 @@ impl PartialEq for Buffer { } impl Eq for Buffer {} + +impl PartialOrd for Buffer { + fn partial_cmp(&self, other: &Self) -> Option { + self.as_ref().partial_cmp(other.as_ref()) + } +} diff --git a/vortex-buffer/src/string.rs b/vortex-buffer/src/string.rs new file mode 100644 index 0000000000..b0023490eb --- /dev/null +++ b/vortex-buffer/src/string.rs @@ -0,0 +1,52 @@ +use std::ops::Deref; +use std::str::Utf8Error; + +use crate::Buffer; + +/// A wrapper around a `Buffer` that guarantees that the buffer contains valid UTF-8. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] +pub struct BufferString(Buffer); + +impl BufferString { + /// Creates a new `BufferString` from a `Buffer`. + /// + /// # Safety + /// Assumes that the buffer contains valid UTF-8. + pub unsafe fn new_unchecked(buffer: Buffer) -> Self { + Self(buffer) + } + + pub fn as_str(&self) -> &str { + // SAFETY: We have already validated that the buffer is valid UTF-8 + unsafe { std::str::from_utf8_unchecked(self.0.as_ref()) } + } +} + +impl From for Buffer { + fn from(value: BufferString) -> Self { + value.0 + } +} + +impl From for BufferString { + fn from(value: String) -> Self { + BufferString(Buffer::from(value.into_bytes())) + } +} + +impl TryFrom for BufferString { + type Error = Utf8Error; + + fn try_from(value: Buffer) -> Result { + let _ = std::str::from_utf8(value.as_ref())?; + Ok(Self(value)) + } +} + +impl Deref for BufferString { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.as_str() + } +} diff --git a/vortex-dict/src/compress.rs b/vortex-dict/src/compress.rs index 4f9059391c..ae516537f2 100644 --- a/vortex-dict/src/compress.rs +++ b/vortex-dict/src/compress.rs @@ -11,10 +11,9 @@ use vortex::compress::{CompressConfig, Compressor, EncodingCompression}; use vortex::stats::ArrayStatistics; use vortex::validity::Validity; use vortex::{Array, ArrayDType, ArrayDef, IntoArray, OwnedArray, ToArray}; -use vortex_dtype::NativePType; use vortex_dtype::{match_each_native_ptype, DType}; +use vortex_dtype::{NativePType, ToBytes}; use vortex_error::VortexResult; -use vortex_scalar::AsBytes; use crate::dict::{DictArray, DictEncoding}; @@ -93,19 +92,19 @@ impl EncodingCompression for DictEncoding { #[derive(Debug)] struct Value(T); -impl Hash for Value { +impl Hash for Value { fn hash(&self, state: &mut H) { - self.0.as_bytes().hash(state) + self.0.to_le_bytes().hash(state) } } -impl PartialEq for Value { +impl PartialEq for Value { fn eq(&self, other: &Self) -> bool { - self.0.as_bytes().eq(other.0.as_bytes()) + self.0.to_le_bytes().eq(other.0.to_le_bytes()) } } -impl Eq for Value {} +impl Eq for Value {} /// Dictionary encode primitive array with given PType. /// Null values in the original array are encoded in the dictionary. @@ -255,7 +254,9 @@ mod test { use vortex::array::varbin::VarBinArray; use vortex::compute::scalar_at::scalar_at; use vortex::ToArray; - use vortex_scalar::PrimitiveScalar; + use vortex_dtype::Nullability::Nullable; + use vortex_dtype::{DType, PType}; + use vortex_scalar::Scalar; use crate::compress::{dict_encode_typed_primitive, dict_encode_varbin}; @@ -286,15 +287,15 @@ mod test { ); assert_eq!( scalar_at(&values.to_array(), 0).unwrap(), - PrimitiveScalar::nullable::(None).into() + Scalar::null(DType::Primitive(PType::I32, Nullable)) ); assert_eq!( scalar_at(&values.to_array(), 1).unwrap(), - PrimitiveScalar::nullable(Some(1)).into() + Scalar::primitive(1, Nullable) ); assert_eq!( scalar_at(&values.to_array(), 2).unwrap(), - PrimitiveScalar::nullable(Some(3)).into() + Scalar::primitive(3, Nullable) ); } diff --git a/vortex-dict/src/compute.rs b/vortex-dict/src/compute.rs index ef26882dd9..2c5a9c8153 100644 --- a/vortex-dict/src/compute.rs +++ b/vortex-dict/src/compute.rs @@ -24,7 +24,7 @@ impl ArrayCompute for DictArray<'_> { impl ScalarAtFn for DictArray<'_> { fn scalar_at(&self, index: usize) -> VortexResult { - let dict_index: usize = scalar_at(&self.codes(), index)?.try_into()?; + let dict_index: usize = scalar_at(&self.codes(), index)?.as_ref().try_into()?; scalar_at(&self.values(), dict_index) } } diff --git a/vortex-dict/src/dict.rs b/vortex-dict/src/dict.rs index b048629faa..2f5ebba372 100644 --- a/vortex-dict/src/dict.rs +++ b/vortex-dict/src/dict.rs @@ -56,7 +56,11 @@ impl ArrayFlatten for DictArray<'_> { impl ArrayValidity for DictArray<'_> { fn is_valid(&self, index: usize) -> bool { - let values_index = scalar_at(&self.codes(), index).unwrap().try_into().unwrap(); + let values_index = scalar_at(&self.codes(), index) + .unwrap() + .as_ref() + .try_into() + .unwrap(); self.values().with_dyn(|a| a.is_valid(values_index)) } diff --git a/vortex-dtype/Cargo.toml b/vortex-dtype/Cargo.toml index 3dcdc20299..5c38ea9f38 100644 --- a/vortex-dtype/Cargo.toml +++ b/vortex-dtype/Cargo.toml @@ -17,7 +17,7 @@ path = "src/lib.rs" [dependencies] flatbuffers = { workspace = true, optional = true } -half = { workspace = true } +half = { workspace = true, features = ["num-traits"] } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true, optional = true } @@ -34,3 +34,10 @@ build-vortex = { path = "../build-vortex" } [lints] workspace = true + +[features] +# Uncomment for improved IntelliJ support +# default = ["flatbuffers", "proto", "serde"] +flatbuffers = ["dep:flatbuffers"] +proto = ["dep:prost"] +serde = ["dep:serde"] \ No newline at end of file diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 9da2cbf363..f767b26562 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -21,7 +21,7 @@ pub enum DType { Utf8(Nullability), Binary(Nullability), Struct(StructDType, Nullability), - List(Box, Nullability), + List(Arc, Nullability), Extension(ExtDType, Nullability), } @@ -127,6 +127,10 @@ impl StructDType { &self.names } + pub fn find_name(&self, name: &str) -> Option { + self.names.iter().position(|n| n.as_ref() == name) + } + pub fn dtypes(&self) -> &Arc<[DType]> { &self.dtypes } diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index 44f70828cb..179d9d3a61 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg(target_endian = "little")] + pub use dtype::*; pub use extension::*; pub use half; @@ -9,15 +11,18 @@ mod nullability; mod ptype; mod serde; -#[cfg(feature = "prost")] +#[cfg(feature = "proto")] pub mod proto { - include!(concat!(env!("OUT_DIR"), "/proto/vortex.dtype.rs")); + pub mod dtype { + include!(concat!(env!("OUT_DIR"), "/proto/vortex.dtype.rs")); + } } #[cfg(feature = "flatbuffers")] pub mod flatbuffers { #[allow(unused_imports)] #[allow(dead_code)] + #[allow(dead_code)] #[allow(clippy::all)] #[allow(non_camel_case_types)] mod generated { diff --git a/vortex-dtype/src/ptype.rs b/vortex-dtype/src/ptype.rs index c2318001c7..94cc99c34d 100644 --- a/vortex-dtype/src/ptype.rs +++ b/vortex-dtype/src/ptype.rs @@ -1,4 +1,5 @@ use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; use std::panic::RefUnwindSafe; use num_traits::{FromPrimitive, Num, NumCast}; @@ -41,6 +42,8 @@ pub trait NativePType: + Num + NumCast + FromPrimitive + + ToBytes + + TryFromBytes { const PTYPE: PType; } @@ -246,3 +249,44 @@ impl From for DType { Primitive(item, NonNullable) } } + +pub trait ToBytes: Sized { + fn to_le_bytes(&self) -> &[u8]; +} + +pub trait TryFromBytes: Sized { + fn try_from_le_bytes(bytes: &[u8]) -> VortexResult; +} + +macro_rules! try_from_bytes { + ($T:ty) => { + impl ToBytes for $T { + #[inline] + #[allow(clippy::size_of_in_element_count)] + fn to_le_bytes(&self) -> &[u8] { + // NOTE(ngates): this assumes the platform is little-endian. Currently enforced + // with a flag cfg(target_endian = "little") + let raw_ptr = self as *const $T as *const u8; + unsafe { std::slice::from_raw_parts(raw_ptr, std::mem::size_of::<$T>()) } + } + } + + impl TryFromBytes for $T { + fn try_from_le_bytes(bytes: &[u8]) -> VortexResult { + Ok(<$T>::from_le_bytes(bytes.try_into()?)) + } + } + }; +} + +try_from_bytes!(u8); +try_from_bytes!(u16); +try_from_bytes!(u32); +try_from_bytes!(u64); +try_from_bytes!(i8); +try_from_bytes!(i16); +try_from_bytes!(i32); +try_from_bytes!(i64); +try_from_bytes!(f16); +try_from_bytes!(f32); +try_from_bytes!(f64); diff --git a/vortex-dtype/src/serde/flatbuffers.rs b/vortex-dtype/src/serde/flatbuffers.rs index add24ad831..affd425de0 100644 --- a/vortex-dtype/src/serde/flatbuffers.rs +++ b/vortex-dtype/src/serde/flatbuffers.rs @@ -1,5 +1,7 @@ #![cfg(feature = "flatbuffers")] +use std::sync::Arc; + use flatbuffers::{FlatBufferBuilder, WIPOffset}; use itertools::Itertools; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; @@ -30,7 +32,7 @@ impl TryFrom> for DType { let fb_list = fb.type__as_list().unwrap(); let element_dtype = DType::try_from(fb_list.element_type().unwrap())?; Ok(DType::List( - Box::new(element_dtype), + Arc::new(element_dtype), fb_list.nullable().into(), )) } @@ -220,6 +222,7 @@ impl TryFrom for PType { #[cfg(test)] mod test { + use std::sync::Arc; use flatbuffers::root; use vortex_flatbuffers::FlatBufferToBytes; @@ -242,7 +245,7 @@ mod test { roundtrip_dtype(DType::Binary(Nullability::NonNullable)); roundtrip_dtype(DType::Utf8(Nullability::NonNullable)); roundtrip_dtype(DType::List( - Box::new(DType::Primitive(PType::F32, Nullability::Nullable)), + Arc::new(DType::Primitive(PType::F32, Nullability::Nullable)), Nullability::NonNullable, )); roundtrip_dtype(DType::Struct( diff --git a/vortex-dtype/src/serde/proto.rs b/vortex-dtype/src/serde/proto.rs index 0e49ded1cf..247b806c62 100644 --- a/vortex-dtype/src/serde/proto.rs +++ b/vortex-dtype/src/serde/proto.rs @@ -1,9 +1,11 @@ -#![cfg(feature = "prost")] +#![cfg(feature = "proto")] + +use std::sync::Arc; use vortex_error::{vortex_err, VortexError, VortexResult}; -use crate::proto::d_type::Type; -use crate::{proto as pb, DType, ExtDType, ExtID, ExtMetadata, PType, StructDType}; +use crate::proto::dtype::d_type::Type; +use crate::{proto::dtype as pb, DType, ExtDType, ExtID, ExtMetadata, PType, StructDType}; impl TryFrom<&pb::DType> for DType { type Error = VortexError; @@ -38,7 +40,7 @@ impl TryFrom<&pb::DType> for DType { .ok_or_else(|| vortex_err!(InvalidSerde: "Invalid list element type"))? .as_ref() .try_into() - .map(Box::new)?, + .map(Arc::new)?, nullable, )) } diff --git a/vortex-error/src/lib.rs b/vortex-error/src/lib.rs index 0be97d9d47..e1fd36e529 100644 --- a/vortex-error/src/lib.rs +++ b/vortex-error/src/lib.rs @@ -101,6 +101,12 @@ pub enum VortexError { #[backtrace] io::Error, ), + #[error(transparent)] + Utf8Error( + #[from] + #[backtrace] + std::str::Utf8Error, + ), #[cfg(feature = "parquet")] #[error(transparent)] ParquetError( diff --git a/vortex-fastlanes/src/bitpacking/compress.rs b/vortex-fastlanes/src/bitpacking/compress.rs index bfa6c1ef7c..483aec0024 100644 --- a/vortex-fastlanes/src/bitpacking/compress.rs +++ b/vortex-fastlanes/src/bitpacking/compress.rs @@ -173,7 +173,7 @@ fn bitpack_patches( indices.into_array(), PrimitiveArray::from_vec(values, Validity::AllValid).into_array(), parray.len(), - Scalar::null(&parray.dtype().as_nullable()), + Scalar::null(parray.dtype().as_nullable()), ).unwrap().into_array() }) } @@ -358,6 +358,7 @@ fn count_exceptions(bit_width: usize, bit_width_freq: &[usize]) -> usize { mod test { use vortex::encoding::ArrayEncoding; use vortex::{Context, ToArray}; + use vortex_scalar::PrimitiveScalar; use super::*; @@ -407,13 +408,12 @@ mod test { .iter() .enumerate() .for_each(|(i, v)| { - let scalar_at: u16 = - if let Scalar::Primitive(pscalar) = unpack_single(&compressed, i).unwrap() { - pscalar.value().unwrap().try_into().unwrap() - } else { - panic!("expected u8 scalar") - }; - assert_eq!(scalar_at, *v); + let scalar = unpack_single(&compressed, i).unwrap(); + let scalar = PrimitiveScalar::try_from(&scalar) + .unwrap() + .typed_value::() + .unwrap(); + assert_eq!(scalar, *v); }); } } diff --git a/vortex-fastlanes/src/bitpacking/compute/mod.rs b/vortex-fastlanes/src/bitpacking/compute/mod.rs index dbb396ead3..0960fe9e29 100644 --- a/vortex-fastlanes/src/bitpacking/compute/mod.rs +++ b/vortex-fastlanes/src/bitpacking/compute/mod.rs @@ -59,7 +59,7 @@ impl TakeFn for BitPackedArray<'_> { Ok(primitive_patches.into_array()) } else { Ok( - ConstantArray::new(Scalar::null(&self.dtype().as_nullable()), indices.len()) + ConstantArray::new(Scalar::null(self.dtype().as_nullable()), indices.len()) .into_array(), ) }; @@ -236,11 +236,12 @@ mod test { .enumerate() .for_each(|(ti, i)| { assert_eq!( - u32::try_from(scalar_at(packed.array(), *i as usize).unwrap()).unwrap(), + u32::try_from(scalar_at(packed.array(), *i as usize).unwrap().as_ref()) + .unwrap(), values[*i as usize] ); assert_eq!( - u32::try_from(scalar_at(&taken, ti).unwrap()).unwrap(), + u32::try_from(scalar_at(&taken, ti).unwrap().as_ref()).unwrap(), values[*i as usize] ); }); @@ -258,7 +259,7 @@ mod test { values.iter().enumerate().for_each(|(i, v)| { assert_eq!( - u32::try_from(scalar_at(packed.array(), i).unwrap()).unwrap(), + u32::try_from(scalar_at(packed.array(), i).unwrap().as_ref()).unwrap(), *v ); }); diff --git a/vortex-fastlanes/src/for/compress.rs b/vortex-fastlanes/src/for/compress.rs index 3faf308be3..e59b95f79f 100644 --- a/vortex-fastlanes/src/for/compress.rs +++ b/vortex-fastlanes/src/for/compress.rs @@ -55,7 +55,7 @@ impl EncodingCompression for FoREncoding { if shift == <$T>::PTYPE.bit_width() as u8 { ConstantArray::new($T::default(), parray.len()).into_array() } else { - compress_primitive::<$T>(parray, shift, $T::try_from(min.clone())?).into_array() + compress_primitive::<$T>(parray, shift, $T::try_from(&min)?).into_array() } }); let for_like = like.map(|like_arr| FoRArray::try_from(like_arr).unwrap()); @@ -204,7 +204,15 @@ mod test { .iter() .enumerate() .for_each(|(i, v)| { - assert_eq!(*v, compressed.scalar_at(i).unwrap().try_into().unwrap()); + assert_eq!( + *v, + compressed + .scalar_at(i) + .unwrap() + .as_ref() + .try_into() + .unwrap() + ); }); } } diff --git a/vortex-fastlanes/src/for/compute.rs b/vortex-fastlanes/src/for/compute.rs index 0706a2a89a..59c3c1c95c 100644 --- a/vortex-fastlanes/src/for/compute.rs +++ b/vortex-fastlanes/src/for/compute.rs @@ -4,7 +4,7 @@ use vortex::compute::take::{take, TakeFn}; use vortex::compute::ArrayCompute; use vortex::{Array, IntoArray, OwnedArray}; use vortex_dtype::match_each_integer_ptype; -use vortex_error::VortexResult; +use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::{PrimitiveScalar, Scalar}; use crate::FoRArray; @@ -37,20 +37,20 @@ impl TakeFn for FoRArray<'_> { impl ScalarAtFn for FoRArray<'_> { fn scalar_at(&self, index: usize) -> VortexResult { let encoded_scalar = scalar_at(&self.encoded(), index)?; + let encoded = PrimitiveScalar::try_from(&encoded_scalar)?; + let reference = PrimitiveScalar::try_from(self.reference())?; - match (&encoded_scalar, self.reference()) { - (Scalar::Primitive(p), Scalar::Primitive(r)) => match p.value() { - None => Ok(encoded_scalar), - Some(pv) => match_each_integer_ptype!(pv.ptype(), |$P| { - use num_traits::WrappingAdd; - Ok(PrimitiveScalar::try_new::<$P>( - Some((p.typed_value::<$P>().unwrap() << self.shift()).wrapping_add(r.typed_value::<$P>().unwrap())), - p.dtype().nullability() - ).unwrap().into()) - }), - }, - _ => unreachable!("Reference and encoded values had different dtypes"), + if encoded.ptype() != reference.ptype() { + vortex_bail!("Reference and encoded values had different dtypes"); } + + match_each_integer_ptype!(encoded.ptype(), |$P| { + use num_traits::WrappingAdd; + Ok(Scalar::primitive::<$P>( + (encoded.typed_value::<$P>().unwrap() << self.shift()).wrapping_add(reference.typed_value::<$P>().unwrap()), + encoded.dtype().nullability() + )) + }) } } diff --git a/vortex-fastlanes/src/for/mod.rs b/vortex-fastlanes/src/for/mod.rs index 161b9e068b..c8844a84f4 100644 --- a/vortex-fastlanes/src/for/mod.rs +++ b/vortex-fastlanes/src/for/mod.rs @@ -84,6 +84,6 @@ impl ArrayTrait for FoRArray<'_> { } fn nbytes(&self) -> usize { - self.reference().nbytes() + self.encoded().nbytes() + self.encoded().nbytes() } } diff --git a/vortex-flatbuffers/src/lib.rs b/vortex-flatbuffers/src/lib.rs index e120c5248a..0f453d1d83 100644 --- a/vortex-flatbuffers/src/lib.rs +++ b/vortex-flatbuffers/src/lib.rs @@ -1,15 +1,25 @@ use std::io; use std::io::Write; -use flatbuffers::{FlatBufferBuilder, WIPOffset}; +use flatbuffers::{root, FlatBufferBuilder, Follow, InvalidFlatbuffer, Verifiable, WIPOffset}; pub trait FlatBufferRoot {} pub trait ReadFlatBuffer: Sized { - type Source<'a>; - type Error; + type Source<'a>: Verifiable + Follow<'a>; + type Error: From; - fn read_flatbuffer(fb: &Self::Source<'_>) -> Result; + fn read_flatbuffer<'buf>( + fb: & as Follow<'buf>>::Inner, + ) -> Result; + + fn read_flatbuffer_bytes<'buf>(bytes: &'buf [u8]) -> Result + where + ::Source<'buf>: 'buf, + { + let fb = root::>(bytes)?; + Self::read_flatbuffer(&fb) + } } pub trait WriteFlatBuffer { diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index e80b67d012..648c3d5296 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -22,7 +22,7 @@ pub struct REEMetadata { impl REEArray<'_> { pub fn try_new(ends: Array, values: Array, validity: Validity) -> VortexResult { - let length: usize = scalar_at(&ends, ends.len() - 1)?.try_into()?; + let length: usize = scalar_at(&ends, ends.len() - 1)?.as_ref().try_into()?; Self::with_offset_and_size(ends, values, validity, length, 0) } diff --git a/vortex-scalar/Cargo.toml b/vortex-scalar/Cargo.toml index 0f18be7574..b2aae9bd7f 100644 --- a/vortex-scalar/Cargo.toml +++ b/vortex-scalar/Cargo.toml @@ -13,10 +13,13 @@ rust-version = { workspace = true } [dependencies] flatbuffers = { workspace = true, optional = true } -flexbuffers = { workspace = true } +flexbuffers = { workspace = true, optional = true } itertools = { workspace = true } +paste = { workspace = true } +prost = { workspace = true, optional = true } +prost-types = { workspace = true, optional = true } num-traits = { workspace = true } -serde = { workspace = true, optional = true } +serde = { workspace = true, optional = true, features = ["rc"] } vortex-buffer = { path = "../vortex-buffer" } vortex-dtype = { path = "../vortex-dtype" } vortex-error = { path = "../vortex-error" } @@ -27,3 +30,20 @@ build-vortex = { path = "../build-vortex" } [lints] workspace = true + +[features] +# Uncomment for improved IntelliJ support +# default = ["flatbuffers", "proto", "serde"] +flatbuffers = [ + "dep:flatbuffers", + "dep:flexbuffers", + "dep:serde", + "vortex-buffer/flexbuffers", + "vortex-error/flexbuffers", +] +proto = [ + "dep:prost", + "dep:prost-types", + "vortex-dtype/proto", +] +serde = ["dep:serde", "serde/derive"] diff --git a/vortex-scalar/flatbuffers/scalar.fbs b/vortex-scalar/flatbuffers/scalar.fbs index eff54ee082..90519fd0da 100644 --- a/vortex-scalar/flatbuffers/scalar.fbs +++ b/vortex-scalar/flatbuffers/scalar.fbs @@ -2,57 +2,9 @@ include "vortex-dtype/flatbuffers/dtype.fbs"; namespace vortex.scalar; -table Binary { - value: [ubyte]; -} - -table Bool { - value: bool; -} - -table List { - value: [Scalar]; -} - -table Null { -} - -table Primitive { - ptype: dtype.PType; - // TODO(ngates): this isn't an ideal way to store the bytes. - bytes: [ubyte]; -} - -table Struct_ { - names: [string]; - scalars: [Scalar]; -} - -table UTF8 { - value: string; -} - -table Extension { - id: string; - metadata: [ubyte]; - value: Scalar; -} - -union Type { - Binary, - Bool, - List, - Null, - Primitive, - Struct_, - UTF8, - Extension, -} - -// TODO(ngates): separate out ScalarValue from Scalar, even in-memory, so we can avoid duplicating dtype information (e.g. Struct field names). table Scalar { - type: Type; - nullability: bool; + dtype: vortex.dtype.DType (required); + value: [ubyte] (flexbuffer); } root_type Scalar; \ No newline at end of file diff --git a/vortex-scalar/proto/scalar.proto b/vortex-scalar/proto/scalar.proto new file mode 100644 index 0000000000..308ec0670a --- /dev/null +++ b/vortex-scalar/proto/scalar.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +import "google/protobuf/any.proto"; +import "vortex-dtype/proto/dtype.proto"; + +package vortex.scalar; + +message Scalar { + vortex.dtype.DType dtype = 1; + oneof value { + bool bool = 2; + uint32 uint32 = 3; + uint64 uint64 = 4; + sint32 sint32 = 5; + sint64 sint64 = 6; + float float = 7; + double double = 8; + bytes bytes = 9; + string string = 10; + } +} \ No newline at end of file diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index 385aa3b133..8d9f811aca 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -1,74 +1,59 @@ -use std::fmt::{Display, Formatter}; - -use vortex_dtype::DType; -use vortex_dtype::Nullability::{NonNullable, Nullable}; +use vortex_buffer::Buffer; +use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use crate::value::ScalarValue; use crate::Scalar; -pub type BinaryScalar = ScalarValue>; +pub struct BinaryScalar<'a> { + dtype: &'a DType, + value: Option, +} -impl BinaryScalar { +impl<'a> BinaryScalar<'a> { #[inline] - pub fn dtype(&self) -> &DType { - match self.nullability() { - NonNullable => &DType::Binary(NonNullable), - Nullable => &DType::Binary(Nullable), - } + pub fn dtype(&self) -> &'a DType { + self.dtype } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub fn value(&self) -> Option { + self.value.as_ref().cloned() } - pub fn nbytes(&self) -> usize { - self.value().map(|s| s.len()).unwrap_or(1) - } -} - -impl From> for Scalar { - fn from(value: Vec) -> Self { - BinaryScalar::some(value).into() + pub fn cast(&self, _dtype: &DType) -> VortexResult { + todo!() } } -impl From<&[u8]> for Scalar { - fn from(value: &[u8]) -> Self { - BinaryScalar::some(value.to_vec()).into() +impl Scalar { + pub fn binary(buffer: Buffer, nullability: Nullability) -> Self { + Scalar { + dtype: DType::Binary(nullability), + value: ScalarValue::Buffer(buffer), + } } } -impl TryFrom for Vec { +impl<'a> TryFrom<&'a Scalar> for BinaryScalar<'a> { type Error = VortexError; - fn try_from(value: Scalar) -> VortexResult { - let Scalar::Binary(b) = value else { - vortex_bail!(MismatchedTypes: "binary", value.dtype()); - }; - b.into_value() - .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + fn try_from(value: &'a Scalar) -> Result { + if !matches!(value.dtype(), DType::Binary(_)) { + vortex_bail!("Expected binary scalar, found {}", value.dtype()) + } + Ok(Self { + dtype: value.dtype(), + value: value.value.as_buffer()?, + }) } } -impl TryFrom<&Scalar> for Vec { +impl<'a> TryFrom<&'a Scalar> for Buffer { type Error = VortexError; - fn try_from(value: &Scalar) -> VortexResult { - let Scalar::Binary(b) = value else { - vortex_bail!(MismatchedTypes: "binary", value.dtype()); - }; - b.value() - .cloned() + fn try_from(value: &'a Scalar) -> VortexResult { + BinaryScalar::try_from(value)? + .value() .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) } } - -impl Display for BinaryScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.value() { - None => write!(f, "bytes[none]"), - Some(b) => write!(f, "bytes[{}]", b.len()), - } - } -} diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index c069750487..cb6da48a51 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -1,71 +1,74 @@ -use std::fmt::{Display, Formatter}; - +use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use crate::value::ScalarValue; use crate::Scalar; -pub type BoolScalar = ScalarValue; +pub struct BoolScalar<'a> { + dtype: &'a DType, + value: Option, +} -impl BoolScalar { +impl<'a> BoolScalar<'a> { #[inline] - pub fn dtype(&self) -> &DType { - match self.nullability() { - Nullability::NonNullable => &DType::Bool(Nullability::NonNullable), - Nullability::Nullable => &DType::Bool(Nullability::Nullable), - } + pub fn dtype(&self) -> &'a DType { + self.dtype + } + + pub fn value(&self) -> Option { + self.value } pub fn cast(&self, dtype: &DType) -> VortexResult { match dtype { - DType::Bool(_) => Ok(self.clone().into()), - _ => Err(vortex_err!(MismatchedTypes: "bool", dtype)), + DType::Bool(_) => Ok(Scalar::bool( + self.value().ok_or_else(|| vortex_err!("not a bool"))?, + dtype.nullability(), + )), + _ => vortex_bail!("Can't cast {} to bool", dtype), } } - - pub fn nbytes(&self) -> usize { - 1 - } } -impl From for Scalar { - #[inline] - fn from(value: bool) -> Self { - BoolScalar::some(value).into() +impl Scalar { + pub fn bool(value: bool, nullability: Nullability) -> Self { + Scalar { + dtype: DType::Bool(nullability), + value: ScalarValue::Bool(value), + } } } -impl TryFrom<&Scalar> for bool { +impl<'a> TryFrom<&'a Scalar> for BoolScalar<'a> { type Error = VortexError; - fn try_from(value: &Scalar) -> VortexResult { - let Scalar::Bool(b) = value else { - vortex_bail!(MismatchedTypes: "bool", value.dtype()); - }; - b.value() - .cloned() - .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + fn try_from(value: &'a Scalar) -> Result { + if !matches!(value.dtype(), DType::Bool(_)) { + vortex_bail!("Expected bool scalar, found {}", value.dtype()) + } + Ok(Self { + dtype: value.dtype(), + value: value.value.as_bool()?, + }) } } -impl TryFrom for bool { +impl TryFrom<&Scalar> for bool { type Error = VortexError; - fn try_from(value: Scalar) -> VortexResult { - let Scalar::Bool(b) = value else { - vortex_bail!(MismatchedTypes: "bool", value.dtype()); - }; - b.into_value() + fn try_from(value: &Scalar) -> VortexResult { + BoolScalar::try_from(value)? + .value() .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) } } -impl Display for BoolScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.value() { - None => write!(f, "null"), - Some(b) => Display::fmt(&b, f), +impl From for Scalar { + fn from(value: bool) -> Self { + Scalar { + dtype: DType::Bool(NonNullable), + value: ScalarValue::Bool(value), } } } @@ -77,6 +80,6 @@ mod test { #[test] fn into_from() { let scalar: Scalar = false.into(); - assert!(!bool::try_from(scalar).unwrap()); + assert!(!bool::try_from(&scalar).unwrap()); } } diff --git a/vortex-scalar/src/display.rs b/vortex-scalar/src/display.rs new file mode 100644 index 0000000000..b04da0e500 --- /dev/null +++ b/vortex-scalar/src/display.rs @@ -0,0 +1,41 @@ +use std::fmt::{Display, Formatter}; + +use vortex_dtype::{match_each_native_ptype, DType}; + +use crate::bool::BoolScalar; +use crate::primitive::PrimitiveScalar; +use crate::Scalar; + +impl Display for Scalar { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.dtype() { + DType::Null => write!(f, "null"), + DType::Bool(_) => match BoolScalar::try_from(self).expect("bool").value() { + None => write!(f, "null"), + Some(b) => write!(f, "{}", b), + }, + DType::Primitive(ptype, _) => match_each_native_ptype!(ptype, |$T| { + match PrimitiveScalar::try_from(self).expect("primitive").typed_value::<$T>() { + None => write!(f, "null"), + Some(v) => write!(f, "{}", v), + } + }), + DType::Utf8(_) => todo!(), + DType::Binary(_) => todo!(), + DType::Struct(..) => todo!(), + DType::List(..) => todo!(), + DType::Extension(..) => todo!(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::Scalar; + + #[test] + fn display() { + let scalar = Scalar::from(false); + assert_eq!(format!("{}", scalar), "false"); + } +} diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index 18bb247927..73cc9e9461 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -1,97 +1,52 @@ -use std::fmt::{Display, Formatter}; -use std::sync::Arc; - -use vortex_dtype::{DType, ExtDType, ExtID, ExtMetadata, Nullability}; -use vortex_error::{vortex_bail, VortexResult}; +use vortex_dtype::{DType, ExtDType}; +use vortex_error::{vortex_bail, VortexError, VortexResult}; +use crate::value::ScalarValue; use crate::Scalar; -#[derive(Debug, Clone, PartialEq)] -pub struct ExtScalar { - dtype: DType, - value: Option>, +pub struct ExtScalar<'a> { + dtype: &'a DType, + // TODO(ngates): we may need to serialize the value's dtype too so we can pull + // it out as a scalar. + value: &'a ScalarValue, } -impl ExtScalar { - pub fn try_new( - ext: ExtDType, - nullability: Nullability, - value: Option, - ) -> VortexResult { - if value.is_none() && nullability == Nullability::NonNullable { - vortex_bail!("Value cannot be None for NonNullable Scalar"); - } - - // Throw away the inner scalar if it is null. - let value = value - .and_then(|scalar| if scalar.is_null() { None } else { Some(scalar) }) - .map(Arc::new); - - Ok(Self { - dtype: DType::Extension(ext, nullability), - value, - }) - } - - pub fn null(ext: ExtDType) -> Self { - Self::try_new(ext, Nullability::Nullable, None).expect("Incorrect nullability check") - } - +impl<'a> ExtScalar<'a> { #[inline] - pub fn id(&self) -> &ExtID { - self.ext_dtype().id() + pub fn dtype(&self) -> &'a DType { + self.dtype } - #[inline] - pub fn metadata(&self) -> Option<&ExtMetadata> { - self.ext_dtype().metadata() - } - - #[inline] - pub fn ext_dtype(&self) -> &ExtDType { - let DType::Extension(ext, _) = &self.dtype else { - unreachable!() - }; - ext - } - - #[inline] - pub fn dtype(&self) -> &DType { - &self.dtype - } - - pub fn value(&self) -> Option<&Arc> { - self.value.as_ref() + /// Returns the stored value of the extension scalar. + pub fn value(&self) -> &'a ScalarValue { + self.value } pub fn cast(&self, _dtype: &DType) -> VortexResult { todo!() } - - pub fn nbytes(&self) -> usize { - todo!() - } } -impl PartialOrd for ExtScalar { - fn partial_cmp(&self, other: &Self) -> Option { - if let (Some(s), Some(o)) = (self.value(), other.value()) { - s.partial_cmp(o) - } else { - None +impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> { + type Error = VortexError; + + fn try_from(value: &'a Scalar) -> Result { + if !matches!(value.dtype(), DType::Extension(..)) { + vortex_bail!("Expected extension scalar, found {}", value.dtype()) } + + Ok(Self { + dtype: value.dtype(), + value: &value.value, + }) } } -impl Display for ExtScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{} ({})", - self.value() - .map(|s| format!("{}", s)) - .unwrap_or_else(|| "".to_string()), - self.dtype - ) +impl Scalar { + pub fn extension(ext_dtype: ExtDType, storage: Scalar) -> Self { + Scalar { + dtype: DType::Extension(ext_dtype, storage.dtype().nullability()), + value: storage.value, + } } } diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 429640e477..4a359b8ea2 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -1,28 +1,39 @@ -use std::fmt::{Debug, Display, Formatter}; +use std::cmp::Ordering; -pub use binary::*; -pub use bool::*; -pub use extension::*; -pub use list::*; -pub use null::*; -pub use primitive::*; -pub use struct_::*; -pub use utf8::*; -use vortex_dtype::NativePType; -use vortex_dtype::{DType, Nullability}; -use vortex_error::VortexResult; +use vortex_dtype::DType; mod binary; mod bool; +mod display; mod extension; mod list; -mod null; mod primitive; +mod pvalue; mod serde; mod struct_; mod utf8; mod value; +pub use binary::*; +pub use bool::*; +pub use extension::*; +pub use list::*; +pub use primitive::*; +pub use struct_::*; +pub use utf8::*; +pub use value::*; +use vortex_error::{vortex_bail, VortexResult}; + +#[cfg(feature = "proto")] +pub mod proto { + #[allow(clippy::module_inception)] + pub mod scalar { + include!(concat!(env!("OUT_DIR"), "/proto/vortex.scalar.rs")); + } + + pub use vortex_dtype::proto::dtype; +} + #[cfg(feature = "flatbuffers")] pub mod flatbuffers { pub use gen_scalar::vortex::*; @@ -43,127 +54,74 @@ pub mod flatbuffers { } } -#[derive(Debug, Clone, PartialEq, PartialOrd)] -pub enum Scalar { - Binary(BinaryScalar), - Bool(BoolScalar), - List(ListScalar), - Null(NullScalar), - Primitive(PrimitiveScalar), - Struct(StructScalar), - Utf8(Utf8Scalar), - Extension(ExtScalar), -} - -macro_rules! impls_for_scalars { - ($variant:tt, $E:ty) => { - impl From<$E> for Scalar { - fn from(arr: $E) -> Self { - Self::$variant(arr) - } - } - }; -} - -impls_for_scalars!(Binary, BinaryScalar); -impls_for_scalars!(Bool, BoolScalar); -impls_for_scalars!(List, ListScalar); -impls_for_scalars!(Null, NullScalar); -impls_for_scalars!(Primitive, PrimitiveScalar); -impls_for_scalars!(Struct, StructScalar); -impls_for_scalars!(Utf8, Utf8Scalar); -impls_for_scalars!(Extension, ExtScalar); - -macro_rules! match_each_scalar { - ($self:expr, | $_:tt $scalar:ident | $($body:tt)*) => ({ - macro_rules! __with_scalar__ {( $_ $scalar:ident ) => ( $($body)* )} - match $self { - Scalar::Binary(s) => __with_scalar__! { s }, - Scalar::Bool(s) => __with_scalar__! { s }, - Scalar::List(s) => __with_scalar__! { s }, - Scalar::Null(s) => __with_scalar__! { s }, - Scalar::Primitive(s) => __with_scalar__! { s }, - Scalar::Struct(s) => __with_scalar__! { s }, - Scalar::Utf8(s) => __with_scalar__! { s }, - Scalar::Extension(s) => __with_scalar__! { s }, - } - }) +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] +pub struct Scalar { + pub(crate) dtype: DType, + pub(crate) value: ScalarValue, } impl Scalar { pub fn dtype(&self) -> &DType { - match_each_scalar! { self, |$s| $s.dtype() } + &self.dtype } - pub fn cast(&self, dtype: &DType) -> VortexResult { - match_each_scalar! { self, |$s| $s.cast(dtype) } + pub fn into_value(self) -> ScalarValue { + self.value } - pub fn nbytes(&self) -> usize { - match_each_scalar! { self, |$s| $s.nbytes() } + pub fn is_null(&self) -> bool { + self.value.is_null() } - pub fn nullability(&self) -> Nullability { - self.dtype().nullability() + pub fn null(dtype: DType) -> Self { + assert!(dtype.is_nullable()); + Self { + dtype, + value: ScalarValue::Null, + } } - pub fn is_null(&self) -> bool { - match self { - Scalar::Binary(b) => b.value().is_none(), - Scalar::Bool(b) => b.value().is_none(), - Scalar::List(l) => l.values().is_none(), - Scalar::Null(_) => true, - Scalar::Primitive(p) => p.value().is_none(), - // FIXME(ngates): can't have a null struct? - Scalar::Struct(_) => false, - Scalar::Utf8(u) => u.value().is_none(), - Scalar::Extension(e) => e.value().is_none(), + pub fn cast(&self, dtype: &DType) -> VortexResult { + if self.dtype() == dtype { + return Ok(self.clone()); + } + + if self.is_null() && !dtype.is_nullable() { + vortex_bail!("Can't cast null scalar to non-nullable type") } - } - pub fn null(dtype: &DType) -> Self { - assert!(dtype.is_nullable()); match dtype { - DType::Null => NullScalar::new().into(), - DType::Bool(_) => BoolScalar::none().into(), - DType::Primitive(p, _) => PrimitiveScalar::none_from_ptype(*p).into(), - DType::Utf8(_) => Utf8Scalar::none().into(), - DType::Binary(_) => BinaryScalar::none().into(), - DType::Struct(..) => StructScalar::new(dtype.clone(), vec![]).into(), - DType::List(..) => ListScalar::new(dtype.clone(), None).into(), - DType::Extension(ext, _) => ExtScalar::null(ext.clone()).into(), + DType::Null => vortex_bail!("Can't cast non-null to null"), + DType::Bool(_) => BoolScalar::try_from(self).and_then(|s| s.cast(dtype)), + DType::Primitive(..) => PrimitiveScalar::try_from(self).and_then(|s| s.cast(dtype)), + DType::Utf8(_) => Utf8Scalar::try_from(self).and_then(|s| s.cast(dtype)), + DType::Binary(_) => BinaryScalar::try_from(self).and_then(|s| s.cast(dtype)), + DType::Struct(..) => StructScalar::try_from(self).and_then(|s| s.cast(dtype)), + DType::List(..) => ListScalar::try_from(self).and_then(|s| s.cast(dtype)), + DType::Extension(..) => ExtScalar::try_from(self).and_then(|s| s.cast(dtype)), } } } -impl Display for Scalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match_each_scalar! { self, |$s| Display::fmt($s, f) } +impl PartialEq for Scalar { + fn eq(&self, other: &Self) -> bool { + self.dtype == other.dtype && self.value == other.value } } -/// Allows conversion from Enc scalars to a byte slice. -pub trait AsBytes { - /// Converts this instance into a byte slice - fn as_bytes(&self) -> &[u8]; -} - -impl AsBytes for T { - #[inline] - fn as_bytes(&self) -> &[u8] { - let raw_ptr = self as *const T as *const u8; - unsafe { std::slice::from_raw_parts(raw_ptr, std::mem::size_of::()) } +impl PartialOrd for Scalar { + fn partial_cmp(&self, other: &Self) -> Option { + if self.dtype() == other.dtype() { + self.value.partial_cmp(&other.value) + } else { + None + } } } -#[cfg(test)] -mod test { - use std::mem; - - use crate::Scalar; - - #[test] - fn size_of() { - assert_eq!(mem::size_of::(), 72); +impl AsRef for Scalar { + fn as_ref(&self) -> &Scalar { + self } } diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index 7ea3590895..e07e56074f 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -1,117 +1,111 @@ -use std::fmt::{Display, Formatter}; +use std::ops::Deref; +use std::sync::Arc; use itertools::Itertools; use vortex_dtype::DType; -use vortex_error::{vortex_err, VortexError, VortexResult}; +use vortex_dtype::Nullability::NonNullable; +use vortex_error::{vortex_bail, VortexError, VortexResult}; +use crate::value::ScalarValue; use crate::Scalar; -#[derive(Debug, Clone, PartialEq, PartialOrd)] -pub struct ListScalar { - dtype: DType, - values: Option>, +pub struct ListScalar<'a> { + dtype: &'a DType, + elements: Option>, } -impl ListScalar { +impl<'a> ListScalar<'a> { #[inline] - pub fn new(dtype: DType, values: Option>) -> Self { - Self { dtype, values } + pub fn dtype(&self) -> &'a DType { + self.dtype } #[inline] - pub fn values(&self) -> Option<&[Scalar]> { - self.values.as_deref() + pub fn len(&self) -> usize { + self.elements.as_ref().map(|e| e.len()).unwrap_or(0) } #[inline] - pub fn dtype(&self) -> &DType { - &self.dtype + pub fn is_empty(&self) -> bool { + match self.elements.as_ref() { + None => true, + Some(l) => l.is_empty(), + } } - pub fn cast(&self, dtype: &DType) -> VortexResult { - match dtype { - DType::List(field_dtype, n) => { - let new_fields: Option> = self - .values() - .map(|v| v.iter().map(|field| field.cast(field_dtype)).try_collect()) - .transpose()?; - - let new_type = if let Some(nf) = new_fields.as_ref() { - if nf.is_empty() { - dtype.clone() - } else { - DType::List(Box::new(nf[0].dtype().clone()), *n) - } - } else { - dtype.clone() - }; - Ok(ListScalar::new(new_type, new_fields).into()) - } - _ => Err(vortex_err!(MismatchedTypes: "any list", dtype)), - } + pub fn element_dtype(&self) -> DType { + let DType::List(element_type, _) = self.dtype() else { + unreachable!(); + }; + (*element_type).deref().clone() } - pub fn nbytes(&self) -> usize { - self.values() - .map(|v| v.iter().map(|s| s.nbytes()).sum()) - .unwrap_or(0) + pub fn element(&self, idx: usize) -> Option { + self.elements + .as_ref() + .and_then(|l| l.get(idx)) + .map(|value| Scalar { + dtype: self.element_dtype(), + value: value.clone(), + }) } -} -#[derive(Debug, Clone, Default, PartialEq)] -pub struct ListScalarVec(pub Vec); + pub fn elements(&self) -> impl Iterator + '_ { + self.elements + .as_ref() + .map(|e| e.as_ref()) + .unwrap_or_else(|| &[] as &[ScalarValue]) + .iter() + .map(|e| Scalar { + dtype: self.element_dtype(), + value: e.clone(), + }) + } -impl> From> for Scalar { - fn from(value: ListScalarVec) -> Self { - let values: Vec = value.0.into_iter().map(|v| v.into()).collect(); - if values.is_empty() { - panic!("can't implicitly convert empty list into ListScalar"); - } - ListScalar::new(values[0].dtype().clone(), Some(values)).into() + pub fn cast(&self, _dtype: &DType) -> VortexResult { + todo!() } } -impl> TryFrom for ListScalarVec { +impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> { type Error = VortexError; - fn try_from(value: Scalar) -> Result { - if let Scalar::List(ls) = value { - if let Some(vs) = ls.values { - Ok(ListScalarVec( - vs.into_iter().map(|v| v.try_into()).try_collect()?, - )) - } else { - Err(vortex_err!("can't extract present value from null scalar")) - } - } else { - Err(vortex_err!(MismatchedTypes: "any list", value.dtype())) + fn try_from(value: &'a Scalar) -> Result { + if !matches!(value.dtype(), DType::List(..)) { + vortex_bail!("Expected list scalar, found {}", value.dtype()) } + + Ok(Self { + dtype: value.dtype(), + elements: value.value.as_list()?.cloned(), + }) } } -impl<'a, T: TryFrom<&'a Scalar, Error = VortexError>> TryFrom<&'a Scalar> for ListScalarVec { +impl<'a, T: for<'b> TryFrom<&'b Scalar, Error = VortexError>> TryFrom<&'a Scalar> for Vec { type Error = VortexError; fn try_from(value: &'a Scalar) -> Result { - if let Scalar::List(ls) = value { - if let Some(vs) = ls.values() { - Ok(ListScalarVec( - vs.iter().map(|v| v.try_into()).try_collect()?, - )) - } else { - Err(vortex_err!("can't extract present value from null scalar")) - } - } else { - Err(vortex_err!(MismatchedTypes: "any list", value.dtype())) + let value = ListScalar::try_from(value)?; + let mut elems = vec![]; + for e in value.elements() { + elems.push(T::try_from(&e)?); } + Ok(elems) } } -impl Display for ListScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.values() { - None => write!(f, ""), - Some(vs) => write!(f, "{}", vs.iter().format(", ")), +impl From> for Scalar +where + Scalar: From, +{ + fn from(value: Vec) -> Self { + let scalars = value.into_iter().map(|v| Scalar::from(v)).collect_vec(); + let element_dtype = scalars.first().expect("Empty list").dtype().clone(); + let dtype = DType::List(Arc::new(element_dtype), NonNullable); + Scalar { + dtype, + value: ScalarValue::List(scalars.into_iter().map(|s| s.value).collect_vec().into()), } } } diff --git a/vortex-scalar/src/null.rs b/vortex-scalar/src/null.rs deleted file mode 100644 index 5e8208a9ba..0000000000 --- a/vortex-scalar/src/null.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::fmt::{Display, Formatter}; - -use vortex_dtype::DType; -use vortex_error::VortexResult; - -use crate::Scalar; - -#[derive(Debug, Clone, PartialEq, PartialOrd)] -pub struct NullScalar; - -impl Default for NullScalar { - fn default() -> Self { - Self::new() - } -} - -impl NullScalar { - #[inline] - pub fn new() -> Self { - Self {} - } - - #[inline] - pub fn dtype(&self) -> &DType { - &DType::Null - } - - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() - } - - pub fn nbytes(&self) -> usize { - 1 - } -} - -impl Display for NullScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "null") - } -} diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index d739458d21..9eef4da6b4 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -1,78 +1,22 @@ -use std::any; -use std::cmp::Ordering; -use std::fmt::{Display, Formatter}; -use std::mem::size_of; - -use num_traits::identities::Zero; +use num_traits::NumCast; use vortex_dtype::half::f16; -use vortex_dtype::{match_each_integer_ptype, match_each_native_ptype}; -use vortex_dtype::{DType, Nullability}; -use vortex_dtype::{NativePType, PType}; +use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use crate::pvalue::PValue; +use crate::value::ScalarValue; use crate::Scalar; -pub trait PScalarType: NativePType + Into + TryFrom {} - -impl + TryFrom> PScalarType for T {} - -#[derive(Debug, Clone, PartialEq)] -pub struct PrimitiveScalar { +pub struct PrimitiveScalar<'a> { + dtype: &'a DType, ptype: PType, - dtype: DType, - nullability: Nullability, - value: Option, + pvalue: Option, } -impl PrimitiveScalar { - pub fn try_new( - value: Option, - nullability: Nullability, - ) -> VortexResult { - if value.is_none() && nullability == Nullability::NonNullable { - vortex_bail!("Value cannot be None for NonNullable Scalar"); - } - Ok(Self { - ptype: T::PTYPE, - dtype: DType::from(T::PTYPE).with_nullability(nullability), - nullability, - value: value.map(|v| Into::::into(v)), - }) - } - - pub fn none_from_ptype(ptype: PType) -> Self { - Self { - ptype, - dtype: DType::from(ptype).with_nullability(Nullability::Nullable), - nullability: Nullability::Nullable, - value: None, - } - } - - pub fn nullable(value: Option) -> Self { - Self::try_new(value, Nullability::Nullable).unwrap() - } - - pub fn some(value: T) -> Self { - Self::try_new::(Some(value), Nullability::default()).unwrap() - } - - pub fn none() -> Self { - Self::try_new::(None, Nullability::Nullable).unwrap() - } - +impl<'a> PrimitiveScalar<'a> { #[inline] - pub fn value(&self) -> Option { - self.value - } - - pub fn typed_value(&self) -> Option { - assert_eq!( - T::PTYPE, - self.ptype, - "typed_value called with incorrect ptype" - ); - self.value.map(|v| v.try_into().unwrap()) + pub fn dtype(&self) -> &'a DType { + self.dtype } #[inline] @@ -80,364 +24,128 @@ impl PrimitiveScalar { self.ptype } - #[inline] - pub fn dtype(&self) -> &DType { - &self.dtype + pub fn typed_value>(&self) -> Option { + if self.ptype != T::PTYPE { + panic!("Attempting to read {} scalar as {}", self.ptype, T::PTYPE); + } + self.pvalue + .as_ref() + .map(|pv| T::try_from(*pv).expect("checked on construction")) } pub fn cast(&self, dtype: &DType) -> VortexResult { - let ptype: PType = dtype.try_into()?; - match_each_native_ptype!(ptype, |$T| { - Ok(PrimitiveScalar::try_new( - self.value() - .map(|ps| ps.cast_ptype(ptype)) - .transpose()? - .map(|s| $T::try_from(s)) - .transpose()?, - self.nullability, - )?.into()) + let ptype = PType::try_from(dtype)?; + match_each_native_ptype!(ptype, |$Q| { + match_each_native_ptype!(self.ptype(), |$T| { + Ok(Scalar::primitive::<$Q>( + <$Q as NumCast>::from(self.typed_value::<$T>().expect("Invalid value")) + .ok_or_else(|| vortex_err!("Can't cast scalar to {}", dtype))?, + dtype.nullability(), + )) + }) }) } - - pub fn nbytes(&self) -> usize { - size_of::() - } -} - -impl PartialOrd for PrimitiveScalar { - fn partial_cmp(&self, other: &Self) -> Option { - if let (Some(s), Some(o)) = (self.value, other.value) { - s.partial_cmp(&o) - } else { - None - } - } -} - -impl Display for PrimitiveScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.value() { - None => write!(f, "({}?)", self.ptype), - Some(v) => write!(f, "{}({})", v, self.ptype), - } - } } -#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] -pub enum PScalar { - U8(u8), - U16(u16), - U32(u32), - U64(u64), - I8(i8), - I16(i16), - I32(i32), - I64(i64), - F16(f16), - F32(f32), - F64(f64), -} - -impl PScalar { - pub fn ptype(&self) -> PType { - match self { - PScalar::U8(_) => PType::U8, - PScalar::U16(_) => PType::U16, - PScalar::U32(_) => PType::U32, - PScalar::U64(_) => PType::U64, - PScalar::I8(_) => PType::I8, - PScalar::I16(_) => PType::I16, - PScalar::I32(_) => PType::I32, - PScalar::I64(_) => PType::I64, - PScalar::F16(_) => PType::F16, - PScalar::F32(_) => PType::F32, - PScalar::F64(_) => PType::F64, - } - } - - pub fn cast_ptype(&self, ptype: PType) -> VortexResult { - macro_rules! from_int { - ($ptype:ident, $v:ident) => { - match $ptype { - PType::U8 => Ok((*$v as u8).into()), - PType::U16 => Ok((*$v as u16).into()), - PType::U32 => Ok((*$v as u32).into()), - PType::U64 => Ok((*$v as u64).into()), - PType::I8 => Ok((*$v as i8).into()), - PType::I16 => Ok((*$v as i16).into()), - PType::I32 => Ok((*$v as i32).into()), - PType::I64 => Ok((*$v as i64).into()), - PType::F16 => Ok(f16::from_f32(*$v as f32).into()), - PType::F32 => Ok((*$v as f32).into()), - PType::F64 => Ok((*$v as f64).into()), - } - }; - } +impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> { + type Error = VortexError; - macro_rules! from_floating { - ($ptype:ident , $v:ident) => { - match $ptype { - PType::F16 => Ok((f16::from_f32(*$v as f32)).into()), - PType::F32 => Ok((*$v as f32).into()), - PType::F64 => Ok((*$v as f64).into()), - _ => Err(vortex_err!(MismatchedTypes: "any float", ptype)), - } - }; + fn try_from(value: &'a Scalar) -> Result { + if !matches!(value.dtype(), DType::Primitive(..)) { + vortex_bail!("Expected primitive scalar, found {}", value.dtype()) } - match self { - PScalar::U8(v) => from_int!(ptype, v), - PScalar::U16(v) => from_int!(ptype, v), - PScalar::U32(v) => from_int!(ptype, v), - PScalar::U64(v) => from_int!(ptype, v), - PScalar::I8(v) => from_int!(ptype, v), - PScalar::I16(v) => from_int!(ptype, v), - PScalar::I32(v) => from_int!(ptype, v), - PScalar::I64(v) => from_int!(ptype, v), - PScalar::F16(v) => match ptype { - PType::F16 => Ok((*v).into()), - PType::F32 => Ok(v.to_f32().into()), - PType::F64 => Ok(v.to_f64().into()), - _ => Err(vortex_err!(MismatchedTypes: "any float", ptype)), - }, - PScalar::F32(v) => from_floating!(ptype, v), - PScalar::F64(v) => from_floating!(ptype, v), - } - } + let ptype = PType::try_from(value.dtype())?; - pub fn is_positive(&self) -> bool { - match self { - PScalar::U8(v) => *v > 0, - PScalar::U16(v) => *v > 0, - PScalar::U32(v) => *v > 0, - PScalar::U64(v) => *v > 0, - PScalar::I8(v) => *v > 0, - PScalar::I16(v) => *v > 0, - PScalar::I32(v) => *v > 0, - PScalar::I64(v) => *v > 0, - PScalar::F16(v) => v.to_f32() > 0.0, - PScalar::F32(v) => *v > 0.0, - PScalar::F64(v) => *v > 0.0, - } - } + // Read the serialized value into the correct PValue. + // The serialized form may come back over the wire as e.g. any integer type. + let pvalue = match_each_native_ptype!(ptype, |$T| { + if let Some(pvalue) = value.value.as_pvalue()? { + Some(PValue::from(<$T>::try_from(pvalue)?)) + } else { + None + } + }); - pub fn is_negative(&self) -> bool { - match self { - PScalar::U8(_) => false, - PScalar::U16(_) => false, - PScalar::U32(_) => false, - PScalar::U64(_) => false, - PScalar::I8(v) => *v < 0, - PScalar::I16(v) => *v < 0, - PScalar::I32(v) => *v < 0, - PScalar::I64(v) => *v < 0, - PScalar::F16(v) => v.to_f32() < 0.0, - PScalar::F32(v) => *v < 0.0, - PScalar::F64(v) => *v < 0.0, - } + Ok(Self { + dtype: value.dtype(), + ptype, + pvalue, + }) } +} - pub fn is_zero(&self) -> bool { - match self { - PScalar::U8(v) => *v == 0, - PScalar::U16(v) => *v == 0, - PScalar::U32(v) => *v == 0, - PScalar::U64(v) => *v == 0, - PScalar::I8(v) => *v == 0, - PScalar::I16(v) => *v == 0, - PScalar::I32(v) => *v == 0, - PScalar::I64(v) => *v == 0, - PScalar::F16(v) => (*v).is_zero(), - PScalar::F32(v) => (*v).is_zero(), - PScalar::F64(v) => (*v).is_zero(), +impl Scalar { + pub fn primitive>(value: T, nullability: Nullability) -> Scalar { + Scalar { + dtype: DType::Primitive(T::PTYPE, nullability), + value: ScalarValue::Primitive(value.into()), } } } -#[inline] -fn is_negative(value: T) -> bool { - value < T::default() -} - -macro_rules! pscalar { - ($T:ty, $ptype:tt) => { - impl From<$T> for PScalar { - fn from(value: $T) -> Self { - PScalar::$ptype(value) - } - } - +macro_rules! primitive_scalar { + ($T:ty) => { impl From<$T> for Scalar { fn from(value: $T) -> Self { - PrimitiveScalar::some(value).into() - } - } - - impl TryFrom<&Scalar> for $T { - type Error = VortexError; - - fn try_from(value: &Scalar) -> VortexResult { - match value { - Scalar::Primitive(PrimitiveScalar { - value: Some(pscalar), - .. - }) => match pscalar { - PScalar::$ptype(v) => Ok(*v), - _ => Err(vortex_err!(MismatchedTypes: any::type_name::(), pscalar.ptype())), - }, - _ => Err(vortex_err!("can't extract {} from scalar: {}", any::type_name::(), value)), + Scalar { + dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable), + value: ScalarValue::Primitive(value.into()), } } } - impl TryFrom for $T { - type Error = VortexError; - - fn try_from(value: Scalar) -> VortexResult { - match value { - Scalar::Primitive(PrimitiveScalar { - value: Some(pscalar), - .. - }) => pscalar.try_into(), - _ => Err(vortex_err!( - "Can't extract value of type {} from primitive scalar: {}", - any::type_name::(), - value - )), + impl From> for Scalar { + fn from(value: Option<$T>) -> Self { + Scalar { + dtype: DType::Primitive(<$T>::PTYPE, Nullability::Nullable), + value: value + .map(|v| ScalarValue::Primitive(v.into())) + .unwrap_or_else(|| ScalarValue::Null), } } } - impl TryFrom for $T { + impl TryFrom<&Scalar> for $T { type Error = VortexError; - fn try_from(value: PScalar) -> Result { - match value { - PScalar::$ptype(v) => Ok(v), - _ => Err(vortex_err!( - "Expected {} type but got {}", - any::type_name::(), - value - )), - } + fn try_from(value: &Scalar) -> Result { + PrimitiveScalar::try_from(value)? + .typed_value::<$T>() + .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) } } }; } -pscalar!(u8, U8); -pscalar!(u16, U16); -pscalar!(u32, U32); -pscalar!(u64, U64); -pscalar!(i8, I8); -pscalar!(i16, I16); -pscalar!(i32, I32); -pscalar!(i64, I64); -pscalar!(f16, F16); -pscalar!(f32, F32); -pscalar!(f64, F64); - -impl From> for Scalar { - fn from(value: Option) -> Self { - PrimitiveScalar::nullable(value).into() - } -} +primitive_scalar!(u8); +primitive_scalar!(u16); +primitive_scalar!(u32); +primitive_scalar!(u64); +primitive_scalar!(i8); +primitive_scalar!(i16); +primitive_scalar!(i32); +primitive_scalar!(i64); +primitive_scalar!(f16); +primitive_scalar!(f32); +primitive_scalar!(f64); impl From for Scalar { - #[inline] fn from(value: usize) -> Self { - PrimitiveScalar::some::(value as u64).into() - } -} - -impl TryFrom<&PrimitiveScalar> for usize { - type Error = VortexError; - - fn try_from(value: &PrimitiveScalar) -> Result { - match_each_integer_ptype!(value.ptype(), |$V| { - match value.typed_value::<$V>() { - None => Err(vortex_err!(ComputeError: "required non null scalar")), - Some(v) => { - if is_negative(v) { - vortex_bail!(ComputeError: "required positive integer"); - } - Ok(v as usize) - } - } - }) - } -} - -impl TryFrom for usize { - type Error = VortexError; - - fn try_from(value: Scalar) -> VortexResult { - match value { - Scalar::Primitive(p) => (&p).try_into(), - _ => Err(vortex_err!("can't extract usize out of scalar: {}", value)), - } + Scalar::from(value as u64) } } +/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics. impl TryFrom<&Scalar> for usize { type Error = VortexError; - fn try_from(value: &Scalar) -> VortexResult { - match value { - Scalar::Primitive(p) => p.try_into(), - _ => Err(vortex_err!("can't extract usize out of scalar: {}", value)), - } - } -} - -impl Display for PScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - PScalar::U8(p) => Display::fmt(p, f), - PScalar::U16(p) => Display::fmt(p, f), - PScalar::U32(p) => Display::fmt(p, f), - PScalar::U64(p) => Display::fmt(p, f), - PScalar::I8(p) => Display::fmt(p, f), - PScalar::I16(p) => Display::fmt(p, f), - PScalar::I32(p) => Display::fmt(p, f), - PScalar::I64(p) => Display::fmt(p, f), - PScalar::F16(p) => Display::fmt(p, f), - PScalar::F32(p) => Display::fmt(p, f), - PScalar::F64(p) => Display::fmt(p, f), - } - } -} - -#[cfg(test)] -mod test { - use vortex_dtype::PType; - use vortex_dtype::{DType, Nullability}; - use vortex_error::VortexError; - - use crate::Scalar; - - #[test] - fn into_from() { - let scalar: Scalar = 10u16.into(); - assert_eq!(u16::try_from(scalar.clone()).unwrap(), 10u16); - // All integers should be convertible to usize - assert_eq!(usize::try_from(scalar).unwrap(), 10usize); - - let scalar: Scalar = (-10i16).into(); - let error = usize::try_from(scalar).err().unwrap(); - let VortexError::ComputeError(s, _) = error else { - unreachable!() - }; - assert_eq!(s.to_string(), "required positive integer"); - } - - #[test] - fn cast() { - let scalar: Scalar = 10u16.into(); - let u32_scalar = scalar - .cast(&DType::Primitive(PType::U32, Nullability::NonNullable)) - .unwrap(); - let u32_scalar_ptype: PType = u32_scalar.dtype().try_into().unwrap(); - assert_eq!(u32_scalar_ptype, PType::U32); + fn try_from(value: &Scalar) -> Result { + u64::try_from( + value + .cast(&DType::Primitive(PType::U64, Nullability::NonNullable))? + .as_ref(), + ) + .map(|v| v as usize) } } diff --git a/vortex-scalar/src/pvalue.rs b/vortex-scalar/src/pvalue.rs new file mode 100644 index 0000000000..b4bca018ab --- /dev/null +++ b/vortex-scalar/src/pvalue.rs @@ -0,0 +1,132 @@ +use num_traits::NumCast; +use vortex_dtype::half::f16; +use vortex_dtype::PType; +use vortex_error::vortex_err; +use vortex_error::VortexError; + +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +pub enum PValue { + U8(u8), + U16(u16), + U32(u32), + U64(u64), + I8(i8), + I16(i16), + I32(i32), + I64(i64), + F16(f16), + F32(f32), + F64(f64), +} + +impl PValue { + pub fn ptype(&self) -> PType { + match self { + PValue::U8(_) => PType::U8, + PValue::U16(_) => PType::U16, + PValue::U32(_) => PType::U32, + PValue::U64(_) => PType::U64, + PValue::I8(_) => PType::I8, + PValue::I16(_) => PType::I16, + PValue::I32(_) => PType::I32, + PValue::I64(_) => PType::I64, + PValue::F16(_) => PType::F16, + PValue::F32(_) => PType::F32, + PValue::F64(_) => PType::F64, + } + } +} + +macro_rules! int_pvalue { + ($T:ty, $PT:tt) => { + impl TryFrom for $T { + type Error = VortexError; + + fn try_from(value: PValue) -> Result { + match value { + PValue::U8(v) => <$T as NumCast>::from(v), + PValue::U16(v) => <$T as NumCast>::from(v), + PValue::U32(v) => <$T as NumCast>::from(v), + PValue::U64(v) => <$T as NumCast>::from(v), + PValue::I8(v) => <$T as NumCast>::from(v), + PValue::I16(v) => <$T as NumCast>::from(v), + PValue::I32(v) => <$T as NumCast>::from(v), + PValue::I64(v) => <$T as NumCast>::from(v), + _ => None, + } + .ok_or_else(|| { + vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT) + }) + } + } + }; +} + +int_pvalue!(u8, U8); +int_pvalue!(u16, U16); +int_pvalue!(u32, U32); +int_pvalue!(u64, U64); +int_pvalue!(i8, I8); +int_pvalue!(i16, I16); +int_pvalue!(i32, I32); +int_pvalue!(i64, I64); + +macro_rules! float_pvalue { + ($T:ty, $PT:tt) => { + impl TryFrom for $T { + type Error = VortexError; + + fn try_from(value: PValue) -> Result { + match value { + PValue::F16(f) => <$T as NumCast>::from(f), + PValue::F32(f) => <$T as NumCast>::from(f), + PValue::F64(f) => <$T as NumCast>::from(f), + _ => None, + } + .ok_or_else(|| { + vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT) + }) + } + } + }; +} + +float_pvalue!(f32, F32); +float_pvalue!(f64, F64); + +impl TryFrom for f16 { + type Error = VortexError; + + fn try_from(value: PValue) -> Result { + // We serialize f16 as u16. + match value { + PValue::U16(u) => Some(f16::from_bits(u)), + PValue::F32(f) => ::from(f), + PValue::F64(f) => ::from(f), + _ => None, + } + .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F16)) + } +} + +macro_rules! impl_pvalue { + ($T:ty, $PT:tt) => { + impl From<$T> for PValue { + fn from(value: $T) -> Self { + PValue::$PT(value) + } + } + }; +} + +impl_pvalue!(u8, U8); +impl_pvalue!(u16, U16); +impl_pvalue!(u32, U32); +impl_pvalue!(u64, U64); +impl_pvalue!(i8, I8); +impl_pvalue!(i16, I16); +impl_pvalue!(i32, I32); +impl_pvalue!(i64, I64); +impl_pvalue!(f16, F16); +impl_pvalue!(f32, F32); +impl_pvalue!(f64, F64); diff --git a/vortex-scalar/src/serde.rs b/vortex-scalar/src/serde.rs deleted file mode 100644 index 203865913d..0000000000 --- a/vortex-scalar/src/serde.rs +++ /dev/null @@ -1,250 +0,0 @@ -#![cfg(feature = "serde")] -#![cfg(feature = "flatbuffers")] -use flatbuffers::{root, FlatBufferBuilder, WIPOffset}; -use serde::de::Visitor; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use vortex_dtype::match_each_native_ptype; -use vortex_dtype::Nullability; -use vortex_error::{vortex_bail, VortexError}; -use vortex_flatbuffers::{FlatBufferRoot, FlatBufferToBytes, ReadFlatBuffer, WriteFlatBuffer}; - -use crate::flatbuffers::scalar as fb; -use crate::{PScalar, PrimitiveScalar, Scalar, Utf8Scalar}; - -impl FlatBufferRoot for Scalar {} - -impl WriteFlatBuffer for Scalar { - type Target<'a> = fb::Scalar<'a>; - - fn write_flatbuffer<'fb>( - &self, - fbb: &mut FlatBufferBuilder<'fb>, - ) -> WIPOffset> { - let union = match self { - Scalar::Binary(b) => { - let bytes = b.value().map(|bytes| fbb.create_vector(bytes)); - fb::ScalarArgs { - type_type: fb::Type::Binary, - type_: Some( - fb::Binary::create(fbb, &fb::BinaryArgs { value: bytes }).as_union_value(), - ), - nullability: self.nullability().into(), - } - } - Scalar::Bool(b) => fb::ScalarArgs { - type_type: fb::Type::Bool, - // TODO(ngates): I think this optional is in the wrong place and should be inside BoolArgs. - // However I think Rust Flatbuffers has incorrectly generated non-optional BoolArgs. - type_: b - .value() - .map(|&value| fb::Bool::create(fbb, &fb::BoolArgs { value }).as_union_value()), - nullability: self.nullability().into(), - }, - Scalar::List(_) => panic!("List not supported in scalar serde"), - Scalar::Null(_) => fb::ScalarArgs { - type_type: fb::Type::Null, - type_: Some(fb::Null::create(fbb, &fb::NullArgs {}).as_union_value()), - nullability: self.nullability().into(), - }, - Scalar::Primitive(p) => { - let bytes = p.value().map(|pscalar| match pscalar { - PScalar::U8(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::U16(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::U32(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::U64(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::I8(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::I16(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::I32(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::I64(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::F16(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::F32(v) => fbb.create_vector(&v.to_le_bytes()), - PScalar::F64(v) => fbb.create_vector(&v.to_le_bytes()), - }); - let primitive = fb::Primitive::create( - fbb, - &fb::PrimitiveArgs { - ptype: p.ptype().into(), - bytes, - }, - ); - fb::ScalarArgs { - type_type: fb::Type::Primitive, - type_: Some(primitive.as_union_value()), - nullability: self.nullability().into(), - } - } - Scalar::Struct(_) => panic!(), - Scalar::Utf8(utf) => { - let value = utf.value().map(|utf| fbb.create_string(utf)); - let value = fb::UTF8::create(fbb, &fb::UTF8Args { value }).as_union_value(); - fb::ScalarArgs { - type_type: fb::Type::UTF8, - type_: Some(value), - nullability: self.nullability().into(), - } - } - Scalar::Extension(ext) => { - let id = Some(fbb.create_string(ext.id().as_ref())); - let metadata = ext.metadata().map(|m| fbb.create_vector(m.as_ref())); - let value = ext.value().map(|s| s.write_flatbuffer(fbb)); - fb::ScalarArgs { - type_type: fb::Type::Extension, - type_: Some( - fb::Extension::create( - fbb, - &fb::ExtensionArgs { - id, - metadata, - value, - }, - ) - .as_union_value(), - ), - nullability: self.nullability().into(), - } - } - }; - - fb::Scalar::create(fbb, &union) - } -} - -impl ReadFlatBuffer for Scalar { - type Source<'a> = fb::Scalar<'a>; - type Error = VortexError; - - fn read_flatbuffer(fb: &Self::Source<'_>) -> Result { - let nullability = Nullability::from(fb.nullability()); - match fb.type_type() { - fb::Type::Binary => { - todo!() - } - fb::Type::Bool => { - todo!() - } - fb::Type::List => { - todo!() - } - fb::Type::Null => { - todo!() - } - fb::Type::Primitive => { - let primitive = fb.type__as_primitive().expect("missing Primitive value"); - let ptype = primitive.ptype().try_into()?; - Ok(match_each_native_ptype!(ptype, |$T| { - Scalar::Primitive(PrimitiveScalar::try_new( - if let Some(bytes) = primitive.bytes() { - Some($T::from_le_bytes(bytes.bytes().try_into()?)) - } else { - None - }, - nullability, - )?) - })) - } - fb::Type::Struct_ => { - todo!() - } - fb::Type::UTF8 => Ok(Scalar::Utf8(Utf8Scalar::try_new( - fb.type__as_utf8() - .expect("missing UTF8 value") - .value() - .map(|s| s.to_string()), - nullability, - )?)), - _ => vortex_bail!(InvalidSerde: "Unrecognized scalar type"), - } - } -} - -impl Serialize for Scalar { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - self.with_flatbuffer_bytes(|bytes| serializer.serialize_bytes(bytes)) - } -} - -struct ScalarDeserializer; - -impl<'de> Visitor<'de> for ScalarDeserializer { - type Value = Scalar; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a vortex dtype") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { - let fb = root::(v).map_err(E::custom)?; - Scalar::read_flatbuffer(&fb).map_err(E::custom) - } -} - -// TODO(ngates): Should we just inline composites in scalars? -impl<'de> Deserialize<'de> for Scalar { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_bytes(ScalarDeserializer) - } -} - -// impl<'a, 'b> ScalarReader<'a, 'b> { -// pub fn read(&mut self) -> VortexResult { -// let bytes = self.reader.read_slice()?; -// let scalar = root::(&bytes) -// .map_err(|_e| VortexError::InvalidArgument("Invalid FlatBuffer".into())) -// .unwrap(); - -// } -// -// fn read_primitive_scalar(&mut self) -> VortexResult { -// let ptype = self.reader.ptype()?; -// let is_present = self.reader.read_option_tag()?; -// if is_present { -// let pscalar = match ptype { -// PType::U8 => PrimitiveScalar::some(PScalar::U8(u8::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::U16 => PrimitiveScalar::some(PScalar::U16(u16::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::U32 => PrimitiveScalar::some(PScalar::U32(u32::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::U64 => PrimitiveScalar::some(PScalar::U64(u64::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::I8 => PrimitiveScalar::some(PScalar::I8(i8::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::I16 => PrimitiveScalar::some(PScalar::I16(i16::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::I32 => PrimitiveScalar::some(PScalar::I32(i32::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::I64 => PrimitiveScalar::some(PScalar::I64(i64::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::F16 => PrimitiveScalar::some(PScalar::F16(f16::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::F32 => PrimitiveScalar::some(PScalar::F32(f32::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// PType::F64 => PrimitiveScalar::some(PScalar::F64(f64::from_le_bytes( -// self.reader.read_nbytes()?, -// ))), -// }; -// Ok(pscalar) -// } else { -// Ok(PrimitiveScalar::none(ptype)) -// } -// } -// } diff --git a/vortex-scalar/src/serde/flatbuffers.rs b/vortex-scalar/src/serde/flatbuffers.rs new file mode 100644 index 0000000000..f1e8925e96 --- /dev/null +++ b/vortex-scalar/src/serde/flatbuffers.rs @@ -0,0 +1,29 @@ +#![cfg(feature = "flatbuffers")] + +use itertools::Itertools; +use serde::Deserialize; +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexError}; + +use crate::flatbuffers::scalar as fb; +use crate::{Scalar, ScalarValue}; + +impl TryFrom> for Scalar { + type Error = VortexError; + + fn try_from(value: fb::Scalar<'_>) -> Result { + let dtype = value.dtype(); + let dtype = DType::try_from(dtype)?; + + let flex_value = value + .value() + .ok_or_else(|| vortex_err!("Missing scalar value"))?; + + // TODO(ngates): what's the point of all this if I have to copy the data into a Vec? + let flex_value = flex_value.iter().collect_vec(); + let reader = flexbuffers::Reader::get_root(flex_value.as_slice())?; + let value = ScalarValue::deserialize(reader)?; + + Ok(Scalar { dtype, value }) + } +} diff --git a/vortex-scalar/src/serde/mod.rs b/vortex-scalar/src/serde/mod.rs new file mode 100644 index 0000000000..a5e9e2a19a --- /dev/null +++ b/vortex-scalar/src/serde/mod.rs @@ -0,0 +1,4 @@ +mod flatbuffers; +mod proto; +#[allow(clippy::module_inception)] +mod serde; diff --git a/vortex-scalar/src/serde/proto.rs b/vortex-scalar/src/serde/proto.rs new file mode 100644 index 0000000000..72dcee1e4e --- /dev/null +++ b/vortex-scalar/src/serde/proto.rs @@ -0,0 +1,41 @@ +#![cfg(feature = "proto")] + +use vortex_buffer::{Buffer, BufferString}; +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexError}; + +use crate::proto::scalar::scalar::Value; +use crate::pvalue::PValue; +use crate::{proto::scalar as pb, Scalar, ScalarValue}; + +impl TryFrom<&pb::Scalar> for Scalar { + type Error = VortexError; + + fn try_from(value: &pb::Scalar) -> Result { + let dtype = DType::try_from( + value + .dtype + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?, + )?; + + let value = value + .value + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?; + + let value = match value { + Value::Bool(b) => ScalarValue::Bool(*b), + Value::Uint32(v) => ScalarValue::Primitive(PValue::U32(*v)), + Value::Uint64(v) => ScalarValue::Primitive(PValue::U64(*v)), + Value::Sint32(v) => ScalarValue::Primitive(PValue::I32(*v)), + Value::Sint64(v) => ScalarValue::Primitive(PValue::I64(*v)), + Value::Float(v) => ScalarValue::Primitive(PValue::F32(*v)), + Value::Double(v) => ScalarValue::Primitive(PValue::F64(*v)), + Value::Bytes(v) => ScalarValue::Buffer(Buffer::from(v.clone())), + Value::String(v) => ScalarValue::BufferString(BufferString::from(v.clone())), + }; + + Ok(Scalar { dtype, value }) + } +} diff --git a/vortex-scalar/src/serde/serde.rs b/vortex-scalar/src/serde/serde.rs new file mode 100644 index 0000000000..5f2f59f4cb --- /dev/null +++ b/vortex-scalar/src/serde/serde.rs @@ -0,0 +1,175 @@ +#![cfg(feature = "serde")] + +use std::fmt::Formatter; + +use serde::de::{Error, SeqAccess, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use vortex_buffer::BufferString; + +use crate::pvalue::PValue; +use crate::value::ScalarValue; + +impl Serialize for ScalarValue { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + ScalarValue::Null => ().serialize(serializer), + ScalarValue::Bool(b) => b.serialize(serializer), + ScalarValue::Primitive(p) => p.serialize(serializer), + ScalarValue::Buffer(buffer) => buffer.as_ref().serialize(serializer), + ScalarValue::BufferString(buffer) => buffer.as_str().serialize(serializer), + ScalarValue::List(l) => l.serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for ScalarValue { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ScalarValueVisitor; + impl<'v> Visitor<'v> for ScalarValueVisitor { + type Value = ScalarValue; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + write!(formatter, "a scalar data value") + } + + fn visit_bool(self, v: bool) -> Result + where + E: Error, + { + Ok(ScalarValue::Bool(v)) + } + + fn visit_i8(self, v: i8) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::I8(v))) + } + + fn visit_i16(self, v: i16) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::I16(v))) + } + + fn visit_i32(self, v: i32) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::I32(v))) + } + + fn visit_i64(self, v: i64) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::I64(v))) + } + + fn visit_u8(self, v: u8) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::U8(v))) + } + + fn visit_u16(self, v: u16) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::U16(v))) + } + + fn visit_u32(self, v: u32) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::U32(v))) + } + + fn visit_u64(self, v: u64) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::U64(v))) + } + + fn visit_f32(self, v: f32) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::F32(v))) + } + + fn visit_f64(self, v: f64) -> Result + where + E: Error, + { + Ok(ScalarValue::Primitive(PValue::F64(v))) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + Ok(ScalarValue::BufferString(BufferString::from(v.to_string()))) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: Error, + { + Ok(ScalarValue::Buffer(v.to_vec().into())) + } + + fn visit_unit(self) -> Result + where + E: Error, + { + Ok(ScalarValue::Null) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'v>, + { + let mut elems = vec![]; + while let Some(e) = seq.next_element::()? { + elems.push(e); + } + Ok(ScalarValue::List(elems.into())) + } + } + + deserializer.deserialize_any(ScalarValueVisitor) + } +} + +impl Serialize for PValue { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + PValue::U8(v) => serializer.serialize_u8(*v), + PValue::U16(v) => serializer.serialize_u16(*v), + PValue::U32(v) => serializer.serialize_u32(*v), + PValue::U64(v) => serializer.serialize_u64(*v), + PValue::I8(v) => serializer.serialize_i8(*v), + PValue::I16(v) => serializer.serialize_i16(*v), + PValue::I32(v) => serializer.serialize_i32(*v), + PValue::I64(v) => serializer.serialize_i64(*v), + // NOTE(ngates): f16's are serialized bit-wise as u16. + PValue::F16(v) => serializer.serialize_u16(v.to_bits()), + PValue::F32(v) => serializer.serialize_f32(*v), + PValue::F64(v) => serializer.serialize_f64(*v), + } + } +} diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/struct_.rs index d52e5f771f..08dcb6472c 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/struct_.rs @@ -1,97 +1,65 @@ -use std::cmp::Ordering; -use std::fmt::{Display, Formatter}; +use std::sync::Arc; -use itertools::Itertools; -use vortex_dtype::{DType, FieldNames, Nullability, StructDType}; -use vortex_error::{vortex_bail, vortex_err, VortexResult}; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexError, VortexResult}; +use crate::value::ScalarValue; use crate::Scalar; -#[derive(Debug, Clone, PartialEq)] -pub struct StructScalar { - dtype: DType, - values: Vec, +pub struct StructScalar<'a> { + dtype: &'a DType, + fields: Option>, } -impl StructScalar { +impl<'a> StructScalar<'a> { #[inline] - pub fn new(dtype: DType, values: Vec) -> Self { - Self { dtype, values } + pub fn dtype(&self) -> &'a DType { + self.dtype } - #[inline] - pub fn values(&self) -> &[Scalar] { - self.values.as_ref() - } - - #[inline] - pub fn dtype(&self) -> &DType { - &self.dtype + pub fn field_by_idx(&self, idx: usize, dtype: DType) -> Option { + self.fields + .as_ref() + .and_then(|fields| fields.get(idx)) + .map(|field| Scalar { + dtype, + value: field.clone(), + }) } - pub fn names(&self) -> &FieldNames { - let DType::Struct(st, _) = self.dtype() else { - unreachable!("Not a scalar dtype"); + pub fn field(&self, name: &str, dtype: DType) -> Option { + let DType::Struct(struct_dtype, _) = self.dtype() else { + unreachable!() }; - st.names() - } - - pub fn cast(&self, dtype: &DType) -> VortexResult { - match dtype { - DType::Struct(st, n) => { - // TODO(ngates): check nullability. - assert_eq!(Nullability::NonNullable, *n); - - if st.dtypes().len() != self.values.len() { - vortex_bail!( - MismatchedTypes: format!("Struct with {} fields", self.values.len()), - dtype - ); - } - - let new_fields: Vec = self - .values - .iter() - .zip_eq(st.dtypes().iter()) - .map(|(field, field_dtype)| field.cast(field_dtype)) - .try_collect()?; - - let new_type = DType::Struct( - StructDType::new( - st.names().clone(), - new_fields.iter().map(|x| x.dtype().clone()).collect(), - ), - dtype.nullability(), - ); - Ok(StructScalar::new(new_type, new_fields).into()) - } - _ => Err(vortex_err!(MismatchedTypes: "struct", dtype)), - } + struct_dtype + .find_name(name) + .and_then(|idx| self.field_by_idx(idx, dtype)) } - pub fn nbytes(&self) -> usize { - self.values().iter().map(|s| s.nbytes()).sum() + pub fn cast(&self, _dtype: &DType) -> VortexResult { + todo!() } } -impl PartialOrd for StructScalar { - fn partial_cmp(&self, other: &Self) -> Option { - if self.dtype != other.dtype { - None - } else { - self.values.partial_cmp(&other.values) +impl Scalar { + pub fn r#struct(dtype: DType, children: Vec) -> Scalar { + Scalar { + dtype, + value: ScalarValue::List(children.into()), } } } -impl Display for StructScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let DType::Struct(st, _) = self.dtype() else { - unreachable!() - }; - for (n, v) in st.names().iter().zip(self.values.iter()) { - write!(f, "{} = {}", n, v)?; +impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> { + type Error = VortexError; + + fn try_from(value: &'a Scalar) -> Result { + if matches!(value.dtype(), DType::Struct(..)) { + vortex_bail!("Expected struct scalar, found {}", value.dtype()) } - Ok(()) + Ok(Self { + dtype: value.dtype(), + fields: value.value.as_list()?.cloned(), + }) } } diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index b3923f8ffd..e5eb25df49 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -1,75 +1,73 @@ -use std::fmt::{Display, Formatter}; - -use vortex_dtype::{DType, Nullability::NonNullable, Nullability::Nullable}; +use vortex_buffer::BufferString; +use vortex_dtype::DType; +use vortex_dtype::Nullability; +use vortex_dtype::Nullability::NonNullable; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use crate::value::ScalarValue; use crate::Scalar; -pub type Utf8Scalar = ScalarValue; +pub struct Utf8Scalar<'a> { + dtype: &'a DType, + value: Option, +} -impl Utf8Scalar { +impl<'a> Utf8Scalar<'a> { #[inline] - pub fn dtype(&self) -> &DType { - match self.nullability() { - NonNullable => &DType::Utf8(NonNullable), - Nullable => &DType::Utf8(Nullable), - } - } - - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub fn dtype(&self) -> &'a DType { + self.dtype } - pub fn nbytes(&self) -> usize { - self.value().map(|v| v.len()).unwrap_or(0) + pub fn value(&self) -> Option { + self.value.as_ref().cloned() } -} -impl From for Scalar { - fn from(value: String) -> Self { - Utf8Scalar::some(value).into() + pub fn cast(&self, _dtype: &DType) -> VortexResult { + todo!() } } -impl From<&str> for Scalar { - fn from(value: &str) -> Self { - Utf8Scalar::some(value.to_string()).into() +impl Scalar { + pub fn utf8(str: B, nullability: Nullability) -> Self + where + BufferString: From, + { + Scalar { + dtype: DType::Utf8(nullability), + value: ScalarValue::BufferString(BufferString::from(str)), + } } } -impl TryFrom for String { +impl<'a> TryFrom<&'a Scalar> for Utf8Scalar<'a> { type Error = VortexError; - fn try_from(value: Scalar) -> Result { - let Scalar::Utf8(u) = value else { - vortex_bail!(MismatchedTypes: "Utf8", value.dtype()); - }; - - u.into_value() - .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + fn try_from(value: &'a Scalar) -> Result { + if !matches!(value.dtype(), DType::Utf8(_)) { + vortex_bail!("Expected utf8 scalar, found {}", value.dtype()) + } + Ok(Self { + dtype: value.dtype(), + value: value.value.as_buffer_string()?, + }) } } -impl TryFrom<&Scalar> for String { +impl<'a> TryFrom<&'a Scalar> for BufferString { type Error = VortexError; - fn try_from(value: &Scalar) -> Result { - let Scalar::Utf8(u) = value else { - vortex_bail!(MismatchedTypes: "Utf8", value.dtype()); - }; - - u.value() - .cloned() + fn try_from(value: &'a Scalar) -> VortexResult { + Utf8Scalar::try_from(value)? + .value() .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) } } -impl Display for Utf8Scalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.value() { - None => write!(f, ""), - Some(v) => write!(f, "\"{}\"", v), +impl From<&str> for Scalar { + fn from(value: &str) -> Self { + Scalar { + dtype: DType::Utf8(NonNullable), + value: ScalarValue::BufferString(value.to_string().into()), } } } diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/value.rs index 3618fbd19c..cebf1a3033 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/value.rs @@ -1,45 +1,68 @@ -use vortex_dtype::Nullability; -use vortex_error::{vortex_bail, VortexResult}; +use std::sync::Arc; -#[derive(Debug, Clone, PartialEq, PartialOrd)] -pub struct ScalarValue { - nullability: Nullability, - value: Option, -} +use vortex_buffer::{Buffer, BufferString}; +use vortex_error::{vortex_err, VortexResult}; -impl ScalarValue { - pub fn try_new(value: Option, nullability: Nullability) -> VortexResult { - if value.is_none() && nullability == Nullability::NonNullable { - vortex_bail!("Value cannot be None for NonNullable Scalar"); - } - Ok(Self { value, nullability }) - } +use crate::pvalue::PValue; - pub fn non_nullable(value: T) -> Self { - Self::try_new(Some(value), Nullability::NonNullable).unwrap() - } +/// Represents the internal data of a scalar value. Must be interpreted by wrapping +/// up with a DType to make a Scalar. +/// +/// Note that these values can be deserialized from JSON or other formats. So a PValue may not +/// have the correct width for what the DType expects. This means primitive values must be +/// cast on-read. +#[derive(Debug, Clone, PartialEq, PartialOrd)] +pub enum ScalarValue { + Null, + Bool(bool), + Primitive(PValue), + Buffer(Buffer), + BufferString(BufferString), + List(Arc<[ScalarValue]>), +} - pub fn nullable(value: T) -> Self { - Self::try_new(Some(value), Nullability::Nullable).unwrap() +impl ScalarValue { + pub fn is_null(&self) -> bool { + matches!(self, ScalarValue::Null) } - pub fn some(value: T) -> Self { - Self::try_new(Some(value), Nullability::default()).unwrap() + pub fn as_bool(&self) -> VortexResult> { + match self { + ScalarValue::Null => Ok(None), + ScalarValue::Bool(b) => Ok(Some(*b)), + _ => Err(vortex_err!("Expected a bool scalar, found {:?}", self)), + } } - pub fn none() -> Self { - Self::try_new(None, Nullability::Nullable).unwrap() + pub fn as_pvalue(&self) -> VortexResult> { + match self { + ScalarValue::Null => Ok(None), + ScalarValue::Primitive(p) => Ok(Some(*p)), + _ => Err(vortex_err!("Expected a primitive scalar, found {:?}", self)), + } } - pub fn value(&self) -> Option<&T> { - self.value.as_ref() + pub fn as_buffer(&self) -> VortexResult> { + match self { + ScalarValue::Null => Ok(None), + ScalarValue::Buffer(b) => Ok(Some(b.clone())), + _ => Err(vortex_err!("Expected a binary scalar, found {:?}", self)), + } } - pub fn into_value(self) -> Option { - self.value + pub fn as_buffer_string(&self) -> VortexResult> { + match self { + ScalarValue::Null => Ok(None), + ScalarValue::Buffer(b) => Ok(Some(BufferString::try_from(b.clone())?)), + ScalarValue::BufferString(b) => Ok(Some(b.clone())), + _ => Err(vortex_err!("Expected a string scalar, found {:?}", self)), + } } - pub fn nullability(&self) -> Nullability { - self.nullability + pub fn as_list(&self) -> VortexResult>> { + match self { + ScalarValue::List(l) => Ok(Some(l)), + _ => Err(vortex_err!("Expected a list scalar, found {:?}", self)), + } } } diff --git a/vortex-zigzag/src/compute.rs b/vortex-zigzag/src/compute.rs index fc43083924..0c76cc6502 100644 --- a/vortex-zigzag/src/compute.rs +++ b/vortex-zigzag/src/compute.rs @@ -1,9 +1,10 @@ use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; use vortex::compute::slice::{slice, SliceFn}; use vortex::compute::ArrayCompute; -use vortex::{ArrayDType, IntoArray, OwnedArray}; -use vortex_error::{vortex_err, VortexResult}; -use vortex_scalar::{PScalar, Scalar}; +use vortex::{IntoArray, OwnedArray}; +use vortex_dtype::PType; +use vortex_error::VortexResult; +use vortex_scalar::{PrimitiveScalar, Scalar}; use zigzag::ZigZag as ExternalZigZag; use crate::ZigZagArray; @@ -21,18 +22,13 @@ impl ArrayCompute for ZigZagArray<'_> { impl ScalarAtFn for ZigZagArray<'_> { fn scalar_at(&self, index: usize) -> VortexResult { let scalar = scalar_at(&self.encoded(), index)?; - match scalar { - Scalar::Primitive(p) => match p.value() { - None => Ok(Scalar::null(self.dtype())), - Some(p) => match p { - PScalar::U8(u) => Ok(i8::decode(u).into()), - PScalar::U16(u) => Ok(i16::decode(u).into()), - PScalar::U32(u) => Ok(i32::decode(u).into()), - PScalar::U64(u) => Ok(i64::decode(u).into()), - _ => Err(vortex_err!(MismatchedTypes: "unsigned int", self.dtype())), - }, - }, - _ => Err(vortex_err!(MismatchedTypes: "primitive scalar", self.dtype())), + let pscalar = PrimitiveScalar::try_from(&scalar)?; + match pscalar.ptype() { + PType::U8 => Ok(i8::decode(pscalar.typed_value::().unwrap()).into()), + PType::U16 => Ok(i16::decode(pscalar.typed_value::().unwrap()).into()), + PType::U32 => Ok(i32::decode(pscalar.typed_value::().unwrap()).into()), + PType::U64 => Ok(i64::decode(pscalar.typed_value::().unwrap()).into()), + _ => unreachable!(), } } } From c09d2d629db9f04b34ea3d97757708e67427d316 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 8 May 2024 14:35:58 +0100 Subject: [PATCH 3/4] StatsView2 (#305) Adds a few more flatbuffer implementations that I forgot and removes the closure methods from the stats trait since cloning is now much cheaper. --- vortex-array/src/data.rs | 33 +------------- vortex-array/src/stats/mod.rs | 61 +++++--------------------- vortex-scalar/flatbuffers/scalar.fbs | 6 ++- vortex-scalar/src/lib.rs | 16 ++++++- vortex-scalar/src/list.rs | 9 ++++ vortex-scalar/src/serde/flatbuffers.rs | 53 ++++++++++++++++++---- 6 files changed, 86 insertions(+), 92 deletions(-) diff --git a/vortex-array/src/data.rs b/vortex-array/src/data.rs index 8b64dba74a..f6fcdd46dd 100644 --- a/vortex-array/src/data.rs +++ b/vortex-array/src/data.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, RwLock}; use vortex_buffer::Buffer; use vortex_dtype::DType; -use vortex_error::{vortex_err, VortexResult}; +use vortex_error::VortexResult; use vortex_scalar::Scalar; use crate::encoding::EncodingRef; @@ -162,35 +162,4 @@ impl Statistics for ArrayData { ); self.get(stat) } - - #[inline] - fn with_stat_value<'a>( - &self, - stat: Stat, - f: &'a mut dyn FnMut(&Scalar) -> VortexResult<()>, - ) -> VortexResult<()> { - self.stats_map - .read() - .unwrap() - .get(stat) - .ok_or_else(|| vortex_err!(ComputeError: "statistic {} missing", stat)) - .and_then(f) - } - - #[inline] - fn with_computed_stat_value<'a>( - &self, - stat: Stat, - f: &'a mut dyn FnMut(&Scalar) -> VortexResult<()>, - ) -> VortexResult<()> { - if let Some(s) = self.stats_map.read().unwrap().get(stat) { - return f(s); - } - - self.stats_map - .write() - .unwrap() - .extend(self.to_array().with_dyn(|a| a.compute_statistics(stat))?); - self.with_stat_value(stat, f) - } } diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index 46e9fad434..150a150c13 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -3,8 +3,9 @@ use std::hash::Hash; use enum_iterator::Sequence; pub use statsset::*; +use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, NativePType}; -use vortex_error::{VortexError, VortexResult}; +use vortex_error::{vortex_err, VortexError, VortexResult}; use vortex_scalar::Scalar; mod statsset; @@ -51,18 +52,6 @@ pub trait Statistics { /// Computes the value of the stat if it's not present fn compute(&self, stat: Stat) -> Option; - - fn with_stat_value<'a>( - &self, - stat: Stat, - f: &'a mut dyn FnMut(&Scalar) -> VortexResult<()>, - ) -> VortexResult<()>; - - fn with_computed_stat_value<'a>( - &self, - stat: Stat, - f: &'a mut dyn FnMut(&Scalar) -> VortexResult<()>, - ) -> VortexResult<()>; } pub struct EmptyStatistics; @@ -81,24 +70,6 @@ impl Statistics for EmptyStatistics { fn compute(&self, _stat: Stat) -> Option { None } - - #[inline] - fn with_stat_value<'a>( - &self, - _stat: Stat, - _f: &'a mut dyn FnMut(&Scalar) -> VortexResult<()>, - ) -> VortexResult<()> { - Ok(()) - } - - #[inline] - fn with_computed_stat_value<'a>( - &self, - _stat: Stat, - _f: &'a mut dyn FnMut(&Scalar) -> VortexResult<()>, - ) -> VortexResult<()> { - Ok(()) - } } pub trait ArrayStatistics { @@ -117,36 +88,28 @@ impl dyn Statistics + '_ { &self, stat: Stat, ) -> VortexResult { - let mut res: Option = None; - self.with_stat_value(stat, &mut |s| { - res = Some(U::try_from(s)?); - Ok(()) - })?; - Ok(res.expect("Result should have been populated by previous call")) + self.get(stat) + .ok_or_else(|| vortex_err!(ComputeError: "statistic {} missing", stat)) + .and_then(|s| U::try_from(&s)) } pub fn compute_as TryFrom<&'a Scalar, Error = VortexError>>( &self, stat: Stat, ) -> VortexResult { - let mut res: Option = None; - self.with_computed_stat_value(stat, &mut |s| { - res = Some(U::try_from(s)?); - Ok(()) - })?; - Ok(res.expect("Result should have been populated by previous call")) + self.compute(stat) + .ok_or_else(|| vortex_err!(ComputeError: "statistic {} missing", stat)) + .and_then(|s| U::try_from(&s)) } pub fn compute_as_cast TryFrom<&'a Scalar, Error = VortexError>>( &self, stat: Stat, ) -> VortexResult { - let mut res: Option = None; - self.with_computed_stat_value(stat, &mut |s| { - res = Some(U::try_from(s.cast(&DType::from(U::PTYPE))?.as_ref())?); - Ok(()) - })?; - Ok(res.expect("Result should have been populated by previous call")) + self.compute(stat) + .ok_or_else(|| vortex_err!(ComputeError: "statistic {} missing", stat)) + .and_then(|s| s.cast(&DType::Primitive(U::PTYPE, NonNullable))) + .and_then(|s| U::try_from(s.as_ref())) } pub fn compute_min TryFrom<&'a Scalar, Error = VortexError>>( diff --git a/vortex-scalar/flatbuffers/scalar.fbs b/vortex-scalar/flatbuffers/scalar.fbs index 90519fd0da..f2b3a7667c 100644 --- a/vortex-scalar/flatbuffers/scalar.fbs +++ b/vortex-scalar/flatbuffers/scalar.fbs @@ -4,7 +4,11 @@ namespace vortex.scalar; table Scalar { dtype: vortex.dtype.DType (required); - value: [ubyte] (flexbuffer); + value: ScalarValue (required); +} + +table ScalarValue { + flex: [ubyte] (required, flexbuffer); } root_type Scalar; \ No newline at end of file diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 4a359b8ea2..faeee748cd 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -19,6 +19,7 @@ pub use bool::*; pub use extension::*; pub use list::*; pub use primitive::*; +pub use pvalue::*; pub use struct_::*; pub use utf8::*; pub use value::*; @@ -36,13 +37,13 @@ pub mod proto { #[cfg(feature = "flatbuffers")] pub mod flatbuffers { - pub use gen_scalar::vortex::*; + pub use generated::vortex::scalar::*; #[allow(unused_imports)] #[allow(dead_code)] #[allow(non_camel_case_types)] #[allow(clippy::all)] - mod gen_scalar { + pub mod generated { include!(concat!(env!("OUT_DIR"), "/flatbuffers/scalar.rs")); } @@ -62,10 +63,21 @@ pub struct Scalar { } impl Scalar { + pub fn new(dtype: DType, value: ScalarValue) -> Self { + Self { dtype, value } + } + + #[inline] pub fn dtype(&self) -> &DType { &self.dtype } + #[inline] + pub fn value(&self) -> &ScalarValue { + &self.value + } + + #[inline] pub fn into_value(self) -> ScalarValue { self.value } diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index e07e56074f..7dd42c2146 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -67,6 +67,15 @@ impl<'a> ListScalar<'a> { } } +impl Scalar { + pub fn list(element_dtype: DType, children: Vec) -> Scalar { + Scalar { + dtype: DType::List(Arc::new(element_dtype), NonNullable), + value: ScalarValue::List(children.into()), + } + } +} + impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> { type Error = VortexError; diff --git a/vortex-scalar/src/serde/flatbuffers.rs b/vortex-scalar/src/serde/flatbuffers.rs index f1e8925e96..4ef47e2cbe 100644 --- a/vortex-scalar/src/serde/flatbuffers.rs +++ b/vortex-scalar/src/serde/flatbuffers.rs @@ -1,11 +1,13 @@ #![cfg(feature = "flatbuffers")] +use flatbuffers::{FlatBufferBuilder, WIPOffset}; use itertools::Itertools; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use vortex_dtype::DType; -use vortex_error::{vortex_err, VortexError}; +use vortex_error::VortexError; +use vortex_flatbuffers::WriteFlatBuffer; -use crate::flatbuffers::scalar as fb; +use crate::flatbuffers as fb; use crate::{Scalar, ScalarValue}; impl TryFrom> for Scalar { @@ -15,15 +17,50 @@ impl TryFrom> for Scalar { let dtype = value.dtype(); let dtype = DType::try_from(dtype)?; - let flex_value = value - .value() - .ok_or_else(|| vortex_err!("Missing scalar value"))?; - // TODO(ngates): what's the point of all this if I have to copy the data into a Vec? - let flex_value = flex_value.iter().collect_vec(); + let flex_value = value.value().flex().iter().collect_vec(); let reader = flexbuffers::Reader::get_root(flex_value.as_slice())?; let value = ScalarValue::deserialize(reader)?; Ok(Scalar { dtype, value }) } } + +impl TryFrom> for ScalarValue { + type Error = VortexError; + + fn try_from(value: fb::ScalarValue<'_>) -> Result { + // TODO(ngates): what's the point of all this if I have to copy the data into a Vec? + let flex_value = value.flex().iter().collect_vec(); + let reader = flexbuffers::Reader::get_root(flex_value.as_slice())?; + Ok(ScalarValue::deserialize(reader)?) + } +} + +impl WriteFlatBuffer for Scalar { + type Target<'a> = fb::Scalar<'a>; + + fn write_flatbuffer<'fb>( + &self, + fbb: &mut FlatBufferBuilder<'fb>, + ) -> WIPOffset> { + let dtype = Some(self.dtype.write_flatbuffer(fbb)); + let value = Some(self.value.write_flatbuffer(fbb)); + fb::Scalar::create(fbb, &fb::ScalarArgs { dtype, value }) + } +} + +impl WriteFlatBuffer for ScalarValue { + type Target<'a> = fb::ScalarValue<'a>; + + fn write_flatbuffer<'fb>( + &self, + fbb: &mut FlatBufferBuilder<'fb>, + ) -> WIPOffset> { + let mut value_se = flexbuffers::FlexbufferSerializer::new(); + self.serialize(&mut value_se) + .expect("Failed to serialize ScalarValue"); + let flex = Some(fbb.create_vector(value_se.view())); + fb::ScalarValue::create(fbb, &fb::ScalarValueArgs { flex }) + } +} From 59d4db07af60f2861091a4695ff0baaa05a3e04e Mon Sep 17 00:00:00 2001 From: Josh Casale Date: Wed, 8 May 2024 15:05:34 -0400 Subject: [PATCH 4/4] Include stats in IPC messages (#302) - Modifies the flatbuffer array type to include an ArrayStats table - Array Table contains optional fields corresponding to each currently available stats type - Current implementation populates all stats - Implementation of the Statistics trait for ArrayView - Implementation does no allocations or computation, only references values that exist in the underlying flatbuffer - For demonstrative purposes, I've [written (and removed) an implementation](https://github.com/spiraldb/vortex/commit/de4077c35238795565f27203b0b0a550f355b1f0) that allocates only if someone calls set(stat, value) to populate additional, possibly missing, stats. I've removed this because we don't currently have a use for it, but it's easy enough to do without any unsafe shenanigans. - Callers can specify which stats should be included with a serialized IPC array when constructing a ViewContext. By default, all stats are included. - Tests demonstrating the presence of correct stats after a round-trip through IPC for primitive and chunked arrays ~I've included a mechanism to configure all of the statistics by default here because the overhead they add to the flatbuffer message is relatively small, given that the arrays themselves are sufficiently large. I considered adding a mechanism to check the length of the arrays [here](https://github.com/spiraldb/vortex/pull/302/files#diff-b7cc44a4bd1e1c769cb029b5ecaa98f080fdb7aa48b79566a9c8bb1306b84149R212) to choose a subset of stats based on the size of the array (probably just drop the two frequency arrays, because they're much larger than everything else), but decided against it for now. I don't think we expect to frequently see arrays small enough that these stats would add a relatively significant amount of wire overhead~ --------- Co-authored-by: Nicholas Gates --- vortex-array/flatbuffers/array.fbs | 17 ++++ vortex-array/src/array/chunked/stats.rs | 5 +- vortex-array/src/lib.rs | 4 +- vortex-array/src/stats/flatbuffers.rs | 49 ++++++++++ vortex-array/src/stats/mod.rs | 11 +++ vortex-array/src/view.rs | 89 ++++++++++++++++-- vortex-ipc/src/lib.rs | 8 +- vortex-ipc/src/messages.rs | 9 +- vortex-ipc/src/reader.rs | 117 +++++++++++++++++++++++- 9 files changed, 296 insertions(+), 13 deletions(-) create mode 100644 vortex-array/src/stats/flatbuffers.rs diff --git a/vortex-array/flatbuffers/array.fbs b/vortex-array/flatbuffers/array.fbs index 86961eedc6..838b9cb43a 100644 --- a/vortex-array/flatbuffers/array.fbs +++ b/vortex-array/flatbuffers/array.fbs @@ -1,3 +1,5 @@ +include "vortex-scalar/flatbuffers/scalar.fbs"; + namespace vortex.array; enum Version: uint8 { @@ -9,7 +11,22 @@ table Array { has_buffer: bool; encoding: uint16; metadata: [ubyte]; + stats: ArrayStats; children: [Array]; } +table ArrayStats { + min: vortex.scalar.ScalarValue; + max: vortex.scalar.ScalarValue; + is_sorted: bool = null; + is_strict_sorted: bool = null; + is_constant: bool = null; + run_count: uint64 = null; + true_count: uint64 = null; + null_count: uint64 = null; + bit_width_freq: [uint64]; + trailing_zero_freq: [uint64]; +} + + root_type Array; diff --git a/vortex-array/src/array/chunked/stats.rs b/vortex-array/src/array/chunked/stats.rs index 876f563ffd..be2bf3f02e 100644 --- a/vortex-array/src/array/chunked/stats.rs +++ b/vortex-array/src/array/chunked/stats.rs @@ -14,9 +14,10 @@ impl ArrayStatisticsCompute for ChunkedArray<'_> { s.compute(stat); s.to_set() }) - .fold(StatsSet::new(), |mut acc, x| { + .reduce(|mut acc, x| { acc.merge(&x); acc - })) + }) + .unwrap_or_else(StatsSet::new)) } } diff --git a/vortex-array/src/lib.rs b/vortex-array/src/lib.rs index 028038e860..19933b628e 100644 --- a/vortex-array/src/lib.rs +++ b/vortex-array/src/lib.rs @@ -39,13 +39,13 @@ use crate::validity::ArrayValidity; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; pub mod flatbuffers { - pub use gen_array::vortex::*; + pub use generated::vortex::array::*; #[allow(unused_imports)] #[allow(dead_code)] #[allow(non_camel_case_types)] #[allow(clippy::all)] - mod gen_array { + mod generated { include!(concat!(env!("OUT_DIR"), "/flatbuffers/array.rs")); } diff --git a/vortex-array/src/stats/flatbuffers.rs b/vortex-array/src/stats/flatbuffers.rs new file mode 100644 index 0000000000..3dfee1bc3b --- /dev/null +++ b/vortex-array/src/stats/flatbuffers.rs @@ -0,0 +1,49 @@ +use flatbuffers::{FlatBufferBuilder, WIPOffset}; +use itertools::Itertools; +use vortex_flatbuffers::WriteFlatBuffer; + +use crate::stats::{Stat, Statistics}; + +impl WriteFlatBuffer for &dyn Statistics { + type Target<'t> = crate::flatbuffers::ArrayStats<'t>; + + fn write_flatbuffer<'fb>( + &self, + fbb: &mut FlatBufferBuilder<'fb>, + ) -> WIPOffset> { + let trailing_zero_freq = self + .get_as::>(Stat::TrailingZeroFreq) + .ok() + .map(|v| v.iter().copied().collect_vec()) + .map(|v| fbb.create_vector(v.as_slice())); + + let bit_width_freq = self + .get_as::>(Stat::BitWidthFreq) + .ok() + .map(|v| v.iter().copied().collect_vec()) + .map(|v| fbb.create_vector(v.as_slice())); + + let min = self + .get(Stat::Min) + .map(|min| min.value().write_flatbuffer(fbb)); + + let max = self + .get(Stat::Max) + .map(|max| max.value().write_flatbuffer(fbb)); + + let stat_args = &crate::flatbuffers::ArrayStatsArgs { + min, + max, + is_sorted: self.get_as::(Stat::IsSorted).ok(), + is_strict_sorted: self.get_as::(Stat::IsStrictSorted).ok(), + is_constant: self.get_as::(Stat::IsConstant).ok(), + run_count: self.get_as_cast::(Stat::RunCount).ok(), + true_count: self.get_as_cast::(Stat::TrueCount).ok(), + null_count: self.get_as_cast::(Stat::NullCount).ok(), + bit_width_freq, + trailing_zero_freq, + }; + + crate::flatbuffers::ArrayStats::create(fbb, stat_args) + } +} diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index 150a150c13..59e12d691b 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -8,6 +8,7 @@ use vortex_dtype::{DType, NativePType}; use vortex_error::{vortex_err, VortexError, VortexResult}; use vortex_scalar::Scalar; +pub mod flatbuffers; mod statsset; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Sequence)] @@ -93,6 +94,16 @@ impl dyn Statistics + '_ { .and_then(|s| U::try_from(&s)) } + pub fn get_as_cast TryFrom<&'a Scalar, Error = VortexError>>( + &self, + stat: Stat, + ) -> VortexResult { + self.get(stat) + .ok_or_else(|| vortex_err!(ComputeError: "statistic {} missing", stat)) + .and_then(|s| s.cast(&DType::Primitive(U::PTYPE, NonNullable))) + .and_then(|s| U::try_from(&s)) + } + pub fn compute_as TryFrom<&'a Scalar, Error = VortexError>>( &self, stat: Stat, diff --git a/vortex-array/src/view.rs b/vortex-array/src/view.rs index 6a5f7388f6..96d4fbb036 100644 --- a/vortex-array/src/view.rs +++ b/vortex-array/src/view.rs @@ -1,13 +1,16 @@ use std::fmt::{Debug, Formatter}; +use enum_iterator::all; use itertools::Itertools; +use log::warn; use vortex_buffer::Buffer; -use vortex_dtype::DType; +use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use vortex_scalar::{PValue, Scalar, ScalarValue}; use crate::encoding::{EncodingId, EncodingRef}; -use crate::flatbuffers::array as fb; -use crate::stats::{EmptyStatistics, Statistics}; +use crate::flatbuffers as fb; +use crate::stats::{Stat, Statistics, StatsSet}; use crate::Context; use crate::{Array, IntoArray, ToArray}; @@ -53,7 +56,6 @@ impl<'v> ArrayView<'v> { Self::cumulative_nbuffers(array) ) } - let view = Self { encoding, dtype, @@ -136,8 +138,83 @@ impl<'v> ArrayView<'v> { } pub fn statistics(&self) -> &dyn Statistics { - // TODO(ngates): store statistics in FlatBuffers - &EmptyStatistics + self + } +} + +impl Statistics for ArrayView<'_> { + fn get(&self, stat: Stat) -> Option { + match stat { + Stat::Max => { + let max = self.array.stats()?.max(); + max.and_then(|v| ScalarValue::try_from(v).ok()) + .map(|v| Scalar::new(self.dtype.clone(), v)) + } + Stat::Min => { + let min = self.array.stats()?.min(); + min.and_then(|v| ScalarValue::try_from(v).ok()) + .map(|v| Scalar::new(self.dtype.clone(), v)) + } + Stat::IsConstant => self.array.stats()?.is_constant().map(bool::into), + Stat::IsSorted => self.array.stats()?.is_sorted().map(bool::into), + Stat::IsStrictSorted => self.array.stats()?.is_strict_sorted().map(bool::into), + Stat::RunCount => self.array.stats()?.run_count().map(u64::into), + Stat::TrueCount => self.array.stats()?.true_count().map(u64::into), + Stat::NullCount => self.array.stats()?.null_count().map(u64::into), + Stat::BitWidthFreq => self + .array + .stats()? + .bit_width_freq() + .map(|v| { + v.iter() + .map(|v| ScalarValue::Primitive(PValue::U64(v))) + .collect_vec() + }) + .map(|v| { + Scalar::list( + DType::Primitive(vortex_dtype::PType::U64, Nullability::NonNullable), + v, + ) + }), + Stat::TrailingZeroFreq => self + .array + .stats()? + .trailing_zero_freq() + .map(|v| v.iter().collect_vec()) + .map(|v| v.into()), + } + } + + /// NB: part of the contract for to_set is that it does not do any expensive computation. + /// In other implementations, this means returning the underlying stats map, but for the flatbuffer + /// implemetation, we have 'precalculated' stats in the flatbuffer itself, so we need to + /// alllocate a stats map and populate it with those fields. + fn to_set(&self) -> StatsSet { + let mut result = StatsSet::new(); + for stat in all::() { + if let Some(value) = self.get(stat) { + result.set(stat, value) + } + } + result + } + + /// We want to avoid any sort of allocation on instantiation of the ArrayView, so we + /// do not allocate a stats_set to cache values. + fn set(&self, _stat: Stat, _value: Scalar) { + warn!("Cannot write stats to a view") + } + + fn compute(&self, stat: Stat) -> Option { + if let Some(s) = self.get(stat) { + return Some(s); + } + + self.to_array() + .with_dyn(|a| a.compute_statistics(stat)) + .ok()? + .get(stat) + .cloned() } } diff --git a/vortex-ipc/src/lib.rs b/vortex-ipc/src/lib.rs index dbc4620191..d15a0ff6dd 100644 --- a/vortex-ipc/src/lib.rs +++ b/vortex-ipc/src/lib.rs @@ -15,11 +15,17 @@ pub mod flatbuffers { mod deps { pub mod array { - pub use vortex::flatbuffers::array; + pub use vortex::flatbuffers as array; } + pub mod dtype { pub use vortex_dtype::flatbuffers as dtype; } + + pub mod scalar { + #[allow(unused_imports)] + pub use vortex_scalar::flatbuffers as scalar; + } } } diff --git a/vortex-ipc/src/messages.rs b/vortex-ipc/src/messages.rs index 46a744b5c8..80b34f333e 100644 --- a/vortex-ipc/src/messages.rs +++ b/vortex-ipc/src/messages.rs @@ -1,6 +1,6 @@ use flatbuffers::{FlatBufferBuilder, WIPOffset}; use itertools::Itertools; -use vortex::flatbuffers::array as fba; +use vortex::flatbuffers as fba; use vortex::{ArrayData, Context, ViewContext}; use vortex_dtype::DType; use vortex_error::{vortex_err, VortexError}; @@ -17,11 +17,15 @@ pub(crate) enum IPCMessage<'a> { } pub(crate) struct IPCContext<'a>(pub &'a ViewContext); + pub(crate) struct IPCSchema<'a>(pub &'a DType); + pub(crate) struct IPCChunk<'a>(pub &'a ViewContext, pub &'a ArrayData); + pub(crate) struct IPCArray<'a>(pub &'a ViewContext, pub &'a ArrayData); impl FlatBufferRoot for IPCMessage<'_> {} + impl WriteFlatBuffer for IPCMessage<'_> { type Target<'a> = fb::Message<'a>; @@ -186,6 +190,8 @@ impl<'a> WriteFlatBuffer for IPCArray<'a> { .collect_vec(); let children = Some(fbb.create_vector(&children)); + let stats = Some(self.1.statistics().write_flatbuffer(fbb)); + fba::Array::create( fbb, &fba::ArrayArgs { @@ -193,6 +199,7 @@ impl<'a> WriteFlatBuffer for IPCArray<'a> { has_buffer: column_data.buffer().is_some(), encoding, metadata, + stats, children, }, ) diff --git a/vortex-ipc/src/reader.rs b/vortex-ipc/src/reader.rs index a7502d4b22..f34e44fcab 100644 --- a/vortex-ipc/src/reader.rs +++ b/vortex-ipc/src/reader.rs @@ -278,6 +278,7 @@ impl<'iter, R: Read> FallibleLendingIterator for StreamArrayReader<'iter, R> { .unwrap() .array() .unwrap(); + let view = ArrayView::try_new(self.ctx, &self.dtype, col_array, self.buffers.as_slice())?; // Validate it @@ -384,7 +385,8 @@ mod tests { use vortex::array::chunked::{Chunked, ChunkedArray}; use vortex::array::primitive::{Primitive, PrimitiveArray, PrimitiveEncoding}; use vortex::encoding::{ArrayEncoding, EncodingId, EncodingRef}; - use vortex::{Array, ArrayDType, ArrayDef, Context, IntoArray, OwnedArray}; + use vortex::stats::{ArrayStatistics, Stat}; + use vortex::{Array, ArrayDType, ArrayDef, Context, IntoArray, OwnedArray, ToStatic}; use vortex_alp::{ALPArray, ALPEncoding}; use vortex_dtype::NativePType; use vortex_error::VortexResult; @@ -469,6 +471,46 @@ mod tests { ); } + #[test] + fn test_stats() { + let data = PrimitiveArray::from((0i32..3_000_000).collect_vec()).into_array(); + // calculate stats on the input array so that the output array will also have stats + data.statistics().compute_min::().unwrap(); + + let data = round_trip(&data); + verify_stats(&data); + + let run_count: u64 = data.statistics().get_as::(Stat::RunCount).unwrap(); + assert_eq!(run_count, 3000000); + } + + #[test] + fn test_stats_chunked() { + let array = PrimitiveArray::from((0i32..1_500_000).collect_vec()).into_array(); + let array2 = PrimitiveArray::from((1_500_000i32..3_000_000).collect_vec()).into_array(); + + // calculate stats on the input array so that the output array will also have stats + array.statistics().compute_min::().unwrap(); + array2.statistics().compute_min::().unwrap(); + let chunked_array = + ChunkedArray::try_new(vec![array.clone(), array2], array.dtype().clone()) + .unwrap() + .into_array(); + + let data = round_trip(&chunked_array); + + // NB: data is an ArrayData constructed from the result of calling read_array on an array + // reader. compute on a ChunkedArray calls get_or_compute on the underlying chunks and + // merges the results, while get does not. Thus we need to compute a stat and force this + // merge computation before we can test get() + data.statistics().compute(Stat::Min).unwrap(); + verify_stats(&data); + + // TODO(@jcasale): run_count calculation is wrong for chunked arrays, this should be 3mm + let run_count: u64 = data.statistics().get_as::(Stat::RunCount).unwrap(); + assert_eq!(run_count, 3000001); + } + #[test] fn test_empty_index() { let data = PrimitiveArray::from((0i32..3_000_000).rev().collect_vec()).into_array(); @@ -743,4 +785,77 @@ mod tests { assert!(result_iter.next().unwrap().is_none()); Ok(result.unwrap()) } + + fn verify_stats(data: &Array) { + let min: i32 = data + .statistics() + .get(Stat::Min) + .unwrap() + .as_ref() + .try_into() + .unwrap(); + assert_eq!(min, 0); + let max: i32 = data + .statistics() + .get(Stat::Max) + .unwrap() + .as_ref() + .try_into() + .unwrap(); + assert_eq!(max, 2_999_999); + let is_sorted = data.statistics().get_as::(Stat::IsSorted).unwrap(); + assert!(is_sorted); + let is_strict_sorted: bool = data + .statistics() + .get_as::(Stat::IsStrictSorted) + .unwrap(); + assert!(is_strict_sorted); + let is_constant: bool = data.statistics().get_as::(Stat::IsConstant).unwrap(); + assert!(!is_constant); + + let null_ct: u64 = data.statistics().get_as::(Stat::NullCount).unwrap(); + assert_eq!(null_ct, 0); + let bit_width_freq = data + .statistics() + .get_as::>(Stat::BitWidthFreq) + .unwrap(); + assert_eq!( + bit_width_freq, + vec![ + 1, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, + 65536, 131072, 262144, 524288, 1048576, 902848, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ] + ); + let trailing_zero_freq = data + .statistics() + .get_as::>(Stat::TrailingZeroFreq) + .unwrap(); + assert_eq!( + trailing_zero_freq, + vec![ + 1500000, 750000, 375000, 187500, 93750, 46875, 23437, 11719, 5859, 2930, 1465, 732, + 366, 183, 92, 46, 23, 11, 6, 3, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ] + ); + data.statistics() + .compute_true_count() + .expect_err("Should not be able to calculate true count for non-boolean array"); + } + + fn round_trip<'a>(chunked_array: &'a Array<'a>) -> Array<'a> { + let context = Context::default(); + let mut buffer = vec![]; + { + let mut cursor = Cursor::new(&mut buffer); + { + let mut writer = StreamWriter::try_new(&mut cursor, &context).unwrap(); + writer.write_array(chunked_array).unwrap(); + } + } + + let mut cursor = Cursor::new(&buffer); + let mut reader = StreamReader::try_new(&mut cursor, &context).unwrap(); + let data = reader.read_array().unwrap(); + data.to_static() + } }