From ac94020b48cf418a432443252ac341a8037adce8 Mon Sep 17 00:00:00 2001
From: Joshua Barretto <joshua.s.barretto@gmail.com>
Date: Mon, 25 Feb 2019 16:45:54 +0000
Subject: [PATCH] Netcode usability improvements

Former-commit-id: a0e0d0b6fd503b4de9c60d2fcb02562def98c75e
---
 common/src/net/error.rs      | 54 +++++++++++++++++++--
 common/src/net/postbox.rs    | 89 ++++++++++++++++++++++++-----------
 common/src/net/postoffice.rs | 73 ++++++++++++++++++++---------
 common/src/net/test.rs       | 91 ++++++++++++++++++++----------------
 4 files changed, 215 insertions(+), 92 deletions(-)

diff --git a/common/src/net/error.rs b/common/src/net/error.rs
index cd86e809bf..c3dd162fc9 100644
--- a/common/src/net/error.rs
+++ b/common/src/net/error.rs
@@ -1,26 +1,74 @@
 #[derive(Debug)]
 pub enum PostError {
+    InvalidMessage,
+    InternalError,
+    Disconnected,
+}
+
+#[derive(Debug)]
+pub enum PostErrorInternal {
     Io(std::io::Error),
     Serde(bincode::Error),
     ChannelRecv(std::sync::mpsc::TryRecvError),
     ChannelSend, // Empty because I couldn't figure out how to handle generic type in mpsc::TrySendError properly
     MsgSizeLimitExceeded,
+    MioError,
 }
 
+impl<'a, T: Into<&'a PostErrorInternal>> From<T> for PostError {
+    fn from(err: T) -> Self {
+        match err.into() {
+            // TODO: Are I/O errors always disconnect errors?
+            PostErrorInternal::Io(_) => PostError::Disconnected,
+            PostErrorInternal::Serde(_) => PostError::InvalidMessage,
+            PostErrorInternal::MsgSizeLimitExceeded => PostError::InvalidMessage,
+            PostErrorInternal::MioError => PostError::InternalError,
+            PostErrorInternal::ChannelRecv(_) => PostError::InternalError,
+            PostErrorInternal::ChannelSend => PostError::InternalError,
+        }
+    }
+}
+
+impl From<PostErrorInternal> for PostError {
+    fn from(err: PostErrorInternal) -> Self {
+        (&err).into()
+    }
+}
+
+impl From<std::io::Error> for PostErrorInternal {
+    fn from(err: std::io::Error) -> Self {
+        PostErrorInternal::Io(err)
+    }
+}
+
+impl From<bincode::Error> for PostErrorInternal {
+    fn from(err: bincode::Error) -> Self {
+        PostErrorInternal::Serde(err)
+    }
+}
+
+impl From<std::sync::mpsc::TryRecvError> for PostErrorInternal {
+    fn from(err: std::sync::mpsc::TryRecvError) -> Self {
+        PostErrorInternal::ChannelRecv(err)
+    }
+}
+
+
+
 impl From<std::io::Error> for PostError {
     fn from(err: std::io::Error) -> Self {
-        PostError::Io(err)
+        (&PostErrorInternal::from(err)).into()
     }
 }
 
 impl From<bincode::Error> for PostError {
     fn from(err: bincode::Error) -> Self {
-        PostError::Serde(err)
+        (&PostErrorInternal::from(err)).into()
     }
 }
 
 impl From<std::sync::mpsc::TryRecvError> for PostError {
     fn from(err: std::sync::mpsc::TryRecvError) -> Self {
-        PostError::ChannelRecv(err)
+        (&PostErrorInternal::from(err)).into()
     }
 }
diff --git a/common/src/net/postbox.rs b/common/src/net/postbox.rs
index 08d6eb254e..87a41da4dd 100644
--- a/common/src/net/postbox.rs
+++ b/common/src/net/postbox.rs
@@ -1,10 +1,15 @@
 // Standard
