From 9e61a46ce812357e10fa08d60b1efd3fba57641f Mon Sep 17 00:00:00 2001 From: LinGeLin <1057445597@qq.com> Date: Wed, 25 May 2022 21:27:36 +0800 Subject: [PATCH 01/13] Added ArrowS3Dataset --- tensorflow_io/arrow.py | 3 + .../core/kernels/arrow/arrow_dataset_ops.cc | 198 ++++++++++++++++++ .../core/kernels/arrow/arrow_util.cc | 75 ++++++- tensorflow_io/core/kernels/arrow/arrow_util.h | 18 ++ tensorflow_io/core/ops/arrow_ops.cc | 28 +++ tensorflow_io/python/ops/arrow_dataset_ops.py | 81 +++++++ third_party/arrow.BUILD | 11 + third_party/aws-sdk-cpp.BUILD | 54 +++++ 8 files changed, 467 insertions(+), 1 deletion(-) diff --git a/tensorflow_io/arrow.py b/tensorflow_io/arrow.py index 44de3253c..e6265af2a 100644 --- a/tensorflow_io/arrow.py +++ b/tensorflow_io/arrow.py @@ -17,6 +17,7 @@ @@ArrowDataset @@ArrowFeatherDataset @@ArrowStreamDataset +@@ArrowS3Dataset @@list_feather_columns """ @@ -26,6 +27,7 @@ from tensorflow_io.python.ops.arrow_dataset_ops import ArrowDataset from tensorflow_io.python.ops.arrow_dataset_ops import ArrowFeatherDataset from tensorflow_io.python.ops.arrow_dataset_ops import ArrowStreamDataset +from tensorflow_io.python.ops.arrow_dataset_ops import ArrowS3Dataset from tensorflow_io.python.ops.arrow_dataset_ops import list_feather_columns @@ -33,6 +35,7 @@ "ArrowDataset", "ArrowFeatherDataset", "ArrowStreamDataset", + "ArrowS3Dataset", "list_feather_columns", ] diff --git a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc index 7716391a9..4efb93bf2 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc @@ -937,6 +937,201 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase { }; }; +class ArrowS3DatasetOp : public ArrowOpKernelBase { + public: + explicit ArrowS3DatasetOp(OpKernelConstruction* ctx) + : ArrowOpKernelBase(ctx) {} + + virtual void MakeArrowDataset( + OpKernelContext* ctx, const std::vector& columns, + const int64 batch_size, const ArrowBatchMode batch_mode, + const DataTypeVector& output_types, + const std::vector& output_shapes, + ArrowDatasetBase** output) override { + tstring aws_access_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "aws_access_key", + &aws_access_key)); + + tstring aws_secret_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "aws_secret_key", + &aws_secret_key)); + + tstring aws_endpoint_override; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "aws_endpoint_override", + &aws_endpoint_override)); + + const Tensor* parquet_files_tensor; + OP_REQUIRES_OK(ctx, ctx->input("parquet_files", &parquet_files_tensor)); + OP_REQUIRES( + ctx, parquet_files_tensor->dims() <= 1, + errors::InvalidArgument("`parquet_files` must be a scalar or vector.")); + std::vector parquet_files; + parquet_files.reserve(parquet_files_tensor->NumElements()); + for (int i = 0; i < parquet_files_tensor->NumElements(); ++i) { + parquet_files.push_back(parquet_files_tensor->flat()(i)); + } + + const Tensor* column_names_tensor; + OP_REQUIRES_OK(ctx, ctx->input("column_names", &column_names_tensor)); + OP_REQUIRES( + ctx, column_names_tensor->dims() <= 1, + errors::InvalidArgument("`column_names` must be a scalar or vector.")); + std::vector column_names; + column_names.reserve(column_names_tensor->NumElements()); + for (int i = 0; i < column_names_tensor->NumElements(); ++i) { + column_names.push_back(column_names_tensor->flat()(i)); + } + + std::vector column_cols; + auto s3Dataset = ArrowUtil::GetS3Dataset( + aws_access_key, aws_secret_key, aws_endpoint_override, parquet_files); + auto status = ArrowUtil::GetColumns(s3Dataset, column_names, column_cols); + + int64 offset; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "offset", &offset)); + + int64 max_rows; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_rows", &max_rows)); + + *output = + new Dataset(ctx, aws_access_key, aws_secret_key, aws_endpoint_override, + parquet_files, column_names, offset, max_rows, column_cols, + batch_size, batch_mode, output_types_, output_shapes_); + } + + private: + class Dataset : public ArrowDatasetBase { + public: + Dataset(OpKernelContext* ctx, const std::string& aws_access_key, + const std::string& aws_secret_key, + const std::string& aws_endpoint_override, + const std::vector& parquet_files, + const std::vector& column_names, int64 offset, + int64 max_rows, const std::vector columns, + const int64 batch_size, const ArrowBatchMode batch_mode, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : ArrowDatasetBase(ctx, columns, batch_size, batch_mode, output_types, + output_shapes), + aws_access_key_(aws_access_key), + aws_secret_key_(aws_secret_key), + aws_endpoint_override_(aws_endpoint_override), + parquet_files_(parquet_files), + column_names_(column_names), + offset_(offset), + max_rows_(max_rows) {} + + string DebugString() const override { return "ArrowS3DatasetOp::Dataset"; } + Status InputDatasets(std::vector* inputs) const { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* aws_access_key = nullptr; + tstring access_key = aws_access_key_; + TF_RETURN_IF_ERROR(b->AddScalar(access_key, &aws_access_key)); + Node* aws_secret_key = nullptr; + tstring secret_key = aws_secret_key_; + TF_RETURN_IF_ERROR(b->AddScalar(secret_key, &aws_secret_key)); + Node* aws_endpoint_override = nullptr; + tstring endpoint_override = aws_endpoint_override_; + TF_RETURN_IF_ERROR( + b->AddScalar(endpoint_override, &aws_endpoint_override)); + Node* parquet_files = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(parquet_files_, &parquet_files)); + Node* column_names = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(column_names_, &column_names)); + Node* offset = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(offset_, &offset)); + Node* max_rows = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(max_rows_, &max_rows)); + Node* columns = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(columns_, &columns)); + Node* batch_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); + Node* batch_mode = nullptr; + tstring batch_mode_str; + TF_RETURN_IF_ERROR(GetBatchModeStr(batch_mode_, &batch_mode_str)); + TF_RETURN_IF_ERROR(b->AddScalar(batch_mode_str, &batch_mode)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {aws_access_key, aws_secret_key, aws_endpoint_override, parquet_files, + column_names, offset, max_rows, columns, batch_size, batch_mode}, + output)); + return Status::OK(); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::ArrowS3")})); + } + + private: + class Iterator : public ArrowBaseIterator { + public: + explicit Iterator(const Params& params) + : ArrowBaseIterator(params) {} + + private: + Status SetupStreamsLocked(Env* env) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + auto parquet_dataset = ArrowUtil::GetS3Dataset( + dataset()->aws_access_key_, dataset()->aws_secret_key_, + dataset()->aws_endpoint_override_, dataset()->parquet_files_); + auto scanner_builder = parquet_dataset->NewScan().ValueOrDie(); + using arrow::compute::and_; + using arrow::compute::field_ref; + using arrow::compute::greater_equal; + using arrow::compute::literal; + arrow::compute::Expression filter_rowstokeep = and_( + {greater_equal(field_ref(arrow::FieldRef(K_ROW_INDEX_COLUMN_NAME)), + literal(dataset()->offset_)), + less(field_ref(arrow::FieldRef(K_ROW_INDEX_COLUMN_NAME)), + literal(dataset()->offset_ + dataset()->max_rows_))}); + + scanner_builder->Project(dataset()->column_names_); + if (dataset()->max_rows_ > 0) { + scanner_builder->Filter(filter_rowstokeep); + } + scanner_builder->BatchSize(K_DEFAULT_BATCH_SIZE); + auto scanner = scanner_builder->Finish().ValueOrDie(); + reader_ = scanner->ToRecordBatchReader().ValueOrDie(); + CHECK_ARROW(reader_->ReadNext(¤t_batch_)); + return Status::OK(); + } + + Status NextStreamLocked(Env* env) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + ArrowBaseIterator::NextStreamLocked(env); + CHECK_ARROW(reader_->ReadNext(¤t_batch_)); + return Status::OK(); + } + + void ResetStreamsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + ArrowBaseIterator::ResetStreamsLocked(); + reader_.reset(); + } + + size_t current_endpoint_idx_ TF_GUARDED_BY(mu_) = 0; + std::shared_ptr reader_ TF_GUARDED_BY(mu_); + }; + + const std::string aws_access_key_; + const std::string aws_secret_key_; + const std::string aws_endpoint_override_; + const std::vector parquet_files_; + const std::vector column_names_; + const int64 offset_; + const int64 max_rows_; + }; +}; // class ArrowS3DatasetOp + REGISTER_KERNEL_BUILDER(Name("IO>ArrowZeroCopyDataset").Device(DEVICE_CPU), ArrowZeroCopyDatasetOp); @@ -949,5 +1144,8 @@ REGISTER_KERNEL_BUILDER(Name("IO>ArrowFeatherDataset").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("IO>ArrowStreamDataset").Device(DEVICE_CPU), ArrowStreamDatasetOp); +REGISTER_KERNEL_BUILDER(Name("IO>ArrowS3Dataset").Device(DEVICE_CPU), + ArrowS3DatasetOp); + } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.cc b/tensorflow_io/core/kernels/arrow/arrow_util.cc index b5d500883..d6ba18f59 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_util.cc @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow_io/core/kernels/arrow/arrow_util.h" +#include + #include "arrow/adapters/tensorflow/convert.h" #include "arrow/api.h" #include "arrow/ipc/api.h" @@ -470,6 +471,78 @@ Status ParseHost(std::string host, std::string* host_address, return Status::OK(); } +std::shared_ptr GetS3Dataset( + const std::string& access_key, const std::string& secret_key, + const std::string& endpoint_override, + const std::vector& parquet_files) { + afs::EnsureS3Initialized(); + + afs::S3Options s3Options = + afs::S3Options::FromAccessKey(access_key, secret_key); + s3Options.endpoint_override = endpoint_override; + s3Options.scheme = "http"; + + std::shared_ptr s3fs = + afs::S3FileSystem::Make(s3Options).ValueOrDie(); + auto format = std::make_shared(); + ads::FileSystemFactoryOptions options; + + auto factory = + ads::FileSystemDatasetFactory::Make(s3fs, parquet_files, format, options) + .ValueOrDie(); + return factory->Finish().ValueOrDie(); +} + +Status GetColumns(const std::shared_ptr& dataset, + const std::vector& column_names, + std::vector& column_cols) { + std::vector v_cols; + std::list l_cols; + int column_count = column_names.size(); + column_cols.reserve(column_count); + v_cols.reserve(column_count); + auto schema = dataset->schema(); + for (const auto& name : column_names) { + int index = schema->GetFieldIndex(name); + if (index != -1) { + v_cols.push_back(index); + + auto iter = l_cols.begin(); + if (l_cols.empty()) { + l_cols.insert(iter, index); + } else { + while (iter != l_cols.end()) { + if (*iter > index) { + l_cols.insert(iter, index); + break; + } + iter++; + } + if (iter == l_cols.end()) { + l_cols.insert(iter, index); + } + } + } else { + return errors::InvalidArgument("Column name: " + name + + " does not exist"); + } + } + + for (const auto& v_index : v_cols) { + int32 index = 0; + for (const auto& l_index : l_cols) { + if (v_index == l_index) { + column_cols.push_back(index); + break; + } else { + index++; + } + } + } + + return Status::OK(); +} + } // namespace ArrowUtil } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.h b/tensorflow_io/core/kernels/arrow/arrow_util.h index 9dabc16c5..3b5859984 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.h +++ b/tensorflow_io/core/kernels/arrow/arrow_util.h @@ -17,12 +17,18 @@ limitations under the License. #define TENSORFLOW_IO_CORE_KERNELS_ARROW_UTIL_H_ #include "arrow/api.h" +#include "arrow/dataset/dataset.h" +#include "arrow/dataset/file_parquet.h" +#include "arrow/filesystem/s3fs.h" #include "arrow/ipc/api.h" #include "arrow/util/io_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +namespace afs = arrow::fs; +namespace ads = arrow::dataset; + namespace tensorflow { // Forward declaration @@ -39,6 +45,9 @@ namespace data { } \ } while (false) +const std::string K_ROW_INDEX_COLUMN_NAME = "row_index"; +#define K_DEFAULT_BATCH_SIZE 100 * 1024 + namespace ArrowUtil { // Convert Arrow Data Type to TensorFlow @@ -80,6 +89,15 @@ Status ParseEndpoint(std::string endpoint, std::string* endpoint_type, Status ParseHost(std::string host, std::string* host_address, std::string* host_port); +std::shared_ptr GetS3Dataset( + const std::string& access_key, const std::string& secret_key, + const std::string& endpoint_override, + const std::vector& parquet_files); + +Status GetColumns(const std::shared_ptr& dataset, + const std::vector& column_names, + std::vector& column_cols); + } // namespace ArrowUtil } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/ops/arrow_ops.cc b/tensorflow_io/core/ops/arrow_ops.cc index ceff0a7c7..cad90ce55 100644 --- a/tensorflow_io/core/ops/arrow_ops.cc +++ b/tensorflow_io/core/ops/arrow_ops.cc @@ -188,4 +188,32 @@ REGISTER_OP("IO>ArrowReadableRead") return Status::OK(); }); +REGISTER_OP("IO>ArrowS3Dataset") + .Input("aws_access_key: string") + .Input("aws_secret_key: string") + .Input("aws_endpoint_override: string") + .Input("parquet_files: string") + .Input("column_names: string") + .Input("offset: int64") + .Input("max_rows: int64") + .Input("columns: int32") + .Input("batch_size: int64") + .Input("batch_mode: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset from s3 parqeut files + +aws_access_key: S3 access key. +aws_secret_key: S3 secret_key. +aws_endpoint_override: S3 endpoint override +parquet_files: One or more parqeut file path on s3 +column_names: Select columns to read by names +offset: Set the offset start to read +max_rows: Max rows count to read +)doc"); + } // namespace tensorflow diff --git a/tensorflow_io/python/ops/arrow_dataset_ops.py b/tensorflow_io/python/ops/arrow_dataset_ops.py index e051c0b75..bd93b6a7d 100644 --- a/tensorflow_io/python/ops/arrow_dataset_ops.py +++ b/tensorflow_io/python/ops/arrow_dataset_ops.py @@ -651,7 +651,88 @@ def gen_record_batches(): ) +class ArrowS3Dataset(ArrowBaseDataset): + """An Arrow Dataset for reading record batches from an input stream. + Currently supported input streams are a socket client or stdin. + """ + + def __init__( + self, + aws_access_key, + aws_secret_key, + aws_endpoint_override, + parquet_files, + column_names, + columns, + output_types, + offset=0, + max_rows=0, + output_shapes=None, + batch_size=None, + batch_mode="keep_remainder", + ): + """Create an ArrowDataset from an input stream. + + Args: + aws_access_key: S3 access key + aws_secret_key: S3 secret key + aws_endpoint_override: S3 endpoint override + parquet_files: A list of parquet files path on s3 + column_names: A list of column names to be used in the dataset + offset: Set the offset start to read + max_rows: Max rows count to read + columns: A list of column indices to be used in the Dataset + output_types: Tensor dtypes of the output tensors + output_shapes: TensorShapes of the output tensors or None to + infer partial + batch_size: Batch size of output tensors, setting a batch size here + will create batched tensors from Arrow memory and can be more + efficient than using tf.data.Dataset.batch(). + NOTE: batch_size does not need to be set if batch_mode='auto' + batch_mode: Mode of batching, supported strings: + "keep_remainder" (default, keeps partial batch data), + "drop_remainder" (discard partial batch data), + "auto" (size to number of records in Arrow record batch) + """ + aws_access_key = tf.convert_to_tensor( + aws_access_key, dtype=dtypes.string, name="aws_access_key" + ) + aws_secret_key = tf.convert_to_tensor( + aws_secret_key, dtype=dtypes.string, name="aws_secret_key" + ) + aws_endpoint_override = tf.convert_to_tensor( + aws_endpoint_override, dtype=dtypes.string, name="aws_endpoint_override" + ) + parquet_files = tf.convert_to_tensor( + parquet_files, dtype=dtypes.string, name="parquet_files" + ) + column_names = tf.convert_to_tensor( + column_names, dtype=dtypes.string, name="column_names" + ) + offset = tf.convert_to_tensor(offset, dtype=dtypes.int64, name="offset") + max_rows = tf.convert_to_tensor(max_rows, dtype=dtypes.int64, name="max_rows") + + super().__init__( + partial( + core_ops.io_arrow_s3_dataset, + aws_access_key, + aws_secret_key, + aws_endpoint_override, + parquet_files, + column_names, + offset, + max_rows, + ), + columns, + output_types, + output_shapes, + batch_size, + batch_mode, + ) + + def list_feather_columns(filename, **kwargs): + """list_feather_columns""" if not tf.executing_eagerly(): raise NotImplementedError("list_feather_columns only support eager mode") diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index 4dbce8ede..5cc04d1c6 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -62,10 +62,18 @@ cc_library( "cpp/src/arrow/json/*.cc", "cpp/src/arrow/tensor/*.cc", "cpp/src/arrow/util/*.cc", + "cpp/src/arrow/dataset/*.cc", + "cpp/src/arrow/filesystem/*.cc", + "cpp/src/arrow/compute/*.cc", + "cpp/src/arrow/compute/kernels/*.cc", + "cpp/src/arrow/compute/exec/*.cc", "cpp/src/arrow/vendored/musl/strptime.c", "cpp/src/arrow/vendored/optional.hpp", "cpp/src/arrow/vendored/string_view.hpp", "cpp/src/arrow/vendored/variant.hpp", + "cpp/src/arrow/vendored/base64.cpp", + "cpp/src/arrow/vendored/datetime/tz.cpp", + "cpp/src/arrow/vendored/uriparser/*.c", "cpp/src/arrow/**/*.h", "cpp/src/parquet/**/*.h", "cpp/src/parquet/**/*.cc", @@ -80,6 +88,7 @@ cc_library( "cpp/src/**/test_*.cc", "cpp/src/**/*hdfs*.cc", "cpp/src/**/*fuzz*.cc", + "cpp/src/**/*gcsfs*.cc", "cpp/src/**/file_to_stream.cc", "cpp/src/**/stream_to_file.cc", "cpp/src/arrow/util/bpacking_avx2.cc", @@ -116,6 +125,8 @@ cc_library( ], deps = [ ":arrow_format", + "@aws-sdk-cpp//:identity-management", + "@aws-sdk-cpp//:s3", "@boringssl//:crypto", "@brotli", "@bzip2", diff --git a/third_party/aws-sdk-cpp.BUILD b/third_party/aws-sdk-cpp.BUILD index ba7d90bcb..a1f1da060 100644 --- a/third_party/aws-sdk-cpp.BUILD +++ b/third_party/aws-sdk-cpp.BUILD @@ -163,6 +163,60 @@ cc_library( ], ) +cc_library( + name = "cognito-identity", + srcs = glob([ + "aws-cpp-sdk-cognito-identity/source/*.cpp", + "aws-cpp-sdk-cognito-identity/source/model/*.cpp", + ]), + hdrs = glob([ + "aws-cpp-sdk-cognito-identity/include/aws/cognito-identity/*.h", + "aws-cpp-sdk-cognito-identity/include/aws/cognito-identity/model/*.h", + ]), + includes = [ + "aws-cpp-sdk-cognito-identity/include", + ], + deps = [ + ":core", + ], +) + +cc_library( + name = "sts", + srcs = glob([ + "aws-cpp-sdk-sts/source/*.cpp", + "aws-cpp-sdk-sts/source/model/*.cpp", + ]), + hdrs = glob([ + "aws-cpp-sdk-sts/include/aws/sts/*.h", + "aws-cpp-sdk-sts/include/aws/sts/model/*.h", + ]), + includes = [ + "aws-cpp-sdk-sts/include", + ], + deps = [ + ":core", + ], +) + +cc_library( + name = "identity-management", + srcs = glob([ + "aws-cpp-sdk-identity-management/source/auth/*.cpp", + ]), + hdrs = glob([ + "aws-cpp-sdk-identity-management/include/aws/identity-management/auth/*.h", + ]), + includes = [ + "aws-cpp-sdk-identity-management/include", + ], + deps = [ + ":cognito-identity", + ":core", + ":sts", + ], +) + genrule( name = "SDKConfig_h", outs = [ From c9b68ede701285edd1225f2156d5a28cd7553a06 Mon Sep 17 00:00:00 2001 From: LinGeLin <1057445597@qq.com> Date: Tue, 7 Jun 2022 15:21:53 +0800 Subject: [PATCH 02/13] set protobuf to 3.20.0 to fix lint check error --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e93db6ef9..f27705d36 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,7 @@ jobs: sudo python3 --version sudo python3 -m pip install dataclasses sudo python3 -m pip install setuptools + sudo python3 -m pip install --upgrade protobuf==3.20.0 sudo python3 -m pip install -U git+https://github.com/tensorflow/docs find docs -name '*.ipynb' | xargs python3 -m tensorflow_docs.tools.nbfmt echo "Check for failed fmt: " From 7ff9a44d915305e609b1b4a145d3071b3e264aaa Mon Sep 17 00:00:00 2001 From: LinGeLin <1057445597@qq.com> Date: Thu, 23 Jun 2022 16:57:15 +0800 Subject: [PATCH 03/13] Modified read s3 parquet method for ArrowS3Dataset --- WORKSPACE | 9 +- .../core/kernels/arrow/arrow_dataset_ops.cc | 183 ++++++++++++------ .../core/kernels/arrow/arrow_util.cc | 72 ------- tensorflow_io/core/kernels/arrow/arrow_util.h | 12 -- tensorflow_io/core/ops/arrow_ops.cc | 4 - tensorflow_io/python/ops/arrow_dataset_ops.py | 8 - 6 files changed, 133 insertions(+), 155 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 5eb4e72dc..f567a65cf 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -158,11 +158,12 @@ http_archive( http_archive( name = "arrow", build_file = "//third_party:arrow.BUILD", - sha256 = "57e13c62f27b710e1de54fd30faed612aefa22aa41fa2c0c3bacd204dd18a8f3", - strip_prefix = "arrow-apache-arrow-7.0.0", + patch_cmds = ["""sed -i.bak '24i\\'$'\\n#undef ARROW_WITH_OPENTELEMETRY\\n' cpp/src/arrow/util/tracing_internal.h"""], + sha256 = "19ece12de48e51ce4287d2dee00dc358fbc5ff02f41629d16076f77b8579e272", + strip_prefix = "arrow-apache-arrow-8.0.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-7.0.0.tar.gz", - "https://github.com/apache/arrow/archive/apache-arrow-7.0.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-8.0.0.tar.gz", + "https://github.com/apache/arrow/archive/apache-arrow-8.0.0.tar.gz", ], ) diff --git a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc index 4efb93bf2..041882378 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "arrow/api.h" #include "arrow/io/stdio.h" #include "arrow/ipc/api.h" #include "arrow/result.h" +#include "parquet/arrow/reader.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/public/version.h" #include "tensorflow_io/core/kernels/arrow/arrow_kernels.h" #include "tensorflow_io/core/kernels/arrow/arrow_stream_client.h" @@ -101,7 +105,6 @@ class ArrowDatasetBase : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - // If in initial state, setup and read first batch if (current_batch_ == nullptr && current_row_idx_ == 0) { TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); @@ -163,6 +166,7 @@ class ArrowDatasetBase : public DatasetBase { } // Assign Tensors for each column in the current row + result_tensors->reserve(this->dataset()->columns_.size()); for (size_t i = 0; i < this->dataset()->columns_.size(); ++i) { int32 col = this->dataset()->columns_[i]; DataType output_type = this->dataset()->output_types_[i]; @@ -177,7 +181,6 @@ class ArrowDatasetBase : public DatasetBase { Tensor tensor(ctx->allocator({}), output_type, output_shape); TF_RETURN_IF_ERROR( ArrowUtil::AssignTensor(arr, current_row_idx_, &tensor)); - result_tensors->emplace_back(std::move(tensor)); } @@ -757,8 +760,6 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase { std::shared_ptr<::arrow::Table> table; CHECK_ARROW(reader->Read(&table)); - int64_t num_columns = table->num_columns(); - // Convert the table to a sequence of batches arrow::TableBatchReader tr(*table.get()); std::shared_ptr batch; @@ -983,21 +984,13 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { column_names.push_back(column_names_tensor->flat()(i)); } - std::vector column_cols; - auto s3Dataset = ArrowUtil::GetS3Dataset( - aws_access_key, aws_secret_key, aws_endpoint_override, parquet_files); - auto status = ArrowUtil::GetColumns(s3Dataset, column_names, column_cols); - - int64 offset; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "offset", &offset)); - - int64 max_rows; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_rows", &max_rows)); + std::vector column_cols(column_names.size()); + std::iota(column_cols.begin(), column_cols.end(), 0); *output = new Dataset(ctx, aws_access_key, aws_secret_key, aws_endpoint_override, - parquet_files, column_names, offset, max_rows, column_cols, - batch_size, batch_mode, output_types_, output_shapes_); + parquet_files, column_names, column_cols, batch_size, + batch_mode, output_types_, output_shapes_); } private: @@ -1007,10 +1000,9 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { const std::string& aws_secret_key, const std::string& aws_endpoint_override, const std::vector& parquet_files, - const std::vector& column_names, int64 offset, - int64 max_rows, const std::vector columns, - const int64 batch_size, const ArrowBatchMode batch_mode, - const DataTypeVector& output_types, + const std::vector& column_names, + const std::vector columns, const int64 batch_size, + const ArrowBatchMode batch_mode, const DataTypeVector& output_types, const std::vector& output_shapes) : ArrowDatasetBase(ctx, columns, batch_size, batch_mode, output_types, output_shapes), @@ -1018,9 +1010,7 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { aws_secret_key_(aws_secret_key), aws_endpoint_override_(aws_endpoint_override), parquet_files_(parquet_files), - column_names_(column_names), - offset_(offset), - max_rows_(max_rows) {} + column_names_(column_names) {} string DebugString() const override { return "ArrowS3DatasetOp::Dataset"; } Status InputDatasets(std::vector* inputs) const { @@ -1046,10 +1036,6 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { TF_RETURN_IF_ERROR(b->AddVector(parquet_files_, &parquet_files)); Node* column_names = nullptr; TF_RETURN_IF_ERROR(b->AddVector(column_names_, &column_names)); - Node* offset = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(offset_, &offset)); - Node* max_rows = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(max_rows_, &max_rows)); Node* columns = nullptr; TF_RETURN_IF_ERROR(b->AddVector(columns_, &columns)); Node* batch_size = nullptr; @@ -1061,7 +1047,7 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { TF_RETURN_IF_ERROR(b->AddDataset( this, {aws_access_key, aws_secret_key, aws_endpoint_override, parquet_files, - column_names, offset, max_rows, columns, batch_size, batch_mode}, + column_names, columns, batch_size, batch_mode}, output)); return Status::OK(); } @@ -1081,45 +1067,134 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { private: Status SetupStreamsLocked(Env* env) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { - auto parquet_dataset = ArrowUtil::GetS3Dataset( - dataset()->aws_access_key_, dataset()->aws_secret_key_, - dataset()->aws_endpoint_override_, dataset()->parquet_files_); - auto scanner_builder = parquet_dataset->NewScan().ValueOrDie(); - using arrow::compute::and_; - using arrow::compute::field_ref; - using arrow::compute::greater_equal; - using arrow::compute::literal; - arrow::compute::Expression filter_rowstokeep = and_( - {greater_equal(field_ref(arrow::FieldRef(K_ROW_INDEX_COLUMN_NAME)), - literal(dataset()->offset_)), - less(field_ref(arrow::FieldRef(K_ROW_INDEX_COLUMN_NAME)), - literal(dataset()->offset_ + dataset()->max_rows_))}); - - scanner_builder->Project(dataset()->column_names_); - if (dataset()->max_rows_ > 0) { - scanner_builder->Filter(filter_rowstokeep); + if (!s3fs_) { + arrow::fs::EnsureS3Initialized(); + auto s3Options = arrow::fs::S3Options::FromAccessKey( + dataset()->aws_access_key_, dataset()->aws_secret_key_); + s3Options.endpoint_override = dataset()->aws_endpoint_override_; + s3fs_ = arrow::fs::S3FileSystem::Make(s3Options).ValueOrDie(); + } + ReadFile(current_file_idx_); + if (!background_worker_) { + background_worker_ = + std::make_shared(env, "download_next_workder"); + } + + if (current_batch_idx_ < record_batches_.size()) { + current_batch_ = record_batches_[current_batch_idx_]; + } + + if (current_file_idx_ + 1 < dataset()->parquet_files_.size()) { + background_worker_->Schedule(std::bind(&Iterator::ReadFile, this, + current_file_idx_ + 1, true)); } - scanner_builder->BatchSize(K_DEFAULT_BATCH_SIZE); - auto scanner = scanner_builder->Finish().ValueOrDie(); - reader_ = scanner->ToRecordBatchReader().ValueOrDie(); - CHECK_ARROW(reader_->ReadNext(¤t_batch_)); return Status::OK(); } Status NextStreamLocked(Env* env) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { ArrowBaseIterator::NextStreamLocked(env); - CHECK_ARROW(reader_->ReadNext(¤t_batch_)); + if (++current_batch_idx_ < record_batches_.size()) { + current_batch_ = record_batches_[current_batch_idx_]; + } else if (++current_file_idx_ < dataset()->parquet_files_.size()) { + current_batch_idx_ = 0; + { + mutex_lock lk(cv_mu_); + while (!background_thread_finished_) { + cv_.wait(lk); + } + } + + record_batches_.swap(next_record_batches_); + if (!record_batches_.empty()) { + current_batch_ = record_batches_[current_batch_idx_]; + } else { + current_batch_ = nullptr; + } + background_thread_finished_ = false; + if (current_file_idx_ + 1 < dataset()->parquet_files_.size()) { + background_worker_->Schedule(std::bind( + &Iterator::ReadFile, this, current_file_idx_ + 1, true)); + } + } return Status::OK(); } void ResetStreamsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { ArrowBaseIterator::ResetStreamsLocked(); - reader_.reset(); + current_file_idx_ = 0; + current_batch_idx_ = 0; + record_batches_.clear(); + next_record_batches_.clear(); } - size_t current_endpoint_idx_ TF_GUARDED_BY(mu_) = 0; - std::shared_ptr reader_ TF_GUARDED_BY(mu_); + Status ReadFile(int file_index, bool background = false) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto access_file = + s3fs_->OpenInputFile(dataset()->parquet_files_[file_index]) + .ValueOrDie(); + + parquet::ArrowReaderProperties properties; + properties.set_use_threads(true); + properties.set_pre_buffer(true); + parquet::ReaderProperties parquet_properties = + parquet::default_reader_properties(); + + std::shared_ptr builder = + std::make_shared(); + builder->Open(access_file, parquet_properties); + + std::unique_ptr reader; + builder->properties(properties)->Build(&reader); + + if (column_indices_.empty()) { + std::shared_ptr schema; + reader->GetSchema(&schema); + for (const auto& name : dataset()->column_names_) { + column_indices_.push_back(schema->GetFieldIndex(name)); + } + } + // Read file columns and build a table + std::shared_ptr<::arrow::Table> table; + CHECK_ARROW(reader->ReadTable(column_indices_, &table)); + + // Convert the table to a sequence of batches + arrow::TableBatchReader tr(*table.get()); + std::shared_ptr batch; + CHECK_ARROW(tr.ReadNext(&batch)); + TF_RETURN_IF_ERROR(CheckBatchColumnTypes(batch)); + next_record_batches_.clear(); + while (batch != nullptr) { + if (!background) { + record_batches_.emplace_back(batch); + } else { + next_record_batches_.emplace_back(batch); + } + CHECK_ARROW(tr.ReadNext(&batch)); + } + + if (background) { + mutex_lock lk(cv_mu_); + background_thread_finished_ = true; + cv_.notify_all(); + } + + return Status::OK(); + } + + size_t current_file_idx_ TF_GUARDED_BY(mu_) = 0; + size_t current_batch_idx_ TF_GUARDED_BY(mu_) = 0; + std::vector> record_batches_ + TF_GUARDED_BY(mu_); + std::vector> next_record_batches_ + TF_GUARDED_BY(mu_); + std::shared_ptr s3fs_ TF_GUARDED_BY(mu_) = + nullptr; + std::vector column_indices_ TF_GUARDED_BY(mu_); + std::shared_ptr background_worker_ = nullptr; + mutex cv_mu_; + condition_variable cv_; + bool background_thread_finished_ = false; }; const std::string aws_access_key_; @@ -1127,8 +1202,6 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { const std::string aws_endpoint_override_; const std::vector parquet_files_; const std::vector column_names_; - const int64 offset_; - const int64 max_rows_; }; }; // class ArrowS3DatasetOp diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.cc b/tensorflow_io/core/kernels/arrow/arrow_util.cc index d6ba18f59..061699fc3 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_util.cc @@ -471,78 +471,6 @@ Status ParseHost(std::string host, std::string* host_address, return Status::OK(); } -std::shared_ptr GetS3Dataset( - const std::string& access_key, const std::string& secret_key, - const std::string& endpoint_override, - const std::vector& parquet_files) { - afs::EnsureS3Initialized(); - - afs::S3Options s3Options = - afs::S3Options::FromAccessKey(access_key, secret_key); - s3Options.endpoint_override = endpoint_override; - s3Options.scheme = "http"; - - std::shared_ptr s3fs = - afs::S3FileSystem::Make(s3Options).ValueOrDie(); - auto format = std::make_shared(); - ads::FileSystemFactoryOptions options; - - auto factory = - ads::FileSystemDatasetFactory::Make(s3fs, parquet_files, format, options) - .ValueOrDie(); - return factory->Finish().ValueOrDie(); -} - -Status GetColumns(const std::shared_ptr& dataset, - const std::vector& column_names, - std::vector& column_cols) { - std::vector v_cols; - std::list l_cols; - int column_count = column_names.size(); - column_cols.reserve(column_count); - v_cols.reserve(column_count); - auto schema = dataset->schema(); - for (const auto& name : column_names) { - int index = schema->GetFieldIndex(name); - if (index != -1) { - v_cols.push_back(index); - - auto iter = l_cols.begin(); - if (l_cols.empty()) { - l_cols.insert(iter, index); - } else { - while (iter != l_cols.end()) { - if (*iter > index) { - l_cols.insert(iter, index); - break; - } - iter++; - } - if (iter == l_cols.end()) { - l_cols.insert(iter, index); - } - } - } else { - return errors::InvalidArgument("Column name: " + name + - " does not exist"); - } - } - - for (const auto& v_index : v_cols) { - int32 index = 0; - for (const auto& l_index : l_cols) { - if (v_index == l_index) { - column_cols.push_back(index); - break; - } else { - index++; - } - } - } - - return Status::OK(); -} - } // namespace ArrowUtil } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.h b/tensorflow_io/core/kernels/arrow/arrow_util.h index 3b5859984..f37ec5fa1 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.h +++ b/tensorflow_io/core/kernels/arrow/arrow_util.h @@ -45,9 +45,6 @@ namespace data { } \ } while (false) -const std::string K_ROW_INDEX_COLUMN_NAME = "row_index"; -#define K_DEFAULT_BATCH_SIZE 100 * 1024 - namespace ArrowUtil { // Convert Arrow Data Type to TensorFlow @@ -89,15 +86,6 @@ Status ParseEndpoint(std::string endpoint, std::string* endpoint_type, Status ParseHost(std::string host, std::string* host_address, std::string* host_port); -std::shared_ptr GetS3Dataset( - const std::string& access_key, const std::string& secret_key, - const std::string& endpoint_override, - const std::vector& parquet_files); - -Status GetColumns(const std::shared_ptr& dataset, - const std::vector& column_names, - std::vector& column_cols); - } // namespace ArrowUtil } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/ops/arrow_ops.cc b/tensorflow_io/core/ops/arrow_ops.cc index cad90ce55..6c0f8a9c0 100644 --- a/tensorflow_io/core/ops/arrow_ops.cc +++ b/tensorflow_io/core/ops/arrow_ops.cc @@ -194,8 +194,6 @@ REGISTER_OP("IO>ArrowS3Dataset") .Input("aws_endpoint_override: string") .Input("parquet_files: string") .Input("column_names: string") - .Input("offset: int64") - .Input("max_rows: int64") .Input("columns: int32") .Input("batch_size: int64") .Input("batch_mode: string") @@ -212,8 +210,6 @@ aws_secret_key: S3 secret_key. aws_endpoint_override: S3 endpoint override parquet_files: One or more parqeut file path on s3 column_names: Select columns to read by names -offset: Set the offset start to read -max_rows: Max rows count to read )doc"); } // namespace tensorflow diff --git a/tensorflow_io/python/ops/arrow_dataset_ops.py b/tensorflow_io/python/ops/arrow_dataset_ops.py index bd93b6a7d..21dea18ad 100644 --- a/tensorflow_io/python/ops/arrow_dataset_ops.py +++ b/tensorflow_io/python/ops/arrow_dataset_ops.py @@ -665,8 +665,6 @@ def __init__( column_names, columns, output_types, - offset=0, - max_rows=0, output_shapes=None, batch_size=None, batch_mode="keep_remainder", @@ -679,8 +677,6 @@ def __init__( aws_endpoint_override: S3 endpoint override parquet_files: A list of parquet files path on s3 column_names: A list of column names to be used in the dataset - offset: Set the offset start to read - max_rows: Max rows count to read columns: A list of column indices to be used in the Dataset output_types: Tensor dtypes of the output tensors output_shapes: TensorShapes of the output tensors or None to @@ -709,8 +705,6 @@ def __init__( column_names = tf.convert_to_tensor( column_names, dtype=dtypes.string, name="column_names" ) - offset = tf.convert_to_tensor(offset, dtype=dtypes.int64, name="offset") - max_rows = tf.convert_to_tensor(max_rows, dtype=dtypes.int64, name="max_rows") super().__init__( partial( @@ -720,8 +714,6 @@ def __init__( aws_endpoint_override, parquet_files, column_names, - offset, - max_rows, ), columns, output_types, From 9db24c30736ee5e05a26f2e6550fe201cffc5640 Mon Sep 17 00:00:00 2001 From: LinGeLin <1057445597@qq.com> Date: Tue, 12 Jul 2022 16:01:11 +0800 Subject: [PATCH 04/13] Added wrong column names check --- .../core/kernels/arrow/arrow_dataset_ops.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc index 041882378..b2bc6ca4c 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc @@ -1077,7 +1077,7 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { ReadFile(current_file_idx_); if (!background_worker_) { background_worker_ = - std::make_shared(env, "download_next_workder"); + std::make_shared(env, "download_next_worker"); } if (current_batch_idx_ < record_batches_.size()) { @@ -1150,8 +1150,18 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { if (column_indices_.empty()) { std::shared_ptr schema; reader->GetSchema(&schema); + // check column name exist + std::string err_column_names; for (const auto& name : dataset()->column_names_) { - column_indices_.push_back(schema->GetFieldIndex(name)); + int fieldIndex = schema->GetFieldIndex(name); + column_indices_.push_back(fieldIndex); + if (-1 == fieldIndex) { + err_column_names = err_column_names + " " + name; + } + } + + if (err_column_names.length() != 0) { + return errors::InvalidArgument("these column names don't exist: ", err_column_names); } } // Read file columns and build a table From a8c590c51cf6c42df6ca608ae6eb762f61a70dfd Mon Sep 17 00:00:00 2001 From: LinGeLin <1057445597@qq.com> Date: Mon, 18 Jul 2022 15:24:42 +0800 Subject: [PATCH 05/13] Added dims1 support --- tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc index b2bc6ca4c..cb50247e0 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc @@ -177,6 +177,13 @@ class ArrowDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(ArrowUtil::AssignShape( arr, current_row_idx_, batch_size, &output_shape)); + if (output_shape.dims() == 1) { + auto&& output_shape_in = this->dataset()->output_shapes_[i]; + if (output_shape_in.dim_size(output_shape_in.dims() - 1) == 1) { + output_shape.AddDim(1); + } + } + // Allocate a new tensor and assign Arrow data to it Tensor tensor(ctx->allocator({}), output_type, output_shape); TF_RETURN_IF_ERROR( @@ -1161,7 +1168,8 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { } if (err_column_names.length() != 0) { - return errors::InvalidArgument("these column names don't exist: ", err_column_names); + return errors::InvalidArgument("these column names don't exist: ", + err_column_names); } } // Read file columns and build a table From 514eac011c170ed73bd473957202485630fcc32a Mon Sep 17 00:00:00 2001 From: LinGeLin <1057445597@qq.com> Date: Fri, 22 Jul 2022 14:26:03 +0800 Subject: [PATCH 06/13] Fixed error column remind --- tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc index cb50247e0..ce0b5665f 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc @@ -1081,7 +1081,7 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { s3Options.endpoint_override = dataset()->aws_endpoint_override_; s3fs_ = arrow::fs::S3FileSystem::Make(s3Options).ValueOrDie(); } - ReadFile(current_file_idx_); + TF_RETURN_IF_ERROR(ReadFile(current_file_idx_)); if (!background_worker_) { background_worker_ = std::make_shared(env, "download_next_worker"); From a477f8e39fe1ac832a3109a892f6b0e31c4c83e6 Mon Sep 17 00:00:00 2001 From: LinGeLin <1057445597@qq.com> Date: Fri, 29 Jul 2022 14:35:02 +0800 Subject: [PATCH 07/13] Added arrow s3 dataset filter --- .../core/kernels/arrow/arrow_dataset_ops.cc | 46 ++- .../core/kernels/arrow/arrow_util.cc | 336 ++++++++++++++++++ tensorflow_io/core/kernels/arrow/arrow_util.h | 4 + tensorflow_io/core/ops/arrow_ops.cc | 1 + tensorflow_io/python/ops/arrow_dataset_ops.py | 4 + 5 files changed, 380 insertions(+), 11 deletions(-) diff --git a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc index ce0b5665f..e425567f0 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc @@ -994,10 +994,13 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { std::vector column_cols(column_names.size()); std::iota(column_cols.begin(), column_cols.end(), 0); + tstring filter; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "filter", &filter)); + *output = new Dataset(ctx, aws_access_key, aws_secret_key, aws_endpoint_override, - parquet_files, column_names, column_cols, batch_size, - batch_mode, output_types_, output_shapes_); + parquet_files, column_names, filter, column_cols, + batch_size, batch_mode, output_types_, output_shapes_); } private: @@ -1008,8 +1011,9 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { const std::string& aws_endpoint_override, const std::vector& parquet_files, const std::vector& column_names, - const std::vector columns, const int64 batch_size, - const ArrowBatchMode batch_mode, const DataTypeVector& output_types, + const std::string& filter, const std::vector columns, + const int64 batch_size, const ArrowBatchMode batch_mode, + const DataTypeVector& output_types, const std::vector& output_shapes) : ArrowDatasetBase(ctx, columns, batch_size, batch_mode, output_types, output_shapes), @@ -1017,7 +1021,8 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { aws_secret_key_(aws_secret_key), aws_endpoint_override_(aws_endpoint_override), parquet_files_(parquet_files), - column_names_(column_names) {} + column_names_(column_names), + filter_(filter) {} string DebugString() const override { return "ArrowS3DatasetOp::Dataset"; } Status InputDatasets(std::vector* inputs) const { @@ -1045,6 +1050,9 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { TF_RETURN_IF_ERROR(b->AddVector(column_names_, &column_names)); Node* columns = nullptr; TF_RETURN_IF_ERROR(b->AddVector(columns_, &columns)); + Node* filter = nullptr; + tstring filter_str = filter_; + TF_RETURN_IF_ERROR(b->AddScalar(filter_str, &filter)); Node* batch_size = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); Node* batch_mode = nullptr; @@ -1054,7 +1062,7 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { TF_RETURN_IF_ERROR(b->AddDataset( this, {aws_access_key, aws_secret_key, aws_endpoint_override, parquet_files, - column_names, columns, batch_size, batch_mode}, + column_names, filter, columns, batch_size, batch_mode}, output)); return Status::OK(); } @@ -1105,6 +1113,7 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { current_batch_ = record_batches_[current_batch_idx_]; } else if (++current_file_idx_ < dataset()->parquet_files_.size()) { current_batch_idx_ = 0; + { mutex_lock lk(cv_mu_); while (!background_thread_finished_) { @@ -1175,11 +1184,25 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { // Read file columns and build a table std::shared_ptr<::arrow::Table> table; CHECK_ARROW(reader->ReadTable(column_indices_, &table)); - // Convert the table to a sequence of batches - arrow::TableBatchReader tr(*table.get()); - std::shared_ptr batch; - CHECK_ARROW(tr.ReadNext(&batch)); + std::shared_ptr batch_reader = + std::make_shared(table); + std::shared_ptr batch = nullptr; + + // filter + if (!dataset()->filter_.empty()) { + auto scanner_builder = + arrow::dataset::ScannerBuilder::FromRecordBatchReader( + batch_reader); + arrow::compute::Expression filter_expr; + TF_RETURN_IF_ERROR( + ArrowUtil::ParseExpression(dataset()->filter_, filter_expr)); + scanner_builder->Filter(filter_expr); + auto scanner = scanner_builder->Finish().ValueOrDie(); + batch_reader = scanner->ToRecordBatchReader().ValueOrDie(); + } + + CHECK_ARROW(batch_reader->ReadNext(&batch)); TF_RETURN_IF_ERROR(CheckBatchColumnTypes(batch)); next_record_batches_.clear(); while (batch != nullptr) { @@ -1188,7 +1211,7 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { } else { next_record_batches_.emplace_back(batch); } - CHECK_ARROW(tr.ReadNext(&batch)); + CHECK_ARROW(batch_reader->ReadNext(&batch)); } if (background) { @@ -1220,6 +1243,7 @@ class ArrowS3DatasetOp : public ArrowOpKernelBase { const std::string aws_endpoint_override_; const std::vector parquet_files_; const std::vector column_names_; + const std::string filter_; }; }; // class ArrowS3DatasetOp diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.cc b/tensorflow_io/core/kernels/arrow/arrow_util.cc index 061699fc3..f3c12791a 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_util.cc @@ -471,6 +471,342 @@ Status ParseHost(std::string host, std::string* host_address, return Status::OK(); } +enum calType { + CONSTANT, + VARIABLE, + ADD, + SUBTRACT, + MULTIPLY, + DIVIDE, + EQUAL, + NOT_EQUAL, + LESS, + LESS_EQUAL, + GREATER, + GREATER_EQUAL, + AND, + OR, + LPAREN, + RPAREN, +}; + +enum OpType { + OPERATOR, + OPERAND, +}; + +typedef struct Token { + Token(calType type, int value) : type_(type) { + if (type_ == calType::CONSTANT) { + expression_ = arrow::compute::literal(value); + } + } + Token(calType type, float value) : type_(type) { + if (type_ == calType::CONSTANT) { + expression_ = arrow::compute::literal(value); + } + } + Token(calType type, std::string func) : type_(type), func_(func) { + if (type_ == calType::VARIABLE) { + expression_ = arrow::compute::field_ref(func_); + } + } + calType type_; + std::string func_; + arrow::compute::Expression expression_; +} Token; + +typedef struct ASTNode { + ASTNode(std::shared_ptr token, std::shared_ptr left, + std::shared_ptr right) + : token_(token), left_(left), right_(right) {} + std::shared_ptr token_; + std::shared_ptr left_; + std::shared_ptr right_; +} ASTNode; + +class Lexer { + public: + Lexer(const std::string& text) + : text_(text), position_(0), cur_op_(OPERATOR){}; + void skip_space() { + while (position_ < text_.length() && text_[position_] == ' ') { + position_++; + } + } + + std::string get_constant() { + int start = position_ - 1; + while (position_ < text_.length() && std::isdigit(text_[position_]) || + '.' == text_[position_]) { + position_++; + } + return text_.substr(start, position_ - start).c_str(); + } + + calType get_comparison_type() { + // == != >= <= + char begin_char = text_[position_ - 1]; + if (position_ < text_.length() && text_[position_] == '=') { + position_++; + if (begin_char == '=') { + return calType::EQUAL; + } else if (begin_char == '!') { + return calType::NOT_EQUAL; + } else if (begin_char == '>') { + return calType::GREATER_EQUAL; + } else if (begin_char == '<') { + return calType::LESS_EQUAL; + } + } else { + if (begin_char == '>') { + return calType::GREATER; + } else if (begin_char == '<') { + return calType::LESS; + } + } + } + + std::string get_variable() { + int start = position_ - 1; + while (position_ < text_.length() && + (std::isalnum(text_[position_]) || '_' == text_[position_])) { + position_++; + } + return text_.substr(start, position_ - start); + } + + std::shared_ptr get_next_token() { + while (position_ < text_.length()) { + char current_char = text_[position_++]; + if (' ' == current_char) { + skip_space(); + } else if (std::isdigit(current_char)) { + cur_op_ = OPERAND; + std::string constant = get_constant(); + if (std::string::npos == constant.find('.')) { + return std::make_shared(calType::CONSTANT, + std::stoi(constant)); + } else { + return std::make_shared(calType::CONSTANT, + std::stof(constant)); + } + } else if (std::isalpha(current_char) || '_' == current_char) { + cur_op_ = OPERAND; + return std::make_shared(calType::VARIABLE, get_variable()); + } else if ('+' == current_char) { + cur_op_ = OPERATOR; + return std::make_shared(calType::ADD, "add"); + } else if ('-' == current_char) { + if (cur_op_ == OPERAND) { + cur_op_ = OPERATOR; + return std::make_shared(calType::SUBTRACT, "subtract"); + } else { + cur_op_ = OPERAND; + std::string constant = get_constant(); + if (constant.length() <= 1) { + return nullptr; + } + if (std::string::npos == constant.find('.')) { + return std::make_shared(calType::CONSTANT, + std::stoi(constant)); + } else { + return std::make_shared(calType::CONSTANT, + std::stof(constant)); + } + } + } else if ('*' == current_char) { + cur_op_ = OPERATOR; + return std::make_shared(calType::MULTIPLY, "multiply"); + } else if ('/' == current_char) { + cur_op_ = OPERATOR; + return std::make_shared(calType::DIVIDE, "divide"); + } else if ('(' == current_char) { + cur_op_ = OPERATOR; + return std::make_shared(calType::LPAREN, "("); + } else if (')' == current_char) { + cur_op_ = OPERAND; + return std::make_shared(calType::RPAREN, ")"); + } else if ('=' == current_char || '!' == current_char || + '>' == current_char || '<' == current_char) { + cur_op_ = OPERATOR; + auto type = get_comparison_type(); + if (calType::EQUAL == type) { + return std::make_shared(calType::EQUAL, "equal"); + } else if (calType::NOT_EQUAL == type) { + return std::make_shared(calType::NOT_EQUAL, "not_equal"); + } else if (calType::LESS == type) { + return std::make_shared(calType::LESS, "less"); + } else if (calType::LESS_EQUAL == type) { + return std::make_shared(calType::LESS_EQUAL, "less_equal"); + } else if (calType::GREATER == type) { + return std::make_shared(calType::GREATER, "greater"); + } else if (calType::GREATER_EQUAL == type) { + return std::make_shared(calType::GREATER_EQUAL, + "greater_equal"); + } + } else if ('&' == current_char) { + cur_op_ = OPERATOR; + if (position_ < text_.length() && '&' == text_[position_]) { + position_++; + return std::make_shared(calType::AND, "and"); + } + } else if ('|' == current_char) { + cur_op_ = OPERATOR; + if (position_ < text_.length() && '|' == text_[position_]) { + position_++; + return std::make_shared(calType::OR, "or"); + } + } + } + return nullptr; + } + + private: + OpType cur_op_; + int position_; + std::string text_; +}; + +class Parser { + public: + Parser(std::shared_ptr ptr) : lexer_ptr_(ptr) { + current_token_ = lexer_ptr_->get_next_token(); + } + + inline void update_current_token() { + current_token_ = lexer_ptr_->get_next_token(); + } + + // constant, variable, lparen + std::shared_ptr factor() { + if (!current_token_) { + return nullptr; + } + auto token = current_token_; + if (token->type_ == calType::CONSTANT) { + update_current_token(); + return std::make_shared(token, nullptr, nullptr); + } else if (token->type_ == calType::VARIABLE) { + update_current_token(); + return std::make_shared(token, nullptr, nullptr); + } else if (token->type_ == calType::LPAREN) { + update_current_token(); + auto node = logical(); + update_current_token(); + return node; + } + return nullptr; + } + + // multiply, divide + std::shared_ptr term() { + auto node = factor(); + while (current_token_ && (current_token_->type_ == calType::MULTIPLY || + current_token_->type_ == calType::DIVIDE)) { + auto token = current_token_; + update_current_token(); + node = std::make_shared(token, node, factor()); + } + return node; + } + + // add, subtract + std::shared_ptr expr() { + auto node = term(); + while (current_token_ && (current_token_->type_ == calType::ADD || + current_token_->type_ == calType::SUBTRACT)) { + auto token = current_token_; + update_current_token(); + node = std::make_shared(token, node, term()); + } + return node; + } + + // Comparison + std::shared_ptr comparison() { + auto node = expr(); + while (current_token_ && + (current_token_->type_ >= calType::EQUAL && + current_token_->type_ <= calType::GREATER_EQUAL)) { + auto token = current_token_; + update_current_token(); + node = std::make_shared(token, node, expr()); + } + return node; + } + + // Logical + std::shared_ptr logical() { + auto node = comparison(); + while (current_token_ && (current_token_->type_ == calType::AND || + current_token_->type_ == calType::OR)) { + auto token = current_token_; + update_current_token(); + node = std::make_shared(token, node, comparison()); + } + return node; + } + + private: + std::shared_ptr lexer_ptr_; + std::shared_ptr current_token_; +}; + +class Interpreter { + public: + Interpreter(std::shared_ptr parser) : parser_(parser) {} + arrow::compute::Expression visit(std::shared_ptr root) { + auto rt = root->token_; + auto rlt = root->left_->token_; + auto rrt = root->right_->token_; + if (rlt->type_ != calType::CONSTANT && rlt->type_ != calType::VARIABLE) { + visit(root->left_); + } + if (rrt->type_ != calType::CONSTANT && rrt->type_ != calType::VARIABLE) { + visit(root->right_); + } + + if (rt->type_ >= calType::ADD && rt->type_ <= calType::OR) { + rt->expression_ = + arrow::compute::call(rt->func_, {rlt->expression_, rrt->expression_}); + } + rt->type_ = calType::VARIABLE; + return rt->expression_; + } + + Status interpreter(std::shared_ptr& ASTree) { + auto root = parser_->logical(); + if (!root || !root->left_ || !root->right_ || + root->token_->type_ < calType::EQUAL || + root->token_->type_ > calType::OR) { + return errors::InvalidArgument( + "Your filter expression is not supported!"); + } + ASTree = root; + return Status::OK(); + } + + private: + std::shared_ptr parser_; +}; + +Status ParseExpression(const std::string& text, + arrow::compute::Expression& expr) { + auto lexer_ptr = std::make_shared(text); + auto parser_ptr = std::make_shared(lexer_ptr); + auto interpreter_ptr = std::make_shared(parser_ptr); + + std::shared_ptr ASTree; + auto status = interpreter_ptr->interpreter(ASTree); + if (!status.ok()) { + return status; + } + + expr = interpreter_ptr->visit(ASTree); + return Status::OK(); +} + } // namespace ArrowUtil } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.h b/tensorflow_io/core/kernels/arrow/arrow_util.h index f37ec5fa1..3fe49f775 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.h +++ b/tensorflow_io/core/kernels/arrow/arrow_util.h @@ -86,6 +86,10 @@ Status ParseEndpoint(std::string endpoint, std::string* endpoint_type, Status ParseHost(std::string host, std::string* host_address, std::string* host_port); +// Parse expr from string for scan filter +Status ParseExpression(const std::string& text, + arrow::compute::Expression& expr); + } // namespace ArrowUtil } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/ops/arrow_ops.cc b/tensorflow_io/core/ops/arrow_ops.cc index 6c0f8a9c0..dfd8f8196 100644 --- a/tensorflow_io/core/ops/arrow_ops.cc +++ b/tensorflow_io/core/ops/arrow_ops.cc @@ -194,6 +194,7 @@ REGISTER_OP("IO>ArrowS3Dataset") .Input("aws_endpoint_override: string") .Input("parquet_files: string") .Input("column_names: string") + .Input("filter: string") .Input("columns: int32") .Input("batch_size: int64") .Input("batch_mode: string") diff --git a/tensorflow_io/python/ops/arrow_dataset_ops.py b/tensorflow_io/python/ops/arrow_dataset_ops.py index 21dea18ad..6f7e59461 100644 --- a/tensorflow_io/python/ops/arrow_dataset_ops.py +++ b/tensorflow_io/python/ops/arrow_dataset_ops.py @@ -668,6 +668,7 @@ def __init__( output_shapes=None, batch_size=None, batch_mode="keep_remainder", + filter="", ): """Create an ArrowDataset from an input stream. @@ -689,6 +690,7 @@ def __init__( "keep_remainder" (default, keeps partial batch data), "drop_remainder" (discard partial batch data), "auto" (size to number of records in Arrow record batch) + filter : filter for reade row """ aws_access_key = tf.convert_to_tensor( aws_access_key, dtype=dtypes.string, name="aws_access_key" @@ -705,6 +707,7 @@ def __init__( column_names = tf.convert_to_tensor( column_names, dtype=dtypes.string, name="column_names" ) + filter = tf.convert_to_tensor(filter, dtype=dtypes.string, name="filter") super().__init__( partial( @@ -714,6 +717,7 @@ def __init__( aws_endpoint_override, parquet_files, column_names, + filter, ), columns, output_types, From c7dde746e86c9de23333e7c9615840386cef9446 Mon Sep 17 00:00:00 2001 From: LinGeLin <1057445597@qq.com> Date: Mon, 1 Aug 2022 20:18:51 +0800 Subject: [PATCH 08/13] Fixed build error --- tensorflow_io/core/kernels/arrow/arrow_util.cc | 2 -- tensorflow_io/core/kernels/arrow/arrow_util.h | 3 --- third_party/arrow.BUILD | 14 +++++++------- third_party/aws-sdk-cpp.BUILD | 1 + 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.cc b/tensorflow_io/core/kernels/arrow/arrow_util.cc index f3c12791a..5f2abd845 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.cc +++ b/tensorflow_io/core/kernels/arrow/arrow_util.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow_io/core/kernels/arrow/arrow_util.h" -#include - #include "arrow/adapters/tensorflow/convert.h" #include "arrow/api.h" #include "arrow/ipc/api.h" diff --git a/tensorflow_io/core/kernels/arrow/arrow_util.h b/tensorflow_io/core/kernels/arrow/arrow_util.h index 3fe49f775..f78f98c48 100644 --- a/tensorflow_io/core/kernels/arrow/arrow_util.h +++ b/tensorflow_io/core/kernels/arrow/arrow_util.h @@ -26,9 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -namespace afs = arrow::fs; -namespace ads = arrow::dataset; - namespace tensorflow { // Forward declaration diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index 5cc04d1c6..f00f83623 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -62,11 +62,10 @@ cc_library( "cpp/src/arrow/json/*.cc", "cpp/src/arrow/tensor/*.cc", "cpp/src/arrow/util/*.cc", - "cpp/src/arrow/dataset/*.cc", - "cpp/src/arrow/filesystem/*.cc", - "cpp/src/arrow/compute/*.cc", - "cpp/src/arrow/compute/kernels/*.cc", - "cpp/src/arrow/compute/exec/*.cc", + "cpp/src/arrow/dataset/dataset.cc", + "cpp/src/arrow/dataset/file_parquet.cc", + "cpp/src/arrow/filesystem/s3fs.cc", + "cpp/src/arrow/compute/exec/expression.cc", "cpp/src/arrow/vendored/musl/strptime.c", "cpp/src/arrow/vendored/optional.hpp", "cpp/src/arrow/vendored/string_view.hpp", @@ -74,6 +73,7 @@ cc_library( "cpp/src/arrow/vendored/base64.cpp", "cpp/src/arrow/vendored/datetime/tz.cpp", "cpp/src/arrow/vendored/uriparser/*.c", + "cpp/src/arrow/vendored/pcg/*.hpp", "cpp/src/arrow/**/*.h", "cpp/src/parquet/**/*.h", "cpp/src/parquet/**/*.cc", @@ -84,8 +84,8 @@ cc_library( "cpp/src/**/*_benchmark.cc", "cpp/src/**/*_main.cc", "cpp/src/**/*_nossl.cc", - "cpp/src/**/*_test.cc", - "cpp/src/**/test_*.cc", + "cpp/src/**/*test*.h", + "cpp/src/**/*test*.cc", "cpp/src/**/*hdfs*.cc", "cpp/src/**/*fuzz*.cc", "cpp/src/**/*gcsfs*.cc", diff --git a/third_party/aws-sdk-cpp.BUILD b/third_party/aws-sdk-cpp.BUILD index a1f1da060..16e9cc9d1 100644 --- a/third_party/aws-sdk-cpp.BUILD +++ b/third_party/aws-sdk-cpp.BUILD @@ -205,6 +205,7 @@ cc_library( "aws-cpp-sdk-identity-management/source/auth/*.cpp", ]), hdrs = glob([ + "aws-cpp-sdk-identity-management/include/aws/identity-management/*.h", "aws-cpp-sdk-identity-management/include/aws/identity-management/auth/*.h", ]), includes = [ From 01a671e723c6c6c4449672117ea06011705c8668 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 3 Aug 2022 01:49:21 +0000 Subject: [PATCH 09/13] Update arrow Bazel BUILD file Signed-off-by: Yong Tang --- third_party/arrow.BUILD | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index f00f83623..6fc9417f4 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -56,7 +56,11 @@ cc_library( [ "cpp/src/arrow/*.cc", "cpp/src/arrow/array/*.cc", + "cpp/src/arrow/compute/*.cc", + "cpp/src/arrow/compute/exec/*.cc", + "cpp/src/arrow/compute/kernels/*.cc", "cpp/src/arrow/csv/*.cc", + "cpp/src/arrow/dataset/*.cc", "cpp/src/arrow/io/*.cc", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/json/*.cc", @@ -64,7 +68,9 @@ cc_library( "cpp/src/arrow/util/*.cc", "cpp/src/arrow/dataset/dataset.cc", "cpp/src/arrow/dataset/file_parquet.cc", + "cpp/src/arrow/filesystem/filesystem.cc", "cpp/src/arrow/filesystem/s3fs.cc", + "cpp/src/arrow/filesystem/util_internal.cc", "cpp/src/arrow/compute/exec/expression.cc", "cpp/src/arrow/vendored/musl/strptime.c", "cpp/src/arrow/vendored/optional.hpp", @@ -119,6 +125,7 @@ cc_library( includes = [ "cpp/src", "cpp/src/arrow/vendored/xxhash", + "cpp/src/generated", ], textual_hdrs = [ "cpp/src/arrow/vendored/xxhash/xxhash.c", From 5ab9e2b66585a6252aea4afcba80bf3e675fc25c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 3 Aug 2022 03:08:58 +0000 Subject: [PATCH 10/13] Update Signed-off-by: Yong Tang --- third_party/arrow.BUILD | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index 6fc9417f4..1cad316c7 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -61,17 +61,12 @@ cc_library( "cpp/src/arrow/compute/kernels/*.cc", "cpp/src/arrow/csv/*.cc", "cpp/src/arrow/dataset/*.cc", + "cpp/src/arrow/filesystem/*.cc", "cpp/src/arrow/io/*.cc", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/json/*.cc", "cpp/src/arrow/tensor/*.cc", "cpp/src/arrow/util/*.cc", - "cpp/src/arrow/dataset/dataset.cc", - "cpp/src/arrow/dataset/file_parquet.cc", - "cpp/src/arrow/filesystem/filesystem.cc", - "cpp/src/arrow/filesystem/s3fs.cc", - "cpp/src/arrow/filesystem/util_internal.cc", - "cpp/src/arrow/compute/exec/expression.cc", "cpp/src/arrow/vendored/musl/strptime.c", "cpp/src/arrow/vendored/optional.hpp", "cpp/src/arrow/vendored/string_view.hpp", From ff92b4b77a3311701d5c1794a0244e25fe691a41 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 3 Aug 2022 20:22:31 +0000 Subject: [PATCH 11/13] Update Signed-off-by: Yong Tang --- third_party/arrow.BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index 1cad316c7..1fd4ef9e8 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -116,6 +116,7 @@ cc_library( "PARQUET_STATIC", "PARQUET_EXPORT=", "WIN32_LEAN_AND_MEAN", + "URI_STATIC_BUILD", ], includes = [ "cpp/src", From c5716b9020edceb97d7ae6767a50d2d67e6fb011 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 3 Aug 2022 21:26:58 +0000 Subject: [PATCH 12/13] Update Signed-off-by: Yong Tang --- third_party/arrow.BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index 1fd4ef9e8..2606c2c9e 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -116,6 +116,7 @@ cc_library( "PARQUET_STATIC", "PARQUET_EXPORT=", "WIN32_LEAN_AND_MEAN", + "ARROW_DS_STATIC", "URI_STATIC_BUILD", ], includes = [ From 6f6a780637d6df1f1367cca45aec1dd42edbd864 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 3 Aug 2022 22:54:07 +0000 Subject: [PATCH 13/13] Update Signed-off-by: Yong Tang --- third_party/arrow.BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index 2606c2c9e..84e3f5cd6 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -104,6 +104,12 @@ cc_library( "cpp/src/parquet/parquet_version.h", ], copts = [], + linkopts = select({ + "@bazel_tools//src/conditions:windows": [ + "-DEFAULTLIB:Ole32.lib", + ], + "//conditions:default": [], + }), defines = [ "ARROW_WITH_BROTLI", "ARROW_WITH_SNAPPY",