base: Use shared_ptr instead of explicit ref counts

Took a night of sleep but I realized how I could solve the
concurrency issues in a good way. Uses shared_ptr, where the map
always accounts for one reference to a session.
This commit is contained in:
tt2468 2021-04-30 08:45:34 -07:00
parent 32758198ab
commit 4be9b995fb
6 changed files with 42 additions and 92 deletions

View File

@ -4,14 +4,14 @@
#include "plugin-macros.generated.h"
WebSocketProtocol::ProcessResult SetSessionParameters(WebSocketSession *session, json incomingMessage)
WebSocketProtocol::ProcessResult SetSessionParameters(SessionPtr session, json incomingMessage)
{
WebSocketProtocol::ProcessResult ret;
return ret;
}
WebSocketProtocol::ProcessResult WebSocketProtocol::ProcessMessage(websocketpp::connection_hdl hdl, WebSocketSession *session, json incomingMessage)
WebSocketProtocol::ProcessResult WebSocketProtocol::ProcessMessage(SessionPtr session, json incomingMessage)
{
WebSocketProtocol::ProcessResult ret;

View File

@ -14,5 +14,5 @@ namespace WebSocketProtocol {
json result;
};
ProcessResult ProcessMessage(websocketpp::connection_hdl hdl, WebSocketSession *session, json incomingMessage);
ProcessResult ProcessMessage(SessionPtr session, json incomingMessage);
}

View File