-use std::collections::VecDeque;
-use std::convert::TryFrom;
-use std::io::ErrorKind;
-use std::io::Read;
-use std::net::SocketAddr;
-use std::thread;
+use std::{
+    collections::VecDeque,
+    convert::TryFrom,
+    io::{
+        ErrorKind,
+        Read,
+    },
+    net::SocketAddr,
+    thread,
+    time::Duration,
+};
 
 // External
 use bincode;
@@ -12,9 +17,15 @@ use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token};
 use mio_extras::channel::{channel, Receiver, Sender};
 
 // Crate
-use super::data::ControlMsg;
-use super::error::PostError;
-use super::{PostRecv, PostSend};
+use super::{
+    data::ControlMsg,
+    error::{
+        PostError,
+        PostErrorInternal,
+    },
+    PostRecv,
+    PostSend,
+};
 
 // Constants
 const CTRL_TOKEN: Token = Token(0); // Token for thread control messages
@@ -31,9 +42,10 @@ where
 {
     handle: Option<thread::JoinHandle<()>>,
     ctrl: Sender<ControlMsg>,
-    recv: Receiver<Result<R, PostError>>,
+    recv: Receiver<Result<R, PostErrorInternal>>,
     send: Sender<S>,
     poll: Poll,
+    err: Option<PostErrorInternal>,
 }
 
 impl<S, R> PostBox<S, R>
@@ -42,16 +54,16 @@ where
     R: PostRecv,
 {
     /// Creates a new [`PostBox`] connected to specified address, meant to be used by the client
-    pub fn to_server(addr: &SocketAddr) -> Result<PostBox<S, R>, PostError> {
-        let connection = TcpStream::connect(addr)?;
+    pub fn to_server<A: Into<SocketAddr>>(addr: A) -> Result<PostBox<S, R>, PostError> {
+        let connection = TcpStream::connect(&addr.into())?;
         Self::from_tcpstream(connection)
     }
 
     /// Creates a new [`PostBox`] from an existing connection, meant to be used by [`PostOffice`](super::PostOffice) on the server
     pub fn from_tcpstream(connection: TcpStream) -> Result<PostBox<S, R>, PostError> {
-        let (ctrl_tx, ctrl_rx) = channel::<ControlMsg>(); // Control messages
-        let (send_tx, send_rx) = channel::<S>(); // main thread -[data to be serialized and sent]> worker thread
-        let (recv_tx, recv_rx) = channel::<Result<R, PostError>>(); // main thread <[received and deserialized data]- worker thread
+        let (ctrl_tx, ctrl_rx) = channel(); // Control messages
+        let (send_tx, send_rx) = channel(); // main thread -[data to be serialized and sent]> worker thread
+        let (recv_tx, recv_rx) = channel(); // main thread <[received and deserialized data]- worker thread
         let thread_poll = Poll::new().unwrap();
         let postbox_poll = Poll::new().unwrap();
         thread_poll
@@ -75,31 +87,54 @@ where
             recv: recv_rx,
             send: send_tx,
             poll: postbox_poll,
+            err: None,
         })
     }
 
+    /// Return an `Option<PostError>` indicating the current status of the `PostBox`.
+    pub fn status(&self) -> Option<PostError> {
+        self.err.as_ref().map(|err| err.into())
+    }
+
     /// Non-blocking sender method
