base: Use shared_ptr instead of explicit ref counts

Took a night of sleep but I realized how I could solve the
concurrency issues in a good way. Uses shared_ptr, where the map
always accounts for one reference to a session.
This commit is contained in:
tt2468 2021-04-30 08:45:34 -07:00
parent 32758198ab
commit 4be9b995fb
6 changed files with 42 additions and 92 deletions

View File

@ -4,14 +4,14 @@
#include "plugin-macros.generated.h" #include "plugin-macros.generated.h"
WebSocketProtocol::ProcessResult SetSessionParameters(WebSocketSession *session, json incomingMessage) WebSocketProtocol::ProcessResult SetSessionParameters(SessionPtr session, json incomingMessage)
{ {
WebSocketProtocol::ProcessResult ret; WebSocketProtocol::ProcessResult ret;
return ret; return ret;
} }
WebSocketProtocol::ProcessResult WebSocketProtocol::ProcessMessage(websocketpp::connection_hdl hdl, WebSocketSession *session, json incomingMessage) WebSocketProtocol::ProcessResult WebSocketProtocol::ProcessMessage(SessionPtr session, json incomingMessage)
{ {
WebSocketProtocol::ProcessResult ret; WebSocketProtocol::ProcessResult ret;

View File

@ -14,5 +14,5 @@ namespace WebSocketProtocol {
json result; json result;
}; };
ProcessResult ProcessMessage(websocketpp::connection_hdl hdl, WebSocketSession *session, json incomingMessage); ProcessResult ProcessMessage(SessionPtr session, json incomingMessage);
} }

View File

