From f3028b4c3fb764d1e1489f0566169cf271242c20 Mon Sep 17 00:00:00 2001 From: John Bampton Date: Fri, 20 Sep 2024 03:00:42 +1000 Subject: [PATCH 1/3] [SEDONA-657] pre-commit add Python black hook; format Python with black Standardizes all the Python code to black format https://black.readthedocs.io/en/stable/integrations/source_version_control.html https://www.flake8rules.com/ --- .github/linters/ruff.toml | 4 +- .pre-commit-config.yaml | 4 + docs/usecases/ApacheSedonaCore.ipynb | 88 +- docs/usecases/ApacheSedonaRaster.ipynb | 118 +- docs/usecases/ApacheSedonaSQL.ipynb | 120 +- ...naSQL_SpatialJoin_AirportsPerCountry.ipynb | 61 +- .../Sedona_OvertureMaps_GeoParquet.ipynb | 116 +- .../contrib/ApacheSedonaImageFilter.ipynb | 133 +- .../contrib/DownloadImageFromGEE.ipynb | 40 +- .../contrib/NdviSentinelApacheSedona.ipynb | 265 ++- .../PostgresqlConnectionApacheSedona.ipynb | 57 +- .../contrib/VectorAnalisisApacheSedona.ipynb | 291 ++- docs/usecases/contrib/foot-traffic.ipynb | 555 +++--- docs/usecases/utilities.py | 231 ++- python/sedona/core/SpatialRDD/__init__.py | 7 +- python/sedona/core/SpatialRDD/circle_rdd.py | 17 +- .../sedona/core/SpatialRDD/linestring_rdd.py | 60 +- python/sedona/core/SpatialRDD/point_rdd.py | 63 +- python/sedona/core/SpatialRDD/polygon_rdd.py | 58 +- .../sedona/core/SpatialRDD/rectangle_rdd.py | 60 +- python/sedona/core/SpatialRDD/spatial_rdd.py | 73 +- .../sedona/core/enums/file_data_splitter.py | 6 +- python/sedona/core/enums/grid_type.py | 2 +- python/sedona/core/enums/index_type.py | 6 +- python/sedona/core/enums/spatial.py | 2 +- python/sedona/core/formatMapper/__init__.py | 2 +- python/sedona/core/formatMapper/disc_utils.py | 6 +- .../core/formatMapper/geo_json_reader.py | 30 +- python/sedona/core/formatMapper/geo_reader.py | 3 +- .../shapefileParser/shape_file_reader.py | 20 +- python/sedona/core/formatMapper/wkb_reader.py | 36 +- python/sedona/core/formatMapper/wkt_reader.py | 33 +- python/sedona/core/geom/circle.py | 2 +- python/sedona/core/geom/envelope.py | 2 +- python/sedona/core/geom/shapely1/circle.py | 46 +- python/sedona/core/geom/shapely1/envelope.py | 48 +- python/sedona/core/geom/shapely2/circle.py | 25 +- python/sedona/core/geom/shapely2/envelope.py | 18 +- python/sedona/core/jvm/config.py | 7 +- python/sedona/core/jvm/translate.py | 16 +- .../sedona/core/spatialOperator/__init__.py | 4 +- .../core/spatialOperator/join_params.py | 15 +- .../sedona/core/spatialOperator/join_query.py | 58 +- .../core/spatialOperator/join_query_raw.py | 61 +- .../sedona/core/spatialOperator/knn_query.py | 20 +- .../core/spatialOperator/range_query.py | 13 +- .../core/spatialOperator/range_query_raw.py | 21 +- python/sedona/core/spatialOperator/rdd.py | 15 +- python/sedona/core/utils.py | 1 + python/sedona/exceptions.py | 1 + python/sedona/maps/SedonaMapUtils.py | 16 +- python/sedona/maps/SedonaPyDeck.py | 200 +- python/sedona/raster/awt_raster.py | 13 +- python/sedona/raster/data_buffer.py | 4 +- python/sedona/raster/meta.py | 31 +- python/sedona/raster/raster_serde.py | 93 +- python/sedona/raster/sample_model.py | 42 +- python/sedona/raster/sedona_raster.py | 162 +- python/sedona/raster_utils/SedonaUtils.py | 2 + python/sedona/register/geo_registrator.py | 2 + python/sedona/register/java_libs.py | 22 +- python/sedona/spark/SedonaContext.py | 5 +- python/sedona/sql/__init__.py | 12 +- python/sedona/sql/dataframe_api.py | 42 +- python/sedona/sql/exceptions.py | 1 + python/sedona/sql/st_aggregates.py | 13 +- python/sedona/sql/st_constructors.py | 151 +- python/sedona/sql/st_functions.py | 375 +++- python/sedona/sql/st_predicates.py | 38 +- python/sedona/utils/abstract_parser.py | 4 +- python/sedona/utils/adapter.py | 55 +- python/sedona/utils/binary_parser.py | 22 +- python/sedona/utils/decorators.py | 10 +- python/sedona/utils/geometry_serde.py | 33 +- python/sedona/utils/geometry_serde_general.py | 174 +- python/sedona/utils/jvm.py | 8 +- python/sedona/utils/meta.py | 23 +- python/sedona/utils/prep.py | 21 +- python/sedona/utils/spatial_rdd_parser.py | 40 +- python/setup.py | 48 +- python/tests/__init__.py | 40 +- .../core/test_avoiding_python_jvm_serde_df.py | 140 +- .../test_avoiding_python_jvm_serde_to_rdd.py | 45 +- .../tests/core/test_core_geom_primitives.py | 4 +- python/tests/core/test_core_rdd.py | 15 +- python/tests/core/test_rdd.py | 88 +- .../tests/core/test_spatial_rdd_from_disc.py | 110 +- .../format_mapper/test_geo_json_reader.py | 48 +- .../format_mapper/test_shapefile_reader.py | 50 +- .../maps/test_sedonakepler_visualization.py | 346 ++-- python/tests/maps/test_sedonapydeck.py | 135 +- .../tests/properties/linestring_properties.py | 16 +- python/tests/properties/point_properties.py | 13 +- python/tests/properties/polygon_properties.py | 8 +- python/tests/raster/test_meta.py | 8 +- python/tests/raster/test_pandas_udf.py | 16 +- python/tests/raster/test_serde.py | 73 +- .../raster_viz_utils/test_sedonautils.py | 16 +- .../tests/serialization/test_deserializers.py | 85 +- .../test_direct_serialization.py | 27 +- .../test_geospark_serializers.py | 7 +- .../serialization/test_rdd_serialization.py | 26 +- .../tests/serialization/test_serializers.py | 114 +- .../serialization/test_with_sc_parellize.py | 18 +- .../test_join_query_correctness.py | 231 ++- .../spatial_operator/test_linestring_join.py | 45 +- .../spatial_operator/test_linestring_knn.py | 8 +- .../spatial_operator/test_linestring_range.py | 32 +- .../tests/spatial_operator/test_point_join.py | 90 +- .../tests/spatial_operator/test_point_knn.py | 32 +- .../spatial_operator/test_point_range.py | 17 +- .../spatial_operator/test_polygon_join.py | 35 +- .../spatial_operator/test_polygon_knn.py | 32 +- .../spatial_operator/test_polygon_range.py | 34 +- .../spatial_operator/test_rectangle_join.py | 45 +- .../spatial_operator/test_rectangle_knn.py | 48 +- .../spatial_operator/test_rectangle_range.py | 25 +- python/tests/spatial_rdd/test_circle_rdd.py | 19 +- .../tests/spatial_rdd/test_linestring_rdd.py | 34 +- python/tests/spatial_rdd/test_point_rdd.py | 56 +- python/tests/spatial_rdd/test_polygon_rdd.py | 81 +- .../tests/spatial_rdd/test_rectangle_rdd.py | 17 +- python/tests/spatial_rdd/test_spatial_rdd.py | 21 +- .../spatial_rdd/test_spatial_rdd_writer.py | 18 +- python/tests/sql/resource/sample_data.py | 28 +- python/tests/sql/test_adapter.py | 257 ++- python/tests/sql/test_aggregate_functions.py | 38 +- python/tests/sql/test_constructor_test.py | 271 ++- python/tests/sql/test_dataframe_api.py | 1377 +++++++++++-- python/tests/sql/test_function.py | 1754 +++++++++++------ python/tests/sql/test_geoparquet.py | 64 +- python/tests/sql/test_predicate.py | 221 ++- python/tests/sql/test_predicate_join.py | 479 +++-- python/tests/sql/test_shapefile.py | 55 +- .../test_spatial_rdd_to_spatial_dataframe.py | 13 +- python/tests/sql/test_st_function_imports.py | 1 + python/tests/streaming/spark/cases_builder.py | 22 +- .../spark/test_constructor_functions.py | 774 +++++--- python/tests/test_assign_raw_spatial_rdd.py | 39 +- python/tests/test_base.py | 4 +- python/tests/test_circle.py | 102 +- python/tests/test_multiple_meta.py | 4 +- python/tests/test_scala_example.py | 120 +- python/tests/tools.py | 7 +- python/tests/utils/test_crs_transformation.py | 20 +- python/tests/utils/test_geometry_serde.py | 244 ++- python/tests/utils/test_geomserde_speedup.py | 68 +- spark-version-converter.py | 87 +- 148 files changed, 8857 insertions(+), 4162 deletions(-) diff --git a/.github/linters/ruff.toml b/.github/linters/ruff.toml index 6302db1e0a..68b176f4e0 100644 --- a/.github/linters/ruff.toml +++ b/.github/linters/ruff.toml @@ -39,8 +39,8 @@ target-version = "py38" # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or # McCabe complexity (`C901`) by default. -select = ["E4", "E7", "E9", "F"] -ignore = ["E721", "E722", "E731", "F401", "F402", "F403", "F405", "F811", "F821", "F822", "F841", "F901"] +select = ["E3", "E4", "E5", "E7", "E9", "F"] +ignore = ["E501", "E721", "E722", "E731", "F401", "F402", "F403", "F405", "F811", "F821", "F822", "F841", "F901"] # Allow fix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f3c8cac01..80c806295d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,10 @@ repos: hooks: - id: identity - id: check-hooks-apply + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.8.0 + hooks: + - id: black - repo: https://github.com/codespell-project/codespell rev: v2.3.0 hooks: diff --git a/docs/usecases/ApacheSedonaCore.ipynb b/docs/usecases/ApacheSedonaCore.ipynb index 69e43e5935..e8123163d1 100644 --- a/docs/usecases/ApacheSedonaCore.ipynb +++ b/docs/usecases/ApacheSedonaCore.ipynb @@ -40,7 +40,7 @@ "from shapely.geometry import Polygon\n", "\n", "from sedona.spark import *\n", - "from sedona.core.geom.envelope import Envelope\n" + "from sedona.core.geom.envelope import Envelope" ] }, { @@ -182,12 +182,16 @@ } ], "source": [ - "config = SedonaContext.builder() .\\\n", - " config('spark.jars.packages',\n", - " 'org.apache.sedona:sedona-spark-3.4_2.12:1.6.0,'\n", - " 'org.datasyslab:geotools-wrapper:1.6.0-28.2,'\n", - " 'uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4'). \\\n", - " getOrCreate()\n", + "config = (\n", + " SedonaContext.builder()\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-spark-3.4_2.12:1.6.0,\"\n", + " \"org.datasyslab:geotools-wrapper:1.6.0-28.2,\"\n", + " \"uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4\",\n", + " )\n", + " .getOrCreate()\n", + ")\n", "\n", "sedona = SedonaContext.create(config)" ] @@ -560,7 +564,9 @@ "metadata": {}, "outputs": [], "source": [ - "point_rdd_to_geo = point_rdd.rawSpatialRDD.map(lambda x: [x.geom, *x.getUserData().split(\"\\t\")])" + "point_rdd_to_geo = point_rdd.rawSpatialRDD.map(\n", + " lambda x: [x.geom, *x.getUserData().split(\"\\t\")]\n", + ")" ] }, { @@ -578,7 +584,9 @@ ], "source": [ "point_gdf = gpd.GeoDataFrame(\n", - " point_rdd_to_geo.collect(), columns=[\"geom\", \"attr1\", \"attr2\", \"attr3\"], geometry=\"geom\"\n", + " point_rdd_to_geo.collect(),\n", + " columns=[\"geom\", \"attr1\", \"attr2\", \"attr3\"],\n", + " geometry=\"geom\",\n", ")" ] }, @@ -696,9 +704,9 @@ "metadata": {}, "outputs": [], "source": [ - "spatial_df = Adapter.\\\n", - " toDf(point_rdd, [\"attr1\", \"attr2\", \"attr3\"], sedona).\\\n", - " createOrReplaceTempView(\"spatial_df\")\n", + "spatial_df = Adapter.toDf(\n", + " point_rdd, [\"attr1\", \"attr2\", \"attr3\"], sedona\n", + ").createOrReplaceTempView(\"spatial_df\")\n", "\n", "spatial_gdf = sedona.sql(\"Select attr1, attr2, attr3, geometry as geom from spatial_df\")" ] @@ -995,10 +1003,16 @@ "metadata": {}, "outputs": [], "source": [ - "rectangle_rdd = RectangleRDD(sc, \"data/zcta510-small.csv\", FileDataSplitter.CSV, True, 11)\n", + "rectangle_rdd = RectangleRDD(\n", + " sc, \"data/zcta510-small.csv\", FileDataSplitter.CSV, True, 11\n", + ")\n", "point_rdd = PointRDD(sc, \"data/arealm-small.csv\", 1, FileDataSplitter.CSV, False, 11)\n", - "polygon_rdd = PolygonRDD(sc, \"data/primaryroads-polygon.csv\", FileDataSplitter.CSV, True, 11)\n", - "linestring_rdd = LineStringRDD(sc, \"data/primaryroads-linestring.csv\", FileDataSplitter.CSV, True)" + "polygon_rdd = PolygonRDD(\n", + " sc, \"data/primaryroads-polygon.csv\", FileDataSplitter.CSV, True, 11\n", + ")\n", + "linestring_rdd = LineStringRDD(\n", + " sc, \"data/primaryroads-linestring.csv\", FileDataSplitter.CSV, True\n", + ")" ] }, { @@ -1298,7 +1312,7 @@ "schema = StructType(\n", " [\n", " StructField(\"geom_left\", GeometryType(), False),\n", - " StructField(\"geom_right\", GeometryType(), False)\n", + " StructField(\"geom_right\", GeometryType(), False),\n", " ]\n", ")" ] @@ -1497,7 +1511,9 @@ "metadata": {}, "outputs": [], "source": [ - "spatial_join_result_non_flat = JoinQuery.SpatialJoinQuery(point_rdd, rectangle_rdd, False, True)" + "spatial_join_result_non_flat = JoinQuery.SpatialJoinQuery(\n", + " point_rdd, rectangle_rdd, False, True\n", + ")" ] }, { @@ -1507,7 +1523,9 @@ "outputs": [], "source": [ "# number of point for each polygon\n", - "number_of_points = spatial_join_result_non_flat.map(lambda x: [x[0].geom, x[1].__len__()])" + "number_of_points = spatial_join_result_non_flat.map(\n", + " lambda x: [x[0].geom, x[1].__len__()]\n", + ")" ] }, { @@ -1516,10 +1534,12 @@ "metadata": {}, "outputs": [], "source": [ - "schema = StructType([\n", - " StructField(\"geometry\", GeometryType(), False),\n", - " StructField(\"number_of_points\", LongType(), False)\n", - "])" + "schema = StructType(\n", + " [\n", + " StructField(\"geometry\", GeometryType(), False),\n", + " StructField(\"number_of_points\", LongType(), False),\n", + " ]\n", + ")" ] }, { @@ -1650,10 +1670,14 @@ ], "source": [ "polygon = Polygon(\n", - " [(-84.237756, 33.904859), (-84.237756, 34.090426),\n", - " (-83.833011, 34.090426), (-83.833011, 33.904859),\n", - " (-84.237756, 33.904859)\n", - " ])\n", + " [\n", + " (-84.237756, 33.904859),\n", + " (-84.237756, 34.090426),\n", + " (-83.833011, 34.090426),\n", + " (-83.833011, 33.904859),\n", + " (-84.237756, 33.904859),\n", + " ]\n", + ")\n", "polygons_nearby = KNNQuery.SpatialKnnQuery(polygon_rdd, polygon, 5, False)" ] }, @@ -1737,7 +1761,9 @@ "source": [ "query_envelope = Envelope(-85.01, -60.01, 34.01, 50.01)\n", "\n", - "result_range_query = RangeQuery.SpatialRangeQuery(linestring_rdd, query_envelope, False, False)" + "result_range_query = RangeQuery.SpatialRangeQuery(\n", + " linestring_rdd, query_envelope, False, False\n", + ")" ] }, { @@ -1835,9 +1861,7 @@ ], "source": [ "sedona.createDataFrame(\n", - " result_range_query.map(lambda x: [x.geom]),\n", - " schema,\n", - " verifySchema=False\n", + " result_range_query.map(lambda x: [x.geom]), schema, verifySchema=False\n", ").show(5, True)" ] }, @@ -2325,7 +2349,9 @@ "source": [ "query_envelope = Envelope(-85.01, -60.01, 34.01, 50.01)\n", "\n", - "result_range_query = RangeQueryRaw.SpatialRangeQuery(linestring_rdd, query_envelope, False, False)" + "result_range_query = RangeQueryRaw.SpatialRangeQuery(\n", + " linestring_rdd, query_envelope, False, False\n", + ")" ] }, { diff --git a/docs/usecases/ApacheSedonaRaster.ipynb b/docs/usecases/ApacheSedonaRaster.ipynb index 3492f1ae59..042e009632 100644 --- a/docs/usecases/ApacheSedonaRaster.ipynb +++ b/docs/usecases/ApacheSedonaRaster.ipynb @@ -149,12 +149,16 @@ } ], "source": [ - "config = SedonaContext.builder() .\\\n", - " config('spark.jars.packages',\n", - " 'org.apache.sedona:sedona-spark-3.4_2.12:1.6.0,'\n", - " 'org.datasyslab:geotools-wrapper:1.6.0-28.2,'\n", - " 'uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4'). \\\n", - " getOrCreate()\n", + "config = (\n", + " SedonaContext.builder()\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-spark-3.4_2.12:1.6.0,\"\n", + " \"org.datasyslab:geotools-wrapper:1.6.0-28.2,\"\n", + " \"uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4\",\n", + " )\n", + " .getOrCreate()\n", + ")\n", "\n", "sedona = SedonaContext.create(config)\n", "\n", @@ -360,15 +364,25 @@ } ], "source": [ - "(width, height) = sedona.sql(\"SELECT RS_Width(raster) as width, RS_Height(raster) as height from raster_table\").first()\n", - "(p1X, p1Y) = sedona.sql(f\"SELECT RS_RasterToWorldCoordX(raster, {width / 2}, {height / 2}) \\\n", - " as pX, RS_RasterToWorldCoordY(raster, {width / 2}, {height / 2}) as pY from raster_table\").first()\n", - "(p2X, p2Y) = sedona.sql(f\"SELECT RS_RasterToWorldCoordX(raster, {(width / 2) + 2}, {height / 2}) \\\n", - " as pX, RS_RasterToWorldCoordY(raster, {(width / 2) + 2}, {height / 2}) as pY from raster_table\").first()\n", - "(p3X, p3Y) = sedona.sql(f\"SELECT RS_RasterToWorldCoordX(raster, {width / 2}, {(height / 2) + 2}) \\\n", - " as pX, RS_RasterToWorldCoordY(raster, {width / 2}, {(height / 2) + 2}) as pY from raster_table\").first()\n", - "(p4X, p4Y) = sedona.sql(f\"SELECT RS_RasterToWorldCoordX(raster, {(width / 2) + 2}, {(height / 2) + 2}) \\\n", - " as pX, RS_RasterToWorldCoordY(raster, {(width / 2) + 2}, {(height / 2) + 2}) as pY from raster_table\").first() " + "(width, height) = sedona.sql(\n", + " \"SELECT RS_Width(raster) as width, RS_Height(raster) as height from raster_table\"\n", + ").first()\n", + "(p1X, p1Y) = sedona.sql(\n", + " f\"SELECT RS_RasterToWorldCoordX(raster, {width / 2}, {height / 2}) \\\n", + " as pX, RS_RasterToWorldCoordY(raster, {width / 2}, {height / 2}) as pY from raster_table\"\n", + ").first()\n", + "(p2X, p2Y) = sedona.sql(\n", + " f\"SELECT RS_RasterToWorldCoordX(raster, {(width / 2) + 2}, {height / 2}) \\\n", + " as pX, RS_RasterToWorldCoordY(raster, {(width / 2) + 2}, {height / 2}) as pY from raster_table\"\n", + ").first()\n", + "(p3X, p3Y) = sedona.sql(\n", + " f\"SELECT RS_RasterToWorldCoordX(raster, {width / 2}, {(height / 2) + 2}) \\\n", + " as pX, RS_RasterToWorldCoordY(raster, {width / 2}, {(height / 2) + 2}) as pY from raster_table\"\n", + ").first()\n", + "(p4X, p4Y) = sedona.sql(\n", + " f\"SELECT RS_RasterToWorldCoordX(raster, {(width / 2) + 2}, {(height / 2) + 2}) \\\n", + " as pX, RS_RasterToWorldCoordY(raster, {(width / 2) + 2}, {(height / 2) + 2}) as pY from raster_table\"\n", + ").first()" ] }, { @@ -418,7 +432,9 @@ } ], "source": [ - "joined_df = sedona.sql(\"SELECT g.geom from raster_table r, geom_table g where RS_Intersects(r.raster, g.geom)\")\n", + "joined_df = sedona.sql(\n", + " \"SELECT g.geom from raster_table r, geom_table g where RS_Intersects(r.raster, g.geom)\"\n", + ")\n", "joined_df.show()" ] }, @@ -490,8 +506,12 @@ } ], "source": [ - "raster_convex_hull = sedona.sql(\"SELECT RS_ConvexHull(raster) as convex_hull from raster_table\")\n", - "raster_min_convex_hull = sedona.sql(\"SELECT RS_MinConvexHull(raster) as min_convex_hull from raster_table\")\n", + "raster_convex_hull = sedona.sql(\n", + " \"SELECT RS_ConvexHull(raster) as convex_hull from raster_table\"\n", + ")\n", + "raster_min_convex_hull = sedona.sql(\n", + " \"SELECT RS_MinConvexHull(raster) as min_convex_hull from raster_table\"\n", + ")\n", "raster_convex_hull.show(truncate=False)\n", "raster_min_convex_hull.show(truncate=False)" ] @@ -541,7 +561,9 @@ } ], "source": [ - "rasterized_geom_df = sedona.sql(\"SELECT RS_AsRaster(ST_GeomFromWKT('POLYGON((150 150, 220 260, 190 300, 300 220, 150 150))'), r.raster, 'b', 230) as rasterized_geom from raster_table r\")\n", + "rasterized_geom_df = sedona.sql(\n", + " \"SELECT RS_AsRaster(ST_GeomFromWKT('POLYGON((150 150, 220 260, 190 300, 300 220, 150 150))'), r.raster, 'b', 230) as rasterized_geom from raster_table r\"\n", + ")\n", "rasterized_geom_df.show()" ] }, @@ -585,7 +607,9 @@ } ], "source": [ - "SedonaUtils.display_image(rasterized_geom_df.selectExpr(\"RS_AsImage(rasterized_geom, 250) as rasterized_geom\"))" + "SedonaUtils.display_image(\n", + " rasterized_geom_df.selectExpr(\"RS_AsImage(rasterized_geom, 250) as rasterized_geom\")\n", + ")" ] }, { @@ -611,7 +635,9 @@ }, "outputs": [], "source": [ - "raster_white_bg = rasterized_geom_df.selectExpr(\"RS_MapAlgebra(rasterized_geom, NULL, 'out[0] = rast[0] == 0 ? 230 : 0;') as raster\")" + "raster_white_bg = rasterized_geom_df.selectExpr(\n", + " \"RS_MapAlgebra(rasterized_geom, NULL, 'out[0] = rast[0] == 0 ? 230 : 0;') as raster\"\n", + ")" ] }, { @@ -661,7 +687,9 @@ } ], "source": [ - "SedonaUtils.display_image(raster_white_bg.selectExpr(\"RS_AsImage(raster, 250) as resampled_raster\"))" + "SedonaUtils.display_image(\n", + " raster_white_bg.selectExpr(\"RS_AsImage(raster, 250) as resampled_raster\")\n", + ")" ] }, { @@ -686,7 +714,9 @@ }, "outputs": [], "source": [ - "resampled_raster_df = sedona.sql(\"SELECT RS_Resample(raster, 1000, 1000, false, 'NearestNeighbor') as resampled_raster from raster_table\")" + "resampled_raster_df = sedona.sql(\n", + " \"SELECT RS_Resample(raster, 1000, 1000, false, 'NearestNeighbor') as resampled_raster from raster_table\"\n", + ")" ] }, { @@ -731,7 +761,11 @@ } ], "source": [ - "SedonaUtils.display_image(resampled_raster_df.selectExpr(\"RS_AsImage(resampled_raster, 500) as resampled_raster\"))" + "SedonaUtils.display_image(\n", + " resampled_raster_df.selectExpr(\n", + " \"RS_AsImage(resampled_raster, 500) as resampled_raster\"\n", + " )\n", + ")" ] }, { @@ -756,7 +790,9 @@ } ], "source": [ - "resampled_raster_df.selectExpr(\"RS_MetaData(resampled_raster) as resampled_raster_metadata\").show(truncate=False)" + "resampled_raster_df.selectExpr(\n", + " \"RS_MetaData(resampled_raster) as resampled_raster_metadata\"\n", + ").show(truncate=False)" ] }, { @@ -769,7 +805,7 @@ "outputs": [], "source": [ "# Load another raster for some more examples\n", - "elevation_raster_df = sedona.read.format('binaryFile').load('data/raster/test1.tiff')\n", + "elevation_raster_df = sedona.read.format(\"binaryFile\").load(\"data/raster/test1.tiff\")\n", "elevation_raster_df.createOrReplaceTempView(\"elevation_raster_binary\")" ] }, @@ -782,7 +818,9 @@ }, "outputs": [], "source": [ - "elevation_raster_df = sedona.sql(\"SELECT RS_FromGeoTiff(content) as raster from elevation_raster_binary\")\n", + "elevation_raster_df = sedona.sql(\n", + " \"SELECT RS_FromGeoTiff(content) as raster from elevation_raster_binary\"\n", + ")\n", "elevation_raster_df.createOrReplaceTempView(\"elevation_raster\")" ] }, @@ -831,11 +869,17 @@ } ], "source": [ - "point_wkt_1 = 'SRID=3857;POINT (-13095600.809482181 4021100.7487925636)'\n", - "point_wkt_2 = 'SRID=3857;POINT (-13095500.809482181 4021000.7487925636)'\n", - "point_df = sedona.sql(\"SELECT ST_GeomFromEWKT('{}') as point_1, ST_GeomFromEWKT('{}') as point_2\".format(point_wkt_1, point_wkt_2))\n", + "point_wkt_1 = \"SRID=3857;POINT (-13095600.809482181 4021100.7487925636)\"\n", + "point_wkt_2 = \"SRID=3857;POINT (-13095500.809482181 4021000.7487925636)\"\n", + "point_df = sedona.sql(\n", + " \"SELECT ST_GeomFromEWKT('{}') as point_1, ST_GeomFromEWKT('{}') as point_2\".format(\n", + " point_wkt_1, point_wkt_2\n", + " )\n", + ")\n", "point_df.createOrReplaceTempView(\"point_table\")\n", - "test_df = sedona.sql(\"SELECT RS_Values(raster, Array(point_1, point_2)) as raster_values from elevation_raster, point_table\")\n", + "test_df = sedona.sql(\n", + " \"SELECT RS_Values(raster, Array(point_1, point_2)) as raster_values from elevation_raster, point_table\"\n", + ")\n", "test_df.show()" ] }, @@ -864,7 +908,9 @@ ], "source": [ "band = elevation_raster_df.selectExpr(\"RS_BandAsArray(raster, 1)\").first()[0]\n", - "print(band[500: 520],) #Print a part of a band as an array horizontally" + "print(\n", + " band[500:520],\n", + ") # Print a part of a band as an array horizontally" ] }, { @@ -883,7 +929,9 @@ "outputs": [], "source": [ "# Convert raster to its convex hull and transform it to EPSG:4326 to be able to visualize\n", - "raster_mbr_df = elevation_raster_df.selectExpr(\"ST_Transform(RS_ConvexHull(raster), 'EPSG:3857', 'EPSG:4326') as raster_mbr\")" + "raster_mbr_df = elevation_raster_df.selectExpr(\n", + " \"ST_Transform(RS_ConvexHull(raster), 'EPSG:3857', 'EPSG:4326') as raster_mbr\"\n", + ")" ] }, { @@ -895,7 +943,9 @@ }, "outputs": [], "source": [ - "sedona_kepler_map_elevation = SedonaKepler.create_map(df=raster_mbr_df, name='RasterMBR')\n", + "sedona_kepler_map_elevation = SedonaKepler.create_map(\n", + " df=raster_mbr_df, name=\"RasterMBR\"\n", + ")\n", "sedona_kepler_map_elevation" ] }, diff --git a/docs/usecases/ApacheSedonaSQL.ipynb b/docs/usecases/ApacheSedonaSQL.ipynb index 5d4ccebf94..f36f1ec60c 100644 --- a/docs/usecases/ApacheSedonaSQL.ipynb +++ b/docs/usecases/ApacheSedonaSQL.ipynb @@ -133,14 +133,18 @@ } ], "source": [ - "config = SedonaContext.builder() .\\\n", - " config('spark.jars.packages',\n", - " 'org.apache.sedona:sedona-spark-3.4_2.12:1.6.0,'\n", - " 'org.datasyslab:geotools-wrapper:1.6.0-28.2,'\n", - " 'uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4'). \\\n", - " getOrCreate()\n", + "config = (\n", + " SedonaContext.builder()\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-spark-3.4_2.12:1.6.0,\"\n", + " \"org.datasyslab:geotools-wrapper:1.6.0-28.2,\"\n", + " \"uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4\",\n", + " )\n", + " .getOrCreate()\n", + ")\n", "\n", - "sedona = SedonaContext.create(config)\n" + "sedona = SedonaContext.create(config)" ] }, { @@ -195,14 +199,18 @@ } ], "source": [ - "point_csv_df = sedona.read.format(\"csv\").\\\n", - " option(\"delimiter\", \",\").\\\n", - " option(\"header\", \"false\").\\\n", - " load(\"data/testpoint.csv\")\n", + "point_csv_df = (\n", + " sedona.read.format(\"csv\")\n", + " .option(\"delimiter\", \",\")\n", + " .option(\"header\", \"false\")\n", + " .load(\"data/testpoint.csv\")\n", + ")\n", "\n", "point_csv_df.createOrReplaceTempView(\"pointtable\")\n", "\n", - "point_df = sedona.sql(\"select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable\")\n", + "point_df = sedona.sql(\n", + " \"select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable\"\n", + ")\n", "point_df.show(5)" ] }, @@ -237,13 +245,17 @@ } ], "source": [ - "polygon_wkt_df = sedona.read.format(\"csv\").\\\n", - " option(\"delimiter\", \"\\t\").\\\n", - " option(\"header\", \"false\").\\\n", - " load(\"data/county_small.tsv\")\n", + "polygon_wkt_df = (\n", + " sedona.read.format(\"csv\")\n", + " .option(\"delimiter\", \"\\t\")\n", + " .option(\"header\", \"false\")\n", + " .load(\"data/county_small.tsv\")\n", + ")\n", "\n", "polygon_wkt_df.createOrReplaceTempView(\"polygontable\")\n", - "polygon_df = sedona.sql(\"select polygontable._c6 as name, ST_GeomFromText(polygontable._c0) as countyshape from polygontable\")\n", + "polygon_df = sedona.sql(\n", + " \"select polygontable._c6 as name, ST_GeomFromText(polygontable._c0) as countyshape from polygontable\"\n", + ")\n", "polygon_df.show(5)" ] }, @@ -278,13 +290,17 @@ } ], "source": [ - "polygon_wkb_df = sedona.read.format(\"csv\").\\\n", - " option(\"delimiter\", \"\\t\").\\\n", - " option(\"header\", \"false\").\\\n", - " load(\"data/county_small_wkb.tsv\")\n", + "polygon_wkb_df = (\n", + " sedona.read.format(\"csv\")\n", + " .option(\"delimiter\", \"\\t\")\n", + " .option(\"header\", \"false\")\n", + " .load(\"data/county_small_wkb.tsv\")\n", + ")\n", "\n", "polygon_wkb_df.createOrReplaceTempView(\"polygontable\")\n", - "polygon_df = sedona.sql(\"select polygontable._c6 as name, ST_GeomFromWKB(polygontable._c0) as countyshape from polygontable\")\n", + "polygon_df = sedona.sql(\n", + " \"select polygontable._c6 as name, ST_GeomFromWKB(polygontable._c0) as countyshape from polygontable\"\n", + ")\n", "polygon_df.show(5)" ] }, @@ -319,13 +335,17 @@ } ], "source": [ - "polygon_json_df = sedona.read.format(\"csv\").\\\n", - " option(\"delimiter\", \"\\t\").\\\n", - " option(\"header\", \"false\").\\\n", - " load(\"data/testPolygon.json\")\n", + "polygon_json_df = (\n", + " sedona.read.format(\"csv\")\n", + " .option(\"delimiter\", \"\\t\")\n", + " .option(\"header\", \"false\")\n", + " .load(\"data/testPolygon.json\")\n", + ")\n", "\n", "polygon_json_df.createOrReplaceTempView(\"polygontable\")\n", - "polygon_df = sedona.sql(\"select ST_GeomFromGeoJSON(polygontable._c0) as countyshape from polygontable\")\n", + "polygon_df = sedona.sql(\n", + " \"select ST_GeomFromGeoJSON(polygontable._c0) as countyshape from polygontable\"\n", + ")\n", "polygon_df.show(5)" ] }, @@ -389,24 +409,36 @@ } ], "source": [ - "point_csv_df_1 = sedona.read.format(\"csv\").\\\n", - " option(\"delimiter\", \",\").\\\n", - " option(\"header\", \"false\").load(\"data/testpoint.csv\")\n", + "point_csv_df_1 = (\n", + " sedona.read.format(\"csv\")\n", + " .option(\"delimiter\", \",\")\n", + " .option(\"header\", \"false\")\n", + " .load(\"data/testpoint.csv\")\n", + ")\n", "\n", "point_csv_df_1.createOrReplaceTempView(\"pointtable\")\n", "\n", - "point_df1 = sedona.sql(\"SELECT ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1, \\'abc\\' as name1 from pointtable\")\n", + "point_df1 = sedona.sql(\n", + " \"SELECT ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1, 'abc' as name1 from pointtable\"\n", + ")\n", "point_df1.createOrReplaceTempView(\"pointdf1\")\n", "\n", - "point_csv_df2 = sedona.read.format(\"csv\").\\\n", - " option(\"delimiter\", \",\").\\\n", - " option(\"header\", \"false\").load(\"data/testpoint.csv\")\n", + "point_csv_df2 = (\n", + " sedona.read.format(\"csv\")\n", + " .option(\"delimiter\", \",\")\n", + " .option(\"header\", \"false\")\n", + " .load(\"data/testpoint.csv\")\n", + ")\n", "\n", "point_csv_df2.createOrReplaceTempView(\"pointtable\")\n", - "point_df2 = sedona.sql(\"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape2, \\'def\\' as name2 from pointtable\")\n", + "point_df2 = sedona.sql(\n", + " \"select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape2, 'def' as name2 from pointtable\"\n", + ")\n", "point_df2.createOrReplaceTempView(\"pointdf2\")\n", "\n", - "distance_join_df = sedona.sql(\"select * from pointdf1, pointdf2 where ST_Distance(pointdf1.pointshape1,pointdf2.pointshape2) < 2\")\n", + "distance_join_df = sedona.sql(\n", + " \"select * from pointdf1, pointdf2 where ST_Distance(pointdf1.pointshape1,pointdf2.pointshape2) < 2\"\n", + ")\n", "distance_join_df.explain()\n", "distance_join_df.show(5)" ] @@ -434,11 +466,10 @@ "outputs": [], "source": [ "import pandas as pd\n", + "\n", "gdf = gpd.read_file(\"data/gis_osm_pois_free_1.shp\")\n", - "gdf = gdf.replace(pd.NA, '')\n", - "osm_points = sedona.createDataFrame(\n", - " gdf\n", - ")" + "gdf = gdf.replace(pd.NA, \"\")\n", + "osm_points = sedona.createDataFrame(gdf)" ] }, { @@ -528,7 +559,8 @@ " name,\n", " ST_Transform(geometry, 'epsg:4326', 'epsg:2180') as geom \n", " FROM points\n", - " \"\"\")" + " \"\"\"\n", + ")" ] }, { @@ -587,13 +619,15 @@ "metadata": {}, "outputs": [], "source": [ - "neighbours_within_1000m = sedona.sql(\"\"\"\n", + "neighbours_within_1000m = sedona.sql(\n", + " \"\"\"\n", " SELECT a.osm_id AS id_1,\n", " b.osm_id AS id_2,\n", " a.geom \n", " FROM points_2180 AS a, points_2180 AS b \n", " WHERE ST_Distance(a.geom,b.geom) < 50\n", - " \"\"\")" + " \"\"\"\n", + ")" ] }, { diff --git a/docs/usecases/ApacheSedonaSQL_SpatialJoin_AirportsPerCountry.ipynb b/docs/usecases/ApacheSedonaSQL_SpatialJoin_AirportsPerCountry.ipynb index 97eeb96cef..8883b98f87 100644 --- a/docs/usecases/ApacheSedonaSQL_SpatialJoin_AirportsPerCountry.ipynb +++ b/docs/usecases/ApacheSedonaSQL_SpatialJoin_AirportsPerCountry.ipynb @@ -35,7 +35,6 @@ "from pyspark.sql.functions import col, expr, when, explode, hex\n", "\n", "\n", - "\n", "from sedona.spark import *\n", "from utilities import getConfig" ] @@ -99,12 +98,16 @@ } ], "source": [ - "config = SedonaContext.builder() .\\\n", - " config('spark.jars.packages',\n", - " 'org.apache.sedona:sedona-spark-shaded-3.4_2.12:1.6.0,'\n", - " 'org.datasyslab:geotools-wrapper:1.6.0-28.2,'\n", - " 'uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4'). \\\n", - " getOrCreate()\n", + "config = (\n", + " SedonaContext.builder()\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-spark-shaded-3.4_2.12:1.6.0,\"\n", + " \"org.datasyslab:geotools-wrapper:1.6.0-28.2,\"\n", + " \"uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4\",\n", + " )\n", + " .getOrCreate()\n", + ")\n", "\n", "sedona = SedonaContext.create(config)\n", "sc = sedona.sparkContext\n", @@ -236,7 +239,9 @@ } ], "source": [ - "countries = ShapefileReader.readToGeometryRDD(sc, \"data/ne_50m_admin_0_countries_lakes/\")\n", + "countries = ShapefileReader.readToGeometryRDD(\n", + " sc, \"data/ne_50m_admin_0_countries_lakes/\"\n", + ")\n", "countries_df = Adapter.toDf(countries, sedona)\n", "countries_df.createOrReplaceTempView(\"country\")\n", "countries_df.printSchema()" @@ -297,7 +302,9 @@ "metadata": {}, "outputs": [], "source": [ - "result = sedona.sql(\"SELECT c.geometry as country_geom, c.NAME_EN, a.geometry as airport_geom, a.name FROM country c, airport a WHERE ST_Contains(c.geometry, a.geometry)\")" + "result = sedona.sql(\n", + " \"SELECT c.geometry as country_geom, c.NAME_EN, a.geometry as airport_geom, a.name FROM country c, airport a WHERE ST_Contains(c.geometry, a.geometry)\"\n", + ")" ] }, { @@ -338,13 +345,19 @@ "considerBoundaryIntersection = True\n", "airports_rdd.buildIndex(IndexType.QUADTREE, buildOnSpatialPartitionedRDD)\n", "\n", - "result_pair_rdd = JoinQueryRaw.SpatialJoinQueryFlat(airports_rdd, countries_rdd, usingIndex, considerBoundaryIntersection)\n", + "result_pair_rdd = JoinQueryRaw.SpatialJoinQueryFlat(\n", + " airports_rdd, countries_rdd, usingIndex, considerBoundaryIntersection\n", + ")\n", "\n", - "result2 = Adapter.toDf(result_pair_rdd, countries_rdd.fieldNames, airports.fieldNames, sedona)\n", + "result2 = Adapter.toDf(\n", + " result_pair_rdd, countries_rdd.fieldNames, airports.fieldNames, sedona\n", + ")\n", "\n", "result2.createOrReplaceTempView(\"join_result_with_all_cols\")\n", "# Select the columns needed in the join\n", - "result2 = sedona.sql(\"SELECT leftgeometry as country_geom, NAME_EN, rightgeometry as airport_geom, name FROM join_result_with_all_cols\")" + "result2 = sedona.sql(\n", + " \"SELECT leftgeometry as country_geom, NAME_EN, rightgeometry as airport_geom, name FROM join_result_with_all_cols\"\n", + ")" ] }, { @@ -503,7 +516,9 @@ "source": [ "# result.createOrReplaceTempView(\"result\")\n", "result2.createOrReplaceTempView(\"result\")\n", - "groupedresult = sedona.sql(\"SELECT c.NAME_EN, c.country_geom, count(*) as AirportCount FROM result c GROUP BY c.NAME_EN, c.country_geom\")\n", + "groupedresult = sedona.sql(\n", + " \"SELECT c.NAME_EN, c.country_geom, count(*) as AirportCount FROM result c GROUP BY c.NAME_EN, c.country_geom\"\n", + ")\n", "groupedresult.show()\n", "groupedresult.createOrReplaceTempView(\"grouped_result\")" ] @@ -557,7 +572,9 @@ } ], "source": [ - "sedona_kepler_map = SedonaKepler.create_map(df=groupedresult, name=\"AirportCount\", config=getConfig())\n", + "sedona_kepler_map = SedonaKepler.create_map(\n", + " df=groupedresult, name=\"AirportCount\", config=getConfig()\n", + ")\n", "sedona_kepler_map" ] }, @@ -669373,7 +669390,9 @@ } ], "source": [ - "sedona_pydeck_map = SedonaPyDeck.create_choropleth_map(df=groupedresult, plot_col='AirportCount')\n", + "sedona_pydeck_map = SedonaPyDeck.create_choropleth_map(\n", + " df=groupedresult, plot_col=\"AirportCount\"\n", + ")\n", "sedona_pydeck_map" ] }, @@ -669430,7 +669449,9 @@ } ], "source": [ - "h3_df = sedona.sql(\"SELECT g.NAME_EN, g.country_geom, ST_H3CellIDs(g.country_geom, 3, false) as h3_cellID from grouped_result g\")\n", + "h3_df = sedona.sql(\n", + " \"SELECT g.NAME_EN, g.country_geom, ST_H3CellIDs(g.country_geom, 3, false) as h3_cellID from grouped_result g\"\n", + ")\n", "h3_df.show(2)" ] }, @@ -669478,7 +669499,9 @@ } ], "source": [ - "exploded_h3 = h3_df.select(h3_df.NAME_EN, h3_df.country_geom, explode(h3_df.h3_cellID).alias(\"h3\"))\n", + "exploded_h3 = h3_df.select(\n", + " h3_df.NAME_EN, h3_df.country_geom, explode(h3_df.h3_cellID).alias(\"h3\")\n", + ")\n", "exploded_h3.show(2)" ] }, @@ -669533,7 +669556,9 @@ "source": [ "exploded_h3 = exploded_h3.sample(0.3)\n", "exploded_h3.createOrReplaceTempView(\"exploded_h3\")\n", - "hex_exploded_h3 = exploded_h3.select(exploded_h3.NAME_EN, hex(exploded_h3.h3).alias(\"ex_h3\"))\n", + "hex_exploded_h3 = exploded_h3.select(\n", + " exploded_h3.NAME_EN, hex(exploded_h3.h3).alias(\"ex_h3\")\n", + ")\n", "hex_exploded_h3.show(2)\n", "hex_exploded_h3.printSchema()" ] diff --git a/docs/usecases/Sedona_OvertureMaps_GeoParquet.ipynb b/docs/usecases/Sedona_OvertureMaps_GeoParquet.ipynb index 17389d0406..875700faf2 100644 --- a/docs/usecases/Sedona_OvertureMaps_GeoParquet.ipynb +++ b/docs/usecases/Sedona_OvertureMaps_GeoParquet.ipynb @@ -70,7 +70,9 @@ }, "outputs": [], "source": [ - "DATA_LINK = \"s3a://wherobots-examples/data/overturemaps-us-west-2/release/2023-07-26-alpha.0/\"" + "DATA_LINK = (\n", + " \"s3a://wherobots-examples/data/overturemaps-us-west-2/release/2023-07-26-alpha.0/\"\n", + ")" ] }, { @@ -212,14 +214,24 @@ } ], "source": [ - "config = SedonaContext.builder() .\\\n", - " config(\"spark.hadoop.fs.s3a.aws.credentials.provider\", \"org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider\"). \\\n", - " config(\"fs.s3a.aws.credentials.provider\", \"org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider\"). \\\n", - " config('spark.jars.packages',\n", - " 'org.apache.sedona:sedona-spark-3.4_2.12:1.6.0,'\n", - " 'org.datasyslab:geotools-wrapper:1.6.0-28.2,'\n", - " 'uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4'). \\\n", - " getOrCreate()\n", + "config = (\n", + " SedonaContext.builder()\n", + " .config(\n", + " \"spark.hadoop.fs.s3a.aws.credentials.provider\",\n", + " \"org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider\",\n", + " )\n", + " .config(\n", + " \"fs.s3a.aws.credentials.provider\",\n", + " \"org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider\",\n", + " )\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-spark-3.4_2.12:1.6.0,\"\n", + " \"org.datasyslab:geotools-wrapper:1.6.0-28.2,\"\n", + " \"uk.co.gresearch.spark:spark-extension_2.12:2.11.0-3.4\",\n", + " )\n", + " .getOrCreate()\n", + ")\n", "\n", "sedona = SedonaContext.create(config)" ] @@ -254,7 +266,7 @@ "outputs": [], "source": [ "# Washington state boundary\n", - "#spatial_filter = \"POLYGON((-123.3208 49.0023,-123.0338 49.0027,-122.0650 49.0018,-121.7491 48.9973,-121.5912 48.9991,-119.6082 49.0009,-118.0378 49.0005,-117.0319 48.9996,-117.0415 47.9614,-117.0394 46.5060,-117.0394 46.4274,-117.0621 46.3498,-117.0277 46.3384,-116.9879 46.2848,-116.9577 46.2388,-116.9659 46.2022,-116.9254 46.1722,-116.9357 46.1432,-116.9584 46.1009,-116.9762 46.0785,-116.9433 46.0537,-116.9165 45.9960,-118.0330 46.0008,-118.9867 45.9998,-119.1302 45.9320,-119.1708 45.9278,-119.2559 45.9402,-119.3047 45.9354,-119.3644 45.9220,-119.4386 45.9172,-119.4894 45.9067,-119.5724 45.9249,-119.6013 45.9196,-119.6700 45.8565,-119.8052 45.8479,-119.9096 45.8278,-119.9652 45.8245,-120.0710 45.7852,-120.1705 45.7623,-120.2110 45.7258,-120.3628 45.7057,-120.4829 45.6951,-120.5942 45.7469,-120.6340 45.7460,-120.6924 45.7143,-120.8558 45.6721,-120.9142 45.6409,-120.9471 45.6572,-120.9787 45.6419,-121.0645 45.6529,-121.1469 45.6078,-121.1847 45.6083,-121.2177 45.6721,-121.3392 45.7057,-121.4010 45.6932,-121.5328 45.7263,-121.6145 45.7091,-121.7361 45.6947,-121.8095 45.7067,-121.9338 45.6452,-122.0451 45.6088,-122.1089 45.5833,-122.1426 45.5838,-122.2009 45.5660,-122.2641 45.5439,-122.3321 45.5482,-122.3795 45.5756,-122.4392 45.5636,-122.5676 45.6006,-122.6891 45.6236,-122.7647 45.6582,-122.7750 45.6817,-122.7619 45.7613,-122.7962 45.8106,-122.7839 45.8642,-122.8114 45.9120,-122.8148 45.9612,-122.8587 46.0160,-122.8848 46.0604,-122.9034 46.0832,-122.9597 46.1028,-123.0579 46.1556,-123.1210 46.1865,-123.1664 46.1893,-123.2810 46.1446,-123.3703 46.1470,-123.4314 46.1822,-123.4287 46.2293,-123.4946 46.2691,-123.5557 46.2582,-123.6209 46.2573,-123.6875 46.2497,-123.7404 46.2691,-123.8729 46.2350,-123.9292 46.2383,-123.9711 46.2677,-124.0212 46.2924,-124.0329 46.2653,-124.2444 46.2596,-124.2691 46.4312,-124.3529 46.8386,-124.4380 47.1832,-124.5616 47.4689,-124.7566 47.8012,-124.8679 48.0423,-124.8679 48.2457,-124.8486 48.3727,-124.7539 48.4984,-124.4174 48.4096,-124.2389 48.3599,-124.0116 48.2964,-123.9141 48.2795,-123.5413 48.2247,-123.3998 48.2539,-123.2501 48.2841,-123.1169 48.4233,-123.1609 48.4533,-123.2220 48.5548,-123.2336 48.5902,-123.2721 48.6901,-123.0084 48.7675,-123.0084 48.8313,-123.3215 49.0023,-123.3208 49.0023))\"\n", + "# spatial_filter = \"POLYGON((-123.3208 49.0023,-123.0338 49.0027,-122.0650 49.0018,-121.7491 48.9973,-121.5912 48.9991,-119.6082 49.0009,-118.0378 49.0005,-117.0319 48.9996,-117.0415 47.9614,-117.0394 46.5060,-117.0394 46.4274,-117.0621 46.3498,-117.0277 46.3384,-116.9879 46.2848,-116.9577 46.2388,-116.9659 46.2022,-116.9254 46.1722,-116.9357 46.1432,-116.9584 46.1009,-116.9762 46.0785,-116.9433 46.0537,-116.9165 45.9960,-118.0330 46.0008,-118.9867 45.9998,-119.1302 45.9320,-119.1708 45.9278,-119.2559 45.9402,-119.3047 45.9354,-119.3644 45.9220,-119.4386 45.9172,-119.4894 45.9067,-119.5724 45.9249,-119.6013 45.9196,-119.6700 45.8565,-119.8052 45.8479,-119.9096 45.8278,-119.9652 45.8245,-120.0710 45.7852,-120.1705 45.7623,-120.2110 45.7258,-120.3628 45.7057,-120.4829 45.6951,-120.5942 45.7469,-120.6340 45.7460,-120.6924 45.7143,-120.8558 45.6721,-120.9142 45.6409,-120.9471 45.6572,-120.9787 45.6419,-121.0645 45.6529,-121.1469 45.6078,-121.1847 45.6083,-121.2177 45.6721,-121.3392 45.7057,-121.4010 45.6932,-121.5328 45.7263,-121.6145 45.7091,-121.7361 45.6947,-121.8095 45.7067,-121.9338 45.6452,-122.0451 45.6088,-122.1089 45.5833,-122.1426 45.5838,-122.2009 45.5660,-122.2641 45.5439,-122.3321 45.5482,-122.3795 45.5756,-122.4392 45.5636,-122.5676 45.6006,-122.6891 45.6236,-122.7647 45.6582,-122.7750 45.6817,-122.7619 45.7613,-122.7962 45.8106,-122.7839 45.8642,-122.8114 45.9120,-122.8148 45.9612,-122.8587 46.0160,-122.8848 46.0604,-122.9034 46.0832,-122.9597 46.1028,-123.0579 46.1556,-123.1210 46.1865,-123.1664 46.1893,-123.2810 46.1446,-123.3703 46.1470,-123.4314 46.1822,-123.4287 46.2293,-123.4946 46.2691,-123.5557 46.2582,-123.6209 46.2573,-123.6875 46.2497,-123.7404 46.2691,-123.8729 46.2350,-123.9292 46.2383,-123.9711 46.2677,-124.0212 46.2924,-124.0329 46.2653,-124.2444 46.2596,-124.2691 46.4312,-124.3529 46.8386,-124.4380 47.1832,-124.5616 47.4689,-124.7566 47.8012,-124.8679 48.0423,-124.8679 48.2457,-124.8486 48.3727,-124.7539 48.4984,-124.4174 48.4096,-124.2389 48.3599,-124.0116 48.2964,-123.9141 48.2795,-123.5413 48.2247,-123.3998 48.2539,-123.2501 48.2841,-123.1169 48.4233,-123.1609 48.4533,-123.2220 48.5548,-123.2336 48.5902,-123.2721 48.6901,-123.0084 48.7675,-123.0084 48.8313,-123.3215 49.0023,-123.3208 49.0023))\"\n", "\n", "# Bellevue city boundary\n", "spatial_filter = \"POLYGON ((-122.235128 47.650163, -122.233796 47.65162, -122.231581 47.653287, -122.228514 47.65482, -122.227526 47.655204, -122.226175 47.655729, -122.222039 47.656743999999996, -122.218428 47.657464, -122.217026 47.657506, -122.21437399999999 47.657588, -122.212091 47.657464, -122.212135 47.657320999999996, -122.21092999999999 47.653552, -122.209834 47.650121, -122.209559 47.648976, -122.209642 47.648886, -122.21042 47.648658999999995, -122.210897 47.64864, -122.211005 47.648373, -122.21103099999999 47.648320999999996, -122.211992 47.64644, -122.212457 47.646426, -122.212469 47.646392, -122.212469 47.646088999999996, -122.212471 47.645213, -122.213115 47.645212, -122.213123 47.644576, -122.21352999999999 47.644576, -122.213768 47.644560999999996, -122.21382 47.644560999999996, -122.21382 47.644456999999996, -122.21373299999999 47.644455, -122.213748 47.643102999999996, -122.213751 47.642790999999995, -122.213753 47.642716, -122.213702 47.642697999999996, -122.213679 47.642689999999995, -122.21364 47.642678, -122.213198 47.642541, -122.213065 47.642500000000005, -122.212918 47.642466, -122.21275 47.642441, -122.212656 47.642433, -122.21253899999999 47.642429, -122.212394 47.64243, -122.212182 47.642444999999995, -122.211957 47.642488, -122.211724 47.642551999999995, -122.21143599999999 47.642647, -122.210906 47.642834, -122.210216 47.643099, -122.209858 47.643215, -122.20973000000001 47.643248, -122.20973599999999 47.643105, -122.209267 47.643217, -122.208832 47.643302, -122.208391 47.643347999999996, -122.207797 47.643414, -122.207476 47.643418, -122.20701199999999 47.643397, -122.206795 47.643387999999995, -122.205742 47.643246, -122.20549 47.643201999999995, -122.20500200000001 47.643119, -122.204802 47.643085, -122.204641 47.643066, -122.204145 47.643012, -122.203547 47.643012, -122.203097 47.643107, -122.20275699999999 47.643283, -122.202507 47.643496999999996, -122.202399 47.643653, -122.202111 47.643771, -122.201668 47.643767, -122.201363 47.643665, -122.20133 47.643648999999996, -122.201096 47.643536, -122.200744 47.64328, -122.200568 47.64309, -122.200391 47.642849, -122.200162 47.642539, -122.199896 47.642500000000005, -122.19980799999999 47.642424, -122.199755 47.642376999999996, -122.199558 47.642227999999996, -122.199439 47.642157, -122.199293 47.642078999999995, -122.199131 47.642004, -122.198928 47.641925, -122.19883 47.641892, -122.19856300000001 47.641811999999994, -122.198203 47.641731, -122.197662 47.641619999999996, -122.196819 47.641436, -122.196294 47.641309, -122.196294 47.642314, -122.19628 47.642855, -122.196282 47.642897999999995, -122.196281 47.643111, -122.196283 47.643415, -122.196283 47.643508999999995, -122.19628399999999 47.643739, -122.196287 47.644203999999995, -122.196287 47.644262999999995, -122.19629 47.644937999999996, -122.19629 47.644954999999996, -122.196292 47.645271, -122.196291 47.645426, -122.19629499999999 47.646315, -122.19629499999999 47.646432, -122.195925 47.646432, -122.195251 47.646432, -122.190853 47.646429999999995, -122.187649 47.646428, -122.187164 47.646426, -122.18683 47.646426, -122.185547 47.646409, -122.185546 47.646316, -122.185537 47.645599, -122.185544 47.644197, -122.185537 47.643294999999995, -122.185544 47.642733, -122.185541 47.641757, -122.185555 47.640681, -122.185561 47.63972, -122.185557 47.638228999999995, -122.185591 47.635419, -122.185611 47.634750999999994, -122.18562299999999 47.634484, -122.18561700000001 47.634375999999996, -122.185592 47.634311, -122.185549 47.634232999999995, -122.185504 47.634181999999996, -122.185426 47.634119, -122.184371 47.633424999999995, -122.18400000000001 47.633198, -122.183896 47.633134, -122.1838 47.633067, -122.18375499999999 47.633019999999995, -122.183724 47.632959, -122.183695 47.632858, -122.183702 47.632675, -122.182757 47.632622999999995, -122.182365 47.63259, -122.18220600000001 47.632562, -122.181984 47.632504999999995, -122.18163799999999 47.632363, -122.18142 47.632262999999995, -122.181229 47.632165, -122.181612 47.632172999999995, -122.18271899999999 47.632151, -122.183138 47.632135, -122.18440000000001 47.632081, -122.184743 47.632065999999995, -122.185312 47.63205, -122.185624 47.632047, -122.185625 47.631873999999996, -122.184618 47.63187, -122.184291 47.631878, -122.184278 47.631817999999996, -122.183882 47.629942, -122.182689 47.623548, -122.182594 47.622789999999995, -122.182654 47.622155, -122.183135 47.622372999999996, -122.183471 47.622506, -122.18360200000001 47.622552, -122.183893 47.622637999999995, -122.184244 47.62272, -122.184618 47.622777, -122.184741 47.622727999999995, -122.184605 47.622679, -122.18424 47.622622, -122.183985 47.622569, -122.183717 47.622501, -122.183506 47.622439, -122.18327 47.622357, -122.18305699999999 47.622271999999995, -122.182669 47.622088999999995, -122.182796 47.621545, -122.18347 47.619628999999996, -122.18365 47.619098, -122.183859 47.6184, -122.183922 47.617793999999996, -122.183956 47.617292, -122.183792 47.616388, -122.183261 47.614391999999995, -122.183202 47.613802, -122.183209 47.613155, -122.183436 47.612384999999996, -122.18395100000001 47.610445999999996, -122.184338 47.60924, -122.184657 47.609116, -122.18481 47.609051, -122.18491900000001 47.608987, -122.184974 47.608942, -122.185047 47.608846, -122.185082 47.608743999999994, -122.185109 47.608526999999995, -122.185116 47.608359, -122.18513 47.608315999999995, -122.185157 47.608273999999994, -122.185183 47.608247, -122.185246 47.608214, -122.185354 47.608196, -122.185475 47.608191999999995, -122.185472 47.606697, -122.185472 47.606373999999995, -122.185521 47.606272, -122.185528 47.606210999999995, -122.185506 47.606037, -122.185451 47.605872999999995, -122.185411 47.605781, -122.185358 47.605681999999995, -122.185248 47.605509999999995, -122.185127 47.605365, -122.185058 47.605292, -122.184772 47.605038, -122.184428 47.604834, -122.184122 47.604693999999995, -122.183775 47.604574, -122.183644 47.604546, -122.183708 47.604400999999996, -122.183749 47.604223999999995, -122.18376 47.604037, -122.183707 47.603778, -122.183619 47.603556999999995, -122.183559 47.603406, -122.183488 47.603303, -122.183824 47.603167, -122.184108 47.603052, -122.184478 47.602902, -122.18543 47.602495, -122.186669 47.601957, -122.186433 47.601220999999995, -122.186341 47.601127999999996, -122.18874199999999 47.593742999999996, -122.188434 47.592338999999996, -122.188479 47.591786, -122.188217 47.591269999999994, -122.18795399999999 47.590871, -122.186822 47.589228, -122.187421 47.589228999999996, -122.18848299999999 47.589228999999996, -122.188433 47.587922999999996, -122.18990000000001 47.588547, -122.191368 47.589169999999996, -122.19158 47.589222, -122.191779 47.589254999999994, -122.192117 47.589289, -122.191569 47.587478999999995, -122.191323 47.586628999999995, -122.191295 47.586554, -122.191268 47.586479, -122.191192 47.586318, -122.191163 47.586268999999994, -122.1911 47.586164, -122.19099 47.586011, -122.19067 47.585668999999996, -122.1905 47.585515, -122.190301 47.58531, -122.190143 47.585152, -122.189573 47.584576999999996, -122.188702 47.583735999999995, -122.188646 47.583679, -122.188239 47.583258, -122.188037 47.583005, -122.187832 47.582657, -122.187726 47.582164999999996, -122.18769499999999 47.581964, -122.18768299999999 47.581781, -122.187678 47.581592, -122.18766099999999 47.581455, -122.187674 47.581311, -122.18768 47.581146, -122.187722 47.580877, -122.187817 47.580569999999994, -122.187932 47.580301999999996, -122.188047 47.580087, -122.188161 47.579933999999994, -122.188399 47.579660999999994, -122.18851699999999 47.579547, -122.188621 47.579454, -122.188042 47.579493, -122.18762 47.579527, -122.187806 47.579358, -122.188009 47.579175, -122.18814499999999 47.579051, -122.188177 47.579021, -122.18842000000001 47.5788, -122.188638 47.578461, -122.188895 47.57806, -122.189791 47.577281, -122.190008 47.577103, -122.190372 47.576805, -122.19119 47.576358, -122.191877 47.576087, -122.193025 47.57566, -122.194317 47.575185999999995, -122.196061 47.574664, -122.197239 47.574386999999994, -122.197873 47.574267, -122.198286 47.574189999999994, -122.199091 47.574044, -122.199067 47.574574999999996, -122.199007 47.575921, -122.200335 47.578222, -122.20057299999999 47.578345999999996, -122.2009 47.578517999999995, -122.201095 47.578621999999996, -122.20138399999999 47.578776999999995, -122.201465 47.57882, -122.201516 47.578846999999996, -122.205753 47.581112, -122.209515 47.583124, -122.210634 47.583721, -122.21473399999999 47.587021, -122.21538699999999 47.588254, -122.21580399999999 47.589042, -122.216534 47.590421, -122.220092 47.596261, -122.220434 47.596821, -122.22041899999999 47.597837999999996, -122.220289 47.606455, -122.220234 47.610121, -122.22048 47.615221999999996, -122.220359 47.615379, -122.220283 47.615477999999996, -122.21999 47.615854999999996, -122.219993 47.61597, -122.22023300000001 47.616634, -122.220356 47.616687999999996, -122.220409 47.616712, -122.221401 47.618538, -122.22142 47.618573, -122.221456 47.618635, -122.221791 47.619222, -122.222492 47.619682999999995, -122.222799 47.619886, -122.222083 47.620368, -122.222046 47.620407, -122.222028 47.620449, -122.222025 47.620483, -122.22203999999999 47.620523999999996, -122.222079 47.620557999999996, -122.222156 47.620594999999994, -122.222458 47.620629, -122.222454 47.620673, -122.222454 47.620711, -122.22244599999999 47.621041999999996, -122.223056 47.621041, -122.223129 47.62104, -122.223153 47.62104, -122.223574 47.621041, -122.22377900000001 47.621041, -122.223857 47.621041, -122.22467499999999 47.621041, -122.224712 47.62104, -122.224958 47.62104, -122.225167 47.621049, -122.226882 47.621037, -122.227565 47.621032, -122.228002 47.621029, -122.22797800000001 47.621300999999995, -122.227919 47.626574999999995, -122.227914 47.627085, -122.227901 47.6283, -122.227881 47.630069, -122.227869 47.631177, -122.227879 47.631952999999996, -122.22789 47.633879, -122.227886 47.63409, -122.227871 47.635534, -122.227918 47.635565, -122.228953 47.635624, -122.22895199999999 47.635571999999996, -122.231018 47.635574999999996, -122.233276 47.635588999999996, -122.233287 47.63617, -122.233273 47.63639, -122.233272 47.636469999999996, -122.23327 47.636578, -122.233266 47.636827, -122.233263 47.636851, -122.233262 47.637014, -122.23322999999999 47.638110999999995, -122.233239 47.638219, -122.233262 47.638279, -122.233313 47.638324999999995, -122.233255 47.638359, -122.233218 47.638380999999995, -122.233153 47.638450999999996, -122.233136 47.638552999999995, -122.233137 47.638692, -122.232715 47.639348999999996, -122.232659 47.640093, -122.232704 47.641375, -122.233821 47.645111, -122.234906 47.648874, -122.234924 47.648938, -122.235128 47.650163))\"" @@ -376,7 +388,7 @@ } ], "source": [ - "sedona.read.parquet_blocks(DATA_LINK+\"theme=places/type=place\").show()" + "sedona.read.parquet_blocks(DATA_LINK + \"theme=places/type=place\").show()" ] }, { @@ -426,7 +438,9 @@ } ], "source": [ - "sedona.read.format(\"geoparquet.metadata\").load(DATA_LINK+\"theme=places/type=place\").drop(\"path\").printSchema()" + "sedona.read.format(\"geoparquet.metadata\").load(\n", + " DATA_LINK + \"theme=places/type=place\"\n", + ").drop(\"path\").printSchema()" ] }, { @@ -485,7 +499,9 @@ } ], "source": [ - "sedona.read.format(\"geoparquet.metadata\").load(DATA_LINK+\"theme=places/type=place\").drop(\"path\").show(truncate = False)" + "sedona.read.format(\"geoparquet.metadata\").load(\n", + " DATA_LINK + \"theme=places/type=place\"\n", + ").drop(\"path\").show(truncate=False)" ] }, { @@ -573,7 +589,9 @@ } ], "source": [ - "sedona.read.format(\"geoparquet\").load(DATA_LINK+\"theme=places/type=place\").printSchema()" + "sedona.read.format(\"geoparquet\").load(\n", + " DATA_LINK + \"theme=places/type=place\"\n", + ").printSchema()" ] }, { @@ -613,9 +631,11 @@ "source": [ "%%time\n", "\n", - "df_place = sedona.read.format(\"geoparquet\").load(DATA_LINK+\"theme=places/type=place\")\n", + "df_place = sedona.read.format(\"geoparquet\").load(DATA_LINK + \"theme=places/type=place\")\n", "\n", - "df_place = df_place.filter(\"ST_Contains(ST_GeomFromWKT('\"+spatial_filter+\"'), geometry) = true\").cache()" + "df_place = df_place.filter(\n", + " \"ST_Contains(ST_GeomFromWKT('\" + spatial_filter + \"'), geometry) = true\"\n", + ").cache()" ] }, { @@ -698,8 +718,15 @@ } ], "source": [ - "df_place.select(\"id\", \"geometry\", \"categories.main\").limit(1000).repartition(1) \\\n", - " .write.format(\"geoparquet\").option(\"geoparquet.version\", \"1.0.0\").option(\"geoparquet.crs\", \"\").mode('overwrite').save(\"places.parquet\")" + "df_place.select(\"id\", \"geometry\", \"categories.main\").limit(1000).repartition(\n", + " 1\n", + ").write.format(\"geoparquet\").option(\"geoparquet.version\", \"1.0.0\").option(\n", + " \"geoparquet.crs\", \"\"\n", + ").mode(\n", + " \"overwrite\"\n", + ").save(\n", + " \"places.parquet\"\n", + ")" ] }, { @@ -861,9 +888,12 @@ } ], "source": [ - "gdf = gpd.GeoDataFrame(df_place.select(\"id\", \"geometry\", \"categories.main\").limit(1000).toPandas(), geometry=\"geometry\")\n", - "gdf.to_file('places.geojson', driver='GeoJSON')\n", - "gdf.to_file('places.shp')\n", + "gdf = gpd.GeoDataFrame(\n", + " df_place.select(\"id\", \"geometry\", \"categories.main\").limit(1000).toPandas(),\n", + " geometry=\"geometry\",\n", + ")\n", + "gdf.to_file(\"places.geojson\", driver=\"GeoJSON\")\n", + "gdf.to_file(\"places.shp\")\n", "gdf" ] }, @@ -904,9 +934,13 @@ "source": [ "%%time\n", "\n", - "df_building = sedona.read.format(\"geoparquet\").load(DATA_LINK+\"theme=buildings/type=building\")\n", + "df_building = sedona.read.format(\"geoparquet\").load(\n", + " DATA_LINK + \"theme=buildings/type=building\"\n", + ")\n", "\n", - "df_building = df_building.filter(\"ST_Contains(ST_GeomFromWKT('\"+spatial_filter+\"'), geometry) = true\")\n", + "df_building = df_building.filter(\n", + " \"ST_Contains(ST_GeomFromWKT('\" + spatial_filter + \"'), geometry) = true\"\n", + ")\n", "\n", "df_building = df_building.limit(200_000)" ] @@ -962,7 +996,7 @@ "source": [ "%%time\n", "\n", - "map_building = SedonaKepler.create_map(df_building, 'Building')\n", + "map_building = SedonaKepler.create_map(df_building, \"Building\")\n", "map_building" ] }, @@ -1007,9 +1041,13 @@ "source": [ "%%time\n", "\n", - "df_admin = sedona.read.format(\"geoparquet\").load(DATA_LINK+\"theme=admins/type=administrativeBoundary\")\n", + "df_admin = sedona.read.format(\"geoparquet\").load(\n", + " DATA_LINK + \"theme=admins/type=administrativeBoundary\"\n", + ")\n", "\n", - "df_admin = df_admin.filter(\"ST_Contains(ST_GeomFromWKT('\"+spatial_filter+\"'), geometry) = true\")" + "df_admin = df_admin.filter(\n", + " \"ST_Contains(ST_GeomFromWKT('\" + spatial_filter + \"'), geometry) = true\"\n", + ")" ] }, { @@ -1103,9 +1141,13 @@ "source": [ "%%time\n", "\n", - "df_locality = sedona.read.format(\"geoparquet\").load(DATA_LINK+\"theme=admins/type=locality\")\n", + "df_locality = sedona.read.format(\"geoparquet\").load(\n", + " DATA_LINK + \"theme=admins/type=locality\"\n", + ")\n", "\n", - "df_locality = df_locality.filter(\"ST_Contains(ST_GeomFromWKT('\"+spatial_filter+\"'), geometry) = true\")" + "df_locality = df_locality.filter(\n", + " \"ST_Contains(ST_GeomFromWKT('\" + spatial_filter + \"'), geometry) = true\"\n", + ")" ] }, { @@ -1161,7 +1203,7 @@ "source": [ "%%time\n", "\n", - "map_locality = SedonaKepler.create_map(df_locality, 'Locality')\n", + "map_locality = SedonaKepler.create_map(df_locality, \"Locality\")\n", "\n", "map_locality" ] @@ -1207,9 +1249,13 @@ "source": [ "%%time\n", "\n", - "df_connector = sedona.read.format(\"geoparquet\").load(DATA_LINK+\"theme=transportation/type=connector\")\n", + "df_connector = sedona.read.format(\"geoparquet\").load(\n", + " DATA_LINK + \"theme=transportation/type=connector\"\n", + ")\n", "\n", - "df_connector = df_connector.filter(\"ST_Contains(ST_GeomFromWKT('\"+spatial_filter+\"'), geometry) = true\")" + "df_connector = df_connector.filter(\n", + " \"ST_Contains(ST_GeomFromWKT('\" + spatial_filter + \"'), geometry) = true\"\n", + ")" ] }, { @@ -1306,9 +1352,13 @@ "source": [ "%%time\n", "\n", - "df_segment = sedona.read.format(\"geoparquet\").load(DATA_LINK+\"theme=transportation/type=segment\")\n", + "df_segment = sedona.read.format(\"geoparquet\").load(\n", + " DATA_LINK + \"theme=transportation/type=segment\"\n", + ")\n", "\n", - "df_segment = df_segment.filter(\"ST_Contains(ST_GeomFromWKT('\"+spatial_filter+\"'), geometry) = true\")\n", + "df_segment = df_segment.filter(\n", + " \"ST_Contains(ST_GeomFromWKT('\" + spatial_filter + \"'), geometry) = true\"\n", + ")\n", "\n", "df_segment = df_segment.limit(200000)" ] diff --git a/docs/usecases/contrib/ApacheSedonaImageFilter.ipynb b/docs/usecases/contrib/ApacheSedonaImageFilter.ipynb index 8dea56b16e..13d4ea7f72 100644 --- a/docs/usecases/contrib/ApacheSedonaImageFilter.ipynb +++ b/docs/usecases/contrib/ApacheSedonaImageFilter.ipynb @@ -52,7 +52,15 @@ "from pyspark.sql import SparkSession\n", "from pyspark import StorageLevel\n", "import pandas as pd\n", - "from pyspark.sql.types import StructType, StructField,StringType, LongType, IntegerType, DoubleType, ArrayType\n", + "from pyspark.sql.types import (\n", + " StructType,\n", + " StructField,\n", + " StringType,\n", + " LongType,\n", + " IntegerType,\n", + " DoubleType,\n", + " ArrayType,\n", + ")\n", "from pyspark.sql.functions import regexp_replace\n", "from sedona.register import SedonaRegistrator\n", "from sedona.utils import SedonaKryoRegistrator, KryoSerializer\n", @@ -120,18 +128,21 @@ } ], "source": [ - "spark = SparkSession.\\\n", - " builder.\\\n", - " appName(\"Demo-app\").\\\n", - " enableHiveSupport().\\\n", - " master(\"local[*]\").\\\n", - " master(\"spark://spark-master:7077\").\\\n", - " config(\"spark.executor.memory\", \"15G\").\\\n", - " config(\"spark.driver.maxResultSize\", \"15G\").\\\n", - " config(\"spark.serializer\", KryoSerializer.getName).\\\n", - " config(\"spark.kryo.registrator\", SedonaKryoRegistrator.getName).\\\n", - " config(\"spark.jars.packages\", \"org.apache.sedona:sedona-python-adapter-3.0_2.12:1.1.0-incubating,org.datasyslab:geotools-wrapper:1.1.0-25.2\") .\\\n", - " getOrCreate()\n", + "spark = (\n", + " SparkSession.builder.appName(\"Demo-app\")\n", + " .enableHiveSupport()\n", + " .master(\"local[*]\")\n", + " .master(\"spark://spark-master:7077\")\n", + " .config(\"spark.executor.memory\", \"15G\")\n", + " .config(\"spark.driver.maxResultSize\", \"15G\")\n", + " .config(\"spark.serializer\", KryoSerializer.getName)\n", + " .config(\"spark.kryo.registrator\", SedonaKryoRegistrator.getName)\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-python-adapter-3.0_2.12:1.1.0-incubating,org.datasyslab:geotools-wrapper:1.1.0-25.2\",\n", + " )\n", + " .getOrCreate()\n", + ")\n", "# config(\"spark.rpc.message.maxSize\", 2047).\\\n", "# rdd = spark.sparkContext.parallelize(range(1000))\n", "# rdd.takeSample(False, 5)\n", @@ -147,9 +158,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Path to directory of geotiff images \n", + "# Path to directory of geotiff images\n", "DATA_DIR = \"hdfs://776faf4d6a1e:8020/tmp/\"\n", - "df = spark.read.format(\"geotiff\").option(\"dropInvalid\",True).load(DATA_DIR)" + "df = spark.read.format(\"geotiff\").option(\"dropInvalid\", True).load(DATA_DIR)" ] }, { @@ -218,7 +229,14 @@ "source": [ "# Java Heap Out Of Memory => Ir nas máquinas e aumentar o export _JAVA_OPTIONS=\"-Xmx15g\"\n", "# Java lang Assertion Error image is too large =>\n", - "df = df.selectExpr(\"image.origin as origin\",\"ST_GeomFromWkt(image.wkt) as Geom\", \"image.height as height\", \"image.width as width\", \"image.data as data\", \"image.nBands as bands\").cache()\n", + "df = df.selectExpr(\n", + " \"image.origin as origin\",\n", + " \"ST_GeomFromWkt(image.wkt) as Geom\",\n", + " \"image.height as height\",\n", + " \"image.width as width\",\n", + " \"image.data as data\",\n", + " \"image.nBands as bands\",\n", + ").cache()\n", "df.show(5)\n", "# df.count()" ] @@ -231,7 +249,13 @@ "outputs": [], "source": [ "# ,\"RS_GetBand(data, 2,bands) as Band2\",\"RS_GetBand(data, 3,bands) as Band3\", \"RS_GetBand(data, 4,bands) as Band4\"\n", - "df = df.selectExpr(\"Geom\",\"RS_GetBand(data, 1,bands) as Band1\",\"RS_GetBand(data, 2,bands) as Band2\",\"RS_GetBand(data, 3,bands) as Band3\", \"RS_GetBand(data, 4,bands) as Band4\").cache()\n", + "df = df.selectExpr(\n", + " \"Geom\",\n", + " \"RS_GetBand(data, 1,bands) as Band1\",\n", + " \"RS_GetBand(data, 2,bands) as Band2\",\n", + " \"RS_GetBand(data, 3,bands) as Band3\",\n", + " \"RS_GetBand(data, 4,bands) as Band4\",\n", + ").cache()\n", "df.createOrReplaceTempView(\"allbands\")\n", "df.show(5)" ] @@ -254,7 +278,9 @@ "metadata": {}, "outputs": [], "source": [ - "NomalizedDifference = df.selectExpr(\"RS_NormalizedDifference(Band1, Band2) as normDiff\").cache()\n", + "NomalizedDifference = df.selectExpr(\n", + " \"RS_NormalizedDifference(Band1, Band2) as normDiff\"\n", + ").cache()\n", "NomalizedDifference.show(5)" ] }, @@ -287,7 +313,9 @@ "metadata": {}, "outputs": [], "source": [ - "greaterthanDF = spark.sql(\"Select RS_GreaterThan(Band1,1000.0) as greaterthan from allbands\").cache()\n", + "greaterthanDF = spark.sql(\n", + " \"Select RS_GreaterThan(Band1,1000.0) as greaterthan from allbands\"\n", + ").cache()\n", "greaterthanDF.show()" ] }, @@ -298,7 +326,9 @@ "metadata": {}, "outputs": [], "source": [ - "greaterthanEqualDF = spark.sql(\"Select RS_GreaterThanEqual(Band1,360.0) as greaterthanEqual from allbands\").cache()\n", + "greaterthanEqualDF = spark.sql(\n", + " \"Select RS_GreaterThanEqual(Band1,360.0) as greaterthanEqual from allbands\"\n", + ").cache()\n", "greaterthanEqualDF.show()" ] }, @@ -309,7 +339,9 @@ "metadata": {}, "outputs": [], "source": [ - "lessthanDF = spark.sql(\"Select RS_LessThan(Band1,1000.0) as lessthan from allbands\").cache()\n", + "lessthanDF = spark.sql(\n", + " \"Select RS_LessThan(Band1,1000.0) as lessthan from allbands\"\n", + ").cache()\n", "lessthanDF.show()" ] }, @@ -320,7 +352,9 @@ "metadata": {}, "outputs": [], "source": [ - "lessthanEqualDF = spark.sql(\"Select RS_LessThanEqual(Band1,2890.0) as lessthanequal from allbands\").cache()\n", + "lessthanEqualDF = spark.sql(\n", + " \"Select RS_LessThanEqual(Band1,2890.0) as lessthanequal from allbands\"\n", + ").cache()\n", "lessthanEqualDF.show()" ] }, @@ -463,12 +497,33 @@ "metadata": {}, "outputs": [], "source": [ - "df = spark.read.format(\"geotiff\").option(\"dropInvalid\",True).load(DATA_DIR)\n", - "df = df.selectExpr(\"image.origin as origin\",\"ST_GeomFromWkt(image.wkt) as Geom\", \"image.height as height\", \"image.width as width\", \"image.data as data\", \"image.nBands as bands\").cache()\n", + "df = spark.read.format(\"geotiff\").option(\"dropInvalid\", True).load(DATA_DIR)\n", + "df = df.selectExpr(\n", + " \"image.origin as origin\",\n", + " \"ST_GeomFromWkt(image.wkt) as Geom\",\n", + " \"image.height as height\",\n", + " \"image.width as width\",\n", + " \"image.data as data\",\n", + " \"image.nBands as bands\",\n", + ").cache()\n", "\n", - "df = df.selectExpr(\"RS_GetBand(data,1,bands) as targetband\", \"height\", \"width\", \"bands\", \"Geom\")\n", - "df_base64 = df.selectExpr(\"Geom\", \"RS_Base64(height,width,RS_Normalize(targetBand), RS_Array(height*width,0.0), RS_Array(height*width, 0.0)) as red\",\"RS_Base64(height,width,RS_Array(height*width, 0.0), RS_Normalize(targetBand), RS_Array(height*width, 0.0)) as green\", \"RS_Base64(height,width,RS_Array(height*width, 0.0), RS_Array(height*width, 0.0), RS_Normalize(targetBand)) as blue\",\"RS_Base64(height,width,RS_Normalize(targetBand), RS_Normalize(targetBand),RS_Normalize(targetBand)) as RGB\" ).cache()\n", - "df_HTML = df_base64.selectExpr(\"Geom\",\"RS_HTML(red) as RedBand\",\"RS_HTML(blue) as BlueBand\",\"RS_HTML(green) as GreenBand\", \"RS_HTML(RGB) as CombinedBand\").cache()\n", + "df = df.selectExpr(\n", + " \"RS_GetBand(data,1,bands) as targetband\", \"height\", \"width\", \"bands\", \"Geom\"\n", + ")\n", + "df_base64 = df.selectExpr(\n", + " \"Geom\",\n", + " \"RS_Base64(height,width,RS_Normalize(targetBand), RS_Array(height*width,0.0), RS_Array(height*width, 0.0)) as red\",\n", + " \"RS_Base64(height,width,RS_Array(height*width, 0.0), RS_Normalize(targetBand), RS_Array(height*width, 0.0)) as green\",\n", + " \"RS_Base64(height,width,RS_Array(height*width, 0.0), RS_Array(height*width, 0.0), RS_Normalize(targetBand)) as blue\",\n", + " \"RS_Base64(height,width,RS_Normalize(targetBand), RS_Normalize(targetBand),RS_Normalize(targetBand)) as RGB\",\n", + ").cache()\n", + "df_HTML = df_base64.selectExpr(\n", + " \"Geom\",\n", + " \"RS_HTML(red) as RedBand\",\n", + " \"RS_HTML(blue) as BlueBand\",\n", + " \"RS_HTML(green) as GreenBand\",\n", + " \"RS_HTML(RGB) as CombinedBand\",\n", + ").cache()\n", "df_HTML.show(5)" ] }, @@ -489,14 +544,14 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def SumOfValues(band):\n", " total = 0.0\n", " for num in band:\n", - " if num>1000.0:\n", - " total+=1\n", + " if num > 1000.0:\n", + " total += 1\n", " return total\n", - " \n", + "\n", + "\n", "calculateSum = udf(SumOfValues, DoubleType())\n", "spark.udf.register(\"RS_Sum\", calculateSum)\n", "\n", @@ -511,20 +566,26 @@ "metadata": {}, "outputs": [], "source": [ - "def generatemask(band, width,height):\n", - " for (i,val) in enumerate(band):\n", - " if (i%width>=12 and i%width<26) and (i%height>=12 and i%height<26):\n", + "def generatemask(band, width, height):\n", + " for i, val in enumerate(band):\n", + " if (i % width >= 12 and i % width < 26) and (\n", + " i % height >= 12 and i % height < 26\n", + " ):\n", " band[i] = 255.0\n", " else:\n", " band[i] = 0.0\n", " return band\n", "\n", + "\n", "maskValues = udf(generatemask, ArrayType(DoubleType()))\n", "spark.udf.register(\"RS_MaskValues\", maskValues)\n", "\n", "\n", - "df_base64 = df.selectExpr(\"Geom\", \"RS_Base64(height,width,RS_Normalize(targetband), RS_Array(height*width,0.0), RS_Array(height*width, 0.0), RS_MaskValues(targetband,width,height)) as region\" ).cache()\n", - "df_HTML = df_base64.selectExpr(\"Geom\",\"RS_HTML(region) as selectedregion\").cache()\n", + "df_base64 = df.selectExpr(\n", + " \"Geom\",\n", + " \"RS_Base64(height,width,RS_Normalize(targetband), RS_Array(height*width,0.0), RS_Array(height*width, 0.0), RS_MaskValues(targetband,width,height)) as region\",\n", + ").cache()\n", + "df_HTML = df_base64.selectExpr(\"Geom\", \"RS_HTML(region) as selectedregion\").cache()\n", "display(HTML(df_HTML.limit(2).toPandas().to_html(escape=False)))" ] }, diff --git a/docs/usecases/contrib/DownloadImageFromGEE.ipynb b/docs/usecases/contrib/DownloadImageFromGEE.ipynb index 9d5a873c0c..b90b630473 100644 --- a/docs/usecases/contrib/DownloadImageFromGEE.ipynb +++ b/docs/usecases/contrib/DownloadImageFromGEE.ipynb @@ -119,11 +119,19 @@ "percent_cloud = 100\n", "# box = [[xmin, ymin], [xmin, ymax], [xmax, ymax], [xmax, ymin], [xmin, ymin]]\n", "# https://boundingbox.klokantech.com/\n", - "box = [[[-54.6306579887,-25.5892766534],[-54.5393341362,-25.5892766534],[-54.5393341362,-25.4046299874],[-54.6306579887,-25.4046299874],[-54.6306579887,-25.5892766534]]]\n", + "box = [\n", + " [\n", + " [-54.6306579887, -25.5892766534],\n", + " [-54.5393341362, -25.5892766534],\n", + " [-54.5393341362, -25.4046299874],\n", + " [-54.6306579887, -25.4046299874],\n", + " [-54.6306579887, -25.5892766534],\n", + " ]\n", + "]\n", "boundary = ee.Geometry.Polygon(box, None, False)\n", "collection_name = \"COPERNICUS/S2_SR\"\n", "scale = 10\n", - "crs = 'EPSG:4326'\n", + "crs = \"EPSG:4326\"\n", "# CRIE UMA PASTA VAZIA NO DIRETORIO RASTER E COLOQUE O CAMINHO DELA AQUI EM BAIXO\n", "out_dir = \"raster/sentinel2_tmp\"\n", "hdfs_dir = \"sentinel2_tmp\"" @@ -237,18 +245,24 @@ "source": [ "for i in range(2, 15, 2):\n", " start_date = str((end_date - timedelta(days=i)))\n", - " collection = ee.ImageCollection(collection_name) \\\n", - " .select(['B2', 'B3', 'B4', 'B8']) \\\n", - " .filterBounds(boundary) \\\n", - " .filterMetadata('CLOUDY_PIXEL_PERCENTAGE', 'less_than', percent_cloud) \\\n", + " collection = (\n", + " ee.ImageCollection(collection_name)\n", + " .select([\"B2\", \"B3\", \"B4\", \"B8\"])\n", + " .filterBounds(boundary)\n", + " .filterMetadata(\"CLOUDY_PIXEL_PERCENTAGE\", \"less_than\", percent_cloud)\n", " .filterDate(start_date, str(end_date))\n", - " geemap.ee_export_image_collection(collection, scale=scale, crs=crs, region=boundary, out_dir=out_dir)\n", - "# NOT SURE WHAT F.. THIS BELOW DOES \n", + " )\n", + " geemap.ee_export_image_collection(\n", + " collection, scale=scale, crs=crs, region=boundary, out_dir=out_dir\n", + " )\n", + " # NOT SURE WHAT F.. THIS BELOW DOES\n", " for root, directory, files in os.walk(out_dir):\n", - " if len(files) == 0:\n", - " geemap.ee_export_image_collection(collection, scale=scale, crs=crs, region=boundary, out_dir=out_dir)\n", - " else:\n", - " print('Existem imagens na pasta')\n" + " if len(files) == 0:\n", + " geemap.ee_export_image_collection(\n", + " collection, scale=scale, crs=crs, region=boundary, out_dir=out_dir\n", + " )\n", + " else:\n", + " print(\"Existem imagens na pasta\")" ] }, { @@ -258,7 +272,7 @@ "metadata": {}, "outputs": [], "source": [ - "hdfs = PyWebHdfsClient(host='179.106.229.159',port='50070', user_name='root')\n", + "hdfs = PyWebHdfsClient(host=\"179.106.229.159\", port=\"50070\", user_name=\"root\")\n", "hdfs.delete_file_dir(hdfs_dir, recursive=True)\n", "hdfs.make_dir(hdfs_dir)" ] diff --git a/docs/usecases/contrib/NdviSentinelApacheSedona.ipynb b/docs/usecases/contrib/NdviSentinelApacheSedona.ipynb index 3eb7aa31b0..c1614ae75d 100644 --- a/docs/usecases/contrib/NdviSentinelApacheSedona.ipynb +++ b/docs/usecases/contrib/NdviSentinelApacheSedona.ipynb @@ -30,7 +30,7 @@ "metadata": {}, "outputs": [], "source": [ - "# pip install sklearn \n", + "# pip install sklearn\n", "# pip install pyarrow\n", "# pip install fsspec" ] @@ -46,7 +46,15 @@ "from pyspark.sql import SparkSession\n", "from pyspark import StorageLevel\n", "import pandas as pd\n", - "from pyspark.sql.types import StructType, StructField,StringType, LongType, IntegerType, DoubleType, ArrayType\n", + "from pyspark.sql.types import (\n", + " StructType,\n", + " StructField,\n", + " StringType,\n", + " LongType,\n", + " IntegerType,\n", + " DoubleType,\n", + " ArrayType,\n", + ")\n", "from pyspark.sql.functions import regexp_replace\n", "from sedona.register import SedonaRegistrator\n", "from sedona.utils import SedonaKryoRegistrator, KryoSerializer\n", @@ -78,7 +86,7 @@ ], "source": [ "analise_folder = \"analise_teste_\" + str(date.today())\n", - "hdfs = PyWebHdfsClient(host='179.106.229.159',port='50070', user_name='root')\n", + "hdfs = PyWebHdfsClient(host=\"179.106.229.159\", port=\"50070\", user_name=\"root\")\n", "hdfs.delete_file_dir(analise_folder, recursive=True)" ] }, @@ -143,26 +151,29 @@ ], "source": [ "# spark.scheduler.mode', 'FAIR'\n", - "spark = SparkSession.\\\n", - " builder.\\\n", - " appName(\"Sentinel-app\").\\\n", - " enableHiveSupport().\\\n", - " master(\"local[*]\").\\\n", - " master(\"spark://spark-master:7077\").\\\n", - " config(\"spark.executor.memory\", \"15G\").\\\n", - " config(\"spark.driver.maxResultSize\", \"135G\").\\\n", - " config(\"spark.sql.shuffle.partitions\", \"500\").\\\n", - " config(' spark.sql.adaptive.coalescePartitions.enabled', True).\\\n", - " config('spark.sql.adaptive.enabled', True).\\\n", - " config('spark.sql.adaptive.coalescePartitions.initialPartitionNum', 125).\\\n", - " config(\"spark.sql.execution.arrow.pyspark.enabled\", True).\\\n", - " config(\"spark.sql.execution.arrow.fallback.enabled\", True).\\\n", - " config('spark.kryoserializer.buffer.max', 2047).\\\n", - " config(\"spark.serializer\", KryoSerializer.getName).\\\n", - " config(\"spark.kryo.registrator\", SedonaKryoRegistrator.getName).\\\n", - " config(\"spark.jars.packages\", \"org.apache.sedona:sedona-python-adapter-3.0_2.12:1.1.0-incubating,org.datasyslab:geotools-wrapper:1.1.0-25.2\") .\\\n", - " enableHiveSupport().\\\n", - " getOrCreate()\n", + "spark = (\n", + " SparkSession.builder.appName(\"Sentinel-app\")\n", + " .enableHiveSupport()\n", + " .master(\"local[*]\")\n", + " .master(\"spark://spark-master:7077\")\n", + " .config(\"spark.executor.memory\", \"15G\")\n", + " .config(\"spark.driver.maxResultSize\", \"135G\")\n", + " .config(\"spark.sql.shuffle.partitions\", \"500\")\n", + " .config(\" spark.sql.adaptive.coalescePartitions.enabled\", True)\n", + " .config(\"spark.sql.adaptive.enabled\", True)\n", + " .config(\"spark.sql.adaptive.coalescePartitions.initialPartitionNum\", 125)\n", + " .config(\"spark.sql.execution.arrow.pyspark.enabled\", True)\n", + " .config(\"spark.sql.execution.arrow.fallback.enabled\", True)\n", + " .config(\"spark.kryoserializer.buffer.max\", 2047)\n", + " .config(\"spark.serializer\", KryoSerializer.getName)\n", + " .config(\"spark.kryo.registrator\", SedonaKryoRegistrator.getName)\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-python-adapter-3.0_2.12:1.1.0-incubating,org.datasyslab:geotools-wrapper:1.1.0-25.2\",\n", + " )\n", + " .enableHiveSupport()\n", + " .getOrCreate()\n", + ")\n", "\n", "SedonaRegistrator.registerAll(spark)\n", "sc = spark.sparkContext" @@ -183,9 +194,9 @@ } ], "source": [ - "# Path to directory of geotiff images \n", + "# Path to directory of geotiff images\n", "DATA_DIR = \"hdfs://776faf4d6a1e:8020/sentinel2_tmp/*\"\n", - "df = spark.read.format(\"geotiff\").option(\"dropInvalid\",True).load(DATA_DIR)" + "df = spark.read.format(\"geotiff\").option(\"dropInvalid\", True).load(DATA_DIR)" ] }, { @@ -204,8 +215,8 @@ ], "source": [ "# SUPER IMPORTANT ULTRA MEGA POWER FOR MEMORY PROBLENS SOLVE\n", - "rdd = spark.sparkContext.parallelize((0,20))\n", - "print(\"From local[5]\"+str(rdd.getNumPartitions()))" + "rdd = spark.sparkContext.parallelize((0, 20))\n", + "print(\"From local[5]\" + str(rdd.getNumPartitions()))" ] }, { @@ -295,7 +306,8 @@ } ], "source": [ - "from pyspark.sql.functions import monotonically_increasing_id \n", + "from pyspark.sql.functions import monotonically_increasing_id\n", + "\n", "# add ID\n", "df_index = df.select(\"*\").withColumn(\"id\", monotonically_increasing_id())\n", "df_index.explain()\n", @@ -325,10 +337,15 @@ } ], "source": [ - "# \"image.wkt as Geom\", \n", - "df_export = df_index.selectExpr(\"id\",\"image.origin as origin\",\n", - " \"image.height as height\", \"image.width as width\", \n", - " \"cast(image.data as string) as data\", \"image.nBands as bands\")\n", + "# \"image.wkt as Geom\",\n", + "df_export = df_index.selectExpr(\n", + " \"id\",\n", + " \"image.origin as origin\",\n", + " \"image.height as height\",\n", + " \"image.width as width\",\n", + " \"cast(image.data as string) as data\",\n", + " \"image.nBands as bands\",\n", + ")\n", "print(df_export.dtypes)\n", "df_export.explain()\n", "df_export.createOrReplaceTempView(\"df_export\")" @@ -372,7 +389,7 @@ "# part_df_export.take(3)\n", "part_df_export = df_export.take(1)\n", "# print(part_df_export)\n", - "pd.DataFrame(part_df_export).to_csv(\"teste.csv\", sep=',', encoding='utf-8')" + "pd.DataFrame(part_df_export).to_csv(\"teste.csv\", sep=\",\", encoding=\"utf-8\")" ] }, { @@ -417,9 +434,14 @@ } ], "source": [ - "df = df.selectExpr(\"image.origin as origin\",\"ST_GeomFromWkt(image.wkt) as Geom\", \n", - " \"image.height as height\", \"image.width as width\", \"image.data as data\", \n", - " \"image.nBands as bands\").cache()\n", + "df = df.selectExpr(\n", + " \"image.origin as origin\",\n", + " \"ST_GeomFromWkt(image.wkt) as Geom\",\n", + " \"image.height as height\",\n", + " \"image.width as width\",\n", + " \"image.data as data\",\n", + " \"image.nBands as bands\",\n", + ").cache()\n", "df.show(5)\n", "print(df.dtypes)\n", "df.explain()" @@ -465,15 +487,20 @@ } ], "source": [ - "df = df.selectExpr(\"origin\", \"Geom\",\"RS_GetBand(data, 1,bands) as B2\",\"RS_GetBand(data, 2,bands) as B3\",\n", - " \"RS_GetBand(data, 3,bands) as B4\", \n", - " \"RS_GetBand(data, 4,bands) as B8\", \n", - " \"RS_Array(height * width, 2.4) as constant_evi_2\",\n", - " \"RS_Array(height * width, 2.5) as constant_evi_1\",\n", - " \"RS_Array(height * width, 1.0) as constant_evi_3\",\n", - " \"RS_Array(height * width, -0.5) as constant_tgi_1\",\n", - " \"RS_Array(height * width, 120.0) as constant_tgi_2\",\n", - " \"RS_Array(height * width, 0.001) as corrector\").cache()\n", + "df = df.selectExpr(\n", + " \"origin\",\n", + " \"Geom\",\n", + " \"RS_GetBand(data, 1,bands) as B2\",\n", + " \"RS_GetBand(data, 2,bands) as B3\",\n", + " \"RS_GetBand(data, 3,bands) as B4\",\n", + " \"RS_GetBand(data, 4,bands) as B8\",\n", + " \"RS_Array(height * width, 2.4) as constant_evi_2\",\n", + " \"RS_Array(height * width, 2.5) as constant_evi_1\",\n", + " \"RS_Array(height * width, 1.0) as constant_evi_3\",\n", + " \"RS_Array(height * width, -0.5) as constant_tgi_1\",\n", + " \"RS_Array(height * width, 120.0) as constant_tgi_2\",\n", + " \"RS_Array(height * width, 0.001) as corrector\",\n", + ").cache()\n", "df.createOrReplaceTempView(\"allbands\")\n", "df.show(5)" ] @@ -517,14 +544,16 @@ } ], "source": [ - "# Não tem data da imagem \n", + "# Não tem data da imagem\n", "# Não tem parte a qual ela se refere\n", "# Necessário adicionar\n", "origin = df.selectExpr(\"origin\")\n", - "split_origin = origin.select(split(col(\"origin\"),\"/\"))\n", + "split_origin = origin.select(split(col(\"origin\"), \"/\"))\n", "split_origin.head()\n", "# 20211226T134212\n", - "split_origin = spark.sql(\"select to_timestamp(REPLACE(SPLIT(SPLIT(origin,'/')[5], '_')[1],'T',' '),'yyyyMMdd HHmmss') as image_date, SPLIT(origin,'/')[4] as feature_name, * from allbands\")\n", + "split_origin = spark.sql(\n", + " \"select to_timestamp(REPLACE(SPLIT(SPLIT(origin,'/')[5], '_')[1],'T',' '),'yyyyMMdd HHmmss') as image_date, SPLIT(origin,'/')[4] as feature_name, * from allbands\"\n", + ")\n", "split_origin.show(5)" ] }, @@ -550,33 +579,43 @@ ], "source": [ "# Fator de correcao da banda para ficar com valores entre 0 e 1\n", - "correct_origin = split_origin.selectExpr(\"RS_MultiplyBands(B2, corrector) as bluen\",\n", - " \"RS_MultiplyBands(B3, corrector) as greenn\", \n", - " \"RS_MultiplyBands(B4, corrector) as redn\",\n", - " \"RS_MultiplyBands(B8, corrector) as nirn\", \n", - " \"*\").cache()\n", - "correct_origin = correct_origin.selectExpr(\"RS_NormalizedDifference(nirn, redn) as gndvi\",\n", - " \"RS_SubtractBands(nirn, redn) as sub_nirn_redn\", \n", - " \"RS_AddBands(nirn,constant_evi_2) as add_nirn_contant_evi_2\",\n", - " \"RS_AddBands(redn, constant_evi_3) as add_redn_contant_evi_3\", \n", - " \"RS_DivideBands(nirn, greenn) as div_nirn_greenn\",\n", - " \"RS_SubtractBands(greenn, redn) as sub_greenn_redn\",\n", - " \"RS_SubtractBands(redn, greenn) as sub_redn_greenn\",\n", - " \"RS_SubtractBands(redn, bluen) as sub_redn_bluen\",\n", - " \"RS_AddBands(greenn, redn) as add_greenn_redn\",\n", - " \"*\").cache()\n", + "correct_origin = split_origin.selectExpr(\n", + " \"RS_MultiplyBands(B2, corrector) as bluen\",\n", + " \"RS_MultiplyBands(B3, corrector) as greenn\",\n", + " \"RS_MultiplyBands(B4, corrector) as redn\",\n", + " \"RS_MultiplyBands(B8, corrector) as nirn\",\n", + " \"*\",\n", + ").cache()\n", + "correct_origin = correct_origin.selectExpr(\n", + " \"RS_NormalizedDifference(nirn, redn) as gndvi\",\n", + " \"RS_SubtractBands(nirn, redn) as sub_nirn_redn\",\n", + " \"RS_AddBands(nirn,constant_evi_2) as add_nirn_contant_evi_2\",\n", + " \"RS_AddBands(redn, constant_evi_3) as add_redn_contant_evi_3\",\n", + " \"RS_DivideBands(nirn, greenn) as div_nirn_greenn\",\n", + " \"RS_SubtractBands(greenn, redn) as sub_greenn_redn\",\n", + " \"RS_SubtractBands(redn, greenn) as sub_redn_greenn\",\n", + " \"RS_SubtractBands(redn, bluen) as sub_redn_bluen\",\n", + " \"RS_AddBands(greenn, redn) as add_greenn_redn\",\n", + " \"*\",\n", + ").cache()\n", "\n", - "correct_origin = correct_origin.selectExpr(\"RS_SubtractBands(add_greenn_redn, bluen) as greenn_redn_sub_bluen\",\n", - " \"RS_AddBands(add_greenn_redn, bluen) as greenn_redn_add_bluen\",\n", - " \"RS_SubtractBands(sub_greenn_redn, bluen) as sub_greenn_redn_bluen\",\n", - " \"RS_SubtractBands(sub_redn_greenn, constant_tgi_2) as sub_red_gren_tgi_2\",\n", - " \"*\").cache()\n", - "correct_origin = correct_origin.selectExpr(\"RS_MultiplyFactor(sub_redn_bluen,120) as ms_redn_bluen_120\",\n", - " \"*\").cache()\n", - "correct_origin = correct_origin.selectExpr(\"RS_MultiplyFactor(sub_redn_greenn,190) as ms_redn_greenn_190\",\n", - " \"*\").cache()\n", - "correct_origin = correct_origin.selectExpr(\"RS_SubtractBands(ms_redn_greenn_190,ms_redn_bluen_120) as sub_msrg_190_msrb_120\",\n", - " \"*\").cache()" + "correct_origin = correct_origin.selectExpr(\n", + " \"RS_SubtractBands(add_greenn_redn, bluen) as greenn_redn_sub_bluen\",\n", + " \"RS_AddBands(add_greenn_redn, bluen) as greenn_redn_add_bluen\",\n", + " \"RS_SubtractBands(sub_greenn_redn, bluen) as sub_greenn_redn_bluen\",\n", + " \"RS_SubtractBands(sub_redn_greenn, constant_tgi_2) as sub_red_gren_tgi_2\",\n", + " \"*\",\n", + ").cache()\n", + "correct_origin = correct_origin.selectExpr(\n", + " \"RS_MultiplyFactor(sub_redn_bluen,120) as ms_redn_bluen_120\", \"*\"\n", + ").cache()\n", + "correct_origin = correct_origin.selectExpr(\n", + " \"RS_MultiplyFactor(sub_redn_greenn,190) as ms_redn_greenn_190\", \"*\"\n", + ").cache()\n", + "correct_origin = correct_origin.selectExpr(\n", + " \"RS_SubtractBands(ms_redn_greenn_190,ms_redn_bluen_120) as sub_msrg_190_msrb_120\",\n", + " \"*\",\n", + ").cache()" ] }, { @@ -636,24 +675,29 @@ } ], "source": [ - " # bluen = src.read(1, masked=True) / 10000\n", - " # greenn = src.read(2, masked=True) / 10000\n", - " # redn = src.read(3, masked=True) / 10000\n", - " # nirn = src.read(4, masked=True) / 10000\n", - " # evi = 2.5 * (nirn - redn) / (nirn + 2.4 * redn + 1)\n", - " # gci = (nirn / greenn) - 1\n", - " # gli = (2 * greenn - redn - bluen) / (2 * greenn + redn + bluen)\n", - " # gndvi = (nirn - greenn) / (nirn + greenn)\n", - " # tgi = (-0.5) * (190 * (redn - greenn) - 120 * (redn - bluen))\n", - " # vari = (greenn - redn) / (greenn + redn - bluen)\n", + "# bluen = src.read(1, masked=True) / 10000\n", + "# greenn = src.read(2, masked=True) / 10000\n", + "# redn = src.read(3, masked=True) / 10000\n", + "# nirn = src.read(4, masked=True) / 10000\n", + "# evi = 2.5 * (nirn - redn) / (nirn + 2.4 * redn + 1)\n", + "# gci = (nirn / greenn) - 1\n", + "# gli = (2 * greenn - redn - bluen) / (2 * greenn + redn + bluen)\n", + "# gndvi = (nirn - greenn) / (nirn + greenn)\n", + "# tgi = (-0.5) * (190 * (redn - greenn) - 120 * (redn - bluen))\n", + "# vari = (greenn - redn) / (greenn + redn - bluen)\n", "\n", "\n", - "calculated = correct_origin.selectExpr(\"RS_NormalizedDifference(nirn, redn) as gndvi\",\n", - " \"RS_DivideBands(RS_MultiplyBands(constant_evi_1, sub_nirn_redn), RS_MultiplyBands(add_nirn_contant_evi_2, add_redn_contant_evi_3)) as evi\",\n", - " \"RS_SubtractBands(div_nirn_greenn, constant_evi_3) as gci\",\n", - " \"RS_DivideBands(sub_greenn_redn, greenn_redn_sub_bluen) as vari\",\n", - " \"RS_DivideBands(RS_MultiplyFactor(sub_greenn_redn_bluen,2),RS_MultiplyFactor(greenn_redn_add_bluen, 2)) as gli\",\n", - " \"RS_MultiplyBands(constant_tgi_1,sub_msrg_190_msrb_120) as tgi\", \"origin\", \"image_date\", \"feature_name\").cache()\n", + "calculated = correct_origin.selectExpr(\n", + " \"RS_NormalizedDifference(nirn, redn) as gndvi\",\n", + " \"RS_DivideBands(RS_MultiplyBands(constant_evi_1, sub_nirn_redn), RS_MultiplyBands(add_nirn_contant_evi_2, add_redn_contant_evi_3)) as evi\",\n", + " \"RS_SubtractBands(div_nirn_greenn, constant_evi_3) as gci\",\n", + " \"RS_DivideBands(sub_greenn_redn, greenn_redn_sub_bluen) as vari\",\n", + " \"RS_DivideBands(RS_MultiplyFactor(sub_greenn_redn_bluen,2),RS_MultiplyFactor(greenn_redn_add_bluen, 2)) as gli\",\n", + " \"RS_MultiplyBands(constant_tgi_1,sub_msrg_190_msrb_120) as tgi\",\n", + " \"origin\",\n", + " \"image_date\",\n", + " \"feature_name\",\n", + ").cache()\n", "calculated.show(5)\n", "calculated.printSchema()" ] @@ -709,12 +753,17 @@ } ], "source": [ - "calculated_mean = calculated.selectExpr(\"RS_Mean(gndvi) as gndvi\",\n", - " \"RS_Mean(evi) as evi\",\n", - " \"RS_Mean(gci) as gci\",\n", - " \"RS_Mean(vari) as vari\",\n", - " \"RS_Mean(gli) as gli\",\n", - " \"RS_Mean(tgi) as tgi\", \"origin\", \"image_date\", \"feature_name\").cache()\n", + "calculated_mean = calculated.selectExpr(\n", + " \"RS_Mean(gndvi) as gndvi\",\n", + " \"RS_Mean(evi) as evi\",\n", + " \"RS_Mean(gci) as gci\",\n", + " \"RS_Mean(vari) as vari\",\n", + " \"RS_Mean(gli) as gli\",\n", + " \"RS_Mean(tgi) as tgi\",\n", + " \"origin\",\n", + " \"image_date\",\n", + " \"feature_name\",\n", + ").cache()\n", "calculated_mean.show(5)\n", "calculated_mean.printSchema()\n", "calculated_mean.createOrReplaceTempView(\"all_mean\")" @@ -740,7 +789,7 @@ "# part_df_export.take(3)\n", "part_df_export = calculated_mean.limit(10).collect()\n", "print(part_df_export)\n", - "pd.DataFrame(part_df_export).to_csv(\"teste.csv\", sep=',', encoding='utf-8')" + "pd.DataFrame(part_df_export).to_csv(\"teste.csv\", sep=\",\", encoding=\"utf-8\")" ] }, { @@ -761,6 +810,7 @@ "# SAVE COPY TO HDFS\n", "# dá o mesmo problema de threadshod unsuficiente que ocorre no fit\n", "import gc\n", + "\n", "collected = gc.collect()\n", "print(\"Garbage collector: collected %d objects.\" % collected)" ] @@ -786,7 +836,12 @@ "from pyspark.ml import Pipeline\n", "from pyspark.ml.classification import RandomForestClassifier\n", "from pyspark.ml.linalg import Vectors\n", - "from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer, VectorAssembler\n", + "from pyspark.ml.feature import (\n", + " IndexToString,\n", + " StringIndexer,\n", + " VectorIndexer,\n", + " VectorAssembler,\n", + ")\n", "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n", "from pyspark.ml.tuning import ParamGridBuilder\n", "import numpy as np\n", @@ -834,12 +889,14 @@ } ], "source": [ - "vari = calculated_mean.select('vari')\n", + "vari = calculated_mean.select(\"vari\")\n", "vari.printSchema()\n", "vari.show(5)\n", - "df_rf_assembler = calculated_mean.selectExpr(\"vari\",\"gndvi\",\"evi\",\"tgi\",\"gli\",\"cast(feature_name as long) as labels\")\n", + "df_rf_assembler = calculated_mean.selectExpr(\n", + " \"vari\", \"gndvi\", \"evi\", \"tgi\", \"gli\", \"cast(feature_name as long) as labels\"\n", + ")\n", "# FORMATO NECESSARIO PARA O FIT\n", - "feature_list = [col for col in df_rf_assembler.columns if col != 'labels']\n", + "feature_list = [col for col in df_rf_assembler.columns if col != \"labels\"]\n", "assembler = VectorAssembler(inputCols=feature_list, outputCol=\"features\")\n", "# rf = RandomForestClassifier(labelCol=\"labels\", featuresCol=\"features\")\n", "df_rf_assembler = assembler.transform(df_rf_assembler)\n", @@ -881,7 +938,11 @@ "from numpy import allclose\n", "from pyspark.ml.linalg import Vectors\n", "from pyspark.ml.feature import StringIndexer\n", - "from pyspark.ml.classification import RandomForestClassifier, RandomForestClassificationModel\n", + "from pyspark.ml.classification import (\n", + " RandomForestClassifier,\n", + " RandomForestClassificationModel,\n", + ")\n", + "\n", "# df = spark.createDataFrame([\n", "# (1.0, Vectors.dense(1.0)),\n", "# (0.0, Vectors.sparse(1, [], []))], [\"label\", \"features\"])\n", diff --git a/docs/usecases/contrib/PostgresqlConnectionApacheSedona.ipynb b/docs/usecases/contrib/PostgresqlConnectionApacheSedona.ipynb index f188bee00b..55bdfd52dc 100644 --- a/docs/usecases/contrib/PostgresqlConnectionApacheSedona.ipynb +++ b/docs/usecases/contrib/PostgresqlConnectionApacheSedona.ipynb @@ -46,11 +46,13 @@ "source": [ "from pyspark.sql import SparkSession\n", "\n", - "spark = SparkSession.builder.appName(\"db-connection-2\")\\\n", - " .master(\"spark://spark-master:7077\")\\\n", - " .config(\"spark.executor.memory\", \"10gb\")\\\n", - " .config(\"spark.jars\", \"postgresql-42.2.24.jar\") \\\n", - " .getOrCreate()" + "spark = (\n", + " SparkSession.builder.appName(\"db-connection-2\")\n", + " .master(\"spark://spark-master:7077\")\n", + " .config(\"spark.executor.memory\", \"10gb\")\n", + " .config(\"spark.jars\", \"postgresql-42.2.24.jar\")\n", + " .getOrCreate()\n", + ")" ] }, { @@ -62,8 +64,15 @@ }, "outputs": [], "source": [ - "properties = {\"user\":\"\", \"password\":\"\", \"host\":\"\", \"port\":\"\", \"database\":\"\"}\n", - "properties[\"url\"] = \"jdbc:postgresql://\"+properties[\"host\"]+\":\"+properties[\"port\"]+\"/\"+properties[\"database\"]" + "properties = {\"user\": \"\", \"password\": \"\", \"host\": \"\", \"port\": \"\", \"database\": \"\"}\n", + "properties[\"url\"] = (\n", + " \"jdbc:postgresql://\"\n", + " + properties[\"host\"]\n", + " + \":\"\n", + " + properties[\"port\"]\n", + " + \"/\"\n", + " + properties[\"database\"]\n", + ")" ] }, { @@ -75,14 +84,17 @@ }, "outputs": [], "source": [ - "jdbcDF = spark.read.format(\"jdbc\"). \\\n", - "options(\n", - " url=properties[\"url\"], # jdbc:postgresql://:/\n", - " dbtable='clima.t_indices_prec_cpc',\n", - " user=properties[\"user\"],\n", - " password=properties[\"password\"], \n", - " driver=\"org.postgresql.Driver\") \\\n", - ".load()" + "jdbcDF = (\n", + " spark.read.format(\"jdbc\")\n", + " .options(\n", + " url=properties[\"url\"], # jdbc:postgresql://:/\n", + " dbtable=\"clima.t_indices_prec_cpc\",\n", + " user=properties[\"user\"],\n", + " password=properties[\"password\"],\n", + " driver=\"org.postgresql.Driver\",\n", + " )\n", + " .load()\n", + ")" ] }, { @@ -189,11 +201,13 @@ "import psycopg2\n", "\n", "\n", - "connection = psycopg2.connect(user=properties[\"user\"],\n", - " password=properties[\"password\"],\n", - " host=properties[\"host\"],\n", - " port=properties[\"port\"],\n", - " database=properties[\"database\"])" + "connection = psycopg2.connect(\n", + " user=properties[\"user\"],\n", + " password=properties[\"password\"],\n", + " host=properties[\"host\"],\n", + " port=properties[\"port\"],\n", + " database=properties[\"database\"],\n", + ")" ] }, { @@ -216,9 +230,10 @@ "source": [ "%%timeit\n", "from pandas import DataFrame\n", + "\n", "cursor = connection.cursor()\n", "cursor.execute(\"SELECT * FROM clima.t_indices_prec_cpc t Where r100mm > 2000\")\n", - "names = [ x[0] for x in cursor.description]\n", + "names = [x[0] for x in cursor.description]\n", "result = cursor.fetchall()\n", "df = DataFrame(result, columns=names)" ] diff --git a/docs/usecases/contrib/VectorAnalisisApacheSedona.ipynb b/docs/usecases/contrib/VectorAnalisisApacheSedona.ipynb index 873b49d906..9867987395 100644 --- a/docs/usecases/contrib/VectorAnalisisApacheSedona.ipynb +++ b/docs/usecases/contrib/VectorAnalisisApacheSedona.ipynb @@ -34,7 +34,15 @@ "from pyspark.sql import SparkSession\n", "from pyspark import StorageLevel\n", "import pandas as pd\n", - "from pyspark.sql.types import StructType, StructField,StringType, LongType, IntegerType, DoubleType, ArrayType\n", + "from pyspark.sql.types import (\n", + " StructType,\n", + " StructField,\n", + " StringType,\n", + " LongType,\n", + " IntegerType,\n", + " DoubleType,\n", + " ArrayType,\n", + ")\n", "from pyspark.sql.functions import regexp_replace\n", "from sedona.register import SedonaRegistrator\n", "from sedona.utils import SedonaKryoRegistrator, KryoSerializer\n", @@ -45,7 +53,7 @@ "from pyspark.sql.functions import udf, lit, flatten\n", "from pywebhdfs.webhdfs import PyWebHdfsClient\n", "from datetime import date\n", - "from pyspark.sql.functions import monotonically_increasing_id \n", + "from pyspark.sql.functions import monotonically_increasing_id\n", "import json" ] }, @@ -110,26 +118,29 @@ ], "source": [ "# spark.scheduler.mode', 'FAIR'\n", - "spark = SparkSession.\\\n", - " builder.\\\n", - " appName(\"Overpass-API\").\\\n", - " enableHiveSupport().\\\n", - " master(\"local[*]\").\\\n", - " master(\"spark://spark-master:7077\").\\\n", - " config(\"spark.executor.memory\", \"15G\").\\\n", - " config(\"spark.driver.maxResultSize\", \"135G\").\\\n", - " config(\"spark.sql.shuffle.partitions\", \"500\").\\\n", - " config(' spark.sql.adaptive.coalescePartitions.enabled', True).\\\n", - " config('spark.sql.adaptive.enabled', True).\\\n", - " config('spark.sql.adaptive.coalescePartitions.initialPartitionNum', 125).\\\n", - " config(\"spark.sql.execution.arrow.pyspark.enabled\", True).\\\n", - " config(\"spark.sql.execution.arrow.fallback.enabled\", True).\\\n", - " config('spark.kryoserializer.buffer.max', 2047).\\\n", - " config(\"spark.serializer\", KryoSerializer.getName).\\\n", - " config(\"spark.kryo.registrator\", SedonaKryoRegistrator.getName).\\\n", - " config(\"spark.jars.packages\", \"org.apache.sedona:sedona-python-adapter-3.0_2.12:1.1.0-incubating,org.datasyslab:geotools-wrapper:1.1.0-25.2\") .\\\n", - " enableHiveSupport().\\\n", - " getOrCreate()\n", + "spark = (\n", + " SparkSession.builder.appName(\"Overpass-API\")\n", + " .enableHiveSupport()\n", + " .master(\"local[*]\")\n", + " .master(\"spark://spark-master:7077\")\n", + " .config(\"spark.executor.memory\", \"15G\")\n", + " .config(\"spark.driver.maxResultSize\", \"135G\")\n", + " .config(\"spark.sql.shuffle.partitions\", \"500\")\n", + " .config(\" spark.sql.adaptive.coalescePartitions.enabled\", True)\n", + " .config(\"spark.sql.adaptive.enabled\", True)\n", + " .config(\"spark.sql.adaptive.coalescePartitions.initialPartitionNum\", 125)\n", + " .config(\"spark.sql.execution.arrow.pyspark.enabled\", True)\n", + " .config(\"spark.sql.execution.arrow.fallback.enabled\", True)\n", + " .config(\"spark.kryoserializer.buffer.max\", 2047)\n", + " .config(\"spark.serializer\", KryoSerializer.getName)\n", + " .config(\"spark.kryo.registrator\", SedonaKryoRegistrator.getName)\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-python-adapter-3.0_2.12:1.1.0-incubating,org.datasyslab:geotools-wrapper:1.1.0-25.2\",\n", + " )\n", + " .enableHiveSupport()\n", + " .getOrCreate()\n", + ")\n", "\n", "SedonaRegistrator.registerAll(spark)\n", "sc = spark.sparkContext" @@ -174,7 +185,7 @@ "out skel qt;\n", "\"\"\"\n", "\n", - "# response = requests.get(overpass_url, \n", + "# response = requests.get(overpass_url,\n", "# params={'data': overpass_query})\n", "# data = response.json()\n", "# hdfs = PyWebHdfsClient(host='179.106.229.159',port='50070', user_name='root')\n", @@ -198,8 +209,8 @@ } ], "source": [ - "path = \"hdfs://776faf4d6a1e:8020/\"+file_name\n", - "df = spark.read.json(path, multiLine = \"true\")" + "path = \"hdfs://776faf4d6a1e:8020/\" + file_name\n", + "df = spark.read.json(path, multiLine=\"true\")" ] }, { @@ -452,32 +463,46 @@ "ids = pd.DataFrame(isolate_ids[\"id\"].iloc[0]).drop_duplicates()\n", "print(ids[0].iloc[1])\n", "\n", - "formatted_df = tb\\\n", - ".withColumn(\"id\", explode(\"elements.id\"))\n", + "formatted_df = tb.withColumn(\"id\", explode(\"elements.id\"))\n", "\n", "formatted_df.show(5)\n", "\n", - "formatted_df = tb\\\n", - ".withColumn(\"new\", arrays_zip(\"elements.id\", \"elements.geometry\", \"elements.nodes\", \"elements.tags\"))\\\n", - ".withColumn(\"new\", explode(\"new\"))\n", + "formatted_df = tb.withColumn(\n", + " \"new\",\n", + " arrays_zip(\"elements.id\", \"elements.geometry\", \"elements.nodes\", \"elements.tags\"),\n", + ").withColumn(\"new\", explode(\"new\"))\n", "\n", "formatted_df.show(5)\n", "\n", "# formatted_df.printSchema()\n", "\n", - "formatted_df = formatted_df.select(\"new.0\",\"new.1\",\"new.2\",\"new.3.maxspeed\",\"new.3.incline\",\"new.3.surface\", \"new.3.name\", \"total_nodes\")\n", - "formatted_df = formatted_df.withColumnRenamed(\"0\",\"id\").withColumnRenamed(\"1\",\"geom\").withColumnRenamed(\"2\",\"nodes\").withColumnRenamed(\"3\",\"tags\")\n", + "formatted_df = formatted_df.select(\n", + " \"new.0\",\n", + " \"new.1\",\n", + " \"new.2\",\n", + " \"new.3.maxspeed\",\n", + " \"new.3.incline\",\n", + " \"new.3.surface\",\n", + " \"new.3.name\",\n", + " \"total_nodes\",\n", + ")\n", + "formatted_df = (\n", + " formatted_df.withColumnRenamed(\"0\", \"id\")\n", + " .withColumnRenamed(\"1\", \"geom\")\n", + " .withColumnRenamed(\"2\", \"nodes\")\n", + " .withColumnRenamed(\"3\", \"tags\")\n", + ")\n", "formatted_df.createOrReplaceTempView(\"formatted_df\")\n", "formatted_df.show(5)\n", "# TODO atualizar daqui para baixo para considerar a linha inteira na lógica\n", "points_tb = spark.sql(\"select geom, id from formatted_df where geom IS NOT NULL\")\n", - "points_tb = points_tb\\\n", - ".withColumn(\"new\", arrays_zip(\"geom.lat\", \"geom.lon\"))\\\n", - ".withColumn(\"new\", explode(\"new\"))\n", + "points_tb = points_tb.withColumn(\"new\", arrays_zip(\"geom.lat\", \"geom.lon\")).withColumn(\n", + " \"new\", explode(\"new\")\n", + ")\n", "\n", - "points_tb = points_tb.select(\"new.0\",\"new.1\", \"id\")\n", + "points_tb = points_tb.select(\"new.0\", \"new.1\", \"id\")\n", "\n", - "points_tb = points_tb.withColumnRenamed(\"0\",\"lat\").withColumnRenamed(\"1\",\"lon\")\n", + "points_tb = points_tb.withColumnRenamed(\"0\", \"lat\").withColumnRenamed(\"1\", \"lon\")\n", "points_tb.printSchema()\n", "\n", "points_tb.createOrReplaceTempView(\"points_tb\")\n", @@ -496,11 +521,15 @@ "# \t\t\t\tST_Point(20, 80))\n", "# \t\t\t\t)) As wktenv;\n", "\n", - "coordinates_tb = spark.sql(\"select (select collect_list(CONCAT(p1.lat,',',p1.lon)) from points_tb p1 where p1.id = p2.id group by p1.id) as coordinates, p2.id, p2.maxspeed, p2.incline, p2.surface, p2.name, p2.nodes, p2.total_nodes from formatted_df p2\")\n", + "coordinates_tb = spark.sql(\n", + " \"select (select collect_list(CONCAT(p1.lat,',',p1.lon)) from points_tb p1 where p1.id = p2.id group by p1.id) as coordinates, p2.id, p2.maxspeed, p2.incline, p2.surface, p2.name, p2.nodes, p2.total_nodes from formatted_df p2\"\n", + ")\n", "coordinates_tb.createOrReplaceTempView(\"coordinates_tb\")\n", "coordinates_tb.show(5)\n", "\n", - "roads_tb = spark.sql(\"SELECT ST_LineStringFromText(REPLACE(REPLACE(CAST(coordinates as string),'[',''),']',''), ',') as geom, id, maxspeed, incline, surface, name, nodes, total_nodes FROM coordinates_tb WHERE coordinates IS NOT NULL\")\n", + "roads_tb = spark.sql(\n", + " \"SELECT ST_LineStringFromText(REPLACE(REPLACE(CAST(coordinates as string),'[',''),']',''), ',') as geom, id, maxspeed, incline, surface, name, nodes, total_nodes FROM coordinates_tb WHERE coordinates IS NOT NULL\"\n", + ")\n", "roads_tb.createOrReplaceTempView(\"roads_tb\")\n", "roads_tb.show(5)" ] @@ -520,8 +549,7 @@ } ], "source": [ - "\n", - "print(roads_tb.select('geom').take(1))" + "print(roads_tb.select(\"geom\").take(1))" ] }, { @@ -598,37 +626,47 @@ "# Não foi considerado que um caminha pode necessitar mais de 1 rua\n", "\n", "start_point = \"-25.4695946,-54.5909028\"\n", - "end_point = \"-25.4786993,-54.57938\" \n", + "end_point = \"-25.4786993,-54.57938\"\n", "\n", - "distance_tb = spark.sql(\"select nodes, st_distance(geom, st_point(\"+end_point+\")) as distance_toend, st_distance(st_point(\"+start_point+\"), geom) as distance, st_length(geom) * 1000 as geomsize, geom, id, maxspeed, incline, surface, name , total_nodes from roads_tb\")\n", + "distance_tb = spark.sql(\n", + " \"select nodes, st_distance(geom, st_point(\"\n", + " + end_point\n", + " + \")) as distance_toend, st_distance(st_point(\"\n", + " + start_point\n", + " + \"), geom) as distance, st_length(geom) * 1000 as geomsize, geom, id, maxspeed, incline, surface, name , total_nodes from roads_tb\"\n", + ")\n", "distance_tb.createOrReplaceTempView(\"distance_tb\")\n", "distance_tb.show(5)\n", "\n", "# considerar distância, direcao(ex: 0 180 e etc), inclinacao(up, down, 0%), superficie(asphalt,paved, concrete), velocidade(60 80 50 40 e etc)\n", "fill_null_tb = spark.sql(\n", - " \"select nodes, IFNULL(maxspeed, 20) as maxspeed, IFNULL(incline, '0%') as incline, IFNULL(surface, 'soil') as surface, name, id, geom, geomsize, distance, distance_toend, total_nodes from distance_tb\")\n", + " \"select nodes, IFNULL(maxspeed, 20) as maxspeed, IFNULL(incline, '0%') as incline, IFNULL(surface, 'soil') as surface, name, id, geom, geomsize, distance, distance_toend, total_nodes from distance_tb\"\n", + ")\n", "fill_null_tb.createOrReplaceTempView(\"fill_null_tb\")\n", "fill_null_tb.show(5)\n", "\n", "surface_index_tb = spark.sql(\n", - " \"select nodes, case surface when 'asphalt'\" +\n", - " \"then 0.01 when 'concrete'\" + \n", - " \"then 0.02 when 'paved'\" +\n", - " \"then 0.03 when 'soil'\" +\n", - " \"then 0.04 when 'unpaved'\" + \n", - " \"then 0.04 when 'sett'\" +\n", - " \"then 0.03 ELSE 0.05 end as surface_index,\"+ \n", - " \"maxspeed, incline, surface, name, id, geom, geomsize, distance, distance_toend, total_nodes from fill_null_tb\")\n", + " \"select nodes, case surface when 'asphalt'\"\n", + " + \"then 0.01 when 'concrete'\"\n", + " + \"then 0.02 when 'paved'\"\n", + " + \"then 0.03 when 'soil'\"\n", + " + \"then 0.04 when 'unpaved'\"\n", + " + \"then 0.04 when 'sett'\"\n", + " + \"then 0.03 ELSE 0.05 end as surface_index,\"\n", + " + \"maxspeed, incline, surface, name, id, geom, geomsize, distance, distance_toend, total_nodes from fill_null_tb\"\n", + ")\n", "surface_index_tb.createOrReplaceTempView(\"surface_index_tb\")\n", "surface_index_tb.show(5)\n", "\n", "incline_index_tb = spark.sql(\n", - " \"select nodes, case incline when 'top' then -0.10 when 'down' then 0.10 when '0%' then 0 end as incline_index, surface_index, maxspeed, incline, surface, name, id, geom, geomsize, distance, distance_toend, total_nodes from surface_index_tb\")\n", + " \"select nodes, case incline when 'top' then -0.10 when 'down' then 0.10 when '0%' then 0 end as incline_index, surface_index, maxspeed, incline, surface, name, id, geom, geomsize, distance, distance_toend, total_nodes from surface_index_tb\"\n", + ")\n", "incline_index_tb.createOrReplaceTempView(\"incline_index_tb\")\n", "incline_index_tb.show(5)\n", - " \n", + "\n", "weight_index_tb = spark.sql(\n", - " \"select nodes, (maxspeed - (maxspeed * surface_index)) + (maxspeed +(maxspeed * incline_index)) as weight, incline_index, surface_index, maxspeed, incline, surface, name, id, geom, geomsize, distance, distance_toend, total_nodes from incline_index_tb WHERE geomsize IS NOT NULL\")\n", + " \"select nodes, (maxspeed - (maxspeed * surface_index)) + (maxspeed +(maxspeed * incline_index)) as weight, incline_index, surface_index, maxspeed, incline, surface, name, id, geom, geomsize, distance, distance_toend, total_nodes from incline_index_tb WHERE geomsize IS NOT NULL\"\n", + ")\n", "weight_index_tb.createOrReplaceTempView(\"weight_index_tb\")\n", "weight_index_tb.show(5)" ] @@ -655,8 +693,7 @@ } ], "source": [ - "teste = spark.sql(\n", - " \"select min(weight) from weight_index_tb\")\n", + "teste = spark.sql(\"select min(weight) from weight_index_tb\")\n", "teste.show(5)" ] }, @@ -819,23 +856,30 @@ ], "source": [ "closestoend_tb = spark.sql(\n", - " \"select w1.id, w1.distance_toend from weight_index_tb w1 group by w1.id, w1.distance_toend having (select min(w2.distance_toend) as distance_toend from weight_index_tb w2) = w1.distance_toend\")\n", + " \"select w1.id, w1.distance_toend from weight_index_tb w1 group by w1.id, w1.distance_toend having (select min(w2.distance_toend) as distance_toend from weight_index_tb w2) = w1.distance_toend\"\n", + ")\n", "closestoend_tb.createOrReplaceTempView(\"closestoend_tb\")\n", - "closestoend = closestoend_tb.take(1)[0]['id']\n", + "closestoend = closestoend_tb.take(1)[0][\"id\"]\n", "print(closestoend)\n", "\n", "closestostart_tb = spark.sql(\n", - " \"select w1.id, w1.distance from weight_index_tb w1 group by w1.id, w1.distance having (select min(w2.distance) as distance from weight_index_tb w2) = w1.distance\")\n", + " \"select w1.id, w1.distance from weight_index_tb w1 group by w1.id, w1.distance having (select min(w2.distance) as distance from weight_index_tb w2) = w1.distance\"\n", + ")\n", "closestostart_tb.createOrReplaceTempView(\"closestostart_tb\")\n", "closestostart_tb.show(5)\n", "\n", - "closestostart = closestostart_tb.take(1)[0]['id']\n", + "closestostart = closestostart_tb.take(1)[0][\"id\"]\n", "\n", "# FOLIUM EM 3857 dado em 4326 st_transform(st_union_aggr(geom),'epsg:3857','epsg:4326')\n", "json_lines = spark.sql(\n", - " \"select ST_AsGeoJSON(st_envelope_aggr(geom)) AS json from weight_index_tb where id in (\"+str(closestostart)+\",\"+str(closestoend)+\")\")\n", - "json_lines_string_teste = json_lines.take(1)[0]['json']\n", - "coordinates_teste = json.loads(json_lines_string_teste)['coordinates']\n", + " \"select ST_AsGeoJSON(st_envelope_aggr(geom)) AS json from weight_index_tb where id in (\"\n", + " + str(closestostart)\n", + " + \",\"\n", + " + str(closestoend)\n", + " + \")\"\n", + ")\n", + "json_lines_string_teste = json_lines.take(1)[0][\"json\"]\n", + "coordinates_teste = json.loads(json_lines_string_teste)[\"coordinates\"]\n", "\n", "\n", "# st_boundary st_contains\n", @@ -843,21 +887,33 @@ "# Pegar o limite entre a uniao da geom inicial e final\n", "# select st_boundary(st_union_aggr(geom)) AS boundary from weight_index_tb where id in (\"+str(closestostart)+\",\"+str(closestoend)+\")\n", "\n", - "boundary_tb = spark.sql(\"select st_envelope_aggr(geom) as boundary from weight_index_tb where id in (\"+str(closestostart)+\",\"+str(closestoend)+\")\")\n", + "boundary_tb = spark.sql(\n", + " \"select st_envelope_aggr(geom) as boundary from weight_index_tb where id in (\"\n", + " + str(closestostart)\n", + " + \",\"\n", + " + str(closestoend)\n", + " + \")\"\n", + ")\n", "boundary_tb.createOrReplaceTempView(\"boundary_tb\")\n", "boundary_tb.show(5)\n", "\n", - "contains_tb = spark.sql(\"select st_intersects(boundary,geom) as contains, id from weight_index_tb, boundary_tb\")\n", + "contains_tb = spark.sql(\n", + " \"select st_intersects(boundary,geom) as contains, id from weight_index_tb, boundary_tb\"\n", + ")\n", "contains_tb.createOrReplaceTempView(\"contains_tb\")\n", "contains_tb.show(5)\n", "\n", - "possible_paths = spark.sql(\"select id, geom from weight_index_tb group by id, geom having id in (select id from contains_tb where contains = true)\")\n", + "possible_paths = spark.sql(\n", + " \"select id, geom from weight_index_tb group by id, geom having id in (select id from contains_tb where contains = true)\"\n", + ")\n", "possible_paths.createOrReplaceTempView(\"possible_paths\")\n", "possible_paths.show(5)\n", "\n", - "paths_collection = spark.sql(\"select ST_AsGeoJSON(st_union_aggr(geom)) AS json from possible_paths\")\n", - "json_lines_string = paths_collection.take(1)[0]['json']\n", - "coordinates = json.loads(json_lines_string)['coordinates']" + "paths_collection = spark.sql(\n", + " \"select ST_AsGeoJSON(st_union_aggr(geom)) AS json from possible_paths\"\n", + ")\n", + "json_lines_string = paths_collection.take(1)[0][\"json\"]\n", + "coordinates = json.loads(json_lines_string)[\"coordinates\"]" ] }, { @@ -1010,47 +1066,60 @@ "source": [ "path = [closestostart]\n", "visited = [closestostart]\n", - "current_nodes = spark.sql(\"select geom from weight_index_tb where id = \"+str(closestostart))\n", + "current_nodes = spark.sql(\n", + " \"select geom from weight_index_tb where id = \" + str(closestostart)\n", + ")\n", "row = current_nodes.rdd.collect()[0][\"geom\"]\n", "id_current = closestostart\n", "\n", + "\n", "def choose_path(row, path, id_current, visited, copy_row):\n", "\n", - " visited_frm = str(visited).replace(\"[\",\"(\").replace(\"]\",\")\")\n", - " \n", - " touches_tb = spark.sql(\"select st_touches(st_geomfromwkt('\"+str(row)+\"'),geom) as touches, * from weight_index_tb where geom IS NOT NULL and distance_toend IS NOT NULL\")\n", + " visited_frm = str(visited).replace(\"[\", \"(\").replace(\"]\", \")\")\n", + "\n", + " touches_tb = spark.sql(\n", + " \"select st_touches(st_geomfromwkt('\"\n", + " + str(row)\n", + " + \"'),geom) as touches, * from weight_index_tb where geom IS NOT NULL and distance_toend IS NOT NULL\"\n", + " )\n", " touches_tb.createOrReplaceTempView(\"touches_tb\")\n", - "# st_distance(st_geomfromwkt('\"+str(row)+\"'),geom)\n", - " fim_distance = spark.sql(\"select distance_toend from touches_tb where id = \"+ str(id_current))\n", + " # st_distance(st_geomfromwkt('\"+str(row)+\"'),geom)\n", + " fim_distance = spark.sql(\n", + " \"select distance_toend from touches_tb where id = \" + str(id_current)\n", + " )\n", " fim_distance.show(5)\n", " fim_distance_value = fim_distance.rdd.collect()[0][\"distance_toend\"]\n", - " \n", - " \n", - " current_distance = spark.sql(\"select distance from touches_tb where id = \" + str(id_current))\n", + "\n", + " current_distance = spark.sql(\n", + " \"select distance from touches_tb where id = \" + str(id_current)\n", + " )\n", " current_distance_value = current_distance.rdd.collect()[0][\"distance\"]\n", - " \n", - "# st_distance(st_geomfromwkt('\"+str(row)+\"'),geom) = \n", "\n", - " sql = \"select geom, id, weight from touches_tb where \" \\\n", - " +\"touches = true\" \\\n", - " +\" and \" \\\n", - " +\"distance_toend < \" \\\n", - " +str(fim_distance_value) \\\n", - " +\" and \" \\\n", - " +\"distance > \" \\\n", - " +str(current_distance_value) \\\n", - " +\" and \" \\\n", - " +\"id NOT IN \" \\\n", - " +visited_frm \\\n", - " \n", + " # st_distance(st_geomfromwkt('\"+str(row)+\"'),geom) =\n", + "\n", + " sql = (\n", + " \"select geom, id, weight from touches_tb where \"\n", + " + \"touches = true\"\n", + " + \" and \"\n", + " + \"distance_toend < \"\n", + " + str(fim_distance_value)\n", + " + \" and \"\n", + " + \"distance > \"\n", + " + str(current_distance_value)\n", + " + \" and \"\n", + " + \"id NOT IN \"\n", + " + visited_frm\n", + " )\n", " print(sql)\n", " current_nodes = spark.sql(sql)\n", " current_nodes.createOrReplaceTempView(\"current_nodes\")\n", " current_nodes.show(5)\n", - " \n", - " current_node = spark.sql(\"select id, geom, max(weight) from current_nodes group by id, geom, weight having max(weight) = weight \")\n", + "\n", + " current_node = spark.sql(\n", + " \"select id, geom, max(weight) from current_nodes group by id, geom, weight having max(weight) = weight \"\n", + " )\n", " current_node.show(5)\n", - " \n", + "\n", " if len(current_nodes.rdd.collect()) == 0:\n", " return path\n", " else:\n", @@ -1059,15 +1128,19 @@ " path.append(id_current)\n", " visited.append(id_current)\n", " return choose_path(row, path, id_current, visited, copy_row)\n", - " \n", + "\n", + "\n", "path_ids = choose_path(row, path, id_current, visited, row)\n", "path_ids.append(closestoend)\n", - "path_ids_frm = str(path_ids).replace(\"[\",\"(\").replace(\"]\",\")\")\n", + "path_ids_frm = str(path_ids).replace(\"[\", \"(\").replace(\"]\", \")\")\n", "print(path_ids_frm)\n", "\n", - "short_path = spark.sql(\"select ST_AsGeoJSON(st_union_aggr(geom)) AS json from weight_index_tb where id in \"+path_ids_frm)\n", - "short_path_string = short_path.take(1)[0]['json']\n", - "short_path_coordinates = json.loads(short_path_string)['coordinates']" + "short_path = spark.sql(\n", + " \"select ST_AsGeoJSON(st_union_aggr(geom)) AS json from weight_index_tb where id in \"\n", + " + path_ids_frm\n", + ")\n", + "short_path_string = short_path.take(1)[0][\"json\"]\n", + "short_path_coordinates = json.loads(short_path_string)[\"coordinates\"]" ] }, { @@ -1095,13 +1168,21 @@ "\n", "import folium\n", "\n", - "start_point_arr = [-25.4695946,-54.5909028]\n", - "end_point_arr = [-25.4786993,-54.57938] \n", + "start_point_arr = [-25.4695946, -54.5909028]\n", + "end_point_arr = [-25.4786993, -54.57938]\n", "tooltip = \"Click me!\"\n", "# 3857\n", - "m = folium.Map(location=[-25.5172662,-54.6170038], zoom_start=12, tiles='OpenStreetMap', crs='EPSG3857' )\n", + "m = folium.Map(\n", + " location=[-25.5172662, -54.6170038],\n", + " zoom_start=12,\n", + " tiles=\"OpenStreetMap\",\n", + " crs=\"EPSG3857\",\n", + ")\n", "folium.Marker(\n", - " start_point_arr, popup=\"Inicio\", tooltip=tooltip, icon=folium.Icon(color=\"green\")\n", + " start_point_arr,\n", + " popup=\"Inicio\",\n", + " tooltip=tooltip,\n", + " icon=folium.Icon(color=\"green\"),\n", ").add_to(m)\n", "folium.Marker(\n", " end_point_arr, popup=\"Fim\", tooltip=tooltip, icon=folium.Icon(color=\"red\")\n", diff --git a/docs/usecases/contrib/foot-traffic.ipynb b/docs/usecases/contrib/foot-traffic.ipynb index 0367463c25..200da0dc72 100644 --- a/docs/usecases/contrib/foot-traffic.ipynb +++ b/docs/usecases/contrib/foot-traffic.ipynb @@ -114,16 +114,18 @@ "from sedona.register import SedonaRegistrator\n", "from sedona.utils import SedonaKryoRegistrator, KryoSerializer\n", "\n", - "spark = SparkSession. \\\n", - " builder. \\\n", - " appName('sigspatial2021'). \\\n", - " master(\"spark://data-ocean-lab-1:7077\").\\\n", - " config(\"spark.serializer\", KryoSerializer.getName).\\\n", - " config(\"spark.kryo.registrator\", SedonaKryoRegistrator.getName) .\\\n", - " config('spark.jars.packages',\n", - " 'org.apache.sedona:sedona-python-adapter-3.0_2.12:1.2.0-incubating,'\n", - " 'org.datasyslab:geotools-wrapper:1.1.0-25.2'). \\\n", - " getOrCreate()" + "spark = (\n", + " SparkSession.builder.appName(\"sigspatial2021\")\n", + " .master(\"spark://data-ocean-lab-1:7077\")\n", + " .config(\"spark.serializer\", KryoSerializer.getName)\n", + " .config(\"spark.kryo.registrator\", SedonaKryoRegistrator.getName)\n", + " .config(\n", + " \"spark.jars.packages\",\n", + " \"org.apache.sedona:sedona-python-adapter-3.0_2.12:1.2.0-incubating,\"\n", + " \"org.datasyslab:geotools-wrapper:1.1.0-25.2\",\n", + " )\n", + " .getOrCreate()\n", + ")" ] }, { @@ -183,13 +185,18 @@ } ], "source": [ - "sample_csv_path = 'file:///media/hdd1/code/sigspatial-2021-cafe-analysis/data/seattle_coffee_monthly_patterns/'\n", + "sample_csv_path = \"file:///media/hdd1/code/sigspatial-2021-cafe-analysis/data/seattle_coffee_monthly_patterns/\"\n", "sample = (\n", - " spark.read.option(\"header\", \"true\").option(\"escape\", \"\\\"\").csv(sample_csv_path)\n", - " .withColumn('date_range_start', f.to_date(f.col('date_range_start')))\n", - " .withColumn('date_range_end', f.to_date(f.col('date_range_end')))\n", - " .withColumn('visitor_home_cbgs', f.from_json('visitor_home_cbgs', schema = MapType(StringType(), IntegerType())))\n", - " .withColumn(\"distance_from_home\", f.col(\"distance_from_home\"))\n", + " spark.read.option(\"header\", \"true\")\n", + " .option(\"escape\", '\"')\n", + " .csv(sample_csv_path)\n", + " .withColumn(\"date_range_start\", f.to_date(f.col(\"date_range_start\")))\n", + " .withColumn(\"date_range_end\", f.to_date(f.col(\"date_range_end\")))\n", + " .withColumn(\n", + " \"visitor_home_cbgs\",\n", + " f.from_json(\"visitor_home_cbgs\", schema=MapType(StringType(), IntegerType())),\n", + " )\n", + " .withColumn(\"distance_from_home\", f.col(\"distance_from_home\"))\n", ")" ] }, @@ -561,16 +568,26 @@ }, "outputs": [], "source": [ - "w = Window().partitionBy('placekey').orderBy(f.col('date_range_start').desc())\n", + "w = Window().partitionBy(\"placekey\").orderBy(f.col(\"date_range_start\").desc())\n", "\n", "cafes_latest = (\n", - " sample\n", - " # as our data improves, addresses or geocodes for a given location may change over time\n", - " # use a window function to keep only the most recent appearance of the given cafe\n", - " .withColumn('row_num', f.row_number().over(w))\n", - " .filter(f.col('row_num') == 1)\n", - " # select the columns we need for mapping\n", - " .select('placekey', 'location_name', 'brands', 'street_address', 'city', 'region', 'postal_code', 'latitude', 'longitude', 'open_hours')\n", + " sample\n", + " # as our data improves, addresses or geocodes for a given location may change over time\n", + " # use a window function to keep only the most recent appearance of the given cafe\n", + " .withColumn(\"row_num\", f.row_number().over(w)).filter(f.col(\"row_num\") == 1)\n", + " # select the columns we need for mapping\n", + " .select(\n", + " \"placekey\",\n", + " \"location_name\",\n", + " \"brands\",\n", + " \"street_address\",\n", + " \"city\",\n", + " \"region\",\n", + " \"postal_code\",\n", + " \"latitude\",\n", + " \"longitude\",\n", + " \"open_hours\",\n", + " )\n", ")" ] }, @@ -597,7 +614,11 @@ "source": [ "# create a geopandas geodataframe\n", "cafes_gdf = cafes_latest.toPandas()\n", - "cafes_gdf = gpd.GeoDataFrame(cafes_gdf, geometry = gpd.points_from_xy(cafes_gdf['longitude'], cafes_gdf['latitude']), crs = 'EPSG:4326')" + "cafes_gdf = gpd.GeoDataFrame(\n", + " cafes_gdf,\n", + " geometry=gpd.points_from_xy(cafes_gdf[\"longitude\"], cafes_gdf[\"latitude\"]),\n", + " crs=\"EPSG:4326\",\n", + ")" ] }, { @@ -614,37 +635,34 @@ "outputs": [], "source": [ "def map_cafes(gdf):\n", - " \n", - " # map bounds\n", - " sw = [gdf.unary_union.bounds[1], gdf.unary_union.bounds[0]]\n", - " ne = [gdf.unary_union.bounds[3], gdf.unary_union.bounds[2]]\n", - " folium_bounds = [sw, ne]\n", - " \n", - " # map\n", - " x = gdf.centroid.x[0]\n", - " y = gdf.centroid.y[0]\n", - " \n", - " map_ = folium.Map(\n", - " location = [y, x],\n", - " tiles = \"OpenStreetMap\"\n", - " )\n", - " \n", - " for i, point in gdf.iterrows():\n", - " \n", - " tooltip = f\"placekey: {point['placekey']}
location_name: {point['location_name']}
brands: {point['brands']}
street_address: {point['street_address']}
city: {point['city']}
region: {point['region']}
postal_code: {point['postal_code']}
open_hours: {point['open_hours']}\"\n", - " \n", - " folium.Circle(\n", - " [point['geometry'].y, point['geometry'].x],\n", - " radius = 40,\n", - " fill_color = 'blue',\n", - " color = 'blue',\n", - " fill_opacity = 1,\n", - " tooltip = tooltip\n", - " ).add_to(map_)\n", "\n", - " map_.fit_bounds(folium_bounds) \n", - " \n", - " return map_" + " # map bounds\n", + " sw = [gdf.unary_union.bounds[1], gdf.unary_union.bounds[0]]\n", + " ne = [gdf.unary_union.bounds[3], gdf.unary_union.bounds[2]]\n", + " folium_bounds = [sw, ne]\n", + "\n", + " # map\n", + " x = gdf.centroid.x[0]\n", + " y = gdf.centroid.y[0]\n", + "\n", + " map_ = folium.Map(location=[y, x], tiles=\"OpenStreetMap\")\n", + "\n", + " for i, point in gdf.iterrows():\n", + "\n", + " tooltip = f\"placekey: {point['placekey']}
location_name: {point['location_name']}
brands: {point['brands']}
street_address: {point['street_address']}
city: {point['city']}
region: {point['region']}
postal_code: {point['postal_code']}
open_hours: {point['open_hours']}\"\n", + "\n", + " folium.Circle(\n", + " [point[\"geometry\"].y, point[\"geometry\"].x],\n", + " radius=40,\n", + " fill_color=\"blue\",\n", + " color=\"blue\",\n", + " fill_opacity=1,\n", + " tooltip=tooltip,\n", + " ).add_to(map_)\n", + "\n", + " map_.fit_bounds(folium_bounds)\n", + "\n", + " return map_" ] }, { @@ -791,21 +809,22 @@ } ], "source": [ - "# the `distance_from_home` column tells us the median distance (as the crow flies), in meters, between the coffee shop and the visitors' homes \n", + "# the `distance_from_home` column tells us the median distance (as the crow flies), in meters, between the coffee shop and the visitors' homes\n", "# which coffee shop's visitors had the highest average median distance traveled since Jan 2018?\n", "\n", "# outlier values in this column distort the histogram\n", "# these outliers are likely due to a combination of (1) coffee shops in downtown areas that receive high numbers of out-of-town visitors and (2) quirks in the underlying GPS data\n", "furthest_traveled = (\n", - " sample\n", - " .groupBy('placekey', 'location_name', 'street_address')\n", - " .agg(f.mean('distance_from_home').alias('avg_median_dist_from_home'))\n", - " .orderBy('avg_median_dist_from_home', ascending = False)\n", + " sample.groupBy(\"placekey\", \"location_name\", \"street_address\")\n", + " .agg(f.mean(\"distance_from_home\").alias(\"avg_median_dist_from_home\"))\n", + " .orderBy(\"avg_median_dist_from_home\", ascending=False)\n", ")\n", "\n", "display(furthest_traveled)\n", - "furthest_traveled.filter(f.col(\"avg_median_dist_from_home\").isNotNull()).withColumn(\"avg_median_dist_from_home\", f.col(\"avg_median_dist_from_home\")/1000.0).show(truncate=False)\n", - "furthest_traveled = furthest_traveled.drop('street_address')" + "furthest_traveled.filter(f.col(\"avg_median_dist_from_home\").isNotNull()).withColumn(\n", + " \"avg_median_dist_from_home\", f.col(\"avg_median_dist_from_home\") / 1000.0\n", + ").show(truncate=False)\n", + "furthest_traveled = furthest_traveled.drop(\"street_address\")" ] }, { @@ -854,8 +873,13 @@ ], "source": [ "# most coffee shops' visitors' homes are <10km away\n", - "display(furthest_traveled.filter(f.col('avg_median_dist_from_home') < 10000))\n", - "print(furthest_traveled.filter(f.col('avg_median_dist_from_home') < 10000).count(), \" coffee shops have vistors' home <10 km away, out of \", furthest_traveled.count(), \"coffee shops.\")" + "display(furthest_traveled.filter(f.col(\"avg_median_dist_from_home\") < 10000))\n", + "print(\n", + " furthest_traveled.filter(f.col(\"avg_median_dist_from_home\") < 10000).count(),\n", + " \" coffee shops have vistors' home <10 km away, out of \",\n", + " furthest_traveled.count(),\n", + " \"coffee shops.\",\n", + ")" ] }, { @@ -905,10 +929,12 @@ ], "source": [ "# load the census block groups for Washington State\n", - "# filter it down to our three counties of interest in Seattle \n", + "# filter it down to our three counties of interest in Seattle\n", "WA_cbgs = (\n", - " spark.read.option('header', 'true').option('escape', \"\\\"\").csv('file:///media/hdd1/code/sigspatial-2021-cafe-analysis/data/wa_cbg.csv')\n", - " .filter(f.col('GEOID').rlike('^(53033|53053|53061)'))\n", + " spark.read.option(\"header\", \"true\")\n", + " .option(\"escape\", '\"')\n", + " .csv(\"file:///media/hdd1/code/sigspatial-2021-cafe-analysis/data/wa_cbg.csv\")\n", + " .filter(f.col(\"GEOID\").rlike(\"^(53033|53053|53061)\"))\n", ")\n", "WA_cbgs.head()" ] @@ -928,14 +954,18 @@ "source": [ "# transform the geometry column into a Geometry-type\n", "WA_cbgs = (\n", - " WA_cbgs\n", - " .withColumn('cbg_geometry', f.expr(\"ST_GeomFromWkt(geometry)\"))\n", - " # we'll just use the CBG centroid\n", - " .withColumn('cbg_geometry', f.expr(\"ST_Centroid(cbg_geometry)\"))\n", - " # since we'll be doing a distance calculation, let's also use a projected CRS - epsg:3857\n", - " .withColumn('cbg_geometry', f.expr(\"ST_Transform(ST_FlipCoordinates(cbg_geometry), 'epsg:4326','epsg:3857', false)\")) # ST_FlipCoordinates() necessary due to this bug: https://issues.apache.org/jira/browse/SEDONA-39\n", - " .withColumnRenamed('GEOID', 'cbg')\n", - " .withColumnRenamed('geometry', 'cbg_polygon_geometry')\n", + " WA_cbgs.withColumn(\"cbg_geometry\", f.expr(\"ST_GeomFromWkt(geometry)\"))\n", + " # we'll just use the CBG centroid\n", + " .withColumn(\"cbg_geometry\", f.expr(\"ST_Centroid(cbg_geometry)\"))\n", + " # since we'll be doing a distance calculation, let's also use a projected CRS - epsg:3857\n", + " .withColumn(\n", + " \"cbg_geometry\",\n", + " f.expr(\n", + " \"ST_Transform(ST_FlipCoordinates(cbg_geometry), 'epsg:4326','epsg:3857', false)\"\n", + " ),\n", + " ) # ST_FlipCoordinates() necessary due to this bug: https://issues.apache.org/jira/browse/SEDONA-39\n", + " .withColumnRenamed(\"GEOID\", \"cbg\")\n", + " .withColumnRenamed(\"geometry\", \"cbg_polygon_geometry\")\n", ")" ] }, @@ -1052,29 +1082,33 @@ "source": [ "# Next let's prep our sample data\n", "sample_seattle_visitors = (\n", - " sample\n", - " .select('placekey', f.explode('visitor_home_cbgs'))\n", - " .withColumnRenamed('key', 'cbg')\n", - " .withColumnRenamed('value', 'visitors')\n", - "# # filter out CBGs with low visitor counts\n", - " .filter(f.col('visitors') > 4)\n", - " # filter down to only the visitors from Seattle CBGs\n", - " .filter(f.col('cbg').rlike('^(53033|53053|53061)'))\n", - " # aggregate up all the visitors over time from each CBG to each Cafe\n", - " .groupBy('placekey', 'cbg')\n", - " .agg(\n", - " f.sum('visitors').alias('visitors')\n", - " )\n", - " # join back with most up-to-date POI information\n", - " .join(\n", - " cafes_latest.select('placekey', 'latitude', 'longitude'),\n", - " 'placekey'\n", - " )\n", - " # transform geometry column\n", - " .withColumn('cafe_geometry', f.expr(\"ST_Point(CAST(longitude AS Decimal(24, 20)), CAST(latitude AS Decimal(24, 20)))\"))\n", - " .withColumn('cafe_geometry', f.expr(\"ST_Transform(ST_FlipCoordinates(cafe_geometry), 'epsg:4326','epsg:3857', false)\"))\n", - " # join with CBG geometries\n", - " .join(WA_cbgs, 'cbg')\n", + " sample.select(\"placekey\", f.explode(\"visitor_home_cbgs\"))\n", + " .withColumnRenamed(\"key\", \"cbg\")\n", + " .withColumnRenamed(\"value\", \"visitors\")\n", + " # # filter out CBGs with low visitor counts\n", + " .filter(f.col(\"visitors\") > 4)\n", + " # filter down to only the visitors from Seattle CBGs\n", + " .filter(f.col(\"cbg\").rlike(\"^(53033|53053|53061)\"))\n", + " # aggregate up all the visitors over time from each CBG to each Cafe\n", + " .groupBy(\"placekey\", \"cbg\")\n", + " .agg(f.sum(\"visitors\").alias(\"visitors\"))\n", + " # join back with most up-to-date POI information\n", + " .join(cafes_latest.select(\"placekey\", \"latitude\", \"longitude\"), \"placekey\")\n", + " # transform geometry column\n", + " .withColumn(\n", + " \"cafe_geometry\",\n", + " f.expr(\n", + " \"ST_Point(CAST(longitude AS Decimal(24, 20)), CAST(latitude AS Decimal(24, 20)))\"\n", + " ),\n", + " )\n", + " .withColumn(\n", + " \"cafe_geometry\",\n", + " f.expr(\n", + " \"ST_Transform(ST_FlipCoordinates(cafe_geometry), 'epsg:4326','epsg:3857', false)\"\n", + " ),\n", + " )\n", + " # join with CBG geometries\n", + " .join(WA_cbgs, \"cbg\")\n", ")" ] }, @@ -1241,11 +1275,13 @@ "outputs": [], "source": [ "distance_traveled_SEA = (\n", - " sample_seattle_visitors\n", - " # calculate the distance from home in meters\n", - " .withColumn('distance_from_home', f.expr(\"ST_Distance(cafe_geometry, cbg_geometry)\"))\n", + " sample_seattle_visitors\n", + " # calculate the distance from home in meters\n", + " .withColumn(\n", + " \"distance_from_home\", f.expr(\"ST_Distance(cafe_geometry, cbg_geometry)\")\n", + " )\n", ")\n", - "distance_traveled_SEA.createOrReplaceTempView('distance_traveled_SEA')" + "distance_traveled_SEA.createOrReplaceTempView(\"distance_traveled_SEA\")" ] }, { @@ -1261,7 +1297,7 @@ }, "outputs": [], "source": [ - "q = '''\n", + "q = \"\"\"\n", "SELECT *\n", "FROM (\n", " SELECT DISTINCT \n", @@ -1273,7 +1309,7 @@ " FROM distance_traveled_SEA\n", ")\n", "WHERE pos > 0\n", - "'''\n", + "\"\"\"\n", "\n", "weighted_median_tmp = spark.sql(q)" ] @@ -1291,8 +1327,8 @@ }, "outputs": [], "source": [ - "grp_window = Window.partitionBy('placekey')\n", - "median_percentile = f.expr('percentile_approx(distance_from_home, 0.5)')" + "grp_window = Window.partitionBy(\"placekey\")\n", + "median_percentile = f.expr(\"percentile_approx(distance_from_home, 0.5)\")" ] }, { @@ -1424,10 +1460,8 @@ }, "outputs": [], "source": [ - "median_dist_traveled_SEA = (\n", - " weighted_median_tmp\n", - " .groupBy('placekey')\n", - " .agg(median_percentile.alias('median_dist_traveled_SEA'))\n", + "median_dist_traveled_SEA = weighted_median_tmp.groupBy(\"placekey\").agg(\n", + " median_percentile.alias(\"median_dist_traveled_SEA\")\n", ")" ] }, @@ -1520,7 +1554,9 @@ } ], "source": [ - "median_dist_traveled_SEA.filter(f.col('median_dist_traveled_SEA') < 25000).limit(10).toPandas().head()\n" + "median_dist_traveled_SEA.filter(f.col(\"median_dist_traveled_SEA\") < 25000).limit(\n", + " 10\n", + ").toPandas().head()" ] }, { @@ -1536,13 +1572,9 @@ }, "outputs": [], "source": [ - "total_visits = (\n", - " sample\n", - " .groupBy('placekey')\n", - " .agg(\n", - " f.sum('raw_visit_counts').alias('total_visits'),\n", - " f.sum('raw_visitor_counts').alias('total_visitors')\n", - " )\n", + "total_visits = sample.groupBy(\"placekey\").agg(\n", + " f.sum(\"raw_visit_counts\").alias(\"total_visits\"),\n", + " f.sum(\"raw_visitor_counts\").alias(\"total_visitors\"),\n", ")" ] }, @@ -1560,14 +1592,16 @@ "outputs": [], "source": [ "distance_traveled_final = (\n", - " furthest_traveled\n", - " .join(median_dist_traveled_SEA, 'placekey')\n", - " .join(cafes_latest, ['placekey', 'location_name'])\n", - " .withColumn('distance_traveled_diff', f.col('avg_median_dist_from_home') - f.col('median_dist_traveled_SEA'))\n", - " # keep only the cafes with a meaningful sample - at least 1000 visits since 2018\n", - " .join(total_visits, 'placekey')\n", - " .filter(f.col('total_visits') > 1000)\n", - " .orderBy('distance_traveled_diff', ascending = False)\n", + " furthest_traveled.join(median_dist_traveled_SEA, \"placekey\")\n", + " .join(cafes_latest, [\"placekey\", \"location_name\"])\n", + " .withColumn(\n", + " \"distance_traveled_diff\",\n", + " f.col(\"avg_median_dist_from_home\") - f.col(\"median_dist_traveled_SEA\"),\n", + " )\n", + " # keep only the cafes with a meaningful sample - at least 1000 visits since 2018\n", + " .join(total_visits, \"placekey\")\n", + " .filter(f.col(\"total_visits\") > 1000)\n", + " .orderBy(\"distance_traveled_diff\", ascending=False)\n", ")" ] }, @@ -1800,7 +1834,7 @@ } ], "source": [ - "distance_traveled_final.select('placekey').distinct().count()" + "distance_traveled_final.select(\"placekey\").distinct().count()" ] }, { @@ -1824,14 +1858,26 @@ } ], "source": [ - "most_tourists = distance_traveled_final.limit(500).withColumn('visitor_type', f.lit('tourist'))\n", - "most_locals = distance_traveled_final.orderBy('distance_traveled_diff').limit(500).withColumn('visitor_type', f.lit('local'))\n", + "most_tourists = distance_traveled_final.limit(500).withColumn(\n", + " \"visitor_type\", f.lit(\"tourist\")\n", + ")\n", + "most_locals = (\n", + " distance_traveled_final.orderBy(\"distance_traveled_diff\")\n", + " .limit(500)\n", + " .withColumn(\"visitor_type\", f.lit(\"local\"))\n", + ")\n", "\n", "visitor_type = most_tourists.unionByName(most_locals)\n", "\n", "# create a geopandas geodataframe\n", "visitor_type_gdf = visitor_type.toPandas()\n", - "visitor_type_gdf = gpd.GeoDataFrame(visitor_type_gdf, geometry = gpd.points_from_xy(visitor_type_gdf['longitude'], visitor_type_gdf['latitude']), crs = 'EPSG:4326')" + "visitor_type_gdf = gpd.GeoDataFrame(\n", + " visitor_type_gdf,\n", + " geometry=gpd.points_from_xy(\n", + " visitor_type_gdf[\"longitude\"], visitor_type_gdf[\"latitude\"]\n", + " ),\n", + " crs=\"EPSG:4326\",\n", + ")" ] }, { @@ -1848,37 +1894,34 @@ "outputs": [], "source": [ "def map_cafe_visitor_type(gdf):\n", - " \n", - " # map bounds\n", - " sw = [gdf.unary_union.bounds[1], gdf.unary_union.bounds[0]]\n", - " ne = [gdf.unary_union.bounds[3], gdf.unary_union.bounds[2]]\n", - " folium_bounds = [sw, ne]\n", - " \n", - " # map\n", - " x = gdf.centroid.x[0]\n", - " y = gdf.centroid.y[0]\n", - " \n", - " map_ = folium.Map(\n", - " location = [y, x],\n", - " tiles = \"OpenStreetMap\"\n", - " )\n", - " \n", - " for i, point in gdf.iterrows():\n", - " \n", - " tooltip = f\"placekey: {point['placekey']}
location_name: {point['location_name']}
brands: {point['brands']}
street_address: {point['street_address']}
city: {point['city']}
region: {point['region']}
postal_code: {point['postal_code']}
visitor_type: {point['visitor_type']}
avg_median_dist_from_home: {point['avg_median_dist_from_home']}\"\n", - " \n", - " folium.Circle(\n", - " [point['geometry'].y, point['geometry'].x],\n", - " radius = 40,\n", - " fill_color = 'blue' if point['visitor_type'] == 'tourist' else 'red',\n", - " color = 'blue' if point['visitor_type'] == 'tourist' else 'red',\n", - " fill_opacity = 1,\n", - " tooltip = tooltip\n", - " ).add_to(map_)\n", "\n", - " map_.fit_bounds(folium_bounds) \n", - " \n", - " return map_" + " # map bounds\n", + " sw = [gdf.unary_union.bounds[1], gdf.unary_union.bounds[0]]\n", + " ne = [gdf.unary_union.bounds[3], gdf.unary_union.bounds[2]]\n", + " folium_bounds = [sw, ne]\n", + "\n", + " # map\n", + " x = gdf.centroid.x[0]\n", + " y = gdf.centroid.y[0]\n", + "\n", + " map_ = folium.Map(location=[y, x], tiles=\"OpenStreetMap\")\n", + "\n", + " for i, point in gdf.iterrows():\n", + "\n", + " tooltip = f\"placekey: {point['placekey']}
location_name: {point['location_name']}
brands: {point['brands']}
street_address: {point['street_address']}
city: {point['city']}
region: {point['region']}
postal_code: {point['postal_code']}
visitor_type: {point['visitor_type']}
avg_median_dist_from_home: {point['avg_median_dist_from_home']}\"\n", + "\n", + " folium.Circle(\n", + " [point[\"geometry\"].y, point[\"geometry\"].x],\n", + " radius=40,\n", + " fill_color=\"blue\" if point[\"visitor_type\"] == \"tourist\" else \"red\",\n", + " color=\"blue\" if point[\"visitor_type\"] == \"tourist\" else \"red\",\n", + " fill_opacity=1,\n", + " tooltip=tooltip,\n", + " ).add_to(map_)\n", + "\n", + " map_.fit_bounds(folium_bounds)\n", + "\n", + " return map_" ] }, { @@ -2029,11 +2072,15 @@ ], "source": [ "WA_neighbs = (\n", - " spark.read.option('header', 'true').option('escape', \"\\\"\").csv('file:///media/hdd1/code/sigspatial-2021-cafe-analysis/data/seattle_neighborhoods.csv')\n", - " # transform the geometry column into a Geometry-type\n", - " .withColumn('geometry', f.expr(\"ST_GeomFromWkt(geometry)\"))\n", + " spark.read.option(\"header\", \"true\")\n", + " .option(\"escape\", '\"')\n", + " .csv(\n", + " \"file:///media/hdd1/code/sigspatial-2021-cafe-analysis/data/seattle_neighborhoods.csv\"\n", + " )\n", + " # transform the geometry column into a Geometry-type\n", + " .withColumn(\"geometry\", f.expr(\"ST_GeomFromWkt(geometry)\"))\n", ")\n", - "WA_neighbs.createOrReplaceTempView('WA_neighbs')\n", + "WA_neighbs.createOrReplaceTempView(\"WA_neighbs\")\n", "\n", "WA_neighbs.limit(10).toPandas().head()" ] @@ -2051,8 +2098,13 @@ }, "outputs": [], "source": [ - "cafes_geo = cafes_latest.withColumn('cafe_geometry', f.expr(\"ST_Point(CAST(longitude AS Decimal(24,20)), CAST(latitude AS Decimal(24,20)))\")).select('placekey', 'cafe_geometry')\n", - "cafes_geo.createOrReplaceTempView('cafes_geo')" + "cafes_geo = cafes_latest.withColumn(\n", + " \"cafe_geometry\",\n", + " f.expr(\n", + " \"ST_Point(CAST(longitude AS Decimal(24,20)), CAST(latitude AS Decimal(24,20)))\"\n", + " ),\n", + ").select(\"placekey\", \"cafe_geometry\")\n", + "cafes_geo.createOrReplaceTempView(\"cafes_geo\")" ] }, { @@ -2069,11 +2121,11 @@ "outputs": [], "source": [ "# perform a spatial join\n", - "q = '''\n", + "q = \"\"\"\n", "SELECT cafes_geo.placekey, WA_neighbs.S_HOOD as neighborhood, WA_neighbs.geometry\n", "FROM WA_neighbs, cafes_geo\n", "WHERE ST_Intersects(WA_neighbs.geometry, cafes_geo.cafe_geometry)\n", - "'''\n", + "\"\"\"\n", "\n", "cafe_neighb_join = spark.sql(q)" ] @@ -2199,13 +2251,12 @@ "source": [ "# add the visit and visitor counts and aggregate up to the neighborhood\n", "neighborhood_agg = (\n", - " cafe_neighb_join\n", - " .join(total_visits, 'placekey')\n", - " .groupBy('neighborhood', f.col('geometry').cast('string').alias('geometry'))\n", - " .agg(\n", - " f.sum('total_visits').alias('total_visits'),\n", - " f.sum('total_visitors').alias('total_visitors')\n", - " )\n", + " cafe_neighb_join.join(total_visits, \"placekey\")\n", + " .groupBy(\"neighborhood\", f.col(\"geometry\").cast(\"string\").alias(\"geometry\"))\n", + " .agg(\n", + " f.sum(\"total_visits\").alias(\"total_visits\"),\n", + " f.sum(\"total_visitors\").alias(\"total_visitors\"),\n", + " )\n", ")" ] }, @@ -2231,11 +2282,12 @@ } ], "source": [ - "neighbs_gdf = (\n", - " neighborhood_agg\n", - " .toPandas()\n", - ")\n", - "neighbs_gdf = gpd.GeoDataFrame(neighbs_gdf, geometry = gpd.GeoSeries.from_wkt(neighbs_gdf['geometry']), crs = 'EPSG:4326')" + "neighbs_gdf = neighborhood_agg.toPandas()\n", + "neighbs_gdf = gpd.GeoDataFrame(\n", + " neighbs_gdf,\n", + " geometry=gpd.GeoSeries.from_wkt(neighbs_gdf[\"geometry\"]),\n", + " crs=\"EPSG:4326\",\n", + ")" ] }, { @@ -2252,38 +2304,37 @@ "outputs": [], "source": [ "def map_neighbs(gdf):\n", - " \n", - " # map bounds\n", - " sw = [gdf.unary_union.bounds[1], gdf.unary_union.bounds[0]]\n", - " ne = [gdf.unary_union.bounds[3], gdf.unary_union.bounds[2]]\n", - " folium_bounds = [sw, ne]\n", - " \n", - " # map\n", - " x = gdf.centroid.x[0]\n", - " y = gdf.centroid.y[0]\n", - " \n", - " map_ = folium.Map(\n", - " location = [y, x],\n", - " tiles = \"OpenStreetMap\"\n", - " )\n", - " \n", - " gdf['percentile'] = pd.qcut(gdf['total_visits'], 100, labels=False) / 100\n", - " \n", - " folium.GeoJson(\n", - " gdf[['neighborhood', 'total_visits', 'total_visitors', 'percentile', 'geometry']],\n", - " style_function = lambda x: {\n", - " 'weight':0,\n", - " 'color':'blue',\n", - " 'fillOpacity': x['properties']['percentile']\n", - " },\n", - " tooltip = folium.features.GeoJsonTooltip(\n", - " fields = ['neighborhood', 'total_visits', 'total_visitors', 'percentile']\n", - " )\n", + "\n", + " # map bounds\n", + " sw = [gdf.unary_union.bounds[1], gdf.unary_union.bounds[0]]\n", + " ne = [gdf.unary_union.bounds[3], gdf.unary_union.bounds[2]]\n", + " folium_bounds = [sw, ne]\n", + "\n", + " # map\n", + " x = gdf.centroid.x[0]\n", + " y = gdf.centroid.y[0]\n", + "\n", + " map_ = folium.Map(location=[y, x], tiles=\"OpenStreetMap\")\n", + "\n", + " gdf[\"percentile\"] = pd.qcut(gdf[\"total_visits\"], 100, labels=False) / 100\n", + "\n", + " folium.GeoJson(\n", + " gdf[\n", + " [\"neighborhood\", \"total_visits\", \"total_visitors\", \"percentile\", \"geometry\"]\n", + " ],\n", + " style_function=lambda x: {\n", + " \"weight\": 0,\n", + " \"color\": \"blue\",\n", + " \"fillOpacity\": x[\"properties\"][\"percentile\"],\n", + " },\n", + " tooltip=folium.features.GeoJsonTooltip(\n", + " fields=[\"neighborhood\", \"total_visits\", \"total_visitors\", \"percentile\"]\n", + " ),\n", " ).add_to(map_)\n", "\n", - " map_.fit_bounds(folium_bounds) \n", - " \n", - " return map_" + " map_.fit_bounds(folium_bounds)\n", + "\n", + " return map_" ] }, { @@ -2356,13 +2407,12 @@ "outputs": [], "source": [ "home_loc_most_cafe_visitors = (\n", - " sample\n", - " .select(f.explode('visitor_home_cbgs'))\n", - " .withColumnRenamed('key', 'cbg')\n", - " .withColumnRenamed('value', 'visitors')\n", - " .groupBy('cbg')\n", - " .agg(f.sum('visitors').alias('visitors'))\n", - " .orderBy('visitors', ascending = False)\n", + " sample.select(f.explode(\"visitor_home_cbgs\"))\n", + " .withColumnRenamed(\"key\", \"cbg\")\n", + " .withColumnRenamed(\"value\", \"visitors\")\n", + " .groupBy(\"cbg\")\n", + " .agg(f.sum(\"visitors\").alias(\"visitors\"))\n", + " .orderBy(\"visitors\", ascending=False)\n", ")" ] }, @@ -2498,15 +2548,15 @@ "source": [ "# Map of the top 1000 CBGs in terms of visitors' origins\n", "home_loc_gdf = (\n", - " home_loc_most_cafe_visitors\n", - " .limit(1000)\n", - " .join(\n", - " WA_cbgs.select('cbg', 'cbg_polygon_geometry'),\n", - " 'cbg'\n", - " )\n", - " .toPandas()\n", + " home_loc_most_cafe_visitors.limit(1000)\n", + " .join(WA_cbgs.select(\"cbg\", \"cbg_polygon_geometry\"), \"cbg\")\n", + " .toPandas()\n", ")\n", - "home_loc_gdf = gpd.GeoDataFrame(home_loc_gdf, geometry = gpd.GeoSeries.from_wkt(home_loc_gdf['cbg_polygon_geometry']), crs = 'EPSG:4326')" + "home_loc_gdf = gpd.GeoDataFrame(\n", + " home_loc_gdf,\n", + " geometry=gpd.GeoSeries.from_wkt(home_loc_gdf[\"cbg_polygon_geometry\"]),\n", + " crs=\"EPSG:4326\",\n", + ")" ] }, { @@ -2523,38 +2573,33 @@ "outputs": [], "source": [ "def map_cbgs(gdf):\n", - " \n", - " # map bounds\n", - " sw = [gdf.unary_union.bounds[1], gdf.unary_union.bounds[0]]\n", - " ne = [gdf.unary_union.bounds[3], gdf.unary_union.bounds[2]]\n", - " folium_bounds = [sw, ne]\n", - " \n", - " # map\n", - " x = gdf.centroid.x[0]\n", - " y = gdf.centroid.y[0]\n", - " \n", - " map_ = folium.Map(\n", - " location = [y, x],\n", - " tiles = \"OpenStreetMap\"\n", - " )\n", - " \n", - " gdf['quantile'] = pd.qcut(gdf['visitors'], 100, labels=False) / 100\n", - " \n", - " folium.GeoJson(\n", - " gdf[['cbg', 'visitors', 'geometry', 'quantile']],\n", - " style_function = lambda x: {\n", - " 'weight':0,\n", - " 'color':'blue',\n", - " 'fillOpacity': x['properties']['quantile']\n", - " },\n", - " tooltip = folium.features.GeoJsonTooltip(\n", - " fields = ['cbg', 'visitors', 'quantile']\n", - " )\n", + "\n", + " # map bounds\n", + " sw = [gdf.unary_union.bounds[1], gdf.unary_union.bounds[0]]\n", + " ne = [gdf.unary_union.bounds[3], gdf.unary_union.bounds[2]]\n", + " folium_bounds = [sw, ne]\n", + "\n", + " # map\n", + " x = gdf.centroid.x[0]\n", + " y = gdf.centroid.y[0]\n", + "\n", + " map_ = folium.Map(location=[y, x], tiles=\"OpenStreetMap\")\n", + "\n", + " gdf[\"quantile\"] = pd.qcut(gdf[\"visitors\"], 100, labels=False) / 100\n", + "\n", + " folium.GeoJson(\n", + " gdf[[\"cbg\", \"visitors\", \"geometry\", \"quantile\"]],\n", + " style_function=lambda x: {\n", + " \"weight\": 0,\n", + " \"color\": \"blue\",\n", + " \"fillOpacity\": x[\"properties\"][\"quantile\"],\n", + " },\n", + " tooltip=folium.features.GeoJsonTooltip(fields=[\"cbg\", \"visitors\", \"quantile\"]),\n", " ).add_to(map_)\n", "\n", - " map_.fit_bounds(folium_bounds) \n", - " \n", - " return map_" + " map_.fit_bounds(folium_bounds)\n", + "\n", + " return map_" ] }, { diff --git a/docs/usecases/utilities.py b/docs/usecases/utilities.py index 6e06bbb778..047e75f3af 100644 --- a/docs/usecases/utilities.py +++ b/docs/usecases/utilities.py @@ -15,99 +15,142 @@ # specific language governing permissions and limitations # under the License. + def getConfig(): - config = {'version': 'v1', - 'config': {'visState': {'filters': [], - 'layers': [{'id': 'ikzru0t', - 'type': 'geojson', - 'config': {'dataId': 'AirportCount', - 'label': 'AirportCount', - 'color': [218, 112, 191], - 'highlightColor': [252, 242, 26, 255], - 'columns': {'geojson': 'geometry'}, - 'isVisible': True, - 'visConfig': {'opacity': 0.8, - 'strokeOpacity': 0.8, - 'thickness': 0.5, - 'strokeColor': [18, 92, 119], - 'colorRange': {'name': 'Uber Viz Sequential 6', - 'type': 'sequential', - 'category': 'Uber', - 'colors': ['#E6FAFA', - '#C1E5E6', - '#9DD0D4', - '#75BBC1', - '#4BA7AF', - '#00939C', - '#108188', - '#0E7077']}, - 'strokeColorRange': {'name': 'Global Warming', - 'type': 'sequential', - 'category': 'Uber', - 'colors': ['#5A1846', - '#900C3F', - '#C70039', - '#E3611C', - '#F1920E', - '#FFC300']}, - 'radius': 10, - 'sizeRange': [0, 10], - 'radiusRange': [0, 50], - 'heightRange': [0, 500], - 'elevationScale': 5, - 'enableElevationZoomFactor': True, - 'stroked': False, - 'filled': True, - 'enable3d': False, - 'wireframe': False}, - 'hidden': False, - 'textLabel': [{'field': None, - 'color': [255, 255, 255], - 'size': 18, - 'offset': [0, 0], - 'anchor': 'start', - 'alignment': 'center'}]}, - 'visualChannels': {'colorField': {'name': 'AirportCount', - 'type': 'integer'}, - 'colorScale': 'quantize', - 'strokeColorField': None, - 'strokeColorScale': 'quantile', - 'sizeField': None, - 'sizeScale': 'linear', - 'heightField': None, - 'heightScale': 'linear', - 'radiusField': None, - 'radiusScale': 'linear'}}], - 'interactionConfig': {'tooltip': {'fieldsToShow': {'AirportCount': [{'name': 'NAME_EN', - 'format': None}, - {'name': 'AirportCount', 'format': None}]}, - 'compareMode': False, - 'compareType': 'absolute', - 'enabled': True}, - 'brush': {'size': 0.5, 'enabled': False}, - 'geocoder': {'enabled': False}, - 'coordinate': {'enabled': False}}, - 'layerBlending': 'normal', - 'splitMaps': [], - 'animationConfig': {'currentTime': None, 'speed': 1}}, - 'mapState': {'bearing': 0, - 'dragRotate': False, - 'latitude': 56.422456606624316, - 'longitude': 9.778836615231771, - 'pitch': 0, - 'zoom': 0.4214991225736964, - 'isSplit': False}, - 'mapStyle': {'styleType': 'dark', - 'topLayerGroups': {}, - 'visibleLayerGroups': {'label': True, - 'road': True, - 'border': False, - 'building': True, - 'water': True, - 'land': True, - '3d building': False}, - 'threeDBuildingColor': [9.665468314072013, - 17.18305478057247, - 31.1442867897876], - 'mapStyles': {}}}} + config = { + "version": "v1", + "config": { + "visState": { + "filters": [], + "layers": [ + { + "id": "ikzru0t", + "type": "geojson", + "config": { + "dataId": "AirportCount", + "label": "AirportCount", + "color": [218, 112, 191], + "highlightColor": [252, 242, 26, 255], + "columns": {"geojson": "geometry"}, + "isVisible": True, + "visConfig": { + "opacity": 0.8, + "strokeOpacity": 0.8, + "thickness": 0.5, + "strokeColor": [18, 92, 119], + "colorRange": { + "name": "Uber Viz Sequential 6", + "type": "sequential", + "category": "Uber", + "colors": [ + "#E6FAFA", + "#C1E5E6", + "#9DD0D4", + "#75BBC1", + "#4BA7AF", + "#00939C", + "#108188", + "#0E7077", + ], + }, + "strokeColorRange": { + "name": "Global Warming", + "type": "sequential", + "category": "Uber", + "colors": [ + "#5A1846", + "#900C3F", + "#C70039", + "#E3611C", + "#F1920E", + "#FFC300", + ], + }, + "radius": 10, + "sizeRange": [0, 10], + "radiusRange": [0, 50], + "heightRange": [0, 500], + "elevationScale": 5, + "enableElevationZoomFactor": True, + "stroked": False, + "filled": True, + "enable3d": False, + "wireframe": False, + }, + "hidden": False, + "textLabel": [ + { + "field": None, + "color": [255, 255, 255], + "size": 18, + "offset": [0, 0], + "anchor": "start", + "alignment": "center", + } + ], + }, + "visualChannels": { + "colorField": {"name": "AirportCount", "type": "integer"}, + "colorScale": "quantize", + "strokeColorField": None, + "strokeColorScale": "quantile", + "sizeField": None, + "sizeScale": "linear", + "heightField": None, + "heightScale": "linear", + "radiusField": None, + "radiusScale": "linear", + }, + } + ], + "interactionConfig": { + "tooltip": { + "fieldsToShow": { + "AirportCount": [ + {"name": "NAME_EN", "format": None}, + {"name": "AirportCount", "format": None}, + ] + }, + "compareMode": False, + "compareType": "absolute", + "enabled": True, + }, + "brush": {"size": 0.5, "enabled": False}, + "geocoder": {"enabled": False}, + "coordinate": {"enabled": False}, + }, + "layerBlending": "normal", + "splitMaps": [], + "animationConfig": {"currentTime": None, "speed": 1}, + }, + "mapState": { + "bearing": 0, + "dragRotate": False, + "latitude": 56.422456606624316, + "longitude": 9.778836615231771, + "pitch": 0, + "zoom": 0.4214991225736964, + "isSplit": False, + }, + "mapStyle": { + "styleType": "dark", + "topLayerGroups": {}, + "visibleLayerGroups": { + "label": True, + "road": True, + "border": False, + "building": True, + "water": True, + "land": True, + "3d building": False, + }, + "threeDBuildingColor": [ + 9.665468314072013, + 17.18305478057247, + 31.1442867897876, + ], + "mapStyles": {}, + }, + }, + } return config diff --git a/python/sedona/core/SpatialRDD/__init__.py b/python/sedona/core/SpatialRDD/__init__.py index e51e5470e7..c39df78353 100644 --- a/python/sedona/core/SpatialRDD/__init__.py +++ b/python/sedona/core/SpatialRDD/__init__.py @@ -24,5 +24,10 @@ __all__ = [ - "PolygonRDD", "PointRDD", "CircleRDD", "LineStringRDD", "RectangleRDD", "SpatialRDD" + "PolygonRDD", + "PointRDD", + "CircleRDD", + "LineStringRDD", + "RectangleRDD", + "SpatialRDD", ] diff --git a/python/sedona/core/SpatialRDD/circle_rdd.py b/python/sedona/core/SpatialRDD/circle_rdd.py index ea98534d6f..48142a52a8 100644 --- a/python/sedona/core/SpatialRDD/circle_rdd.py +++ b/python/sedona/core/SpatialRDD/circle_rdd.py @@ -29,34 +29,35 @@ def __init__(self, spatialRDD: SpatialRDD, Radius: float): :param Radius: float """ super()._do_init(spatialRDD._sc) - self._srdd = self._jvm_spatial_rdd( - spatialRDD._srdd, - Radius - ) + self._srdd = self._jvm_spatial_rdd(spatialRDD._srdd, Radius) - def getCenterPointAsSpatialRDD(self) -> 'PointRDD': + def getCenterPointAsSpatialRDD(self) -> "PointRDD": from sedona.core.SpatialRDD import PointRDD + srdd = self._srdd.getCenterPointAsSpatialRDD() point_rdd = PointRDD() point_rdd.set_srdd(srdd) return point_rdd - def getCenterPolygonAsSpatialRDD(self) -> 'PolygonRDD': + def getCenterPolygonAsSpatialRDD(self) -> "PolygonRDD": from sedona.core.SpatialRDD import PolygonRDD + srdd = self._srdd.getCenterPolygonAsSpatialRDD() polygon_rdd = PolygonRDD() polygon_rdd.set_srdd(srdd) return polygon_rdd - def getCenterLineStringRDDAsSpatialRDD(self) -> 'LineStringRDD': + def getCenterLineStringRDDAsSpatialRDD(self) -> "LineStringRDD": from sedona.core.SpatialRDD import LineStringRDD + srdd = self._srdd.getCenterPolygonAsSpatialRDD() linestring_rdd = LineStringRDD() linestring_rdd.set_srdd(srdd) return linestring_rdd - def getCenterRectangleRDDAsSpatialRDD(self) -> 'RectangleRDD': + def getCenterRectangleRDDAsSpatialRDD(self) -> "RectangleRDD": from sedona.core.SpatialRDD import RectangleRDD + srdd = self._srdd.getCenterLineStringRDDAsSpatialRDD() rectangle_rdd = RectangleRDD() rectangle_rdd.set_srdd(srdd) diff --git a/python/sedona/core/SpatialRDD/linestring_rdd.py b/python/sedona/core/SpatialRDD/linestring_rdd.py index 0bd720f56f..63df1587cb 100644 --- a/python/sedona/core/SpatialRDD/linestring_rdd.py +++ b/python/sedona/core/SpatialRDD/linestring_rdd.py @@ -30,7 +30,9 @@ class LineStringRDD(SpatialRDD, metaclass=MultipleMeta): def __init__(self, rdd: RDD): super().__init__(rdd.ctx) - spatial_rdd = PythonRddToJavaRDDAdapter(self._jvm).deserialize_to_linestring_raw_rdd(rdd._jrdd) + spatial_rdd = PythonRddToJavaRDDAdapter( + self._jvm + ).deserialize_to_linestring_raw_rdd(rdd._jrdd) srdd = self._jvm_spatial_rdd(spatial_rdd) self._srdd = srdd @@ -48,8 +50,16 @@ def __init__(self, rawSpatialRDD: JvmSpatialRDD): jsrdd = rawSpatialRDD.jsrdd self._srdd = self._jvm_spatial_rdd(jsrdd) - def __init__(self, sparkContext: SparkContext, InputLocation: str, startOffset: int, endOffset: int, - splitter: FileDataSplitter, carryInputData: bool, partitions: int): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + startOffset: int, + endOffset: int, + splitter: FileDataSplitter, + carryInputData: bool, + partitions: int, + ): """ :param sparkContext: SparkContext instance @@ -70,11 +80,18 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, startOffset: endOffset, jvm_splitter, carryInputData, - partitions + partitions, ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, startOffset: int, endOffset: int, - splitter: FileDataSplitter, carryInputData: bool): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + startOffset: int, + endOffset: int, + splitter: FileDataSplitter, + carryInputData: bool, + ): """ :param sparkContext: SparkContext instance @@ -96,8 +113,14 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, startOffset: carryInputData, ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: FileDataSplitter, carryInputData: bool, - partitions: int): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + splitter: FileDataSplitter, + carryInputData: bool, + partitions: int, + ): """ :param sparkContext: SparkContext instance @@ -110,15 +133,16 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: Fil super().__init__(sparkContext) jvm_splitter = FileSplitterJvm(self._jvm, splitter).jvm_instance self._srdd = self._jvm_spatial_rdd( - self._jsc, - InputLocation, - jvm_splitter, - carryInputData, - partitions + self._jsc, InputLocation, jvm_splitter, carryInputData, partitions ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: FileDataSplitter, - carryInputData: bool): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + splitter: FileDataSplitter, + carryInputData: bool, + ): """ :param sparkContext: SparkContext instance @@ -130,10 +154,7 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: Fil super().__init__(sparkContext) jvm_splitter = FileSplitterJvm(self._jvm, splitter).jvm_instance self._srdd = self._jvm_spatial_rdd( - self._jsc, - InputLocation, - jvm_splitter, - carryInputData + self._jsc, InputLocation, jvm_splitter, carryInputData ) @property @@ -146,6 +167,7 @@ def _jvm_spatial_rdd(self): def MinimumBoundingRectangle(self): from sedona.core.SpatialRDD import RectangleRDD + rectangle_rdd = RectangleRDD() srdd = self._srdd.MinimumBoundingRectangle() diff --git a/python/sedona/core/SpatialRDD/point_rdd.py b/python/sedona/core/SpatialRDD/point_rdd.py index 658e59217b..dbe401db6d 100644 --- a/python/sedona/core/SpatialRDD/point_rdd.py +++ b/python/sedona/core/SpatialRDD/point_rdd.py @@ -33,7 +33,9 @@ def __init__(self, rdd: RDD): """ super().__init__(rdd.ctx) - spatial_rdd = PythonRddToJavaRDDAdapter(self._jvm).deserialize_to_point_raw_rdd(rdd._jrdd) + spatial_rdd = PythonRddToJavaRDDAdapter(self._jvm).deserialize_to_point_raw_rdd( + rdd._jrdd + ) srdd = self._jvm_spatial_rdd(spatial_rdd) self._srdd = srdd @@ -51,8 +53,15 @@ def __init__(self, rawSpatialRDD: JvmSpatialRDD): jsrdd = rawSpatialRDD.jsrdd self._srdd = self._jvm_spatial_rdd(jsrdd) - def __init__(self, sparkContext: SparkContext, InputLocation: str, Offset: int, splitter: FileDataSplitter, - carryInputData: bool, partitions: int): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + Offset: int, + splitter: FileDataSplitter, + carryInputData: bool, + partitions: int, + ): """ :param sparkContext: SparkContext instance @@ -71,11 +80,17 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, Offset: int, Offset, jvm_splitter, carryInputData, - partitions + partitions, ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, Offset: int, splitter: FileDataSplitter, - carryInputData: bool): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + Offset: int, + splitter: FileDataSplitter, + carryInputData: bool, + ): """ :param sparkContext: SparkContext instance @@ -87,15 +102,17 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, Offset: int, super().__init__(sparkContext) jvm_splitter = FileSplitterJvm(self._jvm, splitter).jvm_instance self._srdd = self._jvm_spatial_rdd( - sparkContext._jsc, - InputLocation, - Offset, - jvm_splitter, - carryInputData + sparkContext._jsc, InputLocation, Offset, jvm_splitter, carryInputData ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: FileDataSplitter, carryInputData: bool, - partitions: int): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + splitter: FileDataSplitter, + carryInputData: bool, + partitions: int, + ): """ :param sparkContext: SparkContext instance @@ -107,15 +124,16 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: Fil super().__init__(sparkContext) jvm_splitter = FileSplitterJvm(self._jvm, splitter).jvm_instance self._srdd = self._jvm_spatial_rdd( - self._jsc, - InputLocation, - jvm_splitter, - carryInputData, - partitions + self._jsc, InputLocation, jvm_splitter, carryInputData, partitions ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: FileDataSplitter, - carryInputData: bool): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + splitter: FileDataSplitter, + carryInputData: bool, + ): """ :param sparkContext: SparkContext instance @@ -127,10 +145,7 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: Fil super().__init__(sparkContext) jvm_splitter = FileSplitterJvm(self._jvm, splitter).jvm_instance self._srdd = self._jvm_spatial_rdd( - self._jsc, - InputLocation, - jvm_splitter, - carryInputData + self._jsc, InputLocation, jvm_splitter, carryInputData ) def MinimumBoundingRectangle(self): diff --git a/python/sedona/core/SpatialRDD/polygon_rdd.py b/python/sedona/core/SpatialRDD/polygon_rdd.py index c72ba9c72b..072776980c 100644 --- a/python/sedona/core/SpatialRDD/polygon_rdd.py +++ b/python/sedona/core/SpatialRDD/polygon_rdd.py @@ -28,7 +28,9 @@ class PolygonRDD(SpatialRDD, metaclass=MultipleMeta): def __init__(self, rdd: RDD): super().__init__(rdd.ctx) - spatial_rdd = PythonRddToJavaRDDAdapter(self._jvm).deserialize_to_polygon_raw_rdd(rdd._jrdd) + spatial_rdd = PythonRddToJavaRDDAdapter( + self._jvm + ).deserialize_to_polygon_raw_rdd(rdd._jrdd) srdd = self._jvm_spatial_rdd(spatial_rdd) self._srdd = srdd @@ -45,8 +47,16 @@ def __init__(self, rawSpatialRDD: JvmSpatialRDD): jsrdd = rawSpatialRDD.jsrdd self._srdd = self._jvm_spatial_rdd(jsrdd) - def __init__(self, sparkContext: SparkContext, InputLocation: str, startOffset: int, endOffset: int, - splitter: FileDataSplitter, carryInputData: bool, partitions: int): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + startOffset: int, + endOffset: int, + splitter: FileDataSplitter, + carryInputData: bool, + partitions: int, + ): """ :param sparkContext: SparkContext, the spark context @@ -67,11 +77,18 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, startOffset: endOffset, jvm_splitter.jvm_instance, carryInputData, - partitions + partitions, ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, startOffset: int, endOffset: int, - splitter: FileDataSplitter, carryInputData: bool): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + startOffset: int, + endOffset: int, + splitter: FileDataSplitter, + carryInputData: bool, + ): """ :param sparkContext: SparkContext, the spark context @@ -90,11 +107,17 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, startOffset: startOffset, endOffset, jvm_splitter.jvm_instance, - carryInputData + carryInputData, ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: FileDataSplitter, - carryInputData: bool, partitions: int): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + splitter: FileDataSplitter, + carryInputData: bool, + partitions: int, + ): """ :param sparkContext: SparkContext, the spark context @@ -112,11 +135,16 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: Fil InputLocation, jvm_splitter.jvm_instance, carryInputData, - partitions + partitions, ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: FileDataSplitter, - carryInputData: bool): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + splitter: FileDataSplitter, + carryInputData: bool, + ): """ :param sparkContext: SparkContext, the spark context @@ -129,14 +157,12 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: Fil jvm_splitter = FileSplitterJvm(self._jvm, splitter) self._srdd = self._jvm_spatial_rdd( - self._jsc, - InputLocation, - jvm_splitter.jvm_instance, - carryInputData + self._jsc, InputLocation, jvm_splitter.jvm_instance, carryInputData ) def MinimumBoundingRectangle(self): from sedona.core.SpatialRDD import RectangleRDD + rectangle_rdd = RectangleRDD() srdd = self._srdd.MinimumBoundingRectangle() diff --git a/python/sedona/core/SpatialRDD/rectangle_rdd.py b/python/sedona/core/SpatialRDD/rectangle_rdd.py index 48d9a222f5..d5de40d5be 100644 --- a/python/sedona/core/SpatialRDD/rectangle_rdd.py +++ b/python/sedona/core/SpatialRDD/rectangle_rdd.py @@ -39,8 +39,15 @@ def __init__(self, rawSpatialRDD: JvmSpatialRDD): jsrdd = rawSpatialRDD.jsrdd self._srdd = self._jvm_spatial_rdd(jsrdd) - def __init__(self, sparkContext: SparkContext, InputLocation: str, Offset: int, - splitter: FileDataSplitter, carryInputData: bool, partitions: int): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + Offset: int, + splitter: FileDataSplitter, + carryInputData: bool, + partitions: int, + ): """ :param sparkContext: SparkContext, the spark context @@ -59,11 +66,17 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, Offset: int, Offset, jvm_splitter.jvm_instance, carryInputData, - partitions + partitions, ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, Offset: int, - splitter: FileDataSplitter, carryInputData: bool): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + Offset: int, + splitter: FileDataSplitter, + carryInputData: bool, + ): """ :param sparkContext: SparkContext, the spark context @@ -76,15 +89,17 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, Offset: int, jvm_splitter = FileSplitterJvm(self._jvm, splitter) self._srdd = self._jvm_spatial_rdd( - self._jsc, - InputLocation, - Offset, - jvm_splitter.jvm_instance, - carryInputData + self._jsc, InputLocation, Offset, jvm_splitter.jvm_instance, carryInputData ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: FileDataSplitter, - carryInputData: bool, partitions: int): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + splitter: FileDataSplitter, + carryInputData: bool, + partitions: int, + ): """ :param sparkContext: SparkContext, the spark context @@ -102,11 +117,16 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: Fil InputLocation, jvm_splitter.jvm_instance, carryInputData, - partitions + partitions, ) - def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: FileDataSplitter, - carryInputData: bool): + def __init__( + self, + sparkContext: SparkContext, + InputLocation: str, + splitter: FileDataSplitter, + carryInputData: bool, + ): """ :param sparkContext: SparkContext, the spark context @@ -119,11 +139,9 @@ def __init__(self, sparkContext: SparkContext, InputLocation: str, splitter: Fil jvm_splitter = FileSplitterJvm(self._jvm, splitter) self._srdd = self._jvm_spatial_rdd( - self._jsc, - InputLocation, - jvm_splitter.jvm_instance, - carryInputData + self._jsc, InputLocation, jvm_splitter.jvm_instance, carryInputData ) + @property def _jvm_spatial_rdd(self): spatial_factory = SpatialRDDFactory(self._sc) @@ -133,4 +151,6 @@ def _jvm_spatial_rdd(self): return jvm_polygon_rdd def MinimumBoundingRectangle(self): - raise NotImplementedError("RectangleRDD has not MinimumBoundingRectangle method.") + raise NotImplementedError( + "RectangleRDD has not MinimumBoundingRectangle method." + ) diff --git a/python/sedona/core/SpatialRDD/spatial_rdd.py b/python/sedona/core/SpatialRDD/spatial_rdd.py index cc0fbe6e41..5b3376437a 100644 --- a/python/sedona/core/SpatialRDD/spatial_rdd.py +++ b/python/sedona/core/SpatialRDD/spatial_rdd.py @@ -42,7 +42,7 @@ class SpatialPartitioner: jvm_partitioner = attr.ib() @classmethod - def from_java_class_name(cls, jvm_partitioner) -> 'SpatialPartitioner': + def from_java_class_name(cls, jvm_partitioner) -> "SpatialPartitioner": if jvm_partitioner is not None: jvm_full_name = jvm_partitioner.toString() full_class_name = jvm_full_name.split("@")[0] @@ -63,7 +63,9 @@ def saveAsObjectFile(self, location: str): self.jsrdd.saveAsObjectFile(location) def persist(self, storage_level: StorageLevel): - new_jsrdd = self.jsrdd.persist(JvmStorageLevel(self.sc._jvm, storage_level).jvm_instance) + new_jsrdd = self.jsrdd.persist( + JvmStorageLevel(self.sc._jvm, storage_level).jvm_instance + ) self.jsrdd = new_jsrdd def count(self): @@ -155,7 +157,9 @@ def boundaryEnvelope(self) -> Envelope: java_boundary_envelope = get_field(self._srdd, "boundaryEnvelope") return Envelope.from_jvm_instance(java_boundary_envelope) - def buildIndex(self, indexType: Union[str, IndexType], buildIndexOnSpatialPartitionedRDD: bool) -> bool: + def buildIndex( + self, indexType: Union[str, IndexType], buildIndexOnSpatialPartitionedRDD: bool + ) -> bool: """ :param indexType: @@ -171,8 +175,7 @@ def buildIndex(self, indexType: Union[str, IndexType], buildIndexOnSpatialPartit else: raise TypeError("indexType should be str or IndexType") return self._srdd.buildIndex( - index_type.jvm_instance, - buildIndexOnSpatialPartitionedRDD + index_type.jvm_instance, buildIndexOnSpatialPartitionedRDD ) else: raise AttributeError("Please run spatial partitioning before") @@ -224,12 +227,17 @@ def getRawSpatialRDD(self): :return: """ - serialized_spatial_rdd = SedonaPythonConverter(self._jvm).translate_spatial_rdd_to_python( - self._srdd.getRawSpatialRDD()) + serialized_spatial_rdd = SedonaPythonConverter( + self._jvm + ).translate_spatial_rdd_to_python(self._srdd.getRawSpatialRDD()) if not hasattr(self, "_raw_spatial_rdd"): RDD.saveAsObjectFile = lambda x, path: x._jrdd.saveAsObjectFile(path) - setattr(self, "_raw_spatial_rdd", RDD(serialized_spatial_rdd, self._sc, SedonaPickler())) + setattr( + self, + "_raw_spatial_rdd", + RDD(serialized_spatial_rdd, self._sc, SedonaPickler()), + ) else: self._raw_spatial_rdd._jrdd = serialized_spatial_rdd @@ -269,7 +277,10 @@ def grids(self) -> Optional[List[Envelope]]: if jvm_grids: number_of_grids = jvm_grids.size() - envelopes = [Envelope.from_jvm_instance(jvm_grids[index]) for index in range(number_of_grids)] + envelopes = [ + Envelope.from_jvm_instance(jvm_grids[index]) + for index in range(number_of_grids) + ] return envelopes else: @@ -338,7 +349,9 @@ def rawSpatialRDD(self, spatial_rdd): self._jvm = spatial_rdd._jvm self._spatial_partitioned = spatial_rdd._spatial_partitioned elif isinstance(spatial_rdd, RDD): - jrdd = JvmSedonaPythonConverter(self._jvm).translate_python_rdd_to_java(spatial_rdd._jrdd) + jrdd = JvmSedonaPythonConverter(self._jvm).translate_python_rdd_to_java( + spatial_rdd._jrdd + ) self._srdd.setRawSpatialRDD(jrdd) else: self._srdd.setRawSpatialRDD(spatial_rdd) @@ -387,18 +400,28 @@ def spatialPartitionedRDD(self): :return: """ - serialized_spatial_rdd = SedonaPythonConverter(self._jvm).translate_spatial_rdd_to_python( - get_field(self._srdd, "spatialPartitionedRDD")) + serialized_spatial_rdd = SedonaPythonConverter( + self._jvm + ).translate_spatial_rdd_to_python( + get_field(self._srdd, "spatialPartitionedRDD") + ) if not hasattr(self, "_spatial_partitioned_rdd"): - setattr(self, "_spatial_partitioned_rdd", RDD(serialized_spatial_rdd, self._sc, SedonaPickler())) + setattr( + self, + "_spatial_partitioned_rdd", + RDD(serialized_spatial_rdd, self._sc, SedonaPickler()), + ) else: self._spatial_partitioned_rdd._jrdd = serialized_spatial_rdd return getattr(self, "_spatial_partitioned_rdd") - def spatialPartitioning(self, partitioning: Union[str, GridType, SpatialPartitioner, List[Envelope]], - num_partitions: Optional[int] = None) -> bool: + def spatialPartitioning( + self, + partitioning: Union[str, GridType, SpatialPartitioner, List[Envelope]], + num_partitions: Optional[int] = None, + ) -> bool: """ :param partitioning: partitioning type @@ -435,7 +458,11 @@ def get_srdd(self): return self._srdd def getRawJvmSpatialRDD(self) -> JvmSpatialRDD: - return JvmSpatialRDD(jsrdd=self._srdd.getRawSpatialRDD(), sc=self._sc, tp=SpatialType.from_str(self.name)) + return JvmSpatialRDD( + jsrdd=self._srdd.getRawSpatialRDD(), + sc=self._sc, + tp=SpatialType.from_str(self.name), + ) @property def rawJvmSpatialRDD(self) -> JvmSpatialRDD: @@ -444,7 +471,9 @@ def rawJvmSpatialRDD(self) -> JvmSpatialRDD: @rawJvmSpatialRDD.setter def rawJvmSpatialRDD(self, jsrdd_p: JvmSpatialRDD): if jsrdd_p.tp.value.lower() != self.name: - raise TypeError(f"value should be type {self.name} but {jsrdd_p.tp} was found") + raise TypeError( + f"value should be type {self.name} but {jsrdd_p.tp} was found" + ) self._sc = jsrdd_p.sc self._jvm = self._sc._jvm @@ -452,8 +481,10 @@ def rawJvmSpatialRDD(self, jsrdd_p: JvmSpatialRDD): self.setRawSpatialRDD(jsrdd_p.jsrdd) def getJvmSpatialPartitionedRDD(self) -> JvmSpatialRDD: - return JvmSpatialRDD(jsrdd=get_field( - self._srdd, "spatialPartitionedRDD"), sc=self._sc, tp=SpatialType.from_str(self.name) + return JvmSpatialRDD( + jsrdd=get_field(self._srdd, "spatialPartitionedRDD"), + sc=self._sc, + tp=SpatialType.from_str(self.name), ) @property @@ -463,7 +494,9 @@ def jvmSpatialPartitionedRDD(self) -> JvmSpatialRDD: @jvmSpatialPartitionedRDD.setter def jvmSpatialPartitionedRDD(self, jsrdd_p: JvmSpatialRDD): if jsrdd_p.tp.value.lower() != self.name: - raise TypeError(f"value should be type {self.name} but {jsrdd_p.tp} was found") + raise TypeError( + f"value should be type {self.name} but {jsrdd_p.tp} was found" + ) self._sc = jsrdd_p.sc self._jvm = self._sc._jvm diff --git a/python/sedona/core/enums/file_data_splitter.py b/python/sedona/core/enums/file_data_splitter.py index d0ecdbe41b..81b402b3f3 100644 --- a/python/sedona/core/enums/file_data_splitter.py +++ b/python/sedona/core/enums/file_data_splitter.py @@ -47,7 +47,11 @@ class FileSplitterJvm(JvmObject): splitter = attr.ib(type=FileDataSplitter) def _create_jvm_instance(self): - return self.jvm_splitter(self.splitter.value) if self.splitter.value is not None else None + return ( + self.jvm_splitter(self.splitter.value) + if self.splitter.value is not None + else None + ) @property @require(["FileDataSplitter"]) diff --git a/python/sedona/core/enums/grid_type.py b/python/sedona/core/enums/grid_type.py index b5d816b9e8..89deeee964 100644 --- a/python/sedona/core/enums/grid_type.py +++ b/python/sedona/core/enums/grid_type.py @@ -28,7 +28,7 @@ class GridType(Enum): KDBTREE = "KDBTREE" @classmethod - def from_str(cls, grid: str) -> 'GridType': + def from_str(cls, grid: str) -> "GridType": try: grid = getattr(cls, grid.upper()) except AttributeError: diff --git a/python/sedona/core/enums/index_type.py b/python/sedona/core/enums/index_type.py index 290f43a02b..3a00004d84 100644 --- a/python/sedona/core/enums/index_type.py +++ b/python/sedona/core/enums/index_type.py @@ -41,7 +41,11 @@ class IndexTypeJvm(JvmObject): index_type = attr.ib(type=IndexType) def _create_jvm_instance(self): - return self.jvm_index(self.index_type.value) if self.index_type.value is not None else None + return ( + self.jvm_index(self.index_type.value) + if self.index_type.value is not None + else None + ) @property @require(["FileDataSplitter"]) diff --git a/python/sedona/core/enums/spatial.py b/python/sedona/core/enums/spatial.py index f8c6fd52cd..da92e374fd 100644 --- a/python/sedona/core/enums/spatial.py +++ b/python/sedona/core/enums/spatial.py @@ -27,7 +27,7 @@ class SpatialType(Enum): CIRCLE = "CIRCLE" @classmethod - def from_str(cls, spatial: str) -> 'SpatialType': + def from_str(cls, spatial: str) -> "SpatialType": try: spatial = getattr(cls, spatial.upper()) except AttributeError: diff --git a/python/sedona/core/formatMapper/__init__.py b/python/sedona/core/formatMapper/__init__.py index 1292e27a94..0daadc2177 100644 --- a/python/sedona/core/formatMapper/__init__.py +++ b/python/sedona/core/formatMapper/__init__.py @@ -19,4 +19,4 @@ from .wkt_reader import WktReader from .wkb_reader import WkbReader -__all__ = ["GeoJsonReader", 'WktReader', 'WkbReader'] +__all__ = ["GeoJsonReader", "WktReader", "WkbReader"] diff --git a/python/sedona/core/formatMapper/disc_utils.py b/python/sedona/core/formatMapper/disc_utils.py index 2fd3872030..34d6d17482 100644 --- a/python/sedona/core/formatMapper/disc_utils.py +++ b/python/sedona/core/formatMapper/disc_utils.py @@ -71,7 +71,9 @@ class LineStringRDDDiscLoader(DiscLoader): def load(cls, sc: SparkContext, path: str) -> SpatialRDD: jvm = sc._jvm line_string_rdd = LineStringRDD() - srdd = SpatialObjectLoaderAdapter(jvm).load_line_string_spatial_rdd(sc._jsc, path) + srdd = SpatialObjectLoaderAdapter(jvm).load_line_string_spatial_rdd( + sc._jsc, path + ) line_string_rdd.set_srdd(srdd) return line_string_rdd @@ -99,7 +101,7 @@ class GeoType(Enum): GeoType.POINT: PointRDDDiscLoader, GeoType.POLYGON: PolygonRDDDiscLoader, GeoType.LINESTRING: LineStringRDDDiscLoader, - GeoType.GEOMETRY: SpatialRDDDiscLoader + GeoType.GEOMETRY: SpatialRDDDiscLoader, } diff --git a/python/sedona/core/formatMapper/geo_json_reader.py b/python/sedona/core/formatMapper/geo_json_reader.py index 8296b2b698..e0eeec695d 100644 --- a/python/sedona/core/formatMapper/geo_json_reader.py +++ b/python/sedona/core/formatMapper/geo_json_reader.py @@ -32,17 +32,20 @@ def readToGeometryRDD(cls, sc: SparkContext, inputPath: str) -> SpatialRDD: :return: SpatialRDD """ jvm = sc._jvm - srdd = jvm.GeoJsonReader.readToGeometryRDD( - sc._jsc, inputPath - ) + srdd = jvm.GeoJsonReader.readToGeometryRDD(sc._jsc, inputPath) spatial_rdd = SpatialRDD(sc) spatial_rdd.set_srdd(srdd) return spatial_rdd @classmethod - def readToGeometryRDD(cls, sc: SparkContext, inputPath: str, allowInvalidGeometries: bool, - skipSyntacticallyInvalidGeometries: bool) -> SpatialRDD: + def readToGeometryRDD( + cls, + sc: SparkContext, + inputPath: str, + allowInvalidGeometries: bool, + skipSyntacticallyInvalidGeometries: bool, + ) -> SpatialRDD: """ :param sc: SparkContext @@ -53,7 +56,10 @@ def readToGeometryRDD(cls, sc: SparkContext, inputPath: str, allowInvalidGeometr """ jvm = sc._jvm srdd = jvm.GeoJsonReader.readToGeometryRDD( - sc._jsc, inputPath, allowInvalidGeometries, skipSyntacticallyInvalidGeometries + sc._jsc, + inputPath, + allowInvalidGeometries, + skipSyntacticallyInvalidGeometries, ) spatial_rdd = SpatialRDD(sc) @@ -70,17 +76,19 @@ def readToGeometryRDD(cls, rawTextRDD: RDD) -> SpatialRDD: sc = rawTextRDD.ctx jvm = sc._jvm - srdd = jvm.GeoJsonReader.readToGeometryRDD( - rawTextRDD._jrdd - ) + srdd = jvm.GeoJsonReader.readToGeometryRDD(rawTextRDD._jrdd) spatial_rdd = SpatialRDD(sc) spatial_rdd.set_srdd(srdd) return spatial_rdd @classmethod - def readToGeometryRDD(cls, rawTextRDD: RDD, allowInvalidGeometries: bool, - skipSyntacticallyInvalidGeometries: bool) -> SpatialRDD: + def readToGeometryRDD( + cls, + rawTextRDD: RDD, + allowInvalidGeometries: bool, + skipSyntacticallyInvalidGeometries: bool, + ) -> SpatialRDD: """ :param rawTextRDD: RDD diff --git a/python/sedona/core/formatMapper/geo_reader.py b/python/sedona/core/formatMapper/geo_reader.py index d60b8053d0..1cfb373a97 100644 --- a/python/sedona/core/formatMapper/geo_reader.py +++ b/python/sedona/core/formatMapper/geo_reader.py @@ -28,4 +28,5 @@ class GeoDataReader(metaclass=MultipleMeta): @abc.abstractmethod def readToGeometryRDD(cls, *args, **kwargs): raise NotImplementedError( - f"Instance of the class {cls.__class__.__name__} has to implement method readToGeometryRDD") + f"Instance of the class {cls.__class__.__name__} has to implement method readToGeometryRDD" + ) diff --git a/python/sedona/core/formatMapper/shapefileParser/shape_file_reader.py b/python/sedona/core/formatMapper/shapefileParser/shape_file_reader.py index 18d807b7a3..48efe0eba0 100644 --- a/python/sedona/core/formatMapper/shapefileParser/shape_file_reader.py +++ b/python/sedona/core/formatMapper/shapefileParser/shape_file_reader.py @@ -37,10 +37,7 @@ def readToGeometryRDD(cls, sc: SparkContext, inputPath: str) -> SpatialRDD: """ jvm = sc._jvm jsc = sc._jsc - srdd = jvm.ShapefileReader.readToGeometryRDD( - jsc, - inputPath - ) + srdd = jvm.ShapefileReader.readToGeometryRDD(jsc, inputPath) spatial_rdd = SpatialRDD(sc=sc) spatial_rdd.set_srdd(srdd) @@ -56,10 +53,7 @@ def readToPolygonRDD(cls, sc: SparkContext, inputPath: str) -> PolygonRDD: """ jvm = sc._jvm jsc = sc._jsc - srdd = jvm.ShapefileReader.readToPolygonRDD( - jsc, - inputPath - ) + srdd = jvm.ShapefileReader.readToPolygonRDD(jsc, inputPath) spatial_rdd = PolygonRDD() spatial_rdd.set_srdd(srdd) return spatial_rdd @@ -74,10 +68,7 @@ def readToPointRDD(cls, sc: SparkContext, inputPath: str) -> PointRDD: """ jvm = sc._jvm jsc = sc._jsc - srdd = jvm.ShapefileReader.readToPointRDD( - jsc, - inputPath - ) + srdd = jvm.ShapefileReader.readToPointRDD(jsc, inputPath) spatial_rdd = PointRDD() spatial_rdd.set_srdd(srdd) return spatial_rdd @@ -92,10 +83,7 @@ def readToLineStringRDD(cls, sc: SparkContext, inputPath: str) -> LineStringRDD: """ jvm = sc._jvm jsc = sc._jsc - srdd = jvm.ShapefileReader.readToLineStringRDD( - jsc, - inputPath - ) + srdd = jvm.ShapefileReader.readToLineStringRDD(jsc, inputPath) spatial_rdd = LineStringRDD() spatial_rdd.set_srdd(srdd) return spatial_rdd diff --git a/python/sedona/core/formatMapper/wkb_reader.py b/python/sedona/core/formatMapper/wkb_reader.py index 3e9c9f67d1..23cc27c9dd 100644 --- a/python/sedona/core/formatMapper/wkb_reader.py +++ b/python/sedona/core/formatMapper/wkb_reader.py @@ -25,8 +25,14 @@ class WkbReader(GeoDataReader, metaclass=MultipleMeta): @classmethod - def readToGeometryRDD(cls, sc: SparkContext, inputPath: str, wkbColumn: int, allowInvalidGeometries: bool, - skipSyntacticallyInvalidGeometries: bool) -> SpatialRDD: + def readToGeometryRDD( + cls, + sc: SparkContext, + inputPath: str, + wkbColumn: int, + allowInvalidGeometries: bool, + skipSyntacticallyInvalidGeometries: bool, + ) -> SpatialRDD: """ :param sc: @@ -38,14 +44,24 @@ def readToGeometryRDD(cls, sc: SparkContext, inputPath: str, wkbColumn: int, all """ jvm = sc._jvm spatial_rdd = SpatialRDD(sc) - srdd = jvm.WkbReader.readToGeometryRDD(sc._jsc, inputPath, wkbColumn, allowInvalidGeometries, - skipSyntacticallyInvalidGeometries) + srdd = jvm.WkbReader.readToGeometryRDD( + sc._jsc, + inputPath, + wkbColumn, + allowInvalidGeometries, + skipSyntacticallyInvalidGeometries, + ) spatial_rdd.set_srdd(srdd) return spatial_rdd @classmethod - def readToGeometryRDD(cls, rawTextRDD: RDD, wkbColumn: int, allowInvalidGeometries: bool, - skipSyntacticallyInvalidGeometries: bool) -> SpatialRDD: + def readToGeometryRDD( + cls, + rawTextRDD: RDD, + wkbColumn: int, + allowInvalidGeometries: bool, + skipSyntacticallyInvalidGeometries: bool, + ) -> SpatialRDD: """ :param rawTextRDD: @@ -58,8 +74,12 @@ def readToGeometryRDD(cls, rawTextRDD: RDD, wkbColumn: int, allowInvalidGeometri jvm = sc._jvm spatial_rdd = SpatialRDD(sc) - srdd = jvm.WkbReader.readToGeometryRDD(rawTextRDD._jrdd, wkbColumn, allowInvalidGeometries, - skipSyntacticallyInvalidGeometries) + srdd = jvm.WkbReader.readToGeometryRDD( + rawTextRDD._jrdd, + wkbColumn, + allowInvalidGeometries, + skipSyntacticallyInvalidGeometries, + ) spatial_rdd.set_srdd(srdd) return spatial_rdd diff --git a/python/sedona/core/formatMapper/wkt_reader.py b/python/sedona/core/formatMapper/wkt_reader.py index e920ec682c..499bbbda59 100644 --- a/python/sedona/core/formatMapper/wkt_reader.py +++ b/python/sedona/core/formatMapper/wkt_reader.py @@ -25,8 +25,14 @@ class WktReader(GeoDataReader, metaclass=MultipleMeta): @classmethod - def readToGeometryRDD(cls, sc: SparkContext, inputPath: str, wktColumn: int, allowInvalidGeometries: bool, - skipSyntacticallyInvalidGeometries: bool) -> SpatialRDD: + def readToGeometryRDD( + cls, + sc: SparkContext, + inputPath: str, + wktColumn: int, + allowInvalidGeometries: bool, + skipSyntacticallyInvalidGeometries: bool, + ) -> SpatialRDD: """ :param sc: SparkContext @@ -37,16 +43,26 @@ def readToGeometryRDD(cls, sc: SparkContext, inputPath: str, wktColumn: int, all :return: """ jvm = sc._jvm - srdd = jvm.WktReader.readToGeometryRDD(sc._jsc, inputPath, wktColumn, allowInvalidGeometries, - skipSyntacticallyInvalidGeometries) + srdd = jvm.WktReader.readToGeometryRDD( + sc._jsc, + inputPath, + wktColumn, + allowInvalidGeometries, + skipSyntacticallyInvalidGeometries, + ) spatial_rdd = SpatialRDD(sc) spatial_rdd.set_srdd(srdd) return spatial_rdd @classmethod - def readToGeometryRDD(cls, rawTextRDD: RDD, wktColumn: int, allowInvalidGeometries: bool, - skipSyntacticallyInvalidGeometries: bool) -> SpatialRDD: + def readToGeometryRDD( + cls, + rawTextRDD: RDD, + wktColumn: int, + allowInvalidGeometries: bool, + skipSyntacticallyInvalidGeometries: bool, + ) -> SpatialRDD: """ :param rawTextRDD: RDD @@ -58,7 +74,10 @@ def readToGeometryRDD(cls, rawTextRDD: RDD, wktColumn: int, allowInvalidGeometri sc = rawTextRDD.ctx jvm = sc._jvm srdd = jvm.WktReader.readToGeometryRDD( - rawTextRDD._jrdd, wktColumn, allowInvalidGeometries, skipSyntacticallyInvalidGeometries + rawTextRDD._jrdd, + wktColumn, + allowInvalidGeometries, + skipSyntacticallyInvalidGeometries, ) spatial_rdd = SpatialRDD(sc) spatial_rdd.set_srdd(srdd) diff --git a/python/sedona/core/geom/circle.py b/python/sedona/core/geom/circle.py index ed091ef38f..669259da53 100644 --- a/python/sedona/core/geom/circle.py +++ b/python/sedona/core/geom/circle.py @@ -18,7 +18,7 @@ import shapely -if shapely.__version__.startswith('2.'): +if shapely.__version__.startswith("2."): from .shapely2.circle import Circle else: from .shapely1.circle import Circle diff --git a/python/sedona/core/geom/envelope.py b/python/sedona/core/geom/envelope.py index 402770394a..db2ec5c4c7 100644 --- a/python/sedona/core/geom/envelope.py +++ b/python/sedona/core/geom/envelope.py @@ -18,7 +18,7 @@ import shapely -if shapely.__version__.startswith('2.'): +if shapely.__version__.startswith("2."): from .shapely2.envelope import Envelope else: from .shapely1.envelope import Envelope diff --git a/python/sedona/core/geom/shapely1/circle.py b/python/sedona/core/geom/shapely1/circle.py index d29b692559..787468f526 100644 --- a/python/sedona/core/geom/shapely1/circle.py +++ b/python/sedona/core/geom/shapely1/circle.py @@ -17,7 +17,14 @@ from math import sqrt -from shapely.geometry import Polygon, Point, LineString, MultiPoint, MultiPolygon, MultiLineString +from shapely.geometry import ( + Polygon, + Point, + LineString, + MultiPoint, + MultiPolygon, + MultiLineString, +) from shapely.geometry.base import BaseGeometry from sedona.core.geom.envelope import Envelope @@ -32,19 +39,23 @@ def __init__(self, centerGeometry: BaseGeometry, givenRadius: float): center_geometry_mbr = Envelope.from_shapely_geom(self.centerGeometry) self.centerPoint = self.centerPoint = Point( (center_geometry_mbr.minx + center_geometry_mbr.maxx) / 2.0, - (center_geometry_mbr.miny + center_geometry_mbr.maxy) / 2.0 + (center_geometry_mbr.miny + center_geometry_mbr.maxy) / 2.0, ) width = center_geometry_mbr.maxx - center_geometry_mbr.minx length = center_geometry_mbr.maxy - center_geometry_mbr.miny - center_geometry_internal_radius = sqrt(width ** 2 + length ** 2) / 2.0 - self.radius = givenRadius if givenRadius > center_geometry_internal_radius else center_geometry_internal_radius + center_geometry_internal_radius = sqrt(width**2 + length**2) / 2.0 + self.radius = ( + givenRadius + if givenRadius > center_geometry_internal_radius + else center_geometry_internal_radius + ) self.MBR = Envelope( self.centerPoint.x - self.radius, self.centerPoint.x + self.radius, self.centerPoint.y - self.radius, - self.centerPoint.y + self.radius + self.centerPoint.y + self.radius, ) super().__init__(self.centerPoint.buffer(self.radius)) @@ -61,13 +72,17 @@ def setRadius(self, givenRadius: float): center_geometry_mbr = Envelope.from_shapely_geom(self.centerGeometry) width = center_geometry_mbr.maxx - center_geometry_mbr.minx length = center_geometry_mbr.maxy - center_geometry_mbr.miny - center_geometry_internal_radius = sqrt(width ** 2 + length ** 2) / 2 - self.radius = givenRadius if givenRadius > center_geometry_internal_radius else center_geometry_internal_radius + center_geometry_internal_radius = sqrt(width**2 + length**2) / 2 + self.radius = ( + givenRadius + if givenRadius > center_geometry_internal_radius + else center_geometry_internal_radius + ) self.MBR = Envelope( self.centerPoint.x - self.radius, self.centerPoint.x + self.radius, self.centerPoint.y - self.radius, - self.centerPoint.y + self.radius + self.centerPoint.y + self.radius, ) def covers(self, other: BaseGeometry) -> bool: @@ -80,9 +95,13 @@ def covers(self, other: BaseGeometry) -> bool: elif isinstance(other, MultiPoint): return all([self.covers_point(point) for point in other]) elif isinstance(other, MultiPolygon): - return all([self.covers_linestring(polygon.exterior) for polygon in other.geoms]) + return all( + [self.covers_linestring(polygon.exterior) for polygon in other.geoms] + ) elif isinstance(other, MultiLineString): - return all([self.covers_linestring(linestring) for linestring in other.geoms]) + return all( + [self.covers_linestring(linestring) for linestring in other.geoms] + ) else: raise TypeError("Not supported") @@ -114,7 +133,12 @@ def _compute_envelope_internal(self): return self.MBR def __str__(self): - return "Circle of radius " + str(self.radius) + " around " + str(self.centerGeometry) + return ( + "Circle of radius " + + str(self.radius) + + " around " + + str(self.centerGeometry) + ) @property def __array_interface__(self): diff --git a/python/sedona/core/geom/shapely1/envelope.py b/python/sedona/core/geom/shapely1/envelope.py index a214c82771..bb16cbc30c 100644 --- a/python/sedona/core/geom/shapely1/envelope.py +++ b/python/sedona/core/geom/shapely1/envelope.py @@ -22,6 +22,7 @@ import math import pickle + class Envelope(Polygon): def __init__(self, minx=0, maxx=1, miny=0, maxy=1): @@ -29,27 +30,29 @@ def __init__(self, minx=0, maxx=1, miny=0, maxy=1): self.maxx = maxx self.miny = miny self.maxy = maxy - super().__init__([ - [self.minx, self.miny], - [self.minx, self.maxy], - [self.maxx, self.maxy], - [self.maxx, self.miny] - ]) + super().__init__( + [ + [self.minx, self.miny], + [self.minx, self.maxy], + [self.maxx, self.maxy], + [self.maxx, self.miny], + ] + ) def isClose(self, a, b) -> bool: return math.isclose(a, b, rel_tol=1e-9) def __eq__(self, other) -> bool: - return self.isClose(self.minx, other.minx) and\ - self.isClose(self.miny, other.miny) and\ - self.isClose(self.maxx, other.maxx) and\ - self.isClose(self.maxy, other.maxy) + return ( + self.isClose(self.minx, other.minx) + and self.isClose(self.miny, other.miny) + and self.isClose(self.maxx, other.maxx) + and self.isClose(self.maxy, other.maxy) + ) @require(["Envelope"]) def create_jvm_instance(self, jvm): - return jvm.Envelope( - self.minx, self.maxx, self.miny, self.maxy - ) + return jvm.Envelope(self.minx, self.maxx, self.miny, self.maxy) @classmethod def from_jvm_instance(cls, java_obj): @@ -62,6 +65,7 @@ def from_jvm_instance(cls, java_obj): def to_bytes(self): from sedona.utils.binary_parser import BinaryBuffer + bin_buffer = BinaryBuffer() bin_buffer.put_double(self.minx) bin_buffer.put_double(self.maxx) @@ -83,13 +87,16 @@ def from_shapely_geom(cls, geometry: BaseGeometry): return cls(min(x_coord), max(x_coord), min(y_coord), max(y_coord)) def __reduce__(self): - return (self.__class__, (), dict( - minx=self.minx, - maxx=self.maxx, - miny=self.miny, - maxy=self.maxy, - - )) + return ( + self.__class__, + (), + dict( + minx=self.minx, + maxx=self.maxx, + miny=self.miny, + maxy=self.maxy, + ), + ) def __getstate__(self): return dict( @@ -97,7 +104,6 @@ def __getstate__(self): maxx=self.maxx, miny=self.miny, maxy=self.maxy, - ) def __setstate__(self, state): diff --git a/python/sedona/core/geom/shapely2/circle.py b/python/sedona/core/geom/shapely2/circle.py index 6b8b3a082a..3147fc1f99 100644 --- a/python/sedona/core/geom/shapely2/circle.py +++ b/python/sedona/core/geom/shapely2/circle.py @@ -17,7 +17,15 @@ from math import sqrt -from shapely.geometry import Polygon, Point, LineString, MultiPoint, MultiPolygon, MultiLineString, box +from shapely.geometry import ( + Polygon, + Point, + LineString, + MultiPoint, + MultiPolygon, + MultiLineString, + box, +) from shapely.geometry.base import BaseGeometry from sedona.core.geom.envelope import Envelope @@ -67,9 +75,13 @@ def covers(self, other: BaseGeometry) -> bool: elif isinstance(other, MultiPoint): return all([self.covers_point(point) for point in other.geoms]) elif isinstance(other, MultiPolygon): - return all([self.covers_linestring(polygon.exterior) for polygon in other.geoms]) + return all( + [self.covers_linestring(polygon.exterior) for polygon in other.geoms] + ) elif isinstance(other, MultiLineString): - return all([self.covers_linestring(linestring) for linestring in other.geoms]) + return all( + [self.covers_linestring(linestring) for linestring in other.geoms] + ) else: raise TypeError("Not supported") @@ -103,4 +115,9 @@ def _compute_envelope_internal(self): return Envelope(minx, maxx, miny, maxy) def __str__(self): - return "Circle of radius " + str(self.getRadius()) + " around " + str(self.centerGeometry) + return ( + "Circle of radius " + + str(self.getRadius()) + + " around " + + str(self.centerGeometry) + ) diff --git a/python/sedona/core/geom/shapely2/envelope.py b/python/sedona/core/geom/shapely2/envelope.py index bf0c9dd30e..fb95434164 100644 --- a/python/sedona/core/geom/shapely2/envelope.py +++ b/python/sedona/core/geom/shapely2/envelope.py @@ -54,10 +54,12 @@ def isClose(self, a, b) -> bool: def __eq__(self, other) -> bool: minx, miny, maxx, maxy = self.bounds other_minx, other_miny, other_maxx, other_maxy = other.bounds - return self.isClose(minx, other_minx) and \ - self.isClose(miny, other_miny) and \ - self.isClose(maxx, other_maxx) and \ - self.isClose(maxy, other_maxy) + return ( + self.isClose(minx, other_minx) + and self.isClose(miny, other_miny) + and self.isClose(maxx, other_maxx) + and self.isClose(maxy, other_maxy) + ) @require(["Envelope"]) def create_jvm_instance(self, jvm): @@ -75,6 +77,7 @@ def from_jvm_instance(cls, java_obj): def to_bytes(self): from sedona.utils.binary_parser import BinaryBuffer + minx, miny, maxx, maxy = self.bounds bin_buffer = BinaryBuffer() bin_buffer.put_double(minx) @@ -108,6 +111,7 @@ class TmpEnvelopeForPickle: immutable. """ + def __init__(self, minx, maxx, miny, maxy): self.minx = minx self.maxx = maxx @@ -115,11 +119,7 @@ def __init__(self, minx, maxx, miny, maxy): self.maxy = maxy def __getstate__(self): - return dict( - minx=self.minx, - maxx=self.maxx, - miny=self.miny, - maxy=self.maxy) + return dict(minx=self.minx, maxx=self.maxx, miny=self.miny, maxy=self.maxy) def __setstate__(self, state): minx = state.get("minx", 0) diff --git a/python/sedona/core/jvm/config.py b/python/sedona/core/jvm/config.py index 96bc22ebea..d1fbc918d0 100644 --- a/python/sedona/core/jvm/config.py +++ b/python/sedona/core/jvm/config.py @@ -27,7 +27,8 @@ import inspect import warnings -string_types = (type(b''), type(u'')) +string_types = (type(b""), type("")) + def is_greater_or_equal_version(version_a: str, version_b: str) -> bool: if all([version_b, version_a]): @@ -92,7 +93,7 @@ def new_func1(*args, **kwargs): warnings.warn( fmt1.format(name=func1.__name__, reason=reason), category=DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return func1(*args, **kwargs) @@ -122,7 +123,7 @@ def new_func2(*args, **kwargs): warnings.warn( fmt2.format(name=func2.__name__), category=DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return func2(*args, **kwargs) diff --git a/python/sedona/core/jvm/translate.py b/python/sedona/core/jvm/translate.py index bddc5963dc..d291f8d824 100644 --- a/python/sedona/core/jvm/translate.py +++ b/python/sedona/core/jvm/translate.py @@ -30,7 +30,9 @@ def translate_spatial_pair_rdd_to_python(self, spatial_rdd): return self._jvm.PythonConverter.translateSpatialPairRDDToPython(spatial_rdd) def translate_spatial_pair_rdd_with_list_to_python(self, spatial_rdd): - return self._jvm.PythonConverter.translateSpatialPairRDDWithListToPython(spatial_rdd) + return self._jvm.PythonConverter.translateSpatialPairRDDWithListToPython( + spatial_rdd + ) def translate_python_rdd_to_java(self, java_rdd): return self._jvm.PythonConverter.translatePythonRDDToJava(java_rdd) @@ -103,13 +105,19 @@ def __init__(self, jvm): self._jvm = jvm def deserialize_to_point_raw_rdd(self, java_spatial_rdd): - return self._jvm.PythonRddToJavaRDDAdapter.deserializeToPointRawRDD(java_spatial_rdd) + return self._jvm.PythonRddToJavaRDDAdapter.deserializeToPointRawRDD( + java_spatial_rdd + ) def deserialize_to_polygon_raw_rdd(self, java_spatial_rdd): - return self._jvm.PythonRddToJavaRDDAdapter.deserializeToPolygonRawRDD(java_spatial_rdd) + return self._jvm.PythonRddToJavaRDDAdapter.deserializeToPolygonRawRDD( + java_spatial_rdd + ) def deserialize_to_linestring_raw_rdd(self, java_spatial_rdd): - return self._jvm.PythonRddToJavaRDDAdapter.deserializeToLineStringRawRDD(java_spatial_rdd) + return self._jvm.PythonRddToJavaRDDAdapter.deserializeToLineStringRawRDD( + java_spatial_rdd + ) class SpatialObjectLoaderAdapter: diff --git a/python/sedona/core/spatialOperator/__init__.py b/python/sedona/core/spatialOperator/__init__.py index 1022e395bb..fc9e5f6783 100644 --- a/python/sedona/core/spatialOperator/__init__.py +++ b/python/sedona/core/spatialOperator/__init__.py @@ -21,6 +21,4 @@ from .join_query_raw import JoinQueryRaw from .range_query_raw import RangeQueryRaw -__all__ = [ - "JoinQuery", "RangeQuery", "KNNQuery", "JoinQueryRaw", "RangeQueryRaw" -] +__all__ = ["JoinQuery", "RangeQuery", "KNNQuery", "JoinQueryRaw", "RangeQueryRaw"] diff --git a/python/sedona/core/spatialOperator/join_params.py b/python/sedona/core/spatialOperator/join_params.py index 232e6b279e..c77224af7f 100644 --- a/python/sedona/core/spatialOperator/join_params.py +++ b/python/sedona/core/spatialOperator/join_params.py @@ -30,7 +30,13 @@ class JoinParams: joinBuildSide = attr.ib(type=str, default=JoinBuildSide.LEFT) def jvm_instance(self, jvm): - return JvmJoinParams(jvm, self.useIndex, self.considerBoundaryIntersection, self.indexType, self.joinBuildSide).jvm_instance + return JvmJoinParams( + jvm, + self.useIndex, + self.considerBoundaryIntersection, + self.indexType, + self.joinBuildSide, + ).jvm_instance @attr.s @@ -41,7 +47,12 @@ class JvmJoinParams(JvmObject): joinBuildSide = attr.ib(type=str, default=JoinBuildSide.LEFT) def _create_jvm_instance(self): - return self.jvm_reference(self.useIndex, self.considerBoundaryIntersection, self.indexType.value, self.joinBuildSide) + return self.jvm_reference( + self.useIndex, + self.considerBoundaryIntersection, + self.indexType.value, + self.joinBuildSide, + ) @property def jvm_reference(self): diff --git a/python/sedona/core/spatialOperator/join_query.py b/python/sedona/core/spatialOperator/join_query.py index d4483e7c24..d6d5f55bc4 100644 --- a/python/sedona/core/spatialOperator/join_query.py +++ b/python/sedona/core/spatialOperator/join_query.py @@ -27,8 +27,13 @@ class JoinQuery: @classmethod @require(["JoinQuery"]) - def SpatialJoinQuery(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, - considerBoundaryIntersection: bool) -> RDD: + def SpatialJoinQuery( + cls, + spatialRDD: SpatialRDD, + queryRDD: SpatialRDD, + useIndex: bool, + considerBoundaryIntersection: bool, + ) -> RDD: """ :param spatialRDD: SpatialRDD @@ -38,13 +43,20 @@ def SpatialJoinQuery(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex :return: """ - pair_rdd = JoinQueryRaw.SpatialJoinQuery(spatialRDD, queryRDD, useIndex, considerBoundaryIntersection) + pair_rdd = JoinQueryRaw.SpatialJoinQuery( + spatialRDD, queryRDD, useIndex, considerBoundaryIntersection + ) return pair_rdd.to_rdd() @classmethod @require(["JoinQuery"]) - def DistanceJoinQuery(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, - considerBoundaryIntersection: bool) -> RDD: + def DistanceJoinQuery( + cls, + spatialRDD: SpatialRDD, + queryRDD: SpatialRDD, + useIndex: bool, + considerBoundaryIntersection: bool, + ) -> RDD: """ :param spatialRDD: SpatialRDD @@ -54,12 +66,16 @@ def DistanceJoinQuery(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useInde :return: """ - pair_rdd = JoinQueryRaw.DistanceJoinQuery(spatialRDD, queryRDD, useIndex, considerBoundaryIntersection) + pair_rdd = JoinQueryRaw.DistanceJoinQuery( + spatialRDD, queryRDD, useIndex, considerBoundaryIntersection + ) return pair_rdd.to_rdd() @classmethod @require(["JoinQuery"]) - def spatialJoin(cls, queryWindowRDD: SpatialRDD, objectRDD: SpatialRDD, joinParams: JoinParams) -> RDD: + def spatialJoin( + cls, queryWindowRDD: SpatialRDD, objectRDD: SpatialRDD, joinParams: JoinParams + ) -> RDD: """ :param queryWindowRDD: SpatialRDD @@ -73,8 +89,13 @@ def spatialJoin(cls, queryWindowRDD: SpatialRDD, objectRDD: SpatialRDD, joinPara @classmethod @require(["JoinQuery"]) - def DistanceJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, - considerBoundaryIntersection: bool) -> RDD: + def DistanceJoinQueryFlat( + cls, + spatialRDD: SpatialRDD, + queryRDD: SpatialRDD, + useIndex: bool, + considerBoundaryIntersection: bool, + ) -> RDD: """ :param spatialRDD: SpatialRDD @@ -90,14 +111,20 @@ def DistanceJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, use :return: """ - pair_rdd = JoinQueryRaw.DistanceJoinQueryFlat(spatialRDD, queryRDD, useIndex, - considerBoundaryIntersection) + pair_rdd = JoinQueryRaw.DistanceJoinQueryFlat( + spatialRDD, queryRDD, useIndex, considerBoundaryIntersection + ) return pair_rdd.to_rdd() @classmethod @require(["JoinQuery"]) - def SpatialJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, - considerBoundaryIntersection: bool) -> RDD: + def SpatialJoinQueryFlat( + cls, + spatialRDD: SpatialRDD, + queryRDD: SpatialRDD, + useIndex: bool, + considerBoundaryIntersection: bool, + ) -> RDD: """ Function takes SpatialRDD and other SpatialRDD and based on two parameters - useIndex @@ -117,6 +144,7 @@ def SpatialJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useI [[GeoData(Polygon, ), GeoData()], [GeoData(), GeoData()], [GeoData(), GeoData()]] """ - pair_rdd = JoinQueryRaw.SpatialJoinQueryFlat(spatialRDD, queryRDD, useIndex, - considerBoundaryIntersection) + pair_rdd = JoinQueryRaw.SpatialJoinQueryFlat( + spatialRDD, queryRDD, useIndex, considerBoundaryIntersection + ) return pair_rdd.to_rdd() diff --git a/python/sedona/core/spatialOperator/join_query_raw.py b/python/sedona/core/spatialOperator/join_query_raw.py index 4e1b54e210..d1b0f10a9e 100644 --- a/python/sedona/core/spatialOperator/join_query_raw.py +++ b/python/sedona/core/spatialOperator/join_query_raw.py @@ -25,75 +25,90 @@ class JoinQueryRaw: @classmethod @require(["JoinQuery"]) - def SpatialJoinQuery(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> SedonaPairRDDList: + def SpatialJoinQuery( + cls, + spatialRDD: SpatialRDD, + queryRDD: SpatialRDD, + useIndex: bool, + considerBoundaryIntersection: bool, + ) -> SedonaPairRDDList: jvm = spatialRDD._jvm sc = spatialRDD._sc srdd = jvm.JoinQuery.SpatialJoinQuery( - spatialRDD._srdd, - queryRDD._srdd, - useIndex, - considerBoundaryIntersection + spatialRDD._srdd, queryRDD._srdd, useIndex, considerBoundaryIntersection ) return SedonaPairRDDList(srdd, sc) @classmethod @require(["JoinQuery"]) - def DistanceJoinQuery(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> SedonaPairRDDList: + def DistanceJoinQuery( + cls, + spatialRDD: SpatialRDD, + queryRDD: SpatialRDD, + useIndex: bool, + considerBoundaryIntersection: bool, + ) -> SedonaPairRDDList: jvm = spatialRDD._jvm sc = spatialRDD._sc srdd = jvm.JoinQuery.DistanceJoinQuery( - spatialRDD._srdd, - queryRDD._srdd, - useIndex, - considerBoundaryIntersection + spatialRDD._srdd, queryRDD._srdd, useIndex, considerBoundaryIntersection ) return SedonaPairRDDList(srdd, sc) @classmethod @require(["JoinQuery"]) - def spatialJoin(cls, queryWindowRDD: SpatialRDD, objectRDD: SpatialRDD, joinParams: JoinParams) -> SedonaPairRDD: + def spatialJoin( + cls, queryWindowRDD: SpatialRDD, objectRDD: SpatialRDD, joinParams: JoinParams + ) -> SedonaPairRDD: jvm = queryWindowRDD._jvm sc = queryWindowRDD._sc jvm_join_params = joinParams.jvm_instance(jvm) - srdd = jvm.JoinQuery.spatialJoin(queryWindowRDD._srdd, objectRDD._srdd, jvm_join_params) + srdd = jvm.JoinQuery.spatialJoin( + queryWindowRDD._srdd, objectRDD._srdd, jvm_join_params + ) return SedonaPairRDD(srdd, sc) @classmethod @require(["JoinQuery"]) - def DistanceJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> SedonaPairRDD: + def DistanceJoinQueryFlat( + cls, + spatialRDD: SpatialRDD, + queryRDD: SpatialRDD, + useIndex: bool, + considerBoundaryIntersection: bool, + ) -> SedonaPairRDD: jvm = spatialRDD._jvm sc = spatialRDD._sc spatial_join = jvm.JoinQuery.DistanceJoinQueryFlat srdd = spatial_join( - spatialRDD._srdd, - queryRDD._srdd, - useIndex, - considerBoundaryIntersection + spatialRDD._srdd, queryRDD._srdd, useIndex, considerBoundaryIntersection ) return SedonaPairRDD(srdd, sc) @classmethod @require(["JoinQuery"]) - def SpatialJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, - considerBoundaryIntersection: bool) -> SedonaPairRDD: + def SpatialJoinQueryFlat( + cls, + spatialRDD: SpatialRDD, + queryRDD: SpatialRDD, + useIndex: bool, + considerBoundaryIntersection: bool, + ) -> SedonaPairRDD: jvm = spatialRDD._jvm sc = spatialRDD._sc spatial_join = jvm.JoinQuery.SpatialJoinQueryFlat srdd = spatial_join( - spatialRDD._srdd, - queryRDD._srdd, - useIndex, - considerBoundaryIntersection + spatialRDD._srdd, queryRDD._srdd, useIndex, considerBoundaryIntersection ) return SedonaPairRDD(srdd, sc) diff --git a/python/sedona/core/spatialOperator/knn_query.py b/python/sedona/core/spatialOperator/knn_query.py index 1d4c1224ad..9e6905ce66 100644 --- a/python/sedona/core/spatialOperator/knn_query.py +++ b/python/sedona/core/spatialOperator/knn_query.py @@ -31,7 +31,13 @@ class KNNQuery: @classmethod @require(["KNNQuery", "GeometryAdapter"]) - def SpatialKnnQuery(self, spatialRDD: SpatialRDD, originalQueryPoint: BaseGeometry, k: int, useIndex: bool): + def SpatialKnnQuery( + self, + spatialRDD: SpatialRDD, + originalQueryPoint: BaseGeometry, + k: int, + useIndex: bool, + ): """ :param spatialRDD: spatialRDD @@ -41,11 +47,17 @@ def SpatialKnnQuery(self, spatialRDD: SpatialRDD, originalQueryPoint: BaseGeomet :return: pyspark.RDD """ jvm = spatialRDD._jvm - jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(jvm, originalQueryPoint) + jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry( + jvm, originalQueryPoint + ) - knn_neighbours = jvm.KNNQuery.SpatialKnnQuery(spatialRDD._srdd, jvm_geom, k, useIndex) + knn_neighbours = jvm.KNNQuery.SpatialKnnQuery( + spatialRDD._srdd, jvm_geom, k, useIndex + ) - srdd = JvmSedonaPythonConverter(jvm).translate_geometry_seq_to_python(knn_neighbours) + srdd = JvmSedonaPythonConverter(jvm).translate_geometry_seq_to_python( + knn_neighbours + ) geoms_data = [] for arr in srdd: diff --git a/python/sedona/core/spatialOperator/range_query.py b/python/sedona/core/spatialOperator/range_query.py index df44ee3744..b76ab8a77a 100644 --- a/python/sedona/core/spatialOperator/range_query.py +++ b/python/sedona/core/spatialOperator/range_query.py @@ -26,8 +26,13 @@ class RangeQuery: @classmethod @require(["RangeQuery", "GeometryAdapter", "GeoSerializerData"]) - def SpatialRangeQuery(self, spatialRDD: SpatialRDD, rangeQueryWindow: BaseGeometry, - considerBoundaryIntersection: bool, usingIndex: bool): + def SpatialRangeQuery( + self, + spatialRDD: SpatialRDD, + rangeQueryWindow: BaseGeometry, + considerBoundaryIntersection: bool, + usingIndex: bool, + ): """ :param spatialRDD: @@ -36,5 +41,7 @@ def SpatialRangeQuery(self, spatialRDD: SpatialRDD, rangeQueryWindow: BaseGeomet :param usingIndex: :return: """ - j_srdd = RangeQueryRaw.SpatialRangeQuery(spatialRDD, rangeQueryWindow, considerBoundaryIntersection, usingIndex) + j_srdd = RangeQueryRaw.SpatialRangeQuery( + spatialRDD, rangeQueryWindow, considerBoundaryIntersection, usingIndex + ) return j_srdd.to_rdd() diff --git a/python/sedona/core/spatialOperator/range_query_raw.py b/python/sedona/core/spatialOperator/range_query_raw.py index dced4cd469..e20646ccb8 100644 --- a/python/sedona/core/spatialOperator/range_query_raw.py +++ b/python/sedona/core/spatialOperator/range_query_raw.py @@ -27,8 +27,13 @@ class RangeQueryRaw: @classmethod @require(["RangeQuery", "GeometryAdapter", "GeoSerializerData"]) - def SpatialRangeQuery(self, spatialRDD: SpatialRDD, rangeQueryWindow: BaseGeometry, - considerBoundaryIntersection: bool, usingIndex: bool) -> SedonaRDD: + def SpatialRangeQuery( + self, + spatialRDD: SpatialRDD, + rangeQueryWindow: BaseGeometry, + considerBoundaryIntersection: bool, + usingIndex: bool, + ) -> SedonaRDD: """ :param spatialRDD: @@ -41,14 +46,12 @@ def SpatialRangeQuery(self, spatialRDD: SpatialRDD, rangeQueryWindow: BaseGeomet jvm = spatialRDD._jvm sc = spatialRDD._sc - jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(jvm, rangeQueryWindow) + jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry( + jvm, rangeQueryWindow + ) - srdd = jvm. \ - RangeQuery.SpatialRangeQuery( - spatialRDD._srdd, - jvm_geom, - considerBoundaryIntersection, - usingIndex + srdd = jvm.RangeQuery.SpatialRangeQuery( + spatialRDD._srdd, jvm_geom, considerBoundaryIntersection, usingIndex ) return SedonaRDD(srdd, sc) diff --git a/python/sedona/core/spatialOperator/rdd.py b/python/sedona/core/spatialOperator/rdd.py index 61969a12d3..d96c65724a 100644 --- a/python/sedona/core/spatialOperator/rdd.py +++ b/python/sedona/core/spatialOperator/rdd.py @@ -29,8 +29,9 @@ def __init__(self, jsrdd, sc: SparkContext): def to_rdd(self) -> RDD: jvm = self.sc._jvm - serialized = JvmSedonaPythonConverter(jvm). \ - translate_spatial_rdd_to_python(self.jsrdd) + serialized = JvmSedonaPythonConverter(jvm).translate_spatial_rdd_to_python( + self.jsrdd + ) return RDD(serialized, self.sc, SedonaPickler()) @@ -43,8 +44,9 @@ def __init__(self, jsrdd, sc: SparkContext): def to_rdd(self) -> RDD: jvm = self.sc._jvm - serialized = JvmSedonaPythonConverter(jvm). \ - translate_spatial_pair_rdd_to_python(self.jsrdd) + serialized = JvmSedonaPythonConverter(jvm).translate_spatial_pair_rdd_to_python( + self.jsrdd + ) return RDD(serialized, self.sc, SedonaPickler()) @@ -57,7 +59,8 @@ def __init__(self, jsrdd, sc: SparkContext): def to_rdd(self): jvm = self.sc._jvm - serialized = JvmSedonaPythonConverter(jvm). \ - translate_spatial_pair_rdd_with_list_to_python(self.jsrdd) + serialized = JvmSedonaPythonConverter( + jvm + ).translate_spatial_pair_rdd_with_list_to_python(self.jsrdd) return RDD(serialized, self.sc, SedonaPickler()) diff --git a/python/sedona/core/utils.py b/python/sedona/core/utils.py index b9990ec38e..f5ebada83b 100644 --- a/python/sedona/core/utils.py +++ b/python/sedona/core/utils.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. + class ImportedJvmLib: _imported_libs = [] diff --git a/python/sedona/exceptions.py b/python/sedona/exceptions.py index 18788608f0..256c1ab9db 100644 --- a/python/sedona/exceptions.py +++ b/python/sedona/exceptions.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. + class InvalidParametersException(Exception): """ Exception added to handle invalid constructor parameters diff --git a/python/sedona/maps/SedonaMapUtils.py b/python/sedona/maps/SedonaMapUtils.py index 52cc47402e..b4f68fb4cf 100644 --- a/python/sedona/maps/SedonaMapUtils.py +++ b/python/sedona/maps/SedonaMapUtils.py @@ -34,7 +34,9 @@ def __convert_to_gdf_or_pdf__(cls, df, rename=True, geometry_col=None): if geometry_col is None: geometry_col = SedonaMapUtils.__get_geometry_col__(df) pandas_df = df.toPandas() - if geometry_col is None: # No geometry column found even after searching schema, return Pandas Dataframe + if ( + geometry_col is None + ): # No geometry column found even after searching schema, return Pandas Dataframe return pandas_df try: import geopandas as gpd @@ -73,14 +75,14 @@ def __extract_coordinate__(cls, geom, type_list): geom_type = geom.geom_type if geom_type not in type_list: type_list.append(geom_type) - if geom_type == 'Polygon': + if geom_type == "Polygon": return geom.exterior.coords[0] else: return geom.coords[0] @classmethod def __extract_point_coordinate__(cls, geom): - if geom.geom_type == 'Point': + if geom.geom_type == "Point": return geom.coords[0] @classmethod @@ -91,5 +93,9 @@ def _extract_first_sub_geometry_(cls, geom): @classmethod def __is_geom_collection__(cls, geom_type): - return geom_type == 'MultiPolygon' or geom_type == 'MultiLineString' or geom_type == 'MultiPoint' \ - or geom_type == 'GeometryCollection' + return ( + geom_type == "MultiPolygon" + or geom_type == "MultiLineString" + or geom_type == "MultiPoint" + or geom_type == "GeometryCollection" + ) diff --git a/python/sedona/maps/SedonaPyDeck.py b/python/sedona/maps/SedonaPyDeck.py index 7c2d8a763f..367fffc9c5 100644 --- a/python/sedona/maps/SedonaPyDeck.py +++ b/python/sedona/maps/SedonaPyDeck.py @@ -17,7 +17,15 @@ from types import ModuleType -from pyspark.sql.types import FloatType, DoubleType, IntegerType, LongType, DecimalType, ShortType, ByteType +from pyspark.sql.types import ( + FloatType, + DoubleType, + IntegerType, + LongType, + DecimalType, + ShortType, + ByteType, +) from sedona.maps.SedonaMapUtils import SedonaMapUtils @@ -26,8 +34,18 @@ class SedonaPyDeck: # User Facing APIs @classmethod - def create_choropleth_map(cls, df, fill_color=None, plot_col=None, initial_view_state=None, map_style=None, - map_provider=None, elevation_col=0, api_keys=None, stroked=True): + def create_choropleth_map( + cls, + df, + fill_color=None, + plot_col=None, + initial_view_state=None, + map_style=None, + map_provider=None, + elevation_col=0, + api_keys=None, + stroked=True, + ): """ Create a pydeck map with a choropleth layer added :param elevation_col: Optional elevation for the polygons @@ -43,17 +61,28 @@ def create_choropleth_map(cls, df, fill_color=None, plot_col=None, initial_view_ """ for field in df.schema.fields: - if field.name == plot_col and field.dataType not in [IntegerType(), FloatType(), DoubleType(), LongType(), - DecimalType(), ShortType(), ByteType()]: - message = (f"'{field.name}' must be of numeric type. \nIf you are importing data from csv set " - f"'inferSchema' option as true") + if field.name == plot_col and field.dataType not in [ + IntegerType(), + FloatType(), + DoubleType(), + LongType(), + DecimalType(), + ShortType(), + ByteType(), + ]: + message = ( + f"'{field.name}' must be of numeric type. \nIf you are importing data from csv set " + f"'inferSchema' option as true" + ) raise TypeError(message) from None pdk = _try_import_pydeck() if initial_view_state is None: gdf = SedonaPyDeck._prepare_df_(df, add_coords=True) - initial_view_state = pdk.data_utils.compute_view(gdf['coordinate_array_sedona']) + initial_view_state = pdk.data_utils.compute_view( + gdf["coordinate_array_sedona"] + ) else: gdf = SedonaPyDeck._prepare_df_(df) @@ -61,7 +90,7 @@ def create_choropleth_map(cls, df, fill_color=None, plot_col=None, initial_view_ fill_color = SedonaPyDeck._create_default_fill_color_(gdf, plot_col) choropleth_layer = pdk.Layer( - 'GeoJsonLayer', # `type` positional argument is here + "GeoJsonLayer", # `type` positional argument is here data=gdf, auto_highlight=True, get_fill_color=fill_color, @@ -70,16 +99,30 @@ def create_choropleth_map(cls, df, fill_color=None, plot_col=None, initial_view_ stroked=stroked, extruded=True, wireframe=True, - pickable=True + pickable=True, ) - return SedonaPyDeck._create_map_obj(layer=choropleth_layer, initial_view_state=initial_view_state, map_style=map_style, - map_provider=map_provider, api_keys=api_keys) + return SedonaPyDeck._create_map_obj( + layer=choropleth_layer, + initial_view_state=initial_view_state, + map_style=map_style, + map_provider=map_provider, + api_keys=api_keys, + ) @classmethod - def create_geometry_map(cls, df, fill_color="[85, 183, 177, 255]", line_color="[85, 183, 177, 255]", - elevation_col=0, initial_view_state=None, - map_style=None, map_provider=None, api_keys=None, stroked=True): + def create_geometry_map( + cls, + df, + fill_color="[85, 183, 177, 255]", + line_color="[85, 183, 177, 255]", + elevation_col=0, + initial_view_state=None, + map_style=None, + map_provider=None, + api_keys=None, + stroked=True, + ): """ Create a pydeck map with a GeoJsonLayer added for plotting given geometries. :param line_color: @@ -102,22 +145,44 @@ def create_geometry_map(cls, df, fill_color="[85, 183, 177, 255]", line_color="[ # if line_color == "[85, 183, 177, 255]": # line_color = "[237, 119, 79]" - layer = SedonaPyDeck._create_fat_layer_(gdf, fill_color=fill_color, elevation_col=elevation_col, - line_color=line_color, stroked=stroked) + layer = SedonaPyDeck._create_fat_layer_( + gdf, + fill_color=fill_color, + elevation_col=elevation_col, + line_color=line_color, + stroked=stroked, + ) if initial_view_state is None: - initial_view_state = pdk.data_utils.compute_view(gdf['coordinate_array_sedona']) + initial_view_state = pdk.data_utils.compute_view( + gdf["coordinate_array_sedona"] + ) if elevation_col != 0 and geom_type == "Polygon": initial_view_state.pitch = 45 # If polygons are elevated, change the pitch to visualize the elevation better - return SedonaPyDeck._create_map_obj(layer=layer, initial_view_state=initial_view_state, map_style=map_style, - map_provider=map_provider, api_keys=api_keys) + return SedonaPyDeck._create_map_obj( + layer=layer, + initial_view_state=initial_view_state, + map_style=map_style, + map_provider=map_provider, + api_keys=api_keys, + ) @classmethod - def create_scatterplot_map(cls, df, fill_color="[255, 140, 0]", radius_col=1, radius_min_pixels=1, - radius_max_pixels=10, radius_scale=1, initial_view_state=None, map_style=None, - map_provider=None, api_keys=None): + def create_scatterplot_map( + cls, + df, + fill_color="[255, 140, 0]", + radius_col=1, + radius_min_pixels=1, + radius_max_pixels=10, + radius_scale=1, + initial_view_state=None, + map_style=None, + map_provider=None, + api_keys=None, + ): """ Create a pydeck map with a scatterplot layer :param radius_scale: @@ -145,19 +210,35 @@ def create_scatterplot_map(cls, df, fill_color="[255, 140, 0]", radius_col=1, ra radius_min_pixels=radius_min_pixels, radius_max_pixels=radius_max_pixels, radius_scale=radius_scale, - get_position='coordinate_array_sedona', + get_position="coordinate_array_sedona", get_fill_color=fill_color, ) if initial_view_state is None: - initial_view_state = pdk.data_utils.compute_view(gdf['coordinate_array_sedona']) - - return SedonaPyDeck._create_map_obj(layer=layer, initial_view_state=initial_view_state, map_style=map_style, - map_provider=map_provider, api_keys=api_keys) + initial_view_state = pdk.data_utils.compute_view( + gdf["coordinate_array_sedona"] + ) + + return SedonaPyDeck._create_map_obj( + layer=layer, + initial_view_state=initial_view_state, + map_style=map_style, + map_provider=map_provider, + api_keys=api_keys, + ) @classmethod - def create_heatmap(cls, df, color_range=None, weight=1, aggregation="SUM", initial_view_state=None, map_style=None, - map_provider=None, api_keys=None): + def create_heatmap( + cls, + df, + color_range=None, + weight=1, + aggregation="SUM", + initial_view_state=None, + map_style=None, + map_provider=None, + api_keys=None, + ): """ Create a pydeck map with a heatmap layer added :param df: SedonaDataFrame to be used to plot the heatmap @@ -181,7 +262,7 @@ def create_heatmap(cls, df, color_range=None, weight=1, aggregation="SUM", initi [254, 178, 76], [253, 141, 60], [240, 59, 32], - [240, 59, 32] + [240, 59, 32], ] layer = pdk.Layer( @@ -190,17 +271,24 @@ def create_heatmap(cls, df, color_range=None, weight=1, aggregation="SUM", initi pickable=True, opacity=0.8, filled=True, - get_position='coordinate_array_sedona', + get_position="coordinate_array_sedona", aggregation=pdk.types.String(aggregation), color_range=color_range, - get_weight=weight + get_weight=weight, ) if initial_view_state is None: - initial_view_state = pdk.data_utils.compute_view(gdf['coordinate_array_sedona']) - - return SedonaPyDeck._create_map_obj(layer=layer, initial_view_state=initial_view_state, map_style=map_style, - map_provider=map_provider, api_keys=api_keys) + initial_view_state = pdk.data_utils.compute_view( + gdf["coordinate_array_sedona"] + ) + + return SedonaPyDeck._create_map_obj( + layer=layer, + initial_view_state=initial_view_state, + map_style=map_style, + map_provider=map_provider, + api_keys=api_keys, + ) @classmethod def _prepare_df_(cls, df, add_coords=False, geometry_col=None): @@ -212,12 +300,13 @@ def _prepare_df_(cls, df, add_coords=False, geometry_col=None): """ if geometry_col is None: geometry_col = SedonaMapUtils.__get_geometry_col__(df=df) - gdf = SedonaMapUtils.__convert_to_gdf_or_pdf__(df, rename=False, geometry_col=geometry_col) + gdf = SedonaMapUtils.__convert_to_gdf_or_pdf__( + df, rename=False, geometry_col=geometry_col + ) if add_coords is True: SedonaPyDeck._create_coord_column_(gdf=gdf, geometry_col=geometry_col) return gdf - @classmethod def _create_default_fill_color_(cls, gdf, plot_col): """ @@ -236,14 +325,18 @@ def _create_coord_column_(cls, gdf, geometry_col, add_points=False): :param geometry_col: column with ST_Points """ type_list = [] - gdf['coordinate_array_sedona'] = gdf.apply( - lambda val: list(SedonaMapUtils.__extract_coordinate__(val[geometry_col], type_list)), axis=1) + gdf["coordinate_array_sedona"] = gdf.apply( + lambda val: list( + SedonaMapUtils.__extract_coordinate__(val[geometry_col], type_list) + ), + axis=1, + ) @classmethod def _create_fat_layer_(cls, gdf, fill_color, line_color, elevation_col, stroked): pdk = _try_import_pydeck() layer = pdk.Layer( - 'GeoJsonLayer', # `type` positional argument is here + "GeoJsonLayer", # `type` positional argument is here data=gdf, auto_highlight=True, get_fill_color=fill_color, @@ -253,31 +346,38 @@ def _create_fat_layer_(cls, gdf, fill_color, line_color, elevation_col, stroked) get_elevation=elevation_col, get_line_color=line_color, pickable=True, - get_line_width=3 + get_line_width=3, ) return layer # creates the final map object and handles the parameters for pydeck @classmethod - def _create_map_obj(cls, layer, initial_view_state, map_style, map_provider, api_keys): + def _create_map_obj( + cls, layer, initial_view_state, map_style, map_provider, api_keys + ): pdk = _try_import_pydeck() if map_provider is None: - map_provider = 'carto' + map_provider = "carto" if map_style is None: - map_style = 'dark' + map_style = "dark" # Default to mapbox if user selects 'satellite' map_style as pydeck uses mapbox and google maps for satellite basemap - if map_style == 'satellite' and map_provider != 'google_maps': - map_provider = 'mapbox' + if map_style == "satellite" and map_provider != "google_maps": + map_provider = "mapbox" if api_keys is not None: api_keys = api_keys - return pdk.Deck(layers=[layer], initial_view_state=initial_view_state, map_style=map_style, - map_provider=map_provider, api_keys=api_keys) + return pdk.Deck( + layers=[layer], + initial_view_state=initial_view_state, + map_style=map_style, + map_provider=map_provider, + api_keys=api_keys, + ) def _try_import_pydeck() -> ModuleType: diff --git a/python/sedona/raster/awt_raster.py b/python/sedona/raster/awt_raster.py index d951359436..e95935e3f6 100644 --- a/python/sedona/raster/awt_raster.py +++ b/python/sedona/raster/awt_raster.py @@ -20,9 +20,8 @@ class AWTRaster: - """Raster data structure of Java AWT Raster used by GeoTools GridCoverage2D. + """Raster data structure of Java AWT Raster used by GeoTools GridCoverage2D.""" - """ min_x: int min_y: int width: int @@ -30,7 +29,15 @@ class AWTRaster: sample_model: SampleModel data_buffer: DataBuffer - def __init__(self, min_x, min_y, width, height, sample_model: SampleModel, data_buffer: DataBuffer): + def __init__( + self, + min_x, + min_y, + width, + height, + sample_model: SampleModel, + data_buffer: DataBuffer, + ): if sample_model.width != width or sample_model.height != height: raise RuntimeError("Size of the image does not match with the sample model") self.min_x = min_x diff --git a/python/sedona/raster/data_buffer.py b/python/sedona/raster/data_buffer.py index 8826e26bdc..28b28bbe3f 100644 --- a/python/sedona/raster/data_buffer.py +++ b/python/sedona/raster/data_buffer.py @@ -32,7 +32,9 @@ class DataBuffer: size: int offsets: List[int] - def __init__(self, data_type: int, bank_data: List[np.ndarray], size: int, offsets: List[int]): + def __init__( + self, data_type: int, bank_data: List[np.ndarray], size: int, offsets: List[int] + ): self.data_type = data_type self.bank_data = bank_data self.size = size diff --git a/python/sedona/raster/meta.py b/python/sedona/raster/meta.py index b0013359dd..a8ed687dc6 100644 --- a/python/sedona/raster/meta.py +++ b/python/sedona/raster/meta.py @@ -26,6 +26,7 @@ class PixelAnchor(Enum): transformation between these conventions. """ + CENTER = 1 UPPER_LEFT = 2 @@ -39,7 +40,9 @@ class AffineTransform: ip_y: float pixel_anchor: PixelAnchor - def __init__(self, scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, pixel_anchor: PixelAnchor): + def __init__( + self, scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, pixel_anchor: PixelAnchor + ): self.scale_x = scale_x self.skew_y = skew_y self.skew_x = skew_x @@ -56,8 +59,15 @@ def with_anchor(self, pixel_anchor: PixelAnchor): def translate(self, offset_x: float, offset_y: float): new_ipx = self.ip_x + offset_x * self.scale_x + offset_y * self.skew_x new_ipy = self.ip_y + offset_x * self.skew_y + offset_y * self.scale_y - return AffineTransform(self.scale_x, self.skew_y, self.skew_x, self.scale_y, - new_ipx, new_ipy, self.pixel_anchor) + return AffineTransform( + self.scale_x, + self.skew_y, + self.skew_x, + self.scale_y, + new_ipx, + new_ipy, + self.pixel_anchor, + ) def _do_change_pixel_anchor(self, from_anchor: PixelAnchor, to_anchor: PixelAnchor): assert from_anchor != to_anchor @@ -88,18 +98,21 @@ def _do_change_pixel_anchor(self, from_anchor: PixelAnchor, to_anchor: PixelAnch new_m10 = old_m10 * m00 + old_m11 * m10 new_m11 = old_m10 * m01 + old_m11 * m11 new_m12 = old_m10 * m02 + old_m11 * m12 + old_m12 - return AffineTransform(new_m00, new_m10, new_m01, new_m11, new_m02, new_m12, to_anchor) + return AffineTransform( + new_m00, new_m10, new_m01, new_m11, new_m02, new_m12, to_anchor + ) def __repr__(self): - return ("[ {} {} {}\n".format(self.scale_x, self.skew_x, self.ip_x) + - " {} {} {}\n".format(self.skew_y, self.scale_y, self.ip_y) + - " 0 0 1 ]") + return ( + "[ {} {} {}\n".format(self.scale_x, self.skew_x, self.ip_x) + + " {} {} {}\n".format(self.skew_y, self.scale_y, self.ip_y) + + " 0 0 1 ]" + ) class SampleDimension: - """Raster band metadata. + """Raster band metadata.""" - """ description: str offset: float scale: float diff --git a/python/sedona/raster/raster_serde.py b/python/sedona/raster/raster_serde.py index 63b740c5a3..2bc3968458 100644 --- a/python/sedona/raster/raster_serde.py +++ b/python/sedona/raster/raster_serde.py @@ -21,7 +21,13 @@ import zlib import numpy as np -from .sample_model import SampleModel, ComponentSampleModel, PixelInterleavedSampleModel, MultiPixelPackedSampleModel, SinglePixelPackedSampleModel +from .sample_model import ( + SampleModel, + ComponentSampleModel, + PixelInterleavedSampleModel, + MultiPixelPackedSampleModel, + SinglePixelPackedSampleModel, +) from .data_buffer import DataBuffer from .awt_raster import AWTRaster from .meta import AffineTransform, PixelAnchor, SampleDimension @@ -52,7 +58,9 @@ def _deserialize(bio: BytesIO, raster_type: int) -> SedonaRaster: if raster_type == RasterTypes.IN_DB: # In-DB raster awt_raster = _read_awt_raster(bio) - return InDbSedonaRaster(width, height, bands_meta, affine_trans, crs_wkt, awt_raster) + return InDbSedonaRaster( + width, height, bands_meta, affine_trans, crs_wkt, awt_raster + ) else: raise ValueError("unsupported raster_type: {}".format(raster_type)) @@ -63,19 +71,23 @@ def _read_grid_envelope(bio: BytesIO) -> Tuple[int, int, int, int]: def _read_affine_transformation(bio: BytesIO) -> AffineTransform: - scale_x, skew_y, skew_x, scale_y, ip_x, ip_y = struct.unpack("=dddddd", bio.read(8 * 6)) - return AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.CENTER) + scale_x, skew_y, skew_x, scale_y, ip_x, ip_y = struct.unpack( + "=dddddd", bio.read(8 * 6) + ) + return AffineTransform( + scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.CENTER + ) def _read_crs_wkt(bio: BytesIO) -> str: - size, = struct.unpack("=i", bio.read(4)) + (size,) = struct.unpack("=i", bio.read(4)) compressed_wkt = bio.read(size) crs_wkt = zlib.decompress(compressed_wkt) - return crs_wkt.decode('utf-8') + return crs_wkt.decode("utf-8") def _read_sample_dimensions(bio: BytesIO) -> List[SampleDimension]: - num_bands, = struct.unpack("=i", bio.read(4)) + (num_bands,) = struct.unpack("=i", bio.read(4)) bands_meta = [] for i in range(num_bands): description = _read_utf8_string(bio) @@ -87,51 +99,76 @@ def _read_sample_dimensions(bio: BytesIO) -> List[SampleDimension]: def _read_awt_raster(bio: BytesIO) -> AWTRaster: min_x, min_y, width, height = struct.unpack("=iiii", bio.read(4 * 4)) - _ignore_java_object(bio) # image properties - _ignore_java_object(bio) # color model + _ignore_java_object(bio) # image properties + _ignore_java_object(bio) # color model min_x_1, min_y_1 = struct.unpack("=ii", bio.read(4 * 2)) if min_x_1 != min_x or min_y_1 != min_y: - raise RuntimeError("malformed serialized raster: minx/miny of the image cannot match with minx/miny of the AWT raster") + raise RuntimeError( + "malformed serialized raster: minx/miny of the image cannot match with minx/miny of the AWT raster" + ) sample_model = _read_sample_model(bio) data_buffer = _read_data_buffer(bio) return AWTRaster(min_x, min_y, width, height, sample_model, data_buffer) def _read_sample_model(bio: BytesIO) -> SampleModel: - sample_model_type, data_type, width, height = struct.unpack("=iiii", bio.read(4 * 4)) + sample_model_type, data_type, width, height = struct.unpack( + "=iiii", bio.read(4 * 4) + ) if sample_model_type == SampleModel.TYPE_BANDED: bank_indices = _read_int_array(bio) band_offsets = _read_int_array(bio) - return ComponentSampleModel(data_type, width, height, 1, width, bank_indices, band_offsets) + return ComponentSampleModel( + data_type, width, height, 1, width, bank_indices, band_offsets + ) elif sample_model_type == SampleModel.TYPE_PIXEL_INTERLEAVED: pixel_stride, scanline_stride = struct.unpack("=ii", bio.read(4 * 2)) band_offsets = _read_int_array(bio) - return PixelInterleavedSampleModel(data_type, width, height, pixel_stride, scanline_stride, band_offsets) - elif sample_model_type in [SampleModel.TYPE_COMPONENT, SampleModel.TYPE_COMPONENT_JAI]: + return PixelInterleavedSampleModel( + data_type, width, height, pixel_stride, scanline_stride, band_offsets + ) + elif sample_model_type in [ + SampleModel.TYPE_COMPONENT, + SampleModel.TYPE_COMPONENT_JAI, + ]: pixel_stride, scanline_stride = struct.unpack("=ii", bio.read(4 * 2)) bank_indices = _read_int_array(bio) band_offsets = _read_int_array(bio) - return ComponentSampleModel(data_type, width, height, pixel_stride, scanline_stride, bank_indices, band_offsets) + return ComponentSampleModel( + data_type, + width, + height, + pixel_stride, + scanline_stride, + bank_indices, + band_offsets, + ) elif sample_model_type == SampleModel.TYPE_SINGLE_PIXEL_PACKED: - scanline_stride, = struct.unpack("=i", bio.read(4)) + (scanline_stride,) = struct.unpack("=i", bio.read(4)) bit_masks = _read_int_array(bio) - return SinglePixelPackedSampleModel(data_type, width, height, scanline_stride, bit_masks) + return SinglePixelPackedSampleModel( + data_type, width, height, scanline_stride, bit_masks + ) elif sample_model_type == SampleModel.TYPE_MULTI_PIXEL_PACKED: - num_bits, scanline_stride, data_bit_offset = struct.unpack("=iii", bio.read(4 * 3)) - return MultiPixelPackedSampleModel(data_type, width, height, num_bits, scanline_stride, data_bit_offset) + num_bits, scanline_stride, data_bit_offset = struct.unpack( + "=iii", bio.read(4 * 3) + ) + return MultiPixelPackedSampleModel( + data_type, width, height, num_bits, scanline_stride, data_bit_offset + ) else: raise RuntimeError(f"Unsupported SampleModel type: {sample_model_type}") def _read_data_buffer(bio: BytesIO) -> DataBuffer: - data_type, = struct.unpack("=i", bio.read(4)) + (data_type,) = struct.unpack("=i", bio.read(4)) offsets = _read_int_array(bio) - size, = struct.unpack("=i", bio.read(4)) + (size,) = struct.unpack("=i", bio.read(4)) - num_banks, = struct.unpack("=i", bio.read(4)) + (num_banks,) = struct.unpack("=i", bio.read(4)) banks = [] for i in range(num_banks): - bank_size, = struct.unpack("=i", bio.read(4)) + (bank_size,) = struct.unpack("=i", bio.read(4)) if data_type == DataBuffer.TYPE_BYTE: np_array = np.frombuffer(bio.read(bank_size), dtype=np.uint8) elif data_type == DataBuffer.TYPE_SHORT: @@ -153,23 +190,23 @@ def _read_data_buffer(bio: BytesIO) -> DataBuffer: def _read_utf8_string(bio: BytesIO) -> str: - size, = struct.unpack("=i", bio.read(4)) + (size,) = struct.unpack("=i", bio.read(4)) utf8_bytes = bio.read(size) - return utf8_bytes.decode('utf-8') + return utf8_bytes.decode("utf-8") def _ignore_java_object(bio: BytesIO): - size, = struct.unpack("=i", bio.read(4)) + (size,) = struct.unpack("=i", bio.read(4)) bio.read(size) def _read_int_array(bio: BytesIO) -> List[int]: - length, = struct.unpack("=i", bio.read(4)) + (length,) = struct.unpack("=i", bio.read(4)) return [struct.unpack("=i", bio.read(4))[0] for _ in range(length)] def _read_utf8_string_map(bio: BytesIO) -> Optional[Dict[str, str]]: - size, = struct.unpack("=i", bio.read(4)) + (size,) = struct.unpack("=i", bio.read(4)) if size == -1: return None params = {} diff --git a/python/sedona/raster/sample_model.py b/python/sedona/raster/sample_model.py index 4c5ac193e7..2959b599f8 100644 --- a/python/sedona/raster/sample_model.py +++ b/python/sedona/raster/sample_model.py @@ -27,6 +27,7 @@ class SampleModel(ABC): SampleModel class in Java AWT. """ + TYPE_BANDED = 1 TYPE_PIXEL_INTERLEAVED = 2 TYPE_SINGLE_PIXEL_PACKED = 3 @@ -47,7 +48,9 @@ def __init__(self, sample_model_type, data_type, width, height): @abstractmethod def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: - raise NotImplementedError("Abstract method as_numpy was not implemented by subclass") + raise NotImplementedError( + "Abstract method as_numpy was not implemented by subclass" + ) class ComponentSampleModel(SampleModel): @@ -56,7 +59,16 @@ class ComponentSampleModel(SampleModel): bank_indices: List[int] band_offsets: List[int] - def __init__(self, data_type, width, height, pixel_stride, scanline_stride, bank_indices, band_offsets): + def __init__( + self, + data_type, + width, + height, + pixel_stride, + scanline_stride, + bank_indices, + band_offsets, + ): super().__init__(SampleModel.TYPE_COMPONENT, data_type, width, height) self.pixel_stride = pixel_stride self.scanline_stride = scanline_stride @@ -71,7 +83,7 @@ def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: bank_data = data_buffer.bank_data[bank_index] offset = self.band_offsets[bank_index] if offset != 0: - bank_data = bank_data[offset:(offset + self.width * self.height)] + bank_data = bank_data[offset : (offset + self.width * self.height)] band_arr = bank_data.reshape(self.height, self.width) band_arrs.append(band_arr) return np.array(band_arrs) @@ -98,7 +110,9 @@ class PixelInterleavedSampleModel(SampleModel): scanline_stride: int band_offsets: List[int] - def __init__(self, data_type, width, height, pixel_stride, scanline_stride, band_offsets): + def __init__( + self, data_type, width, height, pixel_stride, scanline_stride, band_offsets + ): super().__init__(SampleModel.TYPE_PIXEL_INTERLEAVED, data_type, width, height) self.pixel_stride = pixel_stride self.scanline_stride = scanline_stride @@ -107,9 +121,11 @@ def __init__(self, data_type, width, height, pixel_stride, scanline_stride, band def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: num_bands = len(self.band_offsets) bank_data = data_buffer.bank_data[0] - if self.pixel_stride == num_bands and \ - self.scanline_stride == self.width * num_bands and \ - self.band_offsets == list(range(0, num_bands)): + if ( + self.pixel_stride == num_bands + and self.scanline_stride == self.width * num_bands + and self.band_offsets == list(range(0, num_bands)) + ): # Fast path: no gapping in between band data, no band reordering arr = bank_data.reshape(self.height, self.width, num_bands) return np.transpose(arr, [2, 0, 1]) @@ -151,7 +167,9 @@ def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: for mask, bit_offset in zip(self.bit_masks, self.bit_offsets): pixel.append((value & mask) >> bit_offset) pixel_data.append(pixel) - arr = np.array(pixel_data, dtype=bank_data.dtype).reshape(self.height, self.width, num_bands) + arr = np.array(pixel_data, dtype=bank_data.dtype).reshape( + self.height, self.width, num_bands + ) return np.transpose(arr, [2, 0, 1]) @@ -160,7 +178,9 @@ class MultiPixelPackedSampleModel(SampleModel): scanline_stride: int data_bit_offset: int - def __init__(self, data_type, width, height, num_bits, scanline_stride, data_bit_offset): + def __init__( + self, data_type, width, height, num_bits, scanline_stride, data_bit_offset + ): super().__init__(SampleModel.TYPE_MULTI_PIXEL_PACKED, data_type, width, height) self.num_bits = num_bits self.scanline_stride = scanline_stride @@ -178,12 +198,12 @@ def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: pos = y * self.scanline_stride + self.data_bit_offset // bits_per_value value = bank_data[pos] shift = self.data_bit_offset % bits_per_value - value = (value << shift) + value = value << shift pixels: List[int] = [] while len(pixels) < self.width: while shift < bits_per_value and len(pixels) < self.width: pixels.append((value & mask) >> shift_right) - value = (value << self.num_bits) + value = value << self.num_bits shift += self.num_bits pos += 1 value = bank_data[pos] diff --git a/python/sedona/raster/sedona_raster.py b/python/sedona/raster/sedona_raster.py index e5ecb3723b..4520950edc 100644 --- a/python/sedona/raster/sedona_raster.py +++ b/python/sedona/raster/sedona_raster.py @@ -20,18 +20,18 @@ from xml.etree.ElementTree import Element, SubElement, tostring import numpy as np -import rasterio # type: ignore -import rasterio.env # type: ignore -from rasterio.transform import Affine # type: ignore -from rasterio.io import MemoryFile # type: ignore -from rasterio.io import DatasetReader # type: ignore +import rasterio # type: ignore +import rasterio.env # type: ignore +from rasterio.transform import Affine # type: ignore +from rasterio.io import MemoryFile # type: ignore +from rasterio.io import DatasetReader # type: ignore try: # for rasterio >= 1.3.0 - from rasterio._path import _parse_path as parse_path # type: ignore + from rasterio._path import _parse_path as parse_path # type: ignore except: # for rasterio >= 1.2.0 - from rasterio.path import parse_path # type: ignore + from rasterio.path import parse_path # type: ignore from .awt_raster import AWTRaster from .data_buffer import DataBuffer @@ -53,39 +53,62 @@ def _rasterio_open(fp, driver=None): return rasterio.open(fp, mode="r", driver=driver) -def _generate_vrt_xml(src_path, data_type, width, height, geo_transform, crs_wkt, off_x, off_y, band_indices) -> bytes: +def _generate_vrt_xml( + src_path, + data_type, + width, + height, + geo_transform, + crs_wkt, + off_x, + off_y, + band_indices, +) -> bytes: # Create root element - root = Element('VRTDataset') - root.set('rasterXSize', str(width)) - root.set('rasterYSize', str(height)) + root = Element("VRTDataset") + root.set("rasterXSize", str(width)) + root.set("rasterYSize", str(height)) # Add CRS - if crs_wkt is not None and crs_wkt != '': - srs = SubElement(root, 'SRS') + if crs_wkt is not None and crs_wkt != "": + srs = SubElement(root, "SRS") srs.text = crs_wkt # Add GeoTransform - gt = SubElement(root, 'GeoTransform') + gt = SubElement(root, "GeoTransform") gt.text = geo_transform # Add bands for i, band_index in enumerate(band_indices, start=1): - band = SubElement(root, 'VRTRasterBand') - band.set('dataType', data_type) - band.set('band', str(i)) + band = SubElement(root, "VRTRasterBand") + band.set("dataType", data_type) + band.set("band", str(i)) # Add source - source = SubElement(band, 'SimpleSource') - src_prop = SubElement(source, 'SourceFilename') + source = SubElement(band, "SimpleSource") + src_prop = SubElement(source, "SourceFilename") src_prop.text = src_path # Set source properties - SubElement(source, 'SourceBand').text = str(band_index + 1) - SubElement(source, 'SrcRect', {'xOff': str(off_x), 'yOff': str(off_y), 'xSize': str(width), 'ySize': str(height)}) - SubElement(source, 'DstRect', {'xOff': '0', 'yOff': '0', 'xSize': str(width), 'ySize': str(height)}) + SubElement(source, "SourceBand").text = str(band_index + 1) + SubElement( + source, + "SrcRect", + { + "xOff": str(off_x), + "yOff": str(off_y), + "xSize": str(width), + "ySize": str(height), + }, + ) + SubElement( + source, + "DstRect", + {"xOff": "0", "yOff": "0", "xSize": str(width), "ySize": str(height)}, + ) # Generate pretty XML - xml_bytes = tostring(root, encoding='utf-8') + xml_bytes = tostring(root, encoding="utf-8") return xml_bytes @@ -96,8 +119,14 @@ class SedonaRaster(ABC): _affine_trans: AffineTransform _crs_wkt: str - def __init__(self, width: int, height: int, bands_meta: List[SampleDimension], - affine_trans: AffineTransform, crs_wkt: str): + def __init__( + self, + width: int, + height: int, + bands_meta: List[SampleDimension], + affine_trans: AffineTransform, + crs_wkt: str, + ): self._width = width self._height = height self._bands_meta = bands_meta @@ -131,9 +160,7 @@ def affine_trans(self) -> AffineTransform: @abstractmethod def as_numpy(self) -> np.ndarray: - """Get the bands data as an numpy array in CHW layout - - """ + """Get the bands data as an numpy array in CHW layout""" raise NotImplementedError() def as_numpy_masked(self) -> np.ndarray: @@ -150,9 +177,7 @@ def as_numpy_masked(self) -> np.ndarray: @abstractmethod def as_rasterio(self) -> DatasetReader: - """Retrieve the raster as an rasterio DatasetReader - - """ + """Retrieve the raster as an rasterio DatasetReader""" raise NotImplementedError() @abstractmethod @@ -178,9 +203,15 @@ class InDbSedonaRaster(SedonaRaster): rasterio_memfile: Optional[MemoryFile] rasterio_dataset_reader: Optional[DatasetReader] - def __init__(self, width: int, height: int, bands_meta: List[SampleDimension], - affine_trans: AffineTransform, crs_wkt: str, - awt_raster: AWTRaster): + def __init__( + self, + width: int, + height: int, + bands_meta: List[SampleDimension], + affine_trans: AffineTransform, + crs_wkt: str, + awt_raster: AWTRaster, + ): super().__init__(width, height, bands_meta, affine_trans, crs_wkt) self.awt_raster = awt_raster self.rasterio_memfile = None @@ -195,55 +226,72 @@ def as_rasterio(self) -> DatasetReader: return self.rasterio_dataset_reader affine = Affine.from_gdal( - self._affine_trans.ip_x, self._affine_trans.scale_x, self._affine_trans.skew_x, - self._affine_trans.ip_y, self._affine_trans.skew_y, self._affine_trans.scale_y) + self._affine_trans.ip_x, + self._affine_trans.scale_x, + self._affine_trans.skew_x, + self._affine_trans.ip_y, + self._affine_trans.skew_y, + self._affine_trans.scale_y, + ) num_bands = len(self._bands_meta) data_array = np.ascontiguousarray(self.as_numpy()) dtype = data_array.dtype if dtype == np.uint8: - data_type = 'Byte' + data_type = "Byte" elif dtype == np.int8: - data_type = 'Int8' + data_type = "Int8" elif dtype == np.uint16: - data_type = 'Uint16' + data_type = "Uint16" elif dtype == np.int16: - data_type = 'Int16' + data_type = "Int16" elif dtype == np.uint32: - data_type = 'UInt32' + data_type = "UInt32" elif dtype == np.int32: - data_type = 'Int32' + data_type = "Int32" elif dtype == np.float32: - data_type = 'Float32' + data_type = "Float32" elif dtype == np.float64: - data_type = 'Float64' + data_type = "Float64" elif dtype == np.int64: - data_type = 'Int64' + data_type = "Int64" elif dtype == np.uint64: - data_type = 'Uint64' + data_type = "Uint64" else: raise RuntimeError("unknown dtype: " + str(dtype)) arr_if = data_array.__array_interface__ - data_pointer = arr_if['data'][0] - geotransform = (f"{self._affine_trans.ip_x}/{self._affine_trans.scale_x}/{self._affine_trans.skew_x}/" + - f"{self._affine_trans.ip_y}/{self._affine_trans.skew_y}/{self._affine_trans.scale_y}") + data_pointer = arr_if["data"][0] + geotransform = ( + f"{self._affine_trans.ip_x}/{self._affine_trans.scale_x}/{self._affine_trans.skew_x}/" + + f"{self._affine_trans.ip_y}/{self._affine_trans.skew_y}/{self._affine_trans.scale_y}" + ) # FIXME: GDAL 3.6 shipped with rasterio does not support # SPATIALREFERENCE parameter, so we have to workaround this issue in a # hacky way. If newer versions of rasterio bundle GDAL 3.7 then this # won't be a problem. See https://gdal.org/drivers/raster/mem.html - desc = (f"MEM:::DATAPOINTER={data_pointer},PIXELS={self._width},LINES={self._height},BANDS={num_bands}," + - f"DATATYPE={data_type},GEOTRANSFORM={geotransform}") + desc = ( + f"MEM:::DATAPOINTER={data_pointer},PIXELS={self._width},LINES={self._height},BANDS={num_bands}," + + f"DATATYPE={data_type},GEOTRANSFORM={geotransform}" + ) # construct a VRT to wrap this MEM dataset, with SRS set up properly vrt_xml = _generate_vrt_xml( - desc, data_type, self._width, self._height, geotransform.replace('/', ','), self._crs_wkt, - 0, 0, list(range(num_bands))) + desc, + data_type, + self._width, + self._height, + geotransform.replace("/", ","), + self._crs_wkt, + 0, + 0, + list(range(num_bands)), + ) # dataset = _rasterio_open(desc, driver="MEM") - self.rasterio_memfile = MemoryFile(vrt_xml, ext='.vrt') - dataset = self.rasterio_memfile.open(driver='VRT') + self.rasterio_memfile = MemoryFile(vrt_xml, ext=".vrt") + dataset = self.rasterio_memfile.open(driver="VRT") # XXX: dataset does not copy the data held by data_array, so we set # data_array as a property of dataset to make sure that the lifetime of @@ -254,8 +302,8 @@ def as_rasterio(self) -> DatasetReader: def close(self): if self.rasterio_dataset_reader is not None: - self.rasterio_dataset_reader.close() - self.rasterio_dataset_reader = None + self.rasterio_dataset_reader.close() + self.rasterio_dataset_reader = None if self.rasterio_memfile is not None: self.rasterio_memfile.close() self.rasterio_memfile = None diff --git a/python/sedona/raster_utils/SedonaUtils.py b/python/sedona/raster_utils/SedonaUtils.py index 4dbf9e3999..ef0926ab6d 100644 --- a/python/sedona/raster_utils/SedonaUtils.py +++ b/python/sedona/raster_utils/SedonaUtils.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. + class SedonaUtils: @classmethod def display_image(cls, df): from IPython.display import display, HTML + display(HTML(df.toPandas().to_html(escape=False))) diff --git a/python/sedona/register/geo_registrator.py b/python/sedona/register/geo_registrator.py index 2df8474e6a..8b80f3b17d 100644 --- a/python/sedona/register/geo_registrator.py +++ b/python/sedona/register/geo_registrator.py @@ -50,11 +50,13 @@ def registerAll(cls, spark: SparkSession) -> bool: def register(cls, spark: SparkSession): return spark._jvm.SedonaSQLRegistrator.registerAll(spark._jsparkSession) + class PackageImporter: @staticmethod def import_jvm_lib(jvm) -> bool: from sedona.core.utils import ImportedJvmLib + """ Imports all the specified methods and functions in jvm :param jvm: Jvm gateway from py4j diff --git a/python/sedona/register/java_libs.py b/python/sedona/register/java_libs.py index 6489cc4f39..931488d917 100644 --- a/python/sedona/register/java_libs.py +++ b/python/sedona/register/java_libs.py @@ -25,7 +25,9 @@ class SedonaJvmLib(Enum): KNNQuery = "org.apache.sedona.core.spatialOperator.KNNQuery" RangeQuery = "org.apache.sedona.core.spatialOperator.RangeQuery" Envelope = "org.locationtech.jts.geom.Envelope" - GeoSerializerData = "org.apache.sedona.python.wrapper.adapters.GeoSparkPythonConverter" + GeoSerializerData = ( + "org.apache.sedona.python.wrapper.adapters.GeoSparkPythonConverter" + ) GeometryAdapter = "org.apache.sedona.python.wrapper.adapters.GeometryAdapter" PointRDD = "org.apache.sedona.core.spatialRDD.PointRDD" PolygonRDD = "org.apache.sedona.core.spatialRDD.PolygonRDD" @@ -35,19 +37,27 @@ class SedonaJvmLib(Enum): SpatialRDD = "org.apache.sedona.core.spatialRDD.SpatialRDD" FileDataSplitter = "org.apache.sedona.common.enums.FileDataSplitter" GeoJsonReader = "org.apache.sedona.core.formatMapper.GeoJsonReader" - ShapeFileReader = "org.apache.sedona.core.formatMapper.shapefileParser.ShapefileReader" + ShapeFileReader = ( + "org.apache.sedona.core.formatMapper.shapefileParser.ShapefileReader" + ) SedonaSQLRegistrator = "org.apache.sedona.sql.utils.SedonaSQLRegistrator" StorageLevel = "org.apache.spark.storage.StorageLevel" GridType = "org.apache.sedona.core.enums.GridType" IndexType = "org.apache.sedona.core.enums.IndexType" AdapterWrapper = "org.apache.sedona.python.wrapper.utils.PythonAdapterWrapper" WktReader = "org.apache.sedona.core.formatMapper.WktReader" - RawJvmIndexRDDSetter = "org.apache.sedona.python.wrapper.adapters.RawJvmIndexRDDSetter" - SpatialObjectLoaderAdapter = "org.apache.sedona.python.wrapper.adapters.SpatialObjectLoaderAdapter" + RawJvmIndexRDDSetter = ( + "org.apache.sedona.python.wrapper.adapters.RawJvmIndexRDDSetter" + ) + SpatialObjectLoaderAdapter = ( + "org.apache.sedona.python.wrapper.adapters.SpatialObjectLoaderAdapter" + ) WkbReader = "org.apache.sedona.core.formatMapper.WkbReader" EnvelopeAdapter = "org.apache.sedona.python.wrapper.adapters.EnvelopeAdapter" PythonConverter = "org.apache.sedona.python.wrapper.adapters.PythonConverter" - PythonRddToJavaRDDAdapter = "org.apache.sedona.python.wrapper.adapters.PythonRddToJavaRDDAdapter" + PythonRddToJavaRDDAdapter = ( + "org.apache.sedona.python.wrapper.adapters.PythonRddToJavaRDDAdapter" + ) st_constructors = "org.apache.spark.sql.sedona_sql.expressions.st_constructors" st_functions = "org.apache.spark.sql.sedona_sql.expressions.st_functions" st_predicates = "org.apache.spark.sql.sedona_sql.expressions.st_predicates" @@ -55,7 +65,7 @@ class SedonaJvmLib(Enum): SedonaContext = "org.apache.sedona.spark.SedonaContext" @classmethod - def from_str(cls, geo_lib: str) -> 'SedonaJvmLib': + def from_str(cls, geo_lib: str) -> "SedonaJvmLib": try: lib = getattr(cls, geo_lib.upper()) except AttributeError: diff --git a/python/sedona/spark/SedonaContext.py b/python/sedona/spark/SedonaContext.py index cda98a60f5..5cba5df624 100644 --- a/python/sedona/spark/SedonaContext.py +++ b/python/sedona/spark/SedonaContext.py @@ -46,5 +46,6 @@ def builder(cls) -> SparkSession.builder: This method is needed when the user wants to manually configure Sedona :return: SparkSession.builder """ - return SparkSession.builder.config("spark.serializer", KryoSerializer.getName).\ - config("spark.kryo.registrator", SedonaKryoRegistrator.getName) + return SparkSession.builder.config( + "spark.serializer", KryoSerializer.getName + ).config("spark.kryo.registrator", SedonaKryoRegistrator.getName) diff --git a/python/sedona/sql/__init__.py b/python/sedona/sql/__init__.py index 5211b7fd22..c8e002fb0a 100644 --- a/python/sedona/sql/__init__.py +++ b/python/sedona/sql/__init__.py @@ -30,9 +30,11 @@ from sedona.sql.st_predicates import * __all__ = ( - [name for name, obj in inspect.getmembers(sys.modules[__name__])] # get expected values from the modules - + st_predicates.__all__ - + st_constructors.__all__ - + st_functions.__all__ - + st_aggregates.__all__ + [ + name for name, obj in inspect.getmembers(sys.modules[__name__]) + ] # get expected values from the modules + + st_predicates.__all__ + + st_constructors.__all__ + + st_functions.__all__ + + st_aggregates.__all__ ) diff --git a/python/sedona/sql/dataframe_api.py b/python/sedona/sql/dataframe_api.py index aa78563a2b..9d7c2f47b5 100644 --- a/python/sedona/sql/dataframe_api.py +++ b/python/sedona/sql/dataframe_api.py @@ -48,13 +48,21 @@ def _convert_argument_to_java_column(arg: Any) -> Column: return f.lit(arg)._jc -def call_sedona_function(object_name: str, function_name: str, args: Union[Any, Tuple[Any]]) -> Column: +def call_sedona_function( + object_name: str, function_name: str, args: Union[Any, Tuple[Any]] +) -> Column: spark = SparkSession.getActiveSession() if spark is None: - raise ValueError("No active spark session was detected. Unable to call sedona function.") + raise ValueError( + "No active spark session was detected. Unable to call sedona function." + ) # apparently a Column is an Iterable so we need to check for it explicitly - if (not isinstance(args, Iterable)) or isinstance(args, str) or isinstance(args, Column): + if ( + (not isinstance(args, Iterable)) + or isinstance(args, str) + or isinstance(args, Column) + ): args = [args] args = map(_convert_argument_to_java_column, args) @@ -78,7 +86,10 @@ def _get_type_list(annotated_type: Type) -> Tuple[Type, ...]: """ # in 3.8 there is a much nicer way to do this with typing.get_origin # we have to be a bit messy until we drop support for 3.7 - if isinstance(annotated_type, typing._GenericAlias) and annotated_type.__origin__._name == "Union": + if ( + isinstance(annotated_type, typing._GenericAlias) + and annotated_type.__origin__._name == "Union" + ): # again, there is a really nice method for this in 3.8: typing.get_args valid_types = annotated_type.__args__ else: @@ -88,7 +99,7 @@ def _get_type_list(annotated_type: Type) -> Tuple[Type, ...]: def _strip_extra_from_class_name(class_name): - return class_name[len("")].split(".")[-1] + return class_name[len("")].split(".")[-1] def _get_readable_name_for_type(type: Type) -> str: @@ -118,7 +129,9 @@ def _get_bound_arguments(f: Callable, *args, **kwargs) -> Mapping[str, Any]: return bound_args -def _check_bound_arguments(bound_args: Mapping[str, Any], type_annotations: List[Type], function_name: str) -> None: +def _check_bound_arguments( + bound_args: Mapping[str, Any], type_annotations: List[Type], function_name: str +) -> None: """Check bound arguments against type annotations and raise a ValueError if any do not match. :param bound_args: Bound arguments to check. @@ -132,8 +145,12 @@ def _check_bound_arguments(bound_args: Mapping[str, Any], type_annotations: List for bound_arg_name, bound_arg_value in bound_args.arguments.items(): annotated_type = type_annotations[bound_arg_name] valid_type_list = _get_type_list(annotated_type) - if not any([isinstance(bound_arg_value, valid_type) for valid_type in valid_type_list]): - raise ValueError(f"Incorrect argument type: {bound_arg_name} for {function_name} should be {_get_readable_name_for_type(annotated_type)} but received {_strip_extra_from_class_name(str(type(bound_arg_value)))}.") + if not any( + [isinstance(bound_arg_value, valid_type) for valid_type in valid_type_list] + ): + raise ValueError( + f"Incorrect argument type: {bound_arg_name} for {function_name} should be {_get_readable_name_for_type(annotated_type)} but received {_strip_extra_from_class_name(str(type(bound_arg_value)))}." + ) def validate_argument_types(f: Callable) -> Callable: @@ -146,12 +163,19 @@ def validate_argument_types(f: Callable) -> Callable: :return: f wrapped with type validation checks. :rtype: Callable """ + def validated_function(*args, **kwargs) -> Column: # all arguments are Columns or strings are always legal, so only check types when one of the arguments is not a column - if not all([isinstance(x, Column) or isinstance(x, str) for x in itertools.chain(args, kwargs.values())]): + if not all( + [ + isinstance(x, Column) or isinstance(x, str) + for x in itertools.chain(args, kwargs.values()) + ] + ): bound_args = _get_bound_arguments(f, *args, **kwargs) type_annotations = typing.get_type_hints(f) _check_bound_arguments(bound_args, type_annotations, f.__name__) return f(*args, **kwargs) + return functools.update_wrapper(validated_function, f) diff --git a/python/sedona/sql/exceptions.py b/python/sedona/sql/exceptions.py index 3b79970666..6bbbe30a7a 100644 --- a/python/sedona/sql/exceptions.py +++ b/python/sedona/sql/exceptions.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. + class GeometryUnavailableException(Exception): def __init__(self, message): diff --git a/python/sedona/sql/st_aggregates.py b/python/sedona/sql/st_aggregates.py index 67315ecf8e..5632594b48 100644 --- a/python/sedona/sql/st_aggregates.py +++ b/python/sedona/sql/st_aggregates.py @@ -21,7 +21,11 @@ from pyspark.sql import Column -from sedona.sql.dataframe_api import ColumnOrName, call_sedona_function, validate_argument_types +from sedona.sql.dataframe_api import ( + ColumnOrName, + call_sedona_function, + validate_argument_types, +) _call_aggregate_function = partial(call_sedona_function, "st_aggregates") @@ -63,5 +67,8 @@ def ST_Union_Aggr(geometry: ColumnOrName) -> Column: # Automatically populate __all__ -__all__ = [name for name, obj in inspect.getmembers(sys.modules[__name__]) - if inspect.isfunction(obj)] +__all__ = [ + name + for name, obj in inspect.getmembers(sys.modules[__name__]) + if inspect.isfunction(obj) +] diff --git a/python/sedona/sql/st_constructors.py b/python/sedona/sql/st_constructors.py index 594a36c662..b4e9aa6865 100644 --- a/python/sedona/sql/st_constructors.py +++ b/python/sedona/sql/st_constructors.py @@ -22,14 +22,21 @@ from pyspark.sql import Column -from sedona.sql.dataframe_api import ColumnOrName, ColumnOrNameOrNumber, call_sedona_function, validate_argument_types +from sedona.sql.dataframe_api import ( + ColumnOrName, + ColumnOrNameOrNumber, + call_sedona_function, + validate_argument_types, +) _call_constructor_function = partial(call_sedona_function, "st_constructors") @validate_argument_types -def ST_GeomFromGeoHash(geohash: ColumnOrName, precision: Union[ColumnOrName, int]) -> Column: +def ST_GeomFromGeoHash( + geohash: ColumnOrName, precision: Union[ColumnOrName, int] +) -> Column: """Generate a geometry column from a geohash column at a specified precision. :param geohash: Geohash string column to generate from. @@ -41,8 +48,11 @@ def ST_GeomFromGeoHash(geohash: ColumnOrName, precision: Union[ColumnOrName, int """ return _call_constructor_function("ST_GeomFromGeoHash", (geohash, precision)) + @validate_argument_types -def ST_PointFromGeoHash(geohash: ColumnOrName, precision: Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_PointFromGeoHash( + geohash: ColumnOrName, precision: Optional[Union[ColumnOrName, int]] = None +) -> Column: """Generate a point column from a geohash column at a specified precision. :param geohash: Geohash string column to generate from. @@ -94,7 +104,9 @@ def ST_GeomFromKML(kml_string: ColumnOrName) -> Column: @validate_argument_types -def ST_GeomFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_GeomFromText( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate a geometry column from a Well-Known Text (WKT) string column. This is an alias of ST_GeomFromWKT. @@ -107,8 +119,11 @@ def ST_GeomFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = No return _call_constructor_function("ST_GeomFromText", args) + @validate_argument_types -def ST_GeometryFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_GeometryFromText( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate a geometry column from a Well-Known Text (WKT) string column. This is an alias of ST_GeomFromWKT. @@ -133,6 +148,7 @@ def ST_GeomFromWKB(wkb: ColumnOrName) -> Column: """ return _call_constructor_function("ST_GeomFromWKB", wkb) + @validate_argument_types def ST_GeomFromEWKB(wkb: ColumnOrName) -> Column: """Generate a geometry column from a Well-Known Binary (WKB) binary column. @@ -146,7 +162,9 @@ def ST_GeomFromEWKB(wkb: ColumnOrName) -> Column: @validate_argument_types -def ST_GeomFromWKT(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_GeomFromWKT( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate a geometry column from a Well-Known Text (WKT) string column. This is an alias of ST_GeomFromText. @@ -159,6 +177,7 @@ def ST_GeomFromWKT(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = Non return _call_constructor_function("ST_GeomFromWKT", args) + @validate_argument_types def ST_GeomFromEWKT(ewkt: ColumnOrName) -> Column: """Generate a geometry column from a OGC Extended Well-Known Text (WKT) string column. @@ -211,8 +230,14 @@ def ST_Point(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber) -> Column: """ return _call_constructor_function("ST_Point", (x, y)) + @validate_argument_types -def ST_PointZ(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, z: ColumnOrNameOrNumber, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_PointZ( + x: ColumnOrNameOrNumber, + y: ColumnOrNameOrNumber, + z: ColumnOrNameOrNumber, + srid: Optional[ColumnOrNameOrNumber] = None, +) -> Column: """Generates a 3D point geometry column from numeric values. :param x: Either a number or numeric column representing the X coordinate of a point. @@ -229,8 +254,14 @@ def ST_PointZ(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, z: ColumnOrNameO args = (x, y, z) if srid is None else (x, y, z, srid) return _call_constructor_function("ST_PointZ", args) + @validate_argument_types -def ST_PointM(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, m: ColumnOrNameOrNumber, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_PointM( + x: ColumnOrNameOrNumber, + y: ColumnOrNameOrNumber, + m: ColumnOrNameOrNumber, + srid: Optional[ColumnOrNameOrNumber] = None, +) -> Column: """Generates a 3D point geometry column from numeric values. :param x: Either a number or numeric column representing the X coordinate of a point. @@ -247,8 +278,15 @@ def ST_PointM(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, m: ColumnOrNameO args = (x, y, m) if srid is None else (x, y, m, srid) return _call_constructor_function("ST_PointM", args) + @validate_argument_types -def ST_PointZM(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, z: ColumnOrNameOrNumber, m: ColumnOrNameOrNumber,srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_PointZM( + x: ColumnOrNameOrNumber, + y: ColumnOrNameOrNumber, + z: ColumnOrNameOrNumber, + m: ColumnOrNameOrNumber, + srid: Optional[ColumnOrNameOrNumber] = None, +) -> Column: """Generates a 3D point geometry column from numeric values. :param x: Either a number or numeric column representing the X coordinate of a point. @@ -267,6 +305,7 @@ def ST_PointZM(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, z: ColumnOrName args = (x, y, z, m) if srid is None else (x, y, z, m, srid) return _call_constructor_function("ST_PointZM", args) + @validate_argument_types def ST_PointFromText(coords: ColumnOrName, delimiter: ColumnOrName) -> Column: """Generate a point geometry column from coordinates separated by a delimiter and stored @@ -281,8 +320,11 @@ def ST_PointFromText(coords: ColumnOrName, delimiter: ColumnOrName) -> Column: """ return _call_constructor_function("ST_PointFromText", (coords, delimiter)) + @validate_argument_types -def ST_PointFromWKB(wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_PointFromWKB( + wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate a Point geometry column from a Well-Known Binary (WKB) binary column. :param wkb: WKB binary column to generate from. @@ -295,8 +337,11 @@ def ST_PointFromWKB(wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = No args = (wkb) if srid is None else (wkb, srid) return _call_constructor_function("ST_PointFromWKB", args) + @validate_argument_types -def ST_LineFromWKB(wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_LineFromWKB( + wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate a Line geometry column from a Well-Known Binary (WKB) binary column. :param wkb: WKB binary column to generate from. @@ -309,8 +354,11 @@ def ST_LineFromWKB(wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = Non args = (wkb) if srid is None else (wkb, srid) return _call_constructor_function("ST_LineFromWKB", args) + @validate_argument_types -def ST_LinestringFromWKB(wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_LinestringFromWKB( + wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate a Line geometry column from a Well-Known Binary (WKB) binary column. :param wkb: WKB binary column to generate from. @@ -323,23 +371,32 @@ def ST_LinestringFromWKB(wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] args = (wkb) if srid is None else (wkb, srid) return _call_constructor_function("ST_LinestringFromWKB", args) + @validate_argument_types -def ST_MakePointM(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, m: ColumnOrNameOrNumber) -> Column: +def ST_MakePointM( + x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, m: ColumnOrNameOrNumber +) -> Column: """Generate 3D M Point geometry. - :param x: Either a number or numeric column representing the X coordinate of a point. - :type x: ColumnOrNameOrNumber - :param y: Either a number or numeric column representing the Y coordinate of a point. - :type y: ColumnOrNameOrNumber - :param m: Either a number or numeric column representing the M coordinate of a point - :type m: ColumnOrNameOrNumber - :return: Point geometry column generated from the coordinate values. - :rtype: Column - """ + :param x: Either a number or numeric column representing the X coordinate of a point. + :type x: ColumnOrNameOrNumber + :param y: Either a number or numeric column representing the Y coordinate of a point. + :type y: ColumnOrNameOrNumber + :param m: Either a number or numeric column representing the M coordinate of a point + :type m: ColumnOrNameOrNumber + :return: Point geometry column generated from the coordinate values. + :rtype: Column + """ return _call_constructor_function("ST_MakePointM", (x, y, m)) + @validate_argument_types -def ST_MakePoint(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, z: Optional[ColumnOrNameOrNumber] = None, m: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_MakePoint( + x: ColumnOrNameOrNumber, + y: ColumnOrNameOrNumber, + z: Optional[ColumnOrNameOrNumber] = None, + m: Optional[ColumnOrNameOrNumber] = None, +) -> Column: """Generate a 2D, 3D Z or 4D ZM Point geometry. If z is None then a 2D point is generated. This function doesn't support M coordinates for creating a 4D ZM Point in Dataframe API. @@ -361,8 +418,15 @@ def ST_MakePoint(x: ColumnOrNameOrNumber, y: ColumnOrNameOrNumber, z: Optional[C args = args + (m,) return _call_constructor_function("ST_MakePoint", (args)) + @validate_argument_types -def ST_MakeEnvelope(min_x: ColumnOrNameOrNumber, min_y: ColumnOrNameOrNumber, max_x: ColumnOrNameOrNumber, max_y: ColumnOrNameOrNumber, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_MakeEnvelope( + min_x: ColumnOrNameOrNumber, + min_y: ColumnOrNameOrNumber, + max_x: ColumnOrNameOrNumber, + max_y: ColumnOrNameOrNumber, + srid: Optional[ColumnOrNameOrNumber] = None, +) -> Column: """Generate a polygon geometry column from the minimum and maximum coordinates of an envelope with an option to add SRID :param min_x: Minimum X coordinate for the envelope. @@ -384,8 +448,14 @@ def ST_MakeEnvelope(min_x: ColumnOrNameOrNumber, min_y: ColumnOrNameOrNumber, ma return _call_constructor_function("ST_MakeEnvelope", args) + @validate_argument_types -def ST_PolygonFromEnvelope(min_x: ColumnOrNameOrNumber, min_y: ColumnOrNameOrNumber, max_x: ColumnOrNameOrNumber, max_y: ColumnOrNameOrNumber) -> Column: +def ST_PolygonFromEnvelope( + min_x: ColumnOrNameOrNumber, + min_y: ColumnOrNameOrNumber, + max_x: ColumnOrNameOrNumber, + max_y: ColumnOrNameOrNumber, +) -> Column: """Generate a polygon geometry column from the minimum and maximum coordinates of an envelope. :param min_x: Minimum X coordinate for the envelope. @@ -399,7 +469,9 @@ def ST_PolygonFromEnvelope(min_x: ColumnOrNameOrNumber, min_y: ColumnOrNameOrNum :return: Polygon geometry column representing the envelope described by the coordinate bounds. :rtype: Column """ - return _call_constructor_function("ST_PolygonFromEnvelope", (min_x, min_y, max_x, max_y)) + return _call_constructor_function( + "ST_PolygonFromEnvelope", (min_x, min_y, max_x, max_y) + ) @validate_argument_types @@ -416,8 +488,11 @@ def ST_PolygonFromText(coords: ColumnOrName, delimiter: ColumnOrName) -> Column: """ return _call_constructor_function("ST_PolygonFromText", (coords, delimiter)) + @validate_argument_types -def ST_MPolyFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_MPolyFromText( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate multiPolygon geometry from a multiPolygon WKT representation. :param wkt: multiPolygon WKT string column to generate from. @@ -429,8 +504,11 @@ def ST_MPolyFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = N return _call_constructor_function("ST_MPolyFromText", args) + @validate_argument_types -def ST_MLineFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_MLineFromText( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate multiLineString geometry from a multiLineString WKT representation. :param wkt: multiLineString WKT string column to generate from. @@ -442,8 +520,11 @@ def ST_MLineFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = N return _call_constructor_function("ST_MLineFromText", args) + @validate_argument_types -def ST_MPointFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_MPointFromText( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate MultiPoint geometry from a MultiPoint WKT representation. :param wkt: MultiPoint WKT string column to generate from. @@ -457,8 +538,11 @@ def ST_MPointFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = return _call_constructor_function("ST_MPointFromText", args) + @validate_argument_types -def ST_GeomCollFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None) -> Column: +def ST_GeomCollFromText( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: """Generate GeometryCollection geometry from a GeometryCollection WKT representation. :param wkt: GeometryCollection WKT string column to generate from. @@ -474,5 +558,8 @@ def ST_GeomCollFromText(wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] # Automatically populate __all__ -__all__ = [name for name, obj in inspect.getmembers(sys.modules[__name__]) - if inspect.isfunction(obj)] +__all__ = [ + name + for name, obj in inspect.getmembers(sys.modules[__name__]) + if inspect.isfunction(obj) +] diff --git a/python/sedona/sql/st_functions.py b/python/sedona/sql/st_functions.py index fccc1c5653..05d90c2e80 100644 --- a/python/sedona/sql/st_functions.py +++ b/python/sedona/sql/st_functions.py @@ -22,11 +22,17 @@ from pyspark.sql import Column -from sedona.sql.dataframe_api import call_sedona_function, ColumnOrName, ColumnOrNameOrNumber, validate_argument_types +from sedona.sql.dataframe_api import ( + call_sedona_function, + ColumnOrName, + ColumnOrNameOrNumber, + validate_argument_types, +) _call_st_function = partial(call_sedona_function, "st_functions") + @validate_argument_types def GeometryType(geometry: ColumnOrName): """Return the type of the geometry as a string. @@ -39,6 +45,7 @@ def GeometryType(geometry: ColumnOrName): """ return _call_st_function("GeometryType", geometry) + @validate_argument_types def ST_3DDistance(a: ColumnOrName, b: ColumnOrName) -> Column: """Calculate the 3-dimensional minimum Cartesian distance between two geometry columns. @@ -52,8 +59,13 @@ def ST_3DDistance(a: ColumnOrName, b: ColumnOrName) -> Column: """ return _call_st_function("ST_3DDistance", (a, b)) + @validate_argument_types -def ST_AddMeasure(geom: ColumnOrName, measureStart: Union[ColumnOrName, float], measureEnd: Union[ColumnOrName, float]) -> Column: +def ST_AddMeasure( + geom: ColumnOrName, + measureStart: Union[ColumnOrName, float], + measureEnd: Union[ColumnOrName, float], +) -> Column: """Interpolate measure values with the provided start and end points and return the result geometry. :param geom: Geometry column to use in the calculation. @@ -67,8 +79,13 @@ def ST_AddMeasure(geom: ColumnOrName, measureStart: Union[ColumnOrName, float], """ return _call_st_function("ST_AddMeasure", (geom, measureStart, measureEnd)) + @validate_argument_types -def ST_AddPoint(line_string: ColumnOrName, point: ColumnOrName, index: Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_AddPoint( + line_string: ColumnOrName, + point: ColumnOrName, + index: Optional[Union[ColumnOrName, int]] = None, +) -> Column: """Add a point to either the end of a linestring or a specified index. If index is not provided then point will be added to the end of line_string. @@ -96,6 +113,7 @@ def ST_Area(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_Area", geometry) + @validate_argument_types def ST_AreaSpheroid(geometry: ColumnOrName) -> Column: """Calculate the area of a geometry using WGS84 spheroid. @@ -132,8 +150,11 @@ def ST_AsEWKB(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_AsEWKB", geometry) + @validate_argument_types -def ST_AsHEXEWKB(geometry: ColumnOrName, endian: Optional[ColumnOrName] = None) -> Column: +def ST_AsHEXEWKB( + geometry: ColumnOrName, endian: Optional[ColumnOrName] = None +) -> Column: """Generate the Extended Well-Known Binary representation of a geometry as Hex string. :param geometry: Geometry to generate EWKB for. @@ -160,7 +181,9 @@ def ST_AsEWKT(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_AsGeoJSON(geometry: ColumnOrName, type: Optional[Union[ColumnOrName, str]] = None) -> Column: +def ST_AsGeoJSON( + geometry: ColumnOrName, type: Optional[Union[ColumnOrName, str]] = None +) -> Column: """Generate the GeoJSON style representation of a geometry column. :param geometry: Geometry column to generate GeoJSON for. @@ -234,6 +257,7 @@ def ST_BestSRID(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_BestSRID", geometry) + @validate_argument_types def ST_ShiftLongitude(geometry: ColumnOrName) -> Column: """Shifts longitudes between -180..0 degrees to 180..360 degrees and vice versa. @@ -245,6 +269,7 @@ def ST_ShiftLongitude(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_ShiftLongitude", geometry) + @validate_argument_types def ST_Boundary(geometry: ColumnOrName) -> Column: """Calculate the closure of the combinatorial boundary of a geometry column. @@ -258,7 +283,12 @@ def ST_Boundary(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_Buffer(geometry: ColumnOrName, buffer: ColumnOrNameOrNumber, useSpheroid: Optional[Union[ColumnOrName, bool]] = None, parameters: Optional[Union[ColumnOrName, str]] = None) -> Column: +def ST_Buffer( + geometry: ColumnOrName, + buffer: ColumnOrNameOrNumber, + useSpheroid: Optional[Union[ColumnOrName, bool]] = None, + parameters: Optional[Union[ColumnOrName, str]] = None, +) -> Column: """Calculate a geometry that represents all points whose distance from the input geometry column is equal to or less than a given amount. @@ -320,7 +350,9 @@ def ST_Collect(*geometries: ColumnOrName) -> Column: @validate_argument_types -def ST_CollectionExtract(collection: ColumnOrName, geom_type: Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_CollectionExtract( + collection: ColumnOrName, geom_type: Optional[Union[ColumnOrName, int]] = None +) -> Column: """Extract a specific type of geometry from a geometry collection column as a multi-geometry column. @@ -351,7 +383,11 @@ def ST_ClosestPoint(a: ColumnOrName, b: ColumnOrName) -> Column: @validate_argument_types -def ST_ConcaveHull(geometry: ColumnOrName, pctConvex: Union[ColumnOrName, float], allowHoles: Optional[Union[ColumnOrName, bool]] = None) -> Column: +def ST_ConcaveHull( + geometry: ColumnOrName, + pctConvex: Union[ColumnOrName, float], + allowHoles: Optional[Union[ColumnOrName, bool]] = None, +) -> Column: """Generate the cancave hull of a geometry column. :param geometry: Geometry column to generate a cancave hull for. @@ -363,9 +399,14 @@ def ST_ConcaveHull(geometry: ColumnOrName, pctConvex: Union[ColumnOrName, float] :return: Concave hull of geometry as a geometry column. :rtype: Column """ - args = (geometry, pctConvex) if allowHoles is None else (geometry, pctConvex, allowHoles) + args = ( + (geometry, pctConvex) + if allowHoles is None + else (geometry, pctConvex, allowHoles) + ) return _call_st_function("ST_ConcaveHull", args) + @validate_argument_types def ST_ConvexHull(geometry: ColumnOrName) -> Column: """Generate the convex hull of a geometry column. @@ -377,6 +418,7 @@ def ST_ConvexHull(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_ConvexHull", geometry) + @validate_argument_types def ST_CrossesDateLine(a: ColumnOrName) -> Column: """Check whether geometry a crosses the International Date Line. @@ -388,6 +430,7 @@ def ST_CrossesDateLine(a: ColumnOrName) -> Column: """ return _call_st_function("ST_CrossesDateLine", (a)) + @validate_argument_types def ST_Dimension(geometry: ColumnOrName): """Calculate the inherent dimension of a geometry column. @@ -399,6 +442,7 @@ def ST_Dimension(geometry: ColumnOrName): """ return _call_st_function("ST_Dimension", geometry) + @validate_argument_types def ST_Difference(a: ColumnOrName, b: ColumnOrName) -> Column: """Calculate the difference of two geometry columns. This difference @@ -428,6 +472,7 @@ def ST_Distance(a: ColumnOrName, b: ColumnOrName) -> Column: """ return _call_st_function("ST_Distance", (a, b)) + @validate_argument_types def ST_DistanceSpheroid(a: ColumnOrName, b: ColumnOrName) -> Column: """Calculate the geodesic distance between two geometry columns using WGS84 spheroid. @@ -441,8 +486,13 @@ def ST_DistanceSpheroid(a: ColumnOrName, b: ColumnOrName) -> Column: """ return _call_st_function("ST_DistanceSpheroid", (a, b)) + @validate_argument_types -def ST_DistanceSphere(a: ColumnOrName, b: ColumnOrName, radius: Optional[Union[ColumnOrName, float]] = 6371008.0) -> Column: +def ST_DistanceSphere( + a: ColumnOrName, + b: ColumnOrName, + radius: Optional[Union[ColumnOrName, float]] = 6371008.0, +) -> Column: """Calculate the haversine/great-circle distance between two geometry columns using a given radius. :param a: Geometry column to use in the calculation. @@ -456,6 +506,7 @@ def ST_DistanceSphere(a: ColumnOrName, b: ColumnOrName, radius: Optional[Union[C """ return _call_st_function("ST_DistanceSphere", (a, b, radius)) + @validate_argument_types def ST_Dump(geometry: ColumnOrName) -> Column: """Returns an array of geometries that are members of a multi-geometry @@ -508,7 +559,12 @@ def ST_Envelope(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_Expand(geometry: ColumnOrName, deltaX_uniformDelta: Union[ColumnOrName, float], deltaY: Optional[Union[ColumnOrName, float]] = None, deltaZ: Optional[Union[ColumnOrName, float]] = None) -> Column: +def ST_Expand( + geometry: ColumnOrName, + deltaX_uniformDelta: Union[ColumnOrName, float], + deltaY: Optional[Union[ColumnOrName, float]] = None, + deltaZ: Optional[Union[ColumnOrName, float]] = None, +) -> Column: """Expand the given geometry column by a constant unit in each direction :param geometry: Geometry column to calculate the envelope of. @@ -570,7 +626,11 @@ def ST_Force_2D(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_GeneratePoints(geometry: ColumnOrName, numPoints: Union[ColumnOrName, int], seed:Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_GeneratePoints( + geometry: ColumnOrName, + numPoints: Union[ColumnOrName, int], + seed: Optional[Union[ColumnOrName, int]] = None, +) -> Column: """Generate random points in given geometry. :param geometry: Geometry column to hash. @@ -587,6 +647,7 @@ def ST_GeneratePoints(geometry: ColumnOrName, numPoints: Union[ColumnOrName, int return _call_st_function("ST_GeneratePoints", args) + @validate_argument_types def ST_GeoHash(geometry: ColumnOrName, precision: Union[ColumnOrName, int]) -> Column: """Return the geohash of a geometry column at a given precision level. @@ -600,10 +661,14 @@ def ST_GeoHash(geometry: ColumnOrName, precision: Union[ColumnOrName, int]) -> C """ return _call_st_function("ST_GeoHash", (geometry, precision)) + @validate_argument_types -def ST_GeometricMedian(geometry: ColumnOrName, tolerance: Optional[Union[ColumnOrName, float]] = 1e-6, - max_iter: Optional[Union[ColumnOrName, int]] = 1000, - fail_if_not_converged: Optional[Union[ColumnOrName, bool]] = False) -> Column: +def ST_GeometricMedian( + geometry: ColumnOrName, + tolerance: Optional[Union[ColumnOrName, float]] = 1e-6, + max_iter: Optional[Union[ColumnOrName, int]] = 1000, + fail_if_not_converged: Optional[Union[ColumnOrName, bool]] = False, +) -> Column: """Computes the approximate geometric median of a MultiPoint geometry using the Weiszfeld algorithm. The geometric median provides a centrality measure that is less sensitive to outlier points than the centroid. The algorithm will iterate until the distance change between successive iterations is less than the @@ -656,7 +721,9 @@ def ST_GeometryType(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_H3CellDistance(cell1: Union[ColumnOrName, int], cell2: Union[ColumnOrName, int]) -> Column: +def ST_H3CellDistance( + cell1: Union[ColumnOrName, int], cell2: Union[ColumnOrName, int] +) -> Column: """Cover Geometry with H3 Cells and return a List of Long type cell IDs :param cell: start cell :type cell: long @@ -670,7 +737,11 @@ def ST_H3CellDistance(cell1: Union[ColumnOrName, int], cell2: Union[ColumnOrName @validate_argument_types -def ST_H3CellIDs(geometry: ColumnOrName, level: Union[ColumnOrName, int], full_cover: Union[ColumnOrName, bool]) -> Column: +def ST_H3CellIDs( + geometry: ColumnOrName, + level: Union[ColumnOrName, int], + full_cover: Union[ColumnOrName, bool], +) -> Column: """Cover Geometry with H3 Cells and return a List of Long type cell IDs :param geometry: Geometry column to generate cell IDs :type geometry: ColumnOrName @@ -686,7 +757,11 @@ def ST_H3CellIDs(geometry: ColumnOrName, level: Union[ColumnOrName, int], full_c @validate_argument_types -def ST_H3KRing(cell: Union[ColumnOrName, int], k: Union[ColumnOrName, int], exact_ring: Union[ColumnOrName, bool]) -> Column: +def ST_H3KRing( + cell: Union[ColumnOrName, int], + k: Union[ColumnOrName, int], + exact_ring: Union[ColumnOrName, bool], +) -> Column: """Cover Geometry with H3 Cells and return a List of Long type cell IDs :param cell: original cell :type cell: long @@ -804,7 +879,9 @@ def ST_IsSimple(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_IsValid(geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_IsValid( + geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, int]] = None +) -> Column: """Check if a geometry is well formed. :param geometry: Geometry column to check in. @@ -818,8 +895,11 @@ def ST_IsValid(geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, int]] args = (geometry,) if flag is None else (geometry, flag) return _call_st_function("ST_IsValid", args) + @validate_argument_types -def ST_IsValidDetail(geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_IsValidDetail( + geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, int]] = None +) -> Column: """ Return a row of valid, reason and location. valid defines the validity of geometry, reason defines the reason why it is not valid and location defines the location where it is not valid @@ -835,6 +915,7 @@ def ST_IsValidDetail(geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, args = (geometry,) if flag is None else (geometry, flag) return _call_st_function("ST_IsValidDetail", args) + @validate_argument_types def ST_IsValidTrajectory(geometry: ColumnOrName) -> Column: """ @@ -848,8 +929,11 @@ def ST_IsValidTrajectory(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_IsValidTrajectory", (geometry)) + @validate_argument_types -def ST_IsValidReason(geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_IsValidReason( + geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, int]] = None +) -> Column: """ Provides a text description of why a geometry is not valid or states that it is valid. An optional flag parameter can be provided for additional options. @@ -864,6 +948,7 @@ def ST_IsValidReason(geometry: ColumnOrName, flag: Optional[Union[ColumnOrName, args = (geometry,) if flag is None else (geometry, flag) return _call_st_function("ST_IsValidReason", args) + @validate_argument_types def ST_Length(geometry: ColumnOrName) -> Column: """Calculate the length of a linestring geometry. @@ -875,6 +960,7 @@ def ST_Length(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_Length", geometry) + @validate_argument_types def ST_Length2D(geometry: ColumnOrName) -> Column: """Calculate the length of a linestring geometry. @@ -886,6 +972,7 @@ def ST_Length2D(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_Length2D", geometry) + @validate_argument_types def ST_LengthSpheroid(geometry: ColumnOrName) -> Column: """Calculate the perimeter of a geometry using WGS84 spheroid. @@ -897,6 +984,7 @@ def ST_LengthSpheroid(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_LengthSpheroid", geometry) + @validate_argument_types def ST_LineFromMultiPoint(geometry: ColumnOrName) -> Column: """Creates a LineString from a MultiPoint geometry. @@ -910,7 +998,9 @@ def ST_LineFromMultiPoint(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_LineInterpolatePoint(geometry: ColumnOrName, fraction: ColumnOrNameOrNumber) -> Column: +def ST_LineInterpolatePoint( + geometry: ColumnOrName, fraction: ColumnOrNameOrNumber +) -> Column: """Calculate a point that is interpolated along a linestring. :param geometry: Linestring geometry column to interpolate from. @@ -922,6 +1012,7 @@ def ST_LineInterpolatePoint(geometry: ColumnOrName, fraction: ColumnOrNameOrNumb """ return _call_st_function("ST_LineInterpolatePoint", (geometry, fraction)) + @validate_argument_types def ST_LineLocatePoint(linestring: ColumnOrName, point: ColumnOrName) -> Column: """Returns a double between 0 and 1 representing the location of the closest point on a LineString to the given Point, as a fraction of 2d line length. @@ -950,7 +1041,11 @@ def ST_LineMerge(multi_line_string: ColumnOrName) -> Column: @validate_argument_types -def ST_LineSubstring(line_string: ColumnOrName, start_fraction: ColumnOrNameOrNumber, end_fraction: ColumnOrNameOrNumber) -> Column: +def ST_LineSubstring( + line_string: ColumnOrName, + start_fraction: ColumnOrNameOrNumber, + end_fraction: ColumnOrNameOrNumber, +) -> Column: """Generate a substring of a linestring geometry column. :param line_string: Linestring geometry column to generate from. @@ -962,10 +1057,17 @@ def ST_LineSubstring(line_string: ColumnOrName, start_fraction: ColumnOrNameOrNu :return: Smaller linestring that runs from start_fraction to end_fraction of line_string as a linestring geometry column, will be null if either start_fraction or end_fraction are outside the interval [0, 1]. :rtype: Column """ - return _call_st_function("ST_LineSubstring", (line_string, start_fraction, end_fraction)) + return _call_st_function( + "ST_LineSubstring", (line_string, start_fraction, end_fraction) + ) + @validate_argument_types -def ST_LocateAlong(geom: ColumnOrName, measure: Union[ColumnOrName, float], offset: Optional[Union[ColumnOrName, float]] = None) -> Column: +def ST_LocateAlong( + geom: ColumnOrName, + measure: Union[ColumnOrName, float], + offset: Optional[Union[ColumnOrName, float]] = None, +) -> Column: """return locations along a measure geometry that have the given measure value. :param geom: @@ -980,6 +1082,7 @@ def ST_LocateAlong(geom: ColumnOrName, measure: Union[ColumnOrName, float], offs args = (geom, measure) if offset is None else (geom, measure, offset) return _call_st_function("ST_LocateAlong", args) + @validate_argument_types def ST_LongestLine(geom1: ColumnOrName, geom2: ColumnOrName) -> Column: """Compute the longest line between the two geometries @@ -993,6 +1096,7 @@ def ST_LongestLine(geom1: ColumnOrName, geom2: ColumnOrName) -> Column: """ return _call_st_function("ST_LongestLine", (geom1, geom2)) + @validate_argument_types def ST_HasZ(geom: ColumnOrName) -> Column: """Check whether geometry has Z coordinate @@ -1004,6 +1108,7 @@ def ST_HasZ(geom: ColumnOrName) -> Column: """ return _call_st_function("ST_HasZ", geom) + @validate_argument_types def ST_HasM(geom: ColumnOrName) -> Column: """Check whether geometry has M coordinate @@ -1027,6 +1132,7 @@ def ST_M(geom: ColumnOrName) -> Column: """ return _call_st_function("ST_M", geom) + @validate_argument_types def ST_MMin(geom: ColumnOrName) -> Column: """Return the minimum M coordinate of a geometry. @@ -1038,6 +1144,7 @@ def ST_MMin(geom: ColumnOrName) -> Column: """ return _call_st_function("ST_MMin", geom) + @validate_argument_types def ST_MMax(geom: ColumnOrName) -> Column: """Return the maximum M coordinate of a geometry. @@ -1064,6 +1171,7 @@ def ST_MakeLine(geom1: ColumnOrName, geom2: Optional[ColumnOrName] = None) -> Co args = (geom1,) if geom2 is None else (geom1, geom2) return _call_st_function("ST_MakeLine", args) + @validate_argument_types def ST_Points(geometry: ColumnOrName) -> Column: """Creates a MultiPoint geometry consisting of all the coordinates of the input geometry @@ -1075,6 +1183,7 @@ def ST_Points(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_Points", (geometry)) + @validate_argument_types def ST_Polygon(line_string: ColumnOrName, srid: ColumnOrNameOrNumber) -> Column: """Create a polygon built from the given LineString and sets the spatial reference system from the srid. @@ -1088,6 +1197,7 @@ def ST_Polygon(line_string: ColumnOrName, srid: ColumnOrNameOrNumber) -> Column: """ return _call_st_function("ST_Polygon", (line_string, srid)) + @validate_argument_types def ST_Polygonize(geometry: ColumnOrName) -> Column: """Generates a GeometryCollection composed of polygons that are formed from the linework of a set of input geometries. @@ -1099,8 +1209,11 @@ def ST_Polygonize(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_Polygonize", (geometry)) + @validate_argument_types -def ST_MakePolygon(line_string: ColumnOrName, holes: Optional[ColumnOrName] = None) -> Column: +def ST_MakePolygon( + line_string: ColumnOrName, holes: Optional[ColumnOrName] = None +) -> Column: """Create a polygon geometry from a linestring describing the exterior ring as well as an array of linestrings describing holes. :param line_string: Closed linestring geometry column that describes the exterior ring of the polygon. @@ -1115,7 +1228,9 @@ def ST_MakePolygon(line_string: ColumnOrName, holes: Optional[ColumnOrName] = No @validate_argument_types -def ST_MakeValid(geometry: ColumnOrName, keep_collapsed: Optional[Union[ColumnOrName, bool]] = None) -> Column: +def ST_MakeValid( + geometry: ColumnOrName, keep_collapsed: Optional[Union[ColumnOrName, bool]] = None +) -> Column: """Convert an invalid geometry in a geometry column into a valid geometry. :param geometry: Geometry column that contains the invalid geometry. @@ -1133,13 +1248,14 @@ def ST_MakeValid(geometry: ColumnOrName, keep_collapsed: Optional[Union[ColumnOr def ST_MaximumInscribedCircle(geometry: ColumnOrName) -> Column: """Finds the largest circle that is contained within a geometry, or which does not overlap any lines and points - :param geometry: - :type geometry: ColumnOrName - :return: Row of center point, nearest point and radius - :rtype: Column - """ + :param geometry: + :type geometry: ColumnOrName + :return: Row of center point, nearest point and radius + :rtype: Column + """ return _call_st_function("ST_MaximumInscribedCircle", geometry) + @validate_argument_types def ST_MaxDistance(geom1: ColumnOrName, geom2: ColumnOrName) -> Column: """Calculate the maximum distance between two furthest points in the geometries @@ -1153,6 +1269,7 @@ def ST_MaxDistance(geom1: ColumnOrName, geom2: ColumnOrName) -> Column: """ return _call_st_function("ST_MaxDistance", (geom1, geom2)) + @validate_argument_types def ST_MinimumClearance(geometry: ColumnOrName) -> Column: """Calculate the minimum clearance between two vertices @@ -1164,6 +1281,7 @@ def ST_MinimumClearance(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_MinimumClearance", geometry) + @validate_argument_types def ST_MinimumClearanceLine(geometry: ColumnOrName) -> Column: """Calculate the minimum clearance Linestring between two vertices @@ -1175,8 +1293,11 @@ def ST_MinimumClearanceLine(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_MinimumClearanceLine", geometry) + @validate_argument_types -def ST_MinimumBoundingCircle(geometry: ColumnOrName, quadrant_segments: Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_MinimumBoundingCircle( + geometry: ColumnOrName, quadrant_segments: Optional[Union[ColumnOrName, int]] = None +) -> Column: """Generate the minimum bounding circle that contains a geometry. :param geometry: Geometry column to generate minimum bounding circles for. @@ -1273,6 +1394,7 @@ def ST_NumInteriorRings(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_NumInteriorRings", geometry) + @validate_argument_types def ST_NumInteriorRing(geometry: ColumnOrName) -> Column: """Return the number of interior rings contained in a polygon geometry. @@ -1312,7 +1434,9 @@ def ST_PointOnSurface(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_ReducePrecision(geometry: ColumnOrName, precision: Union[ColumnOrName, int]) -> Column: +def ST_ReducePrecision( + geometry: ColumnOrName, precision: Union[ColumnOrName, int] +) -> Column: """Reduce the precision of the coordinates in geometry to a specified number of decimal places. :param geometry: Geometry to reduce the precision of. @@ -1326,7 +1450,9 @@ def ST_ReducePrecision(geometry: ColumnOrName, precision: Union[ColumnOrName, in @validate_argument_types -def ST_RemovePoint(line_string: ColumnOrName, index: Union[ColumnOrName, int]) -> Column: +def ST_RemovePoint( + line_string: ColumnOrName, index: Union[ColumnOrName, int] +) -> Column: """Remove the specified point (0-th based) for a linestring geometry column. :param line_string: Linestring geometry column to remove the point from. @@ -1340,7 +1466,9 @@ def ST_RemovePoint(line_string: ColumnOrName, index: Union[ColumnOrName, int]) - @validate_argument_types -def ST_RemoveRepeatedPoints(geom: ColumnOrName, tolerance: Optional[Union[ColumnOrName, float]] = None) -> Column: +def ST_RemoveRepeatedPoints( + geom: ColumnOrName, tolerance: Optional[Union[ColumnOrName, float]] = None +) -> Column: """Removes duplicate coordinates from a geometry, optionally removing those within a specified distance tolerance. @param geom: Geometry with repeated points @@ -1394,7 +1522,9 @@ def ST_S2ToGeom(cells: Union[ColumnOrName, list]) -> Column: @validate_argument_types -def ST_SetPoint(line_string: ColumnOrName, index: Union[ColumnOrName, int], point: ColumnOrName) -> Column: +def ST_SetPoint( + line_string: ColumnOrName, index: Union[ColumnOrName, int], point: ColumnOrName +) -> Column: """Replace a point in a linestring. :param line_string: Linestring geometry column which contains the point to be replaced. @@ -1422,8 +1552,11 @@ def ST_SetSRID(geometry: ColumnOrName, srid: Union[ColumnOrName, int]) -> Column """ return _call_st_function("ST_SetSRID", (geometry, srid)) + @validate_argument_types -def ST_Snap(input: ColumnOrName, reference: ColumnOrName, tolerance: Union[ColumnOrName, float]) -> Column: +def ST_Snap( + input: ColumnOrName, reference: ColumnOrName, tolerance: Union[ColumnOrName, float] +) -> Column: """Snaps input Geometry to reference Geometry controlled by distance tolerance. :param input: Geometry @@ -1481,7 +1614,9 @@ def ST_StartPoint(line_string: ColumnOrName) -> Column: @validate_argument_types -def ST_SubDivide(geometry: ColumnOrName, max_vertices: Union[ColumnOrName, int]) -> Column: +def ST_SubDivide( + geometry: ColumnOrName, max_vertices: Union[ColumnOrName, int] +) -> Column: """Subdivide a geometry into an array of geometries with at maximum number of vertices in each. :param geometry: Geometry column to subdivide. @@ -1495,7 +1630,9 @@ def ST_SubDivide(geometry: ColumnOrName, max_vertices: Union[ColumnOrName, int]) @validate_argument_types -def ST_SubDivideExplode(geometry: ColumnOrName, max_vertices: Union[ColumnOrName, int]) -> Column: +def ST_SubDivideExplode( + geometry: ColumnOrName, max_vertices: Union[ColumnOrName, int] +) -> Column: """Same as ST_SubDivide except also explode the generated array into multiple rows. :param geometry: Geometry column to subdivide. @@ -1509,7 +1646,9 @@ def ST_SubDivideExplode(geometry: ColumnOrName, max_vertices: Union[ColumnOrName @validate_argument_types -def ST_SimplifyPreserveTopology(geometry: ColumnOrName, distance_tolerance: ColumnOrNameOrNumber) -> Column: +def ST_SimplifyPreserveTopology( + geometry: ColumnOrName, distance_tolerance: ColumnOrNameOrNumber +) -> Column: """Simplify a geometry within a specified tolerance while preserving topological relationships. :param geometry: Geometry column to simplify. @@ -1519,10 +1658,15 @@ def ST_SimplifyPreserveTopology(geometry: ColumnOrName, distance_tolerance: Colu :return: Simplified geometry as a geometry column. :rtype: Column """ - return _call_st_function("ST_SimplifyPreserveTopology", (geometry, distance_tolerance)) + return _call_st_function( + "ST_SimplifyPreserveTopology", (geometry, distance_tolerance) + ) + @validate_argument_types -def ST_SimplifyVW(geometry: ColumnOrName, distance_tolerance: ColumnOrNameOrNumber) -> Column: +def ST_SimplifyVW( + geometry: ColumnOrName, distance_tolerance: ColumnOrNameOrNumber +) -> Column: """Simplify a geometry using Visvalingam-Whyatt algorithm within a specified tolerance while preserving topological relationships. :param geometry: Geometry column to simplify. @@ -1536,7 +1680,11 @@ def ST_SimplifyVW(geometry: ColumnOrName, distance_tolerance: ColumnOrNameOrNumb @validate_argument_types -def ST_SimplifyPolygonHull(geometry: ColumnOrName, vertexFactor: ColumnOrNameOrNumber, isOuter: Optional[Union[ColumnOrName, bool]] = None) -> Column: +def ST_SimplifyPolygonHull( + geometry: ColumnOrName, + vertexFactor: ColumnOrNameOrNumber, + isOuter: Optional[Union[ColumnOrName, bool]] = None, +) -> Column: """Simplify a geometry using Visvalingam-Whyatt algorithm within a specified tolerance while preserving topological relationships. :param geometry: Geometry column to simplify. @@ -1546,10 +1694,15 @@ def ST_SimplifyPolygonHull(geometry: ColumnOrName, vertexFactor: ColumnOrNameOrN :return: Simplified geometry as a geometry column. :rtype: Column """ - args = (geometry, vertexFactor) if isOuter is None else (geometry, vertexFactor, isOuter) + args = ( + (geometry, vertexFactor) + if isOuter is None + else (geometry, vertexFactor, isOuter) + ) return _call_st_function("ST_SimplifyPolygonHull", args) + @validate_argument_types def ST_Split(input: ColumnOrName, blade: ColumnOrName) -> Column: """Split input geometry by the blade geometry. @@ -1579,7 +1732,12 @@ def ST_SymDifference(a: ColumnOrName, b: ColumnOrName) -> Column: @validate_argument_types -def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: Optional[Union[ColumnOrName, str]] = None, disable_error: Optional[Union[ColumnOrName, bool]] = None) -> Column: +def ST_Transform( + geometry: ColumnOrName, + source_crs: ColumnOrName, + target_crs: Optional[Union[ColumnOrName, str]] = None, + disable_error: Optional[Union[ColumnOrName, bool]] = None, +) -> Column: """Convert a geometry from one coordinate system to another coordinate system. :param geometry: Geometry column to convert. @@ -1607,6 +1765,7 @@ def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: O args = (geometry, source_crs, target_crs, disable_error) return _call_st_function("ST_Transform", args) + @validate_argument_types def ST_TriangulatePolygon(geom: ColumnOrName) -> Column: """Computes the constrained Delaunay triangulation of polygons. Holes and Multipolygons are supported. @@ -1618,17 +1777,19 @@ def ST_TriangulatePolygon(geom: ColumnOrName) -> Column: """ return _call_st_function("ST_TriangulatePolygon", geom) + @validate_argument_types def ST_UnaryUnion(geom: ColumnOrName) -> Column: """Calculate the unary union of a geometry - :param geom: Geometry to do union - :type geom: ColumnOrName - :return: Geometry representing the unary union of geom as a geometry column. - :rtype: Column - """ + :param geom: Geometry to do union + :type geom: ColumnOrName + :return: Geometry representing the unary union of geom as a geometry column. + :rtype: Column + """ return _call_st_function("ST_UnaryUnion", geom) + @validate_argument_types def ST_Union(a: ColumnOrName, b: Optional[ColumnOrName] = None) -> Column: """Calculate the union of two geometries. @@ -1732,6 +1893,7 @@ def ST_Z(point: ColumnOrName) -> Column: """ return _call_st_function("ST_Z", point) + @validate_argument_types def ST_Zmflag(geom: ColumnOrName) -> Column: """Return the code indicating the ZM coordinate dimension of a geometry @@ -1744,6 +1906,7 @@ def ST_Zmflag(geom: ColumnOrName) -> Column: """ return _call_st_function("ST_Zmflag", geom) + @validate_argument_types def ST_ZMax(geometry: ColumnOrName) -> Column: """Return the maximum Z coordinate of a geometry. @@ -1755,6 +1918,7 @@ def ST_ZMax(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_ZMax", geometry) + @validate_argument_types def ST_ZMin(geometry: ColumnOrName) -> Column: """Return the minimum Z coordinate of a geometry. @@ -1765,6 +1929,8 @@ def ST_ZMin(geometry: ColumnOrName) -> Column: :rtype: Column """ return _call_st_function("ST_ZMin", geometry) + + @validate_argument_types def ST_NumPoints(geometry: ColumnOrName) -> Column: """Return the number of points in a LineString @@ -1775,8 +1941,11 @@ def ST_NumPoints(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_NumPoints", geometry) + @validate_argument_types -def ST_Force3D(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, float]] = 0.0) -> Column: +def ST_Force3D( + geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, float]] = 0.0 +) -> Column: """ Return a geometry with a 3D coordinate of value 'zValue' forced upon it. No change happens if the geometry is already 3D :param zValue: Optional value of z coordinate to be potentially added, default value is 0.0 @@ -1786,8 +1955,11 @@ def ST_Force3D(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, floa args = (geometry, zValue) return _call_st_function("ST_Force3D", args) + @validate_argument_types -def ST_Force3DM(geometry: ColumnOrName, mValue: Optional[Union[ColumnOrName, float]] = 0.0) -> Column: +def ST_Force3DM( + geometry: ColumnOrName, mValue: Optional[Union[ColumnOrName, float]] = 0.0 +) -> Column: """ Return a geometry with a 3D coordinate of value 'mValue' forced upon it. No change happens if the geometry is already 3D :param mValue: Optional value of m coordinate to be potentially added, default value is 0.0 @@ -1797,8 +1969,11 @@ def ST_Force3DM(geometry: ColumnOrName, mValue: Optional[Union[ColumnOrName, flo args = (geometry, mValue) return _call_st_function("ST_Force3DM", args) + @validate_argument_types -def ST_Force3DZ(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, float]] = 0.0) -> Column: +def ST_Force3DZ( + geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, float]] = 0.0 +) -> Column: """ Return a geometry with a 3D coordinate of value 'zValue' forced upon it. No change happens if the geometry is already 3D :param zValue: Optional value of z coordinate to be potentially added, default value is 0.0 @@ -1808,9 +1983,13 @@ def ST_Force3DZ(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, flo args = (geometry, zValue) return _call_st_function("ST_Force3DZ", args) + @validate_argument_types -def ST_Force4D(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, float]] = 0.0, - mValue: Optional[Union[ColumnOrName, float]] = 0.0) -> Column: +def ST_Force4D( + geometry: ColumnOrName, + zValue: Optional[Union[ColumnOrName, float]] = 0.0, + mValue: Optional[Union[ColumnOrName, float]] = 0.0, +) -> Column: """ Return a geometry with a 4D coordinate of value 'zValue' and mValue forced upon it. No change happens if the geometry is already 4D, if geometry either has z or m, it will not change the existing z or m value. @@ -1822,6 +2001,7 @@ def ST_Force4D(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, floa args = (geometry, zValue, mValue) return _call_st_function("ST_Force4D", args) + @validate_argument_types def ST_ForceCollection(geometry: ColumnOrName) -> Column: """ @@ -1832,6 +2012,7 @@ def ST_ForceCollection(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_ForceCollection", geometry) + @validate_argument_types def ST_ForcePolygonCW(geometry: ColumnOrName) -> Column: """ @@ -1841,6 +2022,7 @@ def ST_ForcePolygonCW(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_ForcePolygonCW", geometry) + @validate_argument_types def ST_ForceRHR(geometry: ColumnOrName) -> Column: """ @@ -1850,6 +2032,7 @@ def ST_ForceRHR(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_ForceRHR", geometry) + @validate_argument_types def ST_NRings(geometry: ColumnOrName) -> Column: """ @@ -1858,8 +2041,15 @@ def ST_NRings(geometry: ColumnOrName) -> Column: :return: Number of exterior rings + interior rings (if any) for the given Polygon or MultiPolygon """ return _call_st_function("ST_NRings", geometry) + + @validate_argument_types -def ST_Translate(geometry: ColumnOrName, deltaX: Union[ColumnOrName, float], deltaY: Union[ColumnOrName, float], deltaZ: Optional[Union[ColumnOrName, float]] = 0.0) -> Column: +def ST_Translate( + geometry: ColumnOrName, + deltaX: Union[ColumnOrName, float], + deltaY: Union[ColumnOrName, float], + deltaZ: Optional[Union[ColumnOrName, float]] = 0.0, +) -> Column: """ Returns the geometry with x, y and z (if present) coordinates offset by given deltaX, deltaY, and deltaZ values. :param geometry: Geometry column whose coordinates are to be translated. @@ -1871,8 +2061,13 @@ def ST_Translate(geometry: ColumnOrName, deltaX: Union[ColumnOrName, float], del args = (geometry, deltaX, deltaY, deltaZ) return _call_st_function("ST_Translate", args) + @validate_argument_types -def ST_VoronoiPolygons(geometry: ColumnOrName, tolerance: Optional[Union[ColumnOrName, float]] = 0.0, extendTo: Optional[ColumnOrName] = None) -> Column: +def ST_VoronoiPolygons( + geometry: ColumnOrName, + tolerance: Optional[Union[ColumnOrName, float]] = 0.0, + extendTo: Optional[ColumnOrName] = None, +) -> Column: """ ST_VoronoiPolygons computes a two-dimensional Voronoi diagram from the vertices of the supplied geometry. The result is a GeometryCollection of Polygons that covers an envelope larger than the extent of the input vertices. @@ -1887,6 +2082,7 @@ def ST_VoronoiPolygons(geometry: ColumnOrName, tolerance: Optional[Union[ColumnO args = (geometry, tolerance, extendTo) return _call_st_function("ST_VoronoiPolygons", args) + @validate_argument_types def ST_FrechetDistance(g1: ColumnOrName, g2: ColumnOrName) -> Column: """ @@ -1900,11 +2096,23 @@ def ST_FrechetDistance(g1: ColumnOrName, g2: ColumnOrName) -> Column: args = (g1, g2) return _call_st_function("ST_FrechetDistance", args) + @validate_argument_types -def ST_Affine(geometry: ColumnOrName, a: Union[ColumnOrName, float], b: Union[ColumnOrName, float], d: Union[ColumnOrName, float], - e: Union[ColumnOrName, float], xOff: Union[ColumnOrName, float], yOff: Union[ColumnOrName, float], c: Optional[Union[ColumnOrName, float]] = None, f: Optional[Union[ColumnOrName, float]] = None, - g: Optional[Union[ColumnOrName, float]] = None, h: Optional[Union[ColumnOrName, float]] = None, - i: Optional[Union[ColumnOrName, float]] = None, zOff: Optional[Union[ColumnOrName, float]] = None) -> Column: +def ST_Affine( + geometry: ColumnOrName, + a: Union[ColumnOrName, float], + b: Union[ColumnOrName, float], + d: Union[ColumnOrName, float], + e: Union[ColumnOrName, float], + xOff: Union[ColumnOrName, float], + yOff: Union[ColumnOrName, float], + c: Optional[Union[ColumnOrName, float]] = None, + f: Optional[Union[ColumnOrName, float]] = None, + g: Optional[Union[ColumnOrName, float]] = None, + h: Optional[Union[ColumnOrName, float]] = None, + i: Optional[Union[ColumnOrName, float]] = None, + zOff: Optional[Union[ColumnOrName, float]] = None, +) -> Column: """ Apply a 3D/2D affine transformation to the given geometry x = a * x + b * y + c * z + xOff | x = a * x + b * y + xOff @@ -1930,6 +2138,7 @@ def ST_Affine(geometry: ColumnOrName, a: Union[ColumnOrName, float], b: Union[Co args = (geometry, a, b, c, d, e, f, g, h, i, xOff, yOff, zOff) return _call_st_function("ST_Affine", args) + @validate_argument_types def ST_BoundingDiagonal(geometry: ColumnOrName) -> Column: """ @@ -1943,7 +2152,12 @@ def ST_BoundingDiagonal(geometry: ColumnOrName) -> Column: @validate_argument_types -def ST_Angle(g1: ColumnOrName, g2: ColumnOrName, g3: Optional[ColumnOrName] = None, g4: Optional[ColumnOrName] = None) -> Column: +def ST_Angle( + g1: ColumnOrName, + g2: ColumnOrName, + g3: Optional[ColumnOrName] = None, + g4: Optional[ColumnOrName] = None, +) -> Column: """ Returns the computed angle between vectors formed by given geometries in radian. Range of result is between 0 and 2 * pi. 3 Variants: @@ -1968,6 +2182,7 @@ def ST_Angle(g1: ColumnOrName, g2: ColumnOrName, g3: Optional[ColumnOrName] = No # args = (g1, g2, g3, g4) return _call_st_function("ST_Angle", args) + @validate_argument_types def ST_Degrees(angleInRadian: Union[ColumnOrName, float]) -> Column: """ @@ -1977,8 +2192,13 @@ def ST_Degrees(angleInRadian: Union[ColumnOrName, float]) -> Column: """ return _call_st_function("ST_Degrees", angleInRadian) + @validate_argument_types -def ST_DelaunayTriangles(geometry: ColumnOrName, tolerance: Optional[Union[ColumnOrName, float]] = None, flag: Optional[Union[ColumnOrName, int]] = None) -> Column: +def ST_DelaunayTriangles( + geometry: ColumnOrName, + tolerance: Optional[Union[ColumnOrName, float]] = None, + flag: Optional[Union[ColumnOrName, int]] = None, +) -> Column: """ Computes the Delaunay Triangles of the vertices of the input geometry. @@ -1993,15 +2213,20 @@ def ST_DelaunayTriangles(geometry: ColumnOrName, tolerance: Optional[Union[Colum """ if flag is None and tolerance is None: - args = (geometry) + args = geometry elif flag is None: args = (geometry, tolerance) else: args = (geometry, tolerance, flag) return _call_st_function("ST_DelaunayTriangles", args) + @validate_argument_types -def ST_HausdorffDistance(g1: ColumnOrName, g2: ColumnOrName, densityFrac: Optional[Union[ColumnOrName, float]] = -1) -> Column: +def ST_HausdorffDistance( + g1: ColumnOrName, + g2: ColumnOrName, + densityFrac: Optional[Union[ColumnOrName, float]] = -1, +) -> Column: """ Returns discretized (and hence approximate) hausdorff distance between two given geometries. Optionally, a distance fraction can also be provided which decreases the gap between actual and discretized hausforff distance @@ -2013,6 +2238,7 @@ def ST_HausdorffDistance(g1: ColumnOrName, g2: ColumnOrName, densityFrac: Option args = (g1, g2, densityFrac) return _call_st_function("ST_HausdorffDistance", args) + @validate_argument_types def ST_CoordDim(geometry: ColumnOrName) -> Column: """Return the number of dimensions contained in a coordinate @@ -2024,6 +2250,7 @@ def ST_CoordDim(geometry: ColumnOrName) -> Column: """ return _call_st_function("ST_CoordDim", geometry) + @validate_argument_types def ST_IsCollection(geometry: ColumnOrName) -> Column: """Check if the geometry is of GeometryCollection type. @@ -2065,8 +2292,13 @@ def ST_RotateY(geometry: ColumnOrName, angle: Union[ColumnOrName, float]) -> Col @validate_argument_types -def ST_Rotate(geometry: ColumnOrName, angle: Union[ColumnOrName, float], originX: Union[ColumnOrName, float] = None, - originY: Union[ColumnOrName, float] = None, pointOrigin: ColumnOrName = None) -> Column: +def ST_Rotate( + geometry: ColumnOrName, + angle: Union[ColumnOrName, float], + originX: Union[ColumnOrName, float] = None, + originY: Union[ColumnOrName, float] = None, + pointOrigin: ColumnOrName = None, +) -> Column: """Return a counter-clockwise rotated geometry along the specified origin. :param geometry: Geometry column or name. @@ -2093,5 +2325,8 @@ def ST_Rotate(geometry: ColumnOrName, angle: Union[ColumnOrName, float], originX # Automatically populate __all__ -__all__ = [name for name, obj in inspect.getmembers(sys.modules[__name__]) - if inspect.isfunction(obj) and name != 'GeometryType'] +__all__ = [ + name + for name, obj in inspect.getmembers(sys.modules[__name__]) + if inspect.isfunction(obj) and name != "GeometryType" +] diff --git a/python/sedona/sql/st_predicates.py b/python/sedona/sql/st_predicates.py index 32efd7c5c9..b91edf8eec 100644 --- a/python/sedona/sql/st_predicates.py +++ b/python/sedona/sql/st_predicates.py @@ -22,7 +22,11 @@ from pyspark.sql import Column from typing import Union, Optional -from sedona.sql.dataframe_api import ColumnOrName, call_sedona_function, validate_argument_types +from sedona.sql.dataframe_api import ( + ColumnOrName, + call_sedona_function, + validate_argument_types, +) _call_predicate_function = partial(call_sedona_function, "st_predicates") @@ -55,6 +59,7 @@ def ST_Crosses(a: ColumnOrName, b: ColumnOrName) -> Column: """ return _call_predicate_function("ST_Crosses", (a, b)) + @validate_argument_types def ST_Disjoint(a: ColumnOrName, b: ColumnOrName) -> Column: """Check whether two geometries are disjoint. @@ -138,8 +143,11 @@ def ST_Touches(a: ColumnOrName, b: ColumnOrName) -> Column: """ return _call_predicate_function("ST_Touches", (a, b)) + @validate_argument_types -def ST_Relate(a: ColumnOrName, b: ColumnOrName, intersectionMatrix: Optional[ColumnOrName] = None) -> Column: +def ST_Relate( + a: ColumnOrName, b: ColumnOrName, intersectionMatrix: Optional[ColumnOrName] = None +) -> Column: """Check whether two geometries are related to each other. :param a: One geometry column to check. @@ -155,6 +163,7 @@ def ST_Relate(a: ColumnOrName, b: ColumnOrName, intersectionMatrix: Optional[Col return _call_predicate_function("ST_Relate", args) + @validate_argument_types def ST_RelateMatch(matrix1: ColumnOrName, matrix2: ColumnOrName) -> Column: """Check whether two DE-9IM are related to each other. @@ -210,8 +219,14 @@ def ST_CoveredBy(a: ColumnOrName, b: ColumnOrName) -> Column: """ return _call_predicate_function("ST_CoveredBy", (a, b)) + @validate_argument_types -def ST_DWithin(a: ColumnOrName, b: ColumnOrName, distance: Union[ColumnOrName, float], use_sphere: Optional[Union[ColumnOrName, bool]] = None): +def ST_DWithin( + a: ColumnOrName, + b: ColumnOrName, + distance: Union[ColumnOrName, float], + use_sphere: Optional[Union[ColumnOrName, bool]] = None, +): """ Check if geometry a is within 'distance' units of geometry b :param a: Geometry column to check @@ -220,10 +235,21 @@ def ST_DWithin(a: ColumnOrName, b: ColumnOrName, distance: Union[ColumnOrName, f :param use_sphere: whether to use spheroid distance or euclidean distance :return: True if a is within distance units of Geometry b """ - args = (a, b, distance, use_sphere) if use_sphere is not None else (a, b, distance,) + args = ( + (a, b, distance, use_sphere) + if use_sphere is not None + else ( + a, + b, + distance, + ) + ) return _call_predicate_function("ST_DWithin", args) # Automatically populate __all__ -__all__ = [name for name, obj in inspect.getmembers(sys.modules[__name__]) - if inspect.isfunction(obj)] +__all__ = [ + name + for name, obj in inspect.getmembers(sys.modules[__name__]) + if inspect.isfunction(obj) +] diff --git a/python/sedona/utils/abstract_parser.py b/python/sedona/utils/abstract_parser.py index 2679a26f67..fdf3a067c5 100644 --- a/python/sedona/utils/abstract_parser.py +++ b/python/sedona/utils/abstract_parser.py @@ -29,9 +29,9 @@ def name(self): raise NotImplementedError @classmethod - def serialize(cls, obj: BaseGeometry, binary_buffer: 'BinaryBuffer'): + def serialize(cls, obj: BaseGeometry, binary_buffer: "BinaryBuffer"): raise NotImplementedError("Parser has to implement serialize method") @classmethod - def deserialize(cls, bin_parser: 'BinaryParser') -> BaseGeometry: + def deserialize(cls, bin_parser: "BinaryParser") -> BaseGeometry: raise NotImplementedError("Parser has to implement deserialize method") diff --git a/python/sedona/utils/adapter.py b/python/sedona/utils/adapter.py index 5fab3de792..a7d786e5ed 100644 --- a/python/sedona/utils/adapter.py +++ b/python/sedona/utils/adapter.py @@ -33,17 +33,17 @@ class Adapter(metaclass=MultipleMeta): @staticmethod def _create_dataframe(jdf, sparkSession: SparkSession) -> DataFrame: - if hasattr(sparkSession, '_wrapped'): + if hasattr(sparkSession, "_wrapped"): # In Spark < 3.3, use the _wrapped SQLContext return DataFrame(jdf, sparkSession._wrapped) else: # In Spark >= 3.3, use the session directly return DataFrame(jdf, sparkSession) - @classmethod - def toRdd(cls, dataFrame: DataFrame) -> 'JvmSpatialRDD': + def toRdd(cls, dataFrame: DataFrame) -> "JvmSpatialRDD": from sedona.core.SpatialRDD.spatial_rdd import JvmSpatialRDD + sc = dataFrame._sc jvm = sc._jvm @@ -70,7 +70,9 @@ def toSpatialRdd(cls, dataFrame: DataFrame, geometryFieldName: str) -> SpatialRD return spatial_rdd @classmethod - def toSpatialRdd(cls, dataFrame: DataFrame, geometryFieldName: str, fieldNames: List) -> SpatialRDD: + def toSpatialRdd( + cls, dataFrame: DataFrame, geometryFieldName: str, fieldNames: List + ) -> SpatialRDD: """ :param dataFrame: @@ -81,7 +83,9 @@ def toSpatialRdd(cls, dataFrame: DataFrame, geometryFieldName: str, fieldNames: sc = dataFrame._sc jvm = sc._jvm - srdd = jvm.PythonAdapterWrapper.toSpatialRdd(dataFrame._jdf, geometryFieldName, fieldNames) + srdd = jvm.PythonAdapterWrapper.toSpatialRdd( + dataFrame._jdf, geometryFieldName, fieldNames + ) spatial_rdd = SpatialRDD(sc) spatial_rdd.set_srdd(srdd) @@ -89,7 +93,9 @@ def toSpatialRdd(cls, dataFrame: DataFrame, geometryFieldName: str, fieldNames: return spatial_rdd @classmethod - def toDf(cls, spatialRDD: SpatialRDD, fieldNames: List, sparkSession: SparkSession) -> DataFrame: + def toDf( + cls, spatialRDD: SpatialRDD, fieldNames: List, sparkSession: SparkSession + ) -> DataFrame: """ :param spatialRDD: @@ -100,7 +106,9 @@ def toDf(cls, spatialRDD: SpatialRDD, fieldNames: List, sparkSession: SparkSessi sc = spatialRDD._sc jvm = sc._jvm - jdf = jvm.PythonAdapterWrapper.toDf(spatialRDD._srdd, fieldNames, sparkSession._jsparkSession) + jdf = jvm.PythonAdapterWrapper.toDf( + spatialRDD._srdd, fieldNames, sparkSession._jsparkSession + ) df = Adapter._create_dataframe(jdf, sparkSession) @@ -132,13 +140,24 @@ def toDf(cls, spatialPairRDD: RDD, sparkSession: SparkSession): :return: """ spatial_pair_rdd_mapped = spatialPairRDD.map( - lambda x: [x[0].geom, *x[0].getUserData().split("\t"), x[1].geom, *x[1].getUserData().split("\t")] + lambda x: [ + x[0].geom, + *x[0].getUserData().split("\t"), + x[1].geom, + *x[1].getUserData().split("\t"), + ] ) df = sparkSession.createDataFrame(spatial_pair_rdd_mapped) return df @classmethod - def toDf(cls, spatialPairRDD: RDD, leftFieldnames: List, rightFieldNames: List, sparkSession: SparkSession): + def toDf( + cls, + spatialPairRDD: RDD, + leftFieldnames: List, + rightFieldNames: List, + sparkSession: SparkSession, + ): """ :param spatialPairRDD: @@ -164,15 +183,27 @@ def toDf(cls, rawPairRDD: SedonaPairRDD, sparkSession: SparkSession): return df @classmethod - def toDf(cls, rawPairRDD: SedonaPairRDD, leftFieldnames: List, rightFieldNames: List, sparkSession: SparkSession): + def toDf( + cls, + rawPairRDD: SedonaPairRDD, + leftFieldnames: List, + rightFieldNames: List, + sparkSession: SparkSession, + ): jvm = sparkSession._jvm jdf = jvm.PythonAdapterWrapper.toDf( - rawPairRDD.jsrdd, leftFieldnames, rightFieldNames, sparkSession._jsparkSession) + rawPairRDD.jsrdd, + leftFieldnames, + rightFieldNames, + sparkSession._jsparkSession, + ) df = Adapter._create_dataframe(jdf, sparkSession) return df @classmethod - def toDf(cls, spatialRDD: SedonaRDD, spark: SparkSession, fieldNames: List = None) -> DataFrame: + def toDf( + cls, spatialRDD: SedonaRDD, spark: SparkSession, fieldNames: List = None + ) -> DataFrame: srdd = SpatialRDD(spatialRDD.sc) srdd.setRawSpatialRDD(spatialRDD.jsrdd) if fieldNames: diff --git a/python/sedona/utils/binary_parser.py b/python/sedona/utils/binary_parser.py index d680421b15..473a634303 100644 --- a/python/sedona/utils/binary_parser.py +++ b/python/sedona/utils/binary_parser.py @@ -33,7 +33,7 @@ "i": INT_SIZE, "b": BYTE_SIZE, "s": CHAR_SIZE, - "?": BOOLEAN_SIZE + "?": BOOLEAN_SIZE, } @@ -47,8 +47,12 @@ def __attrs_post_init__(self): self.bytes = self._convert_to_binary_array(no_negatives) def read_geometry(self, length: int): - geom_bytes = b"".join([struct.pack("b", el) if el < 128 else struct.pack("b", el - 256) for el in - self.bytes[self.current_index: self.current_index + length]]) + geom_bytes = b"".join( + [ + struct.pack("b", el) if el < 128 else struct.pack("b", el - 256) + for el in self.bytes[self.current_index : self.current_index + length] + ] + ) geom = loads(geom_bytes) self.current_index += length return geom @@ -84,7 +88,7 @@ def read_boolean(self): return data def read_string(self, length: int, encoding: str = "utf8"): - string = self.bytes[self.current_index: self.current_index + length] + string = self.bytes[self.current_index : self.current_index + length] self.current_index += length try: @@ -97,7 +101,7 @@ def read_kryo_string(self, length: int, sc: SparkContext) -> str: array_length = length - self.current_index byte_array = sc._gateway.new_array(sc._jvm.Byte, array_length) - for index, bt in enumerate(self.bytes[self.current_index: length]): + for index, bt in enumerate(self.bytes[self.current_index : length]): byte_array[index] = self.bytes[self.current_index + index] decoded_string = sc._jvm.org.imbruced.geo_pyspark.serializers.GeoSerializerData.deserializeUserData( byte_array @@ -107,12 +111,16 @@ def read_kryo_string(self, length: int, sc: SparkContext) -> str: def unpack(self, tp: str, bytes: bytearray): max_index = self.current_index + size_dict[tp] - bytes = self._convert_to_binary_array(bytes[self.current_index: max_index]) + bytes = self._convert_to_binary_array(bytes[self.current_index : max_index]) return struct.unpack(tp, bytes)[0] def unpack_reverse(self, tp: str, bytes: bytearray): max_index = self.current_index + size_dict[tp] - bytes = bytearray(reversed(self._convert_to_binary_array(bytes[self.current_index: max_index]))) + bytes = bytearray( + reversed( + self._convert_to_binary_array(bytes[self.current_index : max_index]) + ) + ) return struct.unpack(tp, bytes)[0] @classmethod diff --git a/python/sedona/utils/decorators.py b/python/sedona/utils/decorators.py index 1b59db3ae7..9d10212b31 100644 --- a/python/sedona/utils/decorators.py +++ b/python/sedona/utils/decorators.py @@ -17,7 +17,7 @@ from typing import List, Iterable, Callable, TypeVar -T = TypeVar('T') +T = TypeVar("T") class classproperty(object): @@ -32,7 +32,9 @@ def __set__(self, instance, value): return self.f() -def get_first_meet_criteria_element_from_iterable(iterable: Iterable[T], criteria: Callable[[T], int]) -> int: +def get_first_meet_criteria_element_from_iterable( + iterable: Iterable[T], criteria: Callable[[T], int] +) -> int: for index, element in enumerate(iterable): if criteria(element): return index @@ -43,6 +45,7 @@ def require(library_names: List[str]): def wrapper(func): def run_function(*args, **kwargs): from sedona.core.utils import ImportedJvmLib + has_all_libs = [lib for lib in library_names] first_not_fulfill_value = get_first_meet_criteria_element_from_iterable( has_all_libs, lambda x: not ImportedJvmLib.has_library(x) @@ -53,7 +56,8 @@ def run_function(*args, **kwargs): else: raise ModuleNotFoundError( f"Did not found {has_all_libs[first_not_fulfill_value]}, make sure that was correctly imported via py4j" - f"Did you use SedonaRegistrator.registerAll, Your jars were properly copied to $SPARK_HOME/jars ? ") + f"Did you use SedonaRegistrator.registerAll, Your jars were properly copied to $SPARK_HOME/jars ? " + ) return run_function diff --git a/python/sedona/utils/geometry_serde.py b/python/sedona/utils/geometry_serde.py index 32a646cbe8..cf872dfb2e 100644 --- a/python/sedona/utils/geometry_serde.py +++ b/python/sedona/utils/geometry_serde.py @@ -34,19 +34,24 @@ def find_geos_c_dll(): packages_dir = os.path.dirname(os.path.dirname(shapely.__file__)) - for lib_dirname in ['shapely.libs', 'Shapely.libs']: + for lib_dirname in ["shapely.libs", "Shapely.libs"]: lib_dirpath = os.path.join(packages_dir, lib_dirname) if not os.path.exists(lib_dirpath): continue for filename in os.listdir(lib_dirpath): - if filename.lower().startswith('geos_c') and filename.lower().endswith('.dll'): + if filename.lower().startswith("geos_c") and filename.lower().endswith( + ".dll" + ): return os.path.join(lib_dirpath, filename) - raise RuntimeError('geos_c DLL not found in {}\\[S|s]hapely.libs'.format(packages_dir)) + raise RuntimeError( + "geos_c DLL not found in {}\\[S|s]hapely.libs".format(packages_dir) + ) - if shapely.__version__.startswith('2.'): - if sys.platform != 'win32': + if shapely.__version__.startswith("2."): + if sys.platform != "win32": # We load geos_c library indirectly by loading shapely.lib import shapely.lib + geomserde_speedup.load_libgeos_c(shapely.lib.__file__) else: # Find geos_c library and load it @@ -62,7 +67,7 @@ def deserialize(buf: bytearray) -> Optional[BaseGeometry]: speedup_enabled = True - elif shapely.__version__.startswith('1.'): + elif shapely.__version__.startswith("1."): # Shapely 1.x uses ctypes.CDLL to load geos_c library. We can obtain the # handle of geos_c library from `shapely.geos._lgeos._handle` import shapely.geos @@ -75,7 +80,7 @@ def deserialize(buf: bytearray) -> Optional[BaseGeometry]: MultiPoint, MultiLineString, MultiPolygon, - GeometryCollection + GeometryCollection, ) lgeos_handle = shapely.geos._lgeos._handle @@ -112,13 +117,13 @@ def deserialize(buf: bytearray) -> Optional[BaseGeometry]: ob = BaseGeometry() geom_type = shapely.geometry.base.GEOMETRY_TYPES[geom_type_id] ob.__class__ = GEOMETRY_CLASSES[geom_type_id] - ob.__dict__['__geom__'] = g - ob.__dict__['__p__'] = None + ob.__dict__["__geom__"] = g + ob.__dict__["__p__"] = None if has_z != 0: - ob.__dict__['_ndim'] = 3 + ob.__dict__["_ndim"] = 3 else: - ob.__dict__['_ndim'] = 2 - ob.__dict__['_is_empty'] = False + ob.__dict__["_ndim"] = 2 + ob.__dict__["_is_empty"] = False return ob, bytes_read speedup_enabled = True @@ -128,5 +133,7 @@ def deserialize(buf: bytearray) -> Optional[BaseGeometry]: from .geometry_serde_general import serialize, deserialize except Exception as e: - warn(f'Cannot load geomserde_speedup, fallback to general python implementation. Reason: {e}') + warn( + f"Cannot load geomserde_speedup, fallback to general python implementation. Reason: {e}" + ) from .geometry_serde_general import serialize, deserialize diff --git a/python/sedona/utils/geometry_serde_general.py b/python/sedona/utils/geometry_serde_general.py index 243a1f1f65..212a302b68 100644 --- a/python/sedona/utils/geometry_serde_general.py +++ b/python/sedona/utils/geometry_serde_general.py @@ -36,8 +36,14 @@ from shapely.wkt import loads as wkt_loads -CoordType = Union[Tuple[float, float], Tuple[float, float, float], Tuple[float, float, float, float]] -ListCoordType = Union[List[Tuple[float, float]], List[Tuple[float, float, float]], List[Tuple[float, float, float, float]]] +CoordType = Union[ + Tuple[float, float], Tuple[float, float, float], Tuple[float, float, float, float] +] +ListCoordType = Union[ + List[Tuple[float, float]], + List[Tuple[float, float, float]], + List[Tuple[float, float, float, float]], +] GET_COORDS_NUMPY_THRESHOLD = 50 @@ -46,6 +52,7 @@ class GeometryTypeID: """ Constants used to identify the geometry type in the serialized bytearray of geometry. """ + POINT = 1 LINESTRING = 2 POLYGON = 3 @@ -59,6 +66,7 @@ class CoordinateType: """ Constants used to identify geometry dimensions in the serialized bytearray of geometry. """ + XY = 1 XYZ = 2 XYM = 3 @@ -66,7 +74,7 @@ class CoordinateType: BYTES_PER_COORDINATE = [16, 24, 24, 32] NUM_COORD_COMPONENTS = [2, 3, 3, 4] - UNPACK_FORMAT = ['dd', 'ddd', 'ddxxxxxxxx', 'dddxxxxxxxx'] + UNPACK_FORMAT = ["dd", "ddd", "ddxxxxxxxx", "dddxxxxxxxx"] @staticmethod def type_of(geom) -> int: @@ -99,12 +107,8 @@ class GeometryBuffer: ints_offset: int def __init__( - self, - buffer: bytearray, - coord_type: int, - coords_offset: int, - num_coords: int - ) -> None: + self, buffer: bytearray, coord_type: int, coords_offset: int, num_coords: int + ) -> None: self.buffer = buffer self.coord_type = coord_type self.bytes_per_coord = CoordinateType.bytes_per_coord(coord_type) @@ -126,17 +130,16 @@ def read_polygon(self) -> Polygon: if num_rings == 0: return Polygon() - rings = [ - self.read_coordinates(self.read_int()) - for _ in range(num_rings) - ] + rings = [self.read_coordinates(self.read_int()) for _ in range(num_rings)] return Polygon(rings[0], rings[1:]) def write_linestring(self, line: LineString) -> None: coords = [tuple(c) for c in line.coords] self.write_int(len(coords)) - self.coords_offset = put_coordinates(self.buffer, self.coords_offset, self.coord_type, coords) + self.coords_offset = put_coordinates( + self.buffer, self.coords_offset, self.coord_type, coords + ) def write_polygon(self, polygon: Polygon) -> None: exterior = polygon.exterior @@ -149,7 +152,9 @@ def write_polygon(self, polygon: Polygon) -> None: self.write_linestring(interior) def read_coordinates(self, num_coords: int) -> ListCoordType: - coords = get_coordinates(self.buffer, self.coords_offset, self.coord_type, num_coords) + coords = get_coordinates( + self.buffer, self.coords_offset, self.coord_type, num_coords + ) self.coords_offset += num_coords * self.bytes_per_coord return coords @@ -161,7 +166,7 @@ def read_coordinate(self) -> CoordType: def read_int(self) -> int: value = struct.unpack_from("i", self.buffer, self.ints_offset)[0] if value > len(self.buffer): - raise ValueError('Unexpected large integer in structural data') + raise ValueError("Unexpected large integer in structural data") self.ints_offset += 4 return value @@ -169,6 +174,7 @@ def write_int(self, value: int) -> None: struct.pack_into("i", self.buffer, self.ints_offset, value) self.ints_offset += 4 + def serialize(geom: BaseGeometry) -> Optional[Union[bytes, bytearray]]: """ Serialize a shapely geometry object to the internal representation of GeometryUDT. @@ -197,6 +203,7 @@ def serialize(geom: BaseGeometry) -> Optional[Union[bytes, bytearray]]: else: raise ValueError(f"Unsupported geometry type: {type(geom)}") + def deserialize(buffer: bytes) -> Optional[BaseGeometry]: """ Deserialize a shapely geometry object from the internal representation of GeometryUDT. @@ -208,9 +215,9 @@ def deserialize(buffer: bytes) -> Optional[BaseGeometry]: preamble_byte = buffer[0] geom_type = (preamble_byte >> 4) & 0x0F coord_type = (preamble_byte >> 1) & 0x07 - num_coords = struct.unpack_from('i', buffer, 4)[0] + num_coords = struct.unpack_from("i", buffer, 4)[0] if num_coords > len(buffer): - raise ValueError('num_coords cannot be larger than buffer size') + raise ValueError("num_coords cannot be larger than buffer size") geom_buffer = GeometryBuffer(buffer, coord_type, 8, num_coords) if geom_type == GeometryTypeID.POINT: geom = deserialize_point(geom_buffer) @@ -231,47 +238,57 @@ def deserialize(buffer: bytes) -> Optional[BaseGeometry]: return geom, geom_buffer.ints_offset -def create_buffer_for_geom(geom_type: int, coord_type: int, size: int, num_coords: int) -> bytearray: +def create_buffer_for_geom( + geom_type: int, coord_type: int, size: int, num_coords: int +) -> bytearray: buffer = bytearray(size) preamble_byte = (geom_type << 4) | (coord_type << 1) buffer[0] = preamble_byte - struct.pack_into('i', buffer, 4, num_coords) + struct.pack_into("i", buffer, 4, num_coords) return buffer + def generate_header_bytes(geom_type: int, coord_type: int, num_coords: int) -> bytes: preamble_byte = (geom_type << 4) | (coord_type << 1) - return struct.pack( - 'BBBBi', - preamble_byte, - 0, - 0, - 0, - num_coords - ) + return struct.pack("BBBBi", preamble_byte, 0, 0, 0, num_coords) -def put_coordinates(buffer: bytearray, offset: int, coord_type: int, coords: ListCoordType) -> int: +def put_coordinates( + buffer: bytearray, offset: int, coord_type: int, coords: ListCoordType +) -> int: for coord in coords: - struct.pack_into(CoordinateType.unpack_format(coord_type, buffer, offset, *coord)) + struct.pack_into( + CoordinateType.unpack_format(coord_type, buffer, offset, *coord) + ) offset += CoordinateType.bytes_per_coord(coord_type) return offset -def put_coordinate(buffer: bytearray, offset: int, coord_type: int, coord: CoordType) -> int: +def put_coordinate( + buffer: bytearray, offset: int, coord_type: int, coord: CoordType +) -> int: struct.pack_into(CoordinateType.unpack_format(coord_type, buffer, offset, *coord)) offset += CoordinateType.bytes_per_coord(coord_type) return offset -def get_coordinates(buffer: bytearray, offset: int, coord_type: int, num_coords: int) -> Union[np.ndarray, ListCoordType]: +def get_coordinates( + buffer: bytearray, offset: int, coord_type: int, num_coords: int +) -> Union[np.ndarray, ListCoordType]: if num_coords < GET_COORDS_NUMPY_THRESHOLD: coords = [ - struct.unpack_from(CoordinateType.unpack_format(coord_type), buffer, offset + (i * CoordinateType.bytes_per_coord(coord_type))) + struct.unpack_from( + CoordinateType.unpack_format(coord_type), + buffer, + offset + (i * CoordinateType.bytes_per_coord(coord_type)), + ) for i in range(num_coords) ] else: nums_per_coord = CoordinateType.components_per_coord(coord_type) - coords = np.frombuffer(buffer, np.float64, num_coords * nums_per_coord, offset).reshape((num_coords, nums_per_coord)) + coords = np.frombuffer( + buffer, np.float64, num_coords * nums_per_coord, offset + ).reshape((num_coords, nums_per_coord)) return coords @@ -290,26 +307,12 @@ def serialize_point(geom: Point) -> bytes: # FIXME this does not handle M yet, but geom.has_z is extremely slow pack_format = "BBBBi" + "d" * geom._ndim coord_type = CoordinateType.type_of(geom) - preamble_byte = ((GeometryTypeID.POINT << 4) | (coord_type << 1)) + preamble_byte = (GeometryTypeID.POINT << 4) | (coord_type << 1) coords = coords[0] - return struct.pack( - pack_format, - preamble_byte, - 0, - 0, - 0, - 1, - *coords - ) + return struct.pack(pack_format, preamble_byte, 0, 0, 0, 1, *coords) else: - return struct.pack( - 'BBBBi', - 18, - 0, - 0, - 0, - 0 - ) + return struct.pack("BBBBi", 18, 0, 0, 0, 0) + def deserialize_point(geom_buffer: GeometryBuffer) -> Point: if geom_buffer.num_coords == 0: @@ -338,7 +341,7 @@ def serialize_multi_point(geom: MultiPoint) -> bytes: for k in range(geom._ndim): coords.append(math.nan) - body = array.array('d', coords).tobytes() + body = array.array("d", coords).tobytes() return header + body @@ -361,8 +364,10 @@ def serialize_linestring(geom: LineString) -> bytes: coords = [tuple(c) for c in geom.coords] if coords: coord_type = CoordinateType.type_of(geom) - header = generate_header_bytes(GeometryTypeID.LINESTRING, coord_type, len(coords)) - return header + array.array('d', [x for c in coords for x in c]).tobytes() + header = generate_header_bytes( + GeometryTypeID.LINESTRING, coord_type, len(coords) + ) + return header + array.array("d", [x for c in coords for x in c]).tobytes() else: return generate_header_bytes(GeometryTypeID.LINESTRING, 1, 0) @@ -382,10 +387,14 @@ def serialize_multi_linestring(geom: MultiLineString) -> bytes: line_lengths = [len(line) for line in lines] num_coords = sum(line_lengths) - header = generate_header_bytes(GeometryTypeID.MULTILINESTRING, coord_type, num_coords) - coord_data = array.array('d', [c for line in lines for coord in line for c in coord]).tobytes() - num_lines = struct.pack('i', len(lines)) - structure_data = array.array('i', line_lengths).tobytes() + header = generate_header_bytes( + GeometryTypeID.MULTILINESTRING, coord_type, num_coords + ) + coord_data = array.array( + "d", [c for line in lines for coord in line for c in coord] + ).tobytes() + num_lines = struct.pack("i", len(lines)) + structure_data = array.array("i", line_lengths).tobytes() result = header + coord_data + num_lines + structure_data @@ -403,17 +412,18 @@ def deserialize_multi_linestring(geom_buffer: GeometryBuffer) -> MultiLineString return wkt_loads("MULTILINESTRING EMPTY") return MultiLineString(linestrings) + def serialize_polygon(geom: Polygon) -> bytes: # it may seem odd, but dumping to wkb and parsing proved to be the fastest here wkb_string = wkb_dumps(geom) - int_format = ">i" if struct.unpack_from('B', wkb_string) == 0 else " bytes: num_coords += ring_len offset += 4 - coord_bytes += wkb_string[offset: (offset + bytes_per_coord * ring_len)] + coord_bytes += wkb_string[offset : (offset + bytes_per_coord * ring_len)] offset += bytes_per_coord * ring_len coord_type = CoordinateType.type_of(geom) header = generate_header_bytes(GeometryTypeID.POLYGON, coord_type, num_coords) - structure_data_bytes = array.array('i', [num_rings] + ring_lengths).tobytes() + structure_data_bytes = array.array("i", [num_rings] + ring_lengths).tobytes() return header + coord_bytes + structure_data_bytes @@ -444,17 +454,37 @@ def deserialize_polygon(geom_buffer: GeometryBuffer) -> Polygon: def serialize_multi_polygon(geom: MultiPolygon) -> bytes: coords_for = lambda x: [y for y in list(x)] - polygons = [[coords_for(polygon.exterior.coords)] + [coords_for(ring.coords) for ring in polygon.interiors] for polygon in list(geom.geoms)] + polygons = [ + [coords_for(polygon.exterior.coords)] + + [coords_for(ring.coords) for ring in polygon.interiors] + for polygon in list(geom.geoms) + ] coord_type = CoordinateType.type_of(geom) - structure_data = array.array('i', [val for polygon in polygons for val in [len(polygon)] + [len(ring) for ring in polygon]]).tobytes() - coords = array.array('d', [component for polygon in polygons for ring in polygon for coord in ring for component in coord]).tobytes() + structure_data = array.array( + "i", + [ + val + for polygon in polygons + for val in [len(polygon)] + [len(ring) for ring in polygon] + ], + ).tobytes() + coords = array.array( + "d", + [ + component + for polygon in polygons + for ring in polygon + for coord in ring + for component in coord + ], + ).tobytes() num_coords = len(coords) // CoordinateType.bytes_per_coord(coord_type) header = generate_header_bytes(GeometryTypeID.MULTIPOLYGON, coord_type, num_coords) - num_polygons = struct.pack('i', len(polygons)) + num_polygons = struct.pack("i", len(polygons)) result = header + coords + num_polygons + structure_data return result @@ -480,7 +510,9 @@ def serialize_geometry_collection(geom: GeometryCollection) -> bytearray: return serialize_shapely_1_empty_geom(geom) geometries = geom.geoms if not geometries: - return create_buffer_for_geom(GeometryTypeID.GEOMETRYCOLLECTION, CoordinateType.XY, 8, 0) + return create_buffer_for_geom( + GeometryTypeID.GEOMETRYCOLLECTION, CoordinateType.XY, 8, 0 + ) num_geometries = len(geometries) total_size = 8 buffers = [] @@ -488,10 +520,12 @@ def serialize_geometry_collection(geom: GeometryCollection) -> bytearray: buf = serialize(geom) buffers.append(buf) total_size += aligned_offset(len(buf)) - buffer = create_buffer_for_geom(GeometryTypeID.GEOMETRYCOLLECTION, CoordinateType.XY, total_size, num_geometries) + buffer = create_buffer_for_geom( + GeometryTypeID.GEOMETRYCOLLECTION, CoordinateType.XY, total_size, num_geometries + ) offset = 8 for buf in buffers: - buffer[offset:(offset + len(buf))] = buf + buffer[offset : (offset + len(buf))] = buf offset += aligned_offset(len(buf)) return buffer diff --git a/python/sedona/utils/jvm.py b/python/sedona/utils/jvm.py index d66b565962..1b843a001f 100644 --- a/python/sedona/utils/jvm.py +++ b/python/sedona/utils/jvm.py @@ -28,7 +28,9 @@ class JvmStorageLevel(JvmObject): @require(["StorageLevel"]) def _create_jvm_instance(self): return self.jvm.StorageLevel.apply( - self.storage_level.useDisk, self.storage_level.useMemory, - self.storage_level.useOffHeap, self.storage_level.deserialized, - self.storage_level.replication + self.storage_level.useDisk, + self.storage_level.useMemory, + self.storage_level.useOffHeap, + self.storage_level.deserialized, + self.storage_level.replication, ) diff --git a/python/sedona/utils/meta.py b/python/sedona/utils/meta.py index dc8a723436..079b8acb5f 100644 --- a/python/sedona/utils/meta.py +++ b/python/sedona/utils/meta.py @@ -25,6 +25,7 @@ try: from typing import GenericMeta except ImportError: + class GenericMeta(type): pass @@ -58,7 +59,9 @@ def register(self, meth): :param meth: :return: """ - if str(meth).startswith(" bool: - geoms = [Point, MultiPoint, Polygon, MultiPolygon, LineString, MultiLineString, GeometryCollection] + geoms = [ + Point, + MultiPoint, + Polygon, + MultiPolygon, + LineString, + MultiLineString, + GeometryCollection, + ] assign_udt_shapely_objects(geoms=geoms) assign_user_data_to_shapely_objects(geoms=geoms) return True @@ -30,6 +46,7 @@ def assign_all() -> bool: def assign_udt_shapely_objects(geoms: List[type(BaseGeometry)]) -> bool: from sedona.sql.types import GeometryType + for geom in geoms: geom.__UDT__ = GeometryType() return True diff --git a/python/sedona/utils/spatial_rdd_parser.py b/python/sedona/utils/spatial_rdd_parser.py index 7f963444b2..c18ac9cd8c 100644 --- a/python/sedona/utils/spatial_rdd_parser.py +++ b/python/sedona/utils/spatial_rdd_parser.py @@ -22,6 +22,7 @@ import attr from shapely.geometry.base import BaseGeometry + try: from pyspark import CPickleSerializer except ImportError: @@ -57,7 +58,7 @@ def __getstate__(self): return dict( geom=bytearray([el if el >= 0 else el + 256 for el in geom_bytes]), - userData=getattr(self, attributes[1]) + userData=getattr(self, attributes[1]), ) def __setstate__(self, attributes): @@ -85,7 +86,9 @@ def userData(self): __slots__ = ("_geom", "_userData") def __repr__(self): - return f"Geometry: {str(self.geom.__class__.__name__)} userData: {self.userData}" + return ( + f"Geometry: {str(self.geom.__class__.__name__)} userData: {self.userData}" + ) def __eq__(self, other): return self.geom == other.geom and self.userData == other.userData @@ -98,15 +101,15 @@ def __ne__(self, other): class AbstractSpatialRDDParser(ABC): @classmethod - def serialize(cls, obj: List[Any], binary_buffer: 'BinaryBuffer') -> bytearray: + def serialize(cls, obj: List[Any], binary_buffer: "BinaryBuffer") -> bytearray: raise NotImplemented() @classmethod - def deserialize(cls, bin_parser: 'BinaryParser') -> BaseGeometry: + def deserialize(cls, bin_parser: "BinaryParser") -> BaseGeometry: raise NotImplementedError("Parser has to implement deserialize method") @classmethod - def _deserialize_geom(cls, bin_parser: 'BinaryParser') -> GeoData: + def _deserialize_geom(cls, bin_parser: "BinaryParser") -> GeoData: is_circle = bin_parser.read_byte() return geom_deserializers[is_circle].geometry_from_bytes(bin_parser) @@ -116,7 +119,7 @@ class SpatialPairRDDParserData(AbstractSpatialRDDParser): name = "SpatialPairRDDParserData" @classmethod - def deserialize(cls, bin_parser: 'BinaryParser'): + def deserialize(cls, bin_parser: "BinaryParser"): left_geom_data = cls._deserialize_geom(bin_parser) _ = bin_parser.read_int() @@ -128,7 +131,7 @@ def deserialize(cls, bin_parser: 'BinaryParser'): return deserialized_data @classmethod - def serialize(cls, obj: BaseGeometry, binary_buffer: 'BinaryBuffer'): + def serialize(cls, obj: BaseGeometry, binary_buffer: "BinaryBuffer"): raise NotImplementedError("Currently this operation is not supported") @@ -137,14 +140,14 @@ class SpatialRDDParserData(AbstractSpatialRDDParser): name = "SpatialRDDParser" @classmethod - def deserialize(cls, bin_parser: 'BinaryParser'): + def deserialize(cls, bin_parser: "BinaryParser"): left_geom_data = cls._deserialize_geom(bin_parser) _ = bin_parser.read_int() return left_geom_data @classmethod - def serialize(cls, obj: BaseGeometry, binary_buffer: 'BinaryBuffer'): + def serialize(cls, obj: BaseGeometry, binary_buffer: "BinaryBuffer"): raise NotImplementedError("Currently this operation is not supported") @@ -153,7 +156,7 @@ class SpatialRDDParserDataMultipleRightGeom(AbstractSpatialRDDParser): name = "SpatialRDDParser" @classmethod - def deserialize(cls, bin_parser: 'BinaryParser'): + def deserialize(cls, bin_parser: "BinaryParser"): left_geom_data = cls._deserialize_geom(bin_parser) geometry_numbers = bin_parser.read_int() @@ -164,12 +167,14 @@ def deserialize(cls, bin_parser: 'BinaryParser'): right_geom_data = cls._deserialize_geom(bin_parser) right_geoms.append(right_geom_data) - deserialized_data = [left_geom_data, right_geoms] if right_geoms else left_geom_data + deserialized_data = ( + [left_geom_data, right_geoms] if right_geoms else left_geom_data + ) return deserialized_data @classmethod - def serialize(cls, obj: BaseGeometry, binary_buffer: 'BinaryBuffer'): + def serialize(cls, obj: BaseGeometry, binary_buffer: "BinaryBuffer"): raise NotImplementedError("Currently this operation is not supported") @@ -235,10 +240,11 @@ def geometry_from_bytes(cls, bin_parser: BinaryParser) -> GeoData: @classmethod def to_bytes(cls, geom: Circle) -> List[int]: - return struct.pack("b", 1) + struct.pack("d", geom.radius) + dumps(geom.centerGeometry) + return ( + struct.pack("b", 1) + + struct.pack("d", geom.radius) + + dumps(geom.centerGeometry) + ) -geom_deserializers = { - 1: CircleGeometryFactory, - 0: GeometryFactory -} +geom_deserializers = {1: CircleGeometryFactory, 0: GeometryFactory} diff --git a/python/setup.py b/python/setup.py index 7fc01b4e87..fd38f34e21 100644 --- a/python/setup.py +++ b/python/setup.py @@ -24,48 +24,52 @@ extension_args = {} -if os.getenv('ENABLE_ASAN'): +if os.getenv("ENABLE_ASAN"): extension_args = { - 'extra_compile_args': ["-fsanitize=address"], - 'extra_link_args': ["-fsanitize=address"] + "extra_compile_args": ["-fsanitize=address"], + "extra_link_args": ["-fsanitize=address"], } ext_modules = [ - Extension('sedona.utils.geomserde_speedup', sources=[ - 'src/geomserde_speedup_module.c', - 'src/geomserde.c', - 'src/geom_buf.c', - 'src/geos_c_dyn.c' - ], **extension_args) + Extension( + "sedona.utils.geomserde_speedup", + sources=[ + "src/geomserde_speedup_module.c", + "src/geomserde.c", + "src/geom_buf.c", + "src/geos_c_dyn.c", + ], + **extension_args + ) ] setup( - name='apache-sedona', + name="apache-sedona", version=version, - description='Apache Sedona is a cluster computing system for processing large-scale spatial data', - url='https://sedona.apache.org', + description="Apache Sedona is a cluster computing system for processing large-scale spatial data", + url="https://sedona.apache.org", license="Apache License v2.0", - author='Apache Sedona', - author_email='dev@sedona.apache.org', + author="Apache Sedona", + author_email="dev@sedona.apache.org", packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), ext_modules=ext_modules, long_description=long_description, long_description_content_type="text/markdown", - python_requires='>=3.6', - install_requires=['attrs', "shapely>=1.7.0", "rasterio>=1.2.10"], + python_requires=">=3.6", + install_requires=["attrs", "shapely>=1.7.0", "rasterio>=1.2.10"], extras_require={ "spark": ["pyspark>=2.3.0"], "pydeck-map": ["geopandas", "pydeck==0.8.0"], "kepler-map": ["geopandas", "keplergl==0.3.2"], - "all": ["pyspark>=2.3.0", "geopandas","pydeck==0.8.0", "keplergl==0.3.2"], + "all": ["pyspark>=2.3.0", "geopandas", "pydeck==0.8.0", "keplergl==0.3.2"], }, project_urls={ - 'Documentation': 'https://sedona.apache.org', - 'Source code': 'https://github.com/apache/sedona', - 'Bug Reports': 'https://issues.apache.org/jira/projects/SEDONA' + "Documentation": "https://sedona.apache.org", + "Source code": "https://github.com/apache/sedona", + "Bug Reports": "https://issues.apache.org/jira/projects/SEDONA", }, classifiers=[ "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License" - ] + "License :: OSI Approved :: Apache Software License", + ], ) diff --git a/python/tests/__init__.py b/python/tests/__init__.py index 96f6a2b860..b63fc7061e 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -23,23 +23,43 @@ mixed_wkb_geometry_input_location = os.path.join(tests_resource, "county_small_wkb.tsv") mixed_wkt_geometry_input_location = os.path.join(tests_resource, "county_small.tsv") shape_file_input_location = os.path.join(tests_resource, "shapefiles/dbf") -shape_file_with_missing_trailing_input_location = os.path.join(tests_resource, "shapefiles/missing") +shape_file_with_missing_trailing_input_location = os.path.join( + tests_resource, "shapefiles/missing" +) geojson_input_location = os.path.join(tests_resource, "testPolygon.json") area_lm_point_input_location = os.path.join(tests_resource, "arealm.csv") csv_point_input_location = os.path.join(tests_resource, "testpoint.csv") csv_polygon_input_location = os.path.join(tests_resource, "testenvelope.csv") -csv_polygon1_input_location = os.path.join(tests_resource, "equalitycheckfiles/testequals_envelope1.csv") -csv_polygon2_input_location = os.path.join(tests_resource, "equalitycheckfiles/testequals_envelope2.csv") -csv_polygon1_random_input_location = os.path.join(tests_resource, "equalitycheckfiles/testequals_envelope1_random.csv") -csv_polygon2_random_input_location = os.path.join(tests_resource, "equalitycheckfiles/testequals_envelope2_random.csv") -overlap_polygon_input_location = os.path.join(tests_resource, "testenvelope_overlap.csv") +csv_polygon1_input_location = os.path.join( + tests_resource, "equalitycheckfiles/testequals_envelope1.csv" +) +csv_polygon2_input_location = os.path.join( + tests_resource, "equalitycheckfiles/testequals_envelope2.csv" +) +csv_polygon1_random_input_location = os.path.join( + tests_resource, "equalitycheckfiles/testequals_envelope1_random.csv" +) +csv_polygon2_random_input_location = os.path.join( + tests_resource, "equalitycheckfiles/testequals_envelope2_random.csv" +) +overlap_polygon_input_location = os.path.join( + tests_resource, "testenvelope_overlap.csv" +) union_polygon_input_location = os.path.join(tests_resource, "testunion.csv") -csv_point1_input_location = os.path.join(tests_resource, "equalitycheckfiles/testequals_point1.csv") -csv_point2_input_location = os.path.join(tests_resource, "equalitycheckfiles/testequals_point2.csv") +csv_point1_input_location = os.path.join( + tests_resource, "equalitycheckfiles/testequals_point1.csv" +) +csv_point2_input_location = os.path.join( + tests_resource, "equalitycheckfiles/testequals_point2.csv" +) geojson_id_input_location = os.path.join(tests_resource, "testContainsId.json") geoparquet_input_location = os.path.join(tests_resource, "geoparquet/example1.parquet") plain_parquet_input_location = os.path.join(tests_resource, "geoparquet/plain.parquet") -legacy_parquet_input_location = os.path.join(tests_resource, "parquet/legacy-parquet-nested-columns.snappy.parquet") +legacy_parquet_input_location = os.path.join( + tests_resource, "parquet/legacy-parquet-nested-columns.snappy.parquet" +) google_buildings_input_location = os.path.join(tests_resource, "813_buildings_test.csv") chicago_crimes_input_location = os.path.join(tests_resource, "Chicago_Crimes.csv") -world_map_raster_input_location = os.path.join(tests_resource, "raster/raster_with_no_data/test5.tiff") +world_map_raster_input_location = os.path.join( + tests_resource, "raster/raster_with_no_data/test5.tiff" +) diff --git a/python/tests/core/test_avoiding_python_jvm_serde_df.py b/python/tests/core/test_avoiding_python_jvm_serde_df.py index d86a965f15..4811f9f513 100644 --- a/python/tests/core/test_avoiding_python_jvm_serde_df.py +++ b/python/tests/core/test_avoiding_python_jvm_serde_df.py @@ -36,26 +36,50 @@ class TestOmitPythonJvmSerdeToDf(TestBase): - expected_pois_within_areas_ids = [['4', '4'], ['1', '6'], ['2', '1'], ['3', '3'], ['3', '7']] + expected_pois_within_areas_ids = [ + ["4", "4"], + ["1", "6"], + ["2", "1"], + ["3", "3"], + ["3", "7"], + ] def test_spatial_join_to_df(self): - poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False) - areas_polygon_rdd = WktReader.readToGeometryRDD(self.sc, areas_csv_path, 1, False, False) + poi_point_rdd = WktReader.readToGeometryRDD( + self.sc, bank_csv_path, 1, False, False + ) + areas_polygon_rdd = WktReader.readToGeometryRDD( + self.sc, areas_csv_path, 1, False, False + ) poi_point_rdd.analyze() areas_polygon_rdd.analyze() poi_point_rdd.spatialPartitioning(GridType.QUADTREE) areas_polygon_rdd.spatialPartitioning(poi_point_rdd.getPartitioner()) - jvm_sedona_rdd = JoinQueryRaw.spatialJoin(poi_point_rdd, areas_polygon_rdd, JoinParams(considerBoundaryIntersection=True)) - sedona_df = Adapter.toDf(jvm_sedona_rdd, ["area_id", "area_name"], ["poi_id", "poi_name"], self.spark) + jvm_sedona_rdd = JoinQueryRaw.spatialJoin( + poi_point_rdd, + areas_polygon_rdd, + JoinParams(considerBoundaryIntersection=True), + ) + sedona_df = Adapter.toDf( + jvm_sedona_rdd, ["area_id", "area_name"], ["poi_id", "poi_name"], self.spark + ) assert sedona_df.count() == 5 - assert sedona_df.columns == ["leftgeometry", "area_id", "area_name", "rightgeometry", - "poi_id", "poi_name"] + assert sedona_df.columns == [ + "leftgeometry", + "area_id", + "area_name", + "rightgeometry", + "poi_id", + "poi_name", + ] def test_distance_join_query_flat_to_df(self): - poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False) + poi_point_rdd = WktReader.readToGeometryRDD( + self.sc, bank_csv_path, 1, False, False + ) circle_rdd = CircleRDD(poi_point_rdd, 2.0) circle_rdd.analyze() @@ -64,12 +88,14 @@ def test_distance_join_query_flat_to_df(self): poi_point_rdd.spatialPartitioning(GridType.QUADTREE) circle_rdd.spatialPartitioning(poi_point_rdd.getPartitioner()) - jvm_sedona_rdd = JoinQueryRaw.DistanceJoinQueryFlat(poi_point_rdd, circle_rdd, False, True) + jvm_sedona_rdd = JoinQueryRaw.DistanceJoinQueryFlat( + poi_point_rdd, circle_rdd, False, True + ) df_sedona_rdd = Adapter.toDf( jvm_sedona_rdd, ["poi_from_id", "poi_from_name"], ["poi_to_id", "poi_to_name"], - self.spark + self.spark, ) assert df_sedona_rdd.count() == 10 @@ -79,12 +105,16 @@ def test_distance_join_query_flat_to_df(self): "poi_from_name", "rightgeometry", "poi_to_id", - "poi_to_name" + "poi_to_name", ] def test_spatial_join_query_flat_to_df(self): - poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False) - areas_polygon_rdd = WktReader.readToGeometryRDD(self.sc, areas_csv_path, 1, False, False) + poi_point_rdd = WktReader.readToGeometryRDD( + self.sc, bank_csv_path, 1, False, False + ) + areas_polygon_rdd = WktReader.readToGeometryRDD( + self.sc, areas_csv_path, 1, False, False + ) poi_point_rdd.analyze() areas_polygon_rdd.analyze() @@ -92,24 +122,29 @@ def test_spatial_join_query_flat_to_df(self): areas_polygon_rdd.spatialPartitioning(poi_point_rdd.getPartitioner()) jvm_sedona_rdd = JoinQueryRaw.SpatialJoinQueryFlat( - poi_point_rdd, areas_polygon_rdd, False, True) + poi_point_rdd, areas_polygon_rdd, False, True + ) - pois_within_areas_with_default_column_names = Adapter.toDf(jvm_sedona_rdd, self.spark) + pois_within_areas_with_default_column_names = Adapter.toDf( + jvm_sedona_rdd, self.spark + ) assert pois_within_areas_with_default_column_names.count() == 5 pois_within_areas_with_passed_column_names = Adapter.toDf( - jvm_sedona_rdd, - ["area_id", "area_name"], - ["poi_id", "poi_name"], - self.spark + jvm_sedona_rdd, ["area_id", "area_name"], ["poi_id", "poi_name"], self.spark ) assert pois_within_areas_with_passed_column_names.count() == 5 - assert pois_within_areas_with_passed_column_names.columns == ["leftgeometry", "area_id", "area_name", - "rightgeometry", - "poi_id", "poi_name"] + assert pois_within_areas_with_passed_column_names.columns == [ + "leftgeometry", + "area_id", + "area_name", + "rightgeometry", + "poi_id", + "poi_name", + ] assert pois_within_areas_with_default_column_names.schema == StructType( [ @@ -118,36 +153,42 @@ def test_spatial_join_query_flat_to_df(self): ] ) - left_geometries_raw = pois_within_areas_with_default_column_names. \ - selectExpr("ST_AsText(leftgeometry)"). \ - collect() + left_geometries_raw = pois_within_areas_with_default_column_names.selectExpr( + "ST_AsText(leftgeometry)" + ).collect() left_geometries = self.__row_to_list(left_geometries_raw) - right_geometries_raw = pois_within_areas_with_default_column_names. \ - selectExpr("ST_AsText(rightgeometry)"). \ - collect() + right_geometries_raw = pois_within_areas_with_default_column_names.selectExpr( + "ST_AsText(rightgeometry)" + ).collect() right_geometries = self.__row_to_list(right_geometries_raw) # Ignore the ordering of these - assert set(geom[0] for geom in left_geometries) == set([ - 'POLYGON ((0 4, -3 3, -8 6, -6 8, -2 9, 0 4))', - 'POLYGON ((10 3, 10 6, 14 6, 14 3, 10 3))', - 'POLYGON ((2 2, 2 4, 3 5, 7 5, 9 3, 8 1, 4 1, 2 2))', - 'POLYGON ((-1 -1, -1 -3, -2 -5, -6 -8, -5 -2, -3 -2, -1 -1))', - 'POLYGON ((-1 -1, -1 -3, -2 -5, -6 -8, -5 -2, -3 -2, -1 -1))' - ]) - assert set(geom[0] for geom in right_geometries) == set([ - 'POINT (-3 5)', - 'POINT (11 5)', - 'POINT (4 3)', - 'POINT (-1 -1)', - 'POINT (-4 -5)' - ]) + assert set(geom[0] for geom in left_geometries) == set( + [ + "POLYGON ((0 4, -3 3, -8 6, -6 8, -2 9, 0 4))", + "POLYGON ((10 3, 10 6, 14 6, 14 3, 10 3))", + "POLYGON ((2 2, 2 4, 3 5, 7 5, 9 3, 8 1, 4 1, 2 2))", + "POLYGON ((-1 -1, -1 -3, -2 -5, -6 -8, -5 -2, -3 -2, -1 -1))", + "POLYGON ((-1 -1, -1 -3, -2 -5, -6 -8, -5 -2, -3 -2, -1 -1))", + ] + ) + assert set(geom[0] for geom in right_geometries) == set( + [ + "POINT (-3 5)", + "POINT (11 5)", + "POINT (4 3)", + "POINT (-1 -1)", + "POINT (-4 -5)", + ] + ) def test_range_query_flat_to_df(self): - poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False) + poi_point_rdd = WktReader.readToGeometryRDD( + self.sc, bank_csv_path, 1, False, False + ) poi_point_rdd.analyze() @@ -164,15 +205,18 @@ def test_range_query_flat_to_df(self): df_without_column_names = Adapter.toDf(result, self.spark) - raw_geometries = self.__row_to_list( - df_without_column_names.collect() - ) + raw_geometries = self.__row_to_list(df_without_column_names.collect()) assert [point[0].wkt for point in raw_geometries] == [ - 'POINT (9 8)', 'POINT (4 3)', 'POINT (12 1)', 'POINT (11 5)' + "POINT (9 8)", + "POINT (4 3)", + "POINT (12 1)", + "POINT (11 5)", ] assert df_without_column_names.count() == 4 - assert df_without_column_names.schema == StructType([StructField("geometry", GeometryType())]) + assert df_without_column_names.schema == StructType( + [StructField("geometry", GeometryType())] + ) df = Adapter.toDf(result, self.spark, ["poi_id", "poi_name"]) diff --git a/python/tests/core/test_avoiding_python_jvm_serde_to_rdd.py b/python/tests/core/test_avoiding_python_jvm_serde_to_rdd.py index bfbd1de1d2..3829342597 100644 --- a/python/tests/core/test_avoiding_python_jvm_serde_to_rdd.py +++ b/python/tests/core/test_avoiding_python_jvm_serde_to_rdd.py @@ -33,23 +33,39 @@ class TestOmitPythonJvmSerdeToRDD(TestBase): - expected_pois_within_areas_ids = [['4', '4'], ['1', '6'], ['2', '1'], ['3', '3'], ['3', '7']] + expected_pois_within_areas_ids = [ + ["4", "4"], + ["1", "6"], + ["2", "1"], + ["3", "3"], + ["3", "7"], + ] def test_spatial_join_to_spatial_rdd(self): - poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False) - areas_polygon_rdd = WktReader.readToGeometryRDD(self.sc, areas_csv_path, 1, False, False) + poi_point_rdd = WktReader.readToGeometryRDD( + self.sc, bank_csv_path, 1, False, False + ) + areas_polygon_rdd = WktReader.readToGeometryRDD( + self.sc, areas_csv_path, 1, False, False + ) poi_point_rdd.analyze() areas_polygon_rdd.analyze() poi_point_rdd.spatialPartitioning(GridType.QUADTREE) areas_polygon_rdd.spatialPartitioning(poi_point_rdd.getPartitioner()) - jvm_sedona_rdd = JoinQueryRaw.spatialJoin(poi_point_rdd, areas_polygon_rdd, JoinParams(considerBoundaryIntersection=True)) + jvm_sedona_rdd = JoinQueryRaw.spatialJoin( + poi_point_rdd, + areas_polygon_rdd, + JoinParams(considerBoundaryIntersection=True), + ) sedona_rdd = jvm_sedona_rdd.to_rdd().collect() assert sedona_rdd.__len__() == 5 def test_distance_join_query_flat_to_df(self): - poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False) + poi_point_rdd = WktReader.readToGeometryRDD( + self.sc, bank_csv_path, 1, False, False + ) circle_rdd = CircleRDD(poi_point_rdd, 2.0) circle_rdd.analyze() @@ -58,13 +74,19 @@ def test_distance_join_query_flat_to_df(self): poi_point_rdd.spatialPartitioning(GridType.QUADTREE) circle_rdd.spatialPartitioning(poi_point_rdd.getPartitioner()) - jvm_sedona_rdd = JoinQueryRaw.DistanceJoinQueryFlat(poi_point_rdd, circle_rdd, False, True) + jvm_sedona_rdd = JoinQueryRaw.DistanceJoinQueryFlat( + poi_point_rdd, circle_rdd, False, True + ) assert jvm_sedona_rdd.to_rdd().collect().__len__() == 10 def test_spatial_join_query_flat_to_df(self): - poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False) - areas_polygon_rdd = WktReader.readToGeometryRDD(self.sc, areas_csv_path, 1, False, False) + poi_point_rdd = WktReader.readToGeometryRDD( + self.sc, bank_csv_path, 1, False, False + ) + areas_polygon_rdd = WktReader.readToGeometryRDD( + self.sc, areas_csv_path, 1, False, False + ) poi_point_rdd.analyze() areas_polygon_rdd.analyze() @@ -72,12 +94,15 @@ def test_spatial_join_query_flat_to_df(self): areas_polygon_rdd.spatialPartitioning(poi_point_rdd.getPartitioner()) jvm_sedona_rdd = JoinQueryRaw.SpatialJoinQueryFlat( - poi_point_rdd, areas_polygon_rdd, False, True) + poi_point_rdd, areas_polygon_rdd, False, True + ) assert jvm_sedona_rdd.to_rdd().collect().__len__() == 5 def test_range_query_flat_to_df(self): - poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False) + poi_point_rdd = WktReader.readToGeometryRDD( + self.sc, bank_csv_path, 1, False, False + ) poi_point_rdd.analyze() diff --git a/python/tests/core/test_core_geom_primitives.py b/python/tests/core/test_core_geom_primitives.py index 6794d1f901..3fa20a369e 100644 --- a/python/tests/core/test_core_geom_primitives.py +++ b/python/tests/core/test_core_geom_primitives.py @@ -25,4 +25,6 @@ def test_jvm_envelope(self): envelope = Envelope(0.0, 5.0, 0.0, 5.0) jvm_instance = envelope.create_jvm_instance(self.spark.sparkContext._jvm) envelope_area = jvm_instance.getArea() - assert envelope_area == 25.0, f"Expected area to be equal 25 but {envelope_area} was found" + assert ( + envelope_area == 25.0 + ), f"Expected area to be equal 25 but {envelope_area} was found" diff --git a/python/tests/core/test_core_rdd.py b/python/tests/core/test_core_rdd.py index 0473164f19..769c0a1ebb 100644 --- a/python/tests/core/test_core_rdd.py +++ b/python/tests/core/test_core_rdd.py @@ -30,13 +30,7 @@ class TestSpatialRDD(TestBase): def test_creating_point_rdd(self): - point_rdd = PointRDD( - self.spark._sc, - point_path, - 4, - FileDataSplitter.WKT, - True - ) + point_rdd = PointRDD(self.spark._sc, point_path, 4, FileDataSplitter.WKT, True) point_rdd.analyze() cnt = point_rdd.countWithoutDuplicates() @@ -44,12 +38,7 @@ def test_creating_point_rdd(self): def test_creating_polygon_rdd(self): polygon_rdd = PolygonRDD( - self.spark._sc, - counties_path, - 0, - 1, - FileDataSplitter.WKT, - True + self.spark._sc, counties_path, 0, 1, FileDataSplitter.WKT, True ) polygon_rdd.analyze() diff --git a/python/tests/core/test_rdd.py b/python/tests/core/test_rdd.py index 201d442928..a002c7db29 100644 --- a/python/tests/core/test_rdd.py +++ b/python/tests/core/test_rdd.py @@ -27,8 +27,13 @@ from sedona.core.spatialOperator.join_params import JoinParams import os -from tests.properties.polygon_properties import polygon_rdd_input_location, polygon_rdd_start_offset, polygon_rdd_end_offset, \ - polygon_rdd_splitter, polygon_rdd_index_type +from tests.properties.polygon_properties import ( + polygon_rdd_input_location, + polygon_rdd_start_offset, + polygon_rdd_end_offset, + polygon_rdd_splitter, + polygon_rdd_index_type, +) from tests.test_base import TestBase from tests.tools import tests_resource @@ -56,7 +61,7 @@ def test_empty_constructor_test(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd_copy = PointRDD() object_rdd_copy.rawJvmSpatialRDD = object_rdd.rawJvmSpatialRDD @@ -68,7 +73,8 @@ def test_spatial_range_query(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False) + carryInputData=False, + ) for i in range(each_query_loop_times): result_size = RangeQuery.SpatialRangeQuery( @@ -82,12 +88,13 @@ def test_range_query_using_index(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd.buildIndex(point_rdd_index_type, False) for i in range(each_query_loop_times): result_size = RangeQuery.SpatialRangeQuery( - object_rdd, range_query_window, False, True).count + object_rdd, range_query_window, False, True + ).count def test_knn_query(self): object_rdd = PointRDD( @@ -95,7 +102,7 @@ def test_knn_query(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) for i in range(each_query_loop_times): result = KNNQuery.SpatialKnnQuery(object_rdd, knn_query_point, 1000, False) @@ -106,10 +113,10 @@ def test_knn_query_with_index(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd.buildIndex(point_rdd_index_type, False) - for i in range(each_query_loop_times): + for i in range(each_query_loop_times): result = KNNQuery.SpatialKnnQuery(object_rdd, knn_query_point, 1000, True) def test_spaltial_join(self): @@ -119,7 +126,7 @@ def test_spaltial_join(self): polygon_rdd_start_offset, polygon_rdd_end_offset, polygon_rdd_splitter, - True + True, ) object_rdd = PointRDD( @@ -127,7 +134,7 @@ def test_spaltial_join(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd.analyze() object_rdd.spatialPartitioning(join_query_partitionin_type) @@ -135,7 +142,8 @@ def test_spaltial_join(self): for x in range(each_query_loop_times): result_size = JoinQuery.SpatialJoinQuery( - object_rdd, query_window_rdd, False, True).count + object_rdd, query_window_rdd, False, True + ).count def test_spatial_join_using_index(self): query_window = PolygonRDD( @@ -144,14 +152,14 @@ def test_spatial_join_using_index(self): polygon_rdd_start_offset, polygon_rdd_end_offset, polygon_rdd_splitter, - True + True, ) object_rdd = PointRDD( sparkContext=self.sc, InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd.analyze() object_rdd.spatialPartitioning(join_query_partitionin_type) @@ -161,7 +169,8 @@ def test_spatial_join_using_index(self): for i in range(each_query_loop_times): result_size = JoinQuery.SpatialJoinQuery( - object_rdd, query_window, True, False).count() + object_rdd, query_window, True, False + ).count() def test_spatial_join_using_index_on_polygons(self): query_window = PolygonRDD( @@ -170,14 +179,14 @@ def test_spatial_join_using_index_on_polygons(self): polygon_rdd_start_offset, polygon_rdd_end_offset, polygon_rdd_splitter, - True + True, ) object_rdd = PointRDD( sparkContext=self.sc, InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd.analyze() object_rdd.spatialPartitioning(join_query_partitionin_type) @@ -187,10 +196,7 @@ def test_spatial_join_using_index_on_polygons(self): for i in range(each_query_loop_times): result_size = JoinQuery.SpatialJoinQuery( - object_rdd, - query_window, - True, - False + object_rdd, query_window, True, False ).count() def test_spatial_join_query_using_index_on_polygons(self): @@ -200,14 +206,14 @@ def test_spatial_join_query_using_index_on_polygons(self): polygon_rdd_start_offset, polygon_rdd_end_offset, polygon_rdd_splitter, - True + True, ) object_rdd = PointRDD( sparkContext=self.sc, InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd.analyze() object_rdd.spatialPartitioning(join_query_partitionin_type) @@ -225,14 +231,14 @@ def test_spatial_join_query_and_build_index_on_points_on_the_fly(self): polygon_rdd_start_offset, polygon_rdd_end_offset, polygon_rdd_splitter, - True + True, ) object_rdd = PointRDD( sparkContext=self.sc, InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd.analyze() object_rdd.spatialPartitioning(join_query_partitionin_type) @@ -240,10 +246,7 @@ def test_spatial_join_query_and_build_index_on_points_on_the_fly(self): for i in range(each_query_loop_times): result_size = JoinQuery.SpatialJoinQuery( - object_rdd, - query_window, - True, - False + object_rdd, query_window, True, False ).count() def test_spatial_join_query_and_build_index_on_polygons_on_the_fly(self): @@ -253,7 +256,7 @@ def test_spatial_join_query_and_build_index_on_polygons_on_the_fly(self): polygon_rdd_start_offset, polygon_rdd_end_offset, polygon_rdd_splitter, - True + True, ) object_rdd = PointRDD( @@ -261,18 +264,18 @@ def test_spatial_join_query_and_build_index_on_polygons_on_the_fly(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) object_rdd.analyze() object_rdd.spatialPartitioning(join_query_partitionin_type) query_window_rdd.spatialPartitioning(object_rdd.getPartitioner()) for i in range(each_query_loop_times): - join_params = JoinParams(True, False, polygon_rdd_index_type, JoinBuildSide.LEFT) + join_params = JoinParams( + True, False, polygon_rdd_index_type, JoinBuildSide.LEFT + ) resultSize = JoinQuery.spatialJoin( - query_window_rdd, - object_rdd, - join_params + query_window_rdd, object_rdd, join_params ).count() def test_distance_join_query(self): @@ -281,7 +284,7 @@ def test_distance_join_query(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) query_window_rdd = CircleRDD(object_rdd, 0.1) object_rdd.analyze() @@ -290,10 +293,8 @@ def test_distance_join_query(self): for i in range(each_query_loop_times): result_size = JoinQuery.DistanceJoinQuery( - object_rdd, - query_window_rdd, - False, - True).count() + object_rdd, query_window_rdd, False, True + ).count() def test_distance_join_query_using_index(self): object_rdd = PointRDD( @@ -301,7 +302,7 @@ def test_distance_join_query_using_index(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) query_window_rdd = CircleRDD(object_rdd, 0.1) object_rdd.analyze() @@ -312,8 +313,5 @@ def test_distance_join_query_using_index(self): for i in range(each_query_loop_times): result_size = JoinQuery.DistanceJoinQuery( - object_rdd, - query_window_rdd, - True, - True + object_rdd, query_window_rdd, True, True ).count diff --git a/python/tests/core/test_spatial_rdd_from_disc.py b/python/tests/core/test_spatial_rdd_from_disc.py index b202418125..b8bfc9d946 100644 --- a/python/tests/core/test_spatial_rdd_from_disc.py +++ b/python/tests/core/test_spatial_rdd_from_disc.py @@ -22,8 +22,11 @@ from sedona.core.SpatialRDD import PointRDD, PolygonRDD, LineStringRDD from sedona.core.enums import IndexType, GridType -from sedona.core.formatMapper.disc_utils import load_spatial_rdd_from_disc, \ - load_spatial_index_rdd_from_disc, GeoType +from sedona.core.formatMapper.disc_utils import ( + load_spatial_rdd_from_disc, + load_spatial_index_rdd_from_disc, + GeoType, +) from sedona.core.spatialOperator import JoinQuery from tests.test_base import TestBase from tests.tools import tests_resource @@ -36,77 +39,109 @@ def remove_directory(path: str) -> bool: return False return True + disc_location = os.path.join(tests_resource, "spatial_objects/temp") + class TestDiscUtils(TestBase): def test_saving_to_disc_spatial_rdd_point(self): - from tests.properties.point_properties import input_location, offset, splitter, num_partitions + from tests.properties.point_properties import ( + input_location, + offset, + splitter, + num_partitions, + ) point_rdd = PointRDD( self.sc, input_location, offset, splitter, True, num_partitions ) - point_rdd.rawJvmSpatialRDD.saveAsObjectFile(os.path.join(disc_location, "point")) + point_rdd.rawJvmSpatialRDD.saveAsObjectFile( + os.path.join(disc_location, "point") + ) def test_saving_to_disc_spatial_rdd_polygon(self): - from tests.properties.polygon_properties import input_location, splitter, num_partitions - polygon_rdd = PolygonRDD( - self.sc, + from tests.properties.polygon_properties import ( input_location, splitter, - True, - num_partitions + num_partitions, + ) + + polygon_rdd = PolygonRDD( + self.sc, input_location, splitter, True, num_partitions + ) + polygon_rdd.rawJvmSpatialRDD.saveAsObjectFile( + os.path.join(disc_location, "polygon") ) - polygon_rdd.rawJvmSpatialRDD.saveAsObjectFile(os.path.join(disc_location, "polygon")) def test_saving_to_disc_spatial_rdd_linestring(self): - from tests.properties.linestring_properties import input_location, splitter, num_partitions - linestring_rdd = LineStringRDD( - self.sc, + from tests.properties.linestring_properties import ( input_location, splitter, - True, - num_partitions + num_partitions, ) - linestring_rdd.rawJvmSpatialRDD.saveAsObjectFile(os.path.join(disc_location, "line_string")) - def test_saving_to_disc_index_linestring(self): - from tests.properties.linestring_properties import input_location, splitter, num_partitions linestring_rdd = LineStringRDD( - self.sc, + self.sc, input_location, splitter, True, num_partitions + ) + linestring_rdd.rawJvmSpatialRDD.saveAsObjectFile( + os.path.join(disc_location, "line_string") + ) + + def test_saving_to_disc_index_linestring(self): + from tests.properties.linestring_properties import ( input_location, splitter, - True, - num_partitions + num_partitions, + ) + + linestring_rdd = LineStringRDD( + self.sc, input_location, splitter, True, num_partitions ) linestring_rdd.buildIndex(IndexType.RTREE, False) - linestring_rdd.indexedRawRDD.saveAsObjectFile(os.path.join(disc_location, "line_string_index")) + linestring_rdd.indexedRawRDD.saveAsObjectFile( + os.path.join(disc_location, "line_string_index") + ) def test_saving_to_disc_index_polygon(self): - from tests.properties.polygon_properties import input_location, splitter, num_partitions - polygon_rdd = PolygonRDD( - self.sc, + from tests.properties.polygon_properties import ( input_location, splitter, - True, - num_partitions + num_partitions, + ) + + polygon_rdd = PolygonRDD( + self.sc, input_location, splitter, True, num_partitions ) polygon_rdd.buildIndex(IndexType.RTREE, False) - polygon_rdd.indexedRawRDD.saveAsObjectFile(os.path.join(disc_location, "polygon_index")) + polygon_rdd.indexedRawRDD.saveAsObjectFile( + os.path.join(disc_location, "polygon_index") + ) def test_saving_to_disc_index_point(self): - from tests.properties.point_properties import input_location, offset, splitter, num_partitions + from tests.properties.point_properties import ( + input_location, + offset, + splitter, + num_partitions, + ) + point_rdd = PointRDD( - self.sc, input_location, offset, splitter, True, num_partitions) + self.sc, input_location, offset, splitter, True, num_partitions + ) point_rdd.buildIndex(IndexType.RTREE, False) - point_rdd.indexedRawRDD.saveAsObjectFile(os.path.join(disc_location, "point_index")) + point_rdd.indexedRawRDD.saveAsObjectFile( + os.path.join(disc_location, "point_index") + ) def test_loading_spatial_rdd_from_disc(self): point_rdd = load_spatial_rdd_from_disc( self.sc, os.path.join(disc_location, "point"), GeoType.POINT ) - point_index_rdd = load_spatial_index_rdd_from_disc(self.sc, os.path.join(disc_location, "point_index")) + point_index_rdd = load_spatial_index_rdd_from_disc( + self.sc, os.path.join(disc_location, "point_index") + ) point_rdd.indexedRawRDD = point_index_rdd assert point_rdd.indexedRawRDD is not None @@ -117,7 +152,9 @@ def test_loading_spatial_rdd_from_disc(self): polygon_rdd = load_spatial_rdd_from_disc( self.sc, os.path.join(disc_location, "polygon"), GeoType.POLYGON ) - polygon_index_rdd = load_spatial_index_rdd_from_disc(self.sc, os.path.join(disc_location, "polygon_index")) + polygon_index_rdd = load_spatial_index_rdd_from_disc( + self.sc, os.path.join(disc_location, "polygon_index") + ) polygon_rdd.indexedRawRDD = polygon_index_rdd polygon_rdd.analyze() @@ -129,7 +166,9 @@ def test_loading_spatial_rdd_from_disc(self): linestring_rdd = load_spatial_rdd_from_disc( self.sc, os.path.join(disc_location, "line_string"), GeoType.LINESTRING ) - linestring_index_rdd = load_spatial_index_rdd_from_disc(self.sc, os.path.join(disc_location, "line_string_index")) + linestring_index_rdd = load_spatial_index_rdd_from_disc( + self.sc, os.path.join(disc_location, "line_string_index") + ) linestring_rdd.indexedRawRDD = linestring_index_rdd assert linestring_rdd.indexedRawRDD is not None @@ -144,7 +183,8 @@ def test_loading_spatial_rdd_from_disc(self): linestring_rdd.buildIndex(IndexType.RTREE, True) result = JoinQuery.SpatialJoinQuery( - linestring_rdd, polygon_rdd, True, True).collect() + linestring_rdd, polygon_rdd, True, True + ).collect() print(result) remove_directory(disc_location) diff --git a/python/tests/format_mapper/test_geo_json_reader.py b/python/tests/format_mapper/test_geo_json_reader.py index b5ed20c53b..73b4ade7a8 100644 --- a/python/tests/format_mapper/test_geo_json_reader.py +++ b/python/tests/format_mapper/test_geo_json_reader.py @@ -27,9 +27,15 @@ geo_json_contains_id = os.path.join(tests_resource, "testContainsId.json") geo_json_geom_with_feature_property = os.path.join(tests_resource, "testPolygon.json") -geo_json_geom_without_feature_property = os.path.join(tests_resource, "testpolygon-no-property.json") -geo_json_with_invalid_geometries = os.path.join(tests_resource, "testInvalidPolygon.json") -geo_json_with_invalid_geom_with_feature_property = os.path.join(tests_resource, "invalidSyntaxGeometriesJson.json") +geo_json_geom_without_feature_property = os.path.join( + tests_resource, "testpolygon-no-property.json" +) +geo_json_with_invalid_geometries = os.path.join( + tests_resource, "testInvalidPolygon.json" +) +geo_json_with_invalid_geom_with_feature_property = os.path.join( + tests_resource, "invalidSyntaxGeometriesJson.json" +) class TestGeoJsonReader(TestBase): @@ -37,15 +43,13 @@ class TestGeoJsonReader(TestBase): def test_read_to_geometry_rdd(self): if is_greater_or_equal_version(SedonaMeta.version, "1.0.0"): geo_json_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_geom_with_feature_property + self.sc, geo_json_geom_with_feature_property ) assert geo_json_rdd.rawSpatialRDD.count() == 1001 geo_json_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_geom_without_feature_property + self.sc, geo_json_geom_without_feature_property ) assert geo_json_rdd.rawSpatialRDD.count() == 10 @@ -53,52 +57,39 @@ def test_read_to_geometry_rdd(self): def test_read_to_valid_geometry_rdd(self): if is_greater_or_equal_version(SedonaMeta.version, "1.0.0"): geo_json_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_geom_with_feature_property, - True, - False + self.sc, geo_json_geom_with_feature_property, True, False ) assert geo_json_rdd.rawSpatialRDD.count() == 1001 geo_json_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_geom_without_feature_property, - True, - False + self.sc, geo_json_geom_without_feature_property, True, False ) assert geo_json_rdd.rawSpatialRDD.count() == 10 geo_json_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_with_invalid_geometries, - False, - False + self.sc, geo_json_with_invalid_geometries, False, False ) assert geo_json_rdd.rawSpatialRDD.count() == 2 geo_json_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_with_invalid_geometries + self.sc, geo_json_with_invalid_geometries ) assert geo_json_rdd.rawSpatialRDD.count() == 3 def test_read_to_include_id_rdd(self): if is_greater_or_equal_version(SedonaMeta.version, "1.0.0"): geo_json_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_contains_id, - True, - False + self.sc, geo_json_contains_id, True, False ) geo_json_rdd = GeoJsonReader.readToGeometryRDD( sc=self.sc, inputPath=geo_json_contains_id, allowInvalidGeometries=True, - skipSyntacticallyInvalidGeometries=False + skipSyntacticallyInvalidGeometries=False, ) assert geo_json_rdd.rawSpatialRDD.count() == 1 try: @@ -109,10 +100,7 @@ def test_read_to_include_id_rdd(self): def test_read_to_geometry_rdd_invalid_syntax(self): if is_greater_or_equal_version(SedonaMeta.version, "1.0.0"): geojson_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_with_invalid_geom_with_feature_property, - False, - True + self.sc, geo_json_with_invalid_geom_with_feature_property, False, True ) assert geojson_rdd.rawSpatialRDD.count() == 1 diff --git a/python/tests/format_mapper/test_shapefile_reader.py b/python/tests/format_mapper/test_shapefile_reader.py index 5134b911d1..73e7fd9ace 100644 --- a/python/tests/format_mapper/test_shapefile_reader.py +++ b/python/tests/format_mapper/test_shapefile_reader.py @@ -38,13 +38,17 @@ def test_shape_file_end_with_undefined_type(self): sc=self.sc, inputPath=undefined_type_shape_location ) - assert shape_rdd.fieldNames == ['LGA_CODE16', 'LGA_NAME16', 'STE_CODE16', 'STE_NAME16', 'AREASQKM16'] + assert shape_rdd.fieldNames == [ + "LGA_CODE16", + "LGA_NAME16", + "STE_CODE16", + "STE_NAME16", + "AREASQKM16", + ] assert shape_rdd.getRawSpatialRDD().count() == 545 def test_read_geometry_rdd(self): - shape_rdd = ShapefileReader.readToGeometryRDD( - self.sc, polygon_shape_location - ) + shape_rdd = ShapefileReader.readToGeometryRDD(self.sc, polygon_shape_location) assert shape_rdd.fieldNames == [] assert shape_rdd.rawSpatialRDD.collect().__len__() == 10000 @@ -56,8 +60,14 @@ def test_read_to_polygon_rdd(self): count = RangeQuery.SpatialRangeQuery(spatial_rdd, window, False, False).count() assert spatial_rdd.rawSpatialRDD.count() == count - assert 'org.apache.sedona.core.spatialRDD.SpatialRDD' in geometry_rdd._srdd.toString() - assert 'org.apache.sedona.core.spatialRDD.PolygonRDD' in spatial_rdd._srdd.toString() + assert ( + "org.apache.sedona.core.spatialRDD.SpatialRDD" + in geometry_rdd._srdd.toString() + ) + assert ( + "org.apache.sedona.core.spatialRDD.PolygonRDD" + in spatial_rdd._srdd.toString() + ) def test_read_to_linestring_rdd(self): input_location = os.path.join(tests_resource, "shapefiles/polyline") @@ -66,8 +76,14 @@ def test_read_to_linestring_rdd(self): window = Envelope(-180.0, 180.0, -90.0, 90.0) count = RangeQuery.SpatialRangeQuery(spatial_rdd, window, False, False).count() assert spatial_rdd.rawSpatialRDD.count() == count - assert 'org.apache.sedona.core.spatialRDD.SpatialRDD' in geometry_rdd._srdd.toString() - assert 'org.apache.sedona.core.spatialRDD.LineStringRDD' in spatial_rdd._srdd.toString() + assert ( + "org.apache.sedona.core.spatialRDD.SpatialRDD" + in geometry_rdd._srdd.toString() + ) + assert ( + "org.apache.sedona.core.spatialRDD.LineStringRDD" + in spatial_rdd._srdd.toString() + ) def test_read_to_point_rdd(self): input_location = os.path.join(tests_resource, "shapefiles/point") @@ -76,8 +92,13 @@ def test_read_to_point_rdd(self): window = Envelope(-180.0, 180.0, -90.0, 90.0) count = RangeQuery.SpatialRangeQuery(spatial_rdd, window, False, False).count() assert spatial_rdd.rawSpatialRDD.count() == count - assert 'org.apache.sedona.core.spatialRDD.SpatialRDD' in geometry_rdd._srdd.toString() - assert 'org.apache.sedona.core.spatialRDD.PointRDD' in spatial_rdd._srdd.toString() + assert ( + "org.apache.sedona.core.spatialRDD.SpatialRDD" + in geometry_rdd._srdd.toString() + ) + assert ( + "org.apache.sedona.core.spatialRDD.PointRDD" in spatial_rdd._srdd.toString() + ) def test_read_to_point_rdd_multipoint(self): input_location = os.path.join(tests_resource, "shapefiles/multipoint") @@ -86,5 +107,10 @@ def test_read_to_point_rdd_multipoint(self): window = Envelope(-180.0, 180.0, -90.0, 90.0) count = RangeQuery.SpatialRangeQuery(spatial_rdd, window, False, False).count() assert spatial_rdd.rawSpatialRDD.count() == count - assert 'org.apache.sedona.core.spatialRDD.SpatialRDD' in geometry_rdd._srdd.toString() - assert 'org.apache.sedona.core.spatialRDD.PointRDD' in spatial_rdd._srdd.toString() + assert ( + "org.apache.sedona.core.spatialRDD.SpatialRDD" + in geometry_rdd._srdd.toString() + ) + assert ( + "org.apache.sedona.core.spatialRDD.PointRDD" in spatial_rdd._srdd.toString() + ) diff --git a/python/tests/maps/test_sedonakepler_visualization.py b/python/tests/maps/test_sedonakepler_visualization.py index e681873a97..7f9ad2195d 100644 --- a/python/tests/maps/test_sedonakepler_visualization.py +++ b/python/tests/maps/test_sedonakepler_visualization.py @@ -25,7 +25,7 @@ class TestVisualization(TestBase): - """ _repr_html() creates a html encoded string of the current map data, can be used to assert data equality """ + """_repr_html() creates a html encoded string of the current map data, can be used to assert data equality""" def test_basic_map_creation(self): sedona_kepler_map = SedonaKepler.create_map() @@ -33,14 +33,20 @@ def test_basic_map_creation(self): assert sedona_kepler_map.config == kepler_map.config def test_map_creation_with_df(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") - polygon_gdf = gpd.GeoDataFrame(data=polygon_df.toPandas(), geometry="countyshape") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) + polygon_gdf = gpd.GeoDataFrame( + data=polygon_df.toPandas(), geometry="countyshape" + ) polygon_gdf_renamed = polygon_gdf.rename_geometry("geometry") sedona_kepler_map = SedonaKepler.create_map(df=polygon_df, name="data_1") @@ -52,23 +58,38 @@ def test_map_creation_with_df(self): def test_df_with_raster(self): df = self.spark.read.format("binaryFile").load(world_map_raster_input_location) - df = df.selectExpr("RS_FromGeoTiff(content) as raster", "ST_GeomFromWKT('POINT (1 1)') as point") + df = df.selectExpr( + "RS_FromGeoTiff(content) as raster", + "ST_GeomFromWKT('POINT (1 1)') as point", + ) sedona_kepler = SedonaKepler.create_map(df=df, name="data_1") actual = sedona_kepler.data - expected = {'data_1': {'index': [0], 'columns': ['geometry'], 'data': [['POINT (1.0000000000000000 1.0000000000000000)']]}} + expected = { + "data_1": { + "index": [0], + "columns": ["geometry"], + "data": [["POINT (1.0000000000000000 1.0000000000000000)"]], + } + } assert actual == expected def test_df_addition(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") - polygon_gdf = gpd.GeoDataFrame(data=polygon_df.toPandas(), geometry="countyshape") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) + polygon_gdf = gpd.GeoDataFrame( + data=polygon_df.toPandas(), geometry="countyshape" + ) polygon_gdf_renamed = polygon_gdf.rename_geometry("geometry") sedona_kepler_empty_map = SedonaKepler.create_map() @@ -81,140 +102,199 @@ def test_df_addition(self): assert sedona_kepler_empty_map.config == kepler_map.config def test_pandas_df_addition(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_h3_df = self.spark.sql( - "select ST_H3CellIDs(ST_GeomFromWKT(polygontable._c0), 3, false) as h3_cellID from polygontable") - polygon_exploded_h3 = polygon_h3_df.select(explode(polygon_h3_df.h3_cellID).alias("h3")) - polygon_hex_exploded_h3 = polygon_exploded_h3.select(hex(polygon_exploded_h3.h3).alias("hex_h3")) + "select ST_H3CellIDs(ST_GeomFromWKT(polygontable._c0), 3, false) as h3_cellID from polygontable" + ) + polygon_exploded_h3 = polygon_h3_df.select( + explode(polygon_h3_df.h3_cellID).alias("h3") + ) + polygon_hex_exploded_h3 = polygon_exploded_h3.select( + hex(polygon_exploded_h3.h3).alias("hex_h3") + ) kepler_map = SedonaKepler.create_map(df=polygon_hex_exploded_h3, name="h3") # just test if the map creation is successful. assert kepler_map is not None def test_adding_multiple_datasets(self): - config = {'version': 'v1', - 'config': {'visState': {'filters': [], - 'layers': [{'id': 'ikzru0t', - 'type': 'geojson', - 'config': {'dataId': 'AirportCount', - 'label': 'AirportCount', - 'color': [218, 112, 191], - 'highlightColor': [252, 242, 26, 255], - 'columns': {'geojson': 'geometry'}, - 'isVisible': True, - 'visConfig': {'opacity': 0.8, - 'strokeOpacity': 0.8, - 'thickness': 0.5, - 'strokeColor': [18, 92, 119], - 'colorRange': { - 'name': 'Uber Viz Sequential 6', - 'type': 'sequential', - 'category': 'Uber', - 'colors': ['#E6FAFA', - '#C1E5E6', - '#9DD0D4', - '#75BBC1', - '#4BA7AF', - '#00939C', - '#108188', - '#0E7077']}, - 'strokeColorRange': { - 'name': 'Global Warming', - 'type': 'sequential', - 'category': 'Uber', - 'colors': ['#5A1846', - '#900C3F', - '#C70039', - '#E3611C', - '#F1920E', - '#FFC300']}, - 'radius': 10, - 'sizeRange': [0, 10], - 'radiusRange': [0, 50], - 'heightRange': [0, 500], - 'elevationScale': 5, - 'enableElevationZoomFactor': True, - 'stroked': False, - 'filled': True, - 'enable3d': False, - 'wireframe': False}, - 'hidden': False, - 'textLabel': [{'field': None, - 'color': [255, 255, 255], - 'size': 18, - 'offset': [0, 0], - 'anchor': 'start', - 'alignment': 'center'}]}, - 'visualChannels': {'colorField': {'name': 'AirportCount', - 'type': 'integer'}, - 'colorScale': 'quantize', - 'strokeColorField': None, - 'strokeColorScale': 'quantile', - 'sizeField': None, - 'sizeScale': 'linear', - 'heightField': None, - 'heightScale': 'linear', - 'radiusField': None, - 'radiusScale': 'linear'}}], - 'interactionConfig': { - 'tooltip': {'fieldsToShow': {'AirportCount': [{'name': 'NAME_EN', - 'format': None}, - {'name': 'AirportCount', - 'format': None}]}, - 'compareMode': False, - 'compareType': 'absolute', - 'enabled': True}, - 'brush': {'size': 0.5, 'enabled': False}, - 'geocoder': {'enabled': False}, - 'coordinate': {'enabled': False}}, - 'layerBlending': 'normal', - 'splitMaps': [], - 'animationConfig': {'currentTime': None, 'speed': 1}}, - 'mapState': {'bearing': 0, - 'dragRotate': False, - 'latitude': 56.422456606624316, - 'longitude': 9.778836615231771, - 'pitch': 0, - 'zoom': 0.4214991225736964, - 'isSplit': False}, - 'mapStyle': {'styleType': 'dark', - 'topLayerGroups': {}, - 'visibleLayerGroups': {'label': True, - 'road': True, - 'border': False, - 'building': True, - 'water': True, - 'land': True, - '3d building': False}, - 'threeDBuildingColor': [9.665468314072013, - 17.18305478057247, - 31.1442867897876], - 'mapStyles': {}}}} - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) - - point_csv_df = self.spark.read.format("csv"). \ - option("delimiter", ","). \ - option("header", "false"). \ - load(csv_point_input_location) + config = { + "version": "v1", + "config": { + "visState": { + "filters": [], + "layers": [ + { + "id": "ikzru0t", + "type": "geojson", + "config": { + "dataId": "AirportCount", + "label": "AirportCount", + "color": [218, 112, 191], + "highlightColor": [252, 242, 26, 255], + "columns": {"geojson": "geometry"}, + "isVisible": True, + "visConfig": { + "opacity": 0.8, + "strokeOpacity": 0.8, + "thickness": 0.5, + "strokeColor": [18, 92, 119], + "colorRange": { + "name": "Uber Viz Sequential 6", + "type": "sequential", + "category": "Uber", + "colors": [ + "#E6FAFA", + "#C1E5E6", + "#9DD0D4", + "#75BBC1", + "#4BA7AF", + "#00939C", + "#108188", + "#0E7077", + ], + }, + "strokeColorRange": { + "name": "Global Warming", + "type": "sequential", + "category": "Uber", + "colors": [ + "#5A1846", + "#900C3F", + "#C70039", + "#E3611C", + "#F1920E", + "#FFC300", + ], + }, + "radius": 10, + "sizeRange": [0, 10], + "radiusRange": [0, 50], + "heightRange": [0, 500], + "elevationScale": 5, + "enableElevationZoomFactor": True, + "stroked": False, + "filled": True, + "enable3d": False, + "wireframe": False, + }, + "hidden": False, + "textLabel": [ + { + "field": None, + "color": [255, 255, 255], + "size": 18, + "offset": [0, 0], + "anchor": "start", + "alignment": "center", + } + ], + }, + "visualChannels": { + "colorField": { + "name": "AirportCount", + "type": "integer", + }, + "colorScale": "quantize", + "strokeColorField": None, + "strokeColorScale": "quantile", + "sizeField": None, + "sizeScale": "linear", + "heightField": None, + "heightScale": "linear", + "radiusField": None, + "radiusScale": "linear", + }, + } + ], + "interactionConfig": { + "tooltip": { + "fieldsToShow": { + "AirportCount": [ + {"name": "NAME_EN", "format": None}, + {"name": "AirportCount", "format": None}, + ] + }, + "compareMode": False, + "compareType": "absolute", + "enabled": True, + }, + "brush": {"size": 0.5, "enabled": False}, + "geocoder": {"enabled": False}, + "coordinate": {"enabled": False}, + }, + "layerBlending": "normal", + "splitMaps": [], + "animationConfig": {"currentTime": None, "speed": 1}, + }, + "mapState": { + "bearing": 0, + "dragRotate": False, + "latitude": 56.422456606624316, + "longitude": 9.778836615231771, + "pitch": 0, + "zoom": 0.4214991225736964, + "isSplit": False, + }, + "mapStyle": { + "styleType": "dark", + "topLayerGroups": {}, + "visibleLayerGroups": { + "label": True, + "road": True, + "border": False, + "building": True, + "water": True, + "land": True, + "3d building": False, + }, + "threeDBuildingColor": [ + 9.665468314072013, + 17.18305478057247, + 31.1442867897876, + ], + "mapStyles": {}, + }, + }, + } + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) + + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) polygon_wkt_df.createOrReplaceTempView("polygontable") - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) - sedona_kepler_map = SedonaKepler.create_map(df=polygon_df, name="data_1", config=config) + sedona_kepler_map = SedonaKepler.create_map( + df=polygon_df, name="data_1", config=config + ) SedonaKepler.add_df(sedona_kepler_map, point_df, name="data_2") - polygon_gdf = gpd.GeoDataFrame(data=polygon_df.toPandas(), geometry="countyshape") + polygon_gdf = gpd.GeoDataFrame( + data=polygon_df.toPandas(), geometry="countyshape" + ) polygon_gdf_renamed = polygon_gdf.rename_geometry("geometry") point_gdf = gpd.GeoDataFrame(data=point_df.toPandas(), geometry="arealandmark") point_gdf_renamed = point_gdf.rename_geometry("geometry") diff --git a/python/tests/maps/test_sedonapydeck.py b/python/tests/maps/test_sedonapydeck.py index b2631bb963..329f2e4f4c 100644 --- a/python/tests/maps/test_sedonapydeck.py +++ b/python/tests/maps/test_sedonapydeck.py @@ -27,19 +27,25 @@ class TestVisualization(TestBase): def testChoroplethMap(self): - buildings_csv_df = self.spark.read.format("csv"). \ - option("delimiter", ","). \ - option("header", "true"). \ - option("inferSchema", "true"). \ - load(google_buildings_input_location) + buildings_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "true") + .option("inferSchema", "true") + .load(google_buildings_input_location) + ) buildings_csv_df.createOrReplaceTempView("buildings_table") buildings_df = self.spark.sql( "SELECT confidence, latitude, longitude, ST_GeomFromWKT(geometry) as geometry from buildings_table" ) - buildings_gdf = gpd.GeoDataFrame(data=buildings_df.toPandas(), geometry="geometry") - fill_color = SedonaPyDeck._create_default_fill_color_(gdf=buildings_gdf, plot_col='confidence') + buildings_gdf = gpd.GeoDataFrame( + data=buildings_df.toPandas(), geometry="geometry" + ) + fill_color = SedonaPyDeck._create_default_fill_color_( + gdf=buildings_gdf, plot_col="confidence" + ) choropleth_layer = pdk.Layer( - 'GeoJsonLayer', # `type` positional argument is here + "GeoJsonLayer", # `type` positional argument is here data=buildings_gdf, auto_highlight=True, get_fill_color=fill_color, @@ -48,87 +54,112 @@ def testChoroplethMap(self): extruded=True, wireframe=True, get_elevation=0, - pickable=True + pickable=True, + ) + p_map = pdk.Deck( + layers=[choropleth_layer], map_style="satellite", map_provider="mapbox" + ) + sedona_pydeck_map = SedonaPyDeck.create_choropleth_map( + df=buildings_df, plot_col="confidence", map_style="satellite" ) - p_map = pdk.Deck(layers=[choropleth_layer], map_style='satellite', map_provider='mapbox') - sedona_pydeck_map = SedonaPyDeck.create_choropleth_map(df=buildings_df, plot_col='confidence', map_style='satellite') res = self.isMapEqual(sedona_map=sedona_pydeck_map, pydeck_map=p_map) assert res is True def testPolygonMap(self): - buildings_csv_df = self.spark.read.format("csv"). \ - option("delimiter", ","). \ - option("header", "true"). \ - load(google_buildings_input_location) + buildings_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "true") + .load(google_buildings_input_location) + ) buildings_csv_df.createOrReplaceTempView("buildings_table") buildings_df = self.spark.sql( "SELECT confidence, latitude, longitude, ST_GeomFromWKT(geometry) as geometry from buildings_table" ) - buildings_gdf = gpd.GeoDataFrame(data=buildings_df.toPandas(), geometry="geometry") + buildings_gdf = gpd.GeoDataFrame( + data=buildings_df.toPandas(), geometry="geometry" + ) polygon_layer = pdk.Layer( - 'GeoJsonLayer', # `type` positional argument is here + "GeoJsonLayer", # `type` positional argument is here data=buildings_gdf, auto_highlight=True, get_fill_color="[85, 183, 177, 255]", opacity=0.4, stroked=True, extruded=True, - get_elevation='confidence * 10', + get_elevation="confidence * 10", pickable=True, get_line_color="[85, 183, 177, 255]", - get_line_width=3 + get_line_width=3, ) p_map = pdk.Deck(layers=[polygon_layer]) - sedona_pydeck_map = SedonaPyDeck.create_geometry_map(df=buildings_df, elevation_col='confidence * 10') + sedona_pydeck_map = SedonaPyDeck.create_geometry_map( + df=buildings_df, elevation_col="confidence * 10" + ) res = self.isMapEqual(sedona_map=sedona_pydeck_map, pydeck_map=p_map) assert res is True def testScatterplotMap(self): - chicago_crimes_csv_df = self.spark.read.format("csv"). \ - option("delimiter", ","). \ - option("header", "true"). \ - load(chicago_crimes_input_location) + chicago_crimes_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "true") + .load(chicago_crimes_input_location) + ) chicago_crimes_csv_df.createOrReplaceTempView("crimes_table") chicago_crimes_df = self.spark.sql( "SELECT ST_POINT(CAST(x as DECIMAL(24, 20)), CAST (y as DECIMAL(24, 20))) as geometry, Description, " - "Year from crimes_table") - chicago_crimes_gdf = gpd.GeoDataFrame(data=chicago_crimes_df.toPandas(), geometry='geometry') - SedonaPyDeck._create_coord_column_(chicago_crimes_gdf, geometry_col='geometry') + "Year from crimes_table" + ) + chicago_crimes_gdf = gpd.GeoDataFrame( + data=chicago_crimes_df.toPandas(), geometry="geometry" + ) + SedonaPyDeck._create_coord_column_(chicago_crimes_gdf, geometry_col="geometry") layer = pdk.Layer( "ScatterplotLayer", data=chicago_crimes_gdf, pickable=True, opacity=0.8, filled=True, - get_position='coordinate_array_sedona', + get_position="coordinate_array_sedona", get_fill_color="[255, 140, 0]", get_radius=1, radius_min_pixels=1, radius_max_pixels=10, - radius_scale=1 + radius_scale=1, ) - p_map = pdk.Deck(layers=[layer], map_style='satellite', map_provider='google_maps') - sedona_pydeck_map = SedonaPyDeck.create_scatterplot_map(df=chicago_crimes_df, map_style='satellite', map_provider='google_maps') + p_map = pdk.Deck( + layers=[layer], map_style="satellite", map_provider="google_maps" + ) + sedona_pydeck_map = SedonaPyDeck.create_scatterplot_map( + df=chicago_crimes_df, map_style="satellite", map_provider="google_maps" + ) assert self.isMapEqual(sedona_map=sedona_pydeck_map, pydeck_map=p_map) def testHeatmap(self): - chicago_crimes_csv_df = self.spark.read.format("csv"). \ - option("delimiter", ","). \ - option("header", "true"). \ - load(chicago_crimes_input_location) + chicago_crimes_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "true") + .load(chicago_crimes_input_location) + ) chicago_crimes_csv_df.createOrReplaceTempView("crimes_table") - chicago_crimes_df = self.spark.sql("SELECT ST_POINT(CAST(x as DECIMAL(24, 20)), CAST (y as DECIMAL(24, " - "20))) as geometry, Description, Year from crimes_table") - chicago_crimes_gdf = gpd.GeoDataFrame(data=chicago_crimes_df.toPandas(), geometry='geometry') - SedonaPyDeck._create_coord_column_(chicago_crimes_gdf, geometry_col='geometry') + chicago_crimes_df = self.spark.sql( + "SELECT ST_POINT(CAST(x as DECIMAL(24, 20)), CAST (y as DECIMAL(24, " + "20))) as geometry, Description, Year from crimes_table" + ) + chicago_crimes_gdf = gpd.GeoDataFrame( + data=chicago_crimes_df.toPandas(), geometry="geometry" + ) + SedonaPyDeck._create_coord_column_(chicago_crimes_gdf, geometry_col="geometry") color_range = [ [255, 255, 178], [254, 217, 118], [254, 178, 76], [253, 141, 60], [240, 59, 32], - [240, 59, 32] + [240, 59, 32], ] aggregation = pdk.types.String("SUM") layer = pdk.Layer( @@ -137,14 +168,16 @@ def testHeatmap(self): pickable=True, opacity=0.8, filled=True, - get_position='coordinate_array_sedona', + get_position="coordinate_array_sedona", aggregation=aggregation, color_range=color_range, - get_weight=1 + get_weight=1, ) - p_map = pdk.Deck(layers=[layer], map_style='satellite', map_provider='mapbox') - sedona_pydeck_map = SedonaPyDeck.create_heatmap(df=chicago_crimes_df, map_style='satellite') + p_map = pdk.Deck(layers=[layer], map_style="satellite", map_provider="mapbox") + sedona_pydeck_map = SedonaPyDeck.create_heatmap( + df=chicago_crimes_df, map_style="satellite" + ) assert self.isMapEqual(sedona_map=sedona_pydeck_map, pydeck_map=p_map) def isMapEqual(self, pydeck_map, sedona_map): @@ -155,14 +188,16 @@ def isMapEqual(self, pydeck_map, sedona_map): res = True for key in sedona_dict: try: - if key == 'initialViewState': + if key == "initialViewState": continue - if key == 'layers': - res &= self.isLayerEqual(pydeck_layer=pydeck_dict[key], sedona_layer=sedona_dict[key]) + if key == "layers": + res &= self.isLayerEqual( + pydeck_layer=pydeck_dict[key], sedona_layer=sedona_dict[key] + ) if res is False: return False continue - res &= (str(pydeck_dict[key]) == str(sedona_dict[key])) + res &= str(pydeck_dict[key]) == str(sedona_dict[key]) if res is False: return False except KeyError: @@ -179,9 +214,9 @@ def isLayerEqual(self, pydeck_layer, sedona_layer): return False try: for key in pydeck_layer[i]: - if key == 'data' or key == 'id': + if key == "data" or key == "id": continue - res &= (str(pydeck_layer[i][key]) == str(sedona_layer[i][key])) + res &= str(pydeck_layer[i][key]) == str(sedona_layer[i][key]) if res is False: return False except KeyError: diff --git a/python/tests/properties/linestring_properties.py b/python/tests/properties/linestring_properties.py index d5d1305ea5..fac63c8c69 100644 --- a/python/tests/properties/linestring_properties.py +++ b/python/tests/properties/linestring_properties.py @@ -31,10 +31,18 @@ distance = 0.01 query_polygon_set = os.path.join(tests_resource, "primaryroads-polygon.csv") input_count = 3000 -input_boundary = Envelope(minx=-123.393766, maxx=-65.648659, miny=17.982169, maxy=49.002374) -input_boundary_2 = Envelope(minx=-123.393766, maxx=-65.649956, miny=17.982169, maxy=49.002374) +input_boundary = Envelope( + minx=-123.393766, maxx=-65.648659, miny=17.982169, maxy=49.002374 +) +input_boundary_2 = Envelope( + minx=-123.393766, maxx=-65.649956, miny=17.982169, maxy=49.002374 +) match_count = 535 match_with_origin_with_duplicates_count = 875 -transformed_envelope = Envelope(14313844.294334238, 16791709.853587367, 942450.5989896103, 8474779.278028358) -transformed_envelope_2 = Envelope(14313844.294334238, 16791709.853587367, 942450.5989896103, 8474645.488977432) +transformed_envelope = Envelope( + 14313844.294334238, 16791709.853587367, 942450.5989896103, 8474779.278028358 +) +transformed_envelope_2 = Envelope( + 14313844.294334238, 16791709.853587367, 942450.5989896103, 8474645.488977432 +) diff --git a/python/tests/properties/point_properties.py b/python/tests/properties/point_properties.py index 86c10c9e21..f8af949618 100644 --- a/python/tests/properties/point_properties.py +++ b/python/tests/properties/point_properties.py @@ -32,17 +32,18 @@ query_polygon_set = "primaryroads-polygon.csv" input_count = 3000 input_boundary = Envelope( - minx=-173.120769, - maxx=-84.965961, - miny=30.244859, - maxy=71.355134 + minx=-173.120769, maxx=-84.965961, miny=30.244859, maxy=71.355134 ) rectangle_match_count = 103 rectangle_with_original_duplicates_count = 103 polygon_match_count = 472 polygon_match_with_original_duplicates_count = 562 -transformed_envelope = Envelope(14313844.294334238, 16587207.463797076, 942450.5989896103, 6697987.652517772) +transformed_envelope = Envelope( + 14313844.294334238, 16587207.463797076, 942450.5989896103, 6697987.652517772 +) crs_point_test = os.path.join(tests_resource, "crs-test-point.csv") crs_envelope = Envelope(26.992172, 71.35513400000001, -179.147236, 179.475569) -crs_envelope_transformed = Envelope(-5446655.086752236, 1983668.3828524568, 534241.8622328975, 6143259.025545624) +crs_envelope_transformed = Envelope( + -5446655.086752236, 1983668.3828524568, 534241.8622328975, 6143259.025545624 +) diff --git a/python/tests/properties/polygon_properties.py b/python/tests/properties/polygon_properties.py index d2638b3c59..f4ca942149 100644 --- a/python/tests/properties/polygon_properties.py +++ b/python/tests/properties/polygon_properties.py @@ -32,13 +32,17 @@ input_location_query_polygon = os.path.join(tests_resource, "crs-test-polygon.csv") query_polygon_count = 13361 -query_envelope = Envelope(14313844.294334238, 16802290.853830762, 942450.5989896103, 8631908.270651892) +query_envelope = Envelope( + 14313844.294334238, 16802290.853830762, 942450.5989896103, 8631908.270651892 +) query_polygon_set = os.path.join(tests_resource, "primaryroads-polygon.csv") input_location_geo_json = os.path.join(tests_resource, "testPolygon.json") input_location_wkt = os.path.join(tests_resource, "county_small.tsv") input_location_wkb = os.path.join(tests_resource, "county_small_wkb.tsv") input_count = 3000 -input_boundary = Envelope(minx=-158.104182, maxx=-66.03575, miny=17.986328, maxy=48.645133) +input_boundary = Envelope( + minx=-158.104182, maxx=-66.03575, miny=17.986328, maxy=48.645133 +) contains_match_count = 6941 contains_match_with_original_duplicates_count = 9334 intersects_match_count = 24323 diff --git a/python/tests/raster/test_meta.py b/python/tests/raster/test_meta.py index 68135ba25d..7372cecf7d 100644 --- a/python/tests/raster/test_meta.py +++ b/python/tests/raster/test_meta.py @@ -32,7 +32,9 @@ def test_change_anchor_to_upper_left(self): ip_x = 100 ip_y = 200 - trans = AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.CENTER) + trans = AffineTransform( + scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.CENTER + ) trans_gdal = trans.with_anchor(PixelAnchor.UPPER_LEFT) assert trans_gdal.scale_x == approx(scale_x) assert trans_gdal.scale_y == approx(scale_y) @@ -49,7 +51,9 @@ def test_change_anchor_to_center(self): ip_x = 100 ip_y = 200 - trans_gdal = AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.UPPER_LEFT) + trans_gdal = AffineTransform( + scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.UPPER_LEFT + ) trans = trans_gdal.with_anchor(PixelAnchor.CENTER) assert trans.scale_x == approx(scale_x) assert trans.scale_y == approx(scale_y) diff --git a/python/tests/raster/test_pandas_udf.py b/python/tests/raster/test_pandas_udf.py index 8e7304941f..b6a7d4682b 100644 --- a/python/tests/raster/test_pandas_udf.py +++ b/python/tests/raster/test_pandas_udf.py @@ -27,11 +27,19 @@ from tests import world_map_raster_input_location + class TestRasterPandasUDF(TestBase): - @pytest.mark.skipif(pyspark.__version__ < '3.4', reason="requires Spark 3.4 or higher") + @pytest.mark.skipif( + pyspark.__version__ < "3.4", reason="requires Spark 3.4 or higher" + ) def test_raster_as_param(self): spark = TestRasterPandasUDF.spark - df = spark.range(10).withColumn("rast", expr("RS_MakeRasterForTesting(1, 'I', 'PixelInterleavedSampleModel', 4, 3, 100, 100, 10, -10, 0, 0, 3857)")) + df = spark.range(10).withColumn( + "rast", + expr( + "RS_MakeRasterForTesting(1, 'I', 'PixelInterleavedSampleModel', 4, 3, 100, 100, 10, -10, 0, 0, 3857)" + ), + ) # A Python Pandas UDF that takes a raster as input @pandas_udf(IntegerType()) @@ -67,10 +75,10 @@ def func(x): rows = df_result.collect() assert len(rows) == 10 for row in rows: - assert row['res'] == 66 + assert row["res"] == 66 df_result = df.selectExpr("pandas_udf_raster_as_param_2(rast) as res") rows = df_result.collect() assert len(rows) == 10 for row in rows: - assert row['res'] == 66 + assert row["res"] == 66 diff --git a/python/tests/raster/test_serde.py b/python/tests/raster/test_serde.py index dc94b01099..b50ddfd337 100644 --- a/python/tests/raster/test_serde.py +++ b/python/tests/raster/test_serde.py @@ -25,47 +25,76 @@ from tests import world_map_raster_input_location + class TestRasterSerde(TestBase): def test_empty_raster(self): - df = TestRasterSerde.spark.sql("SELECT RS_MakeEmptyRaster(2, 100, 200, 1000, 2000, 1) as raster") + df = TestRasterSerde.spark.sql( + "SELECT RS_MakeEmptyRaster(2, 100, 200, 1000, 2000, 1) as raster" + ) raster = df.first()[0] - assert raster.width == 100 and raster.height == 200 and len(raster.bands_meta) == 2 + assert ( + raster.width == 100 and raster.height == 200 and len(raster.bands_meta) == 2 + ) assert raster.affine_trans.ip_x == 1000 and raster.affine_trans.ip_y == 2000 assert raster.affine_trans.scale_x == 1 and raster.affine_trans.scale_y == -1 def test_banded_sample_model(self): - df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'BandedSampleModel', 10, 8, 100, 100, 10, -10, 0, 0, 3857) as raster") + df = TestRasterSerde.spark.sql( + "SELECT RS_MakeRasterForTesting(3, 'I', 'BandedSampleModel', 10, 8, 100, 100, 10, -10, 0, 0, 3857) as raster" + ) raster = df.first()[0] assert raster.width == 10 and raster.height == 8 and len(raster.bands_meta) == 3 self.validate_test_raster(raster) def test_pixel_interleaved_sample_model(self): - df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'PixelInterleavedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster") + df = TestRasterSerde.spark.sql( + "SELECT RS_MakeRasterForTesting(3, 'I', 'PixelInterleavedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster" + ) raster = df.first()[0] - assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 3 + assert ( + raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 3 + ) self.validate_test_raster(raster) - df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, 'I', 'PixelInterleavedSampleModelComplex', 8, 10, 100, 100, 10, -10, 0, 0, 3857) as raster") + df = TestRasterSerde.spark.sql( + "SELECT RS_MakeRasterForTesting(4, 'I', 'PixelInterleavedSampleModelComplex', 8, 10, 100, 100, 10, -10, 0, 0, 3857) as raster" + ) raster = df.first()[0] assert raster.width == 8 and raster.height == 10 and len(raster.bands_meta) == 4 self.validate_test_raster(raster) def test_component_sample_model(self): - for pixel_type in ['B', 'S', 'US', 'I', 'F', 'D']: - df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, '{}', 'ComponentSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster".format(pixel_type)) + for pixel_type in ["B", "S", "US", "I", "F", "D"]: + df = TestRasterSerde.spark.sql( + "SELECT RS_MakeRasterForTesting(4, '{}', 'ComponentSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster".format( + pixel_type + ) + ) raster = df.first()[0] - assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 4 + assert ( + raster.width == 10 + and raster.height == 10 + and len(raster.bands_meta) == 4 + ) self.validate_test_raster(raster) def test_multi_pixel_packed_sample_model(self): - df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(1, 'B', 'MultiPixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster") + df = TestRasterSerde.spark.sql( + "SELECT RS_MakeRasterForTesting(1, 'B', 'MultiPixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster" + ) raster = df.first()[0] - assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 1 + assert ( + raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 1 + ) self.validate_test_raster(raster, packed=True) def test_single_pixel_packed_sample_model(self): - df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, 'I', 'SinglePixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster") + df = TestRasterSerde.spark.sql( + "SELECT RS_MakeRasterForTesting(4, 'I', 'SinglePixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster" + ) raster = df.first()[0] - assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 4 + assert ( + raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 4 + ) self.validate_test_raster(raster, packed=True) def test_raster_read_from_geotiff(self): @@ -73,7 +102,11 @@ def test_raster_read_from_geotiff(self): r_orig = rasterio.open(raster_path) band = r_orig.read(1) band_masked = np.where(band == 0, np.nan, band) - df = TestRasterSerde.spark.read.format("binaryFile").load(raster_path).selectExpr("RS_FromGeoTiff(content) as raster") + df = ( + TestRasterSerde.spark.read.format("binaryFile") + .load(raster_path) + .selectExpr("RS_FromGeoTiff(content) as raster") + ) raster = df.first()[0] assert raster.width == r_orig.width assert raster.height == r_orig.height @@ -92,19 +125,23 @@ def test_raster_read_from_geotiff(self): # test as_numpy_masked arr = raster.as_numpy_masked()[0, :, :] - assert np.array_equal(arr, band_masked) or np.array_equal(np.isnan(arr), np.isnan(band_masked)) + assert np.array_equal(arr, band_masked) or np.array_equal( + np.isnan(arr), np.isnan(band_masked) + ) raster.close() r_orig.close() def test_to_pandas(self): spark = TestRasterSerde.spark - df = spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'BandedSampleModel', 10, 8, 100, 100, 10, -10, 0, 0, 3857) as raster") + df = spark.sql( + "SELECT RS_MakeRasterForTesting(3, 'I', 'BandedSampleModel', 10, 8, 100, 100, 10, -10, 0, 0, 3857) as raster" + ) pandas_df = df.toPandas() - raster = pandas_df.iloc[0]['raster'] + raster = pandas_df.iloc[0]["raster"] self.validate_test_raster(raster) - def validate_test_raster(self, raster, packed = False): + def validate_test_raster(self, raster, packed=False): arr = raster.as_numpy() ds = raster.as_rasterio() bands, height, width = arr.shape diff --git a/python/tests/raster_viz_utils/test_sedonautils.py b/python/tests/raster_viz_utils/test_sedonautils.py index d82679634b..02ae06cdf2 100644 --- a/python/tests/raster_viz_utils/test_sedonautils.py +++ b/python/tests/raster_viz_utils/test_sedonautils.py @@ -23,9 +23,15 @@ class TestSedonaUtils(TestBase): def test_display_image(self): - raster_bin_df = self.spark.read.format('binaryFile').load(world_map_raster_input_location) - raster_bin_df.createOrReplaceTempView('raster_binary_table') - raster_df = self.spark.sql('SELECT RS_FromGeotiff(content) as raster from raster_binary_table') - raster_image_df = raster_df.selectExpr('RS_AsImage(raster) as rast_img') + raster_bin_df = self.spark.read.format("binaryFile").load( + world_map_raster_input_location + ) + raster_bin_df.createOrReplaceTempView("raster_binary_table") + raster_df = self.spark.sql( + "SELECT RS_FromGeotiff(content) as raster from raster_binary_table" + ) + raster_image_df = raster_df.selectExpr("RS_AsImage(raster) as rast_img") html_call = SedonaUtils.display_image(raster_image_df) - assert html_call is None # just test that this function was called and returned no output + assert ( + html_call is None + ) # just test that this function was called and returned no output diff --git a/python/tests/serialization/test_deserializers.py b/python/tests/serialization/test_deserializers.py index e6b1e383a2..8495a7374c 100644 --- a/python/tests/serialization/test_deserializers.py +++ b/python/tests/serialization/test_deserializers.py @@ -17,7 +17,15 @@ import os -from shapely.geometry import MultiPoint, Point, MultiLineString, LineString, Polygon, MultiPolygon, GeometryCollection +from shapely.geometry import ( + MultiPoint, + Point, + MultiLineString, + LineString, + Polygon, + MultiPolygon, + GeometryCollection, +) import geopandas as gpd import pandas as pd @@ -37,14 +45,16 @@ def test_collect(self): df.collect() def test_loading_from_file_deserialization(self): - self.spark.read.\ - options(delimiter="\t", header=False).\ - csv(os.path.join(tests_resource, "county_small.tsv")).\ - limit(1).\ - createOrReplaceTempView("counties") - - geom_area = self.spark.sql("SELECT st_area(st_geomFromWKT(_c0)) as area from counties").collect()[0][0] - polygon_shapely = self.spark.sql("SELECT st_geomFromWKT(_c0) from counties").collect()[0][0] + self.spark.read.options(delimiter="\t", header=False).csv( + os.path.join(tests_resource, "county_small.tsv") + ).limit(1).createOrReplaceTempView("counties") + + geom_area = self.spark.sql( + "SELECT st_area(st_geomFromWKT(_c0)) as area from counties" + ).collect()[0][0] + polygon_shapely = self.spark.sql( + "SELECT st_geomFromWKT(_c0) from counties" + ).collect()[0][0] assert geom_area == polygon_shapely.area def test_polygon_with_holes_deserialization(self): @@ -67,11 +77,15 @@ def test_multipolygon_with_holes_deserialization(self): assert geom.area == 712.5 def test_point_deserialization(self): - geom = self.spark.sql("""SELECT st_geomfromtext('POINT(-6.0 52.0)') as geom""").collect()[0][0] + geom = self.spark.sql( + """SELECT st_geomfromtext('POINT(-6.0 52.0)') as geom""" + ).collect()[0][0] assert geom.wkt == Point(-6.0, 52.0).wkt def test_multipoint_deserialization(self): - geom = self.spark.sql("""select st_geomFromWKT('MULTIPOINT(1 2, -2 3)') as geom""").collect()[0][0] + geom = self.spark.sql( + """select st_geomFromWKT('MULTIPOINT(1 2, -2 3)') as geom""" + ).collect()[0][0] assert geom.wkt == MultiPoint([(1, 2), (-2, 3)]).wkt @@ -91,10 +105,15 @@ def test_multilinestring_deserialization(self): ).collect()[0][0] assert type(geom) == MultiLineString - assert geom.wkt == MultiLineString([ - ((10, 10), (20, 20), (10, 40)), - ((40, 40), (30, 30), (40, 20), (30, 10)) - ]).wkt + assert ( + geom.wkt + == MultiLineString( + [ + ((10, 10), (20, 20), (10, 40)), + ((40, 40), (30, 30), (40, 20), (30, 10)), + ] + ).wkt + ) def test_geometry_collection_deserialization(self): geom = self.spark.sql( @@ -105,25 +124,33 @@ def test_geometry_collection_deserialization(self): ).collect()[0][0] assert type(geom) == GeometryCollection - assert geom.wkt == GeometryCollection([ - MultiLineString([[(1, 2), (3, 4)], [(5, 6), (7, 8)]]), - MultiLineString([[(1, 2), (3, 4)], [(5, 6), (7, 8)], [(9, 10), (11, 12)]]), - Point(10, 20) - ]).wkt + assert ( + geom.wkt + == GeometryCollection( + [ + MultiLineString([[(1, 2), (3, 4)], [(5, 6), (7, 8)]]), + MultiLineString( + [[(1, 2), (3, 4)], [(5, 6), (7, 8)], [(9, 10), (11, 12)]] + ), + Point(10, 20), + ] + ).wkt + ) def test_from_geopandas_convert(self): - gdf = gpd.read_file(os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1/")) - gdf = gdf.replace(pd.NA, '') + gdf = gpd.read_file( + os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1/") + ) + gdf = gdf.replace(pd.NA, "") - self.spark.createDataFrame( - gdf - ).show() + self.spark.createDataFrame(gdf).show() def test_to_geopandas(self): - counties = self.spark.read.\ - options(delimiter="\t", header=False).\ - csv(os.path.join(tests_resource, "county_small.tsv")).\ - limit(1) + counties = ( + self.spark.read.options(delimiter="\t", header=False) + .csv(os.path.join(tests_resource, "county_small.tsv")) + .limit(1) + ) counties.createOrReplaceTempView("county") diff --git a/python/tests/serialization/test_direct_serialization.py b/python/tests/serialization/test_direct_serialization.py index f51d25b27d..07df1f5dda 100644 --- a/python/tests/serialization/test_direct_serialization.py +++ b/python/tests/serialization/test_direct_serialization.py @@ -26,7 +26,9 @@ class TestDirectSerialization(TestBase): def test_polygon(self): polygon = Polygon([(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]) - jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(self.sc._jvm, polygon) + jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry( + self.sc._jvm, polygon + ) assert jvm_geom.toString() == "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))" @@ -34,24 +36,35 @@ def test_polygon(self): int = [(1, 1), (1, 1.5), (1.5, 1.5), (1.5, 1), (1, 1)] polygon = Polygon(ext, [int]) - jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(self.sc._jvm, polygon) + jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry( + self.sc._jvm, polygon + ) - assert jvm_geom.toString() == "POLYGON ((0 0, 0 2, 2 2, 2 0, 0 0), (1 1, 1 1.5, 1.5 1.5, 1.5 1, 1 1))" + assert ( + jvm_geom.toString() + == "POLYGON ((0 0, 0 2, 2 2, 2 0, 0 0), (1 1, 1 1.5, 1.5 1.5, 1.5 1, 1 1))" + ) wkt = "POLYGON ((-71.1776585052917 42.3902909739571, -71.1776820268866 42.3903701743239, -71.1776063012595 42.3903825660754, -71.1775826583081 42.3903033653531, -71.1776585052917 42.3902909739571))" polygon = loads(wkt) - jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(self.sc._jvm, polygon) + jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry( + self.sc._jvm, polygon + ) assert jvm_geom.toString() == wkt def test_point(self): wkt = "POINT (-71.064544 42.28787)" point = loads(wkt) - jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(self.sc._jvm, point) + jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry( + self.sc._jvm, point + ) assert jvm_geom.toString() == wkt def test_linestring(self): - wkt = 'LINESTRING (-71.160281 42.258729, -71.160837 42.259113, -71.161144 42.25932)' + wkt = "LINESTRING (-71.160281 42.258729, -71.160837 42.259113, -71.161144 42.25932)" linestring = loads(wkt) - jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(self.sc._jvm, linestring) + jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry( + self.sc._jvm, linestring + ) assert jvm_geom.toString() == wkt diff --git a/python/tests/serialization/test_geospark_serializers.py b/python/tests/serialization/test_geospark_serializers.py index 65c3e82cb9..18640f4b10 100644 --- a/python/tests/serialization/test_geospark_serializers.py +++ b/python/tests/serialization/test_geospark_serializers.py @@ -24,8 +24,11 @@ def test_creating_point(self): self.spark.sql("SELECT st_GeomFromWKT('Point(21.0 52.0)')").show() def test_spark_config(self): - kryo_reg = ('spark.kryo.registrator', 'org.apache.sedona.core.serde.SedonaKryoRegistrator') - serializer = ('spark.serializer', 'org.apache.spark.serializer.KryoSerializer') + kryo_reg = ( + "spark.kryo.registrator", + "org.apache.sedona.core.serde.SedonaKryoRegistrator", + ) + serializer = ("spark.serializer", "org.apache.spark.serializer.KryoSerializer") spark_config = self.spark.sparkContext._conf.getAll() assert kryo_reg in spark_config assert serializer in spark_config diff --git a/python/tests/serialization/test_rdd_serialization.py b/python/tests/serialization/test_rdd_serialization.py index 137dcc0ca0..56900a17e3 100644 --- a/python/tests/serialization/test_rdd_serialization.py +++ b/python/tests/serialization/test_rdd_serialization.py @@ -24,7 +24,9 @@ point_rdd_input_location = os.path.join(tests_resource, "arealm-small.csv") polygon_rdd_input_location = os.path.join(tests_resource, "primaryroads-polygon.csv") -linestring_rdd_input_location = os.path.join(tests_resource, "primaryroads-linestring.csv") +linestring_rdd_input_location = os.path.join( + tests_resource, "primaryroads-linestring.csv" +) linestring_rdd_splittter = FileDataSplitter.CSV polygon_rdd_splitter = FileDataSplitter.CSV @@ -45,17 +47,21 @@ def test_point_rdd(self): InputLocation=point_rdd_input_location, Offset=point_rdd_offset, splitter=point_rdd_splitter, - carryInputData=False + carryInputData=False, ) collected_points = point_rdd.getRawSpatialRDD().collect() points_coordinates = [ - [-88.331492, 32.324142], [-88.175933, 32.360763], - [-88.388954, 32.357073], [-88.221102, 32.35078] + [-88.331492, 32.324142], + [-88.175933, 32.360763], + [-88.388954, 32.357073], + [-88.221102, 32.35078], ] - assert [[geo_data.geom.x, geo_data.geom.y] for geo_data in collected_points[:4]] == points_coordinates[:4] + assert [ + [geo_data.geom.x, geo_data.geom.y] for geo_data in collected_points[:4] + ] == points_coordinates[:4] def test_polygon_rdd(self): polygon_rdd = PolygonRDD( @@ -64,7 +70,7 @@ def test_polygon_rdd(self): startOffset=polygon_rdd_start_offset, endOffset=polygon_rdd_end_offset, splitter=polygon_rdd_splitter, - carryInputData=True + carryInputData=True, ) collected_polygon_rdd = polygon_rdd.getRawSpatialRDD().collect() @@ -72,10 +78,12 @@ def test_polygon_rdd(self): input_wkt_polygons = [ "POLYGON ((-74.020753 40.836454, -74.020753 40.843768, -74.018162 40.843768, -74.018162 40.836454, -74.020753 40.836454))", "POLYGON ((-74.018978 40.837712, -74.018978 40.852181, -74.014938 40.852181, -74.014938 40.837712, -74.018978 40.837712))", - "POLYGON ((-74.021683 40.833253, -74.021683 40.834288, -74.021368 40.834288, -74.021368 40.833253, -74.021683 40.833253))" + "POLYGON ((-74.021683 40.833253, -74.021683 40.834288, -74.021368 40.834288, -74.021368 40.833253, -74.021683 40.833253))", ] - assert [geo_data.geom.wkt for geo_data in collected_polygon_rdd][:3] == input_wkt_polygons + assert [geo_data.geom.wkt for geo_data in collected_polygon_rdd][ + :3 + ] == input_wkt_polygons # def test_circle_rdd(self): # object_rdd = PointRDD( @@ -96,7 +104,7 @@ def test_linestring_rdd(self): startOffset=0, endOffset=7, splitter=FileDataSplitter.CSV, - carryInputData=True + carryInputData=True, ) wkt = "LINESTRING (-112.506968 45.98186, -112.506968 45.983586, -112.504872 45.983586, -112.504872 45.98186)" diff --git a/python/tests/serialization/test_serializers.py b/python/tests/serialization/test_serializers.py index 2bc2186931..8c7153683e 100644 --- a/python/tests/serialization/test_serializers.py +++ b/python/tests/serialization/test_serializers.py @@ -23,7 +23,14 @@ from tests import tests_resource from sedona.sql.types import GeometryType -from shapely.geometry import Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon +from shapely.geometry import ( + Point, + MultiPoint, + LineString, + MultiLineString, + Polygon, + MultiPolygon, +) from pyspark.sql import types as t from tests.test_base import TestBase @@ -32,22 +39,16 @@ class TestsSerializers(TestBase): def test_point_serializer(self): - data = [ - [1, Point(21.0, 56.0), Point(21.0, 59.0)] - - ] + data = [[1, Point(21.0, 56.0), Point(21.0, 59.0)]] schema = t.StructType( [ t.StructField("id", IntegerType(), True), t.StructField("geom_from", GeometryType(), True), - t.StructField("geom_to", GeometryType(), True) + t.StructField("geom_to", GeometryType(), True), ] ) - self.spark.createDataFrame( - data, - schema - ).createOrReplaceTempView("points") + self.spark.createDataFrame(data, schema).createOrReplaceTempView("points") distance = self.spark.sql( "select st_distance(geom_from, geom_to) from points" @@ -56,67 +57,53 @@ def test_point_serializer(self): def test_multipoint_serializer(self): - multipoint = MultiPoint([ - [21.0, 56.0], - [21.0, 57.0] - ]) - data = [ - [1, multipoint] - ] + multipoint = MultiPoint([[21.0, 56.0], [21.0, 57.0]]) + data = [[1, multipoint]] schema = t.StructType( [ t.StructField("id", IntegerType(), True), - t.StructField("geom", GeometryType(), True) + t.StructField("geom", GeometryType(), True), ] ) - m_point_out = self.spark.createDataFrame( - data, - schema - ).collect()[0][1] + m_point_out = self.spark.createDataFrame(data, schema).collect()[0][1] assert m_point_out == multipoint def test_linestring_serialization(self): linestring = LineString([(0.0, 1.0), (1, 1), (12.0, 1.0)]) - data = [ - [1, linestring] - ] + data = [[1, linestring]] schema = t.StructType( [ t.StructField("id", IntegerType(), True), - t.StructField("geom", GeometryType(), True) + t.StructField("geom", GeometryType(), True), ] ) - self.spark.createDataFrame( - data, - schema - ).createOrReplaceTempView("line") + self.spark.createDataFrame(data, schema).createOrReplaceTempView("line") length = self.spark.sql("select st_length(geom) from line").collect()[0][0] assert length == 12.0 def test_multilinestring_serialization(self): multilinestring = MultiLineString([[[0, 1], [1, 1]], [[2, 2], [3, 2]]]) - data = [ - [1, multilinestring] - ] + data = [[1, multilinestring]] schema = t.StructType( [ t.StructField("id", IntegerType(), True), - t.StructField("geom", GeometryType(), True) + t.StructField("geom", GeometryType(), True), ] ) - self.spark.createDataFrame( - data, - schema - ).createOrReplaceTempView("multilinestring") + self.spark.createDataFrame(data, schema).createOrReplaceTempView( + "multilinestring" + ) - length = self.spark.sql("select st_length(geom) from multilinestring").collect()[0][0] + length = self.spark.sql( + "select st_length(geom) from multilinestring" + ).collect()[0][0] assert length == 2.0 def test_polygon_serialization(self): @@ -125,31 +112,26 @@ def test_polygon_serialization(self): polygon = Polygon(ext, [int]) - data = [ - [1, polygon] - ] + data = [[1, polygon]] schema = t.StructType( [ t.StructField("id", IntegerType(), True), - t.StructField("geom", GeometryType(), True) + t.StructField("geom", GeometryType(), True), ] ) - self.spark.createDataFrame( - data, - schema - ).createOrReplaceTempView("polygon") + self.spark.createDataFrame(data, schema).createOrReplaceTempView("polygon") length = self.spark.sql("select st_area(geom) from polygon").collect()[0][0] assert length == 3.75 def test_geopandas_convertion(self): - gdf = gpd.read_file(os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1/")) - gdf = gdf.replace(pd.NA, '') - print(self.spark.createDataFrame( - gdf - ).toPandas()) + gdf = gpd.read_file( + os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1/") + ) + gdf = gdf.replace(pd.NA, "") + print(self.spark.createDataFrame(gdf).toPandas()) def test_multipolygon_serialization(self): exterior = [(0, 0), (0, 2), (2, 2), (2, 0), (0, 0)] @@ -157,36 +139,25 @@ def test_multipolygon_serialization(self): polygons = [ Polygon(exterior, [interior]), - Polygon([[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]]) + Polygon([[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]]), ] multipolygon = MultiPolygon(polygons) - data = [ - [1, multipolygon] - ] + data = [[1, multipolygon]] schema = t.StructType( [ t.StructField("id", IntegerType(), True), - t.StructField("geom", GeometryType(), True) + t.StructField("geom", GeometryType(), True), ] ) - self.spark.createDataFrame( - data, - schema - ).show(1, False) - self.spark.createDataFrame( - data, - schema - ).createOrReplaceTempView("polygon") + self.spark.createDataFrame(data, schema).show(1, False) + self.spark.createDataFrame(data, schema).createOrReplaceTempView("polygon") length = self.spark.sql("select st_area(geom) from polygon").collect()[0][0] assert length == 4.75 def test_null_serializer(self): - data = [ - [1, None] - - ] + data = [[1, None]] schema = t.StructType( [ @@ -194,10 +165,7 @@ def test_null_serializer(self): t.StructField("geom", GeometryType(), True), ] ) - self.spark.createDataFrame( - data, - schema - ).createOrReplaceTempView("points") + self.spark.createDataFrame(data, schema).createOrReplaceTempView("points") count = self.spark.sql("select count(*) from points").collect()[0][0] assert count == 1 diff --git a/python/tests/serialization/test_with_sc_parellize.py b/python/tests/serialization/test_with_sc_parellize.py index 3db854a43f..b0ad0acf7b 100644 --- a/python/tests/serialization/test_with_sc_parellize.py +++ b/python/tests/serialization/test_with_sc_parellize.py @@ -29,7 +29,7 @@ def test_geo_data_convert_to_point_rdd(self): points = [ GeoData(geom=Point(52.0, -21.0), userData="a"), GeoData(geom=Point(-152.4546, -23.1423), userData="b"), - GeoData(geom=Point(62.253456, 221.2145), userData="c") + GeoData(geom=Point(62.253456, 221.2145), userData="c"), ] rdd_data = self.sc.parallelize(points) @@ -41,7 +41,7 @@ def test_geo_data_convert_to_point_rdd(self): def test_geo_data_convert_polygon_rdd(self): linestring = LineString([(0.0, 1.0), (1, 1), (12.0, 1.0)]) - wkt = 'LINESTRING (-71.160281 42.258729, -71.160837 42.259113, -71.161144 42.25932)' + wkt = "LINESTRING (-71.160281 42.258729, -71.160837 42.259113, -71.161144 42.25932)" linestring2 = loads(wkt) linestrings = [ @@ -54,7 +54,9 @@ def test_geo_data_convert_polygon_rdd(self): linestring_rdd = LineStringRDD(rdd_data) collected_data = linestring_rdd.rawSpatialRDD.collect() sorted_collected_data = sorted(collected_data, key=lambda x: x.userData) - assert all([geo1 == geo2 for geo1, geo2 in zip(linestrings, sorted_collected_data)]) + assert all( + [geo1 == geo2 for geo1, geo2 in zip(linestrings, sorted_collected_data)] + ) def test_geo_data_convert_linestring_rdd(self): polygon = Polygon([(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]) @@ -68,9 +70,9 @@ def test_geo_data_convert_linestring_rdd(self): polygon3 = loads(wkt) polygons = [ - GeoData(geom=polygon, userData="a"), - GeoData(geom=polygon2, userData="b"), - GeoData(geom=polygon3, userData="c"), + GeoData(geom=polygon, userData="a"), + GeoData(geom=polygon2, userData="b"), + GeoData(geom=polygon3, userData="c"), ] rdd_data = self.sc.parallelize(polygons) @@ -78,4 +80,6 @@ def test_geo_data_convert_linestring_rdd(self): polygon_rdd = PolygonRDD(rdd_data) collected_data = polygon_rdd.rawSpatialRDD.collect() sorted_collected_data = sorted(collected_data, key=lambda x: x.userData) - assert all([geo1 == geo2 for geo1, geo2 in zip(polygons, sorted_collected_data)]) + assert all( + [geo1 == geo2 for geo1, geo2 in zip(polygons, sorted_collected_data)] + ) diff --git a/python/tests/spatial_operator/test_join_query_correctness.py b/python/tests/spatial_operator/test_join_query_correctness.py index e2390b28c9..7b2f743527 100644 --- a/python/tests/spatial_operator/test_join_query_correctness.py +++ b/python/tests/spatial_operator/test_join_query_correctness.py @@ -39,10 +39,14 @@ def test_inside_point_join_correctness(self): object_rdd = PointRDD(self.sc.parallelize(self.test_inside_point_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, False).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, False + ).collect() self.verify_join_result(result) - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, False).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, False + ).collect() self.verify_join_result(result_no_index) def test_on_boundary_point_join_correctness(self): @@ -50,10 +54,14 @@ def test_on_boundary_point_join_correctness(self): object_rdd = PointRDD(self.sc.parallelize(self.test_on_boundary_point_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, False).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, False + ).collect() self.verify_join_result(result) - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, False).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, False + ).collect() self.verify_join_result(result_no_index) def test_outside_point_join_correctness(self): @@ -62,46 +70,64 @@ def test_outside_point_join_correctness(self): object_rdd = PointRDD(self.sc.parallelize(self.test_outside_point_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, False).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, False + ).collect() assert 0 == result.__len__() - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, False).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, False + ).collect() assert 0 == result_no_index.__len__() def test_inside_linestring_join_correctness(self): - window_rdd = PolygonRDD( - self.sc.parallelize(self.test_polygon_window_set) - ) + window_rdd = PolygonRDD(self.sc.parallelize(self.test_polygon_window_set)) object_rdd = LineStringRDD(self.sc.parallelize(self.test_inside_linestring_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, False).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, False + ).collect() self.verify_join_result(result) - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, False).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, False + ).collect() self.verify_join_result(result_no_index) def test_overlapped_linestring_join_correctness(self): window_rdd = PolygonRDD(self.sc.parallelize(self.test_polygon_window_set)) - object_rdd = LineStringRDD(self.sc.parallelize(self.test_overlapped_linestring_set)) + object_rdd = LineStringRDD( + self.sc.parallelize(self.test_overlapped_linestring_set) + ) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, True).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, True + ).collect() self.verify_join_result(result) - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, True).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, True + ).collect() self.verify_join_result(result_no_index) def test_outside_line_string_join_correctness(self): window_rdd = PolygonRDD(self.sc.parallelize(self.test_polygon_window_set)) - object_rdd = LineStringRDD(self.sc.parallelize(self.test_outside_linestring_set)) + object_rdd = LineStringRDD( + self.sc.parallelize(self.test_outside_linestring_set) + ) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, False).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, False + ).collect() assert 0 == result.__len__() - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, False).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, False + ).collect() assert 0 == result_no_index.__len__() def test_inside_polygon_join_correctness(self): @@ -110,10 +136,14 @@ def test_inside_polygon_join_correctness(self): object_rdd = PolygonRDD(self.sc.parallelize(self.test_inside_polygon_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, False).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, False + ).collect() self.verify_join_result(result) - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, False).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, False + ).collect() self.verify_join_result(result_no_index) def test_overlapped_polygon_join_correctness(self): @@ -121,10 +151,14 @@ def test_overlapped_polygon_join_correctness(self): object_rdd = PolygonRDD(self.sc.parallelize(self.test_overlapped_polygon_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, True).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, True + ).collect() self.verify_join_result(result) - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, True).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, True + ).collect() self.verify_join_result(result_no_index) def test_outside_polygon_join_correctness(self): @@ -132,49 +166,73 @@ def test_outside_polygon_join_correctness(self): object_rdd = PolygonRDD(self.sc.parallelize(self.test_outside_polygon_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, True, False).collect() + result = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, True, False + ).collect() assert 0 == result.__len__() - result_no_index = JoinQuery.SpatialJoinQuery(object_rdd, window_rdd, False, False).collect() + result_no_index = JoinQuery.SpatialJoinQuery( + object_rdd, window_rdd, False, False + ).collect() assert 0 == result_no_index.__len__() def test_inside_polygon_distance_join_correctness(self): - center_geometry_rdd = PolygonRDD(self.sc.parallelize(self.test_polygon_window_set)) + center_geometry_rdd = PolygonRDD( + self.sc.parallelize(self.test_polygon_window_set) + ) window_rdd = CircleRDD(center_geometry_rdd, 0.1) object_rdd = PolygonRDD(self.sc.parallelize(self.test_inside_polygon_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.DistanceJoinQuery(object_rdd, window_rdd, True, False).collect() + result = JoinQuery.DistanceJoinQuery( + object_rdd, window_rdd, True, False + ).collect() self.verify_join_result(result) - result_no_index = JoinQuery.DistanceJoinQuery(object_rdd, window_rdd, False, False).collect() + result_no_index = JoinQuery.DistanceJoinQuery( + object_rdd, window_rdd, False, False + ).collect() self.verify_join_result(result_no_index) def test_overlapped_polygon_distance_join_correctness(self): - center_geometry_rdd = PolygonRDD(self.sc.parallelize(self.test_polygon_window_set)) + center_geometry_rdd = PolygonRDD( + self.sc.parallelize(self.test_polygon_window_set) + ) window_rdd = CircleRDD(center_geometry_rdd, 0.1) object_rdd = PolygonRDD(self.sc.parallelize(self.test_overlapped_polygon_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.DistanceJoinQuery(object_rdd, window_rdd, True, True).collect() + result = JoinQuery.DistanceJoinQuery( + object_rdd, window_rdd, True, True + ).collect() self.verify_join_result(result) - result_no_index = JoinQuery.DistanceJoinQuery(object_rdd, window_rdd, False, True).collect() + result_no_index = JoinQuery.DistanceJoinQuery( + object_rdd, window_rdd, False, True + ).collect() self.verify_join_result(result_no_index) def test_outside_polygon_distance_join_correctness(self): - center_geometry_rdd = PolygonRDD(self.sc.parallelize(self.test_polygon_window_set)) + center_geometry_rdd = PolygonRDD( + self.sc.parallelize(self.test_polygon_window_set) + ) window_rdd = CircleRDD(center_geometry_rdd, 0.1) object_rdd = PolygonRDD(self.sc.parallelize(self.test_outside_polygon_set)) self.prepare_rdd(object_rdd, window_rdd, GridType.QUADTREE) - result = JoinQuery.DistanceJoinQuery(object_rdd, window_rdd, True, True).collect() + result = JoinQuery.DistanceJoinQuery( + object_rdd, window_rdd, True, True + ).collect() assert 0 == result.__len__() - result_no_index = JoinQuery.DistanceJoinQuery(object_rdd, window_rdd, False, True).collect() + result_no_index = JoinQuery.DistanceJoinQuery( + object_rdd, window_rdd, False, True + ).collect() assert 0 == result_no_index.__len__() - def prepare_rdd(self, object_rdd: SpatialRDD, window_rdd: SpatialRDD, grid_type: GridType): + def prepare_rdd( + self, object_rdd: SpatialRDD, window_rdd: SpatialRDD, grid_type: GridType + ): object_rdd.analyze() window_rdd.analyze() object_rdd.rawSpatialRDD.repartition(4) @@ -188,7 +246,12 @@ def verify_join_result(cls, result): @classmethod def make_square(cls, minx: float, miny: float, side: float) -> Polygon: - coordinates = [(minx, miny), (minx + side, miny), (minx + side, miny + side), (minx, miny + side)] + coordinates = [ + (minx, miny), + (minx + side, miny), + (minx + side, miny + side), + (minx, miny + side), + ] polygon = Polygon(coordinates) @@ -196,7 +259,7 @@ def make_square(cls, minx: float, miny: float, side: float) -> Polygon: @classmethod def make_square_line(cls, minx: float, miny: float, side: float): - coordinates = [(minx, miny), (minx+side, miny), (minx + side, miny+side)] + coordinates = [(minx, miny), (minx + side, miny), (minx + side, miny + side)] return LineString(coordinates) @classmethod @@ -226,32 +289,72 @@ def once_before_all(cls): a = "a:" + id b = "b:" + id - cls.test_polygon_window_set.append(cls.wrap(cls.make_square(base_x, base_y, 5), a)) - cls.test_polygon_window_set.append(cls.wrap(cls.make_square(base_x, base_y, 5), b)) - - cls.test_inside_polygon_set.append(cls.wrap(cls.make_square(base_x + 2, base_y + 2, 2), a)) - cls.test_inside_polygon_set.append(cls.wrap(cls.make_square(base_x + 2, base_y + 2, 2), b)) - - cls.test_overlapped_polygon_set.append(cls.wrap(cls.make_square(base_x + 3, base_y + 3, 3), a)) - cls.test_overlapped_polygon_set.append(cls.wrap(cls.make_square(base_x + 3, base_y + 3, 3), b)) - - cls.test_outside_polygon_set.append(cls.wrap(cls.make_square(base_x + 6, base_y + 6, 3), a)) - cls.test_outside_polygon_set.append(cls.wrap(cls.make_square(base_x + 6, base_y + 6, 3), b)) - - cls.test_inside_linestring_set.append(cls.wrap(cls.make_square_line(base_x + 2, base_y + 2, 2), a)) - cls.test_inside_linestring_set.append(cls.wrap(cls.make_square_line(base_x + 2, base_y + 2, 2), b)) - - cls.test_overlapped_linestring_set.append(cls.wrap(cls.make_square_line(base_x + 3, base_y + 3, 3), a)) - cls.test_overlapped_linestring_set.append(cls.wrap(cls.make_square_line(base_x + 3, base_y + 3, 3), b)) - - cls.test_outside_linestring_set.append(cls.wrap(cls.make_square_line(base_x + 6, base_y + 6, 3), a)) - cls.test_outside_linestring_set.append(cls.wrap(cls.make_square_line(base_x + 6, base_y + 6, 3), b)) - - cls.test_inside_point_set.append(cls.wrap(cls.make_point(base_x + 2.5, base_y + 2.5), a)) - cls.test_inside_point_set.append(cls.wrap(cls.make_point(base_x + 2.5, base_y + 2.5), b)) - - cls.test_on_boundary_point_set.append(cls.wrap(cls.make_point(base_x + 5, base_y + 5), a)) - cls.test_on_boundary_point_set.append(cls.wrap(cls.make_point(base_x + 5, base_y + 5), b)) - - cls.test_outside_point_set.append(cls.wrap(cls.make_point(base_x + 6, base_y + 6), a)) - cls.test_outside_point_set.append(cls.wrap(cls.make_point(base_x + 6, base_y + 6), b)) + cls.test_polygon_window_set.append( + cls.wrap(cls.make_square(base_x, base_y, 5), a) + ) + cls.test_polygon_window_set.append( + cls.wrap(cls.make_square(base_x, base_y, 5), b) + ) + + cls.test_inside_polygon_set.append( + cls.wrap(cls.make_square(base_x + 2, base_y + 2, 2), a) + ) + cls.test_inside_polygon_set.append( + cls.wrap(cls.make_square(base_x + 2, base_y + 2, 2), b) + ) + + cls.test_overlapped_polygon_set.append( + cls.wrap(cls.make_square(base_x + 3, base_y + 3, 3), a) + ) + cls.test_overlapped_polygon_set.append( + cls.wrap(cls.make_square(base_x + 3, base_y + 3, 3), b) + ) + + cls.test_outside_polygon_set.append( + cls.wrap(cls.make_square(base_x + 6, base_y + 6, 3), a) + ) + cls.test_outside_polygon_set.append( + cls.wrap(cls.make_square(base_x + 6, base_y + 6, 3), b) + ) + + cls.test_inside_linestring_set.append( + cls.wrap(cls.make_square_line(base_x + 2, base_y + 2, 2), a) + ) + cls.test_inside_linestring_set.append( + cls.wrap(cls.make_square_line(base_x + 2, base_y + 2, 2), b) + ) + + cls.test_overlapped_linestring_set.append( + cls.wrap(cls.make_square_line(base_x + 3, base_y + 3, 3), a) + ) + cls.test_overlapped_linestring_set.append( + cls.wrap(cls.make_square_line(base_x + 3, base_y + 3, 3), b) + ) + + cls.test_outside_linestring_set.append( + cls.wrap(cls.make_square_line(base_x + 6, base_y + 6, 3), a) + ) + cls.test_outside_linestring_set.append( + cls.wrap(cls.make_square_line(base_x + 6, base_y + 6, 3), b) + ) + + cls.test_inside_point_set.append( + cls.wrap(cls.make_point(base_x + 2.5, base_y + 2.5), a) + ) + cls.test_inside_point_set.append( + cls.wrap(cls.make_point(base_x + 2.5, base_y + 2.5), b) + ) + + cls.test_on_boundary_point_set.append( + cls.wrap(cls.make_point(base_x + 5, base_y + 5), a) + ) + cls.test_on_boundary_point_set.append( + cls.wrap(cls.make_point(base_x + 5, base_y + 5), b) + ) + + cls.test_outside_point_set.append( + cls.wrap(cls.make_point(base_x + 6, base_y + 6), a) + ) + cls.test_outside_point_set.append( + cls.wrap(cls.make_point(base_x + 6, base_y + 6), b) + ) diff --git a/python/tests/spatial_operator/test_linestring_join.py b/python/tests/spatial_operator/test_linestring_join.py index 14afa15e4e..13c41329d6 100644 --- a/python/tests/spatial_operator/test_linestring_join.py +++ b/python/tests/spatial_operator/test_linestring_join.py @@ -47,34 +47,44 @@ def pytest_generate_tests(metafunc): dict(num_partitions=11, grid_type=GridType.KDBTREE), ] params_dyn = [{**param, **{"index_type": IndexType.QUADTREE}} for param in parameters] -params_dyn.extend([{**param, **{"index_type": IndexType.RTREE}} for param in parameters]) +params_dyn.extend( + [{**param, **{"index_type": IndexType.RTREE}} for param in parameters] +) class TestRectangleJoin(TestJoinBase): params = { "test_nested_loop": parameters, "test_dynamic_index_int": params_dyn, - "test_index_int": params_dyn + "test_index_int": params_dyn, } def test_nested_loop(self, num_partitions, grid_type): query_rdd = self.create_polygon_rdd(query_polygon_set, splitter, num_partitions) - spatial_rdd = self.create_linestring_rdd(input_location, splitter, num_partitions) + spatial_rdd = self.create_linestring_rdd( + input_location, splitter, num_partitions + ) self.partition_rdds(query_rdd, spatial_rdd, grid_type) result = JoinQuery.SpatialJoinQuery( - spatial_rdd, query_rdd, False, True).collect() + spatial_rdd, query_rdd, False, True + ).collect() self.sanity_check_join_results(result) - expected_count = match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) else match_count + expected_count = ( + match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else match_count + ) assert expected_count == self.count_join_results(result) def test_dynamic_index_int(self, num_partitions, grid_type, index_type): query_rdd = self.create_polygon_rdd(query_polygon_set, splitter, num_partitions) - spatial_rdd = self.create_linestring_rdd(input_location, splitter, num_partitions) + spatial_rdd = self.create_linestring_rdd( + input_location, splitter, num_partitions + ) self.partition_rdds(query_rdd, spatial_rdd, grid_type) @@ -83,22 +93,31 @@ def test_dynamic_index_int(self, num_partitions, grid_type, index_type): self.sanity_check_flat_join_results(result) - expected_count = match_with_original_duplicates_count \ - if self.expect_to_preserve_original_duplicates(grid_type) else match_count + expected_count = ( + match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else match_count + ) assert expected_count == result.__len__() def test_index_int(self, num_partitions, grid_type, index_type): query_rdd = self.create_polygon_rdd(query_polygon_set, splitter, num_partitions) - spatial_rdd = self.create_linestring_rdd(input_location, splitter, num_partitions) + spatial_rdd = self.create_linestring_rdd( + input_location, splitter, num_partitions + ) self.partition_rdds(query_rdd, spatial_rdd, grid_type) spatial_rdd.buildIndex(index_type, True) result = JoinQuery.SpatialJoinQuery( - spatial_rdd, query_rdd, False, True).collect() + spatial_rdd, query_rdd, False, True + ).collect() self.sanity_check_join_results(result) - expected_count = match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) else match_count + expected_count = ( + match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else match_count + ) assert expected_count == self.count_join_results(result) diff --git a/python/tests/spatial_operator/test_linestring_knn.py b/python/tests/spatial_operator/test_linestring_knn.py index 36b3a948aa..41bbaca246 100644 --- a/python/tests/spatial_operator/test_linestring_knn.py +++ b/python/tests/spatial_operator/test_linestring_knn.py @@ -40,7 +40,9 @@ class TestLineStringKnn(TestBase): def test_spatial_knn_query(self): line_string_rdd = LineStringRDD(self.sc, input_location, splitter, True) for i in range(self.loop_times): - result = KNNQuery.SpatialKnnQuery(line_string_rdd, self.query_point, 5, False) + result = KNNQuery.SpatialKnnQuery( + line_string_rdd, self.query_point, 5, False + ) assert result.__len__() > -1 assert result[0].getUserData() is not None @@ -48,6 +50,8 @@ def test_spatial_knn_query_using_index(self): line_string_rdd = LineStringRDD(self.sc, input_location, splitter, True) line_string_rdd.buildIndex(IndexType.RTREE, False) for i in range(self.loop_times): - result = KNNQuery.SpatialKnnQuery(line_string_rdd, self.query_point, 5, False) + result = KNNQuery.SpatialKnnQuery( + line_string_rdd, self.query_point, 5, False + ) assert result.__len__() > -1 assert result[0].getUserData() is not None diff --git a/python/tests/spatial_operator/test_linestring_range.py b/python/tests/spatial_operator/test_linestring_range.py index d9041622e4..2d3da73d4c 100644 --- a/python/tests/spatial_operator/test_linestring_range.py +++ b/python/tests/spatial_operator/test_linestring_range.py @@ -37,25 +37,33 @@ class TestLineStringRange(TestBase): query_envelope = Envelope(-85.01, -60.01, 34.01, 50.01) def test_spatial_range_query(self): - spatial_rdd = LineStringRDD( - self.sc, input_location, splitter, True - ) + spatial_rdd = LineStringRDD(self.sc, input_location, splitter, True) for i in range(self.loop_times): - result_size = RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False).count() + result_size = RangeQuery.SpatialRangeQuery( + spatial_rdd, self.query_envelope, False, False + ).count() assert result_size == 999 - assert RangeQuery.SpatialRangeQuery( - spatial_rdd, self.query_envelope, False, False).take(10)[1].getUserData() is not None + assert ( + RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False) + .take(10)[1] + .getUserData() + is not None + ) def test_spatial_range_query_using_index(self): - spatial_rdd = LineStringRDD( - self.sc, input_location, splitter, True - ) + spatial_rdd = LineStringRDD(self.sc, input_location, splitter, True) spatial_rdd.buildIndex(IndexType.RTREE, False) for i in range(self.loop_times): - result_size = RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False).count() + result_size = RangeQuery.SpatialRangeQuery( + spatial_rdd, self.query_envelope, False, False + ).count() assert result_size == 999 - assert RangeQuery.SpatialRangeQuery( - spatial_rdd, self.query_envelope, False, False).take(10)[1].getUserData() is not None + assert ( + RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False) + .take(10)[1] + .getUserData() + is not None + ) diff --git a/python/tests/spatial_operator/test_point_join.py b/python/tests/spatial_operator/test_point_join.py index 5da2161842..2363df281b 100644 --- a/python/tests/spatial_operator/test_point_join.py +++ b/python/tests/spatial_operator/test_point_join.py @@ -65,90 +65,122 @@ class TestRectangleJoin(TestJoinBase): "test_quad_tree_with_rectangles": parameters, "test_quad_tree_with_polygons": parameters, "test_dynamic_r_tree_with_rectangles": parameters, - "test_dynamic_r_tree_with_polygons": parameters + "test_dynamic_r_tree_with_polygons": parameters, } def test_nested_loop_with_rectangles(self, num_partitions, grid_type): - query_rdd = self.create_rectangle_rdd(input_location_query_window, splitter, num_partitions) + query_rdd = self.create_rectangle_rdd( + input_location_query_window, splitter, num_partitions + ) self.nested_loop(query_rdd, num_partitions, grid_type, rectangle_match_count) def test_nested_loop_with_polygons(self, num_partitions, grid_type): query_rdd = self.create_polygon_rdd(query_polygon_set, splitter, num_partitions) - expected_count = polygon_match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) else polygon_match_count + expected_count = ( + polygon_match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else polygon_match_count + ) self.nested_loop(query_rdd, num_partitions, grid_type, expected_count) def nested_loop(self, query_rdd, num_partitions, grid_type, expected_count): spatial_rdd = self.create_point_rdd(input_location, splitter, num_partitions) - self.partition_rdds( - query_rdd, spatial_rdd, grid_type) + self.partition_rdds(query_rdd, spatial_rdd, grid_type) result = JoinQuery.SpatialJoinQuery( - spatial_rdd, query_rdd, False, True).collect() + spatial_rdd, query_rdd, False, True + ).collect() self.sanity_check_join_results(result) assert expected_count == self.count_join_results(result) def test_rtree_with_rectangles(self, num_partitions, grid_type): - query_rdd = self.create_rectangle_rdd(input_location_query_window, splitter, num_partitions) + query_rdd = self.create_rectangle_rdd( + input_location_query_window, splitter, num_partitions + ) self.index_int( query_rdd, num_partitions, grid_type, IndexType.RTREE, polygon_match_count - ) def test_r_tree_with_polygons(self, num_partitions, grid_type): query_rdd = self.create_polygon_rdd(query_polygon_set, splitter, num_partitions) - expected_count = polygon_match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) else polygon_match_count + expected_count = ( + polygon_match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else polygon_match_count + ) self.index_int( query_rdd, num_partitions, grid_type, IndexType.RTREE, expected_count - ) def test_quad_tree_with_rectangles(self, num_partitions, grid_type): - query_rdd = self.create_rectangle_rdd(input_location_query_window, splitter, num_partitions) + query_rdd = self.create_rectangle_rdd( + input_location_query_window, splitter, num_partitions + ) self.index_int( - query_rdd, num_partitions, grid_type, IndexType.QUADTREE, polygon_match_count - + query_rdd, + num_partitions, + grid_type, + IndexType.QUADTREE, + polygon_match_count, ) def test_quad_tree_with_polygons(self, num_partitions, grid_type): query_rdd = self.create_polygon_rdd(query_polygon_set, splitter, num_partitions) - expected_count = polygon_match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) else polygon_match_count + expected_count = ( + polygon_match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else polygon_match_count + ) self.index_int( query_rdd, num_partitions, grid_type, IndexType.QUADTREE, expected_count - ) - def index_int(self, query_rdd, num_partitions, grid_type, index_type, expected_count): + def index_int( + self, query_rdd, num_partitions, grid_type, index_type, expected_count + ): spatial_rdd = self.create_point_rdd(input_location, splitter, num_partitions) self.partition_rdds(query_rdd, spatial_rdd, grid_type) spatial_rdd.buildIndex(index_type, True) result = JoinQuery.SpatialJoinQuery( - spatial_rdd, query_rdd, False, True).collect() + spatial_rdd, query_rdd, False, True + ).collect() self.sanity_check_join_results(result) assert expected_count, self.count_join_results(result) def test_dynamic_r_tree_with_rectangles(self, grid_type, num_partitions): - polygon_rdd = self.create_rectangle_rdd(input_location_query_window, splitter, num_partitions) - expected_count = rectangle_match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) \ + polygon_rdd = self.create_rectangle_rdd( + input_location_query_window, splitter, num_partitions + ) + expected_count = ( + rectangle_match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) else rectangle_match_count - self.dynamic_rtree_int(polygon_rdd, num_partitions, grid_type, IndexType.RTREE, expected_count) + ) + self.dynamic_rtree_int( + polygon_rdd, num_partitions, grid_type, IndexType.RTREE, expected_count + ) def test_dynamic_r_tree_with_polygons(self, grid_type, num_partitions): - polygon_rdd = self.create_polygon_rdd(query_polygon_set, splitter, num_partitions) - expected_count = polygon_match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) \ + polygon_rdd = self.create_polygon_rdd( + query_polygon_set, splitter, num_partitions + ) + expected_count = ( + polygon_match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) else polygon_match_count - self.dynamic_rtree_int(polygon_rdd, num_partitions, grid_type, IndexType.RTREE, expected_count) + ) + self.dynamic_rtree_int( + polygon_rdd, num_partitions, grid_type, IndexType.RTREE, expected_count + ) - def dynamic_rtree_int(self, query_rdd, num_partitions, grid_type, index_type, expected_count): + def dynamic_rtree_int( + self, query_rdd, num_partitions, grid_type, index_type, expected_count + ): spatial_rdd = self.create_point_rdd(input_location, splitter, num_partitions) self.partition_rdds(query_rdd, spatial_rdd, grid_type) diff --git a/python/tests/spatial_operator/test_point_knn.py b/python/tests/spatial_operator/test_point_knn.py index 8d6d0274df..017b8b7fee 100644 --- a/python/tests/spatial_operator/test_point_knn.py +++ b/python/tests/spatial_operator/test_point_knn.py @@ -42,7 +42,9 @@ def test_spatial_knn_query(self): point_rdd = PointRDD(self.sc, input_location, offset, splitter, False) for i in range(self.loop_times): - result = KNNQuery.SpatialKnnQuery(point_rdd, self.query_point, self.top_k, False) + result = KNNQuery.SpatialKnnQuery( + point_rdd, self.query_point, self.top_k, False + ) assert result.__len__() > -1 def test_spatial_knn_query_using_index(self): @@ -50,23 +52,35 @@ def test_spatial_knn_query_using_index(self): point_rdd.buildIndex(IndexType.RTREE, False) for i in range(self.loop_times): - result = KNNQuery.SpatialKnnQuery(point_rdd, self.query_point, self.top_k, False) + result = KNNQuery.SpatialKnnQuery( + point_rdd, self.query_point, self.top_k, False + ) assert result.__len__() > -1 def test_spatial_knn_correctness(self): point_rdd = PointRDD(self.sc, input_location, offset, splitter, False) - result_no_index = KNNQuery.SpatialKnnQuery(point_rdd, self.query_point, self.top_k, False) + result_no_index = KNNQuery.SpatialKnnQuery( + point_rdd, self.query_point, self.top_k, False + ) point_rdd.buildIndex(IndexType.RTREE, False) - result_with_index = KNNQuery.SpatialKnnQuery(point_rdd, self.query_point, self.top_k, True) + result_with_index = KNNQuery.SpatialKnnQuery( + point_rdd, self.query_point, self.top_k, True + ) - sorted_result_no_index = sorted(result_no_index, key=lambda geo_data: distance_sorting_functions( - geo_data, self.query_point)) + sorted_result_no_index = sorted( + result_no_index, + key=lambda geo_data: distance_sorting_functions(geo_data, self.query_point), + ) - sorted_result_with_index = sorted(result_with_index, key=lambda geo_data: distance_sorting_functions( - geo_data, self.query_point)) + sorted_result_with_index = sorted( + result_with_index, + key=lambda geo_data: distance_sorting_functions(geo_data, self.query_point), + ) difference = 0 for x in range(self.top_k): - difference += sorted_result_no_index[x].geom.distance(sorted_result_with_index[x].geom) + difference += sorted_result_no_index[x].geom.distance( + sorted_result_with_index[x].geom + ) assert difference == 0 diff --git a/python/tests/spatial_operator/test_point_range.py b/python/tests/spatial_operator/test_point_range.py index ef7c62afe0..86a837fe5b 100644 --- a/python/tests/spatial_operator/test_point_range.py +++ b/python/tests/spatial_operator/test_point_range.py @@ -35,10 +35,7 @@ queryPolygonSet = "primaryroads-polygon.csv" inputCount = 3000 inputBoundary = Envelope( - minx=-173.120769, - maxx=-84.965961, - miny=30.244859, - maxy=71.355134 + minx=-173.120769, maxx=-84.965961, miny=30.244859, maxy=71.355134 ) rectangleMatchCount = 103 rectangleMatchWithOriginalDuplicatesCount = 103 @@ -53,9 +50,9 @@ class TestPointRange(TestBase): def test_spatial_range_query(self): spatial_rdd = PointRDD(self.sc, input_location, offset, splitter, False) for i in range(self.loop_times): - result_size = RangeQuery.\ - SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False)\ - .count() + result_size = RangeQuery.SpatialRangeQuery( + spatial_rdd, self.query_envelope, False, False + ).count() assert result_size == 2830 def test_spatial_range_query_using_index(self): @@ -64,7 +61,7 @@ def test_spatial_range_query_using_index(self): spatial_rdd.buildIndex(IndexType.RTREE, False) for i in range(self.loop_times): - result_size = RangeQuery.\ - SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False)\ - .count() + result_size = RangeQuery.SpatialRangeQuery( + spatial_rdd, self.query_envelope, False, False + ).count() assert result_size == 2830 diff --git a/python/tests/spatial_operator/test_polygon_join.py b/python/tests/spatial_operator/test_polygon_join.py index 77137092e8..e22cf1a60f 100644 --- a/python/tests/spatial_operator/test_polygon_join.py +++ b/python/tests/spatial_operator/test_polygon_join.py @@ -48,17 +48,19 @@ def pytest_generate_tests(metafunc): dict(num_partitions=11, grid_type=GridType.QUADTREE, intersects=False), dict(num_partitions=11, grid_type=GridType.QUADTREE, intersects=True), dict(num_partitions=11, grid_type=GridType.QUADTREE, intersects=True), - dict(num_partitions=11, grid_type=GridType.KDBTREE, intersects=True) + dict(num_partitions=11, grid_type=GridType.KDBTREE, intersects=True), ] params_dyn = [{**param, **{"index_type": IndexType.QUADTREE}} for param in parameters] -params_dyn.extend([{**param, **{"index_type": IndexType.RTREE}} for param in parameters]) +params_dyn.extend( + [{**param, **{"index_type": IndexType.RTREE}} for param in parameters] +) class TestRectangleJoin(TestJoinBase): params = { "test_nested_loop": parameters, "test_dynamic_index_int": params_dyn, - "test_index_int": params_dyn + "test_index_int": params_dyn, } def test_nested_loop(self, num_partitions, grid_type, intersects): @@ -68,10 +70,13 @@ def test_nested_loop(self, num_partitions, grid_type, intersects): self.partition_rdds(query_rdd, spatial_rdd, grid_type) result = JoinQuery.SpatialJoinQuery( - spatial_rdd, query_rdd, False, intersects).collect() + spatial_rdd, query_rdd, False, intersects + ).collect() self.sanity_check_join_results(result) - assert self.get_expected_with_original_duplicates_count(intersects) == self.count_join_results(result) + assert self.get_expected_with_original_duplicates_count( + intersects + ) == self.count_join_results(result) def test_dynamic_index_int(self, num_partitions, grid_type, index_type, intersects): query_rdd = self.create_polygon_rdd(query_polygon_set, splitter, num_partitions) @@ -84,8 +89,11 @@ def test_dynamic_index_int(self, num_partitions, grid_type, index_type, intersec self.sanity_check_flat_join_results(result) - expected_count = self.get_expected_with_original_duplicates_count(intersects) \ - if self.expect_to_preserve_original_duplicates(grid_type) else self.get_expected_count(intersects) + expected_count = ( + self.get_expected_with_original_duplicates_count(intersects) + if self.expect_to_preserve_original_duplicates(grid_type) + else self.get_expected_count(intersects) + ) assert expected_count == result.__len__() def test_index_int(self, num_partitions, grid_type, index_type, intersects): @@ -96,13 +104,20 @@ def test_index_int(self, num_partitions, grid_type, index_type, intersects): spatial_rdd.buildIndex(index_type, True) result = JoinQuery.SpatialJoinQuery( - spatial_rdd, query_rdd, True, intersects).collect() + spatial_rdd, query_rdd, True, intersects + ).collect() self.sanity_check_join_results(result) - assert self.get_expected_with_original_duplicates_count(intersects) == self.count_join_results(result) + assert self.get_expected_with_original_duplicates_count( + intersects + ) == self.count_join_results(result) def get_expected_count(self, intersects): return intersects_match_count if intersects else contains_match_count def get_expected_with_original_duplicates_count(self, intersects): - return intersects_match_count_with_original_duplicates if intersects else contains_match_with_original_duplicates + return ( + intersects_match_count_with_original_duplicates + if intersects + else contains_match_with_original_duplicates + ) diff --git a/python/tests/spatial_operator/test_polygon_knn.py b/python/tests/spatial_operator/test_polygon_knn.py index 2a341e9761..5b8ed2ab6e 100644 --- a/python/tests/spatial_operator/test_polygon_knn.py +++ b/python/tests/spatial_operator/test_polygon_knn.py @@ -41,7 +41,9 @@ def test_spatial_knn_query(self): polygon_rdd = PolygonRDD(self.sc, input_location, splitter, True) for i in range(self.loop_times): - result = KNNQuery.SpatialKnnQuery(polygon_rdd, self.query_point, self.top_k, False) + result = KNNQuery.SpatialKnnQuery( + polygon_rdd, self.query_point, self.top_k, False + ) assert result.__len__() > -1 assert result[0].getUserData() is not None @@ -49,24 +51,36 @@ def test_spatial_knn_query_using_index(self): polygon_rdd = PolygonRDD(self.sc, input_location, splitter, True) polygon_rdd.buildIndex(IndexType.RTREE, False) for i in range(self.loop_times): - result = KNNQuery.SpatialKnnQuery(polygon_rdd, self.query_point, self.top_k, True) + result = KNNQuery.SpatialKnnQuery( + polygon_rdd, self.query_point, self.top_k, True + ) assert result.__len__() > -1 assert result[0].getUserData() is not None def test_spatial_knn_correctness(self): polygon_rdd = PolygonRDD(self.sc, input_location, splitter, True) - result_no_index = KNNQuery.SpatialKnnQuery(polygon_rdd, self.query_point, self.top_k, False) + result_no_index = KNNQuery.SpatialKnnQuery( + polygon_rdd, self.query_point, self.top_k, False + ) polygon_rdd.buildIndex(IndexType.RTREE, False) - result_with_index = KNNQuery.SpatialKnnQuery(polygon_rdd, self.query_point, self.top_k, True) + result_with_index = KNNQuery.SpatialKnnQuery( + polygon_rdd, self.query_point, self.top_k, True + ) - sorted_result_no_index = sorted(result_no_index, key=lambda geo_data: distance_sorting_functions( - geo_data, self.query_point)) + sorted_result_no_index = sorted( + result_no_index, + key=lambda geo_data: distance_sorting_functions(geo_data, self.query_point), + ) - sorted_result_with_index = sorted(result_with_index, key=lambda geo_data: distance_sorting_functions( - geo_data, self.query_point)) + sorted_result_with_index = sorted( + result_with_index, + key=lambda geo_data: distance_sorting_functions(geo_data, self.query_point), + ) difference = 0 for x in range(self.top_k): - difference += sorted_result_no_index[x].geom.distance(sorted_result_with_index[x].geom) + difference += sorted_result_no_index[x].geom.distance( + sorted_result_with_index[x].geom + ) assert difference == 0 diff --git a/python/tests/spatial_operator/test_polygon_range.py b/python/tests/spatial_operator/test_polygon_range.py index 4537ffc06b..1173414458 100644 --- a/python/tests/spatial_operator/test_polygon_range.py +++ b/python/tests/spatial_operator/test_polygon_range.py @@ -35,26 +35,32 @@ class TestPolygonRange(TestBase): query_envelope = Envelope(-85.01, -60.01, 34.01, 50.01) def test_spatial_range_query(self): - spatial_rdd = PolygonRDD( - self.sc, input_location, splitter, True - ) + spatial_rdd = PolygonRDD(self.sc, input_location, splitter, True) for i in range(self.loop_times): - result_size = RangeQuery.\ - SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False).count() + result_size = RangeQuery.SpatialRangeQuery( + spatial_rdd, self.query_envelope, False, False + ).count() assert result_size == 704 - assert RangeQuery.SpatialRangeQuery( - spatial_rdd, self.query_envelope, False, False).take(10)[0].getUserData() is not None + assert ( + RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False) + .take(10)[0] + .getUserData() + is not None + ) def test_spatial_range_query_using_index(self): - spatial_rdd = PolygonRDD( - self.sc, input_location, splitter, True - ) + spatial_rdd = PolygonRDD(self.sc, input_location, splitter, True) spatial_rdd.buildIndex(IndexType.RTREE, False) for i in range(self.loop_times): - result_size = RangeQuery.\ - SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False).count() + result_size = RangeQuery.SpatialRangeQuery( + spatial_rdd, self.query_envelope, False, False + ).count() assert result_size == 704 - assert RangeQuery.SpatialRangeQuery( - spatial_rdd, self.query_envelope, False, False).take(10)[0].getUserData() is not None + assert ( + RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False) + .take(10)[0] + .getUserData() + is not None + ) diff --git a/python/tests/spatial_operator/test_rectangle_join.py b/python/tests/spatial_operator/test_rectangle_join.py index 2c101d8c8a..6c68955346 100644 --- a/python/tests/spatial_operator/test_rectangle_join.py +++ b/python/tests/spatial_operator/test_rectangle_join.py @@ -51,36 +51,46 @@ def pytest_generate_tests(metafunc): dict(num_partitions=11, grid_type=GridType.KDBTREE), ] params_dyn = [{**param, **{"index_type": IndexType.QUADTREE}} for param in parameters] -params_dyn.extend([{**param, **{"index_type": IndexType.RTREE}} for param in parameters]) +params_dyn.extend( + [{**param, **{"index_type": IndexType.RTREE}} for param in parameters] +) class TestRectangleJoin(TestJoinBase): params = { "test_nested_loop": parameters, "test_dynamic_index_int": params_dyn, - "test_index_int": params_dyn + "test_index_int": params_dyn, } def test_nested_loop(self, num_partitions, grid_type): query_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions) - spatial_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions) + spatial_rdd = self.create_rectangle_rdd( + input_location, splitter, num_partitions + ) self.partition_rdds(query_rdd, spatial_rdd, grid_type) result = JoinQuery.SpatialJoinQuery( - spatial_rdd, query_rdd, False, True).collect() + spatial_rdd, query_rdd, False, True + ).collect() count = 0 for el in result: count += el[1].__len__() self.sanity_check_join_results(result) - expected_count = match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) else match_count + expected_count = ( + match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else match_count + ) assert expected_count == self.count_join_results(result) def test_dynamic_index_int(self, num_partitions, grid_type, index_type): query_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions) - spatial_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions) + spatial_rdd = self.create_rectangle_rdd( + input_location, splitter, num_partitions + ) self.partition_rdds(query_rdd, spatial_rdd, grid_type) @@ -89,23 +99,32 @@ def test_dynamic_index_int(self, num_partitions, grid_type, index_type): self.sanity_check_flat_join_results(result) - expected_count = match_with_original_duplicates_count \ - if self.expect_to_preserve_original_duplicates(grid_type) else match_count + expected_count = ( + match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else match_count + ) assert expected_count == result.__len__() def test_index_int(self, num_partitions, grid_type, index_type): query_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions) - spatial_rdd = self.create_rectangle_rdd(input_location, splitter, num_partitions) + spatial_rdd = self.create_rectangle_rdd( + input_location, splitter, num_partitions + ) self.partition_rdds(query_rdd, spatial_rdd, grid_type) spatial_rdd.buildIndex(index_type, True) result = JoinQuery.SpatialJoinQuery( - spatial_rdd, query_rdd, False, True).collect() + spatial_rdd, query_rdd, False, True + ).collect() self.sanity_check_join_results(result) - expected_count = match_with_original_duplicates_count if self.expect_to_preserve_original_duplicates( - grid_type) else match_count + expected_count = ( + match_with_original_duplicates_count + if self.expect_to_preserve_original_duplicates(grid_type) + else match_count + ) assert expected_count == self.count_join_results(result) diff --git a/python/tests/spatial_operator/test_rectangle_knn.py b/python/tests/spatial_operator/test_rectangle_knn.py index f791fb3d92..09d14cda0b 100644 --- a/python/tests/spatial_operator/test_rectangle_knn.py +++ b/python/tests/spatial_operator/test_rectangle_knn.py @@ -47,7 +47,13 @@ class TestRectangleKNN(TestBase): query_point = Point(-84.01, 34.01) top_k = 100 query_polygon = Polygon( - [(-84.01, 34.01), (-84.01, 34.11), (-83.91, 34.11), (-83.91, 34.01), (-84.01, 34.01)] + [ + (-84.01, 34.01), + (-84.01, 34.11), + (-83.91, 34.11), + (-83.91, 34.01), + (-84.01, 34.01), + ] ) query_line = LineString( [(-84.01, 34.01), (-84.01, 34.11), (-83.91, 34.11), (-83.91, 34.01)] @@ -57,7 +63,9 @@ def test_spatial_knn_query(self): rectangle_rdd = RectangleRDD(self.sc, inputLocation, offset, splitter, True) for i in range(self.loop_times): - result = KNNQuery.SpatialKnnQuery(rectangle_rdd, self.query_point, self.top_k, False) + result = KNNQuery.SpatialKnnQuery( + rectangle_rdd, self.query_point, self.top_k, False + ) assert result.__len__() > -1 assert result[0].getUserData() is not None @@ -67,7 +75,9 @@ def test_spatial_knn_query_using_index(self): rectangle_rdd.buildIndex(IndexType.RTREE, False) for i in range(self.loop_times): - result = KNNQuery.SpatialKnnQuery(rectangle_rdd, self.query_point, self.top_k, False) + result = KNNQuery.SpatialKnnQuery( + rectangle_rdd, self.query_point, self.top_k, False + ) assert result.__len__() > -1 assert result[0].getUserData() is not None @@ -75,33 +85,47 @@ def test_spatial_knn_query_using_index(self): def test_spatial_knn_query_correctness(self): rectangle_rdd = RectangleRDD(self.sc, inputLocation, offset, splitter, True) - result_no_index = KNNQuery.SpatialKnnQuery(rectangle_rdd, self.query_point, self.top_k, False) + result_no_index = KNNQuery.SpatialKnnQuery( + rectangle_rdd, self.query_point, self.top_k, False + ) rectangle_rdd.buildIndex(IndexType.RTREE, False) - result_with_index = KNNQuery.SpatialKnnQuery(rectangle_rdd, self.query_point, self.top_k, True) + result_with_index = KNNQuery.SpatialKnnQuery( + rectangle_rdd, self.query_point, self.top_k, True + ) - sorted_result_no_index = sorted(result_no_index, key=lambda geo_data: distance_sorting_functions( - geo_data, self.query_point)) + sorted_result_no_index = sorted( + result_no_index, + key=lambda geo_data: distance_sorting_functions(geo_data, self.query_point), + ) - sorted_result_with_index = sorted(result_with_index, key=lambda geo_data: distance_sorting_functions( - geo_data, self.query_point)) + sorted_result_with_index = sorted( + result_with_index, + key=lambda geo_data: distance_sorting_functions(geo_data, self.query_point), + ) difference = 0 for x in range(self.top_k): - difference += sorted_result_no_index[x].geom.distance(sorted_result_with_index[x].geom) + difference += sorted_result_no_index[x].geom.distance( + sorted_result_with_index[x].geom + ) assert difference == 0 def test_spatial_knn_using_polygon(self): rectangle_rdd = RectangleRDD(self.sc, inputLocation, offset, splitter, True) - result_no_index = KNNQuery.SpatialKnnQuery(rectangle_rdd, self.query_polygon, self.top_k, False) + result_no_index = KNNQuery.SpatialKnnQuery( + rectangle_rdd, self.query_polygon, self.top_k, False + ) print(result_no_index) def test_spatial_knn_using_linestring(self): rectangle_rdd = RectangleRDD(self.sc, inputLocation, offset, splitter, True) - result_no_index = KNNQuery.SpatialKnnQuery(rectangle_rdd, self.query_line, self.top_k, False) + result_no_index = KNNQuery.SpatialKnnQuery( + rectangle_rdd, self.query_line, self.top_k, False + ) print(result_no_index) diff --git a/python/tests/spatial_operator/test_rectangle_range.py b/python/tests/spatial_operator/test_rectangle_range.py index dd25f3ac31..d3e79397bf 100644 --- a/python/tests/spatial_operator/test_rectangle_range.py +++ b/python/tests/spatial_operator/test_rectangle_range.py @@ -48,21 +48,30 @@ def test_spatial_range_query(self): for i in range(self.loop_times): result_size = RangeQuery.SpatialRangeQuery( - spatial_rdd, self.query_envelope, False, False).count() + spatial_rdd, self.query_envelope, False, False + ).count() assert result_size == 193 - assert RangeQuery.SpatialRangeQuery( - spatial_rdd, self.query_envelope, False, False).take(10)[1].getUserData() is not None + assert ( + RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, False) + .take(10)[1] + .getUserData() + is not None + ) def test_spatial_range_query_using_index(self): - spatial_rdd = RectangleRDD( - self.sc, inputLocation, offset, splitter, True) + spatial_rdd = RectangleRDD(self.sc, inputLocation, offset, splitter, True) spatial_rdd.buildIndex(IndexType.RTREE, False) for i in range(self.loop_times): result_size = RangeQuery.SpatialRangeQuery( - spatial_rdd, self.query_envelope, False, True).count() + spatial_rdd, self.query_envelope, False, True + ).count() assert result_size == 193 - assert RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, True).take(10)[1].getUserData()\ - is not None + assert ( + RangeQuery.SpatialRangeQuery(spatial_rdd, self.query_envelope, False, True) + .take(10)[1] + .getUserData() + is not None + ) diff --git a/python/tests/spatial_rdd/test_circle_rdd.py b/python/tests/spatial_rdd/test_circle_rdd.py index 746c5d206a..6117c2bb18 100644 --- a/python/tests/spatial_rdd/test_circle_rdd.py +++ b/python/tests/spatial_rdd/test_circle_rdd.py @@ -17,19 +17,19 @@ from sedona.core.SpatialRDD import PointRDD, CircleRDD from tests.test_base import TestBase -from tests.properties.point_properties import input_location, offset, splitter, num_partitions +from tests.properties.point_properties import ( + input_location, + offset, + splitter, + num_partitions, +) class TestCircleRDD(TestBase): def test_circle_rdd(self): spatial_rdd = PointRDD( - self.sc, - input_location, - offset, - splitter, - True, - num_partitions + self.sc, input_location, offset, splitter, True, num_partitions ) circle_rdd = CircleRDD(spatial_rdd, 0.5) @@ -38,5 +38,8 @@ def test_circle_rdd(self): assert circle_rdd.approximateTotalCount == 3000 - assert circle_rdd.rawSpatialRDD.take(1)[0].getUserData() == "testattribute0\ttestattribute1\ttestattribute2" + assert ( + circle_rdd.rawSpatialRDD.take(1)[0].getUserData() + == "testattribute0\ttestattribute1\ttestattribute2" + ) assert circle_rdd.rawSpatialRDD.take(1)[0].geom.radius == 0.5 diff --git a/python/tests/spatial_rdd/test_linestring_rdd.py b/python/tests/spatial_rdd/test_linestring_rdd.py index ea6d4e7342..edfb38e7cc 100644 --- a/python/tests/spatial_rdd/test_linestring_rdd.py +++ b/python/tests/spatial_rdd/test_linestring_rdd.py @@ -18,8 +18,17 @@ from sedona.core.SpatialRDD import LineStringRDD from sedona.core.enums import IndexType, GridType from sedona.core.geom.envelope import Envelope -from tests.properties.linestring_properties import input_count, input_boundary, input_location, splitter, num_partitions, \ - grid_type, transformed_envelope, input_boundary_2, transformed_envelope_2 +from tests.properties.linestring_properties import ( + input_count, + input_boundary, + input_location, + splitter, + num_partitions, + grid_type, + transformed_envelope, + input_boundary_2, + transformed_envelope_2, +) from tests.test_base import TestBase @@ -38,7 +47,7 @@ def test_constructor(self): InputLocation=input_location, splitter=splitter, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) self.compare_count(spatial_rdd_core, input_boundary, input_count) @@ -49,7 +58,9 @@ def test_constructor(self): self.compare_count(spatial_rdd, input_boundary, input_count) - spatial_rdd = LineStringRDD(self.sc, input_location, 0, 3, splitter, True, num_partitions) + spatial_rdd = LineStringRDD( + self.sc, input_location, 0, 3, splitter, True, num_partitions + ) self.compare_count(spatial_rdd, input_boundary_2, input_count) @@ -57,7 +68,9 @@ def test_constructor(self): self.compare_count(spatial_rdd, input_boundary_2, input_count) - spatial_rdd = LineStringRDD(self.sc, input_location, splitter, True, num_partitions) + spatial_rdd = LineStringRDD( + self.sc, input_location, splitter, True, num_partitions + ) self.compare_count(spatial_rdd, input_boundary, input_count) @@ -69,18 +82,19 @@ def test_constructor(self): self.compare_count(spatial_rdd, input_boundary, input_count) - spatial_rdd = LineStringRDD(self.sc, input_location, 0, 3, splitter, True, num_partitions) + spatial_rdd = LineStringRDD( + self.sc, input_location, 0, 3, splitter, True, num_partitions + ) self.compare_count(spatial_rdd, input_boundary_2, input_count) - def test_empty_constructor(self): spatial_rdd = LineStringRDD( sparkContext=self.sc, InputLocation=input_location, splitter=splitter, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) spatial_rdd.analyze() @@ -96,7 +110,7 @@ def test_build_index_without_set_grid(self): InputLocation=input_location, splitter=splitter, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) spatial_rdd.analyze() @@ -108,7 +122,7 @@ def test_mbr(self): InputLocation=input_location, splitter=splitter, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) rectangle_rdd = linestring_rdd.MinimumBoundingRectangle() diff --git a/python/tests/spatial_rdd/test_point_rdd.py b/python/tests/spatial_rdd/test_point_rdd.py index 9cb1545465..f7a8a1499c 100644 --- a/python/tests/spatial_rdd/test_point_rdd.py +++ b/python/tests/spatial_rdd/test_point_rdd.py @@ -19,8 +19,18 @@ from sedona.core.SpatialRDD.spatial_rdd import SpatialRDD from sedona.core.enums import IndexType, GridType from sedona.core.geom.envelope import Envelope -from tests.properties.point_properties import input_location, offset, splitter, num_partitions, input_count, input_boundary, \ - transformed_envelope, crs_point_test, crs_envelope, crs_envelope_transformed +from tests.properties.point_properties import ( + input_location, + offset, + splitter, + num_partitions, + input_count, + input_boundary, + transformed_envelope, + crs_point_test, + crs_envelope, + crs_envelope_transformed, +) from tests.test_base import TestBase @@ -33,24 +43,33 @@ def compare_count(self, spatial_rdd: SpatialRDD, cnt: int, envelope: Envelope): def test_constructor(self): spatial_rdd = PointRDD( - self.sc, - input_location, - offset, - splitter, - True, - num_partitions + self.sc, input_location, offset, splitter, True, num_partitions ) spatial_rdd.rawSpatialRDD.take(9)[0].getUserData() - assert spatial_rdd.rawSpatialRDD.take(9)[0].getUserData() == "testattribute0\ttestattribute1\ttestattribute2" - assert spatial_rdd.rawSpatialRDD.take(9)[2].getUserData() == "testattribute0\ttestattribute1\ttestattribute2" - assert spatial_rdd.rawSpatialRDD.take(9)[4].getUserData() == "testattribute0\ttestattribute1\ttestattribute2" - assert spatial_rdd.rawSpatialRDD.take(9)[8].getUserData() == "testattribute0\ttestattribute1\ttestattribute2" + assert ( + spatial_rdd.rawSpatialRDD.take(9)[0].getUserData() + == "testattribute0\ttestattribute1\ttestattribute2" + ) + assert ( + spatial_rdd.rawSpatialRDD.take(9)[2].getUserData() + == "testattribute0\ttestattribute1\ttestattribute2" + ) + assert ( + spatial_rdd.rawSpatialRDD.take(9)[4].getUserData() + == "testattribute0\ttestattribute1\ttestattribute2" + ) + assert ( + spatial_rdd.rawSpatialRDD.take(9)[8].getUserData() + == "testattribute0\ttestattribute1\ttestattribute2" + ) spatial_rdd_copy = PointRDD(spatial_rdd.rawJvmSpatialRDD) self.compare_count(spatial_rdd_copy, input_count, input_boundary) spatial_rdd_copy = PointRDD(spatial_rdd.rawJvmSpatialRDD) self.compare_count(spatial_rdd_copy, input_count, input_boundary) - spatial_rdd_copy = PointRDD(self.sc, input_location, offset, splitter, True, num_partitions) + spatial_rdd_copy = PointRDD( + self.sc, input_location, offset, splitter, True, num_partitions + ) self.compare_count(spatial_rdd_copy, input_count, input_boundary) spatial_rdd_copy = PointRDD(self.sc, crs_point_test, splitter, True) self.compare_count(spatial_rdd_copy, 20000, crs_envelope) @@ -62,7 +81,7 @@ def test_empty_constructor(self): Offset=offset, splitter=splitter, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) spatial_rdd.buildIndex(IndexType.RTREE, False) spatial_rdd_copy = PointRDD() @@ -76,12 +95,15 @@ def test_equal_partitioning(self): Offset=offset, splitter=splitter, carryInputData=False, - partitions=10 + partitions=10, ) spatial_rdd.analyze() spatial_rdd.spatialPartitioning(GridType.QUADTREE) - assert spatial_rdd.countWithoutDuplicates() == spatial_rdd.countWithoutDuplicatesSPRDD() + assert ( + spatial_rdd.countWithoutDuplicates() + == spatial_rdd.countWithoutDuplicatesSPRDD() + ) def test_build_index_without_set_grid(self): spatial_rdd = PointRDD( @@ -90,6 +112,6 @@ def test_build_index_without_set_grid(self): Offset=offset, splitter=splitter, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) spatial_rdd.buildIndex(IndexType.RTREE, False) diff --git a/python/tests/spatial_rdd/test_polygon_rdd.py b/python/tests/spatial_rdd/test_polygon_rdd.py index 3897f78565..d2be163048 100644 --- a/python/tests/spatial_rdd/test_polygon_rdd.py +++ b/python/tests/spatial_rdd/test_polygon_rdd.py @@ -19,9 +19,22 @@ from sedona.core.SpatialRDD.spatial_rdd import SpatialRDD from sedona.core.enums import IndexType, FileDataSplitter, GridType from sedona.core.geom.envelope import Envelope -from tests.properties.polygon_properties import input_location, splitter, num_partitions, input_count, input_boundary, grid_type, \ - input_location_geo_json, input_location_wkt, input_location_wkb, query_envelope, polygon_rdd_input_location, \ - polygon_rdd_start_offset, polygon_rdd_end_offset, polygon_rdd_splitter +from tests.properties.polygon_properties import ( + input_location, + splitter, + num_partitions, + input_count, + input_boundary, + grid_type, + input_location_geo_json, + input_location_wkt, + input_location_wkb, + query_envelope, + polygon_rdd_input_location, + polygon_rdd_start_offset, + polygon_rdd_end_offset, + polygon_rdd_splitter, +) from tests.test_base import TestBase @@ -42,7 +55,7 @@ def test_constructor(self): InputLocation=input_location, splitter=splitter, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) self.compare_spatial_rdd(spatial_rdd_core, input_boundary) @@ -57,7 +70,7 @@ def test_constructor(self): polygon_rdd_end_offset, polygon_rdd_splitter, True, - 2 + 2, ) assert query_window_rdd.analyze() assert query_window_rdd.approximateTotalCount == 3000 @@ -68,27 +81,18 @@ def test_constructor(self): polygon_rdd_start_offset, polygon_rdd_end_offset, polygon_rdd_splitter, - True + True, ) assert query_window_rdd.analyze() assert query_window_rdd.approximateTotalCount == 3000 spatial_rdd_core = PolygonRDD( - self.sc, - input_location, - splitter, - True, - num_partitions + self.sc, input_location, splitter, True, num_partitions ) self.compare_spatial_rdd(spatial_rdd_core, input_boundary) - spatial_rdd_core = PolygonRDD( - self.sc, - input_location, - splitter, - True - ) + spatial_rdd_core = PolygonRDD(self.sc, input_location, splitter, True) self.compare_spatial_rdd(spatial_rdd_core, input_boundary) @@ -98,7 +102,7 @@ def test_empty_constructor(self): InputLocation=input_location, splitter=splitter, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) spatial_rdd.analyze() spatial_rdd.spatialPartitioning(grid_type) @@ -113,39 +117,62 @@ def test_geojson_constructor(self): InputLocation=input_location_geo_json, splitter=FileDataSplitter.GEOJSON, carryInputData=True, - partitions=4 + partitions=4, ) spatial_rdd.analyze() assert spatial_rdd.approximateTotalCount == 1001 assert spatial_rdd.boundaryEnvelope is not None - assert spatial_rdd.rawSpatialRDD.take(1)[0].getUserData() == "01\t077\t011501\t5\t1500000US010770115015\t010770115015\t5\tBG\t6844991\t32636" - assert spatial_rdd.rawSpatialRDD.take(2)[1].getUserData() == "01\t045\t021102\t4\t1500000US010450211024\t010450211024\t4\tBG\t11360854\t0" - assert spatial_rdd.fieldNames == ["STATEFP", "COUNTYFP", "TRACTCE", "BLKGRPCE", "AFFGEOID", "GEOID", "NAME", "LSAD", "ALAND", "AWATER"] + assert ( + spatial_rdd.rawSpatialRDD.take(1)[0].getUserData() + == "01\t077\t011501\t5\t1500000US010770115015\t010770115015\t5\tBG\t6844991\t32636" + ) + assert ( + spatial_rdd.rawSpatialRDD.take(2)[1].getUserData() + == "01\t045\t021102\t4\t1500000US010450211024\t010450211024\t4\tBG\t11360854\t0" + ) + assert spatial_rdd.fieldNames == [ + "STATEFP", + "COUNTYFP", + "TRACTCE", + "BLKGRPCE", + "AFFGEOID", + "GEOID", + "NAME", + "LSAD", + "ALAND", + "AWATER", + ] def test_wkt_constructor(self): spatial_rdd = PolygonRDD( sparkContext=self.sc, InputLocation=input_location_wkt, splitter=FileDataSplitter.WKT, - carryInputData=True + carryInputData=True, ) spatial_rdd.analyze() assert spatial_rdd.approximateTotalCount == 103 assert spatial_rdd.boundaryEnvelope is not None - assert spatial_rdd.rawSpatialRDD.take(1)[0].getUserData() == "31\t039\t00835841\t31039\tCuming\tCuming County\t06\tH1\tG4020\t\t\t\tA\t1477895811\t10447360\t+41.9158651\t-096.7885168" + assert ( + spatial_rdd.rawSpatialRDD.take(1)[0].getUserData() + == "31\t039\t00835841\t31039\tCuming\tCuming County\t06\tH1\tG4020\t\t\t\tA\t1477895811\t10447360\t+41.9158651\t-096.7885168" + ) def test_wkb_constructor(self): spatial_rdd = PolygonRDD( sparkContext=self.sc, InputLocation=input_location_wkb, splitter=FileDataSplitter.WKB, - carryInputData=True + carryInputData=True, ) spatial_rdd.analyze() assert spatial_rdd.approximateTotalCount == 103 assert spatial_rdd.boundaryEnvelope is not None - assert spatial_rdd.rawSpatialRDD.take(1)[0].getUserData() == "31\t039\t00835841\t31039\tCuming\tCuming County\t06\tH1\tG4020\t\t\t\tA\t1477895811\t10447360\t+41.9158651\t-096.7885168" + assert ( + spatial_rdd.rawSpatialRDD.take(1)[0].getUserData() + == "31\t039\t00835841\t31039\tCuming\tCuming County\t06\tH1\tG4020\t\t\t\tA\t1477895811\t10447360\t+41.9158651\t-096.7885168" + ) def test_mbr(self): polygon_rdd = PolygonRDD( @@ -153,7 +180,7 @@ def test_mbr(self): InputLocation=input_location, splitter=FileDataSplitter.CSV, carryInputData=True, - partitions=num_partitions + partitions=num_partitions, ) rectangle_rdd = polygon_rdd.MinimumBoundingRectangle() diff --git a/python/tests/spatial_rdd/test_rectangle_rdd.py b/python/tests/spatial_rdd/test_rectangle_rdd.py index 570c09977b..29b406c06d 100644 --- a/python/tests/spatial_rdd/test_rectangle_rdd.py +++ b/python/tests/spatial_rdd/test_rectangle_rdd.py @@ -35,7 +35,9 @@ distance = 0.001 queryPolygonSet = os.path.join(tests_resource, "primaryroads-polygon.csv") inputCount = 3000 -inputBoundary = Envelope(minx=-171.090042, maxx=145.830505, miny=-14.373765, maxy=49.00127) +inputBoundary = Envelope( + minx=-171.090042, maxx=145.830505, miny=-14.373765, maxy=49.00127 +) matchCount = 17599 matchWithOriginalDuplicatesCount = 17738 @@ -49,7 +51,7 @@ def test_constructor(self): Offset=offset, splitter=splitter, carryInputData=True, - partitions=numPartitions + partitions=numPartitions, ) spatial_rdd.analyze() @@ -58,12 +60,7 @@ def test_constructor(self): assert inputBoundary == spatial_rdd.boundaryEnvelope spatial_rdd = RectangleRDD( - self.sc, - inputLocation, - offset, - splitter, - True, - numPartitions + self.sc, inputLocation, offset, splitter, True, numPartitions ) spatial_rdd.analyze() @@ -78,7 +75,7 @@ def test_empty_constructor(self): Offset=offset, splitter=splitter, carryInputData=True, - partitions=numPartitions + partitions=numPartitions, ) spatial_rdd.analyze() @@ -95,7 +92,7 @@ def test_build_index_without_set_grid(self): Offset=offset, splitter=splitter, carryInputData=True, - partitions=numPartitions + partitions=numPartitions, ) spatial_rdd.analyze() diff --git a/python/tests/spatial_rdd/test_spatial_rdd.py b/python/tests/spatial_rdd/test_spatial_rdd.py index 37fe3819f6..6866d6b221 100644 --- a/python/tests/spatial_rdd/test_spatial_rdd.py +++ b/python/tests/spatial_rdd/test_spatial_rdd.py @@ -49,7 +49,7 @@ def create_spatial_rdd(self): Offset=offset, splitter=splitter, carryInputData=True, - partitions=numPartitions + partitions=numPartitions, ) return spatial_rdd @@ -71,13 +71,17 @@ def test_boundary(self): spatial_rdd = self.create_spatial_rdd() envelope = spatial_rdd.boundary() - assert envelope == Envelope(minx=-173.120769, maxx=-84.965961, miny=30.244859, maxy=71.355134) + assert envelope == Envelope( + minx=-173.120769, maxx=-84.965961, miny=30.244859, maxy=71.355134 + ) def test_boundary_envelope(self): spatial_rdd = self.create_spatial_rdd() spatial_rdd.analyze() - assert Envelope( - minx=-173.120769, maxx=-84.965961, miny=30.244859, maxy=71.355134) == spatial_rdd.boundaryEnvelope + assert ( + Envelope(minx=-173.120769, maxx=-84.965961, miny=30.244859, maxy=71.355134) + == spatial_rdd.boundaryEnvelope + ) def test_build_index(self): for grid_type in GridType: @@ -101,15 +105,12 @@ def test_field_names(self): spatial_rdd = self.create_spatial_rdd() assert spatial_rdd.fieldNames == [] geo_json_rdd = GeoJsonReader.readToGeometryRDD( - self.sc, - geo_json_contains_id, - True, - False + self.sc, geo_json_contains_id, True, False ) try: - assert geo_json_rdd.fieldNames == ['zipcode', 'name'] + assert geo_json_rdd.fieldNames == ["zipcode", "name"] except AssertionError: - assert geo_json_rdd.fieldNames == ['id', 'zipcode', 'name'] + assert geo_json_rdd.fieldNames == ["id", "zipcode", "name"] def test_get_partitioner(self): spatial_rdd = self.create_spatial_rdd() diff --git a/python/tests/spatial_rdd/test_spatial_rdd_writer.py b/python/tests/spatial_rdd/test_spatial_rdd_writer.py index 1685ded8df..334a7d1485 100644 --- a/python/tests/spatial_rdd/test_spatial_rdd_writer.py +++ b/python/tests/spatial_rdd/test_spatial_rdd_writer.py @@ -29,11 +29,15 @@ wkb_folder = "wkb" wkt_folder = "wkt" -test_save_as_wkb_with_data = os.path.join(tests_resource, wkb_folder, "testSaveAsWKBWithData") +test_save_as_wkb_with_data = os.path.join( + tests_resource, wkb_folder, "testSaveAsWKBWithData" +) test_save_as_wkb = os.path.join(tests_resource, wkb_folder, "testSaveAsWKB") test_save_as_empty_wkb = os.path.join(tests_resource, wkb_folder, "testSaveAsEmptyWKB") test_save_as_wkt = os.path.join(tests_resource, wkt_folder, "testSaveAsWKT") -test_save_as_wkt_with_data = os.path.join(tests_resource, wkt_folder, "testSaveAsWKTWithData") +test_save_as_wkt_with_data = os.path.join( + tests_resource, wkt_folder, "testSaveAsWKTWithData" +) inputLocation = os.path.join(tests_resource, "arealm-small.csv") queryWindowSet = os.path.join(tests_resource, "zcta510-small.csv") @@ -46,16 +50,14 @@ queryPolygonSet = "primaryroads-polygon.csv" inputCount = 3000 inputBoundary = Envelope( - minx=-173.120769, - maxx=-84.965961, - miny=30.244859, - maxy=71.355134 + minx=-173.120769, maxx=-84.965961, miny=30.244859, maxy=71.355134 ) rectangleMatchCount = 103 rectangleMatchWithOriginalDuplicatesCount = 103 polygonMatchCount = 472 polygonMatchWithOriginalDuplicatesCount = 562 + ## todo add missing tests def remove_directory(path: str) -> bool: try: @@ -79,7 +81,7 @@ def test_save_as_geo_json_with_data(self, remove_wkb_directory): Offset=offset, splitter=splitter, carryInputData=True, - partitions=numPartitions + partitions=numPartitions, ) spatial_rdd.saveAsGeoJSON(test_save_as_wkb_with_data) @@ -89,7 +91,7 @@ def test_save_as_geo_json_with_data(self, remove_wkb_directory): InputLocation=test_save_as_wkb_with_data, splitter=FileDataSplitter.GEOJSON, carryInputData=True, - partitions=numPartitions + partitions=numPartitions, ) assert result_wkb.rawSpatialRDD.count() == spatial_rdd.rawSpatialRDD.count() diff --git a/python/tests/sql/resource/sample_data.py b/python/tests/sql/resource/sample_data.py index 6540cdd7d1..d907f66c48 100644 --- a/python/tests/sql/resource/sample_data.py +++ b/python/tests/sql/resource/sample_data.py @@ -28,16 +28,24 @@ data_path = path.abspath(path.dirname(__file__)) -def create_sample_polygons_df(spark: SparkSession, number_of_polygons: int) -> DataFrame: - return resource_file_to_dataframe(spark, "sample_polygons").limit(number_of_polygons) +def create_sample_polygons_df( + spark: SparkSession, number_of_polygons: int +) -> DataFrame: + return resource_file_to_dataframe(spark, "sample_polygons").limit( + number_of_polygons + ) def create_sample_points_df(spark: SparkSession, number_of_points: int) -> DataFrame: return resource_file_to_dataframe(spark, "sample_points").limit(number_of_points) -def create_simple_polygons_df(spark: SparkSession, number_of_polygons: int) -> DataFrame: - return resource_file_to_dataframe(spark, "simple_polygons").limit(number_of_polygons) +def create_simple_polygons_df( + spark: SparkSession, number_of_polygons: int +) -> DataFrame: + return resource_file_to_dataframe(spark, "simple_polygons").limit( + number_of_polygons + ) def create_sample_lines_df(spark: SparkSession, number_of_lines: int) -> DataFrame: @@ -45,26 +53,24 @@ def create_sample_lines_df(spark: SparkSession, number_of_lines: int) -> DataFra def create_sample_polygons(number_of_polygons: int) -> List: - return load_from_resources(data_path, "sample_polygons")[: number_of_polygons] + return load_from_resources(data_path, "sample_polygons")[:number_of_polygons] def create_sample_points(number_of_points: int) -> List: - return load_from_resources(data_path, "sample_points")[: number_of_points] + return load_from_resources(data_path, "sample_points")[:number_of_points] def create_simple_polygons(number_of_polygons: int) -> List: - return load_from_resources(data_path, "simple_polygons")[: number_of_polygons] + return load_from_resources(data_path, "simple_polygons")[:number_of_polygons] def create_sample_lines(number_of_lines: int) -> List: - return load_from_resources(data_path, "sample_lines")[: number_of_lines] + return load_from_resources(data_path, "sample_lines")[:number_of_lines] def resource_file_to_dataframe(spark: SparkSession, file_path: str) -> DataFrame: geometries = load_from_resources(data_path, file_path) - schema = StructType([ - StructField("geom", GeometryType(), True) - ]) + schema = StructType([StructField("geom", GeometryType(), True)]) return spark.createDataFrame([[el] for el in geometries], schema=schema) diff --git a/python/tests/sql/test_adapter.py b/python/tests/sql/test_adapter.py index a225669a18..c69e330daf 100644 --- a/python/tests/sql/test_adapter.py +++ b/python/tests/sql/test_adapter.py @@ -31,8 +31,11 @@ from sedona.core.jvm.config import is_greater_or_equal_version from sedona.core.spatialOperator import JoinQuery from sedona.utils.adapter import Adapter -from tests import geojson_input_location, shape_file_with_missing_trailing_input_location, \ - geojson_id_input_location +from tests import ( + geojson_input_location, + shape_file_with_missing_trailing_input_location, + geojson_id_input_location, +) from tests import shape_file_input_location, area_lm_point_input_location from tests import mixed_wkt_geometry_input_location from tests.test_base import TestBase @@ -41,16 +44,19 @@ class TestAdapter(TestBase): def test_read_csv_point_into_spatial_rdd(self): - df = self.spark.read.\ - format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(area_lm_point_input_location) + df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(area_lm_point_input_location) + ) df.show() df.createOrReplaceTempView("inputtable") - spatial_df = self.spark.sql("select ST_PointFromText(inputtable._c0,\",\") as arealandmark from inputtable") + spatial_df = self.spark.sql( + 'select ST_PointFromText(inputtable._c0,",") as arealandmark from inputtable' + ) spatial_df.show() spatial_df.printSchema() @@ -59,10 +65,12 @@ def test_read_csv_point_into_spatial_rdd(self): Adapter.toDf(spatial_rdd, self.spark).show() def test_read_csv_point_into_spatial_rdd_by_passing_coordinates(self): - df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(area_lm_point_input_location) + df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(area_lm_point_input_location) + ) df.show() df.createOrReplaceTempView("inputtable") @@ -74,67 +82,82 @@ def test_read_csv_point_into_spatial_rdd_by_passing_coordinates(self): spatial_df.show() spatial_df.printSchema() - def test_read_csv_point_into_spatial_rdd_with_unique_id_by_passing_coordinates(self): - df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(area_lm_point_input_location) + def test_read_csv_point_into_spatial_rdd_with_unique_id_by_passing_coordinates( + self, + ): + df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(area_lm_point_input_location) + ) df.show() df.createOrReplaceTempView("inputtable") spatial_df = self.spark.sql( - "select ST_Point(cast(inputtable._c0 as Decimal(24,20)),cast(inputtable._c1 as Decimal(24,20))) as arealandmark from inputtable") + "select ST_Point(cast(inputtable._c0 as Decimal(24,20)),cast(inputtable._c1 as Decimal(24,20))) as arealandmark from inputtable" + ) spatial_df.show() spatial_df.printSchema() def test_read_mixed_wkt_geometries_into_spatial_rdd(self): - df = self.spark.read.format("csv").\ - option("delimiter", "\t").\ - option("header", "false").load(mixed_wkt_geometry_input_location) + df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) df.show() df.createOrReplaceTempView("inputtable") - spatial_df = self.spark.sql("select ST_GeomFromWKT(inputtable._c0) as usacounty from inputtable") + spatial_df = self.spark.sql( + "select ST_GeomFromWKT(inputtable._c0) as usacounty from inputtable" + ) spatial_df.show() spatial_df.printSchema() spatial_rdd = Adapter.toSpatialRdd(spatial_df, "usacounty") spatial_rdd.analyze() Adapter.toDf(spatial_rdd, self.spark).show() - assert (Adapter.toDf(spatial_rdd, self.spark).columns.__len__() == 1) + assert Adapter.toDf(spatial_rdd, self.spark).columns.__len__() == 1 Adapter.toDf(spatial_rdd, self.spark).show() def test_read_mixed_wkt_geometries_into_spatial_rdd_with_unique_id(self): - df = self.spark.read.format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(mixed_wkt_geometry_input_location) + df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) df.show() df.createOrReplaceTempView("inputtable") spatial_df = self.spark.sql( - "select ST_GeomFromWKT(inputtable._c0) as usacounty, inputtable._c3, inputtable._c5 from inputtable") + "select ST_GeomFromWKT(inputtable._c0) as usacounty, inputtable._c3, inputtable._c5 from inputtable" + ) spatial_df.show() spatial_df.printSchema() spatial_rdd = Adapter.toSpatialRdd(spatial_df, "usacounty") spatial_rdd.analyze() - assert (Adapter.toDf(spatial_rdd, self.spark).columns.__len__() == 3) + assert Adapter.toDf(spatial_rdd, self.spark).columns.__len__() == 3 Adapter.toDf(spatial_rdd, self.spark).show() def test_read_shapefile_to_dataframe(self): spatial_rdd = ShapefileReader.readToGeometryRDD( - self.spark.sparkContext, shape_file_input_location) + self.spark.sparkContext, shape_file_input_location + ) spatial_rdd.analyze() logging.info(spatial_rdd.fieldNames) df = Adapter.toDf(spatial_rdd, self.spark) df.show() def test_read_shapefile_with_missing_to_dataframe(self): - spatial_rdd = ShapefileReader.\ - readToGeometryRDD(self.spark.sparkContext, shape_file_with_missing_trailing_input_location) + spatial_rdd = ShapefileReader.readToGeometryRDD( + self.spark.sparkContext, shape_file_with_missing_trailing_input_location + ) spatial_rdd.analyze() logging.info(spatial_rdd.fieldNames) @@ -144,32 +167,45 @@ def test_read_shapefile_with_missing_to_dataframe(self): def test_geojson_to_dataframe(self): spatial_rdd = PolygonRDD( - self.spark.sparkContext, geojson_input_location, FileDataSplitter.GEOJSON, True + self.spark.sparkContext, + geojson_input_location, + FileDataSplitter.GEOJSON, + True, ) spatial_rdd.analyze() Adapter.toDf(spatial_rdd, self.spark).show() df = Adapter.toDf(spatial_rdd, self.spark) - assert (df.columns[1] == "STATEFP") + assert df.columns[1] == "STATEFP" def test_convert_spatial_join_result_to_dataframe(self): - polygon_wkt_df = self.spark.read.format("csv").option("delimiter", "\t").option("header", "false").load( - mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_df = self.spark.sql( - "select ST_GeomFromWKT(polygontable._c0) as usacounty from polygontable") + "select ST_GeomFromWKT(polygontable._c0) as usacounty from polygontable" + ) polygon_rdd = Adapter.toSpatialRdd(polygon_df, "usacounty") polygon_rdd.analyze() - point_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - area_lm_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(area_lm_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) point_rdd = Adapter.toSpatialRdd(point_df, "arealandmark") point_rdd.analyze() @@ -179,40 +215,44 @@ def test_convert_spatial_join_result_to_dataframe(self): point_rdd.buildIndex(IndexType.QUADTREE, True) - join_result_point_rdd = JoinQuery.\ - SpatialJoinQueryFlat(point_rdd, polygon_rdd, True, True) + join_result_point_rdd = JoinQuery.SpatialJoinQueryFlat( + point_rdd, polygon_rdd, True, True + ) join_result_df = Adapter.toDf(join_result_point_rdd, self.spark) join_result_df.show() - join_result_df2 = Adapter.toDf(join_result_point_rdd, ["abc", "def"], list(), self.spark) + join_result_df2 = Adapter.toDf( + join_result_point_rdd, ["abc", "def"], list(), self.spark + ) join_result_df2.show() def test_distance_join_result_to_dataframe(self): - point_csv_df = self.spark.\ - read.\ - format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - area_lm_point_input_location + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(area_lm_point_input_location) ) point_csv_df.createOrReplaceTempView("pointtable") point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) point_rdd = Adapter.toSpatialRdd(point_df, "arealandmark") point_rdd.analyze() - polygon_wkt_df = self.spark.read.\ - format("csv").\ - option("delimiter", "\t").\ - option("header", "false").load( - mixed_wkt_geometry_input_location + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) ) polygon_wkt_df.createOrReplaceTempView("polygontable") - polygon_df = self.spark.\ - sql("select ST_GeomFromWKT(polygontable._c0) as usacounty from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as usacounty from polygontable" + ) polygon_rdd = Adapter.toSpatialRdd(polygon_df, "usacounty") polygon_rdd.analyze() @@ -223,15 +263,21 @@ def test_distance_join_result_to_dataframe(self): point_rdd.buildIndex(IndexType.QUADTREE, True) - join_result_pair_rdd = JoinQuery.\ - DistanceJoinQueryFlat(point_rdd, circle_rdd, True, True) + join_result_pair_rdd = JoinQuery.DistanceJoinQueryFlat( + point_rdd, circle_rdd, True, True + ) join_result_df = Adapter.toDf(join_result_pair_rdd, self.spark) join_result_df.printSchema() join_result_df.show() def test_load_id_column_data_check(self): - spatial_rdd = PolygonRDD(self.spark.sparkContext, geojson_id_input_location, FileDataSplitter.GEOJSON, True) + spatial_rdd = PolygonRDD( + self.spark.sparkContext, + geojson_id_input_location, + FileDataSplitter.GEOJSON, + True, + ) spatial_rdd.analyze() df = Adapter.toDf(spatial_rdd, self.spark) df.show() @@ -242,15 +288,18 @@ def test_load_id_column_data_check(self): assert df.count() == 1 def _create_spatial_point_table(self) -> DataFrame: - df = self.spark.read.\ - format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(area_lm_point_input_location) + df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(area_lm_point_input_location) + ) df.createOrReplaceTempView("inputtable") - spatial_df = self.spark.sql("select ST_PointFromText(inputtable._c0,\",\") as geom from inputtable") + spatial_df = self.spark.sql( + 'select ST_PointFromText(inputtable._c0,",") as geom from inputtable' + ) return spatial_df @@ -262,25 +311,35 @@ def test_to_spatial_rdd_df_and_geom_field_name(self): spatial_rdd.analyze() assert spatial_rdd.approximateTotalCount == 121960 - assert spatial_rdd.boundaryEnvelope == Envelope(-179.147236, 179.475569, -14.548699, 71.35513400000001) + assert spatial_rdd.boundaryEnvelope == Envelope( + -179.147236, 179.475569, -14.548699, 71.35513400000001 + ) def test_to_spatial_rdd_df_with_non_geom_fields(self): spatial_df = self._create_spatial_point_table() - spatial_df = spatial_df.withColumn("i", expr("10")).withColumn("s", expr("'20'")) + spatial_df = spatial_df.withColumn("i", expr("10")).withColumn( + "s", expr("'20'") + ) spatial_rdd = Adapter.toSpatialRdd(spatial_df, "geom") - assert spatial_rdd.fieldNames == ['i', 's'] + assert spatial_rdd.fieldNames == ["i", "s"] spatial_rdd.analyze() assert spatial_rdd.approximateTotalCount == 121960 - assert spatial_rdd.boundaryEnvelope == Envelope(-179.147236, 179.475569, -14.548699, 71.35513400000001) + assert spatial_rdd.boundaryEnvelope == Envelope( + -179.147236, 179.475569, -14.548699, 71.35513400000001 + ) def test_to_spatial_rdd_df_with_custom_user_data_field_names(self): spatial_df = self._create_spatial_point_table() - spatial_df = spatial_df.withColumn("i", expr("10")).withColumn("s", expr("'20'")) + spatial_df = spatial_df.withColumn("i", expr("10")).withColumn( + "s", expr("'20'") + ) spatial_rdd = Adapter.toSpatialRdd(spatial_df, "geom", ["i2", "s2"]) - assert spatial_rdd.fieldNames == ['i2', 's2'] + assert spatial_rdd.fieldNames == ["i2", "s2"] spatial_rdd.analyze() assert spatial_rdd.approximateTotalCount == 121960 - assert spatial_rdd.boundaryEnvelope == Envelope(-179.147236, 179.475569, -14.548699, 71.35513400000001) + assert spatial_rdd.boundaryEnvelope == Envelope( + -179.147236, 179.475569, -14.548699, 71.35513400000001 + ) def test_to_spatial_rdd_df(self): spatial_df = self._create_spatial_point_table() @@ -290,38 +349,54 @@ def test_to_spatial_rdd_df(self): spatial_rdd.analyze() assert spatial_rdd.approximateTotalCount == 121960 - assert spatial_rdd.boundaryEnvelope == Envelope(-179.147236, 179.475569, -14.548699, 71.35513400000001) + assert spatial_rdd.boundaryEnvelope == Envelope( + -179.147236, 179.475569, -14.548699, 71.35513400000001 + ) - @pytest.mark.skipif(is_greater_or_equal_version(version, "1.0.0"), reason="Deprecated in Sedona") + @pytest.mark.skipif( + is_greater_or_equal_version(version, "1.0.0"), reason="Deprecated in Sedona" + ) def test_to_spatial_rdd_df_geom_column_id(self): - df = self.spark.read.\ - format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(mixed_wkt_geometry_input_location) + df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) - df_shorter = df.select(col("_c0").alias("geom"), col("_c6").alias("county_name")) + df_shorter = df.select( + col("_c0").alias("geom"), col("_c6").alias("county_name") + ) df_shorter.createOrReplaceTempView("county_data") - spatial_df = self.spark.sql("SELECT ST_GeomFromWKT(geom) as geom, county_name FROM county_data") + spatial_df = self.spark.sql( + "SELECT ST_GeomFromWKT(geom) as geom, county_name FROM county_data" + ) spatial_df.show() def test_to_df_srdd_fn_spark(self): spatial_rdd = PolygonRDD( - self.spark.sparkContext, geojson_input_location, FileDataSplitter.GEOJSON, True + self.spark.sparkContext, + geojson_input_location, + FileDataSplitter.GEOJSON, + True, ) spatial_rdd.analyze() assert spatial_rdd.approximateTotalCount == 1001 spatial_columns = [ - "state_id", "county_id", "tract_id", "bg_id", - "fips", "fips_short", "bg_nr", "type", "code1", "code2" - ] - spatial_df = Adapter.toDf( - spatial_rdd, - spatial_columns, - self.spark - ) + "state_id", + "county_id", + "tract_id", + "bg_id", + "fips", + "fips_short", + "bg_nr", + "type", + "code1", + "code2", + ] + spatial_df = Adapter.toDf(spatial_rdd, spatial_columns, self.spark) spatial_df.show() diff --git a/python/tests/sql/test_aggregate_functions.py b/python/tests/sql/test_aggregate_functions.py index a12b141bce..32aa054aec 100644 --- a/python/tests/sql/test_aggregate_functions.py +++ b/python/tests/sql/test_aggregate_functions.py @@ -24,22 +24,28 @@ class TestConstructors(TestBase): def test_st_envelope_aggr(self): - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") - point_df = self.spark.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + point_df = self.spark.sql( + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) point_df.createOrReplaceTempView("pointdf") - boundary = self.spark.sql("select ST_Envelope_Aggr(pointdf.arealandmark) from pointdf") + boundary = self.spark.sql( + "select ST_Envelope_Aggr(pointdf.arealandmark) from pointdf" + ) coordinates = [ (1.1, 101.1), (1.1, 1100.1), (1000.1, 1100.1), (1000.1, 101.1), - (1.1, 101.1) + (1.1, 101.1), ] polygon = Polygon(coordinates) @@ -47,16 +53,22 @@ def test_st_envelope_aggr(self): assert boundary.take(1)[0][0] == polygon def test_st_union_aggr(self): - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(union_polygon_input_location) + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(union_polygon_input_location) + ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() - polygon_df = self.spark.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + polygon_df = self.spark.sql( + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - union = self.spark.sql("select ST_Union_Aggr(polygondf.polygonshape) from polygondf") + union = self.spark.sql( + "select ST_Union_Aggr(polygondf.polygonshape) from polygondf" + ) assert union.take(1)[0][0].area == 10100 diff --git a/python/tests/sql/test_constructor_test.py b/python/tests/sql/test_constructor_test.py index 345862d4cf..3b2c1ba792 100644 --- a/python/tests/sql/test_constructor_test.py +++ b/python/tests/sql/test_constructor_test.py @@ -15,22 +15,31 @@ # specific language governing permissions and limitations # under the License. -from tests import csv_point_input_location, area_lm_point_input_location, mixed_wkt_geometry_input_location, \ - mixed_wkb_geometry_input_location, geojson_input_location +from tests import ( + csv_point_input_location, + area_lm_point_input_location, + mixed_wkt_geometry_input_location, + mixed_wkb_geometry_input_location, + geojson_input_location, +) from tests.test_base import TestBase class TestConstructors(TestBase): def test_st_point(self): - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") - point_df = self.spark.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + point_df = self.spark.sql( + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) assert point_df.count() == 1000 def test_st_point_z(self): @@ -42,194 +51,278 @@ def test_st_point_m(self): assert point_df.count() == 1 def test_st_makepointm(self): - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") - point_df = self.spark.sql("select ST_MakePointM(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20)), 2.0) as arealandmark from pointtable") + point_df = self.spark.sql( + "select ST_MakePointM(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20)), 2.0) as arealandmark from pointtable" + ) assert point_df.count() == 1000 - point_df = self.spark.sql("SELECT ST_AsText(ST_MakePointM(1.2345, 2.3456, 3.4567))") + point_df = self.spark.sql( + "SELECT ST_AsText(ST_MakePointM(1.2345, 2.3456, 3.4567))" + ) assert point_df.take(1)[0][0] == "POINT M(1.2345 2.3456 3.4567)" def test_st_makepoint(self): - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") - point_df = self.spark.sql("select ST_MakePoint(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + point_df = self.spark.sql( + "select ST_MakePoint(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) assert point_df.count() == 1000 - point_df = self.spark.sql("SELECT ST_AsText(ST_MakePoint(1.2345, 2.3456, 3.4567))") + point_df = self.spark.sql( + "SELECT ST_AsText(ST_MakePoint(1.2345, 2.3456, 3.4567))" + ) assert point_df.take(1)[0][0] == "POINT Z(1.2345 2.3456 3.4567)" - point_df = self.spark.sql("SELECT ST_AsText(ST_MakePoint(1.2345, 2.3456, 3.4567, 4))") + point_df = self.spark.sql( + "SELECT ST_AsText(ST_MakePoint(1.2345, 2.3456, 3.4567, 4))" + ) assert point_df.take(1)[0][0] == "POINT ZM(1.2345 2.3456 3.4567 4)" def test_st_point_from_text(self): - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load(area_lm_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(area_lm_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show(truncate=False) - point_df = self.spark.sql("select ST_PointFromText(concat(_c0,',',_c1),',') as arealandmark from pointtable") + point_df = self.spark.sql( + "select ST_PointFromText(concat(_c0,',',_c1),',') as arealandmark from pointtable" + ) assert point_df.count() == 121960 def test_st_geom_from_wkt(self): - polygon_wkt_df = self.spark.read.format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWkt(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWkt(polygontable._c0) as countyshape from polygontable" + ) polygon_df.show(10) assert polygon_df.count() == 100 def test_st_geom_from_ewkt(self): - input_df = self.spark.createDataFrame([("SRID=4269;LineString(1 2, 3 4)",)], ["ewkt"]) + input_df = self.spark.createDataFrame( + [("SRID=4269;LineString(1 2, 3 4)",)], ["ewkt"] + ) input_df.createOrReplaceTempView("input_ewkt") line_df = self.spark.sql("select ST_GeomFromEWKT(ewkt) as geom from input_ewkt") assert line_df.count() == 1 def test_st_geom_from_wkt_3d(self): - input_df = self.spark.createDataFrame([ - ("Point(21 52 87)",), - ("Polygon((0 0 1, 0 1 1, 1 1 1, 1 0 1, 0 0 1))",), - ("Linestring(0 0 1, 1 1 2, 1 0 3)",), - ("MULTIPOINT ((10 40 66), (40 30 77), (20 20 88), (30 10 99))",), - ("MULTIPOLYGON (((30 20 11, 45 40 11, 10 40 11, 30 20 11)), ((15 5 11, 40 10 11, 10 20 11, 5 10 11, 15 5 11)))",), - ("MULTILINESTRING ((10 10 11, 20 20 11, 10 40 11), (40 40 11, 30 30 11, 40 20 11, 30 10 11))",), - ("MULTIPOLYGON (((40 40 11, 20 45 11, 45 30 11, 40 40 11)), ((20 35 11, 10 30 11, 10 10 11, 30 5 11, 45 20 11, 20 35 11), (30 20 11, 20 15 11, 20 25 11, 30 20 11)))",), - ("POLYGON((0 0 11, 0 5 11, 5 5 11, 5 0 11, 0 0 11), (1 1 11, 2 1 11, 2 2 11, 1 2 11, 1 1 11))",), - ], ["wkt"]) + input_df = self.spark.createDataFrame( + [ + ("Point(21 52 87)",), + ("Polygon((0 0 1, 0 1 1, 1 1 1, 1 0 1, 0 0 1))",), + ("Linestring(0 0 1, 1 1 2, 1 0 3)",), + ("MULTIPOINT ((10 40 66), (40 30 77), (20 20 88), (30 10 99))",), + ( + "MULTIPOLYGON (((30 20 11, 45 40 11, 10 40 11, 30 20 11)), ((15 5 11, 40 10 11, 10 20 11, 5 10 11, 15 5 11)))", + ), + ( + "MULTILINESTRING ((10 10 11, 20 20 11, 10 40 11), (40 40 11, 30 30 11, 40 20 11, 30 10 11))", + ), + ( + "MULTIPOLYGON (((40 40 11, 20 45 11, 45 30 11, 40 40 11)), ((20 35 11, 10 30 11, 10 10 11, 30 5 11, 45 20 11, 20 35 11), (30 20 11, 20 15 11, 20 25 11, 30 20 11)))", + ), + ( + "POLYGON((0 0 11, 0 5 11, 5 5 11, 5 0 11, 0 0 11), (1 1 11, 2 1 11, 2 2 11, 1 2 11, 1 1 11))", + ), + ], + ["wkt"], + ) input_df.createOrReplaceTempView("input_wkt") - polygon_df = self.spark.sql("select ST_GeomFromWkt(wkt) as geomn from input_wkt") + polygon_df = self.spark.sql( + "select ST_GeomFromWkt(wkt) as geomn from input_wkt" + ) assert polygon_df.count() == 8 def test_st_make_envelope(self): polygonDF = self.spark.sql( - "select ST_MakeEnvelope(double(1.234),double(2.234),double(3.345),double(3.345), 1111) as geom") - assert (polygonDF.count() == 1) - assert (1111 == polygonDF.selectExpr("ST_SRID(geom)").first()[0]) + "select ST_MakeEnvelope(double(1.234),double(2.234),double(3.345),double(3.345), 1111) as geom" + ) + assert polygonDF.count() == 1 + assert 1111 == polygonDF.selectExpr("ST_SRID(geom)").first()[0] polygonDF = self.spark.sql( - "select ST_MakeEnvelope(double(1.234),double(2.234),double(3.345),double(3.345))") - assert (polygonDF.count() == 1) + "select ST_MakeEnvelope(double(1.234),double(2.234),double(3.345),double(3.345))" + ) + assert polygonDF.count() == 1 def test_st_geom_from_text(self): - polygon_wkt_df = self.spark.read.format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromText(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromText(polygontable._c0) as countyshape from polygontable" + ) polygon_df.show(10) assert polygon_df.count() == 100 def test_st_point_from_geohash(self): - actual = self.spark.sql("select ST_AsText(ST_PointFromGeohash('9qqj7nmxncgyy4d0dbxqz0', 4))").take(1)[0][0] + actual = self.spark.sql( + "select ST_AsText(ST_PointFromGeohash('9qqj7nmxncgyy4d0dbxqz0', 4))" + ).take(1)[0][0] expected = "POINT (-115.13671875 36.123046875)" assert actual == expected - actual = self.spark.sql("select ST_AsText(ST_PointFromGeohash('9qqj7nmxncgyy4d0dbxqz0'))").take(1)[0][0] + actual = self.spark.sql( + "select ST_AsText(ST_PointFromGeohash('9qqj7nmxncgyy4d0dbxqz0'))" + ).take(1)[0][0] expected = "POINT (-115.17281600000001 36.11464599999999)" assert actual == expected def test_st_geometry_from_text(self): - polygon_wkt_df = self.spark.read.format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") - polygon_df = self.spark.sql("select ST_GeometryFromText(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeometryFromText(polygontable._c0) as countyshape from polygontable" + ) assert polygon_df.count() == 100 - polygon_df = self.spark.sql("select ST_GeomFromText(polygontable._c0, 4326) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromText(polygontable._c0, 4326) as countyshape from polygontable" + ) assert polygon_df.count() == 100 def test_st_geom_from_wkb(self): - polygon_wkb_df = self.spark.read.format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(mixed_wkb_geometry_input_location) + polygon_wkb_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkb_geometry_input_location) + ) polygon_wkb_df.createOrReplaceTempView("polygontable") polygon_wkb_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWKB(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKB(polygontable._c0) as countyshape from polygontable" + ) polygon_df.show(10) assert polygon_df.count() == 100 def test_st_geom_from_ewkb(self): - polygon_wkb_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkb_geometry_input_location) + polygon_wkb_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkb_geometry_input_location) + ) polygon_wkb_df.createOrReplaceTempView("polygontable") polygon_wkb_df.show() - polygon_df = self.spark.sql("select ST_GeomFromEWKB(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromEWKB(polygontable._c0) as countyshape from polygontable" + ) polygon_df.show(10) assert polygon_df.count() == 100 def test_st_linestring_from_wkb(self): - linestring_ba = self.spark.sql("select unhex('0102000000020000000000000084d600c00000000080b5d6bf00000060e1eff7bf00000080075de5bf') as wkb") - actual = linestring_ba.selectExpr("ST_AsText(ST_LineStringFromWKB(wkb))").take(1)[0][0] + linestring_ba = self.spark.sql( + "select unhex('0102000000020000000000000084d600c00000000080b5d6bf00000060e1eff7bf00000080075de5bf') as wkb" + ) + actual = linestring_ba.selectExpr("ST_AsText(ST_LineStringFromWKB(wkb))").take( + 1 + )[0][0] expected = "LINESTRING (-2.1047439575195312 -0.354827880859375, -1.49606454372406 -0.6676061153411865)" assert actual == expected - linestring_s = self.spark.sql("select '0102000000020000000000000084d600c00000000080b5d6bf00000060e1eff7bf00000080075de5bf' as wkb") - actual = linestring_s.selectExpr("ST_AsText(ST_LinestringFromWKB(wkb))").take(1)[0][0] + linestring_s = self.spark.sql( + "select '0102000000020000000000000084d600c00000000080b5d6bf00000060e1eff7bf00000080075de5bf' as wkb" + ) + actual = linestring_s.selectExpr("ST_AsText(ST_LinestringFromWKB(wkb))").take( + 1 + )[0][0] assert actual == expected - def test_st_geom_from_geojson(self): - polygon_json_df = self.spark.read.format("csv").\ - option("delimiter", "\t").\ - option("header", "false").\ - load(geojson_input_location) + polygon_json_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(geojson_input_location) + ) polygon_json_df.createOrReplaceTempView("polygontable") polygon_json_df.show() - polygon_df = self.spark.sql("select ST_GeomFromGeoJSON(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromGeoJSON(polygontable._c0) as countyshape from polygontable" + ) polygon_df.show() assert polygon_df.count() == 1000 - def test_line_from_text(self) : + def test_line_from_text(self): input_df = self.spark.createDataFrame([("LineString(1 2, 3 4)",)], ["wkt"]) input_df.createOrReplaceTempView("input_wkt") line_df = self.spark.sql("select ST_LineFromText(wkt) as geom from input_wkt") assert line_df.count() == 1 - def test_Mline_from_text(self) : - input_df = self.spark.createDataFrame([("MULTILINESTRING((1 2, 3 4), (4 5, 6 7))",)], ["wkt"]) + def test_Mline_from_text(self): + input_df = self.spark.createDataFrame( + [("MULTILINESTRING((1 2, 3 4), (4 5, 6 7))",)], ["wkt"] + ) input_df.createOrReplaceTempView("input_wkt") line_df = self.spark.sql("select ST_MLineFromText(wkt) as geom from input_wkt") assert line_df.count() == 1 - def test_MPoly_from_text(self) : - input_df = self.spark.createDataFrame([("MULTIPOLYGON (((0 0, 20 0, 20 20, 0 20, 0 0), (5 5, 5 7, 7 7, 7 5, 5 5)))",)], ["wkt"]) + def test_MPoly_from_text(self): + input_df = self.spark.createDataFrame( + [ + ( + "MULTIPOLYGON (((0 0, 20 0, 20 20, 0 20, 0 0), (5 5, 5 7, 7 7, 7 5, 5 5)))", + ) + ], + ["wkt"], + ) input_df.createOrReplaceTempView("input_wkt") line_df = self.spark.sql("select ST_MPolyFromText(wkt) as geom from input_wkt") assert line_df.count() == 1 def test_mpoint_from_text(self): - baseDf = self.spark.sql("SELECT 'MULTIPOINT ((10 10), (20 20), (30 30))' as geom, 4326 as srid") + baseDf = self.spark.sql( + "SELECT 'MULTIPOINT ((10 10), (20 20), (30 30))' as geom, 4326 as srid" + ) actual = baseDf.selectExpr("ST_AsText(ST_MPointFromText(geom))").take(1)[0][0] - expected = 'MULTIPOINT ((10 10), (20 20), (30 30))' + expected = "MULTIPOINT ((10 10), (20 20), (30 30))" assert expected == actual actualGeom = baseDf.selectExpr("ST_MPointFromText(geom, srid) as geom") @@ -240,9 +333,11 @@ def test_mpoint_from_text(self): assert actualSrid == 4326 def test_geom_coll_from_text(self): - baseDf = self.spark.sql("SELECT 'GEOMETRYCOLLECTION (POINT (50 50), LINESTRING (20 30, 40 60, 80 90), POLYGON ((30 10, 40 20, 30 20, 30 10), (35 15, 45 15, 40 25, 35 15)))' as geom, 4326 as srid") + baseDf = self.spark.sql( + "SELECT 'GEOMETRYCOLLECTION (POINT (50 50), LINESTRING (20 30, 40 60, 80 90), POLYGON ((30 10, 40 20, 30 20, 30 10), (35 15, 45 15, 40 25, 35 15)))' as geom, 4326 as srid" + ) actual = baseDf.selectExpr("ST_AsText(ST_GeomCollFromText(geom))").take(1)[0][0] - expected = 'GEOMETRYCOLLECTION (POINT (50 50), LINESTRING (20 30, 40 60, 80 90), POLYGON ((30 10, 40 20, 30 20, 30 10), (35 15, 45 15, 40 25, 35 15)))' + expected = "GEOMETRYCOLLECTION (POINT (50 50), LINESTRING (20 30, 40 60, 80 90), POLYGON ((30 10, 40 20, 30 20, 30 10), (35 15, 45 15, 40 25, 35 15)))" assert expected == actual actualGeom = baseDf.selectExpr("ST_GeomCollFromText(geom, srid) as geom") diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 08fa955f74..5508e62c06 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -38,119 +38,583 @@ test_configurations = [ # constructors - (stc.ST_GeomFromGeoHash, ("geohash", 4), "constructor", "ST_ReducePrecision(geom, 2)", "POLYGON ((0.7 1.05, 1.05 1.05, 1.05 0.88, 0.7 0.88, 0.7 1.05))"), + ( + stc.ST_GeomFromGeoHash, + ("geohash", 4), + "constructor", + "ST_ReducePrecision(geom, 2)", + "POLYGON ((0.7 1.05, 1.05 1.05, 1.05 0.88, 0.7 0.88, 0.7 1.05))", + ), (stc.ST_GeomFromGeoJSON, ("geojson",), "constructor", "", "POINT (0 1)"), - (stc.ST_GeomFromGML, ("gml",), "constructor", "", "LINESTRING (-71.16 42.25, -71.17 42.25, -71.18 42.25)"), - (stc.ST_GeomFromKML, ("kml",), "constructor", "", "LINESTRING (-71.16 42.26, -71.17 42.26)"), + ( + stc.ST_GeomFromGML, + ("gml",), + "constructor", + "", + "LINESTRING (-71.16 42.25, -71.17 42.25, -71.18 42.25)", + ), + ( + stc.ST_GeomFromKML, + ("kml",), + "constructor", + "", + "LINESTRING (-71.16 42.26, -71.17 42.26)", + ), (stc.ST_GeomFromText, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"), - (stc.ST_GeomFromText, ("wkt",4326), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"), - (stc.ST_GeometryFromText, ("wkt", 4326), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"), - (stc.ST_GeomFromWKB, ("wkbLine",), "constructor", "ST_ReducePrecision(geom, 2)", "LINESTRING (-2.1 -0.35, -1.5 -0.67)"), - (stc.ST_GeomFromEWKB, ("wkbLine",), "constructor", "ST_ReducePrecision(geom, 2)", "LINESTRING (-2.1 -0.35, -1.5 -0.67)"), + (stc.ST_GeomFromText, ("wkt", 4326), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"), + ( + stc.ST_GeometryFromText, + ("wkt", 4326), + "linestring_wkt", + "", + "LINESTRING (1 2, 3 4)", + ), + ( + stc.ST_GeomFromWKB, + ("wkbLine",), + "constructor", + "ST_ReducePrecision(geom, 2)", + "LINESTRING (-2.1 -0.35, -1.5 -0.67)", + ), + ( + stc.ST_GeomFromEWKB, + ("wkbLine",), + "constructor", + "ST_ReducePrecision(geom, 2)", + "LINESTRING (-2.1 -0.35, -1.5 -0.67)", + ), (stc.ST_GeomFromWKT, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"), - (stc.ST_GeomFromWKT, ("wkt",4326), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"), + (stc.ST_GeomFromWKT, ("wkt", 4326), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"), (stc.ST_GeomFromEWKT, ("ewkt",), "linestring_ewkt", "", "LINESTRING (1 2, 3 4)"), (stc.ST_LineFromText, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"), - (stc.ST_LineFromWKB, ("wkbLine",), "constructor", "ST_ReducePrecision(geom, 2)", "LINESTRING (-2.1 -0.35, -1.5 -0.67)"), - (stc.ST_LinestringFromWKB, ("wkbLine",), "constructor", "ST_ReducePrecision(geom, 2)", "LINESTRING (-2.1 -0.35, -1.5 -0.67)"), - (stc.ST_LineStringFromText, ("multiple_point", lambda: f.lit(',')), "constructor", "", "LINESTRING (0 0, 1 0, 1 1, 0 0)"), + ( + stc.ST_LineFromWKB, + ("wkbLine",), + "constructor", + "ST_ReducePrecision(geom, 2)", + "LINESTRING (-2.1 -0.35, -1.5 -0.67)", + ), + ( + stc.ST_LinestringFromWKB, + ("wkbLine",), + "constructor", + "ST_ReducePrecision(geom, 2)", + "LINESTRING (-2.1 -0.35, -1.5 -0.67)", + ), + ( + stc.ST_LineStringFromText, + ("multiple_point", lambda: f.lit(",")), + "constructor", + "", + "LINESTRING (0 0, 1 0, 1 1, 0 0)", + ), (stc.ST_Point, ("x", "y"), "constructor", "", "POINT (0 1)"), (stc.ST_PointZ, ("x", "y", "z", 4326), "constructor", "", "POINT Z (0 1 2)"), (stc.ST_PointZ, ("x", "y", "z"), "constructor", "", "POINT Z (0 1 2)"), - (stc.ST_MPolyFromText, ("mpoly",), "constructor", "" , "MULTIPOLYGON (((0 0, 20 0, 20 20, 0 20, 0 0), (5 5, 5 7, 7 7, 7 5, 5 5)))"), - (stc.ST_MPolyFromText, ("mpoly", 4326), "constructor", "" , "MULTIPOLYGON (((0 0, 20 0, 20 20, 0 20, 0 0), (5 5, 5 7, 7 7, 7 5, 5 5)))"), - (stc.ST_MLineFromText, ("mline", ), "constructor", "" , "MULTILINESTRING ((1 2, 3 4), (4 5, 6 7))"), - (stc.ST_MLineFromText, ("mline", 4326), "constructor", "" , "MULTILINESTRING ((1 2, 3 4), (4 5, 6 7))"), - (stc.ST_MPointFromText, ("mpoint", ), "constructor", "" , "MULTIPOINT (10 10, 20 20, 30 30)"), - (stc.ST_MPointFromText, ("mpoint", 4326), "constructor", "" , "MULTIPOINT (10 10, 20 20, 30 30)"), - (stc.ST_PointFromText, ("single_point", lambda: f.lit(',')), "constructor", "", "POINT (0 1)"), + ( + stc.ST_MPolyFromText, + ("mpoly",), + "constructor", + "", + "MULTIPOLYGON (((0 0, 20 0, 20 20, 0 20, 0 0), (5 5, 5 7, 7 7, 7 5, 5 5)))", + ), + ( + stc.ST_MPolyFromText, + ("mpoly", 4326), + "constructor", + "", + "MULTIPOLYGON (((0 0, 20 0, 20 20, 0 20, 0 0), (5 5, 5 7, 7 7, 7 5, 5 5)))", + ), + ( + stc.ST_MLineFromText, + ("mline",), + "constructor", + "", + "MULTILINESTRING ((1 2, 3 4), (4 5, 6 7))", + ), + ( + stc.ST_MLineFromText, + ("mline", 4326), + "constructor", + "", + "MULTILINESTRING ((1 2, 3 4), (4 5, 6 7))", + ), + ( + stc.ST_MPointFromText, + ("mpoint",), + "constructor", + "", + "MULTIPOINT (10 10, 20 20, 30 30)", + ), + ( + stc.ST_MPointFromText, + ("mpoint", 4326), + "constructor", + "", + "MULTIPOINT (10 10, 20 20, 30 30)", + ), + ( + stc.ST_PointFromText, + ("single_point", lambda: f.lit(",")), + "constructor", + "", + "POINT (0 1)", + ), (stc.ST_PointFromWKB, ("wkbPoint",), "constructor", "", "POINT (10 15)"), (stc.ST_MakePoint, ("x", "y", "z"), "constructor", "", "POINT Z (0 1 2)"), - (stc.ST_MakePointM, ("x", "y", "z"), "constructor", "ST_AsText(geom)", "POINT M(0 1 2)"), - (stc.ST_MakeEnvelope, ("minx", "miny", "maxx", "maxy"), "min_max_x_y", "", "POLYGON ((0 1, 0 3, 2 3, 2 1, 0 1))"), - (stc.ST_MakeEnvelope, ("minx", "miny", "maxx", "maxy", lambda: f.lit(1111)), "min_max_x_y", "", "POLYGON ((0 1, 0 3, 2 3, 2 1, 0 1))"), - (stc.ST_MakeEnvelope, ("minx", "miny", "maxx", "maxy", lambda: f.lit(1111)), "min_max_x_y", "ST_SRID(geom)", 1111), - (stc.ST_PolygonFromEnvelope, ("minx", "miny", "maxx", "maxy"), "min_max_x_y", "", "POLYGON ((0 1, 0 3, 2 3, 2 1, 0 1))"), - (stc.ST_PolygonFromEnvelope, (0.0, 1.0, 2.0, 3.0), "null", "", "POLYGON ((0 1, 0 3, 2 3, 2 1, 0 1))"), - (stc.ST_PolygonFromText, ("multiple_point", lambda: f.lit(',')), "constructor", "", "POLYGON ((0 0, 1 0, 1 1, 0 0))"), - (stc.ST_GeomCollFromText, ("collection",), "constructor", "", "GEOMETRYCOLLECTION (POINT (1 1), LINESTRING (0 0, 1 1))"), - (stc.ST_GeomCollFromText, ("collection", 4326), "constructor", "ST_SRID(geom)", 4326), - + ( + stc.ST_MakePointM, + ("x", "y", "z"), + "constructor", + "ST_AsText(geom)", + "POINT M(0 1 2)", + ), + ( + stc.ST_MakeEnvelope, + ("minx", "miny", "maxx", "maxy"), + "min_max_x_y", + "", + "POLYGON ((0 1, 0 3, 2 3, 2 1, 0 1))", + ), + ( + stc.ST_MakeEnvelope, + ("minx", "miny", "maxx", "maxy", lambda: f.lit(1111)), + "min_max_x_y", + "", + "POLYGON ((0 1, 0 3, 2 3, 2 1, 0 1))", + ), + ( + stc.ST_MakeEnvelope, + ("minx", "miny", "maxx", "maxy", lambda: f.lit(1111)), + "min_max_x_y", + "ST_SRID(geom)", + 1111, + ), + ( + stc.ST_PolygonFromEnvelope, + ("minx", "miny", "maxx", "maxy"), + "min_max_x_y", + "", + "POLYGON ((0 1, 0 3, 2 3, 2 1, 0 1))", + ), + ( + stc.ST_PolygonFromEnvelope, + (0.0, 1.0, 2.0, 3.0), + "null", + "", + "POLYGON ((0 1, 0 3, 2 3, 2 1, 0 1))", + ), + ( + stc.ST_PolygonFromText, + ("multiple_point", lambda: f.lit(",")), + "constructor", + "", + "POLYGON ((0 0, 1 0, 1 1, 0 0))", + ), + ( + stc.ST_GeomCollFromText, + ("collection",), + "constructor", + "", + "GEOMETRYCOLLECTION (POINT (1 1), LINESTRING (0 0, 1 1))", + ), + ( + stc.ST_GeomCollFromText, + ("collection", 4326), + "constructor", + "ST_SRID(geom)", + 4326, + ), # functions (stf.GeometryType, ("line",), "linestring_geom", "", "LINESTRING"), (stf.ST_3DDistance, ("a", "b"), "two_points", "", 5.0), - (stf.ST_Affine, ("geom", 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0), "square_geom", "", "POLYGON ((2 3, 4 5, 5 6, 3 4, 2 3))"), - (stf.ST_Affine, ("geom", 1.0, 2.0, 1.0, 2.0, 1.0, 2.0,), "square_geom", "", "POLYGON ((2 3, 4 5, 5 6, 3 4, 2 3))"), - (stf.ST_AddMeasure, ("line", 10.0, 40.0), "linestring_geom", "ST_AsText(geom)", "LINESTRING M(0 0 10, 1 0 16, 2 0 22, 3 0 28, 4 0 34, 5 0 40)"), - (stf.ST_AddPoint, ("line", lambda: f.expr("ST_Point(1.0, 1.0)")), "linestring_geom", "", "LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0, 1 1)"), - (stf.ST_AddPoint, ("line", lambda: f.expr("ST_Point(1.0, 1.0)"), 1), "linestring_geom", "", "LINESTRING (0 0, 1 1, 1 0, 2 0, 3 0, 4 0, 5 0)"), - (stf.ST_Angle, ("p1", "p2", "p3", "p4", ), "four_points", "", 0.4048917862850834), - (stf.ST_Angle, ("p1", "p2", "p3",), "three_points", "", 0.19739555984988078), + ( + stf.ST_Affine, + ("geom", 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0), + "square_geom", + "", + "POLYGON ((2 3, 4 5, 5 6, 3 4, 2 3))", + ), + ( + stf.ST_Affine, + ( + "geom", + 1.0, + 2.0, + 1.0, + 2.0, + 1.0, + 2.0, + ), + "square_geom", + "", + "POLYGON ((2 3, 4 5, 5 6, 3 4, 2 3))", + ), + ( + stf.ST_AddMeasure, + ("line", 10.0, 40.0), + "linestring_geom", + "ST_AsText(geom)", + "LINESTRING M(0 0 10, 1 0 16, 2 0 22, 3 0 28, 4 0 34, 5 0 40)", + ), + ( + stf.ST_AddPoint, + ("line", lambda: f.expr("ST_Point(1.0, 1.0)")), + "linestring_geom", + "", + "LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0, 1 1)", + ), + ( + stf.ST_AddPoint, + ("line", lambda: f.expr("ST_Point(1.0, 1.0)"), 1), + "linestring_geom", + "", + "LINESTRING (0 0, 1 1, 1 0, 2 0, 3 0, 4 0, 5 0)", + ), + ( + stf.ST_Angle, + ( + "p1", + "p2", + "p3", + "p4", + ), + "four_points", + "", + 0.4048917862850834, + ), + ( + stf.ST_Angle, + ( + "p1", + "p2", + "p3", + ), + "three_points", + "", + 0.19739555984988078, + ), (stf.ST_Angle, ("line1", "line2"), "two_lines", "", 0.19739555984988078), (stf.ST_Degrees, ("angleRad",), "two_lines_angle_rad", "", 11.309932474020213), (stf.ST_Area, ("geom",), "triangle_geom", "", 0.5), (stf.ST_AreaSpheroid, ("point",), "point_geom", "", 0.0), - (stf.ST_AsBinary, ("point",), "point_geom", "", "01010000000000000000000000000000000000f03f"), - (stf.ST_AsEWKB, (lambda: f.expr("ST_SetSRID(point, 3021)"),), "point_geom", "", "0101000020cd0b00000000000000000000000000000000f03f"), - (stf.ST_AsHEXEWKB, ("point",), "point_geom", "", "01010000000000000000000000000000000000F03F"), - (stf.ST_AsEWKT, (lambda: f.expr("ST_SetSRID(point, 4326)"),), "point_geom", "", "SRID=4326;POINT (0 1)"), - (stf.ST_AsGeoJSON, ("point",), "point_geom", "", "{\"type\":\"Point\",\"coordinates\":[0.0,1.0]}"), - (stf.ST_AsGeoJSON, ("point", lambda: f.lit("feature")), "point_geom", "", "{\"type\":\"Feature\",\"geometry\":{\"type\":\"Point\",\"coordinates\":[0.0,1.0]},\"properties\":{}}"), - (stf.ST_AsGeoJSON, ("point", lambda: f.lit("featurecollection")), "point_geom", "", "{\"type\":\"FeatureCollection\",\"features\":[{\"type\":\"Feature\",\"geometry\":{\"type\":\"Point\",\"coordinates\":[0.0,1.0]},\"properties\":{}}]}"), - (stf.ST_AsGML, ("point",), "point_geom", "", "\n \n 0.0,1.0 \n \n\n"), - (stf.ST_AsKML, ("point",), "point_geom", "", "\n 0.0,1.0\n\n"), + ( + stf.ST_AsBinary, + ("point",), + "point_geom", + "", + "01010000000000000000000000000000000000f03f", + ), + ( + stf.ST_AsEWKB, + (lambda: f.expr("ST_SetSRID(point, 3021)"),), + "point_geom", + "", + "0101000020cd0b00000000000000000000000000000000f03f", + ), + ( + stf.ST_AsHEXEWKB, + ("point",), + "point_geom", + "", + "01010000000000000000000000000000000000F03F", + ), + ( + stf.ST_AsEWKT, + (lambda: f.expr("ST_SetSRID(point, 4326)"),), + "point_geom", + "", + "SRID=4326;POINT (0 1)", + ), + ( + stf.ST_AsGeoJSON, + ("point",), + "point_geom", + "", + '{"type":"Point","coordinates":[0.0,1.0]}', + ), + ( + stf.ST_AsGeoJSON, + ("point", lambda: f.lit("feature")), + "point_geom", + "", + '{"type":"Feature","geometry":{"type":"Point","coordinates":[0.0,1.0]},"properties":{}}', + ), + ( + stf.ST_AsGeoJSON, + ("point", lambda: f.lit("featurecollection")), + "point_geom", + "", + '{"type":"FeatureCollection","features":[{"type":"Feature","geometry":{"type":"Point","coordinates":[0.0,1.0]},"properties":{}}]}', + ), + ( + stf.ST_AsGML, + ("point",), + "point_geom", + "", + "\n \n 0.0,1.0 \n \n\n", + ), + ( + stf.ST_AsKML, + ("point",), + "point_geom", + "", + "\n 0.0,1.0\n\n", + ), (stf.ST_AsText, ("point",), "point_geom", "", "POINT (0 1)"), (stf.ST_Azimuth, ("a", "b"), "two_points", "geom * 180.0 / pi()", 90.0), (stf.ST_BestSRID, ("geom",), "triangle_geom", "", 3395), - (stf.ST_Boundary, ("geom",), "triangle_geom", "", "LINESTRING (0 0, 1 0, 1 1, 0 0)"), - (stf.ST_Buffer, ("point", 1.0), "point_geom", "ST_ReducePrecision(geom, 2)", "POLYGON ((0.98 0.8, 0.92 0.62, 0.83 0.44, 0.71 0.29, 0.56 0.17, 0.38 0.08, 0.2 0.02, 0 0, -0.2 0.02, -0.38 0.08, -0.56 0.17, -0.71 0.29, -0.83 0.44, -0.92 0.62, -0.98 0.8, -1 1, -0.98 1.2, -0.92 1.38, -0.83 1.56, -0.71 1.71, -0.56 1.83, -0.38 1.92, -0.2 1.98, 0 2, 0.2 1.98, 0.38 1.92, 0.56 1.83, 0.71 1.71, 0.83 1.56, 0.92 1.38, 0.98 1.2, 1 1, 0.98 0.8))"), - (stf.ST_Buffer, ("point", 1.0, True), "point_geom", "", "POLYGON ((0.0000089758113634 1.0000000082631704, 0.0000088049473096 0.9999982455016537, 0.0000082957180969 0.9999965501645043, 0.0000074676931201 0.9999949874025787, 0.0000063526929175 0.9999936172719434, 0.0000049935663218 0.9999924924259491, 0.000003442543808 0.9999916560917964, 0.0000017592303026 0.9999911404093366, 0.0000000083146005 0.999990965195956, -0.0000017429165916 0.9999911371850047, -0.0000034271644398 0.9999916497670388, -0.0000049797042467 0.9999924832438165, -0.0000063408727792 0.9999936055852924, -0.0000074583610959 0.9999949736605123, -0.0000082892247492 0.9999965348951139, -0.0000088015341156 0.999998229291728, -0.0000089756014341 0.9999999917356441, -0.0000088047373942 1.0000017544971334, -0.0000082955082053 1.000003449834261, -0.0000074674832591 1.000005012596174, -0.000006352483086 1.0000063827268086, -0.0000049933565169 1.000007507572813, -0.0000034423340222 1.0000083439069856, -0.0000017590205255 1.0000088595894723, -0.0000000081048183 1.0000090348028825, 0.0000017431263877 1.0000088628138615, 0.0000034273742602 1.0000083502318489, 0.000004979914097 1.0000075167550835, 0.0000063410826597 1.000006394413609, 0.000007458571003 1.0000050263383786, 0.000008289434675 1.000003465103757, 0.0000088017440489 1.0000017707071163, 0.0000089758113634 1.0000000082631704))"), - (stf.ST_BuildArea, ("geom",), "multiline_geom", "ST_Normalize(geom)", "POLYGON ((0 0, 1 1, 1 0, 0 0))"), - (stf.ST_BoundingDiagonal, ("geom",), "square_geom", "ST_BoundingDiagonal(geom)", "LINESTRING (1 0, 2 1)"), - (stf.ST_Centroid, ("geom",), "triangle_geom", "ST_ReducePrecision(geom, 2)", "POINT (0.67 0.33)"), - (stf.ST_Collect, (lambda: f.expr("array(a, b)"),), "two_points", "", "MULTIPOINT Z (0 0 0, 3 0 4)"), + ( + stf.ST_Boundary, + ("geom",), + "triangle_geom", + "", + "LINESTRING (0 0, 1 0, 1 1, 0 0)", + ), + ( + stf.ST_Buffer, + ("point", 1.0), + "point_geom", + "ST_ReducePrecision(geom, 2)", + "POLYGON ((0.98 0.8, 0.92 0.62, 0.83 0.44, 0.71 0.29, 0.56 0.17, 0.38 0.08, 0.2 0.02, 0 0, -0.2 0.02, -0.38 0.08, -0.56 0.17, -0.71 0.29, -0.83 0.44, -0.92 0.62, -0.98 0.8, -1 1, -0.98 1.2, -0.92 1.38, -0.83 1.56, -0.71 1.71, -0.56 1.83, -0.38 1.92, -0.2 1.98, 0 2, 0.2 1.98, 0.38 1.92, 0.56 1.83, 0.71 1.71, 0.83 1.56, 0.92 1.38, 0.98 1.2, 1 1, 0.98 0.8))", + ), + ( + stf.ST_Buffer, + ("point", 1.0, True), + "point_geom", + "", + "POLYGON ((0.0000089758113634 1.0000000082631704, 0.0000088049473096 0.9999982455016537, 0.0000082957180969 0.9999965501645043, 0.0000074676931201 0.9999949874025787, 0.0000063526929175 0.9999936172719434, 0.0000049935663218 0.9999924924259491, 0.000003442543808 0.9999916560917964, 0.0000017592303026 0.9999911404093366, 0.0000000083146005 0.999990965195956, -0.0000017429165916 0.9999911371850047, -0.0000034271644398 0.9999916497670388, -0.0000049797042467 0.9999924832438165, -0.0000063408727792 0.9999936055852924, -0.0000074583610959 0.9999949736605123, -0.0000082892247492 0.9999965348951139, -0.0000088015341156 0.999998229291728, -0.0000089756014341 0.9999999917356441, -0.0000088047373942 1.0000017544971334, -0.0000082955082053 1.000003449834261, -0.0000074674832591 1.000005012596174, -0.000006352483086 1.0000063827268086, -0.0000049933565169 1.000007507572813, -0.0000034423340222 1.0000083439069856, -0.0000017590205255 1.0000088595894723, -0.0000000081048183 1.0000090348028825, 0.0000017431263877 1.0000088628138615, 0.0000034273742602 1.0000083502318489, 0.000004979914097 1.0000075167550835, 0.0000063410826597 1.000006394413609, 0.000007458571003 1.0000050263383786, 0.000008289434675 1.000003465103757, 0.0000088017440489 1.0000017707071163, 0.0000089758113634 1.0000000082631704))", + ), + ( + stf.ST_BuildArea, + ("geom",), + "multiline_geom", + "ST_Normalize(geom)", + "POLYGON ((0 0, 1 1, 1 0, 0 0))", + ), + ( + stf.ST_BoundingDiagonal, + ("geom",), + "square_geom", + "ST_BoundingDiagonal(geom)", + "LINESTRING (1 0, 2 1)", + ), + ( + stf.ST_Centroid, + ("geom",), + "triangle_geom", + "ST_ReducePrecision(geom, 2)", + "POINT (0.67 0.33)", + ), + ( + stf.ST_Collect, + (lambda: f.expr("array(a, b)"),), + "two_points", + "", + "MULTIPOINT Z (0 0 0, 3 0 4)", + ), (stf.ST_Collect, ("a", "b"), "two_points", "", "MULTIPOINT Z (0 0 0, 3 0 4)"), - (stf.ST_ClosestPoint, ("point", "line",), "point_and_line", "", "POINT (0 1)"), - (stf.ST_CollectionExtract, ("geom",), "geom_collection", "", "MULTILINESTRING ((0 0, 1 0))"), + ( + stf.ST_ClosestPoint, + ( + "point", + "line", + ), + "point_and_line", + "", + "POINT (0 1)", + ), + ( + stf.ST_CollectionExtract, + ("geom",), + "geom_collection", + "", + "MULTILINESTRING ((0 0, 1 0))", + ), (stf.ST_CollectionExtract, ("geom", 1), "geom_collection", "", "MULTIPOINT (0 0)"), - (stf.ST_ConcaveHull, ("geom", 1.0), "triangle_geom", "", "POLYGON ((0 0, 1 1, 1 0, 0 0))"), - (stf.ST_ConcaveHull, ("geom", 1.0, True), "triangle_geom", "", "POLYGON ((1 1, 1 0, 0 0, 1 1))"), - (stf.ST_ConvexHull, ("geom",), "triangle_geom", "", "POLYGON ((0 0, 1 1, 1 0, 0 0))"), + ( + stf.ST_ConcaveHull, + ("geom", 1.0), + "triangle_geom", + "", + "POLYGON ((0 0, 1 1, 1 0, 0 0))", + ), + ( + stf.ST_ConcaveHull, + ("geom", 1.0, True), + "triangle_geom", + "", + "POLYGON ((1 1, 1 0, 0 0, 1 1))", + ), + ( + stf.ST_ConvexHull, + ("geom",), + "triangle_geom", + "", + "POLYGON ((0 0, 1 1, 1 0, 0 0))", + ), (stf.ST_CoordDim, ("point",), "point_geom", "", 2), (stf.ST_CrossesDateLine, ("line",), "line_crossing_dateline", "", True), - (stf.ST_Difference, ("a", "b"), "overlapping_polys", "", "POLYGON ((1 0, 0 0, 0 1, 1 1, 1 0))"), + ( + stf.ST_Difference, + ("a", "b"), + "overlapping_polys", + "", + "POLYGON ((1 0, 0 0, 0 1, 1 1, 1 0))", + ), (stf.ST_Dimension, ("geom",), "geometry_geom_collection", "", 1), (stf.ST_Distance, ("a", "b"), "two_points", "", 3.0), (stf.ST_DistanceSpheroid, ("point", "point"), "point_geom", "", 0.0), (stf.ST_DistanceSphere, ("point", "point"), "point_geom", "", 0.0), (stf.ST_DistanceSphere, ("point", "point", 6378137.0), "point_geom", "", 0.0), - (stf.ST_DelaunayTriangles, ("multipoint", ), "multipoint_geom", "", "GEOMETRYCOLLECTION (POLYGON ((10 40, 20 20, 40 30, 10 40)), POLYGON ((40 30, 20 20, 30 10, 40 30)))"), + ( + stf.ST_DelaunayTriangles, + ("multipoint",), + "multipoint_geom", + "", + "GEOMETRYCOLLECTION (POLYGON ((10 40, 20 20, 40 30, 10 40)), POLYGON ((40 30, 20 20, 30 10, 40 30)))", + ), (stf.ST_Dump, ("geom",), "multipoint", "", ["POINT (0 0)", "POINT (1 1)"]), - (stf.ST_DumpPoints, ("line",), "linestring_geom", "", ["POINT (0 0)", "POINT (1 0)", "POINT (2 0)", "POINT (3 0)", "POINT (4 0)", "POINT (5 0)"]), + ( + stf.ST_DumpPoints, + ("line",), + "linestring_geom", + "", + [ + "POINT (0 0)", + "POINT (1 0)", + "POINT (2 0)", + "POINT (3 0)", + "POINT (4 0)", + "POINT (5 0)", + ], + ), (stf.ST_EndPoint, ("line",), "linestring_geom", "", "POINT (5 0)"), - (stf.ST_Envelope, ("geom",), "triangle_geom", "", "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))"), - (stf.ST_Expand, ("geom", 2.0), "triangle_geom", "", "POLYGON ((-2 -2, -2 3, 3 3, 3 -2, -2 -2))"), - (stf.ST_Expand, ("geom", 2.0, 2.0), "triangle_geom", "", "POLYGON ((-2 -2, -2 3, 3 3, 3 -2, -2 -2))"), - (stf.ST_ExteriorRing, ("geom",), "triangle_geom", "", "LINESTRING (0 0, 1 0, 1 1, 0 0)"), + ( + stf.ST_Envelope, + ("geom",), + "triangle_geom", + "", + "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))", + ), + ( + stf.ST_Expand, + ("geom", 2.0), + "triangle_geom", + "", + "POLYGON ((-2 -2, -2 3, 3 3, 3 -2, -2 -2))", + ), + ( + stf.ST_Expand, + ("geom", 2.0, 2.0), + "triangle_geom", + "", + "POLYGON ((-2 -2, -2 3, 3 3, 3 -2, -2 -2))", + ), + ( + stf.ST_ExteriorRing, + ("geom",), + "triangle_geom", + "", + "LINESTRING (0 0, 1 0, 1 1, 0 0)", + ), (stf.ST_FlipCoordinates, ("point",), "point_geom", "", "POINT (1 0)"), (stf.ST_Force_2D, ("point",), "point_geom", "", "POINT (0 1)"), (stf.ST_Force3D, ("point", 1.0), "point_geom", "", "POINT Z (0 1 1)"), - (stf.ST_Force3DM, ("point", 1.0), "point_geom", "ST_AsText(geom)", "POINT M(0 1 1)"), + ( + stf.ST_Force3DM, + ("point", 1.0), + "point_geom", + "ST_AsText(geom)", + "POINT M(0 1 1)", + ), (stf.ST_Force3DZ, ("point", 1.0), "point_geom", "", "POINT Z (0 1 1)"), - (stf.ST_Force4D, ("point", 1.0, 1.0), "point_geom", "ST_AsText(geom)", "POINT ZM(0 1 1 1)"), - (stf.ST_ForceCollection, ("multipoint",), "multipoint_geom", "ST_NumGeometries(geom)", 4), - (stf.ST_ForcePolygonCW, ("geom",), "geom_with_hole", "", "POLYGON ((0 0, 3 3, 3 0, 0 0), (1 1, 2 1, 2 2, 1 1))"), - (stf.ST_ForcePolygonCCW, ("geom",), "geom_with_hole", "", "POLYGON ((0 0, 3 0, 3 3, 0 0), (1 1, 2 2, 2 1, 1 1))"), - (stf.ST_ForceRHR, ("geom",), "geom_with_hole", "", "POLYGON ((0 0, 3 3, 3 0, 0 0), (1 1, 2 1, 2 2, 1 1))"), - (stf.ST_FrechetDistance, ("point", "line",), "point_and_line", "", 5.0990195135927845), - (stf.ST_GeometricMedian, ("multipoint",), "multipoint_geom", "", "POINT (22.500002656424286 21.250001168173426)"), + ( + stf.ST_Force4D, + ("point", 1.0, 1.0), + "point_geom", + "ST_AsText(geom)", + "POINT ZM(0 1 1 1)", + ), + ( + stf.ST_ForceCollection, + ("multipoint",), + "multipoint_geom", + "ST_NumGeometries(geom)", + 4, + ), + ( + stf.ST_ForcePolygonCW, + ("geom",), + "geom_with_hole", + "", + "POLYGON ((0 0, 3 3, 3 0, 0 0), (1 1, 2 1, 2 2, 1 1))", + ), + ( + stf.ST_ForcePolygonCCW, + ("geom",), + "geom_with_hole", + "", + "POLYGON ((0 0, 3 0, 3 3, 0 0), (1 1, 2 2, 2 1, 1 1))", + ), + ( + stf.ST_ForceRHR, + ("geom",), + "geom_with_hole", + "", + "POLYGON ((0 0, 3 3, 3 0, 0 0), (1 1, 2 1, 2 2, 1 1))", + ), + ( + stf.ST_FrechetDistance, + ( + "point", + "line", + ), + "point_and_line", + "", + 5.0990195135927845, + ), + ( + stf.ST_GeometricMedian, + ("multipoint",), + "multipoint_geom", + "", + "POINT (22.500002656424286 21.250001168173426)", + ), (stf.ST_GeneratePoints, ("geom", 15), "square_geom", "ST_NumGeometries(geom)", 15), - (stf.ST_GeneratePoints, ("geom", 15, 100), "square_geom", "ST_NumGeometries(geom)", 15), + ( + stf.ST_GeneratePoints, + ("geom", 15, 100), + "square_geom", + "ST_NumGeometries(geom)", + 15, + ), (stf.ST_GeometryN, ("geom", 0), "multipoint", "", "POINT (0 0)"), (stf.ST_GeometryType, ("point",), "point_geom", "", "ST_Point"), - (stf.ST_HausdorffDistance, ("point", "line",), "point_and_line", "", 5.0990195135927845), - (stf.ST_InteriorRingN, ("geom", 0), "geom_with_hole", "", "LINESTRING (1 1, 2 2, 2 1, 1 1)"), - (stf.ST_Intersection, ("a", "b"), "overlapping_polys", "", "POLYGON ((2 0, 1 0, 1 1, 2 1, 2 0))"), + ( + stf.ST_HausdorffDistance, + ( + "point", + "line", + ), + "point_and_line", + "", + 5.0990195135927845, + ), + ( + stf.ST_InteriorRingN, + ("geom", 0), + "geom_with_hole", + "", + "LINESTRING (1 1, 2 2, 2 1, 1 1)", + ), + ( + stf.ST_Intersection, + ("a", "b"), + "overlapping_polys", + "", + "POLYGON ((2 0, 1 0, 1 1, 2 1, 2 0))", + ), (stf.ST_IsCollection, ("geom",), "geom_collection", "", True), (stf.ST_IsClosed, ("geom",), "closed_linestring_geom", "", True), (stf.ST_IsEmpty, ("geom",), "empty_geom", "", True), @@ -162,41 +626,179 @@ (stf.ST_IsValid, ("geom",), "triangle_geom", "", True), (stf.ST_IsValid, ("geom", 1), "triangle_geom", "", True), (stf.ST_IsValid, ("geom", 0), "triangle_geom", "", True), - (stf.ST_IsValidDetail, ("geom",), "triangle_geom", "", Row(valid=True, reason=None, location=None).asDict()), - (stf.ST_IsValidDetail, ("geom", 1), "triangle_geom", "", Row(valid=True, reason=None, location=None).asDict()), + ( + stf.ST_IsValidDetail, + ("geom",), + "triangle_geom", + "", + Row(valid=True, reason=None, location=None).asDict(), + ), + ( + stf.ST_IsValidDetail, + ("geom", 1), + "triangle_geom", + "", + Row(valid=True, reason=None, location=None).asDict(), + ), (stf.ST_Length, ("line",), "linestring_geom", "", 5.0), (stf.ST_Length2D, ("line",), "linestring_geom", "", 5.0), (stf.ST_LengthSpheroid, ("point",), "point_geom", "", 0.0), - (stf.ST_LineFromMultiPoint, ("multipoint",), "multipoint_geom", "", "LINESTRING (10 40, 40 30, 20 20, 30 10)"), - (stf.ST_LineInterpolatePoint, ("line", 0.5), "linestring_geom", "", "POINT (2.5 0)"), + ( + stf.ST_LineFromMultiPoint, + ("multipoint",), + "multipoint_geom", + "", + "LINESTRING (10 40, 40 30, 20 20, 30 10)", + ), + ( + stf.ST_LineInterpolatePoint, + ("line", 0.5), + "linestring_geom", + "", + "POINT (2.5 0)", + ), (stf.ST_LineLocatePoint, ("line", "point"), "line_and_point", "", 0.5), - (stf.ST_LineMerge, ("geom",), "multiline_geom", "", "LINESTRING (0 0, 1 0, 1 1, 0 0)"), - (stf.ST_LineSubstring, ("line", 0.5, 1.0), "linestring_geom", "", "LINESTRING (2.5 0, 3 0, 4 0, 5 0)"), - (stf.ST_LongestLine, ("geom", "geom"), "geom_collection", "", "LINESTRING (0 0, 1 0)"), - (stf.ST_LocateAlong, ("line", 1.0), "4D_line", "ST_AsText(geom)", "MULTIPOINT ZM((1 1 1 1))"), - (stf.ST_LocateAlong, ("line", 1.0, 2.0), "4D_line", "ST_AsText(geom)", "MULTIPOINT ZM((-0.4142135623730949 2.414213562373095 1 1), (2.414213562373095 -0.4142135623730949 1 1))"), + ( + stf.ST_LineMerge, + ("geom",), + "multiline_geom", + "", + "LINESTRING (0 0, 1 0, 1 1, 0 0)", + ), + ( + stf.ST_LineSubstring, + ("line", 0.5, 1.0), + "linestring_geom", + "", + "LINESTRING (2.5 0, 3 0, 4 0, 5 0)", + ), + ( + stf.ST_LongestLine, + ("geom", "geom"), + "geom_collection", + "", + "LINESTRING (0 0, 1 0)", + ), + ( + stf.ST_LocateAlong, + ("line", 1.0), + "4D_line", + "ST_AsText(geom)", + "MULTIPOINT ZM((1 1 1 1))", + ), + ( + stf.ST_LocateAlong, + ("line", 1.0, 2.0), + "4D_line", + "ST_AsText(geom)", + "MULTIPOINT ZM((-0.4142135623730949 2.414213562373095 1 1), (2.414213562373095 -0.4142135623730949 1 1))", + ), (stf.ST_HasZ, ("a",), "two_points", "", True), (stf.ST_HasM, ("point",), "4D_point", "", True), (stf.ST_M, ("point",), "4D_point", "", 4.0), (stf.ST_MMin, ("line",), "4D_line", "", -1.0), (stf.ST_MMax, ("line",), "4D_line", "", 3.0), - (stf.ST_MakeValid, ("geom",), "invalid_geom", "", "MULTIPOLYGON (((1 5, 3 3, 1 1, 1 5)), ((5 3, 7 5, 7 1, 5 3)))"), - (stf.ST_MakeLine, ("line1", "line2"), "two_lines", "", "LINESTRING (0 0, 1 1, 0 0, 3 2)"), - (stf.ST_MaximumInscribedCircle, ("geom",), "triangle_geom", "ST_AsText(geom.center)", "POINT (0.70703125 0.29296875)"), - (stf.ST_MaximumInscribedCircle, ("geom",), "triangle_geom", "ST_AsText(geom.nearest)", "POINT (0.5 0.5)"), - (stf.ST_MaximumInscribedCircle, ("geom",), "triangle_geom", "geom.radius", 0.2927864015850548), + ( + stf.ST_MakeValid, + ("geom",), + "invalid_geom", + "", + "MULTIPOLYGON (((1 5, 3 3, 1 1, 1 5)), ((5 3, 7 5, 7 1, 5 3)))", + ), + ( + stf.ST_MakeLine, + ("line1", "line2"), + "two_lines", + "", + "LINESTRING (0 0, 1 1, 0 0, 3 2)", + ), + ( + stf.ST_MaximumInscribedCircle, + ("geom",), + "triangle_geom", + "ST_AsText(geom.center)", + "POINT (0.70703125 0.29296875)", + ), + ( + stf.ST_MaximumInscribedCircle, + ("geom",), + "triangle_geom", + "ST_AsText(geom.nearest)", + "POINT (0.5 0.5)", + ), + ( + stf.ST_MaximumInscribedCircle, + ("geom",), + "triangle_geom", + "geom.radius", + 0.2927864015850548, + ), (stf.ST_MaxDistance, ("a", "b"), "overlapping_polys", "", 3.1622776601683795), - (stf.ST_Points, ("line",), "linestring_geom", "ST_Normalize(geom)", "MULTIPOINT (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)"), - (stf.ST_Polygon, ("geom", 4236), "closed_linestring_geom", "", "POLYGON ((0 0, 1 0, 1 1, 0 0))"), - (stf.ST_Polygonize, ("geom",), "noded_linework", "ST_Normalize(geom)", "GEOMETRYCOLLECTION (POLYGON ((0 2, 1 3, 2 4, 2 3, 2 2, 1 2, 0 2)), POLYGON ((2 2, 2 3, 2 4, 3 3, 4 2, 3 2, 2 2)))"), - (stf.ST_MakePolygon, ("geom",), "closed_linestring_geom", "", "POLYGON ((0 0, 1 0, 1 1, 0 0))"), + ( + stf.ST_Points, + ("line",), + "linestring_geom", + "ST_Normalize(geom)", + "MULTIPOINT (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)", + ), + ( + stf.ST_Polygon, + ("geom", 4236), + "closed_linestring_geom", + "", + "POLYGON ((0 0, 1 0, 1 1, 0 0))", + ), + ( + stf.ST_Polygonize, + ("geom",), + "noded_linework", + "ST_Normalize(geom)", + "GEOMETRYCOLLECTION (POLYGON ((0 2, 1 3, 2 4, 2 3, 2 2, 1 2, 0 2)), POLYGON ((2 2, 2 3, 2 4, 3 3, 4 2, 3 2, 2 2)))", + ), + ( + stf.ST_MakePolygon, + ("geom",), + "closed_linestring_geom", + "", + "POLYGON ((0 0, 1 0, 1 1, 0 0))", + ), (stf.ST_MinimumClearance, ("geom",), "invalid_geom", "", 2.0), - (stf.ST_MinimumClearanceLine, ("geom",), "invalid_geom", "", "LINESTRING (5 3, 3 3)"), - (stf.ST_MinimumBoundingCircle, ("line", 8), "linestring_geom", "ST_ReducePrecision(geom, 2)", "POLYGON ((4.95 -0.49, 4.81 -0.96, 4.58 -1.39, 4.27 -1.77, 3.89 -2.08, 3.46 -2.31, 2.99 -2.45, 2.5 -2.5, 2.01 -2.45, 1.54 -2.31, 1.11 -2.08, 0.73 -1.77, 0.42 -1.39, 0.19 -0.96, 0.05 -0.49, 0 0, 0.05 0.49, 0.19 0.96, 0.42 1.39, 0.73 1.77, 1.11 2.08, 1.54 2.31, 2.01 2.45, 2.5 2.5, 2.99 2.45, 3.46 2.31, 3.89 2.08, 4.27 1.77, 4.58 1.39, 4.81 0.96, 4.95 0.49, 5 0, 4.95 -0.49))"), - (stf.ST_MinimumBoundingCircle, ("line", 2), "linestring_geom", "ST_ReducePrecision(geom, 2)", "POLYGON ((4.27 -1.77, 2.5 -2.5, 0.73 -1.77, 0 0, 0.73 1.77, 2.5 2.5, 4.27 1.77, 5 0, 4.27 -1.77))"), - (stf.ST_MinimumBoundingRadius, ("line",), "linestring_geom", "", {"center": "POINT (2.5 0)", "radius": 2.5}), + ( + stf.ST_MinimumClearanceLine, + ("geom",), + "invalid_geom", + "", + "LINESTRING (5 3, 3 3)", + ), + ( + stf.ST_MinimumBoundingCircle, + ("line", 8), + "linestring_geom", + "ST_ReducePrecision(geom, 2)", + "POLYGON ((4.95 -0.49, 4.81 -0.96, 4.58 -1.39, 4.27 -1.77, 3.89 -2.08, 3.46 -2.31, 2.99 -2.45, 2.5 -2.5, 2.01 -2.45, 1.54 -2.31, 1.11 -2.08, 0.73 -1.77, 0.42 -1.39, 0.19 -0.96, 0.05 -0.49, 0 0, 0.05 0.49, 0.19 0.96, 0.42 1.39, 0.73 1.77, 1.11 2.08, 1.54 2.31, 2.01 2.45, 2.5 2.5, 2.99 2.45, 3.46 2.31, 3.89 2.08, 4.27 1.77, 4.58 1.39, 4.81 0.96, 4.95 0.49, 5 0, 4.95 -0.49))", + ), + ( + stf.ST_MinimumBoundingCircle, + ("line", 2), + "linestring_geom", + "ST_ReducePrecision(geom, 2)", + "POLYGON ((4.27 -1.77, 2.5 -2.5, 0.73 -1.77, 0 0, 0.73 1.77, 2.5 2.5, 4.27 1.77, 5 0, 4.27 -1.77))", + ), + ( + stf.ST_MinimumBoundingRadius, + ("line",), + "linestring_geom", + "", + {"center": "POINT (2.5 0)", "radius": 2.5}, + ), (stf.ST_Multi, ("point",), "point_geom", "", "MULTIPOINT (0 1)"), - (stf.ST_Normalize, ("geom",), "triangle_geom", "", "POLYGON ((0 0, 1 1, 1 0, 0 0))"), + ( + stf.ST_Normalize, + ("geom",), + "triangle_geom", + "", + "POLYGON ((0 0, 1 1, 1 0, 0 0))", + ), (stf.ST_NPoints, ("line",), "linestring_geom", "", 6), (stf.ST_NRings, ("geom",), "square_geom", "", 1), (stf.ST_NumGeometries, ("geom",), "multipoint", "", 2), @@ -205,37 +807,203 @@ (stf.ST_NumPoints, ("line",), "linestring_geom", "", 6), (stf.ST_PointN, ("line", 2), "linestring_geom", "", "POINT (1 0)"), (stf.ST_PointOnSurface, ("line",), "linestring_geom", "", "POINT (2 0)"), - (stf.ST_ReducePrecision, ("geom", 1), "precision_reduce_point", "", "POINT (0.1 0.2)"), - (stf.ST_RemovePoint, ("line", 1), "linestring_geom", "", "LINESTRING (0 0, 2 0, 3 0, 4 0, 5 0)"), - (stf.ST_RemoveRepeatedPoints, ("geom",), "repeated_multipoint", "", "MULTIPOINT (1 1, 2 2, 3 3, 4 4)"), - (stf.ST_Reverse, ("line",), "linestring_geom", "", "LINESTRING (5 0, 4 0, 3 0, 2 0, 1 0, 0 0)"), - (stf.ST_RotateX, ("line", 10.0), "4D_line", "ST_ReducePrecision(geom, 2)", "LINESTRING Z (1 -0.3 -1.383092639965822, 2 -0.59 -2.766185279931644, 3 -0.89 -4.149277919897466, -1 0.3 1.383092639965822)"), - (stf.ST_RotateY, ("line", 10.0), "4D_line", "ST_AsText(ST_ReducePrecision(geom, 2))", "LINESTRING ZM(-1.38 1 -0.2950504181870827 1, -2.77 2 -0.5901008363741653 2, -4.15 3 -0.8851512545612479 3, 1.38 -1 0.2950504181870827 -1)"), - (stf.ST_Rotate, ("line", 10.0), "linestring_geom", "ST_ReducePrecision(geom, 2)", "LINESTRING (0 0, -0.84 -0.54, -1.68 -1.09, -2.52 -1.63, -3.36 -2.18, -4.2 -2.72)"), - (stf.ST_Rotate, ("line", 10.0, 0.0, 0.0), "linestring_geom", "ST_ReducePrecision(geom, 2)", "LINESTRING (0 0, -0.84 -0.54, -1.68 -1.09, -2.52 -1.63, -3.36 -2.18, -4.2 -2.72)"), + ( + stf.ST_ReducePrecision, + ("geom", 1), + "precision_reduce_point", + "", + "POINT (0.1 0.2)", + ), + ( + stf.ST_RemovePoint, + ("line", 1), + "linestring_geom", + "", + "LINESTRING (0 0, 2 0, 3 0, 4 0, 5 0)", + ), + ( + stf.ST_RemoveRepeatedPoints, + ("geom",), + "repeated_multipoint", + "", + "MULTIPOINT (1 1, 2 2, 3 3, 4 4)", + ), + ( + stf.ST_Reverse, + ("line",), + "linestring_geom", + "", + "LINESTRING (5 0, 4 0, 3 0, 2 0, 1 0, 0 0)", + ), + ( + stf.ST_RotateX, + ("line", 10.0), + "4D_line", + "ST_ReducePrecision(geom, 2)", + "LINESTRING Z (1 -0.3 -1.383092639965822, 2 -0.59 -2.766185279931644, 3 -0.89 -4.149277919897466, -1 0.3 1.383092639965822)", + ), + ( + stf.ST_RotateY, + ("line", 10.0), + "4D_line", + "ST_AsText(ST_ReducePrecision(geom, 2))", + "LINESTRING ZM(-1.38 1 -0.2950504181870827 1, -2.77 2 -0.5901008363741653 2, -4.15 3 -0.8851512545612479 3, 1.38 -1 0.2950504181870827 -1)", + ), + ( + stf.ST_Rotate, + ("line", 10.0), + "linestring_geom", + "ST_ReducePrecision(geom, 2)", + "LINESTRING (0 0, -0.84 -0.54, -1.68 -1.09, -2.52 -1.63, -3.36 -2.18, -4.2 -2.72)", + ), + ( + stf.ST_Rotate, + ("line", 10.0, 0.0, 0.0), + "linestring_geom", + "ST_ReducePrecision(geom, 2)", + "LINESTRING (0 0, -0.84 -0.54, -1.68 -1.09, -2.52 -1.63, -3.36 -2.18, -4.2 -2.72)", + ), (stf.ST_S2CellIDs, ("point", 30), "point_geom", "", [1153451514845492609]), - (stf.ST_S2ToGeom, (lambda: f.expr("array(1154047404513689600)"),), "null", "ST_ReducePrecision(geom[0], 5)", "POLYGON ((0 2.46041, 2.46041 2.46041, 2.46041 0, 0 0, 0 2.46041))"), - (stf.ST_SetPoint, ("line", 1, lambda: f.expr("ST_Point(1.0, 1.0)")), "linestring_geom", "", "LINESTRING (0 0, 1 1, 2 0, 3 0, 4 0, 5 0)"), + ( + stf.ST_S2ToGeom, + (lambda: f.expr("array(1154047404513689600)"),), + "null", + "ST_ReducePrecision(geom[0], 5)", + "POLYGON ((0 2.46041, 2.46041 2.46041, 2.46041 0, 0 0, 0 2.46041))", + ), + ( + stf.ST_SetPoint, + ("line", 1, lambda: f.expr("ST_Point(1.0, 1.0)")), + "linestring_geom", + "", + "LINESTRING (0 0, 1 1, 2 0, 3 0, 4 0, 5 0)", + ), (stf.ST_SetSRID, ("point", 3021), "point_geom", "ST_SRID(geom)", 3021), - (stf.ST_ShiftLongitude, ("geom",), "triangle_geom", "", "POLYGON ((0 0, 1 0, 1 1, 0 0))"), - (stf.ST_SimplifyPreserveTopology, ("geom", 0.2), "0.9_poly", "", "POLYGON ((0 0, 1 0, 1 1, 0 0))"), - (stf.ST_SimplifyVW, ("geom", 0.1), "0.9_poly", "", "POLYGON ((0 0, 1 0, 1 1, 0 0))"), - (stf.ST_SimplifyPolygonHull, ("geom", 0.3, False), "polygon_unsimplified", "", "POLYGON ((30 10, 40 40, 10 20, 30 10))"), - (stf.ST_SimplifyPolygonHull, ("geom", 0.3), "polygon_unsimplified", "", "POLYGON ((30 10, 15 15, 10 20, 20 40, 45 45, 30 10))"), - (stf.ST_Snap, ("poly", "line", 2.525), "poly_and_line", "" ,"POLYGON ((2.6 12.5, 2.6 20, 12.6 20, 12.6 12.5, 10.1 10, 2.6 12.5))"), - (stf.ST_Split, ("line", "points"), "multipoint_splitting_line", "", "MULTILINESTRING ((0 0, 0.5 0.5), (0.5 0.5, 1 1), (1 1, 1.5 1.5, 2 2))"), + ( + stf.ST_ShiftLongitude, + ("geom",), + "triangle_geom", + "", + "POLYGON ((0 0, 1 0, 1 1, 0 0))", + ), + ( + stf.ST_SimplifyPreserveTopology, + ("geom", 0.2), + "0.9_poly", + "", + "POLYGON ((0 0, 1 0, 1 1, 0 0))", + ), + ( + stf.ST_SimplifyVW, + ("geom", 0.1), + "0.9_poly", + "", + "POLYGON ((0 0, 1 0, 1 1, 0 0))", + ), + ( + stf.ST_SimplifyPolygonHull, + ("geom", 0.3, False), + "polygon_unsimplified", + "", + "POLYGON ((30 10, 40 40, 10 20, 30 10))", + ), + ( + stf.ST_SimplifyPolygonHull, + ("geom", 0.3), + "polygon_unsimplified", + "", + "POLYGON ((30 10, 15 15, 10 20, 20 40, 45 45, 30 10))", + ), + ( + stf.ST_Snap, + ("poly", "line", 2.525), + "poly_and_line", + "", + "POLYGON ((2.6 12.5, 2.6 20, 12.6 20, 12.6 12.5, 10.1 10, 2.6 12.5))", + ), + ( + stf.ST_Split, + ("line", "points"), + "multipoint_splitting_line", + "", + "MULTILINESTRING ((0 0, 0.5 0.5), (0.5 0.5, 1 1), (1 1, 1.5 1.5, 2 2))", + ), (stf.ST_SRID, ("point",), "point_geom", "", 0), (stf.ST_StartPoint, ("line",), "linestring_geom", "", "POINT (0 0)"), - (stf.ST_SubDivide, ("line", 5), "linestring_geom", "", ["LINESTRING (0 0, 2.5 0)", "LINESTRING (2.5 0, 5 0)"]), - (stf.ST_SubDivideExplode, ("line", 5), "linestring_geom", "collect_list(geom)", ["LINESTRING (0 0, 2.5 0)", "LINESTRING (2.5 0, 5 0)"]), - (stf.ST_SymDifference, ("a", "b"), "overlapping_polys", "", "MULTIPOLYGON (((1 0, 0 0, 0 1, 1 1, 1 0)), ((2 0, 2 1, 3 1, 3 0, 2 0)))"), - (stf.ST_Transform, ("point", lambda: f.lit("EPSG:4326"), lambda: f.lit("EPSG:32649")), "point_geom", "ST_ReducePrecision(geom, 2)", "POINT (-34870890.91 1919456.06)"), - (stf.ST_Translate, ("geom", 1.0, 1.0,), "square_geom", "", "POLYGON ((2 1, 2 2, 3 2, 3 1, 2 1))"), - (stf.ST_TriangulatePolygon, ("geom",), "square_geom", "", "GEOMETRYCOLLECTION (POLYGON ((1 0, 1 1, 2 1, 1 0)), POLYGON ((2 1, 2 0, 1 0, 2 1)))"), - (stf.ST_Union, ("a", "b"), "overlapping_polys", "", "POLYGON ((1 0, 0 0, 0 1, 1 1, 2 1, 3 1, 3 0, 2 0, 1 0))"), - (stf.ST_Union, ("polys",), "array_polygons", "", "POLYGON ((2 3, 3 3, 3 -3, -3 -3, -3 3, -2 3, -2 4, 2 4, 2 3))"), - (stf.ST_UnaryUnion, ("geom",), "overlapping_mPolys", "", "POLYGON ((10 0, 10 10, 0 10, 0 30, 20 30, 20 20, 30 20, 30 0, 10 0))"), - (stf.ST_VoronoiPolygons, ("geom",), "multipoint", "", "GEOMETRYCOLLECTION (POLYGON ((-1 -1, -1 2, 2 -1, -1 -1)), POLYGON ((-1 2, 2 2, 2 -1, -1 2)))"), + ( + stf.ST_SubDivide, + ("line", 5), + "linestring_geom", + "", + ["LINESTRING (0 0, 2.5 0)", "LINESTRING (2.5 0, 5 0)"], + ), + ( + stf.ST_SubDivideExplode, + ("line", 5), + "linestring_geom", + "collect_list(geom)", + ["LINESTRING (0 0, 2.5 0)", "LINESTRING (2.5 0, 5 0)"], + ), + ( + stf.ST_SymDifference, + ("a", "b"), + "overlapping_polys", + "", + "MULTIPOLYGON (((1 0, 0 0, 0 1, 1 1, 1 0)), ((2 0, 2 1, 3 1, 3 0, 2 0)))", + ), + ( + stf.ST_Transform, + ("point", lambda: f.lit("EPSG:4326"), lambda: f.lit("EPSG:32649")), + "point_geom", + "ST_ReducePrecision(geom, 2)", + "POINT (-34870890.91 1919456.06)", + ), + ( + stf.ST_Translate, + ( + "geom", + 1.0, + 1.0, + ), + "square_geom", + "", + "POLYGON ((2 1, 2 2, 3 2, 3 1, 2 1))", + ), + ( + stf.ST_TriangulatePolygon, + ("geom",), + "square_geom", + "", + "GEOMETRYCOLLECTION (POLYGON ((1 0, 1 1, 2 1, 1 0)), POLYGON ((2 1, 2 0, 1 0, 2 1)))", + ), + ( + stf.ST_Union, + ("a", "b"), + "overlapping_polys", + "", + "POLYGON ((1 0, 0 0, 0 1, 1 1, 2 1, 3 1, 3 0, 2 0, 1 0))", + ), + ( + stf.ST_Union, + ("polys",), + "array_polygons", + "", + "POLYGON ((2 3, 3 3, 3 -3, -3 -3, -3 3, -2 3, -2 4, 2 4, 2 3))", + ), + ( + stf.ST_UnaryUnion, + ("geom",), + "overlapping_mPolys", + "", + "POLYGON ((10 0, 10 10, 0 10, 0 30, 20 30, 20 20, 30 20, 30 0, 10 0))", + ), + ( + stf.ST_VoronoiPolygons, + ("geom",), + "multipoint", + "", + "GEOMETRYCOLLECTION (POLYGON ((-1 -1, -1 2, 2 -1, -1 -1)), POLYGON ((-1 2, 2 2, 2 -1, -1 2)))", + ), (stf.ST_X, ("b",), "two_points", "", 3.0), (stf.ST_XMax, ("line",), "linestring_geom", "", 5.0), (stf.ST_XMin, ("line",), "linestring_geom", "", 0.0), @@ -246,33 +1014,125 @@ (stf.ST_Zmflag, ("b",), "two_points", "", 2), (stf.ST_IsValidReason, ("geom",), "triangle_geom", "", "Valid Geometry"), (stf.ST_IsValidReason, ("geom", 1), "triangle_geom", "", "Valid Geometry"), - # predicates - (stp.ST_Contains, ("geom", lambda: f.expr("ST_Point(0.5, 0.25)")), "triangle_geom", "", True), + ( + stp.ST_Contains, + ("geom", lambda: f.expr("ST_Point(0.5, 0.25)")), + "triangle_geom", + "", + True, + ), (stp.ST_Crosses, ("line", "poly"), "line_crossing_poly", "", True), (stp.ST_Disjoint, ("a", "b"), "two_points", "", True), - (stp.ST_Equals, ("line", lambda: f.expr("ST_Reverse(line)")), "linestring_geom", "", True), + ( + stp.ST_Equals, + ("line", lambda: f.expr("ST_Reverse(line)")), + "linestring_geom", + "", + True, + ), (stp.ST_Intersects, ("a", "b"), "overlapping_polys", "", True), - (stp.ST_OrderingEquals, ("line", lambda: f.expr("ST_Reverse(line)")), "linestring_geom", "", False), + ( + stp.ST_OrderingEquals, + ("line", lambda: f.expr("ST_Reverse(line)")), + "linestring_geom", + "", + False, + ), (stp.ST_Overlaps, ("a", "b"), "overlapping_polys", "", True), (stp.ST_Touches, ("a", "b"), "touching_polys", "", True), (stp.ST_Relate, ("a", "b"), "touching_polys", "", "FF2F11212"), (stp.ST_Relate, ("a", "b", lambda: f.lit("FF2F11212")), "touching_polys", "", True), - (stp.ST_RelateMatch, (lambda: f.lit("101202FFF"), lambda: f.lit("TTTTTTFFF")), "touching_polys", "", True), - (stp.ST_Within, (lambda: f.expr("ST_Point(0.5, 0.25)"), "geom"), "triangle_geom", "", True), - (stp.ST_Covers, ("geom", lambda: f.expr("ST_Point(0.5, 0.25)")), "triangle_geom", "", True), - (stp.ST_CoveredBy, (lambda: f.expr("ST_Point(0.5, 0.25)"), "geom"), "triangle_geom", "", True), - (stp.ST_Contains, ("geom", lambda: f.expr("ST_Point(0.0, 0.0)")), "triangle_geom", "", False), - (stp.ST_Within, (lambda: f.expr("ST_Point(0.0, 0.0)"), "geom"), "triangle_geom", "", False), - (stp.ST_Covers, ("geom", lambda: f.expr("ST_Point(0.0, 0.0)")), "triangle_geom", "", True), - (stp.ST_CoveredBy, (lambda: f.expr("ST_Point(0.0, 0.0)"), "geom"), "triangle_geom", "", True), - (stp.ST_DWithin, ("origin", "point", 5.0,), "origin_and_point", "", True), + ( + stp.ST_RelateMatch, + (lambda: f.lit("101202FFF"), lambda: f.lit("TTTTTTFFF")), + "touching_polys", + "", + True, + ), + ( + stp.ST_Within, + (lambda: f.expr("ST_Point(0.5, 0.25)"), "geom"), + "triangle_geom", + "", + True, + ), + ( + stp.ST_Covers, + ("geom", lambda: f.expr("ST_Point(0.5, 0.25)")), + "triangle_geom", + "", + True, + ), + ( + stp.ST_CoveredBy, + (lambda: f.expr("ST_Point(0.5, 0.25)"), "geom"), + "triangle_geom", + "", + True, + ), + ( + stp.ST_Contains, + ("geom", lambda: f.expr("ST_Point(0.0, 0.0)")), + "triangle_geom", + "", + False, + ), + ( + stp.ST_Within, + (lambda: f.expr("ST_Point(0.0, 0.0)"), "geom"), + "triangle_geom", + "", + False, + ), + ( + stp.ST_Covers, + ("geom", lambda: f.expr("ST_Point(0.0, 0.0)")), + "triangle_geom", + "", + True, + ), + ( + stp.ST_CoveredBy, + (lambda: f.expr("ST_Point(0.0, 0.0)"), "geom"), + "triangle_geom", + "", + True, + ), + ( + stp.ST_DWithin, + ( + "origin", + "point", + 5.0, + ), + "origin_and_point", + "", + True, + ), (stp.ST_DWithin, ("ny", "seattle", 4000000.0, True), "ny_seattle", "", True), - # aggregates - (sta.ST_Envelope_Aggr, ("geom",), "exploded_points", "", "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))"), - (sta.ST_Intersection_Aggr, ("geom",), "exploded_polys", "", "LINESTRING (1 0, 1 1)"), - (sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))"), + ( + sta.ST_Envelope_Aggr, + ("geom",), + "exploded_points", + "", + "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))", + ), + ( + sta.ST_Intersection_Aggr, + ("geom",), + "exploded_polys", + "", + "LINESTRING (1 0, 1 1)", + ), + ( + sta.ST_Union_Aggr, + ("geom",), + "exploded_polys", + "", + "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))", + ), ] wrong_type_configurations = [ @@ -314,7 +1174,6 @@ (stc.ST_PolygonFromText, ("", None)), (stc.ST_MakePointM, (None, None, None)), (stc.ST_MakePointM, (None, "", "")), - # functions (stf.ST_3DDistance, (None, "")), (stf.ST_3DDistance, ("", None)), @@ -353,9 +1212,9 @@ (stf.ST_DelaunayTriangles, (None,)), (stf.ST_EndPoint, (None,)), (stf.ST_Envelope, (None,)), - (stf.ST_Expand, (None,"")), - (stf.ST_Expand, (None,None)), - (stf.ST_Expand, ("",None)), + (stf.ST_Expand, (None, "")), + (stf.ST_Expand, (None, None)), + (stf.ST_Expand, ("", None)), (stf.ST_ExteriorRing, (None,)), (stf.ST_FlipCoordinates, (None,)), (stf.ST_Force_2D, (None,)), @@ -437,8 +1296,14 @@ (stf.ST_RemovePoint, ("", 1.0)), (stf.ST_RemoveRepeatedPoints, (None, None)), (stf.ST_Reverse, (None,)), - (stf.ST_Rotate, (None,None,)), - (stf.ST_Rotate, (None,None)), + ( + stf.ST_Rotate, + ( + None, + None, + ), + ), + (stf.ST_Rotate, (None, None)), (stf.ST_S2CellIDs, (None, 2)), (stf.ST_S2ToGeom, (None,)), (stf.ST_SetPoint, (None, 1, "")), @@ -479,7 +1344,6 @@ (stf.ST_YMin, (None,)), (stf.ST_Z, (None,)), (stf.ST_Zmflag, (None,)), - # predicates (stp.ST_Contains, (None, "")), (stp.ST_Contains, ("", None)), @@ -503,27 +1367,29 @@ (stp.ST_RelateMatch, ("", None)), (stp.ST_Within, (None, "")), (stp.ST_Within, ("", None)), - # aggregates (sta.ST_Envelope_Aggr, (None,)), (sta.ST_Intersection_Aggr, (None,)), (sta.ST_Union_Aggr, (None,)), ] + class TestDataFrameAPI(TestBase): @pytest.fixture def base_df(self, request): - wkbLine = '0102000000020000000000000084d600c00000000080b5d6bf00000060e1eff7bf00000080075de5bf' - wkbPoint = '010100000000000000000024400000000000002e40' - wkb = '0102000000020000000000000084d600c00000000080b5d6bf00000060e1eff7bf00000080075de5bf' - mpoly = 'MULTIPOLYGON(((0 0 ,20 0 ,20 20 ,0 20 ,0 0 ),(5 5 ,5 7 ,7 7 ,7 5 ,5 5)))' - mline = 'MULTILINESTRING((1 2, 3 4), (4 5, 6 7))' - mpoint = 'MULTIPOINT ((10 10), (20 20), (30 30))' - geojson = "{ \"type\": \"Feature\", \"properties\": { \"prop\": \"01\" }, \"geometry\": { \"type\": \"Point\", \"coordinates\": [ 0.0, 1.0 ] }}," - gml_string = "-71.16,42.25 -71.17,42.25 -71.18,42.25" + wkbLine = "0102000000020000000000000084d600c00000000080b5d6bf00000060e1eff7bf00000080075de5bf" + wkbPoint = "010100000000000000000024400000000000002e40" + wkb = "0102000000020000000000000084d600c00000000080b5d6bf00000060e1eff7bf00000080075de5bf" + mpoly = ( + "MULTIPOLYGON(((0 0 ,20 0 ,20 20 ,0 20 ,0 0 ),(5 5 ,5 7 ,7 7 ,7 5 ,5 5)))" + ) + mline = "MULTILINESTRING((1 2, 3 4), (4 5, 6 7))" + mpoint = "MULTIPOINT ((10 10), (20 20), (30 30))" + geojson = '{ "type": "Feature", "properties": { "prop": "01" }, "geometry": { "type": "Point", "coordinates": [ 0.0, 1.0 ] }},' + gml_string = '-71.16,42.25 -71.17,42.25 -71.18,42.25' kml_string = "-71.16,42.26 -71.17,42.26" - wktCollection = 'GEOMETRYCOLLECTION(POINT(1 1), LINESTRING(0 0, 1 1)))' + wktCollection = "GEOMETRYCOLLECTION(POINT(1 1), LINESTRING(0 0, 1 1)))" if request.param == "constructor": return TestDataFrameAPI.spark.sql("SELECT null").selectExpr( @@ -542,96 +1408,176 @@ def base_df(self, request): "'s00twy01mt' AS geohash", f"'{gml_string}' AS gml", f"'{kml_string}' AS kml", - f"'{wktCollection}' AS collection" + f"'{wktCollection}' AS collection", ) elif request.param == "point_geom": return TestDataFrameAPI.spark.sql("SELECT ST_Point(0.0, 1.0) AS point") elif request.param == "linestring_geom": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)') AS line") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)') AS line" + ) elif request.param == "linestring_wkt": return TestDataFrameAPI.spark.sql("SELECT 'LINESTRING (1 2, 3 4)' AS wkt") elif request.param == "linestring_ewkt": - return TestDataFrameAPI.spark.sql("SELECT 'SRID=4269;LINESTRING (1 2, 3 4)' AS ewkt") + return TestDataFrameAPI.spark.sql( + "SELECT 'SRID=4269;LINESTRING (1 2, 3 4)' AS ewkt" + ) elif request.param == "min_max_x_y": - return TestDataFrameAPI.spark.sql("SELECT 0.0 AS minx, 1.0 AS miny, 2.0 AS maxx, 3.0 AS maxy") + return TestDataFrameAPI.spark.sql( + "SELECT 0.0 AS minx, 1.0 AS miny, 2.0 AS maxx, 3.0 AS maxy" + ) elif request.param == "x_y_z_m_srid": - return TestDataFrameAPI.spark.sql("SELECT 1.0 AS x, 2.0 AS y, 3.0 AS z, 100.9 AS m, 4326 AS srid") + return TestDataFrameAPI.spark.sql( + "SELECT 1.0 AS x, 2.0 AS y, 3.0 AS z, 100.9 AS m, 4326 AS srid" + ) elif request.param == "multipoint_geom": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('MULTIPOINT((10 40), (40 30), (20 20), (30 10))') AS multipoint") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('MULTIPOINT((10 40), (40 30), (20 20), (30 10))') AS multipoint" + ) elif request.param == "null": return TestDataFrameAPI.spark.sql("SELECT null") elif request.param == "triangle_geom": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 0))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 0))') AS geom" + ) elif request.param == "two_points": - return TestDataFrameAPI.spark.sql("SELECT ST_PointZ(0.0, 0.0, 0.0) AS a, ST_PointZ(3.0, 0.0, 4.0) AS b") + return TestDataFrameAPI.spark.sql( + "SELECT ST_PointZ(0.0, 0.0, 0.0) AS a, ST_PointZ(3.0, 0.0, 4.0) AS b" + ) elif request.param == "4D_point": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POINT ZM(1 2 3 4)') AS point") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POINT ZM(1 2 3 4)') AS point" + ) elif request.param == "4D_line": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('LINESTRING ZM(1 1 1 1, 2 2 2 2, 3 3 3 3, -1 -1 -1 -1)') AS line") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING ZM(1 1 1 1, 2 2 2 2, 3 3 3 3, -1 -1 -1 -1)') AS line" + ) elif request.param == "invalid_geom": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((1 5, 1 1, 3 3, 5 3, 7 1, 7 5, 5 3, 3 3, 1 5))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((1 5, 1 1, 3 3, 5 3, 7 1, 7 5, 5 3, 3 3, 1 5))') AS geom" + ) elif request.param == "overlapping_polys": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON((0 0, 2 0, 2 1, 0 1, 0 0))') AS a, ST_GeomFromWKT('POLYGON((1 0, 3 0, 3 1, 1 1, 1 0))') AS b") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON((0 0, 2 0, 2 1, 0 1, 0 0))') AS a, ST_GeomFromWKT('POLYGON((1 0, 3 0, 3 1, 1 1, 1 0))') AS b" + ) elif request.param == "overlapping_mPolys": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('MULTIPOLYGON(((0 10,0 30,20 30,20 10,0 10)),((10 0,10 20,30 20,30 0,10 0)))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('MULTIPOLYGON(((0 10,0 30,20 30,20 10,0 10)),((10 0,10 20,30 20,30 0,10 0)))') AS geom" + ) elif request.param == "multipoint": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('MULTIPOINT ((0 0), (1 1))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('MULTIPOINT ((0 0), (1 1))') AS geom" + ) elif request.param == "geom_with_hole": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((0 0, 3 0, 3 3, 0 0), (1 1, 2 2, 2 1, 1 1))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 3 0, 3 3, 0 0), (1 1, 2 2, 2 1, 1 1))') AS geom" + ) elif request.param == "0.9_poly": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 0.9, 1 1, 0 0))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 0.9, 1 1, 0 0))') AS geom" + ) elif request.param == "polygon_unsimplified": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((30 10, 40 40, 45 45, 20 40, 25 35, 10 20, 15 15, 30 10))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((30 10, 40 40, 45 45, 20 40, 25 35, 10 20, 15 15, 30 10))') AS geom" + ) elif request.param == "precision_reduce_point": return TestDataFrameAPI.spark.sql("SELECT ST_Point(0.12, 0.23) AS geom") elif request.param == "closed_linestring_geom": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (0 0, 1 0, 1 1, 0 0)') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (0 0, 1 0, 1 1, 0 0)') AS geom" + ) elif request.param == "empty_geom": - return TestDataFrameAPI.spark.sql("SELECT ST_Difference(ST_Point(0.0, 0.0), ST_Point(0.0, 0.0)) AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_Difference(ST_Point(0.0, 0.0), ST_Point(0.0, 0.0)) AS geom" + ) elif request.param == "repeated_multipoint": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('MULTIPOINT (1 1, 1 1, 2 2, 3 3, 3 3, 4 4)') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('MULTIPOINT (1 1, 1 1, 2 2, 3 3, 3 3, 4 4)') AS geom" + ) elif request.param == "multiline_geom": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('MULTILINESTRING ((0 0, 1 0), (1 0, 1 1), (1 1, 0 0))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('MULTILINESTRING ((0 0, 1 0), (1 0, 1 1), (1 1, 0 0))') AS geom" + ) elif request.param == "geom_collection": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION(POINT(0 0), LINESTRING(0 0, 1 0))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION(POINT(0 0), LINESTRING(0 0, 1 0))') AS geom" + ) elif request.param == "exploded_points": - return TestDataFrameAPI.spark.sql("SELECT explode(array(ST_Point(0.0, 0.0), ST_Point(1.0, 1.0))) AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT explode(array(ST_Point(0.0, 0.0), ST_Point(1.0, 1.0))) AS geom" + ) elif request.param == "exploded_polys": - return TestDataFrameAPI.spark.sql("SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))'), ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))'))) AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))'), ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))'))) AS geom" + ) elif request.param == "touching_polys": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))') AS a, ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))') AS b") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))') AS a, ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))') AS b" + ) elif request.param == "line_crossing_poly": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (0 0, 2 1)') AS line, ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 2, 1 2, 1 0))') AS poly") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (0 0, 2 1)') AS line, ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 2, 1 2, 1 0))') AS poly" + ) elif request.param == "square_geom": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))') AS geom" + ) elif request.param == "four_points": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POINT (0 0)') AS p1, ST_GeomFromWKT('POINT (1 1)') AS p2, ST_GeomFromWKT('POINT (1 0)') AS p3, ST_GeomFromWKT('POINT (6 2)') AS p4") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POINT (0 0)') AS p1, ST_GeomFromWKT('POINT (1 1)') AS p2, ST_GeomFromWKT('POINT (1 0)') AS p3, ST_GeomFromWKT('POINT (6 2)') AS p4" + ) elif request.param == "three_points": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POINT (1 1)') AS p1, ST_GeomFromWKT('POINT (0 0)') AS p2, ST_GeomFromWKT('POINT (3 2)') AS p3") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POINT (1 1)') AS p1, ST_GeomFromWKT('POINT (0 0)') AS p2, ST_GeomFromWKT('POINT (3 2)') AS p3" + ) elif request.param == "two_lines": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (0 0, 1 1)') AS line1, ST_GeomFromWKT('LINESTRING (0 0, 3 2)') AS line2") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (0 0, 1 1)') AS line1, ST_GeomFromWKT('LINESTRING (0 0, 3 2)') AS line2" + ) elif request.param == "two_lines_angle_rad": - return TestDataFrameAPI.spark.sql("SELECT ST_Angle(ST_GeomFromWKT('LINESTRING (0 0, 1 1)'), ST_GeomFromWKT('LINESTRING (0 0, 3 2)')) AS angleRad") + return TestDataFrameAPI.spark.sql( + "SELECT ST_Angle(ST_GeomFromWKT('LINESTRING (0 0, 1 1)'), ST_GeomFromWKT('LINESTRING (0 0, 3 2)')) AS angleRad" + ) elif request.param == "geometry_geom_collection": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION(POINT(1 1), LINESTRING(0 0, 1 1, 2 2))') AS geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION(POINT(1 1), LINESTRING(0 0, 1 1, 2 2))') AS geom" + ) elif request.param == "point_and_line": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POINT (0.0 1.0)') AS point, ST_GeomFromWKT('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)') AS line") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POINT (0.0 1.0)') AS point, ST_GeomFromWKT('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)') AS line" + ) elif request.param == "line_and_point": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (0 2, 1 1, 2 0)') AS line, ST_GeomFromWKT('POINT (0 0)') AS point") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (0 2, 1 1, 2 0)') AS line, ST_GeomFromWKT('POINT (0 0)') AS point" + ) elif request.param == "multipoint_splitting_line": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (0 0, 1.5 1.5, 2 2)') AS line, ST_GeomFromWKT('MULTIPOINT (0.5 0.5, 1 1)') AS points") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (0 0, 1.5 1.5, 2 2)') AS line, ST_GeomFromWKT('MULTIPOINT (0.5 0.5, 1 1)') AS points" + ) elif request.param == "origin_and_point": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POINT (0 0)') AS origin, ST_GeomFromWKT('POINT (1 0)') as point") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POINT (0 0)') AS origin, ST_GeomFromWKT('POINT (1 0)') as point" + ) elif request.param == "ny_seattle": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POINT (-122.335167 47.608013)') AS seattle, ST_GeomFromWKT('POINT (-73.935242 40.730610)') as ny") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POINT (-122.335167 47.608013)') AS seattle, ST_GeomFromWKT('POINT (-73.935242 40.730610)') as ny" + ) elif request.param == "line_crossing_dateline": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (179.95 30, -179.95 30)') AS line") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (179.95 30, -179.95 30)') AS line" + ) elif request.param == "array_polygons": - return TestDataFrameAPI.spark.sql("SELECT array(ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))'), ST_GeomFromWKT('POLYGON ((-2 1, 2 1, 2 4, -2 4, -2 1))')) as polys") + return TestDataFrameAPI.spark.sql( + "SELECT array(ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))'), ST_GeomFromWKT('POLYGON ((-2 1, 2 1, 2 4, -2 4, -2 1))')) as polys" + ) elif request.param == "poly_and_line": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POLYGON((2.6 12.5, 2.6 20.0, 12.6 20.0, 12.6 12.5, 2.6 12.5 ))') as poly, ST_GeomFromWKT('LINESTRING (0.5 10.7, 5.4 8.4, 10.1 10.0)') as line") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON((2.6 12.5, 2.6 20.0, 12.6 20.0, 12.6 12.5, 2.6 12.5 ))') as poly, ST_GeomFromWKT('LINESTRING (0.5 10.7, 5.4 8.4, 10.1 10.0)') as line" + ) elif request.param == "noded_linework": - return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION (LINESTRING (2 0, 2 1, 2 2), LINESTRING (2 2, 2 3, 2 4), LINESTRING (0 2, 1 2, 2 2), LINESTRING (2 2, 3 2, 4 2), LINESTRING (0 2, 1 3, 2 4), LINESTRING (2 4, 3 3, 4 2))') as geom") + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION (LINESTRING (2 0, 2 1, 2 2), LINESTRING (2 2, 2 3, 2 4), LINESTRING (0 2, 1 2, 2 2), LINESTRING (2 2, 3 2, 4 2), LINESTRING (0 2, 1 3, 2 4), LINESTRING (2 4, 3 3, 4 2))') as geom" + ) raise ValueError(f"Invalid base_df name passed: {request.param}") def _id_test_configuration(val): @@ -643,8 +1589,15 @@ def _id_test_configuration(val): return f"{val}" return val - @pytest.mark.parametrize("func,args,base_df,post_process,expected_result", test_configurations, ids=_id_test_configuration, indirect=["base_df"]) - def test_dataframe_function(self, func, args, base_df, post_process, expected_result): + @pytest.mark.parametrize( + "func,args,base_df,post_process,expected_result", + test_configurations, + ids=_id_test_configuration, + indirect=["base_df"], + ) + def test_dataframe_function( + self, func, args, base_df, post_process, expected_result + ): args = [arg() if isinstance(arg, Callable) else arg for arg in args] if len(args) == 1: @@ -662,13 +1615,23 @@ def test_dataframe_function(self, func, args, base_df, post_process, expected_re elif isinstance(actual_result, bytearray): actual_result = actual_result.hex() elif isinstance(actual_result, Row): - actual_result = {k: v.wkt if isinstance(v, BaseGeometry) else v for k, v in actual_result.asDict().items()} + actual_result = { + k: v.wkt if isinstance(v, BaseGeometry) else v + for k, v in actual_result.asDict().items() + } elif isinstance(actual_result, list): - actual_result = sorted([x.wkt if isinstance(x, BaseGeometry) else x for x in actual_result]) + actual_result = sorted( + [x.wkt if isinstance(x, BaseGeometry) else x for x in actual_result] + ) - assert(actual_result == expected_result) + assert actual_result == expected_result - @pytest.mark.parametrize("func,args", wrong_type_configurations, ids=_id_test_configuration) + @pytest.mark.parametrize( + "func,args", wrong_type_configurations, ids=_id_test_configuration + ) def test_call_function_with_wrong_type(self, func, args): - with pytest.raises(ValueError, match=f"Incorrect argument type: [A-Za-z_0-9]+ for {func.__name__} should be [A-Za-z0-9\\[\\]_, ]+ but received [A-Za-z0-9_]+."): + with pytest.raises( + ValueError, + match=f"Incorrect argument type: [A-Za-z_0-9]+ for {func.__name__} should be [A-Za-z0-9\\[\\]_, ]+ but received [A-Za-z0-9_]+.", + ): func(*args) diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py index f139c6963e..9100a04f81 100644 --- a/python/tests/sql/test_function.py +++ b/python/tests/sql/test_function.py @@ -24,156 +24,220 @@ from shapely import wkt from shapely.wkt import loads from tests import mixed_wkt_geometry_input_location -from tests.sql.resource.sample_data import create_sample_points, create_simple_polygons_df, \ - create_sample_points_df, create_sample_polygons_df, create_sample_lines_df +from tests.sql.resource.sample_data import ( + create_sample_points, + create_simple_polygons_df, + create_sample_points_df, + create_sample_polygons_df, + create_sample_lines_df, +) from tests.test_base import TestBase from typing import List class TestPredicateJoin(TestBase): - geo_schema = StructType( - [StructField("geom", GeometryType(), False)] - ) + geo_schema = StructType([StructField("geom", GeometryType(), False)]) geo_schema_with_index = StructType( [ StructField("index", IntegerType(), False), - StructField("geom", GeometryType(), False) + StructField("geom", GeometryType(), False), ] ) geo_pair_schema = StructType( [ StructField("geomA", GeometryType(), False), - StructField("geomB", GeometryType(), False) + StructField("geomB", GeometryType(), False), ] ) def test_st_concave_hull(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_ConcaveHull(polygondf.countyshape, 1.0) from polygondf") + function_df = self.spark.sql( + "select ST_ConcaveHull(polygondf.countyshape, 1.0) from polygondf" + ) function_df.show() def test_st_convex_hull(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_ConvexHull(polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_ConvexHull(polygondf.countyshape) from polygondf" + ) function_df.show() def test_st_buffer(self): - polygon_from_wkt = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_from_wkt = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_from_wkt.createOrReplaceTempView("polygontable") polygon_from_wkt.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_ReducePrecision(ST_Buffer(polygondf.countyshape, 1), 2) from polygondf") + function_df = self.spark.sql( + "select ST_ReducePrecision(ST_Buffer(polygondf.countyshape, 1), 2) from polygondf" + ) actual = function_df.take(1)[0][0].wkt - assert actual == "POLYGON ((-98.02 41.77, -98.02 41.78, -98.02 41.8, -98.02 41.81, -98.02 41.82, -98.02 41.83, -98.02 41.84, -98.02 41.85, -98.02 41.86, -98.02 41.87, -98.02 41.89, -98.02 41.9, -98.02 41.91, -98.02 41.92, -98.02 41.93, -98.02 41.95, -98.02 41.98, -98.02 42, -98.02 42.01, -98.02 42.02, -98.02 42.04, -98.02 42.05, -98.02 42.07, -98.02 42.09, -98.02 42.11, -98.02 42.12, -97.99 42.31, -97.93 42.5, -97.84 42.66, -97.72 42.81, -97.57 42.93, -97.4 43.02, -97.21 43.07, -97.02 43.09, -97 43.09, -96.98 43.09, -96.97 43.09, -96.96 43.09, -96.95 43.09, -96.94 43.09, -96.92 43.09, -96.91 43.09, -96.89 43.09, -96.88 43.09, -96.86 43.09, -96.84 43.09, -96.83 43.09, -96.82 43.09, -96.81 43.09, -96.8 43.09, -96.79 43.09, -96.78 43.09, -96.76 43.09, -96.74 43.09, -96.73 43.09, -96.71 43.09, -96.7 43.09, -96.69 43.09, -96.68 43.09, -96.66 43.09, -96.65 43.09, -96.64 43.09, -96.63 43.09, -96.62 43.09, -96.61 43.09, -96.6 43.09, -96.59 43.09, -96.58 43.09, -96.57 43.09, -96.56 43.09, -96.55 43.09, -96.36 43.07, -96.17 43.01, -96 42.92, -95.86 42.8, -95.73 42.66, -95.64 42.49, -95.58 42.31, -95.56 42.12, -95.56 42.1, -95.56 42.09, -95.56 42.08, -95.56 42.07, -95.56 42.06, -95.56 42.04, -95.56 42, -95.56 41.99, -95.56 41.98, -95.56 41.97, -95.56 41.96, -95.56 41.95, -95.56 41.94, -95.56 41.93, -95.56 41.92, -95.56 41.91, -95.56 41.9, -95.56 41.89, -95.56 41.88, -95.56 41.87, -95.56 41.86, -95.56 41.85, -95.56 41.83, -95.56 41.82, -95.56 41.81, -95.56 41.8, -95.56 41.79, -95.56 41.78, -95.56 41.77, -95.56 41.76, -95.56 41.75, -95.56 41.74, -95.58 41.54, -95.63 41.36, -95.72 41.19, -95.85 41.03, -96 40.91, -96.17 40.82, -96.36 40.76, -96.55 40.74, -96.56 40.74, -96.57 40.74, -96.58 40.74, -96.59 40.74, -96.6 40.74, -96.62 40.74, -96.63 40.74, -96.64 40.74, -96.65 40.74, -96.67 40.74, -96.68 40.74, -96.69 40.74, -96.7 40.74, -96.71 40.74, -96.72 40.74, -96.73 40.74, -96.74 40.74, -96.75 40.74, -96.76 40.74, -96.77 40.74, -96.78 40.74, -96.79 40.74, -96.8 40.74, -96.81 40.74, -96.82 40.74, -96.83 40.74, -96.85 40.74, -96.86 40.74, -96.88 40.74, -96.9 40.74, -96.91 40.74, -96.92 40.74, -96.93 40.74, -96.94 40.74, -96.95 40.74, -96.97 40.74, -96.98 40.74, -96.99 40.74, -97.01 40.74, -97.02 40.74, -97.22 40.76, -97.4 40.82, -97.57 40.91, -97.72 41.03, -97.85 41.18, -97.94 41.35, -98 41.54, -98.02 41.73, -98.02 41.75, -98.02 41.76, -98.02 41.77))" + assert ( + actual + == "POLYGON ((-98.02 41.77, -98.02 41.78, -98.02 41.8, -98.02 41.81, -98.02 41.82, -98.02 41.83, -98.02 41.84, -98.02 41.85, -98.02 41.86, -98.02 41.87, -98.02 41.89, -98.02 41.9, -98.02 41.91, -98.02 41.92, -98.02 41.93, -98.02 41.95, -98.02 41.98, -98.02 42, -98.02 42.01, -98.02 42.02, -98.02 42.04, -98.02 42.05, -98.02 42.07, -98.02 42.09, -98.02 42.11, -98.02 42.12, -97.99 42.31, -97.93 42.5, -97.84 42.66, -97.72 42.81, -97.57 42.93, -97.4 43.02, -97.21 43.07, -97.02 43.09, -97 43.09, -96.98 43.09, -96.97 43.09, -96.96 43.09, -96.95 43.09, -96.94 43.09, -96.92 43.09, -96.91 43.09, -96.89 43.09, -96.88 43.09, -96.86 43.09, -96.84 43.09, -96.83 43.09, -96.82 43.09, -96.81 43.09, -96.8 43.09, -96.79 43.09, -96.78 43.09, -96.76 43.09, -96.74 43.09, -96.73 43.09, -96.71 43.09, -96.7 43.09, -96.69 43.09, -96.68 43.09, -96.66 43.09, -96.65 43.09, -96.64 43.09, -96.63 43.09, -96.62 43.09, -96.61 43.09, -96.6 43.09, -96.59 43.09, -96.58 43.09, -96.57 43.09, -96.56 43.09, -96.55 43.09, -96.36 43.07, -96.17 43.01, -96 42.92, -95.86 42.8, -95.73 42.66, -95.64 42.49, -95.58 42.31, -95.56 42.12, -95.56 42.1, -95.56 42.09, -95.56 42.08, -95.56 42.07, -95.56 42.06, -95.56 42.04, -95.56 42, -95.56 41.99, -95.56 41.98, -95.56 41.97, -95.56 41.96, -95.56 41.95, -95.56 41.94, -95.56 41.93, -95.56 41.92, -95.56 41.91, -95.56 41.9, -95.56 41.89, -95.56 41.88, -95.56 41.87, -95.56 41.86, -95.56 41.85, -95.56 41.83, -95.56 41.82, -95.56 41.81, -95.56 41.8, -95.56 41.79, -95.56 41.78, -95.56 41.77, -95.56 41.76, -95.56 41.75, -95.56 41.74, -95.58 41.54, -95.63 41.36, -95.72 41.19, -95.85 41.03, -96 40.91, -96.17 40.82, -96.36 40.76, -96.55 40.74, -96.56 40.74, -96.57 40.74, -96.58 40.74, -96.59 40.74, -96.6 40.74, -96.62 40.74, -96.63 40.74, -96.64 40.74, -96.65 40.74, -96.67 40.74, -96.68 40.74, -96.69 40.74, -96.7 40.74, -96.71 40.74, -96.72 40.74, -96.73 40.74, -96.74 40.74, -96.75 40.74, -96.76 40.74, -96.77 40.74, -96.78 40.74, -96.79 40.74, -96.8 40.74, -96.81 40.74, -96.82 40.74, -96.83 40.74, -96.85 40.74, -96.86 40.74, -96.88 40.74, -96.9 40.74, -96.91 40.74, -96.92 40.74, -96.93 40.74, -96.94 40.74, -96.95 40.74, -96.97 40.74, -96.98 40.74, -96.99 40.74, -97.01 40.74, -97.02 40.74, -97.22 40.76, -97.4 40.82, -97.57 40.91, -97.72 41.03, -97.85 41.18, -97.94 41.35, -98 41.54, -98.02 41.73, -98.02 41.75, -98.02 41.76, -98.02 41.77))" + ) - function_df = self.spark.sql("select ST_ReducePrecision(ST_Buffer(polygondf.countyshape, 1, true), 2) from polygondf") + function_df = self.spark.sql( + "select ST_ReducePrecision(ST_Buffer(polygondf.countyshape, 1, true), 2) from polygondf" + ) actual = function_df.take(1)[0][0].wkt - assert actual == "POLYGON ((-97.02 42.01, -97.02 42.02, -97.02 42.03, -97.02 42.04, -97.02 42.05, -97.02 42.06, -97.02 42.07, -97.02 42.08, -97.02 42.09, -97.01 42.09, -97 42.09, -96.99 42.09, -96.98 42.09, -96.97 42.09, -96.96 42.09, -96.95 42.09, -96.94 42.09, -96.93 42.09, -96.92 42.09, -96.91 42.09, -96.9 42.09, -96.89 42.09, -96.88 42.09, -96.87 42.09, -96.86 42.09, -96.85 42.09, -96.84 42.09, -96.83 42.09, -96.82 42.09, -96.81 42.09, -96.8 42.09, -96.79 42.09, -96.78 42.09, -96.77 42.09, -96.76 42.09, -96.75 42.09, -96.74 42.09, -96.73 42.09, -96.72 42.09, -96.71 42.09, -96.7 42.09, -96.69 42.09, -96.68 42.09, -96.67 42.09, -96.66 42.09, -96.65 42.09, -96.64 42.09, -96.63 42.09, -96.62 42.09, -96.61 42.09, -96.6 42.09, -96.59 42.09, -96.58 42.09, -96.57 42.09, -96.56 42.09, -96.56 42.08, -96.56 42.07, -96.56 42.06, -96.56 42.05, -96.56 42.04, -96.56 42.03, -96.56 42.02, -96.55 42.02, -96.56 42, -96.56 41.99, -96.56 41.98, -96.56 41.97, -96.56 41.96, -96.56 41.95, -96.56 41.94, -96.56 41.93, -96.56 41.92, -96.56 41.91, -96.56 41.9, -96.56 41.89, -96.56 41.88, -96.56 41.87, -96.56 41.86, -96.56 41.85, -96.56 41.84, -96.56 41.83, -96.56 41.82, -96.56 41.81, -96.56 41.8, -96.56 41.79, -96.56 41.78, -96.56 41.77, -96.56 41.76, -96.56 41.75, -96.56 41.74, -96.57 41.74, -96.58 41.74, -96.59 41.74, -96.6 41.74, -96.61 41.74, -96.62 41.74, -96.63 41.74, -96.64 41.74, -96.65 41.74, -96.66 41.74, -96.67 41.74, -96.68 41.74, -96.69 41.74, -96.7 41.74, -96.71 41.74, -96.72 41.74, -96.73 41.74, -96.74 41.74, -96.75 41.74, -96.76 41.74, -96.77 41.74, -96.78 41.74, -96.79 41.74, -96.8 41.74, -96.81 41.74, -96.82 41.74, -96.83 41.74, -96.84 41.74, -96.85 41.74, -96.86 41.74, -96.87 41.74, -96.88 41.74, -96.89 41.74, -96.9 41.74, -96.91 41.74, -96.92 41.74, -96.93 41.74, -96.94 41.74, -96.95 41.74, -96.96 41.74, -96.97 41.74, -96.98 41.74, -96.99 41.74, -97 41.74, -97.01 41.74, -97.02 41.74, -97.02 41.75, -97.02 41.76, -97.02 41.77, -97.02 41.78, -97.02 41.79, -97.02 41.8, -97.02 41.81, -97.02 41.82, -97.02 41.83, -97.02 41.84, -97.02 41.85, -97.02 41.86, -97.02 41.87, -97.02 41.88, -97.02 41.89, -97.02 41.9, -97.02 41.91, -97.02 41.92, -97.02 41.93, -97.02 41.94, -97.02 41.95, -97.02 41.96, -97.02 41.97, -97.02 41.98, -97.02 41.99, -97.02 42, -97.02 42.01))" + assert ( + actual + == "POLYGON ((-97.02 42.01, -97.02 42.02, -97.02 42.03, -97.02 42.04, -97.02 42.05, -97.02 42.06, -97.02 42.07, -97.02 42.08, -97.02 42.09, -97.01 42.09, -97 42.09, -96.99 42.09, -96.98 42.09, -96.97 42.09, -96.96 42.09, -96.95 42.09, -96.94 42.09, -96.93 42.09, -96.92 42.09, -96.91 42.09, -96.9 42.09, -96.89 42.09, -96.88 42.09, -96.87 42.09, -96.86 42.09, -96.85 42.09, -96.84 42.09, -96.83 42.09, -96.82 42.09, -96.81 42.09, -96.8 42.09, -96.79 42.09, -96.78 42.09, -96.77 42.09, -96.76 42.09, -96.75 42.09, -96.74 42.09, -96.73 42.09, -96.72 42.09, -96.71 42.09, -96.7 42.09, -96.69 42.09, -96.68 42.09, -96.67 42.09, -96.66 42.09, -96.65 42.09, -96.64 42.09, -96.63 42.09, -96.62 42.09, -96.61 42.09, -96.6 42.09, -96.59 42.09, -96.58 42.09, -96.57 42.09, -96.56 42.09, -96.56 42.08, -96.56 42.07, -96.56 42.06, -96.56 42.05, -96.56 42.04, -96.56 42.03, -96.56 42.02, -96.55 42.02, -96.56 42, -96.56 41.99, -96.56 41.98, -96.56 41.97, -96.56 41.96, -96.56 41.95, -96.56 41.94, -96.56 41.93, -96.56 41.92, -96.56 41.91, -96.56 41.9, -96.56 41.89, -96.56 41.88, -96.56 41.87, -96.56 41.86, -96.56 41.85, -96.56 41.84, -96.56 41.83, -96.56 41.82, -96.56 41.81, -96.56 41.8, -96.56 41.79, -96.56 41.78, -96.56 41.77, -96.56 41.76, -96.56 41.75, -96.56 41.74, -96.57 41.74, -96.58 41.74, -96.59 41.74, -96.6 41.74, -96.61 41.74, -96.62 41.74, -96.63 41.74, -96.64 41.74, -96.65 41.74, -96.66 41.74, -96.67 41.74, -96.68 41.74, -96.69 41.74, -96.7 41.74, -96.71 41.74, -96.72 41.74, -96.73 41.74, -96.74 41.74, -96.75 41.74, -96.76 41.74, -96.77 41.74, -96.78 41.74, -96.79 41.74, -96.8 41.74, -96.81 41.74, -96.82 41.74, -96.83 41.74, -96.84 41.74, -96.85 41.74, -96.86 41.74, -96.87 41.74, -96.88 41.74, -96.89 41.74, -96.9 41.74, -96.91 41.74, -96.92 41.74, -96.93 41.74, -96.94 41.74, -96.95 41.74, -96.96 41.74, -96.97 41.74, -96.98 41.74, -96.99 41.74, -97 41.74, -97.01 41.74, -97.02 41.74, -97.02 41.75, -97.02 41.76, -97.02 41.77, -97.02 41.78, -97.02 41.79, -97.02 41.8, -97.02 41.81, -97.02 41.82, -97.02 41.83, -97.02 41.84, -97.02 41.85, -97.02 41.86, -97.02 41.87, -97.02 41.88, -97.02 41.89, -97.02 41.9, -97.02 41.91, -97.02 41.92, -97.02 41.93, -97.02 41.94, -97.02 41.95, -97.02 41.96, -97.02 41.97, -97.02 41.98, -97.02 41.99, -97.02 42, -97.02 42.01))" + ) - function_df = self.spark.sql("select ST_ReducePrecision(ST_Buffer(polygondf.countyshape, 10, false, 'endcap=square'), 2) from polygondf") + function_df = self.spark.sql( + "select ST_ReducePrecision(ST_Buffer(polygondf.countyshape, 10, false, 'endcap=square'), 2) from polygondf" + ) actual = function_df.take(1)[0][0].wkt - assert actual == "POLYGON ((-107.02 42.06, -107.02 42.07, -107.02 42.09, -107.02 42.11, -107.02 42.32, -107.02 42.33, -107.01 42.42, -107.01 42.43, -106.77 44.33, -106.16 46.15, -105.22 47.82, -103.98 49.27, -102.48 50.47, -100.78 51.36, -98.94 51.9, -97.04 52.09, -97.03 52.09, -97.01 52.09, -96.95 52.09, -96.9 52.09, -96.81 52.09, -96.7 52.09, -96.68 52.09, -96.65 52.09, -96.55 52.09, -96.54 52.09, -96.49 52.09, -96.48 52.09, -94.58 51.89, -92.74 51.33, -91.04 50.43, -89.55 49.23, -88.32 47.76, -87.39 46.08, -86.79 44.25, -86.56 42.35, -86.56 42.18, -86.56 42.17, -86.56 42.1, -86.55 41.99, -86.56 41.9, -86.56 41.78, -86.56 41.77, -86.56 41.75, -86.56 41.73, -86.56 41.7, -86.75 39.76, -87.33 37.89, -88.25 36.17, -89.49 34.66, -91 33.43, -92.72 32.51, -94.59 31.94, -96.53 31.74, -96.55 31.74, -96.56 31.74, -96.57 31.74, -96.58 31.74, -96.6 31.74, -96.69 31.74, -96.72 31.74, -96.75 31.74, -96.94 31.74, -97.02 31.74, -97.04 31.74, -97.06 31.74, -98.99 31.94, -100.85 32.5, -102.56 33.42, -104.07 34.65, -105.31 36.14, -106.23 37.85, -106.81 39.71, -107.02 41.64, -107.02 41.75, -107.02 41.94, -107.02 41.96, -107.02 41.99, -107.02 42.01, -107.02 42.02, -107.02 42.03, -107.02 42.06))" + assert ( + actual + == "POLYGON ((-107.02 42.06, -107.02 42.07, -107.02 42.09, -107.02 42.11, -107.02 42.32, -107.02 42.33, -107.01 42.42, -107.01 42.43, -106.77 44.33, -106.16 46.15, -105.22 47.82, -103.98 49.27, -102.48 50.47, -100.78 51.36, -98.94 51.9, -97.04 52.09, -97.03 52.09, -97.01 52.09, -96.95 52.09, -96.9 52.09, -96.81 52.09, -96.7 52.09, -96.68 52.09, -96.65 52.09, -96.55 52.09, -96.54 52.09, -96.49 52.09, -96.48 52.09, -94.58 51.89, -92.74 51.33, -91.04 50.43, -89.55 49.23, -88.32 47.76, -87.39 46.08, -86.79 44.25, -86.56 42.35, -86.56 42.18, -86.56 42.17, -86.56 42.1, -86.55 41.99, -86.56 41.9, -86.56 41.78, -86.56 41.77, -86.56 41.75, -86.56 41.73, -86.56 41.7, -86.75 39.76, -87.33 37.89, -88.25 36.17, -89.49 34.66, -91 33.43, -92.72 32.51, -94.59 31.94, -96.53 31.74, -96.55 31.74, -96.56 31.74, -96.57 31.74, -96.58 31.74, -96.6 31.74, -96.69 31.74, -96.72 31.74, -96.75 31.74, -96.94 31.74, -97.02 31.74, -97.04 31.74, -97.06 31.74, -98.99 31.94, -100.85 32.5, -102.56 33.42, -104.07 34.65, -105.31 36.14, -106.23 37.85, -106.81 39.71, -107.02 41.64, -107.02 41.75, -107.02 41.94, -107.02 41.96, -107.02 41.99, -107.02 42.01, -107.02 42.02, -107.02 42.03, -107.02 42.06))" + ) - function_df = self.spark.sql("select ST_ReducePrecision(ST_Buffer(polygondf.countyshape, 10, true, 'endcap=square'), 2) from polygondf") + function_df = self.spark.sql( + "select ST_ReducePrecision(ST_Buffer(polygondf.countyshape, 10, true, 'endcap=square'), 2) from polygondf" + ) actual = function_df.take(1)[0][0].wkt - assert actual == "POLYGON ((-97.02 42.01, -97.02 42.02, -97.02 42.03, -97.02 42.04, -97.02 42.05, -97.02 42.06, -97.02 42.07, -97.02 42.08, -97.02 42.09, -97.01 42.09, -97 42.09, -96.99 42.09, -96.98 42.09, -96.97 42.09, -96.96 42.09, -96.95 42.09, -96.94 42.09, -96.93 42.09, -96.92 42.09, -96.91 42.09, -96.9 42.09, -96.89 42.09, -96.88 42.09, -96.87 42.09, -96.86 42.09, -96.85 42.09, -96.84 42.09, -96.83 42.09, -96.82 42.09, -96.81 42.09, -96.8 42.09, -96.79 42.09, -96.78 42.09, -96.77 42.09, -96.76 42.09, -96.75 42.09, -96.74 42.09, -96.73 42.09, -96.72 42.09, -96.71 42.09, -96.7 42.09, -96.69 42.09, -96.68 42.09, -96.67 42.09, -96.66 42.09, -96.65 42.09, -96.64 42.09, -96.63 42.09, -96.62 42.09, -96.61 42.09, -96.6 42.09, -96.59 42.09, -96.58 42.09, -96.57 42.09, -96.56 42.09, -96.56 42.08, -96.56 42.07, -96.56 42.06, -96.56 42.05, -96.56 42.04, -96.56 42.03, -96.55 42.03, -96.55 42.02, -96.56 42, -96.56 41.99, -96.56 41.98, -96.56 41.97, -96.56 41.96, -96.56 41.95, -96.56 41.94, -96.56 41.93, -96.56 41.92, -96.56 41.91, -96.56 41.9, -96.56 41.89, -96.56 41.88, -96.56 41.87, -96.56 41.86, -96.56 41.85, -96.56 41.84, -96.56 41.83, -96.56 41.82, -96.56 41.81, -96.56 41.8, -96.56 41.79, -96.56 41.78, -96.56 41.77, -96.56 41.76, -96.56 41.75, -96.56 41.74, -96.57 41.74, -96.58 41.74, -96.59 41.74, -96.6 41.74, -96.61 41.74, -96.62 41.74, -96.63 41.74, -96.64 41.74, -96.65 41.74, -96.66 41.74, -96.67 41.74, -96.68 41.74, -96.69 41.74, -96.7 41.74, -96.71 41.74, -96.72 41.74, -96.73 41.74, -96.74 41.74, -96.75 41.74, -96.76 41.74, -96.77 41.74, -96.78 41.74, -96.79 41.74, -96.8 41.74, -96.81 41.74, -96.82 41.74, -96.83 41.74, -96.84 41.74, -96.85 41.74, -96.86 41.74, -96.87 41.74, -96.88 41.74, -96.89 41.74, -96.9 41.74, -96.91 41.74, -96.92 41.74, -96.93 41.74, -96.94 41.74, -96.95 41.74, -96.96 41.74, -96.97 41.74, -96.98 41.74, -96.99 41.74, -97 41.74, -97.01 41.74, -97.02 41.74, -97.02 41.75, -97.02 41.76, -97.02 41.77, -97.02 41.78, -97.02 41.79, -97.02 41.8, -97.02 41.81, -97.02 41.82, -97.02 41.83, -97.02 41.84, -97.02 41.85, -97.02 41.86, -97.02 41.87, -97.02 41.88, -97.02 41.89, -97.02 41.9, -97.02 41.91, -97.02 41.92, -97.02 41.93, -97.02 41.94, -97.02 41.95, -97.02 41.96, -97.02 41.97, -97.02 41.98, -97.02 41.99, -97.02 42, -97.02 42.01))" + assert ( + actual + == "POLYGON ((-97.02 42.01, -97.02 42.02, -97.02 42.03, -97.02 42.04, -97.02 42.05, -97.02 42.06, -97.02 42.07, -97.02 42.08, -97.02 42.09, -97.01 42.09, -97 42.09, -96.99 42.09, -96.98 42.09, -96.97 42.09, -96.96 42.09, -96.95 42.09, -96.94 42.09, -96.93 42.09, -96.92 42.09, -96.91 42.09, -96.9 42.09, -96.89 42.09, -96.88 42.09, -96.87 42.09, -96.86 42.09, -96.85 42.09, -96.84 42.09, -96.83 42.09, -96.82 42.09, -96.81 42.09, -96.8 42.09, -96.79 42.09, -96.78 42.09, -96.77 42.09, -96.76 42.09, -96.75 42.09, -96.74 42.09, -96.73 42.09, -96.72 42.09, -96.71 42.09, -96.7 42.09, -96.69 42.09, -96.68 42.09, -96.67 42.09, -96.66 42.09, -96.65 42.09, -96.64 42.09, -96.63 42.09, -96.62 42.09, -96.61 42.09, -96.6 42.09, -96.59 42.09, -96.58 42.09, -96.57 42.09, -96.56 42.09, -96.56 42.08, -96.56 42.07, -96.56 42.06, -96.56 42.05, -96.56 42.04, -96.56 42.03, -96.55 42.03, -96.55 42.02, -96.56 42, -96.56 41.99, -96.56 41.98, -96.56 41.97, -96.56 41.96, -96.56 41.95, -96.56 41.94, -96.56 41.93, -96.56 41.92, -96.56 41.91, -96.56 41.9, -96.56 41.89, -96.56 41.88, -96.56 41.87, -96.56 41.86, -96.56 41.85, -96.56 41.84, -96.56 41.83, -96.56 41.82, -96.56 41.81, -96.56 41.8, -96.56 41.79, -96.56 41.78, -96.56 41.77, -96.56 41.76, -96.56 41.75, -96.56 41.74, -96.57 41.74, -96.58 41.74, -96.59 41.74, -96.6 41.74, -96.61 41.74, -96.62 41.74, -96.63 41.74, -96.64 41.74, -96.65 41.74, -96.66 41.74, -96.67 41.74, -96.68 41.74, -96.69 41.74, -96.7 41.74, -96.71 41.74, -96.72 41.74, -96.73 41.74, -96.74 41.74, -96.75 41.74, -96.76 41.74, -96.77 41.74, -96.78 41.74, -96.79 41.74, -96.8 41.74, -96.81 41.74, -96.82 41.74, -96.83 41.74, -96.84 41.74, -96.85 41.74, -96.86 41.74, -96.87 41.74, -96.88 41.74, -96.89 41.74, -96.9 41.74, -96.91 41.74, -96.92 41.74, -96.93 41.74, -96.94 41.74, -96.95 41.74, -96.96 41.74, -96.97 41.74, -96.98 41.74, -96.99 41.74, -97 41.74, -97.01 41.74, -97.02 41.74, -97.02 41.75, -97.02 41.76, -97.02 41.77, -97.02 41.78, -97.02 41.79, -97.02 41.8, -97.02 41.81, -97.02 41.82, -97.02 41.83, -97.02 41.84, -97.02 41.85, -97.02 41.86, -97.02 41.87, -97.02 41.88, -97.02 41.89, -97.02 41.9, -97.02 41.91, -97.02 41.92, -97.02 41.93, -97.02 41.94, -97.02 41.95, -97.02 41.96, -97.02 41.97, -97.02 41.98, -97.02 41.99, -97.02 42, -97.02 42.01))" + ) def test_st_bestsrid(self): - polygon_from_wkt = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_from_wkt = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_from_wkt.createOrReplaceTempView("polgontable") polygon_from_wkt.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_BestSRID(polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_BestSRID(polygondf.countyshape) from polygondf" + ) function_df.show() actual = function_df.take(1)[0][0] assert actual == 3395 def test_st_bestsrid(self): - polygon_from_wkt = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_from_wkt = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_from_wkt.createOrReplaceTempView("polgontable") polygon_from_wkt.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_BestSRID(polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_BestSRID(polygondf.countyshape) from polygondf" + ) function_df.show() actual = function_df.take(1)[0][0] assert actual == 3395 def test_st_shiftlongitude(self): - function_df = self.spark.sql("select ST_ShiftLongitude(ST_GeomFromWKT('POLYGON((179 10, -179 10, -179 20, 179 20, 179 10))'))") + function_df = self.spark.sql( + "select ST_ShiftLongitude(ST_GeomFromWKT('POLYGON((179 10, -179 10, -179 20, 179 20, 179 10))'))" + ) actual = function_df.take(1)[0][0].wkt assert actual == "POLYGON ((179 10, 181 10, 181 20, 179 20, 179 10))" - function_df = self.spark.sql("select ST_ShiftLongitude(ST_GeomFromWKT('POINT(-179 10)'))") + function_df = self.spark.sql( + "select ST_ShiftLongitude(ST_GeomFromWKT('POINT(-179 10)'))" + ) actual = function_df.take(1)[0][0].wkt assert actual == "POINT (181 10)" - function_df = self.spark.sql("select ST_ShiftLongitude(ST_GeomFromWKT('LINESTRING(179 10, 181 10)'))") + function_df = self.spark.sql( + "select ST_ShiftLongitude(ST_GeomFromWKT('LINESTRING(179 10, 181 10)'))" + ) actual = function_df.take(1)[0][0].wkt assert actual == "LINESTRING (179 10, -179 10)" def test_st_envelope(self): - polygon_from_wkt = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_from_wkt = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_from_wkt.createOrReplaceTempView("polygontable") polygon_from_wkt.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_Envelope(polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_Envelope(polygondf.countyshape) from polygondf" + ) function_df.show() def test_st_expand(self): baseDf = self.spark.sql( - "SELECT ST_GeomFromWKT('POLYGON ((50 50 1, 50 80 2, 80 80 3, 80 50 2, 50 50 1))') as geom") + "SELECT ST_GeomFromWKT('POLYGON ((50 50 1, 50 80 2, 80 80 3, 80 50 2, 50 50 1))') as geom" + ) actual = baseDf.selectExpr("ST_AsText(ST_Expand(geom, 10))").first()[0] expected = "POLYGON Z((40 40 -9, 40 90 -9, 90 90 13, 90 40 13, 40 40 -9))" assert expected == actual @@ -187,146 +251,203 @@ def test_st_expand(self): assert expected == actual def test_st_centroid(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_Centroid(polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_Centroid(polygondf.countyshape) from polygondf" + ) function_df.show() def test_st_crossesdateline(self): crosses_test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((175 10, -175 10, -175 -10, 175 -10, 175 10))') as geom") + "select ST_GeomFromWKT('POLYGON((175 10, -175 10, -175 -10, 175 -10, 175 10))') as geom" + ) crosses_test_table.createOrReplaceTempView("crossesTesttable") - crosses = self.spark.sql("select(ST_CrossesDateLine(geom)) from crossesTesttable") + crosses = self.spark.sql( + "select(ST_CrossesDateLine(geom)) from crossesTesttable" + ) not_crosses_test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((1 1, 4 1, 4 4, 1 4, 1 1))') as geom") + "select ST_GeomFromWKT('POLYGON((1 1, 4 1, 4 4, 1 4, 1 1))') as geom" + ) not_crosses_test_table.createOrReplaceTempView("notCrossesTesttable") - not_crosses = self.spark.sql("select(ST_CrossesDateLine(geom)) from notCrossesTesttable") + not_crosses = self.spark.sql( + "select(ST_CrossesDateLine(geom)) from notCrossesTesttable" + ) assert crosses.take(1)[0][0] assert not not_crosses.take(1)[0][0] def test_st_length(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false").load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_Length(polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_Length(polygondf.countyshape) from polygondf" + ) function_df.show() def test_st_length2d(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false").load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") - function_df = self.spark.sql("select ST_Length2D(polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_Length2D(polygondf.countyshape) from polygondf" + ) assert function_df.take(1)[0][0] == 1.6244272911181594 def test_st_area(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_Area(polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_Area(polygondf.countyshape) from polygondf" + ) function_df.show() def test_st_distance(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - function_df = self.spark.sql("select ST_Distance(polygondf.countyshape, polygondf.countyshape) from polygondf") + function_df = self.spark.sql( + "select ST_Distance(polygondf.countyshape, polygondf.countyshape) from polygondf" + ) function_df.show() def test_st_3ddistance(self): - function_df = self.spark.sql("select ST_3DDistance(ST_PointZ(0.0, 0.0, 5.0), ST_PointZ(1.0, 1.0, -6.0))") + function_df = self.spark.sql( + "select ST_3DDistance(ST_PointZ(0.0, 0.0, 5.0), ST_PointZ(1.0, 1.0, -6.0))" + ) assert function_df.count() == 1 def test_st_transform(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") polygon_wkt_df.show() - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() function_df = self.spark.sql( - "select ST_ReducePrecision(ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false), 2) from polygondf") + "select ST_ReducePrecision(ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false), 2) from polygondf" + ) actual = function_df.take(1)[0][0].wkt - assert actual[:300] == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340" + assert ( + actual[:300] + == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340" + ) function_df = self.spark.sql( "select ST_ReducePrecision(ST_Transform(ST_SetSRID(polygondf.countyshape, 4326), 'epsg:3857'), 2) from polygondf" ) actual = function_df.take(1)[0][0].wkt - assert actual[:300] == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340" + assert ( + actual[:300] + == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340" + ) def test_st_intersection_intersects_but_not_contains(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((1 1, 8 1, 8 8, 1 8, 1 1))') as a,ST_GeomFromWKT('POLYGON((2 2, 9 2, 9 9, 2 9, 2 2))') as b") + "select ST_GeomFromWKT('POLYGON((1 1, 8 1, 8 8, 1 8, 1 1))') as a,ST_GeomFromWKT('POLYGON((2 2, 9 2, 9 9, 2 9, 2 2))') as b" + ) test_table.createOrReplaceTempView("testtable") intersect = self.spark.sql("select ST_Intersection(a,b) from testtable") assert intersect.take(1)[0][0].wkt == "POLYGON ((2 8, 8 8, 8 2, 2 2, 2 8))" def test_st_intersection_intersects_but_left_contains_right(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((1 1, 1 5, 5 5, 1 1))') as a,ST_GeomFromWKT('POLYGON((2 2, 2 3, 3 3, 2 2))') as b") + "select ST_GeomFromWKT('POLYGON((1 1, 1 5, 5 5, 1 1))') as a,ST_GeomFromWKT('POLYGON((2 2, 2 3, 3 3, 2 2))') as b" + ) test_table.createOrReplaceTempView("testtable") intersects = self.spark.sql("select ST_Intersection(a,b) from testtable") assert intersects.take(1)[0][0].wkt == "POLYGON ((2 2, 2 3, 3 3, 2 2))" def test_st_intersection_intersects_but_right_contains_left(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((2 2, 2 3, 3 3, 2 2))') as a,ST_GeomFromWKT('POLYGON((1 1, 1 5, 5 5, 1 1))') as b") + "select ST_GeomFromWKT('POLYGON((2 2, 2 3, 3 3, 2 2))') as a,ST_GeomFromWKT('POLYGON((1 1, 1 5, 5 5, 1 1))') as b" + ) test_table.createOrReplaceTempView("testtable") intersects = self.spark.sql("select ST_Intersection(a,b) from testtable") assert intersects.take(1)[0][0].wkt == "POLYGON ((2 2, 2 3, 3 3, 2 2))" def test_st_intersection_not_intersects(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((40 21, 40 22, 40 23, 40 21))') as a,ST_GeomFromWKT('POLYGON((2 2, 9 2, 9 9, 2 9, 2 2))') as b") + "select ST_GeomFromWKT('POLYGON((40 21, 40 22, 40 23, 40 21))') as a,ST_GeomFromWKT('POLYGON((2 2, 9 2, 9 9, 2 9, 2 2))') as b" + ) test_table.createOrReplaceTempView("testtable") intersects = self.spark.sql("select ST_Intersection(a,b) from testtable") assert intersects.take(1)[0][0].wkt == "POLYGON EMPTY" def test_st_maximum_inscribed_circle(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((40 180, 110 160, 180 180, 180 120, 140 90, 160 40, 80 10, 70 40, 20 50, 40 180),(60 140, 50 90, 90 140, 60 140))') AS geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((40 180, 110 160, 180 180, 180 120, 140 90, 160 40, 80 10, 70 40, 20 50, 40 180),(60 140, 50 90, 90 140, 60 140))') AS geom" + ) actual = baseDf.selectExpr("ST_MaximumInscribedCircle(geom)").take(1)[0][0] center = actual.center.wkt assert center == "POINT (96.953125 76.328125)" @@ -336,19 +457,28 @@ def test_st_maximum_inscribed_circle(self): assert radius == 45.165845650018 def test_st_is_valid_detail(self): - baseDf = self.spark.sql("SELECT ST_GeomFromText('POLYGON ((0 0, 2 0, 2 2, 0 2, 1 1, 0 0))') AS geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromText('POLYGON ((0 0, 2 0, 2 2, 0 2, 1 1, 0 0))') AS geom" + ) actual = baseDf.selectExpr("ST_IsValidDetail(geom)").first()[0] expected = Row(valid=True, reason=None, location=None) assert expected == actual - baseDf = self.spark.sql("SELECT ST_GeomFromText('POLYGON ((0 0, 2 0, 1 1, 2 2, 0 2, 1 1, 0 0))') AS geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromText('POLYGON ((0 0, 2 0, 1 1, 2 2, 0 2, 1 1, 0 0))') AS geom" + ) actual = baseDf.selectExpr("ST_IsValidDetail(geom)").first()[0] - expected = Row(valid=False, reason="Ring Self-intersection at or near point (1.0, 1.0, NaN)", location= - self.spark.sql("SELECT ST_GeomFromText('POINT (1 1)')").first()[0]) + expected = Row( + valid=False, + reason="Ring Self-intersection at or near point (1.0, 1.0, NaN)", + location=self.spark.sql("SELECT ST_GeomFromText('POINT (1 1)')").first()[0], + ) assert expected == actual def test_st_is_valid_trajectory(self): - baseDf = self.spark.sql("SELECT ST_GeomFromText('LINESTRING M (0 0 1, 0 1 2)') as geom1, ST_GeomFromText('LINESTRING M (0 0 1, 0 1 1)') as geom2") + baseDf = self.spark.sql( + "SELECT ST_GeomFromText('LINESTRING M (0 0 1, 0 1 2)') as geom1, ST_GeomFromText('LINESTRING M (0 0 1, 0 1 1)') as geom2" + ) actual = baseDf.selectExpr("ST_IsValidTrajectory(geom1)").first()[0] assert actual @@ -357,8 +487,8 @@ def test_st_is_valid_trajectory(self): def test_st_is_valid(self): test_table = self.spark.sql( - "SELECT ST_IsValid(ST_GeomFromWKT('POLYGON((0 0, 10 0, 10 10, 0 10, 0 0), (15 15, 15 20, 20 20, 20 15, 15 15))')) AS a, " + - "ST_IsValid(ST_GeomFromWKT('POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))')) as b" + "SELECT ST_IsValid(ST_GeomFromWKT('POLYGON((0 0, 10 0, 10 10, 0 10, 0 0), (15 15, 15 20, 20 20, 20 15, 15 15))')) AS a, " + + "ST_IsValid(ST_GeomFromWKT('POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))')) as b" ) assert not test_table.take(1)[0][0] @@ -370,173 +500,255 @@ def test_fixed_null_pointer_exception_in_st_valid(self): def test_st_precision_reduce(self): test_table = self.spark.sql( - """SELECT ST_ReducePrecision(ST_GeomFromWKT('Point(0.1234567890123456789 0.1234567890123456789)'), 8)""") + """SELECT ST_ReducePrecision(ST_GeomFromWKT('Point(0.1234567890123456789 0.1234567890123456789)'), 8)""" + ) test_table.show(truncate=False) assert test_table.take(1)[0][0].x == 0.12345679 test_table = self.spark.sql( - """SELECT ST_ReducePrecision(ST_GeomFromWKT('Point(0.1234567890123456789 0.1234567890123456789)'), 11)""") + """SELECT ST_ReducePrecision(ST_GeomFromWKT('Point(0.1234567890123456789 0.1234567890123456789)'), 11)""" + ) test_table.show(truncate=False) assert test_table.take(1)[0][0].x == 0.12345678901 def test_st_is_simple(self): test_table = self.spark.sql( - "SELECT ST_IsSimple(ST_GeomFromText('POLYGON((1 1, 3 1, 3 3, 1 3, 1 1))')) AS a, " + - "ST_IsSimple(ST_GeomFromText('POLYGON((1 1,3 1,3 3,2 0,1 1))')) as b" + "SELECT ST_IsSimple(ST_GeomFromText('POLYGON((1 1, 3 1, 3 3, 1 3, 1 1))')) AS a, " + + "ST_IsSimple(ST_GeomFromText('POLYGON((1 1,3 1,3 3,2 0,1 1))')) as b" ) assert test_table.take(1)[0][0] assert not test_table.take(1)[0][1] def test_st_as_text(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") wkt_df = self.spark.sql("select ST_AsText(countyshape) as wkt from polygondf") - assert polygon_df.take(1)[0]["countyshape"].wkt == loads(wkt_df.take(1)[0]["wkt"]).wkt + assert ( + polygon_df.take(1)[0]["countyshape"].wkt + == loads(wkt_df.take(1)[0]["wkt"]).wkt + ) def test_st_astext_3d(self): - input_df = self.spark.createDataFrame([ - ("Point(21 52 87)",), - ("Polygon((0 0 1, 0 1 1, 1 1 1, 1 0 1, 0 0 1))",), - ("Linestring(0 0 1, 1 1 2, 1 0 3)",), - ("MULTIPOINT ((10 40 66), (40 30 77), (20 20 88), (30 10 99))",), - ( - "MULTIPOLYGON (((30 20 11, 45 40 11, 10 40 11, 30 20 11)), ((15 5 11, 40 10 11, 10 20 11, 5 10 11, 15 5 11)))",), - ("MULTILINESTRING ((10 10 11, 20 20 11, 10 40 11), (40 40 11, 30 30 11, 40 20 11, 30 10 11))",), - ( - "MULTIPOLYGON (((40 40 11, 20 45 11, 45 30 11, 40 40 11)), ((20 35 11, 10 30 11, 10 10 11, 30 5 11, 45 20 11, 20 35 11), (30 20 11, 20 15 11, 20 25 11, 30 20 11)))",), - ("POLYGON((0 0 11, 0 5 11, 5 5 11, 5 0 11, 0 0 11), (1 1 11, 2 1 11, 2 2 11, 1 2 11, 1 1 11))",), - ], ["wkt"]) + input_df = self.spark.createDataFrame( + [ + ("Point(21 52 87)",), + ("Polygon((0 0 1, 0 1 1, 1 1 1, 1 0 1, 0 0 1))",), + ("Linestring(0 0 1, 1 1 2, 1 0 3)",), + ("MULTIPOINT ((10 40 66), (40 30 77), (20 20 88), (30 10 99))",), + ( + "MULTIPOLYGON (((30 20 11, 45 40 11, 10 40 11, 30 20 11)), ((15 5 11, 40 10 11, 10 20 11, 5 10 11, 15 5 11)))", + ), + ( + "MULTILINESTRING ((10 10 11, 20 20 11, 10 40 11), (40 40 11, 30 30 11, 40 20 11, 30 10 11))", + ), + ( + "MULTIPOLYGON (((40 40 11, 20 45 11, 45 30 11, 40 40 11)), ((20 35 11, 10 30 11, 10 10 11, 30 5 11, 45 20 11, 20 35 11), (30 20 11, 20 15 11, 20 25 11, 30 20 11)))", + ), + ( + "POLYGON((0 0 11, 0 5 11, 5 5 11, 5 0 11, 0 0 11), (1 1 11, 2 1 11, 2 2 11, 1 2 11, 1 1 11))", + ), + ], + ["wkt"], + ) input_df.createOrReplaceTempView("input_wkt") - polygon_df = self.spark.sql("select ST_AsText(ST_GeomFromWkt(wkt)) as wkt from input_wkt") + polygon_df = self.spark.sql( + "select ST_AsText(ST_GeomFromWkt(wkt)) as wkt from input_wkt" + ) assert polygon_df.count() == 8 def test_st_as_text_3d(self): - polygon_wkt_df = self.spark.read.format("csv"). \ - option("delimiter", "\t"). \ - option("header", "false"). \ - load(mixed_wkt_geometry_input_location) + polygon_wkt_df = ( + self.spark.read.format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixed_wkt_geometry_input_location) + ) polygon_wkt_df.createOrReplaceTempView("polygontable") - polygon_df = self.spark.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable") + polygon_df = self.spark.sql( + "select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") wkt_df = self.spark.sql("select ST_AsText(countyshape) as wkt from polygondf") - assert polygon_df.take(1)[0]["countyshape"].wkt == loads(wkt_df.take(1)[0]["wkt"]).wkt + assert ( + polygon_df.take(1)[0]["countyshape"].wkt + == loads(wkt_df.take(1)[0]["wkt"]).wkt + ) def test_st_n_points(self): test = self.spark.sql( - "SELECT ST_NPoints(ST_GeomFromText('LINESTRING(77.29 29.07,77.42 29.26,77.27 29.31,77.29 29.07)'))") + "SELECT ST_NPoints(ST_GeomFromText('LINESTRING(77.29 29.07,77.42 29.26,77.27 29.31,77.29 29.07)'))" + ) def test_st_geometry_type(self): test = self.spark.sql( - "SELECT ST_GeometryType(ST_GeomFromText('LINESTRING(77.29 29.07,77.42 29.26,77.27 29.31,77.29 29.07)'))") + "SELECT ST_GeometryType(ST_GeomFromText('LINESTRING(77.29 29.07,77.42 29.26,77.27 29.31,77.29 29.07)'))" + ) def test_st_difference_right_overlaps_left(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((0 -4, 4 -4, 4 4, 0 4, 0 -4))') as b") + "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((0 -4, 4 -4, 4 4, 0 4, 0 -4))') as b" + ) test_table.createOrReplaceTempView("test_diff") diff = self.spark.sql("select ST_Difference(a,b) from test_diff") assert diff.take(1)[0][0].wkt == "POLYGON ((0 -3, -3 -3, -3 3, 0 3, 0 -3))" def test_st_difference_right_not_overlaps_left(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))') as b") + "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))') as b" + ) test_table.createOrReplaceTempView("test_diff") diff = self.spark.sql("select ST_Difference(a,b) from test_diff") assert diff.take(1)[0][0].wkt == "POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))" def test_st_difference_left_contains_right(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as b") + "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as b" + ) test_table.createOrReplaceTempView("test_diff") diff = self.spark.sql("select ST_Difference(a,b) from test_diff") - assert diff.take(1)[0][0].wkt == "POLYGON ((-3 -3, -3 3, 3 3, 3 -3, -3 -3), (-1 -1, 1 -1, 1 1, -1 1, -1 -1))" + assert ( + diff.take(1)[0][0].wkt + == "POLYGON ((-3 -3, -3 3, 3 3, 3 -3, -3 -3), (-1 -1, 1 -1, 1 1, -1 1, -1 -1))" + ) def test_st_difference_right_not_overlaps_left(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as a,ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as b") + "select ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as a,ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as b" + ) test_table.createOrReplaceTempView("test_diff") diff = self.spark.sql("select ST_Difference(a,b) from test_diff") assert diff.take(1)[0][0].wkt == "POLYGON EMPTY" def test_st_delaunay_triangles(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('MULTIPOLYGON (((10 10, 10 20, 20 20, 20 10, 10 10)),((25 10, 25 20, 35 20, 35 10, 25 10)))') AS geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('MULTIPOLYGON (((10 10, 10 20, 20 20, 20 10, 10 10)),((25 10, 25 20, 35 20, 35 10, 25 10)))') AS geom" + ) actual = baseDf.selectExpr("ST_DelaunayTriangles(geom)").take(1)[0][0].wkt expected = "GEOMETRYCOLLECTION (POLYGON ((10 20, 10 10, 20 10, 10 20)), POLYGON ((10 20, 20 10, 20 20, 10 20)), POLYGON ((20 20, 20 10, 25 10, 20 20)), POLYGON ((20 20, 25 10, 25 20, 20 20)), POLYGON ((25 20, 25 10, 35 10, 25 20)), POLYGON ((25 20, 35 10, 35 20, 25 20)))" assert expected == actual def test_st_sym_difference_part_of_right_overlaps_left(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as a,ST_GeomFromWKT('POLYGON ((0 -2, 2 -2, 2 0, 0 0, 0 -2))') as b") + "select ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as a,ST_GeomFromWKT('POLYGON ((0 -2, 2 -2, 2 0, 0 0, 0 -2))') as b" + ) test_table.createOrReplaceTempView("test_sym_diff") diff = self.spark.sql("select ST_SymDifference(a,b) from test_sym_diff") - assert diff.take(1)[0][ - 0].wkt == "MULTIPOLYGON (((0 -1, -1 -1, -1 1, 1 1, 1 0, 0 0, 0 -1)), ((0 -1, 1 -1, 1 0, 2 0, 2 -2, 0 -2, 0 -1)))" + assert ( + diff.take(1)[0][0].wkt + == "MULTIPOLYGON (((0 -1, -1 -1, -1 1, 1 1, 1 0, 0 0, 0 -1)), ((0 -1, 1 -1, 1 0, 2 0, 2 -2, 0 -2, 0 -1)))" + ) def test_st_sym_difference_not_overlaps_left(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))') as b") + "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))') as b" + ) test_table.createOrReplaceTempView("test_sym_diff") diff = self.spark.sql("select ST_SymDifference(a,b) from test_sym_diff") - assert diff.take(1)[0][ - 0].wkt == "MULTIPOLYGON (((-3 -3, -3 3, 3 3, 3 -3, -3 -3)), ((5 -3, 5 -1, 7 -1, 7 -3, 5 -3)))" + assert ( + diff.take(1)[0][0].wkt + == "MULTIPOLYGON (((-3 -3, -3 3, 3 3, 3 -3, -3 -3)), ((5 -3, 5 -1, 7 -1, 7 -3, 5 -3)))" + ) def test_st_sym_difference_contains(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as b") + "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as b" + ) test_table.createOrReplaceTempView("test_sym_diff") diff = self.spark.sql("select ST_SymDifference(a,b) from test_sym_diff") - assert diff.take(1)[0][0].wkt == "POLYGON ((-3 -3, -3 3, 3 3, 3 -3, -3 -3), (-1 -1, 1 -1, 1 1, -1 1, -1 -1))" + assert ( + diff.take(1)[0][0].wkt + == "POLYGON ((-3 -3, -3 3, 3 3, 3 -3, -3 -3), (-1 -1, 1 -1, 1 1, -1 1, -1 -1))" + ) def test_st_union_part_of_right_overlaps_left(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a, ST_GeomFromWKT('POLYGON ((-2 1, 2 1, 2 4, -2 4, -2 1))') as b") + "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a, ST_GeomFromWKT('POLYGON ((-2 1, 2 1, 2 4, -2 4, -2 1))') as b" + ) test_table.createOrReplaceTempView("test_union") union = self.spark.sql("select ST_Union(a,b) from test_union") - assert union.take(1)[0][0].wkt == "POLYGON ((2 3, 3 3, 3 -3, -3 -3, -3 3, -2 3, -2 4, 2 4, 2 3))" + assert ( + union.take(1)[0][0].wkt + == "POLYGON ((2 3, 3 3, 3 -3, -3 -3, -3 3, -2 3, -2 4, 2 4, 2 3))" + ) def test_st_union_not_overlaps_left(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))') as b") + "select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))') as b" + ) test_table.createOrReplaceTempView("test_union") union = self.spark.sql("select ST_Union(a,b) from test_union") - assert union.take(1)[0][ - 0].wkt == "MULTIPOLYGON (((-3 -3, -3 3, 3 3, 3 -3, -3 -3)), ((5 -3, 5 -1, 7 -1, 7 -3, 5 -3)))" + assert ( + union.take(1)[0][0].wkt + == "MULTIPOLYGON (((-3 -3, -3 3, 3 3, 3 -3, -3 -3)), ((5 -3, 5 -1, 7 -1, 7 -3, 5 -3)))" + ) def test_st_union_array_variant(self): - test_table = self.spark.sql("select array(ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))'),ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))'), ST_GeomFromWKT('POLYGON((4 4, 4 6, 6 6, 6 4, 4 4))')) as polys") + test_table = self.spark.sql( + "select array(ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))'),ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))'), ST_GeomFromWKT('POLYGON((4 4, 4 6, 6 6, 6 4, 4 4))')) as polys" + ) actual = test_table.selectExpr("ST_Union(polys)").take(1)[0][0].wkt expected = "MULTIPOLYGON (((5 -3, 5 -1, 7 -1, 7 -3, 5 -3)), ((-3 -3, -3 3, 3 3, 3 -3, -3 -3)), ((4 4, 4 6, 6 6, 6 4, 4 4)))" assert expected == actual def test_st_unary_union(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('MULTIPOLYGON(((0 10,0 30,20 30,20 10,0 10)),((10 0,10 20,30 20,30 0,10 0)))') AS geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('MULTIPOLYGON(((0 10,0 30,20 30,20 10,0 10)),((10 0,10 20,30 20,30 0,10 0)))') AS geom" + ) actual = baseDf.selectExpr("ST_UnaryUnion(geom)").take(1)[0][0].wkt - expected = "POLYGON ((10 0, 10 10, 0 10, 0 30, 20 30, 20 20, 30 20, 30 0, 10 0))" + expected = ( + "POLYGON ((10 0, 10 10, 0 10, 0 30, 20 30, 20 20, 30 20, 30 0, 10 0))" + ) assert expected == actual def test_st_azimuth(self): sample_points = create_sample_points(20) sample_pair_points = [[el, sample_points[1]] for el in sample_points] - schema = StructType([ - StructField("geomA", GeometryType(), True), - StructField("geomB", GeometryType(), True) - ]) + schema = StructType( + [ + StructField("geomA", GeometryType(), True), + StructField("geomB", GeometryType(), True), + ] + ) df = self.spark.createDataFrame(sample_pair_points, schema) - st_azimuth_result = [el[0] * 180 / math.pi for el in df.selectExpr("ST_Azimuth(geomA, geomB)").collect()] + st_azimuth_result = [ + el[0] * 180 / math.pi + for el in df.selectExpr("ST_Azimuth(geomA, geomB)").collect() + ] expected_result = [ - 240.0133139011053, 0.0, 270.0, 286.8042682202057, 315.0, 314.9543472191815, 315.0058223408927, - 245.14762725688198, 314.84984546897755, 314.8868529256147, 314.9510567053395, 314.95443984912936, - 314.89925480835245, 314.60187991438806, 314.6834083423315, 314.80689827870725, 314.90290827689506, - 314.90336326341765, 314.7510398533675, 314.73608518601935 + 240.0133139011053, + 0.0, + 270.0, + 286.8042682202057, + 315.0, + 314.9543472191815, + 315.0058223408927, + 245.14762725688198, + 314.84984546897755, + 314.8868529256147, + 314.9510567053395, + 314.95443984912936, + 314.89925480835245, + 314.60187991438806, + 314.6834083423315, + 314.80689827870725, + 314.90290827689506, + 314.90336326341765, + 314.7510398533675, + 314.73608518601935, ] assert st_azimuth_result == expected_result @@ -547,7 +759,10 @@ def test_st_azimuth(self): """ ).collect() - azimuths = [[azimuth1 * 180 / math.pi, azimuth2 * 180 / math.pi] for azimuth1, azimuth2 in azimuth] + azimuths = [ + [azimuth1 * 180 / math.pi, azimuth2 * 180 / math.pi] + for azimuth1, azimuth2 in azimuth + ] assert azimuths[0] == [42.27368900609373, 222.27368900609372] def test_st_x(self): @@ -555,36 +770,50 @@ def test_st_x(self): polygon_df = create_sample_polygons_df(self.spark, 5) linestring_df = create_sample_lines_df(self.spark, 5) - points = point_df \ - .selectExpr("ST_X(geom)").collect() + points = point_df.selectExpr("ST_X(geom)").collect() polygons = polygon_df.selectExpr("ST_X(geom) as x").filter("x IS NOT NULL") - linestrings = linestring_df.selectExpr("ST_X(geom) as x").filter("x IS NOT NULL") + linestrings = linestring_df.selectExpr("ST_X(geom) as x").filter( + "x IS NOT NULL" + ) - assert ([point[0] for point in points] == [-71.064544, -88.331492, 88.331492, 1.0453, 32.324142]) + assert [point[0] for point in points] == [ + -71.064544, + -88.331492, + 88.331492, + 1.0453, + 32.324142, + ] - assert (not linestrings.count()) + assert not linestrings.count() - assert (not polygons.count()) + assert not polygons.count() def test_st_y(self): point_df = create_sample_points_df(self.spark, 5) polygon_df = create_sample_polygons_df(self.spark, 5) linestring_df = create_sample_lines_df(self.spark, 5) - points = point_df \ - .selectExpr("ST_Y(geom)").collect() + points = point_df.selectExpr("ST_Y(geom)").collect() polygons = polygon_df.selectExpr("ST_Y(geom) as y").filter("y IS NOT NULL") - linestrings = linestring_df.selectExpr("ST_Y(geom) as y").filter("y IS NOT NULL") + linestrings = linestring_df.selectExpr("ST_Y(geom) as y").filter( + "y IS NOT NULL" + ) - assert ([point[0] for point in points] == [42.28787, 32.324142, 32.324142, 5.3324324, -88.331492]) + assert [point[0] for point in points] == [ + 42.28787, + 32.324142, + 32.324142, + 5.3324324, + -88.331492, + ] - assert (not linestrings.count()) + assert not linestrings.count() - assert (not polygons.count()) + assert not polygons.count() def test_st_z(self): point_df = self.spark.sql( @@ -597,47 +826,67 @@ def test_st_z(self): "select ST_GeomFromWKT('LINESTRING Z (0 0 1, 0 1 2)') as geom" ) - points = point_df \ - .selectExpr("ST_Z(geom)").collect() + points = point_df.selectExpr("ST_Z(geom)").collect() polygons = polygon_df.selectExpr("ST_Z(geom) as z").filter("z IS NOT NULL") - linestrings = linestring_df.selectExpr("ST_Z(geom) as z").filter("z IS NOT NULL") + linestrings = linestring_df.selectExpr("ST_Z(geom) as z").filter( + "z IS NOT NULL" + ) - assert ([point[0] for point in points] == [3.3]) + assert [point[0] for point in points] == [3.3] - assert (not linestrings.count()) + assert not linestrings.count() - assert (not polygons.count()) + assert not polygons.count() def test_st_zmflag(self): - actual = self.spark.sql("SELECT ST_Zmflag(ST_GeomFromWKT('POINT (1 2)'))").take(1)[0][0] + actual = self.spark.sql("SELECT ST_Zmflag(ST_GeomFromWKT('POINT (1 2)'))").take( + 1 + )[0][0] assert actual == 0 - actual = self.spark.sql("SELECT ST_Zmflag(ST_GeomFromWKT('LINESTRING (1 2 3, 4 5 6)'))").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_Zmflag(ST_GeomFromWKT('LINESTRING (1 2 3, 4 5 6)'))" + ).take(1)[0][0] assert actual == 2 - actual = self.spark.sql("SELECT ST_Zmflag(ST_GeomFromWKT('POLYGON M((1 2 3, 3 4 3, 5 6 3, 3 4 3, 1 2 3))'))").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_Zmflag(ST_GeomFromWKT('POLYGON M((1 2 3, 3 4 3, 5 6 3, 3 4 3, 1 2 3))'))" + ).take(1)[0][0] assert actual == 1 - actual = self.spark.sql("SELECT ST_Zmflag(ST_GeomFromWKT('MULTIPOLYGON ZM (((30 10 5 1, 40 40 10 2, 20 40 15 3, 10 20 20 4, 30 10 5 1)), ((15 5 3 1, 20 10 6 2, 10 10 7 3, 15 5 3 1)))'))").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_Zmflag(ST_GeomFromWKT('MULTIPOLYGON ZM (((30 10 5 1, 40 40 10 2, 20 40 15 3, 10 20 20 4, 30 10 5 1)), ((15 5 3 1, 20 10 6 2, 10 10 7 3, 15 5 3 1)))'))" + ).take(1)[0][0] assert actual == 3 def test_st_z_max(self): - linestring_df = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING Z (0 0 1, 0 1 2)') as geom") - linestring_row = [lnstr_row[0] for lnstr_row in linestring_df.selectExpr("ST_ZMax(geom)").collect()] - assert (linestring_row == [2.0]) + linestring_df = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING Z (0 0 1, 0 1 2)') as geom" + ) + linestring_row = [ + lnstr_row[0] + for lnstr_row in linestring_df.selectExpr("ST_ZMax(geom)").collect() + ] + assert linestring_row == [2.0] def test_st_z_min(self): linestring_df = self.spark.sql( - "SELECT ST_GeomFromWKT('POLYGON Z ((0 0 2, 0 1 1, 1 1 2, 1 0 2, 0 0 2))') as geom") - linestring_row = [lnstr_row[0] for lnstr_row in linestring_df.selectExpr("ST_ZMin(geom)").collect()] - assert (linestring_row == [1.0]) + "SELECT ST_GeomFromWKT('POLYGON Z ((0 0 2, 0 1 1, 1 1 2, 1 0 2, 0 0 2))') as geom" + ) + linestring_row = [ + lnstr_row[0] + for lnstr_row in linestring_df.selectExpr("ST_ZMin(geom)").collect() + ] + assert linestring_row == [1.0] def test_st_n_dims(self): point_df = self.spark.sql("SELECT ST_GeomFromWKT('POINT(1 1 2)') as geom") - point_row = [pt_row[0] for pt_row in point_df.selectExpr("ST_NDims(geom)").collect()] - assert (point_row == [3]) + point_row = [ + pt_row[0] for pt_row in point_df.selectExpr("ST_NDims(geom)").collect() + ] + assert point_row == [3] def test_st_start_point(self): @@ -650,62 +899,85 @@ def test_st_start_point(self): "POINT (-112.519856 45.983586)", "POINT (-112.504872 45.919281)", "POINT (-112.574945 45.987772)", - "POINT (-112.520691 42.912313)" + "POINT (-112.520691 42.912313)", ] - points = point_df.selectExpr("ST_StartPoint(geom) as geom").filter("geom IS NOT NULL") + points = point_df.selectExpr("ST_StartPoint(geom) as geom").filter( + "geom IS NOT NULL" + ) - polygons = polygon_df.selectExpr("ST_StartPoint(geom) as geom").filter("geom IS NOT NULL") + polygons = polygon_df.selectExpr("ST_StartPoint(geom) as geom").filter( + "geom IS NOT NULL" + ) - linestrings = linestring_df.selectExpr("ST_StartPoint(geom) as geom").filter("geom IS NOT NULL") + linestrings = linestring_df.selectExpr("ST_StartPoint(geom) as geom").filter( + "geom IS NOT NULL" + ) - assert ([line[0] for line in linestrings.collect()] == [wkt.loads(el) for el in expected_points]) + assert [line[0] for line in linestrings.collect()] == [ + wkt.loads(el) for el in expected_points + ] - assert (not points.count()) + assert not points.count() - assert (not polygons.count()) + assert not polygons.count() def test_st_snap(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON((2.6 12.5, 2.6 20.0, 12.6 20.0, 12.6 12.5, 2.6 12.5 " - "))') AS poly, ST_GeomFromWKT('LINESTRING (0.5 10.7, 5.4 8.4, 10.1 10.0)') AS line") - actual = baseDf.selectExpr("ST_AsText(ST_Snap(poly, line, 2.525))").take(1)[0][0] + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON((2.6 12.5, 2.6 20.0, 12.6 20.0, 12.6 12.5, 2.6 12.5 " + "))') AS poly, ST_GeomFromWKT('LINESTRING (0.5 10.7, 5.4 8.4, 10.1 10.0)') AS line" + ) + actual = baseDf.selectExpr("ST_AsText(ST_Snap(poly, line, 2.525))").take(1)[0][ + 0 + ] expected = "POLYGON ((2.6 12.5, 2.6 20, 12.6 20, 12.6 12.5, 10.1 10, 2.6 12.5))" - assert (expected == actual) + assert expected == actual - actual = baseDf.selectExpr("ST_AsText(ST_Snap(poly, line, 3.125))").take(1)[0][0] + actual = baseDf.selectExpr("ST_AsText(ST_Snap(poly, line, 3.125))").take(1)[0][ + 0 + ] expected = "POLYGON ((0.5 10.7, 2.6 20, 12.6 20, 12.6 12.5, 10.1 10, 5.4 8.4, 0.5 10.7))" - assert (expected == actual) + assert expected == actual def test_st_end_point(self): linestring_dataframe = create_sample_lines_df(self.spark, 5) - other_geometry_dataframe = create_sample_points_df(self.spark, 5). \ - union(create_sample_points_df(self.spark, 5)) + other_geometry_dataframe = create_sample_points_df(self.spark, 5).union( + create_sample_points_df(self.spark, 5) + ) - point_data_frame = linestring_dataframe.selectExpr("ST_EndPoint(geom) as geom"). \ - filter("geom IS NOT NULL") + point_data_frame = linestring_dataframe.selectExpr( + "ST_EndPoint(geom) as geom" + ).filter("geom IS NOT NULL") expected_ending_points = [ "POINT (-112.504872 45.98186)", "POINT (-112.506968 45.983586)", "POINT (-112.41643 45.919281)", "POINT (-112.519856 45.987772)", - "POINT (-112.442664 42.912313)" + "POINT (-112.442664 42.912313)", ] - empty_dataframe = other_geometry_dataframe.selectExpr("ST_EndPoint(geom) as geom"). \ - filter("geom IS NOT NULL") + empty_dataframe = other_geometry_dataframe.selectExpr( + "ST_EndPoint(geom) as geom" + ).filter("geom IS NOT NULL") - assert ([wkt_row[0] - for wkt_row in point_data_frame.selectExpr("ST_AsText(geom)").collect()] == expected_ending_points) + assert [ + wkt_row[0] + for wkt_row in point_data_frame.selectExpr("ST_AsText(geom)").collect() + ] == expected_ending_points - assert (empty_dataframe.count() == 0) + assert empty_dataframe.count() == 0 def test_st_minimum_clearance(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((65 18, 62 16, 64.5 16, 62 14, 65 14, 65 18))') as geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((65 18, 62 16, 64.5 16, 62 14, 65 14, 65 18))') as geom" + ) actual = baseDf.selectExpr("ST_MinimumClearance(geom)").take(1)[0][0] assert actual == 0.5 def test_st_minimum_clearance_line(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((65 18, 62 16, 64.5 16, 62 14, 65 14, 65 18))') as geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((65 18, 62 16, 64.5 16, 62 14, 65 14, 65 18))') as geom" + ) actual = baseDf.selectExpr("ST_MinimumClearanceLine(geom)").take(1)[0][0].wkt assert actual == "LINESTRING (64.5 16, 65 16)" @@ -714,14 +986,12 @@ def test_st_boundary(self): "LINESTRING(1 1,0 0, -1 1)", "LINESTRING(100 150,50 60, 70 80, 160 170)", "POLYGON (( 10 130, 50 190, 110 190, 140 150, 150 80, 100 10, 20 40, 10 130 ),( 70 40, 100 50, 120 80, 80 110, 50 90, 70 40 ))", - "POLYGON((1 1,0 0, -1 1, 1 1))" + "POLYGON((1 1,0 0, -1 1, 1 1))", ] geometries = [[wkt.loads(wkt_data)] for wkt_data in wkt_list] - schema = StructType( - [StructField("geom", GeometryType(), False)] - ) + schema = StructType([StructField("geom", GeometryType(), False)]) geometry_table = self.spark.createDataFrame(geometries, schema) @@ -729,70 +999,107 @@ def test_st_boundary(self): boundary_table = geometry_table.selectExpr("ST_Boundary(geom) as geom") - boundary_wkt = [wkt_row[0] for wkt_row in boundary_table.selectExpr("ST_AsText(geom)").collect()] - assert (boundary_wkt == [ + boundary_wkt = [ + wkt_row[0] + for wkt_row in boundary_table.selectExpr("ST_AsText(geom)").collect() + ] + assert boundary_wkt == [ "MULTIPOINT ((1 1), (-1 1))", "MULTIPOINT ((100 150), (160 170))", "MULTILINESTRING ((10 130, 50 190, 110 190, 140 150, 150 80, 100 10, 20 40, 10 130), (70 40, 100 50, 120 80, 80 110, 50 90, 70 40))", - "LINESTRING (1 1, 0 0, -1 1, 1 1)" - ]) + "LINESTRING (1 1, 0 0, -1 1, 1 1)", + ] def test_st_exterior_ring(self): polygon_df = create_simple_polygons_df(self.spark, 5) additional_wkt = "POLYGON((0 0, 1 1, 1 2, 1 1, 0 0))" - additional_wkt_df = self.spark.createDataFrame([[wkt.loads(additional_wkt)]], self.geo_schema) + additional_wkt_df = self.spark.createDataFrame( + [[wkt.loads(additional_wkt)]], self.geo_schema + ) polygons_df = polygon_df.union(additional_wkt_df) other_geometry_df = create_sample_lines_df(self.spark, 5).union( - create_sample_points_df(self.spark, 5)) + create_sample_points_df(self.spark, 5) + ) - linestring_df = polygons_df.selectExpr("ST_ExteriorRing(geom) as geom").filter("geom IS NOT NULL") + linestring_df = polygons_df.selectExpr("ST_ExteriorRing(geom) as geom").filter( + "geom IS NOT NULL" + ) - empty_df = other_geometry_df.selectExpr("ST_ExteriorRing(geom) as geom").filter("geom IS NOT NULL") + empty_df = other_geometry_df.selectExpr("ST_ExteriorRing(geom) as geom").filter( + "geom IS NOT NULL" + ) - linestring_wkt = [wkt_row[0] for wkt_row in linestring_df.selectExpr("ST_AsText(geom)").collect()] + linestring_wkt = [ + wkt_row[0] + for wkt_row in linestring_df.selectExpr("ST_AsText(geom)").collect() + ] - assert (linestring_wkt == ["LINESTRING (0 0, 0 1, 1 1, 1 0, 0 0)", "LINESTRING (0 0, 1 1, 1 2, 1 1, 0 0)"]) + assert linestring_wkt == [ + "LINESTRING (0 0, 0 1, 1 1, 1 0, 0 0)", + "LINESTRING (0 0, 1 1, 1 2, 1 1, 0 0)", + ] - assert (not empty_df.count()) + assert not empty_df.count() def test_st_geometry_n(self): - data_frame = self.__wkt_list_to_data_frame(["MULTIPOINT((1 2), (3 4), (5 6), (8 9))"]) - wkts = [data_frame.selectExpr(f"ST_GeometryN(geom, {i}) as geom").selectExpr("st_asText(geom)").collect()[0][0] - for i in range(0, 4)] + data_frame = self.__wkt_list_to_data_frame( + ["MULTIPOINT((1 2), (3 4), (5 6), (8 9))"] + ) + wkts = [ + data_frame.selectExpr(f"ST_GeometryN(geom, {i}) as geom") + .selectExpr("st_asText(geom)") + .collect()[0][0] + for i in range(0, 4) + ] - assert (wkts == ["POINT (1 2)", "POINT (3 4)", "POINT (5 6)", "POINT (8 9)"]) + assert wkts == ["POINT (1 2)", "POINT (3 4)", "POINT (5 6)", "POINT (8 9)"] def test_st_interior_ring_n(self): polygon_df = self.__wkt_list_to_data_frame( [ - "POLYGON((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1), (1 3, 2 3, 2 4, 1 4, 1 3), (3 3, 4 3, 4 4, 3 4, 3 3))"] + "POLYGON((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1), (1 3, 2 3, 2 4, 1 4, 1 3), (3 3, 4 3, 4 4, 3 4, 3 3))" + ] ) - other_geometry = create_sample_points_df(self.spark, 5).union(create_sample_lines_df(self.spark, 5)) - wholes = [polygon_df.selectExpr(f"ST_InteriorRingN(geom, {i}) as geom"). - selectExpr("ST_AsText(geom)").collect()[0][0] - for i in range(3)] + other_geometry = create_sample_points_df(self.spark, 5).union( + create_sample_lines_df(self.spark, 5) + ) + wholes = [ + polygon_df.selectExpr(f"ST_InteriorRingN(geom, {i}) as geom") + .selectExpr("ST_AsText(geom)") + .collect()[0][0] + for i in range(3) + ] - empty_df = other_geometry.selectExpr("ST_InteriorRingN(geom, 1) as geom").filter("geom IS NOT NULL") + empty_df = other_geometry.selectExpr( + "ST_InteriorRingN(geom, 1) as geom" + ).filter("geom IS NOT NULL") - assert (not empty_df.count()) - assert (wholes == ["LINESTRING (1 1, 2 1, 2 2, 1 2, 1 1)", - "LINESTRING (1 3, 2 3, 2 4, 1 4, 1 3)", - "LINESTRING (3 3, 4 3, 4 4, 3 4, 3 3)"]) + assert not empty_df.count() + assert wholes == [ + "LINESTRING (1 1, 2 1, 2 2, 1 2, 1 1)", + "LINESTRING (1 3, 2 3, 2 4, 1 4, 1 3)", + "LINESTRING (3 3, 4 3, 4 4, 3 4, 3 3)", + ] def test_st_dumps(self): expected_geometries = [ - "POINT (21 52)", "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))", + "POINT (21 52)", + "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))", "LINESTRING (0 0, 1 1, 1 0)", - "POINT (10 40)", "POINT (40 30)", "POINT (20 20)", "POINT (30 10)", + "POINT (10 40)", + "POINT (40 30)", + "POINT (20 20)", + "POINT (30 10)", "POLYGON ((30 20, 45 40, 10 40, 30 20))", - "POLYGON ((15 5, 40 10, 10 20, 5 10, 15 5))", "LINESTRING (10 10, 20 20, 10 40)", + "POLYGON ((15 5, 40 10, 10 20, 5 10, 15 5))", + "LINESTRING (10 10, 20 20, 10 40)", "LINESTRING (40 40, 30 30, 40 20, 30 10)", "POLYGON ((40 40, 20 45, 45 30, 40 40))", "POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35), (30 20, 20 15, 20 25, 30 20))", - "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))" + "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))", ] geometry_df = self.__wkt_list_to_data_frame( @@ -804,20 +1111,21 @@ def test_st_dumps(self): "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))", "MULTIPOLYGON (((40 40, 20 45, 45 30, 40 40)), ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35), (30 20, 20 15, 20 25, 30 20)))", - "POLYGON((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))" + "POLYGON((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))", ] ) dumped_geometries = geometry_df.selectExpr("ST_Dump(geom) as geom") - assert (dumped_geometries.select(explode(col("geom"))).count() == 14) + assert dumped_geometries.select(explode(col("geom"))).count() == 14 - collected_geometries = dumped_geometries \ - .select(explode(col("geom")).alias("geom")) \ - .selectExpr("ST_AsText(geom) as geom") \ + collected_geometries = ( + dumped_geometries.select(explode(col("geom")).alias("geom")) + .selectExpr("ST_AsText(geom) as geom") .collect() + ) - assert ([geom_row[0] for geom_row in collected_geometries] == expected_geometries) + assert [geom_row[0] for geom_row in collected_geometries] == expected_geometries def test_st_dump_points(self): expected_points = [ @@ -826,21 +1134,29 @@ def test_st_dump_points(self): "POINT (-112.504872 45.983586)", "POINT (-112.504872 45.98186)", "POINT (-71.064544 42.28787)", - "POINT (0 0)", "POINT (0 1)", - "POINT (1 1)", "POINT (1 0)", - "POINT (0 0)" + "POINT (0 0)", + "POINT (0 1)", + "POINT (1 1)", + "POINT (1 0)", + "POINT (0 0)", ] - geometry_df = create_sample_lines_df(self.spark, 1) \ - .union(create_sample_points_df(self.spark, 1)) \ + geometry_df = ( + create_sample_lines_df(self.spark, 1) + .union(create_sample_points_df(self.spark, 1)) .union(create_simple_polygons_df(self.spark, 1)) + ) - dumped_points = geometry_df.selectExpr("ST_DumpPoints(geom) as geom") \ - .select(explode(col("geom")).alias("geom")) + dumped_points = geometry_df.selectExpr("ST_DumpPoints(geom) as geom").select( + explode(col("geom")).alias("geom") + ) - assert (dumped_points.count() == 10) + assert dumped_points.count() == 10 - collected_points = [geom_row[0] for geom_row in dumped_points.selectExpr("ST_AsText(geom)").collect()] - assert (collected_points == expected_points) + collected_points = [ + geom_row[0] + for geom_row in dumped_points.selectExpr("ST_AsText(geom)").collect() + ] + assert collected_points == expected_points def test_st_is_closed(self): expected_result = [ @@ -853,7 +1169,7 @@ def test_st_is_closed(self): [7, True], [8, False], [9, False], - [10, False] + [10, False], ] geometry_list = [ (1, "Point(21 52)"), @@ -861,18 +1177,32 @@ def test_st_is_closed(self): (3, "Linestring(0 0, 1 1, 1 0)"), (4, "Linestring(0 0, 1 1, 1 0, 0 0)"), (5, "MULTIPOINT ((10 40), (40 30), (20 20), (30 10))"), - (6, "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))"), - (7, "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))"), - (8, "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10))"), - (9, "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))"), - (10, - "GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40), POLYGON ((40 40, 20 45, 45 30, 40 40)))") + ( + 6, + "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", + ), + ( + 7, + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", + ), + ( + 8, + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10))", + ), + ( + 9, + "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))", + ), + ( + 10, + "GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40), POLYGON ((40 40, 20 45, 45 30, 40 40)))", + ), ] geometry_df = self.__wkt_pair_list_with_index_to_data_frame(geometry_list) is_closed = geometry_df.selectExpr("index", "ST_IsClosed(geom)").collect() is_closed_collected = [[*row] for row in is_closed] - assert (is_closed_collected == expected_result) + assert is_closed_collected == expected_result def test_num_interior_rings(self): geometries = [ @@ -881,19 +1211,39 @@ def test_num_interior_rings(self): (3, "Linestring(0 0, 1 1, 1 0)"), (4, "Linestring(0 0, 1 1, 1 0, 0 0)"), (5, "MULTIPOINT ((10 40), (40 30), (20 20), (30 10))"), - (6, "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))"), - (7, "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))"), - (8, "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10))"), - (9, "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))"), - (10, - "GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40), POLYGON ((40 40, 20 45, 45 30, 40 40)))"), - (11, "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))")] + ( + 6, + "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", + ), + ( + 7, + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", + ), + ( + 8, + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10))", + ), + ( + 9, + "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))", + ), + ( + 10, + "GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40), POLYGON ((40 40, 20 45, 45 30, 40 40)))", + ), + (11, "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))"), + ] geometry_df = self.__wkt_pair_list_with_index_to_data_frame(geometries) - number_of_interior_rings = geometry_df.selectExpr("index", "ST_NumInteriorRings(geom) as num") - collected_interior_rings = [[*row] for row in number_of_interior_rings.filter("num is not null").collect()] - assert (collected_interior_rings == [[2, 0], [11, 1]]) + number_of_interior_rings = geometry_df.selectExpr( + "index", "ST_NumInteriorRings(geom) as num" + ) + collected_interior_rings = [ + [*row] + for row in number_of_interior_rings.filter("num is not null").collect() + ] + assert collected_interior_rings == [[2, 0], [11, 1]] def test_num_interior_ring(self): geometries = [ @@ -902,28 +1252,52 @@ def test_num_interior_ring(self): (3, "Linestring(0 0, 1 1, 1 0)"), (4, "Linestring(0 0, 1 1, 1 0, 0 0)"), (5, "MULTIPOINT ((10 40), (40 30), (20 20), (30 10))"), - (6, "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))"), - (7, "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))"), - (8, "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10))"), - (9, "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))"), - (10, - "GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40), POLYGON ((40 40, 20 45, 45 30, 40 40)))"), - (11, "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))")] + ( + 6, + "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", + ), + ( + 7, + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", + ), + ( + 8, + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10))", + ), + ( + 9, + "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))", + ), + ( + 10, + "GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40), POLYGON ((40 40, 20 45, 45 30, 40 40)))", + ), + (11, "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))"), + ] geometry_df = self.__wkt_pair_list_with_index_to_data_frame(geometries) - number_of_interior_rings = geometry_df.selectExpr("index", "ST_NumInteriorRing(geom) as num") - collected_interior_rings = [[*row] for row in number_of_interior_rings.filter("num is not null").collect()] - assert (collected_interior_rings == [[2, 0], [11, 1]]) + number_of_interior_rings = geometry_df.selectExpr( + "index", "ST_NumInteriorRing(geom) as num" + ) + collected_interior_rings = [ + [*row] + for row in number_of_interior_rings.filter("num is not null").collect() + ] + assert collected_interior_rings == [[2, 0], [11, 1]] def test_st_add_measure(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (1 1, 2 2, 2 2, 3 3)') as line, ST_GeomFromWKT('MULTILINESTRING M((1 0 4, 2 0 4, 4 0 4),(1 0 4, 2 0 4, 4 0 4))') as mline") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (1 1, 2 2, 2 2, 3 3)') as line, ST_GeomFromWKT('MULTILINESTRING M((1 0 4, 2 0 4, 4 0 4),(1 0 4, 2 0 4, 4 0 4))') as mline" + ) actual = baseDf.selectExpr("ST_AsText(ST_AddMeasure(line, 1, 70))").first()[0] expected = "LINESTRING M(1 1 1, 2 2 35.5, 2 2 35.5, 3 3 70)" assert expected == actual actual = baseDf.selectExpr("ST_AsText(ST_AddMeasure(mline, 10, 70))").first()[0] - expected = "MULTILINESTRING M((1 0 10, 2 0 20, 4 0 40), (1 0 40, 2 0 50, 4 0 70))" + expected = ( + "MULTILINESTRING M((1 0 10, 2 0 20, 4 0 40), (1 0 40, 2 0 50, 4 0 70))" + ) assert expected == actual def test_st_add_point(self): @@ -933,24 +1307,47 @@ def test_st_add_point(self): ("Linestring(0 0, 1 1, 1 0)", "Point(21 52)"), ("Linestring(0 0, 1 1, 1 0, 0 0)", "Linestring(0 0, 1 1, 1 0, 0 0)"), ("Point(21 52)", "MULTIPOINT ((10 40), (40 30), (20 20), (30 10))"), - ("MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", "Point(21 52)"), - ("MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", "Point(21 52)"), - ("MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10))", "Point(21 52)"), - ("MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))", "Point(21 52)"), + ( + "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", + "Point(21 52)", + ), + ( + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", + "Point(21 52)", + ), + ( + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10))", + "Point(21 52)", + ), + ( + "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))", + "Point(21 52)", + ), ( "GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40), POLYGON ((40 40, 20 45, 45 30, 40 40)))", - "Point(21 52)"), - ("POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))", "Point(21 52)") + "Point(21 52)", + ), + ( + "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))", + "Point(21 52)", + ), ] geometry_df = self.__wkt_pairs_to_data_frame(geometry) - modified_geometries = geometry_df.selectExpr("ST_AddPoint(geomA, geomB) as geom") + modified_geometries = geometry_df.selectExpr( + "ST_AddPoint(geomA, geomB) as geom" + ) collected_geometries = [ - row[0] for row in modified_geometries.filter("geom is not null").selectExpr("ST_AsText(geom)").collect() + row[0] + for row in modified_geometries.filter("geom is not null") + .selectExpr("ST_AsText(geom)") + .collect() ] - assert (collected_geometries[0] == "LINESTRING (0 0, 1 1, 1 0, 21 52)") + assert collected_geometries[0] == "LINESTRING (0 0, 1 1, 1 0, 21 52)" def test_st_rotate_x(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (50 160, 50 50, 100 50)') as geom1, ST_GeomFromWKT('LINESTRING(1 2 3, 1 1 1)') AS geom2") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (50 160, 50 50, 100 50)') as geom1, ST_GeomFromWKT('LINESTRING(1 2 3, 1 1 1)') AS geom2" + ) actual = baseDf.selectExpr("ST_RotateX(geom1, PI())").first()[0].wkt expected = "LINESTRING (50 -160, 50 -50, 100 -50)" @@ -961,7 +1358,9 @@ def test_st_rotate_x(self): assert expected == actual def test_st_rotate_y(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (50 160, 50 50, 100 50)') as geom1, ST_GeomFromWKT('LINESTRING(1 2 3, 1 1 1)') AS geom2") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (50 160, 50 50, 100 50)') as geom1, ST_GeomFromWKT('LINESTRING(1 2 3, 1 1 1)') AS geom2" + ) actual = baseDf.selectExpr("ST_RotateY(geom1, PI())").first()[0].wkt expected = "LINESTRING (-50 160, -50 50, -100 50)" @@ -973,23 +1372,58 @@ def test_st_rotate_y(self): def test_st_remove_point(self): result_and_expected = [ - [self.calculate_st_remove("Linestring(0 0, 1 1, 1 0, 0 0)", 0), "LINESTRING (1 1, 1 0, 0 0)"], - [self.calculate_st_remove("Linestring(0 0, 1 1, 1 0, 0 0)", 1), "LINESTRING (0 0, 1 0, 0 0)"], - [self.calculate_st_remove("Linestring(0 0, 1 1, 1 0, 0 0)", 2), "LINESTRING (0 0, 1 1, 0 0)"], - [self.calculate_st_remove("Linestring(0 0, 1 1, 1 0, 0 0)", 3), "LINESTRING (0 0, 1 1, 1 0)"], + [ + self.calculate_st_remove("Linestring(0 0, 1 1, 1 0, 0 0)", 0), + "LINESTRING (1 1, 1 0, 0 0)", + ], + [ + self.calculate_st_remove("Linestring(0 0, 1 1, 1 0, 0 0)", 1), + "LINESTRING (0 0, 1 0, 0 0)", + ], + [ + self.calculate_st_remove("Linestring(0 0, 1 1, 1 0, 0 0)", 2), + "LINESTRING (0 0, 1 1, 0 0)", + ], + [ + self.calculate_st_remove("Linestring(0 0, 1 1, 1 0, 0 0)", 3), + "LINESTRING (0 0, 1 1, 1 0)", + ], [self.calculate_st_remove("POINT(0 1)", 3), None], - [self.calculate_st_remove("POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))", 3), None], - [self.calculate_st_remove("GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40))", 0), None], - [self.calculate_st_remove( - "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", 3), None], - [self.calculate_st_remove( - "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", 3), None] + [ + self.calculate_st_remove( + "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))", 3 + ), + None, + ], + [ + self.calculate_st_remove( + "GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40))", + 0, + ), + None, + ], + [ + self.calculate_st_remove( + "MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", + 3, + ), + None, + ], + [ + self.calculate_st_remove( + "MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", + 3, + ), + None, + ], ] for actual, expected in result_and_expected: - assert (actual == expected) + assert actual == expected def test_st_remove_repeated_points(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION (POINT (10 10),LINESTRING (20 20, 20 20, 30 30, 30 30),POLYGON ((40 40, 50 50, 50 50, 60 60, 60 60, 70 70, 70 70, 40 40)), MULTIPOINT ((80 80), (90 90), (90 90), (100 100)))', 1000) AS geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION (POINT (10 10),LINESTRING (20 20, 20 20, 30 30, 30 30),POLYGON ((40 40, 50 50, 50 50, 60 60, 60 60, 70 70, 70 70, 40 40)), MULTIPOINT ((80 80), (90 90), (90 90), (100 100)))', 1000) AS geom" + ) actualDf = baseDf.selectExpr("ST_RemoveRepeatedPoints(geom, 1000) as geom") actual = actualDf.selectExpr("ST_AsText(geom)").first()[0] expected = "GEOMETRYCOLLECTION (POINT (10 10), LINESTRING (20 20, 30 30), POLYGON ((40 40, 70 70, 70 70, 40 40)), MULTIPOINT ((80 80)))" @@ -997,52 +1431,74 @@ def test_st_remove_repeated_points(self): actualSRID = actualDf.selectExpr("ST_SRID(geom)").first()[0] assert 1000 == actualSRID - def test_isPolygonCW(self): - actual = self.spark.sql("SELECT ST_IsPolygonCW(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35),(30 20, 20 15, 20 25, 30 20))'))").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_IsPolygonCW(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35),(30 20, 20 15, 20 25, 30 20))'))" + ).take(1)[0][0] assert not actual - actual = self.spark.sql("SELECT ST_IsPolygonCW(ST_GeomFromWKT('POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))'))").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_IsPolygonCW(ST_GeomFromWKT('POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))'))" + ).take(1)[0][0] assert actual def test_st_simplify_vw(self): - basedf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING(5 2, 3 8, 6 20, 7 25, 10 10)') as geom") + basedf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING(5 2, 3 8, 6 20, 7 25, 10 10)') as geom" + ) actual = basedf.selectExpr("ST_SimplifyVW(geom, 30)").take(1)[0][0].wkt expected = "LINESTRING (5 2, 7 25, 10 10)" assert expected == actual def test_st_simplify_polygon_hull(self): - basedf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((30 10, 40 40, 45 45, 20 40, 25 35, 10 20, 15 15, 30 10))') as geom") - actual = basedf.selectExpr("ST_SimplifyPolygonHull(geom, 0.3, false)").take(1)[0][0].wkt + basedf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((30 10, 40 40, 45 45, 20 40, 25 35, 10 20, 15 15, 30 10))') as geom" + ) + actual = ( + basedf.selectExpr("ST_SimplifyPolygonHull(geom, 0.3, false)") + .take(1)[0][0] + .wkt + ) expected = "POLYGON ((30 10, 40 40, 10 20, 30 10))" assert expected == actual - actual = basedf.selectExpr("ST_SimplifyPolygonHull(geom, 0.3)").take(1)[0][0].wkt + actual = ( + basedf.selectExpr("ST_SimplifyPolygonHull(geom, 0.3)").take(1)[0][0].wkt + ) expected = "POLYGON ((30 10, 15 15, 10 20, 20 40, 45 45, 30 10))" assert expected == actual - def test_st_is_ring(self): result_and_expected = [ [self.calculate_st_is_ring("LINESTRING(0 0, 0 1, 1 0, 1 1, 0 0)"), False], [self.calculate_st_is_ring("LINESTRING(2 0, 2 2, 3 3)"), False], [self.calculate_st_is_ring("LINESTRING(0 0, 0 1, 1 1, 1 0, 0 0)"), True], [self.calculate_st_is_ring("POINT (21 52)"), None], - [self.calculate_st_is_ring("POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))"), None], + [ + self.calculate_st_is_ring( + "POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))" + ), + None, + ], ] for actual, expected in result_and_expected: - assert (actual == expected) + assert actual == expected def test_isPolygonCCW(self): - actual = self.spark.sql("SELECT ST_IsPolygonCCW(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35),(30 20, 20 15, 20 25, 30 20))'))").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_IsPolygonCCW(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35),(30 20, 20 15, 20 25, 30 20))'))" + ).take(1)[0][0] assert actual - actual = self.spark.sql("SELECT ST_IsPolygonCCW(ST_GeomFromWKT('POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))'))").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_IsPolygonCCW(ST_GeomFromWKT('POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))'))" + ).take(1)[0][0] assert not actual def test_forcePolygonCCW(self): actualDf = self.spark.sql( - "SELECT ST_ForcePolygonCCW(ST_GeomFromWKT('POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))')) AS polyCW") + "SELECT ST_ForcePolygonCCW(ST_GeomFromWKT('POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))')) AS polyCW" + ) actual = actualDf.selectExpr("ST_AsText(polyCW)").take(1)[0][0] expected = "POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35), (30 20, 20 15, 20 25, 30 20))" assert expected == actual @@ -1053,7 +1509,7 @@ def test_st_subdivide(self): [ "POINT(21 52)", "POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", - "LINESTRING (0 0, 1 1, 2 2)" + "LINESTRING (0 0, 1 1, 2 2)", ] ) geometry_df.createOrReplaceTempView("geometry") @@ -1064,7 +1520,9 @@ def test_st_subdivide(self): # Then assert subdivided.count() == 3 - assert sum([geometries[0].__len__() for geometries in subdivided.collect()]) == 16 + assert ( + sum([geometries[0].__len__() for geometries in subdivided.collect()]) == 16 + ) def test_st_subdivide_explode(self): # Given @@ -1072,7 +1530,7 @@ def test_st_subdivide_explode(self): [ "POINT(21 52)", "POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", - "LINESTRING (0 0, 1 1, 2 2)" + "LINESTRING (0 0, 1 1, 2 2)", ] ) geometry_df.createOrReplaceTempView("geometry") @@ -1084,12 +1542,16 @@ def test_st_subdivide_explode(self): assert subdivided.count() == 16 def test_st_has_z(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON Z ((30 10 5, 40 40 10, 20 40 15, 10 20 20, 30 10 5))') as poly") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON Z ((30 10 5, 40 40 10, 20 40 15, 10 20 20, 30 10 5))') as poly" + ) actual = baseDf.selectExpr("ST_HasZ(poly)") assert actual def test_st_has_m(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON ZM ((30 10 5 1, 40 40 10 2, 20 40 15 3, 10 20 20 4, 30 10 5 1))') as poly") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ZM ((30 10 5 1, 40 40 10 2, 20 40 15 3, 10 20 20 4, 30 10 5 1))') as poly" + ) actual = baseDf.selectExpr("ST_HasM(poly)") assert actual @@ -1099,20 +1561,28 @@ def test_st_m(self): assert actual == 4.0 def test_st_m_min(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING ZM(1 1 1 1, 2 2 2 2, 3 3 3 3, -1 -1 -1 -1)') AS line") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING ZM(1 1 1 1, 2 2 2 2, 3 3 3 3, -1 -1 -1 -1)') AS line" + ) actual = baseDf.selectExpr("ST_MMin(line)").take(1)[0][0] assert actual == -1.0 - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING(1 1, 2 2, 3 3, -1 -1)') AS line") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING(1 1, 2 2, 3 3, -1 -1)') AS line" + ) actual = baseDf.selectExpr("ST_MMin(line)").take(1)[0][0] assert actual is None def test_st_m_max(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING ZM(1 1 1 1, 2 2 2 2, 3 3 3 3, -1 -1 -1 -1)') AS line") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING ZM(1 1 1 1, 2 2 2 2, 3 3 3 3, -1 -1 -1 -1)') AS line" + ) actual = baseDf.selectExpr("ST_MMax(line)").take(1)[0][0] assert actual == 3.0 - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING(1 1, 2 2, 3 3, -1 -1)') AS line") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING(1 1, 2 2, 3 3, -1 -1)') AS line" + ) actual = baseDf.selectExpr("ST_MMax(line)").take(1)[0][0] assert actual is None @@ -1122,15 +1592,16 @@ def test_st_subdivide_explode_lateral(self): [ "POINT(21 52)", "POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", - "LINESTRING (0 0, 1 1, 2 2)" + "LINESTRING (0 0, 1 1, 2 2)", ] ) geometry_df.selectExpr("geom as geometry").createOrReplaceTempView("geometries") # When - lateral_view_result = self.spark. \ - sql("""select geom from geometries LATERAL VIEW ST_SubdivideExplode(geometry, 5) AS geom""") + lateral_view_result = self.spark.sql( + """select geom from geometries LATERAL VIEW ST_SubdivideExplode(geometry, 5) AS geom""" + ) # Then assert lateral_view_result.count() == 16 @@ -1139,18 +1610,31 @@ def test_st_make_line(self): # Given geometry_df = self.spark.createDataFrame( [ - ["POINT(0 0)", "POINT(1 1)" , "LINESTRING (0 0, 1 1)"], - ["MULTIPOINT ((0 0), (1 1))", "MULTIPOINT ((2 2), (2 3))", "LINESTRING (0 0, 1 1, 2 2, 2 3)"], - ["LINESTRING (0 0, 1 1)", "LINESTRING(2 2, 3 3)", "LINESTRING (0 0, 1 1, 2 2, 3 3)"] + ["POINT(0 0)", "POINT(1 1)", "LINESTRING (0 0, 1 1)"], + [ + "MULTIPOINT ((0 0), (1 1))", + "MULTIPOINT ((2 2), (2 3))", + "LINESTRING (0 0, 1 1, 2 2, 2 3)", + ], + [ + "LINESTRING (0 0, 1 1)", + "LINESTRING(2 2, 3 3)", + "LINESTRING (0 0, 1 1, 2 2, 3 3)", + ], ] - ).selectExpr("ST_GeomFromText(_1) AS geom1", "ST_GeomFromText(_2) AS geom2", "_3 AS expected") + ).selectExpr( + "ST_GeomFromText(_1) AS geom1", + "ST_GeomFromText(_2) AS geom2", + "_3 AS expected", + ) # When calling st_MakeLine - geom_lines = geometry_df.withColumn("linestring", expr("ST_MakeLine(geom1, geom2)")) + geom_lines = geometry_df.withColumn( + "linestring", expr("ST_MakeLine(geom1, geom2)") + ) # Then - result = geom_lines.selectExpr("ST_AsText(linestring)", "expected"). \ - collect() + result = geom_lines.selectExpr("ST_AsText(linestring)", "expected").collect() for actual, expected in result: assert actual == expected @@ -1160,15 +1644,23 @@ def test_st_points(self): geometry_df = self.spark.createDataFrame( [ # Adding only the input that will result in a non-null polygon - ["MULTILINESTRING ((0 0, 1 1), (2 2, 3 3))", "MULTIPOINT ((0 0), (1 1), (2 2), (3 3))"] + [ + "MULTILINESTRING ((0 0, 1 1), (2 2, 3 3))", + "MULTIPOINT ((0 0), (1 1), (2 2), (3 3))", + ] ] ).selectExpr("ST_GeomFromText(_1) AS geom", "_2 AS expected") # When calling st_points - geom_poly = geometry_df.withColumn("actual", expr("st_normalize(st_points(geom))")) + geom_poly = geometry_df.withColumn( + "actual", expr("st_normalize(st_points(geom))") + ) - result = geom_poly.filter("actual IS NOT NULL").selectExpr("ST_AsText(actual)", "expected"). \ - collect() + result = ( + geom_poly.filter("actual IS NOT NULL") + .selectExpr("ST_AsText(actual)", "expected") + .collect() + ) assert result.__len__() == 1 @@ -1180,8 +1672,16 @@ def test_st_polygon(self): geometry_df = self.spark.createDataFrame( [ ["POINT(21 52)", 4238, None], - ["POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", 4237, None], - ["LINESTRING (0 0, 0 1, 1 0, 0 0)", 4236, "POLYGON ((0 0, 0 1, 1 0, 0 0))"] + [ + "POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", + 4237, + None, + ], + [ + "LINESTRING (0 0, 0 1, 1 0, 0 0)", + 4236, + "POLYGON ((0 0, 0 1, 1 0, 0 0))", + ], ] ).selectExpr("ST_GeomFromText(_1) AS geom", "_2 AS srid", "_3 AS expected") @@ -1189,10 +1689,14 @@ def test_st_polygon(self): geom_poly = geometry_df.withColumn("polygon", expr("ST_Polygon(geom, srid)")) # Then only based on closed linestring geom is created - geom_poly.filter("polygon IS NOT NULL").selectExpr("ST_AsText(polygon)", "expected"). \ - show() - result = geom_poly.filter("polygon IS NOT NULL").selectExpr("ST_AsText(polygon)", "expected"). \ - collect() + geom_poly.filter("polygon IS NOT NULL").selectExpr( + "ST_AsText(polygon)", "expected" + ).show() + result = ( + geom_poly.filter("polygon IS NOT NULL") + .selectExpr("ST_AsText(polygon)", "expected") + .collect() + ) assert result.__len__() == 1 @@ -1204,32 +1708,43 @@ def test_st_polygonize(self): geometry_df = self.spark.createDataFrame( [ # Adding only the input that will result in a non-null polygon - ["GEOMETRYCOLLECTION (LINESTRING (2 0, 2 1, 2 2), LINESTRING (2 2, 2 3, 2 4), LINESTRING (0 2, 1 2, 2 2), LINESTRING (2 2, 3 2, 4 2), LINESTRING (0 2, 1 3, 2 4), LINESTRING (2 4, 3 3, 4 2))", "GEOMETRYCOLLECTION (POLYGON ((0 2, 1 3, 2 4, 2 3, 2 2, 1 2, 0 2)), POLYGON ((2 2, 2 3, 2 4, 3 3, 4 2, 3 2, 2 2)))"] + [ + "GEOMETRYCOLLECTION (LINESTRING (2 0, 2 1, 2 2), LINESTRING (2 2, 2 3, 2 4), LINESTRING (0 2, 1 2, 2 2), LINESTRING (2 2, 3 2, 4 2), LINESTRING (0 2, 1 3, 2 4), LINESTRING (2 4, 3 3, 4 2))", + "GEOMETRYCOLLECTION (POLYGON ((0 2, 1 3, 2 4, 2 3, 2 2, 1 2, 0 2)), POLYGON ((2 2, 2 3, 2 4, 3 3, 4 2, 3 2, 2 2)))", + ] ] ).selectExpr("ST_GeomFromText(_1) AS geom", "_2 AS expected") # When calling st_polygonize - geom_poly = geometry_df.withColumn("actual", expr("st_normalize(st_polygonize(geom))")) + geom_poly = geometry_df.withColumn( + "actual", expr("st_normalize(st_polygonize(geom))") + ) # Then only based on closed linestring geom is created - geom_poly.filter("actual IS NOT NULL").selectExpr("ST_AsText(actual)", "expected"). \ - show() - result = geom_poly.filter("actual IS NOT NULL").selectExpr("ST_AsText(actual)", "expected"). \ - collect() + geom_poly.filter("actual IS NOT NULL").selectExpr( + "ST_AsText(actual)", "expected" + ).show() + result = ( + geom_poly.filter("actual IS NOT NULL") + .selectExpr("ST_AsText(actual)", "expected") + .collect() + ) assert result.__len__() == 1 for actual, expected in result: assert actual == expected - def test_st_make_polygon(self): # Given geometry_df = self.spark.createDataFrame( [ ["POINT(21 52)", None], - ["POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", None], - ["LINESTRING (0 0, 0 1, 1 0, 0 0)", "POLYGON ((0 0, 0 1, 1 0, 0 0))"] + [ + "POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", + None, + ], + ["LINESTRING (0 0, 0 1, 1 0, 0 0)", "POLYGON ((0 0, 0 1, 1 0, 0 0))"], ] ).selectExpr("ST_GeomFromText(_1) AS geom", "_2 AS expected") @@ -1237,10 +1752,14 @@ def test_st_make_polygon(self): geom_poly = geometry_df.withColumn("polygon", expr("ST_MakePolygon(geom)")) # Then only based on closed linestring geom is created - geom_poly.filter("polygon IS NOT NULL").selectExpr("ST_AsText(polygon)", "expected"). \ - show() - result = geom_poly.filter("polygon IS NOT NULL").selectExpr("ST_AsText(polygon)", "expected"). \ - collect() + geom_poly.filter("polygon IS NOT NULL").selectExpr( + "ST_AsText(polygon)", "expected" + ).show() + result = ( + geom_poly.filter("polygon IS NOT NULL") + .selectExpr("ST_AsText(polygon)", "expected") + .collect() + ) assert result.__len__() == 1 @@ -1252,14 +1771,20 @@ def test_st_geohash(self): geometry_df = self.spark.createDataFrame( [ ["POINT(21 52)", "u3nzvf79zq"], - ["POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", "ssgs3y0zh7"], - ["LINESTRING (0 0, 1 1, 2 2)", "s00twy01mt"] + [ + "POLYGON ((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))", + "ssgs3y0zh7", + ], + ["LINESTRING (0 0, 1 1, 2 2)", "s00twy01mt"], ] - ).select(expr("St_GeomFromText(_1)").alias("geom"), col("_2").alias("expected_hash")) + ).select( + expr("St_GeomFromText(_1)").alias("geom"), col("_2").alias("expected_hash") + ) # When - geohash_df = geometry_df.withColumn("geohash", expr("ST_GeoHash(geom, 10)")). \ - select("geohash", "expected_hash") + geohash_df = geometry_df.withColumn( + "geohash", expr("ST_GeoHash(geom, 10)") + ).select("geohash", "expected_hash") # Then geohash = geohash_df.collect() @@ -1273,22 +1798,28 @@ def test_geom_from_geohash(self): [ [ "POLYGON ((20.999990701675415 51.99999690055847, 20.999990701675415 52.0000022649765, 21.000001430511475 52.0000022649765, 21.000001430511475 51.99999690055847, 20.999990701675415 51.99999690055847))", - "u3nzvf79zq" + "u3nzvf79zq", ], [ "POLYGON ((26.71875 26.71875, 26.71875 28.125, 28.125 28.125, 28.125 26.71875, 26.71875 26.71875))", - "ssg" + "ssg", ], [ "POLYGON ((0.9999918937683105 0.9999972581863403, 0.9999918937683105 1.0000026226043701, 1.0000026226043701 1.0000026226043701, 1.0000026226043701 0.9999972581863403, 0.9999918937683105 0.9999972581863403))", - "s00twy01mt" - ] + "s00twy01mt", + ], ] - ).select(expr("ST_GeomFromGeoHash(_2)").alias("geom"), col("_1").alias("expected_polygon")) + ).select( + expr("ST_GeomFromGeoHash(_2)").alias("geom"), + col("_1").alias("expected_polygon"), + ) # When - wkt_df = geometry_df.withColumn("wkt", expr("ST_AsText(geom)")). \ - select("wkt", "expected_polygon").collect() + wkt_df = ( + geometry_df.withColumn("wkt", expr("ST_AsText(geom)")) + .select("wkt", "expected_polygon") + .collect() + ) for wkt, expected_polygon in wkt_df: assert wkt == expected_polygon @@ -1297,15 +1828,25 @@ def test_geom_from_geohash_precision(self): # Given geometry_df = self.spark.createDataFrame( [ - ["POLYGON ((11.25 50.625, 11.25 56.25, 22.5 56.25, 22.5 50.625, 11.25 50.625))", "u3nzvf79zq"], - ["POLYGON ((22.5 22.5, 22.5 28.125, 33.75 28.125, 33.75 22.5, 22.5 22.5))", "ssgs3y0zh7"], - ["POLYGON ((0 0, 0 5.625, 11.25 5.625, 11.25 0, 0 0))", "s00twy01mt"] + [ + "POLYGON ((11.25 50.625, 11.25 56.25, 22.5 56.25, 22.5 50.625, 11.25 50.625))", + "u3nzvf79zq", + ], + [ + "POLYGON ((22.5 22.5, 22.5 28.125, 33.75 28.125, 33.75 22.5, 22.5 22.5))", + "ssgs3y0zh7", + ], + ["POLYGON ((0 0, 0 5.625, 11.25 5.625, 11.25 0, 0 0))", "s00twy01mt"], ] - ).select(expr("ST_GeomFromGeoHash(_2, 2)").alias("geom"), col("_1").alias("expected_polygon")) + ).select( + expr("ST_GeomFromGeoHash(_2, 2)").alias("geom"), + col("_1").alias("expected_polygon"), + ) # When - wkt_df = geometry_df.withColumn("wkt", expr("ST_ASText(geom)")). \ - select("wkt", "expected_polygon") + wkt_df = geometry_df.withColumn("wkt", expr("ST_ASText(geom)")).select( + "wkt", "expected_polygon" + ) # Then geohash = wkt_df.collect() @@ -1315,124 +1856,170 @@ def test_geom_from_geohash_precision(self): def test_st_closest_point(self): expected = "POINT (0 1)" - actual_df = self.spark.sql("select ST_AsText(ST_ClosestPoint(ST_GeomFromText('POINT (0 1)'), " - "ST_GeomFromText('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)')))") + actual_df = self.spark.sql( + "select ST_AsText(ST_ClosestPoint(ST_GeomFromText('POINT (0 1)'), " + "ST_GeomFromText('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)')))" + ) actual = actual_df.take(1)[0][0] assert expected == actual def test_st_collect_on_array_type(self): # given - geometry_df = self.spark.createDataFrame([ - [1, [loads("POLYGON((1 2,1 4,3 4,3 2,1 2))"), loads("POLYGON((0.5 0.5,5 0,5 5,0 5,0.5 0.5))")]], - [2, [loads("LINESTRING(1 2, 3 4)"), loads("LINESTRING(3 4, 4 5)")]], - [3, [loads("POINT(1 2)"), loads("POINT(-2 3)")]] - ]).toDF("id", "geom") + geometry_df = self.spark.createDataFrame( + [ + [ + 1, + [ + loads("POLYGON((1 2,1 4,3 4,3 2,1 2))"), + loads("POLYGON((0.5 0.5,5 0,5 5,0 5,0.5 0.5))"), + ], + ], + [2, [loads("LINESTRING(1 2, 3 4)"), loads("LINESTRING(3 4, 4 5)")]], + [3, [loads("POINT(1 2)"), loads("POINT(-2 3)")]], + ] + ).toDF("id", "geom") # when calculating st collect - geometry_df_collected = geometry_df.withColumn("collected", expr("ST_Collect(geom)")) + geometry_df_collected = geometry_df.withColumn( + "collected", expr("ST_Collect(geom)") + ) # then result should be as expected - assert (set([el[0] for el in geometry_df_collected.selectExpr("ST_AsText(collected)").collect()]) == { + assert set( + [ + el[0] + for el in geometry_df_collected.selectExpr( + "ST_AsText(collected)" + ).collect() + ] + ) == { "MULTILINESTRING ((1 2, 3 4), (3 4, 4 5))", "MULTIPOINT ((1 2), (-2 3))", - "MULTIPOLYGON (((1 2, 1 4, 3 4, 3 2, 1 2)), ((0.5 0.5, 5 0, 5 5, 0 5, 0.5 0.5)))" - }) + "MULTIPOLYGON (((1 2, 1 4, 3 4, 3 2, 1 2)), ((0.5 0.5, 5 0, 5 5, 0 5, 0.5 0.5)))", + } def test_st_collect_on_multiple_columns(self): # given geometry df with multiple geometry columns - geometry_df = self.spark.createDataFrame([ - [1, loads("POLYGON((1 2,1 4,3 4,3 2,1 2))"), loads("POLYGON((0.5 0.5,5 0,5 5,0 5,0.5 0.5))")], - [2, loads("LINESTRING(1 2, 3 4)"), loads("LINESTRING(3 4, 4 5)")], - [3, loads("POINT(1 2)"), loads("POINT(-2 3)")] - ]).toDF("id", "geom_left", "geom_right") + geometry_df = self.spark.createDataFrame( + [ + [ + 1, + loads("POLYGON((1 2,1 4,3 4,3 2,1 2))"), + loads("POLYGON((0.5 0.5,5 0,5 5,0 5,0.5 0.5))"), + ], + [2, loads("LINESTRING(1 2, 3 4)"), loads("LINESTRING(3 4, 4 5)")], + [3, loads("POINT(1 2)"), loads("POINT(-2 3)")], + ] + ).toDF("id", "geom_left", "geom_right") # when calculating st collect on multiple columns - geometry_df_collected = geometry_df.withColumn("collected", expr("ST_Collect(geom_left, geom_right)")) + geometry_df_collected = geometry_df.withColumn( + "collected", expr("ST_Collect(geom_left, geom_right)") + ) # then result should be calculated - assert (set([el[0] for el in geometry_df_collected.selectExpr("ST_AsText(collected)").collect()]) == { + assert set( + [ + el[0] + for el in geometry_df_collected.selectExpr( + "ST_AsText(collected)" + ).collect() + ] + ) == { "MULTILINESTRING ((1 2, 3 4), (3 4, 4 5))", "MULTIPOINT ((1 2), (-2 3))", - "MULTIPOLYGON (((1 2, 1 4, 3 4, 3 2, 1 2)), ((0.5 0.5, 5 0, 5 5, 0 5, 0.5 0.5)))" - }) + "MULTIPOLYGON (((1 2, 1 4, 3 4, 3 2, 1 2)), ((0.5 0.5, 5 0, 5 5, 0 5, 0.5 0.5)))", + } def test_st_reverse(self): test_cases = { - "'POLYGON((-1 0 0, 1 0 0, 0 0 1, 0 1 0, -1 0 0))'": - "POLYGON Z((-1 0 0, 0 1 0, 0 0 1, 1 0 0, -1 0 0))", - "'LINESTRING(0 0, 1 2, 2 4, 3 6)'": - "LINESTRING (3 6, 2 4, 1 2, 0 0)", - "'POINT(1 2)'": - "POINT (1 2)", - "'MULTIPOINT((10 40 66), (40 30 77), (20 20 88), (30 10 99))'": - "MULTIPOINT Z((10 40 66), (40 30 77), (20 20 88), (30 10 99))", - "'MULTIPOLYGON(((30 20 11, 45 40 11, 10 40 11, 30 20 11)), " \ - "((15 5 11, 40 10 11, 10 20 11, 5 10 11, 15 5 11)))'": - "MULTIPOLYGON Z(((30 20 11, 10 40 11, 45 40 11, 30 20 11)), " \ - "((15 5 11, 5 10 11, 10 20 11, 40 10 11, 15 5 11)))", - "'MULTILINESTRING((10 10 11, 20 20 11, 10 40 11), " \ - "(40 40 11, 30 30 11, 40 20 11, 30 10 11))'": - "MULTILINESTRING Z((10 40 11, 20 20 11, 10 10 11), " \ - "(30 10 11, 40 20 11, 30 30 11, 40 40 11))", - "'MULTIPOLYGON(((40 40 11, 20 45 11, 45 30 11, 40 40 11)), " \ - "((20 35 11, 10 30 11, 10 10 11, 30 5 11, 45 20 11, 20 35 11)," \ - "(30 20 11, 20 15 11, 20 25 11, 30 20 11)))'": - "MULTIPOLYGON Z(((40 40 11, 45 30 11, 20 45 11, 40 40 11)), " \ - "((20 35 11, 45 20 11, 30 5 11, 10 10 11, 10 30 11, 20 35 11), " \ - "(30 20 11, 20 25 11, 20 15 11, 30 20 11)))", - "'POLYGON((0 0 11, 0 5 11, 5 5 11, 5 0 11, 0 0 11), " \ - "(1 1 11, 2 1 11, 2 2 11, 1 2 11, 1 1 11))'": - "POLYGON Z((0 0 11, 5 0 11, 5 5 11, 0 5 11, 0 0 11), " \ - "(1 1 11, 1 2 11, 2 2 11, 2 1 11, 1 1 11))" + "'POLYGON((-1 0 0, 1 0 0, 0 0 1, 0 1 0, -1 0 0))'": "POLYGON Z((-1 0 0, 0 1 0, 0 0 1, 1 0 0, -1 0 0))", + "'LINESTRING(0 0, 1 2, 2 4, 3 6)'": "LINESTRING (3 6, 2 4, 1 2, 0 0)", + "'POINT(1 2)'": "POINT (1 2)", + "'MULTIPOINT((10 40 66), (40 30 77), (20 20 88), (30 10 99))'": "MULTIPOINT Z((10 40 66), (40 30 77), (20 20 88), (30 10 99))", + "'MULTIPOLYGON(((30 20 11, 45 40 11, 10 40 11, 30 20 11)), " + "((15 5 11, 40 10 11, 10 20 11, 5 10 11, 15 5 11)))'": "MULTIPOLYGON Z(((30 20 11, 10 40 11, 45 40 11, 30 20 11)), " + "((15 5 11, 5 10 11, 10 20 11, 40 10 11, 15 5 11)))", + "'MULTILINESTRING((10 10 11, 20 20 11, 10 40 11), " + "(40 40 11, 30 30 11, 40 20 11, 30 10 11))'": "MULTILINESTRING Z((10 40 11, 20 20 11, 10 10 11), " + "(30 10 11, 40 20 11, 30 30 11, 40 40 11))", + "'MULTIPOLYGON(((40 40 11, 20 45 11, 45 30 11, 40 40 11)), " + "((20 35 11, 10 30 11, 10 10 11, 30 5 11, 45 20 11, 20 35 11)," + "(30 20 11, 20 15 11, 20 25 11, 30 20 11)))'": "MULTIPOLYGON Z(((40 40 11, 45 30 11, 20 45 11, 40 40 11)), " + "((20 35 11, 45 20 11, 30 5 11, 10 10 11, 10 30 11, 20 35 11), " + "(30 20 11, 20 25 11, 20 15 11, 30 20 11)))", + "'POLYGON((0 0 11, 0 5 11, 5 5 11, 5 0 11, 0 0 11), " + "(1 1 11, 2 1 11, 2 2 11, 1 2 11, 1 1 11))'": "POLYGON Z((0 0 11, 5 0 11, 5 5 11, 0 5 11, 0 0 11), " + "(1 1 11, 1 2 11, 2 2 11, 2 1 11, 1 1 11))", } for input_geom, expected_geom in test_cases.items(): - reversed_geometry = self.spark.sql("select ST_AsText(ST_Reverse(ST_GeomFromText({})))".format(input_geom)) + reversed_geometry = self.spark.sql( + "select ST_AsText(ST_Reverse(ST_GeomFromText({})))".format(input_geom) + ) assert reversed_geometry.take(1)[0][0] == expected_geom def calculate_st_is_ring(self, wkt): - geometry_collected = self.__wkt_list_to_data_frame([wkt]). \ - selectExpr("ST_IsRing(geom) as is_ring") \ - .filter("is_ring is not null").collect() + geometry_collected = ( + self.__wkt_list_to_data_frame([wkt]) + .selectExpr("ST_IsRing(geom) as is_ring") + .filter("is_ring is not null") + .collect() + ) return geometry_collected[0][0] if geometry_collected.__len__() != 0 else None def calculate_st_remove(self, wkt, index): - geometry_collected = self.__wkt_list_to_data_frame([wkt]). \ - selectExpr(f"ST_RemovePoint(geom, {index}) as geom"). \ - filter("geom is not null"). \ - selectExpr("ST_AsText(geom)").collect() + geometry_collected = ( + self.__wkt_list_to_data_frame([wkt]) + .selectExpr(f"ST_RemovePoint(geom, {index}) as geom") + .filter("geom is not null") + .selectExpr("ST_AsText(geom)") + .collect() + ) return geometry_collected[0][0] if geometry_collected.__len__() != 0 else None def __wkt_pairs_to_data_frame(self, wkt_list: List) -> DataFrame: - return self.spark.createDataFrame([[wkt.loads(wkt_a), wkt.loads(wkt_b)] for wkt_a, wkt_b in wkt_list], - self.geo_pair_schema) + return self.spark.createDataFrame( + [[wkt.loads(wkt_a), wkt.loads(wkt_b)] for wkt_a, wkt_b in wkt_list], + self.geo_pair_schema, + ) def __wkt_list_to_data_frame(self, wkt_list: List) -> DataFrame: - return self.spark.createDataFrame([[wkt.loads(given_wkt)] for given_wkt in wkt_list], self.geo_schema) + return self.spark.createDataFrame( + [[wkt.loads(given_wkt)] for given_wkt in wkt_list], self.geo_schema + ) def __wkt_pair_list_with_index_to_data_frame(self, wkt_list: List) -> DataFrame: - return self.spark.createDataFrame([[index, wkt.loads(given_wkt)] for index, given_wkt in wkt_list], - self.geo_schema_with_index) + return self.spark.createDataFrame( + [[index, wkt.loads(given_wkt)] for index, given_wkt in wkt_list], + self.geo_schema_with_index, + ) def test_st_pointonsurface(self): tests1 = { "'POINT(0 5)'": "POINT (0 5)", "'LINESTRING(0 5, 0 10)'": "POINT (0 5)", "'POLYGON((0 0, 0 5, 5 5, 5 0, 0 0))'": "POINT (2.5 2.5)", - "'LINESTRING(0 5 1, 0 0 1, 0 10 2)'": "POINT Z(0 0 1)" + "'LINESTRING(0 5 1, 0 0 1, 0 10 2)'": "POINT Z(0 0 1)", } for input_geom, expected_geom in tests1.items(): pointOnSurface = self.spark.sql( - "select ST_AsText(ST_PointOnSurface(ST_GeomFromText({})))".format(input_geom)) + "select ST_AsText(ST_PointOnSurface(ST_GeomFromText({})))".format( + input_geom + ) + ) assert pointOnSurface.take(1)[0][0] == expected_geom tests2 = {"'LINESTRING(0 5 1, 0 0 1, 0 10 2)'": "POINT Z(0 0 1)"} for input_geom, expected_geom in tests2.items(): pointOnSurface = self.spark.sql( - "select ST_AsEWKT(ST_PointOnSurface(ST_GeomFromWKT({})))".format(input_geom)) + "select ST_AsEWKT(ST_PointOnSurface(ST_GeomFromWKT({})))".format( + input_geom + ) + ) assert pointOnSurface.take(1)[0][0] == expected_geom def test_st_pointn(self): @@ -1447,54 +2034,50 @@ def test_st_pointn(self): [linestring, 5, None], [linestring, -5, None], ["'POLYGON((1 1, 3 1, 3 3, 1 3, 1 1))'", 2, None], - ["'POINT(1 2)'", 1, None] + ["'POINT(1 2)'", 1, None], ] for test in tests: - point = self.spark.sql(f"select ST_AsText(ST_PointN(ST_GeomFromText({test[0]}), {test[1]}))") + point = self.spark.sql( + f"select ST_AsText(ST_PointN(ST_GeomFromText({test[0]}), {test[1]}))" + ) assert point.take(1)[0][0] == test[2] def test_st_force2d(self): tests1 = { "'POINT(0 5)'": "POINT (0 5)", - "'POLYGON((0 0 2, 0 5 2, 5 0 2, 0 0 2), (1 1 2, 3 1 2, 1 3 2, 1 1 2))'": - "POLYGON ((0 0, 0 5, 5 0, 0 0), (1 1, 3 1, 1 3, 1 1))", - "'LINESTRING(0 5 1, 0 0 1, 0 10 2)'": "LINESTRING (0 5, 0 0, 0 10)" + "'POLYGON((0 0 2, 0 5 2, 5 0 2, 0 0 2), (1 1 2, 3 1 2, 1 3 2, 1 1 2))'": "POLYGON ((0 0, 0 5, 5 0, 0 0), (1 1, 3 1, 1 3, 1 1))", + "'LINESTRING(0 5 1, 0 0 1, 0 10 2)'": "LINESTRING (0 5, 0 0, 0 10)", } for input_geom, expected_geom in tests1.items(): geom_2d = self.spark.sql( - "select ST_AsText(ST_Force_2D(ST_GeomFromText({})))".format(input_geom)) + "select ST_AsText(ST_Force_2D(ST_GeomFromText({})))".format(input_geom) + ) assert geom_2d.take(1)[0][0] == expected_geom def test_st_buildarea(self): tests = { - "'MULTILINESTRING((0 0, 10 0, 10 10, 0 10, 0 0),(10 10, 20 10, 20 20, 10 20, 10 10))'": - "MULTIPOLYGON (((0 0, 0 10, 10 10, 10 0, 0 0)), ((10 10, 10 20, 20 20, 20 10, 10 10)))", - "'MULTILINESTRING((0 0, 10 0, 10 10, 0 10, 0 0),(10 10, 20 10, 20 0, 10 0, 10 10))'": - "POLYGON ((0 0, 0 10, 10 10, 20 10, 20 0, 10 0, 0 0))", - "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0),(2 2, 18 2, 18 18, 2 18, 2 2))'": - "POLYGON ((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2))", - "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2), (8 8, 8 12, 12 12, 12 8, 8 8))'": - "MULTIPOLYGON (((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2)), ((8 8, 8 12, 12 12, 12 8, 8 8)))", - "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0),(2 2, 18 2, 18 18, 2 18, 2 2), " \ - "(8 8, 8 9, 8 10, 8 11, 8 12, 9 12, 10 12, 11 12, 12 12, 12 11, 12 10, 12 9, 12 8, 11 8, 10 8, 9 8, 8 8))'": - "MULTIPOLYGON (((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2)), " \ - "((8 8, 8 9, 8 10, 8 11, 8 12, 9 12, 10 12, 11 12, 12 12, 12 11, 12 10, 12 9, 12 8, 11 8, 10 8, 9 8, 8 8)))", - "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0),(2 2, 18 2, 18 18, 2 18, 2 2),(8 8, 8 12, 12 12, 12 8, 8 8),(10 8, 10 12))'": - "MULTIPOLYGON (((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2)), ((8 8, 8 12, 12 12, 12 8, 8 8)))", - "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0),(2 2, 18 2, 18 18, 2 18, 2 2),(10 2, 10 18))'": - "POLYGON ((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2))", - "'MULTILINESTRING( (0 0, 70 0, 70 70, 0 70, 0 0), (10 10, 10 60, 40 60, 40 10, 10 10), " \ - "(20 20, 20 30, 30 30, 30 20, 20 20), (20 30, 30 30, 30 50, 20 50, 20 30), (50 20, 60 20, 60 40, 50 40, 50 20), " \ - "(50 40, 60 40, 60 60, 50 60, 50 40), (80 0, 110 0, 110 70, 80 70, 80 0), (90 60, 100 60, 100 50, 90 50, 90 60))'": - "MULTIPOLYGON (((0 0, 0 70, 70 70, 70 0, 0 0), (10 10, 40 10, 40 60, 10 60, 10 10), (50 20, 60 20, 60 40, 60 60, 50 60, 50 40, 50 20)), " \ - "((20 20, 20 30, 20 50, 30 50, 30 30, 30 20, 20 20)), " \ - "((80 0, 80 70, 110 70, 110 0, 80 0), (90 50, 100 50, 100 60, 90 60, 90 50)))" + "'MULTILINESTRING((0 0, 10 0, 10 10, 0 10, 0 0),(10 10, 20 10, 20 20, 10 20, 10 10))'": "MULTIPOLYGON (((0 0, 0 10, 10 10, 10 0, 0 0)), ((10 10, 10 20, 20 20, 20 10, 10 10)))", + "'MULTILINESTRING((0 0, 10 0, 10 10, 0 10, 0 0),(10 10, 20 10, 20 0, 10 0, 10 10))'": "POLYGON ((0 0, 0 10, 10 10, 20 10, 20 0, 10 0, 0 0))", + "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0),(2 2, 18 2, 18 18, 2 18, 2 2))'": "POLYGON ((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2))", + "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2), (8 8, 8 12, 12 12, 12 8, 8 8))'": "MULTIPOLYGON (((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2)), ((8 8, 8 12, 12 12, 12 8, 8 8)))", + "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0),(2 2, 18 2, 18 18, 2 18, 2 2), " + "(8 8, 8 9, 8 10, 8 11, 8 12, 9 12, 10 12, 11 12, 12 12, 12 11, 12 10, 12 9, 12 8, 11 8, 10 8, 9 8, 8 8))'": "MULTIPOLYGON (((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2)), " + "((8 8, 8 9, 8 10, 8 11, 8 12, 9 12, 10 12, 11 12, 12 12, 12 11, 12 10, 12 9, 12 8, 11 8, 10 8, 9 8, 8 8)))", + "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0),(2 2, 18 2, 18 18, 2 18, 2 2),(8 8, 8 12, 12 12, 12 8, 8 8),(10 8, 10 12))'": "MULTIPOLYGON (((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2)), ((8 8, 8 12, 12 12, 12 8, 8 8)))", + "'MULTILINESTRING((0 0, 20 0, 20 20, 0 20, 0 0),(2 2, 18 2, 18 18, 2 18, 2 2),(10 2, 10 18))'": "POLYGON ((0 0, 0 20, 20 20, 20 0, 0 0), (2 2, 18 2, 18 18, 2 18, 2 2))", + "'MULTILINESTRING( (0 0, 70 0, 70 70, 0 70, 0 0), (10 10, 10 60, 40 60, 40 10, 10 10), " + "(20 20, 20 30, 30 30, 30 20, 20 20), (20 30, 30 30, 30 50, 20 50, 20 30), (50 20, 60 20, 60 40, 50 40, 50 20), " + "(50 40, 60 40, 60 60, 50 60, 50 40), (80 0, 110 0, 110 70, 80 70, 80 0), (90 60, 100 60, 100 50, 90 50, 90 60))'": "MULTIPOLYGON (((0 0, 0 70, 70 70, 70 0, 0 0), (10 10, 40 10, 40 60, 10 60, 10 10), (50 20, 60 20, 60 40, 60 60, 50 60, 50 40, 50 20)), " + "((20 20, 20 30, 20 50, 30 50, 30 30, 30 20, 20 20)), " + "((80 0, 80 70, 110 70, 110 0, 80 0), (90 50, 100 50, 100 60, 90 60, 90 50)))", } for input_geom, expected_geom in tests.items(): - areal_geom = self.spark.sql("select ST_AsText(ST_BuildArea(ST_GeomFromText({})))".format(input_geom)) + areal_geom = self.spark.sql( + "select ST_AsText(ST_BuildArea(ST_GeomFromText({})))".format(input_geom) + ) assert areal_geom.take(1)[0][0] == expected_geom def test_st_line_from_multi_point(self): @@ -1502,34 +2085,45 @@ def test_st_line_from_multi_point(self): "'POLYGON((-1 0 0, 1 0 0, 0 0 1, 0 1 0, -1 0 0))'": None, "'LINESTRING(0 0, 1 2, 2 4, 3 6)'": None, "'POINT(1 2)'": None, - "'MULTIPOINT((10 40), (40 30), (20 20), (30 10))'": - "LINESTRING (10 40, 40 30, 20 20, 30 10)", - "'MULTIPOINT((10 40 66), (40 30 77), (20 20 88), (30 10 99))'": - "LINESTRING Z(10 40 66, 40 30 77, 20 20 88, 30 10 99)" + "'MULTIPOINT((10 40), (40 30), (20 20), (30 10))'": "LINESTRING (10 40, 40 30, 20 20, 30 10)", + "'MULTIPOINT((10 40 66), (40 30 77), (20 20 88), (30 10 99))'": "LINESTRING Z(10 40 66, 40 30 77, 20 20 88, 30 10 99)", } for input_geom, expected_geom in test_cases.items(): line_geometry = self.spark.sql( - "select ST_AsText(ST_LineFromMultiPoint(ST_GeomFromText({})))".format(input_geom)) + "select ST_AsText(ST_LineFromMultiPoint(ST_GeomFromText({})))".format( + input_geom + ) + ) assert line_geometry.take(1)[0][0] == expected_geom def test_st_locate_along(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('MULTILINESTRING M((1 2 3, 3 4 2, 9 4 3),(1 2 3, 5 4 5))') as geom") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('MULTILINESTRING M((1 2 3, 3 4 2, 9 4 3),(1 2 3, 5 4 5))') as geom" + ) actual = baseDf.selectExpr("ST_AsText(ST_LocateAlong(geom, 2))").take(1)[0][0] expected = "MULTIPOINT M((3 4 2))" assert expected == actual - actual = baseDf.selectExpr("ST_AsText(ST_LocateAlong(geom, 2, -3))").take(1)[0][0] + actual = baseDf.selectExpr("ST_AsText(ST_LocateAlong(geom, 2, -3))").take(1)[0][ + 0 + ] expected = "MULTIPOINT M((5.121320343559642 1.8786796564403572 2), (3 1 2))" assert expected == actual def test_st_longest_line(self): - basedf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((40 180, 110 160, 180 180, 180 120, 140 90, 160 40, 80 10, 70 40, 20 50, 40 180),(60 140, 99 77.5, 90 140, 60 140))') as geom") - actual = basedf.selectExpr("ST_AsText(ST_LongestLine(geom, geom))").take(1)[0][0] + basedf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((40 180, 110 160, 180 180, 180 120, 140 90, 160 40, 80 10, 70 40, 20 50, 40 180),(60 140, 99 77.5, 90 140, 60 140))') as geom" + ) + actual = basedf.selectExpr("ST_AsText(ST_LongestLine(geom, geom))").take(1)[0][ + 0 + ] expected = "LINESTRING (180 180, 20 50)" assert expected == actual def test_st_max_distance(self): - basedf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((40 180, 110 160, 180 180, 180 120, 140 90, 160 40, 80 10, 70 40, 20 50, 40 180),(60 140, 99 77.5, 90 140, 60 140))') as geom") + basedf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((40 180, 110 160, 180 180, 180 120, 140 90, 160 40, 80 10, 70 40, 20 50, 40 180),(60 140, 99 77.5, 90 140, 60 140))') as geom" + ) actual = basedf.selectExpr("ST_MaxDistance(geom, geom)").take(1)[0][0] expected = 206.15528128088303 assert expected == actual @@ -1538,10 +2132,12 @@ def test_st_s2_cell_ids(self): test_cases = [ "'POLYGON((-1 0, 1 0, 0 0, 0 1, -1 0))'", "'LINESTRING(0 0, 1 2, 2 4, 3 6)'", - "'POINT(1 2)'" + "'POINT(1 2)'", ] for input_geom in test_cases: - cell_ids = self.spark.sql("select ST_S2CellIDs(ST_GeomFromText({}), 6)".format(input_geom)).take(1)[0][0] + cell_ids = self.spark.sql( + "select ST_S2CellIDs(ST_GeomFromText({}), 6)".format(input_geom) + ).take(1)[0][0] assert isinstance(cell_ids, list) assert isinstance(cell_ids[0], int) # test null case @@ -1549,12 +2145,14 @@ def test_st_s2_cell_ids(self): assert cell_ids is None def test_st_s2_to_geom(self): - df = self.spark.sql(""" + df = self.spark.sql( + """ SELECT ST_Intersects(ST_GeomFromWKT('POLYGON ((0.1 0.1, 0.5 0.1, 1 0.3, 1 1, 0.1 1, 0.1 0.1))'), ST_S2ToGeom(ST_S2CellIDs(ST_GeomFromWKT('POLYGON ((0.1 0.1, 0.5 0.1, 1 0.3, 1 1, 0.1 1, 0.1 0.1))'), 10))[0]), ST_Intersects(ST_GeomFromWKT('POLYGON ((0.1 0.1, 0.5 0.1, 1 0.3, 1 1, 0.1 1, 0.1 0.1))'), ST_S2ToGeom(ST_S2CellIDs(ST_GeomFromWKT('POLYGON ((0.1 0.1, 0.5 0.1, 1 0.3, 1 1, 0.1 1, 0.1 0.1))'), 10))[1]), ST_Intersects(ST_GeomFromWKT('POLYGON ((0.1 0.1, 0.5 0.1, 1 0.3, 1 1, 0.1 1, 0.1 0.1))'), ST_S2ToGeom(ST_S2CellIDs(ST_GeomFromWKT('POLYGON ((0.1 0.1, 0.5 0.1, 1 0.3, 1 1, 0.1 1, 0.1 0.1))'), 10))[2]) - """) + """ + ) res1, res2, res3 = df.take(1)[0] assert res1 and res2 and res3 @@ -1562,10 +2160,12 @@ def test_st_h3_cell_ids(self): test_cases = [ "'POLYGON((-1 0, 1 0, 0 0, 0 1, -1 0))'", "'LINESTRING(0 0, 1 2, 2 4, 3 6)'", - "'POINT(1 2)'" + "'POINT(1 2)'", ] for input_geom in test_cases: - cell_ids = self.spark.sql("select ST_H3CellIDs(ST_GeomFromText({}), 6, true)".format(input_geom)).take(1)[0][0] + cell_ids = self.spark.sql( + "select ST_H3CellIDs(ST_GeomFromText({}), 6, true)".format(input_geom) + ).take(1)[0][0] assert isinstance(cell_ids, list) assert isinstance(cell_ids[0], int) # test null case @@ -1573,21 +2173,26 @@ def test_st_h3_cell_ids(self): assert cell_ids is None def test_st_h3_cell_distance(self): - df = self.spark.sql("select ST_H3CellDistance(ST_H3CellIDs(ST_GeomFromWKT('POINT(1 2)'), 8, true)[0], ST_H3CellIDs(ST_GeomFromWKT('POINT(1.23 1.59)'), 8, true)[0])") + df = self.spark.sql( + "select ST_H3CellDistance(ST_H3CellIDs(ST_GeomFromWKT('POINT(1 2)'), 8, true)[0], ST_H3CellIDs(ST_GeomFromWKT('POINT(1.23 1.59)'), 8, true)[0])" + ) assert df.take(1)[0][0] == 78 def test_st_h3_kring(self): - df = self.spark.sql(""" + df = self.spark.sql( + """ SELECT ST_H3KRing(ST_H3CellIDs(ST_GeomFromWKT('POINT(1 2)'), 8, true)[0], 1, true) exactRings, ST_H3KRing(ST_H3CellIDs(ST_GeomFromWKT('POINT(1 2)'), 8, true)[0], 1, false) allRings, ST_H3CellIDs(ST_GeomFromWKT('POINT(1 2)'), 8, true) original_cells - """) + """ + ) exact_rings, all_rings, original_cells = df.take(1)[0] assert set(exact_rings + original_cells) == set(all_rings) def test_st_h3_togeom(self): - df = self.spark.sql(""" + df = self.spark.sql( + """ SELECT ST_Intersects( ST_H3ToGeom(ST_H3CellIDs(ST_GeomFromText('POLYGON((-1 0, 1 0, 0 0, 0 1, -1 0))'), 6, true))[10], @@ -1601,114 +2206,150 @@ def test_st_h3_togeom(self): ST_H3ToGeom(ST_H3CellIDs(ST_GeomFromText('POLYGON((-1 0, 1 0, 0 0, 0 1, -1 0))'), 6, false))[50], ST_GeomFromText('POLYGON((-1 0, 1 0, 0 0, 0 1, -1 0))') ) - """) + """ + ) res1, res2, res3 = df.take(1)[0] assert res1 and res2 and res3 def test_st_numPoints(self): - actual = self.spark.sql("SELECT ST_NumPoints(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'))").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_NumPoints(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'))" + ).take(1)[0][0] expected = 3 assert expected == actual def test_force3D(self): expected = 3 - actualDf = self.spark.sql("SELECT ST_Force3D(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1) AS geom") + actualDf = self.spark.sql( + "SELECT ST_Force3D(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1) AS geom" + ) actual = actualDf.selectExpr("ST_NDims(geom)").take(1)[0][0] assert expected == actual def test_force3DM(self): - actualDf = self.spark.sql("SELECT ST_Force3DM(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1) AS geom") + actualDf = self.spark.sql( + "SELECT ST_Force3DM(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1) AS geom" + ) actual = actualDf.selectExpr("ST_HasM(geom)").take(1)[0][0] assert actual def test_force3DZ(self): expected = 3 - actualDf = self.spark.sql("SELECT ST_Force3DZ(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1) AS geom") + actualDf = self.spark.sql( + "SELECT ST_Force3DZ(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1) AS geom" + ) actual = actualDf.selectExpr("ST_NDims(geom)").take(1)[0][0] assert expected == actual def test_force4D(self): expected = 4 - actualDf = self.spark.sql("SELECT ST_Force4D(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1, 1.1) AS geom") + actualDf = self.spark.sql( + "SELECT ST_Force4D(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1, 1.1) AS geom" + ) actual = actualDf.selectExpr("ST_NDims(geom)").take(1)[0][0] assert expected == actual def test_st_force_collection(self): - basedf = self.spark.sql("SELECT ST_GeomFromWKT('MULTIPOINT (30 10, 40 40, 20 20, 10 30, 10 10, 20 50)') AS mpoint, ST_GeomFromWKT('POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))') AS poly") - actual = basedf.selectExpr("ST_NumGeometries(ST_ForceCollection(mpoint))").take(1)[0][0] + basedf = self.spark.sql( + "SELECT ST_GeomFromWKT('MULTIPOINT (30 10, 40 40, 20 20, 10 30, 10 10, 20 50)') AS mpoint, ST_GeomFromWKT('POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))') AS poly" + ) + actual = basedf.selectExpr("ST_NumGeometries(ST_ForceCollection(mpoint))").take( + 1 + )[0][0] assert actual == 6 - actual = basedf.selectExpr("ST_NumGeometries(ST_ForceCollection(poly))").take(1)[0][0] + actual = basedf.selectExpr("ST_NumGeometries(ST_ForceCollection(poly))").take( + 1 + )[0][0] assert actual == 1 def test_forcePolygonCW(self): - actualDf = self.spark.sql("SELECT ST_ForcePolygonCW(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35),(30 20, 20 15, 20 25, 30 20))')) AS polyCW") + actualDf = self.spark.sql( + "SELECT ST_ForcePolygonCW(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35),(30 20, 20 15, 20 25, 30 20))')) AS polyCW" + ) actual = actualDf.selectExpr("ST_AsText(polyCW)").take(1)[0][0] expected = "POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))" assert expected == actual def test_forceRHR(self): - actualDf = self.spark.sql("SELECT ST_ForceRHR(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35),(30 20, 20 15, 20 25, 30 20))')) AS polyCW") + actualDf = self.spark.sql( + "SELECT ST_ForceRHR(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 35),(30 20, 20 15, 20 25, 30 20))')) AS polyCW" + ) actual = actualDf.selectExpr("ST_AsText(polyCW)").take(1)[0][0] expected = "POLYGON ((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 20))" assert expected == actual def test_generate_points(self): - actual = self.spark.sql("SELECT ST_NumGeometries(ST_GeneratePoints(ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round'), 15))")\ - .first()[0] + actual = self.spark.sql( + "SELECT ST_NumGeometries(ST_GeneratePoints(ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round'), 15))" + ).first()[0] assert actual == 15 actual = self.spark.sql( - "SELECT ST_AsText(ST_ReducePrecision(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 5, 10), 5))") \ - .first()[0] + "SELECT ST_AsText(ST_ReducePrecision(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 5, 10), 5))" + ).first()[0] expected = "MULTIPOINT ((53.82582 2.57803), (13.55212 2.44117), (59.12854 3.70611), (61.37698 7.14985), (10.49657 4.40622))" assert expected == actual actual = self.spark.sql( - "SELECT ST_NumGeometries(ST_GeneratePoints(ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round'), 15, 100))") \ - .first()[0] + "SELECT ST_NumGeometries(ST_GeneratePoints(ST_Buffer(ST_GeomFromWKT('LINESTRING(50 50,150 150,150 50)'), 10, false, 'endcap=round join=round'), 15, 100))" + ).first()[0] assert actual == 15 - actual = self.spark.sql("SELECT ST_NumGeometries(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 30))")\ - .first()[0] + actual = self.spark.sql( + "SELECT ST_NumGeometries(ST_GeneratePoints(ST_GeomFromWKT('MULTIPOLYGON (((10 0, 10 10, 20 10, 20 0, 10 0)), ((50 0, 50 10, 70 10, 70 0, 50 0)))'), 30))" + ).first()[0] assert actual == 30 def test_nRings(self): expected = 1 - actualDf = self.spark.sql("SELECT ST_GeomFromText('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))') AS geom") + actualDf = self.spark.sql( + "SELECT ST_GeomFromText('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))') AS geom" + ) actual = actualDf.selectExpr("ST_NRings(geom)").take(1)[0][0] assert expected == actual def test_trangulatePolygon(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0), (5 5, 5 8, 8 8, 8 5, 5 5))') as poly") - actual = baseDf.selectExpr("ST_AsText(ST_TriangulatePolygon(poly))").take(1)[0][0] + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0), (5 5, 5 8, 8 8, 8 5, 5 5))') as poly" + ) + actual = baseDf.selectExpr("ST_AsText(ST_TriangulatePolygon(poly))").take(1)[0][ + 0 + ] expected = "GEOMETRYCOLLECTION (POLYGON ((0 0, 0 10, 5 5, 0 0)), POLYGON ((5 8, 5 5, 0 10, 5 8)), POLYGON ((10 0, 0 0, 5 5, 10 0)), POLYGON ((10 10, 5 8, 0 10, 10 10)), POLYGON ((10 0, 5 5, 8 5, 10 0)), POLYGON ((5 8, 10 10, 8 8, 5 8)), POLYGON ((10 10, 10 0, 8 5, 10 10)), POLYGON ((8 5, 8 8, 10 10, 8 5)))" assert actual == expected def test_translate(self): expected = "POLYGON ((3 5, 3 6, 4 6, 4 5, 3 5))" actual_df = self.spark.sql( - "SELECT ST_Translate(ST_GeomFromText('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 2, 5) AS geom") + "SELECT ST_Translate(ST_GeomFromText('POLYGON ((1 0, 1 1, 2 1, 2 0, 1 0))'), 2, 5) AS geom" + ) actual = actual_df.selectExpr("ST_AsText(geom)").take(1)[0][0] assert expected == actual def test_voronoiPolygons(self): expected = "GEOMETRYCOLLECTION (POLYGON ((-2 -2, -2 4, 4 -2, -2 -2)), POLYGON ((-2 4, 4 4, 4 -2, -2 4)))" - actual_df = self.spark.sql("SELECT ST_VoronoiPolygons(ST_GeomFromText('MULTIPOINT (0 0, 2 2)')) AS geom") + actual_df = self.spark.sql( + "SELECT ST_VoronoiPolygons(ST_GeomFromText('MULTIPOINT (0 0, 2 2)')) AS geom" + ) actual = actual_df.selectExpr("ST_AsText(geom)").take(1)[0][0] assert expected == actual def test_frechetDistance(self): expected = 5.0990195135927845 - actual_df = self.spark.sql("SELECT ST_FrechetDistance(ST_GeomFromText('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, " - "5 0)'), ST_GeomFromText('POINT (0 1)'))") + actual_df = self.spark.sql( + "SELECT ST_FrechetDistance(ST_GeomFromText('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, " + "5 0)'), ST_GeomFromText('POINT (0 1)'))" + ) actual = actual_df.take(1)[0][0] assert expected == actual def test_affine(self): expected = "POLYGON Z((2 3 1, 4 5 1, 7 8 2, 2 3 1))" - actual_df = self.spark.sql("SELECT ST_Affine(ST_GeomFromText('POLYGON ((1 0 1, 1 1 1, 2 2 2, 1 0 1))'), 1, 2, " - "1, 2, 1, 2) AS geom") + actual_df = self.spark.sql( + "SELECT ST_Affine(ST_GeomFromText('POLYGON ((1 0 1, 1 1 1, 2 2 2, 1 0 1))'), 1, 2, " + "1, 2, 1, 2) AS geom" + ) actual = actual_df.selectExpr("ST_AsText(geom)").take(1)[0][0] assert expected == actual @@ -1724,28 +2365,37 @@ def test_st_ashexewkb(self): def test_boundingDiagonal(self): expected = "LINESTRING (1 0, 2 1)" - actual_df = self.spark.sql("SELECT ST_BoundingDiagonal(ST_GeomFromText('POLYGON ((1 0, 1 1, 2 1, 2 0, " - "1 0))')) AS geom") + actual_df = self.spark.sql( + "SELECT ST_BoundingDiagonal(ST_GeomFromText('POLYGON ((1 0, 1 1, 2 1, 2 0, " + "1 0))')) AS geom" + ) actual = actual_df.selectExpr("ST_AsText(geom)").take(1)[0][0] assert expected == actual def test_angle(self): expectedDegrees = 11.309932474020195 expectedRad = 0.19739555984988044 - actual_df = self.spark.sql("SELECT ST_Angle(ST_GeomFromText('LINESTRING (0 0, 1 1)'), ST_GeomFromText('LINESTRING (0 0, 3 2)')) AS angleRad") + actual_df = self.spark.sql( + "SELECT ST_Angle(ST_GeomFromText('LINESTRING (0 0, 1 1)'), ST_GeomFromText('LINESTRING (0 0, 3 2)')) AS angleRad" + ) actualRad = actual_df.take(1)[0][0] actualDegrees = actual_df.selectExpr("ST_Degrees(angleRad)").take(1)[0][0] assert math.isclose(expectedRad, actualRad, rel_tol=1e-9) assert math.isclose(expectedDegrees, actualDegrees, rel_tol=1e-9) + def test_hausdorffDistance(self): expected = 5.0 - actual_df = self.spark.sql("SELECT ST_HausdorffDistance(ST_GeomFromText('POLYGON ((1 0 1, 1 1 2, 2 1 5, " - "2 0 1, 1 0 1))'), ST_GeomFromText('POLYGON ((4 0 4, 6 1 4, 6 4 9, 6 1 3, " - "4 0 4))'), 0.5)") - actual_df_default = self.spark.sql("SELECT ST_HausdorffDistance(ST_GeomFromText('POLYGON ((1 0 1, 1 1 2, " - "2 1 5, " - "2 0 1, 1 0 1))'), ST_GeomFromText('POLYGON ((4 0 4, 6 1 4, 6 4 9, 6 1 3, " - "4 0 4))'))") + actual_df = self.spark.sql( + "SELECT ST_HausdorffDistance(ST_GeomFromText('POLYGON ((1 0 1, 1 1 2, 2 1 5, " + "2 0 1, 1 0 1))'), ST_GeomFromText('POLYGON ((4 0 4, 6 1 4, 6 4 9, 6 1 3, " + "4 0 4))'), 0.5)" + ) + actual_df_default = self.spark.sql( + "SELECT ST_HausdorffDistance(ST_GeomFromText('POLYGON ((1 0 1, 1 1 2, " + "2 1 5, " + "2 0 1, 1 0 1))'), ST_GeomFromText('POLYGON ((4 0 4, 6 1 4, 6 4 9, 6 1 3, " + "4 0 4))'))" + ) actual = actual_df.take(1)[0][0] actual_default = actual_df_default.take(1)[0][0] assert expected == actual @@ -1754,5 +2404,7 @@ def test_hausdorffDistance(self): def test_st_coord_dim(self): point_df = self.spark.sql("SELECT ST_GeomFromWkt('POINT(7 8 6)') AS geom") - point_row = [pt_row[0] for pt_row in point_df.selectExpr("ST_CoordDim(geom)").collect()] - assert(point_row == [3]) + point_row = [ + pt_row[0] for pt_row in point_df.selectExpr("ST_CoordDim(geom)").collect() + ] + assert point_row == [3] diff --git a/python/tests/sql/test_geoparquet.py b/python/tests/sql/test_geoparquet.py index 969d9cfdc0..0878e20fb1 100644 --- a/python/tests/sql/test_geoparquet.py +++ b/python/tests/sql/test_geoparquet.py @@ -36,35 +36,45 @@ def test_interoperability_with_geopandas(self, tmp_path): test_data = [ [1, Point(0, 0), LineString([(1, 2), (3, 4), (5, 6)])], [2, LineString([(1, 2), (3, 4), (5, 6)]), Point(1, 1)], - [3, Point(1, 1), LineString([(1, 2), (3, 4), (5, 6)])] + [3, Point(1, 1), LineString([(1, 2), (3, 4), (5, 6)])], ] - df = self.spark.createDataFrame(data=test_data, schema=["id", "g0", "g1"]).repartition(1) + df = self.spark.createDataFrame( + data=test_data, schema=["id", "g0", "g1"] + ).repartition(1) geoparquet_save_path = os.path.join(tmp_path, "test.parquet") df.write.format("geoparquet").save(geoparquet_save_path) # Load geoparquet file written by sedona using geopandas gdf = geopandas.read_parquet(geoparquet_save_path) - assert gdf.dtypes['g0'].name == 'geometry' - assert gdf.dtypes['g1'].name == 'geometry' + assert gdf.dtypes["g0"].name == "geometry" + assert gdf.dtypes["g1"].name == "geometry" # Load geoparquet file written by geopandas using sedona - gdf = geopandas.GeoDataFrame([ - {'g': wkt_loads('POINT (1 2)'), 'i': 10}, - {'g': wkt_loads('LINESTRING (1 2, 3 4)'), 'i': 20} - ], geometry='g') + gdf = geopandas.GeoDataFrame( + [ + {"g": wkt_loads("POINT (1 2)"), "i": 10}, + {"g": wkt_loads("LINESTRING (1 2, 3 4)"), "i": 20}, + ], + geometry="g", + ) geoparquet_save_path2 = os.path.join(tmp_path, "test_2.parquet") gdf.to_parquet(geoparquet_save_path2) df2 = self.spark.read.format("geoparquet").load(geoparquet_save_path2) assert df2.count() == 2 row = df2.collect()[0] - assert isinstance(row['g'], BaseGeometry) + assert isinstance(row["g"], BaseGeometry) def test_load_geoparquet_with_spatial_filter(self): - df = self.spark.read.format("geoparquet").load(geoparquet_input_location)\ - .where("ST_Contains(geometry, ST_GeomFromText('POINT (35.174722 -6.552465)'))") + df = ( + self.spark.read.format("geoparquet") + .load(geoparquet_input_location) + .where( + "ST_Contains(geometry, ST_GeomFromText('POINT (35.174722 -6.552465)'))" + ) + ) rows = df.collect() assert len(rows) == 1 - assert rows[0]['name'] == 'Tanzania' + assert rows[0]["name"] == "Tanzania" def test_load_plain_parquet_file(self): with pytest.raises(Exception) as excinfo: @@ -72,23 +82,29 @@ def test_load_plain_parquet_file(self): assert "does not contain valid geo metadata" in str(excinfo.value) def test_inspect_geoparquet_metadata(self): - df = self.spark.read.format("geoparquet.metadata").load(geoparquet_input_location) + df = self.spark.read.format("geoparquet.metadata").load( + geoparquet_input_location + ) rows = df.collect() assert len(rows) == 1 row = rows[0] - assert row['path'].endswith('.parquet') - assert len(row['version'].split('.')) == 3 - assert row['primary_column'] == 'geometry' - column_metadata = row['columns']['geometry'] - assert column_metadata['encoding'] == 'WKB' - assert len(column_metadata['bbox']) == 4 - assert isinstance(json.loads(column_metadata['crs']), dict) + assert row["path"].endswith(".parquet") + assert len(row["version"].split(".")) == 3 + assert row["primary_column"] == "geometry" + column_metadata = row["columns"]["geometry"] + assert column_metadata["encoding"] == "WKB" + assert len(column_metadata["bbox"]) == 4 + assert isinstance(json.loads(column_metadata["crs"]), dict) def test_reading_legacy_parquet_files(self): - df = self.spark.read.format("geoparquet").option("legacyMode", "true").load(legacy_parquet_input_location) + df = ( + self.spark.read.format("geoparquet") + .option("legacyMode", "true") + .load(legacy_parquet_input_location) + ) rows = df.collect() assert len(rows) > 0 for row in rows: - assert isinstance(row['geom'], BaseGeometry) - assert isinstance(row['struct_geom']['g0'], BaseGeometry) - assert isinstance(row['struct_geom']['g1'], BaseGeometry) + assert isinstance(row["geom"], BaseGeometry) + assert isinstance(row["struct_geom"]["g0"], BaseGeometry) + assert isinstance(row["struct_geom"]["g1"], BaseGeometry) diff --git a/python/tests/sql/test_predicate.py b/python/tests/sql/test_predicate.py index 504ef81f1d..539ed73584 100644 --- a/python/tests/sql/test_predicate.py +++ b/python/tests/sql/test_predicate.py @@ -15,7 +15,11 @@ # specific language governing permissions and limitations # under the License. -from tests import csv_point_input_location, csv_point1_input_location, csv_polygon1_input_location +from tests import ( + csv_point_input_location, + csv_point1_input_location, + csv_polygon1_input_location, +) from tests.test_base import TestBase from pyspark.sql.functions import expr @@ -23,184 +27,224 @@ class TestPredicate(TestBase): def test_st_contains(self): - point_csv_df = self.spark.read. \ - format("csv"). \ - option("delimiter", ","). \ - option("header", "false").load( - csv_point_input_location - ) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) point_df.createOrReplaceTempView("pointdf") result_df = self.spark.sql( - "select * from pointdf where ST_Contains(ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0), pointdf.arealandmark)") + "select * from pointdf where ST_Contains(ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0), pointdf.arealandmark)" + ) result_df.show() assert result_df.count() == 999 def test_st_intersects(self): - point_csv_df = self.spark.read. \ - format("csv"). \ - option("delimiter", ","). \ - option("header", "false").load( - csv_point_input_location - ) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) point_df.createOrReplaceTempView("pointdf") result_df = self.spark.sql( - "select * from pointdf where ST_Intersects(ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0), pointdf.arealandmark)") + "select * from pointdf where ST_Intersects(ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0), pointdf.arealandmark)" + ) result_df.show() assert result_df.count() == 999 def test_st_within(self): - point_csv_df = self.spark.read. \ - format("csv"). \ - option("delimiter", ","). \ - option("header", "false").load( - csv_point_input_location - ) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) point_df.createOrReplaceTempView("pointdf") result_df = self.spark.sql( - "select * from pointdf where ST_Within(pointdf.arealandmark, ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0))") + "select * from pointdf where ST_Within(pointdf.arealandmark, ST_PolygonFromEnvelope(1.0,100.0,1000.0,1100.0))" + ) result_df.show() assert result_df.count() == 999 def test_st_equals_for_st_point(self): - point_df_csv = self.spark.read.\ - format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_point1_input_location - ) + point_df_csv = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point1_input_location) + ) point_df_csv.createOrReplaceTempView("pointtable") point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as point from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as point from pointtable" + ) point_df.createOrReplaceTempView("pointdf") - equal_df = self.spark.sql("select * from pointdf where ST_Equals(pointdf.point, ST_Point(100.1, 200.1)) ") + equal_df = self.spark.sql( + "select * from pointdf where ST_Equals(pointdf.point, ST_Point(100.1, 200.1)) " + ) equal_df.show() assert equal_df.count() == 5, f"Expected 5 value but got {equal_df.count()}" def test_st_equals_for_polygon(self): - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_polygon1_input_location + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon1_input_location) ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() equal_df1 = self.spark.sql( - "select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromEnvelope(100.01,200.01,100.5,200.5)) ") + "select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromEnvelope(100.01,200.01,100.5,200.5)) " + ) equal_df1.show() assert equal_df1.count() == 5, f"Expected 5 value but got ${equal_df1.count()}" equal_df_2 = self.spark.sql( - "select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromEnvelope(100.5,200.5,100.01,200.01)) ") + "select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromEnvelope(100.5,200.5,100.01,200.01)) " + ) equal_df_2.show() assert equal_df_2.count() == 5, f"Expected 5 value but got {equal_df_2.count()}" def test_st_equals_for_st_point_and_st_polygon(self): - polygon_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - csv_polygon1_input_location) + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon1_input_location) + ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() equal_df = self.spark.sql( - "select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_Point(91.01,191.01)) ") + "select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_Point(91.01,191.01)) " + ) equal_df.show() assert equal_df.count() == 0, f"Expected 0 value but got {equal_df.count()}" def test_st_equals_for_st_linestring_and_st_polygon(self): - polygon_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - csv_polygon1_input_location) + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon1_input_location) + ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() string = "100.01,200.01,100.5,200.01,100.5,200.5,100.01,200.5,100.01,200.01" - equal_df = self.spark.sql(f"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_LineStringFromText(\'{string}\', \',\')) ") + equal_df = self.spark.sql( + f"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_LineStringFromText('{string}', ',')) " + ) equal_df.show() assert equal_df.count() == 0, f"Expected 0 value but got {equal_df.count()}" def test_st_equals_for_st_polygon_from_envelope_and_st_polygon_from_text(self): - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_polygon1_input_location + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon1_input_location) ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() string = "100.01,200.01,100.5,200.01,100.5,200.5,100.01,200.5,100.01,200.01" equal_df = self.spark.sql( - f"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromText(\'{string}\', \',\')) ") + f"select * from polygonDf where ST_Equals(polygonDf.polygonshape, ST_PolygonFromText('{string}', ',')) " + ) equal_df.show() assert equal_df.count() == 5, f"Expected 5 value but got {equal_df.count()}" def test_st_crosses(self): crosses_test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((1 1, 4 1, 4 4, 1 4, 1 1))') as a,ST_GeomFromWKT('LINESTRING(1 5, 5 1)') as b") + "select ST_GeomFromWKT('POLYGON((1 1, 4 1, 4 4, 1 4, 1 1))') as a,ST_GeomFromWKT('LINESTRING(1 5, 5 1)') as b" + ) crosses_test_table.createOrReplaceTempView("crossesTesttable") crosses = self.spark.sql("select(ST_Crosses(a, b)) from crossesTesttable") not_crosses_test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((1 1, 4 1, 4 4, 1 4, 1 1))') as a,ST_GeomFromWKT('POLYGON((2 2, 5 2, 5 5, 2 5, 2 2))') as b") + "select ST_GeomFromWKT('POLYGON((1 1, 4 1, 4 4, 1 4, 1 1))') as a,ST_GeomFromWKT('POLYGON((2 2, 5 2, 5 5, 2 5, 2 2))') as b" + ) not_crosses_test_table.createOrReplaceTempView("notCrossesTesttable") - not_crosses = self.spark.sql("select(ST_Crosses(a, b)) from notCrossesTesttable") + not_crosses = self.spark.sql( + "select(ST_Crosses(a, b)) from notCrossesTesttable" + ) assert crosses.take(1)[0][0] assert not not_crosses.take(1)[0][0] def test_st_touches(self): - point_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - csv_point_input_location + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) ) point_csv_df.createOrReplaceTempView("pointtable") point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable" + ) point_df.createOrReplaceTempView("pointdf") result_df = self.spark.sql( - "select * from pointdf where ST_Touches(pointdf.arealandmark, ST_PolygonFromEnvelope(0.0,99.0,1.1,101.1))") + "select * from pointdf where ST_Touches(pointdf.arealandmark, ST_PolygonFromEnvelope(0.0,99.0,1.1,101.1))" + ) result_df.show() assert result_df.count() == 1 def test_st_relate(self): - baseDf = self.spark.sql("SELECT ST_GeomFromWKT('LINESTRING (1 1, 5 5)') AS g1, ST_GeomFromWKT('POLYGON ((3 3, 3 7, 7 7, 7 3, 3 3))') as g2, '1010F0212' as im") + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (1 1, 5 5)') AS g1, ST_GeomFromWKT('POLYGON ((3 3, 3 7, 7 7, 7 3, 3 3))') as g2, '1010F0212' as im" + ) actual = baseDf.selectExpr("ST_Relate(g1, g2)").take(1)[0][0] assert actual == "1010F0212" @@ -208,12 +252,15 @@ def test_st_relate(self): assert actual def test_st_relate_match(self): - actual = self.spark.sql("SELECT ST_RelateMatch('101202FFF', 'TTTTTTFFF') ").take(1)[0][0] + actual = self.spark.sql( + "SELECT ST_RelateMatch('101202FFF', 'TTTTTTFFF') " + ).take(1)[0][0] assert actual def test_st_overlaps(self): test_table = self.spark.sql( - "select ST_GeomFromWKT('POLYGON((2.5 2.5, 2.5 4.5, 4.5 4.5, 4.5 2.5, 2.5 2.5))') as a,ST_GeomFromWKT('POLYGON((4 4, 4 6, 6 6, 6 4, 4 4))') as b, ST_GeomFromWKT('POLYGON((5 5, 4 6, 6 6, 6 4, 5 5))') as c, ST_GeomFromWKT('POLYGON((5 5, 4 6, 6 6, 6 4, 5 5))') as d") + "select ST_GeomFromWKT('POLYGON((2.5 2.5, 2.5 4.5, 4.5 4.5, 4.5 2.5, 2.5 2.5))') as a,ST_GeomFromWKT('POLYGON((4 4, 4 6, 6 6, 6 4, 4 4))') as b, ST_GeomFromWKT('POLYGON((5 5, 4 6, 6 6, 6 4, 5 5))') as c, ST_GeomFromWKT('POLYGON((5 5, 4 6, 6 6, 6 4, 5 5))') as d" + ) test_table.createOrReplaceTempView("testtable") overlaps = self.spark.sql("select ST_Overlaps(a,b) from testtable") not_overlaps = self.spark.sql("select ST_Overlaps(c,d) from testtable") @@ -221,31 +268,55 @@ def test_st_overlaps(self): assert not not_overlaps.take(1)[0][0] def test_st_ordering_equals_ok(self): - test_table = self.spark.sql("select ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 2 0))') as a, ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 2 0))') as b, ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 0 -2, 2 0))') as c, ST_GeomFromWKT('POLYGON((0 2, -2 0, 2 0, 0 2))') as d") + test_table = self.spark.sql( + "select ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 2 0))') as a, ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 2 0))') as b, ST_GeomFromWKT('POLYGON((2 0, 0 2, -2 0, 0 -2, 2 0))') as c, ST_GeomFromWKT('POLYGON((0 2, -2 0, 2 0, 0 2))') as d" + ) test_table.createOrReplaceTempView("testorderingequals") - order_equals = self.spark.sql("select ST_OrderingEquals(a,b) from testorderingequals") - not_order_equals_diff_geom = self.spark.sql("select ST_OrderingEquals(a,c) from testorderingequals") - not_order_equals_diff_order = self.spark.sql("select ST_OrderingEquals(a,d) from testorderingequals") + order_equals = self.spark.sql( + "select ST_OrderingEquals(a,b) from testorderingequals" + ) + not_order_equals_diff_geom = self.spark.sql( + "select ST_OrderingEquals(a,c) from testorderingequals" + ) + not_order_equals_diff_order = self.spark.sql( + "select ST_OrderingEquals(a,d) from testorderingequals" + ) assert order_equals.take(1)[0][0] assert not not_order_equals_diff_geom.take(1)[0][0] assert not not_order_equals_diff_order.take(1)[0][0] def test_st_dwithin(self): - test_table = self.spark.sql("select ST_GeomFromWKT('POINT (0 0)') as origin, ST_GeomFromWKT('POINT (2 0)') as point_1") + test_table = self.spark.sql( + "select ST_GeomFromWKT('POINT (0 0)') as origin, ST_GeomFromWKT('POINT (2 0)') as point_1" + ) test_table.createOrReplaceTempView("test_table") - isWithin = self.spark.sql("select ST_DWithin(origin, point_1, 3) from test_table").head()[0] + isWithin = self.spark.sql( + "select ST_DWithin(origin, point_1, 3) from test_table" + ).head()[0] assert isWithin is True def test_dwithin_use_sphere(self): - test_table = self.spark.sql("select ST_GeomFromWKT('POINT (-122.335167 47.608013)') as seattle, ST_GeomFromWKT('POINT (-73.935242 40.730610)') as ny") + test_table = self.spark.sql( + "select ST_GeomFromWKT('POINT (-122.335167 47.608013)') as seattle, ST_GeomFromWKT('POINT (-73.935242 40.730610)') as ny" + ) test_table.createOrReplaceTempView("test_table") - isWithin = self.spark.sql("select ST_DWithin(seattle, ny, 2000000, true) from test_table").head()[0] + isWithin = self.spark.sql( + "select ST_DWithin(seattle, ny, 2000000, true) from test_table" + ).head()[0] assert isWithin is False - def test_dwithin_use_sphere_complex_boolean_expression(self): expected = 55 df_point = self.spark.range(10).withColumn("pt", expr("ST_Point(id, id)")) - df_polygon = self.spark.range(10).withColumn("poly", expr("ST_Point(id, id + 0.01)")) - actual = df_point.alias("a").join(df_polygon.alias("b"), expr("ST_DWithin(pt, poly, 10000, a.`id` % 2 = 0)")).count() + df_polygon = self.spark.range(10).withColumn( + "poly", expr("ST_Point(id, id + 0.01)") + ) + actual = ( + df_point.alias("a") + .join( + df_polygon.alias("b"), + expr("ST_DWithin(pt, poly, 10000, a.`id` % 2 = 0)"), + ) + .count() + ) assert expected == actual diff --git a/python/tests/sql/test_predicate_join.py b/python/tests/sql/test_predicate_join.py index 0a14c5886b..f8a35831a8 100644 --- a/python/tests/sql/test_predicate_join.py +++ b/python/tests/sql/test_predicate_join.py @@ -17,272 +17,368 @@ from pyspark import Row from pyspark.sql.functions import broadcast, expr -from pyspark.sql.types import StructType, StringType, IntegerType, StructField, DoubleType - -from tests import csv_polygon_input_location, csv_point_input_location, overlap_polygon_input_location, \ - csv_point1_input_location, csv_point2_input_location, csv_polygon1_input_location, csv_polygon2_input_location, \ - csv_polygon1_random_input_location, csv_polygon2_random_input_location +from pyspark.sql.types import ( + StructType, + StringType, + IntegerType, + StructField, + DoubleType, +) + +from tests import ( + csv_polygon_input_location, + csv_point_input_location, + overlap_polygon_input_location, + csv_point1_input_location, + csv_point2_input_location, + csv_polygon1_input_location, + csv_polygon2_input_location, + csv_polygon1_random_input_location, + csv_polygon2_random_input_location, +) from tests.test_base import TestBase class TestPredicateJoin(TestBase): def test_st_contains_in_join(self): - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_polygon_input_location + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon_input_location) ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_point_input_location + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show() point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable" + ) point_df.createOrReplaceTempView("pointdf") point_df.show() range_join_df = self.spark.sql( - "select * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) ") + "select * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) " + ) range_join_df.explain() range_join_df.show(3) assert range_join_df.count() == 1000 def test_st_intersects_in_a_join(self): - polygon_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - csv_polygon_input_location + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon_input_location) ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - point_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - csv_point_input_location + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show() point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable" + ) point_df.createOrReplaceTempView("pointdf") point_df.show() range_join_df = self.spark.sql( - "select * from polygondf, pointdf where ST_Intersects(polygondf.polygonshape,pointdf.pointshape) ") + "select * from polygondf, pointdf where ST_Intersects(polygondf.polygonshape,pointdf.pointshape) " + ) range_join_df.explain() range_join_df.show(3) assert range_join_df.count() == 1000 def test_st_touches_in_a_join(self): - polygon_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load(csv_polygon_input_location) + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon_input_location) + ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() - polygon_df = self.spark.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + polygon_df = self.spark.sql( + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - point_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load(csv_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show() - point_df = self.spark.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable") + point_df = self.spark.sql( + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable" + ) point_df.createOrReplaceTempView("pointdf") point_df.show() - range_join_df = self.spark.sql("select * from polygondf, pointdf where ST_Touches(polygondf.polygonshape,pointdf.pointshape) ") + range_join_df = self.spark.sql( + "select * from polygondf, pointdf where ST_Touches(polygondf.polygonshape,pointdf.pointshape) " + ) range_join_df.explain() range_join_df.show(3) assert range_join_df.count() == 0 def test_st_within_in_a_join(self): - polygon_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - csv_polygon_input_location) + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon_input_location) + ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - point_csv_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - csv_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show() point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable" + ) point_df.createOrReplaceTempView("pointdf") point_df.show() range_join_df = self.spark.sql( - "select * from polygondf, pointdf where ST_Within(pointdf.pointshape, polygondf.polygonshape) ") + "select * from polygondf, pointdf where ST_Within(pointdf.pointshape, polygondf.polygonshape) " + ) range_join_df.explain() range_join_df.show(3) assert range_join_df.count() == 1000 def test_st_overlaps_in_a_join(self): - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_polygon_input_location + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon_input_location) ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") - polygon_csv_overlap_df = self.spark.read.format("csv").option("delimiter", ",").option("header", "false").load( - overlap_polygon_input_location) + polygon_csv_overlap_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(overlap_polygon_input_location) + ) polygon_csv_overlap_df.createOrReplaceTempView("polygonoverlaptable") polygon_overlap_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygonoverlaptable._c0 as Decimal(24,20)),cast(polygonoverlaptable._c1 as Decimal(24,20)), cast(polygonoverlaptable._c2 as Decimal(24,20)), cast(polygonoverlaptable._c3 as Decimal(24,20))) as polygonshape from polygonoverlaptable") + "select ST_PolygonFromEnvelope(cast(polygonoverlaptable._c0 as Decimal(24,20)),cast(polygonoverlaptable._c1 as Decimal(24,20)), cast(polygonoverlaptable._c2 as Decimal(24,20)), cast(polygonoverlaptable._c3 as Decimal(24,20))) as polygonshape from polygonoverlaptable" + ) polygon_overlap_df.createOrReplaceTempView("polygonodf") range_join_df = self.spark.sql( - "select * from polygondf, polygonodf where ST_Overlaps(polygondf.polygonshape, polygonodf.polygonshape)") + "select * from polygondf, polygonodf where ST_Overlaps(polygondf.polygonshape, polygonodf.polygonshape)" + ) range_join_df.explain() range_join_df.show(3) assert range_join_df.count() == 15 def test_st_crosses_in_a_join(self): - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_polygon_input_location + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon_input_location) ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_point_input_location + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show() point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable" + ) point_df.createOrReplaceTempView("pointdf") point_df.show() range_join_df = self.spark.sql( - "select * from polygondf, pointdf where ST_Crosses(pointdf.pointshape, polygondf.polygonshape) ") + "select * from polygondf, pointdf where ST_Crosses(pointdf.pointshape, polygondf.polygonshape) " + ) range_join_df.explain() range_join_df.show(3) assert range_join_df.count() == 0 def test_st_distance_radius_in_a_join(self): - point_csv_df_1 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_point_input_location + point_csv_df_1 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) ) point_csv_df_1.createOrReplaceTempView("pointtable") point_csv_df_1.show() point_df_1 = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable" + ) point_df_1.createOrReplaceTempView("pointdf1") point_df_1.show() - point_csv_df_2 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_point_input_location) + point_csv_df_2 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df_2.createOrReplaceTempView("pointtable") point_csv_df_2.show() point_df2 = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape2 from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape2 from pointtable" + ) point_df2.createOrReplaceTempView("pointdf2") point_df2.show() distance_join_df = self.spark.sql( - "select * from pointdf1, pointdf2 where ST_Distance(pointdf1.pointshape1,pointdf2.pointshape2) <= 2") + "select * from pointdf1, pointdf2 where ST_Distance(pointdf1.pointshape1,pointdf2.pointshape2) <= 2" + ) distance_join_df.explain() distance_join_df.show(10) assert distance_join_df.count() == 2998 def test_st_distance_less_radius_in_a_join(self): - point_csv_df_1 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load(csv_point_input_location) + point_csv_df_1 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df_1.createOrReplaceTempView("pointtable") point_csv_df_1.show() - point_df1 = self.spark.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable") + point_df1 = self.spark.sql( + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable" + ) point_df1.createOrReplaceTempView("pointdf1") point_df1.show() - point_csv_df2 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load(csv_point_input_location) + point_csv_df2 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df2.createOrReplaceTempView("pointtable") point_csv_df2.show() - point_df2 = self.spark.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape2 from pointtable") + point_df2 = self.spark.sql( + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape2 from pointtable" + ) point_df2.createOrReplaceTempView("pointdf2") point_df2.show() - distance_join_df = self.spark.sql("select * from pointdf1, pointdf2 where ST_Distance(pointdf1.pointshape1,pointdf2.pointshape2) < 2") + distance_join_df = self.spark.sql( + "select * from pointdf1, pointdf2 where ST_Distance(pointdf1.pointshape1,pointdf2.pointshape2) < 2" + ) distance_join_df.explain() distance_join_df.show(10) assert distance_join_df.count() == 2998 def test_st_contains_in_a_range_and_join(self): - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_polygon_input_location) + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon_input_location) + ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() - polygon_df = self.spark.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + polygon_df = self.spark.sql( + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_point_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show() - point_df = self.spark.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable") + point_df = self.spark.sql( + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable" + ) point_df.createOrReplaceTempView("pointdf") point_df.show() - range_join_df = self.spark.sql("select * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) " + - "and ST_Contains(ST_PolygonFromEnvelope(1.0,101.0,501.0,601.0), polygondf.polygonshape)") + range_join_df = self.spark.sql( + "select * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) " + + "and ST_Contains(ST_PolygonFromEnvelope(1.0,101.0,501.0,601.0), polygondf.polygonshape)" + ) range_join_df.explain() range_join_df.show(3) @@ -291,204 +387,269 @@ def test_st_contains_in_a_range_and_join(self): def test_super_small_data_join(self): raw_point_df = self.spark.createDataFrame( self.spark.sparkContext.parallelize( - [Row(1, "40.0", "-120.0"), Row(2, "30.0", "-110.0"), Row(3, "20.0", "-100.0")] + [ + Row(1, "40.0", "-120.0"), + Row(2, "30.0", "-110.0"), + Row(3, "20.0", "-100.0"), + ] ), StructType( [ StructField("id", IntegerType(), True), StructField("lat", StringType(), True), - StructField("lon", StringType(), True) + StructField("lon", StringType(), True), ] - ) + ), ) raw_point_df.createOrReplaceTempView("rawPointDf") pointDF = self.spark.sql( - "select id, ST_Point(cast(lat as Decimal(24,20)), cast(lon as Decimal(24,20))) AS latlon_point FROM rawPointDf") + "select id, ST_Point(cast(lat as Decimal(24,20)), cast(lon as Decimal(24,20))) AS latlon_point FROM rawPointDf" + ) pointDF.createOrReplaceTempView("pointDf") pointDF.show(truncate=False) raw_polygon_df = self.spark.createDataFrame( self.spark.sparkContext.parallelize( [ - Row("A", 25.0, -115.0, 35.0, -105.0), Row("B", 25.0, -135.0, 35.0, -125.0) - ]), + Row("A", 25.0, -115.0, 35.0, -105.0), + Row("B", 25.0, -135.0, 35.0, -125.0), + ] + ), StructType( [ - StructField("id", StringType(), True), StructField("latmin", DoubleType(), True), - StructField("lonmin", DoubleType(), True), StructField("latmax", DoubleType(), True), - StructField("lonmax", DoubleType(), True) + StructField("id", StringType(), True), + StructField("latmin", DoubleType(), True), + StructField("lonmin", DoubleType(), True), + StructField("latmax", DoubleType(), True), + StructField("lonmax", DoubleType(), True), ] - )) + ), + ) raw_polygon_df.createOrReplaceTempView("rawPolygonDf") - polygon_envelope_df = self.spark.sql("select id, ST_PolygonFromEnvelope(" + - "cast(latmin as Decimal(24,20)), cast(lonmin as Decimal(24,20)), " + - "cast(latmax as Decimal(24,20)), cast(lonmax as Decimal(24,20))) AS polygon FROM rawPolygonDf") + polygon_envelope_df = self.spark.sql( + "select id, ST_PolygonFromEnvelope(" + + "cast(latmin as Decimal(24,20)), cast(lonmin as Decimal(24,20)), " + + "cast(latmax as Decimal(24,20)), cast(lonmax as Decimal(24,20))) AS polygon FROM rawPolygonDf" + ) polygon_envelope_df.createOrReplaceTempView("polygonDf") within_envelope_df = self.spark.sql( - "select * FROM pointDf, polygonDf WHERE ST_Within(pointDf.latlon_point, polygonDf.polygon)") + "select * FROM pointDf, polygonDf WHERE ST_Within(pointDf.latlon_point, polygonDf.polygon)" + ) assert within_envelope_df.count() == 1 def test_st_equals_in_a_join_for_st_point(self): - point_csv_df_1 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_point1_input_location) + point_csv_df_1 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point1_input_location) + ) point_csv_df_1.createOrReplaceTempView("pointtable1") point_csv_df_1.show() - point_df1 = self.spark.sql("select ST_Point(cast(pointtable1._c0 as Decimal(24,20)),cast(pointtable1._c1 as Decimal(24,20)) ) as pointshape1 from pointtable1") + point_df1 = self.spark.sql( + "select ST_Point(cast(pointtable1._c0 as Decimal(24,20)),cast(pointtable1._c1 as Decimal(24,20)) ) as pointshape1 from pointtable1" + ) point_df1.createOrReplaceTempView("pointdf1") point_df1.show() - point_csv_df2 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_point2_input_location) + point_csv_df2 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point2_input_location) + ) point_csv_df2.createOrReplaceTempView("pointtable2") point_csv_df2.show() - point_df2 = self.spark.sql("select ST_Point(cast(pointtable2._c0 as Decimal(24,20)),cast(pointtable2._c1 as Decimal(24,20))) as pointshape2 from pointtable2") + point_df2 = self.spark.sql( + "select ST_Point(cast(pointtable2._c0 as Decimal(24,20)),cast(pointtable2._c1 as Decimal(24,20))) as pointshape2 from pointtable2" + ) point_df2.createOrReplaceTempView("pointdf2") point_df2.show() - equal_join_df = self.spark.sql("select * from pointdf1, pointdf2 where ST_Equals(pointdf1.pointshape1,pointdf2.pointshape2) ") + equal_join_df = self.spark.sql( + "select * from pointdf1, pointdf2 where ST_Equals(pointdf1.pointshape1,pointdf2.pointshape2) " + ) equal_join_df.explain() equal_join_df.show(3) - assert equal_join_df.count() == 100, f"Expected 100 but got {equal_join_df.count()}" + assert ( + equal_join_df.count() == 100 + ), f"Expected 100 but got {equal_join_df.count()}" def test_st_equals_in_a_join_for_st_polygon(self): - polygon_csv_df1 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_polygon1_input_location) + polygon_csv_df1 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon1_input_location) + ) polygon_csv_df1.createOrReplaceTempView("polygontable1") polygon_csv_df1.show() polygon_df1 = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable1._c0 as Decimal(24,20)),cast(polygontable1._c1 as Decimal(24,20)), cast(polygontable1._c2 as Decimal(24,20)), cast(polygontable1._c3 as Decimal(24,20))) as polygonshape1 from polygontable1") + "select ST_PolygonFromEnvelope(cast(polygontable1._c0 as Decimal(24,20)),cast(polygontable1._c1 as Decimal(24,20)), cast(polygontable1._c2 as Decimal(24,20)), cast(polygontable1._c3 as Decimal(24,20))) as polygonshape1 from polygontable1" + ) polygon_df1.createOrReplaceTempView("polygondf1") polygon_df1.show() - polygon_csv_df2 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load(csv_polygon2_input_location) + polygon_csv_df2 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon2_input_location) + ) polygon_csv_df2.createOrReplaceTempView("polygontable2") polygon_csv_df2.show() polygon_df2 = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable2._c0 as Decimal(24,20)),cast(polygontable2._c1 as Decimal(24,20)), cast(polygontable2._c2 as Decimal(24,20)), cast(polygontable2._c3 as Decimal(24,20))) as polygonshape2 from polygontable2") + "select ST_PolygonFromEnvelope(cast(polygontable2._c0 as Decimal(24,20)),cast(polygontable2._c1 as Decimal(24,20)), cast(polygontable2._c2 as Decimal(24,20)), cast(polygontable2._c3 as Decimal(24,20))) as polygonshape2 from polygontable2" + ) polygon_df2.createOrReplaceTempView("polygondf2") polygon_df2.show() equal_join_df = self.spark.sql( - "select * from polygondf1, polygondf2 where ST_Equals(polygondf1.polygonshape1,polygondf2.polygonshape2) ") + "select * from polygondf1, polygondf2 where ST_Equals(polygondf1.polygonshape1,polygondf2.polygonshape2) " + ) equal_join_df.explain() equal_join_df.show(3) - assert equal_join_df.count() == 100, f"Expected 100 but got {equal_join_df.count()}" + assert ( + equal_join_df.count() == 100 + ), f"Expected 100 but got {equal_join_df.count()}" def test_st_equals_in_a_join_for_st_polygon_random_shuffle(self): - polygon_csv_df1 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_polygon1_random_input_location) + polygon_csv_df1 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon1_random_input_location) + ) polygon_csv_df1.createOrReplaceTempView("polygontable1") polygon_csv_df1.show() - polygon_df1 = self.spark.sql("select ST_PolygonFromEnvelope(cast(polygontable1._c0 as Decimal(24,20)),cast(polygontable1._c1 as Decimal(24,20)), cast(polygontable1._c2 as Decimal(24,20)), cast(polygontable1._c3 as Decimal(24,20))) as polygonshape1 from polygontable1") + polygon_df1 = self.spark.sql( + "select ST_PolygonFromEnvelope(cast(polygontable1._c0 as Decimal(24,20)),cast(polygontable1._c1 as Decimal(24,20)), cast(polygontable1._c2 as Decimal(24,20)), cast(polygontable1._c3 as Decimal(24,20))) as polygonshape1 from polygontable1" + ) polygon_df1.createOrReplaceTempView("polygondf1") polygon_df1.show() - polygon_csv_df2 = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_polygon2_random_input_location) + polygon_csv_df2 = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon2_random_input_location) + ) polygon_csv_df2.createOrReplaceTempView("polygontable2") polygon_csv_df2.show() - polygon_df2 = self.spark.sql("select ST_PolygonFromEnvelope(cast(polygontable2._c0 as Decimal(24,20)),cast(polygontable2._c1 as Decimal(24,20)), cast(polygontable2._c2 as Decimal(24,20)), cast(polygontable2._c3 as Decimal(24,20))) as polygonshape2 from polygontable2") + polygon_df2 = self.spark.sql( + "select ST_PolygonFromEnvelope(cast(polygontable2._c0 as Decimal(24,20)),cast(polygontable2._c1 as Decimal(24,20)), cast(polygontable2._c2 as Decimal(24,20)), cast(polygontable2._c3 as Decimal(24,20))) as polygonshape2 from polygontable2" + ) polygon_df2.createOrReplaceTempView("polygondf2") polygon_df2.show() - equal_join_df = self.spark.sql("select * from polygondf1, polygondf2 where ST_Equals(polygondf1.polygonshape1,polygondf2.polygonshape2) ") + equal_join_df = self.spark.sql( + "select * from polygondf1, polygondf2 where ST_Equals(polygondf1.polygonshape1,polygondf2.polygonshape2) " + ) equal_join_df.explain() equal_join_df.show(3) - assert equal_join_df.count() == 100, f"Expected 100 but got {equal_join_df.count()}" + assert ( + equal_join_df.count() == 100 + ), f"Expected 100 but got {equal_join_df.count()}" def test_st_equals_in_a_join_for_st_point_and_st_polygon(self): - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_point1_input_location) + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point1_input_location) + ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show() - point_df = self.spark.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20)) ) as pointshape from pointtable") + point_df = self.spark.sql( + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20)) ) as pointshape from pointtable" + ) point_df.createOrReplaceTempView("pointdf") point_df.show() - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").\ - load(csv_polygon1_input_location) + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon1_input_location) + ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() - polygon_df = self.spark.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + polygon_df = self.spark.sql( + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - equal_join_df = self.spark.sql("select * from pointdf, polygondf where ST_Equals(pointdf.pointshape,polygondf.polygonshape) ") + equal_join_df = self.spark.sql( + "select * from pointdf, polygondf where ST_Equals(pointdf.pointshape,polygondf.polygonshape) " + ) equal_join_df.explain() equal_join_df.show(3) assert equal_join_df.count() == 0, f"Expected 0 but got {equal_join_df.count()}" def test_st_contains_in_broadcast_join(self): - polygon_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_polygon_input_location + polygon_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_polygon_input_location) ) polygon_csv_df.createOrReplaceTempView("polygontable") polygon_csv_df.show() polygon_df = self.spark.sql( - "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable") + "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable" + ) polygon_df = polygon_df.repartition(7) polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() - point_csv_df = self.spark.read.format("csv").\ - option("delimiter", ",").\ - option("header", "false").load( - csv_point_input_location + point_csv_df = ( + self.spark.read.format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csv_point_input_location) ) point_csv_df.createOrReplaceTempView("pointtable") point_csv_df.show() point_df = self.spark.sql( - "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable") + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable" + ) point_df = point_df.repartition(9) point_df.createOrReplaceTempView("pointdf") point_df.show() range_join_df = self.spark.sql( - "select /*+ BROADCAST(polygondf) */ * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) ") + "select /*+ BROADCAST(polygondf) */ * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) " + ) range_join_df.explain() range_join_df.show(3) assert range_join_df.rdd.getNumPartitions() == 9 assert range_join_df.count() == 1000 - range_join_df = point_df.alias("pointdf").join(broadcast(polygon_df).alias("polygondf"), on=expr("ST_Contains(polygondf.polygonshape, pointdf.pointshape)")) + range_join_df = point_df.alias("pointdf").join( + broadcast(polygon_df).alias("polygondf"), + on=expr("ST_Contains(polygondf.polygonshape, pointdf.pointshape)"), + ) range_join_df.explain() range_join_df.show(3) diff --git a/python/tests/sql/test_shapefile.py b/python/tests/sql/test_shapefile.py index 1565ee6ea8..b02cbc04d6 100644 --- a/python/tests/sql/test_shapefile.py +++ b/python/tests/sql/test_shapefile.py @@ -34,31 +34,38 @@ def test_read_simple(self): assert row["geometry"].geom_type in ("Polygon", "MultiPolygon") def test_read_osm_pois(self): - input_location = os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + input_location = os.path.join( + tests_resource, "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp" + ) df = self.spark.read.format("shapefile").load(input_location) assert df.count() == 12873 rows = df.take(100) for row in rows: assert len(row) == 5 assert row["geometry"].geom_type == "Point" - assert isinstance(row['osm_id'], str) - assert isinstance(row['fclass'], str) - assert isinstance(row['name'], str) - assert isinstance(row['code'], int) + assert isinstance(row["osm_id"], str) + assert isinstance(row["fclass"], str) + assert isinstance(row["name"], str) + assert isinstance(row["code"], int) def test_customize_geom_and_key_columns(self): input_location = os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1") - df = self.spark.read.format("shapefile").option("geometry.name", "geom").option("key.name", "fid").load(input_location) + df = ( + self.spark.read.format("shapefile") + .option("geometry.name", "geom") + .option("key.name", "fid") + .load(input_location) + ) assert df.count() == 12873 rows = df.take(100) for row in rows: assert len(row) == 6 assert row["geom"].geom_type == "Point" - assert isinstance(row['fid'], int) - assert isinstance(row['osm_id'], str) - assert isinstance(row['fclass'], str) - assert isinstance(row['name'], str) - assert isinstance(row['code'], int) + assert isinstance(row["fid"], int) + assert isinstance(row["osm_id"], str) + assert isinstance(row["fclass"], str) + assert isinstance(row["name"], str) + assert isinstance(row["code"], int) def test_read_multiple_shapefiles(self): input_location = os.path.join(tests_resource, "shapefiles/datatypes") @@ -66,20 +73,20 @@ def test_read_multiple_shapefiles(self): rows = df.collect() assert len(rows) == 9 for row in rows: - id = row['id'] - assert row['aInt'] == id + id = row["id"] + assert row["aInt"] == id if id is not None: - assert row['aUnicode'] == "测试" + str(id) + assert row["aUnicode"] == "测试" + str(id) if id < 10: - assert row['aDecimal'] * 10 == id * 10 + id - assert row['aDecimal2'] is None - assert row['aDate'] == datetime.date(2020 + id, id, id) + assert row["aDecimal"] * 10 == id * 10 + id + assert row["aDecimal2"] is None + assert row["aDate"] == datetime.date(2020 + id, id, id) else: - assert row['aDecimal'] is None - assert row['aDecimal2'] * 100 == id * 100 + id - assert row['aDate'] is None + assert row["aDecimal"] is None + assert row["aDecimal2"] * 100 == id * 100 + id + assert row["aDate"] is None else: - assert row['aUnicode'] == '' - assert row['aDecimal'] is None - assert row['aDecimal2'] is None - assert row['aDate'] is None + assert row["aUnicode"] == "" + assert row["aDecimal"] is None + assert row["aDecimal2"] is None + assert row["aDate"] is None diff --git a/python/tests/sql/test_spatial_rdd_to_spatial_dataframe.py b/python/tests/sql/test_spatial_rdd_to_spatial_dataframe.py index 41682137be..45570396db 100644 --- a/python/tests/sql/test_spatial_rdd_to_spatial_dataframe.py +++ b/python/tests/sql/test_spatial_rdd_to_spatial_dataframe.py @@ -46,7 +46,7 @@ def test_list_to_rdd_and_df(self): [Point(22, 52.0), "2", 2], [Point(23.0, 52), "3", 3], [Point(23, 54), "4", 4], - [Point(24.0, 56.0), "5", 5] + [Point(24.0, 56.0), "5", 5], ] schema = StructType( [ @@ -68,7 +68,7 @@ def test_point_rdd(self): Offset=0, splitter=splitter, carryInputData=True, - partitions=numPartitions + partitions=numPartitions, ) raw_spatial_rdd = spatial_rdd.rawSpatialRDD.map( @@ -78,10 +78,7 @@ def test_point_rdd(self): self.spark.createDataFrame(raw_spatial_rdd).show() schema = StructType( - [ - StructField("geom", GeometryType()), - StructField("name", StringType()) - ] + [StructField("geom", GeometryType()), StructField("name", StringType())] ) spatial_rdd_with_schema = self.spark.createDataFrame( @@ -90,4 +87,6 @@ def test_point_rdd(self): spatial_rdd_with_schema.show() - assert spatial_rdd_with_schema.take(1)[0][0].wkt == "POINT (32.324142 -88.331492)" + assert ( + spatial_rdd_with_schema.take(1)[0][0].wkt == "POINT (32.324142 -88.331492)" + ) diff --git a/python/tests/sql/test_st_function_imports.py b/python/tests/sql/test_st_function_imports.py index 81bc768f1c..79cf29d816 100644 --- a/python/tests/sql/test_st_function_imports.py +++ b/python/tests/sql/test_st_function_imports.py @@ -37,4 +37,5 @@ def test_import(self): def test_geometry_type_should_be_a_sql_type(self): from sedona.spark import GeometryType from pyspark.sql.types import UserDefinedType + assert isinstance(GeometryType(), UserDefinedType) diff --git a/python/tests/streaming/spark/cases_builder.py b/python/tests/streaming/spark/cases_builder.py index 5dd1275675..cfb25e4c67 100644 --- a/python/tests/streaming/spark/cases_builder.py +++ b/python/tests/streaming/spark/cases_builder.py @@ -25,31 +25,27 @@ def __init__(self, container: Dict[str, Any]): @classmethod def empty(cls): - return cls(container=dict(function_name=None, arguments=None, expected_result=None, transform=None)) + return cls( + container=dict( + function_name=None, arguments=None, expected_result=None, transform=None + ) + ) def with_function_name(self, function_name: str): self.container["function_name"] = function_name - return self.__class__( - container=self.container - ) + return self.__class__(container=self.container) def with_expected_result(self, expected_result: Any): self.container["expected_result"] = expected_result - return self.__class__( - container=self.container - ) + return self.__class__(container=self.container) def with_arguments(self, arguments: List[str]): self.container["arguments"] = arguments - return self.__class__( - container=self.container - ) + return self.__class__(container=self.container) def with_transform(self, transform: str): self.container["transform"] = transform - return self.__class__( - container=self.container - ) + return self.__class__(container=self.container) def __iter__(self): return self.container.__iter__() diff --git a/python/tests/streaming/spark/test_constructor_functions.py b/python/tests/streaming/spark/test_constructor_functions.py index 1be6d961b4..a6ae01be21 100644 --- a/python/tests/streaming/spark/test_constructor_functions.py +++ b/python/tests/streaming/spark/test_constructor_functions.py @@ -29,289 +29,450 @@ from tests.test_base import TestBase import math -SCHEMA = StructType( - [ - StructField("geom", GeometryType()) - ] -) +SCHEMA = StructType([StructField("geom", GeometryType())]) SEDONA_LISTED_SQL_FUNCTIONS = [ - (SuiteContainer.empty() - .with_function_name("ST_AsText") - .with_arguments(["ST_GeomFromText('POINT (21 52)')"]) - .with_expected_result("POINT (21 52)")), - (SuiteContainer.empty() - .with_function_name("ST_Buffer") - .with_arguments(["ST_GeomFromText('POINT (21 52)')", "1.0"]) - .with_expected_result(3.1214451522580533) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_Buffer") - .with_arguments(["ST_GeomFromText('POINT (21 52)')", "100000", "true"]) - .with_expected_result(4.088135158017784) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_Distance") - .with_arguments(["ST_GeomFromText('POINT (21 52)')", "ST_GeomFromText('POINT (21 53)')"]) - .with_expected_result(1.0)), - (SuiteContainer.empty() - .with_function_name("ST_ConcaveHull") - .with_arguments(["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')", "1.0"]) - .with_expected_result(1.0) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_ConvexHull") - .with_arguments(["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"]) - .with_expected_result(1.0) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_Envelope") - .with_arguments(["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"]) - .with_expected_result(1.0) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_LENGTH") - .with_arguments(["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"]) - .with_expected_result(4.0)), - (SuiteContainer.empty() - .with_function_name("ST_Area") - .with_arguments(["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"]) - .with_expected_result(1.0)), - (SuiteContainer.empty() - .with_function_name("ST_Centroid") - .with_arguments(["ST_GeomFromText('POINT(21.5 52.5)')"]) - .with_expected_result("POINT (21.5 52.5)") - .with_transform("ST_ASText")), - (SuiteContainer.empty() - .with_function_name("ST_Transform") - .with_arguments(["ST_GeomFromText('POINT(52.5 21.5)')", "'epsg:4326'", "'epsg:2180'"]) - .with_expected_result(-2501415.806893427) - #.with_expected_result("POINT (-2501415.806893427 4119952.52325666)") - .with_transform("ST_Y")), - #.with_transform("ST_ASText")), - (SuiteContainer.empty() - .with_function_name("ST_Intersection") - .with_arguments(["ST_GeomFromText('POINT(21.5 52.5)')", "ST_GeomFromText('POINT(21.5 52.5)')"]) - .with_expected_result(0) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_IsValid") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"]) - .with_expected_result(True)), - (SuiteContainer.empty() - .with_function_name("ST_MakeValid") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')", "false"]) - .with_expected_result(1.0) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_ReducePrecision") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')", "9"]) - .with_expected_result(1.0) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_IsSimple") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"]) - .with_expected_result(True)), - (SuiteContainer.empty() - .with_function_name("ST_Buffer") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')", "0.9"]) - .with_expected_result(7.128370573329018) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_AsText") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"]) - .with_expected_result("POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))")), - (SuiteContainer.empty() + ( + SuiteContainer.empty() + .with_function_name("ST_AsText") + .with_arguments(["ST_GeomFromText('POINT (21 52)')"]) + .with_expected_result("POINT (21 52)") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Buffer") + .with_arguments(["ST_GeomFromText('POINT (21 52)')", "1.0"]) + .with_expected_result(3.1214451522580533) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Buffer") + .with_arguments(["ST_GeomFromText('POINT (21 52)')", "100000", "true"]) + .with_expected_result(4.088135158017784) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Distance") + .with_arguments( + ["ST_GeomFromText('POINT (21 52)')", "ST_GeomFromText('POINT (21 53)')"] + ) + .with_expected_result(1.0) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_ConcaveHull") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')", "1.0"] + ) + .with_expected_result(1.0) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_ConvexHull") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"] + ) + .with_expected_result(1.0) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Envelope") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"] + ) + .with_expected_result(1.0) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_LENGTH") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"] + ) + .with_expected_result(4.0) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Area") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"] + ) + .with_expected_result(1.0) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Centroid") + .with_arguments(["ST_GeomFromText('POINT(21.5 52.5)')"]) + .with_expected_result("POINT (21.5 52.5)") + .with_transform("ST_ASText") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Transform") + .with_arguments( + ["ST_GeomFromText('POINT(52.5 21.5)')", "'epsg:4326'", "'epsg:2180'"] + ) + .with_expected_result(-2501415.806893427) + # .with_expected_result("POINT (-2501415.806893427 4119952.52325666)") + .with_transform("ST_Y") + ), + # .with_transform("ST_ASText")), + ( + SuiteContainer.empty() + .with_function_name("ST_Intersection") + .with_arguments( + [ + "ST_GeomFromText('POINT(21.5 52.5)')", + "ST_GeomFromText('POINT(21.5 52.5)')", + ] + ) + .with_expected_result(0) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_IsValid") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"] + ) + .with_expected_result(True) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_MakeValid") + .with_arguments( + [ + "ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')", + "false", + ] + ) + .with_expected_result(1.0) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_ReducePrecision") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')", "9"] + ) + .with_expected_result(1.0) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_IsSimple") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"] + ) + .with_expected_result(True) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Buffer") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')", "0.9"] + ) + .with_expected_result(7.128370573329018) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_AsText") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"] + ) + .with_expected_result("POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))") + ), + ( + SuiteContainer.empty() .with_function_name("ST_AsGeoJSON") - .with_arguments(["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"]) + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 52, 21 53, 22 53, 22 52, 21 52))')"] + ) .with_expected_result( - """{"type":"Polygon","coordinates":[[[21.0,52.0],[21.0,53.0],[22.0,53.0],[22.0,52.0],[21.0,52.0]]]}""")), - (SuiteContainer.empty() - .with_function_name("ST_AsBinary") - .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) - .with_expected_result(wkb.dumps(wkt.loads("POINT(21 52)")))), - (SuiteContainer.empty() - .with_function_name("ST_AsEWKB") - .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) - .with_expected_result(wkb.dumps(wkt.loads("POINT(21 52)")))), - (SuiteContainer.empty() - .with_function_name("ST_SRID") - .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) - .with_expected_result(0)), - (SuiteContainer.empty() - .with_function_name("ST_SetSRID") - .with_arguments(["ST_GeomFromText('POINT(21 52)')", "4326"]) - .with_expected_result(0) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_NPoints") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"]) - .with_expected_result(5)), - (SuiteContainer.empty() - .with_function_name("ST_SimplifyPreserveTopology") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')", "1.0"]) - .with_expected_result(1) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_GeometryType") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"]) - .with_expected_result("ST_Polygon")), - (SuiteContainer.empty() - .with_function_name("ST_LineMerge") - .with_arguments(["ST_GeomFromText('LINESTRING(-29 -27,-30 -29.7,-36 -31,-45 -33,-46 -32)')"]) - .with_expected_result(0.0) - .with_transform("ST_LENGTH")), - (SuiteContainer.empty() - .with_function_name("ST_Azimuth") - .with_arguments(["ST_GeomFromText('POINT(21 52)')", "ST_GeomFromText('POINT(21 53)')"]) - .with_expected_result(0.0)), - (SuiteContainer.empty() - .with_function_name("ST_X") - .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) - .with_expected_result(21.0)), - (SuiteContainer.empty() - .with_function_name("ST_Y") - .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) - .with_expected_result(52.0)), - (SuiteContainer.empty() - .with_function_name("ST_StartPoint") - .with_arguments(["ST_GeomFromText('LINESTRING(100 150,50 60, 70 80, 160 170)')"]) - .with_expected_result("POINT (100 150)") - .with_transform("ST_ASText")), - (SuiteContainer.empty() - .with_function_name("ST_Endpoint") - .with_arguments(["ST_GeomFromText('LINESTRING(100 150,50 60, 70 80, 160 170)')"]) - .with_expected_result("POINT (160 170)") - .with_transform("ST_ASText")), - (SuiteContainer.empty() - .with_function_name("ST_Boundary") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"]) - .with_expected_result(4) - .with_transform("ST_LENGTH")), - (SuiteContainer.empty() - .with_function_name("ST_ExteriorRing") - .with_arguments(["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"]) - .with_expected_result(4) - .with_transform("ST_LENGTH")), - (SuiteContainer.empty() - .with_function_name("ST_GeometryN") - .with_arguments(["ST_GeomFromText('MULTIPOINT((1 2), (3 4), (5 6), (8 9))')", "0"]) - .with_expected_result(1) - .with_transform("ST_X")), - (SuiteContainer.empty() - .with_function_name("ST_InteriorRingN") - .with_arguments([ - "ST_GeomFromText('POLYGON((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))')", - "0"]) - .with_expected_result(4.0) - .with_transform("ST_LENGTH")), - (SuiteContainer.empty() - .with_function_name("ST_Dump") - .with_arguments([ - "ST_GeomFromText('MULTIPOINT ((10 40), (40 30), (20 20), (30 10))')"]) - .with_expected_result(4) - .with_transform("SIZE")), - (SuiteContainer.empty() - .with_function_name("ST_DumpPoints") - .with_arguments([ - "ST_GeomFromTEXT('LINESTRING (0 0, 1 1, 1 0)')"]) - .with_expected_result(3) - .with_transform("SIZE")), - (SuiteContainer.empty() - .with_function_name("ST_IsClosed") - .with_arguments([ - "ST_GeomFROMTEXT('LINESTRING(0 0, 1 1, 1 0)')"]) - .with_expected_result(False)), - (SuiteContainer.empty() - .with_function_name("ST_NumInteriorRings") - .with_arguments([ - "ST_GeomFROMTEXT('POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))')"]) - .with_expected_result(1)), - (SuiteContainer.empty() - .with_function_name("ST_AddPoint") - .with_arguments([ - "ST_GeomFromText('LINESTRING(0 0, 1 1, 1 0)')", - "ST_GeomFromText('Point(21 52)')", - "1"]) - .with_expected_result(111.86168327044916) - .with_transform("ST_Length")), - (SuiteContainer.empty() - .with_function_name("ST_RemovePoint") - .with_arguments([ - "ST_GeomFromText('LINESTRING(0 0, 1 1, 1 0)')", - "1" - ]) - .with_expected_result("LINESTRING (0 0, 1 0)") - .with_transform("ST_AsText")), - (SuiteContainer.empty() - .with_function_name("ST_IsRing") - .with_arguments([ - "ST_GeomFromText('LINESTRING(0 0, 0 1, 1 1, 1 0, 0 0)')" - ]) - .with_expected_result(True)), - (SuiteContainer.empty() - .with_function_name("ST_NumGeometries") - .with_arguments([ - "ST_GeomFromText('LINESTRING(0 0, 0 1, 1 1, 1 0, 0 0)')" - ]) - .with_expected_result(1)), - (SuiteContainer.empty() - .with_function_name("ST_FlipCoordinates") - .with_arguments([ - "ST_GeomFromText('POINT(21 52)')" - ]) - .with_expected_result(52.0) - .with_transform("ST_X")), - (SuiteContainer.empty() - .with_function_name("ST_MinimumBoundingRadius") - .with_arguments([ - "ST_GeomFromText('POLYGON((1 1,0 0, -1 1, 1 1))')" - ]) - .with_expected_result(Row(center=wkt.loads("POINT(0 1)"), radius=1.0))), - (SuiteContainer.empty() - .with_function_name("ST_MinimumBoundingCircle") - .with_arguments([ - "ST_GeomFromText('POLYGON((1 1,0 0, -1 1, 1 1))')", - "8"]) - .with_expected_result(3.121445152258052) - .with_transform("ST_AREA")), - (SuiteContainer.empty() - .with_function_name("ST_SubDivide") - .with_arguments([ - "ST_GeomFromText('POLYGON((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))')", - "5"]) - .with_expected_result(14) - .with_transform("SIZE")), - (SuiteContainer.empty() - .with_function_name("ST_SubDivideExplode") - .with_arguments([ - "ST_GeomFromText('LINESTRING(0 0, 85 85, 100 100, 120 120, 21 21, 10 10, 5 5)')", - "5"]) - .with_expected_result(2) - .with_transform("ST_NPoints")), - (SuiteContainer.empty() - .with_function_name("ST_GeoHash") - .with_arguments([ - "ST_GeomFromText('POINT(21.427834 52.042576573)')", - "5"]) - .with_expected_result("u3r0p")), - (SuiteContainer.empty() - .with_function_name("ST_Collect") - .with_arguments([ - "ST_GeomFromText('POINT(21.427834 52.042576573)')", - "ST_GeomFromText('POINT(45.342524 56.342354355)')"]) - .with_expected_result(0.0) - .with_transform("ST_LENGTH")), - (SuiteContainer.empty() - .with_function_name("ST_BestSRID") - .with_arguments(["ST_GeomFromText('POINT (-177 60)')"]) - .with_expected_result(32601)), - (SuiteContainer.empty() - .with_function_name("ST_ShiftLongitude") - .with_arguments(["ST_GeomFromText('POINT (-177 60)')"]) - .with_expected_result("POINT (183 60)") - .with_transform("ST_AsText")) + """{"type":"Polygon","coordinates":[[[21.0,52.0],[21.0,53.0],[22.0,53.0],[22.0,52.0],[21.0,52.0]]]}""" + ) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_AsBinary") + .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) + .with_expected_result(wkb.dumps(wkt.loads("POINT(21 52)"))) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_AsEWKB") + .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) + .with_expected_result(wkb.dumps(wkt.loads("POINT(21 52)"))) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_SRID") + .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) + .with_expected_result(0) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_SetSRID") + .with_arguments(["ST_GeomFromText('POINT(21 52)')", "4326"]) + .with_expected_result(0) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_NPoints") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"] + ) + .with_expected_result(5) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_SimplifyPreserveTopology") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')", "1.0"] + ) + .with_expected_result(1) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_GeometryType") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"] + ) + .with_expected_result("ST_Polygon") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_LineMerge") + .with_arguments( + ["ST_GeomFromText('LINESTRING(-29 -27,-30 -29.7,-36 -31,-45 -33,-46 -32)')"] + ) + .with_expected_result(0.0) + .with_transform("ST_LENGTH") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Azimuth") + .with_arguments( + ["ST_GeomFromText('POINT(21 52)')", "ST_GeomFromText('POINT(21 53)')"] + ) + .with_expected_result(0.0) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_X") + .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) + .with_expected_result(21.0) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Y") + .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) + .with_expected_result(52.0) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_StartPoint") + .with_arguments( + ["ST_GeomFromText('LINESTRING(100 150,50 60, 70 80, 160 170)')"] + ) + .with_expected_result("POINT (100 150)") + .with_transform("ST_ASText") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Endpoint") + .with_arguments( + ["ST_GeomFromText('LINESTRING(100 150,50 60, 70 80, 160 170)')"] + ) + .with_expected_result("POINT (160 170)") + .with_transform("ST_ASText") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Boundary") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"] + ) + .with_expected_result(4) + .with_transform("ST_LENGTH") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_ExteriorRing") + .with_arguments( + ["ST_GeomFromText('POLYGON ((21 53, 22 53, 22 52, 21 52, 21 53))')"] + ) + .with_expected_result(4) + .with_transform("ST_LENGTH") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_GeometryN") + .with_arguments( + ["ST_GeomFromText('MULTIPOINT((1 2), (3 4), (5 6), (8 9))')", "0"] + ) + .with_expected_result(1) + .with_transform("ST_X") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_InteriorRingN") + .with_arguments( + [ + "ST_GeomFromText('POLYGON((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))')", + "0", + ] + ) + .with_expected_result(4.0) + .with_transform("ST_LENGTH") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Dump") + .with_arguments( + ["ST_GeomFromText('MULTIPOINT ((10 40), (40 30), (20 20), (30 10))')"] + ) + .with_expected_result(4) + .with_transform("SIZE") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_DumpPoints") + .with_arguments(["ST_GeomFromTEXT('LINESTRING (0 0, 1 1, 1 0)')"]) + .with_expected_result(3) + .with_transform("SIZE") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_IsClosed") + .with_arguments(["ST_GeomFROMTEXT('LINESTRING(0 0, 1 1, 1 0)')"]) + .with_expected_result(False) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_NumInteriorRings") + .with_arguments( + [ + "ST_GeomFROMTEXT('POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))')" + ] + ) + .with_expected_result(1) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_AddPoint") + .with_arguments( + [ + "ST_GeomFromText('LINESTRING(0 0, 1 1, 1 0)')", + "ST_GeomFromText('Point(21 52)')", + "1", + ] + ) + .with_expected_result(111.86168327044916) + .with_transform("ST_Length") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_RemovePoint") + .with_arguments(["ST_GeomFromText('LINESTRING(0 0, 1 1, 1 0)')", "1"]) + .with_expected_result("LINESTRING (0 0, 1 0)") + .with_transform("ST_AsText") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_IsRing") + .with_arguments(["ST_GeomFromText('LINESTRING(0 0, 0 1, 1 1, 1 0, 0 0)')"]) + .with_expected_result(True) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_NumGeometries") + .with_arguments(["ST_GeomFromText('LINESTRING(0 0, 0 1, 1 1, 1 0, 0 0)')"]) + .with_expected_result(1) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_FlipCoordinates") + .with_arguments(["ST_GeomFromText('POINT(21 52)')"]) + .with_expected_result(52.0) + .with_transform("ST_X") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_MinimumBoundingRadius") + .with_arguments(["ST_GeomFromText('POLYGON((1 1,0 0, -1 1, 1 1))')"]) + .with_expected_result(Row(center=wkt.loads("POINT(0 1)"), radius=1.0)) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_MinimumBoundingCircle") + .with_arguments(["ST_GeomFromText('POLYGON((1 1,0 0, -1 1, 1 1))')", "8"]) + .with_expected_result(3.121445152258052) + .with_transform("ST_AREA") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_SubDivide") + .with_arguments( + [ + "ST_GeomFromText('POLYGON((35 10, 45 45, 15 40, 10 20, 35 10), (20 30, 35 35, 30 20, 20 30))')", + "5", + ] + ) + .with_expected_result(14) + .with_transform("SIZE") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_SubDivideExplode") + .with_arguments( + [ + "ST_GeomFromText('LINESTRING(0 0, 85 85, 100 100, 120 120, 21 21, 10 10, 5 5)')", + "5", + ] + ) + .with_expected_result(2) + .with_transform("ST_NPoints") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_GeoHash") + .with_arguments(["ST_GeomFromText('POINT(21.427834 52.042576573)')", "5"]) + .with_expected_result("u3r0p") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_Collect") + .with_arguments( + [ + "ST_GeomFromText('POINT(21.427834 52.042576573)')", + "ST_GeomFromText('POINT(45.342524 56.342354355)')", + ] + ) + .with_expected_result(0.0) + .with_transform("ST_LENGTH") + ), + ( + SuiteContainer.empty() + .with_function_name("ST_BestSRID") + .with_arguments(["ST_GeomFromText('POINT (-177 60)')"]) + .with_expected_result(32601) + ), + ( + SuiteContainer.empty() + .with_function_name("ST_ShiftLongitude") + .with_arguments(["ST_GeomFromText('POINT (-177 60)')"]) + .with_expected_result("POINT (183 60)") + .with_transform("ST_AsText") + ), ] @@ -324,34 +485,43 @@ def pytest_generate_tests(metafunc): class TestConstructorFunctions(TestBase): - params = { - "test_geospatial_function_on_stream": SEDONA_LISTED_SQL_FUNCTIONS - } + params = {"test_geospatial_function_on_stream": SEDONA_LISTED_SQL_FUNCTIONS} @pytest.mark.sparkstreaming - def test_geospatial_function_on_stream(self, function_name: str, arguments: List[str], - expected_result: Any, transform: Optional[str]): - # given input stream + def test_geospatial_function_on_stream( + self, + function_name: str, + arguments: List[str], + expected_result: Any, + transform: Optional[str], + ): + # given input stream - input_stream = self.spark.readStream.schema(SCHEMA).parquet(os.path.join( - tests_resource, - "streaming/geometry_example") - ).selectExpr(f"{function_name}({', '.join(arguments)}) AS result") + input_stream = ( + self.spark.readStream.schema(SCHEMA) + .parquet(os.path.join(tests_resource, "streaming/geometry_example")) + .selectExpr(f"{function_name}({', '.join(arguments)}) AS result") + ) - # and target table - random_table_name = f"view_{uuid.uuid4().hex}" + # and target table + random_table_name = f"view_{uuid.uuid4().hex}" - # when saving stream to memory - streaming_query = input_stream.writeStream.format("memory") \ - .queryName(random_table_name) \ - .outputMode("append").start() + # when saving stream to memory + streaming_query = ( + input_stream.writeStream.format("memory") + .queryName(random_table_name) + .outputMode("append") + .start() + ) - streaming_query.processAllAvailable() + streaming_query.processAllAvailable() - # then result should be as expected - transform_query = "result" if not transform else f"{transform}(result)" - queryResult = self.spark.sql(f"select {transform_query} from {random_table_name}").collect()[0][0] - if (type(queryResult) is float and type(expected_result) is float): - assert math.isclose(queryResult, expected_result, rel_tol=1e-9) - else: - assert queryResult == expected_result + # then result should be as expected + transform_query = "result" if not transform else f"{transform}(result)" + queryResult = self.spark.sql( + f"select {transform_query} from {random_table_name}" + ).collect()[0][0] + if type(queryResult) is float and type(expected_result) is float: + assert math.isclose(queryResult, expected_result, rel_tol=1e-9) + else: + assert queryResult == expected_result diff --git a/python/tests/test_assign_raw_spatial_rdd.py b/python/tests/test_assign_raw_spatial_rdd.py index 9011e93d1d..b367a08528 100644 --- a/python/tests/test_assign_raw_spatial_rdd.py +++ b/python/tests/test_assign_raw_spatial_rdd.py @@ -16,7 +16,12 @@ # under the License. from sedona.core.SpatialRDD import PointRDD, CircleRDD -from tests.properties.point_properties import input_location, offset, splitter, num_partitions +from tests.properties.point_properties import ( + input_location, + offset, + splitter, + num_partitions, +) from tests.test_base import TestBase @@ -24,32 +29,30 @@ class TestSpatialRddAssignment(TestBase): def test_raw_spatial_rdd_assignment(self): spatial_rdd = PointRDD( - self.sc, - input_location, - offset, - splitter, - True, - num_partitions + self.sc, input_location, offset, splitter, True, num_partitions ) spatial_rdd.analyze() empty_point_rdd = PointRDD() empty_point_rdd.rawSpatialRDD = spatial_rdd.rawSpatialRDD empty_point_rdd.analyze() - assert empty_point_rdd.countWithoutDuplicates() == spatial_rdd.countWithoutDuplicates() + assert ( + empty_point_rdd.countWithoutDuplicates() + == spatial_rdd.countWithoutDuplicates() + ) assert empty_point_rdd.boundaryEnvelope == spatial_rdd.boundaryEnvelope - assert empty_point_rdd.rawSpatialRDD.map(lambda x: x.geom.area).collect()[0] == 0.0 - assert empty_point_rdd.rawSpatialRDD.take(9)[4].getUserData() == "testattribute0\ttestattribute1\ttestattribute2" + assert ( + empty_point_rdd.rawSpatialRDD.map(lambda x: x.geom.area).collect()[0] == 0.0 + ) + assert ( + empty_point_rdd.rawSpatialRDD.take(9)[4].getUserData() + == "testattribute0\ttestattribute1\ttestattribute2" + ) def test_raw_circle_rdd_assignment(self): point_rdd = PointRDD( - self.sc, - input_location, - offset, - splitter, - True, - num_partitions + self.sc, input_location, offset, splitter, True, num_partitions ) circle_rdd = CircleRDD(point_rdd, 1.0) circle_rdd.analyze() @@ -58,5 +61,7 @@ def test_raw_circle_rdd_assignment(self): circle_rdd_2.rawSpatialRDD = circle_rdd.rawSpatialRDD circle_rdd_2.analyze() - assert circle_rdd_2.countWithoutDuplicates() == circle_rdd.countWithoutDuplicates() + assert ( + circle_rdd_2.countWithoutDuplicates() == circle_rdd.countWithoutDuplicates() + ) assert circle_rdd_2.boundaryEnvelope == circle_rdd.boundaryEnvelope diff --git a/python/tests/test_base.py b/python/tests/test_base.py index fa63fea750..7742d31464 100644 --- a/python/tests/test_base.py +++ b/python/tests/test_base.py @@ -24,7 +24,9 @@ class TestBase: @classproperty def spark(self): if not hasattr(self, "__spark"): - spark = SedonaContext.create(SedonaContext.builder().master("local[*]").getOrCreate()) + spark = SedonaContext.create( + SedonaContext.builder().master("local[*]").getOrCreate() + ) setattr(self, "__spark", spark) return getattr(self, "__spark") diff --git a/python/tests/test_circle.py b/python/tests/test_circle.py index bb79bf89a4..91d8184823 100644 --- a/python/tests/test_circle.py +++ b/python/tests/test_circle.py @@ -30,14 +30,19 @@ class TestCircle: def test_get_center(self): point = Point(0.0, 0.0) circle = Circle(point, 0.1) - assert circle.centerGeometry.x == pytest.approx(point.x, 1e-6) and circle.centerGeometry.y == pytest.approx(point.y, 1e-6) + assert circle.centerGeometry.x == pytest.approx( + point.x, 1e-6 + ) and circle.centerGeometry.y == pytest.approx(point.y, 1e-6) def test_get_radius(self): point = Point(0.0, 0.0) circle = Circle(point, 0.1) assert circle.getRadius() == pytest.approx(0.1, 0.01) - @pytest.mark.skipif(shapely.__version__.startswith('2.'), reason="Circle is immutable when working with Shapely 2.0") + @pytest.mark.skipif( + shapely.__version__.startswith("2."), + reason="Circle is immutable when working with Shapely 2.0", + ) def test_set_radius(self): point = Point(0.0, 0.0) circle = Circle(point, 0.1) @@ -63,20 +68,30 @@ def test_covers(self): assert not circle.covers(wkt.loads("MULTIPOINT ((0.1 0.1), (1.2 0.4))")) assert not circle.covers(wkt.loads("MULTIPOINT ((1.1 0.1), (0.2 1.4))")) - assert circle.covers(wkt.loads("POLYGON ((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1))")) + assert circle.covers( + wkt.loads("POLYGON ((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1))") + ) assert circle.covers(wkt.loads("POLYGON ((-0.5 0, 0 0.5, 0.5 0, -0.5 0))")) assert not circle.covers(wkt.loads("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))")) - assert not circle.covers(wkt.loads("POLYGON ((0.4 0.4, 0.4 0.45, 0.45 0.45, 0.45 0.4, 0.4 0.4))")) + assert not circle.covers( + wkt.loads("POLYGON ((0.4 0.4, 0.4 0.45, 0.45 0.45, 0.45 0.4, 0.4 0.4))") + ) assert circle.covers( - wkt.loads("MULTIPOLYGON (((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1)),((-0.5 0, 0 0.5, 0.5 0, -0.5 0)))") + wkt.loads( + "MULTIPOLYGON (((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1)),((-0.5 0, 0 0.5, 0.5 0, -0.5 0)))" + ) ) assert not circle.covers( - wkt.loads("MULTIPOLYGON (((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1)),((0 0, 0 1, 1 1, 1 0, 0 0)))") + wkt.loads( + "MULTIPOLYGON (((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1)),((0 0, 0 1, 1 1, 1 0, 0 0)))" + ) ) assert not circle.covers( - wkt.loads("MULTIPOLYGON (((0.4 0.4, 0.4 0.45, 0.45 0.45, 0.45 0.4, 0.4 0.4)),((0 0, 0 1, 1 1, 1 0, 0 0)))") + wkt.loads( + "MULTIPOLYGON (((0.4 0.4, 0.4 0.45, 0.45 0.45, 0.45 0.4, 0.4 0.4)),((0 0, 0 1, 1 1, 1 0, 0 0)))" + ) ) assert circle.covers(wkt.loads("LINESTRING (-0.1 0, 0.2 0.3)")) @@ -87,14 +102,20 @@ def test_covers(self): assert not circle.covers(wkt.loads("LINESTRING (0.4 0.4, 0.45 0.45)")) - assert circle.covers(wkt.loads("MULTILINESTRING ((-0.1 0, 0.2 0.3), (-0.5 0, 0 0.5, 0.5 0))")) - assert not circle.covers(wkt.loads("MULTILINESTRING ((-0.1 0, 0.2 0.3), (-0.1 0, 0 1))")) - assert not circle.covers(wkt.loads("MULTILINESTRING ((0.4 0.4, 0.45 0.45), (-0.1 0, 0 1))")) + assert circle.covers( + wkt.loads("MULTILINESTRING ((-0.1 0, 0.2 0.3), (-0.5 0, 0 0.5, 0.5 0))") + ) + assert not circle.covers( + wkt.loads("MULTILINESTRING ((-0.1 0, 0.2 0.3), (-0.1 0, 0 1))") + ) + assert not circle.covers( + wkt.loads("MULTILINESTRING ((0.4 0.4, 0.45 0.45), (-0.1 0, 0 1))") + ) def test_intersects(self): circle = Circle(Point(0.0, 0.0), 0.5) - assert (circle.intersects(Point(0, 0))) - assert (circle.intersects(Point(0.1, 0.2))) + assert circle.intersects(Point(0, 0)) + assert circle.intersects(Point(0.1, 0.2)) assert not (circle.intersects(Point(0.4, 0.4))) assert not (circle.intersects(Point(-1, 0.4))) @@ -102,33 +123,50 @@ def test_intersects(self): assert circle.intersects(wkt.loads("MULTIPOINT ((0.1 0.1), (1.2 0.4))")) assert not circle.intersects(wkt.loads("MULTIPOINT ((1.1 0.1), (0.2 1.4))")) - assert circle.intersects(wkt.loads("POLYGON ((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1))")) + assert circle.intersects( + wkt.loads("POLYGON ((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1))") + ) assert circle.intersects(wkt.loads("POLYGON ((-0.5 0, 0 0.5, 0.5 0, -0.5 0))")) assert circle.intersects(wkt.loads("POLYGON ((0 0, 1 1, 1 0, 0 0))")) - assert circle.intersects(wkt.loads("POLYGON ((-1 -1, -1 1, 1 1, 1.5 0.5, 1 -1, -1 -1))")) + assert circle.intersects( + wkt.loads("POLYGON ((-1 -1, -1 1, 1 1, 1.5 0.5, 1 -1, -1 -1))") + ) assert circle.intersects( - wkt.loads("POLYGON ((-1 -1, -1 1, 1 1, 1 -1, -1 -1),(-0.1 -0.1, 0.1 -0.1, 0.1 0.1, -0.1 0.1, -0.1 -0.1))") + wkt.loads( + "POLYGON ((-1 -1, -1 1, 1 1, 1 -1, -1 -1),(-0.1 -0.1, 0.1 -0.1, 0.1 0.1, -0.1 0.1, -0.1 -0.1))" + ) ) - assert not circle.intersects(wkt.loads("POLYGON ((0.4 0.4, 0.4 0.45, 0.45 0.45, 0.45 0.4, 0.4 0.4))")) - assert not circle.intersects(wkt.loads("POLYGON ((-1 0, -1 1, 0 1, 0 2, -1 2, -1 0))")) assert not circle.intersects( - wkt.loads("POLYGON ((-1 -1, -1 1, 1 1, 1 -1, -1 -1),(-0.6 -0.6, 0.6 -0.6, 0.6 0.6, -0.6 0.6, -0.6 -0.6))") + wkt.loads("POLYGON ((0.4 0.4, 0.4 0.45, 0.45 0.45, 0.45 0.4, 0.4 0.4))") + ) + assert not circle.intersects( + wkt.loads("POLYGON ((-1 0, -1 1, 0 1, 0 2, -1 2, -1 0))") + ) + assert not circle.intersects( + wkt.loads( + "POLYGON ((-1 -1, -1 1, 1 1, 1 -1, -1 -1),(-0.6 -0.6, 0.6 -0.6, 0.6 0.6, -0.6 0.6, -0.6 -0.6))" + ) ) assert circle.intersects( - wkt.loads("MULTIPOLYGON (((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1)),((-0.5 0, 0 0.5, 0.5 0, -0.5 0)))") + wkt.loads( + "MULTIPOLYGON (((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1)),((-0.5 0, 0 0.5, 0.5 0, -0.5 0)))" + ) ) assert circle.intersects( - wkt.loads("MULTIPOLYGON (((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1)), ((-1 0, -1 1, 0 1, 0 2, -1 2, -1 0)))") + wkt.loads( + "MULTIPOLYGON (((-0.1 0.1, 0 0.4, 0.1 0.2, -0.1 0.1)), ((-1 0, -1 1, 0 1, 0 2, -1 2, -1 0)))" + ) ) assert not circle.intersects( wkt.loads( - "MULTIPOLYGON (((0.4 0.4, 0.4 0.45, 0.45 0.45, 0.45 0.4, 0.4 0.4)),((-1 0, -1 1, 0 1, 0 2, -1 2, -1 0)))" - )) + "MULTIPOLYGON (((0.4 0.4, 0.4 0.45, 0.45 0.45, 0.45 0.4, 0.4 0.4)),((-1 0, -1 1, 0 1, 0 2, -1 2, -1 0)))" + ) + ) assert circle.intersects(wkt.loads("LINESTRING (-1 -1, 1 1)")) @@ -140,14 +178,24 @@ def test_intersects(self): assert not circle.intersects(wkt.loads("LINESTRING (-0.4 -0.4, -2 -3.2)")) assert not circle.intersects(wkt.loads("LINESTRING (0.1 0.5, 1 0.5)")) - assert circle.intersects(wkt.loads("MULTILINESTRING ((-1 -1, 1 1), (-1 0.5, 1 0.5))")) - assert circle.intersects(wkt.loads("MULTILINESTRING ((-1 -1, 1 1), (0.4 0.4, 1 1))")) - assert not circle.intersects(wkt.loads("MULTILINESTRING ((0.1 0.5, 1 0.5), (0.4 0.4, 1 1))")) + assert circle.intersects( + wkt.loads("MULTILINESTRING ((-1 -1, 1 1), (-1 0.5, 1 0.5))") + ) + assert circle.intersects( + wkt.loads("MULTILINESTRING ((-1 -1, 1 1), (0.4 0.4, 1 1))") + ) + assert not circle.intersects( + wkt.loads("MULTILINESTRING ((0.1 0.5, 1 0.5), (0.4 0.4, 1 1))") + ) def test_equality(self): - assert Circle(Point(-112.574945, 45.987772), 0.01) == Circle(Point(-112.574945, 45.987772), 0.01) + assert Circle(Point(-112.574945, 45.987772), 0.01) == Circle( + Point(-112.574945, 45.987772), 0.01 + ) - assert Circle(Point(-112.574945, 45.987772), 0.01) == Circle(Point(-112.574945, 45.987772), 0.01) + assert Circle(Point(-112.574945, 45.987772), 0.01) == Circle( + Point(-112.574945, 45.987772), 0.01 + ) def test_radius(self): polygon = wkt.loads( diff --git a/python/tests/test_multiple_meta.py b/python/tests/test_multiple_meta.py index 6fbeb7786d..c37cc72e82 100644 --- a/python/tests/test_multiple_meta.py +++ b/python/tests/test_multiple_meta.py @@ -45,7 +45,7 @@ def wget(cls, a: int) -> int: assert A.help_function() == 5 assert A.get("s") == "s" * 5 assert A.get(1, 2) == 8 - assert A.wget(4, "s") == 4*"s" + assert A.wget(4, "s") == 4 * "s" assert A.wget(4) == 4 def test_static_methods(self): @@ -61,6 +61,7 @@ def get(a: str) -> str: @classmethod def help_function(cls) -> int: return A.get(1, 2) * A.get("s") + assert A.help_function() == "sss" def test_basic_methods(self): @@ -90,6 +91,7 @@ def multiply_get(self, a: int): def multiply_get(self, c: str): return A.wget(1, 2) * A.wget(4) * c + assert A().multiply_get() == 9 assert A().multiply_get(10) == 120 assert A().multiply_get("c") == 12 * "c" diff --git a/python/tests/test_scala_example.py b/python/tests/test_scala_example.py index adcf53e0df..909bb0212c 100644 --- a/python/tests/test_scala_example.py +++ b/python/tests/test_scala_example.py @@ -52,32 +52,55 @@ class TestScalaExample(TestBase): def test_spatial_range_query(self): object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, ) object_rdd.rawJvmSpatialRDD.persist(StorageLevel.MEMORY_ONLY) for _ in range(each_query_loop_times): - result_size = RangeQuery.SpatialRangeQuery(object_rdd, range_query_window, False, False).count() + result_size = RangeQuery.SpatialRangeQuery( + object_rdd, range_query_window, False, False + ).count() object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, ) object_rdd.rawJvmSpatialRDD.persist(StorageLevel.MEMORY_ONLY) for _ in range(each_query_loop_times): - result_size = RangeQuery.SpatialRangeQuery(object_rdd, range_query_window, False, False).count() + result_size = RangeQuery.SpatialRangeQuery( + object_rdd, range_query_window, False, False + ).count() def test_spatial_range_query_using_index(self): object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True) + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, + ) object_rdd.buildIndex(point_rdd_index_type, False) object_rdd.indexedRawRDD.persist(StorageLevel.MEMORY_ONLY) assert object_rdd.indexedRawRDD.is_cached for _ in range(each_query_loop_times): - result_size = RangeQuery.SpatialRangeQuery(object_rdd, range_query_window, False, True).count + result_size = RangeQuery.SpatialRangeQuery( + object_rdd, range_query_window, False, True + ).count def test_spatial_knn_query(self): object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, ) object_rdd.rawJvmSpatialRDD.persist(StorageLevel.MEMORY_ONLY) @@ -86,7 +109,11 @@ def test_spatial_knn_query(self): def test_spatial_knn_query_using_index(self): object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, ) object_rdd.buildIndex(point_rdd_index_type, False) object_rdd.indexedRawRDD.persist(StorageLevel.MEMORY_ONLY) @@ -96,11 +123,20 @@ def test_spatial_knn_query_using_index(self): def test_spatial_join_query(self): query_window_rdd = PolygonRDD( - self.sc, polygon_rdd_input_location, polygon_rdd_start_offset, polygon_rdd_end_offset, - polygon_rdd_splitter, True + self.sc, + polygon_rdd_input_location, + polygon_rdd_start_offset, + polygon_rdd_end_offset, + polygon_rdd_splitter, + True, ) object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True) + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, + ) object_rdd.spatialPartitioning(join_query_partitioning_type) query_window_rdd.spatialPartitioning(object_rdd.getPartitioner()) @@ -109,15 +145,25 @@ def test_spatial_join_query(self): query_window_rdd.jvmSpatialPartitionedRDD.persist(StorageLevel.MEMORY_ONLY) for _ in range(each_query_loop_times): - result_size = JoinQuery.SpatialJoinQuery(object_rdd, query_window_rdd, False, True).count() + result_size = JoinQuery.SpatialJoinQuery( + object_rdd, query_window_rdd, False, True + ).count() def test_spatial_join_using_index(self): query_window_rdd = PolygonRDD( - self.sc, polygon_rdd_input_location, polygon_rdd_start_offset, - polygon_rdd_end_offset, polygon_rdd_splitter, True + self.sc, + polygon_rdd_input_location, + polygon_rdd_start_offset, + polygon_rdd_end_offset, + polygon_rdd_splitter, + True, ) object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, ) object_rdd.spatialPartitioning(join_query_partitioning_type) @@ -135,7 +181,12 @@ def test_spatial_join_using_index(self): def test_distance_join_query(self): object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True) + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, + ) query_window_rdd = CircleRDD(object_rdd, 0.1) object_rdd.spatialPartitioning(GridType.QUADTREE) @@ -147,11 +198,18 @@ def test_distance_join_query(self): query_window_rdd.spatialPartitionedRDD.persist(StorageLevel.MEMORY_ONLY) for _ in range(each_query_loop_times): - result_size = JoinQuery.DistanceJoinQuery(object_rdd, query_window_rdd, False, True).count() + result_size = JoinQuery.DistanceJoinQuery( + object_rdd, query_window_rdd, False, True + ).count() def test_distance_join_using_index(self): object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True) + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, + ) query_window_rdd = CircleRDD(object_rdd, 0.1) @@ -166,11 +224,18 @@ def test_distance_join_using_index(self): assert query_window_rdd.spatialPartitionedRDD.is_cached for _ in range(each_query_loop_times): - result_size = JoinQuery.DistanceJoinQuery(object_rdd, query_window_rdd, True, True).count() + result_size = JoinQuery.DistanceJoinQuery( + object_rdd, query_window_rdd, True, True + ).count() def test_indexed_rdd_assignment(self): object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True) + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, + ) query_window_rdd = CircleRDD(object_rdd, 0.1) object_rdd.analyze() object_rdd.spatialPartitioning(GridType.QUADTREE) @@ -189,11 +254,18 @@ def test_indexed_rdd_assignment(self): start = time.time() for _ in range(each_query_loop_times): - result_size = JoinQuery.DistanceJoinQuery(object_rdd, query_window_rdd, True, True).count() + result_size = JoinQuery.DistanceJoinQuery( + object_rdd, query_window_rdd, True, True + ).count() diff = time.time() - start object_rdd = PointRDD( - self.sc, point_rdd_input_location, point_rdd_offset, point_rdd_splitter, True) + self.sc, + point_rdd_input_location, + point_rdd_offset, + point_rdd_splitter, + True, + ) query_window_rdd = CircleRDD(object_rdd, 0.1) object_rdd.analyze() @@ -206,4 +278,6 @@ def test_indexed_rdd_assignment(self): start1 = time.time() for _ in range(each_query_loop_times): - result_size = JoinQuery.DistanceJoinQuery(object_rdd, query_window_rdd, True, True).count() + result_size = JoinQuery.DistanceJoinQuery( + object_rdd, query_window_rdd, True, True + ).count() diff --git a/python/tests/tools.py b/python/tests/tools.py index 7c7335013a..0e9be54aec 100644 --- a/python/tests/tools.py +++ b/python/tests/tools.py @@ -21,8 +21,11 @@ from sedona.utils.spatial_rdd_parser import GeoData -tests_path = path.abspath(path.join(__file__ ,"../../../spark/common/src/test/")) -tests_resource = path.abspath(path.join(__file__ ,"../../../spark/common/src/test/resources/")) +tests_path = path.abspath(path.join(__file__, "../../../spark/common/src/test/")) +tests_resource = path.abspath( + path.join(__file__, "../../../spark/common/src/test/resources/") +) + def distance_sorting_functions(geo_data: GeoData, query_point: Point): return geo_data.geom.distance(query_point) diff --git a/python/tests/utils/test_crs_transformation.py b/python/tests/utils/test_crs_transformation.py index 46d63463b1..ebf21eabd2 100644 --- a/python/tests/utils/test_crs_transformation.py +++ b/python/tests/utils/test_crs_transformation.py @@ -28,19 +28,19 @@ class TestCrsTransformation(TestBase): def test_spatial_range_query(self): - spatial_rdd = PointRDD( - self.sc, - input_location, - offset, - splitter, - True - ) + spatial_rdd = PointRDD(self.sc, input_location, offset, splitter, True) spatial_rdd.flipCoordinates() spatial_rdd.CRSTransform("epsg:4326", "epsg:3005") for i in range(loop_times): - result_size = RangeQuery.SpatialRangeQuery(spatial_rdd, query_envelope, False, False).count() + result_size = RangeQuery.SpatialRangeQuery( + spatial_rdd, query_envelope, False, False + ).count() assert result_size == 3127 - assert RangeQuery.SpatialRangeQuery( - spatial_rdd, query_envelope, False, False).take(10)[1].getUserData() is not None + assert ( + RangeQuery.SpatialRangeQuery(spatial_rdd, query_envelope, False, False) + .take(10)[1] + .getUserData() + is not None + ) diff --git a/python/tests/utils/test_geometry_serde.py b/python/tests/utils/test_geometry_serde.py index adb12202bd..79981460cf 100644 --- a/python/tests/utils/test_geometry_serde.py +++ b/python/tests/utils/test_geometry_serde.py @@ -16,7 +16,7 @@ # under the License. import pytest -from pyspark.sql.types import (StructType, StringType) +from pyspark.sql.types import StructType, StringType from sedona.sql.types import GeometryType from pyspark.sql.functions import expr @@ -33,117 +33,191 @@ from tests.test_base import TestBase + class TestGeometrySerde(TestBase): - @pytest.mark.parametrize("geom", [ - GeometryCollection([Point([10.0, 20.0]), Polygon([(10.0, 10.0), (20.0, 20.0), (20.0, 10.0)])]), - LineString([(10.0, 20.0), (30.0, 40.0)]), - LineString([(10.0, 20.0, 30.0), (40.0, 50.0, 60.0)]), - MultiLineString([[(10.0, 20.0), (30.0, 40.0)], [(50.0, 60.0), (70.0, 80.0)]]), - MultiLineString([[(10.0, 20.0, 30.0), (40.0, 50.0, 60.0)], [(70.0, 80.0, 90.0), (100.0, 110.0, 120.0)]]), - MultiPoint([(10.0, 20.0), (30.0, 40.0)]), - MultiPoint([(10.0, 20.0, 30.0), (40.0, 50.0, 60.0)]), - MultiPolygon([Polygon([(10.0, 10.0), (20.0, 20.0), (20.0, 10.0), (10.0, 10.0)]), Polygon([(-10.0, -10.0), (-20.0, -20.0), (-20.0, -10.0), (-10.0, -10.0)])]), - MultiPolygon([Polygon([(10.0, 10.0, 10.0), (20.0, 20.0, 10.0), (20.0, 10.0, 10.0), (10.0, 10.0, 10.0)]), Polygon([(-10.0, -10.0, -10.0), (-20.0, -20.0, -10.0), (-20.0, -10.0, -10.0), (-10.0, -10.0, -10.0)])]), - Point((10.0, 20.0)), - Point((10.0, 20.0, 30.0)), - Polygon([(10.0, 10.0), (20.0, 20.0), (20.0, 10.0), (10.0, 10.0)]), - Polygon([(10.0, 10.0, 10.0), (20.0, 20.0, 10.0), (20.0, 10.0, 10.0), (10.0, 10.0, 10.0)]), - ]) + @pytest.mark.parametrize( + "geom", + [ + GeometryCollection( + [ + Point([10.0, 20.0]), + Polygon([(10.0, 10.0), (20.0, 20.0), (20.0, 10.0)]), + ] + ), + LineString([(10.0, 20.0), (30.0, 40.0)]), + LineString([(10.0, 20.0, 30.0), (40.0, 50.0, 60.0)]), + MultiLineString( + [[(10.0, 20.0), (30.0, 40.0)], [(50.0, 60.0), (70.0, 80.0)]] + ), + MultiLineString( + [ + [(10.0, 20.0, 30.0), (40.0, 50.0, 60.0)], + [(70.0, 80.0, 90.0), (100.0, 110.0, 120.0)], + ] + ), + MultiPoint([(10.0, 20.0), (30.0, 40.0)]), + MultiPoint([(10.0, 20.0, 30.0), (40.0, 50.0, 60.0)]), + MultiPolygon( + [ + Polygon([(10.0, 10.0), (20.0, 20.0), (20.0, 10.0), (10.0, 10.0)]), + Polygon( + [(-10.0, -10.0), (-20.0, -20.0), (-20.0, -10.0), (-10.0, -10.0)] + ), + ] + ), + MultiPolygon( + [ + Polygon( + [ + (10.0, 10.0, 10.0), + (20.0, 20.0, 10.0), + (20.0, 10.0, 10.0), + (10.0, 10.0, 10.0), + ] + ), + Polygon( + [ + (-10.0, -10.0, -10.0), + (-20.0, -20.0, -10.0), + (-20.0, -10.0, -10.0), + (-10.0, -10.0, -10.0), + ] + ), + ] + ), + Point((10.0, 20.0)), + Point((10.0, 20.0, 30.0)), + Polygon([(10.0, 10.0), (20.0, 20.0), (20.0, 10.0), (10.0, 10.0)]), + Polygon( + [ + (10.0, 10.0, 10.0), + (20.0, 20.0, 10.0), + (20.0, 10.0, 10.0), + (10.0, 10.0, 10.0), + ] + ), + ], + ) def test_spark_serde(self, geom): - returned_geom = TestGeometrySerde.spark.createDataFrame([(geom,)], StructType().add("geom", GeometryType())).take(1)[0][0] + returned_geom = TestGeometrySerde.spark.createDataFrame( + [(geom,)], StructType().add("geom", GeometryType()) + ).take(1)[0][0] assert geom.equals_exact(returned_geom, 1e-6) - @pytest.mark.parametrize("wkt", [ - # empty geometries - 'POINT EMPTY', - 'LINESTRING EMPTY', - 'POLYGON EMPTY', - 'MULTIPOINT EMPTY', - 'MULTILINESTRING EMPTY', - 'MULTIPOLYGON EMPTY', - 'GEOMETRYCOLLECTION EMPTY', - # non-empty geometries - 'POINT (10 20)', - 'POINT (10 20 30)', - 'LINESTRING (10 20, 30 40)', - 'LINESTRING (10 20 30, 40 50 60)', - 'POLYGON ((10 10, 20 20, 20 10, 10 10))', - 'POLYGON ((10 10 10, 20 20 10, 20 10 10, 10 10 10))', - 'POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0), (1 1, 1 2, 2 2, 2 1, 1 1))', - # non-empty multi geometries - 'MULTIPOINT ((10 20), (30 40))', - 'MULTIPOINT ((10 20 30), (40 50 60))', - 'MULTILINESTRING ((10 20, 30 40), (50 60, 70 80))', - 'MULTILINESTRING ((10 20 30, 40 50 60), (70 80 90, 100 110 120))', - 'MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((-10 -10, -20 -20, -20 -10, -10 -10)))', - 'MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((0 0, 0 10, 10 10, 10 0, 0 0), (1 1, 1 2, 2 2, 2 1, 1 1)))', - 'GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40))', - 'GEOMETRYCOLLECTION (POINT (10 20 30), LINESTRING (10 20 30, 40 50 60))', - 'GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40), POLYGON ((10 10, 20 20, 20 10, 10 10)))', - # nested geometry collection - 'GEOMETRYCOLLECTION (GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40)))', - 'GEOMETRYCOLLECTION (POINT (1 2), GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40)))', - # multi geometries containing empty geometries - 'MULTIPOINT (EMPTY, (10 20))', - 'MULTIPOINT (EMPTY, EMPTY)', - 'MULTILINESTRING (EMPTY, (10 20, 30 40))', - 'MULTILINESTRING (EMPTY, EMPTY)', - 'MULTIPOLYGON (EMPTY, ((10 10, 20 20, 20 10, 10 10)))', - 'MULTIPOLYGON (EMPTY, EMPTY)', - 'GEOMETRYCOLLECTION (POINT (10 20), POINT EMPTY, LINESTRING (10 20, 30 40))', - 'GEOMETRYCOLLECTION (MULTIPOINT EMPTY, MULTILINESTRING EMPTY, MULTIPOLYGON EMPTY, GEOMETRYCOLLECTION EMPTY)', - ]) + @pytest.mark.parametrize( + "wkt", + [ + # empty geometries + "POINT EMPTY", + "LINESTRING EMPTY", + "POLYGON EMPTY", + "MULTIPOINT EMPTY", + "MULTILINESTRING EMPTY", + "MULTIPOLYGON EMPTY", + "GEOMETRYCOLLECTION EMPTY", + # non-empty geometries + "POINT (10 20)", + "POINT (10 20 30)", + "LINESTRING (10 20, 30 40)", + "LINESTRING (10 20 30, 40 50 60)", + "POLYGON ((10 10, 20 20, 20 10, 10 10))", + "POLYGON ((10 10 10, 20 20 10, 20 10 10, 10 10 10))", + "POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0), (1 1, 1 2, 2 2, 2 1, 1 1))", + # non-empty multi geometries + "MULTIPOINT ((10 20), (30 40))", + "MULTIPOINT ((10 20 30), (40 50 60))", + "MULTILINESTRING ((10 20, 30 40), (50 60, 70 80))", + "MULTILINESTRING ((10 20 30, 40 50 60), (70 80 90, 100 110 120))", + "MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((-10 -10, -20 -20, -20 -10, -10 -10)))", + "MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((0 0, 0 10, 10 10, 10 0, 0 0), (1 1, 1 2, 2 2, 2 1, 1 1)))", + "GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40))", + "GEOMETRYCOLLECTION (POINT (10 20 30), LINESTRING (10 20 30, 40 50 60))", + "GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40), POLYGON ((10 10, 20 20, 20 10, 10 10)))", + # nested geometry collection + "GEOMETRYCOLLECTION (GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40)))", + "GEOMETRYCOLLECTION (POINT (1 2), GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40)))", + # multi geometries containing empty geometries + "MULTIPOINT (EMPTY, (10 20))", + "MULTIPOINT (EMPTY, EMPTY)", + "MULTILINESTRING (EMPTY, (10 20, 30 40))", + "MULTILINESTRING (EMPTY, EMPTY)", + "MULTIPOLYGON (EMPTY, ((10 10, 20 20, 20 10, 10 10)))", + "MULTIPOLYGON (EMPTY, EMPTY)", + "GEOMETRYCOLLECTION (POINT (10 20), POINT EMPTY, LINESTRING (10 20, 30 40))", + "GEOMETRYCOLLECTION (MULTIPOINT EMPTY, MULTILINESTRING EMPTY, MULTIPOLYGON EMPTY, GEOMETRYCOLLECTION EMPTY)", + ], + ) def test_spark_serde_compatibility_with_scala(self, wkt): geom = wkt_loads(wkt) schema = StructType().add("geom", GeometryType()) - returned_geom = TestGeometrySerde.spark.createDataFrame([(geom,)], schema).take(1)[0][0] + returned_geom = TestGeometrySerde.spark.createDataFrame([(geom,)], schema).take( + 1 + )[0][0] assert geom.equals(returned_geom) # serialized by python, deserialized by scala - returned_wkt = TestGeometrySerde.spark.createDataFrame([(geom,)], schema).selectExpr("ST_AsText(geom)").take(1)[0][0] + returned_wkt = ( + TestGeometrySerde.spark.createDataFrame([(geom,)], schema) + .selectExpr("ST_AsText(geom)") + .take(1)[0][0] + ) assert wkt_loads(returned_wkt).equals(geom) # serialized by scala, deserialized by python schema = StructType().add("wkt", StringType()) - returned_geom = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema).selectExpr("ST_GeomFromText(wkt)").take(1)[0][0] + returned_geom = ( + TestGeometrySerde.spark.createDataFrame([(wkt,)], schema) + .selectExpr("ST_GeomFromText(wkt)") + .take(1)[0][0] + ) assert geom.equals(returned_geom) - @pytest.mark.parametrize("wkt", [ - 'POINT ZM (1 2 3 4)', - 'LINESTRING ZM (1 2 3 4, 5 6 7 8)', - 'POLYGON ZM ((10 10 10 1, 20 20 10 1, 20 10 10 1, 10 10 10 1))', - 'MULTIPOINT ZM ((10 20 30 1), (40 50 60 1))', - 'MULTILINESTRING ZM ((10 20 30 1, 40 50 60 1), (70 80 90 1, 100 110 120 1))', - 'MULTIPOLYGON ZM (((10 10 10 1, 20 20 10 1, 20 10 10 1, 10 10 10 1)), ' + - '((0 0 0 1, 0 10 0 1, 10 10 0 1, 10 0 0 1, 0 0 0 1), (1 1 0 1, 1 2 0 1, 2 2 0 1, 2 1 0 1, 1 1 0 1)))', - 'GEOMETRYCOLLECTION (POINT ZM (10 20 30 1), LINESTRING ZM (10 20 30 1, 40 50 60 1))', - ]) + @pytest.mark.parametrize( + "wkt", + [ + "POINT ZM (1 2 3 4)", + "LINESTRING ZM (1 2 3 4, 5 6 7 8)", + "POLYGON ZM ((10 10 10 1, 20 20 10 1, 20 10 10 1, 10 10 10 1))", + "MULTIPOINT ZM ((10 20 30 1), (40 50 60 1))", + "MULTILINESTRING ZM ((10 20 30 1, 40 50 60 1), (70 80 90 1, 100 110 120 1))", + "MULTIPOLYGON ZM (((10 10 10 1, 20 20 10 1, 20 10 10 1, 10 10 10 1)), " + + "((0 0 0 1, 0 10 0 1, 10 10 0 1, 10 0 0 1, 0 0 0 1), (1 1 0 1, 1 2 0 1, 2 2 0 1, 2 1 0 1, 1 1 0 1)))", + "GEOMETRYCOLLECTION (POINT ZM (10 20 30 1), LINESTRING ZM (10 20 30 1, 40 50 60 1))", + ], + ) def test_spark_serde_on_4d_geoms(self, wkt): geom = wkt_loads(wkt) schema = StructType().add("wkt", StringType()) - returned_geom, n_dims = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema)\ - .selectExpr("ST_GeomFromText(wkt)", "ST_NDims(ST_GeomFromText(wkt))")\ + returned_geom, n_dims = ( + TestGeometrySerde.spark.createDataFrame([(wkt,)], schema) + .selectExpr("ST_GeomFromText(wkt)", "ST_NDims(ST_GeomFromText(wkt))") .take(1)[0] + ) assert n_dims == 4 assert geom.equals(returned_geom) - @pytest.mark.parametrize("wkt", [ - 'POINT M (1 2 3)', - 'LINESTRING M (1 2 3, 5 6 7)', - 'POLYGON M ((10 10 10, 20 20 10, 20 10 10, 10 10 10))', - 'MULTIPOINT M ((10 20 30), (40 50 60))', - 'MULTILINESTRING M ((10 20 30, 40 50 60), (70 80 90, 100 110 120))', - 'MULTIPOLYGON M (((10 10 10, 20 20 10, 20 10 10, 10 10 10)), ' + - '((0 0 0, 0 10 0, 10 10 0, 10 0 0, 0 0 0), (1 1 0, 1 2 0, 2 2 0, 2 1 0, 1 1 0)))', - 'GEOMETRYCOLLECTION (POINT M (10 20 30), LINESTRING M (10 20 30, 40 50 60))', - ]) + @pytest.mark.parametrize( + "wkt", + [ + "POINT M (1 2 3)", + "LINESTRING M (1 2 3, 5 6 7)", + "POLYGON M ((10 10 10, 20 20 10, 20 10 10, 10 10 10))", + "MULTIPOINT M ((10 20 30), (40 50 60))", + "MULTILINESTRING M ((10 20 30, 40 50 60), (70 80 90, 100 110 120))", + "MULTIPOLYGON M (((10 10 10, 20 20 10, 20 10 10, 10 10 10)), " + + "((0 0 0, 0 10 0, 10 10 0, 10 0 0, 0 0 0), (1 1 0, 1 2 0, 2 2 0, 2 1 0, 1 1 0)))", + "GEOMETRYCOLLECTION (POINT M (10 20 30), LINESTRING M (10 20 30, 40 50 60))", + ], + ) def test_spark_serde_on_xym_geoms(self, wkt): geom = wkt_loads(wkt) schema = StructType().add("wkt", StringType()) - returned_geom, n_dims, z_min = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema) \ - .withColumn("geom", expr("ST_GeomFromText(wkt)")) \ - .selectExpr("geom", "ST_NDims(geom)", "ST_ZMin(geom)") \ + returned_geom, n_dims, z_min = ( + TestGeometrySerde.spark.createDataFrame([(wkt,)], schema) + .withColumn("geom", expr("ST_GeomFromText(wkt)")) + .selectExpr("geom", "ST_NDims(geom)", "ST_ZMin(geom)") .take(1)[0] + ) assert n_dims == 3 assert z_min is None assert geom.equals(returned_geom) diff --git a/python/tests/utils/test_geomserde_speedup.py b/python/tests/utils/test_geomserde_speedup.py index 4dcf638332..a7478e0c75 100644 --- a/python/tests/utils/test_geomserde_speedup.py +++ b/python/tests/utils/test_geomserde_speedup.py @@ -29,16 +29,13 @@ ) from shapely.wkt import loads as wkt_loads + class TestGeomSerdeSpeedup: def test_speedup_enabled(self): assert geometry_serde.speedup_enabled def test_point(self): - points = [ - wkt_loads("POINT EMPTY"), - Point(10, 20), - Point(10, 20, 30) - ] + points = [wkt_loads("POINT EMPTY"), Point(10, 20), Point(10, 20, 30)] self._test_serde_roundtrip(points) def test_linestring(self): @@ -65,7 +62,9 @@ def test_multi_linestring(self): wkt_loads("MULTILINESTRING EMPTY"), MultiLineString([[(10, 20), (30, 40)]]), MultiLineString([[(10, 20), (30, 40)], [(50, 60), (70, 80)]]), - MultiLineString([[(10, 20, 30), (30, 40, 50)], [(50, 60, 70), (70, 80, 90)]]), + MultiLineString( + [[(10, 20, 30), (30, 40, 50)], [(50, 60, 70), (70, 80, 90)]] + ), ] self._test_serde_roundtrip(multi_linestrings) @@ -90,32 +89,49 @@ def test_multi_polygon(self): MultiPolygon([Polygon(ext)]), MultiPolygon([Polygon(ext), Polygon(ext, [int0])]), MultiPolygon([Polygon(ext), Polygon(ext, [int0, int1])]), - MultiPolygon([Polygon(ext, [int1]), Polygon(ext), Polygon(ext, [int0, int1])]), + MultiPolygon( + [Polygon(ext, [int1]), Polygon(ext), Polygon(ext, [int0, int1])] + ), ] self._test_serde_roundtrip(multi_polygons) def test_geometry_collection(self): geometry_collections = [ wkt_loads("GEOMETRYCOLLECTION EMPTY"), - GeometryCollection([Point(10, 20), LineString([(10, 20), (30, 40)]), Point(30, 40)]), - GeometryCollection([ - MultiPoint([(10, 20), (30, 40)]), - MultiLineString([[(10, 20), (30, 40)], [(50, 60), (70, 80)]]), - MultiPolygon([ - Polygon( - [(0, 0), (100, 0), (100, 100), (0, 100), (0, 0)], - [[(10, 10), (10, 15), (15, 15), (15, 10), (10, 10)]]) - ]), - Point(100, 200) - ]), - GeometryCollection([ - GeometryCollection([Point(10, 20), LineString([(10, 20), (30, 40)]), Point(30, 40)]), - GeometryCollection([ + GeometryCollection( + [Point(10, 20), LineString([(10, 20), (30, 40)]), Point(30, 40)] + ), + GeometryCollection( + [ MultiPoint([(10, 20), (30, 40)]), MultiLineString([[(10, 20), (30, 40)], [(50, 60), (70, 80)]]), - Point(10, 20) - ]) - ]) + MultiPolygon( + [ + Polygon( + [(0, 0), (100, 0), (100, 100), (0, 100), (0, 0)], + [[(10, 10), (10, 15), (15, 15), (15, 10), (10, 10)]], + ) + ] + ), + Point(100, 200), + ] + ), + GeometryCollection( + [ + GeometryCollection( + [Point(10, 20), LineString([(10, 20), (30, 40)]), Point(30, 40)] + ), + GeometryCollection( + [ + MultiPoint([(10, 20), (30, 40)]), + MultiLineString( + [[(10, 20), (30, 40)], [(50, 60), (70, 80)]] + ), + Point(10, 20), + ] + ), + ] + ), ] self._test_serde_roundtrip(geometry_collections) @@ -127,7 +143,9 @@ def _test_serde_roundtrip(geoms): # GEOSGeom_createEmptyLineString in libgeos creates LineString with # Z dimension, This bug has been fixed by # https://github.com/libgeos/geos/pull/745 - geom_actual_wkt = geom_actual.wkt.replace('LINESTRING Z EMPTY', 'LINESTRING EMPTY') + geom_actual_wkt = geom_actual.wkt.replace( + "LINESTRING Z EMPTY", "LINESTRING EMPTY" + ) assert geom.wkt == geom_actual_wkt @staticmethod diff --git a/spark-version-converter.py b/spark-version-converter.py index a46800b5ee..8da0d7c2e8 100644 --- a/spark-version-converter.py +++ b/spark-version-converter.py @@ -19,63 +19,78 @@ import fileinput import sys -spark2_anchor = 'SPARK2 anchor' -spark3_anchor = 'SPARK3 anchor' -files = ['sql/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala', - 'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala', - 'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala', - 'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala', - 'sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala'] +spark2_anchor = "SPARK2 anchor" +spark3_anchor = "SPARK3 anchor" +files = [ + "sql/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala", + "sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala", + "sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala", + "sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala", + "sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala", +] + def switch_version(line): - if line[:2] == '//': - print(line[2:], end='') # enable code - return 'enabled' + if line[:2] == "//": + print(line[2:], end="") # enable code + return "enabled" else: - print('//' + line, end='') # disable code - return 'disabled' + print("//" + line, end="") # disable code + return "disabled" + def enable_version(line): - if line[:2] == '//': - print(line[2:], end='') # enable code - return 'enabled' + if line[:2] == "//": + print(line[2:], end="") # enable code + return "enabled" else: - print(line, end='') - return 'enabled before' + print(line, end="") + return "enabled before" + def disable_version(line): - if line[:2] == '//': - print(line, end='') - return 'disabled before' + if line[:2] == "//": + print(line, end="") + return "disabled before" else: - print('//' + line, end='') # disable code - return 'disabled' + print("//" + line, end="") # disable code + return "disabled" + def parse_file(filepath, argv): - conversion_result_spark2 = '' - conversion_result_spark3 = '' - if argv[1] == 'spark2': + conversion_result_spark2 = "" + conversion_result_spark3 = "" + if argv[1] == "spark2": with fileinput.FileInput(filepath, inplace=True) as file: for line in file: if spark2_anchor in line: - conversion_result_spark2 = spark2_anchor + ' ' + enable_version(line) + conversion_result_spark2 = ( + spark2_anchor + " " + enable_version(line) + ) elif spark3_anchor in line: - conversion_result_spark3 = spark3_anchor + ' ' + disable_version(line) + conversion_result_spark3 = ( + spark3_anchor + " " + disable_version(line) + ) else: - print(line, end='') - return conversion_result_spark2 + ' and ' + conversion_result_spark3 - elif argv[1] == 'spark3': + print(line, end="") + return conversion_result_spark2 + " and " + conversion_result_spark3 + elif argv[1] == "spark3": with fileinput.FileInput(filepath, inplace=True) as file: for line in file: if spark2_anchor in line: - conversion_result_spark2 = spark2_anchor + ' ' + disable_version(line) + conversion_result_spark2 = ( + spark2_anchor + " " + disable_version(line) + ) elif spark3_anchor in line: - conversion_result_spark3 = spark3_anchor + ' ' + enable_version(line) + conversion_result_spark3 = ( + spark3_anchor + " " + enable_version(line) + ) else: - print(line, end='') - return conversion_result_spark2 + ' and ' + conversion_result_spark3 + print(line, end="") + return conversion_result_spark2 + " and " + conversion_result_spark3 else: - return 'wrong spark version' + return "wrong spark version" + for filepath in files: - print(filepath + ': ' + parse_file(filepath, sys.argv)) + print(filepath + ": " + parse_file(filepath, sys.argv)) From b446d77340e1c94788214ec731d99536652fbeb6 Mon Sep 17 00:00:00 2001 From: John Bampton Date: Mon, 23 Sep 2024 18:29:04 +1000 Subject: [PATCH 2/3] Add black jupyter --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80c806295d..93be4b5792 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.8.0 hooks: - - id: black + - id: black-jupyter - repo: https://github.com/codespell-project/codespell rev: v2.3.0 hooks: From 70dc53733eeffc8b3df85d64cb74626f4581a2f3 Mon Sep 17 00:00:00 2001 From: John Bampton Date: Mon, 23 Sep 2024 19:19:32 +1000 Subject: [PATCH 3/3] Remove ruff --- .github/linters/ruff.toml | 77 --------------------------------------- .pre-commit-config.yaml | 5 --- 2 files changed, 82 deletions(-) delete mode 100644 .github/linters/ruff.toml diff --git a/.github/linters/ruff.toml b/.github/linters/ruff.toml deleted file mode 100644 index 68b176f4e0..0000000000 --- a/.github/linters/ruff.toml +++ /dev/null @@ -1,77 +0,0 @@ -# Exclude a variety of commonly ignored directories. -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".git-rewrite", - ".hg", - ".ipynb_checkpoints", - ".mypy_cache", - ".nox", - ".pants.d", - ".pyenv", - ".pytest_cache", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - ".vscode", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "site-packages", - "venv", -] - -# Same as Black. -line-length = 88 -indent-width = 4 - -# Assume Python 3.8 -target-version = "py38" - -[lint] -# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. -# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or -# McCabe complexity (`C901`) by default. -select = ["E3", "E4", "E5", "E7", "E9", "F"] -ignore = ["E501", "E721", "E722", "E731", "F401", "F402", "F403", "F405", "F811", "F821", "F822", "F841", "F901"] - -# Allow fix for all enabled rules (when `--fix`) is provided. -fixable = ["ALL"] -unfixable = [] - -# Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -[format] -# Like Black, use double quotes for strings. -quote-style = "double" - -# Like Black, indent with spaces, rather than tabs. -indent-style = "space" - -# Like Black, respect magic trailing commas. -skip-magic-trailing-comma = false - -# Like Black, automatically detect the appropriate line ending. -line-ending = "auto" - -# Enable auto-formatting of code examples in docstrings. Markdown, -# reStructuredText code/literal blocks and doctests are all supported. -# -# This is currently disabled by default, but it is planned for this -# to be opt-out in the future. -docstring-code-format = false - -# Set the line length limit used when formatting code snippets in -# docstrings. -# -# This only has an effect when the `docstring-code-format` setting is -# enabled. -docstring-code-line-length = "dynamic" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93be4b5792..972d26e9c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,11 +22,6 @@ repos: description: Check spelling with codespell args: [--ignore-words=.github/linters/codespell.txt] exclude: ^docs/image|^spark/common/src/test/resources|^docs/usecases|^tools/maven/scalafmt - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.5 - hooks: - - id: ruff - args: [--config=.github/linters/ruff.toml, --fix] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: