diff --git a/include/dpp/discordclient.h b/include/dpp/discordclient.h index 3ca87dd8b1..9e0222466a 100644 --- a/include/dpp/discordclient.h +++ b/include/dpp/discordclient.h @@ -36,6 +36,7 @@ #include #include #include +#include namespace dpp { @@ -52,18 +53,6 @@ namespace dpp { /* Forward declarations */ class cluster; -/** - * @brief This is an opaque class containing zlib library specific structures. - * We define it this way so that the public facing D++ library doesn't require - * the zlib headers be available to build against it. - */ -class zlibcontext; - -/** - * @brief Size of decompression buffer for zlib compressed traffic - */ -constexpr size_t DECOMP_BUFFER_SIZE = 512 * 1024; - /** * @brief How many seconds to wait between (re)connections. DO NOT change this. * It is mandated by the Discord API spec! @@ -308,15 +297,6 @@ class DPP_EXPORT discord_client : public websocket_client */ bool compressed; - /** - * @brief ZLib decompression buffer - * - * If compression is not in use, this remains set to - * a vector of length zero, but when compression is - * enabled it will be resized to a DECOMP_BUFFER_SIZE buffer. - */ - std::vector decomp_buffer; - /** * @brief Decompressed string */ @@ -330,11 +310,6 @@ class DPP_EXPORT discord_client : public websocket_client */ std::unique_ptr zlib{}; - /** - * @brief Total decompressed received bytes - */ - uint64_t decompressed_total; - /** * @brief Last connect time of cluster */ @@ -573,11 +548,13 @@ class DPP_EXPORT discord_client : public websocket_client /** * @brief Destroy the discord client object */ - virtual ~discord_client() override; + virtual ~discord_client() = default; /** - * @brief Get the decompressed bytes in objectGet decompressed total bytes received - * @return uint64_t bytes received + * @brief Get decompressed total bytes received + * + * This will always return 0 if the connection is not compressed + * @return uint64_t compressed bytes received */ uint64_t get_decompressed_bytes_in(); diff --git a/include/dpp/zlibcontext.h b/include/dpp/zlibcontext.h new file mode 100644 index 0000000000..fa8c6092f7 --- /dev/null +++ b/include/dpp/zlibcontext.h @@ -0,0 +1,85 @@ +/************************************************************************************ + * + * D++, A Lightweight C++ library for Discord + * + * Copyright 2021 Craig Edwards and D++ contributors + * (https://github.com/brainboxdotcc/DPP/graphs/contributors) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ************************************************************************************/ +#pragma once +#include +#include +#include +#include +#include + +/** + * @brief Forward declaration for zlib stream type + */ +typedef struct z_stream_s z_stream; + +namespace dpp { + +/** + * @brief Size of decompression buffer for zlib compressed traffic + */ +constexpr size_t DECOMP_BUFFER_SIZE = 512 * 1024; + +/** + * @brief This is an opaque class containing zlib library specific structures. + * This wraps the C pointers needed for zlib with unique_ptr and gives us a nice + * buffer abstraction so we don't need to wrestle with raw pointers. + */ +class zlibcontext { +public: + /** + * @brief Zlib stream struct. The actual type is defined in zlib.h + * so is only defined in the implementation file. + */ + std::unique_ptr d_stream{}; + + /** + * @brief ZLib decompression buffer. + * This is automatically set to DECOMP_BUFFER_SIZE bytes when + * the class is constructed. + */ + std::vector decomp_buffer{}; + + /** + * @brief Total decompressed received bytes counter + */ + uint64_t decompressed_total{}; + + /** + * @brief Initialise zlib struct via inflateInit() + * and size the buffer + */ + zlibcontext(); + + /** + * @brief Destroy zlib struct via inflateEnd() + */ + ~zlibcontext(); + + /** + * @brief Decompress zlib deflated buffer + * @param buffer input compressed stream + * @param decompressed output decompressed content + * @return an error code on error, or err_no_code_specified (0) on success + */ + exception_error_code decompress(const std::string& buffer, std::string& decompressed); +}; + +} \ No newline at end of file diff --git a/src/dpp/discordclient.cpp b/src/dpp/discordclient.cpp index 9c13fde470..75f8686c6c 100644 --- a/src/dpp/discordclient.cpp +++ b/src/dpp/discordclient.cpp @@ -3,7 +3,7 @@ * D++, A Lightweight C++ library for Discord * * SPDX-License-Identifier: Apache-2.0 - * Copyright 2021 Craig Edwards and D++ contributors +#include * Copyright 2021 Craig Edwards and D++ contributors * (https://github.com/brainboxdotcc/DPP/graphs/contributors) * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +28,6 @@ #include #include #include -#include #define PATH_UNCOMPRESSED_JSON "/?v=" DISCORD_API_VERSION "&encoding=json" #define PATH_COMPRESSED_JSON "/?v=" DISCORD_API_VERSION "&encoding=json&compress=zlib-stream" @@ -48,45 +47,12 @@ namespace dpp { */ constexpr int LARGE_THRESHOLD = 250; -/** - * @brief This is an opaque class containing zlib library specific structures. - * We define it this way so that the public facing D++ library doesn't require - * the zlib headers be available to build against it. - */ -class zlibcontext { -public: - /** - * @brief Zlib stream struct - */ - z_stream d_stream{}; - - /** - * @brief Initialise zlib struct via inflateInit() - */ - zlibcontext() { - int error = inflateInit(&d_stream); - if (error != Z_OK) { - throw dpp::connection_exception((exception_error_code)error, "Can't initialise stream compression!"); - } - } - - /** - * @brief Destroy zlib struct via inflateEnd() - */ - ~zlibcontext() { - inflateEnd(&d_stream); - } -}; - - /** * @brief Resume constructor for websocket client */ discord_client::discord_client(discord_client &old, uint64_t sequence, const std::string& session_id) : websocket_client(old.owner, old.resume_gateway_url, "443", old.compressed ? (old.protocol == ws_json ? PATH_COMPRESSED_JSON : PATH_COMPRESSED_ETF) : (old.protocol == ws_json ? PATH_UNCOMPRESSED_JSON : PATH_UNCOMPRESSED_ETF)), compressed(old.compressed), - zlib(nullptr), - decompressed_total(old.decompressed_total), connect_time(0), ping_start(0.0), etf(nullptr), @@ -113,8 +79,6 @@ discord_client::discord_client(discord_client &old, uint64_t sequence, const std discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint32_t _max_shards, const std::string &_token, uint32_t _intents, bool comp, websocket_protocol_t ws_proto) : websocket_client(_cluster, _cluster->default_gateway, "443", comp ? (ws_proto == ws_json ? PATH_COMPRESSED_JSON : PATH_COMPRESSED_ETF) : (ws_proto == ws_json ? PATH_UNCOMPRESSED_JSON : PATH_UNCOMPRESSED_ETF)), compressed(comp), - zlib(nullptr), - decompressed_total(0), connect_time(0), ping_start(0.0), etf(nullptr), @@ -138,10 +102,9 @@ discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint3 } void discord_client::start_connecting() { - etf = std::make_unique(etf_parser()); + etf = std::make_unique(); if (compressed) { zlib = std::make_unique(); - decomp_buffer.resize(DECOMP_BUFFER_SIZE); } websocket_client::connect(); } @@ -150,10 +113,6 @@ void discord_client::cleanup() { } -discord_client::~discord_client() -{ -} - void discord_client::on_disconnect() { log(ll_trace, "discord_client::on_disconnect()"); @@ -167,7 +126,7 @@ void discord_client::on_disconnect() uint64_t discord_client::get_decompressed_bytes_in() { - return decompressed_total; + return zlib ? zlib->decompressed_total : 0; } void discord_client::set_resume_hostname() @@ -191,40 +150,12 @@ bool discord_client::handle_frame(const std::string &buffer, ws_opcode opcode) /* Check that we have a complete compressed frame */ if ((uint8_t)buffer[buffer.size() - 4] == 0x00 && (uint8_t)buffer[buffer.size() - 3] == 0x00 && (uint8_t)buffer[buffer.size() - 2] == 0xFF && (uint8_t)buffer[buffer.size() - 1] == 0xFF) { - /* Decompress buffer */ - decompressed.clear(); - /* This is safe; zlib requires us to cast away the const. The underlying buffer is unchanged. */ - zlib->d_stream.next_in = reinterpret_cast(const_cast(buffer.data())); - zlib->d_stream.avail_in = static_cast(buffer.size()); - do { - zlib->d_stream.next_out = static_cast(decomp_buffer.data()); - zlib->d_stream.avail_out = DECOMP_BUFFER_SIZE; - int ret = inflate(&(zlib->d_stream), Z_NO_FLUSH); - size_t have = DECOMP_BUFFER_SIZE - zlib->d_stream.avail_out; - switch (ret) - { - case Z_NEED_DICT: - case Z_STREAM_ERROR: - this->error(err_compression_stream); - this->close(); - return false; - case Z_DATA_ERROR: - this->error(err_compression_data); - this->close(); - return false; - case Z_MEM_ERROR: - this->error(err_compression_memory); - this->close(); - return false; - case Z_OK: - this->decompressed.append(decomp_buffer.begin(), decomp_buffer.begin() + have); - this->decompressed_total += have; - break; - default: - /* Stub */ - break; - } - } while (zlib->d_stream.avail_out == 0); + auto result = zlib->decompress(buffer, decompressed); + if (result != err_no_code_specified) { + this->error(result); + this->close(); + return false; + } data = decompressed; } else { /* No complete compressed frame yet */ diff --git a/src/dpp/zlibcontext.cpp b/src/dpp/zlibcontext.cpp new file mode 100644 index 0000000000..6631c6b223 --- /dev/null +++ b/src/dpp/zlibcontext.cpp @@ -0,0 +1,73 @@ +/************************************************************************************ + * + * D++, A Lightweight C++ library for Discord + * + * Copyright 2021 Craig Edwards and D++ contributors + * (https://github.com/brainboxdotcc/DPP/graphs/contributors) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ************************************************************************************/ +#include +#include +#include +#include + +namespace dpp { + +zlibcontext::zlibcontext() { + d_stream = std::make_unique(); + std::memset(d_stream.get(), 0, sizeof(z_stream)); + int error = inflateInit(d_stream.get()); + if (error != Z_OK) { + throw dpp::connection_exception((exception_error_code)error, "Can't initialise stream compression!"); + } + decomp_buffer.resize(DECOMP_BUFFER_SIZE); +} + +zlibcontext::~zlibcontext() { + inflateEnd(d_stream.get()); +} + +exception_error_code zlibcontext::decompress(const std::string& buffer, std::string& decompressed) { + decompressed.clear(); + /* This is safe; zlib requires us to cast away the const. The underlying buffer is unchanged. */ + d_stream->next_in = reinterpret_cast(const_cast(buffer.data())); + d_stream->avail_in = static_cast(buffer.size()); + do { + d_stream->next_out = static_cast(decomp_buffer.data()); + d_stream->avail_out = DECOMP_BUFFER_SIZE; + int ret = inflate(d_stream.get(), Z_NO_FLUSH); + size_t have = DECOMP_BUFFER_SIZE - d_stream->avail_out; + switch (ret) + { + case Z_NEED_DICT: + case Z_STREAM_ERROR: + return err_compression_stream; + case Z_DATA_ERROR: + return err_compression_data; + case Z_MEM_ERROR: + return err_compression_memory; + case Z_OK: + decompressed.append(decomp_buffer.begin(), decomp_buffer.begin() + have); + decompressed_total += have; + break; + default: + /* Stub */ + break; + } + } while (d_stream->avail_out == 0); + return err_no_code_specified; +} + +}; \ No newline at end of file