Skip to content

Commit

Permalink
[DOP-18631] - update test
Browse files Browse the repository at this point in the history
  • Loading branch information
maxim-lixakov committed Aug 5, 2024
1 parent 3bd7515 commit 3e3fc6c
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@ private object ClickhouseDialectExtension extends JdbcDialect {

private val logger = LoggerFactory.getLogger(getClass)

private val arrayTypePattern: Regex = "^Array\\((.*)\\)$".r
private val nullableTypePattern: Regex = "^Nullable\\((.*)\\)$".r
private val dateTypePattern: Regex = "^[dD][aA][tT][eE]$".r
private val dateTimeTypePattern: Regex =
"^[dD][aA][tT][eE][tT][iI][mM][eE](64)?(\\((.*)\\))?$".r
private val decimalTypePattern: Regex =
"^[dD][eE][cC][iI][mM][aA][lL]\\((\\d+),\\s*(\\d+)\\)$".r
private val decimalTypePattern2: Regex =
"^[dD][eE][cC][iI][mM][aA][lL](32|64|128|256)\\((\\d+)\\)$".r
private val enumTypePattern: Regex = "^Enum(8|16)$".r
private val fixedStringTypePattern: Regex = "^FixedString\\((\\d+)\\)$".r
private val arrayTypePattern: Regex = """^Array\((.*)\)$""".r
private val nullableTypePattern: Regex = """^Nullable\((.*)\)$""".r
private val dateTypePattern: Regex = """(?i)^Date$""".r
private val dateTimeTypePattern: Regex = """(?i)^DateTime(64)?(\((.*)\))?$""".r
private val decimalTypePattern: Regex = """(?i)^Decimal\((\d+),\s*(\d+)\)$""".r
private val decimalTypePattern2: Regex = """(?i)^Decimal(32|64|128|256)\((\d+)\)$""".r

override def canHandle(url: String): Boolean = {
url.startsWith("jdbc:clickhouse")
Expand Down Expand Up @@ -51,6 +46,8 @@ private object ClickhouseDialectExtension extends JdbcDialect {
case Types.ARRAY =>
unwrapNullable(typeName) match {
case (_, arrayTypePattern(nestType)) =>
// due to https://github.com/ClickHouse/clickhouse-java/issues/1754, spark is not able to read Arrays of
// any types except Decimal(...) and String
toCatalystType(Types.ARRAY, nestType, size, scale, md).map {
case (nullable, dataType) => ArrayType(dataType, nullable)
}
Expand All @@ -68,7 +65,7 @@ private object ClickhouseDialectExtension extends JdbcDialect {
md: MetadataBuilder): Option[(Boolean, DataType)] = {
val (nullable, _typeName) = unwrapNullable(typeName)
val dataType = _typeName match {
case "String" | "UUID" | fixedStringTypePattern() | enumTypePattern(_) =>
case "String" =>
logger.debug(s"Custom mapping applied: StringType for '${_typeName}'")
Some(StringType)
case "Int8" =>
Expand All @@ -80,10 +77,10 @@ private object ClickhouseDialectExtension extends JdbcDialect {
case "UInt16" | "Int32" =>
logger.debug(s"Custom mapping applied: IntegerType for '${_typeName}'")
Some(IntegerType)
case "UInt32" | "Int64" | "UInt64" | "IPv4" =>
case "UInt32" | "Int64" =>
logger.debug(s"Custom mapping applied: LongType for '${_typeName}'")
Some(LongType)
case "Int128" | "Int256" | "UInt256" =>
case "UInt64" | "Int128" | "Int256" | "UInt256" =>
logger.debug(s"Type '${_typeName}' is not supported")
None
case "Float32" =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,196 @@ class ClickhouseDialectTest
statement.close()
}

test("read ClickHouse UInt8 as Spark ShortType") {
setupTable("uByteColumn UInt8")
insertTestData(Seq("(0)", "(255)")) // min and max values for unsigned byte

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == ShortType)

val data = df.collect().map(_.getShort(0)).sorted
assert(data sameElements Array(0, 255))
}

test("read ClickHouse UInt16 as Spark IntegerType") {
setupTable("uShortColumn UInt16")
insertTestData(Seq("(0)", "(65535)")) // min and max values for unsigned short

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == IntegerType)

val data = df.collect().map(_.getInt(0)).sorted
assert(data sameElements Array(0, 65535))
}

test("read ClickHouse UInt32 as Spark LongType") {
setupTable("uIntColumn UInt32")
insertTestData(Seq("(0)", "(4294967295)")) // min and max values for unsigned int

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == LongType)

val data = df.collect().map(_.getLong(0)).sorted
assert(data sameElements Array(0L, 4294967295L))
}

test("read ClickHouse Float32 as Spark FloatType") {
setupTable("floatColumn Float32")
insertTestData(Seq("(-1.23)", "(4.56)"))

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == FloatType)
val data = df.collect().map(_.getFloat(0)).sorted
assert(data sameElements Array(-1.23f, 4.56f))
}

test("read ClickHouse Float64 as Spark DoubleType") {
setupTable("doubleColumn Float64")
insertTestData(Seq("(-1.7976931348623157E308)", "(1.7976931348623157E308)")) // min and max values for double

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == DoubleType)

val data = df.collect().map(_.getDouble(0)).sorted
assert(data sameElements Array(-1.7976931348623157E308, 1.7976931348623157E308))
}

test("read ClickHouse Decimal(9,2) as Spark DecimalType") {
setupTable("decimalColumn Decimal(9,2)")
insertTestData(Seq("(12345.67)", "(89012.34)"))

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == DecimalType(9, 2))

val data = df.collect().map(_.getDecimal(0)).sorted
assert(data sameElements Array(
new java.math.BigDecimal("12345.67"),
new java.math.BigDecimal("89012.34")
))
}

