From a8d27ede9ef34c9cf502d9d9e041a1a1f13b906b Mon Sep 17 00:00:00 2001 From: tt2468 Date: Sun, 26 Sep 2021 03:12:29 -0700 Subject: [PATCH] Base: Add request batch execution types A new `executionType` field has been added to the `RequestBatch` Op Types added: - `OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME`(default) - `OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME` - `OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL` `OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME`: - Same as how request batches have always worked. - Requests are processed in-order - Requests are processed as soon as possible by one worker thread - The `Sleep` request blocks execution for a specified amount of real world time `OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME`: - New! - Requests are processed in-order - Requests are processed on the graphics thread. BE VERY CAREFUL NOT TO OVERLOAD THE GRAPHICS THREAD WITH LARGE REQUESTS. A general rule of thumb is for your request batches to take a maximum of 2ms per frame of processing. - Requests processing starts right before the next frame is composited. This functionality is perfect for things like `SetSceneItemTransform` - The `Sleep` request will halt processing of the request batch for a specified number of frames (ticks) - To be clear: If you do not have any sleep requests, all requests in the batch will be processed in the span of a single frame - For developers: The execution of requests gets profiled by the OBS profiler under the `obs-websocket-request-batch-frame-tick` name. This value (shown in the OBS log after OBS shutdown) represents the amount of time that the graphics thread spent actively processing requests per frame. This tool can be used to determine the amount of load that your request batches are placing on the graphics thread. `OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL`: - New! - Requests are processed asynchronously at the soonest possible time. - Requests are processed by the core obs-websocket thread pool, where the number of workers == the number of threads on your machine. - If you have 12 threads on your machine, obs-websocket will be able to process 12 requests at any given moment. - The `results` array is populated by order of request completion. Consider the order to be random. - The `Sleep` request will return an error if attempted to be used in this mode. - Note: This feature is experimental and can increase the chances of causing race conditions (crashes). While the implementation is fully thread-safe, OBS itself is not. Usage of this is only recommended if you are processing very large batches and need the performance benefit. - Example use case: Performing `SaveSourceScreenshot` on 8 sources at once. --- CMakeLists.txt | 1 + src/WebSocketServer.cpp | 6 +- src/WebSocketServer.h | 7 + src/WebSocketServer_Protocol.cpp | 61 +++--- ...WebSocketServer_RequestBatchProcessing.cpp | 175 ++++++++++++++++++ src/requesthandler/RequestHandler_General.cpp | 21 ++- src/requesthandler/rpc/Request.cpp | 5 +- src/requesthandler/rpc/Request.h | 9 +- src/requesthandler/rpc/RequestResult.cpp | 3 +- src/requesthandler/rpc/RequestResult.h | 1 + src/requesthandler/rpc/RequestStatus.h | 2 + 11 files changed, 251 insertions(+), 40 deletions(-) create mode 100644 src/WebSocketServer_RequestBatchProcessing.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 16d495d3..f4288218 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,6 +87,7 @@ set(obs-websocket_SOURCES src/Config.cpp src/WebSocketServer.cpp src/WebSocketServer_Protocol.cpp + src/WebSocketServer_RequestBatchProcessing.cpp src/WebSocketSession.cpp src/eventhandler/EventHandler.cpp src/eventhandler/EventHandler_General.cpp diff --git a/src/WebSocketServer.cpp b/src/WebSocketServer.cpp index 6619da85..1f1d24fd 100644 --- a/src/WebSocketServer.cpp +++ b/src/WebSocketServer.cpp @@ -339,7 +339,7 @@ void WebSocketServer::onClose(websocketpp::connection_hdl hdl) void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::server::message_ptr message) { - auto opcode = message->get_opcode(); + auto opCode = message->get_opcode(); std::string payload = message->get_payload(); _threadPool.start(Utils::Compat::CreateFunctionRunnable([=]() { std::unique_lock lock(_sessionMutex); @@ -359,7 +359,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se websocketpp::lib::error_code errorCode; uint8_t sessionEncoding = session->Encoding(); if (sessionEncoding == WebSocketEncoding::Json) { - if (opcode != websocketpp::frame::opcode::text) { + if (opCode != websocketpp::frame::opcode::text) { if (!session->IgnoreInvalidMessages()) _server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to Json, but a binary message was received.", errorCode); return; @@ -373,7 +373,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, websocketpp::se return; } } else if (sessionEncoding == WebSocketEncoding::MsgPack) { - if (opcode != websocketpp::frame::opcode::binary) { + if (opCode != websocketpp::frame::opcode::binary) { if (!session->IgnoreInvalidMessages()) _server.close(hdl, WebSocketCloseCode::MessageDecodeError, "Your session encoding is set to MsgPack, but a text message was received.", errorCode); return; diff --git a/src/WebSocketServer.h b/src/WebSocketServer.h index d9c82736..3a0e3f5c 100644 --- a/src/WebSocketServer.h +++ b/src/WebSocketServer.h @@ -9,6 +9,7 @@ #include "utils/Json.h" #include "WebSocketSession.h" +#include "requesthandler/rpc/Request.h" #include "plugin-macros.generated.h" class WebSocketServer : QObject @@ -53,6 +54,10 @@ class WebSocketServer : QObject UnsupportedRpcVersion = 4009, // The websocket session has been invalidated by the obs-websocket server. SessionInvalidated = 4010, + // A data key's value is invalid, in the case of things like enums. + InvalidDataKeyValue = 4011, + // A feature is not supported because of hardware/software limitations. + UnsupportedFeature = 4012, }; WebSocketServer(); @@ -98,6 +103,8 @@ class WebSocketServer : QObject void SetSessionParameters(SessionPtr session, WebSocketServer::ProcessResult &ret, json payloadData); void ProcessMessage(SessionPtr session, ProcessResult &ret, const uint8_t opCode, json incomingMessage); + void ProcessRequestBatch(SessionPtr session, ObsWebSocketRequestBatchExecutionType executionType, std::vector &requests, std::vector &results); + std::thread _serverThread; websocketpp::server _server; QThreadPool _threadPool; diff --git a/src/WebSocketServer_Protocol.cpp b/src/WebSocketServer_Protocol.cpp index aa4fea96..5b44e3f2 100644 --- a/src/WebSocketServer_Protocol.cpp +++ b/src/WebSocketServer_Protocol.cpp @@ -214,35 +214,42 @@ void WebSocketServer::ProcessMessage(SessionPtr session, WebSocketServer::Proces return; } - std::vector requests = payloadData["requests"]; - json results = json::array(); - - RequestHandler requestHandler(session); - for (auto requestJson : requests) { - Request request(requestJson["requestType"], requestJson["requestData"]); - - RequestResult requestResult = requestHandler.ProcessRequest(request); - - json result; - result["requestType"] = requestJson["requestType"]; - - if (requestJson.contains("requestId")) - result["requestId"] = requestJson["requestId"]; - - result["requestStatus"] = { - {"result", requestResult.StatusCode == RequestStatus::Success}, - {"code", requestResult.StatusCode} - }; - - if (!requestResult.Comment.empty()) - result["requestStatus"]["comment"] = requestResult.Comment; - - if (requestResult.ResponseData.is_object()) - result["responseData"] = requestResult.ResponseData; - - results.push_back(result); + ObsWebSocketRequestBatchExecutionType executionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME; + if (payloadData.contains("executionType") && !payloadData["executionType"].is_null()) { + if (!payloadData["executionType"].is_string()) { + if (!session->IgnoreInvalidMessages()) { + ret.closeCode = WebSocketCloseCode::InvalidDataKeyType; + ret.closeReason = "Your `executionType` is not a string."; + } + return; + } + std::string executionTypeString = payloadData["executionType"]; + if (executionTypeString == "OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME") { + executionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME; + } else if (executionTypeString == "OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME") { + executionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME; + } else if (executionTypeString == "OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL") { + if (_threadPool.maxThreadCount() < 2) { + if (!session->IgnoreInvalidMessages()) { + ret.closeCode = WebSocketCloseCode::UnsupportedFeature; + ret.closeReason = "Parallel request batch processing is not available on this system due to limited core count."; + } + return; + } + executionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL; + } else { + if (!session->IgnoreInvalidMessages()) { + ret.closeCode = WebSocketCloseCode::InvalidDataKeyValue; + ret.closeReason = "Your `executionType`'s value is not recognized."; + } + return; + } } + std::vector requests = payloadData["requests"]; + std::vector results; + ProcessRequestBatch(session, executionType, requests, results); + ret.result["op"] = WebSocketOpCode::RequestBatchResponse; ret.result["d"]["requestId"] = payloadData["requestId"]; ret.result["d"]["results"] = results; diff --git a/src/WebSocketServer_RequestBatchProcessing.cpp b/src/WebSocketServer_RequestBatchProcessing.cpp new file mode 100644 index 00000000..338e2b77 --- /dev/null +++ b/src/WebSocketServer_RequestBatchProcessing.cpp @@ -0,0 +1,175 @@ +#include + +#include "WebSocketServer.h" +#include "requesthandler/RequestHandler.h" +#include "obs-websocket.h" +#include "utils/Compat.h" + +struct SerialFrameBatch +{ + RequestHandler *requestHandler; + size_t frameCount; + size_t sleepUntilFrame; + std::queue requests; + std::vector results; + std::mutex conditionMutex; + std::condition_variable condition; + + SerialFrameBatch(RequestHandler *requestHandler) : + requestHandler(requestHandler), + frameCount(0), + sleepUntilFrame(0) + {} +}; + +struct ParallelBatchResults +{ + RequestHandler *requestHandler; + size_t requestCount; + std::mutex resultsMutex; + std::vector results; + std::condition_variable condition; + + ParallelBatchResults(RequestHandler *requestHandler, size_t requestCount) : + requestHandler(requestHandler), + requestCount(requestCount) + {} +}; + +json ConstructRequestResult(RequestResult requestResult, json requestJson) +{ + json ret; + + ret["requestType"] = requestJson["requestType"]; + + if (requestJson.contains("requestId") && !requestJson["requestId"].is_null()) + ret["requestId"] = requestJson["requestId"]; + + ret["requestStatus"] = { + {"result", requestResult.StatusCode == RequestStatus::Success}, + {"code", requestResult.StatusCode} + }; + + if (!requestResult.Comment.empty()) + ret["requestStatus"]["comment"] = requestResult.Comment; + + if (requestResult.ResponseData.is_object()) + ret["responseData"] = requestResult.ResponseData; + + return ret; +} + +void ObsTickCallback(void *param, float) +{ + profile_start("obs-websocket-request-batch-frame-tick"); + + auto serialFrameBatch = reinterpret_cast(param); + + // Increment frame count + serialFrameBatch->frameCount++; + + if (serialFrameBatch->sleepUntilFrame) { + if (serialFrameBatch->frameCount < serialFrameBatch->sleepUntilFrame) { + // Do not process any requests if in "sleep mode" + profile_end("obs-websocket-request-batch-frame-tick"); + return; + } else { + // Reset frame sleep until counter if not being used + serialFrameBatch->sleepUntilFrame = 0; + } + } + + // Begin recursing any unprocessed requests + while (!serialFrameBatch->requests.empty()) { + // Fetch first in queue + Request request = serialFrameBatch->requests.front(); + // Process request and get result + RequestResult requestResult = serialFrameBatch->requestHandler->ProcessRequest(request); + // Add to results vector + serialFrameBatch->results.push_back(requestResult); + // Remove from front of queue + serialFrameBatch->requests.pop(); + + // If the processed request tells us to sleep, do so accordingly + if (requestResult.SleepFrames) { + serialFrameBatch->sleepUntilFrame = serialFrameBatch->frameCount + requestResult.SleepFrames; + break; + } + } + + // If request queue is empty, we can notify the paused worker thread + if (serialFrameBatch->requests.empty()) { + serialFrameBatch->condition.notify_one(); + } + + profile_end("obs-websocket-request-batch-frame-tick"); +} + +void WebSocketServer::ProcessRequestBatch(SessionPtr session, ObsWebSocketRequestBatchExecutionType executionType, std::vector &requests, std::vector &results) +{ + RequestHandler requestHandler(session); + if (executionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME) { + // Recurse all requests in batch serially, processing the request then moving to the next one + for (auto requestJson : requests) { + Request request(requestJson["requestType"], requestJson["requestData"], executionType); + + RequestResult requestResult = requestHandler.ProcessRequest(request); + + json result = ConstructRequestResult(requestResult, requestJson); + + results.push_back(result); + } + } else if (executionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME) { + SerialFrameBatch serialFrameBatch(&requestHandler); + + // Create Request objects in the worker thread (avoid unnecessary processing in graphics thread) + for (auto requestJson : requests) { + Request request(requestJson["requestType"], requestJson["requestData"], executionType); + serialFrameBatch.requests.push(request); + } + + // Create a callback entry for the graphics thread to execute on each video frame + obs_add_tick_callback(ObsTickCallback, &serialFrameBatch); + + // Wait until the graphics thread processes the last request in the queue + std::unique_lock lock(serialFrameBatch.conditionMutex); + serialFrameBatch.condition.wait(lock, [&serialFrameBatch]{return serialFrameBatch.requests.empty();}); + + // Remove the created callback entry since we don't need it anymore + obs_remove_tick_callback(ObsTickCallback, &serialFrameBatch); + + // Create Request objects in the worker thread (avoid unnecessary processing in graphics thread) + size_t i = 0; + for (auto requestResult : serialFrameBatch.results) { + results.push_back(ConstructRequestResult(requestResult, requests[i])); + i++; + } + } else if (executionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL) { + ParallelBatchResults parallelResults(&requestHandler, requests.size()); + + // Submit each request as a task to the thread pool to be processed ASAP + for (auto requestJson : requests) { + _threadPool.start(Utils::Compat::CreateFunctionRunnable([¶llelResults, &executionType, requestJson]() { + Request request(requestJson["requestType"], requestJson["requestData"], executionType); + + RequestResult requestResult = parallelResults.requestHandler->ProcessRequest(request); + + json result = ConstructRequestResult(requestResult, requestJson); + + std::unique_lock lock(parallelResults.resultsMutex); + parallelResults.results.push_back(result); + lock.unlock(); + parallelResults.condition.notify_one(); + })); + } + + // Wait for the last request to finish processing + std::unique_lock lock(parallelResults.resultsMutex); + auto cb = [¶llelResults]{return parallelResults.results.size() == parallelResults.requestCount;}; + // A check just in case all requests managed to complete before we started waiting for the condition to be notified + if (!cb()) + parallelResults.condition.wait(lock, cb); + + results = parallelResults.results; + } +} diff --git a/src/requesthandler/RequestHandler_General.cpp b/src/requesthandler/RequestHandler_General.cpp index 044603f3..e49d5110 100644 --- a/src/requesthandler/RequestHandler_General.cpp +++ b/src/requesthandler/RequestHandler_General.cpp @@ -148,11 +148,20 @@ RequestResult RequestHandler::Sleep(const Request& request) { RequestStatus::RequestStatus statusCode; std::string comment; - if (!request.ValidateNumber("sleepMillis", statusCode, comment, 0, 50000)) - return RequestResult::Error(statusCode, comment); - int64_t sleepMillis = request.RequestData["sleepMillis"]; - std::this_thread::sleep_for(std::chrono::milliseconds(sleepMillis)); - - return RequestResult::Success(); + if (request.RequestBatchExecutionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME) { + if (!request.ValidateNumber("sleepMillis", statusCode, comment, 0, 50000)) + return RequestResult::Error(statusCode, comment); + int64_t sleepMillis = request.RequestData["sleepMillis"]; + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMillis)); + return RequestResult::Success(); + } else if (request.RequestBatchExecutionType == OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME) { + if (!request.ValidateNumber("sleepFrames", statusCode, comment, 0, 10000)) + return RequestResult::Error(statusCode, comment); + RequestResult ret = RequestResult::Success(); + ret.SleepFrames = request.RequestData["sleepFrames"]; + return ret; + } else { + return RequestResult::Error(RequestStatus::UnsupportedRequestBatchExecutionType); + } } diff --git a/src/requesthandler/rpc/Request.cpp b/src/requesthandler/rpc/Request.cpp index 8845b769..2ca60870 100644 --- a/src/requesthandler/rpc/Request.cpp +++ b/src/requesthandler/rpc/Request.cpp @@ -11,10 +11,11 @@ json GetDefaultJsonObject(json requestData) return requestData; } -Request::Request(std::string requestType, json requestData) : +Request::Request(std::string requestType, json requestData, ObsWebSocketRequestBatchExecutionType requestBatchExecutionType) : HasRequestData(requestData.is_object()), RequestType(requestType), - RequestData(GetDefaultJsonObject(requestData)) + RequestData(GetDefaultJsonObject(requestData)), + RequestBatchExecutionType(requestBatchExecutionType) { } diff --git a/src/requesthandler/rpc/Request.h b/src/requesthandler/rpc/Request.h index 26f6f638..417c988e 100644 --- a/src/requesthandler/rpc/Request.h +++ b/src/requesthandler/rpc/Request.h @@ -3,6 +3,12 @@ #include "RequestStatus.h" #include "../../utils/Json.h" +enum ObsWebSocketRequestBatchExecutionType { + OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME, + OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_FRAME, + OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_PARALLEL +}; + enum ObsWebSocketSceneFilter { OBS_WEBSOCKET_SCENE_FILTER_SCENE_ONLY, OBS_WEBSOCKET_SCENE_FILTER_GROUP_ONLY, @@ -11,7 +17,7 @@ enum ObsWebSocketSceneFilter { struct Request { - Request(const std::string requestType, const json requestData = nullptr); + Request(const std::string requestType, const json requestData = nullptr, const ObsWebSocketRequestBatchExecutionType requestBatchExecutionType = OBS_WEBSOCKET_REQUEST_BATCH_EXECUTION_TYPE_SERIAL_REALTIME); // Contains the key and is not null const bool Contains(const std::string keyName) const; @@ -37,4 +43,5 @@ struct Request const bool HasRequestData; const std::string RequestType; const json RequestData; + const ObsWebSocketRequestBatchExecutionType RequestBatchExecutionType; }; diff --git a/src/requesthandler/rpc/RequestResult.cpp b/src/requesthandler/rpc/RequestResult.cpp index 7308a8ad..e7b964d0 100644 --- a/src/requesthandler/rpc/RequestResult.cpp +++ b/src/requesthandler/rpc/RequestResult.cpp @@ -3,7 +3,8 @@ RequestResult::RequestResult(RequestStatus::RequestStatus statusCode, json responseData, std::string comment) : StatusCode(statusCode), ResponseData(responseData), - Comment(comment) + Comment(comment), + SleepFrames(0) { } diff --git a/src/requesthandler/rpc/RequestResult.h b/src/requesthandler/rpc/RequestResult.h index aa8fbc08..565218b8 100644 --- a/src/requesthandler/rpc/RequestResult.h +++ b/src/requesthandler/rpc/RequestResult.h @@ -11,4 +11,5 @@ struct RequestResult RequestStatus::RequestStatus StatusCode; json ResponseData; std::string Comment; + size_t SleepFrames; }; diff --git a/src/requesthandler/rpc/RequestStatus.h b/src/requesthandler/rpc/RequestStatus.h index 7647715e..3d1566e8 100644 --- a/src/requesthandler/rpc/RequestStatus.h +++ b/src/requesthandler/rpc/RequestStatus.h @@ -15,6 +15,8 @@ namespace RequestStatus { UnknownRequestType = 204, // Generic error code (comment required) GenericError = 205, + // The request batch execution type is not supported + UnsupportedRequestBatchExecutionType = 206, // A required request parameter is missing MissingRequestParameter = 300,