WebSocketServer: Implement haltOnFailure for batch requests

This commit is contained in:
tt2468 2021-12-10 22:28:22 -08:00
parent c9c5da6837
commit 43e2860709
3 changed files with 38 additions and 13 deletions

View File

@ -92,7 +92,7 @@ class WebSocketServer : QObject
void SetSessionParameters(SessionPtr session, WebSocketServer::ProcessResult &ret, const json &payloadData); void SetSessionParameters(SessionPtr session, WebSocketServer::ProcessResult &ret, const json &payloadData);
void ProcessMessage(SessionPtr session, ProcessResult &ret, WebSocketOpCode::WebSocketOpCode opCode, const json &payloadData); void ProcessMessage(SessionPtr session, ProcessResult &ret, WebSocketOpCode::WebSocketOpCode opCode, const json &payloadData);
void ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector<json> &requests, std::vector<json> &results, json &variables); void ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector<json> &requests, std::vector<json> &results, json &variables, bool haltOnFailure);
QThreadPool _threadPool; QThreadPool _threadPool;

View File

@ -28,7 +28,7 @@ with this program. If not, see <https://www.gnu.org/licenses/>
#include "../utils/Platform.h" #include "../utils/Platform.h"
#include "../utils/Compat.h" #include "../utils/Compat.h"
bool IsSupportedRpcVersion(uint8_t requestedVersion) static bool IsSupportedRpcVersion(uint8_t requestedVersion)
{ {
return (requestedVersion == 1); return (requestedVersion == 1);
} }
@ -265,10 +265,23 @@ void WebSocketServer::ProcessMessage(SessionPtr session, WebSocketServer::Proces
} }
} }
bool haltOnFailure = false;
if (payloadData.contains("haltOnFailure") && !payloadData["haltOnFailure"].is_null()) {
if (!payloadData["haltOnFailure"].is_boolean()) {
if (!session->IgnoreInvalidMessages()) {
ret.closeCode = WebSocketCloseCode::InvalidDataFieldType;
ret.closeReason = "Your `haltOnFailure` is not a boolean.";
}
return;
}
haltOnFailure = payloadData["haltOnFailure"];
}
std::vector<json> requests = payloadData["requests"]; std::vector<json> requests = payloadData["requests"];
json variables = payloadData["variables"]; json variables = payloadData["variables"];
std::vector<json> results; std::vector<json> results;
ProcessRequestBatch(session, executionType, requests, results, variables); ProcessRequestBatch(session, executionType, requests, results, variables, haltOnFailure);
ret.result["op"] = WebSocketOpCode::RequestBatchResponse; ret.result["op"] = WebSocketOpCode::RequestBatchResponse;
ret.result["d"]["requestId"] = payloadData["requestId"]; ret.result["d"]["requestId"] = payloadData["requestId"];

View File

