diff --git a/CMakeLists.txt b/CMakeLists.txt index 16d495d3..f4288218 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,6 +87,7 @@ set(obs-websocket_SOURCES src/Config.cpp src/WebSocketServer.cpp src/WebSocketServer_Protocol.cpp + src/WebSocketServer_RequestBatchProcessing.cpp src/WebSocketSession.cpp src/eventhandler/EventHandler.cpp src/eventhandler/EventHandler_General.cpp diff --git a/src/WebSocketServer.cpp b/src/WebSocketServer.cpp index 6619da85..1f1d24fd 100644 --- a/src/WebSocketServer.cpp +++ b/src/WebSocketServer.cpp @@ -339,7 +339,7 @@ void WebSocketServer::onClose(websocketpp::connection_hdl hdl) void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::server::message_ptr message) { - auto opcode = message->get_opcode(); + auto opCode = message->get_opcode(); std::string payload = message->get_payload(); _threadPool.start(Utils::Compat::CreateFunctionRunnable([=]() { std::unique_lock lock(_sessionMutex); @@ -359,7 +359,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se websocketpp::lib::error_code errorCode; uint8_t sessionEncoding = session->Encoding(); if (sessionEncoding == WebSocketEncoding::Json) { - if (opcode != websocketpp::frame::opcode::text) { + if (opCode != websocketpp::frame::opcode::text) { if (!session->IgnoreInvalidMessages()) _server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to Json, but a binary message was received.", errorCode); return; @@ -373,7 +373,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se return; } } else if (sessionEncoding == WebSocketEncoding::MsgPack) { - if (opcode != websocketpp::frame::opcode::binary) { + if (opCode != websocketpp::frame::opcode::binary) { if (!session->IgnoreInvalidMessages()) _server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to MsgPack, but a text message was received.", errorCode); return; diff --git a/src/WebSocketServer.h b/src/WebSocketServer.h index d9c82736..3a0e3f5c 100644 --- a/src/WebSocketServer.h +++ b/src/WebSocketServer.h @@ -9,6 +9,7 @@ #include "utils/Json.h" #include "WebSocketSession.h" +#include "requesthandler/rpc/Request.h" #include "plugin-macros.generated.h" class WebSocketServer : QObject @@ -53,6 +54,10 @@ class WebSocketServer : QObject UnsupportedRpcVersion = 4009, // The websocket session has been invalidated by the obs-websocket server. SessionInvalidated = 4010, + // A data key's value is invalid, in the case of things like enums. + InvalidDataKeyValue = 4011, + // A feature is not supported because of hardware/software limitations. + UnsupportedFeature = 4012, }; WebSocketServer(); @@ -98,6 +103,8 @@ class WebSocketServer : QObject void SetSessionParameters(SessionPtr session, WebSocketServer::ProcessResult &ret, json payloadData); void ProcessMessage(SessionPtr session, ProcessResult &ret, const uint8_t opCode, json incomingMessage); + void ProcessRequestBatch(SessionPtr session, ObsWebSocketRequestBatchExecutionType executionType, std::vector &requests, std::vector &results); + std::thread _serverThread; websocketpp::server _server; QThreadPool _threadPool; diff --git a/src/WebSocketServer_Protocol.cpp b/src/WebSocketServer_Protocol.cpp index aa4fea96..5b44e3f2 100644 --- a/src/WebSocketServer_Protocol.cpp +++ b/src/WebSocketServer_Protocol.cpp @@ -214,35 +214,42 @@ void WebSocketServer::ProcessMessage(SessionPtr session, WebSocketServer::Proces return; } - std::vector requests = payloadData["requests"]; - json results = json::array(); - - RequestHandler requestHandler(session); - for (auto requestJson : requests) { - Request request(requestJson["requestType"], requestJson["requestData"]); - - RequestResult requestResult = requestHandler.ProcessRequest(request); - - json result; - result["requestType"] = requestJson["requestType"]; - - if (requestJson.contains("requestId")) - result["requestId"] = requestJson["requestId"]; - - result["requestStatus"] = { - {"result", requestResult.StatusCode == RequestStatus::Success}, - {"code", requestResult.StatusCode} - }; - - if (!requestResult.Comment.empty()) - result["requestStatus"]["comment"] = requestResult.Comment; - - if (requestResult.ResponseData.is_object()) - result["responseData"] = requestResult.ResponseData; - - results.push_back(result); + ObsWebSocketRequestBatchExecutionType executionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME; + if (payloadData.contains("executionType") && !payloadData["executionType"].is_null()) { + if (!payloadData["executionType"].is_string()) { + if (!session->IgnoreInvalidMessages()) { + ret.closeCode = WebSocketCloseCode::InvalidDataKeyType; + ret.closeReason = "Your `executionType` is not a string."; + } + return; + } + std::string executionTypeString = payloadData["executionType"]; + if (executionTypeString == "OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME") { + executionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME; + } else if (executionTypeString == "OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME") { + executionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME; + } else if (executionTypeString == "OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL") { + if (_threadPool.maxThreadCount() < 2) { + if (!session->IgnoreInvalidMessages()) { + ret.closeCode = WebSocketCloseCode::UnsupportedFeature; + ret.closeReason = "Parallel request batch processing is not available on this system due to limited core count."; + } + return; + } + executionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL; + } else { + if (!session->IgnoreInvalidMessages()) { + ret.closeCode = WebSocketCloseCode::InvalidDataKeyValue; + ret.closeReason = "Your `executionType`'s value is not recognized."; + } + return; + } } + std::vector requests = payloadData["requests"]; + std::vector results; + ProcessRequestBatch(session, executionType, requests, results); + ret.result["op"] = WebSocketOpCode::RequestBatchResponse; ret.result["d"]["requestId"] = payloadData["requestId"]; ret.result["d"]["results"] = results; diff --git a/src/WebSocketServer_RequestBatchProcessing.cpp b/src/WebSocketServer_RequestBatchProcessing.cpp new file mode 100644 index 00000000..338e2b77 --- /dev/null +++ b/src/WebSocketServer_RequestBatchProcessing.cpp @@ -0,0 +1,175 @@ +#include + +#include "WebSocketServer.h" +#include "requesthandler/RequestHandler.h" +#include "obs-websocket.h" +#include "utils/Compat.h" + +struct SerialFrameBatch +{ + RequestHandler *requestHandler; + size_t frameCount; + size_t sleepUntilFrame; + std::queue requests; + std::vector results; + std::mutex conditionMutex; + std::condition_variable condition; + + SerialFrameBatch(RequestHandler *requestHandler) : + requestHandler(requestHandler), + frameCount(0), + sleepUntilFrame(0) + {} +}; + +struct ParallelBatchResults +{ + RequestHandler *requestHandler; + size_t requestCount; + std::mutex resultsMutex; + std::vector results; + std::condition_variable condition; + + ParallelBatchResults(RequestHandler *requestHandler, size_t requestCount) : + requestHandler(requestHandler), + requestCount(requestCount) + {} +}; + +json ConstructRequestResult(RequestResult requestResult, json requestJson) +{ + json ret; + + ret["requestType"] = requestJson["requestType"]; + + if (requestJson.contains("requestId") && !requestJson["requestId"].is_null()) + ret["requestId"] = requestJson["requestId"]; + + ret["requestStatus"] = { + {"result", requestResult.StatusCode == RequestStatus::Success}, + {"code", requestResult.StatusCode} + }; + + if (!requestResult.Comment.empty()) + ret["requestStatus"]["comment"] = requestResult.Comment; + + if (requestResult.ResponseData.is_object()) + ret["responseData"] = requestResult.ResponseData; + + return ret; +} + +void ObsTickCallback(void *param, float) +{ + profile_start("obs-websocket-request-batch-frame-tick"); + + auto serialFrameBatch = reinterpret_cast(param); + + // Increment frame count + serialFrameBatch->frameCount++; + + if (serialFrameBatch->sleepUntilFrame) { + if (serialFrameBatch->frameCount < serialFrameBatch->sleepUntilFrame) { + // Do not process any requests if in "sleep mode" + profile_end("obs-websocket-request-batch-frame-tick"); + return; + } else { + // Reset frame sleep until counter if not being used + serialFrameBatch->sleepUntilFrame = 0; + } + } + + // Begin recursing any unprocessed requests + while (!serialFrameBatch->requests.empty()) { + // Fetch first in queue + Request request = serialFrameBatch->requests.front(); + // Process request and get result + RequestResult requestResult = serialFrameBatch->requestHandler->ProcessRequest(request); + // Add to results vector + serialFrameBatch->results.push_back(requestResult); + // Remove from front of queue + serialFrameBatch->requests.pop(); + + // If the processed request tells us to sleep, do so accordingly + if (requestResult.SleepFrames) { + serialFrameBatch->sleepUntilFrame = serialFrameBatch->frameCount + requestResult.SleepFrames; + break; + } + } + + // If request queue is empty, we can notify the paused worker thread + if (serialFrameBatch->requests.empty()) { + serialFrameBatch->condition.notify_one(); + } + + profile_end("obs-websocket-request-batch-frame-tick"); +} + +void WebSocketServer::ProcessRequestBatch(SessionPtr session, ObsWebSocketRequestBatchExecutionType executionType, std::vector &requests, std::vector &results) +{ + RequestHandler requestHandler(session); + if (executionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME) { + // Recurse all requests in batch serially, processing the request then moving to the next one + for (auto requestJson : requests) { + Request request(requestJson["requestType"], requestJson["requestData"], executionType); + + RequestResult requestResult = requestHandler.ProcessRequest(request); + + json result = ConstructRequestResult(requestResult, requestJson); + + results.push_back(result); + } + } else if (executionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME) { + SerialFrameBatch serialFrameBatch(&requestHandler); + + // Create Request objects in the worker thread (avoid unnecessary processing in graphics thread) + for (auto requestJson : requests) { + Request request(requestJson["requestType"], requestJson["requestData"], executionType); + serialFrameBatch.requests.push(request); + } + + // Create a callback entry for the graphics thread to execute on each video frame + obs_add_tick_callback(ObsTickCallback, &serialFrameBatch); + + // Wait until the graphics thread processes the last request in the queue + std::unique_lock lock(serialFrameBatch.conditionMutex); + serialFrameBatch.condition.wait(lock, [&serialFrameBatch]{return serialFrameBatch.requests.empty();}); + + // Remove the created callback entry since we don't need it anymore + obs_remove_tick_callback(ObsTickCallback, &serialFrameBatch); + + // Create Request objects in the worker thread (avoid unnecessary processing in graphics thread) + size_t i = 0; + for (auto requestResult : serialFrameBatch.results) { + results.push_back(ConstructRequestResult(requestResult, requests[i])); + i++; + } + } else if (executionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL) { + ParallelBatchResults parallelResults(&requestHandler, requests.size()); + + // Submit each request as a task to the thread pool to be processed ASAP + for (auto requestJson : requests) { + _threadPool.start(Utils::Compat::CreateFunctionRunnable([¶llelResults, &executionType, requestJson]() { + Request request(requestJson["requestType"], requestJson["requestData"], executionType); + + RequestResult requestResult = parallelResults.requestHandler->ProcessRequest(request); + + json result = ConstructRequestResult(requestResult, requestJson); + + std::unique_lock lock(parallelResults.resultsMutex); + parallelResults.results.push_back(result); + lock.unlock(); + parallelResults.condition.notify_one(); + })); + } + + // Wait for the last request to finish processing + std::unique_lock lock(parallelResults.resultsMutex); + auto cb = [¶llelResults]{return parallelResults.results.size() == parallelResults.requestCount;}; + // A check just in case all requests managed to complete before we started waiting for the condition to be notified + if (!cb()) + parallelResults.condition.wait(lock, cb); + + results = parallelResults.results; + } +} diff --git a/src/requesthandler/RequestHandler_General.cpp b/src/requesthandler/RequestHandler_General.cpp index 044603f3..e49d5110 100644 --- a/src/requesthandler/RequestHandler_General.cpp +++ b/src/requesthandler/RequestHandler_General.cpp @@ -148,11 +148,20 @@ RequestResult RequestHandler::Sleep(const Request& request) { RequestStatus::RequestStatus statusCode; std::string comment; - if (!request.ValidateNumber("sleepMillis", statusCode, comment, 0, 50000)) - return RequestResult::Error(statusCode, comment); - int64_t sleepMillis = request.RequestData["sleepMillis"]; - std::this_thread::sleep_for(std::chrono::milliseconds(sleepMillis)); - - return RequestResult::Success(); + if (request.RequestBatchExecutionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME) { + if (!request.ValidateNumber("sleepMillis", statusCode, comment, 0, 50000)) + return RequestResult::Error(statusCode, comment); + int64_t sleepMillis = request.RequestData["sleepMillis"]; + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMillis)); + return RequestResult::Success(); + } else if (request.RequestBatchExecutionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME) { + if (!request.ValidateNumber("sleepFrames", statusCode, comment, 0, 10000)) + return RequestResult::Error(statusCode, comment); + RequestResult ret = RequestResult::Success(); + ret.SleepFrames = request.RequestData["sleepFrames"]; + return ret; + } else { + return RequestResult::Error(RequestStatus::UnsupportedRequestBatchExecutionType); + } } diff --git a/src/requesthandler/rpc/Request.cpp b/src/requesthandler/rpc/Request.cpp index 8845b769..2ca60870 100644 --- a/src/requesthandler/rpc/Request.cpp +++ b/src/requesthandler/rpc/Request.cpp @@ -11,10 +11,11 @@ json GetDefaultJsonObject(json requestData) return requestData; } -Request::Request(std::string requestType, json requestData) : +Request::Request(std::string requestType, json requestData, ObsWebSocketRequestBatchExecutionType requestBatchExecutionType) : HasRequestData(requestData.is_object()), RequestType(requestType), - RequestData(GetDefaultJsonObject(requestData)) + RequestData(GetDefaultJsonObject(requestData)), + RequestBatchExecutionType(requestBatchExecutionType) { } diff --git a/src/requesthandler/rpc/Request.h b/src/requesthandler/rpc/Request.h index 26f6f638..417c988e 100644 --- a/src/requesthandler/rpc/Request.h +++ b/src/requesthandler/rpc/Request.h @@ -3,6 +3,12 @@ #include "RequestStatus.h" #include "../../utils/Json.h" +enum ObsWebSocketRequestBatchExecutionType { + OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME, + OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME, + OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL +}; + enum ObsWebSocketSceneFilter { OBS_WEBSOCKET_SCENE_FILTER_SCENE_ONLY, OBS_WEBSOCKET_SCENE_FILTER_GROUP_ONLY, @@ -11,7 +17,7 @@ enum ObsWebSocketSceneFilter { struct Request { - Request(const std::string requestType, const json requestData = nullptr); + Request(const std::string requestType, const json requestData = nullptr, const ObsWebSocketRequestBatchExecutionType requestBatchExecutionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME); // Contains the key and is not null const bool Contains(const std::string keyName) const; @@ -37,4 +43,5 @@ struct Request const bool HasRequestData; const std::string RequestType; const json RequestData; + const ObsWebSocketRequestBatchExecutionType RequestBatchExecutionType; }; diff --git a/src/requesthandler/rpc/RequestResult.cpp b/src/requesthandler/rpc/RequestResult.cpp index 7308a8ad..e7b964d0 100644 --- a/src/requesthandler/rpc/RequestResult.cpp +++ b/src/requesthandler/rpc/RequestResult.cpp @@ -3,7 +3,8 @@ RequestResult::RequestResult(RequestStatus::RequestStatus statusCode, json responseData, std::string comment) : StatusCode(statusCode), ResponseData(responseData), - Comment(comment) + Comment(comment), + SleepFrames(0) { } diff --git a/src/requesthandler/rpc/RequestResult.h b/src/requesthandler/rpc/RequestResult.h index aa8fbc08..565218b8 100644 --- a/src/requesthandler/rpc/RequestResult.h +++ b/src/requesthandler/rpc/RequestResult.h @@ -11,4 +11,5 @@ struct RequestResult RequestStatus::RequestStatus StatusCode; json ResponseData; std::string Comment; + size_t SleepFrames; }; diff --git a/src/requesthandler/rpc/RequestStatus.h b/src/requesthandler/rpc/RequestStatus.h index 7647715e..3d1566e8 100644 --- a/src/requesthandler/rpc/RequestStatus.h +++ b/src/requesthandler/rpc/RequestStatus.h @@ -15,6 +15,8 @@ namespace RequestStatus { UnknownRequestType = 204, // Generic error code (comment required) GenericError = 205, + // The request batch execution type is not supported + UnsupportedRequestBatchExecutionType = 206, // A required request parameter is missing MissingRequestParameter = 300,