diff --git a/src/WebSocketServer.cpp b/src/WebSocketServer.cpp index 28aa2eee..1faeaf07 100644 --- a/src/WebSocketServer.cpp +++ b/src/WebSocketServer.cpp @@ -176,10 +176,13 @@ std::vector WebSocketServer::GetWebSocke std::unique_lock lock(_sessionMutex); for (auto & [hdl, session] : _sessions) { + if (!session.AddRef()) + continue; uint64_t connectedAt = session.ConnectedAt(); uint64_t incomingMessages = session.IncomingMessages(); uint64_t outgoingMessages = session.OutgoingMessages(); std::string remoteAddress = session.RemoteAddress(); + session.DelRef(); webSocketSessions.emplace_back(WebSocketSessionState{hdl, remoteAddress, connectedAt, incomingMessages, outgoingMessages}); } @@ -216,8 +219,12 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT // Recurse connected sessions and send the event to suitable sessions. std::unique_lock lock(_sessionMutex); for (auto & it : _sessions) { - if (!it.second.IsIdentified()) + if (!it.second.AddRef()) continue; + if (!it.second.IsIdentified()) { + it.second.DelRef(); + continue; + } if ((it.second.EventSubscriptions() & requiredIntent) != 0) { websocketpp::lib::error_code errorCode; switch (it.second.Encoding()) { @@ -238,6 +245,7 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT break; } } + it.second.DelRef(); } lock.unlock(); if (_debugEnabled) @@ -254,6 +262,9 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl) auto &session = _sessions[hdl]; lock.unlock(); + if (!session.AddRef()) + return; + // Configure session details session.SetRemoteAddress(conn->get_remote_endpoint()); session.SetConnectedAt(QDateTime::currentSecsSinceEpoch()); @@ -266,6 +277,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl) session.SetEncoding(WebSocketEncoding::MsgPack); } else { conn->close(WebSocketCloseCode::InvalidContentType, "Your HTTP `Content-Type` header specifies an invalid encoding type."); + session.DelRef(); return; } @@ -295,6 +307,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl) _server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode); } session.IncrementOutgoingMessages(); + session.DelRef(); } void WebSocketServer::onClose(websocketpp::connection_hdl hdl) @@ -304,11 +317,14 @@ void WebSocketServer::onClose(websocketpp::connection_hdl hdl) // Get info from the session and then delete it std::unique_lock lock(_sessionMutex); auto &session = _sessions[hdl]; + session.AddRef(); bool isIdentified = session.IsIdentified(); uint64_t connectedAt = session.ConnectedAt(); uint64_t incomingMessages = session.IncomingMessages(); uint64_t outgoingMessages = session.OutgoingMessages(); std::string remoteAddress = session.RemoteAddress(); + session.SetDeleted(); + session.DelRef(); _sessions.erase(hdl); lock.unlock(); @@ -332,6 +348,8 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se QtConcurrent::run(&_threadPool, [=]() { std::unique_lock lock(_sessionMutex); auto &session = _sessions[hdl]; + if (!session.AddRef()) + return; lock.unlock(); session.IncrementIncomingMessages(); @@ -345,6 +363,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se if (opcode != websocketpp::frame::opcode::text) { if (!session.IgnoreInvalidMessages()) _server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to Json, but a binary message was received.", errorCode); + session.DelRef(); return; } @@ -353,12 +372,14 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se } catch (json::parse_error& e) { if (!session.IgnoreInvalidMessages()) _server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode Json: ") + e.what(), errorCode); + session.DelRef(); return; } } else if (sessionEncoding == WebSocketEncoding::MsgPack) { if (opcode != websocketpp::frame::opcode::binary) { if (!session.IgnoreInvalidMessages()) _server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to MsgPack, but a text message was received.", errorCode); + session.DelRef(); return; } @@ -367,6 +388,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se } catch (json::parse_error& e) { if (!session.IgnoreInvalidMessages()) _server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode MsgPack: ") + e.what(), errorCode); + session.DelRef(); return; } } @@ -379,6 +401,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se if (ret.closeCode != WebSocketCloseCode::DontClose) { websocketpp::lib::error_code errorCode; _server.close(hdl, ret.closeCode, ret.closeReason, errorCode); + session.DelRef(); return; } @@ -400,5 +423,6 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se if (errorCode) blog(LOG_WARNING, "[WebSocketServer::onMessage] Sending message to client failed: %s", errorCode.message().c_str()); } + session.DelRef(); }); } diff --git a/src/WebSocketSession.cpp b/src/WebSocketSession.cpp index ffa75a73..fd10b50c 100644 --- a/src/WebSocketSession.cpp +++ b/src/WebSocketSession.cpp @@ -1,8 +1,12 @@ +#include + #include "WebSocketSession.h" #include "plugin-macros.generated.h" WebSocketSession::WebSocketSession() : + _ref(0), + _deleted(false), _remoteAddress(""), _connectedAt(0), _incomingMessages(0), @@ -17,6 +21,31 @@ WebSocketSession::WebSocketSession() : { } +bool WebSocketSession::AddRef() +{ + std::lock_guard lock(_refMutex); + if (_deleted) + return false; + _ref++; + return true; +} + +void WebSocketSession::DelRef() +{ + std::lock_guard lock(_refMutex); + if (_ref == 0) { + blog(LOG_ERROR, "[WebSocketSession::DelRef] Failed to de-increment ref - already 0"); + return; + } + _ref--; +} + +void WebSocketSession::SetDeleted() +{ + std::lock_guard lock(_refMutex); + _deleted = true; +} + std::string WebSocketSession::RemoteAddress() { std::lock_guard lock(_remoteAddressMutex); diff --git a/src/WebSocketSession.h b/src/WebSocketSession.h index c280a253..6aa81591 100644 --- a/src/WebSocketSession.h +++ b/src/WebSocketSession.h @@ -9,6 +9,10 @@ class WebSocketSession public: WebSocketSession(); + bool AddRef(); + void DelRef(); + void SetDeleted(); + std::string RemoteAddress(); void SetRemoteAddress(std::string address); @@ -43,6 +47,9 @@ class WebSocketSession void SetEventSubscriptions(uint64_t subscriptions); private: + std::mutex _refMutex; + uint64_t _ref; + bool _deleted; std::mutex _remoteAddressMutex; std::string _remoteAddress; std::atomic _connectedAt;