-    pub fn send(&self, data: S) {
-        self.send.send(data).unwrap_or(());
+    pub fn send(&mut self, data: S) -> Result<(), PostError> {
+        match &mut self.err {
+            err @ None => if let Err(_) = self.send.send(data) {
+                *err = Some(PostErrorInternal::MioError);
+                Err(err.as_ref().unwrap().into())
+            } else {
+                Ok(())
+            },
+            err => Err(err.as_ref().unwrap().into()),
+        }
     }
 
     /// Non-blocking receiver method returning an iterator over already received and deserialized objects
     /// # Errors
     /// If the other side disconnects PostBox won't realize that until you try to send something
-    pub fn recv_iter(&self) -> Result<impl Iterator<Item = Result<R, PostError>>, PostError> {
+    pub fn recv_iter(&mut self) -> impl Iterator<Item = R> {
         let mut events = Events::with_capacity(4096);
-        self.poll
-            .poll(&mut events, Some(core::time::Duration::new(0, 0)))?;
-        let mut data: VecDeque<Result<R, PostError>> = VecDeque::new();
+        let mut items = VecDeque::new();
+
+        // If an error occured, or previously occured, just give up
+        if let Some(_) = self.err {
+            return items.into_iter();
+        } else if let Err(err) = self.poll.poll(&mut events, Some(Duration::new(0, 0))) {
+            self.err = Some(err.into());
+            return items.into_iter();
+        }
+
         for event in events {
             match event.token() {
-                DATA_TOKEN => {
-                    data.push_back(self.recv.try_recv()?);
-                }
+                DATA_TOKEN => match self.recv.try_recv() {
+                    Ok(Ok(item)) => items.push_back(item),
+                    Err(err) => self.err = Some(err.into()),
+                    Ok(Err(err)) => self.err = Some(err.into()),
+                },
                 _ => (),
             }
         }
-        Ok(data.into_iter())
+        items.into_iter()
     }
 }
 
@@ -107,7 +142,7 @@ fn postbox_thread<S, R>(
     mut connection: TcpStream,
     ctrl_rx: Receiver<ControlMsg>,
     send_rx: Receiver<S>,
-    recv_tx: Sender<Result<R, PostError>>,
+    recv_tx: Sender<Result<R, PostErrorInternal>>,
     poll: Poll,
 ) where
     S: PostSend,
@@ -154,7 +189,7 @@ fn postbox_thread<S, R>(
                         .unwrap(),
                 );
                 if recv_nextlen > MESSAGE_SIZE_CAP {
-                    recv_tx.send(Err(PostError::MsgSizeLimitExceeded)).unwrap();
+                    recv_tx.send(Err(PostErrorInternal::MsgSizeLimitExceeded)).unwrap();
                     connection.shutdown(std::net::Shutdown::Both).unwrap();
                     recv_buff.drain(..);
                     recv_nextlen = 0;
diff --git a/common/src/net/postoffice.rs b/common/src/net/postoffice.rs
index c1e718ac30..479f940419 100644
--- a/common/src/net/postoffice.rs
+++ b/common/src/net/postoffice.rs
@@ -1,18 +1,26 @@
 // Standard
 use core::time::Duration;
-use std::collections::VecDeque;
-use std::net::SocketAddr;
-use std::thread;
+use std::{
+    collections::VecDeque,
+    net::SocketAddr,
+    thread,
+};
 
 // External
 use mio::{net::TcpListener, Events, Poll, PollOpt, Ready, Token};
 use mio_extras::channel::{channel, Receiver, Sender};
 
 // Crate
-use super::data::ControlMsg;
-use super::error::PostError;
-use super::postbox::PostBox;
-use super::{PostRecv, PostSend};
+use super::{
+    data::ControlMsg,
+    error::{
+        PostError,
+        PostErrorInternal,
+    },
+    postbox::PostBox,
+    PostRecv,
+    PostSend,
+};
 
 // Constants
 const CTRL_TOKEN: Token = Token(0); // Token for thread control messages
@@ -28,8 +36,9 @@ where
 {
     handle: Option<thread::JoinHandle<()>>,
     ctrl: Sender<ControlMsg>,
-    recv: Receiver<Result<PostBox<S, R>, PostError>>,
+    recv: Receiver<Result<PostBox<S, R>, PostErrorInternal>>,
     poll: Poll,
+    err: Option<PostErrorInternal>,
 }
 
 impl<S, R> PostOffice<S, R>
@@ -38,49 +47,69 @@ where
     R: PostRecv,
 {
     /// Creates a new [`PostOffice`] listening on specified address
-    pub fn new(addr: &SocketAddr) -> Result<Self, PostError> {
-        let listener = TcpListener::bind(addr)?;
+    pub fn new<A: Into<SocketAddr>>(addr: A) -> Result<Self, PostError> {
+        let listener = TcpListener::bind(&addr.into())?;
         let (ctrl_tx, ctrl_rx) = channel();
         let (recv_tx, recv_rx) = channel();
+
         let thread_poll = Poll::new()?;
         let postbox_poll = Poll::new()?;
         thread_poll.register(&listener, CONN_TOKEN, Ready::readable(), PollOpt::edge())?;
         thread_poll.register(&ctrl_rx, CTRL_TOKEN, Ready::readable(), PollOpt::edge())?;
         postbox_poll.register(&recv_rx, DATA_TOKEN, Ready::readable(), PollOpt::edge())?;
+
         let handle = thread::Builder::new()
             .name("postoffice_worker".into())
             .spawn(move || postoffice_thread(listener, ctrl_rx, recv_tx, thread_poll))?;
+
         Ok(PostOffice {
             handle: Some(handle),
             ctrl: ctrl_tx,
             recv: recv_rx,
             poll: postbox_poll,
+            err: None,
         })
     }
 
+    /// Return an `Option<PostError>` indicating the current status of the `PostOffice`.
+    pub fn status(&self) -> Option<PostError> {
+        self.err.as_ref().map(|err| err.into())
+    }
+
     /// Non-blocking method returning an iterator over new connections wrapped in [`PostBox`]es
-    pub fn get_iter(
-        &self,
-    ) -> Result<impl Iterator<Item = Result<PostBox<S, R>, PostError>>, PostError> {
+    pub fn new_connections(
+        &mut self,
+    ) -> impl Iterator<Item = PostBox<S, R>> {
         let mut events = Events::with_capacity(256);
-        self.poll.poll(&mut events, Some(Duration::new(0, 0)))?;
-        let mut conns: VecDeque<Result<PostBox<S, R>, PostError>> = VecDeque::new();
+        let mut conns = VecDeque::new();
+
+        // If an error occured, or previously occured, just give up
+        if let Some(_) = self.err {
+            return conns.into_iter();
+        } else if let Err(err) = self.poll.poll(&mut events, Some(Duration::new(0, 0))) {
+            self.err = Some(err.into());
+            return conns.into_iter();
+        }
+
         for event in events {
             match event.token() {
-                DATA_TOKEN => {
-                    conns.push_back(self.recv.try_recv()?);
-                }
+                // Ignore recv error
+                DATA_TOKEN => match self.recv.try_recv() {
+                    Ok(Ok(conn)) => conns.push_back(conn),
+                    Err(err) => self.err = Some(err.into()),
+                    Ok(Err(err)) => self.err = Some(err.into()),
+                },
                 _ => (),
             }
         }
-        Ok(conns.into_iter())
+        conns.into_iter()
     }
 }
 
 fn postoffice_thread<S, R>(
     listener: TcpListener,
     ctrl_rx: Receiver<ControlMsg>,
-    recv_tx: Sender<Result<PostBox<S, R>, PostError>>,
+    recv_tx: Sender<Result<PostBox<S, R>, PostErrorInternal>>,
     poll: Poll,
 ) where
     S: PostSend,
@@ -96,7 +125,9 @@ fn postoffice_thread<S, R>(
                 },
                 CONN_TOKEN => {
                     let (conn, _addr) = listener.accept().unwrap();
-                    recv_tx.send(PostBox::from_tcpstream(conn)).unwrap();
+                    recv_tx.send(PostBox::from_tcpstream(conn)
+                        // TODO: Is it okay to count a failure to create a postbox here as an 'internal error'?
+                        .map_err(|_| PostErrorInternal::MioError)).unwrap();
                 }
                 _ => (),
             }
diff --git a/common/src/net/test.rs b/common/src/net/test.rs
index 1bcc7d2a9b..98e03142f3 100644
--- a/common/src/net/test.rs
+++ b/common/src/net/test.rs
@@ -1,69 +1,78 @@
-use std::io::Write;
-use std::net::SocketAddr;
+use std::{
+    io::Write,
+    str::FromStr,
+    net::SocketAddr,
+    thread,
+    time::Duration,
+};
 
 use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token};
 
 use super::{error::PostError, PostBox, PostOffice};
 
+fn new_local_addr(n: u16) -> SocketAddr {
+    SocketAddr::from(([127, 0, 0, 1], 12345 + n))
+}
+
 #[test]
 fn basic_run() {
-    let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12345u16));
-    let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12345u16));
-    let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap();
-    let client: PostBox<String, String> = PostBox::to_server(&conn_addr).unwrap();
+    let srv_addr = new_local_addr(0);
+    let mut server: PostOffice<String, String> = PostOffice::new(srv_addr).unwrap();
+    let mut client: PostBox<String, String> = PostBox::to_server(srv_addr).unwrap();
     std::thread::sleep(std::time::Duration::from_millis(10));
