Skip to content

Commit

Permalink
PPL cidr function (#706)
Browse files Browse the repository at this point in the history
* Initial commit for cidr function

Signed-off-by: Hendrik Saly <[email protected]>

* Add inet.ipaddr library and implement basic cidr check

Signed-off-by: Hendrik Saly <[email protected]>

* Refactor for using lombok

Signed-off-by: Hendrik Saly <[email protected]>

* Add unittests

Signed-off-by: Hendrik Saly <[email protected]>

* add check for mixed ip address types like ipv4 and ipv6; add test with exception runtime exception in case of a mixed adress type

Signed-off-by: Jens Schmidt <[email protected]>

* Fix antlr

Signed-off-by: Hendrik Saly <[email protected]>

* rename cird function cirdmatch

Signed-off-by: Jens Schmidt <[email protected]>

* Fix imports

Signed-off-by: Hendrik Saly <[email protected]>

* Prepare integ tests

Signed-off-by: Hendrik Saly <[email protected]>

* Added IT tests

Signed-off-by: Hendrik Saly <[email protected]>

* Refactor SerializableUdf to be an interface

Signed-off-by: Hendrik Saly <[email protected]>

* Fix imports

Signed-off-by: Hendrik Saly <[email protected]>

* Added docs

Signed-off-by: Hendrik Saly <[email protected]>

---------

Signed-off-by: Hendrik Saly <[email protected]>
Signed-off-by: Jens Schmidt <[email protected]>
Co-authored-by: Jens Schmidt <[email protected]>
  • Loading branch information
salyh and dr-lilienthal authored Oct 30, 2024
1 parent d03cc8c commit d2213c5
Show file tree
Hide file tree
Showing 15 changed files with 380 additions and 21 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration"))
"com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test",
"com.github.sbt" % "junit-interface" % "0.13.3" % "test",
"org.projectlombok" % "lombok" % "1.18.30",
"com.github.seancfoley" % "ipaddress" % "5.5.1",
),
libraryDependencies ++= deps(sparkVersion),
// ANTLR settings
Expand Down
2 changes: 2 additions & 0 deletions docs/ppl-lang/PPL-Example-Commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ _- **Limitation: new field added by eval command with a function cannot be dropp
- `source = table | where a not in (1, 2, 3) | fields a,b,c`
- `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4]
- `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10'
- `source = table | where cidrmatch(ip, '192.169.1.0/24')`
- `source = table | where cidrmatch(ipv6, '2003:db8::/32')`

```sql
source = table | eval status_category =
Expand Down
1 change: 1 addition & 0 deletions docs/ppl-lang/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ For additional examples see the next [documentation](PPL-Example-Commands.md).

- [`Cryptographic Functions`](functions/ppl-cryptographic.md)

- [`IP Address Functions`](functions/ppl-ip.md)

---
### PPL On Spark
Expand Down
35 changes: 35 additions & 0 deletions docs/ppl-lang/functions/ppl-ip.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
## PPL IP Address Functions

### `CIDRMATCH`

**Description**

`CIDRMATCH(ip, cidr)` checks if ip is within the specified cidr range.

**Argument type:**
- STRING, STRING
- Return type: **BOOLEAN**

Example:

os> source=ips | where cidrmatch(ip, '192.169.1.0/24') | fields ip
fetched rows / total rows = 1/1
+--------------+
| ip |
|--------------|
| 192.169.1.5 |
+--------------+

os> source=ipsv6 | where cidrmatch(ip, '2003:db8::/32') | fields ip
fetched rows / total rows = 1/1
+-----------------------------------------+
| ip |
|-----------------------------------------|
| 2003:0db8:0000:0000:0000:0000:0000:0000 |
+-----------------------------------------+

Note:
- `ip` can be an IPv4 or an IPv6 address
- `cidr` can be an IPv4 or an IPv6 block
- `ip` and `cidr` must be either both IPv4 or both IPv6
- `ip` and `cidr` must both be valid and non-empty/non-null
2 changes: 2 additions & 0 deletions docs/ppl-lang/ppl-where-command.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ PPL query:
- `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`
- `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4]
- `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10'
- `source = table | where cidrmatch(ip, '192.169.1.0/24')`
- `source = table | where cidrmatch(ipv6, '2003:db8::/32')`
- `source = table | eval status_category =
case(a >= 200 AND a < 300, 'Success',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,4 +669,30 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| (11, null, false)
| """.stripMargin)
}

protected def createIpAddressTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
| (
| id INT,
| ipAddress STRING,
| isV6 BOOLEAN,
| isValid BOOLEAN
| )
| USING $tableType $tableOptions
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| VALUES (1, '127.0.0.1', false, true),
| (2, '192.168.1.0', false, true),
| (3, '192.168.1.1', false, true),
| (4, '192.168.2.1', false, true),
| (5, '192.168.2.', false, false),
| (6, '2001:db8::ff00:12:3455', true, true),
| (7, '2001:db8::ff00:12:3456', true, true),
| (8, '2001:db8::ff00:13:3457', true, true),
| (9, '2001:db8::ff00:12:', true, false)
| """.stripMargin)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLCidrmatchITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createIpAddressTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test cidrmatch for ipv4 for 192.168.1.0/24") {
val frame = sql(s"""
| source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')
| """.stripMargin)

val results = frame.collect()
assert(results.length == 2)
}

test("test cidrmatch for ipv4 for 192.169.1.0/24") {
val frame = sql(s"""
| source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24')
| """.stripMargin)

val results = frame.collect()
assert(results.length == 0)
}

