diff --git a/build.sbt b/build.sbt index 30858e8d6..66b06d6be 100644 --- a/build.sbt +++ b/build.sbt @@ -154,6 +154,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", "com.github.sbt" % "junit-interface" % "0.13.3" % "test", "org.projectlombok" % "lombok" % "1.18.30", + "com.github.seancfoley" % "ipaddress" % "5.5.1", ), libraryDependencies ++= deps(sparkVersion), // ANTLR settings diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 01c3f1619..3efca3205 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -58,6 +58,8 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where a not in (1, 2, 3) | fields a,b,c` - `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4] - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where cidrmatch(ipv6, '2003:db8::/32')` ```sql source = table | eval status_category = diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 9cb5f118e..8d9b86eda 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -87,6 +87,7 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`Cryptographic Functions`](functions/ppl-cryptographic.md) + - [`IP Address Functions`](functions/ppl-ip.md) --- ### PPL On Spark diff --git a/docs/ppl-lang/functions/ppl-ip.md b/docs/ppl-lang/functions/ppl-ip.md new file mode 100644 index 000000000..fb0b468ba --- /dev/null +++ b/docs/ppl-lang/functions/ppl-ip.md @@ -0,0 +1,35 @@ +## PPL IP Address Functions + +### `CIDRMATCH` + +**Description** + +`CIDRMATCH(ip, cidr)` checks if ip is within the specified cidr range. + +**Argument type:** + - STRING, STRING + - Return type: **BOOLEAN** + +Example: + + os> source=ips | where cidrmatch(ip, '192.169.1.0/24') | fields ip + fetched rows / total rows = 1/1 + +--------------+ + | ip | + |--------------| + | 192.169.1.5 | + +--------------+ + + os> source=ipsv6 | where cidrmatch(ip, '2003:db8::/32') | fields ip + fetched rows / total rows = 1/1 + +-----------------------------------------+ + | ip | + |-----------------------------------------| + | 2003:0db8:0000:0000:0000:0000:0000:0000 | + +-----------------------------------------+ + +Note: + - `ip` can be an IPv4 or an IPv6 address + - `cidr` can be an IPv4 or an IPv6 block + - `ip` and `cidr` must be either both IPv4 or both IPv6 + - `ip` and `cidr` must both be valid and non-empty/non-null \ No newline at end of file diff --git a/docs/ppl-lang/ppl-where-command.md b/docs/ppl-lang/ppl-where-command.md index 89a7e61fa..c954623c3 100644 --- a/docs/ppl-lang/ppl-where-command.md +++ b/docs/ppl-lang/ppl-where-command.md @@ -43,6 +43,8 @@ PPL query: - `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'` - `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4] - `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10' +- `source = table | where cidrmatch(ip, '192.169.1.0/24')` +- `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - `source = table | eval status_category = case(a >= 200 AND a < 300, 'Success', diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 23a336b4c..c8c902294 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -669,4 +669,30 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | (11, null, false) | """.stripMargin) } + + protected def createIpAddressTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | id INT, + | ipAddress STRING, + | isV6 BOOLEAN, + | isValid BOOLEAN + | ) + | USING $tableType $tableOptions + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES (1, '127.0.0.1', false, true), + | (2, '192.168.1.0', false, true), + | (3, '192.168.1.1', false, true), + | (4, '192.168.2.1', false, true), + | (5, '192.168.2.', false, false), + | (6, '2001:db8::ff00:12:3455', true, true), + | (7, '2001:db8::ff00:12:3456', true, true), + | (8, '2001:db8::ff00:13:3457', true, true), + | (9, '2001:db8::ff00:12:', true, false) + | """.stripMargin) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala new file mode 100644 index 000000000..d9cf8968b --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCidrmatchITSuite.scala @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.SparkException +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLCidrmatchITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createIpAddressTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test cidrmatch for ipv4 for 192.168.1.0/24") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 2) + } + + test("test cidrmatch for ipv4 for 192.169.1.0/24") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 0) + } + + test("test cidrmatch for ipv6 for 2001:db8::/32") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2001:db8::/32') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 3) + } + + test("test cidrmatch for ipv6 for 2003:db8::/32") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2003:db8::/32') + | """.stripMargin) + + val results = frame.collect() + assert(results.length == 0) + } + + test("test cidrmatch for ipv6 with ipv4 cidr") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } + + test("test cidrmatch for invalid ipv4 addresses") { + val frame = sql(s""" + | source = $testTable | where isV6 = false and isValid = false and cidrmatch(ipAddress, '192.169.1.0/24') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } + + test("test cidrmatch for invalid ipv6 addresses") { + val frame = sql(s""" + | source = $testTable | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32') + | """.stripMargin) + + assertThrows[SparkException](frame.collect()) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 82fdeb42f..a00e161c7 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -393,6 +393,7 @@ ISNULL: 'ISNULL'; ISNOTNULL: 'ISNOTNULL'; ISPRESENT: 'ISPRESENT'; BETWEEN: 'BETWEEN'; +CIDRMATCH: 'CIDRMATCH'; // FLOWCONTROL FUNCTIONS IFNULL: 'IFNULL'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 48984b3a5..4164843ef 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -440,6 +440,7 @@ booleanExpression | isEmptyExpression # isEmptyExpr | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr | EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr + | cidrMatchFunctionCall # cidrFunctionCallExpr ; isEmptyExpression @@ -519,6 +520,10 @@ booleanFunctionCall : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS ; +cidrMatchFunctionCall + : CIDRMATCH LT_PRTHS ipAddress = functionArg COMMA cidrBlock = functionArg RT_PRTHS + ; + convertedDataType : typeName = DATE | typeName = TIME @@ -1116,4 +1121,5 @@ keywordsCanBeId | SEMI | ANTI | BETWEEN + | CIDRMATCH ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index e1397a754..03c40fcd2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -13,6 +13,7 @@ import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Cidr; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; @@ -322,4 +323,7 @@ public T visitExistsSubquery(ExistsSubquery node, C context) { public T visitWindow(Window node, C context) { return visitChildren(node, context); } + public T visitCidr(Cidr node, C context) { + return visitChildren(node, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java new file mode 100644 index 000000000..fdbb3ef65 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Cidr.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** AST node that represents CIDR function. */ +@AllArgsConstructor +@Getter +@EqualsAndHashCode(callSuper = false) +@ToString +public class Cidr extends UnresolvedExpression { + private UnresolvedExpression ipAddress; + private UnresolvedExpression cidrBlock; + + @Override + public List getChild() { + return Arrays.asList(ipAddress, cidrBlock); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCidr(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java new file mode 100644 index 000000000..2541b3743 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/SerializableUdf.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import inet.ipaddr.AddressStringException; +import inet.ipaddr.IPAddressString; +import inet.ipaddr.IPAddressStringParameters; +import scala.Function2; +import scala.Serializable; +import scala.runtime.AbstractFunction2; + + +public interface SerializableUdf { + + Function2 cidrFunction = new SerializableAbstractFunction2<>() { + + IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder() + .allowEmpty(false) + .setEmptyAsLoopback(false) + .allow_inet_aton(false) + .allowSingleSegment(false) + .toParams(); + + @Override + public Boolean apply(String ipAddress, String cidrBlock) { + + IPAddressString parsedIpAddress = new IPAddressString(ipAddress, valOptions); + + try { + parsedIpAddress.validate(); + } catch (AddressStringException e) { + throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage()); + } + + IPAddressString parsedCidrBlock = new IPAddressString(cidrBlock, valOptions); + + try { + parsedCidrBlock.validate(); + } catch (AddressStringException e) { + throw new RuntimeException("The given cidrBlock '"+cidrBlock+"' is invalid. It must be a valid CIDR or netmask. Error details: "+e.getMessage()); + } + + if(parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) { + throw new RuntimeException("The given ipAddress '"+ipAddress+"' and cidrBlock '"+cidrBlock+"' are not compatible. Both must be either IPv4 or IPv6."); + } + + return parsedCidrBlock.contains(parsedIpAddress); + } + }; + + abstract class SerializableAbstractFunction2 extends AbstractFunction2 + implements Serializable { + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 441287ddb..87010f231 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; +import org.apache.spark.sql.catalyst.expressions.ScalaUDF; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.*; @@ -88,6 +89,7 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.expression.function.SerializableUdf; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; @@ -100,7 +102,11 @@ import scala.collection.IterableLike; import scala.collection.Seq; -import java.util.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Stack; import java.util.function.BiFunction; import java.util.stream.Collectors; @@ -879,5 +885,24 @@ public Expression visitBetween(Between node, CatalystPlanContext context) { context.retainAllNamedParseExpressions(p -> p); return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper))); } + + @Override + public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) { + analyze(node.getIpAddress(), context); + Expression ipAddressExpression = context.getNamedParseExpressions().pop(); + analyze(node.getCidrBlock(), context); + Expression cidrBlockExpression = context.getNamedParseExpressions().pop(); + + ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction, + DataTypes.BooleanType, + seq(ipAddressExpression,cidrBlockExpression), + seq(), + Option.empty(), + Option.apply("cidr"), + false, + true); + + return context.getNamedParseExpressions().push(udf); + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 6a0c80c16..0c7f6a9d4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -19,14 +19,13 @@ import org.opensearch.sql.ast.expression.AttributeList; import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Cidr; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; -import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; @@ -35,13 +34,15 @@ import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -67,14 +68,6 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor { private static final int DEFAULT_TAKE_FUNCTION_SIZE_VALUE = 10; - - private AstBuilder astBuilder; - - /** Set AstBuilder back to AstExpressionBuilder for resolving the subquery plan in subquery expression */ - public void setAstBuilder(AstBuilder astBuilder) { - this.astBuilder = astBuilder; - } - /** * The function name mapping between fronted and core engine. */ @@ -84,6 +77,17 @@ public void setAstBuilder(AstBuilder astBuilder) { .put("isnotnull", IS_NOT_NULL.getName().getFunctionName()) .put("ispresent", IS_NOT_NULL.getName().getFunctionName()) .build(); + private AstBuilder astBuilder; + + public AstExpressionBuilder() { + } + + /** + * Set AstBuilder back to AstExpressionBuilder for resolving the subquery plan in subquery expression + */ + public void setAstBuilder(AstBuilder astBuilder) { + this.astBuilder = astBuilder; + } @Override public UnresolvedExpression visitMappingCompareExpr(OpenSearchPPLParser.MappingCompareExprContext ctx) { @@ -154,7 +158,7 @@ public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprCont @Override public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) { return new Function( - ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); + ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); } @Override @@ -245,7 +249,7 @@ public UnresolvedExpression visitCaseExpr(OpenSearchPPLParser.CaseExprContext ct }) .collect(Collectors.toList()); UnresolvedExpression elseValue = new Literal(null, DataType.NULL); - if(ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) { + if (ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) { // else value is present elseValue = visit(ctx.caseFunction().valueExpression(ctx.caseFunction().valueExpression().size() - 1)); } @@ -290,9 +294,6 @@ private Function buildFunction( functionName, args.stream().map(this::visitFunctionArg).collect(Collectors.toList())); } - public AstExpressionBuilder() { - } - @Override public UnresolvedExpression visitMultiFieldRelevanceFunction( OpenSearchPPLParser.MultiFieldRelevanceFunctionContext ctx) { @@ -306,7 +307,7 @@ public UnresolvedExpression visitTableSource(OpenSearchPPLParser.TableSourceCont if (ctx.getChild(0) instanceof OpenSearchPPLParser.IdentsAsTableQualifiedNameContext) { return visitIdentsAsTableQualifiedName((OpenSearchPPLParser.IdentsAsTableQualifiedNameContext) ctx.getChild(0)); } else { - return visitIdentifiers(Arrays.asList(ctx)); + return visitIdentifiers(List.of(ctx)); } } @@ -398,9 +399,9 @@ public UnresolvedExpression visitRightHint(OpenSearchPPLParser.RightHintContext @Override public UnresolvedExpression visitInSubqueryExpr(OpenSearchPPLParser.InSubqueryExprContext ctx) { UnresolvedExpression expr = new InSubquery( - ctx.valueExpressionList().valueExpression().stream() - .map(this::visit).collect(Collectors.toList()), - astBuilder.visitSubSearch(ctx.subSearch())); + ctx.valueExpressionList().valueExpression().stream() + .map(this::visit).collect(Collectors.toList()), + astBuilder.visitSubSearch(ctx.subSearch())); return ctx.NOT() != null ? new Not(expr) : expr; } @@ -421,6 +422,12 @@ public UnresolvedExpression visitInExpr(OpenSearchPPLParser.InExprContext ctx) { return ctx.NOT() != null ? new Not(expr) : expr; } + + @Override + public UnresolvedExpression visitCidrMatchFunctionCall(OpenSearchPPLParser.CidrMatchFunctionCallContext ctx) { + return new Cidr(visit(ctx.ipAddress), visit(ctx.cidrBlock)); + } + private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java new file mode 100644 index 000000000..3d3940730 --- /dev/null +++ b/ppl-spark-integration/src/test/java/org/opensearch/sql/expression/function/SerializableUdfTest.java @@ -0,0 +1,61 @@ +package org.opensearch.sql.expression.function; + +import org.junit.Assert; +import org.junit.Test; + +public class SerializableUdfTest { + + @Test(expected = RuntimeException.class) + public void cidrNullIpTest() { + SerializableUdf.cidrFunction.apply(null, "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrEmptyIpTest() { + SerializableUdf.cidrFunction.apply("", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrNullCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", null); + } + + @Test(expected = RuntimeException.class) + public void cidrEmptyCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", ""); + } + + @Test(expected = RuntimeException.class) + public void cidrInvalidIpTest() { + SerializableUdf.cidrFunction.apply("xxx", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cidrInvalidCidrTest() { + SerializableUdf.cidrFunction.apply("192.168.0.0", "xxx"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTest() { + SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "192.168.0.0/24"); + SerializableUdf.cidrFunction.apply("192.168.0.0", "2001:db8::/324"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTestV6V4() { + SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "192.168.0.0/24"); + } + + @Test(expected = RuntimeException.class) + public void cirdMixedIpVersionTestV4V6() { + SerializableUdf.cidrFunction.apply("192.168.0.0", "2001:db8::/324"); + } + + @Test + public void cidrBasicTest() { + Assert.assertTrue(SerializableUdf.cidrFunction.apply("192.168.0.0", "192.168.0.0/24")); + Assert.assertFalse(SerializableUdf.cidrFunction.apply("10.10.0.0", "192.168.0.0/24")); + Assert.assertTrue(SerializableUdf.cidrFunction.apply("2001:0db8:85a3:0000:0000:8a2e:0370:7334", "2001:db8::/32")); + Assert.assertFalse(SerializableUdf.cidrFunction.apply("2001:0db7:85a3:0000:0000:8a2e:0370:7334", "2001:0db8::/32")); + } +}