Skip to content

Commit

Permalink
Lock protocolMap with a mutex to prevent data race (#284)
Browse files Browse the repository at this point in the history
* Lock protocolMap with a mutex to prevent data race

* Fix protocolMapMutex comments

* Convert thread-safe protocolMap access to a function

* Make getProtocol return an optional

* Add additional locks for ProtocolData client and heartbeatInfo

* Change entry mutex access in validate

* Remove chained getters

* Remove duplicate protocolMapMutex lock and heartbeat timeout

* Remove lock from onMessage and lock stop earlier

* Make onMessage use getProtocol

* Log heartbeat timeout earlier

* Fix server stop, reintroduce sendRawString data mutex without heartbeat timeout

* Add comment explaining missing lock

* Replace protocolMap.find()  with getProtocol()
  • Loading branch information
quinnmp authored Dec 2, 2023
1 parent a01aa7b commit e7b7b62
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/network/MissionControlProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ void MissionControlProtocol::handleConnection() {
}

void MissionControlProtocol::handleHeartbeatTimedOut() {
LOG_F(ERROR, "Heartbeat timed out! Emergency stopping.");
this->stopAndShutdownPowerRepeat();
robot::emergencyStop();
LOG_F(ERROR, "Heartbeat timed out! Emergency stopping.");
Globals::E_STOP = true;
Globals::armIKEnabled = false;
}
Expand Down
81 changes: 51 additions & 30 deletions src/network/websocket/WebSocketServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,23 @@ void SingleClientWSServer::serverTask() {

void SingleClientWSServer::stop() {
if (isRunning) {
isRunning = false;
server.stop_listening();
for (auto& entry : protocolMap) {
if (entry.second.client) {
try {
server.close(entry.second.client.value(),
websocketpp::close::status::going_away,
"Server shutting down");
} catch (const websocketpp::exception& e) {
LOG_F(ERROR, "Server=%s : An error occurred while shutting down: %s",
serverName.c_str(), e.what());
{
std::lock_guard lock(protocolMapMutex);
isRunning = false;
server.stop_listening();
for (auto& entry : protocolMap) {
std::lock_guard lock(entry.second.mutex);
if (entry.second.client) {
try {
server.close(entry.second.client.value(),
websocketpp::close::status::going_away,
"Server shutting down");
} catch (const websocketpp::exception& e) {
LOG_F(ERROR, "Server=%s : An error occurred while shutting down: %s",
serverName.c_str(), e.what());
}
entry.second.client.reset();
}
entry.second.client.reset();
}
}
if (serverThread.joinable()) {
Expand All @@ -81,6 +85,7 @@ void SingleClientWSServer::stop() {

bool SingleClientWSServer::addProtocol(std::unique_ptr<WebSocketProtocol> protocol) {
std::string path = protocol->getProtocolPath();
std::lock_guard lock(protocolMapMutex);
if (protocolMap.find(path) == protocolMap.end()) {
protocolMap.emplace(path, std::move(protocol));
return true;
Expand All @@ -91,9 +96,10 @@ bool SingleClientWSServer::addProtocol(std::unique_ptr<WebSocketProtocol> protoc

void SingleClientWSServer::sendRawString(const std::string& protocolPath,
const std::string& str) {
auto entry = protocolMap.find(protocolPath);
if (entry != protocolMap.end()) {
auto& protocolData = entry->second;
auto protocolDataOpt = this->getProtocol(protocolPath);
if (protocolDataOpt.has_value()) {
ProtocolData& protocolData = protocolDataOpt.value();
std::lock_guard lock(protocolData.mutex);
if (protocolData.client) {
connection_hdl hdl = protocolData.client.value();
auto conn = server.get_con_from_hdl(hdl);
Expand All @@ -112,12 +118,14 @@ void SingleClientWSServer::sendJSON(const std::string& protocolPath, const json&
bool SingleClientWSServer::validate(connection_hdl hdl) {
auto conn = server.get_con_from_hdl(hdl);
std::string path = conn->get_resource();
auto entry = protocolMap.find(path);
if (entry != protocolMap.end()) {
if (!entry->second.client.has_value()) {
auto protocolDataOpt = this->getProtocol(path);
if (protocolDataOpt.has_value()) {
ProtocolData& pd = protocolDataOpt.value();
std::lock_guard lock(pd.mutex);
if (!pd.client.has_value()) {
return true;
} else {
auto existingConn = server.get_con_from_hdl(entry->second.client.value());
auto existingConn = server.get_con_from_hdl(pd.client.value());
LOG_F(INFO,
"Server=%s, Endpoint=%s : Rejected connection from %s - A client is already "
"connected: %s\n",
Expand All @@ -139,15 +147,15 @@ void SingleClientWSServer::onOpen(connection_hdl hdl) {
LOG_F(INFO, "Server=%s, Endpoint=%s : Connection opened from %s", serverName.c_str(),
path.c_str(), client.c_str());

auto& protocolData = protocolMap.at(path);
ProtocolData& protocolData = this->getProtocol(path).value();
{
std::lock_guard lock(protocolData.mutex);
protocolData.client = hdl;
const auto& heartbeatInfo = protocolData.protocol->heartbeatInfo;
if (heartbeatInfo.has_value()) {
auto eventID =
pingScheduler.scheduleEvent(heartbeatInfo->first / 2, [this, path]() {
auto& pd = this->protocolMap.at(path);
ProtocolData& pd = this->getProtocol(path).value();
std::lock_guard lock(pd.mutex);
if (pd.client.has_value()) {
LOG_F(2, "Ping!");
Expand All @@ -162,9 +170,8 @@ void SingleClientWSServer::onOpen(connection_hdl hdl) {
std::tuple<decltype(eventID)>{eventID},
util::pairToTuple(heartbeatInfo.value()));
}

protocolData.protocol->clientConnected();
}
protocolData.protocol->clientConnected();
}

void SingleClientWSServer::onClose(connection_hdl hdl) {
Expand All @@ -174,7 +181,7 @@ void SingleClientWSServer::onClose(connection_hdl hdl) {
LOG_F(INFO, "Server=%s, Endpoint=%s : Connection disconnected from %s", serverName.c_str(),
path.c_str(), client.c_str());

auto& protocolData = protocolMap.at(path);
ProtocolData& protocolData = this->getProtocol(path).value();
{
std::lock_guard lock(protocolData.mutex);
protocolData.client.reset();
Expand All @@ -190,12 +197,14 @@ void SingleClientWSServer::onMessage(connection_hdl hdl, message_t message) {
auto conn = server.get_con_from_hdl(hdl);
std::string path = conn->get_resource();

auto it = protocolMap.find(path);
if (it != protocolMap.end()) {
auto protocolDataOpt = this->getProtocol(path);
if (protocolDataOpt.has_value()) {
// No need to lock this pd because we only access the protocol, which is constant
ProtocolData& pd = protocolDataOpt.value();
std::string jsonStr = message->get_payload();
LOG_F(1, "Message on %s: %s", path.c_str(), jsonStr.c_str());
json obj = json::parse(jsonStr);
it->second.protocol->processMessage(obj);
pd.protocol->processMessage(obj);
} else {
LOG_F(WARNING, "Received message on unknown protocol path %s", path.c_str());
}
Expand All @@ -205,9 +214,9 @@ void SingleClientWSServer::onPong(connection_hdl hdl, const std::string& payload
LOG_F(2, "Pong from %s", payload.c_str());
auto conn = server.get_con_from_hdl(hdl);

auto it = protocolMap.find(payload);
if (it != protocolMap.end()) {
auto& pd = it->second;
auto protocolDataOpt = this->getProtocol(payload);
if (protocolDataOpt.has_value()) {
ProtocolData& pd = protocolDataOpt.value();
std::lock_guard lock(pd.mutex);
if (pd.heartbeatInfo.has_value()) {
pd.heartbeatInfo->second.feed();
Expand All @@ -216,5 +225,17 @@ void SingleClientWSServer::onPong(connection_hdl hdl, const std::string& payload
LOG_F(WARNING, "Received pong on unknown protocol path %s", payload.c_str());
}
}

std::optional<std::reference_wrapper<SingleClientWSServer::ProtocolData>>
SingleClientWSServer::getProtocol(const std::string& protocolPath) {
std::lock_guard lock(protocolMapMutex);

auto it = protocolMap.find(protocolPath);
if (it != protocolMap.end()) {
return std::ref(it->second);
} else {
return std::nullopt;
}
}
} // namespace websocket
} // namespace net
5 changes: 5 additions & 0 deletions src/network/websocket/WebSocketServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class SingleClientWSServer {
uint16_t port;
websocketpp::server<websocketpp::config::asio> server;
bool isRunning;
// protects against race conditions modifying protocolMap
std::mutex protocolMapMutex;
// maps path prefix to ProtocolData for each protocol
std::map<std::string, ProtocolData> protocolMap;
std::thread serverThread;
Expand All @@ -125,6 +127,9 @@ class SingleClientWSServer {
// called when pong message received from WS client
void onPong(connection_hdl hdl, const std::string& payload);
void serverTask();
// Thread-safe access of the protocolMap
std::optional<std::reference_wrapper<SingleClientWSServer::ProtocolData>>
getProtocol(const std::string& protocolPath);
};
} // namespace websocket
} // namespace net

0 comments on commit e7b7b62

Please sign in to comment.