From c1c6f95589381e4f856e6f34a83a7926dc37007f Mon Sep 17 00:00:00 2001 From: Yadu Date: Sat, 13 Apr 2024 23:04:50 +0800 Subject: [PATCH] Fix race conditions in rmw_wait and map queries to clients (#153) * Fix a race condition in rmw_wait. To very briefly explain, rmw_wait: 1. Checks to see if any of the entities (subscriptions, clients, etc) have data ready to go. 2. If they have data ready to go, then we skip attaching the condition variable and waiting. 3. If they do not have data ready to go, then we attach the condition variable to all entities, take the condition variable lock, and call wait_for/wait on the condition variable. 4. Regardless of whether we did 3 or 4, we check every entity to see if there is data ready, and mark that as appropriate in the wait set. There is a race in all of this, however. If data comes in after we've checked the entity (1), but before we've attached the condition variable (3), then we will never be woken up. In most cases, this means that we'll wait the full timeout for the wait_for, which is not what we want. Fix this by adding another step to 3. In particular, after we've locked the condition variable mutex, check the entities again. Since we change the entities to *also* take the lock before we notify, this ensures that the entities cannot make changes that get lost. Signed-off-by: Chris Lalancette * Small update to a comment. Signed-off-by: Chris Lalancette * Don't return an error if we can't find a number in the sequence map. I'm not really sure that this is correct, but do it for now. Signed-off-by: Chris Lalancette * Fix query queue for multiple clients. In particular, make sure that we track requests from individual clients separately so that we don't mix them up. To do that, we store the client gid in the server set along with the sequence_number and Query itself. Signed-off-by: Chris Lalancette * Finish changes Signed-off-by: Yadunund * Tweak api to store and retrieve query Signed-off-by: Yadunund * Lint Signed-off-by: Yadunund --------- Signed-off-by: Chris Lalancette Signed-off-by: Yadunund Co-authored-by: Chris Lalancette --- rmw_zenoh_cpp/src/detail/event.cpp | 10 +- rmw_zenoh_cpp/src/detail/event.hpp | 4 +- rmw_zenoh_cpp/src/detail/guard_condition.cpp | 7 +- rmw_zenoh_cpp/src/detail/guard_condition.hpp | 5 +- rmw_zenoh_cpp/src/detail/rmw_data_types.cpp | 107 +++++++++++++++---- rmw_zenoh_cpp/src/detail/rmw_data_types.hpp | 25 +++-- rmw_zenoh_cpp/src/rmw_zenoh.cpp | 68 +++++++----- 7 files changed, 160 insertions(+), 66 deletions(-) diff --git a/rmw_zenoh_cpp/src/detail/event.cpp b/rmw_zenoh_cpp/src/detail/event.cpp index 6cb83f62..94dbbea2 100644 --- a/rmw_zenoh_cpp/src/detail/event.cpp +++ b/rmw_zenoh_cpp/src/detail/event.cpp @@ -184,6 +184,7 @@ void EventsManager::add_new_event( ///============================================================================= void EventsManager::attach_event_condition( rmw_zenoh_event_type_t event_id, + std::mutex * condition_mutex, std::condition_variable * condition_variable) { if (event_id > ZENOH_EVENT_ID_MAX) { @@ -194,7 +195,8 @@ void EventsManager::attach_event_condition( return; } - std::lock_guard lock(event_condition_mutex_); + std::lock_guard lock(update_event_condition_mutex_); + event_condition_mutexes_[event_id] = condition_mutex; event_conditions_[event_id] = condition_variable; } @@ -209,7 +211,8 @@ void EventsManager::detach_event_condition(rmw_zenoh_event_type_t event_id) return; } - std::lock_guard lock(event_condition_mutex_); + std::lock_guard lock(update_event_condition_mutex_); + event_condition_mutexes_[event_id] = nullptr; event_conditions_[event_id] = nullptr; } @@ -224,8 +227,9 @@ void EventsManager::notify_event(rmw_zenoh_event_type_t event_id) return; } - std::lock_guard lock(event_condition_mutex_); + std::lock_guard lock(update_event_condition_mutex_); if (event_conditions_[event_id] != nullptr) { + std::lock_guard cvlk(*event_condition_mutexes_[event_id]); event_conditions_[event_id]->notify_one(); } } diff --git a/rmw_zenoh_cpp/src/detail/event.hpp b/rmw_zenoh_cpp/src/detail/event.hpp index a8246e97..8509da05 100644 --- a/rmw_zenoh_cpp/src/detail/event.hpp +++ b/rmw_zenoh_cpp/src/detail/event.hpp @@ -138,6 +138,7 @@ class EventsManager /// @param condition_variable to attach. void attach_event_condition( rmw_zenoh_event_type_t event_id, + std::mutex * condition_mutex, std::condition_variable * condition_variable); /// @brief Detach the condition variable provided by rmw_wait. @@ -154,7 +155,8 @@ class EventsManager /// Mutex to lock when read/writing members. mutable std::mutex event_mutex_; /// Mutex to lock for event_condition. - mutable std::mutex event_condition_mutex_; + mutable std::mutex update_event_condition_mutex_; + std::mutex * event_condition_mutexes_[ZENOH_EVENT_ID_MAX + 1]{nullptr}; /// Condition variable to attach for event notifications. std::condition_variable * event_conditions_[ZENOH_EVENT_ID_MAX + 1]{nullptr}; /// User callback that can be set via data_callback_mgr.set_callback(). diff --git a/rmw_zenoh_cpp/src/detail/guard_condition.cpp b/rmw_zenoh_cpp/src/detail/guard_condition.cpp index b850095f..214dff7e 100644 --- a/rmw_zenoh_cpp/src/detail/guard_condition.cpp +++ b/rmw_zenoh_cpp/src/detail/guard_condition.cpp @@ -33,14 +33,18 @@ void GuardCondition::trigger() has_triggered_ = true; if (condition_variable_ != nullptr) { + std::lock_guard cvlk(*condition_mutex_); condition_variable_->notify_one(); } } ///============================================================================== -void GuardCondition::attach_condition(std::condition_variable * condition_variable) +void GuardCondition::attach_condition( + std::mutex * condition_mutex, + std::condition_variable * condition_variable) { std::lock_guard lock(internal_mutex_); + condition_mutex_ = condition_mutex; condition_variable_ = condition_variable; } @@ -48,6 +52,7 @@ void GuardCondition::attach_condition(std::condition_variable * condition_variab void GuardCondition::detach_condition() { std::lock_guard lock(internal_mutex_); + condition_mutex_ = nullptr; condition_variable_ = nullptr; } diff --git a/rmw_zenoh_cpp/src/detail/guard_condition.hpp b/rmw_zenoh_cpp/src/detail/guard_condition.hpp index b556c5f7..ce13bf7e 100644 --- a/rmw_zenoh_cpp/src/detail/guard_condition.hpp +++ b/rmw_zenoh_cpp/src/detail/guard_condition.hpp @@ -29,7 +29,7 @@ class GuardCondition final // Sets has_triggered_ to true and calls notify_one() on condition_variable_ if set. void trigger(); - void attach_condition(std::condition_variable * condition_variable); + void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable); void detach_condition(); @@ -38,7 +38,8 @@ class GuardCondition final private: mutable std::mutex internal_mutex_; std::atomic_bool has_triggered_; - std::condition_variable * condition_variable_; + std::mutex * condition_mutex_{nullptr}; + std::condition_variable * condition_variable_{nullptr}; }; #endif // DETAIL__GUARD_CONDITION_HPP_ diff --git a/rmw_zenoh_cpp/src/detail/rmw_data_types.cpp b/rmw_zenoh_cpp/src/detail/rmw_data_types.cpp index 8f88b75c..11a8dc4e 100644 --- a/rmw_zenoh_cpp/src/detail/rmw_data_types.cpp +++ b/rmw_zenoh_cpp/src/detail/rmw_data_types.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -62,17 +63,23 @@ size_t rmw_publisher_data_t::get_next_sequence_number() } ///============================================================================= -void rmw_subscription_data_t::attach_condition(std::condition_variable * condition_variable) +void rmw_subscription_data_t::attach_condition( + std::mutex * condition_mutex, + std::condition_variable * condition_variable) { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); + condition_mutex_ = condition_mutex; condition_ = condition_variable; } ///============================================================================= void rmw_subscription_data_t::notify() { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); if (condition_ != nullptr) { + // We also need to take the mutex for the condition_variable; see the comment + // in rmw_wait for more information + std::lock_guard cvlk(*condition_mutex_); condition_->notify_one(); } } @@ -80,7 +87,8 @@ void rmw_subscription_data_t::notify() ///============================================================================= void rmw_subscription_data_t::detach_condition() { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); + condition_mutex_ = nullptr; condition_ = nullptr; } @@ -149,16 +157,20 @@ bool rmw_service_data_t::query_queue_is_empty() const } ///============================================================================= -void rmw_service_data_t::attach_condition(std::condition_variable * condition_variable) +void rmw_service_data_t::attach_condition( + std::mutex * condition_mutex, + std::condition_variable * condition_variable) { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); + condition_mutex_ = condition_mutex; condition_ = condition_variable; } ///============================================================================= void rmw_service_data_t::detach_condition() { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); + condition_mutex_ = nullptr; condition_ = nullptr; } @@ -179,8 +191,11 @@ std::unique_ptr rmw_service_data_t::pop_next_query() ///============================================================================= void rmw_service_data_t::notify() { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); if (condition_ != nullptr) { + // We also need to take the mutex for the condition_variable; see the comment + // in rmw_wait for more information + std::lock_guard cvlk(*condition_mutex_); condition_->notify_one(); } } @@ -208,31 +223,74 @@ void rmw_service_data_t::add_new_query(std::unique_ptr query) notify(); } +static size_t hash_gid(const rmw_request_id_t & request_id) +{ + std::stringstream hash_str; + hash_str << std::hex; + size_t i = 0; + for (; i < (RMW_GID_STORAGE_SIZE - 1); i++) { + hash_str << static_cast(request_id.writer_guid[i]); + } + return std::hash{}(hash_str.str()); +} + ///============================================================================= bool rmw_service_data_t::add_to_query_map( - int64_t sequence_number, std::unique_ptr query) + const rmw_request_id_t & request_id, std::unique_ptr query) { + size_t hash = hash_gid(request_id); + std::lock_guard lock(sequence_to_query_map_mutex_); - if (sequence_to_query_map_.find(sequence_number) != sequence_to_query_map_.end()) { - return false; + + std::unordered_map::iterator it = + sequence_to_query_map_.find(hash); + + if (it == sequence_to_query_map_.end()) { + SequenceToQuery stq; + + sequence_to_query_map_.insert(std::make_pair(hash, std::move(stq))); + + it = sequence_to_query_map_.find(hash); + } else { + // Client already in the map + + if (it->second.find(request_id.sequence_number) != it->second.end()) { + return false; + } } - sequence_to_query_map_.emplace( - std::pair(sequence_number, std::move(query))); + + it->second.insert( + std::make_pair(request_id.sequence_number, std::move(query))); return true; } ///============================================================================= -std::unique_ptr rmw_service_data_t::take_from_query_map(int64_t sequence_number) +std::unique_ptr rmw_service_data_t::take_from_query_map( + const rmw_request_id_t & request_id) { + size_t hash = hash_gid(request_id); + std::lock_guard lock(sequence_to_query_map_mutex_); - auto query_it = sequence_to_query_map_.find(sequence_number); - if (query_it == sequence_to_query_map_.end()) { + + std::unordered_map::iterator it = sequence_to_query_map_.find(hash); + + if (it == sequence_to_query_map_.end()) { + return nullptr; + } + + SequenceToQuery::iterator query_it = it->second.find(request_id.sequence_number); + + if (query_it == it->second.end()) { return nullptr; } std::unique_ptr query = std::move(query_it->second); - sequence_to_query_map_.erase(query_it); + it->second.erase(query_it); + + if (sequence_to_query_map_[hash].size() == 0) { + sequence_to_query_map_.erase(hash); + } return query; } @@ -240,8 +298,11 @@ std::unique_ptr rmw_service_data_t::take_from_query_map(int64_t sequ ///============================================================================= void rmw_client_data_t::notify() { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); if (condition_ != nullptr) { + // We also need to take the mutex for the condition_variable; see the comment + // in rmw_wait for more information + std::lock_guard cvlk(*condition_mutex_); condition_->notify_one(); } } @@ -278,16 +339,20 @@ bool rmw_client_data_t::reply_queue_is_empty() const } ///============================================================================= -void rmw_client_data_t::attach_condition(std::condition_variable * condition_variable) +void rmw_client_data_t::attach_condition( + std::mutex * condition_mutex, + std::condition_variable * condition_variable) { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); + condition_mutex_ = condition_mutex; condition_ = condition_variable; } ///============================================================================= void rmw_client_data_t::detach_condition() { - std::lock_guard lock(condition_mutex_); + std::lock_guard lock(update_condition_mutex_); + condition_mutex_ = nullptr; condition_ = nullptr; } diff --git a/rmw_zenoh_cpp/src/detail/rmw_data_types.hpp b/rmw_zenoh_cpp/src/detail/rmw_data_types.hpp index 9a37eefc..0a307fd3 100644 --- a/rmw_zenoh_cpp/src/detail/rmw_data_types.hpp +++ b/rmw_zenoh_cpp/src/detail/rmw_data_types.hpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -173,7 +172,7 @@ class rmw_subscription_data_t final MessageTypeSupport * type_support; rmw_context_t * context; - void attach_condition(std::condition_variable * condition_variable); + void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable); void detach_condition(); @@ -192,8 +191,9 @@ class rmw_subscription_data_t final void notify(); + std::mutex * condition_mutex_{nullptr}; std::condition_variable * condition_{nullptr}; - std::mutex condition_mutex_; + std::mutex update_condition_mutex_; }; @@ -244,7 +244,7 @@ class rmw_service_data_t final bool query_queue_is_empty() const; - void attach_condition(std::condition_variable * condition_variable); + void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable); void detach_condition(); @@ -252,9 +252,9 @@ class rmw_service_data_t final void add_new_query(std::unique_ptr query); - bool add_to_query_map(int64_t sequence_number, std::unique_ptr query); + bool add_to_query_map(const rmw_request_id_t & request_id, std::unique_ptr query); - std::unique_ptr take_from_query_map(int64_t sequence_number); + std::unique_ptr take_from_query_map(const rmw_request_id_t & request_id); DataCallbackManager data_callback_mgr; @@ -265,12 +265,14 @@ class rmw_service_data_t final std::deque> query_queue_; mutable std::mutex query_queue_mutex_; - // Map to store the sequence_number -> query_id - std::unordered_map> sequence_to_query_map_; + // Map to store the sequence_number (as given by the client) -> ZenohQuery + using SequenceToQuery = std::unordered_map>; + std::unordered_map sequence_to_query_map_; std::mutex sequence_to_query_map_mutex_; + std::mutex * condition_mutex_{nullptr}; std::condition_variable * condition_{nullptr}; - std::mutex condition_mutex_; + std::mutex update_condition_mutex_; }; ///============================================================================= @@ -320,7 +322,7 @@ class rmw_client_data_t final bool reply_queue_is_empty() const; - void attach_condition(std::condition_variable * condition_variable); + void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable); void detach_condition(); @@ -334,8 +336,9 @@ class rmw_client_data_t final size_t sequence_number_{1}; std::mutex sequence_number_mutex_; + std::mutex * condition_mutex_{nullptr}; std::condition_variable * condition_{nullptr}; - std::mutex condition_mutex_; + std::mutex update_condition_mutex_; std::deque> reply_queue_; mutable std::mutex reply_queue_mutex_; diff --git a/rmw_zenoh_cpp/src/rmw_zenoh.cpp b/rmw_zenoh_cpp/src/rmw_zenoh.cpp index 9b81ff21..d2245704 100644 --- a/rmw_zenoh_cpp/src/rmw_zenoh.cpp +++ b/rmw_zenoh_cpp/src/rmw_zenoh.cpp @@ -2811,9 +2811,7 @@ rmw_take_request( request_header->received_timestamp = now_ns.count(); // Add this query to the map, so that rmw_send_response can quickly look it up later - if (!service_data->add_to_query_map( - request_header->request_id.sequence_number, std::move(query))) - { + if (!service_data->add_to_query_map(request_header->request_id, std::move(query))) { RMW_SET_ERROR_MSG("duplicate sequence number in the map"); return RMW_RET_ERROR; } @@ -2849,6 +2847,15 @@ rmw_send_response( rmw_service_data_t * service_data = static_cast(service->data); + // Create the queryable payload + std::unique_ptr query = + service_data->take_from_query_map(*request_header); + if (query == nullptr) { + // If there is no data associated with this request, the higher layers of + // ROS 2 seem to expect that we just silently return with no work. + return RMW_RET_OK; + } + rcutils_allocator_t * allocator = &(service_data->context->options.allocator); size_t max_data_length = ( @@ -2860,7 +2867,7 @@ rmw_send_response( max_data_length, allocator->state)); if (!response_bytes) { - RMW_SET_ERROR_MSG("failed allocate response message bytes"); + RMW_SET_ERROR_MSG("failed to allocate response message bytes"); return RMW_RET_ERROR; } auto free_response_bytes = rcpputils::make_scope_exit( @@ -2883,14 +2890,6 @@ rmw_send_response( size_t data_length = ser.get_serialized_data_length(); - // Create the queryable payload - std::unique_ptr query = - service_data->take_from_query_map(request_header->sequence_number); - if (query == nullptr) { - RMW_SET_ERROR_MSG("Unable to find taken request. Report this bug."); - return RMW_RET_ERROR; - } - const z_query_t loaned_query = query->get_query(); z_query_reply_options_t options = z_query_reply_options_default(); @@ -3240,7 +3239,7 @@ rmw_wait( // rmw_guard_condition_t. So we can directly cast it to GuardCondition. GuardCondition * gc = static_cast(guard_conditions->guard_conditions[i]); if (gc != nullptr) { - gc->attach_condition(&wait_set_data->condition_variable); + gc->attach_condition(&wait_set_data->condition_mutex, &wait_set_data->condition_variable); } } } @@ -3251,7 +3250,9 @@ rmw_wait( for (size_t i = 0; i < subscriptions->subscriber_count; ++i) { auto sub_data = static_cast(subscriptions->subscribers[i]); if (sub_data != nullptr) { - sub_data->attach_condition(&wait_set_data->condition_variable); + sub_data->attach_condition( + &wait_set_data->condition_mutex, + &wait_set_data->condition_variable); } } } @@ -3262,7 +3263,9 @@ rmw_wait( for (size_t i = 0; i < services->service_count; ++i) { auto serv_data = static_cast(services->services[i]); if (serv_data != nullptr) { - serv_data->attach_condition(&wait_set_data->condition_variable); + serv_data->attach_condition( + &wait_set_data->condition_mutex, + &wait_set_data->condition_variable); } } } @@ -3273,7 +3276,9 @@ rmw_wait( for (size_t i = 0; i < clients->client_count; ++i) { rmw_client_data_t * client_data = static_cast(clients->clients[i]); if (client_data != nullptr) { - client_data->attach_condition(&wait_set_data->condition_variable); + client_data->attach_condition( + &wait_set_data->condition_mutex, + &wait_set_data->condition_variable); } } } @@ -3287,6 +3292,7 @@ rmw_wait( if (zenoh_event_it != event_map.end()) { event_data->attach_event_condition( zenoh_event_it->second, + &wait_set_data->condition_mutex, &wait_set_data->condition_variable); } } @@ -3295,16 +3301,24 @@ rmw_wait( std::unique_lock lock(wait_set_data->condition_mutex); - // According to the RMW documentation, if wait_timeout is NULL that means - // "wait forever", if it specified by 0 it means "never wait", and if it is anything else wait - // for that amount of time. - if (wait_timeout == nullptr) { - wait_set_data->condition_variable.wait(lock); - } else { - if (wait_timeout->sec != 0 || wait_timeout->nsec != 0) { - std::cv_status wait_status = wait_set_data->condition_variable.wait_for( - lock, std::chrono::nanoseconds(wait_timeout->nsec + RCUTILS_S_TO_NS(wait_timeout->sec))); - wait_result = wait_status == std::cv_status::no_timeout; + // We have to check the triggered condition *again* under the lock so we + // don't miss notifications. + skip_wait = has_triggered_condition( + subscriptions, guard_conditions, services, clients, events); + + if (!skip_wait) { + // According to the RMW documentation, if wait_timeout is NULL that means + // "wait forever", if it specified by 0 it means "never wait", and if it is anything else wait + // for that amount of time. + if (wait_timeout == nullptr) { + wait_set_data->condition_variable.wait(lock); + } else { + if (wait_timeout->sec != 0 || wait_timeout->nsec != 0) { + std::cv_status wait_status = wait_set_data->condition_variable.wait_for( + lock, + std::chrono::nanoseconds(wait_timeout->nsec + RCUTILS_S_TO_NS(wait_timeout->sec))); + wait_result = wait_status == std::cv_status::no_timeout; + } } } } @@ -3417,7 +3431,7 @@ rmw_get_node_names( } //============================================================================== -/// Return the name, namespae, and enclave name of all nodes in the ROS graph. +/// Return the name, namespace, and enclave name of all nodes in the ROS graph. rmw_ret_t rmw_get_node_names_with_enclaves( const rmw_node_t * node,