diff --git a/docs/ppl-lang/functions/ppl-ip.md b/docs/ppl-lang/functions/ppl-ip.md index fb0b468ba..65cc9dac9 100644 --- a/docs/ppl-lang/functions/ppl-ip.md +++ b/docs/ppl-lang/functions/ppl-ip.md @@ -32,4 +32,67 @@ Note: - `ip` can be an IPv4 or an IPv6 address - `cidr` can be an IPv4 or an IPv6 block - `ip` and `cidr` must be either both IPv4 or both IPv6 - - `ip` and `cidr` must both be valid and non-empty/non-null \ No newline at end of file + - `ip` and `cidr` must both be valid and non-empty/non-null + +### `GEOIP` + +**Description** + +`GEOIP(ip[, property]...)` retrieves geospatial data corresponding to the provided `ip`. + +**Argument type:** +- `ip` is string be **STRING** representing an IPv4 or an IPv6 address. +- `property` is **STRING** and must be one of the following: + - `COUNTRY_ISO_CODE` + - `COUNTRY_NAME` + - `CONTINENT_NAME` + - `REGION_ISO_CODE` + - `REGION_NAME` + - `CITY_NAME` + - `TIME_ZONE` + - `LOCATION` +- Return type: + - **STRING** if one property given + - **STRUCT_TYPE** if more than one or no property is given + +Example: + +_Without properties:_ + + os> source=ips | eval a = geoip(ip) | fields ip, a + fetched rows / total rows = 2/2 + +---------------------+-------------------------------------------------------------------------------------------------------+ + |ip |lol | + +---------------------+-------------------------------------------------------------------------------------------------------+ + |66.249.157.90 |{JM, Jamaica, North America, 14, Saint Catherine Parish, Portmore, America/Jamaica, 17.9686,-76.8827} | + |2a09:bac2:19f8:2ac3::|{CA, Canada, North America, PE, Prince Edward Island, Charlottetown, America/Halifax, 46.2396,-63.1355}| + +---------------------+-------+------+-------------------------------------------------------------------------------------------------------+ + +_With one property:_ + + os> source=users | eval a = geoip(ip, COUNTRY_NAME) | fields ip, a + fetched rows / total rows = 2/2 + +---------------------+-------+ + |ip |a | + +---------------------+-------+ + |66.249.157.90 |Jamaica| + |2a09:bac2:19f8:2ac3::|Canada | + +---------------------+-------+ + +_With multiple properties:_ + + os> source=users | eval a = geoip(ip, COUNTRY_NAME, REGION_NAME, CITY_NAME) | fields ip, a + fetched rows / total rows = 2/2 + +---------------------+---------------------------------------------+ + |ip |a | + +---------------------+---------------------------------------------+ + |66.249.157.90 |{Jamaica, Saint Catherine Parish, Portmore} | + |2a09:bac2:19f8:2ac3::|{Canada, Prince Edward Island, Charlottetown}| + +---------------------+---------------------------------------------+ + +Note: +- To use `geoip` user must create spark table containing geo ip location data. Instructions to create table can be found [here](../../opensearch-geoip.md). + - `geoip` command by default expects the created table to be called `geoip_ip_data`. + - if a different table name is desired, can set `spark.geoip.tablename` spark config to new table name. +- `ip` can be an IPv4 or an IPv6 address. +- `geoip` commands will always calculated first if used with other eval functions. diff --git a/docs/ppl-lang/planning/ppl-geoip-command.md b/docs/ppl-lang/planning/ppl-geoip-command.md new file mode 100644 index 000000000..aaed6c156 --- /dev/null +++ b/docs/ppl-lang/planning/ppl-geoip-command.md @@ -0,0 +1,59 @@ +## geoip syntax proposal + +geoip function to add information about the geographical location of an IPv4 or IPv6 address + +**Implementation syntax** +- `... | eval geoinfo = geoip(ipAddress *[,properties])` +- generic syntax +- `... | eval geoinfo = geoip(ipAddress)` +- retrieves all geo data +- `... | eval geoinfo = geoip(ipAddress, city, location)` +- retrieve only city, and location + +**Implementation details** +- Current implementation requires user to have created a geoip table. Geoip table has the following schema: + + ```SQL + CREATE TABLE geoip ( + cidr STRING, + country_iso_code STRING, + country_name STRING, + continent_name STRING, + region_iso_code STRING, + region_name STRING, + city_name STRING, + time_zone STRING, + location STRING, + ip_range_start BIGINT, + ip_range_end BIGINT, + ipv4 BOOLEAN + ) + ``` + +- `geoip` is resolved by performing a join on said table and projecting the resulting geoip data as a struct. +- an example of using `geoip` is equivalent to running the following SQL query: + + ```SQL + SELECT source.*, struct(geoip.country_name, geoip.city_name) AS a + FROM source, geoip + WHERE geoip.ip_range_start <= ip_to_int(source.ip) + AND geoip.ip_range_end > ip_to_int(source.ip) + AND geoip.ip_type = is_ipv4(source.ip); + ``` +- in the case that only one property is provided in function call, `geoip` returns string of specified property instead: + + ```SQL + SELECT source.*, geoip.country_name AS a + FROM source, geoip + WHERE geoip.ip_range_start <= ip_to_int(source.ip) + AND geoip.ip_range_end > ip_to_int(source.ip) + AND geoip.ip_type = is_ipv4(source.ip); + ``` + +**Future plan for additional data-sources** + +- Currently only using pre-existing geoip table defined within spark is possible. +- There is future plans to allow users to specify data sources: + - API data sources - if users have their own geoip provided will create ability for users to configure and call said endpoints + - OpenSearch geospatial client - once geospatial client is published we can leverage client to utilize OpenSearch geo2ip functionality. +- Additional datasource connection params will be provided through spark config options. diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 7c19cab12..5ea123c9d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -771,6 +771,79 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } + protected def createGeoIpTestTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | ip STRING, + | ipv4 STRING, + | isValid BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES ('66.249.157.90', '66.249.157.90', true), + | ('2a09:bac2:19f8:2ac3::', 'Given IPv6 is not mapped to IPv4', true), + | ('192.168.2.', '192.168.2.', false), + | ('2001:db8::ff00:12:', 'Given IPv6 is not mapped to IPv4', false) + | """.stripMargin) + } + + protected def createGeoIpTable(): Unit = { + sql(s""" + | CREATE TABLE geoip + | ( + | cidr STRING, + | country_iso_code STRING, + | country_name STRING, + | continent_name STRING, + | region_iso_code STRING, + | region_name STRING, + | city_name STRING, + | time_zone STRING, + | location STRING, + | ip_range_start DECIMAL(38,0), + | ip_range_end DECIMAL(38,0), + | ipv4 BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO geoip + | VALUES ( + | '66.249.157.0/24', + | 'JM', + | 'Jamaica', + | 'North America', + | '14', + | 'Saint Catherine Parish', + | 'Portmore', + | 'America/Jamaica', + | '17.9686,-76.8827', + | 1123654912, + | 1123655167, + | true + | ), + | ( + | '2a09:bac2:19f8::/45', + | 'CA', + | 'Canada', + | 'North America', + | 'PE', + | 'Prince Edward Island', + | 'Charlottetown', + | 'America/Halifax', + | '46.2396,-63.1355', + | 55878094401180025937395073088449675264, + | 55878094401189697343951990121847324671, + | false + | ) + | """.stripMargin) + } + protected def createNestedJsonContentTable(tempFile: Path, testTable: String): Unit = { val json = """ diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala new file mode 100644 index 000000000..7031ab067 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala @@ -0,0 +1,314 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import java.util + +import org.opensearch.sql.expression.function.SerializableUdf.visit +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, CreateNamedStruct, EqualTo, Expression, GreaterThanOrEqual, LessThan, Literal} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Filter, Join, JoinHint, LogicalPlan, Project, SubqueryAlias} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLGeoipITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createGeoIpTestTable(testTable) + createGeoIpTable() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + private def getGeoIpQueryPlan( + ipAddress: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan, + projectionProperties: Alias): LogicalPlan = { + val joinPlan = getJoinPlan(ipAddress, left, right) + getProjection(joinPlan, projectionProperties) + } + + private def getJoinPlan( + ipAddress: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan): LogicalPlan = { + val is_ipv4 = visit("is_ipv4", util.List.of[Expression](ipAddress)) + val ip_to_int = visit("ip_to_int", util.List.of[Expression](ipAddress)) + + val t1 = SubqueryAlias("t1", left) + val t2 = SubqueryAlias("t2", right) + + val joinCondition = And( + And( + GreaterThanOrEqual(ip_to_int, UnresolvedAttribute("t2.ip_range_start")), + LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end"))), + EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4"))) + Join(t1, t2, LeftOuter, Some(joinCondition), JoinHint.NONE) + } + + private def getProjection(joinPlan: LogicalPlan, projectionProperties: Alias): LogicalPlan = { + val projection = Project(Seq(UnresolvedStar(None), projectionProperties), joinPlan) + val dropList = Seq( + "t2.country_iso_code", + "t2.country_name", + "t2.continent_name", + "t2.region_iso_code", + "t2.region_name", + "t2.city_name", + "t2.time_zone", + "t2.location", + "t2.cidr", + "t2.ip_range_start", + "t2.ip_range_end", + "t2.ipv4").map(UnresolvedAttribute(_)) + DataFrameDropColumns(dropList, projection) + } + + test("test geoip with no parameters") { + val frame = sql(s""" + | source = $testTable | where isValid = true | eval a = geoip(ip) | fields ip, a + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + + // Define the expected results + val expectedResults: Array[Row] = Array( + Row( + "66.249.157.90", + Row( + "JM", + "Jamaica", + "North America", + "14", + "Saint Catherine Parish", + "Portmore", + "America/Jamaica", + "17.9686,-76.8827")), + Row( + "2a09:bac2:19f8:2ac3::", + Row( + "CA", + "Canada", + "North America", + "PE", + "Prince Edward Island", + "Charlottetown", + "America/Halifax", + "46.2396,-63.1355"))) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(testTable.split("\\.").toSeq)) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + val projectionStruct = CreateNamedStruct( + Seq( + Literal("country_iso_code"), + UnresolvedAttribute("t2.country_iso_code"), + Literal("country_name"), + UnresolvedAttribute("t2.country_name"), + Literal("continent_name"), + UnresolvedAttribute("t2.continent_name"), + Literal("region_iso_code"), + UnresolvedAttribute("t2.region_iso_code"), + Literal("region_name"), + UnresolvedAttribute("t2.region_name"), + Literal("city_name"), + UnresolvedAttribute("t2.city_name"), + Literal("time_zone"), + UnresolvedAttribute("t2.time_zone"), + Literal("location"), + UnresolvedAttribute("t2.location"))) + val structProjection = Alias(projectionStruct, "a")() + val geoIpPlan = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjection) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test geoip with one parameters") { + val frame = sql(s""" + | source = $testTable | where isValid = true | eval a = geoip(ip, country_name) | fields ip, a + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("66.249.157.90", "Jamaica"), Row("2a09:bac2:19f8:2ac3::", "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(testTable.split("\\.").toSeq)) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + val geoIpPlan = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjection) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test geoip with multiple parameters") { + val frame = sql(s""" + | source = $testTable | where isValid = true | eval a = geoip(ip, country_name, city_name) | fields ip, a + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("66.249.157.90", Row("Jamaica", "Portmore")), + Row("2a09:bac2:19f8:2ac3::", Row("Canada", "Charlottetown"))) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(testTable.split("\\.").toSeq)) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + val projectionStruct = CreateNamedStruct( + Seq( + Literal("country_name"), + UnresolvedAttribute("t2.country_name"), + Literal("city_name"), + UnresolvedAttribute("t2.city_name"))) + val structProjection = Alias(projectionStruct, "a")() + val geoIpPlan = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjection) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test geoip with partial projection on evaluated fields") { + val frame = sql(s""" + | source = $testTable | where isValid = true | eval a = geoip(ip, city_name), b = geoip(ip, country_name) | fields ip, b + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("66.249.157.90", "Jamaica"), Row("2a09:bac2:19f8:2ac3::", "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(testTable.split("\\.").toSeq)) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.city_name"), "a")() + val geoIpPlanA = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjectionA) + + val structProjectionB = Alias(UnresolvedAttribute("t2.country_name"), "b")() + val geoIpPlanB = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), geoIpPlanA, geoTable, structProjectionB) + + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("b")), geoIpPlanB) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test geoip with projection on field that exists in both source and geoip table") { + val frame = sql(s""" + | source = $testTable | where isValid = true | eval a = geoip(ip, country_name) | fields ipv4, a + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("66.249.157.90", "Jamaica"), Row("Given IPv6 is not mapped to IPv4", "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(testTable.split("\\.").toSeq)) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + val geoIpPlan = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjection) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ipv4"), UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test geoip with invalid parameter") { + assertThrows[ParseException](sql(s""" + | source = $testTable | where isValid = true | eval a = geoip(ip, invalid_param) | fields ip, a + | """.stripMargin)) + } + + test("test geoip with invalid ip address provided") { + val frame = sql(s""" + | source = $testTable | eval a = geoip(ip) | fields ip, a + | """.stripMargin) + + // Retrieve the results + assertThrows[SparkException](frame.collect()) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index b7d615980..a6ab4f1de 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -416,9 +416,6 @@ ISPRESENT: 'ISPRESENT'; BETWEEN: 'BETWEEN'; CIDRMATCH: 'CIDRMATCH'; -// Geo Loction -GEOIP: 'GEOIP'; - // FLOWCONTROL FUNCTIONS IFNULL: 'IFNULL'; NULLIF: 'NULLIF'; @@ -428,6 +425,18 @@ TYPEOF: 'TYPEOF'; //OTHER CONDITIONAL EXPRESSIONS COALESCE: 'COALESCE'; +//GEOLOCATION FUNCTIONS +GEOIP: 'GEOIP'; + +//GEOLOCATION PROPERTIES +COUNTRY_ISO_CODE: 'COUNTRY_ISO_CODE'; +COUNTRY_NAME: 'COUNTRY_NAME'; +CONTINENT_NAME: 'CONTINENT_NAME'; +REGION_ISO_CODE: 'REGION_ISO_CODE'; +REGION_NAME: 'REGION_NAME'; +CITY_NAME: 'CITY_NAME'; +LOCATION: 'LOCATION'; + // RELEVANCE FUNCTIONS AND PARAMETERS MATCH: 'MATCH'; MATCH_PHRASE: 'MATCH_PHRASE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index b990fd549..0a2cdf1a0 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -387,6 +387,11 @@ sortbyClause evalClause : fieldExpression EQUAL expression + | geoipCommand + ; + +geoipCommand + : fieldExpression EQUAL GEOIP LT_PRTHS ipAddress = functionArg (COMMA properties = geoIpPropertyList)? RT_PRTHS ; // aggregation terms @@ -446,7 +451,6 @@ valueExpression | positionFunction # positionFunctionCall | caseFunction # caseExpr | timestampFunction # timestampFunctionCall - | geoipFunction # geoFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr | ident ARROW expression # lambda @@ -544,11 +548,6 @@ dataTypeFunctionCall : CAST LT_PRTHS expression AS convertedDataType RT_PRTHS ; -// geoip function -geoipFunction - : GEOIP LT_PRTHS (datasource = functionArg COMMA)? ipAddress = functionArg (COMMA properties = stringLiteral)? RT_PRTHS - ; - // boolean functions booleanFunctionCall : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS @@ -582,7 +581,6 @@ evalFunctionName | cryptographicFunctionName | jsonFunctionName | collectionFunctionName - | geoipFunctionName | lambdaFunctionName ; @@ -900,10 +898,6 @@ lambdaFunctionName | TRANSFORM | REDUCE ; - -geoipFunctionName - : GEOIP - ; positionFunctionName : POSITION @@ -913,6 +907,21 @@ coalesceFunctionName : COALESCE ; +geoIpPropertyList + : geoIpProperty (COMMA geoIpProperty)* + ; + +geoIpProperty + : COUNTRY_ISO_CODE + | COUNTRY_NAME + | CONTINENT_NAME + | REGION_ISO_CODE + | REGION_NAME + | CITY_NAME + | TIME_ZONE + | LOCATION + ; + // operators comparisonOperator : EQUAL diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index dadf6b968..f9b333b26 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -341,10 +341,15 @@ public T visitExistsSubquery(ExistsSubquery node, C context) { public T visitWindow(Window node, C context) { return visitChildren(node, context); } + public T visitCidr(Cidr node, C context) { return visitChildren(node, context); } + public T visitGeoIp(GeoIp node, C context) { + return visitChildren(node, context); + } + public T visitFlatten(Flatten flatten, C context) { return visitChildren(flatten, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java index 0cc27b6a9..c8482a4ff 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java @@ -12,7 +12,7 @@ import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.Node; import java.util.List; @@ -23,7 +23,7 @@ @EqualsAndHashCode(callSuper = false) @RequiredArgsConstructor public class Eval extends UnresolvedPlan { - private final List expressionList; + private final List expressionList; private UnresolvedPlan child; @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/GeoIp.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/GeoIp.java new file mode 100644 index 000000000..feefa6929 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/GeoIp.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.AttributeList; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class GeoIp extends UnresolvedPlan { + private UnresolvedPlan child; + private final Field field; + private final UnresolvedExpression ipAddress; + private final AttributeList properties; + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitGeoIp(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } +} \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java index e80a26bc4..e931175ff 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -11,13 +11,18 @@ import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.types.DataTypes; +import scala.Function1; import scala.Function2; import scala.Option; import scala.Serializable; +import scala.runtime.AbstractFunction1; +import scala.runtime.AbstractFunction2; import scala.collection.JavaConverters; import scala.collection.mutable.WrappedArray; -import scala.runtime.AbstractFunction2; +import java.math.BigInteger; +import java.net.InetAddress; +import java.net.UnknownHostException; import java.util.Collection; import java.util.List; import java.util.Map; @@ -28,7 +33,6 @@ import static org.opensearch.sql.expression.function.JsonUtils.removeNestedKey; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; - public interface SerializableUdf { @@ -142,11 +146,66 @@ public Boolean apply(String ipAddress, String cidrBlock) { } }; + class geoIpUtils { + /** + * Checks if provided ip string is ipv4 or ipv6. + * + * @param ipAddress To input ip string. + * @return true if ipAddress is ipv4, false if ipaddress is ipv6, AddressString Exception if invalid ip. + */ + public static Function1 isIpv4 = new SerializableAbstractFunction1<>() { + + IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() + .allowEmpty(false) + .setEmptyAsLoopback(false) + .allow_inet_aton(false) + .allowSingleSegment(false) + .toParams(); + + @Override + public Boolean apply(String ipAddress) { + IPAddressString parsedIpAddress = new IPAddressString(ipAddress, valOptions); + + try { + parsedIpAddress.validate(); + } catch (AddressStringException e) { + throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage()); + } + + return parsedIpAddress.isIPv4(); + } + }; + + /** + * Convert ipAddress string to interger representation + * + * @param ipAddress The input ip string. + * @return converted BigInteger from ipAddress string. + */ + public static Function1 ipToInt = new SerializableAbstractFunction1<>() { + @Override + public BigInteger apply(String ipAddress) { + try { + InetAddress inetAddress = InetAddress.getByName(ipAddress); + byte[] addressBytes = inetAddress.getAddress(); + return new BigInteger(1, addressBytes); + } catch (UnknownHostException e) { + System.err.println("Invalid IP address: " + e.getMessage()); + } + return null; + } + }; + } + + abstract class SerializableAbstractFunction1 extends AbstractFunction1 + implements Serializable { + } + /** - * get the function reference according to its name + * Get the function reference according to its name * - * @param funcName - * @return + * @param funcName string representing function to retrieve. + * @return relevant ScalaUDF for given function name. */ static ScalaUDF visit(String funcName, List expressions) { switch (funcName) { @@ -177,6 +236,24 @@ static ScalaUDF visit(String funcName, List expressions) { Option.apply("json_append"), false, true); + case "is_ipv4": + return new ScalaUDF(geoIpUtils.isIpv4, + DataTypes.BooleanType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("is_ipv4"), + false, + true); + case "ip_to_int": + return new ScalaUDF(geoIpUtils.ipToInt, + DataTypes.createDecimalType(38,0), + seq(expressions), + seq(), + Option.empty(), + Option.apply("ip_to_int"), + false, + true); default: return null; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index d7f59bae3..0a6e869ba 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -56,6 +56,7 @@ import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Flatten; +import org.opensearch.sql.ast.tree.GeoIp; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; @@ -69,9 +70,11 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; +import org.opensearch.sql.ppl.utils.GeoIpCatalystLogicalPlanTranslator; import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils; @@ -562,19 +565,63 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { visitFirstChild(node, context); List aliases = new ArrayList<>(); - List letExpressions = node.getExpressionList(); - for (Let let : letExpressions) { - Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); - aliases.add(alias); + List expressions = node.getExpressionList(); + + // Geoip function modifies logical plan and is treated as QueryPlanVisitor instead of ExpressionVisitor + for (Node expr : expressions) { + if (expr instanceof Let) { + Let let = (Let) expr; + Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); + aliases.add(alias); + } else if (expr instanceof UnresolvedPlan) { + expr.accept(this, context); + } else { + throw new SyntaxCheckException("Unexpected node type when visiting EVAL"); + } } - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + + if (!aliases.isEmpty()) { + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + + visitExpressionList(aliases, context); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } else { + return context.getPlan(); } - List expressionList = visitExpressionList(aliases, context); - Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - // build the plan with the projection step - return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } + + @Override + public LogicalPlan visitGeoIp(GeoIp node, CatalystPlanContext context) { + visitExpression(node.getProperties(), context); + List attributeList = new ArrayList<>(); + + while (!context.getNamedParseExpressions().isEmpty()) { + Expression nextExpression = context.getNamedParseExpressions().pop(); + String attributeName = nextExpression.toString(); + + if (attributeList.contains(attributeName)) { + throw new IllegalStateException("Duplicate attribute in GEOIP attribute list"); + } + + attributeList.add(0, attributeName); + } + + String fieldExpression = node.getField().getField().toString(); + Expression ipAddressExpression = visitExpression(node.getIpAddress(), context); + + return GeoIpCatalystLogicalPlanTranslator.getGeoipLogicalPlan( + new GeoIpCatalystLogicalPlanTranslator.GeoIpParameters( + fieldExpression, + ipAddressExpression, + attributeList + ), + context + ); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index d4f9ece87..2ea23babf 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -336,10 +336,18 @@ public UnresolvedPlan visitSortCommand(OpenSearchPPLParser.SortCommandContext ct public UnresolvedPlan visitEvalCommand(OpenSearchPPLParser.EvalCommandContext ctx) { return new Eval( ctx.evalClause().stream() - .map(ct -> (Let) internalVisitExpression(ct)) + .map(ct -> (ct.geoipCommand() != null) ? visit(ct.geoipCommand()) : (Let) internalVisitExpression(ct)) .collect(Collectors.toList())); } + @Override + public UnresolvedPlan visitGeoipCommand(OpenSearchPPLParser.GeoipCommandContext ctx) { + Field field = (Field) internalVisitExpression(ctx.fieldExpression()); + UnresolvedExpression ipAddress = internalVisitExpression(ctx.ipAddress); + AttributeList properties = ctx.properties == null ? new AttributeList(Collections.emptyList()) : (AttributeList) internalVisitExpression(ctx.properties); + return new GeoIp(field, ipAddress, properties); + } + private List getGroupByList(OpenSearchPPLParser.ByClauseContext ctx) { return ctx.fieldList().fieldExpression().stream() .map(this::internalVisitExpression) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 1fe57d13e..a73c593fe 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -49,6 +49,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; +import org.opensearch.sql.ppl.utils.GeoIpCatalystLogicalPlanTranslator; import java.util.Arrays; import java.util.Collections; @@ -450,6 +451,20 @@ public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { return new LambdaFunction(function, arguments); } + @Override + public UnresolvedExpression visitGeoIpPropertyList(OpenSearchPPLParser.GeoIpPropertyListContext ctx) { + ImmutableList.Builder properties = ImmutableList.builder(); + if (ctx != null) { + for (OpenSearchPPLParser.GeoIpPropertyContext property : ctx.geoIpProperty()) { + String propertyName = property.getText().toUpperCase(); + GeoIpCatalystLogicalPlanTranslator.validateGeoIpProperty(propertyName); + properties.add(new Literal(propertyName, DataType.STRING)); + } + } + + return new AttributeList(properties.build()); + } + private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java new file mode 100644 index 000000000..cedc00846 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java @@ -0,0 +1,222 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.spark.SparkEnv; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.Alias$; +import org.apache.spark.sql.catalyst.expressions.And; +import org.apache.spark.sql.catalyst.expressions.CreateStruct; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.LessThan; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.expression.function.SerializableUdf; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.stream.Collectors; + +import static java.util.List.of; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; + +public interface GeoIpCatalystLogicalPlanTranslator { + String SPARK_CONF_KEY = "spark.geoip.tablename"; + String DEFAULT_GEOIP_TABLE_NAME = "geoip"; + String GEOIP_CIDR_COLUMN_NAME = "cidr"; + String GEOIP_IP_RANGE_START_COLUMN_NAME = "ip_range_start"; + String GEOIP_IP_RANGE_END_COLUMN_NAME = "ip_range_end"; + String GEOIP_IPV4_COLUMN_NAME = "ipv4"; + String SOURCE_TABLE_ALIAS = "t1"; + String GEOIP_TABLE_ALIAS = "t2"; + List GEOIP_TABLE_COLUMNS = Arrays.stream(GeoIpProperty.values()) + .map(Enum::name) + .collect(Collectors.toList()); + + /** + * Responsible to produce a Spark Logical Plan with given GeoIp command arguments, below is the sample logical plan + * with configuration [source=users, field=a, ipAddress=ip, properties=[country_name, city_name]] + * +- 'DataFrameDropColumns ['t2.country_iso_code, 't2.country_name, 't2.continent_name, 't2.region_iso_code, 't2.region_name, 't2.city_name, 't2.time_zone, 't2.location, 't2.cidr, 't2.start, 't2.end, 't2.ipv4] + * -- +- 'Project [*, named_struct(country_name, 't2.country_name, city_name, 't2.city_name) AS a#0] + * -- -- +- 'Join LeftOuter, (((ip_to_int('ip) >= 't2.start) AND (ip_to_int('ip) < 't2.end)) AND (is_ipv4('ip) = 't2.ipv4)) + * -- -- -- :- 'SubqueryAlias t1 + * -- -- -- -- : +- 'UnresolvedRelation [users], [], false + * -- -- -- +- 'SubqueryAlias t2 + * -- -- -- -- -- +- 'UnresolvedRelation [geoip], [], false + * . + * And the corresponded SQL query: + * . + * SELECT users.*, struct(geoip.country_name, geoip.city_name) AS a + * FROM users, geoip + * WHERE geoip.ip_range_start <= ip_to_int(users.ip) + * AND geoip.ip_range_end > ip_to_int(users.ip) + * AND geoip.ip_type = is_ipv4(users.ip); + * + * @param parameters GeoIp function parameters. + * @param context Context instance to retrieved Expression in resolved form. + * @return a LogicalPlan which will project new col with geoip location based on given ipAddresses. + */ + static LogicalPlan getGeoipLogicalPlan(GeoIpParameters parameters, CatalystPlanContext context) { + applyJoin(parameters.getIpAddress(), context); + return applyProjection(parameters.getField(), parameters.getProperties(), context); + } + + /** + * Responsible to produce join plan for GeoIp command, below is the sample logical plan + * with configuration [source=users, ipAddress=ip] + * +- 'Join LeftOuter, (((ip_to_int('ip) >= 't2.start) AND (ip_to_int('ip) < 't2.end)) AND (is_ipv4('ip) = 't2.ipv4)) + * -- :- 'SubqueryAlias t1 + * -- -- : +- 'UnresolvedRelation [users], [], false + * -- +- 'SubqueryAlias t2 + * -- -- -- +- 'UnresolvedRelation [geoip], [], false + * + * @param ipAddress Expression representing ip addresses to be queried. + * @param context Context instance to retrieved Expression in resolved form. + * @return a LogicalPlan which will perform join based on ip within cidr range in geoip table. + */ + static private LogicalPlan applyJoin(Expression ipAddress, CatalystPlanContext context) { + return context.apply(left -> { + LogicalPlan right = new UnresolvedRelation(seq(getGeoipTableName()), CaseInsensitiveStringMap.empty(), false); + LogicalPlan leftAlias = SubqueryAlias$.MODULE$.apply(SOURCE_TABLE_ALIAS, left); + LogicalPlan rightAlias = SubqueryAlias$.MODULE$.apply(GEOIP_TABLE_ALIAS, right); + Optional joinCondition = Optional.of(new And( + new And( + new GreaterThanOrEqual( + SerializableUdf.visit("ip_to_int", of(ipAddress)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,GEOIP_IP_RANGE_START_COLUMN_NAME)) + ), + new LessThan( + SerializableUdf.visit("ip_to_int", of(ipAddress)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,GEOIP_IP_RANGE_END_COLUMN_NAME)) + ) + ), + new EqualTo( + SerializableUdf.visit("is_ipv4", of(ipAddress)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,GEOIP_IPV4_COLUMN_NAME)) + ) + )); + context.retainAllNamedParseExpressions(p -> p); + context.retainAllPlans(p -> p); + return join(leftAlias, + rightAlias, + Join.JoinType.LEFT, + joinCondition, + new Join.JoinHint()); + }); + } + + /** + * Responsible to produce a Spark Logical Plan with given GeoIp command arguments, below is the sample logical plan + * with configuration [source=users, field=a, properties=[country_name, city_name]] + * +- 'DataFrameDropColumns ['t2.country_iso_code, 't2.country_name, 't2.continent_name, 't2.region_iso_code, 't2.region_name, 't2.city_name, 't2.time_zone, 't2.location, 't2.cidr, 't2.start, 't2.end, 't2.ipv4] + * -- +- 'Project [*, named_struct(country_name, 't2.country_name, city_name, 't2.city_name) AS a#0] + * + * @param field Name of new eval geoip column. + * @param properties List of geo properties to be returned. + * @param context Context instance to retrieved Expression in resolved form. + * @return a LogicalPlan which will return source table and new eval geoip column. + */ + static private LogicalPlan applyProjection(String field, List properties, CatalystPlanContext context) { + List projectExpressions = new ArrayList<>(); + projectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); + + List geoIpStructFields = createGeoIpStructFields(properties); + Expression columnValue = (geoIpStructFields.size() == 1)? + geoIpStructFields.get(0) : CreateStruct.apply(seq(geoIpStructFields)); + + NamedExpression geoCol = Alias$.MODULE$.apply( + columnValue, + field, + NamedExpression.newExprId(), + seq(new ArrayList<>()), + Option.empty(), + seq(new ArrayList<>())); + + projectExpressions.add(geoCol); + + List dropList = createGeoIpStructFields(new ArrayList<>()); + dropList.addAll(List.of( + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS, GEOIP_CIDR_COLUMN_NAME)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS, GEOIP_IP_RANGE_START_COLUMN_NAME)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS, GEOIP_IP_RANGE_END_COLUMN_NAME)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS, GEOIP_IPV4_COLUMN_NAME)) + )); + + context.apply(p -> new Project(seq(projectExpressions), p)); + return context.apply(p -> new DataFrameDropColumns(seq(dropList), p)); + } + + static private List createGeoIpStructFields(List attributeList) { + List attributeListToUse; + if (attributeList == null || attributeList.isEmpty()) { + attributeListToUse = GEOIP_TABLE_COLUMNS; + } else { + attributeListToUse = attributeList; + } + + return attributeListToUse.stream() + .map(a -> UnresolvedAttribute$.MODULE$.apply(seq( + GEOIP_TABLE_ALIAS, + a.toLowerCase(Locale.ROOT) + ))) + .collect(Collectors.toList()); + } + + static private String getGeoipTableName() { + String tableName = DEFAULT_GEOIP_TABLE_NAME; + + if (SparkEnv.get() != null && SparkEnv.get().conf() != null) { + tableName = SparkEnv.get().conf().get(SPARK_CONF_KEY, DEFAULT_GEOIP_TABLE_NAME); + } + + return tableName; + } + + @Getter + @AllArgsConstructor + class GeoIpParameters { + private final String field; + private final Expression ipAddress; + private final List properties; + } + + enum GeoIpProperty { + COUNTRY_ISO_CODE, + COUNTRY_NAME, + CONTINENT_NAME, + REGION_ISO_CODE, + REGION_NAME, + CITY_NAME, + TIME_ZONE, + LOCATION + } + + public static void validateGeoIpProperty(String propertyName) { + try { + GeoIpProperty.valueOf(propertyName); + } catch (NullPointerException | IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid properties used."); + } + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala new file mode 100644 index 000000000..460b9769c --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala @@ -0,0 +1,332 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import java.util + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.expression.function.SerializableUdf.visit +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, CreateNamedStruct, Descending, EqualTo, Expression, ExprId, GreaterThanOrEqual, In, LessThan, Literal, NamedExpression, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.{LeftOuter, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.types.DataTypes + +class PPLLogicalPlanGeoipFunctionTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + private def getGeoIpQueryPlan( + ipAddress: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan, + projectionProperties: Alias): LogicalPlan = { + val joinPlan = getJoinPlan(ipAddress, left, right) + getProjection(joinPlan, projectionProperties) + } + + private def getJoinPlan( + ipAddress: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan): LogicalPlan = { + val is_ipv4 = visit("is_ipv4", util.List.of[Expression](ipAddress)) + val ip_to_int = visit("ip_to_int", util.List.of[Expression](ipAddress)) + + val t1 = SubqueryAlias("t1", left) + val t2 = SubqueryAlias("t2", right) + + val joinCondition = And( + And( + GreaterThanOrEqual(ip_to_int, UnresolvedAttribute("t2.ip_range_start")), + LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end"))), + EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4"))) + Join(t1, t2, LeftOuter, Some(joinCondition), JoinHint.NONE) + } + + private def getProjection(joinPlan: LogicalPlan, projectionProperties: Alias): LogicalPlan = { + val projection = Project(Seq(UnresolvedStar(None), projectionProperties), joinPlan) + val dropList = Seq( + "t2.country_iso_code", + "t2.country_name", + "t2.continent_name", + "t2.region_iso_code", + "t2.region_name", + "t2.city_name", + "t2.time_zone", + "t2.location", + "t2.cidr", + "t2.ip_range_start", + "t2.ip_range_end", + "t2.ipv4").map(UnresolvedAttribute(_)) + DataFrameDropColumns(dropList, projection) + } + + test("test geoip function - only ip_address provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source = users | eval a = geoip(ip_address)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val projectionStruct = CreateNamedStruct( + Seq( + Literal("country_iso_code"), + UnresolvedAttribute("t2.country_iso_code"), + Literal("country_name"), + UnresolvedAttribute("t2.country_name"), + Literal("continent_name"), + UnresolvedAttribute("t2.continent_name"), + Literal("region_iso_code"), + UnresolvedAttribute("t2.region_iso_code"), + Literal("region_name"), + UnresolvedAttribute("t2.region_name"), + Literal("city_name"), + UnresolvedAttribute("t2.city_name"), + Literal("time_zone"), + UnresolvedAttribute("t2.time_zone"), + Literal("location"), + UnresolvedAttribute("t2.location"))) + val structProjection = Alias(projectionStruct, "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - source has same name as join alias") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(ip_address, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t1")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - ipAddress col exist in geoip table") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("cidr") + val sourceTable = UnresolvedRelation(seq("t1")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - duplicate parameters") { + val context = new CatalystPlanContext + + val exception = intercept[IllegalStateException] { + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name, country_name)"), + context) + } + + assert(exception.getMessage.contains("Duplicate attribute in GEOIP attribute list")) + } + + test("test geoip function - one property provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=users | eval a = geoip(ip_address, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - multiple properties provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=users | eval a = geoip(ip_address,country_name,location)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + val projectionStruct = CreateNamedStruct( + Seq( + Literal("country_name"), + UnresolvedAttribute("t2.country_name"), + Literal("location"), + UnresolvedAttribute("t2.location"))) + val structProjection = Alias(projectionStruct, "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - multiple geoip calls") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.country_iso_code"), "a")() + val colAPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionA) + + val structProjectionB = Alias(UnresolvedAttribute("t2.region_iso_code"), "b")() + val colBPlan = getGeoIpQueryPlan(ipAddress, colAPlan, geoTable, structProjectionB) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colBPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - other eval function used between geoip") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.time_zone"), "a")() + val colAPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionA) + + val structProjectionC = Alias(UnresolvedAttribute("t2.region_name"), "c")() + val colCPlan = getGeoIpQueryPlan(ipAddress, colAPlan, geoTable, structProjectionC) + + val randProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "b")()) + val colBPlan = Project(randProjectList, colCPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colBPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - other eval function used before geoip") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = rand(), b = geoip(ip_address, city_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionB = Alias(UnresolvedAttribute("t2.city_name"), "b")() + val colBPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionB) + + val randProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "a")()) + val colAPlan = Project(randProjectList, colBPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colAPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - projection on evaluated field") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=users | eval a = geoip(ip_address, country_name) | fields a"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip with partial projection on evaluated fields") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code) | fields b"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.country_iso_code"), "a")() + val colAPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionA) + + val structProjectionB = Alias(UnresolvedAttribute("t2.region_iso_code"), "b")() + val colBPlan = getGeoIpQueryPlan(ipAddress, colAPlan, geoTable, structProjectionB) + + val expectedPlan = Project(Seq(UnresolvedAttribute("b")), colBPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}