Skip to content

Commit

Permalink
Apply CR changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Oct 25, 2024
1 parent 6e57be9 commit e924d32
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.github.luben.zstd.ZstdInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PushbackInputStream;
import java.util.zip.GZIPInputStream;
import net.snowflake.common.core.SqlState;
import org.apache.http.Header;
Expand All @@ -16,39 +15,24 @@ class CompressedStreamFactory {

private static final int STREAM_BUFFER_SIZE = MB;

/**
* Determine the format of the response, if it is not either plain text or gzip, raise an error.
*/
public InputStream createBasedOnEncodingHeader(InputStream is, Header encoding)
throws IOException, SnowflakeSQLException {
InputStream inputStream = is; // Determine the format of the response, if it is not
// either plain text or gzip, raise an error.
if (encoding != null) {
if (GZIP.name().equalsIgnoreCase(encoding.getValue())) {
/* specify buffer size for GZIPInputStream */
inputStream = new GZIPInputStream(is, STREAM_BUFFER_SIZE);
return new GZIPInputStream(is, STREAM_BUFFER_SIZE);
} else if (ZSTD.name().equalsIgnoreCase(encoding.getValue())) {
inputStream = new ZstdInputStream(is);
return new ZstdInputStream(is);
} else {
throw new SnowflakeSQLException(
SqlState.INTERNAL_ERROR,
ErrorCode.INTERNAL_ERROR.getMessageCode(),
"Exception: unexpected compression got " + encoding.getValue());
}
} else {
inputStream = detectGzipAndGetStream(is);
}

return inputStream;
}

private InputStream detectGzipAndGetStream(InputStream is) throws IOException {
PushbackInputStream pb = new PushbackInputStream(is, 2);
byte[] signature = new byte[2];
int len = pb.read(signature);
pb.unread(signature, 0, len);
// https://tools.ietf.org/html/rfc1952
if (signature[0] == (byte) 0x1f && signature[1] == (byte) 0x8b) {
return new GZIPInputStream(pb);
} else {
return pb;
return DefaultResultStreamProvider.detectGzipAndGetStream(is);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import java.io.IOException;
import java.io.InputStream;
import java.io.PushbackInputStream;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import net.snowflake.client.core.ExecTimeTelemetryData;
import net.snowflake.client.core.HttpUtil;
import net.snowflake.client.log.ArgSupplier;
Expand Down Expand Up @@ -145,4 +147,17 @@ else if (context.getQrmk() != null) {
response);
return response;
}

public static InputStream detectGzipAndGetStream(InputStream is) throws IOException {
PushbackInputStream pb = new PushbackInputStream(is, 2);
byte[] signature = new byte[2];
int len = pb.read(signature);
pb.unread(signature, 0, len);
// https://tools.ietf.org/html/rfc1952
if (signature[0] == (byte) 0x1f && signature[1] == (byte) 0x8b) {
return new GZIPInputStream(pb);
} else {
return pb;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package net.snowflake.client.jdbc;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import com.github.luben.zstd.ZstdInputStream;
import com.github.luben.zstd.ZstdOutputStream;
Expand All @@ -10,6 +11,7 @@
import java.nio.charset.StandardCharsets;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.io.IOUtils;
import org.apache.http.Header;
import org.apache.http.message.BasicHeader;
import org.junit.Test;
Expand Down Expand Up @@ -42,16 +44,8 @@ public void testDetectContentEncodingAndGetInputStream_Gzip() throws Exception {
InputStream resultStream = factory.createBasedOnEncodingHeader(gzipStream, encodingHeader);

// Decompress and validate the data matches original
ByteArrayOutputStream decompressedOutput = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int bytesRead;
try (GZIPInputStream gzipInputStream = (GZIPInputStream) resultStream) {
while ((bytesRead = gzipInputStream.read(buffer)) != -1) {
decompressedOutput.write(buffer, 0, bytesRead);
}
}
String decompressedData = new String(decompressedOutput.toByteArray(), StandardCharsets.UTF_8);

assertTrue(resultStream instanceof GZIPInputStream);
String decompressedData = IOUtils.toString(resultStream, StandardCharsets.UTF_8);
assertEquals(originalData, decompressedData);
}

Expand Down Expand Up @@ -79,16 +73,8 @@ public void testDetectContentEncodingAndGetInputStream_Zstd() throws Exception {
InputStream resultStream = factory.createBasedOnEncodingHeader(zstdStream, encodingHeader);

// Decompress and validate the data matches original
ByteArrayOutputStream decompressedOutput = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int bytesRead;
try (ZstdInputStream zstdInputStream = (ZstdInputStream) resultStream) {
while ((bytesRead = zstdInputStream.read(buffer)) != -1) {
decompressedOutput.write(buffer, 0, bytesRead);
}
}
String decompressedData = new String(decompressedOutput.toByteArray(), StandardCharsets.UTF_8);

assertTrue(resultStream instanceof ZstdInputStream);
String decompressedData = IOUtils.toString(resultStream, StandardCharsets.UTF_8);
assertEquals(originalData, decompressedData);
}
}

0 comments on commit e924d32

Please sign in to comment.