From 6a33b426e547903455761d10b6aeceb1dbcab3fd Mon Sep 17 00:00:00 2001 From: Karen Chen Date: Mon, 25 Nov 2024 13:16:10 -0800 Subject: [PATCH 1/5] feat: custom endpoint support --- driver/CMakeLists.txt | 9 + driver/allowed_and_blocked_hosts.h | 74 ++++++++ driver/auth_util.cc | 2 +- driver/cache_map.cc | 94 ++++++++++ driver/cache_map.h | 73 ++++++++ driver/custom_endpoint_info.cc | 84 +++++++++ driver/custom_endpoint_info.h | 137 ++++++++++++++ driver/custom_endpoint_monitor.cc | 175 ++++++++++++++++++ driver/custom_endpoint_monitor.h | 71 +++++++ driver/custom_endpoint_proxy.cc | 134 ++++++++++++++ driver/custom_endpoint_proxy.h | 91 +++++++++ driver/handle.cc | 8 + driver/rds_utils.cc | 33 ++++ driver/rds_utils.h | 2 + driver/sliding_expiration_cache.cc | 9 +- driver/sliding_expiration_cache.h | 31 ++-- ...g_expiration_cache_with_clean_up_thread.cc | 15 +- ...ng_expiration_cache_with_clean_up_thread.h | 9 +- util/installer.cc | 24 ++- util/installer.h | 76 ++++---- 20 files changed, 1085 insertions(+), 66 deletions(-) create mode 100644 driver/allowed_and_blocked_hosts.h create mode 100644 driver/cache_map.cc create mode 100644 driver/cache_map.h create mode 100644 driver/custom_endpoint_info.cc create mode 100644 driver/custom_endpoint_info.h create mode 100644 driver/custom_endpoint_monitor.cc create mode 100644 driver/custom_endpoint_monitor.h create mode 100644 driver/custom_endpoint_proxy.cc create mode 100644 driver/custom_endpoint_proxy.h diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 1c9a090f7..cb172efbe 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -62,6 +62,7 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) auth_util.cc aws_sdk_helper.cc base_metrics_holder.cc + cache_map.cc catalog.cc catalog_no_i_s.cc cluster_topology_info.cc @@ -72,6 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) connect.cc connection_handler.cc connection_proxy.cc + custom_endpoint_proxy.cc + custom_endpoint_info.cc + custom_endpoint_monitor.cc cursor.cc desc.cc dll.cc @@ -131,9 +135,11 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.rc.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc @ONLY) SET(DRIVER_SRCS ${DRIVER_SRCS} driver${CONNECTOR_DRIVER_TYPE_SHORT}.def driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc adfs_proxy.h + allowed_and_blocked_hosts.h auth_util.h aws_sdk_helper.h base_metrics_holder.h + cache_map.h catalog.h cluster_aware_hit_metrics_holder.h cluster_aware_metrics_container.h @@ -142,6 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) cluster_topology_info.h connection_handler.h connection_proxy.h + custom_endpoint_proxy.h + custom_endpoint_info.h + custom_endpoint_monitor.h driver.h efm_proxy.h error.h diff --git a/driver/allowed_and_blocked_hosts.h b/driver/allowed_and_blocked_hosts.h new file mode 100644 index 000000000..f32036478 --- /dev/null +++ b/driver/allowed_and_blocked_hosts.h @@ -0,0 +1,74 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __ALLOWED_AND_BLOCKED_HOSTS__ +#define __ALLOWED_AND_BLOCKED_HOSTS__ + +#include +#include + +/** + * Represents the allowed and blocked hosts for connections. + */ +class ALLOWED_AND_BLOCKED_HOSTS { + public: + /** + * Constructs an AllowedAndBlockedHosts instance with the specified allowed and blocked host IDs. + * @param allowed_host_ids The set of allowed host IDs for connections. If null or empty, all host IDs that are not in + * `blocked_host_ids` are allowed. + * @param blocked_host_ids The set of blocked host IDs for connections. If null or empty, all host IDs in + * `allowed_host_ids` are allowed. If `allowed_host_ids` is also null or empty, there + * are no restrictions on which hosts are allowed. + */ + ALLOWED_AND_BLOCKED_HOSTS(const std::set& allowed_host_ids, + const std::set& blocked_host_ids) + : allowed_host_ids(allowed_host_ids), blocked_host_ids(blocked_host_ids){}; + + /** + * Returns the set of allowed host IDs for connections. If null or empty, all host IDs that are not in + * `blocked_host_ids` are allowed. + * + * @return the set of allowed host IDs for connections. + */ + std::set get_allowed_host_ids() { return this->allowed_host_ids; }; + + /** + * Returns the set of blocked host IDs for connections. If null or empty, all host IDs in `allowed_host_ids` + * are allowed. If `allowed_host_ids` is also null or empty, there are no restrictions on which hosts are allowed. + * + * @return the set of blocked host IDs for connections. + */ + std::set get_blocked_host_ids() { return this->blocked_host_ids; }; + + private: + std::set allowed_host_ids; + std::set blocked_host_ids; +}; + +#endif diff --git a/driver/auth_util.cc b/driver/auth_util.cc index 7ff6aac47..0b5fe1f57 100644 --- a/driver/auth_util.cc +++ b/driver/auth_util.cc @@ -74,7 +74,7 @@ std::pair AUTH_UTIL::get_auth_token(std::unordered_mapbuild_cache_key(host, region, port, user); + const std::string cache_key = build_cache_key(host, region, port, user); bool using_cached_token = false; { diff --git a/driver/cache_map.cc b/driver/cache_map.cc new file mode 100644 index 000000000..a20e97e88 --- /dev/null +++ b/driver/cache_map.cc @@ -0,0 +1,94 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "cache_map.h" + +#include + +#include "custom_endpoint_info.h" + +template +void CACHE_MAP::put(K key, V value, long long item_expiration_nanos) { + this->cache[key] = std::make_shared( + value, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); +} + +template +V CACHE_MAP::get(K key, V default_value) { + if (cache.count(key) > 0 && !cache[key]->is_expired()) { + return this->cache[key]->item; + } + return default_value; +} + +template +V CACHE_MAP::get(K key, V default_value, long long item_expiration_nanos) { + if (cache.count(key) == 0 || this->cache[key]->is_expired()) { + this->put(key, std::move(default_value), item_expiration_nanos); + } + return this->cache[key]->item; +} + +template +void CACHE_MAP::remove(K key) { + if (this->cache.count(key)) { + this->cache.erase(key); + } +} + +template +int CACHE_MAP::size() { + return this->cache.size(); +} + +template +void CACHE_MAP::clear() { + this->cache.clear(); + this->clean_up(); +} + +template +void CACHE_MAP::clean_up() { + if (std::chrono::steady_clock::now() > this->clean_up_time_nanos.load()) { + this->clean_up_time_nanos = + std::chrono::steady_clock::now() + std::chrono::nanoseconds(this->clean_up_time_interval_nanos); + std::vector keys; + keys.reserve(this->cache.size()); + for (auto& [key, cache_item] : this->cache) { + keys.push_back(key); + } + for (const auto& key : keys) { + if (this->cache[key]->is_expired()) { + this->cache.erase(key); + } + } + } +} + +template class CACHE_MAP>; diff --git a/driver/cache_map.h b/driver/cache_map.h new file mode 100644 index 000000000..82e7fed89 --- /dev/null +++ b/driver/cache_map.h @@ -0,0 +1,73 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CACHE_MAP__ +#define __CACHE_MAP__ + +#include +#include +#include + +template +class CACHE_MAP { + public: + class CACHE_ITEM { + public: + CACHE_ITEM() = default; + CACHE_ITEM(V item, std::chrono::steady_clock::time_point expiration_time) + : item(item), expiration_time(expiration_time){}; + ~CACHE_ITEM() = default; + V item; + + bool is_expired() { return std::chrono::steady_clock::now() > this->expiration_time; } + + private: + std::chrono::steady_clock::time_point expiration_time; + }; + + CACHE_MAP() = default; + ~CACHE_MAP() = default; + + void put(K key, V value, long long item_expiration_nanos); + V get(K key, V default_value); + V get(K key, V default_value, long long item_expiration_nanos); + void remove(K key); + int size(); + void clear(); + + protected: + void clean_up(); + const long long clean_up_time_interval_nanos = 60000000000; // 10 minute + std::atomic clean_up_time_nanos; + + private: + std::unordered_map> cache; +}; + +#endif diff --git a/driver/custom_endpoint_info.cc b/driver/custom_endpoint_info.cc new file mode 100644 index 000000000..bb8bc0c9b --- /dev/null +++ b/driver/custom_endpoint_info.cc @@ -0,0 +1,84 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software {} you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY {} without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "custom_endpoint_info.h" + +std::shared_ptr CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint( + const Aws::RDS::Model::DBClusterEndpoint& response_endpoint_info) { + std::vector members; + MEMBERS_LIST_TYPE members_list_type; + + if (response_endpoint_info.StaticMembersHasBeenSet()) { + members = response_endpoint_info.GetStaticMembers(); + members_list_type = STATIC_LIST; + } else { + members = response_endpoint_info.GetExcludedMembers(); + members_list_type = EXCLUSION_LIST; + } + + std::set members_set(members.begin(), members.end()); + + return std::make_shared( + response_endpoint_info.GetDBClusterEndpointIdentifier(), response_endpoint_info.GetDBClusterIdentifier(), + response_endpoint_info.GetEndpoint(), + CUSTOM_ENDPOINT_INFO::get_role_type(response_endpoint_info.GetCustomEndpointType()), members_set, + members_list_type); +} + +std::set CUSTOM_ENDPOINT_INFO::get_excluded_members() const { + if (this->member_list_type == EXCLUSION_LIST) { + return members; + } + + return std::set(); +} + +std::set CUSTOM_ENDPOINT_INFO::get_static_members() const { + if (this->member_list_type == STATIC_LIST) { + return members; + } + + return std::set(); +} + +bool operator==(const CUSTOM_ENDPOINT_INFO& current, const CUSTOM_ENDPOINT_INFO& other) { + return current.endpoint_identifier == other.endpoint_identifier && + current.cluster_identifier == other.cluster_identifier && current.url == other.url && + current.role_type == other.role_type && + current.member_list_type == other.member_list_type; +} + +CUSTOM_ENDPOINT_ROLE_TYPE CUSTOM_ENDPOINT_INFO::get_role_type(const Aws::String& role_type) { + auto it = CUSTOM_ENDPOINT_ROLE_TYPE_MAP.find(role_type); + if (it != CUSTOM_ENDPOINT_ROLE_TYPE_MAP.end()) { + return it->second; + } + + throw std::invalid_argument("Invalid role type for custom endpoint, this should not have happened."); +} diff --git a/driver/custom_endpoint_info.h b/driver/custom_endpoint_info.h new file mode 100644 index 000000000..fe67706f1 --- /dev/null +++ b/driver/custom_endpoint_info.h @@ -0,0 +1,137 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CUSTOM_ENDPOINT_INFO_H__ +#define __CUSTOM_ENDPOINT_INFO_H__ + +#include + +#include +#include +#include + +#include "MYODBC_MYSQL.h" +#include "mylog.h" + +/** + * Enum representing the possible roles of instances specified by a custom endpoint. Note that, currently, it is not + * possible to create a WRITER custom endpoint. + */ +enum CUSTOM_ENDPOINT_ROLE_TYPE { + ANY, // Instances in the custom endpoint may be either a writer or a reader. + WRITER, // Instance in the custom endpoint is always the writer. + READER // Instances in the custom endpoint are always readers. +}; + +static std::unordered_map const CUSTOM_ENDPOINT_ROLE_TYPE_MAP = { + {"ANY", ANY}, {"WRITER", WRITER}, {"READER", READER}}; + +static std::unordered_map const CUSTOM_ENDPOINT_ROLE_TYPE_STR_MAP = { + {ANY, "ANY"}, {WRITER, "WRITER"}, {READER, "READER"}}; + +/** + * Enum representing the member list type of a custom endpoint. This information can be used together with a member list + * to determine which instances are included or excluded from a custom endpoint. + */ +enum MEMBERS_LIST_TYPE { + /** + * The member list for the custom endpoint specifies which instances are included in the custom endpoint. If new + * instances are added to the cluster, they will not be automatically added to the custom endpoint. + */ + STATIC_LIST, + /** + * The member list for the custom endpoint specifies which instances are excluded from the custom endpoint. If new + * instances are added to the cluster, they will be automatically added to the custom endpoint. + */ + EXCLUSION_LIST +}; + +static std::unordered_map const MEMBERS_LIST_TYPE_MAP = { + {STATIC_LIST, "STATIC_LIST0"}, {EXCLUSION_LIST, "EXCLUSION_LIST"}}; + +class CUSTOM_ENDPOINT_INFO { + public: + CUSTOM_ENDPOINT_INFO(std::string endpoint_identifier, std::string cluster_identifier, std::string url, + CUSTOM_ENDPOINT_ROLE_TYPE role_type, std::set members, + MEMBERS_LIST_TYPE member_list_type) + : endpoint_identifier(std::move(endpoint_identifier)), + cluster_identifier(std::move(cluster_identifier)), + url(std::move(url)), + role_type(role_type), + members(std::move(members)), + member_list_type(member_list_type){}; + ~CUSTOM_ENDPOINT_INFO() = default; + + static std::shared_ptr from_db_cluster_endpoint( + const Aws::RDS::Model::DBClusterEndpoint& response_endpoint_info); + std::string get_endpoint_identifier() const { return this->endpoint_identifier; }; + std::string get_cluster_identifier() const { return this->cluster_identifier; }; + std::string get_url() const { return this->url; }; + CUSTOM_ENDPOINT_ROLE_TYPE get_custom_endpoint_type() const { return this->role_type; }; + MEMBERS_LIST_TYPE get_member_list_type() const { return this->member_list_type; }; + std::set get_excluded_members() const; + std::set get_static_members() const; + + std::string to_string() const { + char buf[4096]; + std::string members_list; + + for (auto const& m : members) { + members_list += m; + members_list += ","; + } + if (members_list.empty()) { + members_list = ""; + } else { + members_list.pop_back(); + } + + myodbc_snprintf( + buf, sizeof(buf), + "CustomEndpointInfo[url=%s, cluster_identifier=%s, custom_endpoint_type=%s, member_list_type=%s, members=[%s]", + this->url.c_str(), this->cluster_identifier.c_str(), + CUSTOM_ENDPOINT_ROLE_TYPE_STR_MAP.at(this->role_type).c_str(), + MEMBERS_LIST_TYPE_MAP.at(this->member_list_type).c_str(), members_list.c_str()); + + return std::string(buf); + } + + friend bool operator==(const CUSTOM_ENDPOINT_INFO& current, const CUSTOM_ENDPOINT_INFO& other); + + private: + const std::string endpoint_identifier; + const std::string cluster_identifier; + const std::string url; + const CUSTOM_ENDPOINT_ROLE_TYPE role_type; + const std::set members; + const MEMBERS_LIST_TYPE member_list_type; + static CUSTOM_ENDPOINT_ROLE_TYPE get_role_type(const Aws::String& role_type); +}; + +#endif diff --git a/driver/custom_endpoint_monitor.cc b/driver/custom_endpoint_monitor.cc new file mode 100644 index 000000000..5afda5c6c --- /dev/null +++ b/driver/custom_endpoint_monitor.cc @@ -0,0 +1,175 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "custom_endpoint_monitor.h" + +#include + +#include "allowed_and_blocked_hosts.h" +#include "aws_sdk_helper.h" +#include "driver.h" +#include "monitor_service.h" +#include "mylog.h" + +namespace { +AWS_SDK_HELPER SDK_HELPER; +} + +CACHE_MAP> CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache; + +CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr& custom_endpoint_host_info, + const std::string& endpoint_identifier, const std::string& region, + DataSource* ds, bool enable_logging) + : custom_endpoint_host_info(custom_endpoint_host_info), + endpoint_identifier(endpoint_identifier), + region(region), + enable_logging(enable_logging) { + if (enable_logging) { + this->logger = init_log_file(); + } + + ++SDK_HELPER; + + Aws::RDS::RDSClientConfiguration client_config; + if (!region.empty()) { + client_config.region = region; + } + + this->rds_client = std::make_shared( + Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config); + + this->run(); +} + +bool CUSTOM_ENDPOINT_MONITOR::should_dispose() { return true; } + +bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const { + auto default_val = std::shared_ptr(nullptr); + return custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), default_val) != default_val; +} + +void CUSTOM_ENDPOINT_MONITOR::run() { + this->thread_pool.resize(1); + this->thread_pool.push([=](int id) { + MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", + this->custom_endpoint_host_info->get_host().c_str()); + + try { + while (!this->should_stop.load()) { + const std::chrono::time_point start = std::chrono::steady_clock::now(); + Aws::RDS::Model::Filter filter; + filter.SetName("db-cluster-endpoint-type"); + filter.SetValues({"custom"}); + + Aws::RDS::Model::DescribeDBClusterEndpointsRequest request; + request.SetDBClusterIdentifier(this->endpoint_identifier); + request.SetFilters({filter}); + const auto response = this->rds_client->DescribeDBClusterEndpoints(request); + + const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints(); + if (custom_endpoints.size() != 1) { + MYLOG_TRACE(this->logger, 0, + "Unexpected number of custom endpoints with endpoint identifier '%s' in region '%s'. Expected 1 " + "custom endpoint, but found %d. Endpoints: %s", + endpoint_identifier.c_str(), region.c_str(), custom_endpoints.size(), + this->get_endpoints_as_string(custom_endpoints).c_str()); + + std::this_thread::sleep_for(std::chrono::nanoseconds(this->refresh_rate_nanos)); + continue; + } + const std::shared_ptr endpoint_info = + CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]); + const std::shared_ptr cache_endpoint_info = + custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), nullptr); + + if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) { + const long long elapsed_time = + std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); + std::this_thread::sleep_for( + std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); + continue; + } + + MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}", + custom_endpoint_host_info->get_host().c_str(), endpoint_info->to_string().c_str()); + + // The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts. + std::shared_ptr allowed_and_blocked_hosts; + if (endpoint_info->get_member_list_type() == STATIC_LIST) { + allowed_and_blocked_hosts = + std::make_shared(endpoint_info->get_static_members(), std::set()); + } else { + allowed_and_blocked_hosts = + std::make_shared(std::set(), endpoint_info->get_excluded_members()); + } + + custom_endpoint_cache.put(this->custom_endpoint_host_info->get_host(), endpoint_info, + CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS); + const long long elapsed_time = + std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); + std::this_thread::sleep_for( + std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); + } + + } catch (const std::exception& e) { + // Log and continue monitoring. + if (this->enable_logging) { + MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what()); + } + } + }); +} + +std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string( + const std::vector& custom_endpoints) { + std::string endpoints("["); + + for (auto const& e : custom_endpoints) { + endpoints += e.GetDBClusterEndpointIdentifier(); + endpoints += ","; + } + if (endpoints.empty()) { + endpoints = ""; + } else { + endpoints.pop_back(); + endpoints += "]"; + } + return endpoints; +} + +void CUSTOM_ENDPOINT_MONITOR::stop() { + this->should_stop.store(true); + this->thread_pool.stop(true); + this->thread_pool.resize(0); + custom_endpoint_cache.remove(this->custom_endpoint_host_info->get_host()); + --SDK_HELPER; + MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host_info->get_host()); +} + +void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); } diff --git a/driver/custom_endpoint_monitor.h b/driver/custom_endpoint_monitor.h new file mode 100644 index 000000000..e824b3d72 --- /dev/null +++ b/driver/custom_endpoint_monitor.h @@ -0,0 +1,71 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CUSTOM_ENDPOINT_MONITOR_H__ +#define __CUSTOM_ENDPOINT_MONITOR_H__ + +#include + +#include +#include "cache_map.h" +#include "connection_handler.h" +#include "connection_proxy.h" +#include "custom_endpoint_info.h" +#include "host_info.h" + +class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this { + public: + CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr& custom_endpoint_host_info, const std::string& endpoint_identifier, + const std::string& region, DataSource* ds, bool enable_logging = false); + ~CUSTOM_ENDPOINT_MONITOR() = default; + + static bool should_dispose(); + bool has_custom_endpoint_info() const; + void stop(); + void run(); + static void clear_cache(); + + protected: + static CACHE_MAP> custom_endpoint_cache; + static constexpr long long CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS = 300000000000; // 5 minutes + std::shared_ptr custom_endpoint_host_info; + std::string endpoint_identifier; + std::string region; + long long refresh_rate_nanos; + bool enable_logging; + std::shared_ptr logger; + ctpl::thread_pool thread_pool; + std::atomic_bool should_stop{false}; + std::shared_ptr rds_client; + + private: + static std::string get_endpoints_as_string(const std::vector& custom_endpoints); +}; + +#endif diff --git a/driver/custom_endpoint_proxy.cc b/driver/custom_endpoint_proxy.cc new file mode 100644 index 000000000..cbb337c23 --- /dev/null +++ b/driver/custom_endpoint_proxy.cc @@ -0,0 +1,134 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "custom_endpoint_proxy.h" +#include "custom_endpoint_monitor.h" +#include "installer.h" +#include "mylog.h" +#include "rds_utils.h" + +SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> + CUSTOM_ENDPOINT_PROXY::monitors(std::make_shared(), + std::make_shared(), CACHE_CLEANUP_RATE_NANO); +std::mutex CUSTOM_ENDPOINT_PROXY::monitor_cache_mutex; + +CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds) : CUSTOM_ENDPOINT_PROXY(dbc, ds, nullptr) {} +CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) + : CONNECTION_PROXY(dbc, ds) { + this->next_proxy = next_proxy; + + if (ds->opt_LOG_QUERY) { + this->logger = init_log_file(); + } + + this->should_wait_for_info = ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO; + this->wait_on_cached_info_duration_ms = ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; + this->idle_monitor_expiration_ms = ds->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS; +} + +bool CUSTOM_ENDPOINT_PROXY::connect(const char* host, const char* user, const char* password, const char* database, + unsigned int port, const char* socket, unsigned long flags) { + if (!RDS_UTILS::is_rds_custom_cluster_dns(host)) { + return this->next_proxy->connect(host, user, password, database, port, socket, flags); + } + + this->custom_endpoint_host = host; + MYLOG_TRACE(this->logger, 0, "Detected a connection request to a custom endpoint URL: '%s'", host); + + this->custom_endpoint_id = RDS_UTILS::get_rds_cluster_id(host); + + if (this->custom_endpoint_id.empty()) { + this->set_custom_error_message("Unable to parse custom endpoint identifier from URL."); + return false; + } + + this->region = ds->opt_CUSTOM_ENDPOINT_REGION ? static_cast(ds->opt_CUSTOM_ENDPOINT_REGION) + : RDS_UTILS::get_rds_region(host); + if (this->region.empty()) { + this->set_custom_error_message( + "Unable to determine connection region. If you are using a non-standard RDS URL, please set the " + "'custom_endpoint_region' property"); + return false; + } + + const std::shared_ptr monitor = create_monitor_if_absent(ds); + if (this->should_wait_for_info) { + // If needed, wait a short time for custom endpoint info to be discovered. + this->wait_for_custom_endpoint_info(monitor); + } + + return this->next_proxy->connect(host, user, password, database, port, socket, flags); +} + +void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptr monitor) { + bool has_custom_endpoint_info = monitor->has_custom_endpoint_info(); + + if (has_custom_endpoint_info) { + return; + } + + // Wait for the monitor to place the custom endpoint info in the cache. This ensures other plugins get accurate + // custom endpoint info. + MYLOG_TRACE(this->logger, 0, + "Custom endpoint info for '%s' was not found. Waiting %dms for the endpoint monitor to fetch info...", + this->custom_endpoint_host_info->get_host().c_str(), this->wait_on_cached_info_duration_ms) + + const auto wait_for_endpoint_info_timeout_nanos = + std::chrono::steady_clock::now() + std::chrono::duration_cast( + std::chrono::milliseconds(this->wait_on_cached_info_duration_ms)); + + while (!has_custom_endpoint_info && std::chrono::steady_clock::now() < wait_for_endpoint_info_timeout_nanos) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + has_custom_endpoint_info = monitor->has_custom_endpoint_info(); + } + + if (!has_custom_endpoint_info) { + char buf[1024]; + myodbc_snprintf( + buf, sizeof(buf), + "The custom endpoint plugin timed out after %ld ms while waiting for custom endpoint info for host %s.", + this->wait_on_cached_info_duration_ms, this->custom_endpoint_host_info->get_host().c_str()); + + set_custom_error_message(buf); + } +} + +std::shared_ptr CUSTOM_ENDPOINT_PROXY::create_monitor_if_absent(DataSource* ds) { + const auto refresh_rate_nanos = std::chrono::duration_cast( + std::chrono::milliseconds(ds->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS)) + .count(); + + return monitors.compute_if_absent( + this->custom_endpoint_host_info->get_host(), + [=](std::string key) { + return std::make_shared(this->custom_endpoint_host_info, this->custom_endpoint_id, + this->region, ds); + }, + refresh_rate_nanos); +} diff --git a/driver/custom_endpoint_proxy.h b/driver/custom_endpoint_proxy.h new file mode 100644 index 000000000..be7338bbb --- /dev/null +++ b/driver/custom_endpoint_proxy.h @@ -0,0 +1,91 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a scopy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CUSTOM_ENDPOINT_PROXY__ +#define __CUSTOM_ENDPOINT_PROXY__ + +#include +#include +#include "connection_proxy.h" +#include "custom_endpoint_monitor.h" +#include "sliding_expiration_cache_with_clean_up_thread.h" + +class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { + public: + CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds); + CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy); + + bool connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, + const char* socket, unsigned long flags) override; + + class CUSTOM_ENDPOINTS_SHOULD_DISPOSE_FUNC : public SHOULD_DISPOSE_FUNC> { + public: + bool should_dispose(std::shared_ptr item) override { return true; } + }; + + class CUSTOM_ENDPOINTS_ITEM_DISPOSAL_FUNC : public ITEM_DISPOSAL_FUNC> { + public: + void dispose(const std::shared_ptr monitor) override { + try { + monitor->stop(); + } catch (const std::exception& e) { + // Ignore + } + } + }; + static constexpr long long CACHE_CLEANUP_RATE_NANO = 60000000000; + + protected: + std::string custom_endpoint_id; + std::string region; + std::string custom_endpoint_host; + std::shared_ptr custom_endpoint_host_info; + std::shared_ptr rds_client; + bool should_wait_for_info; + long wait_on_cached_info_duration_ms; + long idle_monitor_expiration_ms; + + static SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> monitors; + static std::mutex monitor_cache_mutex; + + std::shared_ptr create_monitor_if_absent(DataSource* ds); + + /** + * If custom endpoint info does not exist for the current custom endpoint, waits a short time for the info to be + * made available by the custom endpoint monitor. This is necessary so that other plugins can rely on accurate custom + * endpoint info. Since custom endpoint monitors and information are shared, we should not have to wait often. + */ + void wait_for_custom_endpoint_info(std::shared_ptr monitor); + + private: + std::shared_ptr logger; + std::mutex mutex_; +}; + +#endif diff --git a/driver/handle.cc b/driver/handle.cc index 5f6386fd1..125ae86cf 100644 --- a/driver/handle.cc +++ b/driver/handle.cc @@ -57,6 +57,8 @@ #include +#include "custom_endpoint_proxy.h" + thread_local long thread_count = 0; std::mutex g_lock; @@ -123,6 +125,12 @@ void DBC::init_proxy_chain(DataSource* dsrc) { CONNECTION_PROXY *head = new MYSQL_PROXY(this, dsrc); + if (dsrc->opt_ENABLE_CUSTOM_ENDPOINT_MONITORING) { + CONNECTION_PROXY* custom_endpoint_proxy = new CUSTOM_ENDPOINT_PROXY(this, dsrc); + custom_endpoint_proxy->set_next_proxy(head); + head = custom_endpoint_proxy; + } + if (dsrc->opt_ENABLE_FAILURE_DETECTION) { CONNECTION_PROXY* efm_proxy = new EFM_PROXY(this, dsrc); efm_proxy->set_next_proxy(head); diff --git a/driver/rds_utils.cc b/driver/rds_utils.cc index 26b6f9157..878555332 100644 --- a/driver/rds_utils.cc +++ b/driver/rds_utils.cc @@ -133,6 +133,24 @@ std::string RDS_UTILS::get_rds_cluster_host_url(std::string host) { return f(AURORA_CHINA_CLUSTER_PATTERN); } +std::string RDS_UTILS::get_rds_cluster_id(std::string host) { + auto f = [host](const std::regex pattern) { + std::smatch m; + if (std::regex_search(host, m, pattern) && m.size() > 1) { + return m.size() > 1 ? m.str(1) : std::string(""); + } + return std::string(); + }; + + auto result = f(AURORA_CLUSTER_PATTERN); + if (!result.empty()) { + return result; + } + + return f(AURORA_CHINA_CLUSTER_PATTERN); +} + + std::string RDS_UTILS::get_rds_instance_host_pattern(std::string host) { auto f = [host](const std::regex pattern) { std::smatch m; @@ -155,6 +173,21 @@ std::string RDS_UTILS::get_rds_instance_host_pattern(std::string host) { return f(AURORA_CHINA_DNS_PATTERN); } +std::string RDS_UTILS::get_rds_region(std::string host) { + auto f = [host](const std::regex pattern) { + // TODO: implement region + return std::string(); + }; + + auto result = f(AURORA_DNS_PATTERN); + if (!result.empty()) { + return result; + } + + return f(AURORA_CHINA_DNS_PATTERN); +} + + bool RDS_UTILS::is_ipv4(std::string host) { return std::regex_match(host, IPV4_PATTERN); } bool RDS_UTILS::is_ipv6(std::string host) { return std::regex_match(host, IPV6_PATTERN) || std::regex_match(host, IPV6_COMPRESSED_PATTERN); diff --git a/driver/rds_utils.h b/driver/rds_utils.h index 1475129a2..34ad0cfd9 100644 --- a/driver/rds_utils.h +++ b/driver/rds_utils.h @@ -45,7 +45,9 @@ class RDS_UTILS { static bool is_ipv6(std::string host); static std::string get_rds_cluster_host_url(std::string host); + static std::string get_rds_cluster_id(std::string host); static std::string get_rds_instance_host_pattern(std::string host); + static std::string get_rds_region(std::string host); }; #endif diff --git a/driver/sliding_expiration_cache.cc b/driver/sliding_expiration_cache.cc index f1a1169f0..423246e7e 100644 --- a/driver/sliding_expiration_cache.cc +++ b/driver/sliding_expiration_cache.cc @@ -33,6 +33,8 @@ #include #include +#include "custom_endpoint_monitor.h" + template void SLIDING_EXPIRATION_CACHE::remove_and_dispose(K key) { if (this->cache.count(key)) { @@ -64,12 +66,12 @@ void SLIDING_EXPIRATION_CACHE::clean_up() { } template -V SLIDING_EXPIRATION_CACHE::compute_if_absent(K key, std::function mapping_function, +V SLIDING_EXPIRATION_CACHE::compute_if_absent(K key, std::function mapping_function, long long item_expiration_nanos) { this->clean_up(); auto cache_item = std::make_shared(mapping_function(key), std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); - this->cache.emplace(key, cache_item); + this->cache[key] = cache_item; return cache_item->with_extend_expiration(item_expiration_nanos)->item; } @@ -135,3 +137,6 @@ void SLIDING_EXPIRATION_CACHE::set_clean_up_interval_nanos(long long clean } template class SLIDING_EXPIRATION_CACHE; +template class SLIDING_EXPIRATION_CACHE>; +template class SHOULD_DISPOSE_FUNC>; +template class ITEM_DISPOSAL_FUNC>; diff --git a/driver/sliding_expiration_cache.h b/driver/sliding_expiration_cache.h index 642226b6f..bff3135a6 100644 --- a/driver/sliding_expiration_cache.h +++ b/driver/sliding_expiration_cache.h @@ -39,13 +39,15 @@ template class SHOULD_DISPOSE_FUNC { public: - virtual bool should_dispose(T item); + virtual ~SHOULD_DISPOSE_FUNC() = default; + virtual bool should_dispose(T item) { return true; }; }; template class ITEM_DISPOSAL_FUNC { public: - virtual void dispose(T item); + virtual ~ITEM_DISPOSAL_FUNC() = default; + virtual void dispose(T item) {/* Do nothing. */}; }; template @@ -55,7 +57,7 @@ class SLIDING_EXPIRATION_CACHE { public: CACHE_ITEM() = default; CACHE_ITEM(V item, std::chrono::steady_clock::time_point expiration_time) - : item(item), expiration_time(expiration_time){}; + : item(item), expiration_time(expiration_time){}; ~CACHE_ITEM() = default; V item; @@ -64,7 +66,7 @@ class SLIDING_EXPIRATION_CACHE { return this; } - bool should_clean_up(SHOULD_DISPOSE_FUNC* should_dispose_func) { + bool should_clean_up(std::shared_ptr> should_dispose_func) { if (should_dispose_func != nullptr) { return std::chrono::steady_clock::now() > this->expiration_time && should_dispose_func->should_dispose(this->item); @@ -82,15 +84,16 @@ class SLIDING_EXPIRATION_CACHE { this->item_disposal_func = nullptr; } - SLIDING_EXPIRATION_CACHE(SHOULD_DISPOSE_FUNC* should_dispose_func, ITEM_DISPOSAL_FUNC* item_disposal_func) - : should_dispose_func(should_dispose_func), item_disposal_func(item_disposal_func){}; - SLIDING_EXPIRATION_CACHE(SHOULD_DISPOSE_FUNC* should_dispose_func, ITEM_DISPOSAL_FUNC* item_disposal_func, - long long clean_up_interval_nanos) - : clean_up_interval_nanos(clean_up_interval_nanos), - should_dispose_func(should_dispose_func), - item_disposal_func(item_disposal_func){}; + SLIDING_EXPIRATION_CACHE(std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func) + : should_dispose_func(should_dispose_func), item_disposal_func(item_disposal_func){}; + SLIDING_EXPIRATION_CACHE(std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func, long long clean_up_interval_nanos) + : clean_up_interval_nanos(clean_up_interval_nanos), + should_dispose_func(std::move(should_dispose_func)), + item_disposal_func(std::move(item_disposal_func)){}; - V compute_if_absent(K key, std::function mapping_function, long long item_expiration_nanos); + V compute_if_absent(K key, std::function mapping_function, long long item_expiration_nanos); V put(K key, V value, long long item_expiration_nanos); V get(K key, long long item_expiration_nanos, V default_value); @@ -121,8 +124,8 @@ class SLIDING_EXPIRATION_CACHE { std::unordered_map> cache; long long clean_up_interval_nanos = 6000000000; // 1 minutes std::atomic clean_up_time_nanos; - SHOULD_DISPOSE_FUNC* should_dispose_func; - ITEM_DISPOSAL_FUNC* item_disposal_func; + std::shared_ptr> should_dispose_func; + std::shared_ptr> item_disposal_func; void remove_and_dispose(K key); void remove_if_expired(K key) { diff --git a/driver/sliding_expiration_cache_with_clean_up_thread.cc b/driver/sliding_expiration_cache_with_clean_up_thread.cc index 815db1e76..7d6056466 100644 --- a/driver/sliding_expiration_cache_with_clean_up_thread.cc +++ b/driver/sliding_expiration_cache_with_clean_up_thread.cc @@ -30,6 +30,9 @@ #include "sliding_expiration_cache_with_clean_up_thread.h" #include +#include + +#include "custom_endpoint_monitor.h" template void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::init_clean_up_thread() { @@ -65,16 +68,17 @@ SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WI template SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - SHOULD_DISPOSE_FUNC* should_dispose_func, ITEM_DISPOSAL_FUNC* item_disposal_func) - : SLIDING_EXPIRATION_CACHE(should_dispose_func, item_disposal_func) { + std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func) + : SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func)) { this->init_clean_up_thread(); } template SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - SHOULD_DISPOSE_FUNC* should_dispose_func, ITEM_DISPOSAL_FUNC* item_disposal_func, - long long clean_up_interval_nanos) - : SLIDING_EXPIRATION_CACHE(should_dispose_func, item_disposal_func, clean_up_interval_nanos) { + std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func, long long clean_up_interval_nanos) + : SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func), clean_up_interval_nanos) { this->init_clean_up_thread(); } @@ -96,3 +100,4 @@ void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::release_resources() { } template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD; +template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD>; diff --git a/driver/sliding_expiration_cache_with_clean_up_thread.h b/driver/sliding_expiration_cache_with_clean_up_thread.h index ebc5f0c46..807744347 100644 --- a/driver/sliding_expiration_cache_with_clean_up_thread.h +++ b/driver/sliding_expiration_cache_with_clean_up_thread.h @@ -39,14 +39,13 @@ template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD : public SLIDING_EXPIRATION_CACHE { public: SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(); - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(SHOULD_DISPOSE_FUNC* should_dispose_func, - ITEM_DISPOSAL_FUNC* item_disposal_func); - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(SHOULD_DISPOSE_FUNC* should_dispose_func, - ITEM_DISPOSAL_FUNC* item_disposal_func, + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func); + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func, long long clean_up_interval_nanos); ~SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default; - /** * Stop clean up thread. Should be called at the end of the cache's lifetime. */ diff --git a/util/installer.cc b/util/installer.cc index e5978810a..b348cd8db 100644 --- a/util/installer.cc +++ b/util/installer.cc @@ -277,13 +277,21 @@ static SQLWCHAR W_CONNECT_TIMEOUT[] = { 'C', 'O', 'N', 'N', 'E', 'C', 'T', '_', static SQLWCHAR W_NETWORK_TIMEOUT[] = { 'N', 'E', 'T', 'W', 'O', 'R', 'K', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', 0 }; /* Monitoring */ -static SQLWCHAR W_ENABLE_FAILURE_DETECTION[] = { 'E', 'N', 'A', 'B', 'L', 'E', '_', 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', 0 }; +static SQLWCHAR W_ENABLE_FAILURE_DETECTION[] = { 'E', 'N', 'A', 'B', 'L', 'E', '_','E', 'N', 'A', 'B', 'L', 'E', '_', 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', 0 }; static SQLWCHAR W_FAILURE_DETECTION_TIME[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 }; static SQLWCHAR W_FAILURE_DETECTION_INTERVAL[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', '_', 'I', 'N', 'T', 'E', 'R', 'V', 'A', 'L', 0 }; static SQLWCHAR W_FAILURE_DETECTION_COUNT[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', '_', 'C', 'O', 'U', 'N', 'T', 0 }; static SQLWCHAR W_MONITOR_DISPOSAL_TIME[] = { 'M', 'O', 'N', 'I', 'T', 'O', 'R', '_', 'D', 'I', 'S', 'P', 'O', 'S', 'A', 'L', '_', 'T', 'I', 'M', 'E', 0 }; static SQLWCHAR W_FAILURE_DETECTION_TIMEOUT[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', 0 }; +/* Custom Endpoint */ +static SQLWCHAR W_ENABLE_CUSTOM_ENDPOINT_MONITORING[] = {'E', 'N', 'A', 'B', 'L', 'E', '_', 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'M', 'O', 'N', 'I', 'T', 'O', 'R', 'I', 'N', 'G', 0}; +static SQLWCHAR W_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS[] = { 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'I', 'N', 'F', 'O', '_', 'R', 'E', 'F', 'R', 'E', 'S', 'H', '_', 'R', 'A', 'T', 'E', '_', 'M', 'S', 0 }; +static SQLWCHAR W_WAIT_FOR_CUSTOM_ENDPOINT_INFO[] = { 'W', 'A', 'I', 'T', '_', 'F', 'O', 'R', '_', 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'I', 'N', 'F', 'O', 0 }; +static SQLWCHAR W_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS[] = { 'W', 'A', 'I', 'T', '_', 'F', 'O', 'R', '_', 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'I', 'N', 'F', 'O', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', '_', 'M', 'S', 0 }; +static SQLWCHAR W_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS[] = { 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'M', 'O', 'N', 'I', 'T', 'O', 'R', '_', 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'M', 'S', 0 }; +static SQLWCHAR W_CUSTOM_ENDPOINT_REGION[] = { 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 }; + /* DS_PARAM */ /* externally used strings */ const SQLWCHAR W_DRIVER_PARAM[]= {';', 'D', 'R', 'I', 'V', 'E', 'R', '=', 0}; @@ -341,7 +349,12 @@ SQLWCHAR *dsnparams[]= {W_DSN, W_DRIVER, W_DESCRIPTION, W_SERVER, /* Monitoring */ W_ENABLE_FAILURE_DETECTION, W_FAILURE_DETECTION_TIME, W_FAILURE_DETECTION_INTERVAL, W_FAILURE_DETECTION_COUNT, - W_MONITOR_DISPOSAL_TIME, W_FAILURE_DETECTION_TIMEOUT}; + W_MONITOR_DISPOSAL_TIME, W_FAILURE_DETECTION_TIMEOUT, + /* Custom Endpoints */ + W_ENABLE_CUSTOM_ENDPOINT_MONITORING, + W_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS, W_WAIT_FOR_CUSTOM_ENDPOINT_INFO, + W_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS, W_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS, W_CUSTOM_ENDPOINT_REGION}; + static const int dsnparamcnt= sizeof(dsnparams) / sizeof(SQLWCHAR *); /* DS_PARAM */ @@ -675,7 +688,7 @@ int Driver::from_kvpair_semicolon(const SQLWCHAR *attrs) memcpy(attribute, attrs, (split - attrs) * sizeof(SQLWCHAR)); attribute[split - attrs]= 0; /* add null term */ ++split; - + /* if its one we want, copy it over */ if (!sqlwcharcasecmp(W_DRIVER, attribute)) dest = &lib; @@ -1060,6 +1073,11 @@ void DataSource::reset() { this->opt_MONITOR_DISPOSAL_TIME.set_default(MONITOR_DISPOSAL_TIME_MS); this->opt_FAILURE_DETECTION_TIMEOUT.set_default(FAILURE_DETECTION_TIMEOUT_SECS); + this->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO.set_default(true); + this->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.set_default(5000); + this->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS.set_default(30000); + this->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS.set_default(900000); + this->opt_AUTH_PORT.set_default(opt_PORT); this->opt_AUTH_EXPIRATION.set_default(900); // 15 minutes this->opt_FED_AUTH_PORT.set_default(opt_PORT); diff --git a/util/installer.h b/util/installer.h index 7968cfc4b..1e17ec51d 100644 --- a/util/installer.h +++ b/util/installer.h @@ -364,49 +364,53 @@ unsigned int get_network_timeout(unsigned int seconds); X(FAILURE_DETECTION_TIMEOUT) \ X(MONITOR_DISPOSAL_TIME) +#define CUSTOM_ENDPOINT_BOOL_OPTIONS_LIST(X) X(WAIT_FOR_CUSTOM_ENDPOINT_INFO) \ + X(ENABLE_CUSTOM_ENDPOINT_MONITORING) + +#define CUSTOM_ENDPOINT_INT_OPTIONS_LIST(X) \ + X(CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS) \ + X(WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS) \ + X(CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS) + +#define CUSTOM_ENDPOINT_STR_OPTIONS_LIST(X) X(CUSTOM_ENDPOINT_REGION) + #define STR_OPTIONS_LIST(X) \ X(DSN) \ X(DRIVER) \ X(DESCRIPTION) \ - X(SERVER) \ - X(UID) \ - X(PWD) MFA_OPTS(X) X(DATABASE) X(SOCKET) X(INITSTMT) X(CHARSET) X(SSL_KEY) \ - X(SSL_CERT) X(SSL_CA) X(SSL_CAPATH) X(SSL_CIPHER) X(SSL_MODE) X(RSAKEY) \ - X(SAVEFILE) X(PLUGIN_DIR) X(DEFAULT_AUTH) X(LOAD_DATA_LOCAL_DIR) \ - X(OCI_CONFIG_FILE) X(OCI_CONFIG_PROFILE) \ - X(AUTHENTICATION_KERBEROS_MODE) X(TLS_VERSIONS) X(SSL_CRL) \ - X(SSL_CRLPATH) X(SSLVERIFY) X(OPENTELEMETRY) \ - AWS_AUTH_STR_OPTIONS_LIST(X) FAILOVER_STR_OPTIONS_LIST(X) FED_AUTH_STR_OPTIONS_LIST(X) - -#define INT_OPTIONS_LIST(X) \ - X(PORT) \ - X(READTIMEOUT) \ - X(WRITETIMEOUT) \ - X(CLIENT_INTERACTIVE) X(PREFETCH) FAILOVER_INT_OPTIONS_LIST(X) \ - AWS_AUTH_INT_OPTIONS_LIST(X) MONITORING_INT_OPTIONS_LIST(X) FED_AUTH_INT_OPTIONS_LIST(X) + X(SERVER) \ + X(UID) \ + X(PWD) \ + MFA_OPTS(X) X(DATABASE) X(SOCKET) X(INITSTMT) X(CHARSET) X(SSL_KEY) X(SSL_CERT) X(SSL_CA) X(SSL_CAPATH) \ + X(SSL_CIPHER) X(SSL_MODE) X(RSAKEY) X(SAVEFILE) X(PLUGIN_DIR) X(DEFAULT_AUTH) X(LOAD_DATA_LOCAL_DIR) \ + X(OCI_CONFIG_FILE) X(OCI_CONFIG_PROFILE) X(AUTHENTICATION_KERBEROS_MODE) X(TLS_VERSIONS) X(SSL_CRL) \ + X(SSL_CRLPATH) X(SSLVERIFY) X(OPENTELEMETRY) AWS_AUTH_STR_OPTIONS_LIST(X) FAILOVER_STR_OPTIONS_LIST(X) \ + CUSTOM_ENDPOINT_STR_OPTIONS_LIST(X) FED_AUTH_STR_OPTIONS_LIST(X) + +#define INT_OPTIONS_LIST(X) \ + X(PORT) \ + X(READTIMEOUT) \ + X(WRITETIMEOUT) \ + X(CLIENT_INTERACTIVE) \ + X(PREFETCH) FAILOVER_INT_OPTIONS_LIST(X) AWS_AUTH_INT_OPTIONS_LIST(X) MONITORING_INT_OPTIONS_LIST(X) \ + CUSTOM_ENDPOINT_INT_OPTIONS_LIST(X) FED_AUTH_INT_OPTIONS_LIST(X) // TODO: remove AUTO_RECONNECT when special handling (warning) // is not needed anymore. -#define BOOL_OPTIONS_LIST(X) \ - X(FOUND_ROWS) \ - X(BIG_PACKETS) \ - X(COMPRESSED_PROTO) \ - X(NO_BIGINT) \ - X(SAFE) \ - X(AUTO_RECONNECT) X(AUTO_IS_NULL) X(NO_BINARY_RESULT) X(CAN_HANDLE_EXP_PWD) \ - X(ENABLE_CLEARTEXT_PLUGIN) X(GET_SERVER_PUBLIC_KEY) X(NO_PROMPT) \ - X(DYNAMIC_CURSOR) X(NO_DEFAULT_CURSOR) X(NO_LOCALE) X(PAD_SPACE) \ - X(NO_CACHE) X(FULL_COLUMN_NAMES) X(IGNORE_SPACE) X(NAMED_PIPE) \ - X(NO_CATALOG) X(NO_SCHEMA) X(USE_MYCNF) X(NO_TRANSACTIONS) \ - X(FORWARD_CURSOR) X(MULTI_STATEMENTS) X(COLUMN_SIZE_S32) \ - X(MIN_DATE_TO_ZERO) X(ZERO_DATE_TO_MIN) X( \ - DFLT_BIGINT_BIND_STR) X(LOG_QUERY) X(NO_SSPS) \ - X(NO_TLS_1_2) X(NO_TLS_1_3) X(NO_DATE_OVERFLOW) \ - X(ENABLE_LOCAL_INFILE) X(ENABLE_DNS_SRV) \ - X(MULTI_HOST) \ - FAILOVER_BOOL_OPTIONS_LIST(X) \ - MONITORING_BOOL_OPTIONS_LIST(X) \ - FED_AUTH_BOOL_OPTIONS_LIST(X) +#define BOOL_OPTIONS_LIST(X) \ + X(FOUND_ROWS) \ + X(BIG_PACKETS) \ + X(COMPRESSED_PROTO) \ + X(NO_BIGINT) \ + X(SAFE) \ + X(AUTO_RECONNECT) \ + X(AUTO_IS_NULL) X(NO_BINARY_RESULT) X(CAN_HANDLE_EXP_PWD) X(ENABLE_CLEARTEXT_PLUGIN) X(GET_SERVER_PUBLIC_KEY) \ + X(NO_PROMPT) X(DYNAMIC_CURSOR) X(NO_DEFAULT_CURSOR) X(NO_LOCALE) X(PAD_SPACE) X(NO_CACHE) X(FULL_COLUMN_NAMES) \ + X(IGNORE_SPACE) X(NAMED_PIPE) X(NO_CATALOG) X(NO_SCHEMA) X(USE_MYCNF) X(NO_TRANSACTIONS) X(FORWARD_CURSOR) \ + X(MULTI_STATEMENTS) X(COLUMN_SIZE_S32) X(MIN_DATE_TO_ZERO) X(ZERO_DATE_TO_MIN) X(DFLT_BIGINT_BIND_STR) \ + X(LOG_QUERY) X(NO_SSPS) X(NO_TLS_1_2) X(NO_TLS_1_3) X(NO_DATE_OVERFLOW) X(ENABLE_LOCAL_INFILE) \ + X(ENABLE_DNS_SRV) X(MULTI_HOST) FAILOVER_BOOL_OPTIONS_LIST(X) MONITORING_BOOL_OPTIONS_LIST(X) \ + CUSTOM_ENDPOINT_BOOL_OPTIONS_LIST(X) FED_AUTH_BOOL_OPTIONS_LIST(X) #define FULL_OPTIONS_LIST(X) \ STR_OPTIONS_LIST(X) INT_OPTIONS_LIST(X) BOOL_OPTIONS_LIST(X) From 1dbd66ade1fea0e3ad25fe0f2f983a2341fb0664 Mon Sep 17 00:00:00 2001 From: Karen Chen Date: Thu, 9 Jan 2025 17:59:19 -0800 Subject: [PATCH 2/5] chore: add custom endpoints to UI --- driver/custom_endpoint_monitor.cc | 2 +- setupgui/callbacks.cc | 57 ++++++++++++++++++++------- setupgui/setupgui.h | 15 +++---- setupgui/windows/odbcdialogparams.cpp | 2 + setupgui/windows/odbcdialogparams.rc | 34 ++++++++++++---- setupgui/windows/resource.h | 9 +++++ 6 files changed, 89 insertions(+), 30 deletions(-) diff --git a/driver/custom_endpoint_monitor.cc b/driver/custom_endpoint_monitor.cc index 5afda5c6c..d5e53cae2 100644 --- a/driver/custom_endpoint_monitor.cc +++ b/driver/custom_endpoint_monitor.cc @@ -169,7 +169,7 @@ void CUSTOM_ENDPOINT_MONITOR::stop() { this->thread_pool.resize(0); custom_endpoint_cache.remove(this->custom_endpoint_host_info->get_host()); --SDK_HELPER; - MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host_info->get_host()); + MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host_info->get_host().c_str()); } void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); } diff --git a/setupgui/callbacks.cc b/setupgui/callbacks.cc index 53177485b..80776f34a 100644 --- a/setupgui/callbacks.cc +++ b/setupgui/callbacks.cc @@ -346,7 +346,15 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(FED_AUTH_TAB, CLIENT_SOCKET_TIMEOUT); GET_BOOL_TAB(FED_AUTH_TAB, ENABLE_SSL); - /* 5 - Failover */ + /* 5 - Custom Endpoint */ + GET_BOOL_TAB(CUSTOM_ENDPOINT_TAB, ENABLE_CUSTOM_ENDPOINT_MONITORING); + GET_BOOL_TAB(CUSTOM_ENDPOINT_TAB, WAIT_FOR_CUSTOM_ENDPOINT_INFO); + GET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS); + GET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS); + GET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS); + GET_STRING_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_REGION); + + /* 6 - Failover */ GET_BOOL_TAB(FAILOVER_TAB, ENABLE_CLUSTER_FAILOVER); GET_COMBO_TAB(FAILOVER_TAB, FAILOVER_MODE); GET_BOOL_TAB(FAILOVER_TAB, GATHER_PERF_METRICS); @@ -365,7 +373,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(FAILOVER_TAB, CONNECT_TIMEOUT); GET_UNSIGNED_TAB(FAILOVER_TAB, NETWORK_TIMEOUT); - /* 6 - Monitoring */ + /* 7 - Monitoring */ GET_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION); if (READ_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION)) { @@ -376,7 +384,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(MONITORING_TAB, MONITOR_DISPOSAL_TIME); } - /* 7 - Metadata*/ + /* 8 - Metadata*/ GET_BOOL_TAB(METADATA_TAB, NO_BIGINT); GET_BOOL_TAB(METADATA_TAB, NO_BINARY_RESULT); GET_BOOL_TAB(METADATA_TAB, FULL_COLUMN_NAMES); @@ -384,7 +392,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_BOOL_TAB(METADATA_TAB, NO_SCHEMA); GET_BOOL_TAB(METADATA_TAB, COLUMN_SIZE_S32); - /* 8 - Cursors/Results */ + /* 9 - Cursors/Results */ GET_BOOL_TAB(CURSORS_TAB, FOUND_ROWS); GET_BOOL_TAB(CURSORS_TAB, AUTO_IS_NULL); GET_BOOL_TAB(CURSORS_TAB, DYNAMIC_CURSOR); @@ -402,10 +410,10 @@ void syncTabsData(HWND hwnd, DataSource *params) { params->opt_PREFETCH = 0; } - /* 9 - debug*/ + /* 10 - debug*/ GET_BOOL_TAB(DEBUG_TAB,LOG_QUERY); - /* 10 - ssl related */ + /* 11 - ssl related */ GET_STRING_TAB(SSL_TAB, SSL_KEY); GET_STRING_TAB(SSL_TAB, SSL_CERT); GET_STRING_TAB(SSL_TAB, SSL_CA); @@ -420,7 +428,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_STRING_TAB(SSL_TAB, SSL_CRL); GET_STRING_TAB(SSL_TAB, SSL_CRLPATH); - /* 11 - Misc*/ + /* 12 - Misc*/ GET_BOOL_TAB(MISC_TAB, SAFE); GET_BOOL_TAB(MISC_TAB, NO_LOCALE); GET_BOOL_TAB(MISC_TAB, IGNORE_SPACE); @@ -501,7 +509,28 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(FED_AUTH_TAB, CLIENT_SOCKET_TIMEOUT); SET_BOOL_TAB(FED_AUTH_TAB, ENABLE_SSL); - /* 5 - Failover */ + /* 5 - Custom Endpoint */ + SET_BOOL_TAB(CUSTOM_ENDPOINT_TAB, ENABLE_CUSTOM_ENDPOINT_MONITORING); + SET_BOOL_TAB(CUSTOM_ENDPOINT_TAB, WAIT_FOR_CUSTOM_ENDPOINT_INFO); + + if (params->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS > 0) + { + SET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS); + } + + if (params->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS > 0) + { + SET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS); + } + + if (params->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS > 0) + { + SET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS); + } + + SET_STRING_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_REGION); + + /* 6 - Failover */ SET_BOOL_TAB(FAILOVER_TAB, ENABLE_CLUSTER_FAILOVER); SET_COMBO_TAB(FAILOVER_TAB, FAILOVER_MODE); SET_BOOL_TAB(FAILOVER_TAB, GATHER_PERF_METRICS); @@ -552,7 +581,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(FAILOVER_TAB, NETWORK_TIMEOUT); } - /* 6 - Monitoring */ + /* 7 - Monitoring */ SET_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION); if (READ_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION)) { #ifdef _WIN32 @@ -569,7 +598,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(MONITORING_TAB, FAILURE_DETECTION_TIMEOUT); } - /* 7 - Metadata */ + /* 8 - Metadata */ SET_BOOL_TAB(METADATA_TAB, NO_BIGINT); SET_BOOL_TAB(METADATA_TAB, NO_BINARY_RESULT); SET_BOOL_TAB(METADATA_TAB, FULL_COLUMN_NAMES); @@ -577,7 +606,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_BOOL_TAB(METADATA_TAB, NO_SCHEMA); SET_BOOL_TAB(METADATA_TAB, COLUMN_SIZE_S32); - /* 8 - Cursors/Results */ + /* 9 - Cursors/Results */ SET_BOOL_TAB(CURSORS_TAB, FOUND_ROWS); SET_BOOL_TAB(CURSORS_TAB, AUTO_IS_NULL); SET_BOOL_TAB(CURSORS_TAB, DYNAMIC_CURSOR); @@ -596,10 +625,10 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(CURSORS_TAB, PREFETCH); } - /* 9 - debug*/ + /* 10 - debug*/ SET_BOOL_TAB(DEBUG_TAB,LOG_QUERY); - /* 10 - ssl related */ + /* 11 - ssl related */ #ifdef _WIN32 if ( getTabCtrlTabPages(SSL_TAB-1) ) #endif @@ -637,7 +666,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_STRING_TAB(SSL_TAB, TLS_VERSIONS); } - /* 11 - Misc*/ + /* 12 - Misc*/ SET_BOOL_TAB(MISC_TAB, SAFE); SET_BOOL_TAB(MISC_TAB, NO_LOCALE); SET_BOOL_TAB(MISC_TAB, IGNORE_SPACE); diff --git a/setupgui/setupgui.h b/setupgui/setupgui.h index 903476b05..f9bd3c1a3 100644 --- a/setupgui/setupgui.h +++ b/setupgui/setupgui.h @@ -39,13 +39,14 @@ #define AUTH_TAB 2 #define AWS_AUTH_TAB 3 #define FED_AUTH_TAB 4 -#define FAILOVER_TAB 5 -#define MONITORING_TAB 6 -#define METADATA_TAB 7 -#define CURSORS_TAB 8 -#define DEBUG_TAB 9 -#define SSL_TAB 10 -#define MISC_TAB 11 +#define CUSTOM_ENDPOINT_TAB 5 +#define FAILOVER_TAB 6 +#define MONITORING_TAB 7 +#define METADATA_TAB 8 +#define CURSORS_TAB 9 +#define DEBUG_TAB 10 +#define SSL_TAB 11 +#define MISC_TAB 12 #else # include diff --git a/setupgui/windows/odbcdialogparams.cpp b/setupgui/windows/odbcdialogparams.cpp index 64bf5e332..174a0665a 100644 --- a/setupgui/windows/odbcdialogparams.cpp +++ b/setupgui/windows/odbcdialogparams.cpp @@ -376,6 +376,7 @@ void btnDetails_Click (HWND hwnd) L"Authentication", L"AWS Authentication", L"Federated Authentication", + L"Custom Endpoint Monitoring", L"Cluster Failover", L"Monitoring", L"Metadata", @@ -396,6 +397,7 @@ void btnDetails_Click (HWND hwnd) MAKEINTRESOURCE(IDD_TAB9), MAKEINTRESOURCE(IDD_TAB10), MAKEINTRESOURCE(IDD_TAB11), + MAKEINTRESOURCE(IDD_TAB12), 0}; New_TabControl( &TabCtrl_1, // address of TabControl struct diff --git a/setupgui/windows/odbcdialogparams.rc b/setupgui/windows/odbcdialogparams.rc index e0b6f325a..ac787e1a9 100644 --- a/setupgui/windows/odbcdialogparams.rc +++ b/setupgui/windows/odbcdialogparams.rc @@ -252,7 +252,25 @@ BEGIN CONTROL "&Enable SSL",IDC_CHECK_ENABLE_SSL,"Button",BS_AUTOCHECKBOX | WS_TABSTOP,207,108,47,10 END -IDD_TAB5 DIALOGEX 0, 0, 209, 281 +IDD_TAB5 DIALOGEX 0, 0, 209, 181 +STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD +FONT 8, "MS Shell Dlg", 400, 0, 0x1 +BEGIN + CONTROL "&Enable custom endpoint monitoring",IDC_CHECK_ENABLE_CUSTOM_ENDPOINT_MONITORING, + "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,12,147,10 + CONTROL "&Wait for custom endpoint info",IDC_CHECK_WAIT_FOR_CUSTOM_ENDPOINT_INFO, + "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,27,147,10 + RTEXT "Custom endpoint info refresh rate (ms):",IDC_STATIC,12,42,150,10 + EDITTEXT IDC_EDIT_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS,165,40,64,12,ES_AUTOHSCROLL | ES_NUMBER | WS_DISABLED + RTEXT "Wait for custom endpoint info timeout (ms):",IDC_STATIC,12,57,150,8 + EDITTEXT IDC_EDIT_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS,165,55,64,12,ES_AUTOHSCROLL | ES_NUMBER | WS_DISABLED + RTEXT "Custom endpoint monitor expiration time (ms):",IDC_STATIC,12,72,150,8 + EDITTEXT IDC_EDIT_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS,165,70,64,12,ES_AUTOHSCROLL | ES_NUMBER | WS_DISABLED + RTEXT "Custom endpoint region:",IDC_STATIC,12,87,150,8 + EDITTEXT IDC_EDIT_CUSTOM_ENDPOINT_REGION,165,85,64,12,ES_AUTOHSCROLL +END + +IDD_TAB6 DIALOGEX 0, 0, 209, 281 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -284,7 +302,7 @@ BEGIN "Button", BS_AUTOCHECKBOX | WS_TABSTOP | WS_DISABLED, 210, 96, 120, 10 END -IDD_TAB6 DIALOGEX 0, 0, 209, 181 +IDD_TAB7 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -302,7 +320,7 @@ BEGIN EDITTEXT IDC_EDIT_MONITOR_DISPOSAL_TIME,132,85,64,12,ES_AUTOHSCROLL | ES_NUMBER| WS_DISABLED END -IDD_TAB7 DIALOGEX 0, 0, 209, 181 +IDD_TAB8 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -320,7 +338,7 @@ BEGIN "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,87,141,10 END -IDD_TAB8 DIALOGEX 0, 0, 209, 181 +IDD_TAB9 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -345,15 +363,15 @@ BEGIN "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,125,138,10 END -IDD_TAB9 DIALOGEX 0, 0, 209, 181 +IDD_TAB10 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN CONTROL "&Log driver activity to %TEMP%\\myodbc.log",IDC_CHECK_LOG_QUERY, - "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,12,148,10 + "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,12,1170,10 END -IDD_TAB10 DIALOGEX 0, 0, 509, 181 +IDD_TAB11 DIALOGEX 0, 0, 509, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -386,7 +404,7 @@ BEGIN CONTROL "Disable TLS Version 1.&3",IDC_CHECK_NO_TLS_1_3,"Button",BS_AUTOCHECKBOX | WS_TABSTOP,90,164,87,10 END -IDD_TAB11 DIALOGEX 0, 0, 209, 181 +IDD_TAB12 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN diff --git a/setupgui/windows/resource.h b/setupgui/windows/resource.h index dcf7f33c3..ac11fb104 100644 --- a/setupgui/windows/resource.h +++ b/setupgui/windows/resource.h @@ -60,6 +60,7 @@ #define IDD_TAB9 140 #define IDD_TAB10 141 #define IDD_TAB11 142 +#define IDD_TAB12 143 #define IDC_LOGO 1000 #define IDC_EDIT 1010 #define IDC_EDIT_PASSWORD 1010 @@ -199,6 +200,14 @@ #define IDC_EDIT_FED_AUTH_HOST 11032 #define IDC_EDIT_FED_AUTH_PORT 11033 #define IDC_EDIT_FED_AUTH_EXPIRATION 11034 +#define IDC_CHECK_ENABLE_CUSTOM_ENDPOINT_MONITORING 11040 +#define IDC_CHECK_WAIT_FOR_CUSTOM_ENDPOINT_INFO 11041 +#define IDC_EDIT_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS 11042 +#define IDC_EDIT_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS 11043 +#define IDC_EDIT_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS 11044 +#define IDC_EDIT_CUSTOM_ENDPOINT_REGION 11045 + +#define IDC #define MYSQL_ADMIN_PORT 33062 #define IDC_STATIC -1 From 9889f4b1be24440878f06631f7c9974681121abf Mon Sep 17 00:00:00 2001 From: Karen Chen Date: Fri, 10 Jan 2025 13:24:31 -0800 Subject: [PATCH 3/5] chore: address review comments --- driver/cache_map.cc | 3 ++- driver/custom_endpoint_monitor.cc | 18 +++++++++--------- driver/custom_endpoint_proxy.cc | 1 - driver/custom_endpoint_proxy.h | 2 -- driver/rds_utils.cc | 4 +--- setupgui/windows/resource.h | 2 -- util/installer.cc | 2 +- 7 files changed, 13 insertions(+), 19 deletions(-) diff --git a/driver/cache_map.cc b/driver/cache_map.cc index a20e97e88..faa47fe05 100644 --- a/driver/cache_map.cc +++ b/driver/cache_map.cc @@ -37,6 +37,7 @@ template void CACHE_MAP::put(K key, V value, long long item_expiration_nanos) { this->cache[key] = std::make_shared( value, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); + this->clean_up(); } template @@ -60,6 +61,7 @@ void CACHE_MAP::remove(K key) { if (this->cache.count(key)) { this->cache.erase(key); } + this->clean_up(); } template @@ -70,7 +72,6 @@ int CACHE_MAP::size() { template void CACHE_MAP::clear() { this->cache.clear(); - this->clean_up(); } template diff --git a/driver/custom_endpoint_monitor.cc b/driver/custom_endpoint_monitor.cc index d5e53cae2..b48d77234 100644 --- a/driver/custom_endpoint_monitor.cc +++ b/driver/custom_endpoint_monitor.cc @@ -139,27 +139,27 @@ void CUSTOM_ENDPOINT_MONITOR::run() { } catch (const std::exception& e) { // Log and continue monitoring. - if (this->enable_logging) { - MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what()); - } + MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what()); } }); } std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string( const std::vector& custom_endpoints) { + if (custom_endpoints.empty()) { + return ""; + } + std::string endpoints("["); for (auto const& e : custom_endpoints) { endpoints += e.GetDBClusterEndpointIdentifier(); endpoints += ","; } - if (endpoints.empty()) { - endpoints = ""; - } else { - endpoints.pop_back(); - endpoints += "]"; - } + + endpoints.pop_back(); + endpoints += "]"; + return endpoints; } diff --git a/driver/custom_endpoint_proxy.cc b/driver/custom_endpoint_proxy.cc index cbb337c23..777ae5bdb 100644 --- a/driver/custom_endpoint_proxy.cc +++ b/driver/custom_endpoint_proxy.cc @@ -36,7 +36,6 @@ SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> CUSTOM_ENDPOINT_PROXY::monitors(std::make_shared(), std::make_shared(), CACHE_CLEANUP_RATE_NANO); -std::mutex CUSTOM_ENDPOINT_PROXY::monitor_cache_mutex; CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds) : CUSTOM_ENDPOINT_PROXY(dbc, ds, nullptr) {} CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) diff --git a/driver/custom_endpoint_proxy.h b/driver/custom_endpoint_proxy.h index be7338bbb..a3c96d77d 100644 --- a/driver/custom_endpoint_proxy.h +++ b/driver/custom_endpoint_proxy.h @@ -72,7 +72,6 @@ class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { long idle_monitor_expiration_ms; static SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> monitors; - static std::mutex monitor_cache_mutex; std::shared_ptr create_monitor_if_absent(DataSource* ds); @@ -85,7 +84,6 @@ class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { private: std::shared_ptr logger; - std::mutex mutex_; }; #endif diff --git a/driver/rds_utils.cc b/driver/rds_utils.cc index 878555332..5f674e677 100644 --- a/driver/rds_utils.cc +++ b/driver/rds_utils.cc @@ -137,7 +137,7 @@ std::string RDS_UTILS::get_rds_cluster_id(std::string host) { auto f = [host](const std::regex pattern) { std::smatch m; if (std::regex_search(host, m, pattern) && m.size() > 1) { - return m.size() > 1 ? m.str(1) : std::string(""); + return m.str(1); } return std::string(); }; @@ -150,7 +150,6 @@ std::string RDS_UTILS::get_rds_cluster_id(std::string host) { return f(AURORA_CHINA_CLUSTER_PATTERN); } - std::string RDS_UTILS::get_rds_instance_host_pattern(std::string host) { auto f = [host](const std::regex pattern) { std::smatch m; @@ -187,7 +186,6 @@ std::string RDS_UTILS::get_rds_region(std::string host) { return f(AURORA_CHINA_DNS_PATTERN); } - bool RDS_UTILS::is_ipv4(std::string host) { return std::regex_match(host, IPV4_PATTERN); } bool RDS_UTILS::is_ipv6(std::string host) { return std::regex_match(host, IPV6_PATTERN) || std::regex_match(host, IPV6_COMPRESSED_PATTERN); diff --git a/setupgui/windows/resource.h b/setupgui/windows/resource.h index ac11fb104..d1dbcd73b 100644 --- a/setupgui/windows/resource.h +++ b/setupgui/windows/resource.h @@ -206,8 +206,6 @@ #define IDC_EDIT_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS 11043 #define IDC_EDIT_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS 11044 #define IDC_EDIT_CUSTOM_ENDPOINT_REGION 11045 - -#define IDC #define MYSQL_ADMIN_PORT 33062 #define IDC_STATIC -1 diff --git a/util/installer.cc b/util/installer.cc index b348cd8db..004dec17a 100644 --- a/util/installer.cc +++ b/util/installer.cc @@ -277,7 +277,7 @@ static SQLWCHAR W_CONNECT_TIMEOUT[] = { 'C', 'O', 'N', 'N', 'E', 'C', 'T', '_', static SQLWCHAR W_NETWORK_TIMEOUT[] = { 'N', 'E', 'T', 'W', 'O', 'R', 'K', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', 0 }; /* Monitoring */ -static SQLWCHAR W_ENABLE_FAILURE_DETECTION[] = { 'E', 'N', 'A', 'B', 'L', 'E', '_','E', 'N', 'A', 'B', 'L', 'E', '_', 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', 0 }; +static SQLWCHAR W_ENABLE_FAILURE_DETECTION[] = { 'E', 'N', 'A', 'B', 'L', 'E', '_', 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', 0 }; static SQLWCHAR W_FAILURE_DETECTION_TIME[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 }; static SQLWCHAR W_FAILURE_DETECTION_INTERVAL[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', '_', 'I', 'N', 'T', 'E', 'R', 'V', 'A', 'L', 0 }; static SQLWCHAR W_FAILURE_DETECTION_COUNT[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', '_', 'C', 'O', 'U', 'N', 'T', 0 }; From 07f6b7dc9eab40ac8fb2ada6f219996d5063fc71 Mon Sep 17 00:00:00 2001 From: Karen Chen Date: Mon, 13 Jan 2025 20:22:26 +0000 Subject: [PATCH 4/5] fix: add missing imports --- driver/cache_map.h | 1 + driver/custom_endpoint_monitor.cc | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/driver/cache_map.h b/driver/cache_map.h index 82e7fed89..7e1e11753 100644 --- a/driver/cache_map.h +++ b/driver/cache_map.h @@ -30,6 +30,7 @@ #ifndef __CACHE_MAP__ #define __CACHE_MAP__ +#include #include #include #include diff --git a/driver/custom_endpoint_monitor.cc b/driver/custom_endpoint_monitor.cc index b48d77234..8a9e0844a 100644 --- a/driver/custom_endpoint_monitor.cc +++ b/driver/custom_endpoint_monitor.cc @@ -30,6 +30,10 @@ #include "custom_endpoint_monitor.h" #include +#include +#include +#include +#include #include "allowed_and_blocked_hosts.h" #include "aws_sdk_helper.h" From cb99e7bad5918b3832a4b8972bfc038960455247 Mon Sep 17 00:00:00 2001 From: Karen Chen Date: Mon, 13 Jan 2025 23:26:20 -0800 Subject: [PATCH 5/5] test: custom endpoints unit tests and integration tests --- .github/workflows/failover.yml | 1 + docs/using-the-aws-driver/CustomEndpoint.md | 24 + driver/CMakeLists.txt | 4 +- driver/cache_map.h | 4 +- driver/cluster_topology_info.cc | 22 + driver/cluster_topology_info.h | 2 + driver/custom_endpoint_info.h | 5 +- driver/custom_endpoint_monitor.cc | 88 +- driver/custom_endpoint_monitor.h | 24 +- driver/custom_endpoint_proxy.cc | 73 +- driver/custom_endpoint_proxy.h | 17 +- driver/driver.h | 6 + driver/failover_handler.cc | 3 +- driver/handle.cc | 16 +- driver/host_info.cc | 21 +- driver/host_info.h | 2 + driver/rds_utils.cc | 39 +- driver/rds_utils.h | 1 + driver/sliding_expiration_cache.cc | 6 +- ...g_expiration_cache_with_clean_up_thread.cc | 44 +- ...ng_expiration_cache_with_clean_up_thread.h | 14 +- driver/topology_service.cc | 44 +- driver/topology_service.h | 141 +-- integration/CMakeLists.txt | 1 + integration/base_failover_integration_test.cc | 18 +- integration/connection_string_builder.h | 840 ++++++++++-------- .../custom_endpoint_integration_test.cc | 178 ++++ scripts/build_aws_sdk_unix.sh | 2 +- scripts/build_aws_sdk_win.ps1 | 2 +- unit_testing/CMakeLists.txt | 2 + unit_testing/custom_endpoint_monitor_test.cc | 109 +++ unit_testing/custom_endpoint_proxy_test.cc | 104 +++ unit_testing/failover_handler_test.cc | 21 + unit_testing/mock_objects.h | 20 + unit_testing/sliding_expiration_cache_test.cc | 2 +- unit_testing/test_utils.cc | 20 + unit_testing/test_utils.h | 66 +- 37 files changed, 1346 insertions(+), 640 deletions(-) create mode 100644 docs/using-the-aws-driver/CustomEndpoint.md create mode 100644 integration/custom_endpoint_integration_test.cc create mode 100644 unit_testing/custom_endpoint_monitor_test.cc create mode 100644 unit_testing/custom_endpoint_proxy_test.cc diff --git a/.github/workflows/failover.yml b/.github/workflows/failover.yml index de8c1e3df..05a7cdd13 100644 --- a/.github/workflows/failover.yml +++ b/.github/workflows/failover.yml @@ -1,6 +1,7 @@ name: Failover Unit Tests on: + workflow_dispatch: push: branches: - main diff --git a/docs/using-the-aws-driver/CustomEndpoint.md b/docs/using-the-aws-driver/CustomEndpoint.md new file mode 100644 index 000000000..03dcd7d54 --- /dev/null +++ b/docs/using-the-aws-driver/CustomEndpoint.md @@ -0,0 +1,24 @@ +# Custom Endpoint Support + +The Custom Endpoint support allows client application to use the driver with RDS custom endpoints. When the Custom Endpoint feature is enabled, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover. + +## How to use the Driver with Custom Endpoint + +### Enabling the Custom Endpoint Feature + +1. If needed, create a custom endpoint using the AWS RDS Console: + - If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html). +2. Set `ENABLE_CUSTOM_ENDPOINT_MONITORING` to `TRUE` to enable custom endpoint support. +3. If you are using the failover plugin, set the failover parameter `FAILOVER_MODE` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `FAILOVER_MODE` to `strict-reader`, or if it is of type `ANY`, you can set `FAILOVER_MODE` to `reader-or-writer`. +4. Specify parameters that are required or specific to your case. + +### Custom Endpoint Plugin Parameters + +| Parameter | Value | Required | Description | Default Value | Example Value | +| ------------------------------------------ | :----: | :------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------- | ------------- | +| `ENABLE_CUSTOM_ENDPOINT_MONITORING` | bool | No | Set to TRUE to enable custom endpoint support. | `FALSE` | `TRUE` | +| `CUSTOM_ENDPOINT_REGION` | string | No | The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL. | `N/A` | `us-west-1` | +| `CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS` | long | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `30000` | `20000` | +| `CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS` | long | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` | +| `WAIT_FOR_CUSTOM_ENDPOINT_INFO` | bool | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `true` | `true` | +| `WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS` | long | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `5000` | `7000` | diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index cb172efbe..e0a5252cc 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -73,9 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) connect.cc connection_handler.cc connection_proxy.cc - custom_endpoint_proxy.cc custom_endpoint_info.cc custom_endpoint_monitor.cc + custom_endpoint_proxy.cc cursor.cc desc.cc dll.cc @@ -148,9 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) cluster_topology_info.h connection_handler.h connection_proxy.h - custom_endpoint_proxy.h custom_endpoint_info.h custom_endpoint_monitor.h + custom_endpoint_proxy.h driver.h efm_proxy.h error.h diff --git a/driver/cache_map.h b/driver/cache_map.h index 7e1e11753..f01adaecf 100644 --- a/driver/cache_map.h +++ b/driver/cache_map.h @@ -27,8 +27,8 @@ // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. -#ifndef __CACHE_MAP__ -#define __CACHE_MAP__ +#ifndef __CACHE_MAP_H__ +#define __CACHE_MAP_H__ #include #include diff --git a/driver/cluster_topology_info.cc b/driver/cluster_topology_info.cc index d33360d8a..db049caf7 100644 --- a/driver/cluster_topology_info.cc +++ b/driver/cluster_topology_info.cc @@ -30,6 +30,7 @@ #include "cluster_topology_info.h" #include +#include /** Initialize and return random number. @@ -75,6 +76,20 @@ void CLUSTER_TOPOLOGY_INFO::add_host(std::shared_ptr host_info) { update_time(); } +void CLUSTER_TOPOLOGY_INFO::remove_host(std::shared_ptr host_info) { + auto position = std::find(writers.begin(), writers.end(), host_info); + if (position != writers.end()) { + writers.erase(position); + } + + position = std::find(readers.begin(), readers.end(), host_info); + if (position != readers.end()) { + readers.erase(position); + } + update_time(); +} + + size_t CLUSTER_TOPOLOGY_INFO::total_hosts() { return writers.size() + readers.size(); } @@ -136,6 +151,13 @@ std::vector> CLUSTER_TOPOLOGY_INFO::get_writers() { return writers; } +std::vector> CLUSTER_TOPOLOGY_INFO::get_instances() { + std::vector instances(writers); + instances.insert(instances.end(), writers.begin(), writers.end()); + + return instances; +} + std::shared_ptr CLUSTER_TOPOLOGY_INFO::get_last_used_reader() { return last_used_reader; } diff --git a/driver/cluster_topology_info.h b/driver/cluster_topology_info.h index 90d840370..1e7271ee7 100644 --- a/driver/cluster_topology_info.h +++ b/driver/cluster_topology_info.h @@ -46,6 +46,7 @@ class CLUSTER_TOPOLOGY_INFO { virtual ~CLUSTER_TOPOLOGY_INFO(); void add_host(std::shared_ptr host_info); + void remove_host(std::shared_ptr host_info); size_t total_hosts(); size_t num_readers(); // return number of readers in the cluster std::time_t time_last_updated(); @@ -58,6 +59,7 @@ class CLUSTER_TOPOLOGY_INFO { std::shared_ptr get_reader(int i); std::vector> get_writers(); std::vector> get_readers(); + std::vector> get_instances(); private: int current_reader = -1; diff --git a/driver/custom_endpoint_info.h b/driver/custom_endpoint_info.h index fe67706f1..738be6b62 100644 --- a/driver/custom_endpoint_info.h +++ b/driver/custom_endpoint_info.h @@ -33,10 +33,7 @@ #include #include -#include -#include - -#include "MYODBC_MYSQL.h" +#include "stringutil.h" #include "mylog.h" /** diff --git a/driver/custom_endpoint_monitor.cc b/driver/custom_endpoint_monitor.cc index 8a9e0844a..b1edebcc0 100644 --- a/driver/custom_endpoint_monitor.cc +++ b/driver/custom_endpoint_monitor.cc @@ -27,18 +27,17 @@ // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. -#include "custom_endpoint_monitor.h" - #include #include #include #include +#include #include #include "allowed_and_blocked_hosts.h" #include "aws_sdk_helper.h" +#include "custom_endpoint_monitor.h" #include "driver.h" -#include "monitor_service.h" #include "mylog.h" namespace { @@ -47,13 +46,16 @@ AWS_SDK_HELPER SDK_HELPER; CACHE_MAP> CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache; -CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr& custom_endpoint_host_info, +CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, const std::string& endpoint_identifier, const std::string& region, - DataSource* ds, bool enable_logging) - : custom_endpoint_host_info(custom_endpoint_host_info), - endpoint_identifier(endpoint_identifier), - region(region), - enable_logging(enable_logging) { + long long refresh_rate_nanos, bool enable_logging) + : topology_service(topology_service), + custom_endpoint_host(custom_endpoint_host), + endpoint_identifier(endpoint_identifier), + region(region), + refresh_rate_nanos(refresh_rate_nanos), + enable_logging(enable_logging) { if (enable_logging) { this->logger = init_log_file(); } @@ -66,34 +68,57 @@ CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptrrds_client = std::make_shared( - Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config); + Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config); + + this->run(); +} +#ifdef UNIT_TEST_BUILD +CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, + const std::string& endpoint_identifier, const std::string& region, + long long refresh_rate_nanos, bool enable_logging, + std::shared_ptr client) + : topology_service(topology_service), + custom_endpoint_host(custom_endpoint_host), + endpoint_identifier(endpoint_identifier), + region(region), + refresh_rate_nanos(refresh_rate_nanos), + enable_logging(enable_logging), + rds_client(std::move(client)) { + if (enable_logging) { + this->logger = init_log_file(); + } this->run(); } +#endif bool CUSTOM_ENDPOINT_MONITOR::should_dispose() { return true; } bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const { auto default_val = std::shared_ptr(nullptr); - return custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), default_val) != default_val; + return custom_endpoint_cache.get(this->custom_endpoint_host, default_val) != default_val; } void CUSTOM_ENDPOINT_MONITOR::run() { this->thread_pool.resize(1); this->thread_pool.push([=](int id) { - MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", - this->custom_endpoint_host_info->get_host().c_str()); + MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str()); try { while (!this->should_stop.load()) { const std::chrono::time_point start = std::chrono::steady_clock::now(); Aws::RDS::Model::Filter filter; filter.SetName("db-cluster-endpoint-type"); - filter.SetValues({"custom"}); + filter.AddValues("custom"); Aws::RDS::Model::DescribeDBClusterEndpointsRequest request; - request.SetDBClusterIdentifier(this->endpoint_identifier); - request.SetFilters({filter}); + request.SetDBClusterEndpointIdentifier(this->endpoint_identifier); + // TODO: Investigate why filters returns `InvalidParameterCombination` error saying filter values are null. + // request.AddFilters(filter); + if (!this->rds_client) { + break; + } const auto response = this->rds_client->DescribeDBClusterEndpoints(request); const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints(); @@ -108,37 +133,37 @@ void CUSTOM_ENDPOINT_MONITOR::run() { continue; } const std::shared_ptr endpoint_info = - CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]); + CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]); const std::shared_ptr cache_endpoint_info = - custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), nullptr); + custom_endpoint_cache.get(this->custom_endpoint_host, nullptr); if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) { const long long elapsed_time = - std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); + std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); std::this_thread::sleep_for( - std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); + std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); continue; } MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}", - custom_endpoint_host_info->get_host().c_str(), endpoint_info->to_string().c_str()); + custom_endpoint_host.c_str(), endpoint_info->to_string().c_str()); // The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts. std::shared_ptr allowed_and_blocked_hosts; if (endpoint_info->get_member_list_type() == STATIC_LIST) { allowed_and_blocked_hosts = - std::make_shared(endpoint_info->get_static_members(), std::set()); + std::make_shared(endpoint_info->get_static_members(), std::set()); } else { - allowed_and_blocked_hosts = - std::make_shared(std::set(), endpoint_info->get_excluded_members()); + allowed_and_blocked_hosts = std::make_shared( + std::set(), endpoint_info->get_excluded_members()); } - custom_endpoint_cache.put(this->custom_endpoint_host_info->get_host(), endpoint_info, - CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS); + this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts); + custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS); const long long elapsed_time = - std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); + std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); std::this_thread::sleep_for( - std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); + std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); } } catch (const std::exception& e) { @@ -149,7 +174,7 @@ void CUSTOM_ENDPOINT_MONITOR::run() { } std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string( - const std::vector& custom_endpoints) { + const std::vector& custom_endpoints) { if (custom_endpoints.empty()) { return ""; } @@ -171,9 +196,10 @@ void CUSTOM_ENDPOINT_MONITOR::stop() { this->should_stop.store(true); this->thread_pool.stop(true); this->thread_pool.resize(0); - custom_endpoint_cache.remove(this->custom_endpoint_host_info->get_host()); + custom_endpoint_cache.remove(this->custom_endpoint_host); + this->rds_client.reset(); --SDK_HELPER; - MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host_info->get_host().c_str()); + MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str()); } void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); } diff --git a/driver/custom_endpoint_monitor.h b/driver/custom_endpoint_monitor.h index e824b3d72..6c6f5b7ef 100644 --- a/driver/custom_endpoint_monitor.h +++ b/driver/custom_endpoint_monitor.h @@ -34,15 +34,23 @@ #include #include "cache_map.h" -#include "connection_handler.h" -#include "connection_proxy.h" #include "custom_endpoint_info.h" #include "host_info.h" +#include "topology_service.h" class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this { public: - CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr& custom_endpoint_host_info, const std::string& endpoint_identifier, - const std::string& region, DataSource* ds, bool enable_logging = false); + CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, const std::string& endpoint_identifier, + const std::string& region, long long refresh_rate_nanos, bool enable_logging = false); +#ifdef UNIT_TEST_BUILD + CUSTOM_ENDPOINT_MONITOR() = default; + CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, const std::string& endpoint_identifier, + const std::string& region, long long refresh_rate_nanos, bool enable_logging, + std::shared_ptr client); +#endif + ~CUSTOM_ENDPOINT_MONITOR() = default; static bool should_dispose(); @@ -54,7 +62,7 @@ class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this> custom_endpoint_cache; static constexpr long long CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS = 300000000000; // 5 minutes - std::shared_ptr custom_endpoint_host_info; + std::string custom_endpoint_host; std::string endpoint_identifier; std::string region; long long refresh_rate_nanos; @@ -63,9 +71,15 @@ class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this rds_client; + std::shared_ptr topology_service; private: static std::string get_endpoints_as_string(const std::vector& custom_endpoints); + +#ifdef UNIT_TEST_BUILD + // Allows for testing private/protected methods + friend class TEST_UTILS; +#endif }; #endif diff --git a/driver/custom_endpoint_proxy.cc b/driver/custom_endpoint_proxy.cc index 777ae5bdb..041a6cffc 100644 --- a/driver/custom_endpoint_proxy.cc +++ b/driver/custom_endpoint_proxy.cc @@ -28,27 +28,32 @@ // http://www.gnu.org/licenses/gpl-2.0.html. #include "custom_endpoint_proxy.h" -#include "custom_endpoint_monitor.h" -#include "installer.h" #include "mylog.h" #include "rds_utils.h" SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> - CUSTOM_ENDPOINT_PROXY::monitors(std::make_shared(), - std::make_shared(), CACHE_CLEANUP_RATE_NANO); + CUSTOM_ENDPOINT_PROXY::monitors(std::make_shared(), + std::make_shared(), CACHE_CLEANUP_RATE_NANO); + +bool CUSTOM_ENDPOINT_PROXY::is_monitor_cache_initialized(false); CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds) : CUSTOM_ENDPOINT_PROXY(dbc, ds, nullptr) {} CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) - : CONNECTION_PROXY(dbc, ds) { + : CONNECTION_PROXY(dbc, ds) { this->next_proxy = next_proxy; + this->topology_service = dbc->get_topology_service(); if (ds->opt_LOG_QUERY) { this->logger = init_log_file(); } - this->should_wait_for_info = ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO; this->wait_on_cached_info_duration_ms = ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; this->idle_monitor_expiration_ms = ds->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS; + + if (!is_monitor_cache_initialized) { + monitors.init_clean_up_thread(); + is_monitor_cache_initialized = true; + } } bool CUSTOM_ENDPOINT_PROXY::connect(const char* host, const char* user, const char* password, const char* database, @@ -71,8 +76,8 @@ bool CUSTOM_ENDPOINT_PROXY::connect(const char* host, const char* user, const ch : RDS_UTILS::get_rds_region(host); if (this->region.empty()) { this->set_custom_error_message( - "Unable to determine connection region. If you are using a non-standard RDS URL, please set the " - "'custom_endpoint_region' property"); + "Unable to determine connection region. If you are using a non-standard RDS URL, please set the " + "'custom_endpoint_region' property"); return false; } @@ -85,6 +90,26 @@ bool CUSTOM_ENDPOINT_PROXY::connect(const char* host, const char* user, const ch return this->next_proxy->connect(host, user, password, database, port, socket, flags); } +int CUSTOM_ENDPOINT_PROXY::query(const char* q) { + const std::shared_ptr monitor = create_monitor_if_absent(ds); + if (this->should_wait_for_info) { + // If needed, wait a short time for custom endpoint info to be discovered. + this->wait_for_custom_endpoint_info(monitor); + } + + return next_proxy->query(q); +} + +int CUSTOM_ENDPOINT_PROXY::real_query(const char* q, unsigned long length) { + const std::shared_ptr monitor = create_monitor_if_absent(ds); + if (this->should_wait_for_info) { + // If needed, wait a short time for custom endpoint info to be discovered. + this->wait_for_custom_endpoint_info(monitor); + } + + return next_proxy->real_query(q, length); +} + void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptr monitor) { bool has_custom_endpoint_info = monitor->has_custom_endpoint_info(); @@ -96,11 +121,11 @@ void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptrlogger, 0, "Custom endpoint info for '%s' was not found. Waiting %dms for the endpoint monitor to fetch info...", - this->custom_endpoint_host_info->get_host().c_str(), this->wait_on_cached_info_duration_ms) + this->custom_endpoint_host.c_str(), this->wait_on_cached_info_duration_ms) const auto wait_for_endpoint_info_timeout_nanos = - std::chrono::steady_clock::now() + std::chrono::duration_cast( - std::chrono::milliseconds(this->wait_on_cached_info_duration_ms)); + std::chrono::steady_clock::now() + std::chrono::duration_cast( + std::chrono::milliseconds(this->wait_on_cached_info_duration_ms)); while (!has_custom_endpoint_info && std::chrono::steady_clock::now() < wait_for_endpoint_info_timeout_nanos) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); @@ -110,24 +135,26 @@ void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptrwait_on_cached_info_duration_ms, this->custom_endpoint_host_info->get_host().c_str()); + buf, sizeof(buf), + "The custom endpoint plugin timed out after %ld ms while waiting for custom endpoint info for host %s.", + this->wait_on_cached_info_duration_ms, this->custom_endpoint_host.c_str()); set_custom_error_message(buf); } } +std::shared_ptr CUSTOM_ENDPOINT_PROXY::create_custom_endpoint_monitor( + const long long refresh_rate_nanos) { + return std::make_shared(this->topology_service, this->custom_endpoint_host, + this->custom_endpoint_id, this->region, refresh_rate_nanos); +} + std::shared_ptr CUSTOM_ENDPOINT_PROXY::create_monitor_if_absent(DataSource* ds) { - const auto refresh_rate_nanos = std::chrono::duration_cast( - std::chrono::milliseconds(ds->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS)) - .count(); + const long long refresh_rate_nanos = std::chrono::duration_cast( + std::chrono::milliseconds(ds->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS)) + .count(); return monitors.compute_if_absent( - this->custom_endpoint_host_info->get_host(), - [=](std::string key) { - return std::make_shared(this->custom_endpoint_host_info, this->custom_endpoint_id, - this->region, ds); - }, - refresh_rate_nanos); + this->custom_endpoint_host, + [=](std::string key) { return this->create_custom_endpoint_monitor(refresh_rate_nanos); }, refresh_rate_nanos); } diff --git a/driver/custom_endpoint_proxy.h b/driver/custom_endpoint_proxy.h index a3c96d77d..4c3fefbb5 100644 --- a/driver/custom_endpoint_proxy.h +++ b/driver/custom_endpoint_proxy.h @@ -27,13 +27,13 @@ // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. -#ifndef __CUSTOM_ENDPOINT_PROXY__ -#define __CUSTOM_ENDPOINT_PROXY__ +#ifndef __CUSTOM_ENDPOINT_PROXY_H__ +#define __CUSTOM_ENDPOINT_PROXY_H__ -#include #include #include "connection_proxy.h" #include "custom_endpoint_monitor.h" +#include "driver.h" #include "sliding_expiration_cache_with_clean_up_thread.h" class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { @@ -44,6 +44,9 @@ class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { bool connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, const char* socket, unsigned long flags) override; + int query(const char* q) override; + int real_query(const char* q, unsigned long length) override; + class CUSTOM_ENDPOINTS_SHOULD_DISPOSE_FUNC : public SHOULD_DISPOSE_FUNC> { public: bool should_dispose(std::shared_ptr item) override { return true; } @@ -62,14 +65,15 @@ class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { static constexpr long long CACHE_CLEANUP_RATE_NANO = 60000000000; protected: + static bool is_monitor_cache_initialized; std::string custom_endpoint_id; std::string region; std::string custom_endpoint_host; - std::shared_ptr custom_endpoint_host_info; std::shared_ptr rds_client; bool should_wait_for_info; long wait_on_cached_info_duration_ms; long idle_monitor_expiration_ms; + std::shared_ptr topology_service; static SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> monitors; @@ -84,6 +88,11 @@ class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { private: std::shared_ptr logger; + virtual std::shared_ptr create_custom_endpoint_monitor(long long refresh_rate_nanos); +#ifdef UNIT_TEST_BUILD + // Allows for testing private/protected methods + friend class TEST_UTILS; +#endif }; #endif diff --git a/driver/driver.h b/driver/driver.h index c41caad62..a52cee9ee 100644 --- a/driver/driver.h +++ b/driver/driver.h @@ -47,6 +47,7 @@ #include "connection_handler.h" #include "connection_proxy.h" +#include "topology_service.h" #include "failover.h" /* Disable _attribute__ on non-gcc compilers. */ @@ -627,6 +628,7 @@ struct DBC FAILOVER_HANDLER *fh = nullptr; /* Failover handler */ std::shared_ptr connection_handler = nullptr; + std::shared_ptr topology_service = nullptr; DBC(ENV *p_env); void free_explicit_descriptors(); @@ -639,6 +641,10 @@ struct DBC void execute_prep_stmt(MYSQL_STMT *pstmt, std::string &query, std::vector ¶m_bind, MYSQL_BIND *result_bind); void init_proxy_chain(DataSource *dsrc); + std::shared_ptr get_topology_service() { + return this->topology_service ? this->topology_service + : std::make_shared(this->id, ds ? ds->opt_LOG_QUERY : false); + } inline bool transactions_supported() { return connection_proxy->get_server_capabilities() & CLIENT_TRANSACTIONS; diff --git a/driver/failover_handler.cc b/driver/failover_handler.cc index f3233ac4d..8cdf0ec9f 100644 --- a/driver/failover_handler.cc +++ b/driver/failover_handler.cc @@ -52,8 +52,7 @@ const char* MYSQL_READONLY_QUERY = "SELECT @@innodb_read_only AS is_reader"; FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds) : FAILOVER_HANDLER( - dbc, ds, dbc ? dbc->connection_handler : nullptr, - std::make_shared(dbc ? dbc->id : 0, ds ? ds->opt_LOG_QUERY : false), + dbc, ds, dbc ? dbc->connection_handler : nullptr, dbc ? dbc->get_topology_service() : nullptr, std::make_shared(dbc, ds)) {} FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds, diff --git a/driver/handle.cc b/driver/handle.cc index 125ae86cf..7bee0ef95 100644 --- a/driver/handle.cc +++ b/driver/handle.cc @@ -123,13 +123,9 @@ void DBC::close() // construct a proxy chain, example: iam->efm->mysql void DBC::init_proxy_chain(DataSource* dsrc) { - CONNECTION_PROXY *head = new MYSQL_PROXY(this, dsrc); + this->topology_service = std::make_shared(this->id, ds ? ds->opt_LOG_QUERY : false); - if (dsrc->opt_ENABLE_CUSTOM_ENDPOINT_MONITORING) { - CONNECTION_PROXY* custom_endpoint_proxy = new CUSTOM_ENDPOINT_PROXY(this, dsrc); - custom_endpoint_proxy->set_next_proxy(head); - head = custom_endpoint_proxy; - } + CONNECTION_PROXY* head = new MYSQL_PROXY(this, dsrc); if (dsrc->opt_ENABLE_FAILURE_DETECTION) { CONNECTION_PROXY* efm_proxy = new EFM_PROXY(this, dsrc); @@ -165,6 +161,13 @@ void DBC::init_proxy_chain(DataSource* dsrc) } } + if (dsrc->opt_ENABLE_CUSTOM_ENDPOINT_MONITORING) { + CONNECTION_PROXY* custom_endpoint_proxy = new CUSTOM_ENDPOINT_PROXY(this, dsrc); + custom_endpoint_proxy->set_next_proxy(head); + head = custom_endpoint_proxy; + } + + this->connection_proxy = head; } @@ -173,6 +176,7 @@ DBC::~DBC() if (env) env->remove_dbc(this); + this->topology_service.reset(); if (connection_proxy) delete connection_proxy; diff --git a/driver/host_info.cc b/driver/host_info.cc index ba9e3beba..139a5ad29 100644 --- a/driver/host_info.cc +++ b/driver/host_info.cc @@ -29,6 +29,8 @@ #include "host_info.h" +#include "rds_utils.h" + // TODO // the entire HOST_INFO needs to be reviewed based on needed interfaces and other objects like CLUSTER_TOPOLOGY_INFO // most/all of the HOST_INFO potentially could be internal to CLUSTER_TOPOLOGY_INFO and specfic information may be accessed @@ -45,27 +47,30 @@ HOST_INFO::HOST_INFO(const char* host, int port) : HOST_INFO(host, port, UP, false) {} HOST_INFO::HOST_INFO(std::string host, int port, HOST_STATE state, bool is_writer) - : host{ host }, port{ port }, host_state{ state }, is_writer{ is_writer } -{ -} + : host{host}, host_id{RDS_UTILS::get_rds_instance_id(host)}, port{port}, host_state{state}, is_writer{is_writer} {} // would need some checks for nulls HOST_INFO::HOST_INFO(const char* host, int port, HOST_STATE state, bool is_writer) - : host{ host }, port{ port }, host_state{ state }, is_writer{ is_writer } -{ -} + : host{host}, host_id{RDS_UTILS::get_rds_instance_id(host)}, port{port}, host_state{state}, is_writer{is_writer} {} HOST_INFO::~HOST_INFO() {} /** - * Returns the host. + * Returns the host endpoint. * - * @return the host + * @return the host endpoint */ std::string HOST_INFO::get_host() { return host; } +/** + * Returns the host name. + * + * @return the host name + */ +std::string HOST_INFO::get_host_id() { return host_id; } + /** * Returns the port. * diff --git a/driver/host_info.h b/driver/host_info.h index e5c64a420..6049cac05 100644 --- a/driver/host_info.h +++ b/driver/host_info.h @@ -49,6 +49,7 @@ class HOST_INFO { int get_port(); std::string get_host(); + std::string get_host_id(); std::string get_host_port_pair(); bool equal_host_port_pair(HOST_INFO& hi); HOST_STATE get_host_state(); @@ -69,6 +70,7 @@ class HOST_INFO { private: const std::string HOST_PORT_SEPARATOR = ":"; const std::string host; + const std::string host_id; const int port = NO_PORT; HOST_STATE host_state; diff --git a/driver/rds_utils.cc b/driver/rds_utils.cc index 5f674e677..a49667a0c 100644 --- a/driver/rds_utils.cc +++ b/driver/rds_utils.cc @@ -31,7 +31,7 @@ namespace { const std::regex AURORA_DNS_PATTERN( - R"#((.+)\.(proxy-|cluster-|cluster-ro-|cluster-custom-)?([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", + R"#((.+)\.(proxy-|cluster-|cluster-ro-|cluster-custom-)?([a-zA-Z0-9]+\.([a-zA-Z0-9\-]+)\.rds\.amazonaws\.com))#", std::regex_constants::icase); const std::regex AURORA_PROXY_DNS_PATTERN(R"#((.+)\.(proxy-)+([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", std::regex_constants::icase); @@ -136,30 +136,44 @@ std::string RDS_UTILS::get_rds_cluster_host_url(std::string host) { std::string RDS_UTILS::get_rds_cluster_id(std::string host) { auto f = [host](const std::regex pattern) { std::smatch m; - if (std::regex_search(host, m, pattern) && m.size() > 1) { + if (std::regex_search(host, m, pattern) && m.size() > 1 && !m.str(2).empty()) { return m.str(1); } return std::string(); }; - auto result = f(AURORA_CLUSTER_PATTERN); + auto result = f(AURORA_DNS_PATTERN); if (!result.empty()) { return result; } - return f(AURORA_CHINA_CLUSTER_PATTERN); + return f(AURORA_CHINA_DNS_PATTERN); } +std::string RDS_UTILS::get_rds_instance_id(std::string host) { + auto f = [host](const std::regex pattern) { + std::smatch m; + if (std::regex_search(host, m, pattern) && m.size() > 1 && m.str(2).empty()) { + return m.str(1); + } + return std::string(); + }; + + auto result = f(AURORA_DNS_PATTERN); + if (!result.empty()) { + return result; + } + + return f(AURORA_CHINA_DNS_PATTERN); +} std::string RDS_UTILS::get_rds_instance_host_pattern(std::string host) { auto f = [host](const std::regex pattern) { std::smatch m; - if (std::regex_search(host, m, pattern) && m.size() > 3) { - if (!m.str(3).empty()) { - std::string result("?."); - result.append(m.str(3)); + if (std::regex_search(host, m, pattern) && m.size() > 4 && !m.str(3).empty()) { + std::string result("?."); + result.append(m.str(3)); - return result; - } + return result; } return std::string(); }; @@ -174,7 +188,10 @@ std::string RDS_UTILS::get_rds_instance_host_pattern(std::string host) { std::string RDS_UTILS::get_rds_region(std::string host) { auto f = [host](const std::regex pattern) { - // TODO: implement region + std::smatch m; + if (std::regex_search(host, m, pattern) && m.size() > 4 && !m.str(4).empty()) { + return m.str(4); + } return std::string(); }; diff --git a/driver/rds_utils.h b/driver/rds_utils.h index 34ad0cfd9..4c4b1648c 100644 --- a/driver/rds_utils.h +++ b/driver/rds_utils.h @@ -47,6 +47,7 @@ class RDS_UTILS { static std::string get_rds_cluster_host_url(std::string host); static std::string get_rds_cluster_id(std::string host); static std::string get_rds_instance_host_pattern(std::string host); + static std::string get_rds_instance_id(std::string host); static std::string get_rds_region(std::string host); }; diff --git a/driver/sliding_expiration_cache.cc b/driver/sliding_expiration_cache.cc index 423246e7e..af56f0a1c 100644 --- a/driver/sliding_expiration_cache.cc +++ b/driver/sliding_expiration_cache.cc @@ -69,9 +69,9 @@ template V SLIDING_EXPIRATION_CACHE::compute_if_absent(K key, std::function mapping_function, long long item_expiration_nanos) { this->clean_up(); - auto cache_item = std::make_shared(mapping_function(key), - std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); - this->cache[key] = cache_item; + V item = mapping_function(key); + auto cache_item = std::make_shared(item, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); + this->cache.emplace(key, cache_item); return cache_item->with_extend_expiration(item_expiration_nanos)->item; } diff --git a/driver/sliding_expiration_cache_with_clean_up_thread.cc b/driver/sliding_expiration_cache_with_clean_up_thread.cc index 7d6056466..cc43582c5 100644 --- a/driver/sliding_expiration_cache_with_clean_up_thread.cc +++ b/driver/sliding_expiration_cache_with_clean_up_thread.cc @@ -30,7 +30,6 @@ #include "sliding_expiration_cache_with_clean_up_thread.h" #include -#include #include "custom_endpoint_monitor.h" @@ -40,7 +39,6 @@ void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::init_clean_up_thread() std::unique_lock lock(mutex_); if (!this->is_initialized) { this->clean_up_thread_pool.resize(this->clean_up_thread_pool.size() + 1); - this->clean_up_thread_pool.push([=](int id) { while (!should_stop) { const std::chrono::nanoseconds clean_up_interval = std::chrono::nanoseconds(this->clean_up_interval_nanos); @@ -61,42 +59,16 @@ void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::init_clean_up_thread() } } -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() { - this->init_clean_up_thread(); -} - -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - std::shared_ptr> should_dispose_func, - std::shared_ptr> item_disposal_func) - : SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func)) { - this->init_clean_up_thread(); -} - -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - std::shared_ptr> should_dispose_func, - std::shared_ptr> item_disposal_func, long long clean_up_interval_nanos) - : SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func), clean_up_interval_nanos) { - this->init_clean_up_thread(); -} - -#ifdef UNIT_TEST_BUILD -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - long long clean_up_interval_nanos) { - this->clean_up_interval_nanos = clean_up_interval_nanos; - this->init_clean_up_thread(); -} -#endif - template void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::release_resources() { - this->should_stop = true; - this->clean_up_thread_pool.stop(true); - this->clean_up_thread_pool.resize(0); - this->is_initialized = false; + std::unique_lock lock(mutex_); + { + this->should_stop = true; + this->clean_up_thread_pool.stop(true); + this->clean_up_thread_pool.resize(0); + this->is_initialized = false; + this->clear(); + } } template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD; diff --git a/driver/sliding_expiration_cache_with_clean_up_thread.h b/driver/sliding_expiration_cache_with_clean_up_thread.h index 807744347..dd0583f84 100644 --- a/driver/sliding_expiration_cache_with_clean_up_thread.h +++ b/driver/sliding_expiration_cache_with_clean_up_thread.h @@ -38,12 +38,12 @@ template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD : public SLIDING_EXPIRATION_CACHE { public: - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(); + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default; SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr> should_dispose_func, - std::shared_ptr> item_disposal_func); + std::shared_ptr> item_disposal_func){}; SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr> should_dispose_func, std::shared_ptr> item_disposal_func, - long long clean_up_interval_nanos); + long long clean_up_interval_nanos){}; ~SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default; /** @@ -52,17 +52,17 @@ class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD : public SLIDING_EXPIRATION_ void release_resources(); #ifdef UNIT_TEST_BUILD - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(long long clean_up_interval_nanos); + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(long long clean_up_interval_nanos) { + this->clean_up_interval_nanos = clean_up_interval_nanos; + }; #endif + void init_clean_up_thread(); protected: bool is_initialized = false; bool should_stop = false; std::mutex mutex_; ctpl::thread_pool clean_up_thread_pool; - - private: - void init_clean_up_thread(); }; #endif diff --git a/driver/topology_service.cc b/driver/topology_service.cc index 265dee136..7364b9ad2 100644 --- a/driver/topology_service.cc +++ b/driver/topology_service.cc @@ -164,26 +164,46 @@ std::shared_ptr TOPOLOGY_SERVICE::get_cached_topology() { return get_from_cache(); } -//TODO consider the return value -//Note to determine whether or not force_update succeeded one would compare +// TODO consider the return value +// Note to determine whether or not force_update succeeded one would compare // CLUSTER_TOPOLOGY_INFO->time_last_updated() prior and after the call if non-null information was given prior. std::shared_ptr TOPOLOGY_SERVICE::get_topology(CONNECTION_PROXY* connection, bool force_update) { - //TODO reconsider using this cache. It appears that we only store information for the current cluster Id. + // TODO reconsider using this cache. It appears that we only store information for the current cluster Id. // therefore instead of a map we can just keep CLUSTER_TOPOLOGY_INFO* topology_info member variable. - auto cached_topology = get_from_cache(); - if (!cached_topology - || force_update - || refresh_needed(cached_topology->time_last_updated())) - { - auto latest_topology = query_for_topology(connection); - if (latest_topology) { + auto topology = get_from_cache(); + if (!topology || force_update || refresh_needed(topology->time_last_updated())) { + if (auto latest_topology = query_for_topology(connection)) { put_to_cache(latest_topology); - return latest_topology; + topology = latest_topology; } } - return cached_topology; + if (!this->allowed_and_blocked_hosts) { + return topology; + } + + std::set allowed_list = this->allowed_and_blocked_hosts->get_allowed_host_ids(); + std::set blocked_list = this->allowed_and_blocked_hosts->get_blocked_host_ids(); + + const std::shared_ptr filtered_topology = topology; + if (allowed_list.size() > 0) { + for (const auto& host : topology->get_instances()) { + if (allowed_list.find(host->get_host_id()) != allowed_list.end()) { + filtered_topology->add_host(host); + } + } + } + + if (blocked_list.size() > 0) { + for (const auto& host : filtered_topology->get_instances()) { + // Remove blocked hosts from the filtered_topology. + if (blocked_list.find(host->get_host_id()) != blocked_list.end()) { + filtered_topology->remove_host(host); + } + } + } + return filtered_topology; } // TODO consider thread safety and usage of pointers diff --git a/driver/topology_service.h b/driver/topology_service.h index 60f4b6d4d..3464e8519 100644 --- a/driver/topology_service.h +++ b/driver/topology_service.h @@ -24,7 +24,7 @@ // See the GNU General Public License, version 2.0, for more details. // // You should have received a copy of the GNU General Public License -// along with this program. If not, see +// along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. #ifndef __TOPOLOGYSERVICE_H__ @@ -36,16 +36,18 @@ #include "cluster_topology_info.h" #include "connection_proxy.h" -#include -#include #include #include +#include +#include +#include "allowed_and_blocked_hosts.h" // TODO - consider - do we really need miliseconds for refresh? - the default numbers here are already 30 seconds.000; #define DEFAULT_REFRESH_RATE_IN_MILLISECONDS 30000 #define WRITER_SESSION_ID "MASTER_SESSION_ID" -#define RETRIEVE_TOPOLOGY_SQL "SELECT SERVER_ID, SESSION_ID, LAST_UPDATE_TIMESTAMP, REPLICA_LAG_IN_MILLISECONDS \ +#define RETRIEVE_TOPOLOGY_SQL \ + "SELECT SERVER_ID, SESSION_ID, LAST_UPDATE_TIMESTAMP, REPLICA_LAG_IN_MILLISECONDS \ FROM information_schema.replica_host_status \ WHERE time_to_sec(timediff(now(), LAST_UPDATE_TIMESTAMP)) <= 300 \ ORDER BY LAST_UPDATE_TIMESTAMP DESC" @@ -54,69 +56,74 @@ static std::map> topology_ca static std::mutex topology_cache_mutex; class TOPOLOGY_SERVICE { -public: - TOPOLOGY_SERVICE(unsigned long dbc_id, bool enable_logging = false); - TOPOLOGY_SERVICE(const TOPOLOGY_SERVICE&); - virtual ~TOPOLOGY_SERVICE(); - - virtual void set_cluster_id(std::string cluster_id); - virtual void set_cluster_instance_template(std::shared_ptr host_template); //is this equivalent to setcluster_instance_host - - virtual std::shared_ptr get_topology( - CONNECTION_PROXY* connection, bool force_update = false); - std::shared_ptr get_cached_topology(); - - std::shared_ptr get_last_used_reader(); - void set_last_used_reader(std::shared_ptr reader); - std::set get_down_hosts(); - virtual void mark_host_down(std::shared_ptr host); - virtual void mark_host_up(std::shared_ptr host); - void set_refresh_rate(int refresh_rate); - void set_gather_metric(bool can_gather); - void clear_all(); - void clear(); - - // Property Keys - const std::string SESSION_ID = "TOPOLOGY_SERVICE_SESSION_ID"; - const std::string LAST_UPDATED = "TOPOLOGY_SERVICE_LAST_UPDATE_TIMESTAMP"; - const std::string REPLICA_LAG = "TOPOLOGY_SERVICE_REPLICA_LAG_IN_MILLISECONDS"; - const std::string INSTANCE_NAME = "TOPOLOGY_SERVICE_SERVER_ID"; - -private: - const int DEFAULT_CACHE_EXPIRE_MS = 5 * 60 * 1000; // 5 min - - const std::string GET_INSTANCE_NAME_SQL = "SELECT @@aurora_server_id"; - const std::string GET_INSTANCE_NAME_COL = "@@aurora_server_id"; - - const std::string FIELD_SERVER_ID = "SERVER_ID"; - const std::string FIELD_SESSION_ID = "SESSION_ID"; - const std::string FIELD_LAST_UPDATED = "LAST_UPDATE_TIMESTAMP"; - const std::string FIELD_REPLICA_LAG = "REPLICA_LAG_IN_MILLISECONDS"; - - std::shared_ptr logger = nullptr; - unsigned long dbc_id = 0; - -protected: - const int NO_CONNECTION_INDEX = -1; - int refresh_rate_in_ms; - - std::string cluster_id; - std::shared_ptr cluster_instance_host; - - std::shared_ptr metrics_container; - - bool refresh_needed(std::time_t last_updated); - std::shared_ptr query_for_topology(CONNECTION_PROXY* connection); - std::shared_ptr create_host(MYSQL_ROW& row); - std::string get_host_endpoint(const char* node_name); - static bool does_instance_exist( - std::map>& instances, - std::shared_ptr host_info); - - std::shared_ptr get_from_cache(); - void put_to_cache(std::shared_ptr topology_info); - - MYSQL_RES* try_execute_query(CONNECTION_PROXY* connection_proxy, const char* query); + public: + TOPOLOGY_SERVICE(unsigned long dbc_id, bool enable_logging = false); + TOPOLOGY_SERVICE(const TOPOLOGY_SERVICE&); + virtual ~TOPOLOGY_SERVICE(); + + virtual void set_cluster_id(std::string cluster_id); + virtual void set_cluster_instance_template( + std::shared_ptr host_template); // is this equivalent to set_cluster_instance_host + + virtual std::shared_ptr get_topology(CONNECTION_PROXY* connection, bool force_update = false); + + std::shared_ptr get_cached_topology(); + + std::shared_ptr get_last_used_reader(); + void set_last_used_reader(std::shared_ptr reader); + std::set get_down_hosts(); + virtual void mark_host_down(std::shared_ptr host); + virtual void mark_host_up(std::shared_ptr host); + void set_refresh_rate(int refresh_rate); + void set_gather_metric(bool can_gather); + void clear_all(); + void clear(); + void set_allowed_and_blocked_hosts(std::shared_ptr hosts) { + this->allowed_and_blocked_hosts = hosts; + }; + + // Property Keys + const std::string SESSION_ID = "TOPOLOGY_SERVICE_SESSION_ID"; + const std::string LAST_UPDATED = "TOPOLOGY_SERVICE_LAST_UPDATE_TIMESTAMP"; + const std::string REPLICA_LAG = "TOPOLOGY_SERVICE_REPLICA_LAG_IN_MILLISECONDS"; + const std::string INSTANCE_NAME = "TOPOLOGY_SERVICE_SERVER_ID"; + + private: + const int DEFAULT_CACHE_EXPIRE_MS = 5 * 60 * 1000; // 5 min + + const std::string GET_INSTANCE_NAME_SQL = "SELECT @@aurora_server_id"; + const std::string GET_INSTANCE_NAME_COL = "@@aurora_server_id"; + + const std::string FIELD_SERVER_ID = "SERVER_ID"; + const std::string FIELD_SESSION_ID = "SESSION_ID"; + const std::string FIELD_LAST_UPDATED = "LAST_UPDATE_TIMESTAMP"; + const std::string FIELD_REPLICA_LAG = "REPLICA_LAG_IN_MILLISECONDS"; + + std::shared_ptr logger = nullptr; + unsigned long dbc_id = 0; + + protected: + const int NO_CONNECTION_INDEX = -1; + int refresh_rate_in_ms; + + std::string cluster_id; + std::shared_ptr cluster_instance_host; + + std::shared_ptr metrics_container; + + std::shared_ptr allowed_and_blocked_hosts; + + bool refresh_needed(std::time_t last_updated); + std::shared_ptr query_for_topology(CONNECTION_PROXY* connection); + std::shared_ptr create_host(MYSQL_ROW& row); + std::string get_host_endpoint(const char* node_name); + static bool does_instance_exist(std::map>& instances, + std::shared_ptr host_info); + + std::shared_ptr get_from_cache(); + void put_to_cache(std::shared_ptr topology_info); + + MYSQL_RES* try_execute_query(CONNECTION_PROXY* connection_proxy, const char* query); }; #endif /* __TOPOLOGYSERVICE_H__ */ diff --git a/integration/CMakeLists.txt b/integration/CMakeLists.txt index fe716112a..733c42dd9 100644 --- a/integration/CMakeLists.txt +++ b/integration/CMakeLists.txt @@ -99,6 +99,7 @@ set(TEST_SOURCES base_failover_integration_test.cc connection_string_builder_test.cc) set(INTEGRATION_TESTS + custom_endpoint_integration_test.cc iam_authentication_integration_test.cc secrets_manager_integration_test.cc network_failover_integration_test.cc diff --git a/integration/base_failover_integration_test.cc b/integration/base_failover_integration_test.cc index 77390b649..3e1e429b4 100644 --- a/integration/base_failover_integration_test.cc +++ b/integration/base_failover_integration_test.cc @@ -37,7 +37,9 @@ #include #include #include +#include #include +#include #include #include @@ -49,14 +51,17 @@ #include #include #include +#include #include #include -#include -#include -#include +#if defined(__APPLE__) || defined(__linux__) #include +#include #include +#include +#include +#endif #include "connection_string_builder.h" #include "integration_test_utils.h" @@ -267,6 +272,13 @@ class BaseFailoverIntegrationTest : public testing::Test { } } + static Aws::RDS::Model::DBClusterEndpoint get_custom_endpoint_info(const Aws::RDS::RDSClient& client, const std::string& endpoint_id) { + Aws::RDS::Model::DescribeDBClusterEndpointsRequest request; + request.SetDBClusterEndpointIdentifier(endpoint_id); + const auto response = client.DescribeDBClusterEndpoints(request); + return response.GetResult().GetDBClusterEndpoints()[0]; + } + static Aws::RDS::Model::DBClusterMember get_DB_cluster_writer_instance(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id) { Aws::RDS::Model::DBClusterMember instance; const Aws::RDS::Model::DBCluster cluster = get_DB_cluster(client, cluster_id); diff --git a/integration/connection_string_builder.h b/integration/connection_string_builder.h index 8eb450d25..f31cf1a6a 100644 --- a/integration/connection_string_builder.h +++ b/integration/connection_string_builder.h @@ -24,22 +24,22 @@ // See the GNU General Public License, version 2.0, for more details. // // You should have received a copy of the GNU General Public License -// along with this program. If not, see +// along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. #ifndef __CONNECTIONSTRINGBUILDER_H__ #define __CONNECTIONSTRINGBUILDER_H__ +#include #include #include -#include class ConnectionStringBuilder; class ConnectionString { - public: - // friend class so the builder can access ConnectionString private attributes - friend class ConnectionStringBuilder; + public: + // friend class so the builder can access ConnectionString private attributes + friend class ConnectionStringBuilder; ConnectionString() : m_dsn(""), m_server(""), m_port(-1), m_uid(""), m_pwd(""), m_db(""), m_log_query(true), @@ -49,9 +49,9 @@ class ConnectionString { m_failure_detection_interval(-1), m_failure_detection_count(-1), m_monitor_disposal_time(-1), m_read_timeout(-1), m_write_timeout(-1), m_auth_mode(""), m_auth_region(""), m_auth_host(""), m_auth_port(-1), m_auth_expiration(-1), m_secret_id(""), - + is_set_uid(false), is_set_pwd(false), is_set_db(false), is_set_log_query(false), - is_set_failover_mode(false), + is_set_failover_mode(false), is_set_multi_statements(false), is_set_enable_cluster_failover(false), is_set_failover_timeout(false), is_set_connect_timeout(false), is_set_network_timeout(false), is_set_host_pattern(false), is_set_enable_failure_detection(false), is_set_failure_detection_time(false), is_set_failure_detection_timeout(false), @@ -59,417 +59,493 @@ class ConnectionString { is_set_read_timeout(false), is_set_write_timeout(false), is_set_auth_mode(false), is_set_auth_region(false), is_set_auth_host(false), is_set_auth_port(false), is_set_auth_expiration(false), is_set_secret_id(false) {}; - std::string get_connection_string() const { - char conn_in[4096] = "\0"; - int length = 0; - length += sprintf(conn_in, "DSN=%s;SERVER=%s;PORT=%d;", m_dsn.c_str(), m_server.c_str(), m_port); - - if (is_set_uid) { - length += sprintf(conn_in + length, "UID=%s;", m_uid.c_str()); - } - if (is_set_pwd) { - length += sprintf(conn_in + length, "PWD=%s;", m_pwd.c_str()); - } - if (is_set_db) { - length += sprintf(conn_in + length, "DATABASE=%s;", m_db.c_str()); - } - if (is_set_log_query) { - length += sprintf(conn_in + length, "LOG_QUERY=%d;", m_log_query ? 1 : 0); - } - if (is_set_failover_mode) { - length += sprintf(conn_in + length, "FAILOVER_MODE=%s;", m_failover_mode.c_str()); - } - if (is_set_multi_statements) { - length += sprintf(conn_in + length, "MULTI_STATEMENTS=%d;", m_multi_statements ? 1 : 0); - } - if (is_set_enable_cluster_failover) { - length += sprintf(conn_in + length, "ENABLE_CLUSTER_FAILOVER=%d;", m_enable_cluster_failover ? 1 : 0); - } - if (is_set_failover_timeout) { - length += sprintf(conn_in + length, "FAILOVER_TIMEOUT=%d;", m_failover_timeout); - } - if (is_set_connect_timeout) { - length += sprintf(conn_in + length, "CONNECT_TIMEOUT=%d;", m_connect_timeout); - } - if (is_set_network_timeout) { - length += sprintf(conn_in + length, "NETWORK_TIMEOUT=%d;", m_network_timeout); - } - if (is_set_host_pattern) { - length += sprintf(conn_in + length, "HOST_PATTERN=%s;", m_host_pattern.c_str()); - } - if (is_set_enable_failure_detection) { - length += sprintf(conn_in + length, "ENABLE_FAILURE_DETECTION=%d;", m_enable_failure_detection ? 1 : 0); - } - if (is_set_failure_detection_time) { - length += sprintf(conn_in + length, "FAILURE_DETECTION_TIME=%d;", m_failure_detection_time); - } - if (is_set_failure_detection_timeout) { - length += sprintf(conn_in + length, "FAILURE_DETECTION_TIMEOUT=%d;", m_failure_detection_timeout); - } - if (is_set_failure_detection_interval) { - length += sprintf(conn_in + length, "FAILURE_DETECTION_INTERVAL=%d;", m_failure_detection_interval); - } - if (is_set_failure_detection_count) { - length += sprintf(conn_in + length, "FAILURE_DETECTION_COUNT=%d;", m_failure_detection_count); - } - if (is_set_monitor_disposal_time) { - length += sprintf(conn_in + length, "MONITOR_DISPOSAL_TIME=%d;", m_monitor_disposal_time); - } - if (is_set_read_timeout) { - length += sprintf(conn_in + length, "READTIMEOUT=%d;", m_read_timeout); - } - if (is_set_write_timeout) { - length += sprintf(conn_in + length, "WRITETIMEOUT=%d;", m_write_timeout); - } - if (is_set_auth_mode) { - length += sprintf(conn_in + length, "AUTHENTICATION_MODE=%s;", m_auth_mode.c_str()); - } - if (is_set_auth_region) { - length += sprintf(conn_in + length, "AWS_REGION=%s;", m_auth_region.c_str()); - } - if (is_set_auth_host) { - length += sprintf(conn_in + length, "IAM_HOST=%s;", m_auth_host.c_str()); - } - if (is_set_auth_port) { - length += sprintf(conn_in + length, "IAM_PORT=%d;", m_auth_port); - } - if (is_set_auth_expiration) { - length += sprintf(conn_in + length, "IAM_EXPIRATION_TIME=%d;", m_auth_expiration); - } - if (is_set_secret_id) { - length += sprintf(conn_in + length, "SECRET_ID=%s;", m_secret_id.c_str()); - } - snprintf(conn_in + length, sizeof(conn_in) - length, "\0"); - - std::string connection_string(conn_in); - return connection_string; - } - - private: - // Required fields - std::string m_dsn, m_server; - int m_port; - - // Optional fields - std::string m_uid, m_pwd, m_db; - bool m_log_query, m_multi_statements, m_enable_cluster_failover; - int m_failover_timeout, m_connect_timeout, m_network_timeout; - std::string m_host_pattern, m_failover_mode; - bool m_enable_failure_detection; - int m_failure_detection_time, m_failure_detection_timeout, m_failure_detection_interval, m_failure_detection_count, m_monitor_disposal_time, m_read_timeout, m_write_timeout; - std::string m_auth_mode, m_auth_region, m_auth_host, m_secret_id; - int m_auth_port, m_auth_expiration; - - bool is_set_uid, is_set_pwd, is_set_db; - bool is_set_log_query, is_set_failover_mode, is_set_multi_statements; - bool is_set_enable_cluster_failover; - bool is_set_failover_timeout, is_set_connect_timeout, is_set_network_timeout; - bool is_set_host_pattern; - bool is_set_enable_failure_detection; - bool is_set_failure_detection_time, is_set_failure_detection_timeout, is_set_failure_detection_interval, is_set_failure_detection_count; - bool is_set_monitor_disposal_time; - bool is_set_read_timeout, is_set_write_timeout; - bool is_set_auth_mode, is_set_auth_region, is_set_auth_host, is_set_auth_port, is_set_auth_expiration, is_set_secret_id; - - void set_dsn(const std::string& dsn) { - m_dsn = dsn; - } - - void set_server(const std::string& server) { - m_server = server; - } - - void set_port(const int& port) { - m_port = port; - } - - void set_uid(const std::string& uid) { - m_uid = uid; - is_set_uid = true; - } + std::string get_connection_string() const { + char conn_in[4096] = "\0"; + int length = 0; + length += sprintf(conn_in, "DSN=%s;SERVER=%s;PORT=%d;", m_dsn.c_str(), m_server.c_str(), m_port); - void set_pwd(const std::string& pwd) { - m_pwd = pwd; - is_set_pwd = true; + if (is_set_uid) { + length += sprintf(conn_in + length, "UID=%s;", m_uid.c_str()); } - - void set_db(const std::string& db) { - m_db = db; - is_set_db = true; + if (is_set_pwd) { + length += sprintf(conn_in + length, "PWD=%s;", m_pwd.c_str()); } - - void set_log_query(const bool& log_query) { - m_log_query = log_query; - is_set_log_query = true; - } - - void set_failover_mode(const std::string& failover_mode) { - m_failover_mode = failover_mode; - is_set_failover_mode = true; + if (is_set_db) { + length += sprintf(conn_in + length, "DATABASE=%s;", m_db.c_str()); } - - void set_multi_statements(const bool& multi_statements) { - m_multi_statements = multi_statements; - is_set_multi_statements = true; + if (is_set_log_query) { + length += sprintf(conn_in + length, "LOG_QUERY=%d;", m_log_query ? 1 : 0); } - - void set_enable_cluster_failover(const bool& enable_cluster_failover) { - m_enable_cluster_failover = enable_cluster_failover; - is_set_enable_cluster_failover = true; + if (is_set_failover_mode) { + length += sprintf(conn_in + length, "FAILOVER_MODE=%s;", m_failover_mode.c_str()); } - - void set_failover_timeout(const int& failover_timeout) { - m_failover_timeout = failover_timeout; - is_set_failover_timeout = true; + if (is_set_multi_statements) { + length += sprintf(conn_in + length, "MULTI_STATEMENTS=%d;", m_multi_statements ? 1 : 0); } - - void set_connect_timeout(const int& connect_timeout) { - m_connect_timeout = connect_timeout; - is_set_connect_timeout = true; + if (is_set_enable_cluster_failover) { + length += sprintf(conn_in + length, "ENABLE_CLUSTER_FAILOVER=%d;", m_enable_cluster_failover ? 1 : 0); } - - void set_network_timeout(const int& network_timeout) { - m_network_timeout = network_timeout; - is_set_network_timeout = true; + if (is_set_failover_timeout) { + length += sprintf(conn_in + length, "FAILOVER_TIMEOUT=%d;", m_failover_timeout); } - - void set_host_pattern(const std::string& host_pattern) { - m_host_pattern = host_pattern; - is_set_host_pattern = true; + if (is_set_connect_timeout) { + length += sprintf(conn_in + length, "CONNECT_TIMEOUT=%d;", m_connect_timeout); } - - void set_enable_failure_detection(const bool& enable_failure_detection) { - m_enable_failure_detection = enable_failure_detection; - is_set_enable_failure_detection = true; + if (is_set_network_timeout) { + length += sprintf(conn_in + length, "NETWORK_TIMEOUT=%d;", m_network_timeout); } - - void set_failure_detection_time(const int& failure_detection_time) { - m_failure_detection_time = failure_detection_time; - is_set_failure_detection_time = true; + if (is_set_host_pattern) { + length += sprintf(conn_in + length, "HOST_PATTERN=%s;", m_host_pattern.c_str()); } - - void set_failure_detection_timeout(const int& failure_detection_timeout) { - m_failure_detection_timeout = failure_detection_timeout; - is_set_failure_detection_timeout = true; + if (is_set_enable_failure_detection) { + length += sprintf(conn_in + length, "ENABLE_FAILURE_DETECTION=%d;", m_enable_failure_detection ? 1 : 0); } - - void set_failure_detection_interval(const int& failure_detection_interval) { - m_failure_detection_interval = failure_detection_interval; - is_set_failure_detection_interval = true; + if (is_set_failure_detection_time) { + length += sprintf(conn_in + length, "FAILURE_DETECTION_TIME=%d;", m_failure_detection_time); } - - void set_failure_detection_count(const int& failure_detection_count) { - m_failure_detection_count = failure_detection_count; - is_set_failure_detection_count = true; + if (is_set_failure_detection_timeout) { + length += sprintf(conn_in + length, "FAILURE_DETECTION_TIMEOUT=%d;", m_failure_detection_timeout); } - - void set_monitor_disposal_time(const int& monitor_disposal_time) { - m_monitor_disposal_time = monitor_disposal_time; - is_set_monitor_disposal_time = true; + if (is_set_failure_detection_interval) { + length += sprintf(conn_in + length, "FAILURE_DETECTION_INTERVAL=%d;", m_failure_detection_interval); } - - void set_read_timeout(const int& read_timeout) { - m_read_timeout = read_timeout; - is_set_read_timeout = true; + if (is_set_failure_detection_count) { + length += sprintf(conn_in + length, "FAILURE_DETECTION_COUNT=%d;", m_failure_detection_count); } - - void set_write_timeout(const int& write_timeout) { - m_write_timeout = write_timeout; - is_set_write_timeout = true; + if (is_set_monitor_disposal_time) { + length += sprintf(conn_in + length, "MONITOR_DISPOSAL_TIME=%d;", m_monitor_disposal_time); } - - void set_auth_mode(const std::string& auth_mode) { - m_auth_mode = auth_mode; - is_set_auth_mode = true; + if (is_set_read_timeout) { + length += sprintf(conn_in + length, "READTIMEOUT=%d;", m_read_timeout); } - - void set_auth_region(const std::string& auth_region) { - m_auth_region = auth_region; - is_set_auth_region = true; + if (is_set_write_timeout) { + length += sprintf(conn_in + length, "WRITETIMEOUT=%d;", m_write_timeout); } - - void set_auth_host(const std::string& auth_host) { - m_auth_host = auth_host; - is_set_auth_host = true; + if (is_set_auth_mode) { + length += sprintf(conn_in + length, "AUTHENTICATION_MODE=%s;", m_auth_mode.c_str()); } - - void set_auth_port(const int& auth_port) { - m_auth_port = auth_port; - is_set_auth_port = true; + if (is_set_auth_region) { + length += sprintf(conn_in + length, "AWS_REGION=%s;", m_auth_region.c_str()); } - - void set_auth_expiration(const int& auth_expiration) { - m_auth_expiration = auth_expiration; - is_set_auth_expiration = true; + if (is_set_auth_host) { + length += sprintf(conn_in + length, "IAM_HOST=%s;", m_auth_host.c_str()); } - - void set_secret_id(const std::string& secret_id) { - m_secret_id = secret_id; - is_set_secret_id = true; + if (is_set_auth_port) { + length += sprintf(conn_in + length, "IAM_PORT=%d;", m_auth_port); } -}; - -class ConnectionStringBuilder { - public: - ConnectionStringBuilder() { - connection_string.reset(new ConnectionString()); + if (is_set_auth_expiration) { + length += sprintf(conn_in + length, "IAM_EXPIRATION_TIME=%d;", m_auth_expiration); } - - ConnectionStringBuilder& withDSN(const std::string& dsn) { - connection_string->set_dsn(dsn); - return *this; + if (is_set_secret_id) { + length += sprintf(conn_in + length, "SECRET_ID=%s;", m_secret_id.c_str()); } - - ConnectionStringBuilder& withServer(const std::string& server) { - connection_string->set_server(server); - return *this; + if (is_set_enable_custom_endpoint_monitoring) { + length += sprintf(conn_in + length, "ENABLE_CUSTOM_ENDPOINT_MONITORING=%d;", m_enable_custom_endpoint_monitoring ? 1 : 0); } - - ConnectionStringBuilder& withPort(const int& port) { - connection_string->set_port(port); - return *this; + if (is_set_custom_endpoint_region) { + length += sprintf(conn_in + length, "CUSTOM_ENDPOINT_REGION=%s;", m_custom_endpoint_region.c_str()); } - - ConnectionStringBuilder& withUID(const std::string& uid) { - connection_string->set_uid(uid); - return *this; + if (is_set_should_wait_for_info) { + length += sprintf(conn_in + length, "WAIT_FOR_CUSTOM_ENDPOINT_INFO=%d;", m_should_wait_for_info ? 1 : 0); } - - ConnectionStringBuilder& withPWD(const std::string& pwd) { - connection_string->set_pwd(pwd); - return *this; + if (is_set_custom_endpoint_info_refresh_rate_ms) { + length += sprintf(conn_in + length, "CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS=%d;", m_custom_endpoint_info_refresh_rate_ms); } - - ConnectionStringBuilder& withDatabase(const std::string& db) { - connection_string->set_db(db); - return *this; + if (is_set_wait_on_cached_info_duration_ms) { + length += sprintf(conn_in + length, "WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS=%ld;", m_wait_on_cached_info_duration_ms); } - - ConnectionStringBuilder& withLogQuery(const bool& log_query) { - connection_string->set_log_query(log_query); - return *this; + if (is_set_idle_monitor_expiration_ms) { + length += sprintf(conn_in + length, "CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS=%ld;", m_idle_monitor_expiration_ms); } + snprintf(conn_in + length, sizeof(conn_in) - length, "\0"); - ConnectionStringBuilder& withFailoverMode(const std::string& failover_mode) { - connection_string->set_failover_mode(failover_mode); - return *this; - } + std::string connection_string(conn_in); + return connection_string; + } - ConnectionStringBuilder& withMultiStatements(const bool& multi_statements) { - connection_string->set_multi_statements(multi_statements); - return *this; - } + private: + // Required fields + std::string m_dsn, m_server; + int m_port; - ConnectionStringBuilder& withEnableClusterFailover(const bool& enable_cluster_failover) { - connection_string->set_enable_cluster_failover(enable_cluster_failover); - return *this; - } - - ConnectionStringBuilder& withFailoverTimeout(const int& failover_t) { - connection_string->set_failover_timeout(failover_t); - return *this; - } - - ConnectionStringBuilder& withConnectTimeout(const int& connect_timeout) { - connection_string->set_connect_timeout(connect_timeout); - return *this; - } - - ConnectionStringBuilder& withNetworkTimeout(const int& network_timeout) { - connection_string->set_network_timeout(network_timeout); - return *this; - } - - ConnectionStringBuilder& withHostPattern(const std::string& host_pattern) { - connection_string->set_host_pattern(host_pattern); - return *this; - } - - ConnectionStringBuilder& withEnableFailureDetection(const bool& enable_failure_detection) { - connection_string->set_enable_failure_detection(enable_failure_detection); - return *this; - } - - ConnectionStringBuilder& withFailureDetectionTime(const int& failure_detection_time) { - connection_string->set_failure_detection_time(failure_detection_time); - return *this; - } - - ConnectionStringBuilder& withFailureDetectionTimeout(const int& failure_detection_timeout) { - connection_string->set_failure_detection_timeout(failure_detection_timeout); - return *this; - } - - ConnectionStringBuilder& withFailureDetectionInterval(const int& failure_detection_interval) { - connection_string->set_failure_detection_interval(failure_detection_interval); - return *this; - } - - ConnectionStringBuilder& withFailureDetectionCount(const int& failure_detection_count) { - connection_string->set_failure_detection_count(failure_detection_count); - return *this; - } - - ConnectionStringBuilder& withMonitorDisposalTime(const int& monitor_disposal_time) { - connection_string->set_monitor_disposal_time(monitor_disposal_time); - return *this; - } - - ConnectionStringBuilder& withReadTimeout(const int& read_timeout) { - connection_string->set_read_timeout(read_timeout); - return *this; - } - - ConnectionStringBuilder& withWriteTimeout(const int& write_timeout) { - connection_string->set_write_timeout(write_timeout); - return *this; - } - - ConnectionStringBuilder& withAuthMode(const std::string& auth_mode) { - connection_string->set_auth_mode(auth_mode); - return *this; - } - - ConnectionStringBuilder& withAuthRegion(const std::string& auth_region) { - connection_string->set_auth_region(auth_region); - return *this; - } - - ConnectionStringBuilder& withAuthHost(const std::string& auth_host) { - connection_string->set_auth_host(auth_host); - return *this; - } - - ConnectionStringBuilder& withAuthPort(const int& auth_port) { - connection_string->set_auth_port(auth_port); - return *this; - } - - ConnectionStringBuilder& withAuthExpiration(const int& auth_expiration) { - connection_string->set_auth_expiration(auth_expiration); - return *this; - } - - ConnectionStringBuilder& withSecretId(const std::string& secret_id) { - connection_string->set_secret_id(secret_id); - return *this; - } + // Optional fields + std::string m_uid, m_pwd, m_db; + bool m_log_query, m_multi_statements, m_enable_cluster_failover; + int m_failover_timeout, m_connect_timeout, m_network_timeout; + std::string m_host_pattern, m_failover_mode; + bool m_enable_failure_detection; + int m_failure_detection_time, m_failure_detection_timeout, m_failure_detection_interval, m_failure_detection_count, + m_monitor_disposal_time, m_read_timeout, m_write_timeout; + std::string m_auth_mode, m_auth_region, m_auth_host, m_secret_id, m_custom_endpoint_region; + int m_auth_port, m_auth_expiration; + bool m_enable_custom_endpoint_monitoring, m_should_wait_for_info; + long m_custom_endpoint_info_refresh_rate_ms, m_wait_on_cached_info_duration_ms, m_idle_monitor_expiration_ms; + + bool is_set_uid, is_set_pwd, is_set_db; + bool is_set_log_query, is_set_failover_mode, is_set_multi_statements; + bool is_set_enable_cluster_failover; + bool is_set_failover_timeout, is_set_connect_timeout, is_set_network_timeout; + bool is_set_host_pattern; + bool is_set_enable_failure_detection; + bool is_set_failure_detection_time, is_set_failure_detection_timeout, is_set_failure_detection_interval, + is_set_failure_detection_count; + bool is_set_monitor_disposal_time; + bool is_set_read_timeout, is_set_write_timeout; + bool is_set_auth_mode, is_set_auth_region, is_set_auth_host, is_set_auth_port, is_set_auth_expiration, + is_set_secret_id; + bool is_set_enable_custom_endpoint_monitoring, is_set_should_wait_for_info, is_set_custom_endpoint_region; + bool is_set_custom_endpoint_info_refresh_rate_ms, is_set_wait_on_cached_info_duration_ms, is_set_idle_monitor_expiration_ms; + + void set_dsn(const std::string& dsn) { m_dsn = dsn; } + + void set_server(const std::string& server) { m_server = server; } + + void set_port(const int& port) { m_port = port; } + + void set_uid(const std::string& uid) { + m_uid = uid; + is_set_uid = true; + } + + void set_pwd(const std::string& pwd) { + m_pwd = pwd; + is_set_pwd = true; + } + + void set_db(const std::string& db) { + m_db = db; + is_set_db = true; + } + + void set_log_query(const bool& log_query) { + m_log_query = log_query; + is_set_log_query = true; + } + + void set_failover_mode(const std::string& failover_mode) { + m_failover_mode = failover_mode; + is_set_failover_mode = true; + } + + void set_multi_statements(const bool& multi_statements) { + m_multi_statements = multi_statements; + is_set_multi_statements = true; + } + + void set_enable_cluster_failover(const bool& enable_cluster_failover) { + m_enable_cluster_failover = enable_cluster_failover; + is_set_enable_cluster_failover = true; + } + + void set_failover_timeout(const int& failover_timeout) { + m_failover_timeout = failover_timeout; + is_set_failover_timeout = true; + } + + void set_connect_timeout(const int& connect_timeout) { + m_connect_timeout = connect_timeout; + is_set_connect_timeout = true; + } + + void set_network_timeout(const int& network_timeout) { + m_network_timeout = network_timeout; + is_set_network_timeout = true; + } + + void set_host_pattern(const std::string& host_pattern) { + m_host_pattern = host_pattern; + is_set_host_pattern = true; + } + + void set_enable_failure_detection(const bool& enable_failure_detection) { + m_enable_failure_detection = enable_failure_detection; + is_set_enable_failure_detection = true; + } + + void set_failure_detection_time(const int& failure_detection_time) { + m_failure_detection_time = failure_detection_time; + is_set_failure_detection_time = true; + } + + void set_failure_detection_timeout(const int& failure_detection_timeout) { + m_failure_detection_timeout = failure_detection_timeout; + is_set_failure_detection_timeout = true; + } + + void set_failure_detection_interval(const int& failure_detection_interval) { + m_failure_detection_interval = failure_detection_interval; + is_set_failure_detection_interval = true; + } + + void set_failure_detection_count(const int& failure_detection_count) { + m_failure_detection_count = failure_detection_count; + is_set_failure_detection_count = true; + } + + void set_monitor_disposal_time(const int& monitor_disposal_time) { + m_monitor_disposal_time = monitor_disposal_time; + is_set_monitor_disposal_time = true; + } + + void set_read_timeout(const int& read_timeout) { + m_read_timeout = read_timeout; + is_set_read_timeout = true; + } + + void set_write_timeout(const int& write_timeout) { + m_write_timeout = write_timeout; + is_set_write_timeout = true; + } + + void set_auth_mode(const std::string& auth_mode) { + m_auth_mode = auth_mode; + is_set_auth_mode = true; + } + + void set_auth_region(const std::string& auth_region) { + m_auth_region = auth_region; + is_set_auth_region = true; + } + + void set_auth_host(const std::string& auth_host) { + m_auth_host = auth_host; + is_set_auth_host = true; + } + + void set_auth_port(const int& auth_port) { + m_auth_port = auth_port; + is_set_auth_port = true; + } + + void set_auth_expiration(const int& auth_expiration) { + m_auth_expiration = auth_expiration; + is_set_auth_expiration = true; + } + + void set_secret_id(const std::string& secret_id) { + m_secret_id = secret_id; + is_set_secret_id = true; + } + + void set_enable_custom_endpoint_monitoring(const bool& enable_custom_endpoint_monitoring) { + m_enable_custom_endpoint_monitoring = enable_custom_endpoint_monitoring; + is_set_enable_custom_endpoint_monitoring = true; + } + + void set_custom_endpoint_monitoring_region(const std::string& region) { + m_custom_endpoint_region = region; + is_set_custom_endpoint_region = true; + } + + void set_should_wait_for_info(const bool& wait_for_info) { + m_should_wait_for_info = wait_for_info; + is_set_should_wait_for_info = true; + } + + void set_custom_endpoint_info_refresh_rate_ms(const long& custom_endpoint_info_refresh_rate_ms) { + m_custom_endpoint_info_refresh_rate_ms = custom_endpoint_info_refresh_rate_ms; + is_set_custom_endpoint_info_refresh_rate_ms = true; + } + + void set_wait_on_cached_info_duration_ms(const long& wait_on_cached_info_duration_ms) { + m_wait_on_cached_info_duration_ms = wait_on_cached_info_duration_ms; + is_set_wait_on_cached_info_duration_ms = true; + } + + void set_idle_monitor_expiration_ms(const long& idle_monitor_expiration_ms) { + m_idle_monitor_expiration_ms = idle_monitor_expiration_ms; + is_set_idle_monitor_expiration_ms = true; + } +}; - std::string build() const { - if (connection_string->m_dsn.empty()) { - throw std::runtime_error("DSN is a required field in a connection string."); - } - if (connection_string->m_server.empty()) { - throw std::runtime_error("Server is a required field in a connection string."); - } - if (connection_string->m_port < 1) { - throw std::runtime_error("Port is a required field in a connection string."); - } - return connection_string->get_connection_string(); - } - - private: - std::unique_ptr connection_string; +class ConnectionStringBuilder { + public: + ConnectionStringBuilder() { connection_string.reset(new ConnectionString()); } + + ConnectionStringBuilder& withDSN(const std::string& dsn) { + connection_string->set_dsn(dsn); + return *this; + } + + ConnectionStringBuilder& withServer(const std::string& server) { + connection_string->set_server(server); + return *this; + } + + ConnectionStringBuilder& withPort(const int& port) { + connection_string->set_port(port); + return *this; + } + + ConnectionStringBuilder& withUID(const std::string& uid) { + connection_string->set_uid(uid); + return *this; + } + + ConnectionStringBuilder& withPWD(const std::string& pwd) { + connection_string->set_pwd(pwd); + return *this; + } + + ConnectionStringBuilder& withDatabase(const std::string& db) { + connection_string->set_db(db); + return *this; + } + + ConnectionStringBuilder& withLogQuery(const bool& log_query) { + connection_string->set_log_query(log_query); + return *this; + } + + ConnectionStringBuilder& withFailoverMode(const std::string& failover_mode) { + connection_string->set_failover_mode(failover_mode); + return *this; + } + + ConnectionStringBuilder& withMultiStatements(const bool& multi_statements) { + connection_string->set_multi_statements(multi_statements); + return *this; + } + + ConnectionStringBuilder& withEnableClusterFailover(const bool& enable_cluster_failover) { + connection_string->set_enable_cluster_failover(enable_cluster_failover); + return *this; + } + + ConnectionStringBuilder& withFailoverTimeout(const int& failover_t) { + connection_string->set_failover_timeout(failover_t); + return *this; + } + + ConnectionStringBuilder& withConnectTimeout(const int& connect_timeout) { + connection_string->set_connect_timeout(connect_timeout); + return *this; + } + + ConnectionStringBuilder& withNetworkTimeout(const int& network_timeout) { + connection_string->set_network_timeout(network_timeout); + return *this; + } + + ConnectionStringBuilder& withHostPattern(const std::string& host_pattern) { + connection_string->set_host_pattern(host_pattern); + return *this; + } + + ConnectionStringBuilder& withEnableFailureDetection(const bool& enable_failure_detection) { + connection_string->set_enable_failure_detection(enable_failure_detection); + return *this; + } + + ConnectionStringBuilder& withFailureDetectionTime(const int& failure_detection_time) { + connection_string->set_failure_detection_time(failure_detection_time); + return *this; + } + + ConnectionStringBuilder& withFailureDetectionTimeout(const int& failure_detection_timeout) { + connection_string->set_failure_detection_timeout(failure_detection_timeout); + return *this; + } + + ConnectionStringBuilder& withFailureDetectionInterval(const int& failure_detection_interval) { + connection_string->set_failure_detection_interval(failure_detection_interval); + return *this; + } + + ConnectionStringBuilder& withFailureDetectionCount(const int& failure_detection_count) { + connection_string->set_failure_detection_count(failure_detection_count); + return *this; + } + + ConnectionStringBuilder& withMonitorDisposalTime(const int& monitor_disposal_time) { + connection_string->set_monitor_disposal_time(monitor_disposal_time); + return *this; + } + + ConnectionStringBuilder& withReadTimeout(const int& read_timeout) { + connection_string->set_read_timeout(read_timeout); + return *this; + } + + ConnectionStringBuilder& withWriteTimeout(const int& write_timeout) { + connection_string->set_write_timeout(write_timeout); + return *this; + } + + ConnectionStringBuilder& withAuthMode(const std::string& auth_mode) { + connection_string->set_auth_mode(auth_mode); + return *this; + } + + ConnectionStringBuilder& withAuthRegion(const std::string& auth_region) { + connection_string->set_auth_region(auth_region); + return *this; + } + + ConnectionStringBuilder& withAuthHost(const std::string& auth_host) { + connection_string->set_auth_host(auth_host); + return *this; + } + + ConnectionStringBuilder& withAuthPort(const int& auth_port) { + connection_string->set_auth_port(auth_port); + return *this; + } + + ConnectionStringBuilder& withAuthExpiration(const int& auth_expiration) { + connection_string->set_auth_expiration(auth_expiration); + return *this; + } + + ConnectionStringBuilder& withSecretId(const std::string& secret_id) { + connection_string->set_secret_id(secret_id); + return *this; + } + + ConnectionStringBuilder& withEnableCustomEndpointMonitoring(const bool& enable_custom_endpoint_monitoring) { + connection_string->set_enable_custom_endpoint_monitoring(enable_custom_endpoint_monitoring); + return *this; + } + ConnectionStringBuilder& withCustomEndpointRegion(const std::string& region) { + connection_string->set_custom_endpoint_monitoring_region(region); + return *this; + } + + ConnectionStringBuilder& withShouldWaitForInfo(const bool& should_wait_for_info) { + connection_string->set_should_wait_for_info(should_wait_for_info); + return *this; + } + + ConnectionStringBuilder& withCustomEndpointInfoRefreshRateMs(const long& custom_endpoint_info_refresh_rate_ms) { + connection_string->set_custom_endpoint_info_refresh_rate_ms(custom_endpoint_info_refresh_rate_ms); + return *this; + } + + ConnectionStringBuilder& withWaitOnCachedInfoDurationMs(const long& wait_on_cached_info_duration_ms) { + connection_string->set_wait_on_cached_info_duration_ms(wait_on_cached_info_duration_ms); + return *this; + } + + ConnectionStringBuilder& withIdleMonitorExpirationMs(const long& idle_monitor_expiration_ms) { + connection_string->set_idle_monitor_expiration_ms(idle_monitor_expiration_ms); + return *this; + } + + std::string build() const { + if (connection_string->m_dsn.empty()) { + throw std::runtime_error("DSN is a required field in a connection string."); + } + if (connection_string->m_server.empty()) { + throw std::runtime_error("Server is a required field in a connection string."); + } + if (connection_string->m_port < 1) { + throw std::runtime_error("Port is a required field in a connection string."); + } + return connection_string->get_connection_string(); + } + + private: + std::unique_ptr connection_string; }; #endif /* __CONNECTIONSTRINGBUILDER_H__ */ diff --git a/integration/custom_endpoint_integration_test.cc b/integration/custom_endpoint_integration_test.cc new file mode 100644 index 000000000..5e687c094 --- /dev/null +++ b/integration/custom_endpoint_integration_test.cc @@ -0,0 +1,178 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "base_failover_integration_test.cc" +#include +#include +#include + +class CustomEndpointIntegrationTest : public BaseFailoverIntegrationTest { + protected: + std::string ACCESS_KEY = std::getenv("AWS_ACCESS_KEY_ID"); + std::string SECRET_ACCESS_KEY = std::getenv("AWS_SECRET_ACCESS_KEY"); + std::string SESSION_TOKEN = std::getenv("AWS_SESSION_TOKEN"); + std::string RDS_ENDPOINT = std::getenv("RDS_ENDPOINT"); + std::string RDS_REGION = std::getenv("RDS_REGION"); + std::string ENDPOINT_ID = + "test-endpoint-1-" + std::to_string(std::chrono::steady_clock::now().time_since_epoch().count()); + std::string region = "us-east-2"; + Aws::RDS::RDSClientConfiguration client_config; + Aws::RDS::RDSClient rds_client; + SQLHENV env = nullptr; + SQLHDBC dbc = nullptr; + + bool is_endpoint_created = false; + + static void SetUpTestSuite() { Aws::InitAPI(options); } + + static void TearDownTestSuite() { Aws::ShutdownAPI(options); } + void SetUp() override { + SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env); + SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(SQL_OV_ODBC3), 0); + SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc); + + Aws::Auth::AWSCredentials credentials = + SESSION_TOKEN.empty() ? Aws::Auth::AWSCredentials(Aws::String(ACCESS_KEY), Aws::String(SECRET_ACCESS_KEY)) + : Aws::Auth::AWSCredentials(Aws::String(ACCESS_KEY), Aws::String(SECRET_ACCESS_KEY), + Aws::String(SESSION_TOKEN)); + if (!RDS_REGION.empty()) { + region = RDS_REGION; + } + client_config.region = region; + if (!RDS_ENDPOINT.empty()) { + client_config.endpointOverride = RDS_ENDPOINT; + } + rds_client = Aws::RDS::RDSClient(credentials, client_config); + + cluster_instances = retrieve_topology_via_SDK(rds_client, cluster_id); + writer_id = get_writer_id(cluster_instances); + writer_endpoint = get_endpoint(writer_id); + readers = get_readers(cluster_instances); + reader_id = get_first_reader_id(cluster_instances); + reader_endpoint = get_proxied_endpoint(reader_id); + target_writer_id = get_random_DB_cluster_reader_instance_id(readers); + + builder = ConnectionStringBuilder(); + builder.withPort(MYSQL_PORT).withLogQuery(true).withEnableFailureDetection(true); + + if (!is_endpoint_created) { + const std::vector writer{writer_id}; + create_custom_endpoint(cluster_id, writer); + wait_until_endpoint_available(cluster_id); + } + } + + void create_custom_endpoint(const std::string& cluster_id, const std::vector& writer) const { + Aws::RDS::Model::CreateDBClusterEndpointRequest rds_req; + rds_req.SetDBClusterEndpointIdentifier(ENDPOINT_ID); + rds_req.SetDBClusterIdentifier(cluster_id); + rds_req.SetEndpointType("ANY"); + rds_req.SetStaticMembers(writer); + rds_client.CreateDBClusterEndpoint(rds_req); + } + + void delete_custom_endpoint() { + Aws::RDS::Model::DeleteDBClusterEndpointRequest rds_req; + rds_req.SetDBClusterEndpointIdentifier(ENDPOINT_ID); + rds_client.DeleteDBClusterEndpoint(rds_req); + } + + /** + * Wait up to 5 minutes for the new custom endpoint to become unavailable. + */ + void wait_until_endpoint_available(const std::string& cluster_id) const { + const std::chrono::steady_clock::time_point end_time = std::chrono::steady_clock::now() + std::chrono::minutes(5); + bool is_available = false; + + Aws::String status = get_DB_cluster(rds_client, cluster_id).GetStatus(); + + while (std::chrono::steady_clock::now() < end_time) { + const auto endpoint_info = get_custom_endpoint_info(rds_client, ENDPOINT_ID); + is_available = endpoint_info.GetStatus() == "available"; + if (is_available) { + break; + } + std::this_thread::sleep_for(std::chrono::seconds(3)); + } + + if (!is_available) { + throw std::runtime_error( + "The test setup step timed out while waiting for the custom endpoint to become available: " + ENDPOINT_ID); + } + } + + void TearDown() override { + delete_custom_endpoint(); + + if (nullptr != dbc) { + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + } + if (nullptr != env) { + SQLFreeHandle(SQL_HANDLE_ENV, env); + } + } +}; + +TEST_F(CustomEndpointIntegrationTest, test_CustomeEndpointFailover) { + const auto endpoint_info = get_custom_endpoint_info(rds_client, ENDPOINT_ID); + + connection_string = builder.withDSN(dsn) + .withServer(endpoint_info.GetEndpoint()) + .withUID(user) + .withPWD(pwd) + .withDatabase(db) + .withFailoverMode("reader or writer") + .withEnableCustomEndpointMonitoring(true) + .withCustomEndpointRegion(region) + .build(); + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, + MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT)); + + const std::vector endpoint_members = endpoint_info.GetStaticMembers(); + for (const auto& member : endpoint_members) { + std::cout << "static members: " << member << std::endl; + } + const std::string current_connection_id = query_instance_id(dbc); + std::cout << "current connection id: " << current_connection_id << std::endl; + EXPECT_NE(std::find(endpoint_members.begin(), endpoint_members.end(), current_connection_id), endpoint_members.end()); + + failover_cluster_and_wait_until_writer_changed( + rds_client, cluster_id, writer_id, current_connection_id == writer_id ? target_writer_id : current_connection_id); + + assert_query_failed(dbc, SERVER_ID_QUERY, ERROR_COMM_LINK_CHANGED); + + const std::string new_connection_id = query_instance_id(dbc); + std::cout << "new connection id: " << new_connection_id << std::endl; + + EXPECT_NE(std::find(endpoint_members.begin(), endpoint_members.end(), new_connection_id), endpoint_members.end()); + + EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc)); +} diff --git a/scripts/build_aws_sdk_unix.sh b/scripts/build_aws_sdk_unix.sh index 4e73c5fdd..2a7388b0b 100755 --- a/scripts/build_aws_sdk_unix.sh +++ b/scripts/build_aws_sdk_unix.sh @@ -40,7 +40,7 @@ AWS_INSTALL_DIR=$AWS_SRC_DIR/../install mkdir -p $AWS_SRC_DIR $AWS_BUILD_DIR $AWS_INSTALL_DIR -git clone --recurse-submodules -b "1.11.394" "https://github.com/aws/aws-sdk-cpp.git" $AWS_SRC_DIR +git clone --recurse-submodules -b "1.11.488" "https://github.com/aws/aws-sdk-cpp.git" $AWS_SRC_DIR cmake -S $AWS_SRC_DIR -B $AWS_BUILD_DIR -DCMAKE_INSTALL_PREFIX="${AWS_INSTALL_DIR}" -DCMAKE_BUILD_TYPE="${CONFIGURATION}" -DBUILD_ONLY="rds;secretsmanager;sts" -DENABLE_TESTING="OFF" -DBUILD_SHARED_LIBS="ON" -DCPP_STANDARD="14" cd $AWS_BUILD_DIR diff --git a/scripts/build_aws_sdk_win.ps1 b/scripts/build_aws_sdk_win.ps1 index 2ec5c4935..d887a8cab 100644 --- a/scripts/build_aws_sdk_win.ps1 +++ b/scripts/build_aws_sdk_win.ps1 @@ -44,7 +44,7 @@ Write-Host $args # Make AWS SDK source directory New-Item -Path $SRC_DIR -ItemType Directory -Force | Out-Null # Clone the AWS SDK CPP repo -git clone --recurse-submodules -b "1.11.394" "https://github.com/aws/aws-sdk-cpp.git" $SRC_DIR +git clone --recurse-submodules -b "1.11.488" "https://github.com/aws/aws-sdk-cpp.git" $SRC_DIR # Make and move to build directory New-Item -Path $BUILD_DIR -ItemType Directory -Force | Out-Null diff --git a/unit_testing/CMakeLists.txt b/unit_testing/CMakeLists.txt index 1d37f23fe..c8184dc78 100644 --- a/unit_testing/CMakeLists.txt +++ b/unit_testing/CMakeLists.txt @@ -57,6 +57,8 @@ add_executable( adfs_proxy_test.cc cluster_aware_metrics_test.cc + custom_endpoint_monitor_test.cc + custom_endpoint_proxy_test.cc efm_proxy_test.cc iam_proxy_test.cc failover_handler_test.cc diff --git a/unit_testing/custom_endpoint_monitor_test.cc b/unit_testing/custom_endpoint_monitor_test.cc new file mode 100644 index 000000000..3ee3c94bb --- /dev/null +++ b/unit_testing/custom_endpoint_monitor_test.cc @@ -0,0 +1,109 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include + +#include + +#include +#include + +#include "driver/custom_endpoint_monitor.h" +#include "test_utils.h" +#include "mock_objects.h" + +using namespace Aws::RDS; + +using ::testing::_; +using ::testing::Return; +using ::testing::StrEq; + +namespace { +const std::string WRITER_CLUSTER_URL{"writer.cluster-XYZ.us-east-1.rds.amazonaws.com"}; +const std::string CUSTOM_ENDPOINT_URL{"custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"}; +const auto ENDPOINT = Aws::Utils::Json::JsonValue(CUSTOM_ENDPOINT_URL); + +const long long REFRESH_RATE_NANOS = 50000000; +} // namespace + +static SQLHENV env; +static Aws::SDKOptions sdk_options; + +class CustomEndpointMonitorTest : public testing::Test { + protected: + DBC* dbc; + DataSource* ds; + MOCK_CONNECTION_PROXY* mock_connection_proxy; + std::shared_ptr mock_rds_client; + std::shared_ptr mock_topology_service; + + static void SetUpTestSuite() { + Aws::InitAPI(sdk_options); + } + + static void TearDownTestSuite() { + Aws::ShutdownAPI(sdk_options); + mysql_library_end(); + } + + void SetUp() override { + allocate_odbc_handles(env, dbc, ds); + mock_rds_client = std::make_shared(); + mock_connection_proxy = new MOCK_CONNECTION_PROXY(dbc, ds); + } + + void TearDown() override { + cleanup_odbc_handles(env, dbc, ds); + TEST_UTILS::get_custom_endpoint_cache().clear(); + delete mock_connection_proxy; + } +}; + +TEST_F(CustomEndpointMonitorTest, TestRun) { + Model::DBClusterEndpoint endpoint; + endpoint.AddStaticMembers(CUSTOM_ENDPOINT_URL); + std::vector endpoints{endpoint}; + + const auto expected_result = Model::DescribeDBClusterEndpointsResult().WithDBClusterEndpoints(endpoints); + const auto expected_outcome = Model::DescribeDBClusterEndpointsOutcome(expected_result); + + EXPECT_CALL(*mock_rds_client, DescribeDBClusterEndpoints(Property( + "GetDBClusterEndpointIdentifier", + &Model::DescribeDBClusterEndpointsRequest::GetDBClusterEndpointIdentifier, + StrEq("custom")))) + .WillRepeatedly(Return(expected_outcome)); + + CUSTOM_ENDPOINT_MONITOR monitor(mock_topology_service, CUSTOM_ENDPOINT_URL, "custom", "us-east-1", REFRESH_RATE_NANOS, + true, + mock_rds_client); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + monitor.stop(); +} diff --git a/unit_testing/custom_endpoint_proxy_test.cc b/unit_testing/custom_endpoint_proxy_test.cc new file mode 100644 index 000000000..41d7dc75b --- /dev/null +++ b/unit_testing/custom_endpoint_proxy_test.cc @@ -0,0 +1,104 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include + +#include "test_utils.h" +#include "mock_objects.h" + +using ::testing::_; +using ::testing::Return; +using ::testing::StrEq; + +namespace { +const std::string WRITER_CLUSTER_URL{"writer.cluster-XYZ.us-east-1.rds.amazonaws.com"}; +const std::string CUSTOM_ENDPOINT_URL{"custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"}; +} // namespace + +static SQLHENV env; +static Aws::SDKOptions options; + +class CustomEndpointProxyTest : public testing::Test { + protected: + DBC* dbc; + DataSource* ds; + MOCK_CONNECTION_PROXY* mock_connection_proxy; + std::shared_ptr mock_monitor = std::make_shared(); + + static void SetUpTestSuite() {} + + static void TearDownTestSuite() { mysql_library_end(); } + + void SetUp() override { + allocate_odbc_handles(env, dbc, ds); + ds->opt_ENABLE_CUSTOM_ENDPOINT_MONITORING = true; + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO = true; + ds->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS = 60000; + + mock_connection_proxy = new MOCK_CONNECTION_PROXY(dbc, ds); + } + + void TearDown() override { + TEST_UTILS::get_custom_endpoint_monitor_cache().release_resources(); + cleanup_odbc_handles(env, dbc, ds); + } +}; + +TEST_F(CustomEndpointProxyTest, TestConnect_MonitorNotCreatedIfNotCustomEndpointHost) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + custom_endpoint_proxy.connect(WRITER_CLUSTER_URL.c_str(), "", "", "", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(0); +} + +TEST_F(CustomEndpointProxyTest, TestConnect_MonitorCreated) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + EXPECT_EQ(0, TEST_UTILS::get_custom_endpoint_monitor_cache().size()); + EXPECT_CALL(custom_endpoint_proxy, create_custom_endpoint_monitor(_)).WillOnce(Return(mock_monitor)); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(1); + custom_endpoint_proxy.connect(CUSTOM_ENDPOINT_URL.c_str(), "", "", "", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 1); +} + +TEST_F(CustomEndpointProxyTest, TestConnect_TimeoutWaitingForInfo) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = 100; + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + EXPECT_CALL(custom_endpoint_proxy, create_custom_endpoint_monitor(_)).WillOnce(Return(mock_monitor)); + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = 1; + custom_endpoint_proxy.connect(CUSTOM_ENDPOINT_URL.c_str(), "user", "pwd", "db", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 1); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(0); +} diff --git a/unit_testing/failover_handler_test.cc b/unit_testing/failover_handler_test.cc index f86f1ce2b..60c561bde 100644 --- a/unit_testing/failover_handler_test.cc +++ b/unit_testing/failover_handler_test.cc @@ -43,6 +43,7 @@ using ::testing::StrEq; namespace { const std::string US_EAST_REGION_CLUSTER = "database-test-name.cluster-XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_CLUSTER_READ_ONLY = "database-test-name.cluster-ro-XYZ.us-east-2.rds.amazonaws.com"; + const std::string US_EAST_REGION_INSTANCE = "instance-test-name.XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_PROXY = "proxy-test-name.proxy-XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_CUSTON_DOMAIN = "custom-test-name.cluster-custom-XYZ.us-east-2.rds.amazonaws.com"; @@ -424,6 +425,26 @@ TEST_F(FailoverHandlerTest, GetRdsClusterHostUrl) { EXPECT_EQ(std::string(), TEST_UTILS::get_rds_cluster_host_url(CHINA_REGION_CUSTON_DOMAIN)); } +TEST_F(FailoverHandlerTest, GetRdsClusterId) { + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CLUSTER)); + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_INSTANCE)); + + EXPECT_EQ("proxy-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_PROXY)); + EXPECT_EQ("custom-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CUSTON_DOMAIN)); + + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CLUSTER)); + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ("proxy-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_PROXY)); + EXPECT_EQ("custom-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CUSTON_DOMAIN)); +} + +TEST_F(FailoverHandlerTest, GetRdsInstanceId) { + EXPECT_EQ("instance-test-name", TEST_UTILS::get_rds_instance_id(US_EAST_REGION_INSTANCE)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_instance_id(US_EAST_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_instance_id(US_EAST_REGION_CLUSTER)); +} + TEST_F(FailoverHandlerTest, ConnectToNewWriter) { std::string server = "my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com"; ds->opt_SERVER.set_remove_brackets((SQLWCHAR*)to_sqlwchar_string(server).c_str(), server.size()); diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index 80ebce772..f684ec3d9 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -35,6 +35,7 @@ #include #include "driver/connection_proxy.h" +#include "driver/custom_endpoint_proxy.h" #include "driver/failover.h" #include "driver/saml_http_client.h" #include "driver/monitor_thread_container.h" @@ -222,6 +223,13 @@ class MOCK_SECRETS_MANAGER_CLIENT : public Aws::SecretsManager::SecretsManagerCl MOCK_METHOD(Aws::SecretsManager::Model::GetSecretValueOutcome, GetSecretValue, (const Aws::SecretsManager::Model::GetSecretValueRequest&), (const)); }; +class MOCK_RDS_CLIENT : public Aws::RDS::RDSClient { +public: + MOCK_RDS_CLIENT() : RDSClient(){}; + + MOCK_METHOD(Aws::RDS::Model::DescribeDBClusterEndpointsOutcome, DescribeDBClusterEndpoints, (const Aws::RDS::Model::DescribeDBClusterEndpointsRequest&), (const)); +}; + class MOCK_AUTH_UTIL : public AUTH_UTIL { public: MOCK_AUTH_UTIL() : AUTH_UTIL() {}; @@ -234,4 +242,16 @@ class MOCK_SAML_HTTP_CLIENT : public SAML_HTTP_CLIENT { MOCK_METHOD(nlohmann::json, post, (const std::string&, const std::string&, const std::string&)); MOCK_METHOD(nlohmann::json, get, (const std::string&, const httplib::Headers&)); }; + +class TEST_CUSTOM_ENDPOINT_PROXY : public CUSTOM_ENDPOINT_PROXY { +public: + TEST_CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CUSTOM_ENDPOINT_PROXY(dbc, ds, next_proxy) {}; + MOCK_METHOD(std::shared_ptr, create_custom_endpoint_monitor, (const long long refresh_rate_nanos), (override)); + static int get_monitor_size() { return monitors.size(); } +}; + +class MOCK_CUSTOM_ENDPOINT_MONITOR : public CUSTOM_ENDPOINT_MONITOR { + public: + MOCK_CUSTOM_ENDPOINT_MONITOR() {}; +}; #endif /* __MOCKOBJECTS_H__ */ diff --git a/unit_testing/sliding_expiration_cache_test.cc b/unit_testing/sliding_expiration_cache_test.cc index b623c4b8a..e9364a613 100644 --- a/unit_testing/sliding_expiration_cache_test.cc +++ b/unit_testing/sliding_expiration_cache_test.cc @@ -145,7 +145,7 @@ TEST_F(SlidingExpirationCacheTest, ExpirationTimeUpdateGet) { TEST_F(SlidingExpirationCacheTest, GetCacheExpireThread) { SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD cache(cache_exp_short); - + cache.init_clean_up_thread(); EXPECT_EQ(0, cache.size()); cache.put(cache_key_a, cache_val_a, cache_exp_short); cache.put(cache_key_b, cache_val_b, cache_exp_long); diff --git a/unit_testing/test_utils.cc b/unit_testing/test_utils.cc index 3b96eb9bf..7a01cb977 100644 --- a/unit_testing/test_utils.cc +++ b/unit_testing/test_utils.cc @@ -29,6 +29,9 @@ #include "test_utils.h" +#include "driver/custom_endpoint_monitor.h" +#include "driver/custom_endpoint_proxy.h" + void allocate_odbc_handles(SQLHENV& env, DBC*& dbc, DataSource*& ds) { SQLHDBC hdbc = nullptr; @@ -155,6 +158,23 @@ std::string TEST_UTILS::get_rds_cluster_host_url(std::string host) { return RDS_UTILS::get_rds_cluster_host_url(host); } +std::string TEST_UTILS::get_rds_cluster_id(std::string host) { + return RDS_UTILS::get_rds_cluster_id(host); +} + +std::string TEST_UTILS::get_rds_instance_id(std::string host) { + return RDS_UTILS::get_rds_instance_id(host); +} + std::string TEST_UTILS::get_rds_instance_host_pattern(std::string host) { return RDS_UTILS::get_rds_instance_host_pattern(host); } + +CACHE_MAP>& TEST_UTILS::get_custom_endpoint_cache() { + return std::ref(CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache); +} + +SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD>& +TEST_UTILS::get_custom_endpoint_monitor_cache() { + return std::ref(CUSTOM_ENDPOINT_PROXY::monitors); +} diff --git a/unit_testing/test_utils.h b/unit_testing/test_utils.h index d7ce4f84f..f431ce8fc 100644 --- a/unit_testing/test_utils.h +++ b/unit_testing/test_utils.h @@ -30,46 +30,54 @@ #ifndef __TESTUTILS_H__ #define __TESTUTILS_H__ -#include "driver/auth_util.h" +#include "driver/auth_util.h" +#include "driver/cache_map.h" +#include "driver/custom_endpoint_info.h" +#include "driver/custom_endpoint_monitor.h" #include "driver/driver.h" #include "driver/failover.h" #include "driver/iam_proxy.h" -#include "driver/okta_proxy.h" #include "driver/monitor.h" #include "driver/monitor_thread_container.h" -#include "driver/secrets_manager_proxy.h" +#include "driver/okta_proxy.h" #include "driver/rds_utils.h" +#include "driver/secrets_manager_proxy.h" +#include "driver/sliding_expiration_cache_with_clean_up_thread.h" void allocate_odbc_handles(SQLHENV& env, DBC*& dbc, DataSource*& ds); void cleanup_odbc_handles(SQLHENV env, DBC*& dbc, DataSource*& ds, bool call_myodbc_end = false); class TEST_UTILS { -public: - static std::chrono::milliseconds get_connection_check_interval(std::shared_ptr monitor); - static CONNECTION_STATUS check_connection_status(std::shared_ptr monitor); - static void populate_monitor_map(std::shared_ptr container, - std::set node_keys, std::shared_ptr monitor); - static void populate_task_map(std::shared_ptr container, - std::shared_ptr monitor); - static bool has_monitor(std::shared_ptr container, std::string node_key); - static bool has_task(std::shared_ptr container, std::shared_ptr monitor); - static bool has_any_tasks(std::shared_ptr container); - static bool has_available_monitor(std::shared_ptr container); - static std::shared_ptr get_available_monitor(std::shared_ptr container); - static size_t get_map_size(std::shared_ptr container); - static std::list> get_contexts(std::shared_ptr monitor); - static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); - static bool token_cache_contains_key(std::unordered_map token_cache, std::string cache_key); - static std::map, Aws::Utils::Json::JsonValue>& get_secrets_cache(); - static bool try_parse_region_from_secret(std::string secret, std::string& region); - static bool is_dns_pattern_valid(std::string host); - static bool is_rds_dns(std::string host); - static bool is_rds_cluster_dns(std::string host); - static bool is_rds_proxy_dns(std::string host); - static bool is_rds_writer_cluster_dns(std::string host); - static bool is_rds_custom_cluster_dns(std::string host); - static std::string get_rds_cluster_host_url(std::string host); - static std::string get_rds_instance_host_pattern(std::string host); + public: + static std::chrono::milliseconds get_connection_check_interval(std::shared_ptr monitor); + static CONNECTION_STATUS check_connection_status(std::shared_ptr monitor); + static void populate_monitor_map(std::shared_ptr container, std::set node_keys, + std::shared_ptr monitor); + static void populate_task_map(std::shared_ptr container, std::shared_ptr monitor); + static bool has_monitor(std::shared_ptr container, std::string node_key); + static bool has_task(std::shared_ptr container, std::shared_ptr monitor); + static bool has_any_tasks(std::shared_ptr container); + static bool has_available_monitor(std::shared_ptr container); + static std::shared_ptr get_available_monitor(std::shared_ptr container); + static size_t get_map_size(std::shared_ptr container); + static std::list> get_contexts(std::shared_ptr monitor); + static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); + static bool token_cache_contains_key(std::unordered_map token_cache, std::string cache_key); + static std::map, Aws::Utils::Json::JsonValue>& get_secrets_cache(); + static bool try_parse_region_from_secret(std::string secret, std::string& region); + static bool is_dns_pattern_valid(std::string host); + static bool is_rds_dns(std::string host); + static bool is_rds_cluster_dns(std::string host); + static bool is_rds_proxy_dns(std::string host); + static bool is_rds_writer_cluster_dns(std::string host); + static bool is_rds_custom_cluster_dns(std::string host); + static std::string get_rds_cluster_host_url(std::string host); + static std::string get_rds_cluster_id(std::string host); + static std::string get_rds_instance_id(std::string host); + static std::string get_rds_instance_host_pattern(std::string host); + static CACHE_MAP>& get_custom_endpoint_cache(); + static SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD>& + get_custom_endpoint_monitor_cache(); }; #endif /* __TESTUTILS_H__ */