diff --git a/src/WSRequestHandler.cpp b/src/WSRequestHandler.cpp index 5535d3c0..a414de69 100644 --- a/src/WSRequestHandler.cpp +++ b/src/WSRequestHandler.cpp @@ -128,10 +128,11 @@ QSet WSRequestHandler::authNotRequired { "Authenticate" }; -WSRequestHandler::WSRequestHandler() : +WSRequestHandler::WSRequestHandler(QVariantHash* connProperties) : _messageId(0), _requestType(""), - data(nullptr) + data(nullptr), + _connProperties(connProperties) { } @@ -160,13 +161,13 @@ std::string WSRequestHandler::processIncomingMessage(std::string& textMessage) { _requestType = obs_data_get_string(data, "request-type"); _messageId = obs_data_get_string(data, "message-id"); - // if (Config::Current()->AuthRequired - // && (_client->property(PROP_AUTHENTICATED).toBool() == false) - // && (authNotRequired.find(_requestType) == authNotRequired.end())) - // { - // SendErrorResponse("Not Authenticated"); - // return; - // } + if (Config::Current()->AuthRequired + && (!authNotRequired.contains(_requestType)) + && (_connProperties->value(PROP_AUTHENTICATED).toBool() == false)) + { + SendErrorResponse("Not Authenticated"); + return _response; + } void (*handlerFunc)(WSRequestHandler*) = (messageMap[_requestType]); diff --git a/src/WSRequestHandler.h b/src/WSRequestHandler.h index 8d69693f..86182056 100644 --- a/src/WSRequestHandler.h +++ b/src/WSRequestHandler.h @@ -22,6 +22,7 @@ with this program. If not, see #include #include +#include #include #include @@ -32,7 +33,7 @@ class WSRequestHandler : public QObject { Q_OBJECT public: - explicit WSRequestHandler(); + explicit WSRequestHandler(QVariantHash* connProperties); ~WSRequestHandler(); std::string processIncomingMessage(std::string& textMessage); bool hasField(QString name); @@ -41,6 +42,7 @@ class WSRequestHandler : public QObject { const char* _messageId; const char* _requestType; std::string _response; + QVariantHash* _connProperties; OBSDataAutoRelease data; void SendOKResponse(obs_data_t* additionalFields = NULL); diff --git a/src/WSRequestHandler_General.cpp b/src/WSRequestHandler_General.cpp index efe0b9db..76f1ca01 100644 --- a/src/WSRequestHandler_General.cpp +++ b/src/WSRequestHandler_General.cpp @@ -85,20 +85,24 @@ void WSRequestHandler::HandleAuthenticate(WSRequestHandler* req) { return; } + if (req->_connProperties->value(PROP_AUTHENTICATED).toBool() == true) { + req->SendErrorResponse("already authenticated"); + return; + } + QString auth = obs_data_get_string(req->data, "auth"); if (auth.isEmpty()) { req->SendErrorResponse("auth not specified!"); return; } - // if ((req->_client->property(PROP_AUTHENTICATED).toBool() == false) - // && Config::Current()->CheckAuth(auth)) - // { - // req->_client->setProperty(PROP_AUTHENTICATED, true); - // req->SendOKResponse(); - // } else { - // req->SendErrorResponse("Authentication Failed."); - // } + if (Config::Current()->CheckAuth(auth) == false) { + req->SendErrorResponse("Authentication Failed."); + return; + } + + req->_connProperties->insert(PROP_AUTHENTICATED, true); + req->SendOKResponse(); } /** diff --git a/src/WSServer.cpp b/src/WSServer.cpp index fc532e9c..d15e39fc 100644 --- a/src/WSServer.cpp +++ b/src/WSServer.cpp @@ -118,11 +118,15 @@ void WSServer::onMessage(connection_hdl hdl, server::message_ptr message) return; } + QVariantHash connProperties = _connectionProperties[hdl]; + std::string payload = message->get_payload(); - WSRequestHandler handler; + WSRequestHandler handler(&connProperties); std::string response = handler.processIncomingMessage(payload); + _connectionProperties[hdl] = connProperties; + _server.send(hdl, response, websocketpp::frame::opcode::text); } @@ -130,6 +134,7 @@ void WSServer::onClose(connection_hdl hdl) { QMutexLocker locker(&_clMutex); _connections.erase(hdl); + _connectionProperties.erase(hdl); locker.unlock(); QString clientIp = getRemoteEndpoint(hdl); diff --git a/src/WSServer.h b/src/WSServer.h index d0f79541..8c243700 100644 --- a/src/WSServer.h +++ b/src/WSServer.h @@ -21,8 +21,11 @@ with this program. If not, see #include #include +#include +#include #include + #include #include @@ -59,6 +62,7 @@ private: server _server; quint16 _serverPort; std::set> _connections; + std::map> _connectionProperties; QMutex _clMutex; };