diff --git a/docs/index.md b/docs/index.md index 82c147de2..e0211d8fa 100644 --- a/docs/index.md +++ b/docs/index.md @@ -577,6 +577,10 @@ The following table define the data type mapping between Flint data type and Spa * Spark data types VarcharType(length) and CharType(length) are both currently mapped to Flint data type *keyword*, dropping their length property. On the other hand, Flint data type *keyword* only maps to StringType. +* Spark data type MapType is mapped to an empty OpenSearch object. The inner fields then rely on + dynamic mapping. On the other hand, Flint data type *object* only maps to StructType. +* Spark data type DecimalType is mapped to an OpenSearch double. On the other hand, Flint data type + *double* only maps to DoubleType. Unsupported Spark data types: * DecimalType diff --git a/docs/load_geoip_data.scala b/docs/load_geoip_data.scala new file mode 100644 index 000000000..1540dbfb1 --- /dev/null +++ b/docs/load_geoip_data.scala @@ -0,0 +1,440 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import java.io.BufferedReader +import java.io.FileReader +import java.io.PrintStream +import java.math.BigInteger +import scala.collection.mutable.ListBuffer + +var ipv4NodeCount = 0 +var ipv6NodeCount = 0 +var ipv4NodeOutputCount = 0 +var ipv6NodeOutputCount = 0 + +/* Create a binary tree based on the bits of the start IP address of the subnets. Only use the + first bits needed for the netmask. For example with a subnet of "192.168.2.0/24", only use the + first 24 bits. + + If a node for a subnet has children, then there is an overlap that must be corrected. To correct + an overlap, make sure that both children of the node exist and remove the subnet for the current + node. Finally check the child nodes for overlapping subnets and continue. + */ +class TreeNode(var ipAddressBytes: Array[Byte], var netmask: Int, var isIPv4: Boolean, var lineRemainder: String) { + var falseChild: TreeNode = null + var trueChild: TreeNode = null + + def maxNetmask: Integer = if (isIPv4) 32 else 128 + + // Add a new node to the tree in the correct position + def addNode(nodeToAdd: TreeNode): Unit = { + if (netmask >= nodeToAdd.netmask || netmask == maxNetmask) { + return + } + + var byteIndex = netmask / 8 + var bitValue = (nodeToAdd.ipAddressBytes(byteIndex) & (1 << (7 - (netmask % 8)))) > 0 + + if (netmask + 1 == nodeToAdd.netmask) { + if (bitValue) { + trueChild = nodeToAdd + } else { + falseChild = nodeToAdd + } + } else { + var nextChild: TreeNode = null + if (bitValue) { + nextChild = trueChild + if (trueChild == null) { + nextChild = new TreeNode(null, netmask + 1, isIPv4, null) + trueChild = nextChild + } + } else { + nextChild = falseChild + if (falseChild == null) { + nextChild = new TreeNode(null, netmask + 1, isIPv4, null) + falseChild = nextChild + } + } + + nextChild.addNode(nodeToAdd) + } + + return + } + + def haveOverlap(): Boolean = falseChild != null || trueChild != null + + // Convert the IP address to a string. For IPv6, this is more complicated, since it may + // need to be reduced. + def ipAddressString(): String = { + if (isIPv4) { + return ipAddressBytes.map(v => 255 & v).mkString(".") + } else { + var allZeroes = true + for (b <- ipAddressBytes) { + if (b != 0) { + allZeroes = false + } + } + + if (allZeroes) { + return "::" + } + + var zeroes: ListBuffer[(Int, Int)] = ListBuffer() + var zeroesStart = -1 + var zeroesStartIndex = -1 + for (i <- 0 to 7) { + if (ipAddressBytes(i * 2) == 0 && ipAddressBytes(i * 2 + 1) == 0) { + if (zeroesStart == -1) { + zeroesStart = i + zeroesStartIndex = zeroes.length + zeroes = zeroes :+ (i, 1) + } else { + var existingTuple = zeroes(zeroesStartIndex) + zeroes.update(zeroesStartIndex, (existingTuple._1, 1 + existingTuple._2)) + } + } else { + zeroesStart = -1 + zeroesStartIndex = -1 + } + } + + var longestZeroesIndex = -1 + var longestZeroesLength = 0 + for (v <- zeroes) { + if (v._2 >= longestZeroesLength) { + longestZeroesLength = v._2 + longestZeroesIndex = v._1 + } + } + + var fullIpAddress: Array[String] = Array.fill(8){null} + for (i <- 0 to 7) { + var strValue = (((255 & ipAddressBytes(i * 2)) << 8) + (255 & ipAddressBytes(i * 2 + 1))).toHexString + fullIpAddress(i) = strValue + } + + if (longestZeroesIndex == -1) { + return fullIpAddress.mkString(":") + } else { + var ipPartsStart = fullIpAddress.slice(0, longestZeroesIndex) + var ipPartsEnd = fullIpAddress.slice(longestZeroesIndex + longestZeroesLength, 8) + return ipPartsStart.mkString(":") + "::" + ipPartsEnd.mkString(":") + } + } + } + + def getStart(): BigInteger = new BigInteger(ipAddressBytes) + + def getEnd(): BigInteger = { + var valueToAdd = new BigInteger(Array.fill(maxNetmask / 8){0.toByte}) + if (netmask < maxNetmask) { + valueToAdd = valueToAdd.flipBit(maxNetmask - netmask) + valueToAdd = valueToAdd.subtract(new BigInteger("1")) + } + return getStart().add(valueToAdd) + } + + def valueToByteArray(value: BigInteger): Array[Byte] = { + var fullArray = Array.fill(maxNetmask / 8){0.toByte} + var valueArray = value.toByteArray() + valueArray.copyToArray(fullArray, (maxNetmask / 8) - valueArray.length, valueArray.length) + return fullArray + } + + def incrementNodeCount(): Unit = { + if (isIPv4) { + ipv4NodeCount += ipv4NodeCount + } else { + ipv6NodeCount += ipv6NodeCount + } + } + + // Split a node. Make sure that both children exist and remove the subnet for the current node. + def split(): Unit = { + if (ipAddressBytes == null) { + return + } + + var ipAddressStr = ipAddressString() + println(s">>> Splitting IP: $ipAddressStr") + + if (falseChild == null) { + falseChild = new TreeNode(ipAddressBytes, netmask + 1, isIPv4, lineRemainder) + } else if (falseChild.ipAddressBytes == null) { + falseChild.ipAddressBytes = ipAddressBytes + falseChild.lineRemainder = lineRemainder + } + + if (trueChild == null) { + var valueStart = falseChild.getEnd().add(new BigInteger("1")) + var startArray = valueToByteArray(valueStart) + trueChild = new TreeNode(startArray, netmask + 1, isIPv4, lineRemainder) + } else if (trueChild.ipAddressBytes == null) { + var valueStart = falseChild.getEnd().add(new BigInteger("1")) + var startArray = valueToByteArray(valueStart) + trueChild.ipAddressBytes = startArray + trueChild.lineRemainder = lineRemainder + } + + ipAddressBytes = null + lineRemainder = null + + return + } + + def fixTree(): Unit = { + if (haveOverlap()) { + split() + } + + if (falseChild != null) { + falseChild.fixTree() + } + + if (trueChild != null) { + trueChild.fixTree() + } + } + + def printTree(outStream: PrintStream, tenPercentCount: Int): Unit = { + if (ipAddressBytes != null) { + outStream.print(ipAddressString()) + outStream.print("/") + outStream.print(netmask.toString) + outStream.print(",") + outStream.print(lineRemainder) + outStream.print(",") + outStream.print(getStart().toString()) + outStream.print(",") + outStream.print(getEnd().toString()) + outStream.print(",") + outStream.println(isIPv4.toString) + + var currentNodeCount = if (isIPv4) ipv4NodeOutputCount else ipv6NodeOutputCount + if (currentNodeCount % tenPercentCount == 0) { + print((currentNodeCount * 10 / tenPercentCount).toString + "%..") + } + + if (isIPv4) { + ipv4NodeOutputCount += 1 + } else { + ipv6NodeOutputCount += 1 + } + } + + if (falseChild != null) { + falseChild.printTree(outStream, tenPercentCount) + } + if (trueChild != null) { + trueChild.printTree(outStream, tenPercentCount) + } + } +} + +// Create a node for an IPv4 entry +def createIPv4TreeNode(fullLine: String): TreeNode = { + var charIndex = fullLine.indexOf(",") + var subnet = fullLine.substring(0, charIndex) + var lineRemainder = fullLine.substring(charIndex + 1) + + charIndex = subnet.indexOf("/") + var ipAddressStr = subnet.substring(0, charIndex) + var netmask = subnet.substring(charIndex + 1).toInt + + var addrParts = ipAddressStr.split("\\.") + var bytes = Array[Byte]( + addrParts(0).toInt.toByte, + addrParts(1).toInt.toByte, + addrParts(2).toInt.toByte, + addrParts(3).toInt.toByte + ) + + return new TreeNode(bytes, netmask, true, lineRemainder) +} + +// Create a node for an IPv6 entry +def createIPv6TreeNode(fullLine: String): TreeNode = { + var charIndex = fullLine.indexOf(",") + var subnet = fullLine.substring(0, charIndex) + var lineRemainder = fullLine.substring(charIndex + 1) + + charIndex = subnet.indexOf("/") + var ipAddressStr = subnet.substring(0, charIndex) + var netmask = subnet.substring(charIndex + 1).toInt + + var bytes: Array[Byte] = null + charIndex = ipAddressStr.indexOf("::") + + if (charIndex == -1) { + var values = ipAddressStr.split(":").map(x => Integer.parseInt(x, 16)) + bytes = Array.fill(16){0.toByte} + for (i <- 0 to 7) { + bytes(i * 2) = (values(i) >> 8).toByte + bytes(i * 2 + 1) = (values(i) & 255).toByte + } + } else if ("::" == ipAddressStr) { + bytes = Array.fill(16){0.toByte} + } else { + if (charIndex == 0) { + var values = ipAddressStr.substring(2).split(":").map(x => Integer.parseInt(x, 16)) + bytes = Array.fill(16){0.toByte} + for (i <- 8 - values.length to 7) { + var valuesIndex = i - 8 + values.length + bytes(i * 2) = (values(valuesIndex) >> 8).toByte + bytes(i * 2 + 1) = (values(valuesIndex) & 255).toByte + } + } else if (charIndex == ipAddressStr.length - 2) { + var values = ipAddressStr.substring(0, ipAddressStr.length - 2).split(":").map(x => Integer.parseInt(x, 16)) + bytes = Array.fill(16){0.toByte} + for (i <- 0 to values.length - 1) { + bytes(i * 2) = (values(i) >> 8).toByte + bytes(i * 2 + 1) = (values(i) & 255).toByte + } + } else { + var startValues = ipAddressStr.substring(0, charIndex).split(":").map(x => Integer.parseInt(x, 16)) + var endValues = ipAddressStr.substring(charIndex + 2).split(":").map(x => Integer.parseInt(x, 16)) + bytes = Array.fill(16){0.toByte} + for (i <- 0 to startValues.length - 1) { + bytes(i * 2) = (startValues(i) >> 8).toByte + bytes(i * 2 + 1) = (startValues(i) & 255).toByte + } + for (i <- 8 - endValues.length to 7) { + var valuesIndex = i - 8 + endValues.length + bytes(i * 2) = (endValues(valuesIndex) >> 8).toByte + bytes(i * 2 + 1) = (endValues(valuesIndex) & 255).toByte + } + } + } + + return new TreeNode(bytes, netmask, false, lineRemainder) +} + +def createTreeNode(fullLine: String): TreeNode = { + var charIndex = fullLine.indexOf(",") + var subnet = fullLine.substring(0, charIndex) + if (subnet.indexOf(':') > -1) { + return createIPv6TreeNode(fullLine) + } else { + return createIPv4TreeNode(fullLine) + } +} + +var header: String = null +def readSubnets(fileName: String, ipv4Root: TreeNode, ipv6Root: TreeNode): Unit = { + var reader = new BufferedReader(new FileReader(fileName)) + header = reader.readLine() + + var line = reader.readLine() + while (line != null) { + var newNode = createTreeNode(line) + if (newNode.isIPv4) { + ipv4Root.addNode(newNode) + ipv4NodeCount += 1 + } else { + ipv6Root.addNode(newNode) + ipv6NodeCount += 1 + } + + line = reader.readLine() + } + + reader.close() +} + +def writeSubnets(fileName: String, ipv4Root: TreeNode, ipv6Root: TreeNode): Unit = { + var outStream = new PrintStream(fileName) + outStream.print(header) + outStream.print(",ip_range_start,ip_range_end,ipv4") + outStream.print("\r\n") + + println("Writing IPv4 data") + ipv4NodeOutputCount = 0 + ipv4Root.printTree(outStream, (ipv4NodeCount / 10).floor.toInt) + println() + + println("Writing IPv6 data") + ipv6NodeOutputCount = 0 + ipv6Root.printTree(outStream, (ipv6NodeCount / 10).floor.toInt) + println() + + outStream.close() +} + +// Create the table in Spark +def createTable(fileName: String, tableName: String): Unit = { + try { + var sparkSessionClass = Class.forName("org.apache.spark.sql.SparkSession") + var activeSessionMethod = sparkSessionClass.getMethod("active") + var sparkSession = activeSessionMethod.invoke(sparkSessionClass) + + var readMethod = sparkSessionClass.getMethod("read") + var dataFrameReader = readMethod.invoke(sparkSession) + + var dataFrameReaderClass = Class.forName("org.apache.spark.sql.DataFrameReader") + var formatMethod = dataFrameReaderClass.getMethod("format", classOf[java.lang.String]) + dataFrameReader = formatMethod.invoke(dataFrameReader, "csv") + + var optionMethod = dataFrameReaderClass.getMethod("option", classOf[java.lang.String], classOf[java.lang.String]) + dataFrameReader = optionMethod.invoke(dataFrameReader, "inferSchema", "true") + dataFrameReader = optionMethod.invoke(dataFrameReader, "header", "true") + + var loadMethod = dataFrameReaderClass.getMethod("load", classOf[java.lang.String]) + var dataset = loadMethod.invoke(dataFrameReader, fileName) + + var datasetClass = Class.forName("org.apache.spark.sql.Dataset") + var writeMethod = datasetClass.getMethod("write") + var dataFrameWriter = writeMethod.invoke(dataset) + + var dataFrameWriterClass = Class.forName("org.apache.spark.sql.DataFrameWriter") + var saveAsTableMethod = dataFrameWriterClass.getMethod("saveAsTable", classOf[java.lang.String]) + saveAsTableMethod.invoke(dataFrameWriter, tableName) + } catch { + case e: Exception => { + println("Unable to load data into table") + e.printStackTrace() + } + } +} + +// Sanitize the data and import it into a Spark table +def cleanAndImport(inputFile: String, outputFile: String, tableName: String): Unit = { + if (tableName != null) { + try { + Class.forName("org.apache.spark.sql.SparkSession") + } catch { + case e: ClassNotFoundException => { + println("Must run in Spark CLI to create the Spark table") + return + } + } + } + + println("Loading data") + var ipv4Root = new TreeNode(null, 0, true, null) + var ipv6Root = new TreeNode(null, 0, false, null) + readSubnets(inputFile, ipv4Root, ipv6Root) + + println("Fixing overlapping subnets") + ipv4Root.fixTree() + ipv6Root.fixTree() + + println("Writing data to file") + writeSubnets(outputFile, ipv4Root, ipv6Root) + + if (tableName != null) { + println("Creating and populating Spark table") + createTable(outputFile, tableName) + } + + println("Done") +} + +var FILE_PATH_TO_INPUT_CSV: String = "/replace/this/value" +var FILE_PATH_TO_OUTPUT_CSV: String = "/replace/this/value" +var TABLE_NAME: String = null +var result = cleanAndImport(FILE_PATH_TO_INPUT_CSV, FILE_PATH_TO_OUTPUT_CSV, TABLE_NAME) diff --git a/docs/opensearch-geoip.md b/docs/opensearch-geoip.md new file mode 100644 index 000000000..cd262e187 --- /dev/null +++ b/docs/opensearch-geoip.md @@ -0,0 +1,90 @@ +# OpenSearch Geographic IP Location Data + +## Overview + +OpenSearch has PPL functions for looking up the geographic location of IP addresses. In order +to use these functions, a table needs to be created containing the geographic location +information. + +## How to Create Geographic Location Index + +A script has been created that can cleanup and augment a CSV file that contains geographic +location information for IP addresses ranges. The CSV file is expected to have the following +columns: + +| Column Name | Description | +|------------------|---------------------------------------------------------------------------------------------------------| +| cidr | IP address subnet in format `IP_ADDRESS/NETMASK` (ex. `192.168.0.0/24`). IP address can be IPv4 or IPv6 | +| country_iso_code | ISO code of the country where the IP address subnet is located | +| country_name | Name of the country where the IP address subnet is located | +| continent_name | Name of the continent where the IP address subent is located | +| region_iso_code | ISO code of the region where the IP address subnet is located | +| region_name | Name of the region where the IP address subnet is located | +| city_name | Name of the city where the IP address subnet is located | +| time_zone | Time zone where the IP address subnet is located | +| location | Latitude and longitude where the IP address subnet is located | + +The script will cleanup the data by splitting IP address subnets so that an IP address can only be in at most one subnet. + +The data is augmented by adding 3 fields. + +| Column Name | Description | +|----------------|--------------------------------------------------------------------| +| ip_range_start | An integer value used to determine if an IP address is in a subnet | +| ip_range_end | An integer value used to determine if an IP address is in a subnet | +| ipv4 | A boolean value, `true` if the IP address subnet is in IPv4 format | + +## Run the Script + +1. Create a copy of the scala file `load_geoip_data.scala` +2. Edit the copy of the file `load_geoip_data.scala` + There are three variables that need to be updated. + 1. `FILE_PATH_TO_INPUT_CSV` - the full path to the CSV file to load + 2. `FILE_PATH_TO_OUTPUT_CSV` - the full path of the CSV file to write the sanitized data to + 3. `TABLE_NAME` - name of the index to create in OpenSearch. No table is created if this is null +4. Save the file +5. Run the Apache Spark CLI and connect to the database +6. Load the Scala script + ```scala + :load FILENAME + ``` + Replace `FILENAME` with the full path to the Scala script. + +## Notes for EMR + +With EMR it is necessary to load the data from an S3 object. Follow the instructions for +**Run the Script**, but make sure that `TABLE_NAME` is set to `null`. Upload the +`FILE_PATH_TO_OUTPUT_CSV` to S3. + +## End-to-End + +How to download a sample data GeoIP location data set, clean it up and import it into a +Spark table. + +1. Use a web browser to download the [data set Zip file](https://geoip.maps.opensearch.org/v1/geolite2-city/data/geolite2-city_1732905911000.zip) +2. Extract the Zip file +3. Copy the file `geolite2-City.csv` to the computer where you run `spark-shell` +4. Copy the file file `load_geoip_data.scala` to the computer where you run `spark-shell` +5. Connect to the computer where you run `spark-shell` +6. Change to the directory containing `geolite2-City.csv` and `load_geoip_data.scala` +7. Update the `load_geoip_data.scala` file to specify the CSV files to read and write. Also update + it to specify the Spark table to create (`geo_ip_data` in this case). + ``` + sed -i \ + -e "s#^var FILE_PATH_TO_INPUT_CSV: String =.*#var FILE_PATH_TO_INPUT_CSV: String = \"${PWD}/geolite2-City.csv\"#" \ + load_geoip_data.scala + sed -i \ + -e "s#^var FILE_PATH_TO_OUTPUT_CSV: String = .*#var FILE_PATH_TO_OUTPUT_CSV: String = \"${PWD}/geolite2-City-fixed.csv\"#" \ + load_geoip_data.scala + sed -i \ + -e 's#^var TABLE_NAME: String = .*#var TABLE_NAME: String = "geo_ip_data"#' \ + load_geoip_data.scala + ``` +8. Run `spark-shell` + ``` + spark-shell + ``` +9. Load and run the `load_geoip_data.scala` script + ``` + :load load_geoip_data.scala + ``` diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/datatype/FlintDataType.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/datatype/FlintDataType.scala index 5d920a07e..19fe28a2d 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/datatype/FlintDataType.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/datatype/FlintDataType.scala @@ -142,6 +142,7 @@ object FlintDataType { case ByteType => JObject("type" -> JString("byte")) case DoubleType => JObject("type" -> JString("double")) case FloatType => JObject("type" -> JString("float")) + case DecimalType() => JObject("type" -> JString("double")) // Date case TimestampType | _: TimestampNTZType => diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/datatype/FlintDataTypeSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/datatype/FlintDataTypeSuite.scala index 44e8158d8..312f3a5a1 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/datatype/FlintDataTypeSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/datatype/FlintDataTypeSuite.scala @@ -143,6 +143,20 @@ class FlintDataTypeSuite extends FlintSuite with Matchers { |}""".stripMargin) } + test("spark decimal type serialize") { + val sparkStructType = StructType( + StructField("decimalField", DecimalType(1, 1), true) :: + Nil) + + FlintDataType.serialize(sparkStructType) shouldBe compactJson("""{ + | "properties": { + | "decimalField": { + | "type": "double" + | } + | } + |}""".stripMargin) + } + test("spark varchar and char type serialize") { val flintDataType = """{ | "properties": { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index ae2e53090..bf5e6309e 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -523,5 +523,45 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { } } + test("create materialized view with decimal and map types") { + val decimalAndMapTable = s"$catalogName.default.mv_test_decimal_map" + val decimalAndMapMv = s"$catalogName.default.mv_test_decimal_map_ser" + withTable(decimalAndMapTable) { + createMapAndDecimalTimeSeriesTable(decimalAndMapTable) + + withTempDir { checkpointDir => + sql(s""" + | CREATE MATERIALIZED VIEW $decimalAndMapMv + | AS + | SELECT + | base_score, mymap + | FROM $decimalAndMapTable + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}' + | ) + |""".stripMargin) + + // Wait for streaming job complete current micro batch + val flintIndex = getFlintIndexName(decimalAndMapMv) + val job = spark.streams.active.find(_.name == flintIndex) + job shouldBe defined + failAfter(streamingTimeout) { + job.get.processAllAvailable() + } + + flint.describeIndex(flintIndex) shouldBe defined + checkAnswer( + flint.queryIndex(flintIndex).select("base_score", "mymap"), + Seq( + Row(3.1415926, Row(null, null, null, null, "mapvalue1")), + Row(4.1415926, Row("mapvalue2", null, null, null, null)), + Row(5.1415926, Row(null, null, "mapvalue3", null, null)), + Row(6.1415926, Row(null, null, null, "mapvalue4", null)), + Row(7.1415926, Row(null, "mapvalue5", null, null, null)))) + } + } + } + private def timestamp(ts: String): Timestamp = Timestamp.valueOf(ts) } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 68d370791..7c19cab12 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -445,6 +445,34 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 03:00:00', 'E', 15, 'Vancouver')") } + protected def createMapAndDecimalTimeSeriesTable(testTable: String): Unit = { + // CSV tables do not support MAP types so we use JSON instead + val finalTableType = if (tableType == "CSV") "JSON" else tableType + + sql(s""" + | CREATE TABLE $testTable + | ( + | time TIMESTAMP, + | name STRING, + | age INT, + | base_score DECIMAL(8, 7), + | mymap MAP + | ) + | USING $finalTableType $tableOptions + |""".stripMargin) + + sql( + s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:01:00', 'A', 30, 3.1415926, Map('mapkey1', 'mapvalue1'))") + sql( + s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:10:00', 'B', 20, 4.1415926, Map('mapkey2', 'mapvalue2'))") + sql( + s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:15:00', 'C', 35, 5.1415926, Map('mapkey3', 'mapvalue3'))") + sql( + s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 01:00:00', 'D', 40, 6.1415926, Map('mapkey4', 'mapvalue4'))") + sql( + s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 03:00:00', 'E', 15, 7.1415926, Map('mapkey5', 'mapvalue5'))") + } + protected def createTimeSeriesTransactionTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala index 4788aa23f..ca96c126f 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala @@ -277,6 +277,26 @@ class FlintSparkPPLFillnullITSuite assert(ex.getMessage().contains("Syntax error ")) } + test("test fillnull with null_replacement type mismatch") { + val frame = sql(s""" + | source = $testTable | fillnull with cast(0 as long) in status_code + | """.stripMargin) + + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp", "status_code"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, "/home", null, 200), + Row(2, "/about", "2023-10-01 10:05:00", 0), + Row(3, "/contact", "2023-10-01 10:10:00", 0), + Row(4, null, "2023-10-01 10:15:00", 301), + Row(5, null, "2023-10-01 10:20:00", 200), + Row(6, "/home", null, 403)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + private def fillNullExpectedPlan( nullReplacements: Seq[(String, Expression)], addDefaultProject: Boolean = true): LogicalPlan = { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 1ed5f9059..1c99f657d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Generate; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$; import org.apache.spark.sql.catalyst.plans.logical.Project$; import org.apache.spark.sql.catalyst.plans.logical.UnresolvedTableSpec; import org.apache.spark.sql.connector.expressions.FieldReference; @@ -465,10 +466,30 @@ public LogicalPlan visitFillNull(FillNull fillNull, CatalystPlanContext context) Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - LogicalPlan resultWithoutDuplicatedColumns = context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(toDrop), logicalPlan)); + LogicalPlan resultWithoutDuplicatedColumns = context.apply(dropOriginalColumns(p -> p.children().head(), toDrop)); return Objects.requireNonNull(resultWithoutDuplicatedColumns, "FillNull operation failed"); } + /** + * This method is used to generate DataFrameDropColumns operator for dropping duplicated columns + * in the original plan. Then achieving similar effect like updating columns. + * + * PLAN_ID_TAG is a mechanism inner Spark that explicitly specify a plan to resolve the + * UnresolvedAttributes. Set toDrop expressions' PLAN_ID_TAG to the same value as that of the + * original plan, so Spark will resolve them correctly by that plan instead of the child. + */ + private java.util.function.Function dropOriginalColumns( + java.util.function.Function findOriginalPlan, + List toDrop) { + return logicalPlan -> { + LogicalPlan originalPlan = findOriginalPlan.apply(logicalPlan); + long planId = logicalPlan.hashCode(); + originalPlan.setTagValue(LogicalPlan$.MODULE$.PLAN_ID_TAG(), planId); + toDrop.forEach(e -> e.setTagValue(LogicalPlan$.MODULE$.PLAN_ID_TAG(), planId)); + return DataFrameDropColumns$.MODULE$.apply(seq(toDrop), logicalPlan); + }; + } + @Override public LogicalPlan visitFlatten(Flatten flatten, CatalystPlanContext context) { visitFirstChild(flatten, context);