From f8d243d640891f18dcfbd662f1fc8fcf95f9700d Mon Sep 17 00:00:00 2001 From: Kenrick Yap <14yapkc1@gmail.com> Date: Wed, 11 Dec 2024 13:06:34 -0800 Subject: [PATCH 1/9] geoip function implementation Signed-off-by: Kenrick Yap <14yapkc1@gmail.com> --- docs/ppl-lang/functions/ppl-ip.md | 65 +++- docs/ppl-lang/planning/ppl-geoip-command.md | 68 +++++ .../flint/spark/FlintSparkSuite.scala | 76 +++++ .../spark/ppl/FlintSparkPPLGeoipITSuite.scala | 92 ++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 15 +- .../src/main/antlr4/OpenSearchPPLParser.g4 | 31 +- .../sql/ast/AbstractNodeVisitor.java | 5 + .../org/opensearch/sql/ast/tree/Eval.java | 4 +- .../org/opensearch/sql/ast/tree/GeoIp.java | 47 +++ .../expression/function/SerializableUdf.java | 52 +++- .../sql/ppl/CatalystQueryPlanVisitor.java | 69 ++++- .../opensearch/sql/ppl/parser/AstBuilder.java | 10 +- .../sql/ppl/parser/AstExpressionBuilder.java | 21 ++ .../GeoIpCatalystLogicalPlanTranslator.java | 239 +++++++++++++++ ...PlanGeoipFunctionTranslatorTestSuite.scala | 287 ++++++++++++++++++ 15 files changed, 1050 insertions(+), 31 deletions(-) create mode 100644 docs/ppl-lang/planning/ppl-geoip-command.md create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/GeoIp.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala diff --git a/docs/ppl-lang/functions/ppl-ip.md b/docs/ppl-lang/functions/ppl-ip.md index fb0b468ba..1a11f8b45 100644 --- a/docs/ppl-lang/functions/ppl-ip.md +++ b/docs/ppl-lang/functions/ppl-ip.md @@ -32,4 +32,67 @@ Note: - `ip` can be an IPv4 or an IPv6 address - `cidr` can be an IPv4 or an IPv6 block - `ip` and `cidr` must be either both IPv4 or both IPv6 - - `ip` and `cidr` must both be valid and non-empty/non-null \ No newline at end of file + - `ip` and `cidr` must both be valid and non-empty/non-null + +### `GEOIP` + +**Description** + +`GEOIP(ip[, property]...)` retrieves geospatial data corresponding to the provided `ip`. + +**Argument type:** +- `ip` is string be **STRING**. +- `property` is **STRING** and must be one of the following: + - `COUNTRY_ISO_CODE` + - `COUNTRY_NAME` + - `CONTINENT_NAME` + - `REGION_ISO_CODE` + - `REGION_NAME` + - `CITY_NAME` + - `TIME_ZONE` + - `LOCATION` +- Return type: + - **STRING** if one property given + - **STRUCT_TYPE** if more than one or no property is given + +Example: + +_Without properties:_ + + os> source=ips | eval a = geoip(ip) | fields ip, a + fetched rows / total rows = 2/2 + +---------------------+-------------------------------------------------------------------------------------------------------+ + |ip |lol | + +---------------------+-------------------------------------------------------------------------------------------------------+ + |66.249.157.90 |{JM, Jamaica, North America, 14, Saint Catherine Parish, Portmore, America/Jamaica, 17.9686,-76.8827} | + |2a09:bac2:19f8:2ac3::|{CA, Canada, North America, PE, Prince Edward Island, Charlottetown, America/Halifax, 46.2396,-63.1355}| + +---------------------+-------+------+-------------------------------------------------------------------------------------------------------+ + +_With one property:_ + + os> source=users | eval a = geoip(ip, COUNTRY_NAME) | fields ip, a + fetched rows / total rows = 2/2 + +---------------------+-------+ + |ip |a | + +---------------------+-------+ + |66.249.157.90 |Jamaica| + |2a09:bac2:19f8:2ac3::|Canada | + +---------------------+-------+ + +_With multiple properties:_ + + os> source=users | eval a = geoip(ip, COUNTRY_NAME, REGION_NAME, CITY_NAME) | fields ip, a + fetched rows / total rows = 2/2 + +---------------------+---------------------------------------------+ + |ip |a | + +---------------------+---------------------------------------------+ + |66.249.157.90 |{Jamaica, Saint Catherine Parish, Portmore} | + |2a09:bac2:19f8:2ac3::|{Canada, Prince Edward Island, Charlottetown}| + +---------------------+---------------------------------------------+ + +Note: +- To use `geoip` user must create spark table containing geo ip location data. Instructions to create table can be found [here](../../opensearch-geoip.md). + - `geoip` command by default expects the created table to be called `geoip_ip_data`. + - if a different table name is desired, can set `spark.geoip.tablename` spark config to new table name. +- `ip` can be an IPv4 or an IPv6 address. +- `geoip` commands will always calculated first if used with other eval functions. diff --git a/docs/ppl-lang/planning/ppl-geoip-command.md b/docs/ppl-lang/planning/ppl-geoip-command.md new file mode 100644 index 000000000..142516ca3 --- /dev/null +++ b/docs/ppl-lang/planning/ppl-geoip-command.md @@ -0,0 +1,68 @@ +## geoip syntax proposal + +geoip function to add information about the geographical location of an IPv4 or IPv6 address + +**Implementation syntax** +- `... | eval geoinfo = geoip([datasource,] ipAddress *[,properties])` +- generic syntax +- `... | eval geoinfo = geoip(ipAddress)` +- use the default geoip datasource +- `... | eval geoinfo = geoip("abc", ipAddress)` +- use the "abc" geoip datasource +- `... | eval geoinfo = geoip(ipAddress, city, location)` +- use the default geoip datasource, retrieve only city, and location +- `... | eval geoinfo = geoip("abc", ipAddress, city, location")` +- use the "abc" geoip datasource, retrieve only city, and location + +**Implementation details** +- Current implementation requires user to have created a geoip table. Geoip table has the following schema: + + ```SQL + CREATE TABLE geoip ( + cidr STRING, + country_iso_code STRING, + country_name STRING, + continent_name STRING, + region_iso_code STRING, + region_name STRING, + city_name STRING, + time_zone STRING, + location STRING, + ip_range_start BIGINT, + ip_range_end BIGINT, + ipv4 BOOLEAN + ) + ``` + +- `geoip` is resolved by performing a join on said table and projecting the resulting geoip data as a struct. +- an example of using `geoip` is equivalent to running the following SQL query: + + ```SQL + SELECT source.*, struct(geoip.country_name, geoip.city_name) AS a + FROM source, geoip + WHERE geoip.ip_range_start <= ip_to_int(source.ip) + AND geoip.ip_range_end > ip_to_int(source.ip) + AND geoip.ip_type = is_ipv4(source.ip); + ``` + +**Future plan for additional data-sources** + +- Currently only using pre-existing geoip table defined within spark is possible. +- There is future plans to allow users to specify data sources: + - API data sources - if users have their own geoip provided will create ability for users to configure and call said endpoints + - OpenSearch geospatial client - once geospatial client is published we can leverage client to utilize opensearch geo2ip functionality. + +### New syntax definition in ANTLR + +```ANTLR + +// functions +evalFunctionCall + : evalFunctionName LT_PRTHS functionArgs RT_PRTHS + | geoipFunction + ; + +geoipFunction + : GEOIP LT_PRTHS (datasource = functionArg COMMA)? ipAddress = functionArg (COMMA properties = stringLiteral)? RT_PRTHS + ; +``` diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 7c19cab12..36d0c6a7d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -771,6 +771,82 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | """.stripMargin) } + protected def createGeoIpTestTable(testTable: String): Unit = { + sql( + s""" + | CREATE TABLE $testTable + | ( + | ip STRING, + | isValid BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql( + s""" + | INSERT INTO $testTable + | VALUES ('66.249.157.90', true), + | ('2a09:bac2:19f8:2ac3::', true), + | ('192.168.2.', false), + | ('2001:db8::ff00:12:', false) + | """.stripMargin) + } + + protected def createGeoIpTable(): Unit = { + sql( + s""" + | CREATE TABLE geoip + | ( + | cidr STRING, + | country_iso_code STRING, + | country_name STRING, + | continent_name STRING, + | region_iso_code STRING, + | region_name STRING, + | city_name STRING, + | time_zone STRING, + | location STRING, + | ip_range_start BIGINT, + | ip_range_end BIGINT, + | ipv4 BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql( + s""" + | INSERT INTO geoip + | VALUES ( + | '66.249.157.0/24', + | 'JM', + | 'Jamaica', + | 'North America', + | '14', + | 'Saint Catherine Parish', + | 'Portmore', + | 'America/Jamaica', + | '17.9686,-76.8827', + | 1123654912, + | 1123655167, + | true + | ), + | VALUES ( + | '2a09:bac2:19f8::/45', + | `'CA', + | 'Canada', + | 'North America', + | 'PE', + | 'Prince Edward Island', + | 'Charlottetown', + | 'America/Halifax', + | '46.2396,-63.1355', + | 55878094401180025937395073088449675264, + | 55878094401189697343951990121847324671, + | false + | ) + | """.stripMargin) + } + protected def createNestedJsonContentTable(tempFile: Path, testTable: String): Unit = { val json = """ diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala new file mode 100644 index 000000000..3e2728553 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLGeoipITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createGeoIpTestTable(testTable) + createGeoIpTable() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test geoip with no parameters") { + val frame = sql( + s""" + | source = $testTable| where isValid = true | eval a = geoip(ip) | fields ip, a + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("66.249.157.90", Row("JM", "Jamaica", "North America", 14, "Saint Catherine Parish", "Portmore", "America/Jamaica", "17.9686, -76.8827")), + Row("2a09:bac2:19f8:2ac3::", Row("CA", "Canada", "North America", "PE", "Prince Edward Island", "Charlottetown", "America/Halifax", "46.2396, -63.1355")) + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test geoip with one parameters") { + val frame = sql( + s""" + | source = $testTable| where isValid = true | eval a = geoip(ip, country_name) | fields ip, a + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("66.249.157.90", "Jamaica"), + Row("2a09:bac2:19f8:2ac3::", "Canada") + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test geoip with multiple parameters") { + val frame = sql( + s""" + | source = $testTable| where isValid = true | eval a = geoip(ip, country_name, city_name) | fields ip, a + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("66.249.157.90", Row("Jamaica", "Portmore")), + Row("2a09:bac2:19f8:2ac3::", Row("Canada", "Charlottetown")) + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index d15f5c8e3..b224fb772 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -416,9 +416,6 @@ ISPRESENT: 'ISPRESENT'; BETWEEN: 'BETWEEN'; CIDRMATCH: 'CIDRMATCH'; -// Geo Loction -GEOIP: 'GEOIP'; - // FLOWCONTROL FUNCTIONS IFNULL: 'IFNULL'; NULLIF: 'NULLIF'; @@ -428,6 +425,18 @@ TYPEOF: 'TYPEOF'; //OTHER CONDITIONAL EXPRESSIONS COALESCE: 'COALESCE'; +//GEOLOCATION FUNCTIONS +GEOIP: 'GEOIP'; + +//GEOLOCATION PROPERTIES +COUNTRY_ISO_CODE: 'COUNTRY_ISO_CODE'; +COUNTRY_NAME: 'COUNTRY_NAME'; +CONTINENT_NAME: 'CONTINENT_NAME'; +REGION_ISO_CODE: 'REGION_ISO_CODE'; +REGION_NAME: 'REGION_NAME'; +CITY_NAME: 'CITY_NAME'; +LOCATION: 'LOCATION'; + // RELEVANCE FUNCTIONS AND PARAMETERS MATCH: 'MATCH'; MATCH_PHRASE: 'MATCH_PHRASE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 2466a3d23..ed1640d2f 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -387,6 +387,11 @@ sortbyClause evalClause : fieldExpression EQUAL expression + | geoipCommand + ; + +geoipCommand + : fieldExpression EQUAL GEOIP LT_PRTHS ipAddress = functionArg (COMMA properties = geoIpPropertyList)? RT_PRTHS ; // aggregation terms @@ -446,7 +451,6 @@ valueExpression | positionFunction # positionFunctionCall | caseFunction # caseExpr | timestampFunction # timestampFunctionCall - | geoipFunction # geoFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr | ident ARROW expression # lambda @@ -544,11 +548,6 @@ dataTypeFunctionCall : CAST LT_PRTHS expression AS convertedDataType RT_PRTHS ; -// geoip function -geoipFunction - : GEOIP LT_PRTHS (datasource = functionArg COMMA)? ipAddress = functionArg (COMMA properties = stringLiteral)? RT_PRTHS - ; - // boolean functions booleanFunctionCall : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS @@ -582,7 +581,6 @@ evalFunctionName | cryptographicFunctionName | jsonFunctionName | collectionFunctionName - | geoipFunctionName | lambdaFunctionName ; @@ -900,10 +898,6 @@ lambdaFunctionName | TRANSFORM | REDUCE ; - -geoipFunctionName - : GEOIP - ; positionFunctionName : POSITION @@ -913,6 +907,21 @@ coalesceFunctionName : COALESCE ; +geoIpPropertyList + : geoIpProperty (COMMA geoIpProperty)* + ; + +geoIpProperty + : COUNTRY_ISO_CODE + | COUNTRY_NAME + | CONTINENT_NAME + | REGION_ISO_CODE + | REGION_NAME + | CITY_NAME + | TIME_ZONE + | LOCATION + ; + // operators comparisonOperator : EQUAL diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index dadf6b968..f9b333b26 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -341,10 +341,15 @@ public T visitExistsSubquery(ExistsSubquery node, C context) { public T visitWindow(Window node, C context) { return visitChildren(node, context); } + public T visitCidr(Cidr node, C context) { return visitChildren(node, context); } + public T visitGeoIp(GeoIp node, C context) { + return visitChildren(node, context); + } + public T visitFlatten(Flatten flatten, C context) { return visitChildren(flatten, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java index 0cc27b6a9..c8482a4ff 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java @@ -12,7 +12,7 @@ import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.Node; import java.util.List; @@ -23,7 +23,7 @@ @EqualsAndHashCode(callSuper = false) @RequiredArgsConstructor public class Eval extends UnresolvedPlan { - private final List expressionList; + private final List expressionList; private UnresolvedPlan child; @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/GeoIp.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/GeoIp.java new file mode 100644 index 000000000..feefa6929 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/GeoIp.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.AttributeList; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class GeoIp extends UnresolvedPlan { + private UnresolvedPlan child; + private final Field field; + private final UnresolvedExpression ipAddress; + private final AttributeList properties; + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitGeoIp(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } +} \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java index 2541b3743..e9fe0b2b8 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -8,10 +8,15 @@ import inet.ipaddr.AddressStringException; import inet.ipaddr.IPAddressString; import inet.ipaddr.IPAddressStringParameters; +import scala.Function1; import scala.Function2; import scala.Serializable; +import scala.runtime.AbstractFunction1; import scala.runtime.AbstractFunction2; +import java.math.BigInteger; +import java.net.InetAddress; +import java.net.UnknownHostException; public interface SerializableUdf { @@ -48,8 +53,51 @@ public Boolean apply(String ipAddress, String cidrBlock) { } return parsedCidrBlock.contains(parsedIpAddress); - } - }; + + }}; + + class geoIpUtils { + public static Function1 isIpv4 = new SerializableAbstractFunction1<>() { + + IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() + .allowEmpty(false) + .setEmptyAsLoopback(false) + .allow_inet_aton(false) + .allowSingleSegment(false) + .toParams(); + + @Override + public Boolean apply(String ipAddress) { + IPAddressString parsedIpAddress = new IPAddressString(ipAddress, valOptions); + + try { + parsedIpAddress.validate(); + } catch (AddressStringException e) { + throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage()); + } + + return parsedIpAddress.isIPv4(); + } + }; + + public static Function1 ipToInt = new SerializableAbstractFunction1<>() { + @Override + public BigInteger apply(String ipAddress) { + try { + InetAddress inetAddress = InetAddress.getByName(ipAddress); + byte[] addressBytes = inetAddress.getAddress(); + return new BigInteger(1, addressBytes); + } catch (UnknownHostException e) { + System.err.println("Invalid IP address: " + e.getMessage()); + } + return null; + } + }; + } + + abstract class SerializableAbstractFunction1 extends AbstractFunction1 + implements Serializable { + } abstract class SerializableAbstractFunction2 extends AbstractFunction2 implements Serializable { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index d7f59bae3..0a6e869ba 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -56,6 +56,7 @@ import org.opensearch.sql.ast.tree.FillNull; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Flatten; +import org.opensearch.sql.ast.tree.GeoIp; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; @@ -69,9 +70,11 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; +import org.opensearch.sql.ppl.utils.GeoIpCatalystLogicalPlanTranslator; import org.opensearch.sql.ppl.utils.ParseTransformer; import org.opensearch.sql.ppl.utils.SortUtils; import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils; @@ -562,19 +565,63 @@ public LogicalPlan visitRename(Rename node, CatalystPlanContext context) { public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { visitFirstChild(node, context); List aliases = new ArrayList<>(); - List letExpressions = node.getExpressionList(); - for (Let let : letExpressions) { - Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); - aliases.add(alias); + List expressions = node.getExpressionList(); + + // Geoip function modifies logical plan and is treated as QueryPlanVisitor instead of ExpressionVisitor + for (Node expr : expressions) { + if (expr instanceof Let) { + Let let = (Let) expr; + Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); + aliases.add(alias); + } else if (expr instanceof UnresolvedPlan) { + expr.accept(this, context); + } else { + throw new SyntaxCheckException("Unexpected node type when visiting EVAL"); + } } - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + + if (!aliases.isEmpty()) { + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + + visitExpressionList(aliases, context); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } else { + return context.getPlan(); } - List expressionList = visitExpressionList(aliases, context); - Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - // build the plan with the projection step - return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } + + @Override + public LogicalPlan visitGeoIp(GeoIp node, CatalystPlanContext context) { + visitExpression(node.getProperties(), context); + List attributeList = new ArrayList<>(); + + while (!context.getNamedParseExpressions().isEmpty()) { + Expression nextExpression = context.getNamedParseExpressions().pop(); + String attributeName = nextExpression.toString(); + + if (attributeList.contains(attributeName)) { + throw new IllegalStateException("Duplicate attribute in GEOIP attribute list"); + } + + attributeList.add(0, attributeName); + } + + String fieldExpression = node.getField().getField().toString(); + Expression ipAddressExpression = visitExpression(node.getIpAddress(), context); + + return GeoIpCatalystLogicalPlanTranslator.getGeoipLogicalPlan( + new GeoIpCatalystLogicalPlanTranslator.GeoIpParameters( + fieldExpression, + ipAddressExpression, + attributeList + ), + context + ); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index d4f9ece87..2ea23babf 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -336,10 +336,18 @@ public UnresolvedPlan visitSortCommand(OpenSearchPPLParser.SortCommandContext ct public UnresolvedPlan visitEvalCommand(OpenSearchPPLParser.EvalCommandContext ctx) { return new Eval( ctx.evalClause().stream() - .map(ct -> (Let) internalVisitExpression(ct)) + .map(ct -> (ct.geoipCommand() != null) ? visit(ct.geoipCommand()) : (Let) internalVisitExpression(ct)) .collect(Collectors.toList())); } + @Override + public UnresolvedPlan visitGeoipCommand(OpenSearchPPLParser.GeoipCommandContext ctx) { + Field field = (Field) internalVisitExpression(ctx.fieldExpression()); + UnresolvedExpression ipAddress = internalVisitExpression(ctx.ipAddress); + AttributeList properties = ctx.properties == null ? new AttributeList(Collections.emptyList()) : (AttributeList) internalVisitExpression(ctx.properties); + return new GeoIp(field, ipAddress, properties); + } + private List getGroupByList(OpenSearchPPLParser.ByClauseContext ctx) { return ctx.fieldList().fieldExpression().stream() .map(this::internalVisitExpression) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 1fe57d13e..f376c23ae 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -49,6 +49,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; +import org.opensearch.sql.ppl.utils.GeoIpCatalystLogicalPlanTranslator; import java.util.Arrays; import java.util.Collections; @@ -450,6 +451,26 @@ public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { return new LambdaFunction(function, arguments); } + @Override + public UnresolvedExpression visitGeoIpPropertyList(OpenSearchPPLParser.GeoIpPropertyListContext ctx) { + ImmutableList.Builder properties = ImmutableList.builder(); + if (ctx != null) { + for (OpenSearchPPLParser.GeoIpPropertyContext property : ctx.geoIpProperty()) { + String propertyName = property.getText().toUpperCase(); + + try { + GeoIpCatalystLogicalPlanTranslator.GeoIpProperty.valueOf(propertyName); + } catch (NullPointerException | IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid properties used."); + } + + properties.add(new Literal(propertyName, DataType.STRING)); + } + } + + return new AttributeList(properties.build()); + } + private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java new file mode 100644 index 000000000..2192c3485 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java @@ -0,0 +1,239 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.spark.SparkEnv; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.Alias$; +import org.apache.spark.sql.catalyst.expressions.And; +import org.apache.spark.sql.catalyst.expressions.CreateStruct; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.LessThan; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.expression.function.SerializableUdf; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; + +public interface GeoIpCatalystLogicalPlanTranslator { + String SPARK_CONF_KEY = "spark.geoip.tablename"; + String DEFAULT_GEOIP_TABLE_NAME = "geoip"; + String SOURCE_TABLE_ALIAS = "t1"; + String GEOIP_TABLE_ALIAS = "t2"; + List GEOIP_TABLE_COLUMNS = List.of( + "country_iso_code", + "country_name", + "continent_name", + "region_iso_code", + "region_name", + "city_name", + "time_zone", + "location" + ); + + /** + * Responsible to produce a Spark Logical Plan with given GeoIp command arguments, below is the sample logical plan + * with configuration [source=users, field=a, ipAddress=ip, properties=[country_name, city_name]] + * +- 'DataFrameDropColumns ['t2.country_iso_code, 't2.country_name, 't2.continent_name, 't2.region_iso_code, 't2.region_name, 't2.city_name, 't2.time_zone, 't2.location, 't2.cidr, 't2.start, 't2.end, 't2.ipv4] + * -- +- 'Project [*, named_struct(country_name, 't2.country_name, city_name, 't2.city_name) AS a#0] + * -- -- +- 'Join LeftOuter, (((ip_to_int('ip) >= 't2.start) AND (ip_to_int('ip) < 't2.end)) AND (is_ipv4('ip) = 't2.ipv4)) + * -- -- -- :- 'SubqueryAlias t1 + * -- -- -- -- : +- 'UnresolvedRelation [users], [], false + * -- -- -- +- 'SubqueryAlias t2 + * -- -- -- -- -- +- 'UnresolvedRelation [geoip], [], false + * . + * And the corresponded SQL query: + * . + * SELECT users.*, struct(geoip.country_name, geoip.city_name) AS a + * FROM users, geoip + * WHERE geoip.ip_range_start <= ip_to_int(users.ip) + * AND geoip.ip_range_end > ip_to_int(users.ip) + * AND geoip.ip_type = is_ipv4(users.ip); + * + * @param parameters GeoIp function parameters. + * @param context Context instance to retrieved Expression in resolved form. + * @return a LogicalPlan which will project new col with geoip location based on given ipAddresses. + */ + static LogicalPlan getGeoipLogicalPlan(GeoIpParameters parameters, CatalystPlanContext context) { + applyJoin(parameters.getIpAddress(), context); + return applyProjection(parameters.getField(), parameters.getProperties(), context); + } + + /** + * Responsible to produce join plan for GeoIp command, below is the sample logical plan + * with configuration [source=users, ipAddress=ip] + * +- 'Join LeftOuter, (((ip_to_int('ip) >= 't2.start) AND (ip_to_int('ip) < 't2.end)) AND (is_ipv4('ip) = 't2.ipv4)) + * -- :- 'SubqueryAlias t1 + * -- -- : +- 'UnresolvedRelation [users], [], false + * -- +- 'SubqueryAlias t2 + * -- -- -- +- 'UnresolvedRelation [geoip], [], false + * + * @param ipAddress Expression representing ip addresses to be queried. + * @param context Context instance to retrieved Expression in resolved form. + * @return a LogicalPlan which will perform join based on ip within cidr range in geoip table. + */ + static private LogicalPlan applyJoin(Expression ipAddress, CatalystPlanContext context) { + return context.apply(left -> { + LogicalPlan right = new UnresolvedRelation(seq(getGeoipTableName()), CaseInsensitiveStringMap.empty(), false); + LogicalPlan leftAlias = SubqueryAlias$.MODULE$.apply(SOURCE_TABLE_ALIAS, left); + LogicalPlan rightAlias = SubqueryAlias$.MODULE$.apply(GEOIP_TABLE_ALIAS, right); + Optional joinCondition = Optional.of(new And( + new And( + new GreaterThanOrEqual( + getIpInt(ipAddress), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_start")) + ), + new LessThan( + getIpInt(ipAddress), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_end")) + ) + ), + new EqualTo( + getIsIpv4(ipAddress), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ipv4")) + ) + )); + context.retainAllNamedParseExpressions(p -> p); + context.retainAllPlans(p -> p); + return join(leftAlias, + rightAlias, + Join.JoinType.LEFT, + joinCondition, + new Join.JoinHint()); + }); + } + + /** + * Responsible to produce a Spark Logical Plan with given GeoIp command arguments, below is the sample logical plan + * with configuration [source=users, field=a, properties=[country_name, city_name]] + * +- 'DataFrameDropColumns ['t2.country_iso_code, 't2.country_name, 't2.continent_name, 't2.region_iso_code, 't2.region_name, 't2.city_name, 't2.time_zone, 't2.location, 't2.cidr, 't2.start, 't2.end, 't2.ipv4] + * -- +- 'Project [*, named_struct(country_name, 't2.country_name, city_name, 't2.city_name) AS a#0] + * + * @param field Name of new eval geoip column. + * @param properties List of geo properties to be returned. + * @param context Context instance to retrieved Expression in resolved form. + * @return a LogicalPlan which will return source table and new eval geoip column. + */ + static private LogicalPlan applyProjection(String field, List properties, CatalystPlanContext context) { + List projectExpressions = new ArrayList<>(); + projectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); + + List geoIpStructFields = createGeoIpStructFields(properties); + Expression columnValue = (geoIpStructFields.size() == 1)? + geoIpStructFields.get(0) : CreateStruct.apply(seq(geoIpStructFields)); + + NamedExpression geoCol = Alias$.MODULE$.apply( + columnValue, + field, + NamedExpression.newExprId(), + seq(new ArrayList<>()), + Option.empty(), + seq(new ArrayList<>())); + + projectExpressions.add(geoCol); + + List dropList = createGeoIpStructFields(new ArrayList<>()); + dropList.addAll(List.of( + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"cidr")), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_start")), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_end")), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ipv4")) + )); + + context.apply(p -> new Project(seq(projectExpressions), p)); + return context.apply(p -> new DataFrameDropColumns(seq(dropList), p)); + } + + static private List createGeoIpStructFields(List attributeList) { + List attributeListToUse; + if (attributeList == null || attributeList.isEmpty()) { + attributeListToUse = GEOIP_TABLE_COLUMNS; + } else { + attributeListToUse = attributeList; + } + + return attributeListToUse.stream() + .map(a -> UnresolvedAttribute$.MODULE$.apply(seq( + GEOIP_TABLE_ALIAS, + a.toLowerCase(Locale.ROOT) + ))) + .collect(Collectors.toList()); + } + + static private Expression getIpInt(Expression ipAddress) { + return new ScalaUDF(SerializableUdf.geoIpUtils.ipToInt, + DataTypes.createDecimalType(38,0), + seq(ipAddress), + seq(), + Option.empty(), + Option.apply("ip_to_int"), + false, + true + ); + } + + static private Expression getIsIpv4(Expression ipAddress) { + return new ScalaUDF(SerializableUdf.geoIpUtils.isIpv4, + DataTypes.BooleanType, + seq(ipAddress), + seq(), Option.empty(), + Option.apply("is_ipv4"), + false, + true + ); + } + + static private String getGeoipTableName() { + String tableName = DEFAULT_GEOIP_TABLE_NAME; + + if (SparkEnv.get() != null && SparkEnv.get().conf() != null) { + tableName = SparkEnv.get().conf().get(SPARK_CONF_KEY, DEFAULT_GEOIP_TABLE_NAME); + } + + return tableName; + } + + @Getter + @AllArgsConstructor + class GeoIpParameters { + private final String field; + private final Expression ipAddress; + private final List properties; + } + + enum GeoIpProperty { + COUNTRY_ISO_CODE, + COUNTRY_NAME, + CONTINENT_NAME, + REGION_ISO_CODE, + REGION_NAME, + CITY_NAME, + TIME_ZONE, + LOCATION + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala new file mode 100644 index 000000000..2fd961312 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala @@ -0,0 +1,287 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, CreateNamedStruct, Descending, EqualTo, ExprId, GreaterThanOrEqual, In, LessThan, Literal, NamedExpression, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.{LeftOuter, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.types.DataTypes + +class PPLLogicalPlanGeoipFunctionTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + private def getGeoIpQueryPlan( + ipAddress: UnresolvedAttribute, + left : LogicalPlan, + right : LogicalPlan, + projectionProperties : Alias + ) : LogicalPlan = { + val joinPlan = getJoinPlan(ipAddress, left, right) + getProjection(joinPlan, projectionProperties) + } + + private def getJoinPlan( + ipAddress: UnresolvedAttribute, + left : LogicalPlan, + right : LogicalPlan + ) : LogicalPlan = { + val is_ipv4 = ScalaUDF( + SerializableUdf.geoIpUtils.isIpv4, + DataTypes.BooleanType, + seq(ipAddress), + seq(), + Option.empty, + Option.apply("is_ipv4"), + false, + true + ) + val ip_to_int = ScalaUDF( + SerializableUdf.geoIpUtils.ipToInt, + DataTypes.createDecimalType(38, 0), + seq(ipAddress), + seq(), + Option.empty, + Option.apply("ip_to_int"), + false, + true + ) + + val t1 = SubqueryAlias("t1", left) + val t2 = SubqueryAlias("t2", right) + + val joinCondition = And( + And( + GreaterThanOrEqual(ip_to_int, UnresolvedAttribute("t2.ip_range_start")), + LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end")) + ), + EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4")) + ) + Join(t1, t2, LeftOuter, Some(joinCondition), JoinHint.NONE) + } + + private def getProjection(joinPlan : LogicalPlan, projectionProperties : Alias) : LogicalPlan = { + val projection = Project(Seq(UnresolvedStar(None), projectionProperties), joinPlan) + val dropList = Seq( + "t2.country_iso_code", "t2.country_name", "t2.continent_name", + "t2.region_iso_code", "t2.region_name", "t2.city_name", + "t2.time_zone", "t2.location", "t2.cidr", "t2.ip_range_start", "t2.ip_range_end", "t2.ipv4" + ).map(UnresolvedAttribute(_)) + DataFrameDropColumns(dropList, projection) + } + + test("test geoip function - only ip_address provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source = users | eval a = geoip(ip_address)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val projectionStruct = CreateNamedStruct(Seq( + Literal("country_iso_code"), UnresolvedAttribute("t2.country_iso_code"), + Literal("country_name"), UnresolvedAttribute("t2.country_name"), + Literal("continent_name"), UnresolvedAttribute("t2.continent_name"), + Literal("region_iso_code"), UnresolvedAttribute("t2.region_iso_code"), + Literal("region_name"), UnresolvedAttribute("t2.region_name"), + Literal("city_name"), UnresolvedAttribute("t2.city_name"), + Literal("time_zone"), UnresolvedAttribute("t2.time_zone"), + Literal("location"), UnresolvedAttribute("t2.location") + )) + val structProjection = Alias(projectionStruct, "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - source has same name as join alias") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(ip_address, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t1")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + + test("test geoip function - ipAddress col exist in geoip table") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("cidr") + val sourceTable = UnresolvedRelation(seq("t1")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - duplicate parameters") { + val context = new CatalystPlanContext + + val exception = intercept[IllegalStateException]{ + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name, country_name)"), + context) + } + + assert(exception.getMessage.contains("Duplicate attribute in GEOIP attribute list")) + } + + test("test geoip function - one property provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=users | eval a = geoip(ip_address, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - multiple properties provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=users | eval a = geoip(ip_address,country_name,location)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + val projectionStruct = CreateNamedStruct(Seq( + Literal("country_name"), UnresolvedAttribute("t2.country_name"), + Literal("location"), UnresolvedAttribute("t2.location") + )) + val structProjection = Alias(projectionStruct, "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - multiple geoip calls") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.country_iso_code"), "a")() + val colAPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionA) + + val structProjectionB = Alias(UnresolvedAttribute("t2.region_iso_code"), "b")() + val colBPlan = getGeoIpQueryPlan(ipAddress, colAPlan, geoTable, structProjectionB) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colBPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - other eval function used between geoip") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.time_zone"), "a")() + val colAPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionA) + + val structProjectionC = Alias(UnresolvedAttribute("t2.region_name"), "c")() + val colCPlan = getGeoIpQueryPlan(ipAddress, colAPlan, geoTable, structProjectionC) + + val randProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "b")()) + val colBPlan = Project(randProjectList, colCPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colBPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - other eval function used before geoip") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = rand(), b = geoip(ip_address, city_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionB = Alias(UnresolvedAttribute("t2.city_name"), "b")() + val colBPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionB) + + val randProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "a")()) + val colAPlan = Project(randProjectList, colBPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colAPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} From daf11fa84a403056a3dbdb7b024587dec1f5229b Mon Sep 17 00:00:00 2001 From: Kenrick Yap Date: Wed, 18 Dec 2024 00:17:07 +0000 Subject: [PATCH 2/9] Fixed integration tests Signed-off-by: Kenrick Yap --- .../org/opensearch/flint/spark/FlintSparkSuite.scala | 8 ++++---- .../flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 36d0c6a7d..b9b9e6f1c 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -806,8 +806,8 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | city_name STRING, | time_zone STRING, | location STRING, - | ip_range_start BIGINT, - | ip_range_end BIGINT, + | ip_range_start DECIMAL(38,0), + | ip_range_end DECIMAL(38,0), | ipv4 BOOLEAN | ) | USING $tableType $tableOptions @@ -830,9 +830,9 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | 1123655167, | true | ), - | VALUES ( + | ( | '2a09:bac2:19f8::/45', - | `'CA', + | 'CA', | 'Canada', | 'North America', | 'PE', diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala index 3e2728553..9eac24223 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala @@ -41,10 +41,11 @@ class FlintSparkPPLGeoipITSuite // Retrieve the results val results: Array[Row] = frame.collect() + // Define the expected results val expectedResults: Array[Row] = Array( - Row("66.249.157.90", Row("JM", "Jamaica", "North America", 14, "Saint Catherine Parish", "Portmore", "America/Jamaica", "17.9686, -76.8827")), - Row("2a09:bac2:19f8:2ac3::", Row("CA", "Canada", "North America", "PE", "Prince Edward Island", "Charlottetown", "America/Halifax", "46.2396, -63.1355")) + Row("66.249.157.90", Row("JM", "Jamaica", "North America", "14", "Saint Catherine Parish", "Portmore", "America/Jamaica", "17.9686,-76.8827")), + Row("2a09:bac2:19f8:2ac3::", Row("CA", "Canada", "North America", "PE", "Prince Edward Island", "Charlottetown", "America/Halifax", "46.2396,-63.1355")) ) // Compare the results From 543740d3ad5d39de529bc68b9ce7be4341596819 Mon Sep 17 00:00:00 2001 From: Kenrick Yap Date: Wed, 18 Dec 2024 18:30:53 +0000 Subject: [PATCH 3/9] linting Signed-off-by: Kenrick Yap --- .../flint/spark/FlintSparkSuite.scala | 12 +-- .../spark/ppl/FlintSparkPPLGeoipITSuite.scala | 45 ++++++--- ...PlanGeoipFunctionTranslatorTestSuite.scala | 97 +++++++++++-------- 3 files changed, 89 insertions(+), 65 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index b9b9e6f1c..071c6ba1b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -772,8 +772,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit } protected def createGeoIpTestTable(testTable: String): Unit = { - sql( - s""" + sql(s""" | CREATE TABLE $testTable | ( | ip STRING, @@ -782,8 +781,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | USING $tableType $tableOptions |""".stripMargin) - sql( - s""" + sql(s""" | INSERT INTO $testTable | VALUES ('66.249.157.90', true), | ('2a09:bac2:19f8:2ac3::', true), @@ -793,8 +791,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit } protected def createGeoIpTable(): Unit = { - sql( - s""" + sql(s""" | CREATE TABLE geoip | ( | cidr STRING, @@ -813,8 +810,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | USING $tableType $tableOptions |""".stripMargin) - sql( - s""" + sql(s""" | INSERT INTO geoip | VALUES ( | '66.249.157.0/24', diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala index 9eac24223..5dd4e77b7 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala @@ -9,7 +9,7 @@ import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLGeoipITSuite - extends QueryTest + extends QueryTest with LogicalPlanTestUtils with FlintPPLSuite with StreamTest { @@ -34,8 +34,7 @@ class FlintSparkPPLGeoipITSuite } test("test geoip with no parameters") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| where isValid = true | eval a = geoip(ip) | fields ip, a | """.stripMargin) @@ -44,9 +43,28 @@ class FlintSparkPPLGeoipITSuite // Define the expected results val expectedResults: Array[Row] = Array( - Row("66.249.157.90", Row("JM", "Jamaica", "North America", "14", "Saint Catherine Parish", "Portmore", "America/Jamaica", "17.9686,-76.8827")), - Row("2a09:bac2:19f8:2ac3::", Row("CA", "Canada", "North America", "PE", "Prince Edward Island", "Charlottetown", "America/Halifax", "46.2396,-63.1355")) - ) + Row( + "66.249.157.90", + Row( + "JM", + "Jamaica", + "North America", + "14", + "Saint Catherine Parish", + "Portmore", + "America/Jamaica", + "17.9686,-76.8827")), + Row( + "2a09:bac2:19f8:2ac3::", + Row( + "CA", + "Canada", + "North America", + "PE", + "Prince Edward Island", + "Charlottetown", + "America/Halifax", + "46.2396,-63.1355"))) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -54,18 +72,15 @@ class FlintSparkPPLGeoipITSuite } test("test geoip with one parameters") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| where isValid = true | eval a = geoip(ip, country_name) | fields ip, a | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("66.249.157.90", "Jamaica"), - Row("2a09:bac2:19f8:2ac3::", "Canada") - ) + val expectedResults: Array[Row] = + Array(Row("66.249.157.90", "Jamaica"), Row("2a09:bac2:19f8:2ac3::", "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -73,8 +88,7 @@ class FlintSparkPPLGeoipITSuite } test("test geoip with multiple parameters") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| where isValid = true | eval a = geoip(ip, country_name, city_name) | fields ip, a | """.stripMargin) @@ -83,8 +97,7 @@ class FlintSparkPPLGeoipITSuite // Define the expected results val expectedResults: Array[Row] = Array( Row("66.249.157.90", Row("Jamaica", "Portmore")), - Row("2a09:bac2:19f8:2ac3::", Row("Canada", "Charlottetown")) - ) + Row("2a09:bac2:19f8:2ac3::", Row("Canada", "Charlottetown"))) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala index 2fd961312..a20bc4ecd 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala @@ -19,7 +19,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Join, import org.apache.spark.sql.types.DataTypes class PPLLogicalPlanGeoipFunctionTranslatorTestSuite - extends SparkFunSuite + extends SparkFunSuite with PlanTest with LogicalPlanTestUtils with Matchers { @@ -28,20 +28,18 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite private val pplParser = new PPLSyntaxParser() private def getGeoIpQueryPlan( - ipAddress: UnresolvedAttribute, - left : LogicalPlan, - right : LogicalPlan, - projectionProperties : Alias - ) : LogicalPlan = { + ipAddress: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan, + projectionProperties: Alias): LogicalPlan = { val joinPlan = getJoinPlan(ipAddress, left, right) getProjection(joinPlan, projectionProperties) } private def getJoinPlan( - ipAddress: UnresolvedAttribute, - left : LogicalPlan, - right : LogicalPlan - ) : LogicalPlan = { + ipAddress: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan): LogicalPlan = { val is_ipv4 = ScalaUDF( SerializableUdf.geoIpUtils.isIpv4, DataTypes.BooleanType, @@ -50,8 +48,7 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite Option.empty, Option.apply("is_ipv4"), false, - true - ) + true) val ip_to_int = ScalaUDF( SerializableUdf.geoIpUtils.ipToInt, DataTypes.createDecimalType(38, 0), @@ -60,8 +57,7 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite Option.empty, Option.apply("ip_to_int"), false, - true - ) + true) val t1 = SubqueryAlias("t1", left) val t2 = SubqueryAlias("t2", right) @@ -69,20 +65,26 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite val joinCondition = And( And( GreaterThanOrEqual(ip_to_int, UnresolvedAttribute("t2.ip_range_start")), - LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end")) - ), - EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4")) - ) + LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end"))), + EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4"))) Join(t1, t2, LeftOuter, Some(joinCondition), JoinHint.NONE) } - private def getProjection(joinPlan : LogicalPlan, projectionProperties : Alias) : LogicalPlan = { + private def getProjection(joinPlan: LogicalPlan, projectionProperties: Alias): LogicalPlan = { val projection = Project(Seq(UnresolvedStar(None), projectionProperties), joinPlan) val dropList = Seq( - "t2.country_iso_code", "t2.country_name", "t2.continent_name", - "t2.region_iso_code", "t2.region_name", "t2.city_name", - "t2.time_zone", "t2.location", "t2.cidr", "t2.ip_range_start", "t2.ip_range_end", "t2.ipv4" - ).map(UnresolvedAttribute(_)) + "t2.country_iso_code", + "t2.country_name", + "t2.continent_name", + "t2.region_iso_code", + "t2.region_name", + "t2.city_name", + "t2.time_zone", + "t2.location", + "t2.cidr", + "t2.ip_range_start", + "t2.ip_range_end", + "t2.ipv4").map(UnresolvedAttribute(_)) DataFrameDropColumns(dropList, projection) } @@ -98,16 +100,24 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite val sourceTable = UnresolvedRelation(seq("users")) val geoTable = UnresolvedRelation(seq("geoip")) - val projectionStruct = CreateNamedStruct(Seq( - Literal("country_iso_code"), UnresolvedAttribute("t2.country_iso_code"), - Literal("country_name"), UnresolvedAttribute("t2.country_name"), - Literal("continent_name"), UnresolvedAttribute("t2.continent_name"), - Literal("region_iso_code"), UnresolvedAttribute("t2.region_iso_code"), - Literal("region_name"), UnresolvedAttribute("t2.region_name"), - Literal("city_name"), UnresolvedAttribute("t2.city_name"), - Literal("time_zone"), UnresolvedAttribute("t2.time_zone"), - Literal("location"), UnresolvedAttribute("t2.location") - )) + val projectionStruct = CreateNamedStruct( + Seq( + Literal("country_iso_code"), + UnresolvedAttribute("t2.country_iso_code"), + Literal("country_name"), + UnresolvedAttribute("t2.country_name"), + Literal("continent_name"), + UnresolvedAttribute("t2.continent_name"), + Literal("region_iso_code"), + UnresolvedAttribute("t2.region_iso_code"), + Literal("region_name"), + UnresolvedAttribute("t2.region_name"), + Literal("city_name"), + UnresolvedAttribute("t2.city_name"), + Literal("time_zone"), + UnresolvedAttribute("t2.time_zone"), + Literal("location"), + UnresolvedAttribute("t2.location"))) val structProjection = Alias(projectionStruct, "a")() val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) @@ -135,7 +145,6 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - test("test geoip function - ipAddress col exist in geoip table") { val context = new CatalystPlanContext @@ -158,7 +167,7 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite test("test geoip function - duplicate parameters") { val context = new CatalystPlanContext - val exception = intercept[IllegalStateException]{ + val exception = intercept[IllegalStateException] { planTransformer.visit( plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name, country_name)"), context) @@ -197,10 +206,12 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite val ipAddress = UnresolvedAttribute("ip_address") val sourceTable = UnresolvedRelation(seq("users")) val geoTable = UnresolvedRelation(seq("geoip")) - val projectionStruct = CreateNamedStruct(Seq( - Literal("country_name"), UnresolvedAttribute("t2.country_name"), - Literal("location"), UnresolvedAttribute("t2.location") - )) + val projectionStruct = CreateNamedStruct( + Seq( + Literal("country_name"), + UnresolvedAttribute("t2.country_name"), + Literal("location"), + UnresolvedAttribute("t2.location"))) val structProjection = Alias(projectionStruct, "a")() val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) @@ -214,7 +225,9 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite val logPlan = planTransformer.visit( - plan(pplParser, "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"), + plan( + pplParser, + "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"), context) val ipAddress = UnresolvedAttribute("ip_address") @@ -237,7 +250,9 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite val logPlan = planTransformer.visit( - plan(pplParser, "source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"), + plan( + pplParser, + "source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"), context) val ipAddress = UnresolvedAttribute("ip_address") From 91ac8e5e96fa5bcd43eb40e7f05bc92325a058bd Mon Sep 17 00:00:00 2001 From: Kenrick Yap <14yapkc1@gmail.com> Date: Thu, 19 Dec 2024 00:35:46 -0800 Subject: [PATCH 4/9] addressing PR comments (added addtional integ tests, doc changes) Signed-off-by: Kenrick Yap <14yapkc1@gmail.com> --- docs/ppl-lang/functions/ppl-ip.md | 2 +- docs/ppl-lang/planning/ppl-geoip-command.md | 22 +- .../flint/spark/FlintSparkSuite.scala | 9 +- .../spark/ppl/FlintSparkPPLGeoipITSuite.scala | 218 +++++++++++++++++- .../expression/function/SerializableUdf.java | 36 ++- .../GeoIpCatalystLogicalPlanTranslator.java | 59 ++--- ...PlanGeoipFunctionTranslatorTestSuite.scala | 70 ++++-- 7 files changed, 336 insertions(+), 80 deletions(-) diff --git a/docs/ppl-lang/functions/ppl-ip.md b/docs/ppl-lang/functions/ppl-ip.md index 1a11f8b45..65cc9dac9 100644 --- a/docs/ppl-lang/functions/ppl-ip.md +++ b/docs/ppl-lang/functions/ppl-ip.md @@ -41,7 +41,7 @@ Note: `GEOIP(ip[, property]...)` retrieves geospatial data corresponding to the provided `ip`. **Argument type:** -- `ip` is string be **STRING**. +- `ip` is string be **STRING** representing an IPv4 or an IPv6 address. - `property` is **STRING** and must be one of the following: - `COUNTRY_ISO_CODE` - `COUNTRY_NAME` diff --git a/docs/ppl-lang/planning/ppl-geoip-command.md b/docs/ppl-lang/planning/ppl-geoip-command.md index 142516ca3..59ca2df08 100644 --- a/docs/ppl-lang/planning/ppl-geoip-command.md +++ b/docs/ppl-lang/planning/ppl-geoip-command.md @@ -3,16 +3,12 @@ geoip function to add information about the geographical location of an IPv4 or IPv6 address **Implementation syntax** -- `... | eval geoinfo = geoip([datasource,] ipAddress *[,properties])` +- `... | eval geoinfo = geoip(ipAddress *[,properties])` - generic syntax - `... | eval geoinfo = geoip(ipAddress)` -- use the default geoip datasource -- `... | eval geoinfo = geoip("abc", ipAddress)` -- use the "abc" geoip datasource +- retrieves all geo data - `... | eval geoinfo = geoip(ipAddress, city, location)` -- use the default geoip datasource, retrieve only city, and location -- `... | eval geoinfo = geoip("abc", ipAddress, city, location")` -- use the "abc" geoip datasource, retrieve only city, and location +- retrieve only city, and location **Implementation details** - Current implementation requires user to have created a geoip table. Geoip table has the following schema: @@ -44,13 +40,23 @@ geoip function to add information about the geographical location of an IPv4 or AND geoip.ip_range_end > ip_to_int(source.ip) AND geoip.ip_type = is_ipv4(source.ip); ``` +- in the case that only one property is provided in function call, `geoip` returns string of specified property instead: + + ```SQL + SELECT source.*, geoip.country_name AS a + FROM source, geoip + WHERE geoip.ip_range_start <= ip_to_int(source.ip) + AND geoip.ip_range_end > ip_to_int(source.ip) + AND geoip.ip_type = is_ipv4(source.ip); + ``` **Future plan for additional data-sources** - Currently only using pre-existing geoip table defined within spark is possible. - There is future plans to allow users to specify data sources: - API data sources - if users have their own geoip provided will create ability for users to configure and call said endpoints - - OpenSearch geospatial client - once geospatial client is published we can leverage client to utilize opensearch geo2ip functionality. + - OpenSearch geospatial client - once geospatial client is published we can leverage client to utilize OpenSearch geo2ip functionality. +- Additional datasource connection params will be provided through spark config options. ### New syntax definition in ANTLR diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 071c6ba1b..5ea123c9d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -776,6 +776,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | CREATE TABLE $testTable | ( | ip STRING, + | ipv4 STRING, | isValid BOOLEAN | ) | USING $tableType $tableOptions @@ -783,10 +784,10 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit sql(s""" | INSERT INTO $testTable - | VALUES ('66.249.157.90', true), - | ('2a09:bac2:19f8:2ac3::', true), - | ('192.168.2.', false), - | ('2001:db8::ff00:12:', false) + | VALUES ('66.249.157.90', '66.249.157.90', true), + | ('2a09:bac2:19f8:2ac3::', 'Given IPv6 is not mapped to IPv4', true), + | ('192.168.2.', '192.168.2.', false), + | ('2001:db8::ff00:12:', 'Given IPv6 is not mapped to IPv4', false) | """.stripMargin) } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala index 5dd4e77b7..9c51f69ee 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala @@ -5,7 +5,17 @@ package org.opensearch.flint.spark.ppl +import java.util + +import org.opensearch.sql.expression.function.SerializableUdf.visit +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, CreateNamedStruct, EqualTo, Expression, GreaterThanOrEqual, LessThan, Literal} +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Filter, Join, JoinHint, LogicalPlan, Project, SubqueryAlias} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLGeoipITSuite @@ -33,9 +43,54 @@ class FlintSparkPPLGeoipITSuite } } + private def getGeoIpQueryPlan( + ipAddress: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan, + projectionProperties: Alias): LogicalPlan = { + val joinPlan = getJoinPlan(ipAddress, left, right) + getProjection(joinPlan, projectionProperties) + } + + private def getJoinPlan( + ipAddress: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan): LogicalPlan = { + val is_ipv4 = visit("is_ipv4", util.List.of[Expression](ipAddress)) + val ip_to_int = visit("ip_to_int", util.List.of[Expression](ipAddress)) + + val t1 = SubqueryAlias("t1", left) + val t2 = SubqueryAlias("t2", right) + + val joinCondition = And( + And( + GreaterThanOrEqual(ip_to_int, UnresolvedAttribute("t2.ip_range_start")), + LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end"))), + EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4"))) + Join(t1, t2, LeftOuter, Some(joinCondition), JoinHint.NONE) + } + + private def getProjection(joinPlan: LogicalPlan, projectionProperties: Alias): LogicalPlan = { + val projection = Project(Seq(UnresolvedStar(None), projectionProperties), joinPlan) + val dropList = Seq( + "t2.country_iso_code", + "t2.country_name", + "t2.continent_name", + "t2.region_iso_code", + "t2.region_name", + "t2.city_name", + "t2.time_zone", + "t2.location", + "t2.cidr", + "t2.ip_range_start", + "t2.ip_range_end", + "t2.ipv4").map(UnresolvedAttribute(_)) + DataFrameDropColumns(dropList, projection) + } + test("test geoip with no parameters") { val frame = sql(s""" - | source = $testTable| where isValid = true | eval a = geoip(ip) | fields ip, a + | source = $testTable | where isValid = true | eval a = geoip(ip) | fields ip, a | """.stripMargin) // Retrieve the results @@ -69,11 +124,44 @@ class FlintSparkPPLGeoipITSuite // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(seq(testTable))) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + val projectionStruct = CreateNamedStruct( + Seq( + Literal("country_iso_code"), + UnresolvedAttribute("t2.country_iso_code"), + Literal("country_name"), + UnresolvedAttribute("t2.country_name"), + Literal("continent_name"), + UnresolvedAttribute("t2.continent_name"), + Literal("region_iso_code"), + UnresolvedAttribute("t2.region_iso_code"), + Literal("region_name"), + UnresolvedAttribute("t2.region_name"), + Literal("city_name"), + UnresolvedAttribute("t2.city_name"), + Literal("time_zone"), + UnresolvedAttribute("t2.time_zone"), + Literal("location"), + UnresolvedAttribute("t2.location"))) + val structProjection = Alias(projectionStruct, "a")() + val geoIpPlan = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjection) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } test("test geoip with one parameters") { val frame = sql(s""" - | source = $testTable| where isValid = true | eval a = geoip(ip, country_name) | fields ip, a + | source = $testTable | where isValid = true | eval a = geoip(ip, country_name) | fields ip, a | """.stripMargin) // Retrieve the results @@ -85,11 +173,26 @@ class FlintSparkPPLGeoipITSuite // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(seq(testTable))) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + val geoIpPlan = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjection) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } test("test geoip with multiple parameters") { val frame = sql(s""" - | source = $testTable| where isValid = true | eval a = geoip(ip, country_name, city_name) | fields ip, a + | source = $testTable | where isValid = true | eval a = geoip(ip, country_name, city_name) | fields ip, a | """.stripMargin) // Retrieve the results @@ -102,5 +205,114 @@ class FlintSparkPPLGeoipITSuite // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(seq(testTable))) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + val projectionStruct = CreateNamedStruct( + Seq( + Literal("country_name"), + UnresolvedAttribute("t2.country_name"), + Literal("city_name"), + UnresolvedAttribute("t2.city_name"))) + val structProjection = Alias(projectionStruct, "a")() + val geoIpPlan = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjection) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test geoip with partial projection on evaluated fields") { + val frame = sql(s""" + | source = $testTable | where isValid = true | eval a = geoip(ip, city_name), b = geoip(ip, country_name) | fields ip, b + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("66.249.157.90", "Portmore"), Row("2a09:bac2:19f8:2ac3::", "Charlottetown")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(seq(testTable))) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.city_name"), "a")() + val geoIpPlanA = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjectionA) + + val structProjectionB = Alias(UnresolvedAttribute("t2.country_name"), "b")() + val geoIpPlanB = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), geoIpPlanA, geoTable, structProjectionB) + + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ip"), UnresolvedAttribute("b")), geoIpPlanB) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test geoip with projection on field that exists in both source and geoip table") { + val frame = sql(s""" + | source = $testTable| where isValid = true | eval a = geoip(ip, country_name) | fields ipv4, a + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array( + Row("66.249.157.90", "Portmore"), + Row("Given IPv6 is not mapped to IPv4", "Charlottetown")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Compare the logical plans + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sourceTable: LogicalPlan = Filter( + EqualTo(UnresolvedAttribute("isValid"), Literal(true)), + UnresolvedRelation(seq(testTable))) + val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + val geoIpPlan = + getGeoIpQueryPlan(UnresolvedAttribute("ip"), sourceTable, geoTable, structProjection) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("ipv4"), UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test geoip with invalid parameter") { + val frame = sql(s""" + | source = $testTable | where isValid = true | eval a = geoip(ip, invalid_param) | fields ip, a + | """.stripMargin) + + // Retrieve the results + assertThrows[SparkException](frame.collect()) + } + + test("test geoip with invalid ip address provided") { + val frame = sql(s""" + | source = $testTable | eval a = geoip(ip) | fields ip, a + | """.stripMargin) + + // Retrieve the results + assertThrows[SparkException](frame.collect()) } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java index 175ada0e2..596a6a295 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -143,10 +143,17 @@ public Boolean apply(String ipAddress, String cidrBlock) { } return parsedCidrBlock.contains(parsedIpAddress); - - }}; + } + }; class geoIpUtils { + /** + * Append values to JSON arrays based on specified path-values. + * + * @param jsonStr The input JSON string. + * @param elements A list of path-values where the first item is the path and subsequent items are values to append. + * @return The updated JSON string. + */ public static Function1 isIpv4 = new SerializableAbstractFunction1<>() { IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() @@ -170,6 +177,13 @@ public Boolean apply(String ipAddress) { } }; + /** + * Append values to JSON arrays based on specified path-values. + * + * @param jsonStr The input JSON string. + * @param elements A list of path-values where the first item is the path and subsequent items are values to append. + * @return The updated JSON string. + */ public static Function1 ipToInt = new SerializableAbstractFunction1<>() { @Override public BigInteger apply(String ipAddress) { @@ -224,6 +238,24 @@ static ScalaUDF visit(String funcName, List expressions) { Option.apply("json_append"), false, true); + case "is_ipv4": + return new ScalaUDF(geoIpUtils.isIpv4, + DataTypes.BooleanType, + seq(expressions), + seq(), + Option.empty(), + Option.apply("is_ipv4"), + false, + true); + case "ip_to_int": + return new ScalaUDF(geoIpUtils.ipToInt, + DataTypes.createDecimalType(38,0), + seq(expressions), + seq(), + Option.empty(), + Option.apply("ip_to_int"), + false, + true); default: return null; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java index 2192c3485..153c46dbc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java @@ -19,12 +19,10 @@ import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Project; import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$; -import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.expression.function.SerializableUdf; @@ -32,29 +30,29 @@ import scala.Option; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Optional; import java.util.stream.Collectors; +import static java.util.List.of; + import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; public interface GeoIpCatalystLogicalPlanTranslator { String SPARK_CONF_KEY = "spark.geoip.tablename"; String DEFAULT_GEOIP_TABLE_NAME = "geoip"; + String GEOIP_CIDR_COLUMN_NAME = "cidr"; + String GEOIP_IP_RANGE_START_COLUMN_NAME = "ip_range_start"; + String GEOIP_IP_RANGE_END_COLUMN_NAME = "ip_range_end"; + String GEOIP_IPV4_COLUMN_NAME = "ipv4"; String SOURCE_TABLE_ALIAS = "t1"; String GEOIP_TABLE_ALIAS = "t2"; - List GEOIP_TABLE_COLUMNS = List.of( - "country_iso_code", - "country_name", - "continent_name", - "region_iso_code", - "region_name", - "city_name", - "time_zone", - "location" - ); + List GEOIP_TABLE_COLUMNS = Arrays.stream(GeoIpProperty.values()) + .map(Enum::name) + .collect(Collectors.toList()); /** * Responsible to produce a Spark Logical Plan with given GeoIp command arguments, below is the sample logical plan @@ -105,16 +103,16 @@ static private LogicalPlan applyJoin(Expression ipAddress, CatalystPlanContext c Optional joinCondition = Optional.of(new And( new And( new GreaterThanOrEqual( - getIpInt(ipAddress), + SerializableUdf.visit("ip_to_int", of(ipAddress)), UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_start")) ), new LessThan( - getIpInt(ipAddress), + SerializableUdf.visit("ip_to_int", of(ipAddress)), UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_end")) ) ), new EqualTo( - getIsIpv4(ipAddress), + SerializableUdf.visit("is_ipv4", of(ipAddress)), UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ipv4")) ) )); @@ -159,10 +157,10 @@ static private LogicalPlan applyProjection(String field, List properties List dropList = createGeoIpStructFields(new ArrayList<>()); dropList.addAll(List.of( - UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"cidr")), - UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_start")), - UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_end")), - UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ipv4")) + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS, GEOIP_CIDR_COLUMN_NAME)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS, GEOIP_IP_RANGE_START_COLUMN_NAME)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS, GEOIP_IP_RANGE_END_COLUMN_NAME)), + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS, GEOIP_IPV4_COLUMN_NAME)) )); context.apply(p -> new Project(seq(projectExpressions), p)); @@ -185,29 +183,6 @@ static private List createGeoIpStructFields(List attributeLi .collect(Collectors.toList()); } - static private Expression getIpInt(Expression ipAddress) { - return new ScalaUDF(SerializableUdf.geoIpUtils.ipToInt, - DataTypes.createDecimalType(38,0), - seq(ipAddress), - seq(), - Option.empty(), - Option.apply("ip_to_int"), - false, - true - ); - } - - static private Expression getIsIpv4(Expression ipAddress) { - return new ScalaUDF(SerializableUdf.geoIpUtils.isIpv4, - DataTypes.BooleanType, - seq(ipAddress), - seq(), Option.empty(), - Option.apply("is_ipv4"), - false, - true - ); - } - static private String getGeoipTableName() { String tableName = DEFAULT_GEOIP_TABLE_NAME; diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala index a20bc4ecd..460b9769c 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala @@ -5,15 +5,17 @@ package org.opensearch.flint.spark.ppl +import java.util + import org.opensearch.flint.spark.ppl.PlaneUtils.plan -import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.expression.function.SerializableUdf.visit import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, CreateNamedStruct, Descending, EqualTo, ExprId, GreaterThanOrEqual, In, LessThan, Literal, NamedExpression, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, CreateNamedStruct, Descending, EqualTo, Expression, ExprId, GreaterThanOrEqual, In, LessThan, Literal, NamedExpression, ScalaUDF, SortOrder} import org.apache.spark.sql.catalyst.plans.{LeftOuter, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} import org.apache.spark.sql.types.DataTypes @@ -40,24 +42,8 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite ipAddress: UnresolvedAttribute, left: LogicalPlan, right: LogicalPlan): LogicalPlan = { - val is_ipv4 = ScalaUDF( - SerializableUdf.geoIpUtils.isIpv4, - DataTypes.BooleanType, - seq(ipAddress), - seq(), - Option.empty, - Option.apply("is_ipv4"), - false, - true) - val ip_to_int = ScalaUDF( - SerializableUdf.geoIpUtils.ipToInt, - DataTypes.createDecimalType(38, 0), - seq(ipAddress), - seq(), - Option.empty, - Option.apply("ip_to_int"), - false, - true) + val is_ipv4 = visit("is_ipv4", util.List.of[Expression](ipAddress)) + val ip_to_int = visit("ip_to_int", util.List.of[Expression](ipAddress)) val t1 = SubqueryAlias("t1", left) val t2 = SubqueryAlias("t2", right) @@ -299,4 +285,48 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + + test("test geoip function - projection on evaluated field") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=users | eval a = geoip(ip_address, country_name) | fields a"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedAttribute("a")), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip with partial projection on evaluated fields") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code) | fields b"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.country_iso_code"), "a")() + val colAPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionA) + + val structProjectionB = Alias(UnresolvedAttribute("t2.region_iso_code"), "b")() + val colBPlan = getGeoIpQueryPlan(ipAddress, colAPlan, geoTable, structProjectionB) + + val expectedPlan = Project(Seq(UnresolvedAttribute("b")), colBPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } } From 89cdb4bf9498de94f4106ff482f77f420f808410 Mon Sep 17 00:00:00 2001 From: Kenrick Yap Date: Thu, 19 Dec 2024 09:49:01 +0000 Subject: [PATCH 5/9] fixed new integ tests Signed-off-by: Kenrick Yap --- .../spark/ppl/FlintSparkPPLGeoipITSuite.scala | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala index 9c51f69ee..7031ab067 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLGeoipITSuite.scala @@ -14,6 +14,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, CreateNamedStruct, EqualTo, Expression, GreaterThanOrEqual, LessThan, Literal} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Filter, Join, JoinHint, LogicalPlan, Project, SubqueryAlias} import org.apache.spark.sql.streaming.StreamTest @@ -130,7 +131,7 @@ class FlintSparkPPLGeoipITSuite val sourceTable: LogicalPlan = Filter( EqualTo(UnresolvedAttribute("isValid"), Literal(true)), - UnresolvedRelation(seq(testTable))) + UnresolvedRelation(testTable.split("\\.").toSeq)) val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) val projectionStruct = CreateNamedStruct( Seq( @@ -179,7 +180,7 @@ class FlintSparkPPLGeoipITSuite val sourceTable: LogicalPlan = Filter( EqualTo(UnresolvedAttribute("isValid"), Literal(true)), - UnresolvedRelation(seq(testTable))) + UnresolvedRelation(testTable.split("\\.").toSeq)) val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() val geoIpPlan = @@ -211,7 +212,7 @@ class FlintSparkPPLGeoipITSuite val sourceTable: LogicalPlan = Filter( EqualTo(UnresolvedAttribute("isValid"), Literal(true)), - UnresolvedRelation(seq(testTable))) + UnresolvedRelation(testTable.split("\\.").toSeq)) val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) val projectionStruct = CreateNamedStruct( Seq( @@ -237,7 +238,7 @@ class FlintSparkPPLGeoipITSuite val results: Array[Row] = frame.collect() // Define the expected results val expectedResults: Array[Row] = - Array(Row("66.249.157.90", "Portmore"), Row("2a09:bac2:19f8:2ac3::", "Charlottetown")) + Array(Row("66.249.157.90", "Jamaica"), Row("2a09:bac2:19f8:2ac3::", "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -248,7 +249,7 @@ class FlintSparkPPLGeoipITSuite val sourceTable: LogicalPlan = Filter( EqualTo(UnresolvedAttribute("isValid"), Literal(true)), - UnresolvedRelation(seq(testTable))) + UnresolvedRelation(testTable.split("\\.").toSeq)) val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) val structProjectionA = Alias(UnresolvedAttribute("t2.city_name"), "a")() @@ -267,16 +268,14 @@ class FlintSparkPPLGeoipITSuite test("test geoip with projection on field that exists in both source and geoip table") { val frame = sql(s""" - | source = $testTable| where isValid = true | eval a = geoip(ip, country_name) | fields ipv4, a + | source = $testTable | where isValid = true | eval a = geoip(ip, country_name) | fields ipv4, a | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results val expectedResults: Array[Row] = - Array( - Row("66.249.157.90", "Portmore"), - Row("Given IPv6 is not mapped to IPv4", "Charlottetown")) + Array(Row("66.249.157.90", "Jamaica"), Row("Given IPv6 is not mapped to IPv4", "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -287,7 +286,7 @@ class FlintSparkPPLGeoipITSuite val sourceTable: LogicalPlan = Filter( EqualTo(UnresolvedAttribute("isValid"), Literal(true)), - UnresolvedRelation(seq(testTable))) + UnresolvedRelation(testTable.split("\\.").toSeq)) val geoTable: LogicalPlan = UnresolvedRelation(seq("geoip")) val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() val geoIpPlan = @@ -299,12 +298,9 @@ class FlintSparkPPLGeoipITSuite } test("test geoip with invalid parameter") { - val frame = sql(s""" + assertThrows[ParseException](sql(s""" | source = $testTable | where isValid = true | eval a = geoip(ip, invalid_param) | fields ip, a - | """.stripMargin) - - // Retrieve the results - assertThrows[SparkException](frame.collect()) + | """.stripMargin)) } test("test geoip with invalid ip address provided") { From 26ab5b0fe4d6f75ce79907eb820fca9edea1eefc Mon Sep 17 00:00:00 2001 From: Kenrick Yap Date: Thu, 19 Dec 2024 10:16:15 +0000 Subject: [PATCH 6/9] addressing pr comments Signed-off-by: Kenrick Yap --- .../sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java index 153c46dbc..a57b26799 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java @@ -104,16 +104,16 @@ static private LogicalPlan applyJoin(Expression ipAddress, CatalystPlanContext c new And( new GreaterThanOrEqual( SerializableUdf.visit("ip_to_int", of(ipAddress)), - UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_start")) + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,GEOIP_IP_RANGE_START_COLUMN_NAME)) ), new LessThan( SerializableUdf.visit("ip_to_int", of(ipAddress)), - UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ip_range_end")) + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,GEOIP_IP_RANGE_END_COLUMN_NAME)) ) ), new EqualTo( SerializableUdf.visit("is_ipv4", of(ipAddress)), - UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,"ipv4")) + UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,GEOIP_IPV4_COLUMN_NAME)) ) )); context.retainAllNamedParseExpressions(p -> p); From e037672b3d995ae68f91d48fc31321c7e926502c Mon Sep 17 00:00:00 2001 From: Kenrick Yap Date: Thu, 19 Dec 2024 10:08:02 -0800 Subject: [PATCH 7/9] address review comments Signed-off-by: Kenrick Yap --- docs/ppl-lang/planning/ppl-geoip-command.md | 15 --------------- .../sql/ppl/parser/AstExpressionBuilder.java | 16 +++++++++------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/docs/ppl-lang/planning/ppl-geoip-command.md b/docs/ppl-lang/planning/ppl-geoip-command.md index 59ca2df08..aaed6c156 100644 --- a/docs/ppl-lang/planning/ppl-geoip-command.md +++ b/docs/ppl-lang/planning/ppl-geoip-command.md @@ -57,18 +57,3 @@ geoip function to add information about the geographical location of an IPv4 or - API data sources - if users have their own geoip provided will create ability for users to configure and call said endpoints - OpenSearch geospatial client - once geospatial client is published we can leverage client to utilize OpenSearch geo2ip functionality. - Additional datasource connection params will be provided through spark config options. - -### New syntax definition in ANTLR - -```ANTLR - -// functions -evalFunctionCall - : evalFunctionName LT_PRTHS functionArgs RT_PRTHS - | geoipFunction - ; - -geoipFunction - : GEOIP LT_PRTHS (datasource = functionArg COMMA)? ipAddress = functionArg (COMMA properties = stringLiteral)? RT_PRTHS - ; -``` diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index f376c23ae..4b2c90b71 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -457,13 +457,7 @@ public UnresolvedExpression visitGeoIpPropertyList(OpenSearchPPLParser.GeoIpProp if (ctx != null) { for (OpenSearchPPLParser.GeoIpPropertyContext property : ctx.geoIpProperty()) { String propertyName = property.getText().toUpperCase(); - - try { - GeoIpCatalystLogicalPlanTranslator.GeoIpProperty.valueOf(propertyName); - } catch (NullPointerException | IllegalArgumentException e) { - throw new IllegalArgumentException("Invalid properties used."); - } - + validateGeoIpProperty(propertyName); properties.add(new Literal(propertyName, DataType.STRING)); } } @@ -471,6 +465,14 @@ public UnresolvedExpression visitGeoIpPropertyList(OpenSearchPPLParser.GeoIpProp return new AttributeList(properties.build()); } + private static void validateGeoIpProperty(String propertyName) { + try { + GeoIpCatalystLogicalPlanTranslator.GeoIpProperty.valueOf(propertyName); + } catch (NullPointerException | IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid properties used."); + } + } + private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = From 406d2c93807eb85652f9c1ebed39d061208824d5 Mon Sep 17 00:00:00 2001 From: Kenrick Yap Date: Thu, 19 Dec 2024 10:15:57 -0800 Subject: [PATCH 8/9] moved validateGeoIpProperty to relevant class Signed-off-by: Kenrick Yap --- .../sql/ppl/parser/AstExpressionBuilder.java | 10 +--------- .../ppl/utils/GeoIpCatalystLogicalPlanTranslator.java | 8 ++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 4b2c90b71..a73c593fe 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -457,7 +457,7 @@ public UnresolvedExpression visitGeoIpPropertyList(OpenSearchPPLParser.GeoIpProp if (ctx != null) { for (OpenSearchPPLParser.GeoIpPropertyContext property : ctx.geoIpProperty()) { String propertyName = property.getText().toUpperCase(); - validateGeoIpProperty(propertyName); + GeoIpCatalystLogicalPlanTranslator.validateGeoIpProperty(propertyName); properties.add(new Literal(propertyName, DataType.STRING)); } } @@ -465,14 +465,6 @@ public UnresolvedExpression visitGeoIpPropertyList(OpenSearchPPLParser.GeoIpProp return new AttributeList(properties.build()); } - private static void validateGeoIpProperty(String propertyName) { - try { - GeoIpCatalystLogicalPlanTranslator.GeoIpProperty.valueOf(propertyName); - } catch (NullPointerException | IllegalArgumentException e) { - throw new IllegalArgumentException("Invalid properties used."); - } - } - private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java index a57b26799..cedc00846 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java @@ -211,4 +211,12 @@ enum GeoIpProperty { TIME_ZONE, LOCATION } + + public static void validateGeoIpProperty(String propertyName) { + try { + GeoIpProperty.valueOf(propertyName); + } catch (NullPointerException | IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid properties used."); + } + } } From 3f4fe5717c3efcc8a0ada630d2c709c08923d763 Mon Sep 17 00:00:00 2001 From: Kenrick Yap Date: Thu, 19 Dec 2024 11:34:06 +0000 Subject: [PATCH 9/9] updated scalaudf function descriptions Signed-off-by: Kenrick Yap --- .../expression/function/SerializableUdf.java | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java index 596a6a295..e931175ff 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -148,11 +148,10 @@ public Boolean apply(String ipAddress, String cidrBlock) { class geoIpUtils { /** - * Append values to JSON arrays based on specified path-values. + * Checks if provided ip string is ipv4 or ipv6. * - * @param jsonStr The input JSON string. - * @param elements A list of path-values where the first item is the path and subsequent items are values to append. - * @return The updated JSON string. + * @param ipAddress To input ip string. + * @return true if ipAddress is ipv4, false if ipaddress is ipv6, AddressString Exception if invalid ip. */ public static Function1 isIpv4 = new SerializableAbstractFunction1<>() { @@ -178,11 +177,10 @@ public Boolean apply(String ipAddress) { }; /** - * Append values to JSON arrays based on specified path-values. + * Convert ipAddress string to interger representation * - * @param jsonStr The input JSON string. - * @param elements A list of path-values where the first item is the path and subsequent items are values to append. - * @return The updated JSON string. + * @param ipAddress The input ip string. + * @return converted BigInteger from ipAddress string. */ public static Function1 ipToInt = new SerializableAbstractFunction1<>() { @Override @@ -204,10 +202,10 @@ abstract class SerializableAbstractFunction1 extends AbstractFunction1 expressions) { switch (funcName) {