Skip to content

Commit

Permalink
Fix a race condition in rmw_wait.
Browse files Browse the repository at this point in the history
To very briefly explain, rmw_wait:
1.  Checks to see if any of the entities (subscriptions, clients, etc)
have data ready to go.
2.  If they have data ready to go, then we skip attaching the
condition variable and waiting.
3.  If they do not have data ready to go, then we attach the
condition variable to all entities, take the condition variable
lock, and call wait_for/wait on the condition variable.
4.  Regardless of whether we did 3 or 4, we check every entity
to see if there is data ready, and mark that as appropriate
in the wait set.

There is a race in all of this, however.  If data comes in after
we've checked the entity (1), but before we've attached the
condition variable (3), then we will never be woken up.  In most
cases, this means that we'll wait the full timeout for the wait_for,
which is not what we want.

Fix this by adding another step to 3.  In particular, after we've
locked the condition variable mutex, check the entities again.
Since we change the entities to *also* take the lock before we
notify, this ensures that the entities cannot make changes that
get lost.

Signed-off-by: Chris Lalancette <[email protected]>
  • Loading branch information
clalancette authored and Yadunund committed Apr 18, 2024
1 parent 12ebc2e commit f65940e
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 39 deletions.
10 changes: 7 additions & 3 deletions rmw_zenoh_cpp/src/detail/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ void EventsManager::add_new_event(
///=============================================================================
void EventsManager::attach_event_condition(
rmw_zenoh_event_type_t event_id,
std::mutex * condition_mutex,
std::condition_variable * condition_variable)
{
if (event_id > ZENOH_EVENT_ID_MAX) {
Expand All @@ -194,7 +195,8 @@ void EventsManager::attach_event_condition(
return;
}

std::lock_guard<std::mutex> lock(event_condition_mutex_);
std::lock_guard<std::mutex> lock(update_event_condition_mutex_);
event_condition_mutexes_[event_id] = condition_mutex;
event_conditions_[event_id] = condition_variable;
}

Expand All @@ -209,7 +211,8 @@ void EventsManager::detach_event_condition(rmw_zenoh_event_type_t event_id)
return;
}

std::lock_guard<std::mutex> lock(event_condition_mutex_);
std::lock_guard<std::mutex> lock(update_event_condition_mutex_);
event_condition_mutexes_[event_id] = nullptr;
event_conditions_[event_id] = nullptr;
}

Expand All @@ -224,8 +227,9 @@ void EventsManager::notify_event(rmw_zenoh_event_type_t event_id)
return;
}

std::lock_guard<std::mutex> lock(event_condition_mutex_);
std::lock_guard<std::mutex> lock(update_event_condition_mutex_);
if (event_conditions_[event_id] != nullptr) {
std::lock_guard<std::mutex> cvlk(*event_condition_mutexes_[event_id]);
event_conditions_[event_id]->notify_one();
}
}
4 changes: 3 additions & 1 deletion rmw_zenoh_cpp/src/detail/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class EventsManager
/// @param condition_variable to attach.
void attach_event_condition(
rmw_zenoh_event_type_t event_id,
std::mutex * condition_mutex,
std::condition_variable * condition_variable);

/// @brief Detach the condition variable provided by rmw_wait.
Expand All @@ -154,7 +155,8 @@ class EventsManager
/// Mutex to lock when read/writing members.
mutable std::mutex event_mutex_;
/// Mutex to lock for event_condition.
mutable std::mutex event_condition_mutex_;
mutable std::mutex update_event_condition_mutex_;
std::mutex * event_condition_mutexes_[ZENOH_EVENT_ID_MAX + 1]{nullptr};
/// Condition variable to attach for event notifications.
std::condition_variable * event_conditions_[ZENOH_EVENT_ID_MAX + 1]{nullptr};
/// User callback that can be set via data_callback_mgr.set_callback().
Expand Down
5 changes: 4 additions & 1 deletion rmw_zenoh_cpp/src/detail/guard_condition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,24 @@ void GuardCondition::trigger()
has_triggered_ = true;

if (condition_variable_ != nullptr) {
std::lock_guard<std::mutex> cvlk(*condition_mutex_);
condition_variable_->notify_one();
}
}

///==============================================================================
void GuardCondition::attach_condition(std::condition_variable * condition_variable)
void GuardCondition::attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable)
{
std::lock_guard<std::mutex> lock(internal_mutex_);
condition_mutex_ = condition_mutex;
condition_variable_ = condition_variable;
}

///==============================================================================
void GuardCondition::detach_condition()
{
std::lock_guard<std::mutex> lock(internal_mutex_);
condition_mutex_ = nullptr;
condition_variable_ = nullptr;
}