test("read ClickHouse Decimal32 as Spark DecimalType(9, scale)") {
setupTable("decimalColumn Decimal32(2)")
insertTestData(Seq("(12345.67)", "(89012.34)"))

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == DecimalType(9, 2))

val data = df.collect().map(_.getDecimal(0)).sorted
assert(data sameElements Array(
new java.math.BigDecimal("12345.67"),
new java.math.BigDecimal("89012.34")
))
}

test("read ClickHouse Decimal64 as Spark DecimalType(18, scale)") {
setupTable("decimalColumn Decimal64(2)")
insertTestData(Seq("(123456789.12)", "(987654321.34)"))

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == DecimalType(18, 2))

val data = df.collect().map(_.getDecimal(0)).sorted
assert(data sameElements Array(
new java.math.BigDecimal("123456789.12"),
new java.math.BigDecimal("987654321.34")
))
}

test("read ClickHouse Decimal128 as Spark DecimalType(38, scale)") {
setupTable("decimalColumn Decimal128(2)")
insertTestData(Seq("(123456789012345678901234567890.12)", "(987654321098765432109876543210.34)"))

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == DecimalType(38, 2))

val data = df.collect().map(_.getDecimal(0)).sorted
assert(data sameElements Array(
new java.math.BigDecimal("123456789012345678901234567890.12"),
new java.math.BigDecimal("987654321098765432109876543210.34")
))
}

test("read ClickHouse Date as Spark DateType") {
setupTable("dateColumn Date")
insertTestData(Seq("('2023-01-01')", "('2023-12-31')"))

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == DateType)

val data = df.collect().map(_.getDate(0))
assert(data sameElements Array(java.sql.Date.valueOf("2023-01-01"), java.sql.Date.valueOf("2023-12-31")))
}

test("read ClickHouse DateTime as Spark TimestampType") {
setupTable("datetimeColumn DateTime")
insertTestData(Seq("('2023-01-01 12:34:56')", "('2023-12-31 23:59:59')"))

val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.load()

assert(df.schema.fields.head.dataType == TimestampType)

val data = df.collect().map(_.getTimestamp(0))
assert(data sameElements Array(
java.sql.Timestamp.valueOf("2023-01-01 12:34:56"),
java.sql.Timestamp.valueOf("2023-12-31 23:59:59")
))
}

test("write Spark TimestampType as ClickHouse Datetime64(6)") {
val schema = StructType(Seq(StructField("timestampColumn", TimestampType, nullable = true)))
val currentTime = new java.sql.Timestamp(System.currentTimeMillis())
Expand Down Expand Up @@ -257,53 +447,63 @@ class ClickhouseDialectTest
}
}