test("test cidrmatch for ipv6 for 2001:db8::/32") {
val frame = sql(s"""
| source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2001:db8::/32')
| """.stripMargin)

val results = frame.collect()
assert(results.length == 3)
}

test("test cidrmatch for ipv6 for 2003:db8::/32") {
val frame = sql(s"""
| source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2003:db8::/32')
| """.stripMargin)

val results = frame.collect()
assert(results.length == 0)
}

test("test cidrmatch for ipv6 with ipv4 cidr") {
val frame = sql(s"""
| source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24')
| """.stripMargin)

assertThrows[SparkException](frame.collect())
}

test("test cidrmatch for invalid ipv4 addresses") {
val frame = sql(s"""
| source = $testTable | where isV6 = false and isValid = false and cidrmatch(ipAddress, '192.169.1.0/24')
| """.stripMargin)

assertThrows[SparkException](frame.collect())
}

test("test cidrmatch for invalid ipv6 addresses") {
val frame = sql(s"""
| source = $testTable | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32')
| """.stripMargin)

assertThrows[SparkException](frame.collect())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ ISNULL: 'ISNULL';
ISNOTNULL: 'ISNOTNULL';
ISPRESENT: 'ISPRESENT';
BETWEEN: 'BETWEEN';
CIDRMATCH: 'CIDRMATCH';

// FLOWCONTROL FUNCTIONS
IFNULL: 'IFNULL';
Expand Down
6 changes: 6 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ booleanExpression
| isEmptyExpression # isEmptyExpr
| valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr
| EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr
| cidrMatchFunctionCall # cidrFunctionCallExpr
;

isEmptyExpression
Expand Down Expand Up @@ -519,6 +520,10 @@ booleanFunctionCall
: conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS
;

cidrMatchFunctionCall
: CIDRMATCH LT_PRTHS ipAddress = functionArg COMMA cidrBlock = functionArg RT_PRTHS
;

convertedDataType
: typeName = DATE
| typeName = TIME
Expand Down Expand Up @@ -1116,4 +1121,5 @@ keywordsCanBeId
| SEMI
| ANTI
| BETWEEN
| CIDRMATCH
;
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
Expand Down Expand Up @@ -322,4 +323,7 @@ public T visitExistsSubquery(ExistsSubquery node, C context) {
public T visitWindow(Window node, C context) {
return visitChildren(node, context);
}
public T visitCidr(Cidr node, C context) {
return visitChildren(node, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;

import java.util.Arrays;
import java.util.List;

/** AST node that represents CIDR function. */
@AllArgsConstructor
@Getter
@EqualsAndHashCode(callSuper = false)
@ToString
public class Cidr extends UnresolvedExpression {
private UnresolvedExpression ipAddress;
private UnresolvedExpression cidrBlock;

@Override
public List<UnresolvedExpression> getChild() {
return Arrays.asList(ipAddress, cidrBlock);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitCidr(this, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function;

import inet.ipaddr.AddressStringException;
import inet.ipaddr.IPAddressString;
import inet.ipaddr.IPAddressStringParameters;
import scala.Function2;
import scala.Serializable;
import scala.runtime.AbstractFunction2;


public interface SerializableUdf {

Function2<String,String,Boolean> cidrFunction = new SerializableAbstractFunction2<>() {

IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder()
.allowEmpty(false)
.setEmptyAsLoopback(false)
.allow_inet_aton(false)
.allowSingleSegment(false)
.toParams();

@Override
public Boolean apply(String ipAddress, String cidrBlock) {

IPAddressString parsedIpAddress = new IPAddressString(ipAddress, valOptions);

try {
parsedIpAddress.validate();
} catch (AddressStringException e) {
throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage());
}

IPAddressString parsedCidrBlock = new IPAddressString(cidrBlock, valOptions);

try {
parsedCidrBlock.validate();
} catch (AddressStringException e) {
throw new RuntimeException("The given cidrBlock '"+cidrBlock+"' is invalid. It must be a valid CIDR or netmask. Error details: "+e.getMessage());
}

if(parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) {
throw new RuntimeException("The given ipAddress '"+ipAddress+"' and cidrBlock '"+cidrBlock+"' are not compatible. Both must be either IPv4 or IPv6.");
}

return parsedCidrBlock.contains(parsedIpAddress);
}
};

abstract class SerializableAbstractFunction2<T1,T2,R> extends AbstractFunction2<T1,T2,R>
implements Serializable {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.*;
Expand Down Expand Up @@ -88,6 +89,7 @@
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Window;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.expression.function.SerializableUdf;
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
Expand All @@ -100,7 +102,11 @@
import scala.collection.IterableLike;
import scala.collection.Seq;

import java.util.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -879,5 +885,24 @@ public Expression visitBetween(Between node, CatalystPlanContext context) {
context.retainAllNamedParseExpressions(p -> p);
return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper)));
}

@Override
public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) {
analyze(node.getIpAddress(), context);
Expression ipAddressExpression = context.getNamedParseExpressions().pop();
analyze(node.getCidrBlock(), context);
Expression cidrBlockExpression = context.getNamedParseExpressions().pop();

ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction,
DataTypes.BooleanType,
seq(ipAddressExpression,cidrBlockExpression),
seq(),
Option.empty(),
Option.apply("cidr"),
false,
true);

return context.getNamedParseExpressions().push(udf);
}
}
}
Loading

0 comments on commit d2213c5

Please sign in to comment.