Expand Down
5 changes: 3 additions & 2 deletions rmw_zenoh_cpp/src/detail/guard_condition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class GuardCondition final
// Sets has_triggered_ to true and calls notify_one() on condition_variable_ if set.
void trigger();

void attach_condition(std::condition_variable * condition_variable);
void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable);

void detach_condition();

Expand All @@ -38,7 +38,8 @@ class GuardCondition final
private:
mutable std::mutex internal_mutex_;
std::atomic_bool has_triggered_;
std::condition_variable * condition_variable_;
std::mutex * condition_mutex_{nullptr};
std::condition_variable * condition_variable_{nullptr};
};

#endif // DETAIL__GUARD_CONDITION_HPP_
39 changes: 27 additions & 12 deletions rmw_zenoh_cpp/src/detail/rmw_data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,30 @@ size_t rmw_publisher_data_t::get_next_sequence_number()
}

///=============================================================================
void rmw_subscription_data_t::attach_condition(std::condition_variable * condition_variable)
void rmw_subscription_data_t::attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable)
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = condition_mutex;
condition_ = condition_variable;
}

///=============================================================================
void rmw_subscription_data_t::notify()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
if (condition_ != nullptr) {
// We also need to take the mutex for the condition_variable; see the comment
// in rmw_wait for more information
std::lock_guard<std::mutex> cvlk(*condition_mutex_);
condition_->notify_one();
}
}

///=============================================================================
void rmw_subscription_data_t::detach_condition()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = nullptr;
condition_ = nullptr;
}

Expand Down Expand Up @@ -150,16 +155,18 @@ bool rmw_service_data_t::query_queue_is_empty() const
}

///=============================================================================
void rmw_service_data_t::attach_condition(std::condition_variable * condition_variable)
void rmw_service_data_t::attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable)
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = condition_mutex;
condition_ = condition_variable;
}

///=============================================================================
void rmw_service_data_t::detach_condition()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = nullptr;
condition_ = nullptr;
}

Expand All @@ -180,8 +187,11 @@ std::unique_ptr<ZenohQuery> rmw_service_data_t::pop_next_query()
///=============================================================================
void rmw_service_data_t::notify()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
if (condition_ != nullptr) {
// We also need to take the mutex for the condition_variable; see the comment
// in rmw_wait for more information
std::lock_guard<std::mutex> cvlk(*condition_mutex_);
condition_->notify_one();
}
}
Expand Down Expand Up @@ -282,8 +292,11 @@ std::unique_ptr<ZenohQuery> rmw_service_data_t::take_from_query_map(
///=============================================================================
void rmw_client_data_t::notify()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
if (condition_ != nullptr) {
// We also need to take the mutex for the condition_variable; see the comment
// in rmw_wait for more information
std::lock_guard<std::mutex> cvlk(*condition_mutex_);
condition_->notify_one();
}
}
Expand Down Expand Up @@ -320,16 +333,18 @@ bool rmw_client_data_t::reply_queue_is_empty() const
}

///=============================================================================
void rmw_client_data_t::attach_condition(std::condition_variable * condition_variable)
void rmw_client_data_t::attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable)
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = condition_mutex;
condition_ = condition_variable;
}

///=============================================================================
void rmw_client_data_t::detach_condition()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = nullptr;
condition_ = nullptr;
}

Expand Down
15 changes: 9 additions & 6 deletions rmw_zenoh_cpp/src/detail/rmw_data_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class rmw_subscription_data_t final
MessageTypeSupport * type_support;
rmw_context_t * context;

void attach_condition(std::condition_variable * condition_variable);
void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable);

void detach_condition();

Expand All @@ -191,8 +191,9 @@ class rmw_subscription_data_t final

void notify();

std::mutex * condition_mutex_{nullptr};
std::condition_variable * condition_{nullptr};
std::mutex condition_mutex_;
std::mutex update_condition_mutex_;
};


Expand Down Expand Up @@ -243,7 +244,7 @@ class rmw_service_data_t final

bool query_queue_is_empty() const;

void attach_condition(std::condition_variable * condition_variable);
void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable);

void detach_condition();

Expand All @@ -269,8 +270,9 @@ class rmw_service_data_t final
std::unordered_map<size_t, SequenceToQuery> sequence_to_query_map_;
std::mutex sequence_to_query_map_mutex_;

std::mutex * condition_mutex_{nullptr};
std::condition_variable * condition_{nullptr};
std::mutex condition_mutex_;
std::mutex update_condition_mutex_;
};