@ -40,17 +40,20 @@ struct SerialFrameRequest
struct SerialFrameBatch struct SerialFrameBatch
{ {
RequestHandler &requestHandler; RequestHandler &requestHandler;
json &variables;
size_t frameCount;
size_t sleepUntilFrame;
std::queue<SerialFrameRequest> requests; std::queue<SerialFrameRequest> requests;
std::vector<RequestResult> results; std::vector<RequestResult> results;
json &variables;
bool haltOnFailure;
size_t frameCount;
size_t sleepUntilFrame;
std::mutex conditionMutex; std::mutex conditionMutex;
std::condition_variable condition; std::condition_variable condition;
SerialFrameBatch(RequestHandler &requestHandler, json &variables) : SerialFrameBatch(RequestHandler &requestHandler, json &variables, bool haltOnFailure) :
requestHandler(requestHandler), requestHandler(requestHandler),
variables(variables), variables(variables),
haltOnFailure(haltOnFailure),
frameCount(0), frameCount(0),
sleepUntilFrame(0) sleepUntilFrame(0)
{} {}
@ -72,7 +75,7 @@ struct ParallelBatchResults
bool PreProcessVariables(const json &variables, const json &inputVariables, json &requestData) static bool PreProcessVariables(const json &variables, const json &inputVariables, json &requestData)
{ {
if (variables.empty() || inputVariables.empty() || !inputVariables.is_object() || !requestData.is_object()) if (variables.empty() || inputVariables.empty() || !inputVariables.is_object() || !requestData.is_object())
return !requestData.empty(); return !requestData.empty();
@ -97,7 +100,7 @@ bool PreProcessVariables(const json &variables, const json &inputVariables, json
return !requestData.empty(); return !requestData.empty();
} }
void PostProcessVariables(json &variables, const json &outputVariables, const json &responseData) static void PostProcessVariables(json &variables, const json &outputVariables, const json &responseData)
{ {
if (outputVariables.empty() || !outputVariables.is_object() || responseData.empty()) if (outputVariables.empty() || !outputVariables.is_object() || responseData.empty())
return; return;
@ -120,7 +123,7 @@ void PostProcessVariables(json &variables, const json &outputVariables, const js
} }
} }
json ConstructRequestResult(RequestResult requestResult, const json &requestJson) static json ConstructRequestResult(RequestResult requestResult, const json &requestJson)
{ {
json ret; json ret;
@ -143,7 +146,7 @@ json ConstructRequestResult(RequestResult requestResult, const json &requestJson
return ret; return ret;
} }
void ObsTickCallback(void *param, float) static void ObsTickCallback(void *param, float)
{ {
ScopeProfiler prof{"obs_websocket_request_batch_frame_tick"}; ScopeProfiler prof{"obs_websocket_request_batch_frame_tick"};
@ -177,6 +180,12 @@ void ObsTickCallback(void *param, float)
// Remove from front of queue // Remove from front of queue
serialFrameBatch->requests.pop(); serialFrameBatch->requests.pop();
// If haltOnFailure and the request failed, clear the queue to make the batch return early.
if (serialFrameBatch->haltOnFailure && requestResult.StatusCode != RequestStatus::Success) {
serialFrameBatch->requests = std::queue<SerialFrameRequest>();
break;
}
// If the processed request tells us to sleep, do so accordingly // If the processed request tells us to sleep, do so accordingly
if (requestResult.SleepFrames) { if (requestResult.SleepFrames) {
serialFrameBatch->sleepUntilFrame = serialFrameBatch->frameCount + requestResult.SleepFrames; serialFrameBatch->sleepUntilFrame = serialFrameBatch->frameCount + requestResult.SleepFrames;
@ -190,7 +199,7 @@ void ObsTickCallback(void *param, float)
} }
} }
void WebSocketServer::ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector<json> &requests, std::vector<json> &results, json &variables) void WebSocketServer::ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector<json> &requests, std::vector<json> &results, json &variables, bool haltOnFailure)
{ {
RequestHandler requestHandler(session); RequestHandler requestHandler(session);
if (executionType == RequestBatchExecutionType::SerialRealtime) { if (executionType == RequestBatchExecutionType::SerialRealtime) {
@ -207,9 +216,12 @@ void WebSocketServer::ProcessRequestBatch(SessionPtr session, RequestBatchExecut
json result = ConstructRequestResult(requestResult, requestJson); json result = ConstructRequestResult(requestResult, requestJson);
results.push_back(result); results.push_back(result);
if (haltOnFailure && requestResult.StatusCode != RequestStatus::Success)
break;
} }
} else if (executionType == RequestBatchExecutionType::SerialFrame) { } else if (executionType == RequestBatchExecutionType::SerialFrame) {
SerialFrameBatch serialFrameBatch(requestHandler, variables); SerialFrameBatch serialFrameBatch(requestHandler, variables, haltOnFailure);
// Create Request objects in the worker thread (avoid unnecessary processing in graphics thread) // Create Request objects in the worker thread (avoid unnecessary processing in graphics thread)
for (auto requestJson : requests) { for (auto requestJson : requests) {