WebSocketServer: Implement haltOnFailure for batch requests

This commit is contained in:
tt2468 2021-12-10 22:28:22 -08:00
parent c9c5da6837
commit 43e2860709
3 changed files with 38 additions and 13 deletions

View File

@ -92,7 +92,7 @@ class WebSocketServer : QObject
void SetSessionParameters(SessionPtr session, WebSocketServer::ProcessResult &ret, const json &payloadData);
void ProcessMessage(SessionPtr session, ProcessResult &ret, WebSocketOpCode::WebSocketOpCode opCode, const json &payloadData);
void ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector<json> &requests, std::vector<json> &results, json &variables);
void ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector<json> &requests, std::vector<json> &results, json &variables, bool haltOnFailure);
QThreadPool _threadPool;

View File

@ -28,7 +28,7 @@ with this program. If not, see <https://www.gnu.org/licenses/>
#include "../utils/Platform.h"
#include "../utils/Compat.h"
bool IsSupportedRpcVersion(uint8_t requestedVersion)
static bool IsSupportedRpcVersion(uint8_t requestedVersion)
{
return (requestedVersion == 1);
}
@ -265,10 +265,23 @@ void WebSocketServer::ProcessMessage(SessionPtr session, WebSocketServer::Proces
}
}
bool haltOnFailure = false;
if (payloadData.contains("haltOnFailure") && !payloadData["haltOnFailure"].is_null()) {
if (!payloadData["haltOnFailure"].is_boolean()) {
if (!session->IgnoreInvalidMessages()) {
ret.closeCode = WebSocketCloseCode::InvalidDataFieldType;
ret.closeReason = "Your `haltOnFailure` is not a boolean.";
}
return;
}
haltOnFailure = payloadData["haltOnFailure"];
}
std::vector<json> requests = payloadData["requests"];
json variables = payloadData["variables"];
std::vector<json> results;
ProcessRequestBatch(session, executionType, requests, results, variables);
ProcessRequestBatch(session, executionType, requests, results, variables, haltOnFailure);
ret.result["op"] = WebSocketOpCode::RequestBatchResponse;
ret.result["d"]["requestId"] = payloadData["requestId"];

View File

@ -40,17 +40,20 @@ struct SerialFrameRequest
struct SerialFrameBatch
{
RequestHandler &requestHandler;
json &variables;
size_t frameCount;
size_t sleepUntilFrame;
std::queue<SerialFrameRequest> requests;
std::vector<RequestResult> results;
json &variables;
bool haltOnFailure;
size_t frameCount;
size_t sleepUntilFrame;
std::mutex conditionMutex;
std::condition_variable condition;
SerialFrameBatch(RequestHandler &requestHandler, json &variables) :
SerialFrameBatch(RequestHandler &requestHandler, json &variables, bool haltOnFailure) :
requestHandler(requestHandler),
variables(variables),
haltOnFailure(haltOnFailure),
frameCount(0),
sleepUntilFrame(0)
{}
@ -72,7 +75,7 @@ struct ParallelBatchResults
bool PreProcessVariables(const json &variables, const json &inputVariables, json &requestData)
static bool PreProcessVariables(const json &variables, const json &inputVariables, json &requestData)
{
if (variables.empty() || inputVariables.empty() || !inputVariables.is_object() || !requestData.is_object())
return !requestData.empty();
@ -97,7 +100,7 @@ bool PreProcessVariables(const json &variables, const json &inputVariables, json
return !requestData.empty();
}
void PostProcessVariables(json &variables, const json &outputVariables, const json &responseData)
static void PostProcessVariables(json &variables, const json &outputVariables, const json &responseData)
{
if (outputVariables.empty() || !outputVariables.is_object() || responseData.empty())
return;
@ -120,7 +123,7 @@ void PostProcessVariables(json &variables, const json &outputVariables, const js
}
}
json ConstructRequestResult(RequestResult requestResult, const json &requestJson)
static json ConstructRequestResult(RequestResult requestResult, const json &requestJson)
{
json ret;
@ -143,7 +146,7 @@ json ConstructRequestResult(RequestResult requestResult, const json &requestJson
return ret;
}
void ObsTickCallback(void *param, float)
static void ObsTickCallback(void *param, float)
{
ScopeProfiler prof{"obs_websocket_request_batch_frame_tick"};
@ -177,6 +180,12 @@ void ObsTickCallback(void *param, float)
// Remove from front of queue
serialFrameBatch->requests.pop();
// If haltOnFailure and the request failed, clear the queue to make the batch return early.
if (serialFrameBatch->haltOnFailure && requestResult.StatusCode != RequestStatus::Success) {
serialFrameBatch->requests = std::queue<SerialFrameRequest>();
break;
}
// If the processed request tells us to sleep, do so accordingly
if (requestResult.SleepFrames) {
serialFrameBatch->sleepUntilFrame = serialFrameBatch->frameCount + requestResult.SleepFrames;
@ -190,7 +199,7 @@ void ObsTickCallback(void *param, float)
}
}
void WebSocketServer::ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector<json> &requests, std::vector<json> &results, json &variables)
void WebSocketServer::ProcessRequestBatch(SessionPtr session, RequestBatchExecutionType::RequestBatchExecutionType executionType, const std::vector<json> &requests, std::vector<json> &results, json &variables, bool haltOnFailure)
{
RequestHandler requestHandler(session);
if (executionType == RequestBatchExecutionType::SerialRealtime) {
@ -207,9 +216,12 @@ void WebSocketServer::ProcessRequestBatch(SessionPtr session, RequestBatchExecut
json result = ConstructRequestResult(requestResult, requestJson);
results.push_back(result);
if (haltOnFailure && requestResult.StatusCode != RequestStatus::Success)
break;
}
} else if (executionType == RequestBatchExecutionType::SerialFrame) {
SerialFrameBatch serialFrameBatch(requestHandler, variables);
SerialFrameBatch serialFrameBatch(requestHandler, variables, haltOnFailure);
// Create Request objects in the worker thread (avoid unnecessary processing in graphics thread)
for (auto requestJson : requests) {