WebSocketSession: Add refcount

Working towards fixing concurrency issues. Todo:

- Wait for refcount to be 0 before deleting object
- Use .at() instead of operator[] to prevent recreating deleted
sessions
- There was a third thing. Dont remember what it was
This commit is contained in:
tt2468 2021-04-29 22:11:24 -07:00
parent e151a9a8db
commit 32758198ab
3 changed files with 61 additions and 1 deletions

View File

@ -176,10 +176,13 @@ std::vector<WebSocketServer::WebSocketSessionState> WebSocketServer::GetWebSocke
std::unique_lock<std::mutex> lock(_sessionMutex);
for (auto & [hdl, session] : _sessions) {
if (!session.AddRef())
continue;
uint64_t connectedAt = session.ConnectedAt();
uint64_t incomingMessages = session.IncomingMessages();
uint64_t outgoingMessages = session.OutgoingMessages();
std::string remoteAddress = session.RemoteAddress();
session.DelRef();
webSocketSessions.emplace_back(WebSocketSessionState{hdl, remoteAddress, connectedAt, incomingMessages, outgoingMessages});
}
@ -216,8 +219,12 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT
// Recurse connected sessions and send the event to suitable sessions.
std::unique_lock<std::mutex> lock(_sessionMutex);
for (auto & it : _sessions) {
if (!it.second.IsIdentified())
if (!it.second.AddRef())
continue;
if (!it.second.IsIdentified()) {
it.second.DelRef();
continue;
}
if ((it.second.EventSubscriptions() & requiredIntent) != 0) {
websocketpp::lib::error_code errorCode;
switch (it.second.Encoding()) {
@ -238,6 +245,7 @@ void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, std::string eventT
break;
}
}
it.second.DelRef();
}
lock.unlock();
if (_debugEnabled)
@ -254,6 +262,9 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
auto &session = _sessions[hdl];
lock.unlock();
if (!session.AddRef())
return;
// Configure session details
session.SetRemoteAddress(conn->get_remote_endpoint());
session.SetConnectedAt(QDateTime::currentSecsSinceEpoch());
@ -266,6 +277,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
session.SetEncoding(WebSocketEncoding::MsgPack);
} else {
conn->close(WebSocketCloseCode::InvalidContentType, "Your HTTP `Content-Type` header specifies an invalid encoding type.");
session.DelRef();
return;
}
@ -295,6 +307,7 @@ void WebSocketServer::onOpen(websocketpp::connection_hdl hdl)
_server.send(hdl, messageMsgPack, websocketpp::frame::opcode::binary, errorCode);
}
session.IncrementOutgoingMessages();
session.DelRef();
}
void WebSocketServer::onClose(websocketpp::connection_hdl hdl)
@ -304,11 +317,14 @@ void WebSocketServer::onClose(websocketpp::connection_hdl hdl)
// Get info from the session and then delete it
std::unique_lock<std::mutex> lock(_sessionMutex);
auto &session = _sessions[hdl];
session.AddRef();
bool isIdentified = session.IsIdentified();
uint64_t connectedAt = session.ConnectedAt();
uint64_t incomingMessages = session.IncomingMessages();
uint64_t outgoingMessages = session.OutgoingMessages();
std::string remoteAddress = session.RemoteAddress();
session.SetDeleted();
session.DelRef();
_sessions.erase(hdl);
lock.unlock();
@ -332,6 +348,8 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
QtConcurrent::run(&_threadPool, [=]() {
std::unique_lock<std::mutex> lock(_sessionMutex);
auto &session = _sessions[hdl];
if (!session.AddRef())
return;
lock.unlock();
session.IncrementIncomingMessages();
@ -345,6 +363,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
if (opcode != websocketpp::frame::opcode::text) {
if (!session.IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to Json, but a binary message was received.", errorCode);
session.DelRef();
return;
}
@ -353,12 +372,14 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
} catch (json::parse_error& e) {
if (!session.IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode Json: ") + e.what(), errorCode);
session.DelRef();
return;
}
} else if (sessionEncoding == WebSocketEncoding::MsgPack) {
if (opcode != websocketpp::frame::opcode::binary) {
if (!session.IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to MsgPack, but a text message was received.", errorCode);
session.DelRef();
return;
}
@ -367,6 +388,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
} catch (json::parse_error& e) {
if (!session.IgnoreInvalidMessages())
_server.close(hdl, WebSocketCloseCode::MessageDecodeError, std::string("Unable to decode MsgPack: ") + e.what(), errorCode);
session.DelRef();
return;
}
}
@ -379,6 +401,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
if (ret.closeCode != WebSocketCloseCode::DontClose) {
websocketpp::lib::error_code errorCode;
_server.close(hdl, ret.closeCode, ret.closeReason, errorCode);
session.DelRef();
return;
}
@ -400,5 +423,6 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se
if (errorCode)
blog(LOG_WARNING, "[WebSocketServer::onMessage] Sending message to client failed: %s", errorCode.message().c_str());
}
session.DelRef();
});
}

View File

@ -1,8 +1,12 @@
#include <obs-module.h>
#include "WebSocketSession.h"
#include "plugin-macros.generated.h"
WebSocketSession::WebSocketSession() :
_ref(0),
_deleted(false),
_remoteAddress(""),
_connectedAt(0),
_incomingMessages(0),
@ -17,6 +21,31 @@ WebSocketSession::WebSocketSession() :
{
}
bool WebSocketSession::AddRef()
{
std::lock_guard<std::mutex> lock(_refMutex);
if (_deleted)
return false;
_ref++;
return true;
}
void WebSocketSession::DelRef()
{
std::lock_guard<std::mutex> lock(_refMutex);
if (_ref == 0) {
blog(LOG_ERROR, "[WebSocketSession::DelRef] Failed to de-increment ref - already 0");
return;
}
_ref--;
}
void WebSocketSession::SetDeleted()
{
std::lock_guard<std::mutex> lock(_refMutex);
_deleted = true;
}
std::string WebSocketSession::RemoteAddress()
{
std::lock_guard<std::mutex> lock(_remoteAddressMutex);

View File

@ -9,6 +9,10 @@ class WebSocketSession
public:
WebSocketSession();
bool AddRef();
void DelRef();
void SetDeleted();
std::string RemoteAddress();
void SetRemoteAddress(std::string address);
@ -43,6 +47,9 @@ class WebSocketSession
void SetEventSubscriptions(uint64_t subscriptions);
private:
std::mutex _refMutex;
uint64_t _ref;
bool _deleted;
std::mutex _remoteAddressMutex;
std::string _remoteAddress;
std::atomic<uint64_t> _connectedAt;