Skip to content

Commit

Permalink
supporting different compressions csv
Browse files Browse the repository at this point in the history
  • Loading branch information
Jolanrensen committed Oct 4, 2024
1 parent 843be8e commit 2f51ebf
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.deephaven.csv.CsvSpecs
import org.apache.commons.csv.CSVFormat
import org.jetbrains.kotlinx.dataframe.api.ParserOptions
import org.jetbrains.kotlinx.dataframe.io.ColType
import org.jetbrains.kotlinx.dataframe.io.CsvCompression
import org.jetbrains.kotlinx.dataframe.io.DEFAULT_COL_TYPE
import org.jetbrains.kotlinx.dataframe.io.QuoteMode

Expand Down Expand Up @@ -31,10 +32,11 @@ internal object CsvTsvParams {
val HEADER: List<String> = emptyList()

/**
* @param isCompressed If `true`, the input stream is compressed and will be decompressed before reading.
* The default is `false`.
* @param compression Determines the compression of the CSV file.
* If a ZIP file contains multiple files, an [IllegalArgumentException] is thrown.
* The default is [CsvCompression.None].
*/
const val IS_COMPRESSED: Boolean = false
val COMPRESSION: CsvCompression<*> = CsvCompression.None

/**
* @param colTypes A map of column names to their expected [ColType]s. Can be supplied to force
Expand Down Expand Up @@ -70,7 +72,7 @@ internal object CsvTsvParams {
)

/**
* @param ignoreEmptyLines If `true`, empty lines will be skipped.
* @param ignoreEmptyLines If `true`, intermediate empty lines will be skipped.
* The default is `false`.
*/
const val IGNORE_EMPTY_LINES: Boolean = false
Expand All @@ -79,9 +81,9 @@ internal object CsvTsvParams {
* @param allowMissingColumns If this set to `true`, then rows that are too short
* (that have fewer columns than the header row) will be interpreted as if the missing columns contained
* the empty string.
* The default is `false`.
* The default is `true`.
*/
const val ALLOW_MISSING_COLUMNS: Boolean = false
const val ALLOW_MISSING_COLUMNS: Boolean = true

/**
* @param ignoreExcessColumns If this set to `true`, then rows that are too long
Expand Down Expand Up @@ -158,7 +160,7 @@ internal object CsvTsvParams {
* @param recordSeparator The character that separates records in a CSV/TSV file.
* The default is `'\n'`.
*/
const val RECORD_SEPARATOR: Char = '\n'
const val RECORD_SEPARATOR: String = "\n"

/**
* @param headerComments A list of comments to include at the beginning of the CSV/TSV file.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
package org.jetbrains.kotlinx.dataframe.impl.io

import org.apache.commons.io.input.BOMInputStream
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.io.CsvCompression
import org.jetbrains.kotlinx.dataframe.io.CsvCompression.Custom
import org.jetbrains.kotlinx.dataframe.io.CsvCompression.Gzip
import org.jetbrains.kotlinx.dataframe.io.CsvCompression.None
import org.jetbrains.kotlinx.dataframe.io.CsvCompression.Zip
import org.jetbrains.kotlinx.dataframe.io.isURL
import org.jetbrains.kotlinx.dataframe.io.readJson
import java.io.File
import java.io.InputStream
import java.net.HttpURLConnection
import java.net.URL
import java.util.zip.ZipInputStream

internal fun isCompressed(fileOrUrl: String) = listOf("gz", "zip").contains(fileOrUrl.split(".").last())
internal fun compressionStateOf(fileOrUrl: String): CsvCompression<*> =
when (fileOrUrl.split(".").last()) {
"gz" -> CsvCompression.Gzip
"zip" -> CsvCompression.Zip
else -> CsvCompression.None
}

internal fun isCompressed(file: File) = listOf("gz", "zip").contains(file.extension)
internal fun compressionStateOf(file: File): CsvCompression<*> =
when (file.extension) {
"gz" -> CsvCompression.Gzip
"zip" -> CsvCompression.Zip
else -> CsvCompression.None
}

internal fun isCompressed(url: URL) = isCompressed(url.path)
internal fun compressionStateOf(url: URL): CsvCompression<*> = compressionStateOf(url.path)

internal fun catchHttpResponse(url: URL, body: (InputStream) -> AnyFrame): AnyFrame {
val connection = url.openConnection()
Expand Down Expand Up @@ -42,5 +59,41 @@ public fun asURL(fileOrUrl: String): URL =
if (isURL(fileOrUrl)) {
URL(fileOrUrl).toURI()
} else {
File(fileOrUrl).toURI()
File(fileOrUrl).also {
require(it.exists()) { "File not found: \"$fileOrUrl\"" }
require(it.isFile) { "Not a file: \"$fileOrUrl\"" }
}.toURI()
}.toURL()

internal inline fun <T> InputStream.useSafely(compression: CsvCompression<*>, block: (InputStream) -> T): T {
var zipInputStream: ZipInputStream? = null

// first wrap the stream in the compression algorithm
val unpackedStream = when (compression) {
None -> this

Zip -> compression(this).also {
it as ZipInputStream
// make sure to call nextEntry once to prepare the stream
if (it.nextEntry == null) error("No entries in zip file")

zipInputStream = it
}

Gzip -> compression(this)

is Custom<*> -> compression(this)
}

val bomSafeStream = BOMInputStream.builder().setInputStream(unpackedStream).get()

try {
return block(bomSafeStream)
} finally {
// if we were reading from a ZIP, make sure there was only one entry, as to
// warn the user of potential issues
if (compression == Zip && zipInputStream!!.nextEntry != null) {
throw IllegalArgumentException("Zip file contains more than one entry")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import kotlinx.datetime.Instant
import kotlinx.datetime.LocalDate
import kotlinx.datetime.LocalDateTime
import kotlinx.datetime.LocalTime
import org.apache.commons.io.input.BOMInputStream
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
Expand All @@ -34,11 +33,11 @@ import org.jetbrains.kotlinx.dataframe.api.tryParse
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import org.jetbrains.kotlinx.dataframe.impl.ColumnNameGenerator
import org.jetbrains.kotlinx.dataframe.io.ColType
import org.jetbrains.kotlinx.dataframe.io.CsvCompression
import org.jetbrains.kotlinx.dataframe.io.DEFAULT_COL_TYPE
import java.io.InputStream
import java.math.BigDecimal
import java.net.URL
import java.util.zip.GZIPInputStream
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability
import kotlin.reflect.typeOf
Expand All @@ -49,7 +48,7 @@ import kotlin.time.Duration
* @include [CsvTsvParams.INPUT_STREAM]
* @param delimiter The field delimiter character. The default is ',' for CSV, '\t' for TSV.
* @include [CsvTsvParams.HEADER]
* @include [CsvTsvParams.IS_COMPRESSED]
* @include [CsvTsvParams.COMPRESSION]
* @include [CsvTsvParams.COL_TYPES]
* @include [CsvTsvParams.SKIP_LINES]
* @include [CsvTsvParams.READ_LINES]
Expand All @@ -67,7 +66,7 @@ internal fun readCsvOrTsvImpl(
inputStream: InputStream,
delimiter: Char,
header: List<String> = CsvTsvParams.HEADER,
isCompressed: Boolean = CsvTsvParams.IS_COMPRESSED,
compression: CsvCompression<*> = CsvTsvParams.COMPRESSION,
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
skipLines: Long = CsvTsvParams.SKIP_LINES,
readLines: Long? = CsvTsvParams.READ_LINES,
Expand Down Expand Up @@ -115,32 +114,30 @@ internal fun readCsvOrTsvImpl(
colTypes(colTypes, useDeepHavenLocalDateTime) // this function must be last, so the return value is used
}.build()

val adjustedInputStream = inputStream
.let { if (isCompressed) GZIPInputStream(it) else it }
.let { BOMInputStream.builder().setInputStream(it).get() }

if (adjustedInputStream.available() <= 0) {
return if (header.isEmpty()) {
DataFrame.empty()
} else {
dataFrameOf(
header.map {
DataColumn.createValueColumn(
name = it,
values = emptyList<String>(),
type = typeOf<String>(),
)
},
)
val csvReaderResult = inputStream.useSafely(compression) { safeInputStream ->
if (safeInputStream.available() <= 0) {
return if (header.isEmpty()) {
DataFrame.empty()
} else {
dataFrameOf(
header.map {
DataColumn.createValueColumn(
name = it,
values = emptyList<String>(),
type = typeOf<String>(),
)
},
)
}
}
}

// read the csv
val csvReaderResult = CsvReader.read(
csvSpecs,
adjustedInputStream,
ListSink.SINK_FACTORY,
)
// read the csv
CsvReader.read(
csvSpecs,
safeInputStream,
ListSink.SINK_FACTORY,
)
}

val defaultColType = colTypes[DEFAULT_COL_TYPE]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal fun writeCsvOrTsvImpl(
escapeChar: Char? = CsvTsvParams.ESCAPE_CHAR,
commentChar: Char? = CsvTsvParams.COMMENT_CHAR,
headerComments: List<String> = CsvTsvParams.HEADER_COMMENTS,
recordSeparator: Char = CsvTsvParams.RECORD_SEPARATOR,
recordSeparator: String = CsvTsvParams.RECORD_SEPARATOR,
additionalCsvFormat: CSVFormat = CsvTsvParams.ADDITIONAL_CSV_FORMAT,
) {
val format = with(CSVFormat.Builder.create(additionalCsvFormat)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.jetbrains.kotlinx.dataframe.io

import java.io.InputStream
import java.util.zip.GZIPInputStream
import java.util.zip.ZipInputStream

/**
* Compression algorithm to use when reading csv files.
* We support GZIP and ZIP compression out of the box.
*
* Custom compression algorithms can be added by creating an instance of [Custom].
*/
public sealed class CsvCompression<I : InputStream>(public open val wrapStream: (InputStream) -> I) :
(InputStream) -> I by wrapStream {

public data object Gzip : CsvCompression<GZIPInputStream>(::GZIPInputStream)

public data object Zip : CsvCompression<ZipInputStream>(::ZipInputStream)

public data object None : CsvCompression<InputStream>({ it })

public data class Custom<I : InputStream>(override val wrapStream: (InputStream) -> I) :
CsvCompression<I>(wrapStream)
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import org.jetbrains.kotlinx.dataframe.api.ParserOptions
import org.jetbrains.kotlinx.dataframe.impl.io.CsvTsvParams
import org.jetbrains.kotlinx.dataframe.impl.io.asURL
import org.jetbrains.kotlinx.dataframe.impl.io.catchHttpResponse
import org.jetbrains.kotlinx.dataframe.impl.io.isCompressed
import org.jetbrains.kotlinx.dataframe.impl.io.compressionStateOf
import org.jetbrains.kotlinx.dataframe.impl.io.readCsvOrTsvImpl
import java.io.File
import java.io.FileInputStream
Expand All @@ -24,6 +24,7 @@ public fun DataFrame.Companion.readCsv(
file: File,
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
header: List<String> = CsvTsvParams.HEADER,
compression: CsvCompression<*> = compressionStateOf(file),
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
skipLines: Long = CsvTsvParams.SKIP_LINES,
readLines: Long? = CsvTsvParams.READ_LINES,
Expand All @@ -41,7 +42,7 @@ public fun DataFrame.Companion.readCsv(
inputStream = it,
delimiter = delimiter,
header = header,
isCompressed = isCompressed(file),
compression = compression,
colTypes = colTypes,
skipLines = skipLines,
readLines = readLines,
Expand All @@ -61,6 +62,7 @@ public fun DataFrame.Companion.readCsv(
url: URL,
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
header: List<String> = CsvTsvParams.HEADER,
compression: CsvCompression<*> = compressionStateOf(url),
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
skipLines: Long = CsvTsvParams.SKIP_LINES,
readLines: Long? = CsvTsvParams.READ_LINES,
Expand All @@ -78,7 +80,7 @@ public fun DataFrame.Companion.readCsv(
inputStream = it,
delimiter = delimiter,
header = header,
isCompressed = isCompressed(url),
compression = compression,
colTypes = colTypes,
skipLines = skipLines,
readLines = readLines,
Expand All @@ -98,6 +100,7 @@ public fun DataFrame.Companion.readCsv(
fileOrUrl: String,
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
header: List<String> = CsvTsvParams.HEADER,
compression: CsvCompression<*> = compressionStateOf(fileOrUrl),
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
skipLines: Long = CsvTsvParams.SKIP_LINES,
readLines: Long? = CsvTsvParams.READ_LINES,
Expand All @@ -115,7 +118,7 @@ public fun DataFrame.Companion.readCsv(
inputStream = it,
delimiter = delimiter,
header = header,
isCompressed = isCompressed(fileOrUrl),
compression = compression,
colTypes = colTypes,
skipLines = skipLines,
readLines = readLines,
Expand All @@ -136,7 +139,7 @@ public fun DataFrame.Companion.readCsv(
inputStream: InputStream,
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
header: List<String> = CsvTsvParams.HEADER,
isCompressed: Boolean = CsvTsvParams.IS_COMPRESSED,
compression: CsvCompression<*> = CsvTsvParams.COMPRESSION,
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
skipLines: Long = CsvTsvParams.SKIP_LINES,
readLines: Long? = CsvTsvParams.READ_LINES,
Expand All @@ -154,7 +157,7 @@ public fun DataFrame.Companion.readCsv(
inputStream = inputStream,
delimiter = delimiter,
header = header,
isCompressed = isCompressed,
compression = compression,
colTypes = colTypes,
skipLines = skipLines,
readLines = readLines,
Expand All @@ -174,7 +177,7 @@ public fun DataFrame.Companion.readCsvStr(
text: String,
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
header: List<String> = CsvTsvParams.HEADER,
isCompressed: Boolean = CsvTsvParams.IS_COMPRESSED,
compression: CsvCompression<*> = CsvTsvParams.COMPRESSION,
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
skipLines: Long = CsvTsvParams.SKIP_LINES,
readLines: Long? = CsvTsvParams.READ_LINES,
Expand All @@ -191,7 +194,7 @@ public fun DataFrame.Companion.readCsvStr(
inputStream = text.byteInputStream(),
delimiter = delimiter,
header = header,
isCompressed = isCompressed,
compression = compression,
colTypes = colTypes,
skipLines = skipLines,
readLines = readLines,
Expand Down
Loading

0 comments on commit 2f51ebf

Please sign in to comment.