diff --git a/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverter.java
index 788de87e0..f61e9954d 100644
--- a/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverter.java
+++ b/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverter.java
@@ -7,22 +7,8 @@
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
-import java.util.Map;
import java.util.TimeZone;
-import net.snowflake.client.core.DataConversionContext;
-import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
-import net.snowflake.client.jdbc.ErrorCode;
-import net.snowflake.client.jdbc.SnowflakeSQLException;
-import net.snowflake.client.jdbc.SnowflakeSQLLoggedException;
-import net.snowflake.client.jdbc.SnowflakeType;
-import net.snowflake.common.core.SqlState;
-import org.apache.arrow.vector.ValueVector;
-import org.apache.arrow.vector.complex.FixedSizeListVector;
-import org.apache.arrow.vector.complex.ListVector;
-import org.apache.arrow.vector.complex.MapVector;
-import org.apache.arrow.vector.complex.StructVector;
-import org.apache.arrow.vector.types.Types;
/** Interface to convert from arrow vector values into java data types. */
public interface ArrowVectorConverter {
@@ -177,201 +163,4 @@ public interface ArrowVectorConverter {
* @param isUTC true or false value of whether NTZ timestamp should be set to UTC
*/
void setTreatNTZAsUTC(boolean isUTC);
-
- /**
- * Given an arrow vector (a single column in a single record batch), return an arrow vector
- * converter. Note, converter is built on top of arrow vector, so that arrow data can be converted
- * back to java data
- *
- *
- *
- *
Arrow converter mappings for Snowflake fixed-point numbers
- * ----------------------------------------------------------------------------------------- Max
- * position and scale Converter
- * -----------------------------------------------------------------------------------------
- * number(3,0) {@link TinyIntToFixedConverter} number(3,2) {@link TinyIntToScaledFixedConverter}
- * number(5,0) {@link SmallIntToFixedConverter} number(5,4) {@link SmallIntToScaledFixedConverter}
- * number(10,0) {@link IntToFixedConverter} number(10,9) {@link IntToScaledFixedConverter}
- * number(19,0) {@link BigIntToFixedConverter} number(19,18) {@link BigIntToFixedConverter}
- * number(38,37) {@link DecimalToScaledFixedConverter}
- * ------------------------------------------------------------------------------------------
- *
- * @param vector an arrow vector
- * @param context data conversion context
- * @param session SFBaseSession for purposes of logging
- * @param idx the index of the vector in its batch
- * @return A converter on top og the vector
- */
- static ArrowVectorConverter initConverter(
- ValueVector vector, DataConversionContext context, SFBaseSession session, int idx)
- throws SnowflakeSQLException {
- // arrow minor type
- Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType());
-
- // each column's metadata
- Map customMeta = vector.getField().getMetadata();
- if (type == Types.MinorType.DECIMAL) {
- // Note: Decimal vector is different from others
- return new DecimalToScaledFixedConverter(vector, idx, context);
- } else if (!customMeta.isEmpty()) {
- SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType"));
- switch (st) {
- case ANY:
- case CHAR:
- case TEXT:
- case VARIANT:
- return new VarCharConverter(vector, idx, context);
-
- case MAP:
- if (vector instanceof MapVector) {
- return new MapConverter((MapVector) vector, idx, context);
- } else {
- return new VarCharConverter(vector, idx, context);
- }
-
- case VECTOR:
- return new VectorTypeConverter((FixedSizeListVector) vector, idx, context);
-
- case ARRAY:
- if (vector instanceof ListVector) {
- return new ArrayConverter((ListVector) vector, idx, context);
- } else {
- return new VarCharConverter(vector, idx, context);
- }
-
- case OBJECT:
- if (vector instanceof StructVector) {
- return new StructConverter((StructVector) vector, idx, context);
- } else {
- return new VarCharConverter(vector, idx, context);
- }
-
- case BINARY:
- return new VarBinaryToBinaryConverter(vector, idx, context);
-
- case BOOLEAN:
- return new BitToBooleanConverter(vector, idx, context);
-
- case DATE:
- boolean getFormatDateWithTimeZone = false;
- if (context.getSession() != null) {
- getFormatDateWithTimeZone = context.getSession().getFormatDateWithTimezone();
- }
- return new DateConverter(vector, idx, context, getFormatDateWithTimeZone);
-
- case FIXED:
- String scaleStr = vector.getField().getMetadata().get("scale");
- int sfScale = Integer.parseInt(scaleStr);
- switch (type) {
- case TINYINT:
- if (sfScale == 0) {
- return new TinyIntToFixedConverter(vector, idx, context);
- } else {
- return new TinyIntToScaledFixedConverter(vector, idx, context, sfScale);
- }
- case SMALLINT:
- if (sfScale == 0) {
- return new SmallIntToFixedConverter(vector, idx, context);
- } else {
- return new SmallIntToScaledFixedConverter(vector, idx, context, sfScale);
- }
- case INT:
- if (sfScale == 0) {
- return new IntToFixedConverter(vector, idx, context);
- } else {
- return new IntToScaledFixedConverter(vector, idx, context, sfScale);
- }
- case BIGINT:
- if (sfScale == 0) {
- return new BigIntToFixedConverter(vector, idx, context);
- } else {
- return new BigIntToScaledFixedConverter(vector, idx, context, sfScale);
- }
- }
- break;
-
- case REAL:
- return new DoubleToRealConverter(vector, idx, context);
-
- case TIME:
- switch (type) {
- case INT:
- return new IntToTimeConverter(vector, idx, context);
- case BIGINT:
- return new BigIntToTimeConverter(vector, idx, context);
- default:
- throw new SnowflakeSQLLoggedException(
- session,
- ErrorCode.INTERNAL_ERROR.getMessageCode(),
- SqlState.INTERNAL_ERROR,
- "Unexpected Arrow Field for ",
- st.name());
- }
-
- case TIMESTAMP_LTZ:
- if (vector.getField().getChildren().isEmpty()) {
- // case when the scale of the timestamp is equal or smaller than millisecs since epoch
- return new BigIntToTimestampLTZConverter(vector, idx, context);
- } else if (vector.getField().getChildren().size() == 2) {
- // case when the scale of the timestamp is larger than millisecs since epoch, e.g.,
- // nanosecs
- return new TwoFieldStructToTimestampLTZConverter(vector, idx, context);
- } else {
- throw new SnowflakeSQLLoggedException(
- session,
- ErrorCode.INTERNAL_ERROR.getMessageCode(),
- SqlState.INTERNAL_ERROR,
- "Unexpected Arrow Field for ",
- st.name());
- }
-
- case TIMESTAMP_NTZ:
- if (vector.getField().getChildren().isEmpty()) {
- // case when the scale of the timestamp is equal or smaller than 7
- return new BigIntToTimestampNTZConverter(vector, idx, context);
- } else if (vector.getField().getChildren().size() == 2) {
- // when the timestamp is represent in two-field struct
- return new TwoFieldStructToTimestampNTZConverter(vector, idx, context);
- } else {
- throw new SnowflakeSQLLoggedException(
- session,
- ErrorCode.INTERNAL_ERROR.getMessageCode(),
- SqlState.INTERNAL_ERROR,
- "Unexpected Arrow Field for ",
- st.name());
- }
-
- case TIMESTAMP_TZ:
- if (vector.getField().getChildren().size() == 2) {
- // case when the scale of the timestamp is equal or smaller than millisecs since epoch
- return new TwoFieldStructToTimestampTZConverter(vector, idx, context);
- } else if (vector.getField().getChildren().size() == 3) {
- // case when the scale of the timestamp is larger than millisecs since epoch, e.g.,
- // nanosecs
- return new ThreeFieldStructToTimestampTZConverter(vector, idx, context);
- } else {
- throw new SnowflakeSQLLoggedException(
- session,
- ErrorCode.INTERNAL_ERROR.getMessageCode(),
- SqlState.INTERNAL_ERROR,
- "Unexpected SnowflakeType ",
- st.name());
- }
-
- default:
- throw new SnowflakeSQLLoggedException(
- session,
- ErrorCode.INTERNAL_ERROR.getMessageCode(),
- SqlState.INTERNAL_ERROR,
- "Unexpected Arrow Field for ",
- st.name());
- }
- }
- throw new SnowflakeSQLLoggedException(
- session,
- ErrorCode.INTERNAL_ERROR.getMessageCode(),
- SqlState.INTERNAL_ERROR,
- "Unexpected Arrow Field for ",
- type.toString());
- }
}
diff --git a/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverterUtil.java b/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverterUtil.java
new file mode 100644
index 000000000..1aa84db8f
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverterUtil.java
@@ -0,0 +1,219 @@
+package net.snowflake.client.core.arrow;
+
+import java.util.Map;
+import net.snowflake.client.core.DataConversionContext;
+import net.snowflake.client.core.SFBaseSession;
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+import net.snowflake.client.jdbc.ErrorCode;
+import net.snowflake.client.jdbc.SnowflakeSQLException;
+import net.snowflake.client.jdbc.SnowflakeSQLLoggedException;
+import net.snowflake.client.jdbc.SnowflakeType;
+import net.snowflake.common.core.SqlState;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.complex.FixedSizeListVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.types.Types;
+
+@SnowflakeJdbcInternalApi
+public final class ArrowVectorConverterUtil {
+ private ArrowVectorConverterUtil() {}
+
+ /**
+ * Given an arrow vector (a single column in a single record batch), return an arrow vector
+ * converter. Note, converter is built on top of arrow vector, so that arrow data can be converted
+ * back to java data
+ *
+ *
+ *
+ *
Arrow converter mappings for Snowflake fixed-point numbers
+ * ----------------------------------------------------------------------------------------- Max
+ * position and scale Converter
+ * -----------------------------------------------------------------------------------------
+ * number(3,0) {@link TinyIntToFixedConverter} number(3,2) {@link TinyIntToScaledFixedConverter}
+ * number(5,0) {@link SmallIntToFixedConverter} number(5,4) {@link SmallIntToScaledFixedConverter}
+ * number(10,0) {@link IntToFixedConverter} number(10,9) {@link IntToScaledFixedConverter}
+ * number(19,0) {@link BigIntToFixedConverter} number(19,18) {@link BigIntToFixedConverter}
+ * number(38,37) {@link DecimalToScaledFixedConverter}
+ * ------------------------------------------------------------------------------------------
+ *
+ * @param vector an arrow vector
+ * @param context data conversion context
+ * @param session SFBaseSession for purposes of logging
+ * @param idx the index of the vector in its batch
+ * @return A converter on top og the vector
+ */
+ public static ArrowVectorConverter initConverter(
+ ValueVector vector, DataConversionContext context, SFBaseSession session, int idx)
+ throws SnowflakeSQLException {
+ // arrow minor type
+ Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType());
+
+ // each column's metadata
+ Map customMeta = vector.getField().getMetadata();
+ if (type == Types.MinorType.DECIMAL) {
+ // Note: Decimal vector is different from others
+ return new DecimalToScaledFixedConverter(vector, idx, context);
+ } else if (!customMeta.isEmpty()) {
+ SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType"));
+ switch (st) {
+ case ANY:
+ case CHAR:
+ case TEXT:
+ case VARIANT:
+ return new VarCharConverter(vector, idx, context);
+
+ case MAP:
+ if (vector instanceof MapVector) {
+ return new MapConverter((MapVector) vector, idx, context);
+ } else {
+ return new VarCharConverter(vector, idx, context);
+ }
+
+ case VECTOR:
+ return new VectorTypeConverter((FixedSizeListVector) vector, idx, context);
+
+ case ARRAY:
+ if (vector instanceof ListVector) {
+ return new ArrayConverter((ListVector) vector, idx, context);
+ } else {
+ return new VarCharConverter(vector, idx, context);
+ }
+
+ case OBJECT:
+ if (vector instanceof StructVector) {
+ return new StructConverter((StructVector) vector, idx, context);
+ } else {
+ return new VarCharConverter(vector, idx, context);
+ }
+
+ case BINARY:
+ return new VarBinaryToBinaryConverter(vector, idx, context);
+
+ case BOOLEAN:
+ return new BitToBooleanConverter(vector, idx, context);
+
+ case DATE:
+ boolean getFormatDateWithTimeZone = false;
+ if (context.getSession() != null) {
+ getFormatDateWithTimeZone = context.getSession().getFormatDateWithTimezone();
+ }
+ return new DateConverter(vector, idx, context, getFormatDateWithTimeZone);
+
+ case FIXED:
+ String scaleStr = vector.getField().getMetadata().get("scale");
+ int sfScale = Integer.parseInt(scaleStr);
+ switch (type) {
+ case TINYINT:
+ if (sfScale == 0) {
+ return new TinyIntToFixedConverter(vector, idx, context);
+ } else {
+ return new TinyIntToScaledFixedConverter(vector, idx, context, sfScale);
+ }
+ case SMALLINT:
+ if (sfScale == 0) {
+ return new SmallIntToFixedConverter(vector, idx, context);
+ } else {
+ return new SmallIntToScaledFixedConverter(vector, idx, context, sfScale);
+ }
+ case INT:
+ if (sfScale == 0) {
+ return new IntToFixedConverter(vector, idx, context);
+ } else {
+ return new IntToScaledFixedConverter(vector, idx, context, sfScale);
+ }
+ case BIGINT:
+ if (sfScale == 0) {
+ return new BigIntToFixedConverter(vector, idx, context);
+ } else {
+ return new BigIntToScaledFixedConverter(vector, idx, context, sfScale);
+ }
+ }
+ break;
+
+ case REAL:
+ return new DoubleToRealConverter(vector, idx, context);
+
+ case TIME:
+ switch (type) {
+ case INT:
+ return new IntToTimeConverter(vector, idx, context);
+ case BIGINT:
+ return new BigIntToTimeConverter(vector, idx, context);
+ default:
+ throw new SnowflakeSQLLoggedException(
+ session,
+ ErrorCode.INTERNAL_ERROR.getMessageCode(),
+ SqlState.INTERNAL_ERROR,
+ "Unexpected Arrow Field for ",
+ st.name());
+ }
+
+ case TIMESTAMP_LTZ:
+ if (vector.getField().getChildren().isEmpty()) {
+ // case when the scale of the timestamp is equal or smaller than millisecs since epoch
+ return new BigIntToTimestampLTZConverter(vector, idx, context);
+ } else if (vector.getField().getChildren().size() == 2) {
+ // case when the scale of the timestamp is larger than millisecs since epoch, e.g.,
+ // nanosecs
+ return new TwoFieldStructToTimestampLTZConverter(vector, idx, context);
+ } else {
+ throw new SnowflakeSQLLoggedException(
+ session,
+ ErrorCode.INTERNAL_ERROR.getMessageCode(),
+ SqlState.INTERNAL_ERROR,
+ "Unexpected Arrow Field for ",
+ st.name());
+ }
+
+ case TIMESTAMP_NTZ:
+ if (vector.getField().getChildren().isEmpty()) {
+ // case when the scale of the timestamp is equal or smaller than 7
+ return new BigIntToTimestampNTZConverter(vector, idx, context);
+ } else if (vector.getField().getChildren().size() == 2) {
+ // when the timestamp is represent in two-field struct
+ return new TwoFieldStructToTimestampNTZConverter(vector, idx, context);
+ } else {
+ throw new SnowflakeSQLLoggedException(
+ session,
+ ErrorCode.INTERNAL_ERROR.getMessageCode(),
+ SqlState.INTERNAL_ERROR,
+ "Unexpected Arrow Field for ",
+ st.name());
+ }
+
+ case TIMESTAMP_TZ:
+ if (vector.getField().getChildren().size() == 2) {
+ // case when the scale of the timestamp is equal or smaller than millisecs since epoch
+ return new TwoFieldStructToTimestampTZConverter(vector, idx, context);
+ } else if (vector.getField().getChildren().size() == 3) {
+ // case when the scale of the timestamp is larger than millisecs since epoch, e.g.,
+ // nanosecs
+ return new ThreeFieldStructToTimestampTZConverter(vector, idx, context);
+ } else {
+ throw new SnowflakeSQLLoggedException(
+ session,
+ ErrorCode.INTERNAL_ERROR.getMessageCode(),
+ SqlState.INTERNAL_ERROR,
+ "Unexpected SnowflakeType ",
+ st.name());
+ }
+
+ default:
+ throw new SnowflakeSQLLoggedException(
+ session,
+ ErrorCode.INTERNAL_ERROR.getMessageCode(),
+ SqlState.INTERNAL_ERROR,
+ "Unexpected Arrow Field for ",
+ st.name());
+ }
+ }
+ throw new SnowflakeSQLLoggedException(
+ session,
+ ErrorCode.INTERNAL_ERROR.getMessageCode(),
+ SqlState.INTERNAL_ERROR,
+ "Unexpected Arrow Field for ",
+ type.toString());
+ }
+}
diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java
index dca895464..f69469542 100644
--- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java
+++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java
@@ -3,7 +3,7 @@
*/
package net.snowflake.client.jdbc;
-import static net.snowflake.client.core.arrow.ArrowVectorConverter.initConverter;
+import static net.snowflake.client.core.arrow.ArrowVectorConverterUtil.initConverter;
import java.io.IOException;
import java.io.InputStream;
diff --git a/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java b/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java
index 3ee556bb4..567db8fa1 100644
--- a/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java
+++ b/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java
@@ -2,6 +2,7 @@
import static net.snowflake.client.core.Constants.MB;
+import com.github.luben.zstd.ZstdInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PushbackInputStream;
@@ -153,6 +154,8 @@ private InputStream detectContentEncodingAndGetInputStream(HttpResponse response
if ("gzip".equalsIgnoreCase(encoding.getValue())) {
/* specify buffer size for GZIPInputStream */
inputStream = new GZIPInputStream(is, STREAM_BUFFER_SIZE);
+ } else if ("zstd".equalsIgnoreCase(encoding.getValue())) {
+ inputStream = new ZstdInputStream(is);
} else {
throw new SnowflakeSQLException(
SqlState.INTERNAL_ERROR,
diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeDriver.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeDriver.java
index 05566da82..060ac977e 100644
--- a/src/main/java/net/snowflake/client/jdbc/SnowflakeDriver.java
+++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeDriver.java
@@ -37,7 +37,7 @@ public class SnowflakeDriver implements Driver {
static SnowflakeDriver INSTANCE;
public static final Properties EMPTY_PROPERTIES = new Properties();
- public static String implementVersion = "3.19.0";
+ public static String implementVersion = "3.19.1";
static int majorVersion = 0;
static int minorVersion = 0;
diff --git a/src/test/java/net/snowflake/client/core/arrow/BaseConverterTest.java b/src/test/java/net/snowflake/client/core/arrow/BaseConverterTest.java
index 20a07a655..e669ac006 100644
--- a/src/test/java/net/snowflake/client/core/arrow/BaseConverterTest.java
+++ b/src/test/java/net/snowflake/client/core/arrow/BaseConverterTest.java
@@ -3,6 +3,7 @@
*/
package net.snowflake.client.core.arrow;
+import java.nio.ByteOrder;
import java.util.TimeZone;
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFSession;
@@ -10,6 +11,8 @@
import net.snowflake.common.core.SFBinaryFormat;
import net.snowflake.common.core.SnowflakeDateTimeFormat;
import org.junit.After;
+import org.junit.Assume;
+import org.junit.Before;
public class BaseConverterTest implements DataConversionContext {
private SnowflakeDateTimeFormat dateTimeFormat =
@@ -32,6 +35,13 @@ public void clearTimeZone() {
System.clearProperty("user.timezone");
}
+ @Before
+ public void assumeLittleEndian() {
+ Assume.assumeTrue(
+ "Arrow doesn't support cross endianness",
+ ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN));
+ }
+
@Override
public SnowflakeDateTimeFormat getTimestampLTZFormatter() {
return timestampLTZFormat;
diff --git a/src/test/java/net/snowflake/client/jdbc/DefaultResultStreamProviderTest.java b/src/test/java/net/snowflake/client/jdbc/DefaultResultStreamProviderTest.java
new file mode 100644
index 000000000..b78bf3a5e
--- /dev/null
+++ b/src/test/java/net/snowflake/client/jdbc/DefaultResultStreamProviderTest.java
@@ -0,0 +1,120 @@
+package net.snowflake.client.jdbc;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import com.github.luben.zstd.ZstdInputStream;
+import com.github.luben.zstd.ZstdOutputStream;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.InputStream;
+import java.lang.reflect.Method;
+import java.nio.charset.StandardCharsets;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+import org.apache.http.Header;
+import org.apache.http.HttpResponse;
+import org.junit.Before;
+import org.junit.Test;
+
+public class DefaultResultStreamProviderTest {
+
+ private DefaultResultStreamProvider resultStreamProvider;
+ private HttpResponse mockResponse;
+
+ @Before
+ public void setUp() {
+ resultStreamProvider = new DefaultResultStreamProvider();
+ mockResponse = mock(HttpResponse.class);
+ }
+
+ private InputStream invokeDetectContentEncodingAndGetInputStream(
+ HttpResponse response, InputStream inputStream) throws Exception {
+ Method method =
+ DefaultResultStreamProvider.class.getDeclaredMethod(
+ "detectContentEncodingAndGetInputStream", HttpResponse.class, InputStream.class);
+ method.setAccessible(true);
+ return (InputStream) method.invoke(resultStreamProvider, response, inputStream);
+ }
+
+ @Test
+ public void testDetectContentEncodingAndGetInputStream_Gzip() throws Exception {
+ // Mocking gzip content encoding
+ Header encodingHeader = mock(Header.class);
+ when(encodingHeader.getValue()).thenReturn("gzip");
+ when(mockResponse.getFirstHeader("Content-Encoding")).thenReturn(encodingHeader);
+
+ // Original data to compress and validate
+ String originalData = "Some data in GZIP";
+
+ // Creating a gzip byte array using GZIPOutputStream
+ byte[] gzipData;
+ try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+ GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)) {
+ gzipOutputStream.write(originalData.getBytes(StandardCharsets.UTF_8));
+ gzipOutputStream.close(); // close to flush and finish the compression
+ gzipData = byteArrayOutputStream.toByteArray();
+ }
+
+ // Mocking input stream with the gzip data
+ InputStream gzipStream = new ByteArrayInputStream(gzipData);
+
+ // Call the private method using reflection
+ InputStream resultStream =
+ invokeDetectContentEncodingAndGetInputStream(mockResponse, gzipStream);
+
+ // Decompress and validate the data matches original
+ ByteArrayOutputStream decompressedOutput = new ByteArrayOutputStream();
+ byte[] buffer = new byte[1024];
+ int bytesRead;
+ try (GZIPInputStream gzipInputStream = (GZIPInputStream) resultStream) {
+ while ((bytesRead = gzipInputStream.read(buffer)) != -1) {
+ decompressedOutput.write(buffer, 0, bytesRead);
+ }
+ }
+ String decompressedData = new String(decompressedOutput.toByteArray(), StandardCharsets.UTF_8);
+
+ assertEquals(originalData, decompressedData);
+ }
+
+ @Test
+ public void testDetectContentEncodingAndGetInputStream_Zstd() throws Exception {
+ // Mocking zstd content encoding
+ Header encodingHeader = mock(Header.class);
+ when(encodingHeader.getValue()).thenReturn("zstd");
+ when(mockResponse.getFirstHeader("Content-Encoding")).thenReturn(encodingHeader);
+
+ // Original data to compress and validate
+ String originalData = "Some data in ZSTD";
+
+ // Creating a zstd byte array using ZstdOutputStream
+ byte[] zstdData;
+ try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+ ZstdOutputStream zstdOutputStream = new ZstdOutputStream(byteArrayOutputStream)) {
+ zstdOutputStream.write(originalData.getBytes(StandardCharsets.UTF_8));
+ zstdOutputStream.close(); // close to flush and finish the compression
+ zstdData = byteArrayOutputStream.toByteArray();
+ }
+
+ // Mocking input stream with the zstd data
+ InputStream zstdStream = new ByteArrayInputStream(zstdData);
+
+ // Call the private method using reflection
+ InputStream resultStream =
+ invokeDetectContentEncodingAndGetInputStream(mockResponse, zstdStream);
+
+ // Decompress and validate the data matches original
+ ByteArrayOutputStream decompressedOutput = new ByteArrayOutputStream();
+ byte[] buffer = new byte[1024];
+ int bytesRead;
+ try (ZstdInputStream zstdInputStream = (ZstdInputStream) resultStream) {
+ while ((bytesRead = zstdInputStream.read(buffer)) != -1) {
+ decompressedOutput.write(buffer, 0, bytesRead);
+ }
+ }
+ String decompressedData = new String(decompressedOutput.toByteArray(), StandardCharsets.UTF_8);
+
+ assertEquals(originalData, decompressedData);
+ }
+}
diff --git a/thin_public_pom.xml b/thin_public_pom.xml
index 31a1aedee..a781a376d 100644
--- a/thin_public_pom.xml
+++ b/thin_public_pom.xml
@@ -64,6 +64,7 @@
UTF-8
2.0.13
1.6.9
+ 1.5.6-5
@@ -262,6 +263,11 @@
jsoup
${jsoup.version}
+
+ com.github.luben
+ zstd-jni
+ ${zstd-jni.version}
+
org.slf4j
slf4j-api