mirror of
https://github.com/Palakis/obs-websocket.git
synced 2024-08-30 18:12:16 +00:00
base: Use shared_ptr instead of explicit ref counts
Took a night of sleep but I realized how I could solve the concurrency issues in a good way. Uses shared_ptr, where the map always accounts for one reference to a session.
This commit is contained in:
parent
32758198ab
commit
4be9b995fb
@ -4,14 +4,14 @@
|
||||
|
||||
#include "plugin-macros.generated.h"
|
||||
|
||||
WebSocketProtocol::ProcessResult SetSessionParameters(WebSocketSession *session, json incomingMessage)
|
||||
WebSocketProtocol::ProcessResult SetSessionParameters(SessionPtr session, json incomingMessage)
|
||||
{
|
||||
WebSocketProtocol::ProcessResult ret;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
WebSocketProtocol::ProcessResult WebSocketProtocol::ProcessMessage(websocketpp::connection_hdl hdl, WebSocketSession *session, json incomingMessage)
|
||||
WebSocketProtocol::ProcessResult WebSocketProtocol::ProcessMessage(SessionPtr session, json incomingMessage)
|
||||
{
|
||||
WebSocketProtocol::ProcessResult ret;
|
||||
|
||||
|
@ -14,5 +14,5 @@ namespace WebSocketProtocol {
|
||||
json result;
|
||||
};
|
||||
|
||||
ProcessResult ProcessMessage(websocketpp::connection_hdl hdl, WebSocketSession *session, json incomingMessage);
|
||||
ProcessResult ProcessMessage(SessionPtr session, json incomingMessage);
|
||||
}
|
||||
|
@ -176,13 +176,10 @@ std::vector<WebSocketServer::WebSocketSessionState> WebSocketServer::GetWebSocke
|
||||
|
||||
std::unique_lock<std::mutex> lock(_sessionMutex);
|
||||
for (auto & [hdl, session] : _sessions) {
|
||||
if (!session.AddRef())
|
||||
continue;
|
||||
uint64_t connectedAt = session.ConnectedAt();
|
||||
uint64_t incomingMessages = session.IncomingMessages();
|
||||
uint64_t outgoingMessages = session.OutgoingMessages();
|
||||
std::string remoteAddress = session.RemoteAddress();
|
||||
session.DelRef();
|
||||
uint64_t connectedAt = session->ConnectedAt();
|
||||
uint64_t incomingMessages = session->IncomingMessages();
|
||||
uint64_t outgoingMessages = session->OutgoingMessages();
|
||||
std::string remoteAddress = session->RemoteAddress();
|
||||
|
||||
webSocketSessions.emplace_back(WebSocketSessionState{hdl, remoteAddress, connectedAt, incomingMessages, outgoingMessages});
|
||||
}
|
||||
@ -219,21 +216,18 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT
|
||||
// Recurse connected sessions and send the event to suitable sessions.
|
||||
std::unique_lock<std::mutex> lock(_sessionMutex);
|
||||
for (auto & it : _sessions) {
|
||||
if (!it.second.AddRef())
|
||||
continue;
|
||||
if (!it.second.IsIdentified()) {
|
||||
it.second.DelRef();
|
||||
if (!it.second->IsIdentified()) {
|
||||
continue;
|
||||
}
|
||||
if ((it.second.EventSubscriptions() & requiredIntent) != 0) {
|
||||
if ((it.second->EventSubscriptions() & requiredIntent) != 0) {
|
||||
websocketpp::lib::error_code errorCode;
|
||||
switch (it.second.Encoding()) {
|
||||
switch (it.second->Encoding()) {
|
||||
case WebSocketEncoding::Json:
|
||||
if (messageJson.empty()) {
|
||||
messageJson = eventMessage.dump();
|
||||
}
|
||||
_server.send((websocketpp::connection_hdl)it.first, messageJson, websocketpp::frame::opcode::text, errorCode);
|
||||
it.second.IncrementOutgoingMessages();
|
||||
it.second->IncrementOutgoingMessages();
|
||||
break;
|
||||
case WebSocketEncoding::MsgPack:
|
||||
if (messageMsgPack.empty()) {
|
||||
@ -241,11 +235,10 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT
|
||||
messageMsgPack = std::string(msgPackData.begin(), msgPackData.end());
|
||||
}
|
||||
_server.send((websocketpp::connection_hdl)it.first, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
|
||||
it.second.IncrementOutgoingMessages();
|
||||
it.second->IncrementOutgoingMessages();
|
||||
break;
|
||||
}
|
||||
}
|
||||
it.second.DelRef();
|
||||
}
|
||||
lock.unlock();
|
||||
if (_debugEnabled)
|
||||
@ -259,25 +252,21 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
|
||||
|
||||
// Build new session
|
||||
std::unique_lock<std::mutex> lock(_sessionMutex);
|
||||
auto &session = _sessions[hdl];
|
||||
SessionPtr session = _sessions[hdl] = std::make_shared<WebSocketSession>();
|
||||
lock.unlock();
|
||||
|
||||
if (!session.AddRef())
|
||||
return;
|
||||
|
||||
// Configure session details
|
||||
session.SetRemoteAddress(conn->get_remote_endpoint());
|
||||
session.SetConnectedAt(QDateTime::currentSecsSinceEpoch());
|
||||
session->SetRemoteAddress(conn->get_remote_endpoint());
|
||||
session->SetConnectedAt(QDateTime::currentSecsSinceEpoch());
|
||||
std::string contentType = conn->get_request_header("Content-Type");
|
||||
if (contentType == "") {
|
||||
;
|
||||
} else if (contentType == "application/json") {
|
||||
session.SetEncoding(WebSocketEncoding::Json);
|
||||
session->SetEncoding(WebSocketEncoding::Json);
|
||||
} else if (contentType == "application/msgpack") {
|
||||
session.SetEncoding(WebSocketEncoding::MsgPack);
|
||||
session->SetEncoding(WebSocketEncoding::MsgPack);
|
||||
} else {
|
||||
conn->close(WebSocketCloseCode::InvalidContentType, "Your HTTP `Content-Type` header specifies an invalid encoding type.");
|
||||
session.DelRef();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -289,7 +278,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
|
||||
// todo: Add request and event lists
|
||||
if (AuthenticationRequired) {
|
||||
std::string sessionChallenge = Utils::Crypto::GenerateSalt();
|
||||
session.SetChallenge(sessionChallenge);
|
||||
session->SetChallenge(sessionChallenge);
|
||||
helloMessage["authentication"] = {};
|
||||
helloMessage["authentication"]["challenge"] = sessionChallenge;
|
||||
helloMessage["authentication"]["salt"] = AuthenticationSalt;
|
||||
@ -297,7 +286,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
|
||||
|
||||
// Send object to client
|
||||
websocketpp::lib::error_code errorCode;
|
||||
auto sessionEncoding = session.Encoding();
|
||||
auto sessionEncoding = session->Encoding();
|
||||
if (sessionEncoding == WebSocketEncoding::Json) {
|
||||
std::string helloMessageJson = helloMessage.dump();
|
||||
_server.send(hdl, helloMessageJson, websocketpp::frame::opcode::text, errorCode);
|
||||
@ -306,8 +295,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
|
||||
std::string messageMsgPack(msgPackData.begin(), msgPackData.end());
|
||||
_server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
|
||||
}
|
||||
session.IncrementOutgoingMessages();
|
||||
session.DelRef();
|
||||
session->IncrementOutgoingMessages();
|
||||
}
|
||||
|
||||
void WebSocketServer::onClose(websocketpp::connection_hdl hdl)
|
||||
@ -316,15 +304,12 @@ void WebSocketServer::onClose(websocketpp::connection_hdl hdl)
|
||||
|
||||
// Get info from the session and then delete it
|
||||
std::unique_lock<std::mutex> lock(_sessionMutex);
|
||||
auto &session = _sessions[hdl];
|
||||
session.AddRef();
|
||||
bool isIdentified = session.IsIdentified();
|
||||
uint64_t connectedAt = session.ConnectedAt();
|
||||
uint64_t incomingMessages = session.IncomingMessages();
|
||||
uint64_t outgoingMessages = session.OutgoingMessages();
|
||||
std::string remoteAddress = session.RemoteAddress();
|
||||
session.SetDeleted();
|
||||
session.DelRef();
|
||||
SessionPtr session = _sessions[hdl];
|
||||
bool isIdentified = session->IsIdentified();
|
||||
uint64_t connectedAt = session->ConnectedAt();
|
||||
uint64_t incomingMessages = session->IncomingMessages();
|
||||
uint64_t outgoingMessages = session->OutgoingMessages();
|
||||
std::string remoteAddress = session->RemoteAddress();
|
||||
_sessions.erase(hdl);
|
||||
lock.unlock();
|
||||
|
||||
@ -347,48 +332,47 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
|
||||
std::string payload = message->get_payload();
|
||||
QtConcurrent::run(&_threadPool, [=]() {
|
||||
std::unique_lock<std::mutex> lock(_sessionMutex);
|
||||
auto &session = _sessions[hdl];
|
||||
if (!session.AddRef())
|
||||
SessionPtr session;
|
||||
try {
|
||||
session = _sessions.at(hdl);
|
||||
} catch (const std::out_of_range& oor) {
|
||||
return;
|
||||
}
|
||||
lock.unlock();
|
||||
|
||||
session.IncrementIncomingMessages();
|
||||
session->IncrementIncomingMessages();
|
||||
|
||||
json incomingMessage;
|
||||
|
||||
// Check for invalid opcode and decode
|
||||
websocketpp::lib::error_code errorCode;
|
||||
uint8_t sessionEncoding = session.Encoding();
|
||||
uint8_t sessionEncoding = session->Encoding();
|
||||
if (sessionEncoding == WebSocketEncoding::Json) {
|
||||
if (opcode != websocketpp::frame::opcode::text) {
|
||||
if (!session.IgnoreInvalidMessages())
|
||||
if (!session->IgnoreInvalidMessages())
|
||||
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to Json, but a binary message was received.", errorCode);
|
||||
session.DelRef();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
incomingMessage = json::parse(payload);
|
||||
} catch (json::parse_error& e) {
|
||||
if (!session.IgnoreInvalidMessages())
|
||||
if (!session->IgnoreInvalidMessages())
|
||||
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode Json: ") + e.what(), errorCode);
|
||||
session.DelRef();
|
||||
return;
|
||||
}
|
||||
} else if (sessionEncoding == WebSocketEncoding::MsgPack) {
|
||||
if (opcode != websocketpp::frame::opcode::binary) {
|
||||
if (!session.IgnoreInvalidMessages())
|
||||
if (!session->IgnoreInvalidMessages())
|
||||
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to MsgPack, but a text message was received.", errorCode);
|
||||
session.DelRef();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
incomingMessage = json::from_msgpack(payload);
|
||||
} catch (json::parse_error& e) {
|
||||
if (!session.IgnoreInvalidMessages())
|
||||
if (!session->IgnoreInvalidMessages())
|
||||
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode MsgPack: ") + e.what(), errorCode);
|
||||
session.DelRef();
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -396,12 +380,11 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
|
||||
if (_debugEnabled)
|
||||
blog(LOG_INFO, "[WebSocketServer::onMessage] Incoming message (decoded):\n%s", incomingMessage.dump(2).c_str());
|
||||
|
||||
WebSocketProtocol::ProcessResult ret = WebSocketProtocol::ProcessMessage(hdl, &session, incomingMessage);
|
||||
WebSocketProtocol::ProcessResult ret = WebSocketProtocol::ProcessMessage(session, incomingMessage);
|
||||
|
||||
if (ret.closeCode != WebSocketCloseCode::DontClose) {
|
||||
websocketpp::lib::error_code errorCode;
|
||||
_server.close(hdl, ret.closeCode, ret.closeReason, errorCode);
|
||||
session.DelRef();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -415,7 +398,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
|
||||
std::string messageMsgPack(msgPackData.begin(), msgPackData.end());
|
||||
_server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
|
||||
}
|
||||
session.IncrementOutgoingMessages();
|
||||
session->IncrementOutgoingMessages();
|
||||
|
||||
if (_debugEnabled)
|
||||
blog(LOG_INFO, "[WebSocketServer::onMessage] Outgoing message:\n%s", ret.result.dump(2).c_str());
|
||||
@ -423,6 +406,5 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
|
||||
if (errorCode)
|
||||
blog(LOG_WARNING, "[WebSocketServer::onMessage] Sending message to client failed: %s", errorCode.message().c_str());
|
||||
}
|
||||
session.DelRef();
|
||||
});
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <QThreadPool>
|
||||
#include <QString>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <websocketpp/config/asio_no_tls.hpp>
|
||||
@ -12,6 +13,7 @@
|
||||
#include "WebSocketSession.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
typedef std::shared_ptr<WebSocketSession> SessionPtr;
|
||||
|
||||
class WebSocketServer : QObject
|
||||
{
|
||||
@ -100,7 +102,7 @@ class WebSocketServer : QObject
|
||||
websocketpp::server<websocketpp::config::asio> _server;
|
||||
QThreadPool _threadPool;
|
||||
std::mutex _sessionMutex;
|
||||
std::map<websocketpp::connection_hdl, WebSocketSession, std::owner_less<websocketpp::connection_hdl>> _sessions;
|
||||
std::map<websocketpp::connection_hdl, SessionPtr, std::owner_less<websocketpp::connection_hdl>> _sessions;
|
||||
uint16_t _serverPort;
|
||||
QString _serverPassword;
|
||||
bool _debugEnabled;
|
||||
|
@ -5,8 +5,6 @@
|
||||
#include "plugin-macros.generated.h"
|
||||
|
||||
WebSocketSession::WebSocketSession() :
|
||||
_ref(0),
|
||||
_deleted(false),
|
||||
_remoteAddress(""),
|
||||
_connectedAt(0),
|
||||
_incomingMessages(0),
|
||||
@ -21,31 +19,6 @@ WebSocketSession::WebSocketSession() :
|
||||
{
|
||||
}
|
||||
|
||||
bool WebSocketSession::AddRef()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(_refMutex);
|
||||
if (_deleted)
|
||||
return false;
|
||||
_ref++;
|
||||
return true;
|
||||
}
|
||||
|
||||
void WebSocketSession::DelRef()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(_refMutex);
|
||||
if (_ref == 0) {
|
||||
blog(LOG_ERROR, "[WebSocketSession::DelRef] Failed to de-increment ref - already 0");
|
||||
return;
|
||||
}
|
||||
_ref--;
|
||||
}
|
||||
|
||||
void WebSocketSession::SetDeleted()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(_refMutex);
|
||||
_deleted = true;
|
||||
}
|
||||
|
||||
std::string WebSocketSession::RemoteAddress()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(_remoteAddressMutex);
|
||||
|
@ -9,10 +9,6 @@ class WebSocketSession
|
||||
public:
|
||||
WebSocketSession();
|
||||
|
||||
bool AddRef();
|
||||
void DelRef();
|
||||
void SetDeleted();
|
||||
|
||||
std::string RemoteAddress();
|
||||
void SetRemoteAddress(std::string address);
|
||||
|
||||
@ -47,9 +43,6 @@ class WebSocketSession
|
||||
void SetEventSubscriptions(uint64_t subscriptions);
|
||||
|
||||
private:
|
||||
std::mutex _refMutex;
|
||||
uint64_t _ref;
|
||||
bool _deleted;
|
||||
std::mutex _remoteAddressMutex;
|
||||
std::string _remoteAddress;
|
||||
std::atomic<uint64_t> _connectedAt;
|
||||
|
Loading…
Reference in New Issue
Block a user