From 33c570075c7688484b3482ec6fd6f55fa6347da7 Mon Sep 17 00:00:00 2001
From: Mikayla Fischler <mikayla@ky8.io>
Date: Mon, 17 Apr 2023 19:48:03 -0400
Subject: [PATCH] supervisor code cleanup

---
 supervisor/session/coordinator.lua |  4 +++-
 supervisor/session/plc.lua         |  4 +++-
 supervisor/session/rtu.lua         |  3 ++-
 supervisor/session/svsessions.lua  | 11 +++++------
 supervisor/startup.lua             |  6 +++---
 supervisor/supervisor.lua          |  1 +
 6 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/supervisor/session/coordinator.lua b/supervisor/session/coordinator.lua
index ad706e8..238765f 100644
--- a/supervisor/session/coordinator.lua
+++ b/supervisor/session/coordinator.lua
@@ -173,7 +173,7 @@ function coordinator.new_session(id, in_queue, out_queue, timeout, facility)
     end
 
     -- handle a packet
-    ---@param pkt crdn_frame
+    ---@param pkt mgmt_frame|crdn_frame
     local function _handle_packet(pkt)
         -- check sequence number
         if self.r_seq_num == nil then
@@ -190,6 +190,7 @@ function coordinator.new_session(id, in_queue, out_queue, timeout, facility)
 
         -- process packet
         if pkt.scada_frame.protocol() == PROTOCOL.SCADA_MGMT then
+            ---@cast pkt mgmt_frame
             if pkt.type == SCADA_MGMT_TYPE.KEEP_ALIVE then
                 -- keep alive reply
                 if pkt.length == 2 then
@@ -214,6 +215,7 @@ function coordinator.new_session(id, in_queue, out_queue, timeout, facility)
                 log.debug(log_header .. "handler received unsupported SCADA_MGMT packet type " .. pkt.type)
             end
         elseif pkt.scada_frame.protocol() == PROTOCOL.SCADA_CRDN then
+            ---@cast pkt crdn_frame
             if pkt.type == SCADA_CRDN_TYPE.INITIAL_BUILDS then
                 -- acknowledgement to coordinator receiving builds
                 self.acks.builds = true
diff --git a/supervisor/session/plc.lua b/supervisor/session/plc.lua
index 40efd8c..fd0d0f8 100644
--- a/supervisor/session/plc.lua
+++ b/supervisor/session/plc.lua
@@ -279,7 +279,7 @@ function plc.new_session(id, reactor_id, in_queue, out_queue, timeout)
     end
 
     -- handle a packet
-    ---@param pkt rplc_frame
+    ---@param pkt mgmt_frame|rplc_frame
     local function _handle_packet(pkt)
         -- check sequence number
         if self.r_seq_num == nil then
@@ -293,6 +293,7 @@ function plc.new_session(id, reactor_id, in_queue, out_queue, timeout)
 
         -- process packet
         if pkt.scada_frame.protocol() == PROTOCOL.RPLC then
+            ---@cast pkt rplc_frame
             -- check reactor ID
             if pkt.id ~= reactor_id then
                 log.warning(log_header .. "RPLC packet with ID not matching reactor ID: reactor " .. reactor_id .. " != " .. pkt.id)
@@ -469,6 +470,7 @@ function plc.new_session(id, reactor_id, in_queue, out_queue, timeout)
                 log.debug(log_header .. "handler received unsupported RPLC packet type " .. pkt.type)
             end
         elseif pkt.scada_frame.protocol() == PROTOCOL.SCADA_MGMT then
+            ---@cast pkt mgmt_frame
             if pkt.type == SCADA_MGMT_TYPE.KEEP_ALIVE then
                 -- keep alive reply
                 if pkt.length == 2 then
