diff --git a/.bazelversion b/.bazelversion index a0cd9f0cc..47b6be3fa 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -3.1.0 \ No newline at end of file +3.7.2 \ No newline at end of file diff --git a/.github/workflows/api.yml b/.github/workflows/api.yml index a601448fc..66a8e8d31 100644 --- a/.github/workflows/api.yml +++ b/.github/workflows/api.yml @@ -1,6 +1,9 @@ name: API Compatibility on: + push: + branches: + - master schedule: - cron: "0 12 * * *" @@ -13,12 +16,12 @@ jobs: name: macOS ${{ matrix.python }} + ${{ matrix.version }} runs-on: macos-latest strategy: + fail-fast: false matrix: python: ['3.8'] - version: ['tensorflow==2.4.0rc4:tensorflow-io-nightly', 'tf-nightly:tensorflow-io-nightly'] + version: ['tensorflow==2.4.0:tensorflow-io-nightly', 'tf-nightly:tensorflow-io==0.17.0', 'tf-nightly:tensorflow-io-nightly'] steps: - uses: actions/checkout@v2 - - uses: docker-practice/actions-setup-docker@v1 - uses: actions/setup-python@v1 with: python-version: ${{ matrix.python }} @@ -29,9 +32,7 @@ jobs: - name: Setup macOS run: | set -x -e - docker version bash -x -e tests/test_azure/start_azure.sh - bash -x -e tests/test_aws/aws_test.sh - name: Test macOS run: | set -x -e @@ -40,21 +41,23 @@ jobs: rm -rf tensorflow_io echo ${{ matrix.version }} | awk -F: '{print $1}' | xargs python -m pip install -U echo ${{ matrix.version }} | awk -F: '{print $2}' | xargs python -m pip install --no-deps -U - python -m pip install pytest-benchmark boto3 + python -m pip install pytest-benchmark boto3 google-cloud-storage==1.32.0 python -m pip freeze python -c 'import tensorflow as tf; print(tf.version.VERSION)' python -c 'import tensorflow_io as tfio; print(tfio.version.VERSION)' - python -m pytest -s -v tests/test_http_eager.py - python -m pytest -s -v tests/test_s3_eager.py + python -m pytest -s -v tests/test_http.py + python -m pytest -s -v tests/test_s3.py python -m pytest -s -v tests/test_azure.py + python -m pytest -s -v tests/test_gcs.py linux: name: Linux ${{ matrix.python }} + ${{ matrix.version }} runs-on: ubuntu-20.04 strategy: + fail-fast: false matrix: python: ['3.8'] - version: ['tensorflow==2.4.0rc4:tensorflow-io-nightly', 'tf-nightly:tensorflow-io-nightly'] + version: ['tensorflow==2.4.0:tensorflow-io-nightly', 'tf-nightly:tensorflow-io==0.17.0', 'tf-nightly:tensorflow-io-nightly'] steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v1 @@ -75,21 +78,23 @@ jobs: rm -rf tensorflow_io echo ${{ matrix.version }} | awk -F: '{print $1}' | xargs python -m pip install -U echo ${{ matrix.version }} | awk -F: '{print $2}' | xargs python -m pip install --no-deps -U - python -m pip install pytest-benchmark boto3 + python -m pip install pytest-benchmark boto3 google-cloud-storage==1.32.0 python -m pip freeze python -c 'import tensorflow as tf; print(tf.version.VERSION)' python -c 'import tensorflow_io as tfio; print(tfio.version.VERSION)' - python -m pytest -s -v tests/test_http_eager.py - python -m pytest -s -v tests/test_s3_eager.py + python -m pytest -s -v tests/test_http.py + python -m pytest -s -v tests/test_s3.py python -m pytest -s -v tests/test_azure.py + if [[ "${{ matrix.version }}" != "tf-nightly:tensorflow-io==0.17.0" ]]; then python -m pytest -s -v tests/test_gcs.py ; fi windows: name: Windows ${{ matrix.python }} + ${{ matrix.version }} runs-on: windows-latest strategy: + fail-fast: false matrix: python: ['3.8'] - version: ['tensorflow==2.4.0rc4:tensorflow-io-nightly', 'tf-nightly:tensorflow-io-nightly'] + version: ['tensorflow==2.4.0:tensorflow-io-nightly', 'tf-nightly:tensorflow-io==0.17.0', 'tf-nightly:tensorflow-io-nightly'] steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v1 @@ -115,4 +120,4 @@ jobs: python -m pip freeze python -c 'import tensorflow as tf; print(tf.version.VERSION)' python -c 'import tensorflow_io as tfio; print(tfio.version.VERSION)' - python -m pytest -s -v tests/test_http_eager.py -k remote + python -m pytest -s -v tests/test_http.py -k remote diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml new file mode 100644 index 000000000..68fa2255f --- /dev/null +++ b/.github/workflows/benchmarks.yml @@ -0,0 +1,81 @@ +name: API Performance Benchmarks + +on: + push: + branches: + - master + +jobs: + macos: + name: macOS ${{ matrix.python }} + ${{ matrix.version }} + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + python: ['3.8'] + version: ['tensorflow==2.4.0:tensorflow-io-nightly', 'tensorflow==2.4.0:tensorflow-io'] + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + - name: Setup macOS + run: | + set -x -e + python -m pip install -U wheel setuptools + python --version + - name: Benchmark on macOS + run: | + set -x -e + python --version + df -h + rm -rf tensorflow_io + echo ${{ matrix.version }} | awk -F: '{print $1}' | xargs python -m pip install -U + echo ${{ matrix.version }} | awk -F: '{print $2}' | xargs python -m pip install --no-deps -U + python -m pip install -q scikit-image pytest pytest-benchmark boto3 fastavro avro-python3 scikit-image pandas pyarrow==2.0.0 google-cloud-pubsub==2.1.0 google-cloud-bigquery-storage==1.1.0 google-cloud-bigquery==2.3.1 google-cloud-storage==1.32.0 + python -m pip freeze + python -c 'import tensorflow as tf; print(tf.version.VERSION)' + python -c 'import tensorflow_io as tfio; print(tfio.version.VERSION)' + python -m pytest --benchmark-only -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! \( -iname "test_*_v1.py" -o -iname "test_bigquery.py" \) \)) + + linux: + name: Linux ${{ matrix.python }} + ${{ matrix.version }} + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python: ['3.8'] + version: ['tensorflow==2.4.0:tensorflow-io-nightly', 'tensorflow==2.4.0:tensorflow-io'] + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + - name: Setup Linux + run: | + set -x -e + bash -x -e .github/workflows/build.space.sh + bash -x -e tests/test_sql/sql_test.sh + - name: Benchmark on Linux + run: | + set -x -e + python --version + df -h + rm -rf tensorflow_io + echo ${{ matrix.version }} | awk -F: '{print $1}' | xargs python -m pip install -U + echo ${{ matrix.version }} | awk -F: '{print $2}' | xargs python -m pip install --no-deps -U + python -m pip install -q scikit-image pytest pytest-benchmark boto3 fastavro avro-python3 scikit-image pandas pyarrow==2.0.0 google-cloud-pubsub==2.1.0 google-cloud-bigquery-storage==1.1.0 google-cloud-bigquery==2.3.1 google-cloud-storage==1.32.0 + python -m pip freeze + python -c 'import tensorflow as tf; print(tf.version.VERSION)' + python -c 'import tensorflow_io as tfio; print(tfio.version.VERSION)' + python -m pytest --benchmark-only --benchmark-json benchmark.json -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! \( -iname "test_*_v1.py" -o -iname "test_bigquery.py" \) \)) + - name: Store benchmark result + uses: rhysd/github-action-benchmark@v1 + with: + name: Tensorflow-IO Benchmarks + tool: 'pytest' + output-file-path: benchmark.json + github-token: ${{ secrets.GITHUB_TOKEN }} + auto-push: true diff --git a/.github/workflows/build.wheel.sh b/.github/workflows/build.wheel.sh index 1130a7cc9..c2a525de8 100755 --- a/.github/workflows/build.wheel.sh +++ b/.github/workflows/build.wheel.sh @@ -6,11 +6,11 @@ run_test() { entry=$1 CPYTHON_VERSION=$($entry -c 'import sys; print(str(sys.version_info[0])+str(sys.version_info[1]))') (cd wheelhouse && $entry -m pip install tensorflow_io-*-cp${CPYTHON_VERSION}-*.whl) - $entry -m pip install -q pytest pytest-benchmark boto3 fastavro avro-python3 scikit-image pandas pyarrow==0.16.0 google-cloud-pubsub==2.1.0 google-cloud-bigquery-storage==1.1.0 google-cloud-bigquery==2.3.1 google-cloud-storage==1.32.0 - (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! \( -iname "test_*_eager.py" \) \))) - (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_*_eager.py" ! \( -iname "test_bigquery_eager.py" \) \))) - # GRPC and test_bigquery_eager tests have to be executed separately because of https://github.com/grpc/grpc/issues/20034 - (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_bigquery_eager.py" \))) + $entry -m pip install -q pytest pytest-benchmark boto3 fastavro avro-python3 scikit-image pandas pyarrow==3.0.0 google-cloud-pubsub==2.1.0 google-cloud-bigtable==1.6.0 google-cloud-bigquery-storage==1.1.0 google-cloud-bigquery==2.3.1 google-cloud-storage==1.32.0 PyYAML==5.3.1 + (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_*_v1.py" \))) + (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! \( -iname "test_*_v1.py" -o -iname "test_bigquery.py" \) \))) + # GRPC and test_bigquery tests have to be executed separately because of https://github.com/grpc/grpc/issues/20034 + (cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_bigquery.py" \))) } PYTHON_VERSION=python diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fbe84c670..87acc8d51 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,6 +9,9 @@ on: - master env: + GCP_CREDS: ${{ secrets.GCP_CREDS }} + REPO_NAME: ${{ github.repository }} + EVENT_NAME: ${{ github.event_name }} BAZEL_OPTIMIZATION: --copt=-msse4.2 --copt=-mavx --compilation_mode=opt jobs: @@ -49,54 +52,42 @@ jobs: steps: - uses: actions/checkout@v2 - run: | + if [[ "${EVENT_NAME}" == "push" && "${REPO_NAME}" == "tensorflow/io" ]]; then + printf '%s\n' "${GCP_CREDS}" >service_account_creds.json + export BAZEL_OPTIMIZATION="--remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=true --google_credentials=service_account_creds.json" + else + export BAZEL_OPTIMIZATION="--remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=false" + fi set -x -e echo "Bring /usr/bin to front as GitHub does not use system python3 by default" export PATH=/usr/bin:$PATH echo $PATH python3 --version python3 -c 'import site; print(site.getsitepackages())' - python3 .github/workflows/build.instruction.py --sudo=true README.md "#### macOS" > source.sh + python3 .github/workflows/build.instruction.py --sudo=true docs/development.md "#### macOS" > source.sh bash -x -e source.sh python3 -c 'import tensorflow as tf; print(tf.version.VERSION)' - ubuntu-2004: - name: Ubuntu 20.04 + linux: + name: Linux runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - run: | - set -x -e - bash -x -e .github/workflows/build.space.sh - python3 .github/workflows/build.instruction.py README.md "##### Ubuntu 20.04" > source.sh - cat source.sh - docker run -i --rm -v $PWD:/v -w /v --net=host ubuntu:20.04 \ - bash -x -e source.sh - - centos-8: - name: CentOS 8 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - run: | - set -x -e - bash -x -e .github/workflows/build.space.sh - python3 .github/workflows/build.instruction.py README.md "##### CentOS 8" > source.sh - cat source.sh - docker run -i --rm -v $PWD:/v -w /v --net=host centos:8 \ - bash -x -e source.sh - - centos-7: - name: CentOS 7 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - run: | + - name: Ubuntu 20.04 + run: | + if [[ "${EVENT_NAME}" == "push" && "${REPO_NAME}" == "tensorflow/io" ]]; then + printf '%s\n' "${GCP_CREDS}" >service_account_creds.json + export BAZEL_OPTIMIZATION="--remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=true --google_credentials=service_account_creds.json" + else + export BAZEL_OPTIMIZATION="--remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=false" + fi set -x -e bash -x -e .github/workflows/build.space.sh - python3 .github/workflows/build.instruction.py README.md "##### CentOS 7" > source.sh + python3 .github/workflows/build.instruction.py docs/development.md "##### Ubuntu 20.04" > source.sh cat source.sh - docker run -i --rm -v $PWD:/v -w /v --net=host centos:7 \ - bash -x -e source.sh + docker run -i --rm -v $PWD:/v -w /v --net=host \ + -e BAZEL_OPTIMIZATION="${BAZEL_OPTIMIZATION}" \ + ubuntu:20.04 bash -x -e source.sh macos-bazel: name: Bazel macOS @@ -105,6 +96,12 @@ jobs: - uses: actions/checkout@v2 - name: Bazel on macOS run: | + if [[ "${EVENT_NAME}" == "push" && "${REPO_NAME}" == "tensorflow/io" ]]; then + printf '%s\n' "${GCP_CREDS}" >service_account_creds.json + export BAZEL_OPTIMIZATION="${BAZEL_OPTIMIZATION} --remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=true --google_credentials=service_account_creds.json" + else + export BAZEL_OPTIMIZATION="${BAZEL_OPTIMIZATION} --remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=false" + fi set -x -e echo "Bring /usr/bin to front as GitHub does not use system python3 by default" export PATH=/usr/bin:$PATH @@ -135,7 +132,7 @@ jobs: runs-on: macos-latest strategy: matrix: - python: ['3.6', '3.7', '3.8'] + python: ['3.6', '3.7', '3.8', '3.9'] steps: - uses: actions/checkout@v2 - uses: actions/download-artifact@v1 @@ -150,7 +147,7 @@ jobs: set -x -e python -m pip install -U wheel setuptools python --version - python setup.py --data bazel-bin -q bdist_wheel --plat-name macosx_10_13_x86_64 + python setup.py --data bazel-bin -q bdist_wheel --plat-name macosx_10_14_x86_64 - name: Auditwheel ${{ matrix.python }} macOS run: | set -x -e @@ -172,10 +169,9 @@ jobs: runs-on: macos-latest strategy: matrix: - python: ['3.7', '3.8'] + python: ['3.7', '3.8', '3.9'] steps: - uses: actions/checkout@v2 - - uses: docker-practice/actions-setup-docker@v1 - uses: actions/download-artifact@v1 with: name: ${{ runner.os }}-${{ matrix.python }}-wheel @@ -190,14 +186,10 @@ jobs: - name: Setup ${{ matrix.python }} macOS run: | set -x -e - docker version bash -x -e tests/test_kafka/kafka_test.sh bash -x -e tests/test_azure/start_azure.sh - bash -x -e tests/test_pubsub/pubsub_test.sh - bash -x -e tests/test_aws/aws_test.sh + bash -x -e tests/test_gcloud/test_pubsub_bigtable.sh bash -x -e tests/test_pulsar/pulsar_test.sh - bash -x -e tests/test_elasticsearch/elasticsearch_test.sh start - bash -x -e tests/test_mongodb/mongodb_test.sh start - name: Install ${{ matrix.python }} macOS run: | set -x -e @@ -217,6 +209,12 @@ jobs: - uses: actions/checkout@v2 - name: Bazel on Linux run: | + if [[ "${EVENT_NAME}" == "push" && "${REPO_NAME}" == "tensorflow/io" ]]; then + printf '%s\n' "${GCP_CREDS}" >service_account_creds.json + export BAZEL_OPTIMIZATION="${BAZEL_OPTIMIZATION} --remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=true --google_credentials=service_account_creds.json" + else + export BAZEL_OPTIMIZATION="${BAZEL_OPTIMIZATION} --remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=false" + fi set -x -e bash -x -e .github/workflows/build.space.sh BAZEL_OS=$(uname | tr '[:upper:]' '[:lower:]') @@ -239,7 +237,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ['3.6', '3.7', '3.8'] + python: ['3.6', '3.7', '3.8', '3.9'] steps: - uses: actions/checkout@v2 - uses: actions/download-artifact@v1 @@ -268,16 +266,10 @@ jobs: linux-test: name: Test ${{ matrix.python }} Linux needs: linux-wheel - runs-on: ${{ matrix.os }} + runs-on: ubuntu-20.04 strategy: matrix: - os: [ubuntu-18.04, ubuntu-20.04] - python: ['3.7', '3.8'] - exclude: - - os: ubuntu-18.04 - python: '3.8' - - os: ubuntu-20.04 - python: '3.7' + python: ['3.7', '3.8', '3.9'] steps: - uses: actions/checkout@v2 - uses: actions/download-artifact@v1 @@ -290,7 +282,7 @@ jobs: bash -x -e .github/workflows/build.space.sh bash -x -e tests/test_kafka/kafka_test.sh bash -x -e tests/test_aws/aws_test.sh - bash -x -e tests/test_pubsub/pubsub_test.sh + bash -x -e tests/test_gcloud/test_pubsub_bigtable.sh bash -x -e tests/test_prometheus/prometheus_test.sh start bash -x -e tests/test_elasticsearch/elasticsearch_test.sh start bash -x -e tests/test_mongodb/mongodb_test.sh start @@ -304,7 +296,7 @@ jobs: set -x -e df -h docker run -i --rm -v $PWD:/v -w /v --net=host \ - buildpack-deps:$(echo ${{ matrix.os }} | awk -F- '{print $2}') \ + buildpack-deps:20.04 \ bash -x -e .github/workflows/build.wheel.sh python${{ matrix.python }} windows-bazel: @@ -320,6 +312,18 @@ jobs: BAZEL_VC: "C:/Program Files (x86)/Microsoft Visual Studio/2019/Enterprise/VC/" shell: cmd run: | + if "%EVENT_NAME%" == "push" ( + if "%REPO_NAME%" == "tensorflow/io" ( + printenv GCP_CREDS > service_account_creds.json + set "BAZEL_OPTIMIZATION=--remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=true --google_credentials=service_account_creds.json" + ) else ( + echo %REPO_NAME% + set "BAZEL_OPTIMIZATION=--remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=false" + ) + ) else ( + echo %EVENT_NAME% + set "BAZEL_OPTIMIZATION=--remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=false" + ) @echo on set /P BAZEL_VERSION=< .bazelversion curl -sSL -o bazel.exe https://github.com/bazelbuild/bazel/releases/download/%BAZEL_VERSION%/bazel-%BAZEL_VERSION%-windows-x86_64.exe @@ -333,7 +337,7 @@ jobs: python3 setup.py --package-version | xargs python3 -m pip install python3 tools/build/configure.py cat .bazelrc - bazel build -s --verbose_failures //tensorflow_io/core:python/ops/libtensorflow_io.so //tensorflow_io/core:python/ops/libtensorflow_io_plugins.so + bazel build -s --verbose_failures %BAZEL_OPTIMIZATION% //tensorflow_io/core:python/ops/libtensorflow_io.so //tensorflow_io/core:python/ops/libtensorflow_io_plugins.so - uses: actions/upload-artifact@v1 with: name: ${{ runner.os }}-bazel-bin @@ -345,7 +349,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python: ['3.6', '3.7', '3.8'] + python: ['3.6', '3.7', '3.8', '3.9'] steps: - uses: actions/checkout@v2 - uses: actions/download-artifact@v1 @@ -374,7 +378,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python: ['3.7', '3.8'] + python: ['3.7', '3.8', '3.9'] steps: - uses: actions/checkout@v2 - uses: actions/download-artifact@v1 @@ -405,14 +409,13 @@ jobs: python --version python -m pip install -U pytest-benchmark rm -rf tensorflow_io - (cd tests && python -m pytest -s -v test_lmdb_eager.py) - (python -m pytest -s -v test_image_eager.py -k "webp or ppm or bmp or bounding or exif or hdr or openexr or tiff or avif") - (python -m pytest -s -v test_serialization_eager.py) - (python -m pytest -s -v test_io_dataset_eager.py -k "numpy or hdf5 or audio or to_file") - (python -m pytest -s -v test_http_eager.py) + (cd tests && python -m pytest -s -v test_lmdb.py) + (python -m pytest -s -v test_image.py -k "webp or ppm or bmp or bounding or exif or hdr or openexr or tiff or avif") + (python -m pytest -s -v test_serialization.py) + (python -m pytest -s -v test_io_dataset.py -k "numpy or hdf5 or audio or to_file") + (python -m pytest -s -v test_http.py) python -m pip install google-cloud-bigquery-storage==0.7.0 google-cloud-bigquery==1.22.0 fastavro - (python -m pytest -s -v test_bigquery_eager.py) - (python -m pytest -s -v test_dicom_eager.py) + (python -m pytest -s -v test_bigquery.py) (python -m pytest -s -v test_dicom.py) release: @@ -433,6 +436,10 @@ jobs: with: name: macOS-3.8-wheel path: macOS-3.8-wheel + - uses: actions/download-artifact@v1 + with: + name: macOS-3.9-wheel + path: macOS-3.9-wheel - uses: actions/download-artifact@v1 with: name: Linux-3.6-wheel @@ -445,6 +452,10 @@ jobs: with: name: Linux-3.8-wheel path: Linux-3.8-wheel + - uses: actions/download-artifact@v1 + with: + name: Linux-3.9-wheel + path: Linux-3.9-wheel - uses: actions/download-artifact@v1 with: name: Windows-3.6-wheel @@ -457,18 +468,25 @@ jobs: with: name: Windows-3.8-wheel path: Windows-3.8-wheel + - uses: actions/download-artifact@v1 + with: + name: Windows-3.9-wheel + path: Windows-3.9-wheel - run: | set -e -x mkdir -p wheelhouse cp macOS-3.6-wheel/*.whl wheelhouse/ cp macOS-3.7-wheel/*.whl wheelhouse/ cp macOS-3.8-wheel/*.whl wheelhouse/ + cp macOS-3.9-wheel/*.whl wheelhouse/ cp Linux-3.6-wheel/*.whl wheelhouse/ cp Linux-3.7-wheel/*.whl wheelhouse/ cp Linux-3.8-wheel/*.whl wheelhouse/ + cp Linux-3.9-wheel/*.whl wheelhouse/ cp Windows-3.6-wheel/*.whl wheelhouse/ cp Windows-3.7-wheel/*.whl wheelhouse/ cp Windows-3.8-wheel/*.whl wheelhouse/ + cp Windows-3.9-wheel/*.whl wheelhouse/ ls -la wheelhouse/ sha256sum wheelhouse/*.whl - uses: actions/upload-artifact@v1 @@ -517,11 +535,11 @@ jobs: macos-nightly: name: Nightly ${{ matrix.python }} macOS if: github.event_name == 'push' - needs: [build-number, release] + needs: [build-number, macos-wheel] runs-on: macos-latest strategy: matrix: - python: ['3.6', '3.7', '3.8'] + python: ['3.6', '3.7', '3.8', '3.9'] steps: - uses: actions/download-artifact@v1 with: @@ -541,7 +559,7 @@ jobs: set -x -e python -m pip install -U wheel setuptools python --version - python setup.py --data bazel-bin -q bdist_wheel --plat-name macosx_10_13_x86_64 --nightly $BUILD_NUMBER + python setup.py --data bazel-bin -q bdist_wheel --plat-name macosx_10_14_x86_64 --nightly $BUILD_NUMBER - name: Auditwheel ${{ matrix.python }} macOS run: | set -x -e @@ -560,12 +578,11 @@ jobs: linux-nightly: name: Nightly ${{ matrix.python }} Linux if: github.event_name == 'push' - needs: [build-number, release] - runs-on: ${{ matrix.os }} + needs: [build-number, linux-wheel] + runs-on: ubuntu-20.04 strategy: matrix: - os: [ubuntu-18.04] - python: ['3.6', '3.7', '3.8'] + python: ['3.6', '3.7', '3.8', '3.9'] steps: - uses: actions/download-artifact@v1 with: @@ -599,11 +616,11 @@ jobs: windows-nightly: name: Nightly ${{ matrix.python }} Windows if: github.event_name == 'push' - needs: [build-number, release] + needs: [build-number, windows-wheel] runs-on: windows-latest strategy: matrix: - python: ['3.6', '3.7', 3.8] + python: ['3.6', '3.7', '3.8', '3.9'] steps: - uses: actions/download-artifact@v1 with: @@ -649,6 +666,10 @@ jobs: with: name: macOS-3.8-nightly path: macOS-3.8-nightly + - uses: actions/download-artifact@v1 + with: + name: macOS-3.9-nightly + path: macOS-3.9-nightly - uses: actions/download-artifact@v1 with: name: Linux-3.6-nightly @@ -661,6 +682,10 @@ jobs: with: name: Linux-3.8-nightly path: Linux-3.8-nightly + - uses: actions/download-artifact@v1 + with: + name: Linux-3.9-nightly + path: Linux-3.9-nightly - uses: actions/download-artifact@v1 with: name: Windows-3.6-nightly @@ -673,18 +698,25 @@ jobs: with: name: Windows-3.8-nightly path: Windows-3.8-nightly + - uses: actions/download-artifact@v1 + with: + name: Windows-3.9-nightly + path: Windows-3.9-nightly - run: | set -e -x mkdir -p dist cp macOS-3.6-nightly/*.whl dist/ cp macOS-3.7-nightly/*.whl dist/ cp macOS-3.8-nightly/*.whl dist/ + cp macOS-3.9-nightly/*.whl dist/ cp Linux-3.6-nightly/*.whl dist/ cp Linux-3.7-nightly/*.whl dist/ cp Linux-3.8-nightly/*.whl dist/ + cp Linux-3.9-nightly/*.whl dist/ cp Windows-3.6-nightly/*.whl dist/ cp Windows-3.7-nightly/*.whl dist/ cp Windows-3.8-nightly/*.whl dist/ + cp Windows-3.9-nightly/*.whl dist/ ls -la dist/ sha256sum dist/*.whl - uses: pypa/gh-action-pypi-publish@master diff --git a/.kokorun/io_cpu.sh b/.kokorun/io_cpu.sh index 9c62f914c..ab4a43480 100755 --- a/.kokorun/io_cpu.sh +++ b/.kokorun/io_cpu.sh @@ -48,7 +48,7 @@ docker --version export PYTHON_VERSION=3.8 export BAZEL_VERSION=$(cat .bazelversion) -export BAZEL_OPTIMIZATION="--copt=-msse4.2 --copt=-mavx --compilation_mode=opt" +export BAZEL_OPTIMIZATION="--copt=-msse4.2 --copt=-mavx --compilation_mode=opt --remote_cache=https://storage.googleapis.com/tensorflow-sigs-io --remote_upload_local_results=false" export BAZEL_OS=$(uname | tr '[:upper:]' '[:lower:]') docker run -i --rm -v $PWD:/v -w /v --net=host \ @@ -75,7 +75,7 @@ bash -x -e tests/test_gcloud/test_gcs.sh gcs-emulator bash -x -e tests/test_kafka/kafka_test.sh bash -x -e tests/test_pulsar/pulsar_test.sh bash -x -e tests/test_aws/aws_test.sh -bash -x -e tests/test_pubsub/pubsub_test.sh pubsub +bash -x -e tests/test_gcloud/test_pubsub_bigtable.sh bash -x -e tests/test_prometheus/prometheus_test.sh start bash -x -e tests/test_azure/start_azure.sh bash -x -e tests/test_sql/sql_test.sh sql diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e8e02ea5f..2c4cc2378 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,7 +1,6 @@ # Contributing -Tensorflow I/O project welcomes all kinds of contributions, be it code changes, bug-fixes or documentation changes. -This guide should help you in taking care of some basic setups & code conventions. +Tensorflow I/O project welcomes all kinds of contributions, be it code changes, bug-fixes or documentation changes. This guide should help you in taking care of some basic setups & code conventions. ## Contributor License Agreement @@ -17,8 +16,7 @@ again. ## Coding Style -Tensorflow project wide code style guidelines can be followed at [TensorFlow Style Guide - Conventions](https://www.tensorflow.org/community/contribute/code_style) and Tensorflow I/O project specific -code style guidelines can be followed at [Style Guide](STYLE_GUIDE.md). +Tensorflow project wide code style guidelines can be followed at [TensorFlow Style Guide - Conventions](https://www.tensorflow.org/community/contribute/code_style) and Tensorflow I/O project specific code style guidelines can be followed at [Style Guide](STYLE_GUIDE.md). ## Code Reviews diff --git a/README.md b/README.md index 321ed0139..82ab09ade 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,11 @@ import tensorflow as tf import tensorflow_io as tfio # Read the MNIST data into the IODataset. +dataset_url = "http://storage.googleapis.com/cvdf-datasets/mnist/" d_train = tfio.IODataset.from_mnist( - 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', - 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz') + dataset_url + "train-images-idx3-ubyte.gz", + dataset_url + "train-labels-idx1-ubyte.gz", +) # Shuffle the elements of the dataset. d_train = d_train.shuffle(buffer_size=1024) @@ -38,17 +40,19 @@ d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), d_train = d_train.batch(32) # Build the model. -model = tf.keras.models.Sequential([ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(512, activation=tf.nn.relu), - tf.keras.layers.Dropout(0.2), - tf.keras.layers.Dense(10, activation=tf.nn.softmax) -]) +model = tf.keras.models.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(512, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation=tf.nn.softmax), + ] +) # Compile the model. -model.compile(optimizer='adam', - loss='sparse_categorical_crossentropy', - metrics=['accuracy']) +model.compile( + optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] +) # Fit the model. model.fit(d_train, epochs=5, steps_per_epoch=200) @@ -79,6 +83,8 @@ People who are a little more adventurous can also try our nightly binaries: $ pip install tensorflow-io-nightly ``` +### Docker Images + In addition to the pip packages, the docker images can be used to quickly get started. For stable builds: @@ -132,343 +138,20 @@ of releases [here](https://github.com/tensorflow/io/releases). | 0.2.0 | 1.12.0 | Jan 29, 2019 | | 0.1.0 | 1.12.0 | Dec 16, 2018 | -## Development - -### IDE Setup - -For instructions on how to configure Visual Studio Code for developing TensorFlow I/O, please refer to -https://github.com/tensorflow/io/blob/master/docs/vscode.md - -### Lint - -TensorFlow I/O's code conforms to Bazel Buildifier, Clang Format, Black, and Pyupgrade. -Please use the following command to check the source code and identify lint issues: -``` -$ bazel run //tools/lint:check -``` - -For Bazel Buildifier and Clang Format, the following command will automatically identify -and fix any lint errors: -``` -$ bazel run //tools/lint:lint -``` - -Alternatively, if you only want to perform lint check using individual linters, -then you can selectively pass `black`, `pyupgrade`, `bazel`, or `clang` to the above commands. - -For example, a `black` specific lint check can be done using: -``` -$ bazel run //tools/lint:check -- black -``` - -Lint fix using Bazel Buildifier and Clang Format can be done using: -``` -$ bazel run //tools/lint:lint -- bazel clang -``` - -Lint check using `black` and `pyupgrade` for an individual python file can be done using: -``` -$ bazel run //tools/lint:check -- black pyupgrade -- tensorflow_io/core/python/ops/version_ops.py -``` - -Lint fix an individual python file with black and pyupgrade using: -``` -$ bazel run //tools/lint:lint -- black pyupgrade -- tensorflow_io/core/python/ops/version_ops.py -``` - - -### Python - -#### macOS - -On macOS Catalina 10.15.7, it is possible to build tensorflow-io with -system provided python 3.8.2. Both `tensorflow` and `bazel` are needed. - -NOTE: The system default python 3.8.2 on macOS 10.15.7 will cause `regex` installation -error caused by compiler option of `-arch arm64 -arch x86_64` (similar to the issue -mentioned in https://github.com/giampaolo/psutil/issues/1832). To overcome this issue -`export ARCHFLAGS="-arch x86_64"` will be needed to remove arm64 build option. - -```sh -#!/usr/bin/env bash - -# Disable arm64 build by specifying only x86_64 arch. -# Only needed for macOS's system default python 3.8.2 on macOS 10.15.7 -export ARCHFLAGS="-arch x86_64" -# Use following command to check if Xcode is correctly installed: -xcodebuild -version +## Performance Benchmarking -# Show macOS's default python3 -python3 --version - -# Install Bazel version specified in .bazelversion -curl -OL https://github.com/bazelbuild/bazel/releases/download/$(cat .bazelversion)/bazel-$(cat .bazelversion)-installer-darwin-x86_64.sh -sudo bash -x -e bazel-$(cat .bazelversion)-installer-darwin-x86_64.sh - -# Install tensorflow and configure bazel -sudo ./configure.sh - -# Build shared libraries -bazel build -s --verbose_failures //tensorflow_io/... - -# Once build is complete, shared libraries will be available in -# `bazel-bin/tensorflow_io/core/python/ops/` and it is possible -# to run tests with `pytest`, e.g.: -sudo python3 -m pip install pytest -TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization_eager.py -``` - -NOTE: When running pytest, `TFIO_DATAPATH=bazel-bin` has to be passed so that python can utilize the generated shared libraries after the build process. - -##### Troubleshoot - -If Xcode is installed, but `$ xcodebuild -version` is not displaying the expected output, you might need to enable Xcode command line with the command: - -`$ xcode-select -s /Applications/Xcode.app/Contents/Developer`. - -A terminal restart might be required for the changes to take effect. - -Sample output: - -``` -$ xcodebuild -version -Xcode 11.6 -Build version 11E708 -``` - - -#### Linux - -Development of tensorflow-io on Linux is similar to macOS. The required packages -are gcc, g++, git, bazel, and python 3. Newer versions of gcc or python, other than the default system installed -versions might be required though. - -##### Ubuntu 20.04 - -Ubuntu 20.04 requires gcc/g++, git, and python 3. The following will install dependencies and build -the shared libraries on Ubuntu 20.04: -```sh -#!/usr/bin/env bash - -# Install gcc/g++, git, unzip/curl (for bazel), and python3 -sudo apt-get -y -qq update -sudo apt-get -y -qq install gcc g++ git unzip curl python3-pip - -# Install Bazel version specified in .bazelversion -curl -sSOL https://github.com/bazelbuild/bazel/releases/download/$(cat .bazelversion)/bazel-$(cat .bazelversion)-installer-linux-x86_64.sh -sudo bash -x -e bazel-$(cat .bazelversion)-installer-linux-x86_64.sh - -# Upgrade pip -sudo python3 -m pip install -U pip - -# Install tensorflow and configure bazel -sudo ./configure.sh - -# Build shared libraries -bazel build -s --verbose_failures //tensorflow_io/... - -# Once build is complete, shared libraries will be available in -# `bazel-bin/tensorflow_io/core/python/ops/` and it is possible -# to run tests with `pytest`, e.g.: -sudo python3 -m pip install pytest -TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization_eager.py -``` - -##### CentOS 8 - -CentOS 8 requires gcc/g++, git, and python 3. The following will install dependencies and build -the shared libraries on CentOS 8: -```sh -#!/usr/bin/env bash - -# Install gcc/g++, git, unzip/which (for bazel), and python3 -sudo yum install -y python3 python3-devel gcc gcc-c++ git unzip which make - -# Install Bazel version specified in .bazelversion -curl -sSOL https://github.com/bazelbuild/bazel/releases/download/$(cat .bazelversion)/bazel-$(cat .bazelversion)-installer-linux-x86_64.sh -sudo bash -x -e bazel-$(cat .bazelversion)-installer-linux-x86_64.sh - -# Upgrade pip -sudo python3 -m pip install -U pip - -# Install tensorflow and configure bazel -sudo ./configure.sh - -# Build shared libraries -bazel build -s --verbose_failures //tensorflow_io/... - -# Once build is complete, shared libraries will be available in -# `bazel-bin/tensorflow_io/core/python/ops/` and it is possible -# to run tests with `pytest`, e.g.: -sudo python3 -m pip install pytest -TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization_eager.py -``` - -##### CentOS 7 - -On CentOS 7, the default python and gcc version are too old to build tensorflow-io's shared -libraries (.so). The gcc provided by Developer Toolset and rh-python36 should be used instead. -Also, the libstdc++ has to be linked statically to avoid discrepancy of libstdc++ installed on -CentOS vs. newer gcc version by devtoolset. - -Furthermore, a special flag `--//tensorflow_io/core:static_build` has to be passed to Bazel -in order to avoid duplication of symbols in statically linked libraries for file system -plugins. - -The following will install bazel, devtoolset-9, rh-python36, and build the shared libraries: -```sh -#!/usr/bin/env bash - -# Install centos-release-scl, then install gcc/g++ (devtoolset), git, and python 3 -sudo yum install -y centos-release-scl -sudo yum install -y devtoolset-9 git rh-python36 make - -# Install Bazel version specified in .bazelversion -curl -sSOL https://github.com/bazelbuild/bazel/releases/download/$(cat .bazelversion)/bazel-$(cat .bazelversion)-installer-linux-x86_64.sh -sudo bash -x -e bazel-$(cat .bazelversion)-installer-linux-x86_64.sh - -# Upgrade pip -scl enable rh-python36 devtoolset-9 \ - 'python3 -m pip install -U pip' - -# Install tensorflow and configure bazel with rh-python36 -scl enable rh-python36 devtoolset-9 \ - './configure.sh' - -# Build shared libraries, notice the passing of --//tensorflow_io/core:static_build -BAZEL_LINKOPTS="-static-libstdc++ -static-libgcc" BAZEL_LINKLIBS="-lm -l%:libstdc++.a" \ - scl enable rh-python36 devtoolset-9 \ - 'bazel build -s --verbose_failures --//tensorflow_io/core:static_build //tensorflow_io/...' - -# Once build is complete, shared libraries will be available in -# `bazel-bin/tensorflow_io/core/python/ops/` and it is possible -# to run tests with `pytest`, e.g.: -scl enable rh-python36 devtoolset-9 \ - 'python3 -m pip install pytest' - -TFIO_DATAPATH=bazel-bin \ - scl enable rh-python36 devtoolset-9 \ - 'python3 -m pytest -s -v tests/test_serialization_eager.py' -``` - -#### Python Wheels - -It is possible to build python wheels after bazel build is complete with the following command: -``` -$ python3 setup.py bdist_wheel --data bazel-bin -``` -The .whl file will be available in dist directory. Note the bazel binary directory `bazel-bin` -has to be passed with `--data` args in order for setup.py to locate the necessary share objects, -as `bazel-bin` is outside of the `tensorflow_io` package directory. - -Alternatively, source install could be done with: -``` -$ TFIO_DATAPATH=bazel-bin python3 -m pip install . -``` -with `TFIO_DATAPATH=bazel-bin` passed for the same reason. - -Note installing with `-e` is different from the above. The -``` -$ TFIO_DATAPATH=bazel-bin python3 -m pip install -e . -``` -will not install shared object automatically even with `TFIO_DATAPATH=bazel-bin`. Instead, -`TFIO_DATAPATH=bazel-bin` has to be passed everytime the program is run after the install: -``` -$ TFIO_DATAPATH=bazel-bin python3 - ->>> import tensorflow_io as tfio ->>> ... -``` - -#### Docker - -For Python development, a reference Dockerfile [here](tools/docker/devel.Dockerfile) can be -used to build the TensorFlow I/O package (`tensorflow-io`) from source. Additionally, the -pre-built devel images can be used as well: -```sh -# Pull (if necessary) and start the devel container -$ docker run -it --rm --name tfio-dev --net=host -v ${PWD}:/v -w /v tfsigio/tfio:latest-devel bash - -# Inside the docker container, ./configure.sh will install TensorFlow or use existing install -(tfio-dev) root@docker-desktop:/v$ ./configure.sh - -# Clean up exisiting bazel build's (if any) -(tfio-dev) root@docker-desktop:/v$ rm -rf bazel-* - -# Build TensorFlow I/O C++. For compilation optimization flags, the default (-march=native) -# optimizes the generated code for your machine's CPU type. -# Reference: https://www.tensorflow.orginstall/source#configuration_options). - -# NOTE: Based on the available resources, please change the number of job workers to: -# -j 4/8/16 to prevent bazel server terminations and resource oriented build errors. - -(tfio-dev) root@docker-desktop:/v$ bazel build -j 8 --copt=-msse4.2 --copt=-mavx --compilation_mode=opt --verbose_failures --test_output=errors --crosstool_top=//third_party/toolchains/gcc7_manylinux2010:toolchain //tensorflow_io/... - - -# Run tests with PyTest, note: some tests require launching additional containers to run (see below) -(tfio-dev) root@docker-desktop:/v$ pytest -s -v tests/ -# Build the TensorFlow I/O package -(tfio-dev) root@docker-desktop:/v$ python setup.py bdist_wheel -``` - -A package file `dist/tensorflow_io-*.whl` will be generated after a build is successful. - -NOTE: When working in the Python development container, an environment variable -`TFIO_DATAPATH` is automatically set to point tensorflow-io to the shared C++ -libraries built by Bazel to run `pytest` and build the `bdist_wheel`. Python -`setup.py` can also accept `--data [path]` as an argument, for example -`python setup.py --data bazel-bin bdist_wheel`. - -NOTE: While the tfio-dev container gives developers an easy to work with -environment, the released whl packages are built differently due to manylinux2010 -requirements. Please check [Build Status and CI] section for more details -on how the released whl packages are generated. - -#### Starting Test Containers - -Some tests require launching a test container before running. In order -to run all tests, execute the following commands: - -```sh -$ bash -x -e tests/test_ignite/start_ignite.sh -$ bash -x -e tests/test_kafka/kafka_test.sh -$ bash -x -e tests/test_kinesis/kinesis_test.sh -``` - -### R - -We provide a reference Dockerfile [here](R-package/scripts/Dockerfile) for you -so that you can use the R package directly for testing. You can build it via: -```sh -$ docker build -t tfio-r-dev -f R-package/scripts/Dockerfile . -``` - -Inside the container, you can start your R session, instantiate a `SequenceFileDataset` -from an example [Hadoop SequenceFile](https://wiki.apache.org/hadoop/SequenceFile) -[string.seq](R-package/tests/testthat/testdata/string.seq), and then use any [transformation functions](https://tensorflow.rstudio.com/tools/tfdatasets/articles/introduction.html#transformations) provided by [tfdatasets package](https://tensorflow.rstudio.com/tools/tfdatasets/) on the dataset like the following: - -```r -library(tfio) -dataset <- sequence_file_dataset("R-package/tests/testthat/testdata/string.seq") %>% - dataset_repeat(2) - -sess <- tf$Session() -iterator <- make_iterator_one_shot(dataset) -next_batch <- iterator_get_next(iterator) - -until_out_of_range({ - batch <- sess$run(next_batch) - print(batch) -}) -``` +We use [github-pages](https://tensorflow.github.io/io/dev/bench/) to document the results of API performance benchmarks. The benchmark job is triggered on every commit to `master` branch and +facilitates tracking performance w.r.t commits. ## Contributing Tensorflow I/O is a community led open source project. As such, the project -depends on public contributions, bug-fixes, and documentation. Please -see [contribution guidelines](CONTRIBUTING.md) for a guide on how to -contribute. +depends on public contributions, bug-fixes, and documentation. Please see: + +- [contribution guidelines](CONTRIBUTING.md) for a guide on how to contribute. +- [development doc](docs/development.md) for instructions on the development environment setup. +- [tutorials](docs/tutorials) for a list of tutorial notebooks and instructions on how to write one. ### Build Status and CI @@ -500,7 +183,7 @@ It takes some time to build, but once complete, there will be python `3.5`, `3.6`, `3.7` compatible whl packages available in `wheelhouse` directory. -On macOS, the same command could be used though the script expect `python` in shell +On macOS, the same command could be used. However, the script expects `python` in shell and will only generate a whl package that matches the version of `python` in shell. If you want to build a whl package for a specific python then you have to alias this version of python to `python` in shell. See [.github/workflows/build.yml](.github/workflows/build.yml) @@ -512,17 +195,16 @@ TensorFlow I/O uses both GitHub Workflows and Google CI (Kokoro) for continuous GitHub Workflows is used for macOS build and test. Kokoro is used for Linux build and test. Again, because of the manylinux2010 requirement, on Linux whl packages are always built with Ubuntu 16.04 + Developer Toolset 7. Tests are done on a variatiy of systems -with different python version to ensure a good coverage: +with different python3 versions to ensure a good coverage: -| Python | Ubuntu 16.04| Ubuntu 18.04 | macOS + osx9 | -| ------- | ----- | ------- | ------- | -| 2.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | -| 3.5 | :heavy_check_mark: | N/A | :heavy_check_mark: | -| 3.6 | N/A | :heavy_check_mark: | :heavy_check_mark: | -| 3.7 | N/A | :heavy_check_mark: | N/A | +| Python | Ubuntu 18.04| Ubuntu 20.04 | macOS + osx9 | Windows-2019 | +| ------- | ----- | ------- | ------- | --------- | +| 2.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | N/A | +| 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | -TensorFlow I/O has integrations with may systems and cloud vendors such as +TensorFlow I/O has integrations with many systems and cloud vendors such as Prometheus, Apache Kafka, Apache Ignite, Google Cloud PubSub, AWS Kinesis, Microsoft Azure Storage, Alibaba Cloud OSS etc. @@ -545,8 +227,11 @@ level of coverage as live systems or emulators. | AWS Kinesis | | :heavy_check_mark: |:heavy_check_mark:| | | Alibaba Cloud OSS | | | | :heavy_check_mark: | | Google BigTable/BigQuery | | to be added | | | +| Elasticsearch (experimental) | :heavy_check_mark: | |:heavy_check_mark:| | +| MongoDB (experimental) | :heavy_check_mark: | |:heavy_check_mark:| | + -Note: +References for emulators: - Official [PubSub Emulator](https://cloud.google.com/sdk/gcloud/reference/beta/emulators/pubsub/) by Google Cloud for Cloud PubSub. - Official [Azurite Emulator](https://github.com/Azure/Azurite) by Azure for Azure Storage. - None-official [LocalStack emulator](https://github.com/localstack/localstack) by LocalStack for AWS Kinesis. @@ -558,7 +243,7 @@ Note: * SIG IO [Monthly Meeting Notes](https://docs.google.com/document/d/1CB51yJxns5WA4Ylv89D-a5qReiGTC0GYum6DU-9nKGo/edit) * Gitter room: [tensorflow/sig-io](https://gitter.im/tensorflow/sig-io) -## More Information +## Additional Information * [Streaming Machine Learning with Tiered Storage and Without a Data Lake](https://www.confluent.io/blog/streaming-machine-learning-with-tiered-storage/) - [Kai Waehner](https://github.com/kaiwaehner) * [TensorFlow with Apache Arrow Datasets](https://medium.com/tensorflow/tensorflow-with-apache-arrow-datasets-cdbcfe80a59f) - [Bryan Cutler](https://github.com/BryanCutler) diff --git a/WORKSPACE b/WORKSPACE index 907aedd40..45de019a8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -80,11 +80,11 @@ http_archive( http_archive( name = "avro", build_file = "//third_party:avro.BUILD", - sha256 = "e382ac6685544ae9539084793ac0a4ffd377ba476ea756439625552e14d212b0", - strip_prefix = "avro-release-1.9.1/lang/c++", + sha256 = "8fd1f850ce37e60835e6d8335c0027a959aaa316773da8a9660f7d33a66ac142", + strip_prefix = "avro-release-1.10.1/lang/c++", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/avro/archive/release-1.9.1.tar.gz", - "https://github.com/apache/avro/archive/release-1.9.1.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/avro/archive/release-1.10.1.tar.gz", + "https://github.com/apache/avro/archive/release-1.10.1.tar.gz", ], ) @@ -142,22 +142,22 @@ new_git_repository( http_archive( name = "libgeotiff", build_file = "//third_party:libgeotiff.BUILD", - sha256 = "12c26422e89da7032efcd60d48f3d82c7c0b4c9f3f61aa30c5e3df512946c6cf", - strip_prefix = "libgeotiff-1.5.1", + sha256 = "9452dadd126223a22ce6b97d202066d3873792aaefa7ce739519635a3fe34034", + strip_prefix = "libgeotiff-1.6.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/download.osgeo.org/geotiff/libgeotiff/libgeotiff-1.5.1.zip", - "https://download.osgeo.org/geotiff/libgeotiff/libgeotiff-1.5.1.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/OSGeo/libgeotiff/releases/download/1.6.0/libgeotiff-1.6.0.zip", + "https://github.com/OSGeo/libgeotiff/releases/download/1.6.0/libgeotiff-1.6.0.zip", ], ) http_archive( name = "proj", build_file = "//third_party:proj.BUILD", - sha256 = "0b157e1aa81df4d0dbd89368a0005916bb717f0c09143b4dbc1b20d59204e9f2", - strip_prefix = "proj-6.2.0", + sha256 = "219c6e11b2baa9a3e2bd7ec54ce19702909591032cf6f7d1004b406f10b7c9ad", + strip_prefix = "proj-7.2.1", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/download.osgeo.org/proj/proj-6.2.0.zip", - "https://download.osgeo.org/proj/proj-6.2.0.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/OSGeo/PROJ/releases/download/7.2.1/proj-7.2.1.zip", + "https://github.com/OSGeo/PROJ/releases/download/7.2.1/proj-7.2.1.zip", ], ) @@ -180,11 +180,11 @@ http_archive( "echo '' >> include/base64.h", "echo '#include ' >> include/base64.h", ], - sha256 = "597d9894061f4871a909f1c2c3f56725a69c188ea17784cc71e1e170687faf00", - strip_prefix = "azure-storage-cpplite-0.2.0", + sha256 = "25f34354fb0400ffe1b5a5c09c793c9fc8104d375910f6c84ab10fa50c0059cb", + strip_prefix = "azure-storage-cpplite-0.3.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Azure/azure-storage-cpplite/archive/v0.2.0.tar.gz", - "https://github.com/Azure/azure-storage-cpplite/archive/v0.2.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Azure/azure-storage-cpplite/archive/v0.3.0.tar.gz", + "https://github.com/Azure/azure-storage-cpplite/archive/v0.3.0.tar.gz", ], ) @@ -256,22 +256,27 @@ http_archive( http_archive( name = "thrift", build_file = "//third_party:thrift.BUILD", - sha256 = "b7452d1873c6c43a580d2b4ae38cfaf8fa098ee6dc2925bae98dce0c010b1366", - strip_prefix = "thrift-0.12.0", + sha256 = "5da60088e60984f4f0801deeea628d193c33cec621e78c8a43a5d8c4055f7ad9", + strip_prefix = "thrift-0.13.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/thrift/archive/0.12.0.tar.gz", - "https://github.com/apache/thrift/archive/0.12.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/thrift/archive/v0.13.0.tar.gz", + "https://github.com/apache/thrift/archive/v0.13.0.tar.gz", ], ) http_archive( name = "arrow", build_file = "//third_party:arrow.BUILD", - sha256 = "d7b3838758a365c8c47d55ab0df1006a70db951c6964440ba354f81f518b8d8d", - strip_prefix = "arrow-apache-arrow-0.16.0", + patch_cmds = [ + # TODO: Remove the fowllowing once arrow issue is resolved. + """sed -i.bak 's/type_traits/std::max(sizeof(int16_t), type_traits/g' cpp/src/parquet/column_reader.cc""", + """sed -i.bak 's/value_byte_size/value_byte_size)/g' cpp/src/parquet/column_reader.cc""", + ], + sha256 = "fc461c4f0a60e7470a7c58b28e9344aa8fb0be5cc982e9658970217e084c3a82", + strip_prefix = "arrow-apache-arrow-3.0.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-0.16.0.tar.gz", - "https://github.com/apache/arrow/archive/apache-arrow-0.16.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-3.0.0.tar.gz", + "https://github.com/apache/arrow/archive/apache-arrow-3.0.0.tar.gz", ], ) @@ -305,22 +310,22 @@ http_archive( http_archive( name = "openjpeg", build_file = "//third_party:openjpeg.BUILD", - sha256 = "63f5a4713ecafc86de51bfad89cc07bb788e9bba24ebbf0c4ca637621aadb6a9", - strip_prefix = "openjpeg-2.3.1", + sha256 = "8702ba68b442657f11aaeb2b338443ca8d5fb95b0d845757968a7be31ef7f16d", + strip_prefix = "openjpeg-2.4.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/uclouvain/openjpeg/archive/v2.3.1.tar.gz", - "https://github.com/uclouvain/openjpeg/archive/v2.3.1.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/uclouvain/openjpeg/archive/v2.4.0.tar.gz", + "https://github.com/uclouvain/openjpeg/archive/v2.4.0.tar.gz", ], ) http_archive( name = "libtiff", build_file = "//third_party:libtiff.BUILD", - sha256 = "5d29f32517dadb6dbcd1255ea5bbc93a2b54b94fbf83653b4d65c7d6775b8634", - strip_prefix = "tiff-4.1.0", + sha256 = "eb0484e568ead8fa23b513e9b0041df7e327f4ee2d22db5a533929dfc19633cb", + strip_prefix = "tiff-4.2.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/download.osgeo.org/libtiff/tiff-4.1.0.tar.gz", - "https://download.osgeo.org/libtiff/tiff-4.1.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/download.osgeo.org/libtiff/tiff-4.2.0.tar.gz", + "https://download.osgeo.org/libtiff/tiff-4.2.0.tar.gz", ], ) @@ -395,11 +400,11 @@ http_archive( http_archive( name = "com_google_absl", - sha256 = "f368a8476f4e2e0eccf8a7318b98dafbe30b2600f4e3cf52636e5eb145aba06a", - strip_prefix = "abseil-cpp-df3ea785d8c30a9503321a3d35ee7d35808f190d", + sha256 = "62c27e7a633e965a2f40ff16b487c3b778eae440bab64cad83b34ef1cbe3aa93", + strip_prefix = "abseil-cpp-6f9d96a1f41439ac172ee2ef7ccd8edf0e5d068c", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/6f9d96a1f41439ac172ee2ef7ccd8edf0e5d068c.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/6f9d96a1f41439ac172ee2ef7ccd8edf0e5d068c.tar.gz", ], ) @@ -419,32 +424,32 @@ http_archive( http_archive( name = "curl", build_file = "//third_party:curl.BUILD", - sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5", - strip_prefix = "curl-7.60.0", + sha256 = "01ae0c123dee45b01bbaef94c0bc00ed2aec89cb2ee0fd598e0d302a6b5e0a98", + strip_prefix = "curl-7.69.1", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/curl.haxx.se/download/curl-7.60.0.tar.gz", - "https://curl.haxx.se/download/curl-7.60.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/curl.haxx.se/download/curl-7.69.1.tar.gz", + "https://curl.haxx.se/download/curl-7.69.1.tar.gz", ], ) http_archive( name = "com_github_google_flatbuffers", - sha256 = "12a13686cab7ffaf8ea01711b8f55e1dbd3bf059b7c46a25fefa1250bdd9dd23", - strip_prefix = "flatbuffers-b99332efd732e6faf60bb7ce1ce5902ed65d5ba3", + sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45", + strip_prefix = "flatbuffers-1.12.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/b99332efd732e6faf60bb7ce1ce5902ed65d5ba3.tar.gz", - "https://github.com/google/flatbuffers/archive/b99332efd732e6faf60bb7ce1ce5902ed65d5ba3.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz", + "https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz", ], ) http_archive( name = "xz", build_file = "//third_party:xz.BUILD", - sha256 = "b512f3b726d3b37b6dc4c8570e137b9311e7552e8ccbab4d39d47ce5f4177145", - strip_prefix = "xz-5.2.4", + sha256 = "0d2b89629f13dd1a0602810529327195eff5f62a0142ccd65b903bc16a4ac78a", + strip_prefix = "xz-5.2.5", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/tukaani.org/xz/xz-5.2.4.tar.gz", - "https://tukaani.org/xz/xz-5.2.4.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/xz-mirror/xz/archive/v5.2.5.tar.gz", + "https://github.com/xz-mirror/xz/archive/v5.2.5.tar.gz", ], ) @@ -584,11 +589,11 @@ http_archive( "@com_github_curl_curl": "@curl", "@com_github_nlohmann_json": "@nlohmann_json_lib", }, - sha256 = "ff82045b9491f0d880fc8e5c83fd9542eafb156dcac9ff8c6209ced66ed2a7f0", - strip_prefix = "google-cloud-cpp-1.17.1", + sha256 = "14bf9bf97431b890e0ae5dca8f8904841d4883b8596a7108a42f5700ae58d711", + strip_prefix = "google-cloud-cpp-1.21.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/google-cloud-cpp/archive/v1.17.1.tar.gz", - "https://github.com/googleapis/google-cloud-cpp/archive/v1.17.1.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/google-cloud-cpp/archive/v1.21.0.tar.gz", + "https://github.com/googleapis/google-cloud-cpp/archive/v1.21.0.tar.gz", ], ) @@ -676,7 +681,7 @@ http_archive( patches = [ "//third_party:libapr1.patch", ], - sha256 = "1a0909a1146a214a6ab9de28902045461901baab4e0ee43797539ec05b6dbae0", + sha256 = "096968a363b2374f7450a3c65f3cc0b50561204a8da7bc03a2c39e080febd6e1", strip_prefix = "apr-1.6.5", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/apr/archive/1.6.5.tar.gz", @@ -1105,3 +1110,14 @@ http_archive( "https://github.com/mongodb/mongo-c-driver/releases/download/1.16.2/mongo-c-driver-1.16.2.tar.gz", ], ) + +http_archive( + name = "tinyobjloader", + build_file = "//third_party:tinyobjloader.BUILD", + sha256 = "b8c972dfbbcef33d55554e7c9031abe7040795b67778ad3660a50afa7df6ec56", + strip_prefix = "tinyobjloader-2.0.0rc8", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/tinyobjloader/tinyobjloader/archive/v2.0.0rc8.tar.gz", + "https://github.com/tinyobjloader/tinyobjloader/archive/v2.0.0rc8.tar.gz", + ], +) diff --git a/docs/development.md b/docs/development.md new file mode 100644 index 000000000..5c43a509a --- /dev/null +++ b/docs/development.md @@ -0,0 +1,351 @@ + +## Development + +The document contains the necessary information for setting up the development environement +and building the `tensorflow-io` package from source on various platforms. + +### IDE Setup + +For instructions on how to configure Visual Studio Code for developing TensorFlow I/O, please refer to this [doc](https://github.com/tensorflow/io/blob/master/docs/vscode.md). + +### Lint + +TensorFlow I/O's code conforms to Bazel Buildifier, Clang Format, Black, and Pyupgrade. +Please use the following command to check the source code and identify lint issues: +``` +$ bazel run //tools/lint:check +``` + +For Bazel Buildifier and Clang Format, the following command will automatically identify +and fix any lint errors: +``` +$ bazel run //tools/lint:lint +``` + +Alternatively, if you only want to perform lint check using individual linters, +then you can selectively pass `black`, `pyupgrade`, `bazel`, or `clang` to the above commands. + +For example, a `black` specific lint check can be done using: +``` +$ bazel run //tools/lint:check -- black +``` + +Lint fix using Bazel Buildifier and Clang Format can be done using: +``` +$ bazel run //tools/lint:lint -- bazel clang +``` + +Lint check using `black` and `pyupgrade` for an individual python file can be done using: +``` +$ bazel run //tools/lint:check -- black pyupgrade -- tensorflow_io/core/python/ops/version_ops.py +``` + +Lint fix an individual python file with black and pyupgrade using: +``` +$ bazel run //tools/lint:lint -- black pyupgrade -- tensorflow_io/core/python/ops/version_ops.py +``` + +### Python + +#### macOS + +On macOS Catalina 10.15.7, it is possible to build tensorflow-io with +system provided python 3.8.2. Both `tensorflow` and `bazel` are needed to do so. + +NOTE: The system default python 3.8.2 on macOS 10.15.7 will cause `regex` installation +error caused by compiler option of `-arch arm64 -arch x86_64` (similar to the issue +mentioned in https://github.com/giampaolo/psutil/issues/1832). To overcome this issue +`export ARCHFLAGS="-arch x86_64"` will be needed to remove arm64 build option. + +```sh +#!/usr/bin/env bash + +# Disable arm64 build by specifying only x86_64 arch. +# Only needed for macOS's system default python 3.8.2 on macOS 10.15.7 +export ARCHFLAGS="-arch x86_64" + +# Use following command to check if Xcode is correctly installed: +xcodebuild -version + +# Show macOS's default python3 +python3 --version + +# Install Bazel version specified in .bazelversion +curl -OL https://github.com/bazelbuild/bazel/releases/download/$(cat .bazelversion)/bazel-$(cat .bazelversion)-installer-darwin-x86_64.sh +sudo bash -x -e bazel-$(cat .bazelversion)-installer-darwin-x86_64.sh + +# Install tensorflow and configure bazel +sudo ./configure.sh + +# Add any optimization on bazel command, e.g., --compilation_mode=opt, +# --copt=-msse4.2, --remote_cache=, etc. +# export BAZEL_OPTIMIZATION= + +# Build shared libraries +bazel build -s --verbose_failures $BAZEL_OPTIMIZATION //tensorflow_io/... + +# Once build is complete, shared libraries will be available in +# `bazel-bin/tensorflow_io/core/python/ops/` and it is possible +# to run tests with `pytest`, e.g.: +sudo python3 -m pip install pytest +TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization.py +``` + +NOTE: When running pytest, `TFIO_DATAPATH=bazel-bin` has to be passed so that python can utilize the generated shared libraries after the build process. + +##### Troubleshoot + +If Xcode is installed, but `$ xcodebuild -version` is not displaying the expected output, you might need to enable Xcode command line with the command: + +`$ xcode-select -s /Applications/Xcode.app/Contents/Developer`. + +A terminal restart might be required for the changes to take effect. + +Sample output: + +``` +$ xcodebuild -version +Xcode 12.2 +Build version 12B45b +``` + +#### Linux + +Development of tensorflow-io on Linux is similar to macOS. The required packages +are gcc, g++, git, bazel, and python 3. Newer versions of gcc or python, other than the default system installed +versions might be required though. + +##### Ubuntu 20.04 + +Ubuntu 20.04 requires gcc/g++, git, and python 3. The following will install dependencies and build +the shared libraries on Ubuntu 20.04: +```sh +#!/usr/bin/env bash + +# Install gcc/g++, git, unzip/curl (for bazel), and python3 +sudo apt-get -y -qq update +sudo apt-get -y -qq install gcc g++ git unzip curl python3-pip + +# Install Bazel version specified in .bazelversion +curl -sSOL https://github.com/bazelbuild/bazel/releases/download/$(cat .bazelversion)/bazel-$(cat .bazelversion)-installer-linux-x86_64.sh +sudo bash -x -e bazel-$(cat .bazelversion)-installer-linux-x86_64.sh + +# Upgrade pip +sudo python3 -m pip install -U pip + +# Install tensorflow and configure bazel +sudo ./configure.sh + +# Add any optimization on bazel command, e.g., --compilation_mode=opt, +# --copt=-msse4.2, --remote_cache=, etc. +# export BAZEL_OPTIMIZATION= + +# Build shared libraries +bazel build -s --verbose_failures $BAZEL_OPTIMIZATION //tensorflow_io/... + +# Once build is complete, shared libraries will be available in +# `bazel-bin/tensorflow_io/core/python/ops/` and it is possible +# to run tests with `pytest`, e.g.: +sudo python3 -m pip install pytest +TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization.py +``` + +##### CentOS 8 + +The steps to build shared libraries for CentOS 8 is similiar to Ubuntu 20.04 above +excpet that +``` +sudo yum install -y python3 python3-devel gcc gcc-c++ git unzip which make +``` +should be used instead to install gcc/g++, git, unzip/which (for bazel), and python3. + +##### CentOS 7 + +On CentOS 7, the default python and gcc version are too old to build tensorflow-io's shared +libraries (.so). The gcc provided by Developer Toolset and rh-python36 should be used instead. +Also, the libstdc++ has to be linked statically to avoid discrepancy of libstdc++ installed on +CentOS vs. newer gcc version by devtoolset. + +Furthermore, a special flag `--//tensorflow_io/core:static_build` has to be passed to Bazel +in order to avoid duplication of symbols in statically linked libraries for file system +plugins. + +The following will install bazel, devtoolset-9, rh-python36, and build the shared libraries: +```sh +#!/usr/bin/env bash + +# Install centos-release-scl, then install gcc/g++ (devtoolset), git, and python 3 +sudo yum install -y centos-release-scl +sudo yum install -y devtoolset-9 git rh-python36 make + +# Install Bazel version specified in .bazelversion +curl -sSOL https://github.com/bazelbuild/bazel/releases/download/$(cat .bazelversion)/bazel-$(cat .bazelversion)-installer-linux-x86_64.sh +sudo bash -x -e bazel-$(cat .bazelversion)-installer-linux-x86_64.sh + +# Upgrade pip +scl enable rh-python36 devtoolset-9 \ + 'python3 -m pip install -U pip' + +# Install tensorflow and configure bazel with rh-python36 +scl enable rh-python36 devtoolset-9 \ + './configure.sh' + +# Add any optimization on bazel command, e.g., --compilation_mode=opt, +# --copt=-msse4.2, --remote_cache=, etc. +# export BAZEL_OPTIMIZATION= + +# Build shared libraries, notice the passing of --//tensorflow_io/core:static_build +BAZEL_LINKOPTS="-static-libstdc++ -static-libgcc" BAZEL_LINKLIBS="-lm -l%:libstdc++.a" \ + scl enable rh-python36 devtoolset-9 \ + 'bazel build -s --verbose_failures $BAZEL_OPTIMIZATION --//tensorflow_io/core:static_build //tensorflow_io/...' + +# Once build is complete, shared libraries will be available in +# `bazel-bin/tensorflow_io/core/python/ops/` and it is possible +# to run tests with `pytest`, e.g.: +scl enable rh-python36 devtoolset-9 \ + 'python3 -m pip install pytest' + +TFIO_DATAPATH=bazel-bin \ + scl enable rh-python36 devtoolset-9 \ + 'python3 -m pytest -s -v tests/test_serialization.py' +``` + +#### Python Wheels + +It is possible to build python wheels after bazel build is complete with the following command: +``` +$ python3 setup.py bdist_wheel --data bazel-bin +``` +The .whl file will be available in dist directory. Note the bazel binary directory `bazel-bin` +has to be passed with `--data` args in order for setup.py to locate the necessary share objects, +as `bazel-bin` is outside of the `tensorflow_io` package directory. + +Alternatively, source install could be done with: +``` +$ TFIO_DATAPATH=bazel-bin python3 -m pip install . +``` +with `TFIO_DATAPATH=bazel-bin` passed for the same reason. + +Note installing with `-e` is different from the above. The +``` +$ TFIO_DATAPATH=bazel-bin python3 -m pip install -e . +``` +will not install shared object automatically even with `TFIO_DATAPATH=bazel-bin`. Instead, +`TFIO_DATAPATH=bazel-bin` has to be passed everytime the program is run after the install: +``` +$ TFIO_DATAPATH=bazel-bin python3 + +>>> import tensorflow_io as tfio +>>> ... +``` + +#### Docker + +For Python development, a reference Dockerfile [here](tools/docker/devel.Dockerfile) can be +used to build the TensorFlow I/O package (`tensorflow-io`) from source. Additionally, the +pre-built devel images can be used as well: +```sh +# Pull (if necessary) and start the devel container +$ docker run -it --rm --name tfio-dev --net=host -v ${PWD}:/v -w /v tfsigio/tfio:latest-devel bash + +# Inside the docker container, ./configure.sh will install TensorFlow or use existing install +(tfio-dev) root@docker-desktop:/v$ ./configure.sh + +# Clean up exisiting bazel build's (if any) +(tfio-dev) root@docker-desktop:/v$ rm -rf bazel-* + +# Build TensorFlow I/O C++. For compilation optimization flags, the default (-march=native) +# optimizes the generated code for your machine's CPU type. +# Reference: https://www.tensorflow.orginstall/source#configuration_options). + +# NOTE: Based on the available resources, please change the number of job workers to: +# -j 4/8/16 to prevent bazel server terminations and resource oriented build errors. + +(tfio-dev) root@docker-desktop:/v$ bazel build -j 8 --copt=-msse4.2 --copt=-mavx --compilation_mode=opt --verbose_failures --test_output=errors --crosstool_top=//third_party/toolchains/gcc7_manylinux2010:toolchain //tensorflow_io/... + + +# Run tests with PyTest, note: some tests require launching additional containers to run (see below) +(tfio-dev) root@docker-desktop:/v$ pytest -s -v tests/ +# Build the TensorFlow I/O package +(tfio-dev) root@docker-desktop:/v$ python setup.py bdist_wheel +``` + +A package file `dist/tensorflow_io-*.whl` will be generated after a build is successful. + +NOTE: When working in the Python development container, an environment variable +`TFIO_DATAPATH` is automatically set to point tensorflow-io to the shared C++ +libraries built by Bazel to run `pytest` and build the `bdist_wheel`. Python +`setup.py` can also accept `--data [path]` as an argument, for example +`python setup.py --data bazel-bin bdist_wheel`. + +NOTE: While the tfio-dev container gives developers an easy to work with +environment, the released whl packages are built differently due to manylinux2010 +requirements. Please check [Build Status and CI] section for more details +on how the released whl packages are generated. + +#### Testing + +Some tests require launching a test container or start a local instance +of the associated tool before running. For example, to run kafka +related tests which will start a local instance of kafka, zookeeper and schema-registry, +use: + +```sh +# Start the local instances of kafka, zookeeper and schema-registry +$ bash -x -e tests/test_kafka/kafka_test.sh + +# Run the tests +$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_kafka.py +``` + +Testing `Datasets` associated with tools such as `Elasticsearch` or `MongoDB` +require docker to be available on the system. In such scenarios, use: + + +```sh +# Start elasticsearch within docker container +$ bash tests/test_elasticsearch/elasticsearch_test.sh start + +# Run the tests +$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_elasticsearch.py + +# Stop and remove the container +$ bash tests/test_elasticsearch/elasticsearch_test.sh stop +``` + +Additionally, testing some features of `tensorflow-io` doesn't require you to spin up +any additional tools as the data has been provided in the `tests` directory itself. +For example, to run tests related to `parquet` dataset's, use: + +```sh +# Just run the test +$ TFIO_DATAPATH=bazel-bin pytest -s -vv tests/test_parquet.py +``` + + +### R + +We provide a reference Dockerfile [here](R-package/scripts/Dockerfile) for you +so that you can use the R package directly for testing. You can build it via: +```sh +$ docker build -t tfio-r-dev -f R-package/scripts/Dockerfile . +``` + +Inside the container, you can start your R session, instantiate a `SequenceFileDataset` +from an example [Hadoop SequenceFile](https://wiki.apache.org/hadoop/SequenceFile) +[string.seq](R-package/tests/testthat/testdata/string.seq), and then use any [transformation functions](https://tensorflow.rstudio.com/tools/tfdatasets/articles/introduction.html#transformations) provided by [tfdatasets package](https://tensorflow.rstudio.com/tools/tfdatasets/) on the dataset like the following: + +```r +library(tfio) +dataset <- sequence_file_dataset("R-package/tests/testthat/testdata/string.seq") %>% + dataset_repeat(2) + +sess <- tf$Session() +iterator <- make_iterator_one_shot(dataset) +next_batch <- iterator_get_next(iterator) + +until_out_of_range({ + batch <- sess$run(next_batch) + print(batch) +}) +``` diff --git a/docs/tutorials/_toc.yaml b/docs/tutorials/_toc.yaml index 3ec2d2f26..1c2ee891d 100644 --- a/docs/tutorials/_toc.yaml +++ b/docs/tutorials/_toc.yaml @@ -34,3 +34,6 @@ toc: path: /io/tutorials/kafka - title: "Elasticsearch" path: /io/tutorials/elasticsearch +- title: "Avro" + path: /io/tutorials/avro + diff --git a/docs/tutorials/avro.ipynb b/docs/tutorials/avro.ipynb new file mode 100644 index 000000000..3d7271911 --- /dev/null +++ b/docs/tutorials/avro.ipynb @@ -0,0 +1,567 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Tce3stUlHN0L" + }, + "source": [ + "##### Copyright 2020 The TensorFlow IO Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "tuOe1ymfHZPu" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qFdPvlXBOdUN" + }, + "source": [ + "# Avro Dataset API" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MfBg1C5NB3X0" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xHxb-dlhMIzW" + }, + "source": [ + "## Overview\n", + "\n", + "The objective of Avro Dataset API is to load Avro formatted data natively into TensorFlow as TensorFlow dataset. Avro is a data serialization system similiar to Protocol Buffers. It's widely used in Apache Hadoop where it can provide both a serialization format for persistent data, and a wire format for communication between Hadoop nodes. Avro data is a row-oriented, compacted binary data format. It relies on schema which is stored as a separate JSON file. For the spec of Avro format and schema declaration, please refer to the official manual.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MUXex9ctTuDB" + }, + "source": [ + "## Setup package\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "upgCc3gXybsA" + }, + "source": [ + "### Install the required tensorflow-io package" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uUDYyMZRfkX4" + }, + "outputs": [], + "source": [ + "!pip install tensorflow-io" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gjrZNJQRJP-U" + }, + "source": [ + "### Import packages" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "m6KXZuTBWgRm" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_io as tfio\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eCgO11GTJaTj" + }, + "source": [ + "### Validate tf and tfio imports" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "dX74RKfZ_TdF" + }, + "outputs": [], + "source": [ + "print(\"tensorflow-io version: {}\".format(tfio.__version__))\n", + "print(\"tensorflow version: {}\".format(tf.__version__))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J0ZKhA6s0Pjp" + }, + "source": [ + "## Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4CfKVmCvwcL7" + }, + "source": [ + "### Explore the dataset\n", + "\n", + "For the purpose of this tutorial, let's download the sample Avro dataset. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IGnbXuVnSo8T" + }, + "source": [ + "Download a sample Avro file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Tu01THzWcE-J" + }, + "outputs": [], + "source": [ + "!curl -OL https://github.com/tensorflow/io/raw/master/docs/tutorials/avro/train.avro\n", + "!ls -l train.avro" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IGnbXuVnSo8T" + }, + "source": [ + "Download the corresponding schema file of the sample Avro file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Tu01THzWcE-J" + }, + "outputs": [], + "source": [ + "!curl -OL https://github.com/tensorflow/io/raw/master/docs/tutorials/avro/train.avsc\n", + "!ls -l train.avsc" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9GCyPWNuOm7" + }, + "source": [ + "In the above example, a testing Avro dataset were created based on mnist dataset. The original mnist dataset in TFRecord format is generated from TF named dataset. However, the mnist dataset is too large as a demo dataset. For simplicity purpose, most of it were trimmed and first few records only were kept. Moreover, additional trimming was done for `image` field in original mnist dataset and mapped it to `features` field in Avro. So the avro file `train.avro` has 4 records, each of which has 3 fields: `features`, which is an array of int, `label`, an int or null, and `dataType`, an enum. To view the decoded `train.avro` (Note the original avro data file is not human readable as avro is a compacted format):\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "upgCc3gXybsB" + }, + "source": [ + "Install the required package to read Avro file:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nS3eTBvjt-O4" + }, + "outputs": [], + "source": [ + "!pip install avro\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "upgCc3gXybsB" + }, + "source": [ + "To read and print an Avro file in a human-readable format:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nS3eTBvjt-O5" + }, + "outputs": [], + "source": [ + "from avro.io import DatumReader\n", + "from avro.datafile import DataFileReader\n", + "\n", + "import json\n", + "\n", + "def print_avro(avro_file, max_record_num=None):\n", + " if max_record_num is not None and max_record_num <= 0:\n", + " return\n", + "\n", + " with open(avro_file, 'rb') as avro_handler:\n", + " reader = DataFileReader(avro_handler, DatumReader())\n", + " record_count = 0\n", + " for record in reader:\n", + " record_count = record_count+1\n", + " print(record)\n", + " if max_record_num is not None and record_count == max_record_num:\n", + " break\n", + "\n", + "print_avro(avro_file='train.avro')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9GCyPWNuOm7" + }, + "source": [ + "And the schema of `train.avro` which is represented by `train.avsc` is a JSON-formatted file.\n", + "To view the `train.avsc`: \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nS3eTBvjt-O5" + }, + "outputs": [], + "source": [ + "def print_schema(avro_schema_file):\n", + " with open(avro_schema_file, 'r') as handle:\n", + " parsed = json.load(handle)\n", + " print(json.dumps(parsed, indent=4, sort_keys=True))\n", + "\n", + "print_schema('train.avsc')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4CfKVmCvwcL7" + }, + "source": [ + "### Prepare the dataset\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9GCyPWNuOm7" + }, + "source": [ + "Load `train.avro` as TensorFlow dataset with Avro dataset API: \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nS3eTBvjt-O5" + }, + "outputs": [], + "source": [ + "features = {\n", + " 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32),\n", + " 'label': tf.io.FixedLenFeature(shape=[], dtype=tf.int32, default_value=-100),\n", + " 'dataType': tf.io.FixedLenFeature(shape=[], dtype=tf.string)\n", + "}\n", + "\n", + "schema = tf.io.gfile.GFile('train.avsc').read()\n", + "\n", + "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n", + " reader_schema=schema,\n", + " features=features,\n", + " shuffle=False,\n", + " batch_size=3,\n", + " num_epochs=1)\n", + "\n", + "for record in dataset:\n", + " print(record['features[*]'])\n", + " print(record['label'])\n", + " print(record['dataType'])\n", + " print(\"--------------------\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IF_kYz_o2DH4" + }, + "source": [ + "The above example converts `train.avro` into tensorflow dataset. Each element of the dataset is a dictionary whose key is the feature name, value is the converted sparse or dense tensor. \n", + "E.g, it converts `features`, `label`, `dataType` field to a VarLenFeature(SparseTensor), FixedLenFeature(DenseTensor), and FixedLenFeature(DenseTensor) respectively. Since batch_size is 3, it coerce 3 records from `train.avro` into one element in the result dataset.\n", + "For the first record in `train.avro` whose label is null, avro reader replaces it with the specified default value(-100).\n", + "In this example, there're 4 records in total in `train.avro`. Since batch size is 3, the result dataset contains 3 elements, last of which's batch size is 1. However user is also able to drop the last batch if the size is smaller than batch size by enabling `drop_final_batch`. E.g: \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nS3eTBvjt-O5" + }, + "outputs": [], + "source": [ + "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n", + " reader_schema=schema,\n", + " features=features,\n", + " shuffle=False,\n", + " batch_size=3,\n", + " drop_final_batch=True,\n", + " num_epochs=1)\n", + "\n", + "for record in dataset:\n", + " print(record)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IF_kYz_o2DH4" + }, + "source": [ + "One can also increase num_parallel_reads to expediate Avro data processing by increasing avro parse/read parallelism.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nS3eTBvjt-O5" + }, + "outputs": [], + "source": [ + "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n", + " reader_schema=schema,\n", + " features=features,\n", + " shuffle=False,\n", + " num_parallel_reads=16,\n", + " batch_size=3,\n", + " drop_final_batch=True,\n", + " num_epochs=1)\n", + "\n", + "for record in dataset:\n", + " print(record)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IF_kYz_o2DH4" + }, + "source": [ + "For detailed usage of `make_avro_record_dataset`, please refer to API doc.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4CfKVmCvwcL7" + }, + "source": [ + "### Train tf.keras models with Avro dataset\n", + "\n", + "Now let's walk through an end-to-end example of tf.keras model training with Avro dataset based on mnist dataset.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9GCyPWNuOm7" + }, + "source": [ + "Load `train.avro` as TensorFlow dataset with Avro dataset API: \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nS3eTBvjt-O5" + }, + "outputs": [], + "source": [ + "features = {\n", + " 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32)\n", + "}\n", + "\n", + "schema = tf.io.gfile.GFile('train.avsc').read()\n", + "\n", + "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n", + " reader_schema=schema,\n", + " features=features,\n", + " shuffle=False,\n", + " batch_size=1,\n", + " num_epochs=1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9GCyPWNuOm7" + }, + "source": [ + "Define a simple keras model: \n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "m6KXZuTBWgRm" + }, + "outputs": [], + "source": [ + "def build_and_compile_cnn_model():\n", + " model = tf.keras.Sequential()\n", + " model.compile(optimizer='sgd', loss='mse')\n", + " return model\n", + "\n", + "model = build_and_compile_cnn_model()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4CfKVmCvwcL7" + }, + "source": [ + "### Train the keras model with Avro dataset:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "m6KXZuTBWgRm" + }, + "outputs": [], + "source": [ + "model.fit(x=dataset, epochs=1, steps_per_epoch=1, verbose=1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IF_kYz_o2DH4" + }, + "source": [ + "The avro dataset can parse and coerce any avro data into TensorFlow tensors, including records in records, maps, arrays, branches, and enumerations. The parsing information is passed into the avro dataset implementation as a map where \n", + "keys encode how to parse the data \n", + "values encode on how to coerce the data into TensorFlow tensors – deciding the primitive type (e.g. bool, int, long, float, double, string) as well as the tensor type (e.g. sparse or dense). A listing of TensorFlow's parser types (see Table 1) and the coercion of primitive types (Table 2) is provided. \n", + "\n", + "Table 1 the supported TensorFlow parser types:\n", + "\n", + "TensorFlow Parser Types|TensorFlow Tensors|Explanation\n", + "----|----|------\n", + "tf.FixedLenFeature([], tf.int32)|dense tensor|Parse a fixed length feature; that is all rows have the same constant number of elements, e.g. just one element or an array that has always the same number of elements for each row \n", + "tf.SparseFeature(index_key=['key_1st_index', 'key_2nd_index'], value_key='key_value', dtype=tf.int64, size=[20, 50]) |sparse tensor|Parse a sparse feature where each row has a variable length list of indices and values. The 'index_key' identifies the indices. The 'value_key' identifies the value. The 'dtype' is the data type. The 'size' is the expected maximum index value for each index entry\n", + "tfio.experimental.columnar.VarLenFeatureWithRank([],tf.int64) |sparse tensor|Parse a variable length feature; that means each data row can have a variable number of elements, e.g. the 1st row has 5 elements, the 2nd row has 7 elements\n", + "\n", + "Table 2 the supported conversion from Avro types to TensorFlow's types:\n", + "\n", + "Avro Primitive Type|TensorFlow Primitive Type\n", + "----|----\n", + "boolean: a binary value|tf.bool\n", + "bytes: a sequence of 8-bit unsigned bytes|tf.string\n", + "double: double precision 64-bit IEEE floating point number|tf.float64\n", + "enum: enumeration type|tf.string using the symbol name\n", + "float: single precision 32-bit IEEE floating point number|tf.float32\n", + "int: 32-bit signed integer|tf.int32\n", + "long: 64-bit signed integer|tf.int64\n", + "null: no value|uses default value\n", + "string: unicode character sequence|tf.string\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IF_kYz_o2DH4" + }, + "source": [ + "A comprehensive set of examples of Avro dataset API is provided within the tests.\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "Tce3stUlHN0L" + ], + "name": "avro.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/tutorials/avro/train.avro b/docs/tutorials/avro/train.avro new file mode 100644 index 000000000..35f63a6b6 Binary files /dev/null and b/docs/tutorials/avro/train.avro differ diff --git a/docs/tutorials/avro/train.avsc b/docs/tutorials/avro/train.avsc new file mode 100644 index 000000000..904864b37 --- /dev/null +++ b/docs/tutorials/avro/train.avsc @@ -0,0 +1 @@ +{"name": "ImageDataset", "type": "record", "fields": [{"name": "features", "type": {"type": "array", "items": "int"}}, {"name": "label", "type": ["int", "null"]}, {"name": "dataType", "type": {"type": "enum", "name": "dataTypes", "symbols": ["TRAINING", "VALIDATION"]}}]} \ No newline at end of file diff --git a/docs/tutorials/kafka.ipynb b/docs/tutorials/kafka.ipynb index 51c0ee9a7..a5af9decc 100644 --- a/docs/tutorials/kafka.ipynb +++ b/docs/tutorials/kafka.ipynb @@ -70,11 +70,13 @@ "source": [ "## Overview\n", "\n", - "This tutorial focuses on streaming data from a [Kafka](https://docs.confluent.io/current/getting-started.html) cluster into a `tf.data.Dataset` which is then used in conjunction with `tf.keras` for training and inference.\n", + "This tutorial focuses on streaming data from a [Kafka](https://kafka.apache.org/quickstart) cluster into a `tf.data.Dataset` which is then used in conjunction with `tf.keras` for training and inference.\n", "\n", "Kafka is primarily a distributed event-streaming platform which provides scalable and fault-tolerant streaming data across data pipelines. It is an essential technical component of a plethora of major enterprises where mission-critical data delivery is a primary requirement.\n", "\n", - "**NOTE:** A basic understanding of the [kafka components](https://docs.confluent.io/current/kafka/introduction.html) will help you in following the tutorial with ease." + "**NOTE:** A basic understanding of the [kafka components](https://kafka.apache.org/documentation/#intro_concepts_and_terms) will help you in following the tutorial with ease.\n", + "\n", + "**NOTE:** A Java runtime environment is required to run this tutorial." ] }, { @@ -180,8 +182,8 @@ }, "outputs": [], "source": [ - "!curl -sSOL http://packages.confluent.io/archive/5.4/confluent-community-5.4.1-2.12.tar.gz\n", - "!tar -xzf confluent-community-5.4.1-2.12.tar.gz" + "!curl -sSOL https://downloads.apache.org/kafka/2.7.0/kafka_2.13-2.7.0.tgz\n", + "!tar -xzf kafka_2.13-2.7.0.tgz" ] }, { @@ -190,7 +192,7 @@ "id": "vAzfu_WiEs4F" }, "source": [ - "Using the default configurations (provided by the confluent package) for spinning up the instances." + "Using the default configurations (provided by Apache Kafka) for spinning up the instances." ] }, { @@ -201,8 +203,8 @@ }, "outputs": [], "source": [ - "!cd confluent-5.4.1 && bin/zookeeper-server-start -daemon etc/kafka/zookeeper.properties\n", - "!cd confluent-5.4.1 && bin/kafka-server-start -daemon etc/kafka/server.properties\n", + "!./kafka_2.13-2.7.0/bin/zookeeper-server-start.sh -daemon ./kafka_2.13-2.7.0/config/zookeeper.properties\n", + "!./kafka_2.13-2.7.0/bin/kafka-server-start.sh -daemon ./kafka_2.13-2.7.0/config/server.properties\n", "!echo \"Waiting for 10 secs until kafka and zookeeper services are up and running\"\n", "!sleep 10\n" ] @@ -247,8 +249,8 @@ }, "outputs": [], "source": [ - "!confluent-5.4.1/bin/kafka-topics --create --zookeeper 127.0.0.1:2181 --replication-factor 1 --partitions 1 --topic susy-train\n", - "!confluent-5.4.1/bin/kafka-topics --create --zookeeper 127.0.0.1:2181 --replication-factor 1 --partitions 2 --topic susy-test\n" + "!./kafka_2.13-2.7.0/bin/kafka-topics.sh --create --bootstrap-server 127.0.0.1:9092 --replication-factor 1 --partitions 1 --topic susy-train\n", + "!./kafka_2.13-2.7.0/bin/kafka-topics.sh --create --bootstrap-server 127.0.0.1:9092 --replication-factor 1 --partitions 2 --topic susy-test\n" ] }, { @@ -268,8 +270,8 @@ }, "outputs": [], "source": [ - "!confluent-5.4.1/bin/kafka-topics --bootstrap-server 127.0.0.1:9092 --describe --topic susy-train\n", - "!confluent-5.4.1/bin/kafka-topics --bootstrap-server 127.0.0.1:9092 --describe --topic susy-test\n" + "!./kafka_2.13-2.7.0/bin/kafka-topics.sh --describe --bootstrap-server 127.0.0.1:9092 --topic susy-train\n", + "!./kafka_2.13-2.7.0/bin/kafka-topics.sh --describe --bootstrap-server 127.0.0.1:9092 --topic susy-test\n" ] }, { @@ -720,7 +722,7 @@ }, "outputs": [], "source": [ - "!confluent-5.4.1/bin/kafka-consumer-groups --bootstrap-server 127.0.0.1:9092 --describe --group testcg\n" + "!./kafka_2.13-2.7.0/bin/kafka-consumer-groups.sh --bootstrap-server 127.0.0.1:9092 --describe --group testcg\n" ] }, { @@ -753,7 +755,7 @@ "source": [ "### The tfio training dataset for online learning\n", "\n", - "The `streaming.KafkaBatchIODataset` is similar to the `streaming.KafkaGroupIODataset` in it's API. Additionally, it is recommended to utilize the `stream_timeout` parameter to configure the duration for which the dataset will block for new messages before timing out. In the instance below, the dataset is configured with a `stream_timeout` of `30000` milliseconds. This implies that, after all the messages from the topic have been consumed, the dataset will wait for an additional 30 seconds before timing out and disconnecting from the kafka cluster. If new messages are streamed into the topic before timing out, the data consumption and model training resumes for those newly consumed data points. To block indefinitely, set it to `-1`." + "The `streaming.KafkaBatchIODataset` is similar to the `streaming.KafkaGroupIODataset` in it's API. Additionally, it is recommended to utilize the `stream_timeout` parameter to configure the duration for which the dataset will block for new messages before timing out. In the instance below, the dataset is configured with a `stream_timeout` of `10000` milliseconds. This implies that, after all the messages from the topic have been consumed, the dataset will wait for an additional 10 seconds before timing out and disconnecting from the kafka cluster. If new messages are streamed into the topic before timing out, the data consumption and model training resumes for those newly consumed data points. To block indefinitely, set it to `-1`." ] }, { @@ -768,7 +770,7 @@ " topics=[\"susy-train\"],\n", " group_id=\"cgonline\",\n", " servers=\"127.0.0.1:9092\",\n", - " stream_timeout=30000, # in milliseconds, to block indefinitely, set it to -1.\n", + " stream_timeout=10000, # in milliseconds, to block indefinitely, set it to -1.\n", " configuration=[\n", " \"session.timeout.ms=7000\",\n", " \"max.poll.interval.ms=8000\",\n", @@ -777,50 +779,6 @@ ")" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "sJronJPnZhyR" - }, - "source": [ - "In addition to training the model on existing data, a background thread will be started, which will start streaming additional data into the `susy-train` topic after a sleep duration of 30 seconds. This demonstrates the functionality of resuming the training as soons as new data is fed into the topic without the need for building the dataset over and over again." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "iaBjhFkmZd1C" - }, - "outputs": [], - "source": [ - "def error_callback(exc):\n", - " raise Exception('Error while sendig data to kafka: {0}'.format(str(exc)))\n", - "\n", - "def write_to_kafka_after_sleep(topic_name, items):\n", - " time.sleep(30)\n", - " print(\"#\"*100)\n", - " print(\"Writing messages into topic: {0} after a nice sleep !\".format(topic_name))\n", - " print(\"#\"*100)\n", - " count=0\n", - " producer = KafkaProducer(bootstrap_servers=['127.0.0.1:9092'])\n", - " for message, key in items:\n", - " producer.send(topic_name,\n", - " key=key.encode('utf-8'),\n", - " value=message.encode('utf-8')\n", - " ).add_errback(error_callback)\n", - " count+=1\n", - " producer.flush()\n", - " print(\"#\"*100)\n", - " print(\"Wrote {0} messages into topic: {1}\".format(count, topic_name))\n", - " print(\"#\"*100)\n", - "\n", - "def decode_kafka_online_item(raw_message, raw_key):\n", - " message = tf.io.decode_csv(raw_message, [[0.0] for i in range(NUM_COLUMNS)])\n", - " key = tf.strings.to_number(raw_key)\n", - " return (message, key)\n" - ] - }, { "cell_type": "markdown", "metadata": { @@ -838,11 +796,11 @@ }, "outputs": [], "source": [ - "thread = threading.Thread(target=write_to_kafka_after_sleep,\n", - " args=(\"susy-train\", zip(x_train, y_train)))\n", - "thread.daemon = True\n", - "thread.start()\n", - "\n", + "def decode_kafka_online_item(raw_message, raw_key):\n", + " message = tf.io.decode_csv(raw_message, [[0.0] for i in range(NUM_COLUMNS)])\n", + " key = tf.strings.to_number(raw_key)\n", + " return (message, key)\n", + " \n", "for mini_ds in online_train_ds:\n", " mini_ds = mini_ds.shuffle(buffer_size=32)\n", " mini_ds = mini_ds.map(decode_kafka_online_item)\n", diff --git a/docs/tutorials/postgresql.ipynb b/docs/tutorials/postgresql.ipynb index b5b3d7f35..c5d9426b3 100644 --- a/docs/tutorials/postgresql.ipynb +++ b/docs/tutorials/postgresql.ipynb @@ -275,14 +275,7 @@ "id": "8y-VpwcWNYTF" }, "source": [ - "As you could see from the output of `dataset.element_spec` above, the element of the created `Dataset` is a python dict object with column names of the database table as keys:\n", - "```\n", - "{\n", - " 'co': TensorSpec(shape=(), dtype=tf.float32, name=None),\n", - " 'pt08s1': TensorSpec(shape=(), dtype=tf.int32, name=None),\n", - "}\n", - "```\n", - "\n", + "As you could see from the output of `dataset.element_spec` above, the element of the created `Dataset` is a python dict object with column names of the database table as keys.\n", "It is quite convenient to apply further operations. For example, you could select both `nox` and `no2` field of the `Dataset`, and calculate the difference:" ] }, diff --git a/setup.py b/setup.py index 74a891891..8e7835c4f 100644 --- a/setup.py +++ b/setup.py @@ -134,7 +134,7 @@ def has_ext_modules(self): ], keywords="tensorflow io machine learning", packages=setuptools.find_packages(where=".", exclude=["tests"]), - python_requires=">=3.5, <3.9", + python_requires=">=3.5, <3.10", install_requires=[package], package_data={".": ["*.so"],}, project_urls={ diff --git a/tensorflow_io/arrow/README.md b/tensorflow_io/arrow/README.md index e0454cbc9..72dccf3e0 100644 --- a/tensorflow_io/arrow/README.md +++ b/tensorflow_io/arrow/README.md @@ -20,12 +20,11 @@ import tensorflow_io.arrow as arrow_io # Assume `df` is an existing Pandas DataFrame dataset = arrow_io.ArrowDataset.from_pandas(df) -iterator = dataset.make_one_shot_iterator() -next_element = iterator.get_next() +# All `tf.data.Dataset` operations can now be performed, for ex: +dataset = dataset.batch(2) -with tf.Session() as sess: - for i in range(len(df)): - print(sess.run(next_element)) +for row in dataset: + print(row) ``` NOTE: The entire DataFrame will be serialized to the Dataset and is not @@ -59,16 +58,9 @@ dataset = arrow_io.ArrowFeatherDataset( output_types=(tf.int32, tf.float32), output_shapes=([], [])) -iterator = dataset.make_one_shot_iterator() -next_element = iterator.get_next() - # This will iterate over each row of each file provided -with tf.Session() as sess: - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break +for row in dataset: + print(row) ``` An alternate constructor can also be used to infer output types and shapes from @@ -101,17 +93,10 @@ dataset = arrow_io.ArrowStreamDataset( output_types=(tf.int32, tf.float32), output_shapes=([], [])) -iterator = dataset.make_one_shot_iterator() -next_element = iterator.get_next() - # The host connection is made when the Dataset op is run and will iterate over # each row of each record batch until the Arrow stream is finished -with tf.Session() as sess: - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break +for row in dataset: + print(row) ``` An alternate constructor can also be used to infer output types and shapes from diff --git a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc index 3b4e41917..6ae8b1457 100644 --- a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc +++ b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "arrow/api.h" #include "arrow/ipc/api.h" +#include "arrow/result.h" #include "arrow/util/io_util.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/graph/graph.h" @@ -476,12 +477,17 @@ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase { buffer_ = std::make_shared(dataset()->buffer_ptr_, dataset()->buffer_size_); buffer_reader_ = std::make_shared(buffer_); - CHECK_ARROW(arrow::ipc::RecordBatchFileReader::Open( - buffer_reader_.get(), buffer_->size(), &reader_)); + arrow::Result> + result = arrow::ipc::RecordBatchFileReader::Open( + buffer_reader_.get(), buffer_->size()); + CHECK_ARROW(result.status()); + reader_ = std::move(result).ValueUnsafe(); num_batches_ = reader_->num_record_batches(); if (num_batches_ > 0) { - CHECK_ARROW( - reader_->ReadRecordBatch(current_batch_idx_, ¤t_batch_)); + arrow::Result> result = + reader_->ReadRecordBatch(current_batch_idx_); + CHECK_ARROW(result.status()); + current_batch_ = std::move(result).ValueUnsafe(); TF_RETURN_IF_ERROR(CheckBatchColumnTypes(current_batch_)); } return Status::OK(); @@ -491,8 +497,10 @@ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase { TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { ArrowBaseIterator::NextStreamLocked(env); if (++current_batch_idx_ < num_batches_) { - CHECK_ARROW( - reader_->ReadRecordBatch(current_batch_idx_, ¤t_batch_)); + arrow::Result> result = + reader_->ReadRecordBatch(current_batch_idx_); + CHECK_ARROW(result.status()); + current_batch_ = std::move(result).ValueUnsafe(); } return Status::OK(); } @@ -604,12 +612,14 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase { const string& batches = dataset()->batches_.scalar()(); auto buffer = std::make_shared(batches); auto buffer_reader = std::make_shared(buffer); - CHECK_ARROW( - arrow::ipc::RecordBatchFileReader::Open(buffer_reader, &reader_)); + auto result = arrow::ipc::RecordBatchFileReader::Open(buffer_reader); + CHECK_ARROW(result.status()); + reader_ = std::move(result).ValueUnsafe(); num_batches_ = reader_->num_record_batches(); if (num_batches_ > 0) { - CHECK_ARROW( - reader_->ReadRecordBatch(current_batch_idx_, ¤t_batch_)); + auto result = reader_->ReadRecordBatch(current_batch_idx_); + CHECK_ARROW(result.status()); + current_batch_ = std::move(result).ValueUnsafe(); TF_RETURN_IF_ERROR(CheckBatchColumnTypes(current_batch_)); } return Status::OK(); @@ -619,8 +629,9 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase { TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { ArrowBaseIterator::NextStreamLocked(env); if (++current_batch_idx_ < num_batches_) { - CHECK_ARROW( - reader_->ReadRecordBatch(current_batch_idx_, ¤t_batch_)); + auto result = reader_->ReadRecordBatch(current_batch_idx_); + CHECK_ARROW(result.status()); + current_batch_ = std::move(result).ValueUnsafe(); } return Status::OK(); } @@ -736,14 +747,18 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase { new ArrowRandomAccessFile(tf_file.get(), size)); // Create the Feather reader - std::unique_ptr reader; - CHECK_ARROW(arrow::ipc::feather::TableReader::Open(in_file, &reader)); + std::shared_ptr reader; + arrow::Result> result = + arrow::ipc::feather::Reader::Open(in_file); + CHECK_ARROW(result.status()); + reader = std::move(result).ValueUnsafe(); // Read file columns and build a table - int64_t num_columns = reader->num_columns(); std::shared_ptr<::arrow::Table> table; CHECK_ARROW(reader->Read(&table)); + int64_t num_columns = table->num_columns(); + // Convert the table to a sequence of batches arrow::TableBatchReader tr(*table.get()); std::shared_ptr batch; @@ -885,8 +900,10 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase { in_stream_ = socket_stream; } - CHECK_ARROW(arrow::ipc::RecordBatchStreamReader::Open(in_stream_.get(), - &reader_)); + auto result = + arrow::ipc::RecordBatchStreamReader::Open(in_stream_.get()); + CHECK_ARROW(result.status()); + reader_ = std::move(result).ValueUnsafe(); CHECK_ARROW(reader_->ReadNext(¤t_batch_)); TF_RETURN_IF_ERROR(CheckBatchColumnTypes(current_batch_)); return Status::OK(); diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.cc b/tensorflow_io/arrow/kernels/arrow_kernels.cc index 27e64b0bf..08b4007e4 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.cc +++ b/tensorflow_io/arrow/kernels/arrow_kernels.cc @@ -161,10 +161,11 @@ class ArrowReadableFromMemoryInitOp auto buffer_reader = std::make_shared(buffer_); std::shared_ptr schema; - arrow::Status status = - arrow::ipc::ReadSchema(buffer_reader.get(), nullptr, &schema); - OP_REQUIRES(context, status.ok(), + arrow::Result> result = + arrow::ipc::ReadSchema(buffer_reader.get(), nullptr); + OP_REQUIRES(context, result.ok(), errors::Internal("Error reading Arrow Schema")); + schema = std::move(result).ValueUnsafe(); const Tensor* array_buffer_addrs_tensor; OP_REQUIRES_OK(context, context->input("array_buffer_addresses", @@ -429,10 +430,10 @@ class ListFeatherColumnsOp : public OpKernel { ::arrow::ipc::feather::fbs::GetCTable(buffer.data()); OP_REQUIRES(context, - (table->version() >= ::arrow::ipc::feather::kFeatherVersion), + (table->version() >= ::arrow::ipc::feather::kFeatherV1Version), errors::InvalidArgument( "feather file is old: ", table->version(), " vs. ", - ::arrow::ipc::feather::kFeatherVersion)); + ::arrow::ipc::feather::kFeatherV1Version)); std::vector columns; std::vector dtypes; @@ -577,10 +578,10 @@ class FeatherReadable : public IOReadableInterface { const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data()); - if (table->version() < ::arrow::ipc::feather::kFeatherVersion) { + if (table->version() < ::arrow::ipc::feather::kFeatherV1Version) { return errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", - ::arrow::ipc::feather::kFeatherVersion); + ::arrow::ipc::feather::kFeatherV1Version); } for (size_t i = 0; i < table->columns()->size(); i++) { @@ -683,18 +684,20 @@ class FeatherReadable : public IOReadableInterface { if (feather_file_.get() == nullptr) { feather_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); - arrow::Status s = - arrow::ipc::feather::TableReader::Open(feather_file_, &reader_); - if (!s.ok()) { - return errors::Internal(s.ToString()); + arrow::Result> result = + arrow::ipc::feather::Reader::Open(feather_file_); + if (!result.ok()) { + return errors::Internal(result.status().ToString()); } + reader_ = std::move(result).ValueUnsafe(); } - std::shared_ptr column; - arrow::Status s = reader_->GetColumn(column_index, &column); + std::shared_ptr table; + arrow::Status s = reader_->Read(&table); if (!s.ok()) { return errors::Internal(s.ToString()); } + std::shared_ptr column = table->column(column_index); std::shared_ptr<::arrow::ChunkedArray> slice = column->Slice(element_start, element_stop); @@ -767,7 +770,7 @@ class FeatherReadable : public IOReadableInterface { std::unique_ptr file_ TF_GUARDED_BY(mu_); uint64 file_size_ TF_GUARDED_BY(mu_); std::shared_ptr feather_file_ TF_GUARDED_BY(mu_); - std::unique_ptr reader_ TF_GUARDED_BY(mu_); + std::shared_ptr reader_ TF_GUARDED_BY(mu_); std::vector dtypes_; std::vector shapes_; diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.h b/tensorflow_io/arrow/kernels/arrow_kernels.h index 2ad9e57b5..4a6b88e43 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.h +++ b/tensorflow_io/arrow/kernels/arrow_kernels.h @@ -51,8 +51,12 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { return result.size(); } arrow::Result> Read(int64_t nbytes) override { - std::shared_ptr buffer; - RETURN_NOT_OK(AllocateResizableBuffer(nbytes, &buffer)); + arrow::Result> result = + arrow::AllocateResizableBuffer(nbytes); + ARROW_RETURN_NOT_OK(result); + std::shared_ptr buffer = + std::move(result).ValueUnsafe(); + ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, Read(nbytes, buffer->mutable_data())); RETURN_NOT_OK(buffer->Resize(bytes_read)); diff --git a/tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc b/tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc index 2bc3d0f2d..c685eb14c 100644 --- a/tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc +++ b/tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc @@ -132,8 +132,11 @@ arrow::Result ArrowStreamClient::Read(int64_t nbytes, void* out) { arrow::Result> ArrowStreamClient::Read( int64_t nbytes) { - std::shared_ptr buffer; - ARROW_RETURN_NOT_OK(arrow::AllocateResizableBuffer(nbytes, &buffer)); + arrow::Result> result = + arrow::AllocateResizableBuffer(nbytes); + ARROW_RETURN_NOT_OK(result); + std::shared_ptr buffer = + std::move(result).ValueUnsafe(); int64_t bytes_read; ARROW_ASSIGN_OR_RAISE(bytes_read, Read(nbytes, buffer->mutable_data())); ARROW_RETURN_NOT_OK(buffer->Resize(bytes_read, false)); diff --git a/tensorflow_io/arrow/kernels/arrow_stream_client_windows.cc b/tensorflow_io/arrow/kernels/arrow_stream_client_windows.cc index a1c9cd641..0bc1fdacd 100644 --- a/tensorflow_io/arrow/kernels/arrow_stream_client_windows.cc +++ b/tensorflow_io/arrow/kernels/arrow_stream_client_windows.cc @@ -154,8 +154,11 @@ arrow::Result ArrowStreamClient::Read(int64_t nbytes, void* out) { arrow::Result> ArrowStreamClient::Read( int64_t nbytes) { - std::shared_ptr buffer; - ARROW_RETURN_NOT_OK(arrow::AllocateResizableBuffer(nbytes, &buffer)); + arrow::Result> result = + arrow::AllocateResizableBuffer(nbytes); + ARROW_RETURN_NOT_OK(result); + std::shared_ptr buffer = + std::move(result).ValueUnsafe(); int64_t bytes_read; ARROW_ASSIGN_OR_RAISE(bytes_read, Read(nbytes, buffer->mutable_data())); ARROW_RETURN_NOT_OK(buffer->Resize(bytes_read, false)); diff --git a/tensorflow_io/bigquery/kernels/bigquery_dataset_op.cc b/tensorflow_io/bigquery/kernels/bigquery_dataset_op.cc index a6253de99..66a228235 100644 --- a/tensorflow_io/bigquery/kernels/bigquery_dataset_op.cc +++ b/tensorflow_io/bigquery/kernels/bigquery_dataset_op.cc @@ -102,11 +102,12 @@ class BigQueryDatasetOp : public DatasetOpKernel { arrow::ipc::DictionaryMemo dict_memo; arrow::io::BufferReader input(buffer_); - arrow::Status arrow_status = - arrow::ipc::ReadSchema(&input, &dict_memo, &arrow_schema_); - OP_REQUIRES(ctx, arrow_status.ok(), + arrow::Result> result = + arrow::ipc::ReadSchema(&input, &dict_memo); + OP_REQUIRES(ctx, result.ok(), errors::Internal("Error reading Arrow Schema", - arrow_status.message())); + result.status().message())); + arrow_schema_ = std::move(result).ValueUnsafe(); } else { ctx->CtxFailure(errors::InvalidArgument("Invalid data_format")); } diff --git a/tensorflow_io/bigquery/kernels/bigquery_lib.h b/tensorflow_io/bigquery/kernels/bigquery_lib.h index 3164824ea..0e6914077 100644 --- a/tensorflow_io/bigquery/kernels/bigquery_lib.h +++ b/tensorflow_io/bigquery/kernels/bigquery_lib.h @@ -224,12 +224,13 @@ class BigQueryReaderArrowDatasetIterator arrow::io::BufferReader buffer_reader_(buffer_); arrow::ipc::DictionaryMemo dict_memo; - auto arrow_status = - arrow::ipc::ReadRecordBatch(this->dataset()->arrow_schema(), &dict_memo, - &buffer_reader_, &this->record_batch_); - if (!arrow_status.ok()) { - return errors::Internal(arrow_status.ToString()); + auto result = arrow::ipc::ReadRecordBatch( + this->dataset()->arrow_schema(), &dict_memo, + arrow::ipc::IpcReadOptions::Defaults(), &buffer_reader_); + if (!result.ok()) { + return errors::Internal(result.status().ToString()); } + this->record_batch_ = std::move(result).ValueUnsafe(); VLOG(3) << "got record batch, rows:" << record_batch_->num_rows(); diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 3563b0432..f05d9329c 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -463,7 +463,7 @@ cc_library( linkstatic = True, deps = [ "//tensorflow_io/core:dataset_ops", - "@aws-sdk-cpp", + "@aws-sdk-cpp//:kinesis", ], alwayslink = 1, ) @@ -695,6 +695,22 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "obj_ops", + srcs = [ + "kernels/obj_kernels.cc", + "ops/obj_ops.cc", + ], + copts = tf_io_copts(), + linkstatic = True, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + "@tinyobjloader", + ], + alwayslink = 1, +) + cc_binary( name = "python/ops/libtensorflow_io.so", copts = tf_io_copts(), @@ -717,6 +733,7 @@ cc_binary( "//tensorflow_io/core:parquet_ops", "//tensorflow_io/core:pcap_ops", "//tensorflow_io/core:pulsar_ops", + "//tensorflow_io/core:obj_ops", "//tensorflow_io/core:operation_ops", "//tensorflow_io/core:pubsub_ops", "//tensorflow_io/core:serialization_ops", @@ -734,6 +751,7 @@ cc_binary( "//tensorflow_io/core:genome_ops", "//tensorflow_io/core:optimization", "//tensorflow_io/core:oss_ops", + "//tensorflow_io/gcs:gcs_config_ops", "//tensorflow_io/core/kernels/gsmemcachedfs:gs_memcached_file_system", ], }) + select({ diff --git a/tensorflow_io/core/kernels/audio_video_ogg_kernels.cc b/tensorflow_io/core/kernels/audio_video_ogg_kernels.cc index 7b2633c14..2d2d9350f 100644 --- a/tensorflow_io/core/kernels/audio_video_ogg_kernels.cc +++ b/tensorflow_io/core/kernels/audio_video_ogg_kernels.cc @@ -83,7 +83,10 @@ static ov_callbacks OggVorbisCallbacks = { class OggVorbisReadableResource : public AudioReadableResourceBase { public: OggVorbisReadableResource(Env* env) : env_(env) {} - ~OggVorbisReadableResource() {} + ~OggVorbisReadableResource() { + // Cleanup the vorbis file + ov_clear(&ogg_vorbis_file_); + } Status Init(const string& filename, const void* optional_memory, const size_t optional_length) override { @@ -142,8 +145,8 @@ class OggVorbisReadableResource : public AudioReadableResourceBase { long samples_read = 0; long samples_to_read = value->shape().dim_size(0); + float** buffer; while (samples_read < samples_to_read) { - float** buffer; int bitstream = 0; long chunk = ov_read_float(&ogg_vorbis_file_, &buffer, samples_to_read - samples_read, &bitstream); @@ -160,6 +163,7 @@ class OggVorbisReadableResource : public AudioReadableResourceBase { } samples_read += chunk; } + return Status::OK(); } string DebugString() const override { return "OggVorbisReadableResource"; } diff --git a/tensorflow_io/core/kernels/avro/parse_avro_kernels.cc b/tensorflow_io/core/kernels/avro/parse_avro_kernels.cc index 6eecf841c..156d6c490 100644 --- a/tensorflow_io/core/kernels/avro/parse_avro_kernels.cc +++ b/tensorflow_io/core/kernels/avro/parse_avro_kernels.cc @@ -172,7 +172,9 @@ Status ParseAvro(const AvroParserConfig& config, const gtl::ArraySlice& serialized, thread::ThreadPool* thread_pool, AvroResult* result) { DCHECK(result != nullptr); - + using clock = std::chrono::system_clock; + using ms = std::chrono::duration; + const auto before = clock::now(); // Allocate dense output for fixed length dense values // (variable-length dense and sparse and ragged have to be buffered). /* std::vector fixed_len_dense_values(config.dense.size()); @@ -189,6 +191,10 @@ Status ParseAvro(const AvroParserConfig& config, // This parameter affects performance in a big and data-dependent way. const size_t kMiniBatchSizeBytes = 50000; + // avro_num_minibatches_ is int64 in the op interface. If not set + // the default value is 0. + size_t avro_num_minibatches_; + // Calculate number of minibatches. // In main regime make each minibatch around kMiniBatchSizeBytes bytes. // Apply 'special logic' below for small and big regimes. @@ -204,8 +210,13 @@ Status ParseAvro(const AvroParserConfig& config, minibatch_bytes = 0; } } - // 'special logic' - const size_t min_minibatches = std::min(8, serialized.size()); + if (avro_num_minibatches_) { + VLOG(5) << "Overriding num_minibatches with " << avro_num_minibatches_; + result = avro_num_minibatches_; + } + // This is to ensure users can control the num minibatches all the way down + // to size of 1(no parallelism). + const size_t min_minibatches = std::min(1, serialized.size()); const size_t max_minibatches = 64; return std::max(min_minibatches, std::min(max_minibatches, result)); @@ -245,13 +256,16 @@ Status ParseAvro(const AvroParserConfig& config, auto read_value = [&](avro::GenericDatum& d) { return range_reader.read(d); }; - + VLOG(5) << "Processing minibatch " << minibatch; status_of_minibatch[minibatch] = parser_tree.ParseValues( &buffers[minibatch], read_value, reader_schema, defaults); }; - + const auto before_parse = clock::now(); ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool); - + const auto after_parse = clock::now(); + const ms parse_read_duration = after_parse - before_parse; + VLOG(5) << "PARSER_TIMING: Time spend reading and parsing " + << parse_read_duration.count() << " ms "; for (Status& status : status_of_minibatch) { TF_RETURN_IF_ERROR(status); } @@ -367,15 +381,22 @@ Status ParseAvro(const AvroParserConfig& config, return Status::OK(); }; - + const auto before_sparse_merge = clock::now(); for (size_t d = 0; d < config.sparse.size(); ++d) { TF_RETURN_IF_ERROR(MergeSparseMinibatches(d)); } - + const auto after_sparse_merge = clock::now(); + const ms s_merge_duration = after_sparse_merge - before_sparse_merge; for (size_t d = 0; d < config.dense.size(); ++d) { TF_RETURN_IF_ERROR(MergeDenseMinibatches(d)); } + const auto after_dense_merge = clock::now(); + const ms d_merge_duration = after_dense_merge - after_sparse_merge; + VLOG(5) << "PARSER_TIMING: Sparse merge duration" << s_merge_duration.count() + << " ms "; + VLOG(5) << "PARSER_TIMING: Dense merge duration" << d_merge_duration.count() + << " ms "; return Status::OK(); } @@ -388,6 +409,8 @@ class ParseAvroOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_types", &dense_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("avro_num_minibatches", &avro_num_minibatches_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_)); @@ -401,6 +424,11 @@ class ParseAvroOp : public OpKernel { dense_shapes_[d].dims() > 1 && dense_shapes_[d].dim_size(0) == -1; } + // Check that avro_num_minibatches is not negative + OP_REQUIRES(ctx, avro_num_minibatches_ >= 0, + errors::InvalidArgument("Need avro_num_minibatches >= 0, got ", + avro_num_minibatches_)); + string reader_schema_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("reader_schema", &reader_schema_str)); @@ -495,6 +523,7 @@ class ParseAvroOp : public OpKernel { avro::ValidSchema reader_schema_; size_t num_dense_; size_t num_sparse_; + int64 avro_num_minibatches_; private: std::vector> CreateKeysAndTypes() { diff --git a/tensorflow_io/core/kernels/avro/utils/avro_parser_tree.cc b/tensorflow_io/core/kernels/avro/utils/avro_parser_tree.cc index b1b12573e..b8d5616d5 100644 --- a/tensorflow_io/core/kernels/avro/utils/avro_parser_tree.cc +++ b/tensorflow_io/core/kernels/avro/utils/avro_parser_tree.cc @@ -81,6 +81,9 @@ Status AvroParserTree::ParseValues( const std::function read_value, const avro::ValidSchema& reader_schema, const std::map& defaults) const { + using clock = std::chrono::system_clock; + using ms = std::chrono::duration; + // new assignment of all buffers TF_RETURN_IF_ERROR(InitializeValueBuffers(key_to_value)); @@ -90,11 +93,24 @@ Status AvroParserTree::ParseValues( avro::GenericDatum datum(reader_schema); bool has_value = false; - - while ((has_value = read_value(datum))) { + ms parse_duration; + ms read_duration; + while (true) { + const auto before_read = clock::now(); + if (!(has_value = read_value(datum))) { + break; + } + const auto after_read = clock::now(); TF_RETURN_IF_ERROR((*root_).Parse(key_to_value, datum, defaults)); + const auto after_parse = clock::now(); + parse_duration += after_parse - after_read; + read_duration += after_read - before_read; } + VLOG(5) << "PARSER_TIMING: Avro Read times " << read_duration.count() + << " ms "; + VLOG(5) << "PARSER_TIMING: Avro Parse times " << parse_duration.count() + << " ms "; // add end marks to all buffers for batch TF_RETURN_IF_ERROR(AddFinishMarks(key_to_value)); diff --git a/tensorflow_io/core/kernels/csv_kernels.cc b/tensorflow_io/core/kernels/csv_kernels.cc index c2d5c4f5e..196d1a31c 100644 --- a/tensorflow_io/core/kernels/csv_kernels.cc +++ b/tensorflow_io/core/kernels/csv_kernels.cc @@ -44,19 +44,24 @@ class CSVReadable : public IOReadableInterface { csv_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); - ::arrow::Status status; - - status = ::arrow::csv::TableReader::Make( + auto result = ::arrow::csv::TableReader::Make( ::arrow::default_memory_pool(), csv_file_, ::arrow::csv::ReadOptions::Defaults(), ::arrow::csv::ParseOptions::Defaults(), - ::arrow::csv::ConvertOptions::Defaults(), &reader_); - if (!status.ok()) { - return errors::InvalidArgument("unable to make a TableReader: ", status); + ::arrow::csv::ConvertOptions::Defaults()); + if (!result.status().ok()) { + return errors::InvalidArgument("unable to make a TableReader: ", + result.status()); } - status = reader_->Read(&table_); - if (!status.ok()) { - return errors::InvalidArgument("unable to read table: ", status); + reader_ = std::move(result).ValueUnsafe(); + + { + auto result = reader_->Read(); + if (!result.status().ok()) { + return errors::InvalidArgument("unable to read table: ", + result.status()); + } + table_ = std::move(result).ValueUnsafe(); } for (int i = 0; i < table_->num_columns(); i++) { @@ -108,11 +113,9 @@ class CSVReadable : public IOReadableInterface { case ::arrow::Type::TIMESTAMP: case ::arrow::Type::TIME32: case ::arrow::Type::TIME64: - case ::arrow::Type::INTERVAL: case ::arrow::Type::DECIMAL: case ::arrow::Type::LIST: case ::arrow::Type::STRUCT: - case ::arrow::Type::UNION: case ::arrow::Type::DICTIONARY: case ::arrow::Type::MAP: default: diff --git a/tensorflow_io/core/kernels/mongodb_kernels.cc b/tensorflow_io/core/kernels/mongodb_kernels.cc index cc342ab27..cc4356c21 100644 --- a/tensorflow_io/core/kernels/mongodb_kernels.cc +++ b/tensorflow_io/core/kernels/mongodb_kernels.cc @@ -59,7 +59,6 @@ class MongoDBReadableResource : public ResourceBase { // Register the application name so we can track it in the profile logs // on the server. This can also be done from the URI. - mongoc_client_set_appname(client_obj_, "tfio-mongo-read"); // Get a handle on the database "db_name" and collection "coll_name" @@ -87,12 +86,15 @@ class MongoDBReadableResource : public ResourceBase { const bson_t* doc; int num_records = 0; - while (mongoc_cursor_next(cursor_obj_, &doc) && - num_records < max_num_records) { - char* record = bson_as_canonical_extended_json(doc, NULL); - records.emplace_back(record); - num_records++; - bson_free(record); + while (num_records < max_num_records) { + if (mongoc_cursor_next(cursor_obj_, &doc)) { + char* record = bson_as_canonical_extended_json(doc, NULL); + records.emplace_back(record); + num_records++; + bson_free(record); + } else { + break; + } } TensorShape shape({static_cast(records.size())}); diff --git a/tensorflow_io/core/kernels/obj_kernels.cc b/tensorflow_io/core/kernels/obj_kernels.cc new file mode 100644 index 000000000..e619431e1 --- /dev/null +++ b/tensorflow_io/core/kernels/obj_kernels.cc @@ -0,0 +1,70 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/logging.h" +#include "tiny_obj_loader.h" + +namespace tensorflow { +namespace io { +namespace { + +class DecodeObjOp : public OpKernel { + public: + explicit DecodeObjOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input", &input_tensor)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_tensor->shape()), + errors::InvalidArgument("input must be scalar, got shape ", + input_tensor->shape().DebugString())); + const tstring& input = input_tensor->scalar()(); + + tinyobj::ObjReader reader; + + if (!reader.ParseFromString(input.c_str(), "")) { + OP_REQUIRES( + context, false, + errors::Internal("Unable to read obj file: ", reader.Error())); + } + + if (!reader.Warning().empty()) { + LOG(WARNING) << "TinyObjReader: " << reader.Warning(); + } + + auto& attrib = reader.GetAttrib(); + + int64 count = attrib.vertices.size() / 3; + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({count, 3}), + &output_tensor)); + // Loop over attrib.vertices: + for (int64 i = 0; i < count; i++) { + tinyobj::real_t x = attrib.vertices[i * 3 + 0]; + tinyobj::real_t y = attrib.vertices[i * 3 + 1]; + tinyobj::real_t z = attrib.vertices[i * 3 + 2]; + output_tensor->tensor()(i, 0) = x; + output_tensor->tensor()(i, 1) = y; + output_tensor->tensor()(i, 2) = z; + } + } +}; +REGISTER_KERNEL_BUILDER(Name("IO>DecodeObj").Device(DEVICE_CPU), DecodeObjOp); + +} // namespace +} // namespace io +} // namespace tensorflow diff --git a/tensorflow_io/core/kernels/parquet_kernels.cc b/tensorflow_io/core/kernels/parquet_kernels.cc index c2eafb2c3..56224c643 100644 --- a/tensorflow_io/core/kernels/parquet_kernels.cc +++ b/tensorflow_io/core/kernels/parquet_kernels.cc @@ -162,6 +162,9 @@ class ParquetReadableResource : public ResourceBase { row_group_reader->Column(column_index); // buffer to fill location is value.data()[row_to_read_start - start] + // Note: ReadBatch may not be able to read the elements requested + // (row_to_read_count) in one shot, as such we use while loop of + // `while (row_left > 0) {...}` to read until complete. #define PARQUET_PROCESS_TYPE(ptype, type) \ { \ @@ -172,11 +175,16 @@ class ParquetReadableResource : public ResourceBase { } \ ptype::c_type* value_p = (ptype::c_type*)(void*)(&( \ value->flat().data()[row_to_read_start - element_start])); \ - int64_t values_read; \ - int64_t levels_read = reader->ReadBatch(row_to_read_count, nullptr, \ - nullptr, value_p, &values_read); \ - if (!(levels_read == values_read && levels_read == row_to_read_count)) { \ - return errors::InvalidArgument("null value in column: ", column); \ + int64_t row_left = row_to_read_count; \ + while (row_left > 0) { \ + int64_t values_read; \ + int64_t levels_read = reader->ReadBatch( \ + row_left, nullptr, nullptr, &value_p[row_to_read_count - row_left], \ + &values_read); \ + if (!(levels_read == values_read && levels_read > 0)) { \ + return errors::InvalidArgument("null value in column: ", column); \ + } \ + row_left -= levels_read; \ } \ } @@ -189,13 +197,18 @@ class ParquetReadableResource : public ResourceBase { } \ std::unique_ptr value_p( \ new ptype::c_type[row_to_read_count]); \ - int64_t values_read; \ - int64_t levels_read = reader->ReadBatch( \ - row_to_read_count, nullptr, nullptr, value_p.get(), &values_read); \ - if (!(levels_read == values_read && levels_read == row_to_read_count)) { \ - return errors::InvalidArgument("null value in column: ", column); \ + int64_t row_left = row_to_read_count; \ + while (row_left > 0) { \ + int64_t values_read; \ + int64_t levels_read = reader->ReadBatch( \ + row_left, nullptr, nullptr, \ + &value_p.get()[row_to_read_count - row_left], &values_read); \ + if (!(levels_read == values_read && levels_read > 0)) { \ + return errors::InvalidArgument("null value in column: ", column); \ + } \ + row_left -= levels_read; \ } \ - for (int64_t index = 0; index < values_read; index++) { \ + for (int64_t index = 0; index < row_to_read_count; index++) { \ value->flat()(row_to_read_start - element_start + index) = \ ByteArrayToString(value_p[index]); \ } \ @@ -210,13 +223,18 @@ class ParquetReadableResource : public ResourceBase { } \ std::unique_ptr value_p( \ new ptype::c_type[row_to_read_count]); \ - int64_t values_read; \ - int64_t levels_read = reader->ReadBatch( \ - row_to_read_count, nullptr, nullptr, value_p.get(), &values_read); \ - if (!(levels_read == values_read && levels_read == row_to_read_count)) { \ - return errors::InvalidArgument("null value in column: ", column); \ + int64_t row_left = row_to_read_count; \ + while (row_left > 0) { \ + int64_t values_read; \ + int64_t levels_read = reader->ReadBatch( \ + row_left, nullptr, nullptr, \ + &value_p.get()[row_to_read_count - row_left], &values_read); \ + if (!(levels_read == values_read && levels_read > 0)) { \ + return errors::InvalidArgument("null value in column: ", column); \ + } \ + row_left -= levels_read; \ } \ - for (int64_t index = 0; index < values_read; index++) { \ + for (int64_t index = 0; index < row_to_read_count; index++) { \ value->flat()(row_to_read_start - element_start + index) = \ string((const char*)value_p[index].ptr, len); \ } \ diff --git a/tensorflow_io/core/kernels/text_kernels.cc b/tensorflow_io/core/kernels/text_kernels.cc index 119437705..dea82134f 100644 --- a/tensorflow_io/core/kernels/text_kernels.cc +++ b/tensorflow_io/core/kernels/text_kernels.cc @@ -16,10 +16,11 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow_io/core/kernels/io_stream.h" - #if defined(_MSC_VER) #include #define STDIN_FILENO _fileno(stdin) +#else +#include #endif namespace tensorflow { diff --git a/tensorflow_io/core/kernels/video_kernels.h b/tensorflow_io/core/kernels/video_kernels.h index cddcfe594..91838dbe5 100644 --- a/tensorflow_io/core/kernels/video_kernels.h +++ b/tensorflow_io/core/kernels/video_kernels.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include static int xioctl(int fh, int request, void* arg) { int r; diff --git a/tensorflow_io/core/ops/avro_ops.cc b/tensorflow_io/core/ops/avro_ops.cc index 2432292a0..34a8b19fb 100644 --- a/tensorflow_io/core/ops/avro_ops.cc +++ b/tensorflow_io/core/ops/avro_ops.cc @@ -83,6 +83,7 @@ REGISTER_OP("IO>ParseAvro") .Output("sparse_values: sparse_types") .Output("sparse_shapes: num_sparse * int64") .Output("dense_values: dense_types") + .Attr("avro_num_minibatches: int >= 0") .Attr("num_sparse: int >= 0") .Attr("reader_schema: string") .Attr("sparse_keys: list(string) >= 0") @@ -94,6 +95,7 @@ REGISTER_OP("IO>ParseAvro") .SetShapeFn([](shape_inference::InferenceContext* c) { size_t num_dense; size_t num_sparse; + int64 avro_num_minibatches; int64 num_sparse_from_user; std::vector sparse_types; std::vector dense_types; @@ -106,6 +108,8 @@ REGISTER_OP("IO>ParseAvro") TF_RETURN_IF_ERROR(c->GetAttr("sparse_types", &sparse_types)); TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types)); TF_RETURN_IF_ERROR(c->GetAttr("dense_shapes", &dense_shapes)); + TF_RETURN_IF_ERROR( + c->GetAttr("avro_num_minibatches", &avro_num_minibatches)); TF_RETURN_IF_ERROR(c->GetAttr("sparse_keys", &sparse_keys)); TF_RETURN_IF_ERROR(c->GetAttr("sparse_ranks", &sparse_ranks)); diff --git a/tensorflow_io/core/ops/obj_ops.cc b/tensorflow_io/core/ops/obj_ops.cc new file mode 100644 index 000000000..e3c45653a --- /dev/null +++ b/tensorflow_io/core/ops/obj_ops.cc @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace io { +namespace { + +REGISTER_OP("IO>DecodeObj") + .Input("input: string") + .Output("output: float32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + c->set_output(0, c->MakeShape({c->UnknownDim(), 3})); + return Status::OK(); + }); + +} // namespace +} // namespace io +} // namespace tensorflow diff --git a/tensorflow_io/core/plugins/BUILD b/tensorflow_io/core/plugins/BUILD index 6559a6ba6..d6b2d83c2 100644 --- a/tensorflow_io/core/plugins/BUILD +++ b/tensorflow_io/core/plugins/BUILD @@ -16,7 +16,7 @@ cc_library( linkstatic = True, deps = [ "@local_config_tf//:libtensorflow_framework", - "@local_config_tf//:tf_header_lib", + "@local_config_tf//:tf_c_header_lib", ], alwayslink = 1, ) diff --git a/tensorflow_io/core/plugins/az/az_file_system.cc b/tensorflow_io/core/plugins/az/az_file_system.cc index ddf0aefd4..18ff9b7f2 100644 --- a/tensorflow_io/core/plugins/az/az_file_system.cc +++ b/tensorflow_io/core/plugins/az/az_file_system.cc @@ -161,7 +161,7 @@ std::string errno_to_string() { case container_delete_fail: return "container_delete_fail"; /* blob level */ - case blob__already_exists: + case blob_already_exists: return "blob__already_exists"; case blob_not_exists: return "blob_not_exists"; diff --git a/tensorflow_io/core/plugins/file_system_plugins.cc b/tensorflow_io/core/plugins/file_system_plugins.cc index bc408dcb7..5000d7c83 100644 --- a/tensorflow_io/core/plugins/file_system_plugins.cc +++ b/tensorflow_io/core/plugins/file_system_plugins.cc @@ -15,7 +15,13 @@ limitations under the License. #include "tensorflow_io/core/plugins/file_system_plugins.h" +#include "absl/strings/ascii.h" + void TF_InitPlugin(TF_FilesystemPluginInfo* info) { + const char* enable_legacy_env = getenv("TF_ENABLE_LEGACY_FILESYSTEM"); + std::string enable_legacy = + enable_legacy_env ? absl::AsciiStrToLower(enable_legacy_env) : ""; + info->plugin_memory_allocate = tensorflow::io::plugin_memory_allocate; info->plugin_memory_free = tensorflow::io::plugin_memory_free; info->num_schemes = 7; @@ -24,9 +30,17 @@ void TF_InitPlugin(TF_FilesystemPluginInfo* info) { sizeof(info->ops[0]))); tensorflow::io::az::ProvideFilesystemSupportFor(&info->ops[0], "az"); tensorflow::io::http::ProvideFilesystemSupportFor(&info->ops[1], "http"); - tensorflow::io::s3::ProvideFilesystemSupportFor(&info->ops[2], "s3e"); - tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[3], "hdfse"); - tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[4], "viewfse"); - tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[5], "hare"); - tensorflow::io::gs::ProvideFilesystemSupportFor(&info->ops[6], "gse"); + if (enable_legacy == "true" || enable_legacy == "1") { + tensorflow::io::s3::ProvideFilesystemSupportFor(&info->ops[2], "s3e"); + tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[3], "hdfse"); + tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[4], "viewfse"); + tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[5], "hare"); + tensorflow::io::gs::ProvideFilesystemSupportFor(&info->ops[6], "gse"); + } else { + tensorflow::io::s3::ProvideFilesystemSupportFor(&info->ops[2], "s3"); + tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[3], "hdfs"); + tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[4], "viewfs"); + tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[5], "har"); + tensorflow::io::gs::ProvideFilesystemSupportFor(&info->ops[6], "gs"); + } } diff --git a/tensorflow_io/core/plugins/gs/BUILD b/tensorflow_io/core/plugins/gs/BUILD index 7932a1c0a..ca0495936 100644 --- a/tensorflow_io/core/plugins/gs/BUILD +++ b/tensorflow_io/core/plugins/gs/BUILD @@ -12,8 +12,6 @@ cc_library( srcs = [ "cleanup.h", "expiring_lru_cache.h", - "gcs_env.cc", - "gcs_env.h", "gcs_filesystem.cc", "gcs_helper.cc", "gcs_helper.h", diff --git a/tensorflow_io/core/plugins/gs/expiring_lru_cache.h b/tensorflow_io/core/plugins/gs/expiring_lru_cache.h index 8902bec6b..110791b97 100644 --- a/tensorflow_io/core/plugins/gs/expiring_lru_cache.h +++ b/tensorflow_io/core/plugins/gs/expiring_lru_cache.h @@ -24,8 +24,8 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" +#include "tensorflow/c/env.h" #include "tensorflow/c/tf_status.h" -#include "tensorflow_io/core/plugins/gs/gcs_env.h" namespace tensorflow { namespace io { @@ -44,7 +44,7 @@ class ExpiringLRUCache { /// that there is no limit on the number of entries in the cache (however, if /// `max_age` is also 0, the cache will not be populated). ExpiringLRUCache(uint64_t max_age, size_t max_entries, - std::function timer_seconds = GCSNowSeconds) + std::function timer_seconds = TF_NowSeconds) : max_age_(max_age), max_entries_(max_entries), timer_seconds_(timer_seconds) {} diff --git a/tensorflow_io/core/plugins/gs/gcs_env.cc b/tensorflow_io/core/plugins/gs/gcs_env.cc deleted file mode 100644 index a8cbb1418..000000000 --- a/tensorflow_io/core/plugins/gs/gcs_env.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#if defined(_MSC_VER) -#include -#else -#include -#endif -#include -#include - -#include "tensorflow/c/logging.h" -#include "tensorflow_io/core/plugins/gs/gcs_env.h" - -namespace tensorflow { -namespace io { -namespace gs { -namespace { -// Returns a unique number every time it is called. -int64_t UniqueId() { - static absl::Mutex mu; - static int64_t id = 0; - absl::MutexLock l(&mu); - return ++id; -} - -static bool IsAbsolutePath(absl::string_view path) { - return !path.empty() && path[0] == '/'; -} - -std::string JoinPath(std::initializer_list paths) { - std::string result; - - for (absl::string_view path : paths) { - if (path.empty()) continue; - - if (result.empty()) { - result = std::string(path); - continue; - } - - if (result[result.size() - 1] == '/') { - if (IsAbsolutePath(path)) { - absl::StrAppend(&result, path.substr(1)); - } else { - absl::StrAppend(&result, path); - } - } else { - if (IsAbsolutePath(path)) { - absl::StrAppend(&result, path); - } else { - absl::StrAppend(&result, "/", path); - } - } - } - - return result; -} - -} // namespace - -uint64_t GCSNowSeconds(void) { - // TODO: Either implement NowSeconds here, or have TensorFlow API exposed - std::abort(); -} - -void GCSDefaultThreadOptions(GCSThreadOptions* options) { - options->stack_size = 0; - options->guard_size = 0; - options->numa_node = -1; -} - -std::string GCSGetTempFileName(const std::string& extension) { -#if defined(_MSC_VER) - char temp_dir[_MAX_PATH]; - DWORD retval; - retval = GetTempPath(_MAX_PATH, temp_dir); - if (retval > _MAX_PATH || retval == 0) { - TF_Log(TF_FATAL, "Cannot get the directory for temporary files."); - } - - char temp_file_name[_MAX_PATH]; - retval = GetTempFileNameA(temp_dir, "", UniqueId(), temp_file_name); - if (retval > _MAX_PATH || retval == 0) { - TF_Log(TF_FATAL, "Cannot get a temporary file in: %s", temp_dir); - } - - std::string full_tmp_file_name(temp_file_name); - full_tmp_file_name.append(extension); - return full_tmp_file_name; -#else - for (const char* dir : std::vector( - {getenv("TEST_TMPDIR"), getenv("TMPDIR"), getenv("TMP"), "/tmp"})) { - if (!dir || !dir[0]) { - continue; - } - struct stat statbuf; - if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) { - // UniqueId is added here because mkstemps is not as thread safe as it - // looks. https://github.com/tensorflow/tensorflow/issues/5804 shows - // the problem. - std::string tmp_filepath; - int fd; - if (extension.length()) { - tmp_filepath = - JoinPath({dir, absl::StrCat("tmp_file_tensorflow_", UniqueId(), - "_XXXXXX.", extension)}); - fd = mkstemps(&tmp_filepath[0], extension.length() + 1); - } else { - tmp_filepath = JoinPath( - {dir, absl::StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX")}); - fd = mkstemp(&tmp_filepath[0]); - } - if (fd < 0) { - TF_Log(TF_FATAL, "Failed to create temp file."); - } else { - if (close(fd) < 0) { - TF_Log(TF_ERROR, "close() failed: %s", strerror(errno)); - } - return tmp_filepath; - } - } - } - TF_Log(TF_FATAL, "No temp directory found."); - std::abort(); -#endif -} - -GCSThread* GCSStartThread(const GCSThreadOptions* options, - const char* thread_name, void (*work_func)(void*), - void* param) { - // TODO: Either implement StartThread here, or have TensorFlow API exposed - std::abort(); - return nullptr; -} - -void GCSJoinThread(GCSThread* thread) { - // TODO: Either implement JoinThread here, or have TensorFlow API exposed - std::abort(); -} - -} // namespace gs -} // namespace io -} // namespace tensorflow diff --git a/tensorflow_io/core/plugins/gs/gcs_env.h b/tensorflow_io/core/plugins/gs/gcs_env.h deleted file mode 100644 index d4af2e726..000000000 --- a/tensorflow_io/core/plugins/gs/gcs_env.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_IO_CORE_PLUGINS_GS_GCS_ENV_H_ -#define TENSORFLOW_IO_CORE_PLUGINS_GS_GCS_ENV_H_ - -#include "inttypes.h" -#include "tensorflow/c/tf_status.h" - -namespace tensorflow { -namespace io { -namespace gs { - -typedef struct GCSThread GCSThread; -typedef struct GCSThreadOptions { - size_t stack_size; - size_t guard_size; - int numa_node; -} GCSThreadOptions; - -std::string GCSGetTempFileName(const std::string& extension); -uint64_t GCSNowSeconds(void); -void GCSDefaultThreadOptions(GCSThreadOptions* options); -GCSThread* GCSStartThread(const GCSThreadOptions* options, - const char* thread_name, void (*work_func)(void*), - void* param); -void GCSJoinThread(GCSThread* thread); - -} // namespace gs -} // namespace io -} // namespace tensorflow - -#endif // TENSORFLOW_IO_CORE_PLUGINS_GS_GCS_ENV_H_ diff --git a/tensorflow_io/core/plugins/gs/gcs_helper.cc b/tensorflow_io/core/plugins/gs/gcs_helper.cc index 1368bd98a..9948e6524 100644 --- a/tensorflow_io/core/plugins/gs/gcs_helper.cc +++ b/tensorflow_io/core/plugins/gs/gcs_helper.cc @@ -16,10 +16,13 @@ limitations under the License. #include +#include #include #include #include +#include "tensorflow/c/env.h" + TempFile::TempFile(const std::string& temp_file_name, std::ios::openmode mode) : std::fstream(temp_file_name, mode), name_(temp_file_name) {} @@ -38,3 +41,11 @@ bool TempFile::truncate() { std::fstream::open(name_, std::ios::binary | std::ios::out); return std::fstream::is_open(); } + +std::string GCSGetTempFileName(const std::string& extension) { + char* raw_temp_file_name = TF_GetTempFileName(extension.c_str()); + if (!raw_temp_file_name) return ""; + std::string temp_file_name(raw_temp_file_name); + std::free(raw_temp_file_name); + return temp_file_name; +} diff --git a/tensorflow_io/core/plugins/gs/gcs_helper.h b/tensorflow_io/core/plugins/gs/gcs_helper.h index 034777c3b..33c7926a7 100644 --- a/tensorflow_io/core/plugins/gs/gcs_helper.h +++ b/tensorflow_io/core/plugins/gs/gcs_helper.h @@ -31,4 +31,6 @@ class TempFile : public std::fstream { const std::string name_; }; +std::string GCSGetTempFileName(const std::string& extension); + #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_ diff --git a/tensorflow_io/core/plugins/gs/ram_file_block_cache.h b/tensorflow_io/core/plugins/gs/ram_file_block_cache.h index d93be1a01..788f1baa8 100644 --- a/tensorflow_io/core/plugins/gs/ram_file_block_cache.h +++ b/tensorflow_io/core/plugins/gs/ram_file_block_cache.h @@ -27,9 +27,9 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" +#include "tensorflow/c/env.h" #include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" -#include "tensorflow_io/core/plugins/gs/gcs_env.h" namespace tensorflow { namespace io { @@ -56,19 +56,19 @@ class RamFileBlockCache { RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness, BlockFetcher block_fetcher, - std::function timer_seconds = GCSNowSeconds) + std::function timer_seconds = TF_NowSeconds) : block_size_(block_size), max_bytes_(max_bytes), max_staleness_(max_staleness), block_fetcher_(block_fetcher), timer_seconds_(timer_seconds), pruning_thread_(nullptr, - [](GCSThread* thread) { GCSJoinThread(thread); }) { + [](TF_Thread* thread) { TF_JoinThread(thread); }) { if (max_staleness_ > 0) { - GCSThreadOptions thread_options; - GCSDefaultThreadOptions(&thread_options); + TF_ThreadOptions thread_options; + TF_DefaultThreadOptions(&thread_options); pruning_thread_.reset( - GCSStartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); + TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); } TF_VLog(1, "GCS file block cache is %s.\n", (IsCacheEnabled() ? "enabled" : "disabled")); @@ -236,7 +236,7 @@ class RamFileBlockCache { void RemoveBlock(BlockMap::iterator entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); /// The cache pruning thread that removes files with expired blocks. - std::unique_ptr> pruning_thread_; + std::unique_ptr> pruning_thread_; /// Notification for stopping the cache pruning thread. absl::Notification stop_pruning_thread_; diff --git a/tensorflow_io/core/plugins/hdfs/BUILD b/tensorflow_io/core/plugins/hdfs/BUILD index d4571712f..867310815 100644 --- a/tensorflow_io/core/plugins/hdfs/BUILD +++ b/tensorflow_io/core/plugins/hdfs/BUILD @@ -16,6 +16,8 @@ cc_library( linkstatic = True, deps = [ "//tensorflow_io/core/plugins:plugins_header", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@hadoop", ], alwayslink = 1, diff --git a/tensorflow_io/core/plugins/hdfs/hadoop_filesystem.cc b/tensorflow_io/core/plugins/hdfs/hadoop_filesystem.cc index 83bdb0a1d..4946a2b3c 100644 --- a/tensorflow_io/core/plugins/hdfs/hadoop_filesystem.cc +++ b/tensorflow_io/core/plugins/hdfs/hadoop_filesystem.cc @@ -774,6 +774,11 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, TF_SetStatus(status, TF_OK, ""); } +static void RecursivelyCreateDir(const TF_Filesystem* filesystem, + const char* path, TF_Status* status) { + CreateDir(filesystem, path, status); +} + void DeleteDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { auto hadoop_file = @@ -931,6 +936,8 @@ void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, const char* uri) { ops->filesystem_ops->new_read_only_memory_region_from_file = tf_hdfs_filesystem::NewReadOnlyMemoryRegionFromFile; ops->filesystem_ops->create_dir = tf_hdfs_filesystem::CreateDir; + ops->filesystem_ops->recursively_create_dir = + tf_hdfs_filesystem::RecursivelyCreateDir; ops->filesystem_ops->delete_file = tf_hdfs_filesystem::DeleteFile; ops->filesystem_ops->delete_dir = tf_hdfs_filesystem::DeleteDir; ops->filesystem_ops->rename_file = tf_hdfs_filesystem::RenameFile; diff --git a/tensorflow_io/core/plugins/s3/BUILD b/tensorflow_io/core/plugins/s3/BUILD index 6007d4d0d..7b262c053 100644 --- a/tensorflow_io/core/plugins/s3/BUILD +++ b/tensorflow_io/core/plugins/s3/BUILD @@ -21,7 +21,8 @@ cc_library( linkstatic = True, deps = [ "//tensorflow_io/core/plugins:plugins_header", - "@aws-sdk-cpp", + "@aws-sdk-cpp//:s3", + "@aws-sdk-cpp//:transfer", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], diff --git a/tensorflow_io/core/plugins/s3/s3_filesystem.cc b/tensorflow_io/core/plugins/s3/s3_filesystem.cc index 7bc5dd41d..17f4ce084 100644 --- a/tensorflow_io/core/plugins/s3/s3_filesystem.cc +++ b/tensorflow_io/core/plugins/s3/s3_filesystem.cc @@ -62,29 +62,41 @@ constexpr size_t kUploadRetries = 3; constexpr size_t kS3ReadAppendableFileBufferSize = 1024 * 1024; // 1 MB -static void* plugin_memory_allocate(size_t size) { return calloc(1, size); } -static void plugin_memory_free(void* ptr) { free(ptr); } - static inline void TF_SetStatusFromAWSError( const Aws::Client::AWSError& error, TF_Status* status) { - switch (error.GetResponseCode()) { - case Aws::Http::HttpResponseCode::FORBIDDEN: - TF_SetStatus(status, TF_FAILED_PRECONDITION, - "AWS Credentials have not been set properly. " - "Unable to access the specified S3 location"); - break; - case Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE: - TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); - break; - case Aws::Http::HttpResponseCode::NOT_FOUND: - TF_SetStatus(status, TF_NOT_FOUND, error.GetMessage().c_str()); - break; - default: - TF_SetStatus( - status, TF_UNKNOWN, - (error.GetExceptionName() + ": " + error.GetMessage()).c_str()); - break; + auto http_code = error.GetResponseCode(); + auto status_msg = error.GetExceptionName() + ": " + error.GetMessage(); + if (http_code == Aws::Http::HttpResponseCode::BAD_REQUEST) { + return TF_SetStatus(status, TF_INVALID_ARGUMENT, status_msg.c_str()); + } + if (http_code == Aws::Http::HttpResponseCode::UNAUTHORIZED) { + return TF_SetStatus(status, TF_UNAUTHENTICATED, status_msg.c_str()); + } + if (http_code == Aws::Http::HttpResponseCode::FORBIDDEN) { + return TF_SetStatus(status, TF_PERMISSION_DENIED, status_msg.c_str()); + } + if (http_code == Aws::Http::HttpResponseCode::NOT_FOUND) { + return TF_SetStatus(status, TF_NOT_FOUND, status_msg.c_str()); } + if (http_code == Aws::Http::HttpResponseCode::METHOD_NOT_ALLOWED || + http_code == Aws::Http::HttpResponseCode::NOT_ACCEPTABLE || + http_code == Aws::Http::HttpResponseCode::PROXY_AUTHENTICATION_REQUIRED) { + return TF_SetStatus(status, TF_PERMISSION_DENIED, status_msg.c_str()); + } + if (http_code == Aws::Http::HttpResponseCode::REQUEST_TIMEOUT) { + return TF_SetStatus(status, TF_RESOURCE_EXHAUSTED, status_msg.c_str()); + } + if (http_code == Aws::Http::HttpResponseCode::PRECONDITION_FAILED) { + return TF_SetStatus(status, TF_FAILED_PRECONDITION, status_msg.c_str()); + } + if (http_code == + Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE) { + return TF_SetStatus(status, TF_OUT_OF_RANGE, status_msg.c_str()); + } + if (Aws::Http::HttpResponseCode::INTERNAL_SERVER_ERROR <= http_code) { + return TF_SetStatus(status, TF_INTERNAL, status_msg.c_str()); + } + return TF_SetStatus(status, TF_UNKNOWN, status_msg.c_str()); } void ParseS3Path(const Aws::String& fname, bool object_empty_ok, @@ -123,8 +135,6 @@ static Aws::Client::ClientConfiguration& GetDefaultClientConfig() { absl::MutexLock l(&cfg_lock); if (!init) { - const char* endpoint = getenv("S3_ENDPOINT"); - if (endpoint) cfg.endpointOverride = Aws::String(endpoint); const char* region = getenv("AWS_REGION"); // TODO (yongtang): `S3_REGION` should be deprecated after 2.0. if (!region) region = getenv("S3_REGION"); @@ -156,20 +166,6 @@ static Aws::Client::ClientConfiguration& GetDefaultClientConfig() { cfg.region = profiles["default"].GetRegion(); } } - const char* use_https = getenv("S3_USE_HTTPS"); - if (use_https) { - if (use_https[0] == '0') - cfg.scheme = Aws::Http::Scheme::HTTP; - else - cfg.scheme = Aws::Http::Scheme::HTTPS; - } - const char* verify_ssl = getenv("S3_VERIFY_SSL"); - if (verify_ssl) { - if (verify_ssl[0] == '0') - cfg.verifySSL = false; - else - cfg.verifySSL = true; - } // if these timeouts are low, you may see an error when // uploading/downloading large files: Unable to connect to endpoint int64_t timeout; @@ -218,9 +214,24 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) { // in the bucket name. Due to TLS hostname validation or DNS rules, // the bucket may not be resolved. Disabling of virtual addressing // should address the issue. See GitHub issue 16397 for details. - s3_file->s3_client = Aws::MakeShared( - kS3ClientAllocationTag, GetDefaultClientConfig(), - Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false); + s3_file->s3_client = std::shared_ptr( + Aws::New( + kS3ClientAllocationTag, GetDefaultClientConfig(), + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false), + [&options](Aws::S3::S3Client* s3_client) { + if (s3_client != nullptr) { + Aws::Delete(s3_client); + Aws::ShutdownAPI(options); + tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging(); + } + }); + + int temp_value; + if (absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value)) + s3_file->use_multi_part_download = (temp_value != 1); + + const char* endpoint = getenv("S3_ENDPOINT"); + if (endpoint) s3_file->s3_client->OverrideEndpoint(endpoint); } } @@ -243,24 +254,26 @@ static void GetTransferManager( absl::MutexLock l(&s3_file->initialization_lock); - if (s3_file->transfer_managers[direction].get() == nullptr) { + if (s3_file->transfer_managers.count(direction) == 0) { + uint64_t temp_value; + if (direction == Aws::Transfer::TransferDirection::UPLOAD) { + if (!absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"), + &temp_value)) + temp_value = kS3MultiPartUploadChunkSize; + } else if (direction == Aws::Transfer::TransferDirection::DOWNLOAD) { + if (!absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"), + &temp_value)) + temp_value = kS3MultiPartDownloadChunkSize; + } + s3_file->multi_part_chunk_sizes.emplace(direction, temp_value); + Aws::Transfer::TransferManagerConfiguration config(s3_file->executor.get()); config.s3Client = s3_file->s3_client; - config.bufferSize = s3_file->multi_part_chunk_sizes[direction]; + config.bufferSize = temp_value; // must be larger than pool size * multi part chunk size - config.transferBufferMaxHeapSize = - (kExecutorPoolSize + 1) * s3_file->multi_part_chunk_sizes[direction]; - s3_file->transfer_managers[direction] = - Aws::Transfer::TransferManager::Create(config); - } -} - -static void ShutdownClient(Aws::S3::S3Client* s3_client) { - if (s3_client != nullptr) { - delete s3_client; - Aws::SDKOptions options; - Aws::ShutdownAPI(options); - tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging(); + config.transferBufferMaxHeapSize = (kExecutorPoolSize + 1) * temp_value; + s3_file->transfer_managers.emplace( + direction, Aws::Transfer::TransferManager::Create(config)); } } @@ -319,11 +332,10 @@ static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n, static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n, char* buffer, TF_Status* status) { TF_VLog(3, "Using TransferManager\n"); - auto create_download_stream = [&]() { - return Aws::New( - "S3ReadStream", - Aws::New( - "S3ReadStream", reinterpret_cast(buffer), n)); + auto stream_buf = Aws::MakeShared( + "S3StreamBuf", reinterpret_cast(buffer), n); + auto create_download_stream = [stream_buf]() { + return Aws::New("S3ReadStream", stream_buf.get()); }; TF_VLog(3, "Created stream to read with transferManager\n"); auto handle = s3_file->transfer_manager->DownloadFile( @@ -439,9 +451,11 @@ void Sync(const TF_WritableFile* file, TF_Status* status) { TF_VLog(1, "WriteFileToS3: s3://%s/%s\n", s3_file->bucket.c_str(), s3_file->object.c_str()); auto position = static_cast(s3_file->outfile->tellp()); + // We always re-upload the whole file. + s3_file->outfile->seekg(0); auto handle = s3_file->transfer_manager->UploadFile( s3_file->outfile, s3_file->bucket, s3_file->object, - "application/octet-stream", Aws::Map()); + "application/octet-stream", /*metadata*/ {}); handle->WaitUntilFinished(); size_t retries = 0; @@ -512,29 +526,12 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region) { // is an error - should return OUT_OF_RANGE with less bytes. namespace tf_s3_filesystem { S3File::S3File() - : s3_client(nullptr, ShutdownClient), + : s3_client(nullptr), executor(nullptr), transfer_managers(), multi_part_chunk_sizes(), use_multi_part_download(false), // TODO: change to true after fix - initialization_lock() { - uint64_t temp_value; - multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] = - absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"), &temp_value) - ? temp_value - : kS3MultiPartUploadChunkSize; - multi_part_chunk_sizes[Aws::Transfer::TransferDirection::DOWNLOAD] = - absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"), &temp_value) - ? temp_value - : kS3MultiPartDownloadChunkSize; - use_multi_part_download = - absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value) - ? (temp_value != 1) - : use_multi_part_download; - transfer_managers.emplace(Aws::Transfer::TransferDirection::UPLOAD, nullptr); - transfer_managers.emplace(Aws::Transfer::TransferDirection::DOWNLOAD, - nullptr); -} + initialization_lock() {} void Init(TF_Filesystem* filesystem, TF_Status* status) { filesystem->plugin_filesystem = new S3File(); TF_SetStatus(status, TF_OK, ""); @@ -768,8 +765,8 @@ static void SimpleCopyFile(const Aws::String& source, const Aws::String& bucket_dst, const Aws::String& object_dst, S3File* s3_file, TF_Status* status) { - TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", bucket_dst.c_str(), - object_dst.c_str()); + TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", source.c_str(), + bucket_dst.c_str(), object_dst.c_str()); Aws::S3::Model::CopyObjectRequest copy_object_request; copy_object_request.WithCopySource(source) .WithBucket(bucket_dst) @@ -834,8 +831,8 @@ static void MultiPartCopy(const Aws::String& source, const Aws::String& object_dst, const size_t num_parts, const uint64_t file_size, S3File* s3_file, TF_Status* status) { - TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", bucket_dst.c_str(), - object_dst.c_str()); + TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", source.c_str(), + bucket_dst.c_str(), object_dst.c_str()); Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request; create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst); @@ -1066,6 +1063,11 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, TF_SetStatus(status, TF_OK, ""); } +void RecursivelyCreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + CreateDir(filesystem, path, status); +} + void DeleteDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { TF_VLog(1, "DeleteDir: %s\n", path); @@ -1262,6 +1264,8 @@ void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, const char* uri) { ops->filesystem_ops->new_read_only_memory_region_from_file = tf_s3_filesystem::NewReadOnlyMemoryRegionFromFile; ops->filesystem_ops->create_dir = tf_s3_filesystem::CreateDir; + ops->filesystem_ops->recursively_create_dir = + tf_s3_filesystem::RecursivelyCreateDir; ops->filesystem_ops->delete_file = tf_s3_filesystem::DeleteFile; ops->filesystem_ops->delete_dir = tf_s3_filesystem::DeleteDir; ops->filesystem_ops->copy_file = tf_s3_filesystem::CopyFile; diff --git a/tensorflow_io/core/plugins/s3/s3_filesystem.h b/tensorflow_io/core/plugins/s3/s3_filesystem.h index 85e71b0d6..882785854 100644 --- a/tensorflow_io/core/plugins/s3/s3_filesystem.h +++ b/tensorflow_io/core/plugins/s3/s3_filesystem.h @@ -60,11 +60,12 @@ typedef struct S3File { std::shared_ptr s3_client; std::shared_ptr executor; // We need 2 `TransferManager`, for multipart upload/download. - Aws::Map> + Aws::UnorderedMap> transfer_managers; // Sizes to split objects during multipart upload/download. - Aws::Map multi_part_chunk_sizes; + Aws::UnorderedMap + multi_part_chunk_sizes; bool use_multi_part_download; absl::Mutex initialization_lock; S3File(); diff --git a/tensorflow_io/core/python/api/experimental/image.py b/tensorflow_io/core/python/api/experimental/image.py index 0d9c07e74..643b7b9c0 100644 --- a/tensorflow_io/core/python/api/experimental/image.py +++ b/tensorflow_io/core/python/api/experimental/image.py @@ -27,4 +27,5 @@ decode_yuy2, decode_avif, decode_jp2, + decode_obj, ) diff --git a/tensorflow_io/core/python/api/experimental/serialization.py b/tensorflow_io/core/python/api/experimental/serialization.py index a6fdc246d..4dca3c941 100644 --- a/tensorflow_io/core/python/api/experimental/serialization.py +++ b/tensorflow_io/core/python/api/experimental/serialization.py @@ -19,3 +19,5 @@ decode_avro, encode_avro, ) + +from tensorflow_io.core.python.experimental.serial_ops import save_dataset, load_dataset diff --git a/tensorflow_io/core/python/experimental/avro_record_dataset_ops.py b/tensorflow_io/core/python/experimental/avro_record_dataset_ops.py index b84a97cf4..6429aac10 100644 --- a/tensorflow_io/core/python/experimental/avro_record_dataset_ops.py +++ b/tensorflow_io/core/python/experimental/avro_record_dataset_ops.py @@ -21,6 +21,24 @@ _DEFAULT_READER_SCHEMA = "" # From https://github.com/tensorflow/tensorflow/blob/v2.0.0/tensorflow/python/data/ops/readers.py + +def _require(condition: bool, err_msg: str = None) -> None: + """Checks if the specified condition is true else raises exception + + Args: + condition: The condition to test + err_msg: If specified, it's the error message to use if condition is not true. + + Raises: + ValueError: Raised when the condition is false + + Returns: + None + """ + if not condition: + raise ValueError(err_msg) + + # copied from https://github.com/tensorflow/tensorflow/blob/ # 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L36 def _create_or_validate_filenames_dataset(filenames): @@ -52,21 +70,62 @@ def _create_or_validate_filenames_dataset(filenames): # copied from https://github.com/tensorflow/tensorflow/blob/ # 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L67 -def _create_dataset_reader(dataset_creator, filenames, num_parallel_reads=None): - """create_dataset_reader""" - - def read_one_file(filename): - filename = tf.convert_to_tensor(filename, tf.string, name="filename") - return dataset_creator(filename) - - if num_parallel_reads is None: - return filenames.flat_map(read_one_file) - if num_parallel_reads == tf.data.experimental.AUTOTUNE: - return filenames.interleave( - read_one_file, num_parallel_calls=num_parallel_reads - ) +def _create_dataset_reader( + dataset_creator, + filenames, + cycle_length=None, + num_parallel_calls=None, + deterministic=None, + block_length=1, +): + """ + This creates a dataset reader which reads records from multiple files and interleaves them together +``` +dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] +# NOTE: New lines indicate "block" boundaries. +dataset = dataset.interleave( + lambda x: Dataset.from_tensors(x).repeat(6), + cycle_length=2, block_length=4) +list(dataset.as_numpy_iterator()) +``` +Results in the following output: +[1,1,1,1, + 2,2,2,2, + 1,1, + 2,2, + 3,3,3,3, + 4,4,4,4, + 3,4, + 5,5,5,5, + 5,5, +] + Args: + dataset_creator: Initializer for AvroDatasetRecord + filenames: A `tf.data.Dataset` iterator of filenames to read + cycle_length: The number of files to be processed in parallel. This is used by `Dataset.Interleave`. + We set this equal to `block_length`, so that each time n number of records are returned for each of the n + files. + num_parallel_calls: Number of threads spawned by the interleave call. + deterministic: Sets whether the interleaved records are written in deterministic order. in tf.interleave this is default true + block_length: Sets the number of output on the output tensor. Defaults to 1 + Returns: + A dataset iterator with an interleaved list of parsed avro records. + + """ + + def read_many_files(filenames): + filenames = tf.convert_to_tensor(filenames, tf.string, name="filename") + return dataset_creator(filenames) + + if cycle_length is None: + return filenames.flat_map(read_many_files) + return filenames.interleave( - read_one_file, cycle_length=num_parallel_reads, block_length=1 + read_many_files, + cycle_length=cycle_length, + num_parallel_calls=num_parallel_calls, + block_length=block_length, + deterministic=deterministic, ) @@ -128,10 +187,16 @@ class AvroRecordDataset(tf.data.Dataset): """A `Dataset` comprising records from one or more AvroRecord files.""" def __init__( - self, filenames, buffer_size=None, num_parallel_reads=None, reader_schema=None + self, + filenames, + buffer_size=None, + num_parallel_reads=None, + num_parallel_calls=None, + reader_schema=None, + deterministic=True, + block_length=1, ): """Creates a `AvroRecordDataset` to read one or more AvroRecord files. - Args: filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or more filenames. @@ -144,25 +209,61 @@ def __init__( files read in parallel are outputted in an interleaved order. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value greater than one to parallelize the I/O. If `None`, files will be - read sequentially. + read sequentially. This must be set to equal or greater than `num_parallel_calls`. + This constraint exists because `num_parallel_reads` becomes `cycle_length` in the + underlying call to `tf.Dataset.Interleave`, and the `cycle_length` is required to be + equal or higher than the number of threads(`num_parallel_calls`). + `cycle_length` in tf.Dataset.Interleave will dictate how many items it will pick up to process + num_parallel_calls: (Optional.) number of thread to spawn. This must be set to `None` + or greater than 0. Also this must be less than or equal to `num_parallel_reads`. This defines + the degree of parallelism in the underlying Dataset.interleave call. reader_schema: (Optional.) A `tf.string` scalar representing the reader schema or None. - + deterministic: (Optional.) A boolean controlling whether determinism should be traded for performance by + allowing elements to be produced out of order. Defaults to `True` + block_length: Sets the number of output on the output tensor. Defaults to 1 Raises: TypeError: If any argument does not have the expected type. ValueError: If any argument does not have the expected shape. """ + _require( + num_parallel_calls is None + or num_parallel_calls == tf.data.experimental.AUTOTUNE + or num_parallel_calls > 0, + f"num_parallel_calls: {num_parallel_calls} must be set to None, " + f"tf.data.experimental.AUTOTUNE, or greater than 0", + ) + if num_parallel_calls is not None: + _require( + num_parallel_reads is not None + and ( + num_parallel_reads >= num_parallel_calls + or num_parallel_reads == tf.data.experimental.AUTOTUNE + ), + f"num_parallel_reads: {num_parallel_reads} must be greater than or equal to " + f"num_parallel_calls: {num_parallel_calls} or set to tf.data.experimental.AUTOTUNE", + ) + filenames = _create_or_validate_filenames_dataset(filenames) self._filenames = filenames self._buffer_size = buffer_size self._num_parallel_reads = num_parallel_reads + self._num_parallel_calls = num_parallel_calls self._reader_schema = reader_schema + self._block_length = block_length - def creator_fn(filename): - return _AvroRecordDataset(filename, buffer_size, reader_schema) + def read_multiple_files(filenames): + return _AvroRecordDataset(filenames, buffer_size, reader_schema) - self._impl = _create_dataset_reader(creator_fn, filenames, num_parallel_reads) + self._impl = _create_dataset_reader( + read_multiple_files, + filenames, + cycle_length=num_parallel_reads, + num_parallel_calls=num_parallel_calls, + deterministic=deterministic, + block_length=block_length, + ) variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access super().__init__(variant_tensor) @@ -171,13 +272,17 @@ def _clone( filenames=None, buffer_size=None, num_parallel_reads=None, + num_parallel_calls=None, reader_schema=None, + block_length=None, ): return AvroRecordDataset( filenames or self._filenames, buffer_size or self._buffer_size, num_parallel_reads or self._num_parallel_reads, + num_parallel_calls or self._num_parallel_calls, reader_schema or self._reader_schema, + block_length or self._block_length, ) def _inputs(self): diff --git a/tensorflow_io/core/python/experimental/azure_ops.py b/tensorflow_io/core/python/experimental/azure_ops.py deleted file mode 100644 index 29239d104..000000000 --- a/tensorflow_io/core/python/experimental/azure_ops.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""tensorflow-io azure file system import""" - - -import tensorflow_io.core.python.ops # pylint: disable=unused-import - - -def authenticate_with_device_code(account_name): - """Setup storage tokens by authenticating with device code - and use management APIs. - - Args: - account_name (str): The storage account name for which to authenticate - """ - - import urllib # pylint: disable=import-outside-toplevel - import json # pylint: disable=import-outside-toplevel - import os # pylint: disable=import-outside-toplevel - from tensorflow.python.platform import ( # pylint: disable=import-outside-toplevel - tf_logging as log, - ) - - try: - from adal import ( # pylint: disable=import-outside-toplevel - AuthenticationContext, - ) - except ModuleNotFoundError: - log.error( - "Please install adal library with `python -m pip install -U adal`" - "to use the device code authentication method" - ) - return - - ctx = AuthenticationContext("https://login.microsoftonline.com/common") - - storage_resource = "https://management.azure.com/" - # Current multi-tenant client registerd in my AzureAD tenant - client_id = "8c375311-7f4c-406c-84f8-03dfe11ba2d3" - - device_code = ctx.acquire_user_code(resource=storage_resource, client_id=client_id) - - # Display authentication message to user to action in their browser - log.warn(device_code["message"]) - - token_response = ctx.acquire_token_with_device_code( - resource=storage_resource, user_code_info=device_code, client_id=client_id - ) - - headers = {"Authorization": "Bearer " + token_response["accessToken"]} - - subscription_list_req = urllib.request.Request( - url="https://management.azure.com/subscriptions?api-version=2016-06-01", - headers=headers, - ) - - with urllib.request.urlopen(subscription_list_req) as f: - subscriptions = json.load(f) - subscriptions = subscriptions["value"] - - storage_account = None - for subscription in subscriptions: - url = "https://management.azure.com/subscriptions/{}/providers/Microsoft.Storage/storageAccounts?api-version=2019-04-01".format( - subscription["subscriptionId"] - ) - storage_account_list_req = urllib.request.Request(url=url, headers=headers) - - with urllib.request.urlopen(storage_account_list_req) as f: - storage_accounts = json.load(f) - - storage_accounts = storage_accounts["value"] - account_by_name = [s for s in storage_accounts if s.get("name") == account_name] - if any(account_by_name): - storage_account = account_by_name[0] - break - - if storage_account is None: - log.error( - "Couldn't find storage account {} in any " - "available subscription".format(account_name) - ) - return - - url = "https://management.azure.com/{}/listKeys?api-version=2019-04-01".format( - storage_account["id"] - ) - storage_list_keys_req = urllib.request.Request( - url=url, headers=headers, method="POST" - ) - - with urllib.request.urlopen(storage_list_keys_req) as f: - account_keys = json.load(f) - - os.environ["TF_AZURE_STORAGE_KEY"] = account_keys["keys"][0]["value"] - log.info( - "Successfully set account key environment for {} " - "storage account".format(account_name) - ) diff --git a/tensorflow_io/core/python/experimental/image_ops.py b/tensorflow_io/core/python/experimental/image_ops.py index b399de102..ebde7e6ae 100644 --- a/tensorflow_io/core/python/experimental/image_ops.py +++ b/tensorflow_io/core/python/experimental/image_ops.py @@ -208,3 +208,18 @@ def decode_jp2(contents, dtype=tf.uint8, name=None): A `Tensor` of type `uint8` and shape of `[height, width, 3]` (RGB). """ return core_ops.io_decode_jpeg2k(contents, dtype=dtype, name=name) + + +def decode_obj(contents, name=None): + """ + Decode a Wavefront (obj) file into a float32 tensor. + + Args: + contents: A 0-dimensional Tensor of type string, i.e the + content of the Wavefront (.obj) file. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `float32` and shape of `[n, 3]` for vertices. + """ + return core_ops.io_decode_obj(contents, name=name) diff --git a/tensorflow_io/core/python/experimental/io_dataset_ops.py b/tensorflow_io/core/python/experimental/io_dataset_ops.py index 8ff46c22c..2367c7668 100644 --- a/tensorflow_io/core/python/experimental/io_dataset_ops.py +++ b/tensorflow_io/core/python/experimental/io_dataset_ops.py @@ -171,7 +171,7 @@ def from_numpy_file(cls, filename, spec=None, **kwargs): In case numpy file consists of unnamed elements, a tuple of numpy arrays are returned, otherwise a dict is returned for named elements. - ``` + Args: filename: filename of numpy file (npy or npz). spec: A tuple of tf.TensorSpec or dtype, or a dict of diff --git a/tensorflow_io/core/python/experimental/make_avro_record_dataset.py b/tensorflow_io/core/python/experimental/make_avro_record_dataset.py index 11934175b..af4eefa61 100644 --- a/tensorflow_io/core/python/experimental/make_avro_record_dataset.py +++ b/tensorflow_io/core/python/experimental/make_avro_record_dataset.py @@ -37,60 +37,41 @@ def make_avro_record_dataset( shuffle_seed=None, prefetch_buffer_size=tf.data.experimental.AUTOTUNE, num_parallel_reads=None, - num_parallel_parser_calls=None, drop_final_batch=False, ): """Reads and (optionally) parses avro files into a dataset. - Provides common functionality such as batching, optional parsing, shuffling, and performing defaults. - Args: file_pattern: List of files or patterns of avro file paths. See `tf.io.gfile.glob` for pattern rules. - features: A map of feature names mapped to feature information. - batch_size: An int representing the number of records to combine in a single batch. - reader_schema: The reader schema. - reader_buffer_size: (Optional.) An int specifying the readers buffer size in By. If None (the default) will use the default value from AvroRecordDataset. - num_epochs: (Optional.) An int specifying the number of times this dataset is repeated. If None (the default), cycles through the dataset forever. If set to None drops final batch. - shuffle: (Optional.) A bool that indicates whether the input should be shuffled. Defaults to `True`. - shuffle_buffer_size: (Optional.) Buffer size to use for shuffling. A large buffer size ensures better shuffling, but increases memory usage and startup time. If not provided assumes default value of 10,000 records. Note that the shuffle size is measured in records. - shuffle_seed: (Optional.) Randomization seed to use for shuffling. By default uses a pseudo-random seed. - prefetch_buffer_size: (Optional.) An int specifying the number of feature batches to prefetch for performance improvement. Defaults to auto-tune. Set to 0 to disable prefetching. - - num_parallel_reads: (Optional.) Number of threads used to read - records from files. By default or if set to a value >1, the - results will be interleaved. - - num_parallel_parser_calls: (Optional.) Number of parallel - records to parse in parallel. Defaults to an automatic selection. - + num_parallel_reads: (Optional.) Number of parallel + records to parse in parallel. Defaults to None(no parallelization). drop_final_batch: (Optional.) Whether the last batch should be dropped in case its size is smaller than `batch_size`; the default behavior is not to drop the smaller batch. - Returns: A dataset, where each element matches the output of `parser_fn` except it will have an additional leading `batch-size` dimension, @@ -99,20 +80,15 @@ def make_avro_record_dataset( """ files = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle, seed=shuffle_seed) - if num_parallel_reads is None: - # Note: We considered auto-tuning this value, but there is a concern - # that this affects the mixing of records from different files, which - # could affect training convergence/accuracy, so we are defaulting to - # a constant for now. - num_parallel_reads = 24 - if reader_buffer_size is None: reader_buffer_size = 1024 * 1024 - + num_parallel_calls = num_parallel_reads dataset = AvroRecordDataset( files, buffer_size=reader_buffer_size, num_parallel_reads=num_parallel_reads, + num_parallel_calls=num_parallel_calls, + block_length=num_parallel_calls, reader_schema=reader_schema, ) @@ -131,14 +107,11 @@ def make_avro_record_dataset( dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) - if num_parallel_parser_calls is None: - num_parallel_parser_calls = tf.data.experimental.AUTOTUNE - dataset = dataset.map( lambda data: parse_avro( serialized=data, reader_schema=reader_schema, features=features ), - num_parallel_calls=num_parallel_parser_calls, + num_parallel_calls=num_parallel_calls, ) if prefetch_buffer_size == 0: diff --git a/tensorflow_io/core/python/experimental/mongodb_dataset_ops.py b/tensorflow_io/core/python/experimental/mongodb_dataset_ops.py index 7f0b963f3..7331db65f 100644 --- a/tensorflow_io/core/python/experimental/mongodb_dataset_ops.py +++ b/tensorflow_io/core/python/experimental/mongodb_dataset_ops.py @@ -54,9 +54,54 @@ def get_next_batch(self, resource): class MongoDBIODataset(tf.data.Dataset): - """Fetch records from mongoDB""" + """Fetch records from mongoDB + + The dataset aids in faster retrieval of data from MongoDB collections. + + To make a connection and read the documents from the mongo collections, + the `tfio.experimental.mongodb.MongoDBIODataset` API can be used. + + Example: + + >>> URI = "mongodb://mongoadmin:default_password@localhost:27017" + >>> DATABASE = "tfiodb" + >>> COLLECTION = "test" + >>> dataset = tfio.experimental.mongodb.MongoDBIODataset( + uri=URI, database=DATABASE, collection=COLLECTION) + + Perform operations on the dataset as one would with any `tf.data.Dataset` + >>> dataset = dataset.map(transform_func) + >>> dataset = dataset.batch(batch_size) + + Assuming the user has already built a `tf.keras` model, the dataset can be directly + passed for training purposes. + + >>> model.fit(dataset) # to train + >>> model.predict(dataset) # to infer + + """ def __init__(self, uri, database, collection): + """Initialize the dataset with the following parameters + + Args: + uri: The uri of the mongo server or replicaset to connect to. + - To connect to a MongoDB server with username and password + based authentication, the following uri pattern can be used. + Example: `"mongodb://mongoadmin:default_password@localhost:27017"`. + + - Connecting to a replica set is much like connecting to a + standalone MongoDB server. Simply specify the replica set name + using the `?replicaSet=myreplset` URI option. + Example: "mongodb://host01:27017,host02:27017,host03:27017/?replicaSet=myreplset" + + Additional information on writing uri's can be found here: + - [libmongoc uri docs](http://mongoc.org/libmongoc/current/mongoc_uri_t.html) + - [mongodb uri docs](https://docs.mongodb.com/manual/reference/connection-string/) + database: The database in the standalone standalone MongoDB server or a replica set + to connect to. + collection: The collection from which the documents have to be retrieved. + """ handler = _MongoDBHandler(uri=uri, database=database, collection=collection) resource = handler.get_healthy_resource() dataset = tf.data.experimental.Counter() diff --git a/tensorflow_io/core/python/experimental/mongodb_writer_ops.py b/tensorflow_io/core/python/experimental/mongodb_writer_ops.py index 3af8a14bf..cb1fff29e 100644 --- a/tensorflow_io/core/python/experimental/mongodb_writer_ops.py +++ b/tensorflow_io/core/python/experimental/mongodb_writer_ops.py @@ -22,10 +22,53 @@ class MongoDBWriter: - """Write documents to mongoDB""" + """Write documents to mongoDB. + + The writer can be used to store documents in mongoDB while dealing with tensorflow + based models and inference outputs. Without loss of generality, consider an ML + model that is being used for inference. The outputs of inference can be modelled into + a structured record by enriching the schema with additional information( for ex: metadata + about input data and the semantics of the inference etc.) and can be stored in mongo + collections for persistence or future analysis. + + To make a connection and write the documents to the mongo collections, + the `tfio.experimental.mongodb.MongoDBWriter` API can be used. + + Example: + + >>> URI = "mongodb://mongoadmin:default_password@localhost:27017" + >>> DATABASE = "tfiodb" + >>> COLLECTION = "test" + >>> writer = tfio.experimental.mongodb.MongoDBWriter( + uri=URI, database=DATABASE, collection=COLLECTION + ) + >>> for i in range(1000): + ... data = {"key{}".format(i): "value{}".format(i)} + ... writer.write(data) + + """ def __init__(self, uri, database, collection): + """Initialize the dataset with the following parameters + + Args: + uri: The uri of the mongo server or replicaset to connect to. + - To connect to a MongoDB server with username and password + based authentication, the following uri pattern can be used. + Example: `"mongodb://mongoadmin:default_password@localhost:27017"`. + + - Connecting to a replica set is much like connecting to a + standalone MongoDB server. Simply specify the replica set name + using the `?replicaSet=myreplset` URI option. + Example: "mongodb://host01:27017,host02:27017,host03:27017/?replicaSet=myreplset" + Additional information on writing uri's can be found here: + - [libmongoc uri docs](http://mongoc.org/libmongoc/current/mongoc_uri_t.html) + - [mongodb uri docs](https://docs.mongodb.com/manual/reference/connection-string/) + database: The database in the standalone standalone MongoDB server or a replica set + to connect to. + collection: The collection from which the documents have to be retrieved. + """ self.uri = uri self.database = database self.collection = collection diff --git a/tensorflow_io/core/python/experimental/parse_avro_ops.py b/tensorflow_io/core/python/experimental/parse_avro_ops.py index edfbbee79..341548435 100644 --- a/tensorflow_io/core/python/experimental/parse_avro_ops.py +++ b/tensorflow_io/core/python/experimental/parse_avro_ops.py @@ -130,6 +130,7 @@ def _parse_avro( dense_defaults=None, dense_shapes=None, name=None, + avro_num_minibatches=0, ): """Parses Avro records. @@ -196,6 +197,7 @@ def _parse_avro( dense_keys=dense_keys, dense_shapes=dense_shapes, name=name, + avro_num_minibatches=avro_num_minibatches, ) (sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs diff --git a/tensorflow_io/core/python/experimental/serial_ops.py b/tensorflow_io/core/python/experimental/serial_ops.py new file mode 100644 index 000000000..461ad63b0 --- /dev/null +++ b/tensorflow_io/core/python/experimental/serial_ops.py @@ -0,0 +1,201 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Easily save tf.data.Datasets as tfrecord files, and restore tfrecords as Datasets. + +The goal of this module is to create a SIMPLE api to tfrecords that can be used without +learning all of the underlying mechanics. + +Users only need to deal with 2 functions: +save_dataset(dataset) +dataset = load_dataset(tfrecord, header) + +It really is that easy! + +To make this work, we create a .header file for each tfrecord which encodes metadata +needed to reconstruct the original dataset. + +Note that PyYAML (yaml) package must be installed to make use of this module. + +Saving must be done in eager mode, but loading is compatible with both eager and +graph execution modes. + +GOTCHAS: +- This module is only compatible with "dictionary-style" datasets {key: val, key2:val2,..., keyN: valN}. +- The restored dataset will have the TFRecord dtypes {float32, int64, string} instead of the original + tensor dtypes. This is always the case with TFRecord datasets, whether you use this module or not. + The original dtypes are stored in the headers if you want to restore them after loading.""" +import functools +import os +import tempfile + +import numpy as np +import tensorflow as tf + + +# The three encoding functions. +def _bytes_feature(value): + """value: list""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) + + +def _float_feature(value): + """value: list""" + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + + +def _int64_feature(value): + """value: list""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + + +# TODO use base_type() to ensure consistent conversion. +def np_value_to_feature(value): + """Maps dataset values to tf Features. + Only numpy types are supported since Datasets only contain tensors. + Each datatype should only have one way of being serialized.""" + if isinstance(value, np.ndarray): + # feature = _bytes_feature(value.tostring()) + if np.issubdtype(value.dtype, np.integer): + feature = _int64_feature(value.flatten()) + elif np.issubdtype(value.dtype, np.float): + feature = _float_feature(value.flatten()) + elif np.issubdtype(value.dtype, np.bool): + feature = _int64_feature(value.flatten()) + else: + raise TypeError(f"value dtype: {value.dtype} is not recognized.") + elif isinstance(value, bytes): + feature = _bytes_feature([value]) + elif np.issubdtype(type(value), np.integer): + feature = _int64_feature([value]) + elif np.issubdtype(type(value), np.float): + feature = _float_feature([value]) + + else: + raise TypeError( + f"value type: {type(value)} is not recognized. value must be a valid Numpy object." + ) + + return feature + + +def base_type(dtype): + """Returns the TFRecords allowed type corresponding to dtype.""" + int_types = [ + tf.int8, + tf.int16, + tf.int32, + tf.int64, + tf.uint8, + tf.uint16, + tf.uint32, + tf.uint64, + tf.qint8, + tf.qint16, + tf.qint32, + tf.bool, + ] + float_types = [tf.float16, tf.float32, tf.float64] + byte_types = [tf.string, bytes] + + if dtype in int_types: + new_dtype = tf.int64 + elif dtype in float_types: + new_dtype = tf.float32 + elif dtype in byte_types: + new_dtype = tf.string + else: + raise ValueError(f"dtype {dtype} is not a recognized/supported type!") + + return new_dtype + + +def build_header(dataset): + """Build header dictionary of metadata for the tensors in the dataset. This will be used when loading + the tfrecords file to reconstruct the original tensors from the raw data. Shape is stored as an array + and dtype is stored as an enumerated value (defined by tensorflow).""" + header = {} + for key in dataset.element_spec.keys(): + header[key] = { + "shape": list(dataset.element_spec[key].shape), + "dtype": dataset.element_spec[key].dtype.as_datatype_enum, + } + + return header + + +def build_feature_desc(header): + """Build feature_desc dictionary for the tensors in the dataset. This will be used to reconstruct Examples + from the tfrecords file. + + Assumes FixedLenFeatures. + If you got VarLenFeatures I feel bad for you son, + I got 115 problems but a VarLenFeature ain't one.""" + feature_desc = {} + for key, params in header.items(): + feature_desc[key] = tf.io.FixedLenFeature( + shape=params["shape"], dtype=base_type(int(params["dtype"])) + ) + + return feature_desc + + +def dataset_to_examples(ds): + """Converts a dataset to a dataset of tf.train.Example strings. Each Example is a single observation. + WARNING: Only compatible with "dictionary-style" datasets {key: val, key2:val2,..., keyN, valN}. + WARNING: Must run in eager mode!""" + # TODO handle tuples and flat datasets as well. + for x in ds: + # Each individual tensor is converted to a known serializable type. + features = {key: np_value_to_feature(value.numpy()) for key, value in x.items()} + # All features are then packaged into a single Example object. + example = tf.train.Example(features=tf.train.Features(feature=features)) + + yield example.SerializeToString() + + +def save_dataset(dataset, tfrecord_path, header_path): + """Saves a flat dataset as a tfrecord file, and builds a header file for reloading as dataset. + Must run in eager mode because it depends on dataset iteration and element_spec.""" + import yaml + + if not tf.executing_eagerly(): + raise ValueError("save_dataset() must run in eager mode!") + + # Header + header = build_header(dataset) + header_file = open(header_path, "w") + yaml.dump(header, stream=header_file) + + # Dataset + ds_examples = tf.data.Dataset.from_generator( + lambda: dataset_to_examples(dataset), output_types=tf.string + ) + writer = tf.data.experimental.TFRecordWriter(tfrecord_path) + writer.write(ds_examples) + + +# TODO-DECIDE is this yaml loader safe? +def load_dataset(tfrecord_path, header_path): + """Uses header file to predict the shape and dtypes of tensors for tf.data.""" + import yaml + + header_file = open(header_path) + header = yaml.load(header_file, Loader=yaml.FullLoader) + + feature_desc = build_feature_desc(header) + parse_func = functools.partial(tf.io.parse_single_example, features=feature_desc) + dataset = tf.data.TFRecordDataset(tfrecord_path).map(parse_func) + + return dataset diff --git a/tensorflow_io/core/python/ops/__init__.py b/tensorflow_io/core/python/ops/__init__.py index c9f6865a9..ebb0cfd47 100644 --- a/tensorflow_io/core/python/ops/__init__.py +++ b/tensorflow_io/core/python/ops/__init__.py @@ -18,6 +18,7 @@ import ctypes import sys import inspect +import warnings import types import tensorflow as tf @@ -95,5 +96,8 @@ def __dir__(self): plugin_ops = _load_library("libtensorflow_io_plugins.so", "fs") except NotImplementedError as e: # Note: load libtensorflow_io.so imperatively in case of statically linking - core_ops = _load_library("libtensorflow_io.so") - plugin_ops = _load_library("libtensorflow_io.so", "fs") + try: + core_ops = _load_library("libtensorflow_io.so") + plugin_ops = _load_library("libtensorflow_io.so", "fs") + except NotImplementedError as e: + warnings.warn("file system plugins are not loaded: {}".format(e)) diff --git a/tensorflow_io/core/python/ops/version_ops.py b/tensorflow_io/core/python/ops/version_ops.py index d6d6e3122..e61f2470d 100644 --- a/tensorflow_io/core/python/ops/version_ops.py +++ b/tensorflow_io/core/python/ops/version_ops.py @@ -14,5 +14,5 @@ # ============================================================================== """version_ops""" -package = "tensorflow>=2.4.0,<2.5.0" -version = "0.17.0" +package = "tf-nightly" +version = "0.18.0" diff --git a/tensorflow_io/gcs/BUILD b/tensorflow_io/gcs/BUILD new file mode 100644 index 000000000..8ffb13dbd --- /dev/null +++ b/tensorflow_io/gcs/BUILD @@ -0,0 +1,34 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load( + "//:tools/build/tensorflow_io.bzl", + "tf_io_copts", +) + +cc_binary( + name = "_gcs_config_ops.so", + copts = tf_io_copts(), + linkshared = 1, + deps = [ + ":gcs_config_ops", + ], +) + +cc_library( + name = "gcs_config_ops", + srcs = [ + "kernels/gcs_config_op_kernels.cc", + "ops/gcs_config_ops.cc", + ], + copts = tf_io_copts(), + linkstatic = True, + deps = [ + "@curl", + "@jsoncpp_git//:jsoncpp", + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow_io/gcs/README.md b/tensorflow_io/gcs/README.md new file mode 100644 index 000000000..99782a341 --- /dev/null +++ b/tensorflow_io/gcs/README.md @@ -0,0 +1,3 @@ +## Cloud Storage (GCS) ## + +The Google Cloud Storage ops allow the user to configure the GCS File System. diff --git a/tensorflow_io/gcs/__init__.py b/tensorflow_io/gcs/__init__.py new file mode 100644 index 000000000..39f6154b7 --- /dev/null +++ b/tensorflow_io/gcs/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module for cloud ops.""" + + +from tensorflow.python.util.all_util import remove_undocumented + +# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top +from tensorflow_io.gcs.python.ops.gcs_config_ops import * + +_allowed_symbols = [ + "configure_colab_session", + "configure_gcs", + "BlockCacheParams", + "ConfigureGcsHook", +] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow_io/gcs/kernels/gcs_config_op_kernels.cc b/tensorflow_io/gcs/kernels/gcs_config_op_kernels.cc new file mode 100644 index 000000000..3fd878a73 --- /dev/null +++ b/tensorflow_io/gcs/kernels/gcs_config_op_kernels.cc @@ -0,0 +1,206 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + tstring json_string; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = + tensorflow::MakeUnique(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ TF_GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ TF_GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow_io/gcs/ops/gcs_config_ops.cc b/tensorflow_io/gcs/ops/gcs_config_ops.cc new file mode 100644 index 000000000..140dbc3a3 --- /dev/null +++ b/tensorflow_io/gcs/ops/gcs_config_ops.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("IO>GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. +The json input can be of the format: +1. Refresh Token: +{ + "client_id": "", + "client_secret": "", + "refresh_token: "", + "type": "authorized_user", +} +2. Service Account: +{ + "type": "service_account", + "project_id": "", + "private_key_id": "", + "private_key": "------BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY------\n", + "client_email": "@.iam.gserviceaccount.com", + "client_id": "", + # Some additional fields elided +} +Note the credentials established through this method are shared across all +sessions run on this runtime. +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("IO>GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow_io/gcs/python/__init__.py b/tensorflow_io/gcs/python/__init__.py new file mode 100644 index 000000000..f00d24fd2 --- /dev/null +++ b/tensorflow_io/gcs/python/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains Python API methods for GCS integration.""" diff --git a/tensorflow_io/gcs/python/ops/__init__.py b/tensorflow_io/gcs/python/ops/__init__.py new file mode 100644 index 000000000..568c0e67a --- /dev/null +++ b/tensorflow_io/gcs/python/ops/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains the Python API methods for GCS integration.""" diff --git a/tensorflow_io/gcs/python/ops/gcs_config_ops.py b/tensorflow_io/gcs/python/ops/gcs_config_ops.py new file mode 100644 index 000000000..148602fe1 --- /dev/null +++ b/tensorflow_io/gcs/python/ops/gcs_config_ops.py @@ -0,0 +1,235 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + + +import json +import os + +import tensorflow as tf +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training +from tensorflow_io.core.python.ops import core_ops + +# Some GCS operations may be pre-defined and available via tf.contrib in +# earlier TF versions. Because these ops are pre-registered, they will not be +# visible from the _gcs_config_ops library. In this case we use the tf.contrib +# version instead. +tf_v1 = tf.version.VERSION.startswith("1") + +if not tf_v1: + gcs_configure_credentials = core_ops.io_gcs_configure_credentials + gcs_configure_block_cache = core_ops.io_gcs_configure_block_cache + + +class BlockCacheParams: # pylint: disable=useless-object-inheritance + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if "refresh_token" in creds_dict or "private_key" in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError( + "credentials was not a well formed JSON string.", e + ) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + "field." + ) + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError( + 'credentials has neither a "refresh_token" nor a ' + '"private_key" field.' + ) + credentials = json.dumps(credentials) + else: + raise ValueError("credentials is of an unknown type") + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError("block_cache must be an instance of BlockCacheParams.") + self._block_cache = block_cache + + def begin(self): + """Called once before using the session. + + When called, the default graph is the one that will be launched in the + session. The hook can modify the graph by adding new operations to it. + After the `begin()` call the graph will be finalized and the other callbacks + can not modify the graph anymore. Second call of `begin()` on the same + graph, should not change the graph. + """ + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_op = gcs_configure_credentials( + self._credentials_placeholder + ) + else: + self._credentials_op = None + + if self._block_cache: + self._block_cache_op = gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness, + ) + else: + self._block_cache_op = None + + def after_create_session(self, session, coord): + """Called when new TensorFlow session is created. + + This is called to signal the hooks that a new session has been created. This + has two essential differences with the situation in which `begin` is called: + + * When this is called, the graph is finalized and ops can no longer be added + to the graph. + * This method will also be called as a result of recovering a wrapped + session, not only at the beginning of the overall session. + + Args: + session: A TensorFlow Session that has been created. + coord: A Coordinator object which keeps track of all threads. + """ + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}, + ) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def _configure_gcs_tfv2(credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Args: + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + gcs_configure_credentials(credentials) + + if block_cache: + gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness, + ) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def _configure_colab_session_tfv2(): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + """ + # Read from the application default credentials (adc). + adc_filename = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "/content/adc.json") + with open(adc_filename) as f: + data = json.load(f) + configure_gcs(credentials=data) + + +if tf_v1: + configure_gcs = tf.contrib.cloud.configure_gcs + configure_colab_session = tf.contrib.cloud.configure_colab_session +else: + configure_gcs = _configure_gcs_tfv2 + configure_colab_session = _configure_colab_session_tfv2 diff --git a/tests/test_archive_eager.py b/tests/test_archive.py similarity index 100% rename from tests/test_archive_eager.py rename to tests/test_archive.py diff --git a/tests/test_arrow_eager.py b/tests/test_arrow.py similarity index 99% rename from tests/test_arrow_eager.py rename to tests/test_arrow.py index a744c6696..9c4d38c6b 100644 --- a/tests/test_arrow_eager.py +++ b/tests/test_arrow.py @@ -508,7 +508,7 @@ def test_arrow_feather_dataset(self): # Create a tempfile that is deleted after tests run with tempfile.NamedTemporaryFile(delete=False) as f: - write_feather(df, f) + write_feather(df, f, version=1) # test single file dataset = arrow_io.ArrowFeatherDataset( @@ -1143,7 +1143,7 @@ def test_arrow_list_feather_columns(self): # Create a tempfile that is deleted after tests run with tempfile.NamedTemporaryFile(delete=False) as f: - write_feather(df, f) + write_feather(df, f, version=1) # test single file # prefix "file://" to test scheme file system (e.g., s3, gcs, azfs, ignite) diff --git a/tests/test_audio_eager.py b/tests/test_audio.py similarity index 100% rename from tests/test_audio_eager.py rename to tests/test_audio.py diff --git a/tests/test_audio_ops_eager.py b/tests/test_audio_ops.py similarity index 100% rename from tests/test_audio_ops_eager.py rename to tests/test_audio_ops.py diff --git a/tests/test_avro_eager.py b/tests/test_avro.py similarity index 100% rename from tests/test_avro_eager.py rename to tests/test_avro.py diff --git a/tests/test_bigquery_eager.py b/tests/test_bigquery.py similarity index 100% rename from tests/test_bigquery_eager.py rename to tests/test_bigquery.py diff --git a/tests/test_bigtable.py b/tests/test_bigtable.py new file mode 100644 index 000000000..6ec179f66 --- /dev/null +++ b/tests/test_bigtable.py @@ -0,0 +1,119 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Stub Test""" + +import os +import sys +import time +import shutil +import datetime +import tempfile +import numpy as np +import pytest + +import tensorflow as tf +import tensorflow_io as tfio + + +def bigtable_func(project_id, instance_id, table_id): + from google.cloud import bigtable + from google.cloud.bigtable import column_family + from google.cloud.bigtable import row_filters + from google.auth.credentials import AnonymousCredentials + + os.environ["BIGTABLE_EMULATOR_HOST"] = "localhost:8086" + + # [START bigtable_hw_connect] + # The client must be created with admin=True because it will create a + # table. + client = bigtable.Client( + project=project_id, admin=True, credentials=AnonymousCredentials() + ) + instance = client.instance(instance_id) + # [END bigtable_hw_connect] + + # [START bigtable_hw_create_table] + print("Creating the {} table.".format(table_id)) + table = instance.table(table_id) + + print("Creating column family cf1 with Max Version GC rule...") + # Create a column family with GC policy : most recent N versions + # Define the GC policy to retain only the most recent 2 versions + max_versions_rule = column_family.MaxVersionsGCRule(2) + column_family_id = "cf1" + column_families = {column_family_id: max_versions_rule} + if not table.exists(): + table.create(column_families=column_families) + else: + print("Table {} already exists.".format(table_id)) + # [END bigtable_hw_create_table] + + # [START bigtable_hw_write_rows] + print("Writing some greetings to the table.") + greetings = ["Hello World!", "Hello Cloud Bigtable!", "Hello Python!"] + rows = [] + column = b"greeting" + for i, value in enumerate(greetings): + # Note: This example uses sequential numeric IDs for simplicity, + # but this can result in poor performance in a production + # application. Since rows are stored in sorted order by key, + # sequential keys can result in poor distribution of operations + # across nodes. + # + # For more information about how to design a Bigtable schema for + # the best performance, see the documentation: + # + # https://cloud.google.com/bigtable/docs/schema-design + row_key = "greeting{}".format(i).encode() + row = table.direct_row(row_key) + row.set_cell( + column_family_id, column, value, timestamp=datetime.datetime.utcnow() + ) + rows.append(row) + table.mutate_rows(rows) + # [END bigtable_hw_write_rows] + + # [START bigtable_hw_create_filter] + # Create a filter to only retrieve the most recent version of the cell + # for each column accross entire row. + row_filter = row_filters.CellsColumnLimitFilter(1) + # [END bigtable_hw_create_filter] + + # [START bigtable_hw_get_with_filter] + print("Getting a single greeting by row key.") + key = b"greeting0" + + row = table.read_row(key, row_filter) + cell = row.cells[column_family_id][column][0] + print(cell.value.decode("utf-8")) + # [END bigtable_hw_get_with_filter] + + # [START bigtable_hw_scan_with_filter] + print("Scanning for all greetings:") + partial_rows = table.read_rows(filter_=row_filter) + + for row in partial_rows: + cell = row.cells[column_family_id][column][0] + print(cell.value.decode("utf-8")) + # [END bigtable_hw_scan_with_filter] + + # [START bigtable_hw_delete_table] + print("Deleting the {} table.".format(table_id)) + table.delete() + # [END bigtable_hw_delete_table] + + +def test_bigtable(): + bigtable_func("bigtable_project", "bigtable_instance", "bigtable_table") diff --git a/tests/test_color_eager.py b/tests/test_color.py similarity index 100% rename from tests/test_color_eager.py rename to tests/test_color.py diff --git a/tests/test_csv_eager.py b/tests/test_csv.py similarity index 100% rename from tests/test_csv_eager.py rename to tests/test_csv.py diff --git a/tests/test_dicom.py b/tests/test_dicom.py index c4605e61a..e867465cb 100644 --- a/tests/test_dicom.py +++ b/tests/test_dicom.py @@ -16,6 +16,7 @@ import os +import numpy as np import pytest import tensorflow as tf @@ -35,8 +36,7 @@ def test_dicom_input(): - """test_dicom_input - """ + """test_dicom_input""" _ = tfio.image.decode_dicom_data _ = tfio.image.decode_dicom_image _ = tfio.image.dicom_tags @@ -66,32 +66,26 @@ def test_dicom_input(): ("MR-MONO2-12-shoulder.dcm", (1, 1024, 1024, 1)), ("OT-MONO2-8-a7.dcm", (1, 512, 512, 1)), ("US-PAL-8-10x-echo.dcm", (10, 430, 600, 3)), + ("TOSHIBA_J2K_OpenJPEGv2Regression.dcm", (1, 512, 512, 1)), ], ) def test_decode_dicom_image(fname, exp_shape): - """test_decode_dicom_image - """ + """test_decode_dicom_image""" dcm_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname ) - g1 = tf.compat.v1.Graph() + file_contents = tf.io.read_file(filename=dcm_path) - with g1.as_default(): - file_contents = tf.io.read_file(filename=dcm_path) - dcm_image = tfio.image.decode_dicom_image( - contents=file_contents, - dtype=tf.float32, - on_error="strict", - scale="auto", - color_dim=True, - ) - - sess = tf.compat.v1.Session(graph=g1) - dcm_image_np = sess.run(dcm_image) - - assert dcm_image_np.shape == exp_shape + dcm_image = tfio.image.decode_dicom_image( + contents=file_contents, + dtype=tf.float32, + on_error="strict", + scale="auto", + color_dim=True, + ) + assert dcm_image.numpy().shape == exp_shape @pytest.mark.parametrize( @@ -121,23 +115,108 @@ def test_decode_dicom_image(fname, exp_shape): ], ) def test_decode_dicom_data(fname, tag, exp_value): - """test_decode_dicom_data - """ + """test_decode_dicom_data""" dcm_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname ) - g1 = tf.compat.v1.Graph() + file_contents = tf.io.read_file(filename=dcm_path) + + dcm_data = tfio.image.decode_dicom_data(contents=file_contents, tags=tag) + + assert dcm_data.numpy() == exp_value + + +def test_dicom_image_shape(): + """test_decode_dicom_image""" + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "US-PAL-8-10x-echo.dcm", + ) + + dataset = tf.data.Dataset.from_tensor_slices([dcm_path]) + dataset = dataset.map(tf.io.read_file) + dataset = dataset.map(lambda e: tfio.image.decode_dicom_image(e, dtype=tf.uint16)) + dataset = dataset.map(lambda e: tf.image.resize(e, (224, 224))) + + +def test_dicom_image_concurrency(): + """test_decode_dicom_image_currency""" - with g1.as_default(): - file_contents = tf.io.read_file(filename=dcm_path) - dcm_data = tfio.image.decode_dicom_data(contents=file_contents, tags=tag) + @tf.function + def preprocess(dcm_content): + tags = tfio.image.decode_dicom_data( + dcm_content, tags=[tfio.image.dicom_tags.PatientsName] + ) + tf.print(tags) + image = tfio.image.decode_dicom_image(dcm_content, dtype=tf.float32) + return image + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "TOSHIBA_J2K_OpenJPEGv2Regression.dcm", + ) + + dataset = ( + tf.data.Dataset.from_tensor_slices([dcm_path]) + .repeat() + .map(tf.io.read_file) + .map(preprocess, num_parallel_calls=8) + .take(200) + ) + for i, item in enumerate(dataset): + print(tf.shape(item), i) + assert np.array_equal(tf.shape(item), [1, 512, 512, 1]) + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "US-PAL-8-10x-echo.dcm", + ) + + dataset = ( + tf.data.Dataset.from_tensor_slices([dcm_path]) + .repeat() + .map(tf.io.read_file) + .map(preprocess, num_parallel_calls=8) + .take(200) + ) + for i, item in enumerate(dataset): + print(tf.shape(item), i) + assert np.array_equal(tf.shape(item), [10, 430, 600, 3]) + + +def test_dicom_sequence(): + """test_decode_dicom_sequence""" + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "2.25.304589190180579357564631626197663875025.dcm", + ) + dcm_content = tf.io.read_file(filename=dcm_path) + + tags = tfio.image.decode_dicom_data( + dcm_content, tags=["[0x0008,0x1115][0][0x0008,0x1140][0][0x0008,0x1155]"] + ) + assert np.array_equal(tags, [b"2.25.211904290918469145111906856660599393535"]) + + dcm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_dicom", + "US-PAL-8-10x-echo.dcm", + ) + dcm_content = tf.io.read_file(filename=dcm_path) - sess = tf.compat.v1.Session(graph=g1) - dcm_data_np = sess.run(dcm_data) + tags = tfio.image.decode_dicom_data(dcm_content, tags=["[0x0020,0x000E]"]) + assert np.array_equal(tags, [b"999.999.94827453"]) - assert dcm_data_np == exp_value + tags = tfio.image.decode_dicom_data(dcm_content, tags=["0x0020,0x000e"]) + assert np.array_equal(tags, [b"999.999.94827453"]) if __name__ == "__main__": diff --git a/tests/test_dicom_eager.py b/tests/test_dicom_eager.py deleted file mode 100644 index e867465cb..000000000 --- a/tests/test_dicom_eager.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -# ============================================================================== -"""Tests for DICOM.""" - - -import os -import numpy as np -import pytest - -import tensorflow as tf -import tensorflow_io as tfio - -# The DICOM sample files must be downloaded befor running the tests -# -# To download the DICOM samples: -# $ bash dicom_samples.sh download -# $ bash dicom_samples.sh extract -# -# To remopve the DICOM samples: -# $ bash dicom_samples.sh clean_dcm -# -# To remopve all the downloaded files: -# $ bash dicom_samples.sh clean_all - - -def test_dicom_input(): - """test_dicom_input""" - _ = tfio.image.decode_dicom_data - _ = tfio.image.decode_dicom_image - _ = tfio.image.dicom_tags - - -@pytest.mark.parametrize( - "fname, exp_shape", - [ - ("OT-MONO2-8-colon.dcm", (1, 512, 512, 1)), - ("CR-MONO1-10-chest.dcm", (1, 440, 440, 1)), - ("CT-MONO2-16-ort.dcm", (1, 512, 512, 1)), - ("MR-MONO2-16-head.dcm", (1, 256, 256, 1)), - ("US-RGB-8-epicard.dcm", (1, 480, 640, 3)), - ("CT-MONO2-8-abdo.dcm", (1, 512, 512, 1)), - ("MR-MONO2-16-knee.dcm", (1, 256, 256, 1)), - ("OT-MONO2-8-hip.dcm", (1, 512, 512, 1)), - ("US-RGB-8-esopecho.dcm", (1, 120, 256, 3)), - ("CT-MONO2-16-ankle.dcm", (1, 512, 512, 1)), - ("MR-MONO2-12-an2.dcm", (1, 256, 256, 1)), - ("MR-MONO2-8-16x-heart.dcm", (16, 256, 256, 1)), - ("OT-PAL-8-face.dcm", (1, 480, 640, 3)), - ("XA-MONO2-8-12x-catheter.dcm", (12, 512, 512, 1)), - ("CT-MONO2-16-brain.dcm", (1, 512, 512, 1)), - ("NM-MONO2-16-13x-heart.dcm", (13, 64, 64, 1)), - ("US-MONO2-8-8x-execho.dcm", (8, 120, 128, 1)), - ("CT-MONO2-16-chest.dcm", (1, 400, 512, 1)), - ("MR-MONO2-12-shoulder.dcm", (1, 1024, 1024, 1)), - ("OT-MONO2-8-a7.dcm", (1, 512, 512, 1)), - ("US-PAL-8-10x-echo.dcm", (10, 430, 600, 3)), - ("TOSHIBA_J2K_OpenJPEGv2Regression.dcm", (1, 512, 512, 1)), - ], -) -def test_decode_dicom_image(fname, exp_shape): - """test_decode_dicom_image""" - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname - ) - - file_contents = tf.io.read_file(filename=dcm_path) - - dcm_image = tfio.image.decode_dicom_image( - contents=file_contents, - dtype=tf.float32, - on_error="strict", - scale="auto", - color_dim=True, - ) - assert dcm_image.numpy().shape == exp_shape - - -@pytest.mark.parametrize( - "fname, tag, exp_value", - [ - ( - "OT-MONO2-8-colon.dcm", - tfio.image.dicom_tags.StudyInstanceUID, - b"1.3.46.670589.17.1.7.1.1.16", - ), - ("OT-MONO2-8-colon.dcm", tfio.image.dicom_tags.Rows, b"512"), - ("OT-MONO2-8-colon.dcm", tfio.image.dicom_tags.Columns, b"512"), - ("OT-MONO2-8-colon.dcm", tfio.image.dicom_tags.SamplesperPixel, b"1"), - ( - "US-PAL-8-10x-echo.dcm", - tfio.image.dicom_tags.StudyInstanceUID, - b"999.999.3859744", - ), - ( - "US-PAL-8-10x-echo.dcm", - tfio.image.dicom_tags.SeriesInstanceUID, - b"999.999.94827453", - ), - ("US-PAL-8-10x-echo.dcm", tfio.image.dicom_tags.NumberofFrames, b"10"), - ("US-PAL-8-10x-echo.dcm", tfio.image.dicom_tags.Rows, b"430"), - ("US-PAL-8-10x-echo.dcm", tfio.image.dicom_tags.Columns, b"600"), - ], -) -def test_decode_dicom_data(fname, tag, exp_value): - """test_decode_dicom_data""" - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_dicom", fname - ) - - file_contents = tf.io.read_file(filename=dcm_path) - - dcm_data = tfio.image.decode_dicom_data(contents=file_contents, tags=tag) - - assert dcm_data.numpy() == exp_value - - -def test_dicom_image_shape(): - """test_decode_dicom_image""" - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "US-PAL-8-10x-echo.dcm", - ) - - dataset = tf.data.Dataset.from_tensor_slices([dcm_path]) - dataset = dataset.map(tf.io.read_file) - dataset = dataset.map(lambda e: tfio.image.decode_dicom_image(e, dtype=tf.uint16)) - dataset = dataset.map(lambda e: tf.image.resize(e, (224, 224))) - - -def test_dicom_image_concurrency(): - """test_decode_dicom_image_currency""" - - @tf.function - def preprocess(dcm_content): - tags = tfio.image.decode_dicom_data( - dcm_content, tags=[tfio.image.dicom_tags.PatientsName] - ) - tf.print(tags) - image = tfio.image.decode_dicom_image(dcm_content, dtype=tf.float32) - return image - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "TOSHIBA_J2K_OpenJPEGv2Regression.dcm", - ) - - dataset = ( - tf.data.Dataset.from_tensor_slices([dcm_path]) - .repeat() - .map(tf.io.read_file) - .map(preprocess, num_parallel_calls=8) - .take(200) - ) - for i, item in enumerate(dataset): - print(tf.shape(item), i) - assert np.array_equal(tf.shape(item), [1, 512, 512, 1]) - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "US-PAL-8-10x-echo.dcm", - ) - - dataset = ( - tf.data.Dataset.from_tensor_slices([dcm_path]) - .repeat() - .map(tf.io.read_file) - .map(preprocess, num_parallel_calls=8) - .take(200) - ) - for i, item in enumerate(dataset): - print(tf.shape(item), i) - assert np.array_equal(tf.shape(item), [10, 430, 600, 3]) - - -def test_dicom_sequence(): - """test_decode_dicom_sequence""" - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "2.25.304589190180579357564631626197663875025.dcm", - ) - dcm_content = tf.io.read_file(filename=dcm_path) - - tags = tfio.image.decode_dicom_data( - dcm_content, tags=["[0x0008,0x1115][0][0x0008,0x1140][0][0x0008,0x1155]"] - ) - assert np.array_equal(tags, [b"2.25.211904290918469145111906856660599393535"]) - - dcm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_dicom", - "US-PAL-8-10x-echo.dcm", - ) - dcm_content = tf.io.read_file(filename=dcm_path) - - tags = tfio.image.decode_dicom_data(dcm_content, tags=["[0x0020,0x000E]"]) - assert np.array_equal(tags, [b"999.999.94827453"]) - - tags = tfio.image.decode_dicom_data(dcm_content, tags=["0x0020,0x000e"]) - assert np.array_equal(tags, [b"999.999.94827453"]) - - -if __name__ == "__main__": - test.main() diff --git a/tests/test_documentation_eager.py b/tests/test_documentation.py similarity index 100% rename from tests/test_documentation_eager.py rename to tests/test_documentation.py diff --git a/tests/test_elasticsearch_eager.py b/tests/test_elasticsearch.py similarity index 100% rename from tests/test_elasticsearch_eager.py rename to tests/test_elasticsearch.py diff --git a/tests/test_feather_eager.py b/tests/test_feather.py similarity index 95% rename from tests/test_feather_eager.py rename to tests/test_feather.py index fae30c5b0..5bbdcd2d3 100644 --- a/tests/test_feather_eager.py +++ b/tests/test_feather.py @@ -20,6 +20,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import tensorflow_io as tfio # pylint: disable=wrong-import-position @@ -37,9 +38,7 @@ def test_feather_format(): } df = pd.DataFrame(data).sort_index(axis=1) with tempfile.NamedTemporaryFile(delete=False) as f: - df.to_feather(f) - - df = pd.read_feather(f.name) + pa.feather.write_feather(df, f, version=1) feather = tfio.IOTensor.from_feather(f.name) for column in df.columns: diff --git a/tests/test_ffmpeg_eager.py b/tests/test_ffmpeg.py similarity index 100% rename from tests/test_ffmpeg_eager.py rename to tests/test_ffmpeg.py diff --git a/tests/test_filter_eager.py b/tests/test_filter.py similarity index 100% rename from tests/test_filter_eager.py rename to tests/test_filter.py diff --git a/tests/test_gcloud/test_gcs.sh b/tests/test_gcloud/test_gcs.sh old mode 100644 new mode 100755 index 235ed1a12..57aa0ac85 --- a/tests/test_gcloud/test_gcs.sh +++ b/tests/test_gcloud/test_gcs.sh @@ -4,7 +4,7 @@ set -o pipefail if [ "$#" -eq 1 ]; then container=$1 docker pull python:3.8 - docker run -d --rm --net=host --name=$container -v $PWD:/v -w /v python:3.8 bash -x -c 'python3 -m pip install gcloud-storage-emulator==0.3.0 && gcloud-storage-emulator start --port=9099' + docker run -d --rm --net=host --name=$container -v $PWD:/v -w /v python:3.8 bash -x -c 'python3 -m pip install -r /v/tests/test_gcloud/testbench/requirements.txt && gunicorn --bind "0.0.0.0:9099" --worker-class gevent --chdir "/v/tests/test_gcloud/testbench" testbench:application' echo wait 30 secs until gcs emulator is up and running sleep 30 exit 0 @@ -12,9 +12,11 @@ fi export PATH=$(python3 -m site --user-base)/bin:$PATH -python3 -m pip install gcloud-storage-emulator==0.3.0 - -gcloud-storage-emulator start --port=9099 & - +python3 -m pip install -r tests/test_gcloud/testbench/requirements.txt +echo starting gcs-testbench +gunicorn --bind "0.0.0.0:9099" \ + --worker-class gevent \ + --chdir "tests/test_gcloud/testbench" \ + testbench:application & sleep 30 # Wait for storage emulator to start -echo gcs emulator started successfully +echo gcs-testbench started successfully diff --git a/tests/test_pubsub/pubsub_test.sh b/tests/test_gcloud/test_pubsub_bigtable.sh similarity index 74% rename from tests/test_pubsub/pubsub_test.sh rename to tests/test_gcloud/test_pubsub_bigtable.sh index c8a82f58a..ff504b924 100755 --- a/tests/test_pubsub/pubsub_test.sh +++ b/tests/test_gcloud/test_pubsub_bigtable.sh @@ -23,8 +23,10 @@ if [ "$#" -eq 1 ]; then echo pull google/cloud-sdk docker pull google/cloud-sdk:236.0.0 echo pull google/cloud-sdk successfully - docker run -d --rm --net=host --name=$container -v $base:/v -w /v google/cloud-sdk:236.0.0 bash -x -c 'gcloud beta emulators pubsub start' + docker run -d --rm --net=host --name=$container-pubsub -v $base:/v -w /v google/cloud-sdk:236.0.0 bash -x -c 'gcloud beta emulators pubsub start' echo wait 10 secs until pubsub is up and running + docker run -d --rm --net=host --name=$container-bigtable -v $base:/v -w /v google/cloud-sdk:236.0.0 bash -x -c 'gcloud beta emulators bigtable start' + echo wait 10 secs until bigtable is up and running sleep 10 exit 0 fi @@ -34,7 +36,9 @@ tar -xzf google-cloud-sdk-236.0.0-darwin-x86_64.tar.gz google-cloud-sdk/install.sh -q google-cloud-sdk/bin/gcloud -q components install beta google-cloud-sdk/bin/gcloud -q components install pubsub-emulator +google-cloud-sdk/bin/gcloud -q components update beta google-cloud-sdk/bin/gcloud -q beta emulators pubsub start & +google-cloud-sdk/bin/gcloud -q beta emulators bigtable start & exit 0 diff --git a/tests/test_gcloud/testbench/README.md b/tests/test_gcloud/testbench/README.md new file mode 100644 index 000000000..ddc3971b3 --- /dev/null +++ b/tests/test_gcloud/testbench/README.md @@ -0,0 +1,15 @@ +# GCS Testbench + +This is a minimal testbench for GCS. It only supports data operation and creating/listing/deleteing bucket. + +## Install Dependencies + +```bash +pip install -r requirements.txt +``` + +## Run Test Bench + +```bash +gunicorn --bind "0.0.0.0:9099" --worker-class gevent --chdir "tests/test_gcs/testbench" testbench:application +``` diff --git a/tests/test_gcloud/testbench/error_response.py b/tests/test_gcloud/testbench/error_response.py new file mode 100644 index 000000000..8d6a5816f --- /dev/null +++ b/tests/test_gcloud/testbench/error_response.py @@ -0,0 +1,36 @@ +# Copyright 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A helper class to send error responses in the storage client test bench.""" + +import flask + + +class ErrorResponse(Exception): + """Simplify generation of error responses.""" + + status_code = 400 + + def __init__(self, message, status_code=None, payload=None): + Exception.__init__(self) + self.message = message + if status_code is not None: + self.status_code = status_code + self.payload = payload + + def as_response(self): + kv = dict(self.payload or ()) + kv["message"] = self.message + response = flask.jsonify(kv) + response.status_code = self.status_code + return response diff --git a/tests/test_gcloud/testbench/gcs_bucket.py b/tests/test_gcloud/testbench/gcs_bucket.py new file mode 100644 index 000000000..1fcba3992 --- /dev/null +++ b/tests/test_gcloud/testbench/gcs_bucket.py @@ -0,0 +1,258 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement a class to simulate GCS buckets.""" + +import base64 +import error_response +import flask +import gcs_object +import json +import re +import testbench_utils +import time + + +class GcsBucket: + """Represent a GCS Bucket.""" + + def __init__(self, gcs_url, name): + self.name = name + self.gcs_url = gcs_url + now = time.gmtime(time.time()) + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", now) + self.metadata = { + "timeCreated": timestamp, + "updated": timestamp, + "metageneration": "0", + "name": self.name, + "location": "US", + "storageClass": "STANDARD", + "etag": "XYZ=", + "labels": {"foo": "bar", "baz": "qux"}, + "owner": {"entity": "project-owners-123456789", "entityId": ""}, + } + self.resumable_uploads = {} + + def versioning_enabled(self): + """Return True if versioning is enabled for this Bucket.""" + v = self.metadata.get("versioning", None) + if v is None: + return False + return v.get("enabled", False) + + def check_preconditions(self, request): + """Verify that the preconditions in request are met. + + :param request:flask.Request the contents of the HTTP request. + :rtype:NoneType + :raises:ErrorResponse if the request does not pass the preconditions, + for example, the request has a `ifMetagenerationMatch` restriction + that is not met. + """ + + metageneration_match = request.args.get("ifMetagenerationMatch") + metageneration_not_match = request.args.get("ifMetagenerationNotMatch") + metageneration = self.metadata.get("metageneration") + + if ( + metageneration_not_match is not None + and metageneration_not_match == metageneration + ): + raise error_response.ErrorResponse( + "Precondition Failed (metageneration = %s)" % metageneration, + status_code=412, + ) + + if metageneration_match is not None and metageneration_match != metageneration: + raise error_response.ErrorResponse( + "Precondition Failed (metageneration = %s)" % metageneration, + status_code=412, + ) + + def create_resumable_upload(self, upload_url, request): + """Capture the details for a resumable upload. + + :param upload_url: str the base URL for uploads. + :param request: flask.Request the original http request. + :return: the HTTP response to send back. + """ + x_upload_content_type = request.headers.get( + "x-upload-content-type", "application/octet-stream" + ) + x_upload_content_length = request.headers.get("x-upload-content-length") + expected_bytes = None + if x_upload_content_length: + expected_bytes = int(x_upload_content_length) + + if request.args.get("name") is not None and len(request.data): + raise error_response.ErrorResponse( + "The name argument is only supported for empty payloads", + status_code=400, + ) + if len(request.data): + metadata = json.loads(request.data) + else: + metadata = {"name": request.args.get("name")} + + if metadata.get("name") is None: + raise error_response.ErrorResponse( + "Missing object name argument", status_code=400 + ) + metadata.setdefault("contentType", x_upload_content_type) + upload = { + "metadata": metadata, + "instructions": request.headers.get("x-goog-testbench-instructions"), + "fields": request.args.get("fields"), + "next_byte": 0, + "expected_bytes": expected_bytes, + "object_name": metadata.get("name"), + "media": b"", + "transfer": set(), + "done": False, + } + # Capture the preconditions, including those that are None. + for precondition in [ + "ifGenerationMatch", + "ifGenerationNotMatch", + "ifMetagenerationMatch", + "ifMetagenerationNotMatch", + ]: + upload[precondition] = request.args.get(precondition) + upload_id = base64.b64encode(bytearray(metadata.get("name"), "utf-8")).decode( + "utf-8" + ) + self.resumable_uploads[upload_id] = upload + location = "{}?uploadType=resumable&upload_id={}".format(upload_url, upload_id) + response = flask.make_response("") + response.headers["Location"] = location + return response + + def receive_upload_chunk(self, gcs_url, request): + """Receive a new upload chunk. + + :param gcs_url: str the base URL for the service. + :param request: flask.Request the original http request. + :return: the HTTP response. + """ + upload_id = request.args.get("upload_id") + if upload_id is None: + raise error_response.ErrorResponse( + "Missing upload_id in resumable_upload_chunk", status_code=400 + ) + upload = self.resumable_uploads.get(upload_id) + if upload is None: + raise error_response.ErrorResponse( + "Cannot find resumable upload %s" % upload_id, status_code=404 + ) + # Be gracious in what you accept, if the Content-Range header is not + # set we assume it is a good header and it is the end of the file. + next_byte = upload["next_byte"] + upload["transfer"].add(request.environ.get("HTTP_TRANSFER_ENCODING", "")) + end = next_byte + len(request.data) + total = end + final_chunk = False + payload = testbench_utils.extract_media(request) + content_range = request.headers.get("content-range") + if content_range is not None: + if content_range.startswith("bytes */*"): + # This is just a query to resume an upload, if it is done, return + # the completed upload payload and an empty range header. + response = flask.make_response(upload.get("payload", "")) + if next_byte > 1 and not upload["done"]: + response.headers["Range"] = "bytes=0-%d" % (next_byte - 1) + response.status_code = 200 if upload["done"] else 308 + return response + match = re.match(r"bytes \*/(\*|[0-9]+)", content_range) + if match: + if match.group(1) == "*": + total = 0 + else: + total = int(match.group(1)) + final_chunk = True + else: + match = re.match(r"bytes ([0-9]+)-([0-9]+)\/(\*|[0-9]+)", content_range) + if not match: + raise error_response.ErrorResponse( + "Invalid Content-Range in upload %s" % content_range, + status_code=400, + ) + begin = int(match.group(1)) + end = int(match.group(2)) + if match.group(3) == "*": + total = 0 + else: + total = int(match.group(3)) + final_chunk = True + + if begin != next_byte: + raise error_response.ErrorResponse( + "Mismatched data range, expected data at %d, got %d" + % (next_byte, begin), + status_code=400, + ) + if len(payload) != end - begin + 1: + raise error_response.ErrorResponse( + "Mismatched data range (%d) vs. received data (%d)" + % (end - begin + 1, len(payload)), + status_code=400, + ) + + upload["media"] = upload.get("media", b"") + payload + next_byte = len(upload.get("media", "")) + upload["next_byte"] = next_byte + response_payload = "" + if final_chunk and next_byte >= total: + expected_bytes = upload["expected_bytes"] + if expected_bytes is not None and expected_bytes != total: + raise error_response.ErrorResponse( + "X-Upload-Content-Length" + "validation failed. Expected=%d, got %d." % (expected_bytes, total) + ) + upload["done"] = True + object_name = upload.get("object_name") + object_path, blob = testbench_utils.get_object( + self.name, object_name, gcs_object.GcsObject(self.name, object_name) + ) + # Release a few resources to control memory usage. + original_metadata = upload.pop("metadata", None) + media = upload.pop("media", None) + blob.check_preconditions_by_value( + upload.get("ifGenerationMatch"), + upload.get("ifGenerationNotMatch"), + upload.get("ifMetagenerationMatch"), + upload.get("ifMetagenerationNotMatch"), + ) + if upload.pop("instructions", None) == "inject-upload-data-error": + media = testbench_utils.corrupt_media(media) + revision = blob.insert_resumable(gcs_url, request, media, original_metadata) + revision.metadata.setdefault("metadata", {}) + revision.metadata["metadata"]["x_testbench_transfer_encoding"] = ":".join( + upload["transfer"] + ) + response_payload = testbench_utils.filter_fields_from_response( + upload.get("fields"), revision.metadata + ) + upload["payload"] = response_payload + testbench_utils.insert_object(object_path, blob) + + response = flask.make_response(response_payload) + if next_byte == 0: + response.headers["Range"] = "bytes=0-0" + else: + response.headers["Range"] = "bytes=0-%d" % (next_byte - 1) + if upload.get("done", False): + response.status_code = 200 + else: + response.status_code = 308 + return response diff --git a/tests/test_gcloud/testbench/gcs_object.py b/tests/test_gcloud/testbench/gcs_object.py new file mode 100644 index 000000000..96acf19de --- /dev/null +++ b/tests/test_gcloud/testbench/gcs_object.py @@ -0,0 +1,770 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement a class to simulate GCS objects.""" + +import base64 +import crc32c +import error_response +import hashlib +import json +import struct +import testbench_utils +import time + + +class GcsObjectVersion: + """Represent a single revision of a GCS Object.""" + + def __init__(self, gcs_url, bucket_name, name, generation, request, media): + """Initialize a new object revision. + + :param gcs_url:str the base URL for the GCS service. + :param bucket_name:str the name of the bucket that contains the object. + :param name:str the name of the object. + :param generation:int the generation number for this object. + :param request:flask.Request the contents of the HTTP request. + :param media:str the contents of the object. + """ + self.gcs_url = gcs_url + self.bucket_name = bucket_name + self.name = name + self.generation = str(generation) + self.object_id = bucket_name + "/o/" + name + "/" + str(generation) + now = time.gmtime(time.time()) + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", now) + self.media = media + instructions = request.headers.get("x-goog-testbench-instructions") + if instructions == "inject-upload-data-error": + self.media = testbench_utils.corrupt_media(media) + + self.metadata = { + "timeCreated": timestamp, + "updated": timestamp, + "metageneration": "0", + "generation": str(generation), + "location": "US", + "storageClass": "STANDARD", + "size": str(len(self.media)), + "etag": "XYZ=", + "owner": {"entity": "project-owners-123456789", "entityId": ""}, + "md5Hash": base64.b64encode(hashlib.md5(self.media).digest()).decode( + "utf-8" + ), + "crc32c": base64.b64encode( + struct.pack(">I", crc32c.crc32(self.media)) + ).decode("utf-8"), + } + if request.headers.get("content-type") is not None: + self.metadata["contentType"] = request.headers.get("content-type") + + def update_from_metadata(self, metadata): + """Update from a metadata dictionary. + + :param metadata:dict a dictionary with new metadata values. + :rtype:NoneType + """ + tmp = self.metadata.copy() + tmp.update(metadata) + tmp["bucket"] = tmp.get("bucket", self.name) + tmp["name"] = tmp.get("name", self.name) + now = time.gmtime(time.time()) + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", now) + # Some values cannot be changed via updates, so we always reset them. + tmp.update( + { + "kind": "storage#object", + "bucket": self.bucket_name, + "name": self.name, + "id": self.object_id, + "selfLink": self.gcs_url + self.name, + "projectNumber": "123456789", + "updated": timestamp, + } + ) + tmp["metageneration"] = str(int(tmp.get("metageneration", "0")) + 1) + self.metadata = tmp + self._validate_hashes() + + def _validate_hashes(self): + """Validate the md5Hash and crc32c fields against the stored media.""" + self._validate_md5_hash() + self._validate_crc32c() + + def _validate_md5_hash(self): + """Validate the md5Hash field against the stored media.""" + actual = self.metadata.get("md5Hash", "") + expected = base64.b64encode(hashlib.md5(self.media).digest()).decode("utf-8") + if actual != expected: + raise error_response.ErrorResponse( + "Mismatched MD5 hash expected={}, actual={}".format(expected, actual) + ) + + def _validate_crc32c(self): + """Validate the crc32c field against the stored media.""" + actual = self.metadata.get("crc32c", "") + expected = base64.b64encode(struct.pack(">I", crc32c.crc32(self.media))).decode( + "utf-8" + ) + if actual != expected: + raise error_response.ErrorResponse( + "Mismatched CRC32C checksum expected={}, actual={}".format( + expected, actual + ) + ) + + def validate_encryption_for_read(self, request, prefix="x-goog-encryption"): + """Verify that the request includes the correct encryption keys. + + :param request:flask.Request the http request. + :param prefix: str the prefix shared by the encryption headers, + typically 'x-goog-encryption', but for rewrite requests it can be + 'x-goog-copy-source-encryption'. + :rtype:NoneType + """ + key_header = prefix + "-key" + hash_header = prefix + "-key-sha256" + algo_header = prefix + "-algorithm" + encryption = self.metadata.get("customerEncryption") + if encryption is None: + # The object is not encrypted, no key is needed. + if request.headers.get(key_header) is None: + return + else: + # The data is not encrypted, sending an encryption key is an + # error. + testbench_utils.raise_csek_error() + # The data is encrypted, the key must be present, match, and match its + # hash. + key_header_value = request.headers.get(key_header) + hash_header_value = request.headers.get(hash_header) + algo_header_value = request.headers.get(algo_header) + testbench_utils.validate_customer_encryption_headers( + key_header_value, hash_header_value, algo_header_value + ) + if encryption.get("keySha256") != hash_header_value: + testbench_utils.raise_csek_error() + + def _capture_customer_encryption(self, request): + """Capture the customer-supplied encryption key, if any. + + :param request:flask.Request the http request. + :rtype:NoneType + """ + if request.headers.get("x-goog-encryption-key") is None: + return + prefix = "x-goog-encryption" + key_header = prefix + "-key" + hash_header = prefix + "-key-sha256" + algo_header = prefix + "-algorithm" + key_header_value = request.headers.get(key_header) + hash_header_value = request.headers.get(hash_header) + algo_header_value = request.headers.get(algo_header) + testbench_utils.validate_customer_encryption_headers( + key_header_value, hash_header_value, algo_header_value + ) + self.metadata["customerEncryption"] = { + "encryptionAlgorithm": algo_header_value, + "keySha256": hash_header_value, + } + + def x_goog_hash_header(self): + """Return the value for the x-goog-hash header.""" + hashes = { + "md5": self.metadata.get("md5Hash", ""), + "crc32c": self.metadata.get("crc32c", ""), + } + hashes = ["{}={}".format(key, val) for key, val in hashes.items() if val] + return ",".join(hashes) + + +class GcsObject: + """Represent a GCS Object, including all its revisions.""" + + def __init__(self, bucket_name, name): + """Initialize a fake GCS Blob. + + :param bucket_name:str the bucket that will contain the new object. + :param name:str the name of the new object. + """ + self.bucket_name = bucket_name + self.name = name + # A counter to create new generation numbers for the object revisions. + # Note that 0 is an invalid generation number. The application can use + # ifGenerationMatch=0 as a pre-condition that means "object does not + # exist". + self.generation_generator = 0 + self.current_generation = None + self.revisions = {} + self.rewrite_token_generator = 0 + self.rewrite_operations = {} + + def get_revision(self, request, version_field_name="generation"): + """Get the information about a particular object revision or raise. + + :param request:flask.Request the contents of the http request. + :param version_field_name:str the name of the generation + parameter, typically 'generation', but sometimes 'sourceGeneration'. + :return: the object revision. + :rtype: GcsObjectVersion + :raises:ErrorResponse if the request contains an invalid generation + number. + """ + generation = request.args.get(version_field_name) + if generation is None: + return self.get_latest() + version = self.revisions.get(generation) + if version is None: + raise error_response.ErrorResponse( + "Precondition Failed: generation %s not found" % generation + ) + return version + + def del_revision(self, request): + """Delete a version of a fake GCS Blob. + + :param request:flask.Request the contents of the HTTP request. + :return: True if the object entry in the Bucket should be deleted. + :rtype: bool + """ + generation = request.args.get("generation") or self.current_generation + if generation is None: + return True + self.revisions.pop(generation) + if len(self.revisions) == 0: + self.current_generation = None + return True + self.current_generation = sorted(self.revisions.keys())[-1] + return False + + @classmethod + def _remove_non_writable_keys(cls, metadata): + """Remove the keys from metadata (an update or patch) that are not + writable. + + Both `Objects: patch` and `Objects: update` either ignore non-writable + keys or return 400 if the key does not match the current value. In + the testbench we simply always ignore them, to make life easier. + + :param metadata:dict a dictionary representing a patch or + update to the metadata. + :return metadata but with only any non-writable keys removed. + :rtype: dict + """ + writeable_keys = { + "acl", + "cacheControl", + "contentDisposition", + "contentEncoding", + "contentLanguage", + "contentType", + "eventBasedHold", + "metadata", + "temporaryHold", + "storageClass", + "customTime", + } + # Cannot change `metadata` while we are iterating over it, so we make + # a copy + keys = [key for key in metadata.keys()] + for key in keys: + if key not in writeable_keys: + metadata.pop(key, None) + return metadata + + def get_revision_by_generation(self, generation): + """Get object revision by generation or None if not found. + + :param generation:int + :return: the object revision by generation or None. + :rtype:GcsObjectRevision + """ + return self.revisions.get(str(generation), None) + + def get_latest(self): + return self.revisions.get(self.current_generation, None) + + def check_preconditions_by_value( + self, + generation_match, + generation_not_match, + metageneration_match, + metageneration_not_match, + ): + """Verify that the given precondition values are met.""" + current_generation = self.current_generation or "0" + if generation_match is not None and generation_match != current_generation: + raise error_response.ErrorResponse("Precondition Failed", status_code=412) + # This object does not exist (yet), testing in this case is special. + if ( + generation_not_match is not None + and generation_not_match == current_generation + ): + raise error_response.ErrorResponse("Precondition Failed", status_code=412) + + if self.current_generation is None: + if metageneration_match is not None or metageneration_not_match is not None: + raise error_response.ErrorResponse( + "Precondition Failed", status_code=412 + ) + return + + current = self.revisions.get(current_generation) + if current is None: + raise error_response.ErrorResponse("Object not found", status_code=404) + metageneration = current.metadata.get("metageneration") + if ( + metageneration_not_match is not None + and metageneration_not_match == metageneration + ): + raise error_response.ErrorResponse("Precondition Failed", status_code=412) + if metageneration_match is not None and metageneration_match != metageneration: + raise error_response.ErrorResponse("Precondition Failed", status_code=412) + + def check_preconditions( + self, + request, + if_generation_match="ifGenerationMatch", + if_generation_not_match="ifGenerationNotMatch", + if_metageneration_match="ifMetagenerationMatch", + if_metageneration_not_match="ifMetagenerationNotMatch", + ): + """Verify that the preconditions in request are met. + + :param request:flask.Request the http request. + :param if_generation_match:str the name of the generation match + parameter name, typically 'ifGenerationMatch', but sometimes + 'ifSourceGenerationMatch'. + :param if_generation_not_match:str the name of the generation not-match + parameter name, typically 'ifGenerationNotMatch', but sometimes + 'ifSourceGenerationNotMatch'. + :param if_metageneration_match:str the name of the metageneration match + parameter name, typically 'ifMetagenerationMatch', but sometimes + 'ifSourceMetagenerationMatch'. + :param if_metageneration_not_match:str the name of the metageneration + not-match parameter name, typically 'ifMetagenerationNotMatch', but + sometimes 'ifSourceMetagenerationNotMatch'. + :rtype:NoneType + """ + generation_match = request.args.get(if_generation_match) + generation_not_match = request.args.get(if_generation_not_match) + metageneration_match = request.args.get(if_metageneration_match) + metageneration_not_match = request.args.get(if_metageneration_not_match) + self.check_preconditions_by_value( + generation_match, + generation_not_match, + metageneration_match, + metageneration_not_match, + ) + + def _insert_revision(self, revision): + """Insert a new revision that has been initialized and checked. + + :param revision: GcsObjectVersion the new revision to insert. + :rtype:NoneType + """ + update = {str(self.generation_generator): revision} + bucket = testbench_utils.lookup_bucket(self.bucket_name) + if not bucket.versioning_enabled(): + self.revisions = update + else: + self.revisions.update(update) + self.current_generation = str(self.generation_generator) + + def insert(self, gcs_url, request): + """Insert a new revision based on the give flask request. + + :param gcs_url:str the root URL for the fake GCS service. + :param request:flask.Request the contents of the HTTP request. + :return: the newly created object version. + :rtype: GcsObjectVersion + """ + media = testbench_utils.extract_media(request) + self.generation_generator += 1 + revision = GcsObjectVersion( + gcs_url, + self.bucket_name, + self.name, + self.generation_generator, + request, + media, + ) + meta = revision.metadata.setdefault("metadata", {}) + meta["x_testbench_upload"] = "simple" + self._insert_revision(revision) + return revision + + def insert_multipart(self, gcs_url, request, resource, media_headers, media_body): + """Insert a new revision based on the give flask request. + + :param gcs_url:str the root URL for the fake GCS service. + :param request:flask.Request the contents of the HTTP request. + :param resource:dict JSON resource with object metadata. + :param media_headers:dict media headers in a multi-part upload. + :param media_body:str object data in a multi-part upload. + :return: the newly created object version. + :rtype: GcsObjectVersion + """ + # There are two ways to specify the content-type, the 'content-type' + # header and the resource['contentType'] field. They must be consistent, + # and the service generates an error when they are not. + if ( + resource.get("contentType") is not None + and media_headers.get("content-type") is not None + and resource.get("contentType") != media_headers.get("content-type") + ): + raise error_response.ErrorResponse( + ( + "Content-Type specified in the upload (%s) does not match" + + "contentType specified in the metadata (%s)." + ) + % (media_headers.get("content-type"), resource.get("contentType")), + status_code=400, + ) + # Set the contentType in the resource from the header. Note that if both + # are set they have the same value. + resource.setdefault("contentType", media_headers.get("content-type")) + self.generation_generator += 1 + revision = GcsObjectVersion( + gcs_url, + self.bucket_name, + self.name, + self.generation_generator, + request, + media_body, + ) + meta = revision.metadata.setdefault("metadata", {}) + meta["x_testbench_upload"] = "multipart" + if "md5Hash" in resource: + # We should return `x_testbench_md5` only when the user enables + # `MD5Hash` computations. + meta["x_testbench_md5"] = resource.get("md5Hash") + meta["x_testbench_crc32c"] = resource.get("crc32c", "") + # Apply any overrides from the resource object part. + revision.update_from_metadata(resource) + self._insert_revision(revision) + return revision + + def insert_resumable(self, gcs_url, request, media, resource): + """Implement the final insert for a resumable upload. + + :param gcs_url:str the root URL for the fake GCS service. + :param request:flask.Request the contents of the HTTP request. + :param media:str the media for the object. + :param resource:dict the metadata for the object. + :return: the newly created object version. + :rtype: GcsObjectVersion + """ + self.generation_generator += 1 + revision = GcsObjectVersion( + gcs_url, + self.bucket_name, + self.name, + self.generation_generator, + request, + media, + ) + meta = revision.metadata.setdefault("metadata", {}) + meta["x_testbench_upload"] = "resumable" + meta["x_testbench_md5"] = resource.get("md5Hash", "") + meta["x_testbench_crc32c"] = resource.get("crc32c", "") + # Apply any overrides from the resource object part. + revision.update_from_metadata(resource) + self._insert_revision(revision) + return revision + + def insert_xml(self, gcs_url, request): + """Implement the insert operation using the XML API. + + :param gcs_url:str the root URL for the fake GCS service. + :param request:flask.Request the contents of the HTTP request. + :return: the newly created object version. + :rtype: GcsObjectVersion + """ + media = testbench_utils.extract_media(request) + self.generation_generator += 1 + goog_hash = request.headers.get("x-goog-hash") + md5hash = None + crc32c = None + if goog_hash is not None: + for hash in goog_hash.split(","): + if hash.startswith("md5="): + md5hash = hash[4:] + if hash.startswith("crc32c="): + crc32c = hash[7:] + revision = GcsObjectVersion( + gcs_url, + self.bucket_name, + self.name, + self.generation_generator, + request, + media, + ) + meta = revision.metadata.setdefault("metadata", {}) + meta["x_testbench_upload"] = "xml" + if md5hash is not None: + meta["x_testbench_md5"] = md5hash + revision.update_from_metadata({"md5Hash": md5hash}) + if crc32c is not None: + meta["x_testbench_crc32c"] = crc32c + revision.update_from_metadata({"crc32c": crc32c}) + self._insert_revision(revision) + return revision + + def copy_from(self, gcs_url, request, source_revision): + """Insert a new revision based on the give flask request. + + :param gcs_url:str the root URL for the fake GCS service. + :param request:flask.Request the contents of the HTTP request. + :param source_revision:GcsObjectVersion the source object version to + copy from. + :return: the newly created object version. + :rtype: GcsObjectVersion + """ + self.generation_generator += 1 + source_revision.validate_encryption_for_read(request) + revision = GcsObjectVersion( + gcs_url, + self.bucket_name, + self.name, + self.generation_generator, + request, + source_revision.media, + ) + metadata = json.loads(request.data) + revision.update_from_metadata(metadata) + self._insert_revision(revision) + return revision + + def compose_from(self, gcs_url, request, composed_media): + """Compose a new revision based on the give flask request. + + :param gcs_url:str the root URL for the fake GCS service. + :param request:flask.Request the contents of the HTTP request. + :param composed_media:str contents of the composed object + :return: the newly created object version. + :rtype: GcsObjectVersion + """ + self.generation_generator += 1 + revision = GcsObjectVersion( + gcs_url, + self.bucket_name, + self.name, + self.generation_generator, + request, + composed_media, + ) + payload = json.loads(request.data) + if payload.get("destination") is not None: + revision.update_from_metadata(payload.get("destination")) + # The server often discards the MD5 Hash when composing objects, we can + # easily maintain them in the testbench, but dropping them helps us + # detect bugs sooner. + revision.metadata.pop("md5Hash") + self._insert_revision(revision) + return revision + + @classmethod + def rewrite_fixed_args(cls): + """The arguments that should not change between requests for the same + rewrite operation.""" + return [ + "destinationKmsKeyName", + "destinationPredefinedAcl", + "ifGenerationMatch", + "ifGenerationNotMatch", + "ifMetagenerationMatch", + "ifMetagenerationNotMatch", + "ifSourceGenerationMatch", + "ifSourceGenerationNotMatch", + "ifSourceMetagenerationMatch", + "ifSourceMetagenerationNotMatch", + "maxBytesRewrittenPerCall", + "projection", + "sourceGeneration", + "userProject", + ] + + @classmethod + def capture_rewrite_operation_arguments( + cls, request, destination_bucket, destination_object + ): + """Captures the arguments used to validate related rewrite calls. + + :rtype:dict + """ + original_arguments = {} + for arg in GcsObject.rewrite_fixed_args(): + original_arguments[arg] = request.args.get(arg) + original_arguments.update( + { + "destination_bucket": destination_bucket, + "destination_object": destination_object, + } + ) + return original_arguments + + @classmethod + def make_rewrite_token( + cls, operation, destination_bucket, destination_object, generation + ): + """Create a new rewrite token for the given operation.""" + return base64.b64encode( + bytearray( + "/".join( + [ + str(operation.get("id")), + destination_bucket, + destination_object, + str(generation), + str(operation.get("bytes_rewritten")), + ] + ), + "utf-8", + ) + ).decode("utf-8") + + def make_rewrite_operation(self, request, destination_bucket, destination_object): + """Create a new rewrite token for `Objects: rewrite`.""" + generation = request.args.get("sourceGeneration") + if generation is None: + generation = str(self.generation_generator) + else: + generation = generation + + self.rewrite_token_generator = self.rewrite_token_generator + 1 + body = json.loads(request.data) + original_arguments = self.capture_rewrite_operation_arguments( + request, destination_object, destination_object + ) + operation = { + "id": self.rewrite_token_generator, + "original_arguments": original_arguments, + "actual_generation": generation, + "bytes_rewritten": 0, + "body": body, + } + token = GcsObject.make_rewrite_token( + operation, destination_bucket, destination_object, generation + ) + return token, operation + + def rewrite_finish(self, gcs_url, request, body, source): + """Complete a rewrite from `source` into this object. + + :param gcs_url:str the root URL for the fake GCS service. + :param request:flask.Request the contents of the HTTP request. + :param body:dict the HTTP payload, parsed via json.loads() + :param source:GcsObjectVersion the source object version. + :return: the newly created object version. + :rtype: GcsObjectVersion + """ + media = source.media + self.check_preconditions(request) + self.generation_generator += 1 + revision = GcsObjectVersion( + gcs_url, + self.bucket_name, + self.name, + self.generation_generator, + request, + media, + ) + revision.update_from_metadata(body) + self._insert_revision(revision) + return revision + + def rewrite_step(self, gcs_url, request, destination_bucket, destination_object): + """Execute an iteration of `Objects: rewrite. + + Objects: rewrite may need to be called multiple times before it + succeeds. Only objects in the same location, with the same encryption, + are guaranteed to complete in a single request. + + The implementation simulates some, but not all, the behaviors of the + server, in particular, only rewrites within the same bucket and smaller + than 1MiB complete immediately. + + :param gcs_url:str the root URL for the fake GCS service. + :param request:flask.Request the contents of the HTTP request. + :param destination_bucket:str where will the object be placed after the + rewrite operation completes. + :param destination_object:str the name of the object when the rewrite + operation completes. + :return: a dictionary prepared for JSON encoding of a + `Objects: rewrite` response. + :rtype:dict + """ + body = json.loads(request.data) + rewrite_token = request.args.get("rewriteToken") + if rewrite_token is not None and rewrite_token != "": + # Note that we remove the rewrite operation, not just look it up. + # That way if the operation completes in this call, and/or fails, + # it is already removed. We need to insert it with a new token + # anyway, so this makes sense. + rewrite = self.rewrite_operations.pop(rewrite_token, None) + if rewrite is None: + raise error_response.ErrorResponse( + "Invalid or expired token in rewrite", status_code=410 + ) + else: + rewrite_token, rewrite = self.make_rewrite_operation( + request, destination_bucket, destination_bucket + ) + + # Compare the difference to the original arguments, on the first call + # this is a waste, but the code is easier to follow. + current_arguments = self.capture_rewrite_operation_arguments( + request, destination_bucket, destination_object + ) + diff = set(current_arguments) ^ set(rewrite.get("original_arguments")) + if len(diff) != 0: + raise error_response.ErrorResponse( + "Mismatched arguments to rewrite", status_code=412 + ) + + # This will raise if the version is deleted while the operation is in + # progress. + source = self.get_revision_by_generation(rewrite.get("actual_generation")) + source.validate_encryption_for_read( + request, prefix="x-goog-copy-source-encryption" + ) + bytes_rewritten = rewrite.get("bytes_rewritten") + bytes_rewritten += 1024 * 1024 + result = {"kind": "storage#rewriteResponse", "objectSize": len(source.media)} + if bytes_rewritten >= len(source.media): + bytes_rewritten = len(source.media) + rewrite["bytes_rewritten"] = bytes_rewritten + # Success, the operation completed. Return the new object: + object_path, destination = testbench_utils.get_object( + destination_bucket, + destination_object, + GcsObject(destination_bucket, destination_object), + ) + revision = destination.rewrite_finish(gcs_url, request, body, source) + testbench_utils.insert_object(object_path, destination) + result["done"] = True + result["resource"] = revision.metadata + rewrite_token = "" + else: + rewrite["bytes_rewritten"] = bytes_rewritten + rewrite_token = GcsObject.make_rewrite_token( + rewrite, destination_bucket, destination_object, source.generation + ) + self.rewrite_operations[rewrite_token] = rewrite + result["done"] = False + + result.update( + {"totalBytesRewritten": bytes_rewritten, "rewriteToken": rewrite_token} + ) + return result diff --git a/tests/test_gcloud/testbench/requirements.txt b/tests/test_gcloud/testbench/requirements.txt new file mode 100644 index 000000000..415ce8662 --- /dev/null +++ b/tests/test_gcloud/testbench/requirements.txt @@ -0,0 +1,5 @@ +crc32c==2.1 +flask==1.1.2 +greenlet==0.4.17 +gevent==20.9.0 +gunicorn==20.0.4 diff --git a/tests/test_gcloud/testbench/testbench.py b/tests/test_gcloud/testbench/testbench.py new file mode 100644 index 000000000..290fbb405 --- /dev/null +++ b/tests/test_gcloud/testbench/testbench.py @@ -0,0 +1,599 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A test bench for the Google Cloud Storage C++ Client Library.""" + +import argparse +import error_response +import flask +import gcs_bucket +import gcs_object +import json +import os +import re +import testbench_utils +import time +import sys +from werkzeug import serving +from werkzeug.middleware.dispatcher import DispatcherMiddleware + + +root = flask.Flask(__name__, subdomain_matching=True) +root.debug = True + + +@root.route("/") +def index(): + """Default handler for the test bench.""" + return "OK" + + +@root.route("/", subdomain="") +def root_get_object(bucket_name, object_name): + return xml_get_object(bucket_name, object_name) + + +@root.route("//", subdomain="") +def root_get_object_with_bucket(bucket_name, object_name): + return xml_get_object(bucket_name, object_name) + + +@root.route("/", subdomain="", methods=["PUT"]) +def root_put_object(bucket_name, object_name): + return xml_put_object(flask.request.host_url, bucket_name, object_name) + + +@root.route("//", subdomain="", methods=["PUT"]) +def root_put_object_with_bucket(bucket_name, object_name): + return xml_put_object(flask.request.host_url, bucket_name, object_name) + + +@root.errorhandler(error_response.ErrorResponse) +def root_error(error): + return error.as_response() + + +# Define the WSGI application to handle bucket requests. +GCS_HANDLER_PATH = "/storage/v1" +gcs = flask.Flask(__name__) +gcs.debug = True + + +def insert_magic_bucket(base_url): + if len(testbench_utils.all_buckets()) == 0: + bucket_name = os.environ.get( + "GOOGLE_CLOUD_CPP_STORAGE_TEST_BUCKET_NAME", "test-bucket" + ) + bucket = gcs_bucket.GcsBucket(base_url, bucket_name) + testbench_utils.insert_bucket(bucket_name, bucket) + + +@gcs.route("/") +def gcs_index(): + """The default handler for GCS requests.""" + return "OK" + + +@gcs.errorhandler(error_response.ErrorResponse) +def gcs_error(error): + return error.as_response() + + +@gcs.route("/b") +def buckets_list(): + """Implement the 'Buckets: list' API: return the Buckets in a project.""" + base_url = flask.url_for("gcs_index", _external=True) + project = flask.request.args.get("project") + if project is None or project.endswith("-"): + raise error_response.ErrorResponse( + "Invalid or missing project id in `Buckets: list`" + ) + insert_magic_bucket(base_url) + result = {"next_page_token": "", "items": []} + for name, b in testbench_utils.all_buckets(): + result["items"].append(b.metadata) + return testbench_utils.filtered_response(flask.request, result) + + +@gcs.route("/b", methods=["POST"]) +def buckets_insert(): + """Implement the 'Buckets: insert' API: create a new Bucket.""" + base_url = flask.url_for("gcs_index", _external=True) + insert_magic_bucket(base_url) + payload = json.loads(flask.request.data) + bucket_name = payload.get("name") + if bucket_name is None: + raise error_response.ErrorResponse( + "Missing bucket name in `Buckets: insert`", status_code=412 + ) + if testbench_utils.has_bucket(bucket_name): + raise error_response.ErrorResponse( + "Bucket %s already exists" % bucket_name, status_code=400 + ) + bucket = gcs_bucket.GcsBucket(base_url, bucket_name) + testbench_utils.insert_bucket(bucket_name, bucket) + return testbench_utils.filtered_response(flask.request, bucket.metadata) + + +@gcs.route("/b/") +def buckets_get(bucket_name): + """Implement the 'Buckets: get' API: return the metadata for a bucket.""" + base_url = flask.url_for("gcs_index", _external=True) + insert_magic_bucket(base_url) + bucket = testbench_utils.lookup_bucket(bucket_name) + bucket.check_preconditions(flask.request) + return testbench_utils.filtered_response(flask.request, bucket.metadata) + + +@gcs.route("/b/", methods=["DELETE"]) +def buckets_delete(bucket_name): + """Implement the 'Buckets: delete' API.""" + bucket = testbench_utils.lookup_bucket(bucket_name) + bucket.check_preconditions(flask.request) + testbench_utils.delete_bucket(bucket_name) + return testbench_utils.filtered_response(flask.request, {}) + + +@gcs.route("/b//o") +def objects_list(bucket_name): + """Implement the 'Objects: list' API: return the objects in a bucket.""" + # Lookup the bucket, if this fails the bucket does not exist, and this + # function should return an error. + base_url = flask.url_for("gcs_index", _external=True) + insert_magic_bucket(base_url) + _ = testbench_utils.lookup_bucket(bucket_name) + result = {"next_page_token": "", "items": [], "prefixes:": []} + versions_parameter = flask.request.args.get("versions") + all_versions = versions_parameter is not None and bool(versions_parameter) + prefixes = set() + prefix = flask.request.args.get("prefix", "", type("")) + delimiter = flask.request.args.get("delimiter", "", type("")) + start_offset = flask.request.args.get("startOffset", "", type("")) + end_offset = flask.request.args.get("endOffset", "", type("")) + bucket_link = bucket_name + "/o/" + for name, o in testbench_utils.all_objects(): + if name.find(bucket_link + prefix) != 0: + continue + if o.get_latest() is None: + continue + # We assume `delimiter` has only one character. + if name[len(bucket_link) :] < start_offset: + continue + if end_offset != "" and name[len(bucket_link) :] >= end_offset: + continue + delimiter_index = name.find(delimiter, len(bucket_link + prefix)) + if delimiter != "" and delimiter_index > 0: + # We don't want to include `bucket_link` in the returned prefix. + prefixes.add(name[len(bucket_link) : delimiter_index + 1]) + continue + if all_versions: + for object_version in o.revisions.values(): + result["items"].append(object_version.metadata) + else: + result["items"].append(o.get_latest().metadata) + result["prefixes"] = list(prefixes) + return testbench_utils.filtered_response(flask.request, result) + + +@gcs.route( + "/b//o//copyTo/b//o/", + methods=["POST"], +) +def objects_copy(source_bucket, source_object, destination_bucket, destination_object): + """Implement the 'Objects: copy' API, copy an object.""" + object_path, blob = testbench_utils.lookup_object(source_bucket, source_object) + blob.check_preconditions( + flask.request, + if_generation_match="ifSourceGenerationMatch", + if_generation_not_match="ifSourceGenerationNotMatch", + if_metageneration_match="ifSourceMetagenerationMatch", + if_metageneration_not_match="ifSourceMetagenerationNotMatch", + ) + source_revision = blob.get_revision(flask.request, "sourceGeneration") + if source_revision is None: + raise error_response.ErrorResponse( + "Revision not found %s" % object_path, status_code=404 + ) + + destination_path, destination = testbench_utils.get_object( + destination_bucket, + destination_object, + gcs_object.GcsObject(destination_bucket, destination_object), + ) + base_url = flask.url_for("gcs_index", _external=True) + current_version = destination.copy_from(base_url, flask.request, source_revision) + testbench_utils.insert_object(destination_path, destination) + return testbench_utils.filtered_response(flask.request, current_version.metadata) + + +@gcs.route( + "/b//o//rewriteTo/b//o/", + methods=["POST"], +) +def objects_rewrite( + source_bucket, source_object, destination_bucket, destination_object +): + """Implement the 'Objects: rewrite' API.""" + base_url = flask.url_for("gcs_index", _external=True) + insert_magic_bucket(base_url) + object_path, blob = testbench_utils.lookup_object(source_bucket, source_object) + blob.check_preconditions( + flask.request, + if_generation_match="ifSourceGenerationMatch", + if_generation_not_match="ifSourceGenerationNotMatch", + if_metageneration_match="ifSourceMetagenerationMatch", + if_metageneration_not_match="ifSourceMetagenerationNotMatch", + ) + response = blob.rewrite_step( + base_url, flask.request, destination_bucket, destination_object + ) + return testbench_utils.filtered_response(flask.request, response) + + +def objects_get_common(bucket_name, object_name, revision): + # Respect the Range: header, if present. + range_header = flask.request.headers.get("range") + response_payload = revision.media + begin = 0 + end = len(response_payload) + if range_header is not None: + m = re.match("bytes=([0-9]+)-([0-9]+)", range_header) + if m: + begin = int(m.group(1)) + end = int(m.group(2)) + response_payload = response_payload[begin : end + 1] + m = re.match("bytes=([0-9]+)-$", range_header) + if m: + begin = int(m.group(1)) + response_payload = response_payload[begin:] + m = re.match("bytes=-([0-9]+)$", range_header) + if m: + last = int(m.group(1)) + response_payload = response_payload[-last:] + # Process custom headers to test error conditions. + instructions = flask.request.headers.get("x-goog-testbench-instructions") + if instructions == "return-broken-stream": + + def streamer(): + chunk_size = 64 * 1024 + for r in range(0, len(response_payload), chunk_size): + if r > 1024 * 1024: + print("\n\n###### EXIT to simulate crash\n") + sys.exit(1) + time.sleep(0.1) + chunk_end = min(r + chunk_size, len(response_payload)) + yield response_payload[r:chunk_end] + + length = len(response_payload) + content_range = "bytes %d-%d/%d" % (begin, end - 1, length) + headers = { + "Content-Range": content_range, + "Content-Length": length, + "x-goog-hash": revision.x_goog_hash_header(), + "x-goog-generation": revision.generation, + } + return flask.Response(streamer(), status=200, headers=headers) + + if instructions == "return-corrupted-data": + response_payload = testbench_utils.corrupt_media(response_payload) + + if instructions is not None and instructions.startswith("stall-always"): + length = len(response_payload) + content_range = "bytes %d-%d/%d" % (begin, end - 1, length) + + def streamer(): + chunk_size = 16 * 1024 + for r in range(begin, end, chunk_size): + chunk_end = min(r + chunk_size, end) + if r == begin: + time.sleep(10) + yield response_payload[r:chunk_end] + + headers = { + "Content-Range": content_range, + "x-goog-hash": revision.x_goog_hash_header(), + "x-goog-generation": revision.generation, + } + return flask.Response(streamer(), status=200, headers=headers) + + if instructions == "stall-at-256KiB" and begin == 0: + length = len(response_payload) + content_range = "bytes %d-%d/%d" % (begin, end - 1, length) + + def streamer(): + chunk_size = 16 * 1024 + for r in range(begin, end, chunk_size): + chunk_end = min(r + chunk_size, end) + if r == 256 * 1024: + time.sleep(10) + yield response_payload[r:chunk_end] + + headers = { + "Content-Range": content_range, + "x-goog-hash": revision.x_goog_hash_header(), + "x-goog-generation": revision.generation, + } + return flask.Response(streamer(), status=200, headers=headers) + + if instructions is not None and instructions.startswith("return-503-after-256K"): + length = len(response_payload) + headers = { + "Content-Range": "bytes %d-%d/%d" % (begin, end - 1, length), + "x-goog-hash": revision.x_goog_hash_header(), + "x-goog-generation": revision.generation, + } + if begin == 0: + + def streamer(): + chunk_size = 4 * 1024 + for r in range(0, len(response_payload), chunk_size): + if r >= 256 * 1024: + print("\n\n###### EXIT to simulate crash\n") + sys.exit(1) + time.sleep(0.01) + chunk_end = min(r + chunk_size, len(response_payload)) + yield response_payload[r:chunk_end] + + return flask.Response(streamer(), status=200, headers=headers) + if instructions.endswith("/retry-1"): + print("## Return error for retry 1") + return flask.Response("Service Unavailable", status=503) + if instructions.endswith("/retry-2"): + print("## Return error for retry 2") + return flask.Response("Service Unavailable", status=503) + print("## Return success for %s" % instructions) + return flask.Response(response_payload, status=200, headers=headers) + + response = flask.make_response(response_payload) + length = len(response_payload) + content_range = "bytes %d-%d/%d" % (begin, end - 1, length) + response.headers["Content-Range"] = content_range + response.headers["x-goog-hash"] = revision.x_goog_hash_header() + response.headers["x-goog-generation"] = revision.generation + return response + + +@gcs.route("/b//o/", methods=["DELETE"]) +def objects_delete(bucket_name, object_name): + """Implement the 'Objects: delete' API. Delete objects.""" + object_path, blob = testbench_utils.lookup_object(bucket_name, object_name) + blob.check_preconditions(flask.request) + remove = blob.del_revision(flask.request) + if remove: + testbench_utils.delete_object(object_path) + return testbench_utils.filtered_response(flask.request, {}) + + +@gcs.route("/b//o//compose", methods=["POST"]) +def objects_compose(bucket_name, object_name): + """Implement the 'Objects: compose' API: concatenate Objects.""" + payload = json.loads(flask.request.data) + source_objects = payload["sourceObjects"] + if source_objects is None: + raise error_response.ErrorResponse( + "You must provide at least one source component.", status_code=400 + ) + if len(source_objects) > 32: + raise error_response.ErrorResponse( + "The number of source components provided" + " (%d) exceeds the maximum (32)" % len(source_objects), + status_code=400, + ) + composed_media = b"" + for source_object in source_objects: + source_object_name = source_object.get("name") + if source_object_name is None: + raise error_response.ErrorResponse("Required.", status_code=400) + source_object_path, source_blob = testbench_utils.lookup_object( + bucket_name, source_object_name + ) + source_revision = source_blob.get_latest() + generation = source_object.get("generation") + if generation is not None: + source_revision = source_blob.get_revision_by_generation(generation) + if source_revision is None: + raise error_response.ErrorResponse( + "No such object: %s" % source_object_path, status_code=404 + ) + object_preconditions = source_object.get("objectPreconditions") + if object_preconditions is not None: + if_generation_match = object_preconditions.get("ifGenerationMatch") + source_blob.check_preconditions_by_value( + if_generation_match, None, None, None + ) + composed_media += source_revision.media + composed_object_path, composed_object = testbench_utils.get_object( + bucket_name, object_name, gcs_object.GcsObject(bucket_name, object_name) + ) + composed_object.check_preconditions(flask.request) + base_url = flask.url_for("gcs_index", _external=True) + current_version = composed_object.compose_from( + base_url, flask.request, composed_media + ) + testbench_utils.insert_object(composed_object_path, composed_object) + return testbench_utils.filtered_response(flask.request, current_version.metadata) + + +# Define the WSGI application to handle bucket requests. +DOWNLOAD_HANDLER_PATH = "/download/storage/v1" +download = flask.Flask(__name__) +download.debug = True + + +@download.errorhandler(error_response.ErrorResponse) +def download_error(error): + return error.as_response() + + +@gcs.route("/b//o/") +@download.route("/b//o/") +def objects_get(bucket_name, object_name): + """Implement the 'Objects: get' API. Read objects or their metadata.""" + _, blob = testbench_utils.lookup_object(bucket_name, object_name) + blob.check_preconditions(flask.request) + revision = blob.get_revision(flask.request) + + media = flask.request.args.get("alt", None) + if media is None or media == "json": + return testbench_utils.filtered_response(flask.request, revision.metadata) + if media != "media": + raise error_response.ErrorResponse("Invalid alt=%s parameter" % media) + revision.validate_encryption_for_read(flask.request) + return objects_get_common(bucket_name, object_name, revision) + + +# Define the WSGI application to handle bucket requests. +UPLOAD_HANDLER_PATH = "/upload/storage/v1" +upload = flask.Flask(__name__) +upload.debug = True + + +@upload.errorhandler(error_response.ErrorResponse) +def upload_error(error): + return error.as_response() + + +@upload.route("/b//o", methods=["POST"]) +def objects_insert(bucket_name): + """Implement the 'Objects: insert' API. Insert a new GCS Object.""" + gcs_url = flask.url_for( + "objects_insert", bucket_name=bucket_name, _external=True + ).replace("/upload/", "/") + insert_magic_bucket(gcs_url) + + upload_type = flask.request.args.get("uploadType") + if upload_type is None: + raise error_response.ErrorResponse( + "uploadType not set in Objects: insert", status_code=400 + ) + if upload_type not in {"multipart", "media", "resumable"}: + raise error_response.ErrorResponse( + "testbench does not support %s uploadType" % upload_type, status_code=400 + ) + + if upload_type == "resumable": + bucket = testbench_utils.lookup_bucket(bucket_name) + upload_url = flask.url_for( + "objects_insert", bucket_name=bucket_name, _external=True + ) + return bucket.create_resumable_upload(upload_url, flask.request) + + object_path = None + blob = None + current_version = None + if upload_type == "media": + object_name = flask.request.args.get("name", None) + if object_name is None: + raise error_response.ErrorResponse( + "name not set in Objects: insert", status_code=412 + ) + object_path, blob = testbench_utils.get_object( + bucket_name, object_name, gcs_object.GcsObject(bucket_name, object_name) + ) + blob.check_preconditions(flask.request) + current_version = blob.insert(gcs_url, flask.request) + else: + resource, media_headers, media_body = testbench_utils.parse_multi_part( + flask.request + ) + object_name = flask.request.args.get("name", resource.get("name", None)) + if object_name is None: + raise error_response.ErrorResponse( + "name not set in Objects: insert", status_code=412 + ) + object_path, blob = testbench_utils.get_object( + bucket_name, object_name, gcs_object.GcsObject(bucket_name, object_name) + ) + blob.check_preconditions(flask.request) + current_version = blob.insert_multipart( + gcs_url, flask.request, resource, media_headers, media_body + ) + testbench_utils.insert_object(object_path, blob) + return testbench_utils.filtered_response(flask.request, current_version.metadata) + + +@upload.route("/b//o", methods=["PUT"]) +def resumable_upload_chunk(bucket_name): + """Receive a chunk for a resumable upload.""" + gcs_url = flask.url_for( + "objects_insert", bucket_name=bucket_name, _external=True + ).replace("/upload/", "/") + bucket = testbench_utils.lookup_bucket(bucket_name) + return bucket.receive_upload_chunk(gcs_url, flask.request) + + +@upload.route("/b//o", methods=["DELETE"]) +def delete_resumable_upload(bucket_name): + upload_type = flask.request.args.get("uploadType") + if upload_type != "resumable": + raise error_response.ErrorResponse( + "testbench can delete resumable uploadType only", status_code=400 + ) + upload_id = flask.request.args.get("upload_id") + if upload_id is None: + raise error_response.ErrorResponse( + "missing upload_id in delete_resumable_upload", status_code=400 + ) + bucket = testbench_utils.lookup_bucket(bucket_name) + if upload_id not in bucket.resumable_uploads: + raise error_response.ErrorResponse("upload_id does not exist", status_code=404) + bucket.resumable_uploads.pop(upload_id) + return testbench_utils.filtered_response(flask.request, {}) + + +def xml_put_object(gcs_url, bucket_name, object_name): + """Implement PUT for the XML API.""" + insert_magic_bucket(gcs_url) + object_path, blob = testbench_utils.get_object( + bucket_name, object_name, gcs_object.GcsObject(bucket_name, object_name) + ) + generation_match = flask.request.headers.get("x-goog-if-generation-match") + metageneration_match = flask.request.headers.get("x-goog-if-metageneration-match") + blob.check_preconditions_by_value( + generation_match, None, metageneration_match, None + ) + revision = blob.insert_xml(gcs_url, flask.request) + testbench_utils.insert_object(object_path, blob) + response = flask.make_response("") + response.headers["x-goog-hash"] = revision.x_goog_hash_header() + return response + + +def xml_get_object(bucket_name, object_name): + """Implement the 'Objects: insert' API. Insert a new GCS Object.""" + object_path, blob = testbench_utils.lookup_object(bucket_name, object_name) + if flask.request.args.get("acl") is not None: + raise error_response.ErrorResponse( + "ACL query not supported in XML API", status_code=500 + ) + if flask.request.args.get("encryption") is not None: + raise error_response.ErrorResponse( + "Encryption query not supported in XML API", status_code=500 + ) + generation_match = flask.request.headers.get("if-generation-match") + metageneration_match = flask.request.headers.get("if-metageneration-match") + blob.check_preconditions_by_value( + generation_match, None, metageneration_match, None + ) + revision = blob.get_revision(flask.request) + return objects_get_common(bucket_name, object_name, revision) + + +application = DispatcherMiddleware( + root, + { + GCS_HANDLER_PATH: gcs, + UPLOAD_HANDLER_PATH: upload, + DOWNLOAD_HANDLER_PATH: download, + }, +) diff --git a/tests/test_gcloud/testbench/testbench_utils.py b/tests/test_gcloud/testbench/testbench_utils.py new file mode 100644 index 000000000..8928273c8 --- /dev/null +++ b/tests/test_gcloud/testbench/testbench_utils.py @@ -0,0 +1,324 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Standalone helpers for the Google Cloud Storage test bench.""" + +import base64 +import error_response +import hashlib +import json +import random + + +def filter_fields_from_response(fields, response): + """Format the response as a JSON string, using any filtering included in + the request. + + :param fields:str the value of the `fields` parameter in the original + request. + :param response:dict a dictionary to be formatted as a JSON string. + :return: the response formatted as a string. + :rtype:str + """ + if fields is None: + return json.dumps(response) + tmp = {} + for key in fields.split(","): + key.replace(" ", "") + parentheses_idx = key.find("(") + if parentheses_idx != -1: + main_key = key[:parentheses_idx] + child_key = key[parentheses_idx + 1 : -1] + if main_key in response: + children = response[main_key] + if isinstance(children, list): + tmp_list = [] + for value in children: + tmp_list.append(value[child_key]) + tmp[main_key] = tmp_list + elif isinstance(children, dict): + tmp[main_key] = children[child_key] + elif key in response: + tmp[key] = response[key] + return json.dumps(tmp) + + +def filtered_response(request, response): + """Format the response as a JSON string, using any filtering included in + the request. + + :param request:flask.Request the original HTTP request. + :param response:dict a dictionary to be formatted as a JSON string. + :return: the response formatted as a string. + :rtype:str + """ + fields = request.args.get("fields") + return filter_fields_from_response(fields, response) + + +def raise_csek_error(code=400): + msg = "Missing a SHA256 hash of the encryption key, or it is not" + msg += " base64 encoded, or it does not match the encryption key." + link = "https://cloud.google.com/storage/docs/encryption#customer-supplied_encryption_keys" + error = { + "error": { + "errors": [ + { + "domain": "global", + "reason": "customerEncryptionKeySha256IsInvalid", + "message": msg, + "extendedHelp": link, + } + ], + "code": code, + "message": msg, + } + } + raise error_response.ErrorResponse(json.dumps(error), status_code=code) + + +def validate_customer_encryption_headers( + key_header_value, hash_header_value, algo_header_value +): + """Verify that the encryption headers are internally consistent. + + :param key_header_value: str the value of the x-goog-*-key header + :param hash_header_value: str the value of the x-goog-*-key-sha256 header + :param algo_header_value: str the value of the x-goog-*-key-algorithm header + :rtype: NoneType + """ + try: + if algo_header_value is None or algo_header_value != "AES256": + raise error_response.ErrorResponse( + "Invalid or missing algorithm %s for CSEK" % algo_header_value, + status_code=400, + ) + + key = base64.standard_b64decode(key_header_value) + if key is None or len(key) != 256 / 8: + raise_csek_error() + + h = hashlib.sha256() + h.update(key) + expected = base64.standard_b64encode(h.digest()).decode("utf-8") + if hash_header_value is None or expected != hash_header_value: + raise_csek_error() + except error_response.ErrorResponse: + # error_response.ErrorResponse indicates that the request was invalid, just pass + # that exception through. + raise + except Exception: + # Many of the functions above may raise, convert those to an + # error_response.ErrorResponse with the right format. + raise_csek_error() + + +def extract_media(request): + """Extract the media from a flask Request. + + To avoid race conditions when using greenlets we cannot perform I/O in the + constructor of GcsObjectVersion, or in any of the operations that modify + the state of the service. Because sometimes the media is uploaded with + chunked encoding, we need to do I/O before finishing the GcsObjectVersion + creation. If we do this I/O after the GcsObjectVersion creation started, + the the state of the application may change due to other I/O. + + :param request:flask.Request the HTTP request. + :return: the full media of the request. + :rtype: str + """ + if request.environ.get("HTTP_TRANSFER_ENCODING", "") == "chunked": + return request.environ.get("wsgi.input").read() + return request.data + + +def corrupt_media(media): + """Return a randomly modified version of a string. + + :param media:bytes a string (typically some object media) to be modified. + :return: a string that is slightly different than media. + :rtype: str + """ + # Deal with the boundary condition. + if not media: + return bytearray(random.sample("abcdefghijklmnopqrstuvwxyz", 1), "utf-8") + return b"B" + media[1:] if media[0:1] == b"A" else b"A" + media[1:] + + +# Define the collection of Buckets indexed by +GCS_BUCKETS = dict() + + +def lookup_bucket(bucket_name): + """Lookup a bucket by name in the global collection. + + :param bucket_name:str the name of the Bucket. + :return: the bucket matching the name. + :rtype:GcsBucket + :raises:ErrorResponse if the bucket is not found. + """ + bucket = GCS_BUCKETS.get(bucket_name) + if bucket is None: + raise error_response.ErrorResponse( + "Bucket %s not found" % bucket_name, status_code=404 + ) + return bucket + + +def has_bucket(bucket_name): + """Return True if the bucket already exists in the global collection.""" + return GCS_BUCKETS.get(bucket_name) is not None + + +def insert_bucket(bucket_name, bucket): + """Insert (or replace) a new bucket into the global collection. + + :param bucket_name:str the name of the bucket. + :param bucket:GcsBucket the bucket to insert. + """ + GCS_BUCKETS[bucket_name] = bucket + + +def delete_bucket(bucket_name): + """Delete a bucket from the global collection.""" + GCS_BUCKETS.pop(bucket_name) + + +def all_buckets(): + """Return a key,value iterator for all the buckets in the global collection. + + :rtype:dict[str, GcsBucket] + """ + return GCS_BUCKETS.items() + + +# Define the collection of GcsObjects indexed by /o/ +GCS_OBJECTS = dict() + + +def lookup_object(bucket_name, object_name): + """Lookup an object by name in the global collection. + + :param bucket_name:str the name of the Bucket that contains the object. + :param object_name:str the name of the Object. + :return: tuple the object path and the object. + :rtype: (str,GcsObject) + :raises:ErrorResponse if the object is not found. + """ + object_path, gcs_object = get_object(bucket_name, object_name, None) + if gcs_object is None: + raise error_response.ErrorResponse( + "Object {} in {} not found".format(object_name, bucket_name), + status_code=404, + ) + return object_path, gcs_object + + +def get_object(bucket_name, object_name, default_value): + """Find an object in the global collection, return a default value if not + found. + + :param bucket_name:str the name of the Bucket that contains the object. + :param object_name:str the name of the Object. + :param default_value:GcsObject the default value returned if the object is + not found. + :return: tuple the object path and the object. + :rtype: (str,GcsObject) + """ + object_path = bucket_name + "/o/" + object_name + return object_path, GCS_OBJECTS.get(object_path, default_value) + + +def insert_object(object_path, value): + """Insert an object to the global collection.""" + GCS_OBJECTS[object_path] = value + + +def delete_object(object_path): + """Delete an object from the global collection.""" + GCS_OBJECTS.pop(object_path) + + +def all_objects(): + """Return a key,value iterator for all the objects in the global collection. + + :rtype:dict[str, GcsBucket] + """ + return GCS_OBJECTS.items() + + +def parse_multi_part(request): + """Parse a multi-part request + + :param request:flask.Request multipart request. + :return: a tuple with the resource, media_headers and the media_body. + :rtype: (dict, dict, str) + """ + content_type = request.headers.get("content-type") + if content_type is None or not content_type.startswith("multipart/related"): + raise error_response.ErrorResponse( + "Missing or invalid content-type header in multipart upload" + ) + _, _, boundary = content_type.partition("boundary=") + boundary = boundary.strip('"') + if boundary is None: + raise error_response.ErrorResponse( + "Missing or invalid boundary in content-type header in multipart upload" + ) + + def parse_metadata(part): + result = part.split(b"\r\n") + if result[0] != b"" and result[-1] != b"": + raise error_response.ErrorResponse( + "Missing or invalid multipart %s" % str(part) + ) + result = list(filter(None, result)) + headers = {} + if len(result) < 2: + result.append(b"") + for header in result[:-1]: + key, value = header.split(b": ") + headers[key.decode("utf-8").lower()] = value.decode("utf-8") + return result[-1] + + def parse_body(part): + if part[0:2] != b"\r\n" or part[-2:] != b"\r\n": + raise error_response.ErrorResponse( + "Missing or invalid multipart %s" % str(part) + ) + part = part[2:-2] + part.lstrip(b"\r\n") + content_type_index = part.find(b"\r\n") + if content_type_index == -1: + raise error_response.ErrorResponse( + "Missing or invalid multipart %s" % str(part) + ) + content_type = part[:content_type_index] + _, value = content_type.decode("utf-8").split(": ") + media = part[content_type_index + 2 :] + if media[:2] == b"\r\n": + # It is either `\r\n` or `\r\n\r\n`, we should remove at most 4 characters. + media = media[2:] + return {"content-type": value}, media + + boundary = boundary.encode("utf-8") + body = extract_media(request) + parts = body.split(b"--" + boundary) + if parts[-1] != b"--\r\n" and parts[-1] != b"--": + raise error_response.ErrorResponse( + "Missing end marker (--%s--) in media body" % boundary + ) + resource = parse_metadata(parts[1]) + metadata = json.loads(resource) + content_type, media = parse_body(parts[2]) + return metadata, content_type, media diff --git a/tests/test_gcs_eager.py b/tests/test_gcs.py similarity index 81% rename from tests/test_gcs_eager.py rename to tests/test_gcs.py index 84556d354..cef36047c 100644 --- a/tests/test_gcs_eager.py +++ b/tests/test_gcs.py @@ -43,19 +43,18 @@ def test_read_file(): body = b"1234567" - # Setup the S3 bucket and key + # Setup the GCS bucket and key key_name = "TEST" - bucket_name = "s3e{}e".format(int(time.time())) - + bucket_name = "gs{}e".format(int(time.time())) bucket = client.create_bucket(bucket_name) - print("Project number: {}".format(bucket.project_number)) blob = bucket.blob(key_name) blob.upload_from_string(body) - response = blob.download_as_string() - print("RESPONSE: ", response) + response = blob.download_as_bytes() assert response == body - # content = tf.io.read_file("gs://{}/{}".format(bucket_name, key_name)) - # assert content == body + os.environ["CLOUD_STORAGE_TESTBENCH_ENDPOINT"] = "http://localhost:9099" + + content = tf.io.read_file("gs://{}/{}".format(bucket_name, key_name)) + assert content == body diff --git a/tests/test_gcs_config_ops.py b/tests/test_gcs_config_ops.py new file mode 100644 index 000000000..291986c22 --- /dev/null +++ b/tests/test_gcs_config_ops.py @@ -0,0 +1,46 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the gcs_config_ops.""" + + +import sys +import pytest + +import tensorflow as tf + +from tensorflow.python.platform import test +from tensorflow_io import gcs + +tf_v1 = tf.version.VERSION.startswith("1") + + +class GcsConfigOpsTest(test.TestCase): + """GCS Config OPS test""" + + @pytest.mark.skipif(sys.platform == "win32", reason="Windows not working yet") + def test_set_block_cache(self): + """test_set_block_cache""" + cfg = gcs.BlockCacheParams(max_bytes=1024 * 1024 * 1024) + if tf_v1: + with tf.Session() as session: + gcs.configure_gcs( + session, credentials=None, block_cache=cfg, device=None + ) + else: + gcs.configure_gcs(block_cache=cfg) + + +if __name__ == "__main__": + test.main() diff --git a/tests/test_genome.py b/tests/test_genome.py index 48b2642b3..1798b8914 100644 --- a/tests/test_genome.py +++ b/tests/test_genome.py @@ -19,8 +19,6 @@ import numpy as np import tensorflow as tf - -tf.compat.v1.disable_eager_execution() import tensorflow_io as tfio # pylint: disable=wrong-import-position fastq_path = os.path.join( @@ -30,13 +28,8 @@ def test_genome_fastq_reader(): """test_genome_fastq_reader""" - g1 = tf.compat.v1.Graph() - - with g1.as_default(): - data = tfio.genome.read_fastq(filename=fastq_path) - sess = tf.compat.v1.Session(graph=g1) - data_np = sess.run(data) + data = tfio.genome.read_fastq(filename=fastq_path) data_expected = [ b"GATTACA", @@ -52,8 +45,8 @@ def test_genome_fastq_reader(): b"FAD", ] - assert np.all(data_np.sequences == data_expected) - assert np.all(data_np.raw_quality == quality_expected) + assert np.all(data.sequences == data_expected) + assert np.all(data.raw_quality == quality_expected) def test_genome_sequences_to_onehot(): @@ -189,12 +182,10 @@ def test_genome_sequences_to_onehot(): [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], ] - with tf.compat.v1.Session() as sess: - raw_data = tfio.genome.read_fastq(filename=fastq_path) - data = tfio.genome.sequences_to_onehot(sequences=raw_data.sequences) - out = sess.run(data) + raw_data = tfio.genome.read_fastq(filename=fastq_path) + data = tfio.genome.sequences_to_onehot(sequences=raw_data.sequences) - assert np.all(out.to_list() == expected) + assert np.all(data.to_list() == expected) def test_genome_phred_sequences_to_probability(): @@ -210,28 +201,21 @@ def test_genome_phred_sequences_to_probability(): 0.00019952621369156986, ] - with tf.compat.v1.Session() as sess: - example_quality = tf.constant(example_quality_list) - converted_phred = tfio.genome.phred_sequences_to_probability(example_quality) - out = sess.run(converted_phred) + example_quality = tf.constant(example_quality_list) + converted_phred = tfio.genome.phred_sequences_to_probability(example_quality) # Compare flat values - assert np.allclose(out.flat_values.flatten(), expected_probabilities) + assert np.allclose( + converted_phred.flat_values.numpy().flatten(), expected_probabilities + ) # Ensure nested array lengths are correct assert np.all( - [len(a) == len(b) for a, b in zip(out.to_list(), example_quality_list)] + [ + len(a) == len(b) + for a, b in zip(converted_phred.to_list(), example_quality_list) + ] ) -def test_genome_phred_sequences_to_probability_with_other_genome_ops(): - """Test quality op in graph with read_fastq op, ensure no errors""" - with tf.compat.v1.Session() as sess: - raw_data = tfio.genome.read_fastq(filename=fastq_path) - data = tfio.genome.phred_sequences_to_probability( - phred_qualities=raw_data.raw_quality - ) - sess.run(data) - - if __name__ == "__main__": test.main() diff --git a/tests/test_genome_eager.py b/tests/test_genome_v1.py similarity index 79% rename from tests/test_genome_eager.py rename to tests/test_genome_v1.py index 1798b8914..48b2642b3 100644 --- a/tests/test_genome_eager.py +++ b/tests/test_genome_v1.py @@ -19,6 +19,8 @@ import numpy as np import tensorflow as tf + +tf.compat.v1.disable_eager_execution() import tensorflow_io as tfio # pylint: disable=wrong-import-position fastq_path = os.path.join( @@ -28,8 +30,13 @@ def test_genome_fastq_reader(): """test_genome_fastq_reader""" + g1 = tf.compat.v1.Graph() + + with g1.as_default(): + data = tfio.genome.read_fastq(filename=fastq_path) - data = tfio.genome.read_fastq(filename=fastq_path) + sess = tf.compat.v1.Session(graph=g1) + data_np = sess.run(data) data_expected = [ b"GATTACA", @@ -45,8 +52,8 @@ def test_genome_fastq_reader(): b"FAD", ] - assert np.all(data.sequences == data_expected) - assert np.all(data.raw_quality == quality_expected) + assert np.all(data_np.sequences == data_expected) + assert np.all(data_np.raw_quality == quality_expected) def test_genome_sequences_to_onehot(): @@ -182,10 +189,12 @@ def test_genome_sequences_to_onehot(): [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], ] - raw_data = tfio.genome.read_fastq(filename=fastq_path) - data = tfio.genome.sequences_to_onehot(sequences=raw_data.sequences) + with tf.compat.v1.Session() as sess: + raw_data = tfio.genome.read_fastq(filename=fastq_path) + data = tfio.genome.sequences_to_onehot(sequences=raw_data.sequences) + out = sess.run(data) - assert np.all(data.to_list() == expected) + assert np.all(out.to_list() == expected) def test_genome_phred_sequences_to_probability(): @@ -201,21 +210,28 @@ def test_genome_phred_sequences_to_probability(): 0.00019952621369156986, ] - example_quality = tf.constant(example_quality_list) - converted_phred = tfio.genome.phred_sequences_to_probability(example_quality) + with tf.compat.v1.Session() as sess: + example_quality = tf.constant(example_quality_list) + converted_phred = tfio.genome.phred_sequences_to_probability(example_quality) + out = sess.run(converted_phred) # Compare flat values - assert np.allclose( - converted_phred.flat_values.numpy().flatten(), expected_probabilities - ) + assert np.allclose(out.flat_values.flatten(), expected_probabilities) # Ensure nested array lengths are correct assert np.all( - [ - len(a) == len(b) - for a, b in zip(converted_phred.to_list(), example_quality_list) - ] + [len(a) == len(b) for a, b in zip(out.to_list(), example_quality_list)] ) +def test_genome_phred_sequences_to_probability_with_other_genome_ops(): + """Test quality op in graph with read_fastq op, ensure no errors""" + with tf.compat.v1.Session() as sess: + raw_data = tfio.genome.read_fastq(filename=fastq_path) + data = tfio.genome.phred_sequences_to_probability( + phred_qualities=raw_data.raw_quality + ) + sess.run(data) + + if __name__ == "__main__": test.main() diff --git a/tests/test_hdf5_eager.py b/tests/test_hdf5.py similarity index 100% rename from tests/test_hdf5_eager.py rename to tests/test_hdf5.py diff --git a/tests/test_hdfs_eager.py b/tests/test_hdfs.py similarity index 89% rename from tests/test_hdfs_eager.py rename to tests/test_hdfs.py index 468d04b06..1324a597d 100644 --- a/tests/test_hdfs_eager.py +++ b/tests/test_hdfs.py @@ -35,8 +35,8 @@ def test_read_file(): print("ADDRESS: {}".format(address)) body = b"1234567" - tf.io.write_file("hdfse://{}:9000/file.txt".format(address), body) + tf.io.write_file("hdfs://{}:9000/file.txt".format(address), body) - content = tf.io.read_file("hdfse://{}:9000/file.txt".format(address)) + content = tf.io.read_file("hdfs://{}:9000/file.txt".format(address)) print("CONTENT: {}".format(content)) assert content == body diff --git a/tests/test_hdfs/hdfs_test.sh b/tests/test_hdfs/hdfs_test.sh index 64eae6921..c46916332 100755 --- a/tests/test_hdfs/hdfs_test.sh +++ b/tests/test_hdfs/hdfs_test.sh @@ -19,7 +19,7 @@ set -o pipefail HADOOP_VERSION=2.7.0 docker pull sequenceiq/hadoop-docker:$HADOOP_VERSION -docker run -d --rm --net=host --name=tensorflow-io-hdfs sequenceiq/hadoop-docker:$HADOOP_VERSION +docker run -d --rm -p 9000:9000 --name=tensorflow-io-hdfs sequenceiq/hadoop-docker:$HADOOP_VERSION echo "Waiting for 30 secs until hadoop is up and running" sleep 30 docker logs tensorflow-io-hdfs diff --git a/tests/test_http_eager.py b/tests/test_http.py similarity index 100% rename from tests/test_http_eager.py rename to tests/test_http.py diff --git a/tests/test_ignite.py b/tests/test_ignite_v1.py similarity index 100% rename from tests/test_ignite.py rename to tests/test_ignite_v1.py diff --git a/tests/test_image_eager.py b/tests/test_image.py similarity index 100% rename from tests/test_image_eager.py rename to tests/test_image.py diff --git a/tests/test_io_dataset_eager.py b/tests/test_io_dataset.py similarity index 100% rename from tests/test_io_dataset_eager.py rename to tests/test_io_dataset.py diff --git a/tests/test_io_layer_eager.py b/tests/test_io_layer.py similarity index 100% rename from tests/test_io_layer_eager.py rename to tests/test_io_layer.py diff --git a/tests/test_io_tensor_eager.py b/tests/test_io_tensor.py similarity index 100% rename from tests/test_io_tensor_eager.py rename to tests/test_io_tensor.py diff --git a/tests/test_json_eager.py b/tests/test_json.py similarity index 100% rename from tests/test_json_eager.py rename to tests/test_json.py diff --git a/tests/test_kafka.py b/tests/test_kafka.py index 8c539473c..06e82b12a 100644 --- a/tests/test_kafka.py +++ b/tests/test_kafka.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of @@ -12,500 +12,504 @@ # License for the specific language governing permissions and limitations under # the License. # ============================================================================== -"""Tests for KafkaDataset.""" +"""Tests for Kafka Output Sequence.""" import time import pytest +import numpy as np +import threading import tensorflow as tf - -tf.compat.v1.disable_eager_execution() - -from tensorflow import dtypes # pylint: disable=wrong-import-position -from tensorflow import errors # pylint: disable=wrong-import-position -from tensorflow import test # pylint: disable=wrong-import-position -from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position - +import tensorflow_io as tfio +from tensorflow_io.kafka.python.ops import ( + kafka_ops, +) # pylint: disable=wrong-import-position import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position -class KafkaDatasetTest(test.TestCase): - """Tests for KafkaDataset.""" - - # The Kafka server has to be setup before the test - # and tear down after the test manually. - # The docker engine has to be installed. - # - # To setup the Kafka server: - # $ bash kafka_test.sh start kafka - # - # To tear down the Kafka server: - # $ bash kafka_test.sh stop kafka - - def test_kafka_dataset(self): - """Tests for KafkaDataset when reading non-keyed messages - from a single-partitioned topic""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset(topics, group="test", eof=True).repeat( - num_epochs +def test_kafka_io_tensor(): + kafka = tfio.IOTensor.from_kafka("test") + assert kafka.dtype == tf.string + assert kafka.shape.as_list() == [None] + assert np.all( + kafka.to_tensor().numpy() == [("D" + str(i)).encode() for i in range(10)] + ) + assert len(kafka.to_tensor()) == 10 + + +@pytest.mark.skip(reason="TODO") +def test_kafka_output_sequence(): + """Test case based on fashion mnist tutorial""" + fashion_mnist = tf.keras.datasets.fashion_mnist + ((train_images, train_labels), (test_images, _)) = fashion_mnist.load_data() + + class_names = [ + "T-shirt/top", + "Trouser", + "Pullover", + "Dress", + "Coat", + "Sandal", + "Shirt", + "Sneaker", + "Bag", + "Ankle boot", + ] + + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation=tf.nn.relu), + tf.keras.layers.Dense(10, activation=tf.nn.softmax), + ] + ) + + model.compile( + optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] + ) + + model.fit(train_images, train_labels, epochs=5) + + class OutputCallback(tf.keras.callbacks.Callback): + """KafkaOutputCallback""" + + def __init__( + self, batch_size, topic, servers + ): # pylint: disable=super-init-not-called + self._sequence = kafka_ops.KafkaOutputSequence(topic=topic, servers=servers) + self._batch_size = batch_size + + def on_predict_batch_end(self, batch, logs=None): + index = batch * self._batch_size + for outputs in logs["outputs"]: + for output in outputs: + self._sequence.setitem(index, class_names[np.argmax(output)]) + index += 1 + + def flush(self): + self._sequence.flush() + + channel = "e{}e".format(time.time()) + topic = "test_" + channel + + # By default batch size is 32 + output = OutputCallback(32, topic, "localhost") + predictions = model.predict(test_images, callbacks=[output]) + output.flush() + + predictions = [class_names[v] for v in np.argmax(predictions, axis=1)] + + # Reading from `test_e(time)e` we should get the same result + dataset = tfio.kafka.KafkaDataset(topics=[topic], group="test", eof=True) + for entry, prediction in zip(dataset, predictions): + assert entry.numpy() == prediction.encode() + + +def test_avro_kafka_dataset(): + """test_avro_kafka_dataset""" + schema = ( + '{"type":"record","name":"myrecord","fields":[' + '{"name":"f1","type":"string"},' + '{"name":"f2","type":"long"},' + '{"name":"f3","type":["null","string"],"default":null}' + "]}" + ) + dataset = kafka_io.KafkaDataset(["avro-test:0"], group="avro-test", eof=True) + # remove kafka framing + dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1)) + # deserialize avro + dataset = dataset.map( + lambda e: tfio.experimental.serialization.decode_avro(e, schema=schema) + ) + entries = [(e["f1"], e["f2"], e["f3"]) for e in dataset] + np.all(entries == [("value1", 1, ""), ("value2", 2, ""), ("value3", 3, "")]) + + +def test_avro_kafka_dataset_with_resource(): + """test_avro_kafka_dataset_with_resource""" + schema = ( + '{"type":"record","name":"myrecord","fields":[' + '{"name":"f1","type":"string"},' + '{"name":"f2","type":"long"},' + '{"name":"f3","type":["null","string"],"default":null}' + ']}"' + ) + schema_resource = kafka_io.decode_avro_init(schema) + dataset = kafka_io.KafkaDataset(["avro-test:0"], group="avro-test", eof=True) + # remove kafka framing + dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1)) + # deserialize avro + dataset = dataset.map( + lambda e: kafka_io.decode_avro( + e, schema=schema_resource, dtype=[tf.string, tf.int64, tf.string] ) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = data.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic test: read a limited number of messages from the topic. - sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) - for i in range(5): - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read all the messages from the topic from offset 5. - sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) - for i in range(5): - self.assertEqual(("D" + str(i + 5)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from different subscriptions of the same topic. - sess.run( - init_op, - feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], num_epochs: 1}, - ) - for j in range(2): - for i in range(5): - self.assertEqual( - ("D" + str(i + j * 5)).encode(), sess.run(get_next) - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both subscriptions. - sess.run( - init_op, - feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], num_epochs: 10}, - ) - for _ in range(10): - for j in range(2): - for i in range(5): - self.assertEqual( - ("D" + str(i + j * 5)).encode(), sess.run(get_next) - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both subscriptions. - sess.run( - init_batch_op, - feed_dict={ - topics: ["test:0:0:4", "test:0:5:-1"], - num_epochs: 10, - batch_size: 5, - }, - ) - for _ in range(10): - self.assertAllEqual( - [("D" + str(i)).encode() for i in range(5)], sess.run(get_next) - ) - self.assertAllEqual( - [("D" + str(i + 5)).encode() for i in range(5)], sess.run(get_next) - ) - - @pytest.mark.skip(reason="TODO") - def test_kafka_dataset_save_and_restore(self): - """Tests for KafkaDataset save and restore.""" - g = tf.Graph() - with g.as_default(): - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True - ).repeat(num_epochs) - iterator = repeat_dataset.make_initializable_iterator() - get_next = iterator.get_next() - - it = tf.data.experimental.make_saveable_from_iterator(iterator) - g.add_to_collection(tf.compat.v1.GraphKeys.SAVEABLE_OBJECTS, it) - saver = tf.compat.v1.train.Saver() - - model_file = "/tmp/test-kafka-model" - with self.cached_session() as sess: - sess.run( - iterator.initializer, - feed_dict={topics: ["test:0:0:4"], num_epochs: 1}, - ) - for i in range(3): - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - # Save current offset which is 2 - saver.save(sess, model_file, global_step=3) - - checkpoint_file = "/tmp/test-kafka-model-3" - with self.cached_session() as sess: - saver.restore(sess, checkpoint_file) - # Restore current offset to 2 - for i in [2, 3]: - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - - def test_kafka_topic_configuration(self): - """Tests for KafkaDataset topic configuration properties.""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - cfg_list = ["auto.offset.reset=earliest"] - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, config_topic=cfg_list - ).repeat(num_epochs) - - iterator = data.Iterator.from_structure(repeat_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Use a wrong offset 100 here to make sure - # configuration 'auto.offset.reset=earliest' works. - sess.run(init_op, feed_dict={topics: ["test:0:100:-1"], num_epochs: 1}) - for i in range(5): - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - - def test_kafka_global_configuration(self): - """Tests for KafkaDataset global configuration properties.""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - cfg_list = ["debug=generic", "enable.auto.commit=false"] - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, config_global=cfg_list - ).repeat(num_epochs) - - iterator = data.Iterator.from_structure(repeat_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) - for i in range(5): - self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_kafka_wrong_global_configuration_failed(self): - """Tests for KafkaDataset worng global configuration properties.""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - # Add wrong configuration - wrong_cfg = ["debug=al"] - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, config_global=wrong_cfg - ).repeat(num_epochs) - - iterator = data.Iterator.from_structure(repeat_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) - with self.assertRaises(errors.InternalError): - sess.run(get_next) - - def test_kafka_wrong_topic_configuration_failed(self): - """Tests for KafkaDataset wrong topic configuration properties.""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - # Add wrong configuration - wrong_cfg = ["auto.offset.reset=arliest"] - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, config_topic=wrong_cfg - ).repeat(num_epochs) - - iterator = data.Iterator.from_structure(repeat_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) - with self.assertRaises(errors.InternalError): - sess.run(get_next) - - def test_write_kafka(self): - """test_write_kafka""" - channel = "e{}e".format(time.time()) - - # Start with reading test topic, replace `D` with `e(time)e`, - # and write to test_e(time)e` topic. - dataset = kafka_io.KafkaDataset(topics=["test:0:0:4"], group="test", eof=True) - dataset = dataset.map( - lambda x: kafka_io.write_kafka( - tf.strings.regex_replace(x, "D", channel), topic="test_" + channel - ) + ) + entries = [(f1.numpy(), f2.numpy(), f3.numpy()) for (f1, f2, f3) in dataset] + np.all(entries == [("value1", 1), ("value2", 2), ("value3", 3)]) + + +def test_kafka_stream_dataset(): + dataset = tfio.IODataset.stream().from_kafka("test").batch(2) + assert np.all( + [k.numpy().tolist() for (k, _) in dataset] + == np.asarray([("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)) + ) + + +def test_kafka_io_dataset(): + dataset = tfio.IODataset.from_kafka( + "test", configuration=["fetch.min.bytes=2"] + ).batch(2) + # repeat multiple times will result in the same result + for _ in range(5): + assert np.all( + [k.numpy().tolist() for (k, _) in dataset] + == np.asarray([("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)) ) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic test: read from topic 0. - sess.run(init_op) - for i in range(5): - self.assertEqual((channel + str(i)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Reading from `test_e(time)e` we should get the same result - dataset = kafka_io.KafkaDataset( - topics=["test_" + channel], group="test", eof=True + + +def test_avro_encode_decode(): + """test_avro_encode_decode""" + schema = ( + '{"type":"record","name":"myrecord","fields":' + '[{"name":"f1","type":"string"},{"name":"f2","type":"long"}]}' + ) + value = [("value1", 1), ("value2", 2), ("value3", 3)] + f1 = tf.cast([v[0] for v in value], tf.string) + f2 = tf.cast([v[1] for v in value], tf.int64) + message = tfio.experimental.serialization.encode_avro([f1, f2], schema=schema) + entries = tfio.experimental.serialization.decode_avro(message, schema=schema) + assert np.all(entries["f1"].numpy() == f1.numpy()) + assert np.all(entries["f2"].numpy() == f2.numpy()) + + +def test_kafka_group_io_dataset_primary_cg(): + """Test the functionality of the KafkaGroupIODataset when the consumer group + is being newly created. + + NOTE: After the kafka cluster is setup during the testing phase, 10 messages + are written to the 'key-partition-test' topic with 5 in each partition + (topic created with 2 partitions, the messages are split based on the keys). + And the same 10 messages are written into the 'key-test' topic (topic created + with 1 partition, so no splitting of the messages based on the keys). + + K0:D0, K1:D1, K0:D2, K1:D3, K0:D4, K1:D5, K0:D6, K1:D7, K0:D8, K1:D9. + + Here, messages D0, D2, D4, D6 and D8 are written into partition 0 and the rest are written + into partition 1. + + Also, since the messages are read from different partitions, the order of retrieval may not be + the same as storage. Thus, we sort and compare. + """ + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(10)]) + ) + + +def test_kafka_group_io_dataset_primary_cg_no_lag(): + """Test the functionality of the KafkaGroupIODataset when the + consumer group has read all the messages and committed the offsets. + """ + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], + ) + assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) + + +def test_kafka_group_io_dataset_primary_cg_new_topic(): + """Test the functionality of the KafkaGroupIODataset when the existing + consumer group reads data from a new topic. + """ + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(10)]) + ) + + +def test_kafka_group_io_dataset_resume_primary_cg(): + """Test the functionality of the KafkaGroupIODataset when the + consumer group is yet to catch up with the newly added messages only + (Instead of reading from the beginning). + """ + + # Write new messages to the topic + for i in range(10, 100): + message = "D{}".format(i) + kafka_io.write_kafka(message=message, topic="key-partition-test") + # Read only the newly sent 90 messages + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(10, 100)]) + ) + + +def test_kafka_group_io_dataset_resume_primary_cg_new_topic(): + """Test the functionality of the KafkaGroupIODataset when the + consumer group is yet to catch up with the newly added messages only + (Instead of reading from the beginning) from the new topic. + """ + + # Write new messages to the topic + for i in range(10, 100): + message = "D{}".format(i) + kafka_io.write_kafka(message=message, topic="key-test") + # Read only the newly sent 90 messages + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-test"], + group_id="cgtestprimary", + servers="localhost:9092", + configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(10, 100)]) + ) + + +def test_kafka_group_io_dataset_secondary_cg(): + """Test the functionality of the KafkaGroupIODataset when a + secondary consumer group is created and is yet to catch up all the messages, + from the beginning. + """ + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtestsecondary", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(100)]) + ) + + +def test_kafka_group_io_dataset_tertiary_cg_multiple_topics(): + """Test the functionality of the KafkaGroupIODataset when a new + consumer group reads data from multiple topics from the beginning. + """ + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test", "key-test"], + group_id="cgtesttertiary", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(100)] * 2) + ) + + +def test_kafka_group_io_dataset_auto_offset_reset(): + """Test the functionality of the `auto.offset.reset` configuration + at global and topic level""" + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgglobaloffsetearliest", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(100)]) + ) + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgglobaloffsetlatest", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=latest", + ], + ) + assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtopicoffsetearliest", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "conf.topic.auto.offset.reset=earliest", + ], + ) + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(100)]) + ) + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgtopicoffsetlatest", + servers="localhost:9092", + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "conf.topic.auto.offset.reset=latest", + ], + ) + assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) + + +def test_kafka_group_io_dataset_invalid_stream_timeout(): + """Test the functionality of the KafkaGroupIODataset when the + consumer is configured to have an invalid stream_timeout value which is + less than the message_timeout value. + NOTE: The default value for message_timeout=5000 + """ + + STREAM_TIMEOUT = -20 + try: + tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test", "key-test"], + group_id="cgteststreaminvalid", + servers="localhost:9092", + stream_timeout=STREAM_TIMEOUT, + configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], ) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - for i in range(5): - self.assertEqual((channel + str(i)).encode(), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_kafka_dataset_with_key(self): - """Tests for KafkaDataset when reading keyed-messages - from a single-partitioned topic""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, message_key=True - ).repeat(num_epochs) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = data.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic test: read a limited number of keyed messages from the topic. - sess.run(init_op, feed_dict={topics: ["key-test:0:0:4"], num_epochs: 1}) - for i in range(5): - self.assertEqual( - (("D" + str(i)).encode(), ("K" + str(i % 2)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read all the keyed messages from the topic from offset 5. - sess.run(init_op, feed_dict={topics: ["key-test:0:5:-1"], num_epochs: 1}) - for i in range(5): - self.assertEqual( - (("D" + str(i + 5)).encode(), ("K" + str((i + 5) % 2)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from different subscriptions of the same topic. - sess.run( - init_op, - feed_dict={ - topics: ["key-test:0:0:4", "key-test:0:5:-1"], - num_epochs: 1, - }, - ) - for j in range(2): - for i in range(5): - self.assertEqual( - ( - ("D" + str(i + j * 5)).encode(), - ("K" + str((i + j * 5) % 2)).encode(), - ), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both subscriptions. - sess.run( - init_op, - feed_dict={ - topics: ["key-test:0:0:4", "key-test:0:5:-1"], - num_epochs: 10, - }, - ) - for _ in range(10): - for j in range(2): - for i in range(5): - self.assertEqual( - ( - ("D" + str(i + j * 5)).encode(), - ("K" + str((i + j * 5) % 2)).encode(), - ), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both subscriptions. - sess.run( - init_batch_op, - feed_dict={ - topics: ["key-test:0:0:4", "key-test:0:5:-1"], - num_epochs: 10, - batch_size: 5, - }, - ) - for _ in range(10): - self.assertAllEqual( - [ - [("D" + str(i)).encode() for i in range(5)], - [("K" + str(i % 2)).encode() for i in range(5)], - ], - sess.run(get_next), - ) - self.assertAllEqual( - [ - [("D" + str(i + 5)).encode() for i in range(5)], - [("K" + str((i + 5) % 2)).encode() for i in range(5)], - ], - sess.run(get_next), - ) - - def test_kafka_dataset_with_partitioned_key(self): - """Tests for KafkaDataset when reading keyed-messages - from a multi-partitioned topic""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, message_key=True - ).repeat(num_epochs) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = data.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic test: read first 5 messages from the first partition of the topic. - # NOTE: The key-partition mapping occurs based on the order in which the data - # is being stored in kafka. Please check kafka_test.sh for the sample data. - - sess.run( - init_op, - feed_dict={topics: ["key-partition-test:0:0:5"], num_epochs: 1}, - ) - for i in range(5): - self.assertEqual( - (("D" + str(i * 2)).encode(), (b"K0")), sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read first 5 messages from the second partition of the topic. - sess.run( - init_op, - feed_dict={topics: ["key-partition-test:1:0:5"], num_epochs: 1}, - ) - for i in range(5): - self.assertEqual( - (("D" + str(i * 2 + 1)).encode(), (b"K1")), sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from different subscriptions to the same topic. - sess.run( - init_op, - feed_dict={ - topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], - num_epochs: 1, - }, - ) - for j in range(2): - for i in range(5): - self.assertEqual( - (("D" + str(i * 2 + j)).encode(), ("K" + str(j)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both subscriptions. - sess.run( - init_op, - feed_dict={ - topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], - num_epochs: 10, - }, - ) - for _ in range(10): - for j in range(2): - for i in range(5): - self.assertEqual( - (("D" + str(i * 2 + j)).encode(), ("K" + str(j)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both subscriptions. - sess.run( - init_batch_op, - feed_dict={ - topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], - num_epochs: 10, - batch_size: 5, - }, + except ValueError as e: + assert str( + e + ) == "Invalid stream_timeout value: {} ,set it to -1 to block indefinitely.".format( + STREAM_TIMEOUT + ) + + +def test_kafka_group_io_dataset_stream_timeout_check(): + """Test the functionality of the KafkaGroupIODataset when the + consumer is configured to have a valid stream_timeout value and thus waits + for the new messages from kafka. + NOTE: The default value for message_timeout=5000 + """ + + def write_messages_background(): + # Write new messages to the topic in a background thread + time.sleep(6) + for i in range(100, 200): + message = "D{}".format(i) + kafka_io.write_kafka(message=message, topic="key-partition-test") + + dataset = tfio.experimental.streaming.KafkaGroupIODataset( + topics=["key-partition-test"], + group_id="cgteststreamvalid", + servers="localhost:9092", + stream_timeout=20000, + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + + # start writing the new messages to kafka using the background job. + # the job sleeps for some time (< stream_timeout) and then writes the + # messages into the topic. + thread = threading.Thread(target=write_messages_background, args=()) + thread.daemon = True + thread.start() + + # At the end, after the timeout has occurred, we must have the old 100 messages + # along with the new 100 messages + assert np.all( + sorted([k.numpy() for (k, _) in dataset]) + == sorted([("D" + str(i)).encode() for i in range(200)]) + ) + + +def test_kafka_batch_io_dataset(): + """Test the functionality of the KafkaBatchIODataset by training a model + directly on the incoming kafka message batch(of type tf.data.Dataset), in an + online-training fashion. + + NOTE: This kind of dataset is suitable in scenarios where the 'keys' of 'messages' + act as labels. If not, additional transformations are required. + """ + + dataset = tfio.experimental.streaming.KafkaBatchIODataset( + topics=["mini-batch-test"], + group_id="cgminibatch", + servers=None, + stream_timeout=5000, + configuration=[ + "session.timeout.ms=7000", + "max.poll.interval.ms=8000", + "auto.offset.reset=earliest", + ], + ) + + NUM_COLUMNS = 1 + model = tf.keras.Sequential( + [ + tf.keras.layers.Input(shape=(NUM_COLUMNS,)), + tf.keras.layers.Dense(4, activation="relu"), + tf.keras.layers.Dropout(0.1), + tf.keras.layers.Dense(1, activation="sigmoid"), + ] + ) + model.compile( + optimizer="adam", + loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), + metrics=["accuracy"], + ) + assert issubclass(type(dataset), tf.data.Dataset) + for mini_d in dataset: + mini_d = mini_d.map( + lambda m, k: ( + tf.strings.to_number(m, out_type=tf.float32), + tf.strings.to_number(k, out_type=tf.float32), ) - for _ in range(10): - for j in range(2): - self.assertAllEqual( - [ - [("D" + str(i * 2 + j)).encode() for i in range(5)], - [("K" + str(j)).encode() for i in range(5)], - ], - sess.run(get_next), - ) - - def test_kafka_dataset_with_offset(self): - """Tests for KafkaDataset when reading non-keyed messages - from a single-partitioned topic""" - topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) - num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = kafka_io.KafkaDataset( - topics, group="test", eof=True, message_offset=True - ).repeat(num_epochs) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = data.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Basic offset test: read a limited number of messages from the topic. - sess.run(init_op, feed_dict={topics: ["offset-test:0:0:4"], num_epochs: 1}) - for i in range(5): - self.assertEqual( - (("D" + str(i)).encode(), ("0:" + str(i)).encode()), - sess.run(get_next), - ) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - -if __name__ == "__main__": - test.main() + ).batch(2) + assert issubclass(type(mini_d), tf.data.Dataset) + # Fits the model as long as the data keeps on streaming + model.fit(mini_d, epochs=5) diff --git a/tests/test_kafka_eager.py b/tests/test_kafka_eager.py deleted file mode 100644 index 06e82b12a..000000000 --- a/tests/test_kafka_eager.py +++ /dev/null @@ -1,515 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -# ============================================================================== -"""Tests for Kafka Output Sequence.""" - - -import time -import pytest -import numpy as np -import threading - -import tensorflow as tf -import tensorflow_io as tfio -from tensorflow_io.kafka.python.ops import ( - kafka_ops, -) # pylint: disable=wrong-import-position -import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position - - -def test_kafka_io_tensor(): - kafka = tfio.IOTensor.from_kafka("test") - assert kafka.dtype == tf.string - assert kafka.shape.as_list() == [None] - assert np.all( - kafka.to_tensor().numpy() == [("D" + str(i)).encode() for i in range(10)] - ) - assert len(kafka.to_tensor()) == 10 - - -@pytest.mark.skip(reason="TODO") -def test_kafka_output_sequence(): - """Test case based on fashion mnist tutorial""" - fashion_mnist = tf.keras.datasets.fashion_mnist - ((train_images, train_labels), (test_images, _)) = fashion_mnist.load_data() - - class_names = [ - "T-shirt/top", - "Trouser", - "Pullover", - "Dress", - "Coat", - "Sandal", - "Shirt", - "Sneaker", - "Bag", - "Ankle boot", - ] - - train_images = train_images / 255.0 - test_images = test_images / 255.0 - - model = tf.keras.Sequential( - [ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(128, activation=tf.nn.relu), - tf.keras.layers.Dense(10, activation=tf.nn.softmax), - ] - ) - - model.compile( - optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] - ) - - model.fit(train_images, train_labels, epochs=5) - - class OutputCallback(tf.keras.callbacks.Callback): - """KafkaOutputCallback""" - - def __init__( - self, batch_size, topic, servers - ): # pylint: disable=super-init-not-called - self._sequence = kafka_ops.KafkaOutputSequence(topic=topic, servers=servers) - self._batch_size = batch_size - - def on_predict_batch_end(self, batch, logs=None): - index = batch * self._batch_size - for outputs in logs["outputs"]: - for output in outputs: - self._sequence.setitem(index, class_names[np.argmax(output)]) - index += 1 - - def flush(self): - self._sequence.flush() - - channel = "e{}e".format(time.time()) - topic = "test_" + channel - - # By default batch size is 32 - output = OutputCallback(32, topic, "localhost") - predictions = model.predict(test_images, callbacks=[output]) - output.flush() - - predictions = [class_names[v] for v in np.argmax(predictions, axis=1)] - - # Reading from `test_e(time)e` we should get the same result - dataset = tfio.kafka.KafkaDataset(topics=[topic], group="test", eof=True) - for entry, prediction in zip(dataset, predictions): - assert entry.numpy() == prediction.encode() - - -def test_avro_kafka_dataset(): - """test_avro_kafka_dataset""" - schema = ( - '{"type":"record","name":"myrecord","fields":[' - '{"name":"f1","type":"string"},' - '{"name":"f2","type":"long"},' - '{"name":"f3","type":["null","string"],"default":null}' - "]}" - ) - dataset = kafka_io.KafkaDataset(["avro-test:0"], group="avro-test", eof=True) - # remove kafka framing - dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1)) - # deserialize avro - dataset = dataset.map( - lambda e: tfio.experimental.serialization.decode_avro(e, schema=schema) - ) - entries = [(e["f1"], e["f2"], e["f3"]) for e in dataset] - np.all(entries == [("value1", 1, ""), ("value2", 2, ""), ("value3", 3, "")]) - - -def test_avro_kafka_dataset_with_resource(): - """test_avro_kafka_dataset_with_resource""" - schema = ( - '{"type":"record","name":"myrecord","fields":[' - '{"name":"f1","type":"string"},' - '{"name":"f2","type":"long"},' - '{"name":"f3","type":["null","string"],"default":null}' - ']}"' - ) - schema_resource = kafka_io.decode_avro_init(schema) - dataset = kafka_io.KafkaDataset(["avro-test:0"], group="avro-test", eof=True) - # remove kafka framing - dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1)) - # deserialize avro - dataset = dataset.map( - lambda e: kafka_io.decode_avro( - e, schema=schema_resource, dtype=[tf.string, tf.int64, tf.string] - ) - ) - entries = [(f1.numpy(), f2.numpy(), f3.numpy()) for (f1, f2, f3) in dataset] - np.all(entries == [("value1", 1), ("value2", 2), ("value3", 3)]) - - -def test_kafka_stream_dataset(): - dataset = tfio.IODataset.stream().from_kafka("test").batch(2) - assert np.all( - [k.numpy().tolist() for (k, _) in dataset] - == np.asarray([("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)) - ) - - -def test_kafka_io_dataset(): - dataset = tfio.IODataset.from_kafka( - "test", configuration=["fetch.min.bytes=2"] - ).batch(2) - # repeat multiple times will result in the same result - for _ in range(5): - assert np.all( - [k.numpy().tolist() for (k, _) in dataset] - == np.asarray([("D" + str(i)).encode() for i in range(10)]).reshape((5, 2)) - ) - - -def test_avro_encode_decode(): - """test_avro_encode_decode""" - schema = ( - '{"type":"record","name":"myrecord","fields":' - '[{"name":"f1","type":"string"},{"name":"f2","type":"long"}]}' - ) - value = [("value1", 1), ("value2", 2), ("value3", 3)] - f1 = tf.cast([v[0] for v in value], tf.string) - f2 = tf.cast([v[1] for v in value], tf.int64) - message = tfio.experimental.serialization.encode_avro([f1, f2], schema=schema) - entries = tfio.experimental.serialization.decode_avro(message, schema=schema) - assert np.all(entries["f1"].numpy() == f1.numpy()) - assert np.all(entries["f2"].numpy() == f2.numpy()) - - -def test_kafka_group_io_dataset_primary_cg(): - """Test the functionality of the KafkaGroupIODataset when the consumer group - is being newly created. - - NOTE: After the kafka cluster is setup during the testing phase, 10 messages - are written to the 'key-partition-test' topic with 5 in each partition - (topic created with 2 partitions, the messages are split based on the keys). - And the same 10 messages are written into the 'key-test' topic (topic created - with 1 partition, so no splitting of the messages based on the keys). - - K0:D0, K1:D1, K0:D2, K1:D3, K0:D4, K1:D5, K0:D6, K1:D7, K0:D8, K1:D9. - - Here, messages D0, D2, D4, D6 and D8 are written into partition 0 and the rest are written - into partition 1. - - Also, since the messages are read from different partitions, the order of retrieval may not be - the same as storage. Thus, we sort and compare. - """ - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(10)]) - ) - - -def test_kafka_group_io_dataset_primary_cg_no_lag(): - """Test the functionality of the KafkaGroupIODataset when the - consumer group has read all the messages and committed the offsets. - """ - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], - ) - assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) - - -def test_kafka_group_io_dataset_primary_cg_new_topic(): - """Test the functionality of the KafkaGroupIODataset when the existing - consumer group reads data from a new topic. - """ - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(10)]) - ) - - -def test_kafka_group_io_dataset_resume_primary_cg(): - """Test the functionality of the KafkaGroupIODataset when the - consumer group is yet to catch up with the newly added messages only - (Instead of reading from the beginning). - """ - - # Write new messages to the topic - for i in range(10, 100): - message = "D{}".format(i) - kafka_io.write_kafka(message=message, topic="key-partition-test") - # Read only the newly sent 90 messages - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(10, 100)]) - ) - - -def test_kafka_group_io_dataset_resume_primary_cg_new_topic(): - """Test the functionality of the KafkaGroupIODataset when the - consumer group is yet to catch up with the newly added messages only - (Instead of reading from the beginning) from the new topic. - """ - - # Write new messages to the topic - for i in range(10, 100): - message = "D{}".format(i) - kafka_io.write_kafka(message=message, topic="key-test") - # Read only the newly sent 90 messages - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-test"], - group_id="cgtestprimary", - servers="localhost:9092", - configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(10, 100)]) - ) - - -def test_kafka_group_io_dataset_secondary_cg(): - """Test the functionality of the KafkaGroupIODataset when a - secondary consumer group is created and is yet to catch up all the messages, - from the beginning. - """ - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtestsecondary", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(100)]) - ) - - -def test_kafka_group_io_dataset_tertiary_cg_multiple_topics(): - """Test the functionality of the KafkaGroupIODataset when a new - consumer group reads data from multiple topics from the beginning. - """ - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test", "key-test"], - group_id="cgtesttertiary", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(100)] * 2) - ) - - -def test_kafka_group_io_dataset_auto_offset_reset(): - """Test the functionality of the `auto.offset.reset` configuration - at global and topic level""" - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgglobaloffsetearliest", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(100)]) - ) - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgglobaloffsetlatest", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=latest", - ], - ) - assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtopicoffsetearliest", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "conf.topic.auto.offset.reset=earliest", - ], - ) - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(100)]) - ) - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgtopicoffsetlatest", - servers="localhost:9092", - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "conf.topic.auto.offset.reset=latest", - ], - ) - assert np.all(sorted([k.numpy() for (k, _) in dataset]) == []) - - -def test_kafka_group_io_dataset_invalid_stream_timeout(): - """Test the functionality of the KafkaGroupIODataset when the - consumer is configured to have an invalid stream_timeout value which is - less than the message_timeout value. - NOTE: The default value for message_timeout=5000 - """ - - STREAM_TIMEOUT = -20 - try: - tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test", "key-test"], - group_id="cgteststreaminvalid", - servers="localhost:9092", - stream_timeout=STREAM_TIMEOUT, - configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"], - ) - except ValueError as e: - assert str( - e - ) == "Invalid stream_timeout value: {} ,set it to -1 to block indefinitely.".format( - STREAM_TIMEOUT - ) - - -def test_kafka_group_io_dataset_stream_timeout_check(): - """Test the functionality of the KafkaGroupIODataset when the - consumer is configured to have a valid stream_timeout value and thus waits - for the new messages from kafka. - NOTE: The default value for message_timeout=5000 - """ - - def write_messages_background(): - # Write new messages to the topic in a background thread - time.sleep(6) - for i in range(100, 200): - message = "D{}".format(i) - kafka_io.write_kafka(message=message, topic="key-partition-test") - - dataset = tfio.experimental.streaming.KafkaGroupIODataset( - topics=["key-partition-test"], - group_id="cgteststreamvalid", - servers="localhost:9092", - stream_timeout=20000, - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - - # start writing the new messages to kafka using the background job. - # the job sleeps for some time (< stream_timeout) and then writes the - # messages into the topic. - thread = threading.Thread(target=write_messages_background, args=()) - thread.daemon = True - thread.start() - - # At the end, after the timeout has occurred, we must have the old 100 messages - # along with the new 100 messages - assert np.all( - sorted([k.numpy() for (k, _) in dataset]) - == sorted([("D" + str(i)).encode() for i in range(200)]) - ) - - -def test_kafka_batch_io_dataset(): - """Test the functionality of the KafkaBatchIODataset by training a model - directly on the incoming kafka message batch(of type tf.data.Dataset), in an - online-training fashion. - - NOTE: This kind of dataset is suitable in scenarios where the 'keys' of 'messages' - act as labels. If not, additional transformations are required. - """ - - dataset = tfio.experimental.streaming.KafkaBatchIODataset( - topics=["mini-batch-test"], - group_id="cgminibatch", - servers=None, - stream_timeout=5000, - configuration=[ - "session.timeout.ms=7000", - "max.poll.interval.ms=8000", - "auto.offset.reset=earliest", - ], - ) - - NUM_COLUMNS = 1 - model = tf.keras.Sequential( - [ - tf.keras.layers.Input(shape=(NUM_COLUMNS,)), - tf.keras.layers.Dense(4, activation="relu"), - tf.keras.layers.Dropout(0.1), - tf.keras.layers.Dense(1, activation="sigmoid"), - ] - ) - model.compile( - optimizer="adam", - loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), - metrics=["accuracy"], - ) - assert issubclass(type(dataset), tf.data.Dataset) - for mini_d in dataset: - mini_d = mini_d.map( - lambda m, k: ( - tf.strings.to_number(m, out_type=tf.float32), - tf.strings.to_number(k, out_type=tf.float32), - ) - ).batch(2) - assert issubclass(type(mini_d), tf.data.Dataset) - # Fits the model as long as the data keeps on streaming - model.fit(mini_d, epochs=5) diff --git a/tests/test_kafka_v1.py b/tests/test_kafka_v1.py new file mode 100644 index 000000000..74be1a8ab --- /dev/null +++ b/tests/test_kafka_v1.py @@ -0,0 +1,512 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KafkaDataset.""" + + +import time +import pytest + +import tensorflow as tf + +tf.compat.v1.disable_eager_execution() + +from tensorflow import dtypes # pylint: disable=wrong-import-position +from tensorflow import errors # pylint: disable=wrong-import-position +from tensorflow import test # pylint: disable=wrong-import-position +from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position + +import tensorflow_io.kafka as kafka_io # pylint: disable=wrong-import-position + + +class KafkaDatasetTest(test.TestCase): + """Tests for KafkaDataset.""" + + # The Kafka server has to be setup before the test + # and tear down after the test manually. + # The docker engine has to be installed. + # + # To setup the Kafka server: + # $ bash kafka_test.sh start kafka + # + # To tear down the Kafka server: + # $ bash kafka_test.sh stop kafka + + def test_kafka_dataset(self): + """Tests for KafkaDataset when reading non-keyed messages + from a single-partitioned topic""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset(topics, group="test", eof=True).repeat( + num_epochs + ) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = data.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic test: read a limited number of messages from the topic. + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read all the messages from the topic from offset 5. + sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual(("D" + str(i + 5)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from different subscriptions of the same topic. + sess.run( + init_op, + feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], num_epochs: 1}, + ) + for j in range(2): + for i in range(5): + self.assertEqual( + ("D" + str(i + j * 5)).encode(), sess.run(get_next) + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both subscriptions. + sess.run( + init_op, + feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], num_epochs: 10}, + ) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual( + ("D" + str(i + j * 5)).encode(), sess.run(get_next) + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both subscriptions. + sess.run( + init_batch_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10, + batch_size: 5, + }, + ) + for _ in range(10): + self.assertAllEqual( + [("D" + str(i)).encode() for i in range(5)], sess.run(get_next) + ) + self.assertAllEqual( + [("D" + str(i + 5)).encode() for i in range(5)], sess.run(get_next) + ) + + @pytest.mark.skip(reason="TODO") + def test_kafka_dataset_save_and_restore(self): + """Tests for KafkaDataset save and restore.""" + g = tf.Graph() + with g.as_default(): + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True + ).repeat(num_epochs) + iterator = repeat_dataset.make_initializable_iterator() + get_next = iterator.get_next() + + it = tf.data.experimental.make_saveable_from_iterator(iterator) + g.add_to_collection(tf.compat.v1.GraphKeys.SAVEABLE_OBJECTS, it) + saver = tf.compat.v1.train.Saver() + + model_file = "/tmp/test-kafka-model" + with self.cached_session() as sess: + sess.run( + iterator.initializer, + feed_dict={topics: ["test:0:0:4"], num_epochs: 1}, + ) + for i in range(3): + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + # Save current offset which is 2 + saver.save(sess, model_file, global_step=3) + + checkpoint_file = "/tmp/test-kafka-model-3" + with self.cached_session() as sess: + saver.restore(sess, checkpoint_file) + # Restore current offset to 2 + for i in [2, 3]: + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + + def test_kafka_topic_configuration(self): + """Tests for KafkaDataset topic configuration properties.""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + cfg_list = ["auto.offset.reset=earliest"] + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, config_topic=cfg_list + ).repeat(num_epochs) + + iterator = data.Iterator.from_structure(repeat_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Use a wrong offset 100 here to make sure + # configuration 'auto.offset.reset=earliest' works. + sess.run(init_op, feed_dict={topics: ["test:0:100:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + + def test_kafka_global_configuration(self): + """Tests for KafkaDataset global configuration properties.""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + cfg_list = ["debug=generic", "enable.auto.commit=false"] + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, config_global=cfg_list + ).repeat(num_epochs) + + iterator = data.Iterator.from_structure(repeat_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual(("D" + str(i)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_kafka_wrong_global_configuration_failed(self): + """Tests for KafkaDataset worng global configuration properties.""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + # Add wrong configuration + wrong_cfg = ["debug=al"] + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, config_global=wrong_cfg + ).repeat(num_epochs) + + iterator = data.Iterator.from_structure(repeat_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + with self.assertRaises(errors.InternalError): + sess.run(get_next) + + def test_kafka_wrong_topic_configuration_failed(self): + """Tests for KafkaDataset wrong topic configuration properties.""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + # Add wrong configuration + wrong_cfg = ["auto.offset.reset=arliest"] + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, config_topic=wrong_cfg + ).repeat(num_epochs) + + iterator = data.Iterator.from_structure(repeat_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + with self.assertRaises(errors.InternalError): + sess.run(get_next) + + @pytest.mark.skip(reason="TODO") + def test_write_kafka(self): + """test_write_kafka""" + channel = "e{}e".format(time.time()) + + # Start with reading test topic, replace `D` with `e(time)e`, + # and write to test_e(time)e` topic. + dataset = kafka_io.KafkaDataset(topics=["test:0:0:4"], group="test", eof=True) + dataset = dataset.map( + lambda x: kafka_io.write_kafka( + tf.strings.regex_replace(x, "D", channel), topic="test_" + channel + ) + ) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic test: read from topic 0. + sess.run(init_op) + for i in range(5): + self.assertEqual((channel + str(i)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Reading from `test_e(time)e` we should get the same result + dataset = kafka_io.KafkaDataset( + topics=["test_" + channel], group="test", eof=True + ) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for i in range(5): + self.assertEqual((channel + str(i)).encode(), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_kafka_dataset_with_key(self): + """Tests for KafkaDataset when reading keyed-messages + from a single-partitioned topic""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, message_key=True + ).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = data.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic test: read a limited number of keyed messages from the topic. + sess.run(init_op, feed_dict={topics: ["key-test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual( + (("D" + str(i)).encode(), ("K" + str(i % 2)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read all the keyed messages from the topic from offset 5. + sess.run(init_op, feed_dict={topics: ["key-test:0:5:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual( + (("D" + str(i + 5)).encode(), ("K" + str((i + 5) % 2)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from different subscriptions of the same topic. + sess.run( + init_op, + feed_dict={ + topics: ["key-test:0:0:4", "key-test:0:5:-1"], + num_epochs: 1, + }, + ) + for j in range(2): + for i in range(5): + self.assertEqual( + ( + ("D" + str(i + j * 5)).encode(), + ("K" + str((i + j * 5) % 2)).encode(), + ), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both subscriptions. + sess.run( + init_op, + feed_dict={ + topics: ["key-test:0:0:4", "key-test:0:5:-1"], + num_epochs: 10, + }, + ) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual( + ( + ("D" + str(i + j * 5)).encode(), + ("K" + str((i + j * 5) % 2)).encode(), + ), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both subscriptions. + sess.run( + init_batch_op, + feed_dict={ + topics: ["key-test:0:0:4", "key-test:0:5:-1"], + num_epochs: 10, + batch_size: 5, + }, + ) + for _ in range(10): + self.assertAllEqual( + [ + [("D" + str(i)).encode() for i in range(5)], + [("K" + str(i % 2)).encode() for i in range(5)], + ], + sess.run(get_next), + ) + self.assertAllEqual( + [ + [("D" + str(i + 5)).encode() for i in range(5)], + [("K" + str((i + 5) % 2)).encode() for i in range(5)], + ], + sess.run(get_next), + ) + + def test_kafka_dataset_with_partitioned_key(self): + """Tests for KafkaDataset when reading keyed-messages + from a multi-partitioned topic""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, message_key=True + ).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = data.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic test: read first 5 messages from the first partition of the topic. + # NOTE: The key-partition mapping occurs based on the order in which the data + # is being stored in kafka. Please check kafka_test.sh for the sample data. + + sess.run( + init_op, + feed_dict={topics: ["key-partition-test:0:0:5"], num_epochs: 1}, + ) + for i in range(5): + self.assertEqual( + (("D" + str(i * 2)).encode(), (b"K0")), sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read first 5 messages from the second partition of the topic. + sess.run( + init_op, + feed_dict={topics: ["key-partition-test:1:0:5"], num_epochs: 1}, + ) + for i in range(5): + self.assertEqual( + (("D" + str(i * 2 + 1)).encode(), (b"K1")), sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from different subscriptions to the same topic. + sess.run( + init_op, + feed_dict={ + topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], + num_epochs: 1, + }, + ) + for j in range(2): + for i in range(5): + self.assertEqual( + (("D" + str(i * 2 + j)).encode(), ("K" + str(j)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both subscriptions. + sess.run( + init_op, + feed_dict={ + topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], + num_epochs: 10, + }, + ) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual( + (("D" + str(i * 2 + j)).encode(), ("K" + str(j)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both subscriptions. + sess.run( + init_batch_op, + feed_dict={ + topics: ["key-partition-test:0:0:5", "key-partition-test:1:0:5"], + num_epochs: 10, + batch_size: 5, + }, + ) + for _ in range(10): + for j in range(2): + self.assertAllEqual( + [ + [("D" + str(i * 2 + j)).encode() for i in range(5)], + [("K" + str(j)).encode() for i in range(5)], + ], + sess.run(get_next), + ) + + def test_kafka_dataset_with_offset(self): + """Tests for KafkaDataset when reading non-keyed messages + from a single-partitioned topic""" + topics = tf.compat.v1.placeholder(dtypes.string, shape=[None]) + num_epochs = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + batch_size = tf.compat.v1.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_io.KafkaDataset( + topics, group="test", eof=True, message_offset=True + ).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = data.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Basic offset test: read a limited number of messages from the topic. + sess.run(init_op, feed_dict={topics: ["offset-test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual( + (("D" + str(i)).encode(), ("0:" + str(i)).encode()), + sess.run(get_next), + ) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tests/test_libsvm_eager.py b/tests/test_libsvm.py similarity index 100% rename from tests/test_libsvm_eager.py rename to tests/test_libsvm.py diff --git a/tests/test_lmdb_eager.py b/tests/test_lmdb.py similarity index 100% rename from tests/test_lmdb_eager.py rename to tests/test_lmdb.py diff --git a/tests/test_mongodb_eager.py b/tests/test_mongodb.py similarity index 51% rename from tests/test_mongodb_eager.py rename to tests/test_mongodb.py index b8d0688ee..ce1c7c8db 100644 --- a/tests/test_mongodb_eager.py +++ b/tests/test_mongodb.py @@ -15,22 +15,44 @@ """Tests for the mongodb datasets""" -from datetime import datetime -import time -import json -import pytest import socket -import requests +import pytest import tensorflow as tf from tensorflow import feature_column from tensorflow.keras import layers import tensorflow_io as tfio # COMMON VARIABLES -TIMESTAMP_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ" URI = "mongodb://mongoadmin:default_password@localhost:27017" DATABASE = "tfiodb" COLLECTION = "test" +RECORDS = [ + { + "name": "person1", + "gender": "Male", + "age": 20, + "fare": 80.52, + "vip": False, + "survived": 1, + }, + { + "name": "person2", + "gender": "Female", + "age": 20, + "fare": 40.88, + "vip": True, + "survived": 0, + }, +] * 1000 +SPECS = { + "name": tf.TensorSpec(tf.TensorShape([]), tf.string), + "gender": tf.TensorSpec(tf.TensorShape([]), tf.string), + "age": tf.TensorSpec(tf.TensorShape([]), tf.int32), + "fare": tf.TensorSpec(tf.TensorShape([]), tf.float64), + "vip": tf.TensorSpec(tf.TensorShape([]), tf.bool), + "survived": tf.TensorSpec(tf.TensorShape([]), tf.int64), +} +BATCH_SIZE = 32 def is_container_running(): @@ -53,10 +75,9 @@ def test_writer_write(): writer = tfio.experimental.mongodb.MongoDBWriter( uri=URI, database=DATABASE, collection=COLLECTION ) - timestamp = datetime.utcnow().strftime(TIMESTAMP_PATTERN) - for i in range(1000): - data = {"timestamp": timestamp, "key{}".format(i): "value{}".format(i)} - writer.write(data) + + for record in RECORDS: + writer.write(record) @pytest.mark.skipif(not is_container_running(), reason="The container is not running") @@ -69,7 +90,65 @@ def test_dataset_read(): count = 0 for d in dataset: count += 1 - assert count == 1000 + assert count == len(RECORDS) + + +@pytest.mark.skipif(not is_container_running(), reason="The container is not running") +def test_train_model(): + """Test the dataset by training a tf.keras model""" + + dataset = tfio.experimental.mongodb.MongoDBIODataset( + uri=URI, database=DATABASE, collection=COLLECTION + ) + dataset = dataset.map( + lambda x: tfio.experimental.serialization.decode_json(x, specs=SPECS) + ) + dataset = dataset.map(lambda v: (v, v.pop("survived"))) + dataset = dataset.batch(BATCH_SIZE) + + assert issubclass(type(dataset), tf.data.Dataset) + + feature_columns = [] + + # Numeric column + fare_column = feature_column.numeric_column("fare") + feature_columns.append(fare_column) + + # Bucketized column + age = feature_column.numeric_column("age") + age_buckets = feature_column.bucketized_column(age, boundaries=[10, 30]) + feature_columns.append(age_buckets) + + # Categorical column + gender = feature_column.categorical_column_with_vocabulary_list( + "gender", ["Male", "Female"] + ) + gender_indicator = feature_column.indicator_column(gender) + feature_columns.append(gender_indicator) + + # Convert the feature columns into a tf.keras layer + feature_layer = tf.keras.layers.DenseFeatures(feature_columns) + + # Build the model + model = tf.keras.Sequential( + [ + feature_layer, + layers.Dense(128, activation="relu"), + layers.Dense(128, activation="relu"), + layers.Dropout(0.1), + layers.Dense(1), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), + metrics=["accuracy"], + ) + + # train the model + model.fit(dataset, epochs=5) @pytest.mark.skipif(not is_container_running(), reason="The container is not running") diff --git a/tests/test_obj.py b/tests/test_obj.py new file mode 100644 index 000000000..ff89ef2b8 --- /dev/null +++ b/tests/test_obj.py @@ -0,0 +1,37 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Test Wavefront OBJ""" + +import os +import numpy as np +import pytest + +import tensorflow as tf +import tensorflow_io as tfio + + +def test_decode_obj(): + """Test case for decode obj""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_obj", "sample.obj", + ) + filename = "file://" + filename + + obj = tfio.experimental.image.decode_obj(tf.io.read_file(filename)) + expected = np.array( + [[-0.5, 0.0, 0.4], [-0.5, 0.0, -0.8], [-0.5, 1.0, -0.8], [-0.5, 1.0, 0.4]], + dtype=np.float32, + ) + assert np.array_equal(obj, expected) diff --git a/tests/test_obj/sample.obj b/tests/test_obj/sample.obj new file mode 100644 index 000000000..da8b327ff --- /dev/null +++ b/tests/test_obj/sample.obj @@ -0,0 +1,6 @@ +# Simple Wavefront file +v -0.500000 0.000000 0.400000 +v -0.500000 0.000000 -0.800000 +v -0.500000 1.000000 -0.800000 +v -0.500000 1.000000 0.400000 +f -4 -3 -2 -1 diff --git a/tests/test_parquet_eager.py b/tests/test_parquet.py similarity index 89% rename from tests/test_parquet_eager.py rename to tests/test_parquet.py index 2efd2d2d3..a4e15b1a6 100644 --- a/tests/test_parquet_eager.py +++ b/tests/test_parquet.py @@ -22,6 +22,8 @@ import tensorflow as tf import tensorflow_io as tfio +import pandas as pd + filename = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_parquet", @@ -184,5 +186,24 @@ def f(e): assert v7 == p7.numpy() +def test_parquet_data(): + """Test case for parquet GitHub 1254""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "part-00000-ca0e89bf-ccd7-47e1-925c-9b42c8716c84-c000.snappy.parquet", + ) + parquet = pd.read_parquet(filename) + dataset = tfio.IODataset.from_parquet(filename) + i = 0 + for columns in dataset: + assert columns[b"user_id"] == parquet["user_id"][i] + assert columns[b"movie_id"] == parquet["movie_id"][i] + assert columns[b"movie_title"] == parquet["movie_title"][i] + assert columns[b"rating"] == parquet["rating"][i] + assert columns[b"timestamp"] == parquet["timestamp"][i] + i += 1 + + if __name__ == "__main__": test.main() diff --git a/tests/test_parquet/part-00000-ca0e89bf-ccd7-47e1-925c-9b42c8716c84-c000.snappy.parquet b/tests/test_parquet/part-00000-ca0e89bf-ccd7-47e1-925c-9b42c8716c84-c000.snappy.parquet new file mode 100644 index 000000000..a1eef8197 Binary files /dev/null and b/tests/test_parquet/part-00000-ca0e89bf-ccd7-47e1-925c-9b42c8716c84-c000.snappy.parquet differ diff --git a/tests/test_parse_avro_eager.py b/tests/test_parse_avro.py similarity index 97% rename from tests/test_parse_avro_eager.py rename to tests/test_parse_avro.py index fc4220ad1..83d39e7da 100644 --- a/tests/test_parse_avro_eager.py +++ b/tests/test_parse_avro.py @@ -246,6 +246,63 @@ def _load_records_as_tensors(filenames, schema): ), ) + def test_inval_num_parallel_calls(self): + """test_inval_num_parallel_calls + This function tests that value errors are raised upon + the passing of invalid values for num_parallel_calls which + includes zero values and values greater than num_parallel_reads + """ + + NUM_PARALLEL_READS = 1 + NUM_PARALLEL_CALLS_ZERO = 0 + NUM_PARALLEL_CALLS_GREATER = 2 + + writer_schema = """{ + "type": "record", + "name": "dataTypes", + "fields": [ + { + "name":"index", + "type":"int" + }, + { + "name":"string_value", + "type":"string" + } + ]}""" + + record_data = [ + {"index": 0, "string_value": ""}, + {"index": 1, "string_value": "SpecialChars@!#$%^&*()-_=+{}[]|/`~\\'?"}, + { + "index": 2, + "string_value": "ABCDEFGHIJKLMNOPQRSTUVW" + + "Zabcdefghijklmnopqrstuvwz0123456789", + }, + ] + + filenames = AvroRecordDatasetTest._setup_files( + writer_schema=writer_schema, records=record_data + ) + + with pytest.raises(ValueError): + + dataset_a = tfio.experimental.columnar.AvroRecordDataset( + filenames=filenames, + num_parallel_reads=NUM_PARALLEL_READS, + num_parallel_calls=NUM_PARALLEL_CALLS_ZERO, + reader_schema="reader_schema", + ) + + with pytest.raises(ValueError): + + dataset_b = tfio.experimental.columnar.AvroRecordDataset( + filenames=filenames, + num_parallel_reads=NUM_PARALLEL_READS, + num_parallel_calls=NUM_PARALLEL_CALLS_GREATER, + reader_schema="reader_schema", + ) + def _test_pass_dataset(self, writer_schema, record_data, **kwargs): """test_pass_dataset""" filenames = AvroRecordDatasetTest._setup_files( diff --git a/tests/test_pcap_eager.py b/tests/test_pcap.py similarity index 100% rename from tests/test_pcap_eager.py rename to tests/test_pcap.py diff --git a/tests/test_pulsar_eager.py b/tests/test_pulsar.py similarity index 100% rename from tests/test_pulsar_eager.py rename to tests/test_pulsar.py diff --git a/tests/test_pulsar/pulsar_test.sh b/tests/test_pulsar/pulsar_test.sh index 6abe01b75..ff0bfa357 100644 --- a/tests/test_pulsar/pulsar_test.sh +++ b/tests/test_pulsar/pulsar_test.sh @@ -22,7 +22,7 @@ TAR_FILE="apache-pulsar-${VERSION}-bin.tar.gz" echo "Downloading pulsar ${VERSION}" if [[ ! -f ${TAR_FILE} ]]; then - curl -sSOL "https://downloads.apache.org/pulsar/pulsar-${VERSION}/${TAR_FILE}" + curl -sSOL "https://archive.apache.org/dist/pulsar/pulsar-${VERSION}/${TAR_FILE}" fi tar -xzf ${TAR_FILE} diff --git a/tests/test_s3_eager.py b/tests/test_s3.py similarity index 84% rename from tests/test_s3_eager.py rename to tests/test_s3.py index ad18af1f7..08f928380 100644 --- a/tests/test_s3_eager.py +++ b/tests/test_s3.py @@ -51,11 +51,7 @@ def test_read_file(): response = client.get_object(Bucket=bucket_name, Key=key_name) assert response["Body"].read() == body - os.environ["S3_ENDPOINT"] = "localhost:4566" - os.environ["S3_USE_HTTPS"] = "0" - os.environ["S3_VERIFY_SSL"] = "0" + os.environ["S3_ENDPOINT"] = "http://localhost:4566" - # TODO: The following is not working yet, need update to use - # s3 implementation with module file system - content = tf.io.read_file("s3e://{}/{}".format(bucket_name, key_name)) + content = tf.io.read_file("s3://{}/{}".format(bucket_name, key_name)) assert content == body diff --git a/tests/test_serial_ops.py b/tests/test_serial_ops.py new file mode 100644 index 000000000..ae461963b --- /dev/null +++ b/tests/test_serial_ops.py @@ -0,0 +1,89 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the super_serial.py serialization module.""" +import os +import tempfile + +import numpy as np +import pytest +import tensorflow as tf + +import tensorflow_io as tfio + + +def test_serialization(): + """Test super serial saving and loading. + NOTE- test will only work in eager mode due to list() dataset cast.""" + savefolder = tempfile.TemporaryDirectory() + savepath = os.path.join(savefolder.name, "temp_dataset") + tfrecord_path = savepath + ".tfrecord" + header_path = savepath + ".header" + + # Data + x = np.linspace(1, 3000, num=3000).reshape(10, 10, 10, 3) + y = np.linspace(1, 10, num=10).astype(int) + ds = tf.data.Dataset.from_tensor_slices({"image": x, "label": y}) + + # Run + tfio.experimental.serialization.save_dataset( + ds, tfrecord_path=tfrecord_path, header_path=header_path + ) + new_ds = tfio.experimental.serialization.load_dataset( + tfrecord_path=tfrecord_path, header_path=header_path + ) + + # Test that values were saved and restored + assert ( + list(ds)[0]["image"].numpy()[0, 0, 0] + == list(new_ds)[0]["image"].numpy()[0, 0, 0] + ) + assert list(ds)[0]["label"] == list(new_ds)[0]["label"] + + assert ( + list(ds)[-1]["image"].numpy()[0, 0, 0] + == list(new_ds)[-1]["image"].numpy()[0, 0, 0] + ) + assert list(ds)[-1]["label"] == list(new_ds)[-1]["label"] + + # Clean up- folder will disappear on crash as well. + savefolder.cleanup() + + +@tf.function +def graph_save_fail(): + """Serial ops is expected to raise an exception when + trying to save in graph mode.""" + savefolder = tempfile.TemporaryDirectory() + savepath = os.path.join(savefolder.name, "temp_dataset") + tfrecord_path = savepath + ".tfrecord" + header_path = savepath + ".header" + + # Data + x = np.linspace(1, 3000, num=3000).reshape(10, 10, 10, 3) + y = np.linspace(1, 10, num=10).astype(int) + ds = tf.data.Dataset.from_tensor_slices({"image": x, "label": y}) + + # Run + assert os.path.isdir(savefolder.name) + assert not tf.executing_eagerly() + tfio.experimental.serialization.save_dataset( + ds, tfrecord_path=tfrecord_path, header_path=header_path + ) + + +def test_ensure_graph_fail(): + """Test that super_serial fails in graph mode.""" + with pytest.raises(ValueError): + graph_save_fail() diff --git a/tests/test_serialization_eager.py b/tests/test_serialization.py similarity index 100% rename from tests/test_serialization_eager.py rename to tests/test_serialization.py diff --git a/tests/test_text_eager.py b/tests/test_text.py similarity index 100% rename from tests/test_text_eager.py rename to tests/test_text.py diff --git a/tests/test_version_eager.py b/tests/test_version.py similarity index 100% rename from tests/test_version_eager.py rename to tests/test_version.py diff --git a/tests/test_video_eager.py b/tests/test_video.py similarity index 100% rename from tests/test_video_eager.py rename to tests/test_video.py diff --git a/third_party/BUILD b/third_party/BUILD index f364c8be7..7b8d4c583 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -28,21 +28,3 @@ cc_library( visibility = ["//visibility:public"], deps = [], ) - -# parquet_types.[h|cpp] are generated from apache arrow 0.16.0 -# on Ubuntu 16.04 with default gcc/g++/cmake/flex/bison. -cc_library( - name = "parquet", - srcs = [ - "parquet/parquet_types.h", - ], - hdrs = [], - copts = [], - includes = ["."], - visibility = ["//visibility:public"], - deps = [], -) - -exports_files([ - "parquet/parquet_types.cpp", -]) diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index b60118cee..85a5a5969 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -31,9 +31,10 @@ genrule( srcs = ["cpp/src/arrow/util/config.h.cmake"], outs = ["cpp/src/arrow/util/config.h"], cmd = ("sed " + - "-e 's/@ARROW_VERSION_MAJOR@/0/g' " + - "-e 's/@ARROW_VERSION_MINOR@/16/g' " + + "-e 's/@ARROW_VERSION_MAJOR@/3/g' " + + "-e 's/@ARROW_VERSION_MINOR@/0/g' " + "-e 's/@ARROW_VERSION_PATCH@/0/g' " + + "-e 's/cmakedefine ARROW_USE_NATIVE_INT128/undef ARROW_USE_NATIVE_INT128/g' " + "-e 's/cmakedefine/define/g' " + "$< >$@"), ) @@ -59,13 +60,17 @@ cc_library( "cpp/src/arrow/io/*.cc", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/json/*.cc", + "cpp/src/arrow/tensor/*.cc", "cpp/src/arrow/util/*.cc", + "cpp/src/arrow/vendored/musl/strptime.c", "cpp/src/arrow/vendored/optional.hpp", "cpp/src/arrow/vendored/string_view.hpp", "cpp/src/arrow/vendored/variant.hpp", "cpp/src/arrow/**/*.h", "cpp/src/parquet/**/*.h", "cpp/src/parquet/**/*.cc", + "cpp/src/generated/*.h", + "cpp/src/generated/*.cpp", ], exclude = [ "cpp/src/**/*_benchmark.cc", @@ -77,16 +82,16 @@ cc_library( "cpp/src/**/*fuzz*.cc", "cpp/src/**/file_to_stream.cc", "cpp/src/**/stream_to_file.cc", + "cpp/src/arrow/util/bpacking_avx2.cc", + "cpp/src/arrow/util/bpacking_avx512.cc", ], - ) + [ - "@org_tensorflow_io//third_party:parquet/parquet_types.cpp", - ], + ), hdrs = [ # declare header from above genrule "cpp/src/arrow/util/config.h", "cpp/src/parquet/parquet_version.h", ], - copts = ["-std=c++11"], + copts = [], defines = [ "ARROW_WITH_BROTLI", "ARROW_WITH_SNAPPY", @@ -112,7 +117,6 @@ cc_library( "@bzip2", "@double-conversion", "@lz4", - "@org_tensorflow_io//third_party:parquet", "@rapidjson", "@snappy", "@thrift", diff --git a/third_party/aws-sdk-cpp.BUILD b/third_party/aws-sdk-cpp.BUILD index 636d50c95..cd939724e 100644 --- a/third_party/aws-sdk-cpp.BUILD +++ b/third_party/aws-sdk-cpp.BUILD @@ -8,51 +8,73 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) cc_library( - name = "aws-sdk-cpp", + name = "core", srcs = glob([ - "aws-cpp-sdk-core/include/**/*.h", - "aws-cpp-sdk-core/source/*.cpp", - "aws-cpp-sdk-core/source/auth/**/*.cpp", - "aws-cpp-sdk-core/source/client/**/*.cpp", - "aws-cpp-sdk-core/source/config/**/*.cpp", - "aws-cpp-sdk-core/source/external/**/*.cpp", - "aws-cpp-sdk-core/source/http/*.cpp", - "aws-cpp-sdk-core/source/http/curl/*.cpp", - "aws-cpp-sdk-core/source/http/standard/*.cpp", - "aws-cpp-sdk-core/source/internal/**/*.cpp", - "aws-cpp-sdk-core/source/monitoring/**/*.cpp", - "aws-cpp-sdk-core/source/utils/*.cpp", - "aws-cpp-sdk-core/source/utils/base64/**/*.cpp", - "aws-cpp-sdk-core/source/utils/crypto/*.cpp", - "aws-cpp-sdk-core/source/utils/crypto/factory/*.cpp", - "aws-cpp-sdk-core/source/utils/crypto/openssl/CryptoImpl.cpp", - "aws-cpp-sdk-core/source/utils/event/**/*.cpp", - "aws-cpp-sdk-core/source/utils/json/**/*.cpp", - "aws-cpp-sdk-core/source/utils/logging/**/*.cpp", - "aws-cpp-sdk-core/source/utils/memory/**/*.cpp", - "aws-cpp-sdk-core/source/utils/stream/**/*.cpp", - "aws-cpp-sdk-core/source/utils/threading/**/*.cpp", - "aws-cpp-sdk-core/source/utils/xml/**/*.cpp", - "aws-cpp-sdk-kinesis/include/**/*.h", - "aws-cpp-sdk-kinesis/source/**/*.cpp", - "aws-cpp-sdk-s3/include/**/*.h", - "aws-cpp-sdk-s3/source/**/*.cpp", - "aws-cpp-sdk-transfer/include/**/*.h", - "aws-cpp-sdk-transfer/source/**/*.cpp", + "aws-cpp-sdk-core/source/*.cpp", # AWS_SOURCE + "aws-cpp-sdk-core/source/external/tinyxml2/*.cpp", # AWS_TINYXML2_SOURCE + "aws-cpp-sdk-core/source/external/cjson/*.cpp", # CJSON_SOURCE + "aws-cpp-sdk-core/source/auth/*.cpp", # AWS_AUTH_SOURCE + "aws-cpp-sdk-core/source/client/*.cpp", # AWS_CLIENT_SOURCE + "aws-cpp-sdk-core/source/internal/*.cpp", # AWS_INTERNAL_SOURCE + "aws-cpp-sdk-core/source/aws/model/*.cpp", # AWS_MODEL_SOURCE + "aws-cpp-sdk-core/source/http/*.cpp", # HTTP_SOURCE + "aws-cpp-sdk-core/source/http/standard/*.cpp", # HTTP_STANDARD_SOURCE + "aws-cpp-sdk-core/source/config/*.cpp", # CONFIG_SOURCE + "aws-cpp-sdk-core/source/monitoring/*.cpp", # MONITORING_SOURCE + "aws-cpp-sdk-core/source/utils/*.cpp", # UTILS_SOURCE + "aws-cpp-sdk-core/source/utils/event/*.cpp", # UTILS_EVENT_SOURCE + "aws-cpp-sdk-core/source/utils/base64/*.cpp", # UTILS_BASE64_SOURCE + "aws-cpp-sdk-core/source/utils/crypto/*.cpp", # UTILS_CRYPTO_SOURCE + "aws-cpp-sdk-core/source/utils/json/*.cpp", # UTILS_JSON_SOURCE + "aws-cpp-sdk-core/source/utils/threading/*.cpp", # UTILS_THREADING_SOURCE + "aws-cpp-sdk-core/source/utils/xml/*.cpp", # UTILS_XML_SOURCE + "aws-cpp-sdk-core/source/utils/logging/*.cpp", # UTILS_LOGGING_SOURCE + "aws-cpp-sdk-core/source/utils/memory/*.cpp", # UTILS_MEMORY_SOURCE + "aws-cpp-sdk-core/source/utils/memory/stl/*.cpp", # UTILS_MEMORY_STL_SOURCE + "aws-cpp-sdk-core/source/utils/stream/*.cpp", # UTILS_STREAM_SOURCE + "aws-cpp-sdk-core/source/utils/crypto/factory/*.cpp", # UTILS_CRYPTO_FACTORY_SOURCE + "aws-cpp-sdk-core/source/http/curl/*.cpp", # HTTP_CURL_CLIENT_SOURCE + "aws-cpp-sdk-core/source/utils/crypto/openssl/*.cpp", # UTILS_CRYPTO_OPENSSL_SOURCE ]) + select({ "@bazel_tools//src/conditions:windows": glob([ - "aws-cpp-sdk-core/source/http/windows/*.cpp", - "aws-cpp-sdk-core/source/net/windows/*.cpp", - "aws-cpp-sdk-core/source/platform/windows/*.cpp", + "aws-cpp-sdk-core/source/net/windows/*.cpp", # NET_SOURCE + "aws-cpp-sdk-core/source/platform/windows/*.cpp", # PLATFORM_WINDOWS_SOURCE ]), "//conditions:default": glob([ - "aws-cpp-sdk-core/source/net/linux-shared/*.cpp", - "aws-cpp-sdk-core/source/platform/linux-shared/*.cpp", + "aws-cpp-sdk-core/source/net/linux-shared/*.cpp", # NET_SOURCE + "aws-cpp-sdk-core/source/platform/linux-shared/*.cpp", # PLATFORM_LINUX_SHARED_SOURCE ]), }), hdrs = [ "aws-cpp-sdk-core/include/aws/core/SDKConfig.h", - ], + ] + glob([ + "aws-cpp-sdk-core/include/aws/core/*.h", # AWS_HEADERS + "aws-cpp-sdk-core/include/aws/core/auth/*.h", # AWS_AUTH_HEADERS + "aws-cpp-sdk-core/include/aws/core/client/*.h", # AWS_CLIENT_HEADERS + "aws-cpp-sdk-core/include/aws/core/internal/*.h", # AWS_INTERNAL_HEADERS + "aws-cpp-sdk-core/include/aws/core/net/*.h", # NET_HEADERS + "aws-cpp-sdk-core/include/aws/core/http/*.h", # HTTP_HEADERS + "aws-cpp-sdk-core/include/aws/core/http/standard/*.h", # HTTP_STANDARD_HEADERS + "aws-cpp-sdk-core/include/aws/core/config/*.h", # CONFIG_HEADERS + "aws-cpp-sdk-core/include/aws/core/monitoring/*.h", # MONITORING_HEADERS + "aws-cpp-sdk-core/include/aws/core/platform/*.h", # PLATFORM_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/*.h", # UTILS_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/event/*.h", # UTILS_EVENT_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/base64/*.h", # UTILS_BASE64_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/crypto/*.h", # UTILS_CRYPTO_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/json/*.h", # UTILS_JSON_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/threading/*.h", # UTILS_THREADING_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/xml/*.h", # UTILS_XML_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/memory/*.h", # UTILS_MEMORY_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/memory/stl/*.h", # UTILS_STL_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/logging/*.h", # UTILS_LOGGING_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/ratelimiter/*.h", # UTILS_RATE_LIMITER_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/stream/*.h", # UTILS_STREAM_HEADERS + "aws-cpp-sdk-core/include/aws/core/external/cjson/*.h", # CJSON_HEADERS + "aws-cpp-sdk-core/include/aws/core/external/tinyxml2/*.h", # TINYXML2_HEADERS + "aws-cpp-sdk-core/include/aws/core/http/curl/*.h", # HTTP_CURL_CLIENT_HEADERS + "aws-cpp-sdk-core/include/aws/core/utils/crypto/openssl/*.h", # UTILS_CRYPTO_OPENSSL_HEADERS + ]), defines = [ 'AWS_SDK_VERSION_STRING=\\"1.7.366\\"', "AWS_SDK_VERSION_MAJOR=1", @@ -72,15 +94,7 @@ cc_library( }), includes = [ "aws-cpp-sdk-core/include", - "aws-cpp-sdk-kinesis/include", - "aws-cpp-sdk-s3/include", - "aws-cpp-sdk-transfer/include", - ] + select({ - "@bazel_tools//src/conditions:windows": [ - "aws-cpp-sdk-core/include/aws/core/platform/refs", - ], - "//conditions:default": [], - }), + ], linkopts = select({ "@bazel_tools//src/conditions:windows": [ "-DEFAULTLIB:userenv.lib", @@ -89,14 +103,64 @@ cc_library( "//conditions:default": [], }), deps = [ - "@aws-c-common", "@aws-c-event-stream", - "@aws-checksums", - "@boringssl//:crypto", "@curl", ], ) +cc_library( + name = "s3", + srcs = glob([ + "aws-cpp-sdk-s3/source/*.cpp", # AWS_S3_SOURCE + "aws-cpp-sdk-s3/source/model/*.cpp", # AWS_S3_MODEL_SOURCE + ]), + hdrs = glob([ + "aws-cpp-sdk-s3/include/aws/s3/*.h", # AWS_S3_HEADERS + "aws-cpp-sdk-s3/include/aws/s3/model/*.h", # AWS_S3_MODEL_HEADERS + ]), + includes = [ + "aws-cpp-sdk-s3/include", + ], + deps = [ + ":core", + ], +) + +cc_library( + name = "transfer", + srcs = glob([ + "aws-cpp-sdk-transfer/source/transfer/*.cpp", # TRANSFER_SOURCE + ]), + hdrs = glob([ + "aws-cpp-sdk-transfer/include/aws/transfer/*.h", # TRANSFER_HEADERS + ]), + includes = [ + "aws-cpp-sdk-transfer/include", + ], + deps = [ + ":core", + ":s3", + ], +) + +cc_library( + name = "kinesis", + srcs = glob([ + "aws-cpp-sdk-kinesis/source/*.cpp", # AWS_KINESIS_SOURCE + "aws-cpp-sdk-kinesis/source/model/*.cpp", # AWS_KINESIS_MODEL_SOURCE + ]), + hdrs = glob([ + "aws-cpp-sdk-kinesis/include/aws/kinesis/*.h", # AWS_KINESIS_HEADERS + "aws-cpp-sdk-kinesis/include/aws/kinesis/model/*.h", # AWS_KINESIS_MODEL_HEADERS + ]), + includes = [ + "aws-cpp-sdk-kinesis/include", + ], + deps = [ + ":core", + ], +) + genrule( name = "SDKConfig_h", outs = [ diff --git a/third_party/azure.BUILD b/third_party/azure.BUILD index 4c11b5864..b1657b284 100644 --- a/third_party/azure.BUILD +++ b/third_party/azure.BUILD @@ -16,7 +16,6 @@ cc_library( ]), hdrs = [], defines = [ - "azure_storage_lite_EXPORTS", "USE_OPENSSL", ] + select({ "@bazel_tools//src/conditions:windows": [ diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index 0bcbee9d1..0fe66bf9d 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -14,7 +14,9 @@ cc_library( "lib/vtls/*.h", "lib/vauth/*.h", "lib/vauth/*.c", - ]), + ]) + [ + "lib/vssh/ssh.h", + ], hdrs = glob([ "include/curl/*.h", ]) + [ @@ -110,7 +112,7 @@ genrule( "# define HAVE_SYS_FILIO_H 1", "# define HAVE_SYS_SOCKIO_H 1", "# define OS \"x86_64-apple-darwin15.5.0\"", - "# define USE_DARWINSSL 1", + "# define USE_SECTRANSP 1", "#else", "# define CURL_CA_BUNDLE \"/etc/ssl/certs/ca-certificates.crt\"", "# define GETSERVBYPORT_R_ARGS 6", diff --git a/third_party/parquet/parquet_types.cpp b/third_party/parquet/parquet_types.cpp deleted file mode 100644 index 327f75a1b..000000000 --- a/third_party/parquet/parquet_types.cpp +++ /dev/null @@ -1,7348 +0,0 @@ -/** - * Autogenerated by Thrift Compiler (0.12.0) - * - * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING - * @generated - */ -#include "parquet_types.h" - -#include -#include - -#include - -namespace parquet { namespace format { - -int _kTypeValues[] = { - Type::BOOLEAN, - Type::INT32, - Type::INT64, - Type::INT96, - Type::FLOAT, - Type::DOUBLE, - Type::BYTE_ARRAY, - Type::FIXED_LEN_BYTE_ARRAY -}; -const char* _kTypeNames[] = { - "BOOLEAN", - "INT32", - "INT64", - "INT96", - "FLOAT", - "DOUBLE", - "BYTE_ARRAY", - "FIXED_LEN_BYTE_ARRAY" -}; -const std::map _Type_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(8, _kTypeValues, _kTypeNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL)); - -std::ostream& operator<<(std::ostream& out, const Type::type& val) { - std::map::const_iterator it = _Type_VALUES_TO_NAMES.find(val); - if (it != _Type_VALUES_TO_NAMES.end()) { - out << it->second; - } else { - out << static_cast(val); - } - return out; -} - -int _kConvertedTypeValues[] = { - ConvertedType::UTF8, - ConvertedType::MAP, - ConvertedType::MAP_KEY_VALUE, - ConvertedType::LIST, - ConvertedType::ENUM, - ConvertedType::DECIMAL, - ConvertedType::DATE, - ConvertedType::TIME_MILLIS, - ConvertedType::TIME_MICROS, - ConvertedType::TIMESTAMP_MILLIS, - ConvertedType::TIMESTAMP_MICROS, - ConvertedType::UINT_8, - ConvertedType::UINT_16, - ConvertedType::UINT_32, - ConvertedType::UINT_64, - ConvertedType::INT_8, - ConvertedType::INT_16, - ConvertedType::INT_32, - ConvertedType::INT_64, - ConvertedType::JSON, - ConvertedType::BSON, - ConvertedType::INTERVAL -}; -const char* _kConvertedTypeNames[] = { - "UTF8", - "MAP", - "MAP_KEY_VALUE", - "LIST", - "ENUM", - "DECIMAL", - "DATE", - "TIME_MILLIS", - "TIME_MICROS", - "TIMESTAMP_MILLIS", - "TIMESTAMP_MICROS", - "UINT_8", - "UINT_16", - "UINT_32", - "UINT_64", - "INT_8", - "INT_16", - "INT_32", - "INT_64", - "JSON", - "BSON", - "INTERVAL" -}; -const std::map _ConvertedType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(22, _kConvertedTypeValues, _kConvertedTypeNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL)); - -std::ostream& operator<<(std::ostream& out, const ConvertedType::type& val) { - std::map::const_iterator it = _ConvertedType_VALUES_TO_NAMES.find(val); - if (it != _ConvertedType_VALUES_TO_NAMES.end()) { - out << it->second; - } else { - out << static_cast(val); - } - return out; -} - -int _kFieldRepetitionTypeValues[] = { - FieldRepetitionType::REQUIRED, - FieldRepetitionType::OPTIONAL, - FieldRepetitionType::REPEATED -}; -const char* _kFieldRepetitionTypeNames[] = { - "REQUIRED", - "OPTIONAL", - "REPEATED" -}; -const std::map _FieldRepetitionType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(3, _kFieldRepetitionTypeValues, _kFieldRepetitionTypeNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL)); - -std::ostream& operator<<(std::ostream& out, const FieldRepetitionType::type& val) { - std::map::const_iterator it = _FieldRepetitionType_VALUES_TO_NAMES.find(val); - if (it != _FieldRepetitionType_VALUES_TO_NAMES.end()) { - out << it->second; - } else { - out << static_cast(val); - } - return out; -} - -int _kEncodingValues[] = { - Encoding::PLAIN, - Encoding::PLAIN_DICTIONARY, - Encoding::RLE, - Encoding::BIT_PACKED, - Encoding::DELTA_BINARY_PACKED, - Encoding::DELTA_LENGTH_BYTE_ARRAY, - Encoding::DELTA_BYTE_ARRAY, - Encoding::RLE_DICTIONARY, - Encoding::BYTE_STREAM_SPLIT -}; -const char* _kEncodingNames[] = { - "PLAIN", - "PLAIN_DICTIONARY", - "RLE", - "BIT_PACKED", - "DELTA_BINARY_PACKED", - "DELTA_LENGTH_BYTE_ARRAY", - "DELTA_BYTE_ARRAY", - "RLE_DICTIONARY", - "BYTE_STREAM_SPLIT" -}; -const std::map _Encoding_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(9, _kEncodingValues, _kEncodingNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL)); - -std::ostream& operator<<(std::ostream& out, const Encoding::type& val) { - std::map::const_iterator it = _Encoding_VALUES_TO_NAMES.find(val); - if (it != _Encoding_VALUES_TO_NAMES.end()) { - out << it->second; - } else { - out << static_cast(val); - } - return out; -} - -int _kCompressionCodecValues[] = { - CompressionCodec::UNCOMPRESSED, - CompressionCodec::SNAPPY, - CompressionCodec::GZIP, - CompressionCodec::LZO, - CompressionCodec::BROTLI, - CompressionCodec::LZ4, - CompressionCodec::ZSTD -}; -const char* _kCompressionCodecNames[] = { - "UNCOMPRESSED", - "SNAPPY", - "GZIP", - "LZO", - "BROTLI", - "LZ4", - "ZSTD" -}; -const std::map _CompressionCodec_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(7, _kCompressionCodecValues, _kCompressionCodecNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL)); - -std::ostream& operator<<(std::ostream& out, const CompressionCodec::type& val) { - std::map::const_iterator it = _CompressionCodec_VALUES_TO_NAMES.find(val); - if (it != _CompressionCodec_VALUES_TO_NAMES.end()) { - out << it->second; - } else { - out << static_cast(val); - } - return out; -} - -int _kPageTypeValues[] = { - PageType::DATA_PAGE, - PageType::INDEX_PAGE, - PageType::DICTIONARY_PAGE, - PageType::DATA_PAGE_V2 -}; -const char* _kPageTypeNames[] = { - "DATA_PAGE", - "INDEX_PAGE", - "DICTIONARY_PAGE", - "DATA_PAGE_V2" -}; -const std::map _PageType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(4, _kPageTypeValues, _kPageTypeNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL)); - -std::ostream& operator<<(std::ostream& out, const PageType::type& val) { - std::map::const_iterator it = _PageType_VALUES_TO_NAMES.find(val); - if (it != _PageType_VALUES_TO_NAMES.end()) { - out << it->second; - } else { - out << static_cast(val); - } - return out; -} - -int _kBoundaryOrderValues[] = { - BoundaryOrder::UNORDERED, - BoundaryOrder::ASCENDING, - BoundaryOrder::DESCENDING -}; -const char* _kBoundaryOrderNames[] = { - "UNORDERED", - "ASCENDING", - "DESCENDING" -}; -const std::map _BoundaryOrder_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(3, _kBoundaryOrderValues, _kBoundaryOrderNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL)); - -std::ostream& operator<<(std::ostream& out, const BoundaryOrder::type& val) { - std::map::const_iterator it = _BoundaryOrder_VALUES_TO_NAMES.find(val); - if (it != _BoundaryOrder_VALUES_TO_NAMES.end()) { - out << it->second; - } else { - out << static_cast(val); - } - return out; -} - - -Statistics::~Statistics() throw() { -} - - -void Statistics::__set_max(const std::string& val) { - this->max = val; -__isset.max = true; -} - -void Statistics::__set_min(const std::string& val) { - this->min = val; -__isset.min = true; -} - -void Statistics::__set_null_count(const int64_t val) { - this->null_count = val; -__isset.null_count = true; -} - -void Statistics::__set_distinct_count(const int64_t val) { - this->distinct_count = val; -__isset.distinct_count = true; -} - -void Statistics::__set_max_value(const std::string& val) { - this->max_value = val; -__isset.max_value = true; -} - -void Statistics::__set_min_value(const std::string& val) { - this->min_value = val; -__isset.min_value = true; -} -std::ostream& operator<<(std::ostream& out, const Statistics& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t Statistics::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->max); - this->__isset.max = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->min); - this->__isset.min = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->null_count); - this->__isset.null_count = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->distinct_count); - this->__isset.distinct_count = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->max_value); - this->__isset.max_value = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->min_value); - this->__isset.min_value = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t Statistics::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("Statistics"); - - if (this->__isset.max) { - xfer += oprot->writeFieldBegin("max", ::apache::thrift::protocol::T_STRING, 1); - xfer += oprot->writeBinary(this->max); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.min) { - xfer += oprot->writeFieldBegin("min", ::apache::thrift::protocol::T_STRING, 2); - xfer += oprot->writeBinary(this->min); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.null_count) { - xfer += oprot->writeFieldBegin("null_count", ::apache::thrift::protocol::T_I64, 3); - xfer += oprot->writeI64(this->null_count); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.distinct_count) { - xfer += oprot->writeFieldBegin("distinct_count", ::apache::thrift::protocol::T_I64, 4); - xfer += oprot->writeI64(this->distinct_count); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.max_value) { - xfer += oprot->writeFieldBegin("max_value", ::apache::thrift::protocol::T_STRING, 5); - xfer += oprot->writeBinary(this->max_value); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.min_value) { - xfer += oprot->writeFieldBegin("min_value", ::apache::thrift::protocol::T_STRING, 6); - xfer += oprot->writeBinary(this->min_value); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(Statistics &a, Statistics &b) { - using ::std::swap; - swap(a.max, b.max); - swap(a.min, b.min); - swap(a.null_count, b.null_count); - swap(a.distinct_count, b.distinct_count); - swap(a.max_value, b.max_value); - swap(a.min_value, b.min_value); - swap(a.__isset, b.__isset); -} - -Statistics::Statistics(const Statistics& other0) { - max = other0.max; - min = other0.min; - null_count = other0.null_count; - distinct_count = other0.distinct_count; - max_value = other0.max_value; - min_value = other0.min_value; - __isset = other0.__isset; -} -Statistics& Statistics::operator=(const Statistics& other1) { - max = other1.max; - min = other1.min; - null_count = other1.null_count; - distinct_count = other1.distinct_count; - max_value = other1.max_value; - min_value = other1.min_value; - __isset = other1.__isset; - return *this; -} -void Statistics::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "Statistics("; - out << "max="; (__isset.max ? (out << to_string(max)) : (out << "")); - out << ", " << "min="; (__isset.min ? (out << to_string(min)) : (out << "")); - out << ", " << "null_count="; (__isset.null_count ? (out << to_string(null_count)) : (out << "")); - out << ", " << "distinct_count="; (__isset.distinct_count ? (out << to_string(distinct_count)) : (out << "")); - out << ", " << "max_value="; (__isset.max_value ? (out << to_string(max_value)) : (out << "")); - out << ", " << "min_value="; (__isset.min_value ? (out << to_string(min_value)) : (out << "")); - out << ")"; -} - - -StringType::~StringType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const StringType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t StringType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t StringType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("StringType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(StringType &a, StringType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -StringType::StringType(const StringType& other2) { - (void) other2; -} -StringType& StringType::operator=(const StringType& other3) { - (void) other3; - return *this; -} -void StringType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "StringType("; - out << ")"; -} - - -UUIDType::~UUIDType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const UUIDType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t UUIDType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t UUIDType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("UUIDType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(UUIDType &a, UUIDType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -UUIDType::UUIDType(const UUIDType& other4) { - (void) other4; -} -UUIDType& UUIDType::operator=(const UUIDType& other5) { - (void) other5; - return *this; -} -void UUIDType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "UUIDType("; - out << ")"; -} - - -MapType::~MapType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const MapType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t MapType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t MapType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("MapType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(MapType &a, MapType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -MapType::MapType(const MapType& other6) { - (void) other6; -} -MapType& MapType::operator=(const MapType& other7) { - (void) other7; - return *this; -} -void MapType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "MapType("; - out << ")"; -} - - -ListType::~ListType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const ListType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t ListType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t ListType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("ListType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(ListType &a, ListType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -ListType::ListType(const ListType& other8) { - (void) other8; -} -ListType& ListType::operator=(const ListType& other9) { - (void) other9; - return *this; -} -void ListType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "ListType("; - out << ")"; -} - - -EnumType::~EnumType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const EnumType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t EnumType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t EnumType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("EnumType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(EnumType &a, EnumType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -EnumType::EnumType(const EnumType& other10) { - (void) other10; -} -EnumType& EnumType::operator=(const EnumType& other11) { - (void) other11; - return *this; -} -void EnumType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "EnumType("; - out << ")"; -} - - -DateType::~DateType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const DateType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t DateType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t DateType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("DateType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(DateType &a, DateType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -DateType::DateType(const DateType& other12) { - (void) other12; -} -DateType& DateType::operator=(const DateType& other13) { - (void) other13; - return *this; -} -void DateType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "DateType("; - out << ")"; -} - - -NullType::~NullType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const NullType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t NullType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t NullType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("NullType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(NullType &a, NullType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -NullType::NullType(const NullType& other14) { - (void) other14; -} -NullType& NullType::operator=(const NullType& other15) { - (void) other15; - return *this; -} -void NullType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "NullType("; - out << ")"; -} - - -DecimalType::~DecimalType() throw() { -} - - -void DecimalType::__set_scale(const int32_t val) { - this->scale = val; -} - -void DecimalType::__set_precision(const int32_t val) { - this->precision = val; -} -std::ostream& operator<<(std::ostream& out, const DecimalType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t DecimalType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_scale = false; - bool isset_precision = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->scale); - isset_scale = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->precision); - isset_precision = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_scale) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_precision) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t DecimalType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("DecimalType"); - - xfer += oprot->writeFieldBegin("scale", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32(this->scale); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("precision", ::apache::thrift::protocol::T_I32, 2); - xfer += oprot->writeI32(this->precision); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(DecimalType &a, DecimalType &b) { - using ::std::swap; - swap(a.scale, b.scale); - swap(a.precision, b.precision); -} - -DecimalType::DecimalType(const DecimalType& other16) { - scale = other16.scale; - precision = other16.precision; -} -DecimalType& DecimalType::operator=(const DecimalType& other17) { - scale = other17.scale; - precision = other17.precision; - return *this; -} -void DecimalType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "DecimalType("; - out << "scale=" << to_string(scale); - out << ", " << "precision=" << to_string(precision); - out << ")"; -} - - -MilliSeconds::~MilliSeconds() throw() { -} - -std::ostream& operator<<(std::ostream& out, const MilliSeconds& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t MilliSeconds::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t MilliSeconds::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("MilliSeconds"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(MilliSeconds &a, MilliSeconds &b) { - using ::std::swap; - (void) a; - (void) b; -} - -MilliSeconds::MilliSeconds(const MilliSeconds& other18) { - (void) other18; -} -MilliSeconds& MilliSeconds::operator=(const MilliSeconds& other19) { - (void) other19; - return *this; -} -void MilliSeconds::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "MilliSeconds("; - out << ")"; -} - - -MicroSeconds::~MicroSeconds() throw() { -} - -std::ostream& operator<<(std::ostream& out, const MicroSeconds& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t MicroSeconds::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t MicroSeconds::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("MicroSeconds"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(MicroSeconds &a, MicroSeconds &b) { - using ::std::swap; - (void) a; - (void) b; -} - -MicroSeconds::MicroSeconds(const MicroSeconds& other20) { - (void) other20; -} -MicroSeconds& MicroSeconds::operator=(const MicroSeconds& other21) { - (void) other21; - return *this; -} -void MicroSeconds::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "MicroSeconds("; - out << ")"; -} - - -NanoSeconds::~NanoSeconds() throw() { -} - -std::ostream& operator<<(std::ostream& out, const NanoSeconds& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t NanoSeconds::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t NanoSeconds::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("NanoSeconds"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(NanoSeconds &a, NanoSeconds &b) { - using ::std::swap; - (void) a; - (void) b; -} - -NanoSeconds::NanoSeconds(const NanoSeconds& other22) { - (void) other22; -} -NanoSeconds& NanoSeconds::operator=(const NanoSeconds& other23) { - (void) other23; - return *this; -} -void NanoSeconds::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "NanoSeconds("; - out << ")"; -} - - -TimeUnit::~TimeUnit() throw() { -} - - -void TimeUnit::__set_MILLIS(const MilliSeconds& val) { - this->MILLIS = val; -__isset.MILLIS = true; -} - -void TimeUnit::__set_MICROS(const MicroSeconds& val) { - this->MICROS = val; -__isset.MICROS = true; -} - -void TimeUnit::__set_NANOS(const NanoSeconds& val) { - this->NANOS = val; -__isset.NANOS = true; -} -std::ostream& operator<<(std::ostream& out, const TimeUnit& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t TimeUnit::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->MILLIS.read(iprot); - this->__isset.MILLIS = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->MICROS.read(iprot); - this->__isset.MICROS = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->NANOS.read(iprot); - this->__isset.NANOS = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t TimeUnit::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("TimeUnit"); - - if (this->__isset.MILLIS) { - xfer += oprot->writeFieldBegin("MILLIS", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->MILLIS.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.MICROS) { - xfer += oprot->writeFieldBegin("MICROS", ::apache::thrift::protocol::T_STRUCT, 2); - xfer += this->MICROS.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.NANOS) { - xfer += oprot->writeFieldBegin("NANOS", ::apache::thrift::protocol::T_STRUCT, 3); - xfer += this->NANOS.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(TimeUnit &a, TimeUnit &b) { - using ::std::swap; - swap(a.MILLIS, b.MILLIS); - swap(a.MICROS, b.MICROS); - swap(a.NANOS, b.NANOS); - swap(a.__isset, b.__isset); -} - -TimeUnit::TimeUnit(const TimeUnit& other24) { - MILLIS = other24.MILLIS; - MICROS = other24.MICROS; - NANOS = other24.NANOS; - __isset = other24.__isset; -} -TimeUnit& TimeUnit::operator=(const TimeUnit& other25) { - MILLIS = other25.MILLIS; - MICROS = other25.MICROS; - NANOS = other25.NANOS; - __isset = other25.__isset; - return *this; -} -void TimeUnit::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "TimeUnit("; - out << "MILLIS="; (__isset.MILLIS ? (out << to_string(MILLIS)) : (out << "")); - out << ", " << "MICROS="; (__isset.MICROS ? (out << to_string(MICROS)) : (out << "")); - out << ", " << "NANOS="; (__isset.NANOS ? (out << to_string(NANOS)) : (out << "")); - out << ")"; -} - - -TimestampType::~TimestampType() throw() { -} - - -void TimestampType::__set_isAdjustedToUTC(const bool val) { - this->isAdjustedToUTC = val; -} - -void TimestampType::__set_unit(const TimeUnit& val) { - this->unit = val; -} -std::ostream& operator<<(std::ostream& out, const TimestampType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t TimestampType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_isAdjustedToUTC = false; - bool isset_unit = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->isAdjustedToUTC); - isset_isAdjustedToUTC = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->unit.read(iprot); - isset_unit = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_isAdjustedToUTC) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_unit) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t TimestampType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("TimestampType"); - - xfer += oprot->writeFieldBegin("isAdjustedToUTC", ::apache::thrift::protocol::T_BOOL, 1); - xfer += oprot->writeBool(this->isAdjustedToUTC); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("unit", ::apache::thrift::protocol::T_STRUCT, 2); - xfer += this->unit.write(oprot); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(TimestampType &a, TimestampType &b) { - using ::std::swap; - swap(a.isAdjustedToUTC, b.isAdjustedToUTC); - swap(a.unit, b.unit); -} - -TimestampType::TimestampType(const TimestampType& other26) { - isAdjustedToUTC = other26.isAdjustedToUTC; - unit = other26.unit; -} -TimestampType& TimestampType::operator=(const TimestampType& other27) { - isAdjustedToUTC = other27.isAdjustedToUTC; - unit = other27.unit; - return *this; -} -void TimestampType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "TimestampType("; - out << "isAdjustedToUTC=" << to_string(isAdjustedToUTC); - out << ", " << "unit=" << to_string(unit); - out << ")"; -} - - -TimeType::~TimeType() throw() { -} - - -void TimeType::__set_isAdjustedToUTC(const bool val) { - this->isAdjustedToUTC = val; -} - -void TimeType::__set_unit(const TimeUnit& val) { - this->unit = val; -} -std::ostream& operator<<(std::ostream& out, const TimeType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t TimeType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_isAdjustedToUTC = false; - bool isset_unit = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->isAdjustedToUTC); - isset_isAdjustedToUTC = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->unit.read(iprot); - isset_unit = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_isAdjustedToUTC) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_unit) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t TimeType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("TimeType"); - - xfer += oprot->writeFieldBegin("isAdjustedToUTC", ::apache::thrift::protocol::T_BOOL, 1); - xfer += oprot->writeBool(this->isAdjustedToUTC); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("unit", ::apache::thrift::protocol::T_STRUCT, 2); - xfer += this->unit.write(oprot); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(TimeType &a, TimeType &b) { - using ::std::swap; - swap(a.isAdjustedToUTC, b.isAdjustedToUTC); - swap(a.unit, b.unit); -} - -TimeType::TimeType(const TimeType& other28) { - isAdjustedToUTC = other28.isAdjustedToUTC; - unit = other28.unit; -} -TimeType& TimeType::operator=(const TimeType& other29) { - isAdjustedToUTC = other29.isAdjustedToUTC; - unit = other29.unit; - return *this; -} -void TimeType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "TimeType("; - out << "isAdjustedToUTC=" << to_string(isAdjustedToUTC); - out << ", " << "unit=" << to_string(unit); - out << ")"; -} - - -IntType::~IntType() throw() { -} - - -void IntType::__set_bitWidth(const int8_t val) { - this->bitWidth = val; -} - -void IntType::__set_isSigned(const bool val) { - this->isSigned = val; -} -std::ostream& operator<<(std::ostream& out, const IntType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t IntType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_bitWidth = false; - bool isset_isSigned = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_BYTE) { - xfer += iprot->readByte(this->bitWidth); - isset_bitWidth = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->isSigned); - isset_isSigned = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_bitWidth) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_isSigned) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t IntType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("IntType"); - - xfer += oprot->writeFieldBegin("bitWidth", ::apache::thrift::protocol::T_BYTE, 1); - xfer += oprot->writeByte(this->bitWidth); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("isSigned", ::apache::thrift::protocol::T_BOOL, 2); - xfer += oprot->writeBool(this->isSigned); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(IntType &a, IntType &b) { - using ::std::swap; - swap(a.bitWidth, b.bitWidth); - swap(a.isSigned, b.isSigned); -} - -IntType::IntType(const IntType& other30) { - bitWidth = other30.bitWidth; - isSigned = other30.isSigned; -} -IntType& IntType::operator=(const IntType& other31) { - bitWidth = other31.bitWidth; - isSigned = other31.isSigned; - return *this; -} -void IntType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "IntType("; - out << "bitWidth=" << to_string(bitWidth); - out << ", " << "isSigned=" << to_string(isSigned); - out << ")"; -} - - -JsonType::~JsonType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const JsonType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t JsonType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t JsonType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("JsonType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(JsonType &a, JsonType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -JsonType::JsonType(const JsonType& other32) { - (void) other32; -} -JsonType& JsonType::operator=(const JsonType& other33) { - (void) other33; - return *this; -} -void JsonType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "JsonType("; - out << ")"; -} - - -BsonType::~BsonType() throw() { -} - -std::ostream& operator<<(std::ostream& out, const BsonType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t BsonType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t BsonType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("BsonType"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(BsonType &a, BsonType &b) { - using ::std::swap; - (void) a; - (void) b; -} - -BsonType::BsonType(const BsonType& other34) { - (void) other34; -} -BsonType& BsonType::operator=(const BsonType& other35) { - (void) other35; - return *this; -} -void BsonType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "BsonType("; - out << ")"; -} - - -LogicalType::~LogicalType() throw() { -} - - -void LogicalType::__set_STRING(const StringType& val) { - this->STRING = val; -__isset.STRING = true; -} - -void LogicalType::__set_MAP(const MapType& val) { - this->MAP = val; -__isset.MAP = true; -} - -void LogicalType::__set_LIST(const ListType& val) { - this->LIST = val; -__isset.LIST = true; -} - -void LogicalType::__set_ENUM(const EnumType& val) { - this->ENUM = val; -__isset.ENUM = true; -} - -void LogicalType::__set_DECIMAL(const DecimalType& val) { - this->DECIMAL = val; -__isset.DECIMAL = true; -} - -void LogicalType::__set_DATE(const DateType& val) { - this->DATE = val; -__isset.DATE = true; -} - -void LogicalType::__set_TIME(const TimeType& val) { - this->TIME = val; -__isset.TIME = true; -} - -void LogicalType::__set_TIMESTAMP(const TimestampType& val) { - this->TIMESTAMP = val; -__isset.TIMESTAMP = true; -} - -void LogicalType::__set_INTEGER(const IntType& val) { - this->INTEGER = val; -__isset.INTEGER = true; -} - -void LogicalType::__set_UNKNOWN(const NullType& val) { - this->UNKNOWN = val; -__isset.UNKNOWN = true; -} - -void LogicalType::__set_JSON(const JsonType& val) { - this->JSON = val; -__isset.JSON = true; -} - -void LogicalType::__set_BSON(const BsonType& val) { - this->BSON = val; -__isset.BSON = true; -} - -void LogicalType::__set_UUID(const UUIDType& val) { - this->UUID = val; -__isset.UUID = true; -} -std::ostream& operator<<(std::ostream& out, const LogicalType& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t LogicalType::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->STRING.read(iprot); - this->__isset.STRING = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->MAP.read(iprot); - this->__isset.MAP = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->LIST.read(iprot); - this->__isset.LIST = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->ENUM.read(iprot); - this->__isset.ENUM = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->DECIMAL.read(iprot); - this->__isset.DECIMAL = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->DATE.read(iprot); - this->__isset.DATE = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 7: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->TIME.read(iprot); - this->__isset.TIME = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 8: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->TIMESTAMP.read(iprot); - this->__isset.TIMESTAMP = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 10: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->INTEGER.read(iprot); - this->__isset.INTEGER = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 11: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->UNKNOWN.read(iprot); - this->__isset.UNKNOWN = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 12: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->JSON.read(iprot); - this->__isset.JSON = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 13: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->BSON.read(iprot); - this->__isset.BSON = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 14: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->UUID.read(iprot); - this->__isset.UUID = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t LogicalType::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("LogicalType"); - - if (this->__isset.STRING) { - xfer += oprot->writeFieldBegin("STRING", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->STRING.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.MAP) { - xfer += oprot->writeFieldBegin("MAP", ::apache::thrift::protocol::T_STRUCT, 2); - xfer += this->MAP.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.LIST) { - xfer += oprot->writeFieldBegin("LIST", ::apache::thrift::protocol::T_STRUCT, 3); - xfer += this->LIST.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.ENUM) { - xfer += oprot->writeFieldBegin("ENUM", ::apache::thrift::protocol::T_STRUCT, 4); - xfer += this->ENUM.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.DECIMAL) { - xfer += oprot->writeFieldBegin("DECIMAL", ::apache::thrift::protocol::T_STRUCT, 5); - xfer += this->DECIMAL.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.DATE) { - xfer += oprot->writeFieldBegin("DATE", ::apache::thrift::protocol::T_STRUCT, 6); - xfer += this->DATE.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.TIME) { - xfer += oprot->writeFieldBegin("TIME", ::apache::thrift::protocol::T_STRUCT, 7); - xfer += this->TIME.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.TIMESTAMP) { - xfer += oprot->writeFieldBegin("TIMESTAMP", ::apache::thrift::protocol::T_STRUCT, 8); - xfer += this->TIMESTAMP.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.INTEGER) { - xfer += oprot->writeFieldBegin("INTEGER", ::apache::thrift::protocol::T_STRUCT, 10); - xfer += this->INTEGER.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.UNKNOWN) { - xfer += oprot->writeFieldBegin("UNKNOWN", ::apache::thrift::protocol::T_STRUCT, 11); - xfer += this->UNKNOWN.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.JSON) { - xfer += oprot->writeFieldBegin("JSON", ::apache::thrift::protocol::T_STRUCT, 12); - xfer += this->JSON.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.BSON) { - xfer += oprot->writeFieldBegin("BSON", ::apache::thrift::protocol::T_STRUCT, 13); - xfer += this->BSON.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.UUID) { - xfer += oprot->writeFieldBegin("UUID", ::apache::thrift::protocol::T_STRUCT, 14); - xfer += this->UUID.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(LogicalType &a, LogicalType &b) { - using ::std::swap; - swap(a.STRING, b.STRING); - swap(a.MAP, b.MAP); - swap(a.LIST, b.LIST); - swap(a.ENUM, b.ENUM); - swap(a.DECIMAL, b.DECIMAL); - swap(a.DATE, b.DATE); - swap(a.TIME, b.TIME); - swap(a.TIMESTAMP, b.TIMESTAMP); - swap(a.INTEGER, b.INTEGER); - swap(a.UNKNOWN, b.UNKNOWN); - swap(a.JSON, b.JSON); - swap(a.BSON, b.BSON); - swap(a.UUID, b.UUID); - swap(a.__isset, b.__isset); -} - -LogicalType::LogicalType(const LogicalType& other36) { - STRING = other36.STRING; - MAP = other36.MAP; - LIST = other36.LIST; - ENUM = other36.ENUM; - DECIMAL = other36.DECIMAL; - DATE = other36.DATE; - TIME = other36.TIME; - TIMESTAMP = other36.TIMESTAMP; - INTEGER = other36.INTEGER; - UNKNOWN = other36.UNKNOWN; - JSON = other36.JSON; - BSON = other36.BSON; - UUID = other36.UUID; - __isset = other36.__isset; -} -LogicalType& LogicalType::operator=(const LogicalType& other37) { - STRING = other37.STRING; - MAP = other37.MAP; - LIST = other37.LIST; - ENUM = other37.ENUM; - DECIMAL = other37.DECIMAL; - DATE = other37.DATE; - TIME = other37.TIME; - TIMESTAMP = other37.TIMESTAMP; - INTEGER = other37.INTEGER; - UNKNOWN = other37.UNKNOWN; - JSON = other37.JSON; - BSON = other37.BSON; - UUID = other37.UUID; - __isset = other37.__isset; - return *this; -} -void LogicalType::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "LogicalType("; - out << "STRING="; (__isset.STRING ? (out << to_string(STRING)) : (out << "")); - out << ", " << "MAP="; (__isset.MAP ? (out << to_string(MAP)) : (out << "")); - out << ", " << "LIST="; (__isset.LIST ? (out << to_string(LIST)) : (out << "")); - out << ", " << "ENUM="; (__isset.ENUM ? (out << to_string(ENUM)) : (out << "")); - out << ", " << "DECIMAL="; (__isset.DECIMAL ? (out << to_string(DECIMAL)) : (out << "")); - out << ", " << "DATE="; (__isset.DATE ? (out << to_string(DATE)) : (out << "")); - out << ", " << "TIME="; (__isset.TIME ? (out << to_string(TIME)) : (out << "")); - out << ", " << "TIMESTAMP="; (__isset.TIMESTAMP ? (out << to_string(TIMESTAMP)) : (out << "")); - out << ", " << "INTEGER="; (__isset.INTEGER ? (out << to_string(INTEGER)) : (out << "")); - out << ", " << "UNKNOWN="; (__isset.UNKNOWN ? (out << to_string(UNKNOWN)) : (out << "")); - out << ", " << "JSON="; (__isset.JSON ? (out << to_string(JSON)) : (out << "")); - out << ", " << "BSON="; (__isset.BSON ? (out << to_string(BSON)) : (out << "")); - out << ", " << "UUID="; (__isset.UUID ? (out << to_string(UUID)) : (out << "")); - out << ")"; -} - - -SchemaElement::~SchemaElement() throw() { -} - - -void SchemaElement::__set_type(const Type::type val) { - this->type = val; -__isset.type = true; -} - -void SchemaElement::__set_type_length(const int32_t val) { - this->type_length = val; -__isset.type_length = true; -} - -void SchemaElement::__set_repetition_type(const FieldRepetitionType::type val) { - this->repetition_type = val; -__isset.repetition_type = true; -} - -void SchemaElement::__set_name(const std::string& val) { - this->name = val; -} - -void SchemaElement::__set_num_children(const int32_t val) { - this->num_children = val; -__isset.num_children = true; -} - -void SchemaElement::__set_converted_type(const ConvertedType::type val) { - this->converted_type = val; -__isset.converted_type = true; -} - -void SchemaElement::__set_scale(const int32_t val) { - this->scale = val; -__isset.scale = true; -} - -void SchemaElement::__set_precision(const int32_t val) { - this->precision = val; -__isset.precision = true; -} - -void SchemaElement::__set_field_id(const int32_t val) { - this->field_id = val; -__isset.field_id = true; -} - -void SchemaElement::__set_logicalType(const LogicalType& val) { - this->logicalType = val; -__isset.logicalType = true; -} -std::ostream& operator<<(std::ostream& out, const SchemaElement& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t SchemaElement::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_name = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast38; - xfer += iprot->readI32(ecast38); - this->type = (Type::type)ecast38; - this->__isset.type = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->type_length); - this->__isset.type_length = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast39; - xfer += iprot->readI32(ecast39); - this->repetition_type = (FieldRepetitionType::type)ecast39; - this->__isset.repetition_type = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readString(this->name); - isset_name = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->num_children); - this->__isset.num_children = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast40; - xfer += iprot->readI32(ecast40); - this->converted_type = (ConvertedType::type)ecast40; - this->__isset.converted_type = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 7: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->scale); - this->__isset.scale = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 8: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->precision); - this->__isset.precision = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 9: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->field_id); - this->__isset.field_id = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 10: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->logicalType.read(iprot); - this->__isset.logicalType = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_name) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t SchemaElement::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("SchemaElement"); - - if (this->__isset.type) { - xfer += oprot->writeFieldBegin("type", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32((int32_t)this->type); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.type_length) { - xfer += oprot->writeFieldBegin("type_length", ::apache::thrift::protocol::T_I32, 2); - xfer += oprot->writeI32(this->type_length); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.repetition_type) { - xfer += oprot->writeFieldBegin("repetition_type", ::apache::thrift::protocol::T_I32, 3); - xfer += oprot->writeI32((int32_t)this->repetition_type); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldBegin("name", ::apache::thrift::protocol::T_STRING, 4); - xfer += oprot->writeString(this->name); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.num_children) { - xfer += oprot->writeFieldBegin("num_children", ::apache::thrift::protocol::T_I32, 5); - xfer += oprot->writeI32(this->num_children); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.converted_type) { - xfer += oprot->writeFieldBegin("converted_type", ::apache::thrift::protocol::T_I32, 6); - xfer += oprot->writeI32((int32_t)this->converted_type); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.scale) { - xfer += oprot->writeFieldBegin("scale", ::apache::thrift::protocol::T_I32, 7); - xfer += oprot->writeI32(this->scale); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.precision) { - xfer += oprot->writeFieldBegin("precision", ::apache::thrift::protocol::T_I32, 8); - xfer += oprot->writeI32(this->precision); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.field_id) { - xfer += oprot->writeFieldBegin("field_id", ::apache::thrift::protocol::T_I32, 9); - xfer += oprot->writeI32(this->field_id); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.logicalType) { - xfer += oprot->writeFieldBegin("logicalType", ::apache::thrift::protocol::T_STRUCT, 10); - xfer += this->logicalType.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(SchemaElement &a, SchemaElement &b) { - using ::std::swap; - swap(a.type, b.type); - swap(a.type_length, b.type_length); - swap(a.repetition_type, b.repetition_type); - swap(a.name, b.name); - swap(a.num_children, b.num_children); - swap(a.converted_type, b.converted_type); - swap(a.scale, b.scale); - swap(a.precision, b.precision); - swap(a.field_id, b.field_id); - swap(a.logicalType, b.logicalType); - swap(a.__isset, b.__isset); -} - -SchemaElement::SchemaElement(const SchemaElement& other41) { - type = other41.type; - type_length = other41.type_length; - repetition_type = other41.repetition_type; - name = other41.name; - num_children = other41.num_children; - converted_type = other41.converted_type; - scale = other41.scale; - precision = other41.precision; - field_id = other41.field_id; - logicalType = other41.logicalType; - __isset = other41.__isset; -} -SchemaElement& SchemaElement::operator=(const SchemaElement& other42) { - type = other42.type; - type_length = other42.type_length; - repetition_type = other42.repetition_type; - name = other42.name; - num_children = other42.num_children; - converted_type = other42.converted_type; - scale = other42.scale; - precision = other42.precision; - field_id = other42.field_id; - logicalType = other42.logicalType; - __isset = other42.__isset; - return *this; -} -void SchemaElement::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "SchemaElement("; - out << "type="; (__isset.type ? (out << to_string(type)) : (out << "")); - out << ", " << "type_length="; (__isset.type_length ? (out << to_string(type_length)) : (out << "")); - out << ", " << "repetition_type="; (__isset.repetition_type ? (out << to_string(repetition_type)) : (out << "")); - out << ", " << "name=" << to_string(name); - out << ", " << "num_children="; (__isset.num_children ? (out << to_string(num_children)) : (out << "")); - out << ", " << "converted_type="; (__isset.converted_type ? (out << to_string(converted_type)) : (out << "")); - out << ", " << "scale="; (__isset.scale ? (out << to_string(scale)) : (out << "")); - out << ", " << "precision="; (__isset.precision ? (out << to_string(precision)) : (out << "")); - out << ", " << "field_id="; (__isset.field_id ? (out << to_string(field_id)) : (out << "")); - out << ", " << "logicalType="; (__isset.logicalType ? (out << to_string(logicalType)) : (out << "")); - out << ")"; -} - - -DataPageHeader::~DataPageHeader() throw() { -} - - -void DataPageHeader::__set_num_values(const int32_t val) { - this->num_values = val; -} - -void DataPageHeader::__set_encoding(const Encoding::type val) { - this->encoding = val; -} - -void DataPageHeader::__set_definition_level_encoding(const Encoding::type val) { - this->definition_level_encoding = val; -} - -void DataPageHeader::__set_repetition_level_encoding(const Encoding::type val) { - this->repetition_level_encoding = val; -} - -void DataPageHeader::__set_statistics(const Statistics& val) { - this->statistics = val; -__isset.statistics = true; -} -std::ostream& operator<<(std::ostream& out, const DataPageHeader& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t DataPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_num_values = false; - bool isset_encoding = false; - bool isset_definition_level_encoding = false; - bool isset_repetition_level_encoding = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->num_values); - isset_num_values = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast43; - xfer += iprot->readI32(ecast43); - this->encoding = (Encoding::type)ecast43; - isset_encoding = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast44; - xfer += iprot->readI32(ecast44); - this->definition_level_encoding = (Encoding::type)ecast44; - isset_definition_level_encoding = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast45; - xfer += iprot->readI32(ecast45); - this->repetition_level_encoding = (Encoding::type)ecast45; - isset_repetition_level_encoding = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->statistics.read(iprot); - this->__isset.statistics = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_num_values) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_encoding) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_definition_level_encoding) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_repetition_level_encoding) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t DataPageHeader::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("DataPageHeader"); - - xfer += oprot->writeFieldBegin("num_values", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32(this->num_values); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("encoding", ::apache::thrift::protocol::T_I32, 2); - xfer += oprot->writeI32((int32_t)this->encoding); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("definition_level_encoding", ::apache::thrift::protocol::T_I32, 3); - xfer += oprot->writeI32((int32_t)this->definition_level_encoding); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("repetition_level_encoding", ::apache::thrift::protocol::T_I32, 4); - xfer += oprot->writeI32((int32_t)this->repetition_level_encoding); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.statistics) { - xfer += oprot->writeFieldBegin("statistics", ::apache::thrift::protocol::T_STRUCT, 5); - xfer += this->statistics.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(DataPageHeader &a, DataPageHeader &b) { - using ::std::swap; - swap(a.num_values, b.num_values); - swap(a.encoding, b.encoding); - swap(a.definition_level_encoding, b.definition_level_encoding); - swap(a.repetition_level_encoding, b.repetition_level_encoding); - swap(a.statistics, b.statistics); - swap(a.__isset, b.__isset); -} - -DataPageHeader::DataPageHeader(const DataPageHeader& other46) { - num_values = other46.num_values; - encoding = other46.encoding; - definition_level_encoding = other46.definition_level_encoding; - repetition_level_encoding = other46.repetition_level_encoding; - statistics = other46.statistics; - __isset = other46.__isset; -} -DataPageHeader& DataPageHeader::operator=(const DataPageHeader& other47) { - num_values = other47.num_values; - encoding = other47.encoding; - definition_level_encoding = other47.definition_level_encoding; - repetition_level_encoding = other47.repetition_level_encoding; - statistics = other47.statistics; - __isset = other47.__isset; - return *this; -} -void DataPageHeader::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "DataPageHeader("; - out << "num_values=" << to_string(num_values); - out << ", " << "encoding=" << to_string(encoding); - out << ", " << "definition_level_encoding=" << to_string(definition_level_encoding); - out << ", " << "repetition_level_encoding=" << to_string(repetition_level_encoding); - out << ", " << "statistics="; (__isset.statistics ? (out << to_string(statistics)) : (out << "")); - out << ")"; -} - - -IndexPageHeader::~IndexPageHeader() throw() { -} - -std::ostream& operator<<(std::ostream& out, const IndexPageHeader& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t IndexPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t IndexPageHeader::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("IndexPageHeader"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(IndexPageHeader &a, IndexPageHeader &b) { - using ::std::swap; - (void) a; - (void) b; -} - -IndexPageHeader::IndexPageHeader(const IndexPageHeader& other48) { - (void) other48; -} -IndexPageHeader& IndexPageHeader::operator=(const IndexPageHeader& other49) { - (void) other49; - return *this; -} -void IndexPageHeader::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "IndexPageHeader("; - out << ")"; -} - - -DictionaryPageHeader::~DictionaryPageHeader() throw() { -} - - -void DictionaryPageHeader::__set_num_values(const int32_t val) { - this->num_values = val; -} - -void DictionaryPageHeader::__set_encoding(const Encoding::type val) { - this->encoding = val; -} - -void DictionaryPageHeader::__set_is_sorted(const bool val) { - this->is_sorted = val; -__isset.is_sorted = true; -} -std::ostream& operator<<(std::ostream& out, const DictionaryPageHeader& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t DictionaryPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_num_values = false; - bool isset_encoding = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->num_values); - isset_num_values = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast50; - xfer += iprot->readI32(ecast50); - this->encoding = (Encoding::type)ecast50; - isset_encoding = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->is_sorted); - this->__isset.is_sorted = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_num_values) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_encoding) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t DictionaryPageHeader::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("DictionaryPageHeader"); - - xfer += oprot->writeFieldBegin("num_values", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32(this->num_values); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("encoding", ::apache::thrift::protocol::T_I32, 2); - xfer += oprot->writeI32((int32_t)this->encoding); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.is_sorted) { - xfer += oprot->writeFieldBegin("is_sorted", ::apache::thrift::protocol::T_BOOL, 3); - xfer += oprot->writeBool(this->is_sorted); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(DictionaryPageHeader &a, DictionaryPageHeader &b) { - using ::std::swap; - swap(a.num_values, b.num_values); - swap(a.encoding, b.encoding); - swap(a.is_sorted, b.is_sorted); - swap(a.__isset, b.__isset); -} - -DictionaryPageHeader::DictionaryPageHeader(const DictionaryPageHeader& other51) { - num_values = other51.num_values; - encoding = other51.encoding; - is_sorted = other51.is_sorted; - __isset = other51.__isset; -} -DictionaryPageHeader& DictionaryPageHeader::operator=(const DictionaryPageHeader& other52) { - num_values = other52.num_values; - encoding = other52.encoding; - is_sorted = other52.is_sorted; - __isset = other52.__isset; - return *this; -} -void DictionaryPageHeader::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "DictionaryPageHeader("; - out << "num_values=" << to_string(num_values); - out << ", " << "encoding=" << to_string(encoding); - out << ", " << "is_sorted="; (__isset.is_sorted ? (out << to_string(is_sorted)) : (out << "")); - out << ")"; -} - - -DataPageHeaderV2::~DataPageHeaderV2() throw() { -} - - -void DataPageHeaderV2::__set_num_values(const int32_t val) { - this->num_values = val; -} - -void DataPageHeaderV2::__set_num_nulls(const int32_t val) { - this->num_nulls = val; -} - -void DataPageHeaderV2::__set_num_rows(const int32_t val) { - this->num_rows = val; -} - -void DataPageHeaderV2::__set_encoding(const Encoding::type val) { - this->encoding = val; -} - -void DataPageHeaderV2::__set_definition_levels_byte_length(const int32_t val) { - this->definition_levels_byte_length = val; -} - -void DataPageHeaderV2::__set_repetition_levels_byte_length(const int32_t val) { - this->repetition_levels_byte_length = val; -} - -void DataPageHeaderV2::__set_is_compressed(const bool val) { - this->is_compressed = val; -__isset.is_compressed = true; -} - -void DataPageHeaderV2::__set_statistics(const Statistics& val) { - this->statistics = val; -__isset.statistics = true; -} -std::ostream& operator<<(std::ostream& out, const DataPageHeaderV2& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t DataPageHeaderV2::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_num_values = false; - bool isset_num_nulls = false; - bool isset_num_rows = false; - bool isset_encoding = false; - bool isset_definition_levels_byte_length = false; - bool isset_repetition_levels_byte_length = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->num_values); - isset_num_values = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->num_nulls); - isset_num_nulls = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->num_rows); - isset_num_rows = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast53; - xfer += iprot->readI32(ecast53); - this->encoding = (Encoding::type)ecast53; - isset_encoding = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->definition_levels_byte_length); - isset_definition_levels_byte_length = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->repetition_levels_byte_length); - isset_repetition_levels_byte_length = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 7: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->is_compressed); - this->__isset.is_compressed = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 8: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->statistics.read(iprot); - this->__isset.statistics = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_num_values) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_num_nulls) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_num_rows) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_encoding) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_definition_levels_byte_length) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_repetition_levels_byte_length) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t DataPageHeaderV2::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("DataPageHeaderV2"); - - xfer += oprot->writeFieldBegin("num_values", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32(this->num_values); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("num_nulls", ::apache::thrift::protocol::T_I32, 2); - xfer += oprot->writeI32(this->num_nulls); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("num_rows", ::apache::thrift::protocol::T_I32, 3); - xfer += oprot->writeI32(this->num_rows); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("encoding", ::apache::thrift::protocol::T_I32, 4); - xfer += oprot->writeI32((int32_t)this->encoding); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("definition_levels_byte_length", ::apache::thrift::protocol::T_I32, 5); - xfer += oprot->writeI32(this->definition_levels_byte_length); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("repetition_levels_byte_length", ::apache::thrift::protocol::T_I32, 6); - xfer += oprot->writeI32(this->repetition_levels_byte_length); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.is_compressed) { - xfer += oprot->writeFieldBegin("is_compressed", ::apache::thrift::protocol::T_BOOL, 7); - xfer += oprot->writeBool(this->is_compressed); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.statistics) { - xfer += oprot->writeFieldBegin("statistics", ::apache::thrift::protocol::T_STRUCT, 8); - xfer += this->statistics.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(DataPageHeaderV2 &a, DataPageHeaderV2 &b) { - using ::std::swap; - swap(a.num_values, b.num_values); - swap(a.num_nulls, b.num_nulls); - swap(a.num_rows, b.num_rows); - swap(a.encoding, b.encoding); - swap(a.definition_levels_byte_length, b.definition_levels_byte_length); - swap(a.repetition_levels_byte_length, b.repetition_levels_byte_length); - swap(a.is_compressed, b.is_compressed); - swap(a.statistics, b.statistics); - swap(a.__isset, b.__isset); -} - -DataPageHeaderV2::DataPageHeaderV2(const DataPageHeaderV2& other54) { - num_values = other54.num_values; - num_nulls = other54.num_nulls; - num_rows = other54.num_rows; - encoding = other54.encoding; - definition_levels_byte_length = other54.definition_levels_byte_length; - repetition_levels_byte_length = other54.repetition_levels_byte_length; - is_compressed = other54.is_compressed; - statistics = other54.statistics; - __isset = other54.__isset; -} -DataPageHeaderV2& DataPageHeaderV2::operator=(const DataPageHeaderV2& other55) { - num_values = other55.num_values; - num_nulls = other55.num_nulls; - num_rows = other55.num_rows; - encoding = other55.encoding; - definition_levels_byte_length = other55.definition_levels_byte_length; - repetition_levels_byte_length = other55.repetition_levels_byte_length; - is_compressed = other55.is_compressed; - statistics = other55.statistics; - __isset = other55.__isset; - return *this; -} -void DataPageHeaderV2::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "DataPageHeaderV2("; - out << "num_values=" << to_string(num_values); - out << ", " << "num_nulls=" << to_string(num_nulls); - out << ", " << "num_rows=" << to_string(num_rows); - out << ", " << "encoding=" << to_string(encoding); - out << ", " << "definition_levels_byte_length=" << to_string(definition_levels_byte_length); - out << ", " << "repetition_levels_byte_length=" << to_string(repetition_levels_byte_length); - out << ", " << "is_compressed="; (__isset.is_compressed ? (out << to_string(is_compressed)) : (out << "")); - out << ", " << "statistics="; (__isset.statistics ? (out << to_string(statistics)) : (out << "")); - out << ")"; -} - - -SplitBlockAlgorithm::~SplitBlockAlgorithm() throw() { -} - -std::ostream& operator<<(std::ostream& out, const SplitBlockAlgorithm& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t SplitBlockAlgorithm::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t SplitBlockAlgorithm::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("SplitBlockAlgorithm"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(SplitBlockAlgorithm &a, SplitBlockAlgorithm &b) { - using ::std::swap; - (void) a; - (void) b; -} - -SplitBlockAlgorithm::SplitBlockAlgorithm(const SplitBlockAlgorithm& other56) { - (void) other56; -} -SplitBlockAlgorithm& SplitBlockAlgorithm::operator=(const SplitBlockAlgorithm& other57) { - (void) other57; - return *this; -} -void SplitBlockAlgorithm::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "SplitBlockAlgorithm("; - out << ")"; -} - - -BloomFilterAlgorithm::~BloomFilterAlgorithm() throw() { -} - - -void BloomFilterAlgorithm::__set_BLOCK(const SplitBlockAlgorithm& val) { - this->BLOCK = val; -__isset.BLOCK = true; -} -std::ostream& operator<<(std::ostream& out, const BloomFilterAlgorithm& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t BloomFilterAlgorithm::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->BLOCK.read(iprot); - this->__isset.BLOCK = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t BloomFilterAlgorithm::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("BloomFilterAlgorithm"); - - if (this->__isset.BLOCK) { - xfer += oprot->writeFieldBegin("BLOCK", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->BLOCK.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(BloomFilterAlgorithm &a, BloomFilterAlgorithm &b) { - using ::std::swap; - swap(a.BLOCK, b.BLOCK); - swap(a.__isset, b.__isset); -} - -BloomFilterAlgorithm::BloomFilterAlgorithm(const BloomFilterAlgorithm& other58) { - BLOCK = other58.BLOCK; - __isset = other58.__isset; -} -BloomFilterAlgorithm& BloomFilterAlgorithm::operator=(const BloomFilterAlgorithm& other59) { - BLOCK = other59.BLOCK; - __isset = other59.__isset; - return *this; -} -void BloomFilterAlgorithm::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "BloomFilterAlgorithm("; - out << "BLOCK="; (__isset.BLOCK ? (out << to_string(BLOCK)) : (out << "")); - out << ")"; -} - - -XxHash::~XxHash() throw() { -} - -std::ostream& operator<<(std::ostream& out, const XxHash& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t XxHash::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t XxHash::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("XxHash"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(XxHash &a, XxHash &b) { - using ::std::swap; - (void) a; - (void) b; -} - -XxHash::XxHash(const XxHash& other60) { - (void) other60; -} -XxHash& XxHash::operator=(const XxHash& other61) { - (void) other61; - return *this; -} -void XxHash::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "XxHash("; - out << ")"; -} - - -BloomFilterHash::~BloomFilterHash() throw() { -} - - -void BloomFilterHash::__set_XXHASH(const XxHash& val) { - this->XXHASH = val; -__isset.XXHASH = true; -} -std::ostream& operator<<(std::ostream& out, const BloomFilterHash& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t BloomFilterHash::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->XXHASH.read(iprot); - this->__isset.XXHASH = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t BloomFilterHash::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("BloomFilterHash"); - - if (this->__isset.XXHASH) { - xfer += oprot->writeFieldBegin("XXHASH", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->XXHASH.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(BloomFilterHash &a, BloomFilterHash &b) { - using ::std::swap; - swap(a.XXHASH, b.XXHASH); - swap(a.__isset, b.__isset); -} - -BloomFilterHash::BloomFilterHash(const BloomFilterHash& other62) { - XXHASH = other62.XXHASH; - __isset = other62.__isset; -} -BloomFilterHash& BloomFilterHash::operator=(const BloomFilterHash& other63) { - XXHASH = other63.XXHASH; - __isset = other63.__isset; - return *this; -} -void BloomFilterHash::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "BloomFilterHash("; - out << "XXHASH="; (__isset.XXHASH ? (out << to_string(XXHASH)) : (out << "")); - out << ")"; -} - - -Uncompressed::~Uncompressed() throw() { -} - -std::ostream& operator<<(std::ostream& out, const Uncompressed& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t Uncompressed::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t Uncompressed::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("Uncompressed"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(Uncompressed &a, Uncompressed &b) { - using ::std::swap; - (void) a; - (void) b; -} - -Uncompressed::Uncompressed(const Uncompressed& other64) { - (void) other64; -} -Uncompressed& Uncompressed::operator=(const Uncompressed& other65) { - (void) other65; - return *this; -} -void Uncompressed::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "Uncompressed("; - out << ")"; -} - - -BloomFilterCompression::~BloomFilterCompression() throw() { -} - - -void BloomFilterCompression::__set_UNCOMPRESSED(const Uncompressed& val) { - this->UNCOMPRESSED = val; -__isset.UNCOMPRESSED = true; -} -std::ostream& operator<<(std::ostream& out, const BloomFilterCompression& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t BloomFilterCompression::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->UNCOMPRESSED.read(iprot); - this->__isset.UNCOMPRESSED = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t BloomFilterCompression::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("BloomFilterCompression"); - - if (this->__isset.UNCOMPRESSED) { - xfer += oprot->writeFieldBegin("UNCOMPRESSED", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->UNCOMPRESSED.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(BloomFilterCompression &a, BloomFilterCompression &b) { - using ::std::swap; - swap(a.UNCOMPRESSED, b.UNCOMPRESSED); - swap(a.__isset, b.__isset); -} - -BloomFilterCompression::BloomFilterCompression(const BloomFilterCompression& other66) { - UNCOMPRESSED = other66.UNCOMPRESSED; - __isset = other66.__isset; -} -BloomFilterCompression& BloomFilterCompression::operator=(const BloomFilterCompression& other67) { - UNCOMPRESSED = other67.UNCOMPRESSED; - __isset = other67.__isset; - return *this; -} -void BloomFilterCompression::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "BloomFilterCompression("; - out << "UNCOMPRESSED="; (__isset.UNCOMPRESSED ? (out << to_string(UNCOMPRESSED)) : (out << "")); - out << ")"; -} - - -BloomFilterHeader::~BloomFilterHeader() throw() { -} - - -void BloomFilterHeader::__set_numBytes(const int32_t val) { - this->numBytes = val; -} - -void BloomFilterHeader::__set_algorithm(const BloomFilterAlgorithm& val) { - this->algorithm = val; -} - -void BloomFilterHeader::__set_hash(const BloomFilterHash& val) { - this->hash = val; -} - -void BloomFilterHeader::__set_compression(const BloomFilterCompression& val) { - this->compression = val; -} -std::ostream& operator<<(std::ostream& out, const BloomFilterHeader& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t BloomFilterHeader::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_numBytes = false; - bool isset_algorithm = false; - bool isset_hash = false; - bool isset_compression = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->numBytes); - isset_numBytes = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->algorithm.read(iprot); - isset_algorithm = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->hash.read(iprot); - isset_hash = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->compression.read(iprot); - isset_compression = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_numBytes) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_algorithm) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_hash) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_compression) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t BloomFilterHeader::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("BloomFilterHeader"); - - xfer += oprot->writeFieldBegin("numBytes", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32(this->numBytes); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("algorithm", ::apache::thrift::protocol::T_STRUCT, 2); - xfer += this->algorithm.write(oprot); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("hash", ::apache::thrift::protocol::T_STRUCT, 3); - xfer += this->hash.write(oprot); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("compression", ::apache::thrift::protocol::T_STRUCT, 4); - xfer += this->compression.write(oprot); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(BloomFilterHeader &a, BloomFilterHeader &b) { - using ::std::swap; - swap(a.numBytes, b.numBytes); - swap(a.algorithm, b.algorithm); - swap(a.hash, b.hash); - swap(a.compression, b.compression); -} - -BloomFilterHeader::BloomFilterHeader(const BloomFilterHeader& other68) { - numBytes = other68.numBytes; - algorithm = other68.algorithm; - hash = other68.hash; - compression = other68.compression; -} -BloomFilterHeader& BloomFilterHeader::operator=(const BloomFilterHeader& other69) { - numBytes = other69.numBytes; - algorithm = other69.algorithm; - hash = other69.hash; - compression = other69.compression; - return *this; -} -void BloomFilterHeader::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "BloomFilterHeader("; - out << "numBytes=" << to_string(numBytes); - out << ", " << "algorithm=" << to_string(algorithm); - out << ", " << "hash=" << to_string(hash); - out << ", " << "compression=" << to_string(compression); - out << ")"; -} - - -PageHeader::~PageHeader() throw() { -} - - -void PageHeader::__set_type(const PageType::type val) { - this->type = val; -} - -void PageHeader::__set_uncompressed_page_size(const int32_t val) { - this->uncompressed_page_size = val; -} - -void PageHeader::__set_compressed_page_size(const int32_t val) { - this->compressed_page_size = val; -} - -void PageHeader::__set_crc(const int32_t val) { - this->crc = val; -__isset.crc = true; -} - -void PageHeader::__set_data_page_header(const DataPageHeader& val) { - this->data_page_header = val; -__isset.data_page_header = true; -} - -void PageHeader::__set_index_page_header(const IndexPageHeader& val) { - this->index_page_header = val; -__isset.index_page_header = true; -} - -void PageHeader::__set_dictionary_page_header(const DictionaryPageHeader& val) { - this->dictionary_page_header = val; -__isset.dictionary_page_header = true; -} - -void PageHeader::__set_data_page_header_v2(const DataPageHeaderV2& val) { - this->data_page_header_v2 = val; -__isset.data_page_header_v2 = true; -} -std::ostream& operator<<(std::ostream& out, const PageHeader& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t PageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_type = false; - bool isset_uncompressed_page_size = false; - bool isset_compressed_page_size = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast70; - xfer += iprot->readI32(ecast70); - this->type = (PageType::type)ecast70; - isset_type = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->uncompressed_page_size); - isset_uncompressed_page_size = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->compressed_page_size); - isset_compressed_page_size = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->crc); - this->__isset.crc = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->data_page_header.read(iprot); - this->__isset.data_page_header = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->index_page_header.read(iprot); - this->__isset.index_page_header = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 7: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->dictionary_page_header.read(iprot); - this->__isset.dictionary_page_header = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 8: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->data_page_header_v2.read(iprot); - this->__isset.data_page_header_v2 = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_type) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_uncompressed_page_size) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_compressed_page_size) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t PageHeader::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("PageHeader"); - - xfer += oprot->writeFieldBegin("type", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32((int32_t)this->type); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("uncompressed_page_size", ::apache::thrift::protocol::T_I32, 2); - xfer += oprot->writeI32(this->uncompressed_page_size); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("compressed_page_size", ::apache::thrift::protocol::T_I32, 3); - xfer += oprot->writeI32(this->compressed_page_size); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.crc) { - xfer += oprot->writeFieldBegin("crc", ::apache::thrift::protocol::T_I32, 4); - xfer += oprot->writeI32(this->crc); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.data_page_header) { - xfer += oprot->writeFieldBegin("data_page_header", ::apache::thrift::protocol::T_STRUCT, 5); - xfer += this->data_page_header.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.index_page_header) { - xfer += oprot->writeFieldBegin("index_page_header", ::apache::thrift::protocol::T_STRUCT, 6); - xfer += this->index_page_header.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.dictionary_page_header) { - xfer += oprot->writeFieldBegin("dictionary_page_header", ::apache::thrift::protocol::T_STRUCT, 7); - xfer += this->dictionary_page_header.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.data_page_header_v2) { - xfer += oprot->writeFieldBegin("data_page_header_v2", ::apache::thrift::protocol::T_STRUCT, 8); - xfer += this->data_page_header_v2.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(PageHeader &a, PageHeader &b) { - using ::std::swap; - swap(a.type, b.type); - swap(a.uncompressed_page_size, b.uncompressed_page_size); - swap(a.compressed_page_size, b.compressed_page_size); - swap(a.crc, b.crc); - swap(a.data_page_header, b.data_page_header); - swap(a.index_page_header, b.index_page_header); - swap(a.dictionary_page_header, b.dictionary_page_header); - swap(a.data_page_header_v2, b.data_page_header_v2); - swap(a.__isset, b.__isset); -} - -PageHeader::PageHeader(const PageHeader& other71) { - type = other71.type; - uncompressed_page_size = other71.uncompressed_page_size; - compressed_page_size = other71.compressed_page_size; - crc = other71.crc; - data_page_header = other71.data_page_header; - index_page_header = other71.index_page_header; - dictionary_page_header = other71.dictionary_page_header; - data_page_header_v2 = other71.data_page_header_v2; - __isset = other71.__isset; -} -PageHeader& PageHeader::operator=(const PageHeader& other72) { - type = other72.type; - uncompressed_page_size = other72.uncompressed_page_size; - compressed_page_size = other72.compressed_page_size; - crc = other72.crc; - data_page_header = other72.data_page_header; - index_page_header = other72.index_page_header; - dictionary_page_header = other72.dictionary_page_header; - data_page_header_v2 = other72.data_page_header_v2; - __isset = other72.__isset; - return *this; -} -void PageHeader::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "PageHeader("; - out << "type=" << to_string(type); - out << ", " << "uncompressed_page_size=" << to_string(uncompressed_page_size); - out << ", " << "compressed_page_size=" << to_string(compressed_page_size); - out << ", " << "crc="; (__isset.crc ? (out << to_string(crc)) : (out << "")); - out << ", " << "data_page_header="; (__isset.data_page_header ? (out << to_string(data_page_header)) : (out << "")); - out << ", " << "index_page_header="; (__isset.index_page_header ? (out << to_string(index_page_header)) : (out << "")); - out << ", " << "dictionary_page_header="; (__isset.dictionary_page_header ? (out << to_string(dictionary_page_header)) : (out << "")); - out << ", " << "data_page_header_v2="; (__isset.data_page_header_v2 ? (out << to_string(data_page_header_v2)) : (out << "")); - out << ")"; -} - - -KeyValue::~KeyValue() throw() { -} - - -void KeyValue::__set_key(const std::string& val) { - this->key = val; -} - -void KeyValue::__set_value(const std::string& val) { - this->value = val; -__isset.value = true; -} -std::ostream& operator<<(std::ostream& out, const KeyValue& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t KeyValue::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_key = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readString(this->key); - isset_key = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readString(this->value); - this->__isset.value = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_key) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t KeyValue::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("KeyValue"); - - xfer += oprot->writeFieldBegin("key", ::apache::thrift::protocol::T_STRING, 1); - xfer += oprot->writeString(this->key); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.value) { - xfer += oprot->writeFieldBegin("value", ::apache::thrift::protocol::T_STRING, 2); - xfer += oprot->writeString(this->value); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(KeyValue &a, KeyValue &b) { - using ::std::swap; - swap(a.key, b.key); - swap(a.value, b.value); - swap(a.__isset, b.__isset); -} - -KeyValue::KeyValue(const KeyValue& other73) { - key = other73.key; - value = other73.value; - __isset = other73.__isset; -} -KeyValue& KeyValue::operator=(const KeyValue& other74) { - key = other74.key; - value = other74.value; - __isset = other74.__isset; - return *this; -} -void KeyValue::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "KeyValue("; - out << "key=" << to_string(key); - out << ", " << "value="; (__isset.value ? (out << to_string(value)) : (out << "")); - out << ")"; -} - - -SortingColumn::~SortingColumn() throw() { -} - - -void SortingColumn::__set_column_idx(const int32_t val) { - this->column_idx = val; -} - -void SortingColumn::__set_descending(const bool val) { - this->descending = val; -} - -void SortingColumn::__set_nulls_first(const bool val) { - this->nulls_first = val; -} -std::ostream& operator<<(std::ostream& out, const SortingColumn& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t SortingColumn::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_column_idx = false; - bool isset_descending = false; - bool isset_nulls_first = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->column_idx); - isset_column_idx = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->descending); - isset_descending = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->nulls_first); - isset_nulls_first = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_column_idx) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_descending) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_nulls_first) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t SortingColumn::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("SortingColumn"); - - xfer += oprot->writeFieldBegin("column_idx", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32(this->column_idx); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("descending", ::apache::thrift::protocol::T_BOOL, 2); - xfer += oprot->writeBool(this->descending); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("nulls_first", ::apache::thrift::protocol::T_BOOL, 3); - xfer += oprot->writeBool(this->nulls_first); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(SortingColumn &a, SortingColumn &b) { - using ::std::swap; - swap(a.column_idx, b.column_idx); - swap(a.descending, b.descending); - swap(a.nulls_first, b.nulls_first); -} - -SortingColumn::SortingColumn(const SortingColumn& other75) { - column_idx = other75.column_idx; - descending = other75.descending; - nulls_first = other75.nulls_first; -} -SortingColumn& SortingColumn::operator=(const SortingColumn& other76) { - column_idx = other76.column_idx; - descending = other76.descending; - nulls_first = other76.nulls_first; - return *this; -} -void SortingColumn::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "SortingColumn("; - out << "column_idx=" << to_string(column_idx); - out << ", " << "descending=" << to_string(descending); - out << ", " << "nulls_first=" << to_string(nulls_first); - out << ")"; -} - - -PageEncodingStats::~PageEncodingStats() throw() { -} - - -void PageEncodingStats::__set_page_type(const PageType::type val) { - this->page_type = val; -} - -void PageEncodingStats::__set_encoding(const Encoding::type val) { - this->encoding = val; -} - -void PageEncodingStats::__set_count(const int32_t val) { - this->count = val; -} -std::ostream& operator<<(std::ostream& out, const PageEncodingStats& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t PageEncodingStats::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_page_type = false; - bool isset_encoding = false; - bool isset_count = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast77; - xfer += iprot->readI32(ecast77); - this->page_type = (PageType::type)ecast77; - isset_page_type = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast78; - xfer += iprot->readI32(ecast78); - this->encoding = (Encoding::type)ecast78; - isset_encoding = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->count); - isset_count = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_page_type) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_encoding) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_count) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t PageEncodingStats::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("PageEncodingStats"); - - xfer += oprot->writeFieldBegin("page_type", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32((int32_t)this->page_type); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("encoding", ::apache::thrift::protocol::T_I32, 2); - xfer += oprot->writeI32((int32_t)this->encoding); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("count", ::apache::thrift::protocol::T_I32, 3); - xfer += oprot->writeI32(this->count); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(PageEncodingStats &a, PageEncodingStats &b) { - using ::std::swap; - swap(a.page_type, b.page_type); - swap(a.encoding, b.encoding); - swap(a.count, b.count); -} - -PageEncodingStats::PageEncodingStats(const PageEncodingStats& other79) { - page_type = other79.page_type; - encoding = other79.encoding; - count = other79.count; -} -PageEncodingStats& PageEncodingStats::operator=(const PageEncodingStats& other80) { - page_type = other80.page_type; - encoding = other80.encoding; - count = other80.count; - return *this; -} -void PageEncodingStats::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "PageEncodingStats("; - out << "page_type=" << to_string(page_type); - out << ", " << "encoding=" << to_string(encoding); - out << ", " << "count=" << to_string(count); - out << ")"; -} - - -ColumnMetaData::~ColumnMetaData() throw() { -} - - -void ColumnMetaData::__set_type(const Type::type val) { - this->type = val; -} - -void ColumnMetaData::__set_encodings(const std::vector & val) { - this->encodings = val; -} - -void ColumnMetaData::__set_path_in_schema(const std::vector & val) { - this->path_in_schema = val; -} - -void ColumnMetaData::__set_codec(const CompressionCodec::type val) { - this->codec = val; -} - -void ColumnMetaData::__set_num_values(const int64_t val) { - this->num_values = val; -} - -void ColumnMetaData::__set_total_uncompressed_size(const int64_t val) { - this->total_uncompressed_size = val; -} - -void ColumnMetaData::__set_total_compressed_size(const int64_t val) { - this->total_compressed_size = val; -} - -void ColumnMetaData::__set_key_value_metadata(const std::vector & val) { - this->key_value_metadata = val; -__isset.key_value_metadata = true; -} - -void ColumnMetaData::__set_data_page_offset(const int64_t val) { - this->data_page_offset = val; -} - -void ColumnMetaData::__set_index_page_offset(const int64_t val) { - this->index_page_offset = val; -__isset.index_page_offset = true; -} - -void ColumnMetaData::__set_dictionary_page_offset(const int64_t val) { - this->dictionary_page_offset = val; -__isset.dictionary_page_offset = true; -} - -void ColumnMetaData::__set_statistics(const Statistics& val) { - this->statistics = val; -__isset.statistics = true; -} - -void ColumnMetaData::__set_encoding_stats(const std::vector & val) { - this->encoding_stats = val; -__isset.encoding_stats = true; -} - -void ColumnMetaData::__set_bloom_filter_offset(const int64_t val) { - this->bloom_filter_offset = val; -__isset.bloom_filter_offset = true; -} -std::ostream& operator<<(std::ostream& out, const ColumnMetaData& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t ColumnMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_type = false; - bool isset_encodings = false; - bool isset_path_in_schema = false; - bool isset_codec = false; - bool isset_num_values = false; - bool isset_total_uncompressed_size = false; - bool isset_total_compressed_size = false; - bool isset_data_page_offset = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast81; - xfer += iprot->readI32(ecast81); - this->type = (Type::type)ecast81; - isset_type = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->encodings.clear(); - uint32_t _size82; - ::apache::thrift::protocol::TType _etype85; - xfer += iprot->readListBegin(_etype85, _size82); - this->encodings.resize(_size82); - uint32_t _i86; - for (_i86 = 0; _i86 < _size82; ++_i86) - { - int32_t ecast87; - xfer += iprot->readI32(ecast87); - this->encodings[_i86] = (Encoding::type)ecast87; - } - xfer += iprot->readListEnd(); - } - isset_encodings = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->path_in_schema.clear(); - uint32_t _size88; - ::apache::thrift::protocol::TType _etype91; - xfer += iprot->readListBegin(_etype91, _size88); - this->path_in_schema.resize(_size88); - uint32_t _i92; - for (_i92 = 0; _i92 < _size88; ++_i92) - { - xfer += iprot->readString(this->path_in_schema[_i92]); - } - xfer += iprot->readListEnd(); - } - isset_path_in_schema = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast93; - xfer += iprot->readI32(ecast93); - this->codec = (CompressionCodec::type)ecast93; - isset_codec = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->num_values); - isset_num_values = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->total_uncompressed_size); - isset_total_uncompressed_size = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 7: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->total_compressed_size); - isset_total_compressed_size = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 8: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->key_value_metadata.clear(); - uint32_t _size94; - ::apache::thrift::protocol::TType _etype97; - xfer += iprot->readListBegin(_etype97, _size94); - this->key_value_metadata.resize(_size94); - uint32_t _i98; - for (_i98 = 0; _i98 < _size94; ++_i98) - { - xfer += this->key_value_metadata[_i98].read(iprot); - } - xfer += iprot->readListEnd(); - } - this->__isset.key_value_metadata = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 9: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->data_page_offset); - isset_data_page_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 10: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->index_page_offset); - this->__isset.index_page_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 11: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->dictionary_page_offset); - this->__isset.dictionary_page_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 12: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->statistics.read(iprot); - this->__isset.statistics = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 13: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->encoding_stats.clear(); - uint32_t _size99; - ::apache::thrift::protocol::TType _etype102; - xfer += iprot->readListBegin(_etype102, _size99); - this->encoding_stats.resize(_size99); - uint32_t _i103; - for (_i103 = 0; _i103 < _size99; ++_i103) - { - xfer += this->encoding_stats[_i103].read(iprot); - } - xfer += iprot->readListEnd(); - } - this->__isset.encoding_stats = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 14: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->bloom_filter_offset); - this->__isset.bloom_filter_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_type) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_encodings) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_path_in_schema) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_codec) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_num_values) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_total_uncompressed_size) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_total_compressed_size) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_data_page_offset) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t ColumnMetaData::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("ColumnMetaData"); - - xfer += oprot->writeFieldBegin("type", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32((int32_t)this->type); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("encodings", ::apache::thrift::protocol::T_LIST, 2); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_I32, static_cast(this->encodings.size())); - std::vector ::const_iterator _iter104; - for (_iter104 = this->encodings.begin(); _iter104 != this->encodings.end(); ++_iter104) - { - xfer += oprot->writeI32((int32_t)(*_iter104)); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("path_in_schema", ::apache::thrift::protocol::T_LIST, 3); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast(this->path_in_schema.size())); - std::vector ::const_iterator _iter105; - for (_iter105 = this->path_in_schema.begin(); _iter105 != this->path_in_schema.end(); ++_iter105) - { - xfer += oprot->writeString((*_iter105)); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("codec", ::apache::thrift::protocol::T_I32, 4); - xfer += oprot->writeI32((int32_t)this->codec); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("num_values", ::apache::thrift::protocol::T_I64, 5); - xfer += oprot->writeI64(this->num_values); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("total_uncompressed_size", ::apache::thrift::protocol::T_I64, 6); - xfer += oprot->writeI64(this->total_uncompressed_size); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("total_compressed_size", ::apache::thrift::protocol::T_I64, 7); - xfer += oprot->writeI64(this->total_compressed_size); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.key_value_metadata) { - xfer += oprot->writeFieldBegin("key_value_metadata", ::apache::thrift::protocol::T_LIST, 8); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->key_value_metadata.size())); - std::vector ::const_iterator _iter106; - for (_iter106 = this->key_value_metadata.begin(); _iter106 != this->key_value_metadata.end(); ++_iter106) - { - xfer += (*_iter106).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldBegin("data_page_offset", ::apache::thrift::protocol::T_I64, 9); - xfer += oprot->writeI64(this->data_page_offset); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.index_page_offset) { - xfer += oprot->writeFieldBegin("index_page_offset", ::apache::thrift::protocol::T_I64, 10); - xfer += oprot->writeI64(this->index_page_offset); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.dictionary_page_offset) { - xfer += oprot->writeFieldBegin("dictionary_page_offset", ::apache::thrift::protocol::T_I64, 11); - xfer += oprot->writeI64(this->dictionary_page_offset); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.statistics) { - xfer += oprot->writeFieldBegin("statistics", ::apache::thrift::protocol::T_STRUCT, 12); - xfer += this->statistics.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.encoding_stats) { - xfer += oprot->writeFieldBegin("encoding_stats", ::apache::thrift::protocol::T_LIST, 13); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->encoding_stats.size())); - std::vector ::const_iterator _iter107; - for (_iter107 = this->encoding_stats.begin(); _iter107 != this->encoding_stats.end(); ++_iter107) - { - xfer += (*_iter107).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.bloom_filter_offset) { - xfer += oprot->writeFieldBegin("bloom_filter_offset", ::apache::thrift::protocol::T_I64, 14); - xfer += oprot->writeI64(this->bloom_filter_offset); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(ColumnMetaData &a, ColumnMetaData &b) { - using ::std::swap; - swap(a.type, b.type); - swap(a.encodings, b.encodings); - swap(a.path_in_schema, b.path_in_schema); - swap(a.codec, b.codec); - swap(a.num_values, b.num_values); - swap(a.total_uncompressed_size, b.total_uncompressed_size); - swap(a.total_compressed_size, b.total_compressed_size); - swap(a.key_value_metadata, b.key_value_metadata); - swap(a.data_page_offset, b.data_page_offset); - swap(a.index_page_offset, b.index_page_offset); - swap(a.dictionary_page_offset, b.dictionary_page_offset); - swap(a.statistics, b.statistics); - swap(a.encoding_stats, b.encoding_stats); - swap(a.bloom_filter_offset, b.bloom_filter_offset); - swap(a.__isset, b.__isset); -} - -ColumnMetaData::ColumnMetaData(const ColumnMetaData& other108) { - type = other108.type; - encodings = other108.encodings; - path_in_schema = other108.path_in_schema; - codec = other108.codec; - num_values = other108.num_values; - total_uncompressed_size = other108.total_uncompressed_size; - total_compressed_size = other108.total_compressed_size; - key_value_metadata = other108.key_value_metadata; - data_page_offset = other108.data_page_offset; - index_page_offset = other108.index_page_offset; - dictionary_page_offset = other108.dictionary_page_offset; - statistics = other108.statistics; - encoding_stats = other108.encoding_stats; - bloom_filter_offset = other108.bloom_filter_offset; - __isset = other108.__isset; -} -ColumnMetaData& ColumnMetaData::operator=(const ColumnMetaData& other109) { - type = other109.type; - encodings = other109.encodings; - path_in_schema = other109.path_in_schema; - codec = other109.codec; - num_values = other109.num_values; - total_uncompressed_size = other109.total_uncompressed_size; - total_compressed_size = other109.total_compressed_size; - key_value_metadata = other109.key_value_metadata; - data_page_offset = other109.data_page_offset; - index_page_offset = other109.index_page_offset; - dictionary_page_offset = other109.dictionary_page_offset; - statistics = other109.statistics; - encoding_stats = other109.encoding_stats; - bloom_filter_offset = other109.bloom_filter_offset; - __isset = other109.__isset; - return *this; -} -void ColumnMetaData::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "ColumnMetaData("; - out << "type=" << to_string(type); - out << ", " << "encodings=" << to_string(encodings); - out << ", " << "path_in_schema=" << to_string(path_in_schema); - out << ", " << "codec=" << to_string(codec); - out << ", " << "num_values=" << to_string(num_values); - out << ", " << "total_uncompressed_size=" << to_string(total_uncompressed_size); - out << ", " << "total_compressed_size=" << to_string(total_compressed_size); - out << ", " << "key_value_metadata="; (__isset.key_value_metadata ? (out << to_string(key_value_metadata)) : (out << "")); - out << ", " << "data_page_offset=" << to_string(data_page_offset); - out << ", " << "index_page_offset="; (__isset.index_page_offset ? (out << to_string(index_page_offset)) : (out << "")); - out << ", " << "dictionary_page_offset="; (__isset.dictionary_page_offset ? (out << to_string(dictionary_page_offset)) : (out << "")); - out << ", " << "statistics="; (__isset.statistics ? (out << to_string(statistics)) : (out << "")); - out << ", " << "encoding_stats="; (__isset.encoding_stats ? (out << to_string(encoding_stats)) : (out << "")); - out << ", " << "bloom_filter_offset="; (__isset.bloom_filter_offset ? (out << to_string(bloom_filter_offset)) : (out << "")); - out << ")"; -} - - -EncryptionWithFooterKey::~EncryptionWithFooterKey() throw() { -} - -std::ostream& operator<<(std::ostream& out, const EncryptionWithFooterKey& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t EncryptionWithFooterKey::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t EncryptionWithFooterKey::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("EncryptionWithFooterKey"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(EncryptionWithFooterKey &a, EncryptionWithFooterKey &b) { - using ::std::swap; - (void) a; - (void) b; -} - -EncryptionWithFooterKey::EncryptionWithFooterKey(const EncryptionWithFooterKey& other110) { - (void) other110; -} -EncryptionWithFooterKey& EncryptionWithFooterKey::operator=(const EncryptionWithFooterKey& other111) { - (void) other111; - return *this; -} -void EncryptionWithFooterKey::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "EncryptionWithFooterKey("; - out << ")"; -} - - -EncryptionWithColumnKey::~EncryptionWithColumnKey() throw() { -} - - -void EncryptionWithColumnKey::__set_path_in_schema(const std::vector & val) { - this->path_in_schema = val; -} - -void EncryptionWithColumnKey::__set_key_metadata(const std::string& val) { - this->key_metadata = val; -__isset.key_metadata = true; -} -std::ostream& operator<<(std::ostream& out, const EncryptionWithColumnKey& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t EncryptionWithColumnKey::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_path_in_schema = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->path_in_schema.clear(); - uint32_t _size112; - ::apache::thrift::protocol::TType _etype115; - xfer += iprot->readListBegin(_etype115, _size112); - this->path_in_schema.resize(_size112); - uint32_t _i116; - for (_i116 = 0; _i116 < _size112; ++_i116) - { - xfer += iprot->readString(this->path_in_schema[_i116]); - } - xfer += iprot->readListEnd(); - } - isset_path_in_schema = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->key_metadata); - this->__isset.key_metadata = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_path_in_schema) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t EncryptionWithColumnKey::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("EncryptionWithColumnKey"); - - xfer += oprot->writeFieldBegin("path_in_schema", ::apache::thrift::protocol::T_LIST, 1); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast(this->path_in_schema.size())); - std::vector ::const_iterator _iter117; - for (_iter117 = this->path_in_schema.begin(); _iter117 != this->path_in_schema.end(); ++_iter117) - { - xfer += oprot->writeString((*_iter117)); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - if (this->__isset.key_metadata) { - xfer += oprot->writeFieldBegin("key_metadata", ::apache::thrift::protocol::T_STRING, 2); - xfer += oprot->writeBinary(this->key_metadata); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(EncryptionWithColumnKey &a, EncryptionWithColumnKey &b) { - using ::std::swap; - swap(a.path_in_schema, b.path_in_schema); - swap(a.key_metadata, b.key_metadata); - swap(a.__isset, b.__isset); -} - -EncryptionWithColumnKey::EncryptionWithColumnKey(const EncryptionWithColumnKey& other118) { - path_in_schema = other118.path_in_schema; - key_metadata = other118.key_metadata; - __isset = other118.__isset; -} -EncryptionWithColumnKey& EncryptionWithColumnKey::operator=(const EncryptionWithColumnKey& other119) { - path_in_schema = other119.path_in_schema; - key_metadata = other119.key_metadata; - __isset = other119.__isset; - return *this; -} -void EncryptionWithColumnKey::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "EncryptionWithColumnKey("; - out << "path_in_schema=" << to_string(path_in_schema); - out << ", " << "key_metadata="; (__isset.key_metadata ? (out << to_string(key_metadata)) : (out << "")); - out << ")"; -} - - -ColumnCryptoMetaData::~ColumnCryptoMetaData() throw() { -} - - -void ColumnCryptoMetaData::__set_ENCRYPTION_WITH_FOOTER_KEY(const EncryptionWithFooterKey& val) { - this->ENCRYPTION_WITH_FOOTER_KEY = val; -__isset.ENCRYPTION_WITH_FOOTER_KEY = true; -} - -void ColumnCryptoMetaData::__set_ENCRYPTION_WITH_COLUMN_KEY(const EncryptionWithColumnKey& val) { - this->ENCRYPTION_WITH_COLUMN_KEY = val; -__isset.ENCRYPTION_WITH_COLUMN_KEY = true; -} -std::ostream& operator<<(std::ostream& out, const ColumnCryptoMetaData& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t ColumnCryptoMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->ENCRYPTION_WITH_FOOTER_KEY.read(iprot); - this->__isset.ENCRYPTION_WITH_FOOTER_KEY = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->ENCRYPTION_WITH_COLUMN_KEY.read(iprot); - this->__isset.ENCRYPTION_WITH_COLUMN_KEY = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t ColumnCryptoMetaData::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("ColumnCryptoMetaData"); - - if (this->__isset.ENCRYPTION_WITH_FOOTER_KEY) { - xfer += oprot->writeFieldBegin("ENCRYPTION_WITH_FOOTER_KEY", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->ENCRYPTION_WITH_FOOTER_KEY.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.ENCRYPTION_WITH_COLUMN_KEY) { - xfer += oprot->writeFieldBegin("ENCRYPTION_WITH_COLUMN_KEY", ::apache::thrift::protocol::T_STRUCT, 2); - xfer += this->ENCRYPTION_WITH_COLUMN_KEY.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(ColumnCryptoMetaData &a, ColumnCryptoMetaData &b) { - using ::std::swap; - swap(a.ENCRYPTION_WITH_FOOTER_KEY, b.ENCRYPTION_WITH_FOOTER_KEY); - swap(a.ENCRYPTION_WITH_COLUMN_KEY, b.ENCRYPTION_WITH_COLUMN_KEY); - swap(a.__isset, b.__isset); -} - -ColumnCryptoMetaData::ColumnCryptoMetaData(const ColumnCryptoMetaData& other120) { - ENCRYPTION_WITH_FOOTER_KEY = other120.ENCRYPTION_WITH_FOOTER_KEY; - ENCRYPTION_WITH_COLUMN_KEY = other120.ENCRYPTION_WITH_COLUMN_KEY; - __isset = other120.__isset; -} -ColumnCryptoMetaData& ColumnCryptoMetaData::operator=(const ColumnCryptoMetaData& other121) { - ENCRYPTION_WITH_FOOTER_KEY = other121.ENCRYPTION_WITH_FOOTER_KEY; - ENCRYPTION_WITH_COLUMN_KEY = other121.ENCRYPTION_WITH_COLUMN_KEY; - __isset = other121.__isset; - return *this; -} -void ColumnCryptoMetaData::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "ColumnCryptoMetaData("; - out << "ENCRYPTION_WITH_FOOTER_KEY="; (__isset.ENCRYPTION_WITH_FOOTER_KEY ? (out << to_string(ENCRYPTION_WITH_FOOTER_KEY)) : (out << "")); - out << ", " << "ENCRYPTION_WITH_COLUMN_KEY="; (__isset.ENCRYPTION_WITH_COLUMN_KEY ? (out << to_string(ENCRYPTION_WITH_COLUMN_KEY)) : (out << "")); - out << ")"; -} - - -ColumnChunk::~ColumnChunk() throw() { -} - - -void ColumnChunk::__set_file_path(const std::string& val) { - this->file_path = val; -__isset.file_path = true; -} - -void ColumnChunk::__set_file_offset(const int64_t val) { - this->file_offset = val; -} - -void ColumnChunk::__set_meta_data(const ColumnMetaData& val) { - this->meta_data = val; -__isset.meta_data = true; -} - -void ColumnChunk::__set_offset_index_offset(const int64_t val) { - this->offset_index_offset = val; -__isset.offset_index_offset = true; -} - -void ColumnChunk::__set_offset_index_length(const int32_t val) { - this->offset_index_length = val; -__isset.offset_index_length = true; -} - -void ColumnChunk::__set_column_index_offset(const int64_t val) { - this->column_index_offset = val; -__isset.column_index_offset = true; -} - -void ColumnChunk::__set_column_index_length(const int32_t val) { - this->column_index_length = val; -__isset.column_index_length = true; -} - -void ColumnChunk::__set_crypto_metadata(const ColumnCryptoMetaData& val) { - this->crypto_metadata = val; -__isset.crypto_metadata = true; -} - -void ColumnChunk::__set_encrypted_column_metadata(const std::string& val) { - this->encrypted_column_metadata = val; -__isset.encrypted_column_metadata = true; -} -std::ostream& operator<<(std::ostream& out, const ColumnChunk& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t ColumnChunk::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_file_offset = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readString(this->file_path); - this->__isset.file_path = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->file_offset); - isset_file_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->meta_data.read(iprot); - this->__isset.meta_data = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->offset_index_offset); - this->__isset.offset_index_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->offset_index_length); - this->__isset.offset_index_length = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->column_index_offset); - this->__isset.column_index_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 7: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->column_index_length); - this->__isset.column_index_length = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 8: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->crypto_metadata.read(iprot); - this->__isset.crypto_metadata = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 9: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->encrypted_column_metadata); - this->__isset.encrypted_column_metadata = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_file_offset) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t ColumnChunk::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("ColumnChunk"); - - if (this->__isset.file_path) { - xfer += oprot->writeFieldBegin("file_path", ::apache::thrift::protocol::T_STRING, 1); - xfer += oprot->writeString(this->file_path); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldBegin("file_offset", ::apache::thrift::protocol::T_I64, 2); - xfer += oprot->writeI64(this->file_offset); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.meta_data) { - xfer += oprot->writeFieldBegin("meta_data", ::apache::thrift::protocol::T_STRUCT, 3); - xfer += this->meta_data.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.offset_index_offset) { - xfer += oprot->writeFieldBegin("offset_index_offset", ::apache::thrift::protocol::T_I64, 4); - xfer += oprot->writeI64(this->offset_index_offset); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.offset_index_length) { - xfer += oprot->writeFieldBegin("offset_index_length", ::apache::thrift::protocol::T_I32, 5); - xfer += oprot->writeI32(this->offset_index_length); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.column_index_offset) { - xfer += oprot->writeFieldBegin("column_index_offset", ::apache::thrift::protocol::T_I64, 6); - xfer += oprot->writeI64(this->column_index_offset); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.column_index_length) { - xfer += oprot->writeFieldBegin("column_index_length", ::apache::thrift::protocol::T_I32, 7); - xfer += oprot->writeI32(this->column_index_length); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.crypto_metadata) { - xfer += oprot->writeFieldBegin("crypto_metadata", ::apache::thrift::protocol::T_STRUCT, 8); - xfer += this->crypto_metadata.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.encrypted_column_metadata) { - xfer += oprot->writeFieldBegin("encrypted_column_metadata", ::apache::thrift::protocol::T_STRING, 9); - xfer += oprot->writeBinary(this->encrypted_column_metadata); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(ColumnChunk &a, ColumnChunk &b) { - using ::std::swap; - swap(a.file_path, b.file_path); - swap(a.file_offset, b.file_offset); - swap(a.meta_data, b.meta_data); - swap(a.offset_index_offset, b.offset_index_offset); - swap(a.offset_index_length, b.offset_index_length); - swap(a.column_index_offset, b.column_index_offset); - swap(a.column_index_length, b.column_index_length); - swap(a.crypto_metadata, b.crypto_metadata); - swap(a.encrypted_column_metadata, b.encrypted_column_metadata); - swap(a.__isset, b.__isset); -} - -ColumnChunk::ColumnChunk(const ColumnChunk& other122) { - file_path = other122.file_path; - file_offset = other122.file_offset; - meta_data = other122.meta_data; - offset_index_offset = other122.offset_index_offset; - offset_index_length = other122.offset_index_length; - column_index_offset = other122.column_index_offset; - column_index_length = other122.column_index_length; - crypto_metadata = other122.crypto_metadata; - encrypted_column_metadata = other122.encrypted_column_metadata; - __isset = other122.__isset; -} -ColumnChunk& ColumnChunk::operator=(const ColumnChunk& other123) { - file_path = other123.file_path; - file_offset = other123.file_offset; - meta_data = other123.meta_data; - offset_index_offset = other123.offset_index_offset; - offset_index_length = other123.offset_index_length; - column_index_offset = other123.column_index_offset; - column_index_length = other123.column_index_length; - crypto_metadata = other123.crypto_metadata; - encrypted_column_metadata = other123.encrypted_column_metadata; - __isset = other123.__isset; - return *this; -} -void ColumnChunk::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "ColumnChunk("; - out << "file_path="; (__isset.file_path ? (out << to_string(file_path)) : (out << "")); - out << ", " << "file_offset=" << to_string(file_offset); - out << ", " << "meta_data="; (__isset.meta_data ? (out << to_string(meta_data)) : (out << "")); - out << ", " << "offset_index_offset="; (__isset.offset_index_offset ? (out << to_string(offset_index_offset)) : (out << "")); - out << ", " << "offset_index_length="; (__isset.offset_index_length ? (out << to_string(offset_index_length)) : (out << "")); - out << ", " << "column_index_offset="; (__isset.column_index_offset ? (out << to_string(column_index_offset)) : (out << "")); - out << ", " << "column_index_length="; (__isset.column_index_length ? (out << to_string(column_index_length)) : (out << "")); - out << ", " << "crypto_metadata="; (__isset.crypto_metadata ? (out << to_string(crypto_metadata)) : (out << "")); - out << ", " << "encrypted_column_metadata="; (__isset.encrypted_column_metadata ? (out << to_string(encrypted_column_metadata)) : (out << "")); - out << ")"; -} - - -RowGroup::~RowGroup() throw() { -} - - -void RowGroup::__set_columns(const std::vector & val) { - this->columns = val; -} - -void RowGroup::__set_total_byte_size(const int64_t val) { - this->total_byte_size = val; -} - -void RowGroup::__set_num_rows(const int64_t val) { - this->num_rows = val; -} - -void RowGroup::__set_sorting_columns(const std::vector & val) { - this->sorting_columns = val; -__isset.sorting_columns = true; -} - -void RowGroup::__set_file_offset(const int64_t val) { - this->file_offset = val; -__isset.file_offset = true; -} - -void RowGroup::__set_total_compressed_size(const int64_t val) { - this->total_compressed_size = val; -__isset.total_compressed_size = true; -} - -void RowGroup::__set_ordinal(const int16_t val) { - this->ordinal = val; -__isset.ordinal = true; -} -std::ostream& operator<<(std::ostream& out, const RowGroup& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t RowGroup::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_columns = false; - bool isset_total_byte_size = false; - bool isset_num_rows = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->columns.clear(); - uint32_t _size124; - ::apache::thrift::protocol::TType _etype127; - xfer += iprot->readListBegin(_etype127, _size124); - this->columns.resize(_size124); - uint32_t _i128; - for (_i128 = 0; _i128 < _size124; ++_i128) - { - xfer += this->columns[_i128].read(iprot); - } - xfer += iprot->readListEnd(); - } - isset_columns = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->total_byte_size); - isset_total_byte_size = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->num_rows); - isset_num_rows = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->sorting_columns.clear(); - uint32_t _size129; - ::apache::thrift::protocol::TType _etype132; - xfer += iprot->readListBegin(_etype132, _size129); - this->sorting_columns.resize(_size129); - uint32_t _i133; - for (_i133 = 0; _i133 < _size129; ++_i133) - { - xfer += this->sorting_columns[_i133].read(iprot); - } - xfer += iprot->readListEnd(); - } - this->__isset.sorting_columns = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->file_offset); - this->__isset.file_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->total_compressed_size); - this->__isset.total_compressed_size = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 7: - if (ftype == ::apache::thrift::protocol::T_I16) { - xfer += iprot->readI16(this->ordinal); - this->__isset.ordinal = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_columns) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_total_byte_size) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_num_rows) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t RowGroup::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("RowGroup"); - - xfer += oprot->writeFieldBegin("columns", ::apache::thrift::protocol::T_LIST, 1); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->columns.size())); - std::vector ::const_iterator _iter134; - for (_iter134 = this->columns.begin(); _iter134 != this->columns.end(); ++_iter134) - { - xfer += (*_iter134).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("total_byte_size", ::apache::thrift::protocol::T_I64, 2); - xfer += oprot->writeI64(this->total_byte_size); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("num_rows", ::apache::thrift::protocol::T_I64, 3); - xfer += oprot->writeI64(this->num_rows); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.sorting_columns) { - xfer += oprot->writeFieldBegin("sorting_columns", ::apache::thrift::protocol::T_LIST, 4); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->sorting_columns.size())); - std::vector ::const_iterator _iter135; - for (_iter135 = this->sorting_columns.begin(); _iter135 != this->sorting_columns.end(); ++_iter135) - { - xfer += (*_iter135).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.file_offset) { - xfer += oprot->writeFieldBegin("file_offset", ::apache::thrift::protocol::T_I64, 5); - xfer += oprot->writeI64(this->file_offset); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.total_compressed_size) { - xfer += oprot->writeFieldBegin("total_compressed_size", ::apache::thrift::protocol::T_I64, 6); - xfer += oprot->writeI64(this->total_compressed_size); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.ordinal) { - xfer += oprot->writeFieldBegin("ordinal", ::apache::thrift::protocol::T_I16, 7); - xfer += oprot->writeI16(this->ordinal); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(RowGroup &a, RowGroup &b) { - using ::std::swap; - swap(a.columns, b.columns); - swap(a.total_byte_size, b.total_byte_size); - swap(a.num_rows, b.num_rows); - swap(a.sorting_columns, b.sorting_columns); - swap(a.file_offset, b.file_offset); - swap(a.total_compressed_size, b.total_compressed_size); - swap(a.ordinal, b.ordinal); - swap(a.__isset, b.__isset); -} - -RowGroup::RowGroup(const RowGroup& other136) { - columns = other136.columns; - total_byte_size = other136.total_byte_size; - num_rows = other136.num_rows; - sorting_columns = other136.sorting_columns; - file_offset = other136.file_offset; - total_compressed_size = other136.total_compressed_size; - ordinal = other136.ordinal; - __isset = other136.__isset; -} -RowGroup& RowGroup::operator=(const RowGroup& other137) { - columns = other137.columns; - total_byte_size = other137.total_byte_size; - num_rows = other137.num_rows; - sorting_columns = other137.sorting_columns; - file_offset = other137.file_offset; - total_compressed_size = other137.total_compressed_size; - ordinal = other137.ordinal; - __isset = other137.__isset; - return *this; -} -void RowGroup::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "RowGroup("; - out << "columns=" << to_string(columns); - out << ", " << "total_byte_size=" << to_string(total_byte_size); - out << ", " << "num_rows=" << to_string(num_rows); - out << ", " << "sorting_columns="; (__isset.sorting_columns ? (out << to_string(sorting_columns)) : (out << "")); - out << ", " << "file_offset="; (__isset.file_offset ? (out << to_string(file_offset)) : (out << "")); - out << ", " << "total_compressed_size="; (__isset.total_compressed_size ? (out << to_string(total_compressed_size)) : (out << "")); - out << ", " << "ordinal="; (__isset.ordinal ? (out << to_string(ordinal)) : (out << "")); - out << ")"; -} - - -TypeDefinedOrder::~TypeDefinedOrder() throw() { -} - -std::ostream& operator<<(std::ostream& out, const TypeDefinedOrder& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t TypeDefinedOrder::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - xfer += iprot->skip(ftype); - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t TypeDefinedOrder::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("TypeDefinedOrder"); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(TypeDefinedOrder &a, TypeDefinedOrder &b) { - using ::std::swap; - (void) a; - (void) b; -} - -TypeDefinedOrder::TypeDefinedOrder(const TypeDefinedOrder& other138) { - (void) other138; -} -TypeDefinedOrder& TypeDefinedOrder::operator=(const TypeDefinedOrder& other139) { - (void) other139; - return *this; -} -void TypeDefinedOrder::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "TypeDefinedOrder("; - out << ")"; -} - - -ColumnOrder::~ColumnOrder() throw() { -} - - -void ColumnOrder::__set_TYPE_ORDER(const TypeDefinedOrder& val) { - this->TYPE_ORDER = val; -__isset.TYPE_ORDER = true; -} -std::ostream& operator<<(std::ostream& out, const ColumnOrder& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t ColumnOrder::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->TYPE_ORDER.read(iprot); - this->__isset.TYPE_ORDER = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t ColumnOrder::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("ColumnOrder"); - - if (this->__isset.TYPE_ORDER) { - xfer += oprot->writeFieldBegin("TYPE_ORDER", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->TYPE_ORDER.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(ColumnOrder &a, ColumnOrder &b) { - using ::std::swap; - swap(a.TYPE_ORDER, b.TYPE_ORDER); - swap(a.__isset, b.__isset); -} - -ColumnOrder::ColumnOrder(const ColumnOrder& other140) { - TYPE_ORDER = other140.TYPE_ORDER; - __isset = other140.__isset; -} -ColumnOrder& ColumnOrder::operator=(const ColumnOrder& other141) { - TYPE_ORDER = other141.TYPE_ORDER; - __isset = other141.__isset; - return *this; -} -void ColumnOrder::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "ColumnOrder("; - out << "TYPE_ORDER="; (__isset.TYPE_ORDER ? (out << to_string(TYPE_ORDER)) : (out << "")); - out << ")"; -} - - -PageLocation::~PageLocation() throw() { -} - - -void PageLocation::__set_offset(const int64_t val) { - this->offset = val; -} - -void PageLocation::__set_compressed_page_size(const int32_t val) { - this->compressed_page_size = val; -} - -void PageLocation::__set_first_row_index(const int64_t val) { - this->first_row_index = val; -} -std::ostream& operator<<(std::ostream& out, const PageLocation& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t PageLocation::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_offset = false; - bool isset_compressed_page_size = false; - bool isset_first_row_index = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->offset); - isset_offset = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->compressed_page_size); - isset_compressed_page_size = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->first_row_index); - isset_first_row_index = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_offset) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_compressed_page_size) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_first_row_index) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t PageLocation::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("PageLocation"); - - xfer += oprot->writeFieldBegin("offset", ::apache::thrift::protocol::T_I64, 1); - xfer += oprot->writeI64(this->offset); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("compressed_page_size", ::apache::thrift::protocol::T_I32, 2); - xfer += oprot->writeI32(this->compressed_page_size); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("first_row_index", ::apache::thrift::protocol::T_I64, 3); - xfer += oprot->writeI64(this->first_row_index); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(PageLocation &a, PageLocation &b) { - using ::std::swap; - swap(a.offset, b.offset); - swap(a.compressed_page_size, b.compressed_page_size); - swap(a.first_row_index, b.first_row_index); -} - -PageLocation::PageLocation(const PageLocation& other142) { - offset = other142.offset; - compressed_page_size = other142.compressed_page_size; - first_row_index = other142.first_row_index; -} -PageLocation& PageLocation::operator=(const PageLocation& other143) { - offset = other143.offset; - compressed_page_size = other143.compressed_page_size; - first_row_index = other143.first_row_index; - return *this; -} -void PageLocation::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "PageLocation("; - out << "offset=" << to_string(offset); - out << ", " << "compressed_page_size=" << to_string(compressed_page_size); - out << ", " << "first_row_index=" << to_string(first_row_index); - out << ")"; -} - - -OffsetIndex::~OffsetIndex() throw() { -} - - -void OffsetIndex::__set_page_locations(const std::vector & val) { - this->page_locations = val; -} -std::ostream& operator<<(std::ostream& out, const OffsetIndex& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t OffsetIndex::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_page_locations = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->page_locations.clear(); - uint32_t _size144; - ::apache::thrift::protocol::TType _etype147; - xfer += iprot->readListBegin(_etype147, _size144); - this->page_locations.resize(_size144); - uint32_t _i148; - for (_i148 = 0; _i148 < _size144; ++_i148) - { - xfer += this->page_locations[_i148].read(iprot); - } - xfer += iprot->readListEnd(); - } - isset_page_locations = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_page_locations) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t OffsetIndex::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("OffsetIndex"); - - xfer += oprot->writeFieldBegin("page_locations", ::apache::thrift::protocol::T_LIST, 1); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->page_locations.size())); - std::vector ::const_iterator _iter149; - for (_iter149 = this->page_locations.begin(); _iter149 != this->page_locations.end(); ++_iter149) - { - xfer += (*_iter149).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(OffsetIndex &a, OffsetIndex &b) { - using ::std::swap; - swap(a.page_locations, b.page_locations); -} - -OffsetIndex::OffsetIndex(const OffsetIndex& other150) { - page_locations = other150.page_locations; -} -OffsetIndex& OffsetIndex::operator=(const OffsetIndex& other151) { - page_locations = other151.page_locations; - return *this; -} -void OffsetIndex::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "OffsetIndex("; - out << "page_locations=" << to_string(page_locations); - out << ")"; -} - - -ColumnIndex::~ColumnIndex() throw() { -} - - -void ColumnIndex::__set_null_pages(const std::vector & val) { - this->null_pages = val; -} - -void ColumnIndex::__set_min_values(const std::vector & val) { - this->min_values = val; -} - -void ColumnIndex::__set_max_values(const std::vector & val) { - this->max_values = val; -} - -void ColumnIndex::__set_boundary_order(const BoundaryOrder::type val) { - this->boundary_order = val; -} - -void ColumnIndex::__set_null_counts(const std::vector & val) { - this->null_counts = val; -__isset.null_counts = true; -} -std::ostream& operator<<(std::ostream& out, const ColumnIndex& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t ColumnIndex::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_null_pages = false; - bool isset_min_values = false; - bool isset_max_values = false; - bool isset_boundary_order = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->null_pages.clear(); - uint32_t _size152; - ::apache::thrift::protocol::TType _etype155; - xfer += iprot->readListBegin(_etype155, _size152); - this->null_pages.resize(_size152); - uint32_t _i156; - for (_i156 = 0; _i156 < _size152; ++_i156) - { - xfer += iprot->readBool(this->null_pages[_i156]); - } - xfer += iprot->readListEnd(); - } - isset_null_pages = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->min_values.clear(); - uint32_t _size157; - ::apache::thrift::protocol::TType _etype160; - xfer += iprot->readListBegin(_etype160, _size157); - this->min_values.resize(_size157); - uint32_t _i161; - for (_i161 = 0; _i161 < _size157; ++_i161) - { - xfer += iprot->readBinary(this->min_values[_i161]); - } - xfer += iprot->readListEnd(); - } - isset_min_values = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->max_values.clear(); - uint32_t _size162; - ::apache::thrift::protocol::TType _etype165; - xfer += iprot->readListBegin(_etype165, _size162); - this->max_values.resize(_size162); - uint32_t _i166; - for (_i166 = 0; _i166 < _size162; ++_i166) - { - xfer += iprot->readBinary(this->max_values[_i166]); - } - xfer += iprot->readListEnd(); - } - isset_max_values = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_I32) { - int32_t ecast167; - xfer += iprot->readI32(ecast167); - this->boundary_order = (BoundaryOrder::type)ecast167; - isset_boundary_order = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->null_counts.clear(); - uint32_t _size168; - ::apache::thrift::protocol::TType _etype171; - xfer += iprot->readListBegin(_etype171, _size168); - this->null_counts.resize(_size168); - uint32_t _i172; - for (_i172 = 0; _i172 < _size168; ++_i172) - { - xfer += iprot->readI64(this->null_counts[_i172]); - } - xfer += iprot->readListEnd(); - } - this->__isset.null_counts = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_null_pages) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_min_values) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_max_values) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_boundary_order) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t ColumnIndex::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("ColumnIndex"); - - xfer += oprot->writeFieldBegin("null_pages", ::apache::thrift::protocol::T_LIST, 1); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_BOOL, static_cast(this->null_pages.size())); - std::vector ::const_iterator _iter173; - for (_iter173 = this->null_pages.begin(); _iter173 != this->null_pages.end(); ++_iter173) - { - xfer += oprot->writeBool((*_iter173)); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("min_values", ::apache::thrift::protocol::T_LIST, 2); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast(this->min_values.size())); - std::vector ::const_iterator _iter174; - for (_iter174 = this->min_values.begin(); _iter174 != this->min_values.end(); ++_iter174) - { - xfer += oprot->writeBinary((*_iter174)); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("max_values", ::apache::thrift::protocol::T_LIST, 3); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast(this->max_values.size())); - std::vector ::const_iterator _iter175; - for (_iter175 = this->max_values.begin(); _iter175 != this->max_values.end(); ++_iter175) - { - xfer += oprot->writeBinary((*_iter175)); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("boundary_order", ::apache::thrift::protocol::T_I32, 4); - xfer += oprot->writeI32((int32_t)this->boundary_order); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.null_counts) { - xfer += oprot->writeFieldBegin("null_counts", ::apache::thrift::protocol::T_LIST, 5); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_I64, static_cast(this->null_counts.size())); - std::vector ::const_iterator _iter176; - for (_iter176 = this->null_counts.begin(); _iter176 != this->null_counts.end(); ++_iter176) - { - xfer += oprot->writeI64((*_iter176)); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(ColumnIndex &a, ColumnIndex &b) { - using ::std::swap; - swap(a.null_pages, b.null_pages); - swap(a.min_values, b.min_values); - swap(a.max_values, b.max_values); - swap(a.boundary_order, b.boundary_order); - swap(a.null_counts, b.null_counts); - swap(a.__isset, b.__isset); -} - -ColumnIndex::ColumnIndex(const ColumnIndex& other177) { - null_pages = other177.null_pages; - min_values = other177.min_values; - max_values = other177.max_values; - boundary_order = other177.boundary_order; - null_counts = other177.null_counts; - __isset = other177.__isset; -} -ColumnIndex& ColumnIndex::operator=(const ColumnIndex& other178) { - null_pages = other178.null_pages; - min_values = other178.min_values; - max_values = other178.max_values; - boundary_order = other178.boundary_order; - null_counts = other178.null_counts; - __isset = other178.__isset; - return *this; -} -void ColumnIndex::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "ColumnIndex("; - out << "null_pages=" << to_string(null_pages); - out << ", " << "min_values=" << to_string(min_values); - out << ", " << "max_values=" << to_string(max_values); - out << ", " << "boundary_order=" << to_string(boundary_order); - out << ", " << "null_counts="; (__isset.null_counts ? (out << to_string(null_counts)) : (out << "")); - out << ")"; -} - - -AesGcmV1::~AesGcmV1() throw() { -} - - -void AesGcmV1::__set_aad_prefix(const std::string& val) { - this->aad_prefix = val; -__isset.aad_prefix = true; -} - -void AesGcmV1::__set_aad_file_unique(const std::string& val) { - this->aad_file_unique = val; -__isset.aad_file_unique = true; -} - -void AesGcmV1::__set_supply_aad_prefix(const bool val) { - this->supply_aad_prefix = val; -__isset.supply_aad_prefix = true; -} -std::ostream& operator<<(std::ostream& out, const AesGcmV1& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t AesGcmV1::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->aad_prefix); - this->__isset.aad_prefix = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->aad_file_unique); - this->__isset.aad_file_unique = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->supply_aad_prefix); - this->__isset.supply_aad_prefix = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t AesGcmV1::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("AesGcmV1"); - - if (this->__isset.aad_prefix) { - xfer += oprot->writeFieldBegin("aad_prefix", ::apache::thrift::protocol::T_STRING, 1); - xfer += oprot->writeBinary(this->aad_prefix); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.aad_file_unique) { - xfer += oprot->writeFieldBegin("aad_file_unique", ::apache::thrift::protocol::T_STRING, 2); - xfer += oprot->writeBinary(this->aad_file_unique); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.supply_aad_prefix) { - xfer += oprot->writeFieldBegin("supply_aad_prefix", ::apache::thrift::protocol::T_BOOL, 3); - xfer += oprot->writeBool(this->supply_aad_prefix); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(AesGcmV1 &a, AesGcmV1 &b) { - using ::std::swap; - swap(a.aad_prefix, b.aad_prefix); - swap(a.aad_file_unique, b.aad_file_unique); - swap(a.supply_aad_prefix, b.supply_aad_prefix); - swap(a.__isset, b.__isset); -} - -AesGcmV1::AesGcmV1(const AesGcmV1& other179) { - aad_prefix = other179.aad_prefix; - aad_file_unique = other179.aad_file_unique; - supply_aad_prefix = other179.supply_aad_prefix; - __isset = other179.__isset; -} -AesGcmV1& AesGcmV1::operator=(const AesGcmV1& other180) { - aad_prefix = other180.aad_prefix; - aad_file_unique = other180.aad_file_unique; - supply_aad_prefix = other180.supply_aad_prefix; - __isset = other180.__isset; - return *this; -} -void AesGcmV1::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "AesGcmV1("; - out << "aad_prefix="; (__isset.aad_prefix ? (out << to_string(aad_prefix)) : (out << "")); - out << ", " << "aad_file_unique="; (__isset.aad_file_unique ? (out << to_string(aad_file_unique)) : (out << "")); - out << ", " << "supply_aad_prefix="; (__isset.supply_aad_prefix ? (out << to_string(supply_aad_prefix)) : (out << "")); - out << ")"; -} - - -AesGcmCtrV1::~AesGcmCtrV1() throw() { -} - - -void AesGcmCtrV1::__set_aad_prefix(const std::string& val) { - this->aad_prefix = val; -__isset.aad_prefix = true; -} - -void AesGcmCtrV1::__set_aad_file_unique(const std::string& val) { - this->aad_file_unique = val; -__isset.aad_file_unique = true; -} - -void AesGcmCtrV1::__set_supply_aad_prefix(const bool val) { - this->supply_aad_prefix = val; -__isset.supply_aad_prefix = true; -} -std::ostream& operator<<(std::ostream& out, const AesGcmCtrV1& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t AesGcmCtrV1::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->aad_prefix); - this->__isset.aad_prefix = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->aad_file_unique); - this->__isset.aad_file_unique = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_BOOL) { - xfer += iprot->readBool(this->supply_aad_prefix); - this->__isset.supply_aad_prefix = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t AesGcmCtrV1::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("AesGcmCtrV1"); - - if (this->__isset.aad_prefix) { - xfer += oprot->writeFieldBegin("aad_prefix", ::apache::thrift::protocol::T_STRING, 1); - xfer += oprot->writeBinary(this->aad_prefix); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.aad_file_unique) { - xfer += oprot->writeFieldBegin("aad_file_unique", ::apache::thrift::protocol::T_STRING, 2); - xfer += oprot->writeBinary(this->aad_file_unique); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.supply_aad_prefix) { - xfer += oprot->writeFieldBegin("supply_aad_prefix", ::apache::thrift::protocol::T_BOOL, 3); - xfer += oprot->writeBool(this->supply_aad_prefix); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(AesGcmCtrV1 &a, AesGcmCtrV1 &b) { - using ::std::swap; - swap(a.aad_prefix, b.aad_prefix); - swap(a.aad_file_unique, b.aad_file_unique); - swap(a.supply_aad_prefix, b.supply_aad_prefix); - swap(a.__isset, b.__isset); -} - -AesGcmCtrV1::AesGcmCtrV1(const AesGcmCtrV1& other181) { - aad_prefix = other181.aad_prefix; - aad_file_unique = other181.aad_file_unique; - supply_aad_prefix = other181.supply_aad_prefix; - __isset = other181.__isset; -} -AesGcmCtrV1& AesGcmCtrV1::operator=(const AesGcmCtrV1& other182) { - aad_prefix = other182.aad_prefix; - aad_file_unique = other182.aad_file_unique; - supply_aad_prefix = other182.supply_aad_prefix; - __isset = other182.__isset; - return *this; -} -void AesGcmCtrV1::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "AesGcmCtrV1("; - out << "aad_prefix="; (__isset.aad_prefix ? (out << to_string(aad_prefix)) : (out << "")); - out << ", " << "aad_file_unique="; (__isset.aad_file_unique ? (out << to_string(aad_file_unique)) : (out << "")); - out << ", " << "supply_aad_prefix="; (__isset.supply_aad_prefix ? (out << to_string(supply_aad_prefix)) : (out << "")); - out << ")"; -} - - -EncryptionAlgorithm::~EncryptionAlgorithm() throw() { -} - - -void EncryptionAlgorithm::__set_AES_GCM_V1(const AesGcmV1& val) { - this->AES_GCM_V1 = val; -__isset.AES_GCM_V1 = true; -} - -void EncryptionAlgorithm::__set_AES_GCM_CTR_V1(const AesGcmCtrV1& val) { - this->AES_GCM_CTR_V1 = val; -__isset.AES_GCM_CTR_V1 = true; -} -std::ostream& operator<<(std::ostream& out, const EncryptionAlgorithm& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t EncryptionAlgorithm::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->AES_GCM_V1.read(iprot); - this->__isset.AES_GCM_V1 = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->AES_GCM_CTR_V1.read(iprot); - this->__isset.AES_GCM_CTR_V1 = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - return xfer; -} - -uint32_t EncryptionAlgorithm::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("EncryptionAlgorithm"); - - if (this->__isset.AES_GCM_V1) { - xfer += oprot->writeFieldBegin("AES_GCM_V1", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->AES_GCM_V1.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.AES_GCM_CTR_V1) { - xfer += oprot->writeFieldBegin("AES_GCM_CTR_V1", ::apache::thrift::protocol::T_STRUCT, 2); - xfer += this->AES_GCM_CTR_V1.write(oprot); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(EncryptionAlgorithm &a, EncryptionAlgorithm &b) { - using ::std::swap; - swap(a.AES_GCM_V1, b.AES_GCM_V1); - swap(a.AES_GCM_CTR_V1, b.AES_GCM_CTR_V1); - swap(a.__isset, b.__isset); -} - -EncryptionAlgorithm::EncryptionAlgorithm(const EncryptionAlgorithm& other183) { - AES_GCM_V1 = other183.AES_GCM_V1; - AES_GCM_CTR_V1 = other183.AES_GCM_CTR_V1; - __isset = other183.__isset; -} -EncryptionAlgorithm& EncryptionAlgorithm::operator=(const EncryptionAlgorithm& other184) { - AES_GCM_V1 = other184.AES_GCM_V1; - AES_GCM_CTR_V1 = other184.AES_GCM_CTR_V1; - __isset = other184.__isset; - return *this; -} -void EncryptionAlgorithm::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "EncryptionAlgorithm("; - out << "AES_GCM_V1="; (__isset.AES_GCM_V1 ? (out << to_string(AES_GCM_V1)) : (out << "")); - out << ", " << "AES_GCM_CTR_V1="; (__isset.AES_GCM_CTR_V1 ? (out << to_string(AES_GCM_CTR_V1)) : (out << "")); - out << ")"; -} - - -FileMetaData::~FileMetaData() throw() { -} - - -void FileMetaData::__set_version(const int32_t val) { - this->version = val; -} - -void FileMetaData::__set_schema(const std::vector & val) { - this->schema = val; -} - -void FileMetaData::__set_num_rows(const int64_t val) { - this->num_rows = val; -} - -void FileMetaData::__set_row_groups(const std::vector & val) { - this->row_groups = val; -} - -void FileMetaData::__set_key_value_metadata(const std::vector & val) { - this->key_value_metadata = val; -__isset.key_value_metadata = true; -} - -void FileMetaData::__set_created_by(const std::string& val) { - this->created_by = val; -__isset.created_by = true; -} - -void FileMetaData::__set_column_orders(const std::vector & val) { - this->column_orders = val; -__isset.column_orders = true; -} - -void FileMetaData::__set_encryption_algorithm(const EncryptionAlgorithm& val) { - this->encryption_algorithm = val; -__isset.encryption_algorithm = true; -} - -void FileMetaData::__set_footer_signing_key_metadata(const std::string& val) { - this->footer_signing_key_metadata = val; -__isset.footer_signing_key_metadata = true; -} -std::ostream& operator<<(std::ostream& out, const FileMetaData& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t FileMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_version = false; - bool isset_schema = false; - bool isset_num_rows = false; - bool isset_row_groups = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_I32) { - xfer += iprot->readI32(this->version); - isset_version = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->schema.clear(); - uint32_t _size185; - ::apache::thrift::protocol::TType _etype188; - xfer += iprot->readListBegin(_etype188, _size185); - this->schema.resize(_size185); - uint32_t _i189; - for (_i189 = 0; _i189 < _size185; ++_i189) - { - xfer += this->schema[_i189].read(iprot); - } - xfer += iprot->readListEnd(); - } - isset_schema = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 3: - if (ftype == ::apache::thrift::protocol::T_I64) { - xfer += iprot->readI64(this->num_rows); - isset_num_rows = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 4: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->row_groups.clear(); - uint32_t _size190; - ::apache::thrift::protocol::TType _etype193; - xfer += iprot->readListBegin(_etype193, _size190); - this->row_groups.resize(_size190); - uint32_t _i194; - for (_i194 = 0; _i194 < _size190; ++_i194) - { - xfer += this->row_groups[_i194].read(iprot); - } - xfer += iprot->readListEnd(); - } - isset_row_groups = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 5: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->key_value_metadata.clear(); - uint32_t _size195; - ::apache::thrift::protocol::TType _etype198; - xfer += iprot->readListBegin(_etype198, _size195); - this->key_value_metadata.resize(_size195); - uint32_t _i199; - for (_i199 = 0; _i199 < _size195; ++_i199) - { - xfer += this->key_value_metadata[_i199].read(iprot); - } - xfer += iprot->readListEnd(); - } - this->__isset.key_value_metadata = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 6: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readString(this->created_by); - this->__isset.created_by = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 7: - if (ftype == ::apache::thrift::protocol::T_LIST) { - { - this->column_orders.clear(); - uint32_t _size200; - ::apache::thrift::protocol::TType _etype203; - xfer += iprot->readListBegin(_etype203, _size200); - this->column_orders.resize(_size200); - uint32_t _i204; - for (_i204 = 0; _i204 < _size200; ++_i204) - { - xfer += this->column_orders[_i204].read(iprot); - } - xfer += iprot->readListEnd(); - } - this->__isset.column_orders = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 8: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->encryption_algorithm.read(iprot); - this->__isset.encryption_algorithm = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 9: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->footer_signing_key_metadata); - this->__isset.footer_signing_key_metadata = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_version) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_schema) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_num_rows) - throw TProtocolException(TProtocolException::INVALID_DATA); - if (!isset_row_groups) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t FileMetaData::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("FileMetaData"); - - xfer += oprot->writeFieldBegin("version", ::apache::thrift::protocol::T_I32, 1); - xfer += oprot->writeI32(this->version); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("schema", ::apache::thrift::protocol::T_LIST, 2); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->schema.size())); - std::vector ::const_iterator _iter205; - for (_iter205 = this->schema.begin(); _iter205 != this->schema.end(); ++_iter205) - { - xfer += (*_iter205).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("num_rows", ::apache::thrift::protocol::T_I64, 3); - xfer += oprot->writeI64(this->num_rows); - xfer += oprot->writeFieldEnd(); - - xfer += oprot->writeFieldBegin("row_groups", ::apache::thrift::protocol::T_LIST, 4); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->row_groups.size())); - std::vector ::const_iterator _iter206; - for (_iter206 = this->row_groups.begin(); _iter206 != this->row_groups.end(); ++_iter206) - { - xfer += (*_iter206).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - - if (this->__isset.key_value_metadata) { - xfer += oprot->writeFieldBegin("key_value_metadata", ::apache::thrift::protocol::T_LIST, 5); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->key_value_metadata.size())); - std::vector ::const_iterator _iter207; - for (_iter207 = this->key_value_metadata.begin(); _iter207 != this->key_value_metadata.end(); ++_iter207) - { - xfer += (*_iter207).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.created_by) { - xfer += oprot->writeFieldBegin("created_by", ::apache::thrift::protocol::T_STRING, 6); - xfer += oprot->writeString(this->created_by); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.column_orders) { - xfer += oprot->writeFieldBegin("column_orders", ::apache::thrift::protocol::T_LIST, 7); - { - xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->column_orders.size())); - std::vector ::const_iterator _iter208; - for (_iter208 = this->column_orders.begin(); _iter208 != this->column_orders.end(); ++_iter208) - { - xfer += (*_iter208).write(oprot); - } - xfer += oprot->writeListEnd(); - } - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.encryption_algorithm) { - xfer += oprot->writeFieldBegin("encryption_algorithm", ::apache::thrift::protocol::T_STRUCT, 8); - xfer += this->encryption_algorithm.write(oprot); - xfer += oprot->writeFieldEnd(); - } - if (this->__isset.footer_signing_key_metadata) { - xfer += oprot->writeFieldBegin("footer_signing_key_metadata", ::apache::thrift::protocol::T_STRING, 9); - xfer += oprot->writeBinary(this->footer_signing_key_metadata); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(FileMetaData &a, FileMetaData &b) { - using ::std::swap; - swap(a.version, b.version); - swap(a.schema, b.schema); - swap(a.num_rows, b.num_rows); - swap(a.row_groups, b.row_groups); - swap(a.key_value_metadata, b.key_value_metadata); - swap(a.created_by, b.created_by); - swap(a.column_orders, b.column_orders); - swap(a.encryption_algorithm, b.encryption_algorithm); - swap(a.footer_signing_key_metadata, b.footer_signing_key_metadata); - swap(a.__isset, b.__isset); -} - -FileMetaData::FileMetaData(const FileMetaData& other209) { - version = other209.version; - schema = other209.schema; - num_rows = other209.num_rows; - row_groups = other209.row_groups; - key_value_metadata = other209.key_value_metadata; - created_by = other209.created_by; - column_orders = other209.column_orders; - encryption_algorithm = other209.encryption_algorithm; - footer_signing_key_metadata = other209.footer_signing_key_metadata; - __isset = other209.__isset; -} -FileMetaData& FileMetaData::operator=(const FileMetaData& other210) { - version = other210.version; - schema = other210.schema; - num_rows = other210.num_rows; - row_groups = other210.row_groups; - key_value_metadata = other210.key_value_metadata; - created_by = other210.created_by; - column_orders = other210.column_orders; - encryption_algorithm = other210.encryption_algorithm; - footer_signing_key_metadata = other210.footer_signing_key_metadata; - __isset = other210.__isset; - return *this; -} -void FileMetaData::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "FileMetaData("; - out << "version=" << to_string(version); - out << ", " << "schema=" << to_string(schema); - out << ", " << "num_rows=" << to_string(num_rows); - out << ", " << "row_groups=" << to_string(row_groups); - out << ", " << "key_value_metadata="; (__isset.key_value_metadata ? (out << to_string(key_value_metadata)) : (out << "")); - out << ", " << "created_by="; (__isset.created_by ? (out << to_string(created_by)) : (out << "")); - out << ", " << "column_orders="; (__isset.column_orders ? (out << to_string(column_orders)) : (out << "")); - out << ", " << "encryption_algorithm="; (__isset.encryption_algorithm ? (out << to_string(encryption_algorithm)) : (out << "")); - out << ", " << "footer_signing_key_metadata="; (__isset.footer_signing_key_metadata ? (out << to_string(footer_signing_key_metadata)) : (out << "")); - out << ")"; -} - - -FileCryptoMetaData::~FileCryptoMetaData() throw() { -} - - -void FileCryptoMetaData::__set_encryption_algorithm(const EncryptionAlgorithm& val) { - this->encryption_algorithm = val; -} - -void FileCryptoMetaData::__set_key_metadata(const std::string& val) { - this->key_metadata = val; -__isset.key_metadata = true; -} -std::ostream& operator<<(std::ostream& out, const FileCryptoMetaData& obj) -{ - obj.printTo(out); - return out; -} - - -uint32_t FileCryptoMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { - - ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot); - uint32_t xfer = 0; - std::string fname; - ::apache::thrift::protocol::TType ftype; - int16_t fid; - - xfer += iprot->readStructBegin(fname); - - using ::apache::thrift::protocol::TProtocolException; - - bool isset_encryption_algorithm = false; - - while (true) - { - xfer += iprot->readFieldBegin(fname, ftype, fid); - if (ftype == ::apache::thrift::protocol::T_STOP) { - break; - } - switch (fid) - { - case 1: - if (ftype == ::apache::thrift::protocol::T_STRUCT) { - xfer += this->encryption_algorithm.read(iprot); - isset_encryption_algorithm = true; - } else { - xfer += iprot->skip(ftype); - } - break; - case 2: - if (ftype == ::apache::thrift::protocol::T_STRING) { - xfer += iprot->readBinary(this->key_metadata); - this->__isset.key_metadata = true; - } else { - xfer += iprot->skip(ftype); - } - break; - default: - xfer += iprot->skip(ftype); - break; - } - xfer += iprot->readFieldEnd(); - } - - xfer += iprot->readStructEnd(); - - if (!isset_encryption_algorithm) - throw TProtocolException(TProtocolException::INVALID_DATA); - return xfer; -} - -uint32_t FileCryptoMetaData::write(::apache::thrift::protocol::TProtocol* oprot) const { - uint32_t xfer = 0; - ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot); - xfer += oprot->writeStructBegin("FileCryptoMetaData"); - - xfer += oprot->writeFieldBegin("encryption_algorithm", ::apache::thrift::protocol::T_STRUCT, 1); - xfer += this->encryption_algorithm.write(oprot); - xfer += oprot->writeFieldEnd(); - - if (this->__isset.key_metadata) { - xfer += oprot->writeFieldBegin("key_metadata", ::apache::thrift::protocol::T_STRING, 2); - xfer += oprot->writeBinary(this->key_metadata); - xfer += oprot->writeFieldEnd(); - } - xfer += oprot->writeFieldStop(); - xfer += oprot->writeStructEnd(); - return xfer; -} - -void swap(FileCryptoMetaData &a, FileCryptoMetaData &b) { - using ::std::swap; - swap(a.encryption_algorithm, b.encryption_algorithm); - swap(a.key_metadata, b.key_metadata); - swap(a.__isset, b.__isset); -} - -FileCryptoMetaData::FileCryptoMetaData(const FileCryptoMetaData& other211) { - encryption_algorithm = other211.encryption_algorithm; - key_metadata = other211.key_metadata; - __isset = other211.__isset; -} -FileCryptoMetaData& FileCryptoMetaData::operator=(const FileCryptoMetaData& other212) { - encryption_algorithm = other212.encryption_algorithm; - key_metadata = other212.key_metadata; - __isset = other212.__isset; - return *this; -} -void FileCryptoMetaData::printTo(std::ostream& out) const { - using ::apache::thrift::to_string; - out << "FileCryptoMetaData("; - out << "encryption_algorithm=" << to_string(encryption_algorithm); - out << ", " << "key_metadata="; (__isset.key_metadata ? (out << to_string(key_metadata)) : (out << "")); - out << ")"; -} - -}} // namespace diff --git a/third_party/parquet/parquet_types.h b/third_party/parquet/parquet_types.h deleted file mode 100644 index 78ade1312..000000000 --- a/third_party/parquet/parquet_types.h +++ /dev/null @@ -1,2901 +0,0 @@ -/** - * Autogenerated by Thrift Compiler (0.12.0) - * - * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING - * @generated - */ -#ifndef parquet_TYPES_H -#define parquet_TYPES_H - -#include - -#include -#include -#include -#include -#include - -#include - -#include "parquet/windows_compatibility.h" - -namespace parquet { namespace format { - -struct Type { - enum type { - BOOLEAN = 0, - INT32 = 1, - INT64 = 2, - INT96 = 3, - FLOAT = 4, - DOUBLE = 5, - BYTE_ARRAY = 6, - FIXED_LEN_BYTE_ARRAY = 7 - }; -}; - -extern const std::map _Type_VALUES_TO_NAMES; - -std::ostream& operator<<(std::ostream& out, const Type::type& val); - -struct ConvertedType { - enum type { - UTF8 = 0, - MAP = 1, - MAP_KEY_VALUE = 2, - LIST = 3, - ENUM = 4, - DECIMAL = 5, - DATE = 6, - TIME_MILLIS = 7, - TIME_MICROS = 8, - TIMESTAMP_MILLIS = 9, - TIMESTAMP_MICROS = 10, - UINT_8 = 11, - UINT_16 = 12, - UINT_32 = 13, - UINT_64 = 14, - INT_8 = 15, - INT_16 = 16, - INT_32 = 17, - INT_64 = 18, - JSON = 19, - BSON = 20, - INTERVAL = 21 - }; -}; - -extern const std::map _ConvertedType_VALUES_TO_NAMES; - -std::ostream& operator<<(std::ostream& out, const ConvertedType::type& val); - -struct FieldRepetitionType { - enum type { - REQUIRED = 0, - OPTIONAL = 1, - REPEATED = 2 - }; -}; - -extern const std::map _FieldRepetitionType_VALUES_TO_NAMES; - -std::ostream& operator<<(std::ostream& out, const FieldRepetitionType::type& val); - -struct Encoding { - enum type { - PLAIN = 0, - PLAIN_DICTIONARY = 2, - RLE = 3, - BIT_PACKED = 4, - DELTA_BINARY_PACKED = 5, - DELTA_LENGTH_BYTE_ARRAY = 6, - DELTA_BYTE_ARRAY = 7, - RLE_DICTIONARY = 8, - BYTE_STREAM_SPLIT = 9 - }; -}; - -extern const std::map _Encoding_VALUES_TO_NAMES; - -std::ostream& operator<<(std::ostream& out, const Encoding::type& val); - -struct CompressionCodec { - enum type { - UNCOMPRESSED = 0, - SNAPPY = 1, - GZIP = 2, - LZO = 3, - BROTLI = 4, - LZ4 = 5, - ZSTD = 6 - }; -}; - -extern const std::map _CompressionCodec_VALUES_TO_NAMES; - -std::ostream& operator<<(std::ostream& out, const CompressionCodec::type& val); - -struct PageType { - enum type { - DATA_PAGE = 0, - INDEX_PAGE = 1, - DICTIONARY_PAGE = 2, - DATA_PAGE_V2 = 3 - }; -}; - -extern const std::map _PageType_VALUES_TO_NAMES; - -std::ostream& operator<<(std::ostream& out, const PageType::type& val); - -struct BoundaryOrder { - enum type { - UNORDERED = 0, - ASCENDING = 1, - DESCENDING = 2 - }; -}; - -extern const std::map _BoundaryOrder_VALUES_TO_NAMES; - -std::ostream& operator<<(std::ostream& out, const BoundaryOrder::type& val); - -class Statistics; - -class StringType; - -class UUIDType; - -class MapType; - -class ListType; - -class EnumType; - -class DateType; - -class NullType; - -class DecimalType; - -class MilliSeconds; - -class MicroSeconds; - -class NanoSeconds; - -class TimeUnit; - -class TimestampType; - -class TimeType; - -class IntType; - -class JsonType; - -class BsonType; - -class LogicalType; - -class SchemaElement; - -class DataPageHeader; - -class IndexPageHeader; - -class DictionaryPageHeader; - -class DataPageHeaderV2; - -class SplitBlockAlgorithm; - -class BloomFilterAlgorithm; - -class XxHash; - -class BloomFilterHash; - -class Uncompressed; - -class BloomFilterCompression; - -class BloomFilterHeader; - -class PageHeader; - -class KeyValue; - -class SortingColumn; - -class PageEncodingStats; - -class ColumnMetaData; - -class EncryptionWithFooterKey; - -class EncryptionWithColumnKey; - -class ColumnCryptoMetaData; - -class ColumnChunk; - -class RowGroup; - -class TypeDefinedOrder; - -class ColumnOrder; - -class PageLocation; - -class OffsetIndex; - -class ColumnIndex; - -class AesGcmV1; - -class AesGcmCtrV1; - -class EncryptionAlgorithm; - -class FileMetaData; - -class FileCryptoMetaData; - -typedef struct _Statistics__isset { - _Statistics__isset() : max(false), min(false), null_count(false), distinct_count(false), max_value(false), min_value(false) {} - bool max :1; - bool min :1; - bool null_count :1; - bool distinct_count :1; - bool max_value :1; - bool min_value :1; -} _Statistics__isset; - -class Statistics : public virtual ::apache::thrift::TBase { - public: - - Statistics(const Statistics&); - Statistics& operator=(const Statistics&); - Statistics() : max(), min(), null_count(0), distinct_count(0), max_value(), min_value() { - } - - virtual ~Statistics() throw(); - std::string max; - std::string min; - int64_t null_count; - int64_t distinct_count; - std::string max_value; - std::string min_value; - - _Statistics__isset __isset; - - void __set_max(const std::string& val); - - void __set_min(const std::string& val); - - void __set_null_count(const int64_t val); - - void __set_distinct_count(const int64_t val); - - void __set_max_value(const std::string& val); - - void __set_min_value(const std::string& val); - - bool operator == (const Statistics & rhs) const - { - if (__isset.max != rhs.__isset.max) - return false; - else if (__isset.max && !(max == rhs.max)) - return false; - if (__isset.min != rhs.__isset.min) - return false; - else if (__isset.min && !(min == rhs.min)) - return false; - if (__isset.null_count != rhs.__isset.null_count) - return false; - else if (__isset.null_count && !(null_count == rhs.null_count)) - return false; - if (__isset.distinct_count != rhs.__isset.distinct_count) - return false; - else if (__isset.distinct_count && !(distinct_count == rhs.distinct_count)) - return false; - if (__isset.max_value != rhs.__isset.max_value) - return false; - else if (__isset.max_value && !(max_value == rhs.max_value)) - return false; - if (__isset.min_value != rhs.__isset.min_value) - return false; - else if (__isset.min_value && !(min_value == rhs.min_value)) - return false; - return true; - } - bool operator != (const Statistics &rhs) const { - return !(*this == rhs); - } - - bool operator < (const Statistics & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(Statistics &a, Statistics &b); - -std::ostream& operator<<(std::ostream& out, const Statistics& obj); - - -class StringType : public virtual ::apache::thrift::TBase { - public: - - StringType(const StringType&); - StringType& operator=(const StringType&); - StringType() { - } - - virtual ~StringType() throw(); - - bool operator == (const StringType & /* rhs */) const - { - return true; - } - bool operator != (const StringType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const StringType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(StringType &a, StringType &b); - -std::ostream& operator<<(std::ostream& out, const StringType& obj); - - -class UUIDType : public virtual ::apache::thrift::TBase { - public: - - UUIDType(const UUIDType&); - UUIDType& operator=(const UUIDType&); - UUIDType() { - } - - virtual ~UUIDType() throw(); - - bool operator == (const UUIDType & /* rhs */) const - { - return true; - } - bool operator != (const UUIDType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const UUIDType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(UUIDType &a, UUIDType &b); - -std::ostream& operator<<(std::ostream& out, const UUIDType& obj); - - -class MapType : public virtual ::apache::thrift::TBase { - public: - - MapType(const MapType&); - MapType& operator=(const MapType&); - MapType() { - } - - virtual ~MapType() throw(); - - bool operator == (const MapType & /* rhs */) const - { - return true; - } - bool operator != (const MapType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const MapType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(MapType &a, MapType &b); - -std::ostream& operator<<(std::ostream& out, const MapType& obj); - - -class ListType : public virtual ::apache::thrift::TBase { - public: - - ListType(const ListType&); - ListType& operator=(const ListType&); - ListType() { - } - - virtual ~ListType() throw(); - - bool operator == (const ListType & /* rhs */) const - { - return true; - } - bool operator != (const ListType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const ListType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(ListType &a, ListType &b); - -std::ostream& operator<<(std::ostream& out, const ListType& obj); - - -class EnumType : public virtual ::apache::thrift::TBase { - public: - - EnumType(const EnumType&); - EnumType& operator=(const EnumType&); - EnumType() { - } - - virtual ~EnumType() throw(); - - bool operator == (const EnumType & /* rhs */) const - { - return true; - } - bool operator != (const EnumType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const EnumType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(EnumType &a, EnumType &b); - -std::ostream& operator<<(std::ostream& out, const EnumType& obj); - - -class DateType : public virtual ::apache::thrift::TBase { - public: - - DateType(const DateType&); - DateType& operator=(const DateType&); - DateType() { - } - - virtual ~DateType() throw(); - - bool operator == (const DateType & /* rhs */) const - { - return true; - } - bool operator != (const DateType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const DateType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(DateType &a, DateType &b); - -std::ostream& operator<<(std::ostream& out, const DateType& obj); - - -class NullType : public virtual ::apache::thrift::TBase { - public: - - NullType(const NullType&); - NullType& operator=(const NullType&); - NullType() { - } - - virtual ~NullType() throw(); - - bool operator == (const NullType & /* rhs */) const - { - return true; - } - bool operator != (const NullType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const NullType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(NullType &a, NullType &b); - -std::ostream& operator<<(std::ostream& out, const NullType& obj); - - -class DecimalType : public virtual ::apache::thrift::TBase { - public: - - DecimalType(const DecimalType&); - DecimalType& operator=(const DecimalType&); - DecimalType() : scale(0), precision(0) { - } - - virtual ~DecimalType() throw(); - int32_t scale; - int32_t precision; - - void __set_scale(const int32_t val); - - void __set_precision(const int32_t val); - - bool operator == (const DecimalType & rhs) const - { - if (!(scale == rhs.scale)) - return false; - if (!(precision == rhs.precision)) - return false; - return true; - } - bool operator != (const DecimalType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const DecimalType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(DecimalType &a, DecimalType &b); - -std::ostream& operator<<(std::ostream& out, const DecimalType& obj); - - -class MilliSeconds : public virtual ::apache::thrift::TBase { - public: - - MilliSeconds(const MilliSeconds&); - MilliSeconds& operator=(const MilliSeconds&); - MilliSeconds() { - } - - virtual ~MilliSeconds() throw(); - - bool operator == (const MilliSeconds & /* rhs */) const - { - return true; - } - bool operator != (const MilliSeconds &rhs) const { - return !(*this == rhs); - } - - bool operator < (const MilliSeconds & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(MilliSeconds &a, MilliSeconds &b); - -std::ostream& operator<<(std::ostream& out, const MilliSeconds& obj); - - -class MicroSeconds : public virtual ::apache::thrift::TBase { - public: - - MicroSeconds(const MicroSeconds&); - MicroSeconds& operator=(const MicroSeconds&); - MicroSeconds() { - } - - virtual ~MicroSeconds() throw(); - - bool operator == (const MicroSeconds & /* rhs */) const - { - return true; - } - bool operator != (const MicroSeconds &rhs) const { - return !(*this == rhs); - } - - bool operator < (const MicroSeconds & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(MicroSeconds &a, MicroSeconds &b); - -std::ostream& operator<<(std::ostream& out, const MicroSeconds& obj); - - -class NanoSeconds : public virtual ::apache::thrift::TBase { - public: - - NanoSeconds(const NanoSeconds&); - NanoSeconds& operator=(const NanoSeconds&); - NanoSeconds() { - } - - virtual ~NanoSeconds() throw(); - - bool operator == (const NanoSeconds & /* rhs */) const - { - return true; - } - bool operator != (const NanoSeconds &rhs) const { - return !(*this == rhs); - } - - bool operator < (const NanoSeconds & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(NanoSeconds &a, NanoSeconds &b); - -std::ostream& operator<<(std::ostream& out, const NanoSeconds& obj); - -typedef struct _TimeUnit__isset { - _TimeUnit__isset() : MILLIS(false), MICROS(false), NANOS(false) {} - bool MILLIS :1; - bool MICROS :1; - bool NANOS :1; -} _TimeUnit__isset; - -class TimeUnit : public virtual ::apache::thrift::TBase { - public: - - TimeUnit(const TimeUnit&); - TimeUnit& operator=(const TimeUnit&); - TimeUnit() { - } - - virtual ~TimeUnit() throw(); - MilliSeconds MILLIS; - MicroSeconds MICROS; - NanoSeconds NANOS; - - _TimeUnit__isset __isset; - - void __set_MILLIS(const MilliSeconds& val); - - void __set_MICROS(const MicroSeconds& val); - - void __set_NANOS(const NanoSeconds& val); - - bool operator == (const TimeUnit & rhs) const - { - if (__isset.MILLIS != rhs.__isset.MILLIS) - return false; - else if (__isset.MILLIS && !(MILLIS == rhs.MILLIS)) - return false; - if (__isset.MICROS != rhs.__isset.MICROS) - return false; - else if (__isset.MICROS && !(MICROS == rhs.MICROS)) - return false; - if (__isset.NANOS != rhs.__isset.NANOS) - return false; - else if (__isset.NANOS && !(NANOS == rhs.NANOS)) - return false; - return true; - } - bool operator != (const TimeUnit &rhs) const { - return !(*this == rhs); - } - - bool operator < (const TimeUnit & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(TimeUnit &a, TimeUnit &b); - -std::ostream& operator<<(std::ostream& out, const TimeUnit& obj); - - -class TimestampType : public virtual ::apache::thrift::TBase { - public: - - TimestampType(const TimestampType&); - TimestampType& operator=(const TimestampType&); - TimestampType() : isAdjustedToUTC(0) { - } - - virtual ~TimestampType() throw(); - bool isAdjustedToUTC; - TimeUnit unit; - - void __set_isAdjustedToUTC(const bool val); - - void __set_unit(const TimeUnit& val); - - bool operator == (const TimestampType & rhs) const - { - if (!(isAdjustedToUTC == rhs.isAdjustedToUTC)) - return false; - if (!(unit == rhs.unit)) - return false; - return true; - } - bool operator != (const TimestampType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const TimestampType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(TimestampType &a, TimestampType &b); - -std::ostream& operator<<(std::ostream& out, const TimestampType& obj); - - -class TimeType : public virtual ::apache::thrift::TBase { - public: - - TimeType(const TimeType&); - TimeType& operator=(const TimeType&); - TimeType() : isAdjustedToUTC(0) { - } - - virtual ~TimeType() throw(); - bool isAdjustedToUTC; - TimeUnit unit; - - void __set_isAdjustedToUTC(const bool val); - - void __set_unit(const TimeUnit& val); - - bool operator == (const TimeType & rhs) const - { - if (!(isAdjustedToUTC == rhs.isAdjustedToUTC)) - return false; - if (!(unit == rhs.unit)) - return false; - return true; - } - bool operator != (const TimeType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const TimeType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(TimeType &a, TimeType &b); - -std::ostream& operator<<(std::ostream& out, const TimeType& obj); - - -class IntType : public virtual ::apache::thrift::TBase { - public: - - IntType(const IntType&); - IntType& operator=(const IntType&); - IntType() : bitWidth(0), isSigned(0) { - } - - virtual ~IntType() throw(); - int8_t bitWidth; - bool isSigned; - - void __set_bitWidth(const int8_t val); - - void __set_isSigned(const bool val); - - bool operator == (const IntType & rhs) const - { - if (!(bitWidth == rhs.bitWidth)) - return false; - if (!(isSigned == rhs.isSigned)) - return false; - return true; - } - bool operator != (const IntType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const IntType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(IntType &a, IntType &b); - -std::ostream& operator<<(std::ostream& out, const IntType& obj); - - -class JsonType : public virtual ::apache::thrift::TBase { - public: - - JsonType(const JsonType&); - JsonType& operator=(const JsonType&); - JsonType() { - } - - virtual ~JsonType() throw(); - - bool operator == (const JsonType & /* rhs */) const - { - return true; - } - bool operator != (const JsonType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const JsonType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(JsonType &a, JsonType &b); - -std::ostream& operator<<(std::ostream& out, const JsonType& obj); - - -class BsonType : public virtual ::apache::thrift::TBase { - public: - - BsonType(const BsonType&); - BsonType& operator=(const BsonType&); - BsonType() { - } - - virtual ~BsonType() throw(); - - bool operator == (const BsonType & /* rhs */) const - { - return true; - } - bool operator != (const BsonType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const BsonType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(BsonType &a, BsonType &b); - -std::ostream& operator<<(std::ostream& out, const BsonType& obj); - -typedef struct _LogicalType__isset { - _LogicalType__isset() : STRING(false), MAP(false), LIST(false), ENUM(false), DECIMAL(false), DATE(false), TIME(false), TIMESTAMP(false), INTEGER(false), UNKNOWN(false), JSON(false), BSON(false), UUID(false) {} - bool STRING :1; - bool MAP :1; - bool LIST :1; - bool ENUM :1; - bool DECIMAL :1; - bool DATE :1; - bool TIME :1; - bool TIMESTAMP :1; - bool INTEGER :1; - bool UNKNOWN :1; - bool JSON :1; - bool BSON :1; - bool UUID :1; -} _LogicalType__isset; - -class LogicalType : public virtual ::apache::thrift::TBase { - public: - - LogicalType(const LogicalType&); - LogicalType& operator=(const LogicalType&); - LogicalType() { - } - - virtual ~LogicalType() throw(); - StringType STRING; - MapType MAP; - ListType LIST; - EnumType ENUM; - DecimalType DECIMAL; - DateType DATE; - TimeType TIME; - TimestampType TIMESTAMP; - IntType INTEGER; - NullType UNKNOWN; - JsonType JSON; - BsonType BSON; - UUIDType UUID; - - _LogicalType__isset __isset; - - void __set_STRING(const StringType& val); - - void __set_MAP(const MapType& val); - - void __set_LIST(const ListType& val); - - void __set_ENUM(const EnumType& val); - - void __set_DECIMAL(const DecimalType& val); - - void __set_DATE(const DateType& val); - - void __set_TIME(const TimeType& val); - - void __set_TIMESTAMP(const TimestampType& val); - - void __set_INTEGER(const IntType& val); - - void __set_UNKNOWN(const NullType& val); - - void __set_JSON(const JsonType& val); - - void __set_BSON(const BsonType& val); - - void __set_UUID(const UUIDType& val); - - bool operator == (const LogicalType & rhs) const - { - if (__isset.STRING != rhs.__isset.STRING) - return false; - else if (__isset.STRING && !(STRING == rhs.STRING)) - return false; - if (__isset.MAP != rhs.__isset.MAP) - return false; - else if (__isset.MAP && !(MAP == rhs.MAP)) - return false; - if (__isset.LIST != rhs.__isset.LIST) - return false; - else if (__isset.LIST && !(LIST == rhs.LIST)) - return false; - if (__isset.ENUM != rhs.__isset.ENUM) - return false; - else if (__isset.ENUM && !(ENUM == rhs.ENUM)) - return false; - if (__isset.DECIMAL != rhs.__isset.DECIMAL) - return false; - else if (__isset.DECIMAL && !(DECIMAL == rhs.DECIMAL)) - return false; - if (__isset.DATE != rhs.__isset.DATE) - return false; - else if (__isset.DATE && !(DATE == rhs.DATE)) - return false; - if (__isset.TIME != rhs.__isset.TIME) - return false; - else if (__isset.TIME && !(TIME == rhs.TIME)) - return false; - if (__isset.TIMESTAMP != rhs.__isset.TIMESTAMP) - return false; - else if (__isset.TIMESTAMP && !(TIMESTAMP == rhs.TIMESTAMP)) - return false; - if (__isset.INTEGER != rhs.__isset.INTEGER) - return false; - else if (__isset.INTEGER && !(INTEGER == rhs.INTEGER)) - return false; - if (__isset.UNKNOWN != rhs.__isset.UNKNOWN) - return false; - else if (__isset.UNKNOWN && !(UNKNOWN == rhs.UNKNOWN)) - return false; - if (__isset.JSON != rhs.__isset.JSON) - return false; - else if (__isset.JSON && !(JSON == rhs.JSON)) - return false; - if (__isset.BSON != rhs.__isset.BSON) - return false; - else if (__isset.BSON && !(BSON == rhs.BSON)) - return false; - if (__isset.UUID != rhs.__isset.UUID) - return false; - else if (__isset.UUID && !(UUID == rhs.UUID)) - return false; - return true; - } - bool operator != (const LogicalType &rhs) const { - return !(*this == rhs); - } - - bool operator < (const LogicalType & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(LogicalType &a, LogicalType &b); - -std::ostream& operator<<(std::ostream& out, const LogicalType& obj); - -typedef struct _SchemaElement__isset { - _SchemaElement__isset() : type(false), type_length(false), repetition_type(false), num_children(false), converted_type(false), scale(false), precision(false), field_id(false), logicalType(false) {} - bool type :1; - bool type_length :1; - bool repetition_type :1; - bool num_children :1; - bool converted_type :1; - bool scale :1; - bool precision :1; - bool field_id :1; - bool logicalType :1; -} _SchemaElement__isset; - -class SchemaElement : public virtual ::apache::thrift::TBase { - public: - - SchemaElement(const SchemaElement&); - SchemaElement& operator=(const SchemaElement&); - SchemaElement() : type((Type::type)0), type_length(0), repetition_type((FieldRepetitionType::type)0), name(), num_children(0), converted_type((ConvertedType::type)0), scale(0), precision(0), field_id(0) { - } - - virtual ~SchemaElement() throw(); - Type::type type; - int32_t type_length; - FieldRepetitionType::type repetition_type; - std::string name; - int32_t num_children; - ConvertedType::type converted_type; - int32_t scale; - int32_t precision; - int32_t field_id; - LogicalType logicalType; - - _SchemaElement__isset __isset; - - void __set_type(const Type::type val); - - void __set_type_length(const int32_t val); - - void __set_repetition_type(const FieldRepetitionType::type val); - - void __set_name(const std::string& val); - - void __set_num_children(const int32_t val); - - void __set_converted_type(const ConvertedType::type val); - - void __set_scale(const int32_t val); - - void __set_precision(const int32_t val); - - void __set_field_id(const int32_t val); - - void __set_logicalType(const LogicalType& val); - - bool operator == (const SchemaElement & rhs) const - { - if (__isset.type != rhs.__isset.type) - return false; - else if (__isset.type && !(type == rhs.type)) - return false; - if (__isset.type_length != rhs.__isset.type_length) - return false; - else if (__isset.type_length && !(type_length == rhs.type_length)) - return false; - if (__isset.repetition_type != rhs.__isset.repetition_type) - return false; - else if (__isset.repetition_type && !(repetition_type == rhs.repetition_type)) - return false; - if (!(name == rhs.name)) - return false; - if (__isset.num_children != rhs.__isset.num_children) - return false; - else if (__isset.num_children && !(num_children == rhs.num_children)) - return false; - if (__isset.converted_type != rhs.__isset.converted_type) - return false; - else if (__isset.converted_type && !(converted_type == rhs.converted_type)) - return false; - if (__isset.scale != rhs.__isset.scale) - return false; - else if (__isset.scale && !(scale == rhs.scale)) - return false; - if (__isset.precision != rhs.__isset.precision) - return false; - else if (__isset.precision && !(precision == rhs.precision)) - return false; - if (__isset.field_id != rhs.__isset.field_id) - return false; - else if (__isset.field_id && !(field_id == rhs.field_id)) - return false; - if (__isset.logicalType != rhs.__isset.logicalType) - return false; - else if (__isset.logicalType && !(logicalType == rhs.logicalType)) - return false; - return true; - } - bool operator != (const SchemaElement &rhs) const { - return !(*this == rhs); - } - - bool operator < (const SchemaElement & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(SchemaElement &a, SchemaElement &b); - -std::ostream& operator<<(std::ostream& out, const SchemaElement& obj); - -typedef struct _DataPageHeader__isset { - _DataPageHeader__isset() : statistics(false) {} - bool statistics :1; -} _DataPageHeader__isset; - -class DataPageHeader : public virtual ::apache::thrift::TBase { - public: - - DataPageHeader(const DataPageHeader&); - DataPageHeader& operator=(const DataPageHeader&); - DataPageHeader() : num_values(0), encoding((Encoding::type)0), definition_level_encoding((Encoding::type)0), repetition_level_encoding((Encoding::type)0) { - } - - virtual ~DataPageHeader() throw(); - int32_t num_values; - Encoding::type encoding; - Encoding::type definition_level_encoding; - Encoding::type repetition_level_encoding; - Statistics statistics; - - _DataPageHeader__isset __isset; - - void __set_num_values(const int32_t val); - - void __set_encoding(const Encoding::type val); - - void __set_definition_level_encoding(const Encoding::type val); - - void __set_repetition_level_encoding(const Encoding::type val); - - void __set_statistics(const Statistics& val); - - bool operator == (const DataPageHeader & rhs) const - { - if (!(num_values == rhs.num_values)) - return false; - if (!(encoding == rhs.encoding)) - return false; - if (!(definition_level_encoding == rhs.definition_level_encoding)) - return false; - if (!(repetition_level_encoding == rhs.repetition_level_encoding)) - return false; - if (__isset.statistics != rhs.__isset.statistics) - return false; - else if (__isset.statistics && !(statistics == rhs.statistics)) - return false; - return true; - } - bool operator != (const DataPageHeader &rhs) const { - return !(*this == rhs); - } - - bool operator < (const DataPageHeader & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(DataPageHeader &a, DataPageHeader &b); - -std::ostream& operator<<(std::ostream& out, const DataPageHeader& obj); - - -class IndexPageHeader : public virtual ::apache::thrift::TBase { - public: - - IndexPageHeader(const IndexPageHeader&); - IndexPageHeader& operator=(const IndexPageHeader&); - IndexPageHeader() { - } - - virtual ~IndexPageHeader() throw(); - - bool operator == (const IndexPageHeader & /* rhs */) const - { - return true; - } - bool operator != (const IndexPageHeader &rhs) const { - return !(*this == rhs); - } - - bool operator < (const IndexPageHeader & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(IndexPageHeader &a, IndexPageHeader &b); - -std::ostream& operator<<(std::ostream& out, const IndexPageHeader& obj); - -typedef struct _DictionaryPageHeader__isset { - _DictionaryPageHeader__isset() : is_sorted(false) {} - bool is_sorted :1; -} _DictionaryPageHeader__isset; - -class DictionaryPageHeader : public virtual ::apache::thrift::TBase { - public: - - DictionaryPageHeader(const DictionaryPageHeader&); - DictionaryPageHeader& operator=(const DictionaryPageHeader&); - DictionaryPageHeader() : num_values(0), encoding((Encoding::type)0), is_sorted(0) { - } - - virtual ~DictionaryPageHeader() throw(); - int32_t num_values; - Encoding::type encoding; - bool is_sorted; - - _DictionaryPageHeader__isset __isset; - - void __set_num_values(const int32_t val); - - void __set_encoding(const Encoding::type val); - - void __set_is_sorted(const bool val); - - bool operator == (const DictionaryPageHeader & rhs) const - { - if (!(num_values == rhs.num_values)) - return false; - if (!(encoding == rhs.encoding)) - return false; - if (__isset.is_sorted != rhs.__isset.is_sorted) - return false; - else if (__isset.is_sorted && !(is_sorted == rhs.is_sorted)) - return false; - return true; - } - bool operator != (const DictionaryPageHeader &rhs) const { - return !(*this == rhs); - } - - bool operator < (const DictionaryPageHeader & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(DictionaryPageHeader &a, DictionaryPageHeader &b); - -std::ostream& operator<<(std::ostream& out, const DictionaryPageHeader& obj); - -typedef struct _DataPageHeaderV2__isset { - _DataPageHeaderV2__isset() : is_compressed(true), statistics(false) {} - bool is_compressed :1; - bool statistics :1; -} _DataPageHeaderV2__isset; - -class DataPageHeaderV2 : public virtual ::apache::thrift::TBase { - public: - - DataPageHeaderV2(const DataPageHeaderV2&); - DataPageHeaderV2& operator=(const DataPageHeaderV2&); - DataPageHeaderV2() : num_values(0), num_nulls(0), num_rows(0), encoding((Encoding::type)0), definition_levels_byte_length(0), repetition_levels_byte_length(0), is_compressed(true) { - } - - virtual ~DataPageHeaderV2() throw(); - int32_t num_values; - int32_t num_nulls; - int32_t num_rows; - Encoding::type encoding; - int32_t definition_levels_byte_length; - int32_t repetition_levels_byte_length; - bool is_compressed; - Statistics statistics; - - _DataPageHeaderV2__isset __isset; - - void __set_num_values(const int32_t val); - - void __set_num_nulls(const int32_t val); - - void __set_num_rows(const int32_t val); - - void __set_encoding(const Encoding::type val); - - void __set_definition_levels_byte_length(const int32_t val); - - void __set_repetition_levels_byte_length(const int32_t val); - - void __set_is_compressed(const bool val); - - void __set_statistics(const Statistics& val); - - bool operator == (const DataPageHeaderV2 & rhs) const - { - if (!(num_values == rhs.num_values)) - return false; - if (!(num_nulls == rhs.num_nulls)) - return false; - if (!(num_rows == rhs.num_rows)) - return false; - if (!(encoding == rhs.encoding)) - return false; - if (!(definition_levels_byte_length == rhs.definition_levels_byte_length)) - return false; - if (!(repetition_levels_byte_length == rhs.repetition_levels_byte_length)) - return false; - if (__isset.is_compressed != rhs.__isset.is_compressed) - return false; - else if (__isset.is_compressed && !(is_compressed == rhs.is_compressed)) - return false; - if (__isset.statistics != rhs.__isset.statistics) - return false; - else if (__isset.statistics && !(statistics == rhs.statistics)) - return false; - return true; - } - bool operator != (const DataPageHeaderV2 &rhs) const { - return !(*this == rhs); - } - - bool operator < (const DataPageHeaderV2 & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(DataPageHeaderV2 &a, DataPageHeaderV2 &b); - -std::ostream& operator<<(std::ostream& out, const DataPageHeaderV2& obj); - - -class SplitBlockAlgorithm : public virtual ::apache::thrift::TBase { - public: - - SplitBlockAlgorithm(const SplitBlockAlgorithm&); - SplitBlockAlgorithm& operator=(const SplitBlockAlgorithm&); - SplitBlockAlgorithm() { - } - - virtual ~SplitBlockAlgorithm() throw(); - - bool operator == (const SplitBlockAlgorithm & /* rhs */) const - { - return true; - } - bool operator != (const SplitBlockAlgorithm &rhs) const { - return !(*this == rhs); - } - - bool operator < (const SplitBlockAlgorithm & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(SplitBlockAlgorithm &a, SplitBlockAlgorithm &b); - -std::ostream& operator<<(std::ostream& out, const SplitBlockAlgorithm& obj); - -typedef struct _BloomFilterAlgorithm__isset { - _BloomFilterAlgorithm__isset() : BLOCK(false) {} - bool BLOCK :1; -} _BloomFilterAlgorithm__isset; - -class BloomFilterAlgorithm : public virtual ::apache::thrift::TBase { - public: - - BloomFilterAlgorithm(const BloomFilterAlgorithm&); - BloomFilterAlgorithm& operator=(const BloomFilterAlgorithm&); - BloomFilterAlgorithm() { - } - - virtual ~BloomFilterAlgorithm() throw(); - SplitBlockAlgorithm BLOCK; - - _BloomFilterAlgorithm__isset __isset; - - void __set_BLOCK(const SplitBlockAlgorithm& val); - - bool operator == (const BloomFilterAlgorithm & rhs) const - { - if (__isset.BLOCK != rhs.__isset.BLOCK) - return false; - else if (__isset.BLOCK && !(BLOCK == rhs.BLOCK)) - return false; - return true; - } - bool operator != (const BloomFilterAlgorithm &rhs) const { - return !(*this == rhs); - } - - bool operator < (const BloomFilterAlgorithm & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(BloomFilterAlgorithm &a, BloomFilterAlgorithm &b); - -std::ostream& operator<<(std::ostream& out, const BloomFilterAlgorithm& obj); - - -class XxHash : public virtual ::apache::thrift::TBase { - public: - - XxHash(const XxHash&); - XxHash& operator=(const XxHash&); - XxHash() { - } - - virtual ~XxHash() throw(); - - bool operator == (const XxHash & /* rhs */) const - { - return true; - } - bool operator != (const XxHash &rhs) const { - return !(*this == rhs); - } - - bool operator < (const XxHash & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(XxHash &a, XxHash &b); - -std::ostream& operator<<(std::ostream& out, const XxHash& obj); - -typedef struct _BloomFilterHash__isset { - _BloomFilterHash__isset() : XXHASH(false) {} - bool XXHASH :1; -} _BloomFilterHash__isset; - -class BloomFilterHash : public virtual ::apache::thrift::TBase { - public: - - BloomFilterHash(const BloomFilterHash&); - BloomFilterHash& operator=(const BloomFilterHash&); - BloomFilterHash() { - } - - virtual ~BloomFilterHash() throw(); - XxHash XXHASH; - - _BloomFilterHash__isset __isset; - - void __set_XXHASH(const XxHash& val); - - bool operator == (const BloomFilterHash & rhs) const - { - if (__isset.XXHASH != rhs.__isset.XXHASH) - return false; - else if (__isset.XXHASH && !(XXHASH == rhs.XXHASH)) - return false; - return true; - } - bool operator != (const BloomFilterHash &rhs) const { - return !(*this == rhs); - } - - bool operator < (const BloomFilterHash & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(BloomFilterHash &a, BloomFilterHash &b); - -std::ostream& operator<<(std::ostream& out, const BloomFilterHash& obj); - - -class Uncompressed : public virtual ::apache::thrift::TBase { - public: - - Uncompressed(const Uncompressed&); - Uncompressed& operator=(const Uncompressed&); - Uncompressed() { - } - - virtual ~Uncompressed() throw(); - - bool operator == (const Uncompressed & /* rhs */) const - { - return true; - } - bool operator != (const Uncompressed &rhs) const { - return !(*this == rhs); - } - - bool operator < (const Uncompressed & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(Uncompressed &a, Uncompressed &b); - -std::ostream& operator<<(std::ostream& out, const Uncompressed& obj); - -typedef struct _BloomFilterCompression__isset { - _BloomFilterCompression__isset() : UNCOMPRESSED(false) {} - bool UNCOMPRESSED :1; -} _BloomFilterCompression__isset; - -class BloomFilterCompression : public virtual ::apache::thrift::TBase { - public: - - BloomFilterCompression(const BloomFilterCompression&); - BloomFilterCompression& operator=(const BloomFilterCompression&); - BloomFilterCompression() { - } - - virtual ~BloomFilterCompression() throw(); - Uncompressed UNCOMPRESSED; - - _BloomFilterCompression__isset __isset; - - void __set_UNCOMPRESSED(const Uncompressed& val); - - bool operator == (const BloomFilterCompression & rhs) const - { - if (__isset.UNCOMPRESSED != rhs.__isset.UNCOMPRESSED) - return false; - else if (__isset.UNCOMPRESSED && !(UNCOMPRESSED == rhs.UNCOMPRESSED)) - return false; - return true; - } - bool operator != (const BloomFilterCompression &rhs) const { - return !(*this == rhs); - } - - bool operator < (const BloomFilterCompression & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(BloomFilterCompression &a, BloomFilterCompression &b); - -std::ostream& operator<<(std::ostream& out, const BloomFilterCompression& obj); - - -class BloomFilterHeader : public virtual ::apache::thrift::TBase { - public: - - BloomFilterHeader(const BloomFilterHeader&); - BloomFilterHeader& operator=(const BloomFilterHeader&); - BloomFilterHeader() : numBytes(0) { - } - - virtual ~BloomFilterHeader() throw(); - int32_t numBytes; - BloomFilterAlgorithm algorithm; - BloomFilterHash hash; - BloomFilterCompression compression; - - void __set_numBytes(const int32_t val); - - void __set_algorithm(const BloomFilterAlgorithm& val); - - void __set_hash(const BloomFilterHash& val); - - void __set_compression(const BloomFilterCompression& val); - - bool operator == (const BloomFilterHeader & rhs) const - { - if (!(numBytes == rhs.numBytes)) - return false; - if (!(algorithm == rhs.algorithm)) - return false; - if (!(hash == rhs.hash)) - return false; - if (!(compression == rhs.compression)) - return false; - return true; - } - bool operator != (const BloomFilterHeader &rhs) const { - return !(*this == rhs); - } - - bool operator < (const BloomFilterHeader & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(BloomFilterHeader &a, BloomFilterHeader &b); - -std::ostream& operator<<(std::ostream& out, const BloomFilterHeader& obj); - -typedef struct _PageHeader__isset { - _PageHeader__isset() : crc(false), data_page_header(false), index_page_header(false), dictionary_page_header(false), data_page_header_v2(false) {} - bool crc :1; - bool data_page_header :1; - bool index_page_header :1; - bool dictionary_page_header :1; - bool data_page_header_v2 :1; -} _PageHeader__isset; - -class PageHeader : public virtual ::apache::thrift::TBase { - public: - - PageHeader(const PageHeader&); - PageHeader& operator=(const PageHeader&); - PageHeader() : type((PageType::type)0), uncompressed_page_size(0), compressed_page_size(0), crc(0) { - } - - virtual ~PageHeader() throw(); - PageType::type type; - int32_t uncompressed_page_size; - int32_t compressed_page_size; - int32_t crc; - DataPageHeader data_page_header; - IndexPageHeader index_page_header; - DictionaryPageHeader dictionary_page_header; - DataPageHeaderV2 data_page_header_v2; - - _PageHeader__isset __isset; - - void __set_type(const PageType::type val); - - void __set_uncompressed_page_size(const int32_t val); - - void __set_compressed_page_size(const int32_t val); - - void __set_crc(const int32_t val); - - void __set_data_page_header(const DataPageHeader& val); - - void __set_index_page_header(const IndexPageHeader& val); - - void __set_dictionary_page_header(const DictionaryPageHeader& val); - - void __set_data_page_header_v2(const DataPageHeaderV2& val); - - bool operator == (const PageHeader & rhs) const - { - if (!(type == rhs.type)) - return false; - if (!(uncompressed_page_size == rhs.uncompressed_page_size)) - return false; - if (!(compressed_page_size == rhs.compressed_page_size)) - return false; - if (__isset.crc != rhs.__isset.crc) - return false; - else if (__isset.crc && !(crc == rhs.crc)) - return false; - if (__isset.data_page_header != rhs.__isset.data_page_header) - return false; - else if (__isset.data_page_header && !(data_page_header == rhs.data_page_header)) - return false; - if (__isset.index_page_header != rhs.__isset.index_page_header) - return false; - else if (__isset.index_page_header && !(index_page_header == rhs.index_page_header)) - return false; - if (__isset.dictionary_page_header != rhs.__isset.dictionary_page_header) - return false; - else if (__isset.dictionary_page_header && !(dictionary_page_header == rhs.dictionary_page_header)) - return false; - if (__isset.data_page_header_v2 != rhs.__isset.data_page_header_v2) - return false; - else if (__isset.data_page_header_v2 && !(data_page_header_v2 == rhs.data_page_header_v2)) - return false; - return true; - } - bool operator != (const PageHeader &rhs) const { - return !(*this == rhs); - } - - bool operator < (const PageHeader & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(PageHeader &a, PageHeader &b); - -std::ostream& operator<<(std::ostream& out, const PageHeader& obj); - -typedef struct _KeyValue__isset { - _KeyValue__isset() : value(false) {} - bool value :1; -} _KeyValue__isset; - -class KeyValue : public virtual ::apache::thrift::TBase { - public: - - KeyValue(const KeyValue&); - KeyValue& operator=(const KeyValue&); - KeyValue() : key(), value() { - } - - virtual ~KeyValue() throw(); - std::string key; - std::string value; - - _KeyValue__isset __isset; - - void __set_key(const std::string& val); - - void __set_value(const std::string& val); - - bool operator == (const KeyValue & rhs) const - { - if (!(key == rhs.key)) - return false; - if (__isset.value != rhs.__isset.value) - return false; - else if (__isset.value && !(value == rhs.value)) - return false; - return true; - } - bool operator != (const KeyValue &rhs) const { - return !(*this == rhs); - } - - bool operator < (const KeyValue & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(KeyValue &a, KeyValue &b); - -std::ostream& operator<<(std::ostream& out, const KeyValue& obj); - - -class SortingColumn : public virtual ::apache::thrift::TBase { - public: - - SortingColumn(const SortingColumn&); - SortingColumn& operator=(const SortingColumn&); - SortingColumn() : column_idx(0), descending(0), nulls_first(0) { - } - - virtual ~SortingColumn() throw(); - int32_t column_idx; - bool descending; - bool nulls_first; - - void __set_column_idx(const int32_t val); - - void __set_descending(const bool val); - - void __set_nulls_first(const bool val); - - bool operator == (const SortingColumn & rhs) const - { - if (!(column_idx == rhs.column_idx)) - return false; - if (!(descending == rhs.descending)) - return false; - if (!(nulls_first == rhs.nulls_first)) - return false; - return true; - } - bool operator != (const SortingColumn &rhs) const { - return !(*this == rhs); - } - - bool operator < (const SortingColumn & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(SortingColumn &a, SortingColumn &b); - -std::ostream& operator<<(std::ostream& out, const SortingColumn& obj); - - -class PageEncodingStats : public virtual ::apache::thrift::TBase { - public: - - PageEncodingStats(const PageEncodingStats&); - PageEncodingStats& operator=(const PageEncodingStats&); - PageEncodingStats() : page_type((PageType::type)0), encoding((Encoding::type)0), count(0) { - } - - virtual ~PageEncodingStats() throw(); - PageType::type page_type; - Encoding::type encoding; - int32_t count; - - void __set_page_type(const PageType::type val); - - void __set_encoding(const Encoding::type val); - - void __set_count(const int32_t val); - - bool operator == (const PageEncodingStats & rhs) const - { - if (!(page_type == rhs.page_type)) - return false; - if (!(encoding == rhs.encoding)) - return false; - if (!(count == rhs.count)) - return false; - return true; - } - bool operator != (const PageEncodingStats &rhs) const { - return !(*this == rhs); - } - - bool operator < (const PageEncodingStats & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(PageEncodingStats &a, PageEncodingStats &b); - -std::ostream& operator<<(std::ostream& out, const PageEncodingStats& obj); - -typedef struct _ColumnMetaData__isset { - _ColumnMetaData__isset() : key_value_metadata(false), index_page_offset(false), dictionary_page_offset(false), statistics(false), encoding_stats(false), bloom_filter_offset(false) {} - bool key_value_metadata :1; - bool index_page_offset :1; - bool dictionary_page_offset :1; - bool statistics :1; - bool encoding_stats :1; - bool bloom_filter_offset :1; -} _ColumnMetaData__isset; - -class ColumnMetaData : public virtual ::apache::thrift::TBase { - public: - - ColumnMetaData(const ColumnMetaData&); - ColumnMetaData& operator=(const ColumnMetaData&); - ColumnMetaData() : type((Type::type)0), codec((CompressionCodec::type)0), num_values(0), total_uncompressed_size(0), total_compressed_size(0), data_page_offset(0), index_page_offset(0), dictionary_page_offset(0), bloom_filter_offset(0) { - } - - virtual ~ColumnMetaData() throw(); - Type::type type; - std::vector encodings; - std::vector path_in_schema; - CompressionCodec::type codec; - int64_t num_values; - int64_t total_uncompressed_size; - int64_t total_compressed_size; - std::vector key_value_metadata; - int64_t data_page_offset; - int64_t index_page_offset; - int64_t dictionary_page_offset; - Statistics statistics; - std::vector encoding_stats; - int64_t bloom_filter_offset; - - _ColumnMetaData__isset __isset; - - void __set_type(const Type::type val); - - void __set_encodings(const std::vector & val); - - void __set_path_in_schema(const std::vector & val); - - void __set_codec(const CompressionCodec::type val); - - void __set_num_values(const int64_t val); - - void __set_total_uncompressed_size(const int64_t val); - - void __set_total_compressed_size(const int64_t val); - - void __set_key_value_metadata(const std::vector & val); - - void __set_data_page_offset(const int64_t val); - - void __set_index_page_offset(const int64_t val); - - void __set_dictionary_page_offset(const int64_t val); - - void __set_statistics(const Statistics& val); - - void __set_encoding_stats(const std::vector & val); - - void __set_bloom_filter_offset(const int64_t val); - - bool operator == (const ColumnMetaData & rhs) const - { - if (!(type == rhs.type)) - return false; - if (!(encodings == rhs.encodings)) - return false; - if (!(path_in_schema == rhs.path_in_schema)) - return false; - if (!(codec == rhs.codec)) - return false; - if (!(num_values == rhs.num_values)) - return false; - if (!(total_uncompressed_size == rhs.total_uncompressed_size)) - return false; - if (!(total_compressed_size == rhs.total_compressed_size)) - return false; - if (__isset.key_value_metadata != rhs.__isset.key_value_metadata) - return false; - else if (__isset.key_value_metadata && !(key_value_metadata == rhs.key_value_metadata)) - return false; - if (!(data_page_offset == rhs.data_page_offset)) - return false; - if (__isset.index_page_offset != rhs.__isset.index_page_offset) - return false; - else if (__isset.index_page_offset && !(index_page_offset == rhs.index_page_offset)) - return false; - if (__isset.dictionary_page_offset != rhs.__isset.dictionary_page_offset) - return false; - else if (__isset.dictionary_page_offset && !(dictionary_page_offset == rhs.dictionary_page_offset)) - return false; - if (__isset.statistics != rhs.__isset.statistics) - return false; - else if (__isset.statistics && !(statistics == rhs.statistics)) - return false; - if (__isset.encoding_stats != rhs.__isset.encoding_stats) - return false; - else if (__isset.encoding_stats && !(encoding_stats == rhs.encoding_stats)) - return false; - if (__isset.bloom_filter_offset != rhs.__isset.bloom_filter_offset) - return false; - else if (__isset.bloom_filter_offset && !(bloom_filter_offset == rhs.bloom_filter_offset)) - return false; - return true; - } - bool operator != (const ColumnMetaData &rhs) const { - return !(*this == rhs); - } - - bool operator < (const ColumnMetaData & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(ColumnMetaData &a, ColumnMetaData &b); - -std::ostream& operator<<(std::ostream& out, const ColumnMetaData& obj); - - -class EncryptionWithFooterKey : public virtual ::apache::thrift::TBase { - public: - - EncryptionWithFooterKey(const EncryptionWithFooterKey&); - EncryptionWithFooterKey& operator=(const EncryptionWithFooterKey&); - EncryptionWithFooterKey() { - } - - virtual ~EncryptionWithFooterKey() throw(); - - bool operator == (const EncryptionWithFooterKey & /* rhs */) const - { - return true; - } - bool operator != (const EncryptionWithFooterKey &rhs) const { - return !(*this == rhs); - } - - bool operator < (const EncryptionWithFooterKey & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(EncryptionWithFooterKey &a, EncryptionWithFooterKey &b); - -std::ostream& operator<<(std::ostream& out, const EncryptionWithFooterKey& obj); - -typedef struct _EncryptionWithColumnKey__isset { - _EncryptionWithColumnKey__isset() : key_metadata(false) {} - bool key_metadata :1; -} _EncryptionWithColumnKey__isset; - -class EncryptionWithColumnKey : public virtual ::apache::thrift::TBase { - public: - - EncryptionWithColumnKey(const EncryptionWithColumnKey&); - EncryptionWithColumnKey& operator=(const EncryptionWithColumnKey&); - EncryptionWithColumnKey() : key_metadata() { - } - - virtual ~EncryptionWithColumnKey() throw(); - std::vector path_in_schema; - std::string key_metadata; - - _EncryptionWithColumnKey__isset __isset; - - void __set_path_in_schema(const std::vector & val); - - void __set_key_metadata(const std::string& val); - - bool operator == (const EncryptionWithColumnKey & rhs) const - { - if (!(path_in_schema == rhs.path_in_schema)) - return false; - if (__isset.key_metadata != rhs.__isset.key_metadata) - return false; - else if (__isset.key_metadata && !(key_metadata == rhs.key_metadata)) - return false; - return true; - } - bool operator != (const EncryptionWithColumnKey &rhs) const { - return !(*this == rhs); - } - - bool operator < (const EncryptionWithColumnKey & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(EncryptionWithColumnKey &a, EncryptionWithColumnKey &b); - -std::ostream& operator<<(std::ostream& out, const EncryptionWithColumnKey& obj); - -typedef struct _ColumnCryptoMetaData__isset { - _ColumnCryptoMetaData__isset() : ENCRYPTION_WITH_FOOTER_KEY(false), ENCRYPTION_WITH_COLUMN_KEY(false) {} - bool ENCRYPTION_WITH_FOOTER_KEY :1; - bool ENCRYPTION_WITH_COLUMN_KEY :1; -} _ColumnCryptoMetaData__isset; - -class ColumnCryptoMetaData : public virtual ::apache::thrift::TBase { - public: - - ColumnCryptoMetaData(const ColumnCryptoMetaData&); - ColumnCryptoMetaData& operator=(const ColumnCryptoMetaData&); - ColumnCryptoMetaData() { - } - - virtual ~ColumnCryptoMetaData() throw(); - EncryptionWithFooterKey ENCRYPTION_WITH_FOOTER_KEY; - EncryptionWithColumnKey ENCRYPTION_WITH_COLUMN_KEY; - - _ColumnCryptoMetaData__isset __isset; - - void __set_ENCRYPTION_WITH_FOOTER_KEY(const EncryptionWithFooterKey& val); - - void __set_ENCRYPTION_WITH_COLUMN_KEY(const EncryptionWithColumnKey& val); - - bool operator == (const ColumnCryptoMetaData & rhs) const - { - if (__isset.ENCRYPTION_WITH_FOOTER_KEY != rhs.__isset.ENCRYPTION_WITH_FOOTER_KEY) - return false; - else if (__isset.ENCRYPTION_WITH_FOOTER_KEY && !(ENCRYPTION_WITH_FOOTER_KEY == rhs.ENCRYPTION_WITH_FOOTER_KEY)) - return false; - if (__isset.ENCRYPTION_WITH_COLUMN_KEY != rhs.__isset.ENCRYPTION_WITH_COLUMN_KEY) - return false; - else if (__isset.ENCRYPTION_WITH_COLUMN_KEY && !(ENCRYPTION_WITH_COLUMN_KEY == rhs.ENCRYPTION_WITH_COLUMN_KEY)) - return false; - return true; - } - bool operator != (const ColumnCryptoMetaData &rhs) const { - return !(*this == rhs); - } - - bool operator < (const ColumnCryptoMetaData & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(ColumnCryptoMetaData &a, ColumnCryptoMetaData &b); - -std::ostream& operator<<(std::ostream& out, const ColumnCryptoMetaData& obj); - -typedef struct _ColumnChunk__isset { - _ColumnChunk__isset() : file_path(false), meta_data(false), offset_index_offset(false), offset_index_length(false), column_index_offset(false), column_index_length(false), crypto_metadata(false), encrypted_column_metadata(false) {} - bool file_path :1; - bool meta_data :1; - bool offset_index_offset :1; - bool offset_index_length :1; - bool column_index_offset :1; - bool column_index_length :1; - bool crypto_metadata :1; - bool encrypted_column_metadata :1; -} _ColumnChunk__isset; - -class ColumnChunk : public virtual ::apache::thrift::TBase { - public: - - ColumnChunk(const ColumnChunk&); - ColumnChunk& operator=(const ColumnChunk&); - ColumnChunk() : file_path(), file_offset(0), offset_index_offset(0), offset_index_length(0), column_index_offset(0), column_index_length(0), encrypted_column_metadata() { - } - - virtual ~ColumnChunk() throw(); - std::string file_path; - int64_t file_offset; - ColumnMetaData meta_data; - int64_t offset_index_offset; - int32_t offset_index_length; - int64_t column_index_offset; - int32_t column_index_length; - ColumnCryptoMetaData crypto_metadata; - std::string encrypted_column_metadata; - - _ColumnChunk__isset __isset; - - void __set_file_path(const std::string& val); - - void __set_file_offset(const int64_t val); - - void __set_meta_data(const ColumnMetaData& val); - - void __set_offset_index_offset(const int64_t val); - - void __set_offset_index_length(const int32_t val); - - void __set_column_index_offset(const int64_t val); - - void __set_column_index_length(const int32_t val); - - void __set_crypto_metadata(const ColumnCryptoMetaData& val); - - void __set_encrypted_column_metadata(const std::string& val); - - bool operator == (const ColumnChunk & rhs) const - { - if (__isset.file_path != rhs.__isset.file_path) - return false; - else if (__isset.file_path && !(file_path == rhs.file_path)) - return false; - if (!(file_offset == rhs.file_offset)) - return false; - if (__isset.meta_data != rhs.__isset.meta_data) - return false; - else if (__isset.meta_data && !(meta_data == rhs.meta_data)) - return false; - if (__isset.offset_index_offset != rhs.__isset.offset_index_offset) - return false; - else if (__isset.offset_index_offset && !(offset_index_offset == rhs.offset_index_offset)) - return false; - if (__isset.offset_index_length != rhs.__isset.offset_index_length) - return false; - else if (__isset.offset_index_length && !(offset_index_length == rhs.offset_index_length)) - return false; - if (__isset.column_index_offset != rhs.__isset.column_index_offset) - return false; - else if (__isset.column_index_offset && !(column_index_offset == rhs.column_index_offset)) - return false; - if (__isset.column_index_length != rhs.__isset.column_index_length) - return false; - else if (__isset.column_index_length && !(column_index_length == rhs.column_index_length)) - return false; - if (__isset.crypto_metadata != rhs.__isset.crypto_metadata) - return false; - else if (__isset.crypto_metadata && !(crypto_metadata == rhs.crypto_metadata)) - return false; - if (__isset.encrypted_column_metadata != rhs.__isset.encrypted_column_metadata) - return false; - else if (__isset.encrypted_column_metadata && !(encrypted_column_metadata == rhs.encrypted_column_metadata)) - return false; - return true; - } - bool operator != (const ColumnChunk &rhs) const { - return !(*this == rhs); - } - - bool operator < (const ColumnChunk & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(ColumnChunk &a, ColumnChunk &b); - -std::ostream& operator<<(std::ostream& out, const ColumnChunk& obj); - -typedef struct _RowGroup__isset { - _RowGroup__isset() : sorting_columns(false), file_offset(false), total_compressed_size(false), ordinal(false) {} - bool sorting_columns :1; - bool file_offset :1; - bool total_compressed_size :1; - bool ordinal :1; -} _RowGroup__isset; - -class RowGroup : public virtual ::apache::thrift::TBase { - public: - - RowGroup(const RowGroup&); - RowGroup& operator=(const RowGroup&); - RowGroup() : total_byte_size(0), num_rows(0), file_offset(0), total_compressed_size(0), ordinal(0) { - } - - virtual ~RowGroup() throw(); - std::vector columns; - int64_t total_byte_size; - int64_t num_rows; - std::vector sorting_columns; - int64_t file_offset; - int64_t total_compressed_size; - int16_t ordinal; - - _RowGroup__isset __isset; - - void __set_columns(const std::vector & val); - - void __set_total_byte_size(const int64_t val); - - void __set_num_rows(const int64_t val); - - void __set_sorting_columns(const std::vector & val); - - void __set_file_offset(const int64_t val); - - void __set_total_compressed_size(const int64_t val); - - void __set_ordinal(const int16_t val); - - bool operator == (const RowGroup & rhs) const - { - if (!(columns == rhs.columns)) - return false; - if (!(total_byte_size == rhs.total_byte_size)) - return false; - if (!(num_rows == rhs.num_rows)) - return false; - if (__isset.sorting_columns != rhs.__isset.sorting_columns) - return false; - else if (__isset.sorting_columns && !(sorting_columns == rhs.sorting_columns)) - return false; - if (__isset.file_offset != rhs.__isset.file_offset) - return false; - else if (__isset.file_offset && !(file_offset == rhs.file_offset)) - return false; - if (__isset.total_compressed_size != rhs.__isset.total_compressed_size) - return false; - else if (__isset.total_compressed_size && !(total_compressed_size == rhs.total_compressed_size)) - return false; - if (__isset.ordinal != rhs.__isset.ordinal) - return false; - else if (__isset.ordinal && !(ordinal == rhs.ordinal)) - return false; - return true; - } - bool operator != (const RowGroup &rhs) const { - return !(*this == rhs); - } - - bool operator < (const RowGroup & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(RowGroup &a, RowGroup &b); - -std::ostream& operator<<(std::ostream& out, const RowGroup& obj); - - -class TypeDefinedOrder : public virtual ::apache::thrift::TBase { - public: - - TypeDefinedOrder(const TypeDefinedOrder&); - TypeDefinedOrder& operator=(const TypeDefinedOrder&); - TypeDefinedOrder() { - } - - virtual ~TypeDefinedOrder() throw(); - - bool operator == (const TypeDefinedOrder & /* rhs */) const - { - return true; - } - bool operator != (const TypeDefinedOrder &rhs) const { - return !(*this == rhs); - } - - bool operator < (const TypeDefinedOrder & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(TypeDefinedOrder &a, TypeDefinedOrder &b); - -std::ostream& operator<<(std::ostream& out, const TypeDefinedOrder& obj); - -typedef struct _ColumnOrder__isset { - _ColumnOrder__isset() : TYPE_ORDER(false) {} - bool TYPE_ORDER :1; -} _ColumnOrder__isset; - -class ColumnOrder : public virtual ::apache::thrift::TBase { - public: - - ColumnOrder(const ColumnOrder&); - ColumnOrder& operator=(const ColumnOrder&); - ColumnOrder() { - } - - virtual ~ColumnOrder() throw(); - TypeDefinedOrder TYPE_ORDER; - - _ColumnOrder__isset __isset; - - void __set_TYPE_ORDER(const TypeDefinedOrder& val); - - bool operator == (const ColumnOrder & rhs) const - { - if (__isset.TYPE_ORDER != rhs.__isset.TYPE_ORDER) - return false; - else if (__isset.TYPE_ORDER && !(TYPE_ORDER == rhs.TYPE_ORDER)) - return false; - return true; - } - bool operator != (const ColumnOrder &rhs) const { - return !(*this == rhs); - } - - bool operator < (const ColumnOrder & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(ColumnOrder &a, ColumnOrder &b); - -std::ostream& operator<<(std::ostream& out, const ColumnOrder& obj); - - -class PageLocation : public virtual ::apache::thrift::TBase { - public: - - PageLocation(const PageLocation&); - PageLocation& operator=(const PageLocation&); - PageLocation() : offset(0), compressed_page_size(0), first_row_index(0) { - } - - virtual ~PageLocation() throw(); - int64_t offset; - int32_t compressed_page_size; - int64_t first_row_index; - - void __set_offset(const int64_t val); - - void __set_compressed_page_size(const int32_t val); - - void __set_first_row_index(const int64_t val); - - bool operator == (const PageLocation & rhs) const - { - if (!(offset == rhs.offset)) - return false; - if (!(compressed_page_size == rhs.compressed_page_size)) - return false; - if (!(first_row_index == rhs.first_row_index)) - return false; - return true; - } - bool operator != (const PageLocation &rhs) const { - return !(*this == rhs); - } - - bool operator < (const PageLocation & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(PageLocation &a, PageLocation &b); - -std::ostream& operator<<(std::ostream& out, const PageLocation& obj); - - -class OffsetIndex : public virtual ::apache::thrift::TBase { - public: - - OffsetIndex(const OffsetIndex&); - OffsetIndex& operator=(const OffsetIndex&); - OffsetIndex() { - } - - virtual ~OffsetIndex() throw(); - std::vector page_locations; - - void __set_page_locations(const std::vector & val); - - bool operator == (const OffsetIndex & rhs) const - { - if (!(page_locations == rhs.page_locations)) - return false; - return true; - } - bool operator != (const OffsetIndex &rhs) const { - return !(*this == rhs); - } - - bool operator < (const OffsetIndex & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(OffsetIndex &a, OffsetIndex &b); - -std::ostream& operator<<(std::ostream& out, const OffsetIndex& obj); - -typedef struct _ColumnIndex__isset { - _ColumnIndex__isset() : null_counts(false) {} - bool null_counts :1; -} _ColumnIndex__isset; - -class ColumnIndex : public virtual ::apache::thrift::TBase { - public: - - ColumnIndex(const ColumnIndex&); - ColumnIndex& operator=(const ColumnIndex&); - ColumnIndex() : boundary_order((BoundaryOrder::type)0) { - } - - virtual ~ColumnIndex() throw(); - std::vector null_pages; - std::vector min_values; - std::vector max_values; - BoundaryOrder::type boundary_order; - std::vector null_counts; - - _ColumnIndex__isset __isset; - - void __set_null_pages(const std::vector & val); - - void __set_min_values(const std::vector & val); - - void __set_max_values(const std::vector & val); - - void __set_boundary_order(const BoundaryOrder::type val); - - void __set_null_counts(const std::vector & val); - - bool operator == (const ColumnIndex & rhs) const - { - if (!(null_pages == rhs.null_pages)) - return false; - if (!(min_values == rhs.min_values)) - return false; - if (!(max_values == rhs.max_values)) - return false; - if (!(boundary_order == rhs.boundary_order)) - return false; - if (__isset.null_counts != rhs.__isset.null_counts) - return false; - else if (__isset.null_counts && !(null_counts == rhs.null_counts)) - return false; - return true; - } - bool operator != (const ColumnIndex &rhs) const { - return !(*this == rhs); - } - - bool operator < (const ColumnIndex & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(ColumnIndex &a, ColumnIndex &b); - -std::ostream& operator<<(std::ostream& out, const ColumnIndex& obj); - -typedef struct _AesGcmV1__isset { - _AesGcmV1__isset() : aad_prefix(false), aad_file_unique(false), supply_aad_prefix(false) {} - bool aad_prefix :1; - bool aad_file_unique :1; - bool supply_aad_prefix :1; -} _AesGcmV1__isset; - -class AesGcmV1 : public virtual ::apache::thrift::TBase { - public: - - AesGcmV1(const AesGcmV1&); - AesGcmV1& operator=(const AesGcmV1&); - AesGcmV1() : aad_prefix(), aad_file_unique(), supply_aad_prefix(0) { - } - - virtual ~AesGcmV1() throw(); - std::string aad_prefix; - std::string aad_file_unique; - bool supply_aad_prefix; - - _AesGcmV1__isset __isset; - - void __set_aad_prefix(const std::string& val); - - void __set_aad_file_unique(const std::string& val); - - void __set_supply_aad_prefix(const bool val); - - bool operator == (const AesGcmV1 & rhs) const - { - if (__isset.aad_prefix != rhs.__isset.aad_prefix) - return false; - else if (__isset.aad_prefix && !(aad_prefix == rhs.aad_prefix)) - return false; - if (__isset.aad_file_unique != rhs.__isset.aad_file_unique) - return false; - else if (__isset.aad_file_unique && !(aad_file_unique == rhs.aad_file_unique)) - return false; - if (__isset.supply_aad_prefix != rhs.__isset.supply_aad_prefix) - return false; - else if (__isset.supply_aad_prefix && !(supply_aad_prefix == rhs.supply_aad_prefix)) - return false; - return true; - } - bool operator != (const AesGcmV1 &rhs) const { - return !(*this == rhs); - } - - bool operator < (const AesGcmV1 & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(AesGcmV1 &a, AesGcmV1 &b); - -std::ostream& operator<<(std::ostream& out, const AesGcmV1& obj); - -typedef struct _AesGcmCtrV1__isset { - _AesGcmCtrV1__isset() : aad_prefix(false), aad_file_unique(false), supply_aad_prefix(false) {} - bool aad_prefix :1; - bool aad_file_unique :1; - bool supply_aad_prefix :1; -} _AesGcmCtrV1__isset; - -class AesGcmCtrV1 : public virtual ::apache::thrift::TBase { - public: - - AesGcmCtrV1(const AesGcmCtrV1&); - AesGcmCtrV1& operator=(const AesGcmCtrV1&); - AesGcmCtrV1() : aad_prefix(), aad_file_unique(), supply_aad_prefix(0) { - } - - virtual ~AesGcmCtrV1() throw(); - std::string aad_prefix; - std::string aad_file_unique; - bool supply_aad_prefix; - - _AesGcmCtrV1__isset __isset; - - void __set_aad_prefix(const std::string& val); - - void __set_aad_file_unique(const std::string& val); - - void __set_supply_aad_prefix(const bool val); - - bool operator == (const AesGcmCtrV1 & rhs) const - { - if (__isset.aad_prefix != rhs.__isset.aad_prefix) - return false; - else if (__isset.aad_prefix && !(aad_prefix == rhs.aad_prefix)) - return false; - if (__isset.aad_file_unique != rhs.__isset.aad_file_unique) - return false; - else if (__isset.aad_file_unique && !(aad_file_unique == rhs.aad_file_unique)) - return false; - if (__isset.supply_aad_prefix != rhs.__isset.supply_aad_prefix) - return false; - else if (__isset.supply_aad_prefix && !(supply_aad_prefix == rhs.supply_aad_prefix)) - return false; - return true; - } - bool operator != (const AesGcmCtrV1 &rhs) const { - return !(*this == rhs); - } - - bool operator < (const AesGcmCtrV1 & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(AesGcmCtrV1 &a, AesGcmCtrV1 &b); - -std::ostream& operator<<(std::ostream& out, const AesGcmCtrV1& obj); - -typedef struct _EncryptionAlgorithm__isset { - _EncryptionAlgorithm__isset() : AES_GCM_V1(false), AES_GCM_CTR_V1(false) {} - bool AES_GCM_V1 :1; - bool AES_GCM_CTR_V1 :1; -} _EncryptionAlgorithm__isset; - -class EncryptionAlgorithm : public virtual ::apache::thrift::TBase { - public: - - EncryptionAlgorithm(const EncryptionAlgorithm&); - EncryptionAlgorithm& operator=(const EncryptionAlgorithm&); - EncryptionAlgorithm() { - } - - virtual ~EncryptionAlgorithm() throw(); - AesGcmV1 AES_GCM_V1; - AesGcmCtrV1 AES_GCM_CTR_V1; - - _EncryptionAlgorithm__isset __isset; - - void __set_AES_GCM_V1(const AesGcmV1& val); - - void __set_AES_GCM_CTR_V1(const AesGcmCtrV1& val); - - bool operator == (const EncryptionAlgorithm & rhs) const - { - if (__isset.AES_GCM_V1 != rhs.__isset.AES_GCM_V1) - return false; - else if (__isset.AES_GCM_V1 && !(AES_GCM_V1 == rhs.AES_GCM_V1)) - return false; - if (__isset.AES_GCM_CTR_V1 != rhs.__isset.AES_GCM_CTR_V1) - return false; - else if (__isset.AES_GCM_CTR_V1 && !(AES_GCM_CTR_V1 == rhs.AES_GCM_CTR_V1)) - return false; - return true; - } - bool operator != (const EncryptionAlgorithm &rhs) const { - return !(*this == rhs); - } - - bool operator < (const EncryptionAlgorithm & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(EncryptionAlgorithm &a, EncryptionAlgorithm &b); - -std::ostream& operator<<(std::ostream& out, const EncryptionAlgorithm& obj); - -typedef struct _FileMetaData__isset { - _FileMetaData__isset() : key_value_metadata(false), created_by(false), column_orders(false), encryption_algorithm(false), footer_signing_key_metadata(false) {} - bool key_value_metadata :1; - bool created_by :1; - bool column_orders :1; - bool encryption_algorithm :1; - bool footer_signing_key_metadata :1; -} _FileMetaData__isset; - -class FileMetaData : public virtual ::apache::thrift::TBase { - public: - - FileMetaData(const FileMetaData&); - FileMetaData& operator=(const FileMetaData&); - FileMetaData() : version(0), num_rows(0), created_by(), footer_signing_key_metadata() { - } - - virtual ~FileMetaData() throw(); - int32_t version; - std::vector schema; - int64_t num_rows; - std::vector row_groups; - std::vector key_value_metadata; - std::string created_by; - std::vector column_orders; - EncryptionAlgorithm encryption_algorithm; - std::string footer_signing_key_metadata; - - _FileMetaData__isset __isset; - - void __set_version(const int32_t val); - - void __set_schema(const std::vector & val); - - void __set_num_rows(const int64_t val); - - void __set_row_groups(const std::vector & val); - - void __set_key_value_metadata(const std::vector & val); - - void __set_created_by(const std::string& val); - - void __set_column_orders(const std::vector & val); - - void __set_encryption_algorithm(const EncryptionAlgorithm& val); - - void __set_footer_signing_key_metadata(const std::string& val); - - bool operator == (const FileMetaData & rhs) const - { - if (!(version == rhs.version)) - return false; - if (!(schema == rhs.schema)) - return false; - if (!(num_rows == rhs.num_rows)) - return false; - if (!(row_groups == rhs.row_groups)) - return false; - if (__isset.key_value_metadata != rhs.__isset.key_value_metadata) - return false; - else if (__isset.key_value_metadata && !(key_value_metadata == rhs.key_value_metadata)) - return false; - if (__isset.created_by != rhs.__isset.created_by) - return false; - else if (__isset.created_by && !(created_by == rhs.created_by)) - return false; - if (__isset.column_orders != rhs.__isset.column_orders) - return false; - else if (__isset.column_orders && !(column_orders == rhs.column_orders)) - return false; - if (__isset.encryption_algorithm != rhs.__isset.encryption_algorithm) - return false; - else if (__isset.encryption_algorithm && !(encryption_algorithm == rhs.encryption_algorithm)) - return false; - if (__isset.footer_signing_key_metadata != rhs.__isset.footer_signing_key_metadata) - return false; - else if (__isset.footer_signing_key_metadata && !(footer_signing_key_metadata == rhs.footer_signing_key_metadata)) - return false; - return true; - } - bool operator != (const FileMetaData &rhs) const { - return !(*this == rhs); - } - - bool operator < (const FileMetaData & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(FileMetaData &a, FileMetaData &b); - -std::ostream& operator<<(std::ostream& out, const FileMetaData& obj); - -typedef struct _FileCryptoMetaData__isset { - _FileCryptoMetaData__isset() : key_metadata(false) {} - bool key_metadata :1; -} _FileCryptoMetaData__isset; - -class FileCryptoMetaData : public virtual ::apache::thrift::TBase { - public: - - FileCryptoMetaData(const FileCryptoMetaData&); - FileCryptoMetaData& operator=(const FileCryptoMetaData&); - FileCryptoMetaData() : key_metadata() { - } - - virtual ~FileCryptoMetaData() throw(); - EncryptionAlgorithm encryption_algorithm; - std::string key_metadata; - - _FileCryptoMetaData__isset __isset; - - void __set_encryption_algorithm(const EncryptionAlgorithm& val); - - void __set_key_metadata(const std::string& val); - - bool operator == (const FileCryptoMetaData & rhs) const - { - if (!(encryption_algorithm == rhs.encryption_algorithm)) - return false; - if (__isset.key_metadata != rhs.__isset.key_metadata) - return false; - else if (__isset.key_metadata && !(key_metadata == rhs.key_metadata)) - return false; - return true; - } - bool operator != (const FileCryptoMetaData &rhs) const { - return !(*this == rhs); - } - - bool operator < (const FileCryptoMetaData & ) const; - - uint32_t read(::apache::thrift::protocol::TProtocol* iprot); - uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const; - - virtual void printTo(std::ostream& out) const; -}; - -void swap(FileCryptoMetaData &a, FileCryptoMetaData &b); - -std::ostream& operator<<(std::ostream& out, const FileCryptoMetaData& obj); - -}} // namespace - -#endif diff --git a/third_party/proj.BUILD b/third_party/proj.BUILD index 4235e3e01..854d730e3 100644 --- a/third_party/proj.BUILD +++ b/third_party/proj.BUILD @@ -12,8 +12,11 @@ cc_library( "src/*.c", "src/*.cpp", "src/iso19111/*.cpp", + "src/iso19111/operation/*.cpp", + "src/iso19111/operation/*.hpp", "src/projections/*.cpp", "src/transformations/*.cpp", + "src/transformations/*.hpp", "src/conversions/*.cpp", ], exclude = [ diff --git a/third_party/tinyobjloader.BUILD b/third_party/tinyobjloader.BUILD new file mode 100644 index 000000000..0e9f74df4 --- /dev/null +++ b/third_party/tinyobjloader.BUILD @@ -0,0 +1,14 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT license + +cc_library( + name = "tinyobjloader", + srcs = [ + "tiny_obj_loader.cc", + ], + hdrs = [ + "tiny_obj_loader.h", + ], + copts = [], +) diff --git a/third_party/toolchains/tf/BUILD.tpl b/third_party/toolchains/tf/BUILD.tpl index bee021f10..425a96e15 100644 --- a/third_party/toolchains/tf/BUILD.tpl +++ b/third_party/toolchains/tf/BUILD.tpl @@ -4,6 +4,22 @@ cc_library( name = "tf_header_lib", hdrs = [":tf_header_include"], includes = ["include"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "tf_c_header_lib", + hdrs = [":tf_c_header_include"], + include_prefix = "tensorflow/c", + strip_include_prefix = "include_c", visibility = ["//visibility:public"], ) @@ -15,4 +31,5 @@ cc_library( ) %{TF_HEADER_GENRULE} -%{TF_SHARED_LIBRARY_GENRULE} \ No newline at end of file +%{TF_C_HEADER_GENRULE} +%{TF_SHARED_LIBRARY_GENRULE} diff --git a/third_party/toolchains/tf/tf_configure.bzl b/third_party/toolchains/tf/tf_configure.bzl index 18e344388..3565f078b 100644 --- a/third_party/toolchains/tf/tf_configure.bzl +++ b/third_party/toolchains/tf/tf_configure.bzl @@ -147,7 +147,7 @@ def _symlink_genrule_for_dir( if src_dir != None: src_dir = _norm_path(src_dir) dest_dir = _norm_path(dest_dir) - files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) + files = "\n".join(sorted([e for e in _read_dir(repository_ctx, src_dir).splitlines() if ("/external/" not in e) and ("/absl/" not in e)])) # Create a list with the src_dir stripped to use for outputs. if tf_pip_dir_rename_pair_len: @@ -176,6 +176,12 @@ def _tf_pip_impl(repository_ctx): "tf_header_include", tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"], ) + tf_c_header_rule = _symlink_genrule_for_dir( + repository_ctx, + tf_header_dir + "/tensorflow/c/", + "include_c", + "tf_c_header_include", + ) tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME] @@ -192,6 +198,7 @@ def _tf_pip_impl(repository_ctx): _tpl(repository_ctx, "BUILD", { "%{TF_HEADER_GENRULE}": tf_header_rule, + "%{TF_C_HEADER_GENRULE}": tf_c_header_rule, "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule, }) diff --git a/tools/docker/devel.Dockerfile b/tools/docker/devel.Dockerfile index b31f86933..075dc6138 100644 --- a/tools/docker/devel.Dockerfile +++ b/tools/docker/devel.Dockerfile @@ -14,7 +14,7 @@ RUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list && apt ffmpeg \ dnsutils -ARG BAZEL_VERSION=3.1.0 +ARG BAZEL_VERSION=3.7.2 ARG BAZEL_OS=linux RUN curl -sL https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-${BAZEL_OS}-x86_64.sh -o bazel-install.sh && \