val testWriteArrayCases: TableFor3[String, Seq[Row], DataType] = Table(
("columnDefinition", "insertedData", "expectedType"),
val testWriteArrayCases = Table(
("columnName", "insertedData", "expectedType", "expectedClickhouseType"),
(
"charArrayColumn Array(String)",
"charArrayColumn",
Seq(Row(Array("a", "b", "c", "d", "e"))),
ArrayType(StringType, containsNull = false)),
ArrayType(StringType, containsNull = false),
"Array(String)"),
(
"byteArrayColumn Array(Int8)",
"byteArrayColumn",
Seq(Row(Array(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte))),
ArrayType(ByteType, containsNull = false)),
ArrayType(ByteType, containsNull = false),
"Array(Int8)"),
(
"intArrayColumn Array(Int32)",
"intArrayColumn",
Seq(Row(Array(1, 2, 3, 4, 5))),
ArrayType(IntegerType, containsNull = false)),
ArrayType(IntegerType, containsNull = false),
"Array(Int32)"),
(
"longArrayColumn Array(Int64)",
"longArrayColumn",
Seq(Row(Array(1L, 2L, 3L, 4L, 5L))),
ArrayType(LongType, containsNull = false)),
ArrayType(LongType, containsNull = false),
"Array(Int64)"),
(
"floatArrayColumn Array(Float32)",
"floatArrayColumn",
Seq(Row(Array(1.0f, 2.0f, 3.0f, 4.0f, 5.0f))),
ArrayType(FloatType, containsNull = false)),
ArrayType(FloatType, containsNull = false),
"Array(Float32)"),
(
"dateArrayColumn Array(Date)",
"dateArrayColumn",
Seq(
Row(
Array(
java.sql.Date.valueOf("2022-01-01"),
java.sql.Date.valueOf("2022-01-02"),
java.sql.Date.valueOf("2022-01-03")))),
ArrayType(DateType, containsNull = false)),
ArrayType(DateType, containsNull = false),
"Array(Date)"),
(
"decimalArrayColumn Array(Decimal(9,2))",
"decimalArrayColumn",
Seq(
Row(Array(
new java.math.BigDecimal("1.23"),
new java.math.BigDecimal("2.34"),
new java.math.BigDecimal("3.45"),
new java.math.BigDecimal("4.56"),
new java.math.BigDecimal("5.67")))),
ArrayType(DecimalType(9, 2), containsNull = false)))
ArrayType(DecimalType(9, 2), containsNull = false),
"Array(Decimal(9, 2))"))

forAll(testWriteArrayCases) {
(columnDefinition: String, insertedData: Seq[Row], expectedType: DataType) =>
test(s"write ClickHouse Array for $columnDefinition column") {
(
columnName: String,
insertedData: Seq[Row],
expectedType: DataType,
expectedClickhouseType: String) =>
test(s"write ClickHouse Array for $columnName column") {

val columnName = columnDefinition.split(" ")(0)
val schema = StructType(Array(StructField(columnName, expectedType, nullable = false)))
val df = spark.createDataFrame(spark.sparkContext.parallelize(insertedData), schema)

Expand All @@ -317,7 +517,8 @@ class ClickhouseDialectTest
.mode("errorIfExists")
.save()

assert(df.count() == 1)
val actualColumnType = getColumnType(columnName)
assert(actualColumnType == expectedClickhouseType)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ trait SharedSparkSession extends BeforeAndAfterAll { self: Suite =>
.appName("Spark Test Session")
.config("spark.ui.enabled", "false") // disable UI to reduce overhead
.config("spark.jars", jarPaths) // include the JAR file containing the custom dialect
.config("spark.driver.bindAddress", "127.0.0.1")
.getOrCreate()

// register custom Clickhouse dialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,26 @@ trait ClickhouseFixture extends BeforeAndAfterEach { self: Suite =>

override def afterEach(): Unit = {
val statement = connection.createStatement()
// statement.executeUpdate(s"DROP TABLE IF EXISTS $tableName")
statement.executeUpdate(s"DROP TABLE IF EXISTS $tableName")
statement.close()
connection.close()
super.afterEach()
}

def getColumnType(columnName: String): String = {
val query = s"DESCRIBE TABLE $tableName"
val statement = connection.createStatement()
val resultSet = statement.executeQuery(query)

var columnType = ""
while (resultSet.next()) {
if (resultSet.getString("name") == columnName) {
columnType = resultSet.getString("type")
}
}

resultSet.close()
statement.close()
columnType
}
}

0 comments on commit 3e3fc6c

Please sign in to comment.