diff --git a/README.md b/README.md index 872f9a9..9b79068 100644 --- a/README.md +++ b/README.md @@ -142,9 +142,10 @@ DATACENTER SET name | PRINT EXIT Exit the program gracefully. FARM (ADD name selector backend* | PARTITION name policy | DELETE name | PRINT) - Change farms HELP Print this help text. -LIMIT (CONN_RATE_ALARM | CONN_RATE) (VHOST vhostName numberOfConnections | DEFAULT numberOfConnections) - Configure connection rate limits (normal or alarmonly) for incoming clients connections +LIMIT (CONN_RATE_ALARM | CONN_RATE) (DEFAULT | VHOST vhostName) numberOfConnections - Configure connection rate limits (normal or alarmonly) for incoming clients connections +LIMIT (TOTAL_CONN_ALARM | TOTAL_CONN) (DEFAULT | VHOST vhostName) numberOfConnections - Configure total connection limits or alarms for incoming client connections LIMIT (DATA_RATE_ALARM | DATA_RATE) (DEFAULT | VHOST vhostName) BytesPerSecond - Configure data rate limits or alarms for incoming client data -LIMIT DISABLE (CONN_RATE_ALARM | CONN_RATE | DATA_RATE_ALARM | DATA_RATE) (VHOST vhostName | DEFAULT) - Disable configured limit thresholds +LIMIT DISABLE (CONN_RATE_ALARM | CONN_RATE | TOTAL_CONN_ALARM | TOTAL_CONN | DATA_RATE_ALARM | DATA_RATE) (VHOST vhostName | DEFAULT) - Disable configured limit thresholds LIMIT PRINT [vhostName] - Print the configured default or specific connection rate limits for specified vhost LISTEN START port | START_SECURE port | STOP [port] LOG CONSOLE verbosity | FILE verbosity diff --git a/docs/config.md b/docs/config.md index 71a73e9..d1aa11e 100644 --- a/docs/config.md +++ b/docs/config.md @@ -21,9 +21,10 @@ DATACENTER SET name | PRINT EXIT Exit the program gracefully. FARM (ADD name selector backend* | PARTITION name policy | DELETE name | PRINT) - Change farms HELP Print this help text. -LIMIT (CONN_RATE_ALARM | CONN_RATE) (VHOST vhostName numberOfConnections | DEFAULT numberOfConnections) - Configure connection rate limits (normal or alarmonly) for incoming clients connections +LIMIT (CONN_RATE_ALARM | CONN_RATE) (DEFAULT | VHOST vhostName) numberOfConnections - Configure connection rate limits (normal or alarmonly) for incoming clients connections +LIMIT (TOTAL_CONN_ALARM | TOTAL_CONN) (DEFAULT | VHOST vhostName) numberOfConnections - Configure total connection limits or alarms for incoming client connections LIMIT (DATA_RATE_ALARM | DATA_RATE) (DEFAULT | VHOST vhostName) BytesPerSecond - Configure data rate limits or alarms for incoming client data -LIMIT DISABLE (CONN_RATE_ALARM | CONN_RATE | DATA_RATE_ALARM | DATA_RATE) (VHOST vhostName | DEFAULT) - Disable configured limit thresholds +LIMIT DISABLE (CONN_RATE_ALARM | CONN_RATE | TOTAL_CONN_ALARM | TOTAL_CONN | DATA_RATE_ALARM | DATA_RATE) (VHOST vhostName | DEFAULT) - Disable configured limit thresholds LIMIT PRINT [vhostName] - Print the configured default or specific connection rate limits for specified vhost LISTEN START port | START_SECURE port | STOP [port] LOG CONSOLE verbosity | FILE verbosity @@ -142,6 +143,21 @@ Apply limit on allowed average number of connections per second for all the vhos #### LIMIT CONN_RATE VHOST vhostName numberOfConnections Apply limit on allowed average number of connections per second for specified vhost. The specific limit takes priority over the default limit for any vhost. +#### LIMIT TOTAL_CONN_ALARM DEFAULT numberOfConnections + +Apply limit on allowed total number of connections in alarm only mode for all the vhosts. So whenever the in-coming connection violates the limit, the proxy will only emit log at warning level with AMQPPROX_CONNECTION_LIMIT as a substring and the relevant limiter details, instead of actively limiting any actual connection. + +#### LIMIT TOTAL_CONN_ALARM VHOST vhostName numberOfConnections + +Apply limit on allowed total number of connections in alarm only mode for specified vhost. The specific limit takes priority over the default limit for any vhost. + +#### LIMIT TOTAL_CONN DEFAULT numberOfConnections + +Apply limit on allowed total number of connections for all the vhosts. So whenever the in-coming connection violates the limit, the proxy will close that connection with appropriate error message and will not allow that client connection to connect to the broker. + +#### LIMIT TOTAL_CONN VHOST vhostName numberOfConnections +Apply limit on allowed total number of connections for specified vhost. The specific limit takes priority over the default limit for any vhost. + #### LIMIT DATA_RATE_ALARM DEFAULT BytesPerSecond Apply limit on allowed max bytes per second in alarm only mode for all the vhosts. So whenever any in-coming connection violates the data rate limit, the proxy will only emit log with Data Rate Alarm as a substring and the relevant limiter details, instead of actively limiting any data. @@ -152,7 +168,8 @@ Apply limit on allowed max bytes per second in alarm only mode for specified vho #### LIMIT DATA_RATE DEFAULT numberOfConnections -Apply limit on allowed max bytes per second for all the vhosts. So the data limit is enforced by counting the number of bytes read from the socket during each read operation, and pausing for one second before starting a read operation if the in-coming client connection violates the data. +Apply limit on allowed max bytes per second for all the vhosts. So the data limit is enforced by counting the number of bytes read from the socket during each read operation, and pausing for one second before +starting a read operation if the in-coming client connection violates the data. #### LIMIT DATA_RATE VHOST vhostName numberOfConnections Apply limit on allowed max bytes per second for specified vhost. The specific limit takes priority over the default limit for any vhost. @@ -165,13 +182,21 @@ Remove default connection rate limit (allowed average number of connections per Remove specific connection rate limit (allowed average number of connections per second) for the specified vhost. The default limit will be applied to the specified vhost, if the default limit is already configured. +#### LIMIT DISABLE TOTAL_CONN_ALARM DEFAULT numberOfConnections + +Remove default total connection limit (allowed total number of connections) in alarm only mode for all the vhosts. + +#### LIMIT DISABLE TOTAL_CONN VHOST vhostName numberOfConnections + +Remove specific total connection limit (allowed total number of connections) for the specified vhost. The default limit will be applied to the specified vhost, if the default limit is already configured. + #### LIMIT DISABLE DATA_RATE_ALARM DEFAULT numberOfConnections -Remove default data rate limit (allowed max bytes per second) in alarm only mode for all the vhosts. +Remove default data rate limit (allowed average bytess per second) in alarm only mode for all the vhosts. #### LIMIT DISABLE DATA_RATE VHOST vhostName numberOfConnections -Remove specific data rate limit (allowed max bytes per second) for the specified vhost. The default data limit will be applied to the specified vhost, if the default data limit is already configured. +Remove specific data rate limit (allowed average bytes per second) for the specified vhost. The default data limit will be applied to the specified vhost, if the default data limit is already configured. #### LIMIT PRINT [vhostName] diff --git a/libamqpprox/CMakeLists.txt b/libamqpprox/CMakeLists.txt index d83d860..117cc4b 100644 --- a/libamqpprox/CMakeLists.txt +++ b/libamqpprox/CMakeLists.txt @@ -90,6 +90,7 @@ add_library(libamqpprox STATIC amqpprox_connectionlimiterinterface.cpp amqpprox_connectionlimitermanager.cpp amqpprox_fixedwindowconnectionratelimiter.cpp + amqpprox_totalconnectionlimiter.cpp amqpprox_limitcontrolcommand.cpp amqpprox_closeerror.cpp) diff --git a/libamqpprox/amqpprox_connectionlimiterinterface.h b/libamqpprox/amqpprox_connectionlimiterinterface.h index 2fa3332..d14288f 100644 --- a/libamqpprox/amqpprox_connectionlimiterinterface.h +++ b/libamqpprox/amqpprox_connectionlimiterinterface.h @@ -35,6 +35,12 @@ class ConnectionLimiterInterface { */ virtual bool allowNewConnection() = 0; + /** + * \brief Called when an aquired connection is closed. Useful for changing + * the state of the limiter based on close connection event. + */ + virtual void connectionClosed() {} + // ACCESSORS /** * \return information about connection limiter as a string diff --git a/libamqpprox/amqpprox_connectionlimitermanager.cpp b/libamqpprox/amqpprox_connectionlimitermanager.cpp index 717dad0..b45b700 100644 --- a/libamqpprox/amqpprox_connectionlimitermanager.cpp +++ b/libamqpprox/amqpprox_connectionlimitermanager.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -47,8 +48,12 @@ void maybePopulateDefaultLimiters( ConnectionLimiterManager::ConnectionLimiterManager() : d_connectionRateLimitersPerVhost() , d_alarmOnlyConnectionRateLimitersPerVhost() +, d_totalConnectionLimitersPerVhost() +, d_alarmOnlyTotalConnectionLimitersPerVhost() , d_defaultConnectionRateLimit() , d_defaultAlarmOnlyConnectionRateLimit() +, d_defaultTotalConnectionLimit() +, d_defaultAlarmOnlyTotalConnectionLimit() , d_mutex() { } @@ -84,6 +89,34 @@ ConnectionLimiterManager::addAlarmOnlyConnectionRateLimiter( return alarmOnlyConnectionRateLimiter; } +std::shared_ptr +ConnectionLimiterManager::addTotalConnectionLimiter( + const std::string &vhostName, + uint32_t numberOfConnections) +{ + std::shared_ptr totalConnectionLimiter = + std::make_shared(numberOfConnections); + + std::lock_guard lg(d_mutex); + d_totalConnectionLimitersPerVhost[vhostName] = {true, + totalConnectionLimiter}; + return totalConnectionLimiter; +} + +std::shared_ptr +ConnectionLimiterManager::addAlarmOnlyTotalConnectionLimiter( + const std::string &vhostName, + uint32_t numberOfConnections) +{ + std::shared_ptr alarmOnlyTotalConnectionLimiter = + std::make_shared(numberOfConnections); + + std::lock_guard lg(d_mutex); + d_alarmOnlyTotalConnectionLimitersPerVhost[vhostName] = { + true, alarmOnlyTotalConnectionLimiter}; + return alarmOnlyTotalConnectionLimiter; +} + void ConnectionLimiterManager::setDefaultConnectionRateLimit( uint32_t defaultConnectionRateLimit) { @@ -119,6 +152,39 @@ void ConnectionLimiterManager::setAlarmOnlyDefaultConnectionRateLimit( } } +void ConnectionLimiterManager::setDefaultTotalConnectionLimit( + uint32_t defaultTotalConnectionLimit) +{ + std::lock_guard lg(d_mutex); + + d_defaultTotalConnectionLimit = defaultTotalConnectionLimit; + // To update new default total connection limit for all the vhosts, + // by removing old already set default total connection limiters + for (auto &limiter : d_totalConnectionLimitersPerVhost) { + if (!limiter.second.first) { + limiter.second.second = std::make_shared( + *d_defaultTotalConnectionLimit); + } + } +} + +void ConnectionLimiterManager::setAlarmOnlyDefaultTotalConnectionLimit( + uint32_t defaultTotalConnectionLimit) +{ + std::lock_guard lg(d_mutex); + + d_defaultAlarmOnlyTotalConnectionLimit = defaultTotalConnectionLimit; + // To update new default alarm only total connection limit for all the + // vhosts, by removing old already set default alarm only total connection + // limiters + for (auto &limiter : d_alarmOnlyConnectionRateLimitersPerVhost) { + if (!limiter.second.first) { + limiter.second.second = std::make_shared( + *d_defaultAlarmOnlyTotalConnectionLimit); + } + } +} + void ConnectionLimiterManager::removeConnectionRateLimiter( const std::string &vhostName) { @@ -151,6 +217,38 @@ void ConnectionLimiterManager::removeAlarmOnlyConnectionRateLimiter( } } +void ConnectionLimiterManager::removeTotalConnectionLimiter( + const std::string &vhostName) +{ + std::lock_guard lg(d_mutex); + + if (d_defaultTotalConnectionLimit) { + d_totalConnectionLimitersPerVhost[vhostName] = { + false, + std::make_shared( + *d_defaultTotalConnectionLimit)}; + } + else { + d_totalConnectionLimitersPerVhost.erase(vhostName); + } +} + +void ConnectionLimiterManager::removeAlarmOnlyTotalConnectionLimiter( + const std::string &vhostName) +{ + std::lock_guard lg(d_mutex); + + if (d_defaultAlarmOnlyTotalConnectionLimit) { + d_alarmOnlyTotalConnectionLimitersPerVhost[vhostName] = { + false, + std::make_shared( + *d_defaultAlarmOnlyTotalConnectionLimit)}; + } + else { + d_alarmOnlyTotalConnectionLimitersPerVhost.erase(vhostName); + } +} + void ConnectionLimiterManager::removeDefaultConnectionRateLimit() { std::lock_guard lg(d_mutex); @@ -183,6 +281,38 @@ void ConnectionLimiterManager::removeAlarmOnlyDefaultConnectionRateLimit() } } +void ConnectionLimiterManager::removeDefaultTotalConnectionLimit() +{ + std::lock_guard lg(d_mutex); + + d_defaultTotalConnectionLimit.reset(); + for (auto it = d_totalConnectionLimitersPerVhost.cbegin(); + it != d_totalConnectionLimitersPerVhost.cend();) { + if (!it->second.first) { + it = d_totalConnectionLimitersPerVhost.erase(it); + } + else { + ++it; + } + } +} + +void ConnectionLimiterManager::removeAlarmOnlyDefaultTotalConnectionLimit() +{ + std::lock_guard lg(d_mutex); + + d_defaultAlarmOnlyTotalConnectionLimit.reset(); + for (auto it = d_alarmOnlyTotalConnectionLimitersPerVhost.cbegin(); + it != d_alarmOnlyTotalConnectionLimitersPerVhost.cend();) { + if (!it->second.first) { + it = d_alarmOnlyTotalConnectionLimitersPerVhost.erase(it); + } + else { + ++it; + } + } +} + bool ConnectionLimiterManager::allowNewConnectionForVhost( const std::string &vhostName) { @@ -218,14 +348,59 @@ bool ConnectionLimiterManager::allowNewConnectionForVhost( if (limiter != d_connectionRateLimitersPerVhost.end()) { if (!(limiter->second.second->allowNewConnection())) { if (limiter->second.first) { - LOG_DEBUG + LOG_INFO << "AMQPPROX_CONNECTION_LIMIT: The connection request for " << vhostName << " is limited by " << limiter->second.second->toString(); return false; } else { - LOG_DEBUG + LOG_INFO + << "AMQPPROX_CONNECTION_LIMIT: The connection request for " + << vhostName << " is limited by default " + << limiter->second.second->toString(); + return false; + } + } + } + + maybePopulateDefaultLimiters(vhostName, + d_defaultAlarmOnlyTotalConnectionLimit, + d_alarmOnlyTotalConnectionLimitersPerVhost); + maybePopulateDefaultLimiters(vhostName, + d_defaultTotalConnectionLimit, + d_totalConnectionLimitersPerVhost); + + alarmLimiter = d_alarmOnlyTotalConnectionLimitersPerVhost.find(vhostName); + if (alarmLimiter != d_alarmOnlyTotalConnectionLimitersPerVhost.end()) { + if (!(alarmLimiter->second.second->allowNewConnection())) { + if (alarmLimiter->second.first) { + LOG_WARN << "AMQPPROX_CONNECTION_LIMIT: The connection " + "request for " + << vhostName << " should be limited by " + << alarmLimiter->second.second->toString(); + } + else { + LOG_WARN << "AMQPPROX_CONNECTION_LIMIT: The connection " + "request for " + << vhostName << " should be limited by default " + << alarmLimiter->second.second->toString(); + } + } + } + + limiter = d_totalConnectionLimitersPerVhost.find(vhostName); + if (limiter != d_totalConnectionLimitersPerVhost.end()) { + if (!(limiter->second.second->allowNewConnection())) { + if (limiter->second.first) { + LOG_INFO + << "AMQPPROX_CONNECTION_LIMIT: The connection request for " + << vhostName << " is limited by " + << limiter->second.second->toString(); + return false; + } + else { + LOG_INFO << "AMQPPROX_CONNECTION_LIMIT: The connection request for " << vhostName << " is limited by default " << limiter->second.second->toString(); @@ -237,6 +412,22 @@ bool ConnectionLimiterManager::allowNewConnectionForVhost( return true; } +void ConnectionLimiterManager::connectionClosed(const std::string &vhostName) +{ + std::lock_guard lg(d_mutex); + + auto alarmLimiter = + d_alarmOnlyTotalConnectionLimitersPerVhost.find(vhostName); + if (alarmLimiter != d_alarmOnlyTotalConnectionLimitersPerVhost.end()) { + alarmLimiter->second.second->connectionClosed(); + } + + auto limiter = d_totalConnectionLimitersPerVhost.find(vhostName); + if (limiter != d_totalConnectionLimitersPerVhost.end()) { + limiter->second.second->connectionClosed(); + } +} + std::shared_ptr ConnectionLimiterManager::getConnectionRateLimiter( const std::string &vhostName) const @@ -264,6 +455,33 @@ ConnectionLimiterManager::getAlarmOnlyConnectionRateLimiter( return nullptr; } +std::shared_ptr +ConnectionLimiterManager::getTotalConnectionLimiter( + const std::string &vhostName) const +{ + std::lock_guard lg(d_mutex); + + auto limiter = d_totalConnectionLimitersPerVhost.find(vhostName); + if (limiter != d_totalConnectionLimitersPerVhost.end()) { + return limiter->second.second; + } + return nullptr; +} + +std::shared_ptr +ConnectionLimiterManager::getAlarmOnlyTotalConnectionLimiter( + const std::string &vhostName) const +{ + std::lock_guard lg(d_mutex); + + auto alarmLimiter = + d_alarmOnlyTotalConnectionLimitersPerVhost.find(vhostName); + if (alarmLimiter != d_alarmOnlyTotalConnectionLimitersPerVhost.end()) { + return alarmLimiter->second.second; + } + return nullptr; +} + std::optional ConnectionLimiterManager::getDefaultConnectionRateLimit() const { @@ -276,5 +494,17 @@ ConnectionLimiterManager::getAlarmOnlyDefaultConnectionRateLimit() const return d_defaultAlarmOnlyConnectionRateLimit; } +std::optional +ConnectionLimiterManager::getDefaultTotalConnectionLimit() const +{ + return d_defaultTotalConnectionLimit; +} + +std::optional +ConnectionLimiterManager::getAlarmOnlyDefaultTotalConnectionLimit() const +{ + return d_defaultAlarmOnlyTotalConnectionLimit; +} + } } diff --git a/libamqpprox/amqpprox_connectionlimitermanager.h b/libamqpprox/amqpprox_connectionlimitermanager.h index 003a3cf..5595b1d 100644 --- a/libamqpprox/amqpprox_connectionlimitermanager.h +++ b/libamqpprox/amqpprox_connectionlimitermanager.h @@ -57,9 +57,13 @@ class ConnectionLimiterManager { // for the vhost ConnectionLimiters d_connectionRateLimitersPerVhost; ConnectionLimiters d_alarmOnlyConnectionRateLimitersPerVhost; + ConnectionLimiters d_totalConnectionLimitersPerVhost; + ConnectionLimiters d_alarmOnlyTotalConnectionLimitersPerVhost; std::optional d_defaultConnectionRateLimit; std::optional d_defaultAlarmOnlyConnectionRateLimit; + std::optional d_defaultTotalConnectionLimit; + std::optional d_defaultAlarmOnlyTotalConnectionLimit; mutable std::mutex d_mutex; public: @@ -92,6 +96,31 @@ class ConnectionLimiterManager { addAlarmOnlyConnectionRateLimiter(const std::string &vhostName, uint32_t numberOfConnections); + /** + * \brief Add new total connection limiter or modify existing total + * connection limiter for specified vhost + * \param vhostName vhost name + * \param numberOfConnections limit number of total connections + * \return the added total connection limiter + */ + std::shared_ptr + addTotalConnectionLimiter(const std::string &vhostName, + uint32_t numberOfConnections); + + /** + * \brief Add new total connection limiter or modify existing total + * connection limiter for specified vhost in alarm only mode. The limiter + * will only emit log at warning level with AMQPPROX_CONNECTION_LIMIT as a + * substring and the relevant limiter details, instead of limiting actual + * connection + * \param vhostName vhost name + * \param numberOfConnections limit number of total connections + * \return the added total connection limiter + */ + std::shared_ptr + addAlarmOnlyTotalConnectionLimiter(const std::string &vhostName, + uint32_t numberOfConnections); + /** * \brief Set default connection rate limit for all connecting vhosts * \param defaultConnectionRateLimit default connection rate (allowed @@ -110,6 +139,24 @@ class ConnectionLimiterManager { void setAlarmOnlyDefaultConnectionRateLimit( uint32_t defaultConnectionRateLimit); + /** + * \brief Set default total connection limit for all connecting vhosts + * \param defaultTotalConnectionLimit default total connection limit + * (allowed total connections) + */ + void setDefaultTotalConnectionLimit(uint32_t defaultTotalConnectionLimit); + + /** + * \brief Set default total connection limit for all connecting vhosts in + * alarm only mode. The limiter will only emit log at warning level with + * AMQPPROX_CONNECTION_LIMIT as a substring and the relevant limiter + * details, instead of limiting actual connection + * \param defaultTotalConnectionLimit default total connection limit + * (allowed total connections) + */ + void setAlarmOnlyDefaultTotalConnectionLimit( + uint32_t defaultTotalConnectionLimit); + /** * \brief Remove specific connection rate limiter for specified vhost * \param vhostName vhost name @@ -123,6 +170,19 @@ class ConnectionLimiterManager { */ void removeAlarmOnlyConnectionRateLimiter(const std::string &vhostName); + /** + * \brief Remove specific total connection limiter for specified vhost + * \param vhostName vhost name + */ + void removeTotalConnectionLimiter(const std::string &vhostName); + + /** + * \brief Remove specific alarm only total connection limiter for specified + * vhost + * \param vhostName vhost name + */ + void removeAlarmOnlyTotalConnectionLimiter(const std::string &vhostName); + /** * \brief Remove default connection rate limit for all the connecting * vhosts @@ -135,6 +195,18 @@ class ConnectionLimiterManager { */ void removeAlarmOnlyDefaultConnectionRateLimit(); + /** + * \brief Remove default total connection limit for all the connecting + * vhosts + */ + void removeDefaultTotalConnectionLimit(); + + /** + * \brief Remove default alarm only total connection rate limit for all the + * connecting vhosts + */ + void removeAlarmOnlyDefaultTotalConnectionLimit(); + /** * \brief Decide whether the current connection request should be allowed * or not based on configured different limiters for the specified vhost @@ -143,9 +215,9 @@ class ConnectionLimiterManager { bool allowNewConnectionForVhost(const std::string &vhostName); /** - * \brief Called when a session is marked as disconnected. + * \brief Called when an aquired connection is closed */ - void sessionClosedForVhost(const std::string &vhostName); + void connectionClosed(const std::string &vhostName); // ACCESSORS /** @@ -163,6 +235,21 @@ class ConnectionLimiterManager { std::shared_ptr getAlarmOnlyConnectionRateLimiter(const std::string &vhostName) const; + /** + * \brief Get particular total connection limiter based on specified vhost + * \param vhostName vhost name + */ + std::shared_ptr + getTotalConnectionLimiter(const std::string &vhostName) const; + + /** + * \brief Get particular alarm only total connection limiter based on + * specified vhost + * \param vhostName vhost name + */ + std::shared_ptr + getAlarmOnlyTotalConnectionLimiter(const std::string &vhostName) const; + /** * \brief Get default connection rate limit (allowed connections per * second) for all the connecting vhosts @@ -174,6 +261,18 @@ class ConnectionLimiterManager { * per second) for all the connecting vhosts */ std::optional getAlarmOnlyDefaultConnectionRateLimit() const; + + /** + * \brief Get default total connection limit (allowed total connections) + * for all the connecting vhosts + */ + std::optional getDefaultTotalConnectionLimit() const; + + /** + * \brief Get alarm only default total connection limit (allowed total + * connections) for all the connecting vhosts + */ + std::optional getAlarmOnlyDefaultTotalConnectionLimit() const; }; } diff --git a/libamqpprox/amqpprox_connectionselector.cpp b/libamqpprox/amqpprox_connectionselector.cpp index e6b1cb0..86124e6 100644 --- a/libamqpprox/amqpprox_connectionselector.cpp +++ b/libamqpprox/amqpprox_connectionselector.cpp @@ -118,6 +118,12 @@ SessionState::ConnectionStatus ConnectionSelector::acquireConnection( return SessionState::ConnectionStatus::SUCCESS; } +void ConnectionSelector::notifyConnectionDisconnect( + const std::string &vhostName) +{ + d_connectionLimiterManager_p->connectionClosed(vhostName); +} + void ConnectionSelector::setDefaultFarm(const std::string &farmName) { std::lock_guard lg(d_mutex); diff --git a/libamqpprox/amqpprox_connectionselector.h b/libamqpprox/amqpprox_connectionselector.h index c61ed7a..bb74f69 100644 --- a/libamqpprox/amqpprox_connectionselector.h +++ b/libamqpprox/amqpprox_connectionselector.h @@ -72,6 +72,12 @@ class ConnectionSelector : public ConnectionSelectorInterface { acquireConnection(std::shared_ptr *connectionOut, const SessionState &state) override; + /** + * \brief Notify connection disconnect event + */ + virtual void + notifyConnectionDisconnect(const std::string &vhostName) override; + /** * \brief Set the default farm if a mapping is not found */ diff --git a/libamqpprox/amqpprox_connectionselectorinterface.h b/libamqpprox/amqpprox_connectionselectorinterface.h index ae4ccce..2a8c7b0 100644 --- a/libamqpprox/amqpprox_connectionselectorinterface.h +++ b/libamqpprox/amqpprox_connectionselectorinterface.h @@ -43,6 +43,11 @@ class ConnectionSelectorInterface { virtual SessionState::ConnectionStatus acquireConnection(std::shared_ptr *connectionOut, const SessionState &state) = 0; + + /** + * \brief Notify connection disconnect event + */ + virtual void notifyConnectionDisconnect(const std::string &vhostName) = 0; }; } diff --git a/libamqpprox/amqpprox_limitcontrolcommand.cpp b/libamqpprox/amqpprox_limitcontrolcommand.cpp index a48c504..c72d16f 100644 --- a/libamqpprox/amqpprox_limitcontrolcommand.cpp +++ b/libamqpprox/amqpprox_limitcontrolcommand.cpp @@ -40,50 +40,78 @@ void handleConnectionLimitAlarm( ConnectionLimiterManager *connectionLimiterManager, bool isDefault, const std::string &vhostName, - bool isDisable) + bool isDisable, + bool isTotalConnLimit) { + std::string limitType = + (isTotalConnLimit ? "total connection" : "connection rate"); if (isDisable) { if (isDefault) { - connectionLimiterManager - ->removeAlarmOnlyDefaultConnectionRateLimit(); - output << "Successfully disabled default alarm only " - "connection rate limit\n "; + isTotalConnLimit + ? connectionLimiterManager + ->removeAlarmOnlyDefaultTotalConnectionLimit() + : connectionLimiterManager + ->removeAlarmOnlyDefaultConnectionRateLimit(); + output << "Successfully disabled default alarm only " << limitType + << " limit\n "; } else { - connectionLimiterManager->removeAlarmOnlyConnectionRateLimiter( - vhostName); - output << "Successfully disabled specific alarm only " - "connection rate limit for vhost " - << vhostName << "\n"; + isTotalConnLimit + ? connectionLimiterManager + ->removeAlarmOnlyTotalConnectionLimiter(vhostName) + : connectionLimiterManager + ->removeAlarmOnlyConnectionRateLimiter(vhostName); + output << "Successfully disabled specific alarm only " << limitType + << " limit for vhost " << vhostName << "\n"; } } else { uint32_t numberOfConnections; if (!(iss >> numberOfConnections)) { - output << "Invalid numberOfConnections provided.\n"; + output << "Invalid numberOfConnections provided for " << limitType + << " limit.\n"; return; } if (isDefault) { - connectionLimiterManager->setAlarmOnlyDefaultConnectionRateLimit( - numberOfConnections); - output << "Default connection rate limit is set to " - << connectionLimiterManager - ->getAlarmOnlyDefaultConnectionRateLimit() - .value() - << " connections per second in alarm only mode.\n"; - output << "The limiter will only log at warning level with " - "AMQPPROX_CONNECTION_LIMIT as a substring and the " - "relevant limit details, when the new incoming " - "connection violates the default limit for all " - "vhosts.\n"; + if (isTotalConnLimit) { + connectionLimiterManager + ->setAlarmOnlyDefaultConnectionRateLimit( + numberOfConnections); + output << "Default " << limitType << " limit is set to " + << connectionLimiterManager + ->getAlarmOnlyDefaultTotalConnectionLimit() + .value() + << " total connections in alarm only mode.\n"; + } + else { + connectionLimiterManager + ->setAlarmOnlyDefaultTotalConnectionLimit( + numberOfConnections); + output << "Default " << limitType << " limit is set to " + << connectionLimiterManager + ->getAlarmOnlyDefaultConnectionRateLimit() + .value() + << " connections per second in alarm only mode.\n"; + } + + output + << "The limiter will only log at warning level with " + "AMQPPROX_CONNECTION_LIMIT as a substring and the " + "relevant limit details, when the new incoming client " + "connection violates the default limit for all vhosts.\n"; } else { output << "For vhost " << vhostName << ", " - << connectionLimiterManager - ->addAlarmOnlyConnectionRateLimiter( - vhostName, numberOfConnections) - ->toString() + << (isTotalConnLimit + ? connectionLimiterManager + ->addAlarmOnlyTotalConnectionLimiter( + vhostName, numberOfConnections) + ->toString() + : connectionLimiterManager + ->addAlarmOnlyConnectionRateLimiter( + vhostName, numberOfConnections) + ->toString()) << " in alarm only mode.\n"; output << "The limiter will only log at warning level with " "AMQPPROX_CONNECTION_LIMIT as a substring and the " @@ -99,42 +127,69 @@ void handleConnectionLimit( ConnectionLimiterManager *connectionLimiterManager, bool isDefault, const std::string &vhostName, - bool isDisable) + bool isDisable, + bool isTotalConnLimit) { + std::string limitType = + (isTotalConnLimit ? "total connection" : "connection rate"); if (isDisable) { if (isDefault) { - connectionLimiterManager->removeDefaultConnectionRateLimit(); - output << "Successfully disabled default connection rate " - "limit\n "; + isTotalConnLimit + ? connectionLimiterManager->removeDefaultConnectionRateLimit() + : connectionLimiterManager + ->removeDefaultTotalConnectionLimit(); + output << "Successfully disabled default " << limitType + << " limit\n "; } else { - connectionLimiterManager->removeConnectionRateLimiter(vhostName); - output << "Successfully disabled specific connection rate " - "limit for vhost " - << vhostName << "\n"; + isTotalConnLimit + ? connectionLimiterManager->removeTotalConnectionLimiter( + vhostName) + : connectionLimiterManager->removeConnectionRateLimiter( + vhostName); + output << "Successfully disabled specific " << limitType + << " limit for vhost " << vhostName << "\n"; } } else { uint32_t numberOfConnections; if (!(iss >> numberOfConnections)) { - output << "Invalid numberOfConnections provided.\n"; + output << "Invalid numberOfConnections provided for " << limitType + << " limit\n"; return; } if (isDefault) { - connectionLimiterManager->setDefaultConnectionRateLimit( - numberOfConnections); - output << "Default connection rate limit is set to " - << connectionLimiterManager->getDefaultConnectionRateLimit() - .value() - << " connections per second.\n"; + if (isTotalConnLimit) { + connectionLimiterManager->setDefaultTotalConnectionLimit( + numberOfConnections); + output << "Default " << limitType << " limit is set to " + << connectionLimiterManager + ->getDefaultTotalConnectionLimit() + .value() + << " total connections.\n"; + } + else { + connectionLimiterManager->setDefaultConnectionRateLimit( + numberOfConnections); + output << "Default " << limitType << " limit is set to " + << connectionLimiterManager + ->getDefaultConnectionRateLimit() + .value() + << " connections per second.\n"; + } } else { output << "For vhost " << vhostName << ", " - << connectionLimiterManager - ->addConnectionRateLimiter(vhostName, - numberOfConnections) - ->toString() + << (isTotalConnLimit + ? connectionLimiterManager + ->addTotalConnectionLimiter( + vhostName, numberOfConnections) + ->toString() + : connectionLimiterManager + ->addConnectionRateLimiter( + vhostName, numberOfConnections) + ->toString()) << "\n"; } } @@ -286,6 +341,40 @@ void printVhostLimits( } } + auto alarmTotalConnLimiter = + connectionLimiterManager->getAlarmOnlyTotalConnectionLimiter( + vhostName); + if (alarmTotalConnLimiter) { + output << "Alarm only limit, for vhost " << vhostName << ", " + << alarmTotalConnLimiter->toString() << ".\n"; + anyConfiguredLimit = true; + } + else { + std::optional alarmTotalConnLimit = + connectionLimiterManager->getAlarmOnlyDefaultConnectionRateLimit(); + if (alarmTotalConnLimit) { + output << "Alarm only limit, for vhost " << vhostName << ", allow " + << *alarmTotalConnLimit << " total connections.\n"; + anyConfiguredLimit = true; + } + } + auto totalConnLimiter = + connectionLimiterManager->getTotalConnectionLimiter(vhostName); + if (totalConnLimiter) { + output << "For vhost " << vhostName << ", " + << totalConnLimiter->toString() << ".\n"; + anyConfiguredLimit = true; + } + else { + std::optional totalConnLimit = + connectionLimiterManager->getDefaultTotalConnectionLimit(); + if (totalConnLimit) { + output << "For vhost " << vhostName << ", allow " + << *totalConnLimit << " total connections.\n"; + anyConfiguredLimit = true; + } + } + std::size_t alarmDataRateLimit = dataRateLimitManager->getDataRateAlarm(vhostName); if (alarmDataRateLimit != std::numeric_limits::max()) { @@ -318,6 +407,11 @@ void printAllLimits( std::optional connectionRateLimit = connectionLimiterManager->getDefaultConnectionRateLimit(); + std::optional alarmOnlyTotalConnectionLimit = + connectionLimiterManager->getAlarmOnlyDefaultTotalConnectionLimit(); + std::optional totalConnectionLimit = + connectionLimiterManager->getDefaultTotalConnectionLimit(); + std::size_t alarmOnlyDataRateLimit = dataRateLimitManager->getDefaultDataRateAlarm(); std::size_t dataRateLimit = @@ -336,6 +430,18 @@ void printAllLimits( anyConfiguredLimit = true; } + if (alarmOnlyTotalConnectionLimit) { + output << "Default limit for any vhost, allow " + << *alarmOnlyTotalConnectionLimit + << " total connections in alarm only mode.\n"; + anyConfiguredLimit = true; + } + if (totalConnectionLimit) { + output << "Default limit for any vhost, allow " + << *totalConnectionLimit << " total connections.\n"; + anyConfiguredLimit = true; + } + if (alarmOnlyDataRateLimit != std::numeric_limits::max()) { output << "Default data limit for any vhost, allow max " << alarmOnlyDataRateLimit @@ -394,18 +500,21 @@ std::string LimitControlCommand::commandVerb() const std::string LimitControlCommand::helpText() const { - return "(CONN_RATE_ALARM | CONN_RATE) (VHOST vhostName " - "numberOfConnections | DEFAULT numberOfConnections) - Configure " - "connection rate limits (normal or alarmonly) for incoming clients " - "connections\n" + return "(CONN_RATE_ALARM | CONN_RATE) (DEFAULT | VHOST vhostName) " + "numberOfConnections - Configure connection rate limits (normal or " + "alarmonly) for incoming clients connections\n" + + "LIMIT (TOTAL_CONN_ALARM | TOTAL_CONN) (DEFAULT | VHOST vhostName) " + "numberOfConnections - Configure total connection limits or alarms " + "for incoming client connections\n" "LIMIT (DATA_RATE_ALARM | DATA_RATE) (DEFAULT | VHOST vhostName) " "BytesPerSecond - Configure data rate limits or alarms for " "incoming client data\n" - "LIMIT DISABLE (CONN_RATE_ALARM | CONN_RATE | DATA_RATE_ALARM | " - "DATA_RATE) (VHOST vhostName | DEFAULT) - Disable configured limit " - "thresholds\n" + "LIMIT DISABLE (CONN_RATE_ALARM | CONN_RATE | TOTAL_CONN_ALARM | " + "TOTAL_CONN | DATA_RATE_ALARM | DATA_RATE) (VHOST vhostName | " + "DEFAULT) - Disable configured limit thresholds\n" "LIMIT PRINT [vhostName] - Print the configured default limits or " "specific vhost limits"; @@ -468,7 +577,8 @@ void LimitControlCommand::handleCommand(const std::string & /* command */, d_connectionLimiterManager_p, isDefault, vhostName, - isDisable); + isDisable, + false); } else if (subcommand == "CONN_RATE") { handleConnectionLimit(iss, @@ -476,7 +586,26 @@ void LimitControlCommand::handleCommand(const std::string & /* command */, d_connectionLimiterManager_p, isDefault, vhostName, - isDisable); + isDisable, + false); + } + else if (subcommand == "TOTAL_CONN_ALARM") { + handleConnectionLimitAlarm(iss, + output, + d_connectionLimiterManager_p, + isDefault, + vhostName, + isDisable, + true); + } + else if (subcommand == "TOTAL_CONN") { + handleConnectionLimit(iss, + output, + d_connectionLimiterManager_p, + isDefault, + vhostName, + isDisable, + true); } else if (subcommand == "DATA_RATE_ALARM") { handleDataRateAlarmLimit(serverHandle, diff --git a/libamqpprox/amqpprox_session.cpp b/libamqpprox/amqpprox_session.cpp index 8b9cbd6..671a116 100644 --- a/libamqpprox/amqpprox_session.cpp +++ b/libamqpprox/amqpprox_session.cpp @@ -168,6 +168,10 @@ Session::Session(boost::asio::io_context &ioContext, Session::~Session() { + if (d_sessionState.getTotalConnectionIncremented()) { + d_connectionSelector_p->notifyConnectionDisconnect( + d_sessionState.getVirtualHost()); + } } bool Session::finished() @@ -488,6 +492,7 @@ void Session::establishConnection() case SessionState::ConnectionStatus::NO_FARM: case SessionState::ConnectionStatus::ERROR_FARM: case SessionState::ConnectionStatus::NO_BACKEND: + d_sessionState.setTotalConnectionIncremented(); d_connector.synthesizeCustomCloseError( true, Reply::Codes::resource_error, @@ -498,6 +503,7 @@ void Session::establishConnection() break; default: + d_sessionState.setTotalConnectionIncremented(); LOG_INFO << "Failed to acquire connection for vhost " << d_sessionState.getVirtualHost() << ", rc: " << static_cast(rc); @@ -506,6 +512,7 @@ void Session::establishConnection() return; } + d_sessionState.setTotalConnectionIncremented(); auto authResponseCb = [this, self, connectionManager]( const authproto::AuthResponse diff --git a/libamqpprox/amqpprox_sessionstate.cpp b/libamqpprox/amqpprox_sessionstate.cpp index 33c0465..1a4f9ee 100644 --- a/libamqpprox/amqpprox_sessionstate.cpp +++ b/libamqpprox/amqpprox_sessionstate.cpp @@ -48,6 +48,7 @@ SessionState::SessionState( , d_authDeniedConnection(false) , d_ingressSecured(false) , d_limitedConnection(false) +, d_totalConnnectionIncremented(false) , d_virtualHost() , d_disconnectedStatus(DisconnectType::NOT_DISCONNECTED) , d_id(s_nextId++) // This isn't a race because this is only on one thread @@ -161,6 +162,11 @@ void SessionState::setLimitedConnection() d_limitedConnection = true; } +void SessionState::setTotalConnectionIncremented() +{ + d_totalConnnectionIncremented = true; +} + std::string SessionState::hostname(const boost::asio::ip::tcp::endpoint &endpoint) const { diff --git a/libamqpprox/amqpprox_sessionstate.h b/libamqpprox/amqpprox_sessionstate.h index 683f589..d91779d 100644 --- a/libamqpprox/amqpprox_sessionstate.h +++ b/libamqpprox/amqpprox_sessionstate.h @@ -71,6 +71,7 @@ class SessionState { std::atomic d_authDeniedConnection; std::atomic d_ingressSecured; std::atomic d_limitedConnection; + bool d_totalConnnectionIncremented; std::string d_virtualHost; DisconnectType d_disconnectedStatus; uint64_t d_id; @@ -146,6 +147,12 @@ class SessionState { */ void setLimitedConnection(); + /** + * \brief Set the current connection session is counted in total connection + * limit + */ + void setTotalConnectionIncremented(); + /** * \brief Set session as disconnected, along with which type of disconnect * \param disconnectType specifies type of disconnection @@ -237,6 +244,12 @@ class SessionState { */ inline bool getLimitedConnection() const; + /** + * \return the state of the current connection session, whether it is + * counted in total connection limit + */ + inline bool getTotalConnectionIncremented() const; + /** * \return session identifier */ @@ -305,6 +318,11 @@ inline bool SessionState::getLimitedConnection() const return d_limitedConnection; } +inline bool SessionState::getTotalConnectionIncremented() const +{ + return d_totalConnnectionIncremented; +} + inline uint64_t SessionState::id() const { return d_id; diff --git a/libamqpprox/amqpprox_totalconnectionlimiter.cpp b/libamqpprox/amqpprox_totalconnectionlimiter.cpp new file mode 100644 index 0000000..b861383 --- /dev/null +++ b/libamqpprox/amqpprox_totalconnectionlimiter.cpp @@ -0,0 +1,70 @@ +/* +** Copyright 2022 Bloomberg Finance L.P. +** +** Licensed under the Apache License, Version 2.0 (the "License"); +** you may not use this file except in compliance with the License. +** You may obtain a copy of the License at +** +** http://www.apache.org/licenses/LICENSE-2.0 +** +** Unless required by applicable law or agreed to in writing, software +** distributed under the License is distributed on an "AS IS" BASIS, +** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +** See the License for the specific language governing permissions and +** limitations under the License. +*/ + +#include + +#include + +namespace Bloomberg { +namespace amqpprox { + +TotalConnectionLimiter::TotalConnectionLimiter(uint32_t totalConnectionLimit) +: ConnectionLimiterInterface() +, d_totalConnectionLimit(totalConnectionLimit) +, d_connectionCount(0) +{ +} + +bool TotalConnectionLimiter::allowNewConnection() +{ + if (d_connectionCount < d_totalConnectionLimit) { + d_connectionCount++; + return true; + } + return false; +} + +void TotalConnectionLimiter::connectionClosed() +{ + if (d_connectionCount == 0) { + // This is possible when the limiter is set up, while having already + // some active connections + return; + } + + d_connectionCount--; +} + +std::string TotalConnectionLimiter::toString() const +{ + std::stringstream ss; + ss << "Allow total " << d_totalConnectionLimit << " connections"; + + return ss.str(); +} + +uint32_t TotalConnectionLimiter::getTotalConnectionLimit() const +{ + return d_totalConnectionLimit; +} + +uint32_t TotalConnectionLimiter::getConnectionCount() const +{ + return d_connectionCount; +} + +} +} diff --git a/libamqpprox/amqpprox_totalconnectionlimiter.h b/libamqpprox/amqpprox_totalconnectionlimiter.h new file mode 100644 index 0000000..a485a8c --- /dev/null +++ b/libamqpprox/amqpprox_totalconnectionlimiter.h @@ -0,0 +1,78 @@ +/* +** Copyright 2022 Bloomberg Finance L.P. +** +** Licensed under the Apache License, Version 2.0 (the "License"); +** you may not use this file except in compliance with the License. +** You may obtain a copy of the License at +** +** http://www.apache.org/licenses/LICENSE-2.0 +** +** Unless required by applicable law or agreed to in writing, software +** distributed under the License is distributed on an "AS IS" BASIS, +** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +** See the License for the specific language governing permissions and +** limitations under the License. +*/ +#ifndef BLOOMBERG_AMQPPROX_TOTALCONNECTIONLIMITER +#define BLOOMBERG_AMQPPROX_TOTALCONNECTIONLIMITER + +#include + +#include + +namespace Bloomberg { +namespace amqpprox { + +/** + * \brief The class will impose total allowed connection limit based on + * provided connection limit. Implements the ConnectionLimiterInterface + * interface + */ +class TotalConnectionLimiter : public ConnectionLimiterInterface { + // Maximum total connection limit + uint32_t d_totalConnectionLimit; + + // connection count + uint32_t d_connectionCount; + + public: + // CREATORS + explicit TotalConnectionLimiter(uint32_t totalConnectionLimit); + + // MANIPULATORS + /** + * \brief Decide whether the current connection request should be allowed + * based on total connection limit + * + * \note The method should always be called in thread-safe manner/serially, + * otherwise the connection counter value will not be maintained accurately + */ + virtual bool allowNewConnection() override; + + /** + * \brief Called when an aquired connection is closed. Useful for changing + * the state of the limiter based on close connection event. + */ + virtual void connectionClosed() override; + + // ACCESSORS + /** + * \return Information about connection limiter as a string + */ + virtual std::string toString() const override; + + /** + * \return the total connection limit (total allowed connections) + */ + uint32_t getTotalConnectionLimit() const; + + /** + * \return the current connection count + */ + uint32_t getConnectionCount() const; +}; + +} +} + +#endif diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3bfb3d6..253aa95 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -41,6 +41,7 @@ add_executable(amqpprox_tests amqpprox_defaultauthintercept.t.cpp amqpprox_httpauthintercept.t.cpp amqpprox_fixedwindowconnectionratelimiter.t.cpp + amqpprox_totalconnectionlimiter.t.cpp amqpprox_connectionlimitermanager.t.cpp ) diff --git a/tests/amqpprox_connectionlimitermanager.t.cpp b/tests/amqpprox_connectionlimitermanager.t.cpp index 78e140f..80699d9 100644 --- a/tests/amqpprox_connectionlimitermanager.t.cpp +++ b/tests/amqpprox_connectionlimitermanager.t.cpp @@ -17,6 +17,7 @@ #include #include +#include #include @@ -33,10 +34,16 @@ TEST(ConnectionLimiterManagerTest, Breathing) ConnectionLimiterManager limiterManager; EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); EXPECT_TRUE(limiterManager.getAlarmOnlyConnectionRateLimiter( "test-vhost") == nullptr); EXPECT_TRUE(limiterManager.getConnectionRateLimiter("test-vhost") == nullptr); + EXPECT_TRUE(limiterManager.getAlarmOnlyTotalConnectionLimiter( + "test-vhost") == nullptr); + EXPECT_TRUE(limiterManager.getTotalConnectionLimiter("test-vhost") == + nullptr); } TEST(ConnectionLimiterManagerTest, AddGetRemoveConnectionRateLimiter) @@ -207,10 +214,179 @@ TEST(ConnectionLimiterManagerTest, SetGetRemoveDefaultConnectionRateLimiter) EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); } +TEST(ConnectionLimiterManagerTest, AddGetRemoveTotalConnectionLimiter) +{ + ConnectionLimiterManager limiterManager; + + std::string vhostName1 = "test-vhost1"; + std::string vhostName2 = "test-vhost2"; + uint32_t connectionLimit1 = 100; + uint32_t connectionLimit2 = 200; + + // Adding limiter for vhostName1 + limiterManager.addTotalConnectionLimiter(vhostName1, connectionLimit1); + + // Getting limiter for vhostName1 + std::shared_ptr limiter1 = + std::dynamic_pointer_cast( + limiterManager.getTotalConnectionLimiter(vhostName1)); + ASSERT_TRUE(limiter1 != nullptr); + EXPECT_EQ(limiter1->getTotalConnectionLimit(), connectionLimit1); + EXPECT_EQ(limiter1->getConnectionCount(), 0); + + // Adding limiter for vhostName2 + std::shared_ptr limiter2 = + std::dynamic_pointer_cast( + limiterManager.addTotalConnectionLimiter(vhostName2, + connectionLimit2)); + ASSERT_TRUE(limiter2 != nullptr); + EXPECT_EQ(limiter2->getTotalConnectionLimit(), connectionLimit2); + EXPECT_EQ(limiter2->getConnectionCount(), 0); + + // Modifying limiter for vhostName1 + uint32_t newConnectionLimit = 300; + std::shared_ptr newLimiter = + std::dynamic_pointer_cast( + limiterManager.addTotalConnectionLimiter(vhostName1, + newConnectionLimit)); + ASSERT_TRUE(newLimiter != nullptr); + EXPECT_EQ(newLimiter->getTotalConnectionLimit(), newConnectionLimit); + EXPECT_EQ(newLimiter->getConnectionCount(), 0); + + // Getting limiter for vhostName2 + limiter2 = std::dynamic_pointer_cast( + limiterManager.getTotalConnectionLimiter(vhostName2)); + ASSERT_TRUE(limiter2 != nullptr); + EXPECT_EQ(limiter2->getTotalConnectionLimit(), connectionLimit2); + EXPECT_EQ(limiter2->getConnectionCount(), 0); + + // Removing limiter for vhostName1 + limiterManager.removeTotalConnectionLimiter(vhostName1); + std::shared_ptr removedLimiter = + std::dynamic_pointer_cast( + limiterManager.getTotalConnectionLimiter(vhostName1)); + ASSERT_TRUE(removedLimiter == nullptr); + + // Getting limiter for vhostName2 + limiter2 = std::dynamic_pointer_cast( + limiterManager.getTotalConnectionLimiter(vhostName2)); + ASSERT_TRUE(limiter2 != nullptr); + EXPECT_EQ(limiter2->getTotalConnectionLimit(), connectionLimit2); + EXPECT_EQ(limiter2->getConnectionCount(), 0); +} + +TEST(ConnectionLimiterManagerTest, AddGetRemoveAlarmOnlyTotalConnectionLimiter) +{ + ConnectionLimiterManager limiterManager; + + std::string vhostName1 = "test-vhost1"; + std::string vhostName2 = "test-vhost2"; + uint32_t connectionLimit1 = 100; + uint32_t connectionLimit2 = 200; + + // Adding alarm only limiter for vhostName1 + limiterManager.addAlarmOnlyTotalConnectionLimiter(vhostName1, + connectionLimit1); + + // Getting alarm only limiter for vhostName1 + std::shared_ptr limiter1 = + std::dynamic_pointer_cast( + limiterManager.getAlarmOnlyTotalConnectionLimiter(vhostName1)); + ASSERT_TRUE(limiter1 != nullptr); + EXPECT_EQ(limiter1->getTotalConnectionLimit(), connectionLimit1); + EXPECT_EQ(limiter1->getConnectionCount(), 0); + + // Adding alarm only limiter for vhostName2 + std::shared_ptr limiter2 = + std::dynamic_pointer_cast( + limiterManager.addAlarmOnlyTotalConnectionLimiter( + vhostName2, connectionLimit2)); + ASSERT_TRUE(limiter2 != nullptr); + EXPECT_EQ(limiter2->getTotalConnectionLimit(), connectionLimit2); + EXPECT_EQ(limiter2->getConnectionCount(), 0); + + // Modifying alarm only limiter for vhostName1 + uint32_t newConnectionLimit = 300; + std::shared_ptr newLimiter = + std::dynamic_pointer_cast( + limiterManager.addAlarmOnlyTotalConnectionLimiter( + vhostName1, newConnectionLimit)); + ASSERT_TRUE(newLimiter != nullptr); + EXPECT_EQ(newLimiter->getTotalConnectionLimit(), newConnectionLimit); + EXPECT_EQ(newLimiter->getConnectionCount(), 0); + + // Getting alarm only limiter for vhostName2 + limiter2 = std::dynamic_pointer_cast( + limiterManager.getAlarmOnlyTotalConnectionLimiter(vhostName2)); + ASSERT_TRUE(limiter2 != nullptr); + EXPECT_EQ(limiter2->getTotalConnectionLimit(), connectionLimit2); + EXPECT_EQ(limiter2->getConnectionCount(), 0); + + // Removing alarm only limiter for vhostName1 + limiterManager.removeAlarmOnlyTotalConnectionLimiter(vhostName1); + std::shared_ptr removedLimiter = + std::dynamic_pointer_cast( + limiterManager.getAlarmOnlyTotalConnectionLimiter(vhostName1)); + ASSERT_TRUE(removedLimiter == nullptr); + + // Getting alarm only limiter for vhostName2 + limiter2 = std::dynamic_pointer_cast( + limiterManager.getAlarmOnlyTotalConnectionLimiter(vhostName2)); + ASSERT_TRUE(limiter2 != nullptr); + EXPECT_EQ(limiter2->getTotalConnectionLimit(), connectionLimit2); + EXPECT_EQ(limiter2->getConnectionCount(), 0); +} + +TEST(ConnectionLimiterManagerTest, SetGetRemoveDefaultTotalConnectionLimiter) +{ + ConnectionLimiterManager limiterManager; + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + + uint32_t connectionLimit1 = 100; + // Setting default limiter + limiterManager.setDefaultTotalConnectionLimit(connectionLimit1); + + // Getting default limiter + ASSERT_TRUE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getDefaultTotalConnectionLimit(), + connectionLimit1); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + + uint32_t connectionLimit2 = 200; + // Setting alarm only default limiter + limiterManager.setAlarmOnlyDefaultTotalConnectionLimit(connectionLimit2); + + // Getting alarm only default limiter + ASSERT_TRUE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getAlarmOnlyDefaultTotalConnectionLimit(), + connectionLimit2); + ASSERT_TRUE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getDefaultTotalConnectionLimit(), + connectionLimit1); + + // Removing default limiter + limiterManager.removeDefaultTotalConnectionLimit(); + + // Getting default limiter + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + ASSERT_TRUE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getAlarmOnlyDefaultTotalConnectionLimit(), + connectionLimit2); + + // Removing alarm only default limiter + limiterManager.removeAlarmOnlyDefaultTotalConnectionLimit(); + + // Getting default limiter + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); +} + TEST(ConnectionLimiterManagerTest, AllowNewConnectionForVhostWithoutAnyLimit) { ConnectionLimiterManager limiterManager; EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost("test-vhost")); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost("test-vhost")); } @@ -287,7 +463,6 @@ TEST(ConnectionLimiterManagerTest, TEST(ConnectionLimiterManagerTest, AllowNewConnectionForVhostWithSpecificAndDefaultRateLimit) { - using namespace std::chrono_literals; ConnectionLimiterManager limiterManager; std::string vhostName = "test-vhost"; EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); @@ -317,3 +492,184 @@ TEST(ConnectionLimiterManagerTest, EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); } + +TEST(ConnectionLimiterManagerTest, + AllowNewConnectionForVhostWithSpecificTotalConnectionLimit) +{ + ConnectionLimiterManager limiterManager; + std::string vhostName = "test-vhost"; + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + + uint32_t connectionLimit = 1; + limiterManager.addTotalConnectionLimiter(vhostName, connectionLimit); + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); +} + +TEST(ConnectionLimiterManagerTest, + AllowNewConnectionForVhostWithAlarmOnlySpecificTotalConnectionLimit) +{ + ConnectionLimiterManager limiterManager; + std::string vhostName = "test-vhost"; + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + + uint32_t connectionLimit = 1; + limiterManager.addAlarmOnlyTotalConnectionLimiter(vhostName, + connectionLimit); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); +} + +TEST(ConnectionLimiterManagerTest, + AllowNewConnectionForVhostWithDefaultTotalConnectionLimit) +{ + ConnectionLimiterManager limiterManager; + std::string vhostName = "test-vhost"; + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + + uint32_t connectionLimit = 1; + limiterManager.setDefaultTotalConnectionLimit(connectionLimit); + ASSERT_TRUE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getDefaultTotalConnectionLimit(), + connectionLimit); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); +} + +TEST(ConnectionLimiterManagerTest, + AllowNewConnectionForVhostWithAlarmOnlyDefaultTotalConnectionLimit) +{ + ConnectionLimiterManager limiterManager; + std::string vhostName = "test-vhost"; + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + + uint32_t connectionLimit = 1; + limiterManager.setAlarmOnlyDefaultTotalConnectionLimit(connectionLimit); + ASSERT_TRUE(limiterManager.getAlarmOnlyDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getAlarmOnlyDefaultTotalConnectionLimit(), + connectionLimit); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); +} + +TEST(ConnectionLimiterManagerTest, + AllowNewConnectionForVhostWithSpecificAndDefaultTotalConnectionLimit) +{ + ConnectionLimiterManager limiterManager; + std::string vhostName = "test-vhost"; + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + + uint32_t connectionLimit = 1; + limiterManager.setDefaultTotalConnectionLimit(connectionLimit); + ASSERT_TRUE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getDefaultTotalConnectionLimit(), + connectionLimit); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); + + uint32_t newConnectionLimit = 2; + limiterManager.addTotalConnectionLimiter(vhostName, newConnectionLimit); + ASSERT_TRUE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getDefaultTotalConnectionLimit(), + connectionLimit); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + + limiterManager.removeTotalConnectionLimiter(vhostName); + ASSERT_TRUE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getDefaultTotalConnectionLimit(), + connectionLimit); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); +} + +TEST(ConnectionLimiterManagerTest, + AllowNewConnectionForVhostWithDefaultRateLimiterAndTotalConnectionLimiter) +{ + ConnectionLimiterManager limiterManager; + std::string vhostName = "test-vhost"; + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + + uint32_t connectionLimit = 5; + limiterManager.setDefaultConnectionRateLimit(connectionLimit); + ASSERT_TRUE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getDefaultConnectionRateLimit(), + connectionLimit); + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + + uint32_t totalConnectionLimit = 1; + limiterManager.setDefaultTotalConnectionLimit(totalConnectionLimit); + ASSERT_TRUE(limiterManager.getDefaultTotalConnectionLimit()); + EXPECT_EQ(*limiterManager.getDefaultTotalConnectionLimit(), + totalConnectionLimit); + + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); + + limiterManager.removeDefaultTotalConnectionLimit(); + EXPECT_FALSE(limiterManager.getDefaultTotalConnectionLimit()); + + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); +} + +TEST( + ConnectionLimiterManagerTest, + AllowNewConnectionForVhostWithSpecificRateLimiterAndTotalConnectionLimiter) +{ + ConnectionLimiterManager limiterManager; + std::string vhostName = "test-vhost"; + + std::shared_ptr limiter1 = + std::dynamic_pointer_cast( + limiterManager.getConnectionRateLimiter(vhostName)); + EXPECT_FALSE(limiter1 != nullptr); + std::shared_ptr limiter2 = + std::dynamic_pointer_cast( + limiterManager.getTotalConnectionLimiter(vhostName)); + EXPECT_FALSE(limiter2 != nullptr); + + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + + uint32_t connectionLimit = 5; + limiter1 = std::dynamic_pointer_cast( + limiterManager.addConnectionRateLimiter(vhostName, connectionLimit)); + ASSERT_TRUE(limiter1 != nullptr); + EXPECT_EQ(limiter1->getConnectionLimit(), connectionLimit); + EXPECT_EQ(limiter1->getTimeWindowInSec(), 1); + + uint32_t totalConnectionLimit = 1; + limiter2 = std::dynamic_pointer_cast( + limiterManager.addTotalConnectionLimiter(vhostName, + totalConnectionLimit)); + ASSERT_TRUE(limiter2 != nullptr); + EXPECT_EQ(limiter2->getTotalConnectionLimit(), totalConnectionLimit); + EXPECT_EQ(limiter2->getConnectionCount(), 0); + + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); + + limiterManager.removeTotalConnectionLimiter(vhostName); + limiter2 = std::dynamic_pointer_cast( + limiterManager.getTotalConnectionLimiter(vhostName)); + EXPECT_FALSE(limiter2 != nullptr); + + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); + EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); +} diff --git a/tests/amqpprox_connectionselector.t.cpp b/tests/amqpprox_connectionselector.t.cpp index 6062b7b..19f2bc4 100644 --- a/tests/amqpprox_connectionselector.t.cpp +++ b/tests/amqpprox_connectionselector.t.cpp @@ -79,6 +79,30 @@ TEST(ConnectionSelector, Limited_Connection) SessionState::ConnectionStatus::LIMIT); } +TEST(ConnectionSelector, Limited_Total_Connection) +{ + FarmStore farmStore; + BackendStore backendStore; + ResourceMapper resourceMapper; + ConnectionLimiterManager connectionLimiterManager; + uint32_t connectionLimit = 1; + std::string vhostName = "test-vhost"; + connectionLimiterManager.addTotalConnectionLimiter(vhostName, + connectionLimit); + ConnectionSelector connectionSelector( + &farmStore, &backendStore, &resourceMapper, &connectionLimiterManager); + SessionState state; + state.setVirtualHost(vhostName); + std::shared_ptr out; + EXPECT_EQ(connectionSelector.acquireConnection(&out, state), + SessionState::ConnectionStatus::NO_FARM); + + // Acquiring second connection will be limited because of configured + // connection limit + EXPECT_EQ(connectionSelector.acquireConnection(&out, state), + SessionState::ConnectionStatus::LIMIT); +} + TEST(ConnectionSelector, Limited_Connection_Alarm_Only) { FarmStore farmStore; diff --git a/tests/amqpprox_session.t.cpp b/tests/amqpprox_session.t.cpp index 3197cb6..0d1d97f 100644 --- a/tests/amqpprox_session.t.cpp +++ b/tests/amqpprox_session.t.cpp @@ -96,6 +96,8 @@ struct SelectorMock : public ConnectionSelectorInterface { acquireConnection, SessionState::ConnectionStatus(std::shared_ptr *, const SessionState &)); + + MOCK_METHOD1(notifyConnectionDisconnect, void(const std::string &)); }; struct HostnameMapperMock : public HostnameMapper { @@ -643,6 +645,7 @@ TEST_F(SessionTest, Connection_Then_Ping_Then_Disconnect) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -734,6 +737,7 @@ TEST_F(SessionTest, BadServerHandshake) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -930,6 +934,7 @@ TEST_F(SessionTest, Connect_Multiple_Dns) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); EXPECT_CALL(*d_mapper, prime(_, _)).Times(AtLeast(1)); EXPECT_CALL(*d_mapper, mapToHostname(makeEndpoint("2.3.4.5", 2345))) @@ -1064,6 +1069,7 @@ TEST_F(SessionTest, Failover_Dns_Failure) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); EXPECT_CALL(*d_mapper, prime(_, _)).Times(AtLeast(1)); EXPECT_CALL(*d_mapper, mapToHostname(makeEndpoint("2.3.4.5", 2345))) @@ -1152,6 +1158,7 @@ TEST_F(SessionTest, Connection_Then_Ping_Then_Force_Disconnect) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -1200,6 +1207,7 @@ TEST_F(SessionTest, Connection_Then_Ping_Then_Backend_Disconnect) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -1255,6 +1263,7 @@ TEST_F(SessionTest, Authorized_Client_Test) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); std::string modifiedMechanism = "TEST_MECHANISM"; std::string modifiedCredentials = "credentials"; @@ -1327,6 +1336,7 @@ TEST_F( EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); authproto::AuthResponse authResponseData; authResponseData.set_result(authproto::AuthResponse::DENY); @@ -1373,6 +1383,7 @@ TEST_F(SessionTest, EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); authproto::AuthResponse authResponseData; authResponseData.set_result(authproto::AuthResponse::DENY); @@ -1431,6 +1442,7 @@ TEST_F(SessionTest, Forward_Received_Close_Method_To_Client_During_Handshake) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -1488,6 +1500,7 @@ TEST_F(SessionTest, Close_Connection_No_Broker_Mapping) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::NO_BACKEND))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -1534,6 +1547,7 @@ TEST_F(SessionTest, Close_Limited_Connection) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::LIMIT))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(0); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -1582,6 +1596,7 @@ TEST_F(SessionTest, Printing_Breathing_Test) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -1626,6 +1641,7 @@ TEST_F(SessionTest, Pause_Disconnects_Previously_Established_Connection) EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); TestSocketState::State base, clientBase; testSetupHostnameMapperForServerClientBase(base, clientBase); @@ -1725,6 +1741,7 @@ TEST_F(SessionTest, EXPECT_CALL(d_selector, acquireConnection(_, _)) .WillOnce(DoAll(SetArgPointee<0>(d_cm), Return(SessionState::ConnectionStatus::SUCCESS))); + EXPECT_CALL(d_selector, notifyConnectionDisconnect(_)).Times(1); // Run the tests through to completion driveTo(16); diff --git a/tests/amqpprox_totalconnectionlimiter.t.cpp b/tests/amqpprox_totalconnectionlimiter.t.cpp new file mode 100644 index 0000000..cba7a65 --- /dev/null +++ b/tests/amqpprox_totalconnectionlimiter.t.cpp @@ -0,0 +1,65 @@ +/* +** Copyright 2022 Bloomberg Finance L.P. +** +** Licensed under the Apache License, Version 2.0 (the "License"); +** you may not use this file except in compliance with the License. +** You may obtain a copy of the License at +** +** http://www.apache.org/licenses/LICENSE-2.0 +** +** Unless required by applicable law or agreed to in writing, software +** distributed under the License is distributed on an "AS IS" BASIS, +** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +** See the License for the specific language governing permissions and +** limitations under the License. +*/ + +#include + +#include +#include + +#include + +using namespace Bloomberg; +using namespace amqpprox; +using namespace testing; + +TEST(TotalConnectionLimiterTest, Breathing) +{ + uint32_t totalConnectionLimit = 1000; + TotalConnectionLimiter limiter(totalConnectionLimit); + EXPECT_EQ(limiter.getTotalConnectionLimit(), totalConnectionLimit); + EXPECT_EQ(limiter.getConnectionCount(), 0); +} + +TEST(TotalConnectionLimiterTest, ToString) +{ + uint32_t totalConnectionLimit = 1000; + TotalConnectionLimiter limiter(totalConnectionLimit); + EXPECT_EQ(limiter.toString(), + "Allow total " + std::to_string(totalConnectionLimit) + + " connections"); +} + +TEST(TotalConnectionLimiterTest, AllowNewConnectionAndCloseConnection) +{ + uint32_t totalConnectionLimit = 1; + TotalConnectionLimiter limiter(totalConnectionLimit); + EXPECT_EQ(limiter.getTotalConnectionLimit(), totalConnectionLimit); + EXPECT_EQ(limiter.getConnectionCount(), 0); + + EXPECT_TRUE(limiter.allowNewConnection()); + EXPECT_FALSE(limiter.allowNewConnection()); + EXPECT_EQ(limiter.getTotalConnectionLimit(), totalConnectionLimit); + EXPECT_EQ(limiter.getConnectionCount(), totalConnectionLimit); + + limiter.connectionClosed(); + EXPECT_EQ(limiter.getTotalConnectionLimit(), totalConnectionLimit); + EXPECT_EQ(limiter.getConnectionCount(), 0); + + EXPECT_TRUE(limiter.allowNewConnection()); + EXPECT_FALSE(limiter.allowNewConnection()); + EXPECT_EQ(limiter.getTotalConnectionLimit(), totalConnectionLimit); + EXPECT_EQ(limiter.getConnectionCount(), totalConnectionLimit); +}