@ -176,13 +176,10 @@ std::vector<WebSocketServer::WebSocketSessionState> WebSocketServer::GetWebSocke
std::unique_lock<std::mutex> lock(_sessionMutex); std::unique_lock<std::mutex> lock(_sessionMutex);
for (auto & [hdl, session] : _sessions) { for (auto & [hdl, session] : _sessions) {
if (!session.AddRef()) uint64_t connectedAt = session->ConnectedAt();
continue; uint64_t incomingMessages = session->IncomingMessages();
uint64_t connectedAt = session.ConnectedAt(); uint64_t outgoingMessages = session->OutgoingMessages();
uint64_t incomingMessages = session.IncomingMessages(); std::string remoteAddress = session->RemoteAddress();
uint64_t outgoingMessages = session.OutgoingMessages();
std::string remoteAddress = session.RemoteAddress();
session.DelRef();
webSocketSessions.emplace_back(WebSocketSessionState{hdl, remoteAddress, connectedAt, incomingMessages, outgoingMessages}); webSocketSessions.emplace_back(WebSocketSessionState{hdl, remoteAddress, connectedAt, incomingMessages, outgoingMessages});
} }
@ -219,21 +216,18 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT
// Recurse connected sessions and send the event to suitable sessions. // Recurse connected sessions and send the event to suitable sessions.
std::unique_lock<std::mutex> lock(_sessionMutex); std::unique_lock<std::mutex> lock(_sessionMutex);
for (auto & it : _sessions) { for (auto & it : _sessions) {
if (!it.second.AddRef()) if (!it.second->IsIdentified()) {
continue;
if (!it.second.IsIdentified()) {
it.second.DelRef();
continue; continue;
} }
if ((it.second.EventSubscriptions() & requiredIntent) != 0) { if ((it.second->EventSubscriptions() & requiredIntent) != 0) {
websocketpp::lib::error_code errorCode; websocketpp::lib::error_code errorCode;
switch (it.second.Encoding()) { switch (it.second->Encoding()) {
case WebSocketEncoding::Json: case WebSocketEncoding::Json:
if (messageJson.empty()) { if (messageJson.empty()) {
messageJson = eventMessage.dump(); messageJson = eventMessage.dump();
} }
_server.send((websocketpp::connection_hdl)it.first, messageJson, websocketpp::frame::opcode::text, errorCode); _server.send((websocketpp::connection_hdl)it.first, messageJson, websocketpp::frame::opcode::text, errorCode);
it.second.IncrementOutgoingMessages(); it.second->IncrementOutgoingMessages();
break; break;
case WebSocketEncoding::MsgPack: case WebSocketEncoding::MsgPack:
if (messageMsgPack.empty()) { if (messageMsgPack.empty()) {
@ -241,11 +235,10 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT
messageMsgPack = std::string(msgPackData.begin(), msgPackData.end()); messageMsgPack = std::string(msgPackData.begin(), msgPackData.end());
} }
_server.send((websocketpp::connection_hdl)it.first, messageMsgPack, websocketpp::frame::opcode::binary, errorCode); _server.send((websocketpp::connection_hdl)it.first, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
it.second.IncrementOutgoingMessages(); it.second->IncrementOutgoingMessages();
break; break;
} }
} }
it.second.DelRef();
} }
lock.unlock(); lock.unlock();
if (_debugEnabled) if (_debugEnabled)
@ -259,25 +252,21 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
// Build new session // Build new session
std::unique_lock<std::mutex> lock(_sessionMutex); std::unique_lock<std::mutex> lock(_sessionMutex);
auto &session = _sessions[hdl]; SessionPtr session = _sessions[hdl] = std::make_shared<WebSocketSession>();
lock.unlock(); lock.unlock();
if (!session.AddRef())
return;
// Configure session details // Configure session details
session.SetRemoteAddress(conn->get_remote_endpoint()); session->SetRemoteAddress(conn->get_remote_endpoint());
session.SetConnectedAt(QDateTime::currentSecsSinceEpoch()); session->SetConnectedAt(QDateTime::currentSecsSinceEpoch());
std::string contentType = conn->get_request_header("Content-Type"); std::string contentType = conn->get_request_header("Content-Type");
if (contentType == "") { if (contentType == "") {
; ;
} else if (contentType == "application/json") { } else if (contentType == "application/json") {
session.SetEncoding(WebSocketEncoding::Json); session->SetEncoding(WebSocketEncoding::Json);
} else if (contentType == "application/msgpack") { } else if (contentType == "application/msgpack") {
session.SetEncoding(WebSocketEncoding::MsgPack); session->SetEncoding(WebSocketEncoding::MsgPack);
} else { } else {
conn->close(WebSocketCloseCode::InvalidContentType, "Your HTTP `Content-Type` header specifies an invalid encoding type."); conn->close(WebSocketCloseCode::InvalidContentType, "Your HTTP `Content-Type` header specifies an invalid encoding type.");
session.DelRef();
return; return;
} }
@ -289,7 +278,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
// todo: Add request and event lists // todo: Add request and event lists
if (AuthenticationRequired) { if (AuthenticationRequired) {
std::string sessionChallenge = Utils::Crypto::GenerateSalt(); std::string sessionChallenge = Utils::Crypto::GenerateSalt();
session.SetChallenge(sessionChallenge); session->SetChallenge(sessionChallenge);
helloMessage["authentication"] = {}; helloMessage["authentication"] = {};
helloMessage["authentication"]["challenge"] = sessionChallenge; helloMessage["authentication"]["challenge"] = sessionChallenge;
helloMessage["authentication"]["salt"] = AuthenticationSalt; helloMessage["authentication"]["salt"] = AuthenticationSalt;
@ -297,7 +286,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
// Send object to client // Send object to client
websocketpp::lib::error_code errorCode; websocketpp::lib::error_code errorCode;
auto sessionEncoding = session.Encoding(); auto sessionEncoding = session->Encoding();
if (sessionEncoding == WebSocketEncoding::Json) { if (sessionEncoding == WebSocketEncoding::Json) {
std::string helloMessageJson = helloMessage.dump(); std::string helloMessageJson = helloMessage.dump();
_server.send(hdl, helloMessageJson, websocketpp::frame::opcode::text, errorCode); _server.send(hdl, helloMessageJson, websocketpp::frame::opcode::text, errorCode);
@ -306,8 +295,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
std::string messageMsgPack(msgPackData.begin(), msgPackData.end()); std::string messageMsgPack(msgPackData.begin(), msgPackData.end());
_server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode); _server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
} }
session.IncrementOutgoingMessages(); session->IncrementOutgoingMessages();
session.DelRef();
} }
void WebSocketServer::onClose(websocketpp::connection_hdl hdl) void WebSocketServer::onClose(websocketpp::connection_hdl hdl)
@ -316,15 +304,12 @@ void WebSocketServer::onClose(websocketpp::connection_hdl hdl)
// Get info from the session and then delete it // Get info from the session and then delete it
std::unique_lock<std::mutex> lock(_sessionMutex); std::unique_lock<std::mutex> lock(_sessionMutex);
auto &session = _sessions[hdl]; SessionPtr session = _sessions[hdl];
session.AddRef(); bool isIdentified = session->IsIdentified();
bool isIdentified = session.IsIdentified(); uint64_t connectedAt = session->ConnectedAt();
uint64_t connectedAt = session.ConnectedAt(); uint64_t incomingMessages = session->IncomingMessages();
uint64_t incomingMessages = session.IncomingMessages(); uint64_t outgoingMessages = session->OutgoingMessages();
uint64_t outgoingMessages = session.OutgoingMessages(); std::string remoteAddress = session->RemoteAddress();
std::string remoteAddress = session.RemoteAddress();
session.SetDeleted();
session.DelRef();
_sessions.erase(hdl); _sessions.erase(hdl);
lock.unlock(); lock.unlock();
@ -347,48 +332,47 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
std::string payload = message->get_payload(); std::string payload = message->get_payload();
QtConcurrent::run(&_threadPool, [=]() { QtConcurrent::run(&_threadPool, [=]() {
std::unique_lock<std::mutex> lock(_sessionMutex); std::unique_lock<std::mutex> lock(_sessionMutex);
auto &session = _sessions[hdl]; SessionPtr session;
if (!session.AddRef()) try {
session = _sessions.at(hdl);
} catch (const std::out_of_range& oor) {
return; return;
}
lock.unlock(); lock.unlock();
session.IncrementIncomingMessages(); session->IncrementIncomingMessages();
json incomingMessage; json incomingMessage;
// Check for invalid opcode and decode // Check for invalid opcode and decode
websocketpp::lib::error_code errorCode; websocketpp::lib::error_code errorCode;
uint8_t sessionEncoding = session.Encoding(); uint8_t sessionEncoding = session->Encoding();
if (sessionEncoding == WebSocketEncoding::Json) { if (sessionEncoding == WebSocketEncoding::Json) {
if (opcode != websocketpp::frame::opcode::text) { if (opcode != websocketpp::frame::opcode::text) {
if (!session.IgnoreInvalidMessages()) if (!session->IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to Json, but a binary message was received.", errorCode); _server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to Json, but a binary message was received.", errorCode);
session.DelRef();
return; return;
} }
try { try {
incomingMessage = json::parse(payload); incomingMessage = json::parse(payload);
} catch (json::parse_error& e) { } catch (json::parse_error& e) {
if (!session.IgnoreInvalidMessages()) if (!session->IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode Json: ") + e.what(), errorCode); _server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode Json: ") + e.what(), errorCode);
session.DelRef();
return; return;
} }
} else if (sessionEncoding == WebSocketEncoding::MsgPack) { } else if (sessionEncoding == WebSocketEncoding::MsgPack) {
if (opcode != websocketpp::frame::opcode::binary) { if (opcode != websocketpp::frame::opcode::binary) {
if (!session.IgnoreInvalidMessages()) if (!session->IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to MsgPack, but a text message was received.", errorCode); _server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to MsgPack, but a text message was received.", errorCode);
session.DelRef();
return; return;
} }
try { try {
incomingMessage = json::from_msgpack(payload); incomingMessage = json::from_msgpack(payload);
} catch (json::parse_error& e) { } catch (json::parse_error& e) {
if (!session.IgnoreInvalidMessages()) if (!session->IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode MsgPack: ") + e.what(), errorCode); _server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode MsgPack: ") + e.what(), errorCode);
session.DelRef();
return; return;
} }
} }
@ -396,12 +380,11 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
if (_debugEnabled) if (_debugEnabled)
blog(LOG_INFO, "[WebSocketServer::onMessage] Incoming message (decoded):\n%s", incomingMessage.dump(2).c_str()); blog(LOG_INFO, "[WebSocketServer::onMessage] Incoming message (decoded):\n%s", incomingMessage.dump(2).c_str());
WebSocketProtocol::ProcessResult ret = WebSocketProtocol::ProcessMessage(hdl, &session, incomingMessage); WebSocketProtocol::ProcessResult ret = WebSocketProtocol::ProcessMessage(session, incomingMessage);
if (ret.closeCode != WebSocketCloseCode::DontClose) { if (ret.closeCode != WebSocketCloseCode::DontClose) {
websocketpp::lib::error_code errorCode; websocketpp::lib::error_code errorCode;
_server.close(hdl, ret.closeCode, ret.closeReason, errorCode); _server.close(hdl, ret.closeCode, ret.closeReason, errorCode);
session.DelRef();
return; return;
} }
@ -415,7 +398,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
std::string messageMsgPack(msgPackData.begin(), msgPackData.end()); std::string messageMsgPack(msgPackData.begin(), msgPackData.end());
_server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode); _server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
} }
session.IncrementOutgoingMessages(); session->IncrementOutgoingMessages();
if (_debugEnabled) if (_debugEnabled)
blog(LOG_INFO, "[WebSocketServer::onMessage] Outgoing message:\n%s", ret.result.dump(2).c_str()); blog(LOG_INFO, "[WebSocketServer::onMessage] Outgoing message:\n%s", ret.result.dump(2).c_str());
@ -423,6 +406,5 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
if (errorCode) if (errorCode)
blog(LOG_WARNING, "[WebSocketServer::onMessage] Sending message to client failed: %s", errorCode.message().c_str()); blog(LOG_WARNING, "[WebSocketServer::onMessage] Sending message to client failed: %s", errorCode.message().c_str());
} }
session.DelRef();
}); });
} }

