From 1da0214201f96e6e53dfa88701cb8f3721946bde Mon Sep 17 00:00:00 2001 From: tt2468 Date: Mon, 25 Apr 2022 21:31:52 -0700 Subject: [PATCH] Config, websocketserver: Add feature to bind to loopback (default) Binds to localhost or 127.0.0.1 by default, since most users don't have to access obs-websocket externally. --- src/Config.cpp | 5 +++++ src/Config.h | 1 + src/utils/Platform.cpp | 15 +++++++++++++++ src/utils/Platform.h | 1 + src/websocketserver/WebSocketServer.cpp | 16 ++++++++++++---- 5 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/Config.cpp b/src/Config.cpp index 2243183b..91ee476f 100644 --- a/src/Config.cpp +++ b/src/Config.cpp @@ -28,6 +28,7 @@ with this program. If not, see #define PARAM_FIRSTLOAD "FirstLoad" #define PARAM_ENABLED "ServerEnabled" #define PARAM_PORT "ServerPort" +#define PARAM_BINDLOOPBACK "BindLoopback" #define PARAM_ALERTS "AlertsEnabled" #define PARAM_AUTHREQUIRED "AuthRequired" #define PARAM_PASSWORD "ServerPassword" @@ -43,6 +44,7 @@ Config::Config() : FirstLoad(true), ServerEnabled(true), ServerPort(4455), + BindLoopback(true), Ipv4Only(false), DebugEnabled(false), AlertsEnabled(false), @@ -64,6 +66,7 @@ void Config::Load() ServerEnabled = config_get_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_ENABLED); AlertsEnabled = config_get_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_ALERTS); ServerPort = config_get_uint(obsConfig, CONFIG_SECTION_NAME, PARAM_PORT); + BindLoopback = config_get_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_BINDLOOPBACK); AuthRequired = config_get_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_AUTHREQUIRED); ServerPassword = config_get_string(obsConfig, CONFIG_SECTION_NAME, PARAM_PASSWORD); @@ -131,6 +134,7 @@ void Config::Save() if (!PortOverridden) { config_set_uint(obsConfig, CONFIG_SECTION_NAME, PARAM_PORT, ServerPort); } + config_set_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_BINDLOOPBACK, BindLoopback); config_set_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_ALERTS, AlertsEnabled); if (!PasswordOverridden) { config_set_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_AUTHREQUIRED, AuthRequired); @@ -151,6 +155,7 @@ void Config::SetDefaultsToGlobalStore() config_set_default_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_FIRSTLOAD, FirstLoad); config_set_default_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_ENABLED, ServerEnabled); config_set_default_uint(obsConfig, CONFIG_SECTION_NAME, PARAM_PORT, ServerPort); + config_set_default_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_BINDLOOPBACK, BindLoopback); config_set_default_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_ALERTS, AlertsEnabled); config_set_default_bool(obsConfig, CONFIG_SECTION_NAME, PARAM_AUTHREQUIRED, AuthRequired); config_set_default_string(obsConfig, CONFIG_SECTION_NAME, PARAM_PASSWORD, QT_TO_UTF8(ServerPassword)); diff --git a/src/Config.h b/src/Config.h index 17533223..3b3f4a54 100644 --- a/src/Config.h +++ b/src/Config.h @@ -38,6 +38,7 @@ struct Config { std::atomic FirstLoad; std::atomic ServerEnabled; std::atomic ServerPort; + std::atomic BindLoopback; std::atomic Ipv4Only; std::atomic DebugEnabled; std::atomic AlertsEnabled; diff --git a/src/utils/Platform.cpp b/src/utils/Platform.cpp index 9079d88b..83b5f897 100644 --- a/src/utils/Platform.cpp +++ b/src/utils/Platform.cpp @@ -76,6 +76,21 @@ std::string Utils::Platform::GetLocalAddress() return preferredAddresses[0].first.toStdString(); } +std::string Utils::Platform::GetLoopbackAddress(bool allowIpv6) +{ + std::vector validAddresses; + for (auto address : QNetworkInterface::allAddresses()) { + if (address == QHostAddress::LocalHost) + return address.toString().toStdString(); + else if (address == QHostAddress::LocalHostIPv6 && allowIpv6) + return address.toString().toStdString(); + else if (address.isLoopback()) + return address.toString().toStdString(); + } + + return ""; +} + QString Utils::Platform::GetCommandLineArgument(QString arg) { QCommandLineParser parser; diff --git a/src/utils/Platform.h b/src/utils/Platform.h index aee2213d..a2692e15 100644 --- a/src/utils/Platform.h +++ b/src/utils/Platform.h @@ -26,6 +26,7 @@ with this program. If not, see namespace Utils { namespace Platform { std::string GetLocalAddress(); + std::string GetLoopbackAddress(bool allowIpv6 = true); QString GetCommandLineArgument(QString arg); bool GetCommandLineFlagSet(QString arg); void SendTrayNotification(QSystemTrayIcon::MessageIcon icon, QString title, QString body); diff --git a/src/websocketserver/WebSocketServer.cpp b/src/websocketserver/WebSocketServer.cpp index 2fa1c40f..ad7b1539 100644 --- a/src/websocketserver/WebSocketServer.cpp +++ b/src/websocketserver/WebSocketServer.cpp @@ -129,17 +129,25 @@ void WebSocketServer::Start() _server.reset(); websocketpp::lib::error_code errorCode; - if (conf->Ipv4Only) { - blog(LOG_INFO, "[WebSocketServer::Start] Locked to IPv4 bindings"); + if (conf->BindLoopback) { + std::string addr = Utils::Platform::GetLoopbackAddress(!conf->Ipv4Only); + if (addr.empty()) { + blog(LOG_ERROR, "[WebSocketServer::Start] Failed to find loopback interface. Server not started."); + return; + } + _server.listen(addr, std::to_string(conf->ServerPort), errorCode); + blog(LOG_INFO, "[WebSocketServer::Start] Locked to loopback interface."); + } else if (conf->Ipv4Only) { _server.listen(websocketpp::lib::asio::ip::tcp::v4(), conf->ServerPort, errorCode); + blog(LOG_INFO, "[WebSocketServer::Start] Locked to IPv4 bindings."); } else { - blog(LOG_INFO, "[WebSocketServer::Start] Not locked to IPv4 bindings"); _server.listen(conf->ServerPort, errorCode); + blog(LOG_INFO, "[WebSocketServer::Start] Not locked to IPv4 bindings."); } if (errorCode) { std::string errorCodeMessage = errorCode.message(); - blog(LOG_INFO, "[WebSocketServer::Start] Listen failed: %s", errorCodeMessage.c_str()); + blog(LOG_ERROR, "[WebSocketServer::Start] Listen failed: %s", errorCodeMessage.c_str()); return; }