-    let scon = server.get_iter().unwrap().next().unwrap().unwrap();
+    let mut scon = server.new_connections().next().unwrap();
     std::thread::sleep(std::time::Duration::from_millis(10));
-    scon.send(String::from("foo"));
-    client.send(String::from("bar"));
+    scon.send(String::from("foo")).unwrap();
+    client.send(String::from("bar")).unwrap();
     std::thread::sleep(std::time::Duration::from_millis(10));
-    assert_eq!("foo", client.recv_iter().unwrap().next().unwrap().unwrap());
-    assert_eq!("bar", scon.recv_iter().unwrap().next().unwrap().unwrap());
+    assert_eq!("foo", client.recv_iter().next().unwrap());
+    assert_eq!("bar", scon.recv_iter().next().unwrap());
 }
 
 #[test]
 fn huge_size_header() {
-    let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12346u16));
-    let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12346u16));
-    let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap();
-    let mut client = TcpStream::connect(&conn_addr).unwrap();
+    let srv_addr = new_local_addr(1);
+
+    let mut server: PostOffice<String, String> = PostOffice::new(srv_addr).unwrap();
+    let mut client = TcpStream::connect(&srv_addr).unwrap();
     std::thread::sleep(std::time::Duration::from_millis(10));
-    let scon = server.get_iter().unwrap().next().unwrap().unwrap();
+    let mut scon = server.new_connections().next().unwrap();
     std::thread::sleep(std::time::Duration::from_millis(10));
     client.write(&[0xffu8; 64]).unwrap();
     std::thread::sleep(std::time::Duration::from_millis(10));
