From 1cc7cf10ca620d2c906d95070163556ff62a6cbc Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 9 Feb 2024 11:14:34 -0800 Subject: [PATCH] Add UT Signed-off-by: Chen Dai --- .../bloomfilter/BloomFilterMightContain.scala | 15 ++++++- .../BloomFilterMightContainSuite.scala | 45 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContainSuite.scala diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala index 683430bd1..2e88c9bcf 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala @@ -16,7 +16,7 @@ import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.functions.{col, lit} -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.types._ /** * Bloom filter function that returns the membership check result for values of `valueExpression` @@ -42,7 +42,18 @@ case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpre override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN" - override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) | + (BinaryType, LongType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s""" + | Input to function $prettyName should be Binary expression followed by a Long value, + | but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}]. + | """.stripMargin) + } + } override protected def withNewChildrenInternal( newBloomFilterExpression: Expression, diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContainSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContainSuite.scala new file mode 100644 index 000000000..132c79b38 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContainSuite.scala @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.skipping.bloomfilter + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types.{BinaryType, DoubleType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class BloomFilterMightContainSuite extends FlintSuite { + + test("checkInputDataTypes should succeed for valid input types") { + val binaryExpression = Literal(Array[Byte](1, 2, 3), BinaryType) + val longExpression = Literal(42L, LongType) + + val bloomFilterMightContain = BloomFilterMightContain(binaryExpression, longExpression) + assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckSuccess) + } + + test("checkInputDataTypes should succeed for valid input types with nulls") { + val binaryExpression = Literal.create(null, BinaryType) + val longExpression = Literal.create(null, LongType) + + val bloomFilterMightContain = BloomFilterMightContain(binaryExpression, longExpression) + assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckSuccess) + } + + test("checkInputDataTypes should fail for invalid input types") { + val stringExpression = Literal(UTF8String.fromString("invalid"), StringType) + val doubleExpression = Literal(3.14, DoubleType) + + val bloomFilterMightContain = BloomFilterMightContain(stringExpression, doubleExpression) + val expectedErrorMsg = + s""" + | Input to function bloom_filter_might_contain should be Binary expression followed by a Long value, + | but it's [${stringExpression.dataType.catalogString}, ${doubleExpression.dataType.catalogString}]. + | """.stripMargin + + assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckFailure(expectedErrorMsg)) + } +}