From 6e7125b61f2ff587a09dbe45ab05d2f28632a702 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 14 Aug 2024 10:41:38 +0900 Subject: [PATCH] GH-43454: [C++][Python] Add Opaque canonical extension type (#43458) ### Rationale for this change Add the newly ratified extension type. ### What changes are included in this PR? The C++/Python implementation only. ### Are these changes tested? Yes ### Are there any user-facing changes? No. * GitHub Issue: #43454 Lead-authored-by: David Li Co-authored-by: Weston Pace Signed-off-by: David Li --- cpp/src/arrow/CMakeLists.txt | 1 + .../compute/kernels/scalar_cast_numeric.cc | 23 ++ cpp/src/arrow/extension/CMakeLists.txt | 6 + cpp/src/arrow/extension/opaque.cc | 109 ++++++++++ cpp/src/arrow/extension/opaque.h | 69 ++++++ cpp/src/arrow/extension/opaque_test.cc | 197 ++++++++++++++++++ docs/source/python/api/arrays.rst | 3 + docs/source/python/api/datatypes.rst | 10 + python/pyarrow/__init__.py | 8 +- python/pyarrow/array.pxi | 28 +++ python/pyarrow/includes/libarrow.pxd | 13 ++ python/pyarrow/lib.pxd | 5 + python/pyarrow/public-api.pxi | 2 + python/pyarrow/scalar.pxi | 6 + python/pyarrow/tests/test_extension_type.py | 46 ++++ python/pyarrow/tests/test_misc.py | 3 + python/pyarrow/types.pxi | 101 +++++++++ 17 files changed, 627 insertions(+), 3 deletions(-) create mode 100644 cpp/src/arrow/extension/opaque.cc create mode 100644 cpp/src/arrow/extension/opaque.h create mode 100644 cpp/src/arrow/extension/opaque_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 9c66a58c54261..67d2c19f98a2d 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -907,6 +907,7 @@ endif() if(ARROW_JSON) arrow_add_object_library(ARROW_JSON extension/fixed_shape_tensor.cc + extension/opaque.cc json/options.cc json/chunked_builder.cc json/chunker.cc diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 3df86e7d6936c..bd9be3e8a9532 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -865,6 +865,25 @@ std::shared_ptr GetCastToHalfFloat() { return func; } +struct NullExtensionTypeMatcher : public TypeMatcher { + ~NullExtensionTypeMatcher() override = default; + + bool Matches(const DataType& type) const override { + return type.id() == Type::EXTENSION && + checked_cast(type).storage_id() == Type::NA; + } + + std::string ToString() const override { return "extension"; } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast(&other); + return casted != nullptr; + } +}; + } // namespace std::vector> GetNumericCasts() { @@ -875,6 +894,10 @@ std::vector> GetNumericCasts() { auto cast_null = std::make_shared("cast_null", Type::NA); DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(), OutputAllNull)); + // Explicitly allow casting extension type with null backing array to null + DCHECK_OK(cast_null->AddKernel( + Type::EXTENSION, {InputType(std::make_shared())}, null(), + OutputAllNull)); functions.push_back(cast_null); functions.push_back(GetCastToInteger("cast_int8")); diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index c15c42874d4de..6741ab602f50b 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -21,4 +21,10 @@ add_arrow_test(test PREFIX "arrow-fixed-shape-tensor") +add_arrow_test(test + SOURCES + opaque_test.cc + PREFIX + "arrow-extension-opaque") + arrow_install_all_headers("arrow/extension") diff --git a/cpp/src/arrow/extension/opaque.cc b/cpp/src/arrow/extension/opaque.cc new file mode 100644 index 0000000000000..c430bb5d2eaab --- /dev/null +++ b/cpp/src/arrow/extension/opaque.cc @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/opaque.h" + +#include + +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/util/logging.h" + +#include +#include +#include + +namespace arrow::extension { + +std::string OpaqueType::ToString(bool show_metadata) const { + std::stringstream ss; + ss << "extension<" << this->extension_name() + << "[storage_type=" << storage_type_->ToString(show_metadata) + << ", type_name=" << type_name_ << ", vendor_name=" << vendor_name_ << "]>"; + return ss.str(); +} + +bool OpaqueType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& opaque = internal::checked_cast(other); + return storage_type()->Equals(*opaque.storage_type()) && + type_name() == opaque.type_name() && vendor_name() == opaque.vendor_name(); +} + +std::string OpaqueType::Serialize() const { + rapidjson::Document document; + document.SetObject(); + rapidjson::Document::AllocatorType& allocator = document.GetAllocator(); + + rapidjson::Value type_name(rapidjson::StringRef(type_name_)); + document.AddMember(rapidjson::Value("type_name", allocator), type_name, allocator); + rapidjson::Value vendor_name(rapidjson::StringRef(vendor_name_)); + document.AddMember(rapidjson::Value("vendor_name", allocator), vendor_name, allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + document.Accept(writer); + return buffer.GetString(); +} + +Result> OpaqueType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized_data) const { + rapidjson::Document document; + const auto& parsed = document.Parse(serialized_data.data(), serialized_data.length()); + if (parsed.HasParseError()) { + return Status::Invalid("Invalid serialized JSON data for OpaqueType: ", + rapidjson::GetParseError_En(parsed.GetParseError()), ": ", + serialized_data); + } else if (!document.IsObject()) { + return Status::Invalid("Invalid serialized JSON data for OpaqueType: not an object"); + } + if (!document.HasMember("type_name")) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: missing type_name"); + } else if (!document.HasMember("vendor_name")) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: missing vendor_name"); + } + + const auto& type_name = document["type_name"]; + const auto& vendor_name = document["vendor_name"]; + if (!type_name.IsString()) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: type_name is not a string"); + } else if (!vendor_name.IsString()) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: vendor_name is not a string"); + } + + return opaque(std::move(storage_type), type_name.GetString(), vendor_name.GetString()); +} + +std::shared_ptr OpaqueType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.opaque", + internal::checked_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +std::shared_ptr opaque(std::shared_ptr storage_type, + std::string type_name, std::string vendor_name) { + return std::make_shared(std::move(storage_type), std::move(type_name), + std::move(vendor_name)); +} + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/opaque.h b/cpp/src/arrow/extension/opaque.h new file mode 100644 index 0000000000000..9814b391cbad6 --- /dev/null +++ b/cpp/src/arrow/extension/opaque.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension_type.h" +#include "arrow/type.h" + +namespace arrow::extension { + +/// \brief Opaque is a placeholder for a type from an external (usually +/// non-Arrow) system that could not be interpreted. +class ARROW_EXPORT OpaqueType : public ExtensionType { + public: + /// \brief Construct an OpaqueType. + /// + /// \param[in] storage_type The underlying storage type. Should be + /// arrow::null if there is no data. + /// \param[in] type_name The name of the type in the external system. + /// \param[in] vendor_name The name of the external system. + explicit OpaqueType(std::shared_ptr storage_type, std::string type_name, + std::string vendor_name) + : ExtensionType(std::move(storage_type)), + type_name_(std::move(type_name)), + vendor_name_(std::move(vendor_name)) {} + + std::string extension_name() const override { return "arrow.opaque"; } + std::string ToString(bool show_metadata) const override; + bool ExtensionEquals(const ExtensionType& other) const override; + std::string Serialize() const override; + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const override; + /// Create an OpaqueArray from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + std::string_view type_name() const { return type_name_; } + std::string_view vendor_name() const { return vendor_name_; } + + private: + std::string type_name_; + std::string vendor_name_; +}; + +/// \brief Opaque is a wrapper for (usually binary) data from an external +/// (often non-Arrow) system that could not be interpreted. +class ARROW_EXPORT OpaqueArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Return an OpaqueType instance. +ARROW_EXPORT std::shared_ptr opaque(std::shared_ptr storage_type, + std::string type_name, + std::string vendor_name); + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/opaque_test.cc b/cpp/src/arrow/extension/opaque_test.cc new file mode 100644 index 0000000000000..1629cdb39651c --- /dev/null +++ b/cpp/src/arrow/extension/opaque_test.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/extension/fixed_shape_tensor.h" +#include "arrow/extension/opaque.h" +#include "arrow/extension_type.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/testing/extension_type.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { + +TEST(OpaqueType, Basics) { + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type2 = internal::checked_pointer_cast( + extension::opaque(null(), "type2", "vendor")); + ASSERT_EQ("arrow.opaque", type->extension_name()); + ASSERT_EQ(*type, *type); + ASSERT_NE(*arrow::null(), *type); + ASSERT_NE(*type, *type2); + ASSERT_EQ(*arrow::null(), *type->storage_type()); + ASSERT_THAT(type->Serialize(), ::testing::Not(::testing::IsEmpty())); + ASSERT_EQ(R"({"type_name":"type","vendor_name":"vendor"})", type->Serialize()); + ASSERT_EQ("type", type->type_name()); + ASSERT_EQ("vendor", type->vendor_name()); + ASSERT_EQ( + "extension", + type->ToString(false)); +} + +TEST(OpaqueType, Equals) { + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type2 = internal::checked_pointer_cast( + extension::opaque(null(), "type2", "vendor")); + auto type3 = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor2")); + auto type4 = internal::checked_pointer_cast( + extension::opaque(int64(), "type", "vendor")); + auto type5 = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type6 = internal::checked_pointer_cast( + extension::fixed_shape_tensor(float64(), {1})); + + ASSERT_EQ(*type, *type); + ASSERT_EQ(*type2, *type2); + ASSERT_EQ(*type3, *type3); + ASSERT_EQ(*type4, *type4); + ASSERT_EQ(*type5, *type5); + + ASSERT_EQ(*type, *type5); + + ASSERT_NE(*type, *type2); + ASSERT_NE(*type, *type3); + ASSERT_NE(*type, *type4); + ASSERT_NE(*type, *type6); + + ASSERT_NE(*type2, *type); + ASSERT_NE(*type2, *type3); + ASSERT_NE(*type2, *type4); + ASSERT_NE(*type2, *type6); + + ASSERT_NE(*type3, *type); + ASSERT_NE(*type3, *type2); + ASSERT_NE(*type3, *type4); + ASSERT_NE(*type3, *type6); + + ASSERT_NE(*type4, *type); + ASSERT_NE(*type4, *type2); + ASSERT_NE(*type4, *type3); + ASSERT_NE(*type4, *type6); + ASSERT_NE(*type6, *type4); +} + +TEST(OpaqueType, CreateFromArray) { + auto type = internal::checked_pointer_cast( + extension::opaque(binary(), "geometry", "adbc.postgresql")); + auto storage = ArrayFromJSON(binary(), R"(["foobar", null])"); + auto array = ExtensionType::WrapArray(type, storage); + ASSERT_EQ(2, array->length()); + ASSERT_EQ(1, array->null_count()); +} + +void CheckDeserialize(const std::string& serialized, + const std::shared_ptr& expected) { + auto type = internal::checked_pointer_cast(expected); + ASSERT_OK_AND_ASSIGN(auto deserialized, + type->Deserialize(type->storage_type(), serialized)); + ASSERT_EQ(*expected, *deserialized); +} + +TEST(OpaqueType, Deserialize) { + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "type", "vendor_name": "vendor"})", + extension::opaque(null(), "type", "vendor"))); + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "long name", "vendor_name": "long name"})", + extension::opaque(null(), "long name", "long name"))); + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "名前", "vendor_name": "名字"})", + extension::opaque(null(), "名前", "名字"))); + ASSERT_NO_FATAL_FAILURE(CheckDeserialize( + R"({"type_name": "type", "vendor_name": "vendor", "extra_field": 2})", + extension::opaque(null(), "type", "vendor"))); + + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("The document is empty"), + type->Deserialize(null(), R"()")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + testing::HasSubstr("Missing a name for object member"), + type->Deserialize(null(), R"({)")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("not an object"), + type->Deserialize(null(), R"([])")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("missing type_name"), + type->Deserialize(null(), R"({})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("type_name is not a string"), + type->Deserialize(null(), R"({"type_name": 2, "vendor_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("type_name is not a string"), + type->Deserialize(null(), R"({"type_name": null, "vendor_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("vendor_name is not a string"), + type->Deserialize(null(), R"({"vendor_name": 2, "type_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("vendor_name is not a string"), + type->Deserialize(null(), R"({"vendor_name": null, "type_name": ""})")); +} + +TEST(OpaqueType, MetadataRoundTrip) { + for (const auto& type : { + extension::opaque(null(), "foo", "bar"), + extension::opaque(binary(), "geometry", "postgis"), + extension::opaque(fixed_size_list(int64(), 4), "foo", "bar"), + extension::opaque(utf8(), "foo", "bar"), + }) { + auto opaque = internal::checked_pointer_cast(type); + std::string serialized = opaque->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + opaque->Deserialize(opaque->storage_type(), serialized)); + ASSERT_EQ(*type, *deserialized); + } +} + +TEST(OpaqueType, BatchRoundTrip) { + auto type = internal::checked_pointer_cast( + extension::opaque(binary(), "geometry", "adbc.postgresql")); + ExtensionTypeGuard guard(type); + + auto storage = ArrayFromJSON(binary(), R"(["foobar", null])"); + auto array = ExtensionType::WrapArray(type, storage); + auto batch = + RecordBatch::Make(schema({field("field", type)}), array->length(), {array}); + + std::shared_ptr written; + { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(&written)); + } + + ASSERT_EQ(*batch->schema(), *written->schema()); + ASSERT_BATCHES_EQUAL(*batch, *written); +} + +} // namespace arrow diff --git a/docs/source/python/api/arrays.rst b/docs/source/python/api/arrays.rst index aefed00b3d2e0..4ad35b190cdd0 100644 --- a/docs/source/python/api/arrays.rst +++ b/docs/source/python/api/arrays.rst @@ -85,6 +85,7 @@ may expose data type-specific methods or properties. UnionArray ExtensionArray FixedShapeTensorArray + OpaqueArray .. _api.scalar: @@ -143,3 +144,5 @@ classes may expose data type-specific methods or properties. StructScalar UnionScalar ExtensionScalar + FixedShapeTensorScalar + OpaqueScalar diff --git a/docs/source/python/api/datatypes.rst b/docs/source/python/api/datatypes.rst index 7edb4e161541d..a43c5299eae51 100644 --- a/docs/source/python/api/datatypes.rst +++ b/docs/source/python/api/datatypes.rst @@ -67,6 +67,8 @@ These should be used to create Arrow data types and schemas. struct dictionary run_end_encoded + fixed_shape_tensor + opaque field schema from_numpy_dtype @@ -117,6 +119,14 @@ Specific classes and functions for extension types. register_extension_type unregister_extension_type +:doc:`Canonical extension types <../../format/CanonicalExtensions>` +implemented by PyArrow. + +.. autosummary:: + :toctree: ../generated/ + + FixedShapeTensorType + OpaqueType .. _api.types.checking: .. currentmodule:: pyarrow.types diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index e52e0d242bee5..aa7bab9f97e05 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -173,6 +173,7 @@ def print_entry(label, value): dictionary, run_end_encoded, fixed_shape_tensor, + opaque, field, type_for_alias, DataType, DictionaryType, StructType, @@ -182,7 +183,7 @@ def print_entry(label, value): TimestampType, Time32Type, Time64Type, DurationType, FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, - RunEndEncodedType, FixedShapeTensorType, + RunEndEncodedType, FixedShapeTensorType, OpaqueType, PyExtensionType, UnknownExtensionType, register_extension_type, unregister_extension_type, DictionaryMemo, @@ -216,7 +217,7 @@ def print_entry(label, value): Time32Array, Time64Array, DurationArray, MonthDayNanoIntervalArray, Decimal128Array, Decimal256Array, StructArray, ExtensionArray, - RunEndEncodedArray, FixedShapeTensorArray, + RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray, scalar, NA, _NULL as NULL, Scalar, NullScalar, BooleanScalar, Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, @@ -233,7 +234,8 @@ def print_entry(label, value): StringScalar, LargeStringScalar, StringViewScalar, FixedSizeBinaryScalar, DictionaryScalar, MapScalar, StructScalar, UnionScalar, - RunEndEncodedScalar, ExtensionScalar) + RunEndEncodedScalar, ExtensionScalar, + FixedShapeTensorScalar, OpaqueScalar) # Buffers, allocation from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 997f208a5dec4..6c40a21db96ca 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -4448,6 +4448,34 @@ cdef class FixedShapeTensorArray(ExtensionArray): ) +cdef class OpaqueArray(ExtensionArray): + """ + Concrete class for opaque extension arrays. + + Examples + -------- + Define the extension type for an opaque array + + >>> import pyarrow as pa + >>> opaque_type = pa.opaque( + ... pa.binary(), + ... type_name="geometry", + ... vendor_name="postgis", + ... ) + + Create an extension array + + >>> arr = [None, b"data"] + >>> storage = pa.array(arr, pa.binary()) + >>> pa.ExtensionArray.from_storage(opaque_type, storage) + + [ + null, + 64617461 + ] + """ + + cdef dict _array_classes = { _Type_NA: NullArray, _Type_BOOL: BooleanArray, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 0d871f411b11b..9b008d150f1f1 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2882,6 +2882,19 @@ cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extens " arrow::extension::FixedShapeTensorArray"(CExtensionArray): const CResult[shared_ptr[CTensor]] ToTensor() const + +cdef extern from "arrow/extension/opaque.h" namespace "arrow::extension" nogil: + cdef cppclass COpaqueType \ + " arrow::extension::OpaqueType"(CExtensionType): + + c_string type_name() + c_string vendor_name() + + cdef cppclass COpaqueArray \ + " arrow::extension::OpaqueArray"(CExtensionArray): + pass + + cdef extern from "arrow/util/compression.h" namespace "arrow" nogil: cdef enum CCompressionType" arrow::Compression::type": CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED" diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 082d8470cdbb0..2cb302d20a8ac 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -215,6 +215,11 @@ cdef class FixedShapeTensorType(BaseExtensionType): const CFixedShapeTensorType* tensor_ext_type +cdef class OpaqueType(BaseExtensionType): + cdef: + const COpaqueType* opaque_ext_type + + cdef class PyExtensionType(ExtensionType): pass diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 966273b4bea84..2f9fc1c554209 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -124,6 +124,8 @@ cdef api object pyarrow_wrap_data_type( return cpy_ext_type.GetInstance() elif ext_type.extension_name() == b"arrow.fixed_shape_tensor": out = FixedShapeTensorType.__new__(FixedShapeTensorType) + elif ext_type.extension_name() == b"arrow.opaque": + out = OpaqueType.__new__(OpaqueType) else: out = BaseExtensionType.__new__(BaseExtensionType) else: diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 41bfde39adb6f..12a99c2aece63 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -1085,6 +1085,12 @@ cdef class FixedShapeTensorScalar(ExtensionScalar): return pyarrow_wrap_tensor(ctensor) +cdef class OpaqueScalar(ExtensionScalar): + """ + Concrete class for opaque extension scalar. + """ + + cdef dict _scalar_classes = { _Type_BOOL: BooleanScalar, _Type_UINT8: UInt8Scalar, diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 1c4d0175a2d97..58c54189f223e 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1661,3 +1661,49 @@ def test_legacy_int_type(): batch = ipc_read_batch(buf) assert isinstance(batch.column(0).type, LegacyIntType) assert batch.column(0) == ext_arr + + +@pytest.mark.parametrize("storage_type,storage", [ + (pa.null(), [None] * 4), + (pa.int64(), [1, 2, None, 4]), + (pa.binary(), [None, b"foobar"]), + (pa.list_(pa.int64()), [[], [1, 2], None, [3, None]]), +]) +def test_opaque_type(pickle_module, storage_type, storage): + opaque_type = pa.opaque(storage_type, "type", "vendor") + assert opaque_type.extension_name == "arrow.opaque" + assert opaque_type.storage_type == storage_type + assert opaque_type.type_name == "type" + assert opaque_type.vendor_name == "vendor" + assert "arrow.opaque" in str(opaque_type) + + assert opaque_type == opaque_type + assert opaque_type != storage_type + assert opaque_type != pa.opaque(storage_type, "type2", "vendor") + assert opaque_type != pa.opaque(storage_type, "type", "vendor2") + assert opaque_type != pa.opaque(pa.decimal128(12, 3), "type", "vendor") + + # Pickle roundtrip + result = pickle_module.loads(pickle_module.dumps(opaque_type)) + assert result == opaque_type + + # IPC roundtrip + opaque_arr_class = opaque_type.__arrow_ext_class__() + storage = pa.array(storage, storage_type) + arr = pa.ExtensionArray.from_storage(opaque_type, storage) + assert isinstance(arr, opaque_arr_class) + + with registered_extension_type(opaque_type): + buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) + batch = ipc_read_batch(buf) + + assert batch.column(0).type.extension_name == "arrow.opaque" + assert isinstance(batch.column(0), opaque_arr_class) + + # cast storage -> extension type + result = storage.cast(opaque_type) + assert result == arr + + # cast extension type -> storage type + inner = arr.cast(storage_type) + assert inner == storage diff --git a/python/pyarrow/tests/test_misc.py b/python/pyarrow/tests/test_misc.py index c42e4fbdfc2e8..9a55a38177fc8 100644 --- a/python/pyarrow/tests/test_misc.py +++ b/python/pyarrow/tests/test_misc.py @@ -247,6 +247,9 @@ def test_set_timezone_db_path_non_windows(): pa.ProxyMemoryPool, pa.Device, pa.MemoryManager, + pa.OpaqueArray, + pa.OpaqueScalar, + pa.OpaqueType, ]) def test_extension_type_constructor_errors(klass): # ARROW-2638: prevent calling extension class constructors directly diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 039870accddcb..93d68fb847890 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1837,6 +1837,50 @@ cdef class FixedShapeTensorType(BaseExtensionType): return FixedShapeTensorScalar +cdef class OpaqueType(BaseExtensionType): + """ + Concrete class for opaque extension type. + + Opaque is a placeholder for a type from an external (often non-Arrow) + system that could not be interpreted. + + Examples + -------- + Create an instance of opaque extension type: + + >>> import pyarrow as pa + >>> pa.opaque(pa.int32(), "geometry", "postgis") + OpaqueType(extension) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.opaque_ext_type = type.get() + + @property + def type_name(self): + """ + The name of the type in the external system. + """ + return frombytes(c_string(self.opaque_ext_type.type_name())) + + @property + def vendor_name(self): + """ + The name of the external system. + """ + return frombytes(c_string(self.opaque_ext_type.vendor_name())) + + def __arrow_ext_class__(self): + return OpaqueArray + + def __reduce__(self): + return opaque, (self.storage_type, self.type_name, self.vendor_name) + + def __arrow_ext_scalar_class__(self): + return OpaqueScalar + + _py_extension_type_auto_load = False @@ -5234,6 +5278,63 @@ def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=N return out +def opaque(DataType storage_type, str type_name not None, str vendor_name not None): + """ + Create instance of opaque extension type. + + Parameters + ---------- + storage_type : DataType + The underlying data type. + type_name : str + The name of the type in the external system. + vendor_name : str + The name of the external system. + + Examples + -------- + Create an instance of an opaque extension type: + + >>> import pyarrow as pa + >>> type = pa.opaque(pa.binary(), "other", "jdbc") + >>> type + OpaqueType(extension) + + Inspect the data type: + + >>> type.storage_type + DataType(binary) + >>> type.type_name + 'other' + >>> type.vendor_name + 'jdbc' + + Create a table with an opaque array: + + >>> arr = [None, b"foobar"] + >>> storage = pa.array(arr, pa.binary()) + >>> other = pa.ExtensionArray.from_storage(type, storage) + >>> pa.table([other], names=["unknown_col"]) + pyarrow.Table + unknown_col: extension + ---- + unknown_col: [[null,666F6F626172]] + + Returns + ------- + type : OpaqueType + """ + + cdef: + c_string c_type_name = tobytes(type_name) + c_string c_vendor_name = tobytes(vendor_name) + shared_ptr[CDataType] c_type = make_shared[COpaqueType]( + storage_type.sp_type, c_type_name, c_vendor_name) + OpaqueType out = OpaqueType.__new__(OpaqueType) + out.init(c_type) + return out + + cdef dict _type_aliases = { 'null': null, 'bool': bool_,