Skip to content

Commit

Permalink
Merge pull request #38 from civitaspo/support-nanos
Browse files Browse the repository at this point in the history
Support timestamp-nanos
  • Loading branch information
civitaspo authored Apr 30, 2020
2 parents 5290a64 + ee21b3b commit 95df81c
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class CatalogRegistrator(
case "timestamp-millis" => "timestamp"
case "timestamp-micros" =>
"bigint" // Glue cannot recognize timestamp-micros.
case "timestamp-nanos" =>
"bigint" // Glue cannot recognize timestamp-nanos.
case "int8" => "tinyint"
case "int16" => "smallint"
case "int32" => "int"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,70 +47,62 @@ object EmbulkMessageType {
LogicalTypeHandlerStore.empty
) extends ColumnVisitor {

override def booleanColumn(column: Column): Unit = {
private def addTypeByLogicalTypeHandlerOrDefault(
column: Column,
default: => Type
): Unit = {
builder.add(
Types.optional(PrimitiveTypeName.BOOLEAN).named(column.getName)
logicalTypeHandlers.get(column.getName, column.getType) match {
case Some(handler) if handler.isConvertible(column.getType) =>
handler.newSchemaFieldType(column.getName)
case _ => default
}
)
}

override def longColumn(column: Column): Unit = {
val name = column.getName
val et = column.getType

val t = logicalTypeHandlers.get(name, et) match {
case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name)
case _ =>
Types.optional(PrimitiveTypeName.INT64).named(column.getName)
}
override def booleanColumn(column: Column): Unit = {
addTypeByLogicalTypeHandlerOrDefault(column, default = {
Types.optional(PrimitiveTypeName.BOOLEAN).named(column.getName)
})
}

builder.add(t)
override def longColumn(column: Column): Unit = {
addTypeByLogicalTypeHandlerOrDefault(column, default = {
Types.optional(PrimitiveTypeName.INT64).named(column.getName)
})
}

override def doubleColumn(column: Column): Unit = {
builder.add(
addTypeByLogicalTypeHandlerOrDefault(column, default = {
Types.optional(PrimitiveTypeName.DOUBLE).named(column.getName)
)
})
}

override def stringColumn(column: Column): Unit = {
builder.add(
addTypeByLogicalTypeHandlerOrDefault(column, default = {
Types
.optional(PrimitiveTypeName.BINARY)
.as(LogicalTypeAnnotation.stringType())
.named(column.getName)
)
})
}

override def timestampColumn(column: Column): Unit = {
val name = column.getName
val et = column.getType

val t = logicalTypeHandlers.get(name, et) match {
case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name)
case _ =>
Types
.optional(PrimitiveTypeName.BINARY)
.as(LogicalTypeAnnotation.stringType())
.named(name)
}

builder.add(t)
addTypeByLogicalTypeHandlerOrDefault(column, default = {
Types
.optional(PrimitiveTypeName.BINARY)
.as(LogicalTypeAnnotation.stringType())
.named(column.getName)
})
}

override def jsonColumn(column: Column): Unit = {
val name = column.getName
val et = column.getType

val t = logicalTypeHandlers.get(name, et) match {
case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name)
case _ =>
Types
.optional(PrimitiveTypeName.BINARY)
.as(LogicalTypeAnnotation.stringType())
.named(name)
}

builder.add(t)
addTypeByLogicalTypeHandlerOrDefault(column, default = {
Types
.optional(PrimitiveTypeName.BINARY)
.as(LogicalTypeAnnotation.stringType())
.named(column.getName)
})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,36 @@ object TimestampMicrosLogicalTypeHandler extends LogicalTypeHandler {
}
}

object TimestampNanosLogicalTypeHandler extends LogicalTypeHandler {

override def isConvertible(t: EmbulkType): Boolean = {
t == EmbulkTypes.TIMESTAMP
}

override def newSchemaFieldType(name: String): PrimitiveType = {
Types
.optional(PrimitiveTypeName.INT64)
.as(
LogicalTypeAnnotation
.timestampType(true, LogicalTypeAnnotation.TimeUnit.NANOS)
)
.named(name)
}

override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = {
orig match {
case ts: Timestamp =>
val v =
(ts.getEpochSecond * 1_000_000_000L) + ts.getNano.asInstanceOf[Long]
recordConsumer.addLong(v)
case _ =>
throw new DataException(
"given mismatched type value; expected type is timestamp"
)
}
}
}

object Int8LogicalTypeHandler
extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(8, true))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ object LogicalTypeHandlerStore {
private val STRING_TO_LOGICAL_TYPE = Map[String, LogicalTypeHandler](
"timestamp-millis" -> TimestampMillisLogicalTypeHandler,
"timestamp-micros" -> TimestampMicrosLogicalTypeHandler,
"timestamp-nanos" -> TimestampNanosLogicalTypeHandler,
"int8" -> Int8LogicalTypeHandler,
"int16" -> Int16LogicalTypeHandler,
"int32" -> Int32LogicalTypeHandler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ import com.amazonaws.services.s3.transfer.{
TransferManagerBuilder
}
import com.google.inject.{Binder, Guice, Module, Stage}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path => HadoopPath}
import org.apache.parquet.hadoop.ParquetReader
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetReader}
import org.apache.parquet.hadoop.util.HadoopInputFile
import org.apache.parquet.schema.MessageType
import org.apache.parquet.tools.read.{SimpleReadSupport, SimpleRecord}
import org.embulk.{TestPluginSourceModule, TestUtilityModule}
import org.embulk.config.{
Expand Down Expand Up @@ -116,7 +119,8 @@ abstract class EmbulkPluginTestHelper
def runOutput(
outConfig: ConfigSource,
schema: Schema,
data: Seq[Seq[Any]]
data: Seq[Seq[Any]],
messageTypeTest: MessageType => Unit = { _ => }
): Seq[Seq[AnyRef]] = {
try {
Exec.doWith(
Expand Down Expand Up @@ -157,7 +161,7 @@ abstract class EmbulkPluginTestHelper
case ex: ExecutionException => throw ex.getCause
}

readS3Parquet(TEST_BUCKET_NAME, TEST_PATH_PREFIX)
readS3Parquet(TEST_BUCKET_NAME, TEST_PATH_PREFIX, messageTypeTest)
}

private def withLocalStackS3Client[A](f: AmazonS3 => A): A = {
Expand All @@ -180,7 +184,11 @@ abstract class EmbulkPluginTestHelper
finally client.shutdown()
}

def readS3Parquet(bucket: String, prefix: String): Seq[Seq[AnyRef]] = {
private def readS3Parquet(
bucket: String,
prefix: String,
messageTypeTest: MessageType => Unit = { _ => }
): Seq[Seq[AnyRef]] = {
val tmpDir: Path = Files.createTempDirectory("embulk-output-parquet")
withLocalStackS3Client { s3 =>
val xfer: TransferManager = TransferManagerBuilder
Expand All @@ -207,11 +215,21 @@ abstract class EmbulkPluginTestHelper
.map(_.getAbsolutePath)
.foldLeft(Seq[Seq[AnyRef]]()) {
(result: Seq[Seq[AnyRef]], path: String) =>
result ++ readParquetFile(path)
result ++ readParquetFile(path, messageTypeTest)
}
}

private def readParquetFile(pathString: String): Seq[Seq[AnyRef]] = {
private def readParquetFile(
pathString: String,
messageTypeTest: MessageType => Unit = { _ => }
): Seq[Seq[AnyRef]] = {
Using.resource(
ParquetFileReader.open(
HadoopInputFile
.fromPath(new HadoopPath(pathString), new Configuration())
)
) { reader => messageTypeTest(reader.getFileMetaData.getSchema) }

val reader: ParquetReader[SimpleRecord] = ParquetReader
.builder(
new SimpleReadSupport(),
Expand Down
Loading

0 comments on commit 95df81c

Please sign in to comment.