diff --git a/supervisor/session/rtu.lua b/supervisor/session/rtu.lua
index da5648c..9357c71 100644
--- a/supervisor/session/rtu.lua
+++ b/supervisor/session/rtu.lua
@@ -226,12 +226,13 @@ function rtu.new_session(id, in_queue, out_queue, timeout, advertisement, facili
 
         -- process packet
         if pkt.scada_frame.protocol() == PROTOCOL.MODBUS_TCP then
+            ---@cast pkt modbus_frame
             if self.units[pkt.unit_id] ~= nil then
                 local unit = self.units[pkt.unit_id]    ---@type unit_session
----@diagnostic disable-next-line: param-type-mismatch
                 unit.handle_packet(pkt)
             end
         elseif pkt.scada_frame.protocol() == PROTOCOL.SCADA_MGMT then
+            ---@cast pkt mgmt_frame
             -- handle management packet
             if pkt.type == SCADA_MGMT_TYPE.KEEP_ALIVE then
                 -- keep alive reply
diff --git a/supervisor/session/svsessions.lua b/supervisor/session/svsessions.lua
index aa3506b..9ed462d 100644
--- a/supervisor/session/svsessions.lua
+++ b/supervisor/session/svsessions.lua
@@ -123,7 +123,7 @@ local function _iterate(sessions)
 end
 
 -- cleanly close a session
----@param session plc_session_struct|rtu_session_struct
+---@param session plc_session_struct|rtu_session_struct|coord_session_struct
 local function _shutdown(session)
     session.open = false
     session.instance.close()
@@ -143,10 +143,8 @@ end
 ---@param sessions table
 local function _close(sessions)
     for i = 1, #sessions do
-        local session = sessions[i]  ---@type plc_session_struct|rtu_session_struct
-        if session.open then
-            _shutdown(session)
-        end
+        local session = sessions[i]  ---@type plc_session_struct|rtu_session_struct|coord_session_struct
+        if session.open then _shutdown(session) end
     end
 end
 
@@ -155,7 +153,7 @@ end
 ---@param timer_event number
 local function _check_watchdogs(sessions, timer_event)
     for i = 1, #sessions do
-        local session = sessions[i]  ---@type plc_session_struct|rtu_session_struct
+        local session = sessions[i]  ---@type plc_session_struct|rtu_session_struct|coord_session_struct
         if session.open then
             local triggered = session.instance.check_wd(timer_event)
             if triggered then
@@ -172,6 +170,7 @@ end
 local function _free_closed(sessions)
     local f = function (session) return session.open end
 
+    ---@param session plc_session_struct|rtu_session_struct|coord_session_struct
     local on_delete = function (session)
         log.debug(util.c("free'ing closed ", session.s_type, " session ", session.instance.get_id(),
             " on remote port ", session.r_port))
diff --git a/supervisor/startup.lua b/supervisor/startup.lua
index 87bd902..68db989 100644
--- a/supervisor/startup.lua
+++ b/supervisor/startup.lua
@@ -9,12 +9,12 @@ local log        = require("scada-common.log")
 local ppm        = require("scada-common.ppm")
 local util       = require("scada-common.util")
 
-local svsessions = require("supervisor.session.svsessions")
-
 local config     = require("supervisor.config")
 local supervisor = require("supervisor.supervisor")
 
-local SUPERVISOR_VERSION = "v0.14.4"
+local svsessions = require("supervisor.session.svsessions")
+
+local SUPERVISOR_VERSION = "v0.14.5"
 
 local println = util.println
 local println_ts = util.println_ts
diff --git a/supervisor/supervisor.lua b/supervisor/supervisor.lua
index 937dac8..99226fb 100644
--- a/supervisor/supervisor.lua
+++ b/supervisor/supervisor.lua
@@ -22,6 +22,7 @@ local println = util.println
 ---@param dev_listen integer listening port for PLC/RTU devices
 ---@param coord_listen integer listening port for coordinator
 ---@param range integer trusted device connection range
+---@diagnostic disable-next-line: unused-local
 function supervisor.comms(version, num_reactors, cooling_conf, modem, dev_listen, coord_listen, range)
     local self = {
         last_est_acks = {}