///=============================================================================
Expand Down Expand Up @@ -320,7 +322,7 @@ class rmw_client_data_t final

bool reply_queue_is_empty() const;

void attach_condition(std::condition_variable * condition_variable);
void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable);

void detach_condition();

Expand All @@ -334,8 +336,9 @@ class rmw_client_data_t final
size_t sequence_number_{1};
std::mutex sequence_number_mutex_;

std::mutex * condition_mutex_{nullptr};
std::condition_variable * condition_{nullptr};
std::mutex condition_mutex_;
std::mutex update_condition_mutex_;

std::deque<std::unique_ptr<ZenohReply>> reply_queue_;
mutable std::mutex reply_queue_mutex_;
Expand Down
35 changes: 21 additions & 14 deletions rmw_zenoh_cpp/src/rmw_zenoh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3239,7 +3239,7 @@ rmw_wait(
// rmw_guard_condition_t. So we can directly cast it to GuardCondition.
GuardCondition * gc = static_cast<GuardCondition *>(guard_conditions->guard_conditions[i]);
if (gc != nullptr) {
gc->attach_condition(&wait_set_data->condition_variable);
gc->attach_condition(&wait_set_data->condition_mutex, &wait_set_data->condition_variable);
}
}
}
Expand All @@ -3250,7 +3250,7 @@ rmw_wait(
for (size_t i = 0; i < subscriptions->subscriber_count; ++i) {
auto sub_data = static_cast<rmw_subscription_data_t *>(subscriptions->subscribers[i]);
if (sub_data != nullptr) {
sub_data->attach_condition(&wait_set_data->condition_variable);
sub_data->attach_condition(&wait_set_data->condition_mutex, &wait_set_data->condition_variable);
}
}
}
Expand All @@ -3261,7 +3261,7 @@ rmw_wait(
for (size_t i = 0; i < services->service_count; ++i) {
auto serv_data = static_cast<rmw_service_data_t *>(services->services[i]);
if (serv_data != nullptr) {
serv_data->attach_condition(&wait_set_data->condition_variable);
serv_data->attach_condition(&wait_set_data->condition_mutex, &wait_set_data->condition_variable);
}
}
}
Expand All @@ -3272,7 +3272,7 @@ rmw_wait(
for (size_t i = 0; i < clients->client_count; ++i) {
rmw_client_data_t * client_data = static_cast<rmw_client_data_t *>(clients->clients[i]);
if (client_data != nullptr) {
client_data->attach_condition(&wait_set_data->condition_variable);
client_data->attach_condition(&wait_set_data->condition_mutex, &wait_set_data->condition_variable);
}
}
}
Expand All @@ -3286,6 +3286,7 @@ rmw_wait(
if (zenoh_event_it != event_map.end()) {
event_data->attach_event_condition(
zenoh_event_it->second,
&wait_set_data->condition_mutex,
&wait_set_data->condition_variable);
}
}
Expand All @@ -3294,16 +3295,22 @@ rmw_wait(

std::unique_lock<std::mutex> lock(wait_set_data->condition_mutex);

// According to the RMW documentation, if wait_timeout is NULL that means
// "wait forever", if it specified by 0 it means "never wait", and if it is anything else wait
// for that amount of time.
if (wait_timeout == nullptr) {
wait_set_data->condition_variable.wait(lock);
} else {
if (wait_timeout->sec != 0 || wait_timeout->nsec != 0) {
std::cv_status wait_status = wait_set_data->condition_variable.wait_for(
lock, std::chrono::nanoseconds(wait_timeout->nsec + RCUTILS_S_TO_NS(wait_timeout->sec)));
wait_result = wait_status == std::cv_status::no_timeout;
// We have to check the triggered condition *again* under the lock so we don't miss notifications.
skip_wait = has_triggered_condition(
subscriptions, guard_conditions, services, clients, events);

if (!skip_wait) {
// According to the RMW documentation, if wait_timeout is NULL that means
// "wait forever", if it specified by 0 it means "never wait", and if it is anything else wait
// for that amount of time.
if (wait_timeout == nullptr) {
wait_set_data->condition_variable.wait(lock);
} else {
if (wait_timeout->sec != 0 || wait_timeout->nsec != 0) {
std::cv_status wait_status = wait_set_data->condition_variable.wait_for(
lock, std::chrono::nanoseconds(wait_timeout->nsec + RCUTILS_S_TO_NS(wait_timeout->sec)));
wait_result = wait_status == std::cv_status::no_timeout;
}
}
}
}
Expand Down

0 comments on commit f65940e

Please sign in to comment.