From ac94020b48cf418a432443252ac341a8037adce8 Mon Sep 17 00:00:00 2001 From: Joshua Barretto <joshua.s.barretto@gmail.com> Date: Mon, 25 Feb 2019 16:45:54 +0000 Subject: [PATCH] Netcode usability improvements Former-commit-id: a0e0d0b6fd503b4de9c60d2fcb02562def98c75e --- common/src/net/error.rs | 54 +++++++++++++++++++-- common/src/net/postbox.rs | 89 ++++++++++++++++++++++++----------- common/src/net/postoffice.rs | 73 ++++++++++++++++++++--------- common/src/net/test.rs | 91 ++++++++++++++++++++---------------- 4 files changed, 215 insertions(+), 92 deletions(-) diff --git a/common/src/net/error.rs b/common/src/net/error.rs index cd86e809bf..c3dd162fc9 100644 --- a/common/src/net/error.rs +++ b/common/src/net/error.rs @@ -1,26 +1,74 @@ #[derive(Debug)] pub enum PostError { + InvalidMessage, + InternalError, + Disconnected, +} + +#[derive(Debug)] +pub enum PostErrorInternal { Io(std::io::Error), Serde(bincode::Error), ChannelRecv(std::sync::mpsc::TryRecvError), ChannelSend, // Empty because I couldn't figure out how to handle generic type in mpsc::TrySendError properly MsgSizeLimitExceeded, + MioError, } +impl<'a, T: Into<&'a PostErrorInternal>> From<T> for PostError { + fn from(err: T) -> Self { + match err.into() { + // TODO: Are I/O errors always disconnect errors? + PostErrorInternal::Io(_) => PostError::Disconnected, + PostErrorInternal::Serde(_) => PostError::InvalidMessage, + PostErrorInternal::MsgSizeLimitExceeded => PostError::InvalidMessage, + PostErrorInternal::MioError => PostError::InternalError, + PostErrorInternal::ChannelRecv(_) => PostError::InternalError, + PostErrorInternal::ChannelSend => PostError::InternalError, + } + } +} + +impl From<PostErrorInternal> for PostError { + fn from(err: PostErrorInternal) -> Self { + (&err).into() + } +} + +impl From<std::io::Error> for PostErrorInternal { + fn from(err: std::io::Error) -> Self { + PostErrorInternal::Io(err) + } +} + +impl From<bincode::Error> for PostErrorInternal { + fn from(err: bincode::Error) -> Self { + PostErrorInternal::Serde(err) + } +} + +impl From<std::sync::mpsc::TryRecvError> for PostErrorInternal { + fn from(err: std::sync::mpsc::TryRecvError) -> Self { + PostErrorInternal::ChannelRecv(err) + } +} + + + impl From<std::io::Error> for PostError { fn from(err: std::io::Error) -> Self { - PostError::Io(err) + (&PostErrorInternal::from(err)).into() } } impl From<bincode::Error> for PostError { fn from(err: bincode::Error) -> Self { - PostError::Serde(err) + (&PostErrorInternal::from(err)).into() } } impl From<std::sync::mpsc::TryRecvError> for PostError { fn from(err: std::sync::mpsc::TryRecvError) -> Self { - PostError::ChannelRecv(err) + (&PostErrorInternal::from(err)).into() } } diff --git a/common/src/net/postbox.rs b/common/src/net/postbox.rs index 08d6eb254e..87a41da4dd 100644 --- a/common/src/net/postbox.rs +++ b/common/src/net/postbox.rs @@ -1,10 +1,15 @@ // Standard -use std::collections::VecDeque; -use std::convert::TryFrom; -use std::io::ErrorKind; -use std::io::Read; -use std::net::SocketAddr; -use std::thread; +use std::{ + collections::VecDeque, + convert::TryFrom, + io::{ + ErrorKind, + Read, + }, + net::SocketAddr, + thread, + time::Duration, +}; // External use bincode; @@ -12,9 +17,15 @@ use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token}; use mio_extras::channel::{channel, Receiver, Sender}; // Crate -use super::data::ControlMsg; -use super::error::PostError; -use super::{PostRecv, PostSend}; +use super::{ + data::ControlMsg, + error::{ + PostError, + PostErrorInternal, + }, + PostRecv, + PostSend, +}; // Constants const CTRL_TOKEN: Token = Token(0); // Token for thread control messages @@ -31,9 +42,10 @@ where { handle: Option<thread::JoinHandle<()>>, ctrl: Sender<ControlMsg>, - recv: Receiver<Result<R, PostError>>, + recv: Receiver<Result<R, PostErrorInternal>>, send: Sender<S>, poll: Poll, + err: Option<PostErrorInternal>, } impl<S, R> PostBox<S, R> @@ -42,16 +54,16 @@ where R: PostRecv, { /// Creates a new [`PostBox`] connected to specified address, meant to be used by the client - pub fn to_server(addr: &SocketAddr) -> Result<PostBox<S, R>, PostError> { - let connection = TcpStream::connect(addr)?; + pub fn to_server<A: Into<SocketAddr>>(addr: A) -> Result<PostBox<S, R>, PostError> { + let connection = TcpStream::connect(&addr.into())?; Self::from_tcpstream(connection) } /// Creates a new [`PostBox`] from an existing connection, meant to be used by [`PostOffice`](super::PostOffice) on the server pub fn from_tcpstream(connection: TcpStream) -> Result<PostBox<S, R>, PostError> { - let (ctrl_tx, ctrl_rx) = channel::<ControlMsg>(); // Control messages - let (send_tx, send_rx) = channel::<S>(); // main thread -[data to be serialized and sent]> worker thread - let (recv_tx, recv_rx) = channel::<Result<R, PostError>>(); // main thread <[received and deserialized data]- worker thread + let (ctrl_tx, ctrl_rx) = channel(); // Control messages + let (send_tx, send_rx) = channel(); // main thread -[data to be serialized and sent]> worker thread + let (recv_tx, recv_rx) = channel(); // main thread <[received and deserialized data]- worker thread let thread_poll = Poll::new().unwrap(); let postbox_poll = Poll::new().unwrap(); thread_poll @@ -75,31 +87,54 @@ where recv: recv_rx, send: send_tx, poll: postbox_poll, + err: None, }) } + /// Return an `Option<PostError>` indicating the current status of the `PostBox`. + pub fn status(&self) -> Option<PostError> { + self.err.as_ref().map(|err| err.into()) + } + /// Non-blocking sender method - pub fn send(&self, data: S) { - self.send.send(data).unwrap_or(()); + pub fn send(&mut self, data: S) -> Result<(), PostError> { + match &mut self.err { + err @ None => if let Err(_) = self.send.send(data) { + *err = Some(PostErrorInternal::MioError); + Err(err.as_ref().unwrap().into()) + } else { + Ok(()) + }, + err => Err(err.as_ref().unwrap().into()), + } } /// Non-blocking receiver method returning an iterator over already received and deserialized objects /// # Errors /// If the other side disconnects PostBox won't realize that until you try to send something - pub fn recv_iter(&self) -> Result<impl Iterator<Item = Result<R, PostError>>, PostError> { + pub fn recv_iter(&mut self) -> impl Iterator<Item = R> { let mut events = Events::with_capacity(4096); - self.poll - .poll(&mut events, Some(core::time::Duration::new(0, 0)))?; - let mut data: VecDeque<Result<R, PostError>> = VecDeque::new(); + let mut items = VecDeque::new(); + + // If an error occured, or previously occured, just give up + if let Some(_) = self.err { + return items.into_iter(); + } else if let Err(err) = self.poll.poll(&mut events, Some(Duration::new(0, 0))) { + self.err = Some(err.into()); + return items.into_iter(); + } + for event in events { match event.token() { - DATA_TOKEN => { - data.push_back(self.recv.try_recv()?); - } + DATA_TOKEN => match self.recv.try_recv() { + Ok(Ok(item)) => items.push_back(item), + Err(err) => self.err = Some(err.into()), + Ok(Err(err)) => self.err = Some(err.into()), + }, _ => (), } } - Ok(data.into_iter()) + items.into_iter() } } @@ -107,7 +142,7 @@ fn postbox_thread<S, R>( mut connection: TcpStream, ctrl_rx: Receiver<ControlMsg>, send_rx: Receiver<S>, - recv_tx: Sender<Result<R, PostError>>, + recv_tx: Sender<Result<R, PostErrorInternal>>, poll: Poll, ) where S: PostSend, @@ -154,7 +189,7 @@ fn postbox_thread<S, R>( .unwrap(), ); if recv_nextlen > MESSAGE_SIZE_CAP { - recv_tx.send(Err(PostError::MsgSizeLimitExceeded)).unwrap(); + recv_tx.send(Err(PostErrorInternal::MsgSizeLimitExceeded)).unwrap(); connection.shutdown(std::net::Shutdown::Both).unwrap(); recv_buff.drain(..); recv_nextlen = 0; diff --git a/common/src/net/postoffice.rs b/common/src/net/postoffice.rs index c1e718ac30..479f940419 100644 --- a/common/src/net/postoffice.rs +++ b/common/src/net/postoffice.rs @@ -1,18 +1,26 @@ // Standard use core::time::Duration; -use std::collections::VecDeque; -use std::net::SocketAddr; -use std::thread; +use std::{ + collections::VecDeque, + net::SocketAddr, + thread, +}; // External use mio::{net::TcpListener, Events, Poll, PollOpt, Ready, Token}; use mio_extras::channel::{channel, Receiver, Sender}; // Crate -use super::data::ControlMsg; -use super::error::PostError; -use super::postbox::PostBox; -use super::{PostRecv, PostSend}; +use super::{ + data::ControlMsg, + error::{ + PostError, + PostErrorInternal, + }, + postbox::PostBox, + PostRecv, + PostSend, +}; // Constants const CTRL_TOKEN: Token = Token(0); // Token for thread control messages @@ -28,8 +36,9 @@ where { handle: Option<thread::JoinHandle<()>>, ctrl: Sender<ControlMsg>, - recv: Receiver<Result<PostBox<S, R>, PostError>>, + recv: Receiver<Result<PostBox<S, R>, PostErrorInternal>>, poll: Poll, + err: Option<PostErrorInternal>, } impl<S, R> PostOffice<S, R> @@ -38,49 +47,69 @@ where R: PostRecv, { /// Creates a new [`PostOffice`] listening on specified address - pub fn new(addr: &SocketAddr) -> Result<Self, PostError> { - let listener = TcpListener::bind(addr)?; + pub fn new<A: Into<SocketAddr>>(addr: A) -> Result<Self, PostError> { + let listener = TcpListener::bind(&addr.into())?; let (ctrl_tx, ctrl_rx) = channel(); let (recv_tx, recv_rx) = channel(); + let thread_poll = Poll::new()?; let postbox_poll = Poll::new()?; thread_poll.register(&listener, CONN_TOKEN, Ready::readable(), PollOpt::edge())?; thread_poll.register(&ctrl_rx, CTRL_TOKEN, Ready::readable(), PollOpt::edge())?; postbox_poll.register(&recv_rx, DATA_TOKEN, Ready::readable(), PollOpt::edge())?; + let handle = thread::Builder::new() .name("postoffice_worker".into()) .spawn(move || postoffice_thread(listener, ctrl_rx, recv_tx, thread_poll))?; + Ok(PostOffice { handle: Some(handle), ctrl: ctrl_tx, recv: recv_rx, poll: postbox_poll, + err: None, }) } + /// Return an `Option<PostError>` indicating the current status of the `PostOffice`. + pub fn status(&self) -> Option<PostError> { + self.err.as_ref().map(|err| err.into()) + } + /// Non-blocking method returning an iterator over new connections wrapped in [`PostBox`]es - pub fn get_iter( - &self, - ) -> Result<impl Iterator<Item = Result<PostBox<S, R>, PostError>>, PostError> { + pub fn new_connections( + &mut self, + ) -> impl Iterator<Item = PostBox<S, R>> { let mut events = Events::with_capacity(256); - self.poll.poll(&mut events, Some(Duration::new(0, 0)))?; - let mut conns: VecDeque<Result<PostBox<S, R>, PostError>> = VecDeque::new(); + let mut conns = VecDeque::new(); + + // If an error occured, or previously occured, just give up + if let Some(_) = self.err { + return conns.into_iter(); + } else if let Err(err) = self.poll.poll(&mut events, Some(Duration::new(0, 0))) { + self.err = Some(err.into()); + return conns.into_iter(); + } + for event in events { match event.token() { - DATA_TOKEN => { - conns.push_back(self.recv.try_recv()?); - } + // Ignore recv error + DATA_TOKEN => match self.recv.try_recv() { + Ok(Ok(conn)) => conns.push_back(conn), + Err(err) => self.err = Some(err.into()), + Ok(Err(err)) => self.err = Some(err.into()), + }, _ => (), } } - Ok(conns.into_iter()) + conns.into_iter() } } fn postoffice_thread<S, R>( listener: TcpListener, ctrl_rx: Receiver<ControlMsg>, - recv_tx: Sender<Result<PostBox<S, R>, PostError>>, + recv_tx: Sender<Result<PostBox<S, R>, PostErrorInternal>>, poll: Poll, ) where S: PostSend, @@ -96,7 +125,9 @@ fn postoffice_thread<S, R>( }, CONN_TOKEN => { let (conn, _addr) = listener.accept().unwrap(); - recv_tx.send(PostBox::from_tcpstream(conn)).unwrap(); + recv_tx.send(PostBox::from_tcpstream(conn) + // TODO: Is it okay to count a failure to create a postbox here as an 'internal error'? + .map_err(|_| PostErrorInternal::MioError)).unwrap(); } _ => (), } diff --git a/common/src/net/test.rs b/common/src/net/test.rs index 1bcc7d2a9b..98e03142f3 100644 --- a/common/src/net/test.rs +++ b/common/src/net/test.rs @@ -1,69 +1,78 @@ -use std::io::Write; -use std::net::SocketAddr; +use std::{ + io::Write, + str::FromStr, + net::SocketAddr, + thread, + time::Duration, +}; use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token}; use super::{error::PostError, PostBox, PostOffice}; +fn new_local_addr(n: u16) -> SocketAddr { + SocketAddr::from(([127, 0, 0, 1], 12345 + n)) +} + #[test] fn basic_run() { - let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12345u16)); - let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12345u16)); - let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap(); - let client: PostBox<String, String> = PostBox::to_server(&conn_addr).unwrap(); + let srv_addr = new_local_addr(0); + let mut server: PostOffice<String, String> = PostOffice::new(srv_addr).unwrap(); + let mut client: PostBox<String, String> = PostBox::to_server(srv_addr).unwrap(); std::thread::sleep(std::time::Duration::from_millis(10)); - let scon = server.get_iter().unwrap().next().unwrap().unwrap(); + let mut scon = server.new_connections().next().unwrap(); std::thread::sleep(std::time::Duration::from_millis(10)); - scon.send(String::from("foo")); - client.send(String::from("bar")); + scon.send(String::from("foo")).unwrap(); + client.send(String::from("bar")).unwrap(); std::thread::sleep(std::time::Duration::from_millis(10)); - assert_eq!("foo", client.recv_iter().unwrap().next().unwrap().unwrap()); - assert_eq!("bar", scon.recv_iter().unwrap().next().unwrap().unwrap()); + assert_eq!("foo", client.recv_iter().next().unwrap()); + assert_eq!("bar", scon.recv_iter().next().unwrap()); } #[test] fn huge_size_header() { - let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12346u16)); - let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12346u16)); - let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap(); - let mut client = TcpStream::connect(&conn_addr).unwrap(); + let srv_addr = new_local_addr(1); + + let mut server: PostOffice<String, String> = PostOffice::new(srv_addr).unwrap(); + let mut client = TcpStream::connect(&srv_addr).unwrap(); std::thread::sleep(std::time::Duration::from_millis(10)); - let scon = server.get_iter().unwrap().next().unwrap().unwrap(); + let mut scon = server.new_connections().next().unwrap(); std::thread::sleep(std::time::Duration::from_millis(10)); client.write(&[0xffu8; 64]).unwrap(); std::thread::sleep(std::time::Duration::from_millis(10)); - assert!(match scon.recv_iter().unwrap().next().unwrap() { - Err(PostError::MsgSizeLimitExceeded) => true, - _ => false, - }); + assert_eq!(scon.recv_iter().next(), None); } #[test] fn disconnect() { - let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12347u16)); - let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12347u16)); - let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap(); + let srv_addr = new_local_addr(2); + + let mut server = PostOffice::<_, String>::new(srv_addr) + .unwrap(); + + // Create then close client { - #[allow(unused_variables)] - let client: PostBox<String, String> = PostBox::to_server(&conn_addr).unwrap(); + PostBox::<String, String>::to_server(srv_addr).unwrap(); } - std::thread::sleep(std::time::Duration::from_millis(10)); - let scon = server.get_iter().unwrap().next().unwrap().unwrap(); - scon.send(String::from("foo")); + std::thread::sleep(std::time::Duration::from_millis(10)); - match scon.recv_iter().unwrap().next().unwrap() { - Ok(_) => panic!("Didn't expect to receive anything"), - Err(err) => { - if !(match err { - PostError::Io(e) => e, - _ => panic!("PostError different than expected"), - } - .kind() - == std::io::ErrorKind::BrokenPipe) - { - panic!("Error different than disconnection") - } - } + let mut to_client = server + .new_connections() + .next() + .unwrap(); + + to_client.send(String::from("foo")).unwrap(); + + thread::sleep(Duration::from_millis(10)); + + match to_client.recv_iter().next() { + None => {}, + _ => panic!("Unexpected message!"), + } + + match to_client.status() { + Some(PostError::Disconnected) => {}, + s => panic!("Did not expect {:?}", s), } }