From 3f447311a3f41bc409a55273d38c08b18e450c77 Mon Sep 17 00:00:00 2001 From: Kenrick Yap <14yapkc1@gmail.com> Date: Fri, 6 Dec 2024 19:00:39 -0800 Subject: [PATCH] Added unit tests --- .../expression/function/SerializableUdf.java | 11 - .../sql/ppl/CatalystQueryPlanVisitor.java | 23 +- .../sql/ppl/parser/AstExpressionBuilder.java | 2 + .../sql/ppl/utils/GeoipCatalystUtils.java | 53 ++-- ...PlanGeoipFunctionTranslatorTestSuite.scala | 288 ++++++++++++++++++ 5 files changed, 337 insertions(+), 40 deletions(-) create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java index 53bff1106..4fb8929a9 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -65,7 +65,6 @@ public Boolean apply(String ipAddress, String cidrBlock) { @Override public Boolean apply(String ipAddress) { - IPAddressString parsedIpAddress = new IPAddressString(ipAddress, valOptions); try { @@ -78,21 +77,11 @@ public Boolean apply(String ipAddress) { }}; Function1 ipToInt = new SerializableAbstractFunction1<>() { - - IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() - .allowEmpty(false) - .setEmptyAsLoopback(false) - .allow_inet_aton(false) - .allowSingleSegment(false) - .toParams(); - @Override public BigInteger apply(String ipAddress) { try { InetAddress inetAddress = InetAddress.getByName(ipAddress); byte[] addressBytes = inetAddress.getAddress(); - - // Convert the byte array to a BigInteger return new BigInteger(1, addressBytes); } catch (UnknownHostException e) { System.err.println("Invalid IP address: " + e.getMessage()); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 32942987a..17a59dc1c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -552,15 +552,20 @@ public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { throw new SyntaxCheckException("Unexpected node type when visiting EVAL"); } } - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - } - List expressionList = visitExpressionList(aliases, context); - Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - // build the plan with the projection step - return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + if (!aliases.isEmpty()) { + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + + visitExpressionList(aliases, context); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } else { + return context.apply(p -> p); + } } @Override @@ -579,7 +584,7 @@ public LogicalPlan visitGeoIp(GeoIp node, CatalystPlanContext context) { attributeList.add(0, attributeName); } - Expression fieldExpression = visitExpression(node.getField(), context); + String fieldExpression = node.getField().getField().toString(); Expression ipAddressExpression = visitExpression(node.getIpAddress(), context); return GeoipCatalystUtils.getGeoipLogicalPlan( diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index d70a93498..1c80b6514 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -458,6 +458,8 @@ public UnresolvedExpression visitGeoIpPropertyList(OpenSearchPPLParser.GeoIpProp propertyName = "CONTINENT_NAME"; } else if (property.REGION_ISO_CODE() != null) { propertyName = "REGION_ISO_CODE"; + } else if (property.REGION_NAME() != null) { + propertyName = "REGION_NAME"; } else if (property.CITY_NAME() != null) { propertyName = "CITY_NAME"; } else if (property.TIME_ZONE() != null) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoipCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoipCatalystUtils.java index 1e0c2a41e..cb41d01a5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoipCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoipCatalystUtils.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; +import org.apache.spark.SparkEnv; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; @@ -34,25 +35,36 @@ import java.util.List; import java.util.Locale; import java.util.Optional; +import java.util.logging.Level; import java.util.stream.Collectors; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; public interface GeoipCatalystUtils { - + String SPARK_CONF_KEY = "spark.geoip.tablename"; String DEFAULT_GEOIP_TABLE_NAME = "geoip"; String SOURCE_TABLE_ALIAS = "t1"; - String GEOIP_TABLE_ALIAS= "t2"; + String GEOIP_TABLE_ALIAS = "t2"; + List GEOIP_TABLE_COLUMNS = List.of( + "country_iso_code", + "country_name", + "continent_name", + "region_iso_code", + "region_name", + "city_name", + "time_zone", + "location" + ); static LogicalPlan getGeoipLogicalPlan(GeoIpParameters parameters, CatalystPlanContext context) { applyJoin(parameters.getIpAddress(), context); return applyProjection(parameters.getField(), parameters.getProperties(), context); } - static LogicalPlan applyJoin(Expression ipAddress, CatalystPlanContext context) { + static private LogicalPlan applyJoin(Expression ipAddress, CatalystPlanContext context) { return context.apply(left -> { - LogicalPlan right = new UnresolvedRelation(seq(DEFAULT_GEOIP_TABLE_NAME), CaseInsensitiveStringMap.empty(), false); + LogicalPlan right = new UnresolvedRelation(seq(getGeoipTableName()), CaseInsensitiveStringMap.empty(), false); LogicalPlan leftAlias = SubqueryAlias$.MODULE$.apply(SOURCE_TABLE_ALIAS, left); LogicalPlan rightAlias = SubqueryAlias$.MODULE$.apply(GEOIP_TABLE_ALIAS, right); Optional joinCondition = Optional.of(new And( @@ -75,13 +87,13 @@ static LogicalPlan applyJoin(Expression ipAddress, CatalystPlanContext context) context.retainAllPlans(p -> p); return join(leftAlias, rightAlias, - Join.JoinType.INNER, + Join.JoinType.LEFT, joinCondition, new Join.JoinHint()); }); } - static private LogicalPlan applyProjection(Expression field, List properties, CatalystPlanContext context) { + static private LogicalPlan applyProjection(String field, List properties, CatalystPlanContext context) { List projectExpressions = new ArrayList<>(); projectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); @@ -91,11 +103,11 @@ static private LogicalPlan applyProjection(Expression field, List proper NamedExpression geoCol = Alias$.MODULE$.apply( columnValue, - field.toString(), + field, NamedExpression.newExprId(), - seq(new java.util.ArrayList<>()), + seq(new ArrayList<>()), Option.empty(), - seq(new java.util.ArrayList<>())); + seq(new ArrayList<>())); projectExpressions.add(geoCol); @@ -114,16 +126,7 @@ static private LogicalPlan applyProjection(Expression field, List proper static private List createGeoIpStructFields(List attributeList) { List attributeListToUse; if (attributeList == null || attributeList.isEmpty()) { - attributeListToUse = List.of( - "country_iso_code", - "country_name", - "continent_name", - "region_iso_code", - "region_name", - "city_name", - "time_zone", - "location" - ); + attributeListToUse = GEOIP_TABLE_COLUMNS; } else { attributeListToUse = attributeList; } @@ -159,10 +162,20 @@ static private Expression getIsIpv4(Expression ipAddress) { ); } + static private String getGeoipTableName() { + String tableName = DEFAULT_GEOIP_TABLE_NAME; + + if (SparkEnv.get() != null && SparkEnv.get().conf() != null) { + tableName = SparkEnv.get().conf().get(SPARK_CONF_KEY, DEFAULT_GEOIP_TABLE_NAME); + } + + return tableName; + } + @Getter @AllArgsConstructor class GeoIpParameters { - private final Expression field; + private final String field; private final Expression ipAddress; private final List properties; } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala new file mode 100644 index 000000000..f6357d211 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanGeoipFunctionTranslatorTestSuite.scala @@ -0,0 +1,288 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.expression.function.SerializableUdf +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, CreateNamedStruct, Descending, EqualTo, ExprId, GreaterThanOrEqual, In, LessThan, Literal, NamedExpression, ScalaUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.{LeftOuter, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.types.DataTypes + + +class PPLLogicalPlanGeoipFunctionTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + private def getGeoIpQueryPlan( + ipAddress: UnresolvedAttribute, + left : LogicalPlan, + right : LogicalPlan, + projectionProperties : Alias + ) : LogicalPlan = { + val joinPlan = getJoinPlan(ipAddress, left, right) + getProjection(joinPlan, projectionProperties) + } + + private def getJoinPlan( + ipAddress: UnresolvedAttribute, + left : LogicalPlan, + right : LogicalPlan + ) : LogicalPlan = { + val is_ipv4 = ScalaUDF( + SerializableUdf.isIpv4, + DataTypes.BooleanType, + seq(ipAddress), + seq(), + Option.empty, + Option.apply("is_ipv4"), + false, + true + ) + val ip_to_int = ScalaUDF( + SerializableUdf.ipToInt, + DataTypes.createDecimalType(38, 0), + seq(ipAddress), + seq(), + Option.empty, + Option.apply("ip_to_int"), + false, + true + ) + + val t1 = SubqueryAlias("t1", left) + val t2 = SubqueryAlias("t2", right) + + val joinCondition = And( + And( + GreaterThanOrEqual(ip_to_int, UnresolvedAttribute("t2.start")), + LessThan(ip_to_int, UnresolvedAttribute("t2.end")) + ), + EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4")) + ) + Join(t1, t2, LeftOuter, Some(joinCondition), JoinHint.NONE) + } + + private def getProjection(joinPlan : LogicalPlan, projectionProperties : Alias) : LogicalPlan = { + val projection = Project(Seq(UnresolvedStar(None), projectionProperties), joinPlan) + val dropList = Seq( + "t2.country_iso_code", "t2.country_name", "t2.continent_name", + "t2.region_iso_code", "t2.region_name", "t2.city_name", + "t2.time_zone", "t2.location", "t2.cidr", "t2.start", "t2.end", "t2.ipv4" + ).map(UnresolvedAttribute(_)) + DataFrameDropColumns(dropList, projection) + } + + test("test geoip function - only ip_address provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source = users | eval a = geoip(ip_address)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val projectionStruct = CreateNamedStruct(Seq( + Literal("country_iso_code"), UnresolvedAttribute("t2.country_iso_code"), + Literal("country_name"), UnresolvedAttribute("t2.country_name"), + Literal("continent_name"), UnresolvedAttribute("t2.continent_name"), + Literal("region_iso_code"), UnresolvedAttribute("t2.region_iso_code"), + Literal("region_name"), UnresolvedAttribute("t2.region_name"), + Literal("city_name"), UnresolvedAttribute("t2.city_name"), + Literal("time_zone"), UnresolvedAttribute("t2.time_zone"), + Literal("location"), UnresolvedAttribute("t2.location") + )) + val structProjection = Alias(projectionStruct, "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - source has same name as join alias") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(ip_address, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t1")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + + test("test geoip function - ipAddress col exist in geoip table") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("cidr") + val sourceTable = UnresolvedRelation(seq("t1")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - duplicate parameters") { + val context = new CatalystPlanContext + + val exception = intercept[IllegalStateException]{ + planTransformer.visit( + plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name, country_name)"), + context) + } + + assert(exception.getMessage.contains("Duplicate attribute in GEOIP attribute list")) + } + + test("test geoip function - one property provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=users | eval a = geoip(ip_address, country_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + val structProjection = Alias(UnresolvedAttribute("t2.country_name"), "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - multiple properties provided") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=users | eval a = geoip(ip_address,country_name,location)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("users")) + val geoTable = UnresolvedRelation(seq("geoip")) + val projectionStruct = CreateNamedStruct(Seq( + Literal("country_name"), UnresolvedAttribute("t2.country_name"), + Literal("location"), UnresolvedAttribute("t2.location") + )) + val structProjection = Alias(projectionStruct, "a")() + + val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection) + val expectedPlan = Project(Seq(UnresolvedStar(None)), geoIpPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - multiple geoip calls") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.country_iso_code"), "a")() + val colAPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionA) + + val structProjectionB = Alias(UnresolvedAttribute("t2.region_iso_code"), "b")() + val colBPlan = getGeoIpQueryPlan(ipAddress, colAPlan, geoTable, structProjectionB) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colBPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - other eval function used between geoip") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionA = Alias(UnresolvedAttribute("t2.time_zone"), "a")() + val colAPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionA) + + val structProjectionC = Alias(UnresolvedAttribute("t2.region_name"), "c")() + val colCPlan = getGeoIpQueryPlan(ipAddress, colAPlan, geoTable, structProjectionC) + + val randProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "b")()) + val colBPlan = Project(randProjectList, colCPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colBPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test geoip function - other eval function used before geoip") { + val context = new CatalystPlanContext + + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | eval a = rand(), b = geoip(ip_address, city_name)"), + context) + + val ipAddress = UnresolvedAttribute("ip_address") + val sourceTable = UnresolvedRelation(seq("t")) + val geoTable = UnresolvedRelation(seq("geoip")) + + val structProjectionB = Alias(UnresolvedAttribute("t2.city_name"), "b")() + val colBPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjectionB) + + val randProjectList: Seq[NamedExpression] = Seq( + UnresolvedStar(None), + Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "a")()) + val colAPlan = Project(randProjectList, colBPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), colAPlan) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}