@ -176,13 +176,10 @@ std::vector<WebSocketServer::WebSocketSessionState> WebSocketServer::GetWebSocke
std::unique_lock<std::mutex> lock(_sessionMutex);
for (auto & [hdl, session] : _sessions) {
if (!session.AddRef())
continue;
uint64_t connectedAt = session.ConnectedAt();
uint64_t incomingMessages = session.IncomingMessages();
uint64_t outgoingMessages = session.OutgoingMessages();
std::string remoteAddress = session.RemoteAddress();
session.DelRef();
uint64_t connectedAt = session->ConnectedAt();
uint64_t incomingMessages = session->IncomingMessages();
uint64_t outgoingMessages = session->OutgoingMessages();
std::string remoteAddress = session->RemoteAddress();
webSocketSessions.emplace_back(WebSocketSessionState{hdl, remoteAddress, connectedAt, incomingMessages, outgoingMessages});
}
@ -219,21 +216,18 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT
// Recurse connected sessions and send the event to suitable sessions.
std::unique_lock<std::mutex> lock(_sessionMutex);
for (auto & it : _sessions) {
if (!it.second.AddRef())
continue;
if (!it.second.IsIdentified()) {
it.second.DelRef();
if (!it.second->IsIdentified()) {
continue;
}
if ((it.second.EventSubscriptions() & requiredIntent) != 0) {
if ((it.second->EventSubscriptions() & requiredIntent) != 0) {
websocketpp::lib::error_code errorCode;
switch (it.second.Encoding()) {
switch (it.second->Encoding()) {
case WebSocketEncoding::Json:
if (messageJson.empty()) {
messageJson = eventMessage.dump();
}
_server.send((websocketpp::connection_hdl)it.first, messageJson, websocketpp::frame::opcode::text, errorCode);
it.second.IncrementOutgoingMessages();
it.second->IncrementOutgoingMessages();
break;
case WebSocketEncoding::MsgPack:
if (messageMsgPack.empty()) {
@ -241,11 +235,10 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT
messageMsgPack = std::string(msgPackData.begin(), msgPackData.end());
}
_server.send((websocketpp::connection_hdl)it.first, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
it.second.IncrementOutgoingMessages();
it.second->IncrementOutgoingMessages();
break;
}
}
it.second.DelRef();
}
lock.unlock();
if (_debugEnabled)
@ -259,25 +252,21 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
// Build new session
std::unique_lock<std::mutex> lock(_sessionMutex);
auto &session = _sessions[hdl];
SessionPtr session = _sessions[hdl] = std::make_shared<WebSocketSession>();
lock.unlock();
if (!session.AddRef())
return;
// Configure session details
session.SetRemoteAddress(conn->get_remote_endpoint());
session.SetConnectedAt(QDateTime::currentSecsSinceEpoch());
session->SetRemoteAddress(conn->get_remote_endpoint());
session->SetConnectedAt(QDateTime::currentSecsSinceEpoch());
std::string contentType = conn->get_request_header("Content-Type");
if (contentType == "") {
;
} else if (contentType == "application/json") {
session.SetEncoding(WebSocketEncoding::Json);
session->SetEncoding(WebSocketEncoding::Json);
} else if (contentType == "application/msgpack") {
session.SetEncoding(WebSocketEncoding::MsgPack);
session->SetEncoding(WebSocketEncoding::MsgPack);
} else {
conn->close(WebSocketCloseCode::InvalidContentType, "Your HTTP `Content-Type` header specifies an invalid encoding type.");
session.DelRef();
return;
}
@ -289,7 +278,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
// todo: Add request and event lists
if (AuthenticationRequired) {
std::string sessionChallenge = Utils::Crypto::GenerateSalt();
session.SetChallenge(sessionChallenge);
session->SetChallenge(sessionChallenge);
helloMessage["authentication"] = {};
helloMessage["authentication"]["challenge"] = sessionChallenge;
helloMessage["authentication"]["salt"] = AuthenticationSalt;
@ -297,7 +286,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
// Send object to client
websocketpp::lib::error_code errorCode;
auto sessionEncoding = session.Encoding();
auto sessionEncoding = session->Encoding();
if (sessionEncoding == WebSocketEncoding::Json) {
std::string helloMessageJson = helloMessage.dump();
_server.send(hdl, helloMessageJson, websocketpp::frame::opcode::text, errorCode);
@ -306,8 +295,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
std::string messageMsgPack(msgPackData.begin(), msgPackData.end());
_server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
}
session.IncrementOutgoingMessages();
session.DelRef();
session->IncrementOutgoingMessages();
}
void WebSocketServer::onClose(websocketpp::connection_hdl hdl)
@ -316,15 +304,12 @@ void WebSocketServer::onClose(websocketpp::connection_hdl hdl)
// Get info from the session and then delete it
std::unique_lock<std::mutex> lock(_sessionMutex);
auto &session = _sessions[hdl];
session.AddRef();
bool isIdentified = session.IsIdentified();
uint64_t connectedAt = session.ConnectedAt();
uint64_t incomingMessages = session.IncomingMessages();
uint64_t outgoingMessages = session.OutgoingMessages();
std::string remoteAddress = session.RemoteAddress();
session.SetDeleted();
session.DelRef();
SessionPtr session = _sessions[hdl];
bool isIdentified = session->IsIdentified();
uint64_t connectedAt = session->ConnectedAt();
uint64_t incomingMessages = session->IncomingMessages();
uint64_t outgoingMessages = session->OutgoingMessages();
std::string remoteAddress = session->RemoteAddress();
_sessions.erase(hdl);
lock.unlock();
@ -347,48 +332,47 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
std::string payload = message->get_payload();
QtConcurrent::run(&_threadPool, [=]() {
std::unique_lock<std::mutex> lock(_sessionMutex);
auto &session = _sessions[hdl];
if (!session.AddRef())
SessionPtr session;
try {
session = _sessions.at(hdl);
} catch (const std::out_of_range& oor) {
return;
}
lock.unlock();
session.IncrementIncomingMessages();
session->IncrementIncomingMessages();
json incomingMessage;
// Check for invalid opcode and decode
websocketpp::lib::error_code errorCode;
uint8_t sessionEncoding = session.Encoding();
uint8_t sessionEncoding = session->Encoding();
if (sessionEncoding == WebSocketEncoding::Json) {
if (opcode != websocketpp::frame::opcode::text) {
if (!session.IgnoreInvalidMessages())
if (!session->IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to Json, but a binary message was received.", errorCode);
session.DelRef();
return;
}
try {
incomingMessage = json::parse(payload);
} catch (json::parse_error& e) {
if (!session.IgnoreInvalidMessages())
if (!session->IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode Json: ") + e.what(), errorCode);
session.DelRef();
return;
}
} else if (sessionEncoding == WebSocketEncoding::MsgPack) {
if (opcode != websocketpp::frame::opcode::binary) {
if (!session.IgnoreInvalidMessages())
if (!session->IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to MsgPack, but a text message was received.", errorCode);
session.DelRef();
return;
}
try {
incomingMessage = json::from_msgpack(payload);
} catch (json::parse_error& e) {
if (!session.IgnoreInvalidMessages())
if (!session->IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode MsgPack: ") + e.what(), errorCode);
session.DelRef();
return;
}
}
@ -396,12 +380,11 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
if (_debugEnabled)
blog(LOG_INFO, "[WebSocketServer::onMessage] Incoming message (decoded):\n%s", incomingMessage.dump(2).c_str());
WebSocketProtocol::ProcessResult ret = WebSocketProtocol::ProcessMessage(hdl, &session, incomingMessage);
WebSocketProtocol::ProcessResult ret = WebSocketProtocol::ProcessMessage(session, incomingMessage);
if (ret.closeCode != WebSocketCloseCode::DontClose) {
websocketpp::lib::error_code errorCode;
_server.close(hdl, ret.closeCode, ret.closeReason, errorCode);
session.DelRef();
return;
}
@ -415,7 +398,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
std::string messageMsgPack(msgPackData.begin(), msgPackData.end());
_server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
}
session.IncrementOutgoingMessages();
session->IncrementOutgoingMessages();
if (_debugEnabled)
blog(LOG_INFO, "[WebSocketServer::onMessage] Outgoing message:\n%s", ret.result.dump(2).c_str());
@ -423,6 +406,5 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
if (errorCode)
blog(LOG_WARNING, "[WebSocketServer::onMessage] Sending message to client failed: %s", errorCode.message().c_str());
}
session.DelRef();
});
}

View File

@ -4,6 +4,7 @@
#include <QThreadPool>
#include <QString>
#include <mutex>
#include <memory>
#include <nlohmann/json.hpp>
#include <websocketpp/config/asio_no_tls.hpp>
@ -12,6 +13,7 @@
#include "WebSocketSession.h"
using json = nlohmann::json;
typedef std::shared_ptr<WebSocketSession> SessionPtr;
class WebSocketServer : QObject
{
@ -100,7 +102,7 @@ class WebSocketServer : QObject
websocketpp::server<websocketpp::config::asio> _server;
QThreadPool _threadPool;
std::mutex _sessionMutex;
std::map<websocketpp::connection_hdl, WebSocketSession, std::owner_less<websocketpp::connection_hdl>> _sessions;
std::map<websocketpp::connection_hdl, SessionPtr, std::owner_less<websocketpp::connection_hdl>> _sessions;
uint16_t _serverPort;
QString _serverPassword;
bool _debugEnabled;

View File

@ -5,8 +5,6 @@
#include "plugin-macros.generated.h"
WebSocketSession::WebSocketSession() :
_ref(0),
_deleted(false),
_remoteAddress(""),
_connectedAt(0),
_incomingMessages(0),
@ -21,31 +19,6 @@ WebSocketSession::WebSocketSession() :
{
}
bool WebSocketSession::AddRef()
{
std::lock_guard<std::mutex> lock(_refMutex);
if (_deleted)
return false;
_ref++;
return true;
}
void WebSocketSession::DelRef()
{
std::lock_guard<std::mutex> lock(_refMutex);
if (_ref == 0) {
blog(LOG_ERROR, "[WebSocketSession::DelRef] Failed to de-increment ref - already 0");
return;
}
_ref--;
}
void WebSocketSession::SetDeleted()
{
std::lock_guard<std::mutex> lock(_refMutex);
_deleted = true;
}
std::string WebSocketSession::RemoteAddress()
{
std::lock_guard<std::mutex> lock(_remoteAddressMutex);

View File

@ -9,10 +9,6 @@ class WebSocketSession
public:
WebSocketSession();
bool AddRef();
void DelRef();
void SetDeleted();
std::string RemoteAddress();
void SetRemoteAddress(std::string address);
@ -47,9 +43,6 @@ class WebSocketSession
void SetEventSubscriptions(uint64_t subscriptions);
private:
std::mutex _refMutex;
uint64_t _ref;
bool _deleted;
std::mutex _remoteAddressMutex;
std::string _remoteAddress;
std::atomic<uint64_t> _connectedAt;