diff --git a/CMakeLists.txt b/CMakeLists.txt index 53bd40f9..9b7b4adc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,7 +88,6 @@ set(obs-websocket_SOURCES src/WebSocketApi.cpp src/websocketserver/WebSocketServer.cpp src/websocketserver/WebSocketServer_Protocol.cpp - src/websocketserver/WebSocketServer_RequestBatchProcessing.cpp src/websocketserver/rpc/WebSocketSession.cpp src/eventhandler/EventHandler.cpp src/eventhandler/EventHandler_General.cpp @@ -101,6 +100,7 @@ set(obs-websocket_SOURCES src/eventhandler/EventHandler_SceneItems.cpp src/eventhandler/EventHandler_MediaInputs.cpp src/requesthandler/RequestHandler.cpp + src/requesthandler/RequestBatchHandler.cpp src/requesthandler/RequestHandler_General.cpp src/requesthandler/RequestHandler_Config.cpp src/requesthandler/RequestHandler_Sources.cpp @@ -111,6 +111,7 @@ set(obs-websocket_SOURCES src/requesthandler/RequestHandler_Record.cpp src/requesthandler/RequestHandler_MediaInputs.cpp src/requesthandler/rpc/Request.cpp + src/requesthandler/rpc/RequestBatchRequest.cpp src/requesthandler/rpc/RequestResult.cpp src/forms/SettingsDialog.cpp src/forms/ConnectInfo.cpp @@ -134,9 +135,11 @@ set(obs-websocket_HEADERS src/eventhandler/EventHandler.h src/eventhandler/types/EventSubscription.h src/requesthandler/RequestHandler.h + src/requesthandler/RequestBatchHandler.h src/requesthandler/types/RequestStatus.h src/requesthandler/types/RequestBatchExecutionType.h src/requesthandler/rpc/Request.h + src/requesthandler/rpc/RequestBatchRequest.h src/requesthandler/rpc/RequestResult.h src/forms/SettingsDialog.h src/forms/ConnectInfo.h diff --git a/src/websocketserver/WebSocketServer_RequestBatchProcessing.cpp b/src/requesthandler/RequestBatchHandler.cpp similarity index 53% rename from src/websocketserver/WebSocketServer_RequestBatchProcessing.cpp rename to src/requesthandler/RequestBatchHandler.cpp index fc3e47fb..82f8453a 100644 --- a/src/websocketserver/WebSocketServer_RequestBatchProcessing.cpp +++ b/src/requesthandler/RequestBatchHandler.cpp @@ -1,6 +1,5 @@ /* obs-websocket -Copyright (C) 2016-2021 Stephane Lepin Copyright (C) 2020-2021 Kyle Manning This program is free software; you can redistribute it and/or modify @@ -17,30 +16,18 @@ You should have received a copy of the GNU General Public License along with this program. If not, see */ +#include +#include #include -#include "WebSocketServer.h" -#include "../requesthandler/RequestHandler.h" -#include "../obs-websocket.h" +#include "RequestBatchHandler.h" #include "../utils/Compat.h" - -struct SerialFrameRequest -{ - Request request; - const json inputVariables; - const json outputVariables; - - SerialFrameRequest(const std::string &requestType, const json &requestData, const json &inputVariables, const json &outputVariables) : - request(requestType, requestData, RequestBatchExecutionType::SerialFrame), - inputVariables(inputVariables), - outputVariables(outputVariables) - {} -}; +#include "../obs-websocket.h" struct SerialFrameBatch { RequestHandler &requestHandler; - std::queue requests; + std::queue requests; std::vector results; json &variables; bool haltOnFailure; @@ -62,25 +49,23 @@ struct SerialFrameBatch struct ParallelBatchResults { RequestHandler &requestHandler; - size_t requestCount; - std::mutex resultsMutex; - std::vector results; + std::vector results; + + std::mutex conditionMutex; std::condition_variable condition; - ParallelBatchResults(RequestHandler &requestHandler, size_t requestCount) : - requestHandler(requestHandler), - requestCount(requestCount) + ParallelBatchResults(RequestHandler &requestHandler) : + requestHandler(requestHandler) {} }; - // `{"inputName": "inputNameVariable"}` is essentially `inputName = inputNameVariable` -static void PreProcessVariables(const json &variables, const json &inputVariables, json &requestData) +static void PreProcessVariables(const json &variables, RequestBatchRequest &request) { - if (variables.empty() || !inputVariables.is_object() || inputVariables.empty() || !requestData.is_object()) + if (variables.empty() || !request.InputVariables.is_object() || request.InputVariables.empty() || !request.RequestData.is_object()) return; - for (auto& [key, value] : inputVariables.items()) { + for (auto& [key, value] : request.InputVariables.items()) { if (!value.is_string()) { blog_debug("[WebSocketServer::ProcessRequestBatch] Value of field `%s` in `inputVariables `is not a string. Skipping!", key.c_str()); continue; @@ -92,55 +77,34 @@ static void PreProcessVariables(const json &variables, const json &inputVariable continue; } - requestData[key] = variables[valueString]; + request.RequestData[key] = variables[valueString]; } + + request.HasRequestData = !request.RequestData.empty(); } // `{"sceneItemIdVariable": "sceneItemId"}` is essentially `sceneItemIdVariable = sceneItemId` -static void PostProcessVariables(json &variables, const json &outputVariables, const json &responseData) +static void PostProcessVariables(json &variables, const RequestBatchRequest &request, const RequestResult &requestResult) { - if (!outputVariables.is_object() || outputVariables.empty() || responseData.empty()) + if (!request.OutputVariables.is_object() || request.OutputVariables.empty() || requestResult.ResponseData.empty()) return; - for (auto& [key, value] : outputVariables.items()) { + for (auto& [key, value] : request.OutputVariables.items()) { if (!value.is_string()) { blog_debug("[WebSocketServer::ProcessRequestBatch] Value of field `%s` in `outputVariables` is not a string. Skipping!", key.c_str()); continue; } std::string valueString = value; - if (!responseData.contains(valueString)) { + if (!requestResult.ResponseData.contains(valueString)) { blog_debug("[WebSocketServer::ProcessRequestBatch] `outputVariables` requested responseData field `%s`, but it does not exist. Skipping!", valueString.c_str()); continue; } - variables[key] = responseData[valueString]; + variables[key] = requestResult.ResponseData[valueString]; } } -static json ConstructRequestResult(RequestResult requestResult, const json &requestJson) -{ - json ret; - - ret["requestType"] = requestJson["requestType"]; - - if (requestJson.contains("requestId") && !requestJson["requestId"].is_null()) - ret["requestId"] = requestJson["requestId"]; - - ret["requestStatus"] = { - {"result", requestResult.StatusCode == RequestStatus::Success}, - {"code", requestResult.StatusCode} - }; - - if (!requestResult.Comment.empty()) - ret["requestStatus"]["comment"] = requestResult.Comment; - - if (requestResult.ResponseData.is_object()) - ret["responseData"] = requestResult.ResponseData; - - return ret; -} - static void ObsTickCallback(void *param, float) { ScopeProfiler prof{"obs_websocket_request_batch_frame_tick"}; @@ -162,15 +126,13 @@ static void ObsTickCallback(void *param, float) // Begin recursing any unprocessed requests while (!serialFrameBatch->requests.empty()) { // Fetch first in queue - SerialFrameRequest frameRequest = serialFrameBatch->requests.front(); + RequestBatchRequest request = serialFrameBatch->requests.front(); // Pre-process batch variables - PreProcessVariables(serialFrameBatch->variables, frameRequest.inputVariables, frameRequest.request.RequestData); - // Determine if there is request data - frameRequest.request.HasRequestData = !frameRequest.request.RequestData.empty(); + PreProcessVariables(serialFrameBatch->variables, request); // Process request and get result - RequestResult requestResult = serialFrameBatch->requestHandler.ProcessRequest(frameRequest.request); + RequestResult requestResult = serialFrameBatch->requestHandler.ProcessRequest(request); // Post-process batch variables - PostProcessVariables(serialFrameBatch->variables, frameRequest.outputVariables, requestResult.ResponseData); + PostProcessVariables(serialFrameBatch->variables, request, requestResult); // Add to results vector serialFrameBatch->results.push_back(requestResult); // Remove from front of queue @@ -178,7 +140,7 @@ static void ObsTickCallback(void *param, float) // If haltOnFailure and the request failed, clear the queue to make the batch return early. if (serialFrameBatch->haltOnFailure && requestResult.StatusCode != RequestStatus::Success) { - serialFrameBatch->requests = std::queue(); + serialFrameBatch->requests = std::queue(); break; } @@ -194,37 +156,33 @@ static void ObsTickCallback(void *param, float) serialFrameBatch->condition.notify_one(); } -void WebSocketServer::ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector &requests, std::vector &results, json &variables, bool haltOnFailure) +std::vector RequestBatchHandler::ProcessRequestBatch(QThreadPool &threadPool, SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, std::vector &requests, json &variables, bool haltOnFailure) { RequestHandler requestHandler(session); if (executionType == RequestBatchExecutionType::SerialRealtime) { + std::vector ret; + // Recurse all requests in batch serially, processing the request then moving to the next one - for (auto requestJson : requests) { - Request request(requestJson["requestType"], requestJson["requestData"], RequestBatchExecutionType::SerialRealtime); - - PreProcessVariables(variables, requestJson["inputVariables"], request.RequestData); - - request.HasRequestData = !request.RequestData.empty(); + for (auto &request : requests) { + PreProcessVariables(variables, request); RequestResult requestResult = requestHandler.ProcessRequest(request); - PostProcessVariables(variables, requestJson["outputVariables"], requestResult.ResponseData); + PostProcessVariables(variables, request, requestResult); - json result = ConstructRequestResult(requestResult, requestJson); - - results.push_back(result); + ret.push_back(requestResult); if (haltOnFailure && requestResult.StatusCode != RequestStatus::Success) break; } + + return ret; } else if (executionType == RequestBatchExecutionType::SerialFrame) { SerialFrameBatch serialFrameBatch(requestHandler, variables, haltOnFailure); // Create Request objects in the worker thread (avoid unnecessary processing in graphics thread) - for (auto requestJson : requests) { - SerialFrameRequest frameRequest(requestJson["requestType"], requestJson["requestData"], requestJson["inputVariables"], requestJson["outputVariables"]); - serialFrameBatch.requests.push(frameRequest); - } + for (auto &request : requests) + serialFrameBatch.requests.push(request); // Create a callback entry for the graphics thread to execute on each video frame obs_add_tick_callback(ObsTickCallback, &serialFrameBatch); @@ -236,38 +194,32 @@ void WebSocketServer::ProcessRequestBatch(SessionPtr session, RequestBatchExecut // Remove the created callback entry since we don't need it anymore obs_remove_tick_callback(ObsTickCallback, &serialFrameBatch); - // Create Request objects in the worker thread (avoid unnecessary processing in graphics thread) - size_t i = 0; - for (auto requestResult : serialFrameBatch.results) { - results.push_back(ConstructRequestResult(requestResult, requests[i])); - i++; - } + return serialFrameBatch.results; } else if (executionType == RequestBatchExecutionType::Parallel) { - ParallelBatchResults parallelResults(requestHandler, requests.size()); + ParallelBatchResults parallelResults(requestHandler); + + // Acquire the lock early to prevent the batch from finishing before we're ready + std::unique_lock lock(parallelResults.conditionMutex); // Submit each request as a task to the thread pool to be processed ASAP - for (auto requestJson : requests) { - _threadPool.start(Utils::Compat::CreateFunctionRunnable([¶llelResults, &executionType, requestJson]() { - Request request(requestJson["requestType"], requestJson["requestData"], RequestBatchExecutionType::Parallel); - + for (auto &request : requests) { + threadPool.start(Utils::Compat::CreateFunctionRunnable([¶llelResults, &request]() { RequestResult requestResult = parallelResults.requestHandler.ProcessRequest(request); - json result = ConstructRequestResult(requestResult, requestJson); - - std::unique_lock lock(parallelResults.resultsMutex); - parallelResults.results.push_back(result); + std::unique_lock lock(parallelResults.conditionMutex); + parallelResults.results.push_back(requestResult); lock.unlock(); parallelResults.condition.notify_one(); })); } // Wait for the last request to finish processing - std::unique_lock lock(parallelResults.resultsMutex); - auto cb = [¶llelResults]{return parallelResults.results.size() == parallelResults.requestCount;}; - // A check just in case all requests managed to complete before we started waiting for the condition to be notified - if (!cb()) - parallelResults.condition.wait(lock, cb); + size_t requestCount = requests.size(); + parallelResults.condition.wait(lock, [¶llelResults, requestCount]{return parallelResults.results.size() == requestCount;}); - results = parallelResults.results; + return parallelResults.results; } + + // Return empty vector if not a batch somehow + return std::vector(); } diff --git a/src/requesthandler/RequestBatchHandler.h b/src/requesthandler/RequestBatchHandler.h new file mode 100644 index 00000000..aa5d0574 --- /dev/null +++ b/src/requesthandler/RequestBatchHandler.h @@ -0,0 +1,28 @@ +/* +obs-websocket +Copyright (C) 2020-2021 Kyle Manning + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program. If not, see +*/ + +#pragma once + +#include + +#include "RequestHandler.h" +#include "rpc/RequestBatchRequest.h" + +namespace RequestBatchHandler { + std::vector ProcessRequestBatch(QThreadPool &threadPool, SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, std::vector &requests, json &variables, bool haltOnFailure); +} diff --git a/src/requesthandler/rpc/Request.cpp b/src/requesthandler/rpc/Request.cpp index aaccafda..669a7954 100644 --- a/src/requesthandler/rpc/Request.cpp +++ b/src/requesthandler/rpc/Request.cpp @@ -19,7 +19,6 @@ with this program. If not, see #include "Request.h" #include "../../obs-websocket.h" -#include "../../plugin-macros.generated.h" json GetDefaultJsonObject(const json &requestData) { @@ -30,15 +29,7 @@ json GetDefaultJsonObject(const json &requestData) return requestData; } -Request::Request(const std::string &requestType, const json &requestData) : - RequestType(requestType), - HasRequestData(requestData.is_object()), - RequestData(GetDefaultJsonObject(requestData)), - ExecutionType(RequestBatchExecutionType::None) -{ -} - -Request::Request(const std::string &requestType, const json &requestData, RequestBatchExecutionType::RequestBatchExecutionType executionType) : +Request::Request(const std::string &requestType, const json &requestData, const RequestBatchExecutionType::RequestBatchExecutionType executionType) : RequestType(requestType), HasRequestData(requestData.is_object()), RequestData(GetDefaultJsonObject(requestData)), diff --git a/src/requesthandler/rpc/Request.h b/src/requesthandler/rpc/Request.h index 86022a3f..03f6fb21 100644 --- a/src/requesthandler/rpc/Request.h +++ b/src/requesthandler/rpc/Request.h @@ -31,8 +31,7 @@ enum ObsWebSocketSceneFilter { struct Request { - Request(const std::string &requestType, const json &requestData = nullptr); - Request(const std::string &requestType, const json &requestData, RequestBatchExecutionType::RequestBatchExecutionType executionType); + Request(const std::string &requestType, const json &requestData = nullptr, const RequestBatchExecutionType::RequestBatchExecutionType executionType = RequestBatchExecutionType::None); // Contains the key and is not null bool Contains(const std::string &keyName) const; diff --git a/src/requesthandler/rpc/RequestBatchRequest.cpp b/src/requesthandler/rpc/RequestBatchRequest.cpp new file mode 100644 index 00000000..6ea83eb6 --- /dev/null +++ b/src/requesthandler/rpc/RequestBatchRequest.cpp @@ -0,0 +1,26 @@ +/* +obs-websocket +Copyright (C) 2020-2021 Kyle Manning + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program. If not, see +*/ + +#include "RequestBatchRequest.h" + +RequestBatchRequest::RequestBatchRequest(const std::string &requestType, const json &requestData, RequestBatchExecutionType::RequestBatchExecutionType executionType, const json &inputVariables, const json &outputVariables) : + Request(requestType, requestData, executionType), + InputVariables(inputVariables), + OutputVariables(outputVariables) +{ +} diff --git a/src/requesthandler/rpc/RequestBatchRequest.h b/src/requesthandler/rpc/RequestBatchRequest.h new file mode 100644 index 00000000..bbe5e6e0 --- /dev/null +++ b/src/requesthandler/rpc/RequestBatchRequest.h @@ -0,0 +1,28 @@ +/* +obs-websocket +Copyright (C) 2020-2021 Kyle Manning + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program. If not, see +*/ + +#pragma once + +#include "Request.h" + +struct RequestBatchRequest : Request { + RequestBatchRequest(const std::string &requestType, const json &requestData, RequestBatchExecutionType::RequestBatchExecutionType executionType, const json &inputVariables = nullptr, const json &outputVariables = nullptr); + + json InputVariables; + json OutputVariables; +}; diff --git a/src/websocketserver/WebSocketServer.h b/src/websocketserver/WebSocketServer.h index 4a2a03fd..4147a362 100644 --- a/src/websocketserver/WebSocketServer.h +++ b/src/websocketserver/WebSocketServer.h @@ -89,11 +89,9 @@ class WebSocketServer : QObject void onClose(websocketpp::connection_hdl hdl); void onMessage(websocketpp::connection_hdl hdl, websocketpp::server::message_ptr message); - void SetSessionParameters(SessionPtr session, WebSocketServer::ProcessResult &ret, const json &payloadData); + static void SetSessionParameters(SessionPtr session, WebSocketServer::ProcessResult &ret, const json &payloadData); void ProcessMessage(SessionPtr session, ProcessResult &ret, WebSocketOpCode::WebSocketOpCode opCode, json &payloadData); - void ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector &requests, std::vector &results, json &variables, bool haltOnFailure); - QThreadPool _threadPool; std::thread _serverThread; diff --git a/src/websocketserver/WebSocketServer_Protocol.cpp b/src/websocketserver/WebSocketServer_Protocol.cpp index 50604786..2e42b25c 100644 --- a/src/websocketserver/WebSocketServer_Protocol.cpp +++ b/src/websocketserver/WebSocketServer_Protocol.cpp @@ -18,9 +18,11 @@ with this program. If not, see */ #include +#include #include "WebSocketServer.h" #include "../requesthandler/RequestHandler.h" +#include "../requesthandler/RequestBatchHandler.h" #include "../eventhandler/EventHandler.h" #include "../obs-websocket.h" #include "../Config.h" @@ -33,6 +35,29 @@ static bool IsSupportedRpcVersion(uint8_t requestedVersion) return (requestedVersion == 1); } +static json ConstructRequestResult(RequestResult requestResult, const json &requestJson) +{ + json ret; + + ret["requestType"] = requestJson["requestType"]; + + if (requestJson.contains("requestId") && !requestJson["requestId"].is_null()) + ret["requestId"] = requestJson["requestId"]; + + ret["requestStatus"] = { + {"result", requestResult.StatusCode == RequestStatus::Success}, + {"code", requestResult.StatusCode} + }; + + if (!requestResult.Comment.empty()) + ret["requestStatus"]["comment"] = requestResult.Comment; + + if (requestResult.ResponseData.is_object()) + ret["responseData"] = requestResult.ResponseData; + + return ret; +} + void WebSocketServer::SetSessionParameters(SessionPtr session, ProcessResult &ret, const json &payloadData) { if (payloadData.contains("ignoreInvalidMessages")) { @@ -281,9 +306,19 @@ void WebSocketServer::ProcessMessage(SessionPtr session, WebSocketServer::Proces } std::vector requests = payloadData["requests"]; - json variables = payloadData["variables"]; + + std::vector requestsVector; + for (auto &requestJson : requests) + requestsVector.emplace_back(requestJson["requestType"], requestJson["requestData"], executionType, requestJson["inputVariables"], requestJson["outputVariables"]); + + auto resultsVector = RequestBatchHandler::ProcessRequestBatch(_threadPool, session, executionType, requestsVector, payloadData["variables"], haltOnFailure); + + size_t i = 0; std::vector results; - ProcessRequestBatch(session, executionType, requests, results, variables, haltOnFailure); + for (auto &requestResult : resultsVector) { + results.push_back(ConstructRequestResult(requestResult, requests[i])); + i++; + } ret.result["op"] = WebSocketOpCode::RequestBatchResponse; ret.result["d"]["requestId"] = payloadData["requestId"];