-    assert!(match scon.recv_iter().unwrap().next().unwrap() {
-        Err(PostError::MsgSizeLimitExceeded) => true,
-        _ => false,
-    });
+    assert_eq!(scon.recv_iter().next(), None);
 }
 
 #[test]
 fn disconnect() {
-    let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12347u16));
-    let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12347u16));
-    let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap();
+    let srv_addr = new_local_addr(2);
+
+    let mut server = PostOffice::<_, String>::new(srv_addr)
+        .unwrap();
+
+    // Create then close client
     {
-        #[allow(unused_variables)]
-        let client: PostBox<String, String> = PostBox::to_server(&conn_addr).unwrap();
+        PostBox::<String, String>::to_server(srv_addr).unwrap();
     }
-    std::thread::sleep(std::time::Duration::from_millis(10));
-    let scon = server.get_iter().unwrap().next().unwrap().unwrap();
-    scon.send(String::from("foo"));
+
     std::thread::sleep(std::time::Duration::from_millis(10));
 
-    match scon.recv_iter().unwrap().next().unwrap() {
-        Ok(_) => panic!("Didn't expect to receive anything"),
-        Err(err) => {
-            if !(match err {
-                PostError::Io(e) => e,
-                _ => panic!("PostError different than expected"),
-            }
-            .kind()
-                == std::io::ErrorKind::BrokenPipe)
-            {
-                panic!("Error different than disconnection")
-            }
-        }
+    let mut to_client = server
+        .new_connections()
+        .next()
+        .unwrap();
+
+    to_client.send(String::from("foo")).unwrap();
+
+    thread::sleep(Duration::from_millis(10));
+
+    match to_client.recv_iter().next() {
+        None => {},
+        _ => panic!("Unexpected message!"),
+    }
+
+    match to_client.status() {
+        Some(PostError::Disconnected) => {},
+        s => panic!("Did not expect {:?}", s),
     }
 }