diff --git a/src/snowflake/connector/vendored/urllib3/response.py b/src/snowflake/connector/vendored/urllib3/response.py index cb7a06465..4011df4ec 100644 --- a/src/snowflake/connector/vendored/urllib3/response.py +++ b/src/snowflake/connector/vendored/urllib3/response.py @@ -17,6 +17,23 @@ except ImportError: brotli = None +try: + import zstandard as zstd +except (AttributeError, ImportError, ValueError): # Defensive: + HAS_ZSTD = False +else: + # Extract major and minor version numbers + version_parts = zstd.__version__.split(".") + major_version = int(version_parts[0]) + minor_version = int(version_parts[1]) + + # Check if the version is at least 0.18.0 + if (major_version, minor_version) < (0, 18): # Defensive: + HAS_ZSTD = False + else: + HAS_ZSTD = True + + from . import util from ._collections import HTTPHeaderDict from .connection import BaseSSLError, HTTPException @@ -126,6 +143,29 @@ def flush(self): return b"" +if HAS_ZSTD: + + class ZstdDecoder(object): + def __init__(self) -> None: + self._obj = zstd.ZstdDecompressor().decompressobj() + + def decompress(self, data: bytes) -> bytes: + if not data: + return b"" + data_parts = [self._obj.decompress(data)] + while self._obj.eof and self._obj.unused_data: + unused_data = self._obj.unused_data + self._obj = zstd.ZstdDecompressor().decompressobj() + data_parts.append(self._obj.decompress(unused_data)) + return b"".join(data_parts) + + def flush(self) -> bytes: + ret = self._obj.flush() # note: this is a no-op + if not self._obj.eof: + raise DecodeError("Zstandard data is incomplete") + return ret + + class MultiDecoder(object): """ From RFC7231: @@ -157,6 +197,9 @@ def _get_decoder(mode): if brotli is not None and mode == "br": return BrotliDecoder() + if HAS_ZSTD and mode == "zstd": + return ZstdDecoder() + return DeflateDecoder() @@ -196,6 +239,8 @@ class is also compatible with the Python standard library's :mod:`io` CONTENT_DECODERS = ["gzip", "deflate"] if brotli is not None: CONTENT_DECODERS += ["br"] + if HAS_ZSTD: + CONTENT_DECODERS += ["zstd"] REDIRECT_STATUSES = [301, 302, 303, 307, 308] def __init__( @@ -394,6 +439,8 @@ def _init_decoder(self): DECODER_ERROR_CLASSES = (IOError, zlib.error) if brotli is not None: DECODER_ERROR_CLASSES += (brotli.error,) + if HAS_ZSTD: + DECODER_ERROR_CLASSES += (zstd.ZstdError,) def _decode(self, data, decode_content, flush_decoder): """