Skip to content

Commit

Permalink
Fix UB in ClientData stuff.
Browse files Browse the repository at this point in the history
The num_in_flight stuff was *still* wrong here.  First of
all, we forgot to increment num_in_flight when actually
kicking off a new query.  Once we did that, we had to
change the lock in NodeData to a recursive one, since the
call to delete_client_data from ClientData could be called
recursively.  And then finally we had to drop the ClientData
lock before the delete_client_data, since we are about to
delete ourselves and the unlock would have been UB.

Signed-off-by: Chris Lalancette <[email protected]>
  • Loading branch information
clalancette committed Nov 20, 2024
1 parent 3d771c8 commit 36539bd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
5 changes: 4 additions & 1 deletion rmw_zenoh_cpp/src/detail/rmw_client_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ rmw_ret_t ClientData::send_request(
opts.value.payload = z_bytes_t{data_length, reinterpret_cast<const uint8_t *>(request_bytes)};
// TODO(Yadunund): Once we switch to zenoh-cpp with lambda closures,
// capture shared_from_this() instead of this.
num_in_flight_++;
z_owned_closure_reply_t zn_closure_reply =
z_closure(client_data_handler, client_data_drop, this);
z_get(
Expand Down Expand Up @@ -563,7 +564,7 @@ bool ClientData::shutdown_and_query_in_flight()
///=============================================================================
void ClientData::decrement_in_flight_and_conditionally_remove()
{
std::lock_guard<std::recursive_mutex> lock(mutex_);
std::unique_lock<std::recursive_mutex> lock(mutex_);
--num_in_flight_;

if (is_shutdown_ && num_in_flight_ == 0) {
Expand All @@ -575,6 +576,8 @@ void ClientData::decrement_in_flight_and_conditionally_remove()
if (node_data == nullptr) {
return;
}
// We have to unlock here since we are about to delete ourself, and thus the unlock would be UB.
lock.unlock();
node_data->delete_client_data(rmw_client_);
}
}
Expand Down
30 changes: 15 additions & 15 deletions rmw_zenoh_cpp/src/detail/rmw_node_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ NodeData::~NodeData()
///=============================================================================
std::size_t NodeData::id() const
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
return id_;
}

Expand All @@ -128,7 +128,7 @@ bool NodeData::create_pub_data(
const rosidl_message_type_support_t * type_support,
const rmw_qos_profile_t * qos_profile)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
if (is_shutdown_) {
RMW_ZENOH_LOG_ERROR_NAMED(
"rmw_zenoh_cpp",
Expand Down Expand Up @@ -169,7 +169,7 @@ bool NodeData::create_pub_data(
///=============================================================================
PublisherDataPtr NodeData::get_pub_data(const rmw_publisher_t * const publisher)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto it = pubs_.find(publisher);
if (it == pubs_.end()) {
return nullptr;
Expand All @@ -181,7 +181,7 @@ PublisherDataPtr NodeData::get_pub_data(const rmw_publisher_t * const publisher)
///=============================================================================
void NodeData::delete_pub_data(const rmw_publisher_t * const publisher)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
pubs_.erase(publisher);
}

Expand All @@ -195,7 +195,7 @@ bool NodeData::create_sub_data(
const rosidl_message_type_support_t * type_support,
const rmw_qos_profile_t * qos_profile)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
if (is_shutdown_) {
RMW_ZENOH_LOG_ERROR_NAMED(
"rmw_zenoh_cpp",
Expand Down Expand Up @@ -237,7 +237,7 @@ bool NodeData::create_sub_data(
///=============================================================================
SubscriptionDataPtr NodeData::get_sub_data(const rmw_subscription_t * const subscription)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto it = subs_.find(subscription);
if (it == subs_.end()) {
return nullptr;
Expand All @@ -249,7 +249,7 @@ SubscriptionDataPtr NodeData::get_sub_data(const rmw_subscription_t * const subs
///=============================================================================
void NodeData::delete_sub_data(const rmw_subscription_t * const subscription)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
subs_.erase(subscription);
}

Expand All @@ -262,7 +262,7 @@ bool NodeData::create_service_data(
const rosidl_service_type_support_t * type_supports,
const rmw_qos_profile_t * qos_profile)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
if (is_shutdown_) {
RMW_ZENOH_LOG_ERROR_NAMED(
"rmw_zenoh_cpp",
Expand Down Expand Up @@ -303,7 +303,7 @@ bool NodeData::create_service_data(
///=============================================================================
ServiceDataPtr NodeData::get_service_data(const rmw_service_t * const service)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto it = services_.find(service);
if (it == services_.end()) {
return nullptr;
Expand All @@ -315,7 +315,7 @@ ServiceDataPtr NodeData::get_service_data(const rmw_service_t * const service)
///=============================================================================
void NodeData::delete_service_data(const rmw_service_t * const service)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
services_.erase(service);
}

Expand All @@ -329,7 +329,7 @@ bool NodeData::create_client_data(
const rosidl_service_type_support_t * type_supports,
const rmw_qos_profile_t * qos_profile)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
if (is_shutdown_) {
RMW_ZENOH_LOG_ERROR_NAMED(
"rmw_zenoh_cpp",
Expand Down Expand Up @@ -371,7 +371,7 @@ bool NodeData::create_client_data(
///=============================================================================
ClientDataPtr NodeData::get_client_data(const rmw_client_t * const client)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto it = clients_.find(client);
if (it == clients_.end()) {
return nullptr;
Expand All @@ -383,7 +383,7 @@ ClientDataPtr NodeData::get_client_data(const rmw_client_t * const client)
///=============================================================================
void NodeData::delete_client_data(const rmw_client_t * const client)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto client_it = clients_.find(client);
if (client_it == clients_.end()) {
return;
Expand All @@ -396,7 +396,7 @@ void NodeData::delete_client_data(const rmw_client_t * const client)
///=============================================================================
rmw_ret_t NodeData::shutdown()
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
rmw_ret_t ret = RMW_RET_OK;
if (is_shutdown_) {
return ret;
Expand Down Expand Up @@ -463,7 +463,7 @@ rmw_ret_t NodeData::shutdown()
// Check if the Node is shutdown.
bool NodeData::is_shutdown() const
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
return is_shutdown_;
}

Expand Down
2 changes: 1 addition & 1 deletion rmw_zenoh_cpp/src/detail/rmw_node_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class NodeData final
std::shared_ptr<liveliness::Entity> entity,
zc_owned_liveliness_token_t token);
// Internal mutex.
mutable std::mutex mutex_;
mutable std::recursive_mutex mutex_;
// The rmw_node_t associated with this NodeData.
const rmw_node_t * node_;
// The entity id of this node as generated by get_next_entity_id().
Expand Down

0 comments on commit 36539bd

Please sign in to comment.