View File

@ -4,6 +4,7 @@
#include <QThreadPool> #include <QThreadPool>
#include <QString> #include <QString>
#include <mutex> #include <mutex>
#include <memory>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <websocketpp/config/asio_no_tls.hpp> #include <websocketpp/config/asio_no_tls.hpp>
@ -12,6 +13,7 @@
#include "WebSocketSession.h" #include "WebSocketSession.h"
using json = nlohmann::json; using json = nlohmann::json;
typedef std::shared_ptr<WebSocketSession> SessionPtr;
class WebSocketServer : QObject class WebSocketServer : QObject
{ {
@ -100,7 +102,7 @@ class WebSocketServer : QObject
websocketpp::server<websocketpp::config::asio> _server; websocketpp::server<websocketpp::config::asio> _server;
QThreadPool _threadPool; QThreadPool _threadPool;
std::mutex _sessionMutex; std::mutex _sessionMutex;
std::map<websocketpp::connection_hdl, WebSocketSession, std::owner_less<websocketpp::connection_hdl>> _sessions; std::map<websocketpp::connection_hdl, SessionPtr, std::owner_less<websocketpp::connection_hdl>> _sessions;
uint16_t _serverPort; uint16_t _serverPort;
QString _serverPassword; QString _serverPassword;
bool _debugEnabled; bool _debugEnabled;

View File

@ -5,8 +5,6 @@
#include "plugin-macros.generated.h" #include "plugin-macros.generated.h"
WebSocketSession::WebSocketSession() : WebSocketSession::WebSocketSession() :
_ref(0),
_deleted(false),
_remoteAddress(""), _remoteAddress(""),
_connectedAt(0), _connectedAt(0),
_incomingMessages(0), _incomingMessages(0),
@ -21,31 +19,6 @@ WebSocketSession::WebSocketSession() :
{ {
} }
bool WebSocketSession::AddRef()
{
std::lock_guard<std::mutex> lock(_refMutex);
if (_deleted)
return false;
_ref++;
return true;
}
void WebSocketSession::DelRef()
{
std::lock_guard<std::mutex> lock(_refMutex);
if (_ref == 0) {
blog(LOG_ERROR, "[WebSocketSession::DelRef] Failed to de-increment ref - already 0");
return;
}
_ref--;
}
void WebSocketSession::SetDeleted()
{
std::lock_guard<std::mutex> lock(_refMutex);
_deleted = true;
}
std::string WebSocketSession::RemoteAddress() std::string WebSocketSession::RemoteAddress()
{ {
std::lock_guard<std::mutex> lock(_remoteAddressMutex); std::lock_guard<std::mutex> lock(_remoteAddressMutex);

View File

@ -9,10 +9,6 @@ class WebSocketSession
public: public:
WebSocketSession(); WebSocketSession();
bool AddRef();
void DelRef();
void SetDeleted();
std::string RemoteAddress(); std::string RemoteAddress();
void SetRemoteAddress(std::string address); void SetRemoteAddress(std::string address);
@ -47,9 +43,6 @@ class WebSocketSession
void SetEventSubscriptions(uint64_t subscriptions); void SetEventSubscriptions(uint64_t subscriptions);
private: private:
std::mutex _refMutex;
uint64_t _ref;
bool _deleted;
std::mutex _remoteAddressMutex; std::mutex _remoteAddressMutex;
std::string _remoteAddress; std::string _remoteAddress;
std::atomic<uint64_t> _connectedAt; std::atomic<uint64_t> _connectedAt;