Netcode usability improvements

Former-commit-id: a0e0d0b6fd503b4de9c60d2fcb02562def98c75e
This commit is contained in:
Joshua Barretto 2019-02-25 16:45:54 +00:00
parent efb4273ad4
commit 3f5782e993
4 changed files with 215 additions and 92 deletions

View File

@ -1,26 +1,74 @@
#[derive(Debug)] #[derive(Debug)]
pub enum PostError { pub enum PostError {
InvalidMessage,
InternalError,
Disconnected,
}
#[derive(Debug)]
pub enum PostErrorInternal {
Io(std::io::Error), Io(std::io::Error),
Serde(bincode::Error), Serde(bincode::Error),
ChannelRecv(std::sync::mpsc::TryRecvError), ChannelRecv(std::sync::mpsc::TryRecvError),
ChannelSend, // Empty because I couldn't figure out how to handle generic type in mpsc::TrySendError properly ChannelSend, // Empty because I couldn't figure out how to handle generic type in mpsc::TrySendError properly
MsgSizeLimitExceeded, MsgSizeLimitExceeded,
MioError,
} }
impl<'a, T: Into<&'a PostErrorInternal>> From<T> for PostError {
fn from(err: T) -> Self {
match err.into() {
// TODO: Are I/O errors always disconnect errors?
PostErrorInternal::Io(_) => PostError::Disconnected,
PostErrorInternal::Serde(_) => PostError::InvalidMessage,
PostErrorInternal::MsgSizeLimitExceeded => PostError::InvalidMessage,
PostErrorInternal::MioError => PostError::InternalError,
PostErrorInternal::ChannelRecv(_) => PostError::InternalError,
PostErrorInternal::ChannelSend => PostError::InternalError,
}
}
}
impl From<PostErrorInternal> for PostError {
fn from(err: PostErrorInternal) -> Self {
(&err).into()
}
}
impl From<std::io::Error> for PostErrorInternal {
fn from(err: std::io::Error) -> Self {
PostErrorInternal::Io(err)
}
}
impl From<bincode::Error> for PostErrorInternal {
fn from(err: bincode::Error) -> Self {
PostErrorInternal::Serde(err)
}
}
impl From<std::sync::mpsc::TryRecvError> for PostErrorInternal {
fn from(err: std::sync::mpsc::TryRecvError) -> Self {
PostErrorInternal::ChannelRecv(err)
}
}
impl From<std::io::Error> for PostError { impl From<std::io::Error> for PostError {
fn from(err: std::io::Error) -> Self { fn from(err: std::io::Error) -> Self {
PostError::Io(err) (&PostErrorInternal::from(err)).into()
} }
} }
impl From<bincode::Error> for PostError { impl From<bincode::Error> for PostError {
fn from(err: bincode::Error) -> Self { fn from(err: bincode::Error) -> Self {
PostError::Serde(err) (&PostErrorInternal::from(err)).into()
} }
} }
impl From<std::sync::mpsc::TryRecvError> for PostError { impl From<std::sync::mpsc::TryRecvError> for PostError {
fn from(err: std::sync::mpsc::TryRecvError) -> Self { fn from(err: std::sync::mpsc::TryRecvError) -> Self {
PostError::ChannelRecv(err) (&PostErrorInternal::from(err)).into()
} }
} }

View File

@ -1,10 +1,15 @@
// Standard // Standard
use std::collections::VecDeque; use std::{
use std::convert::TryFrom; collections::VecDeque,
use std::io::ErrorKind; convert::TryFrom,
use std::io::Read; io::{
use std::net::SocketAddr; ErrorKind,
use std::thread; Read,
},
net::SocketAddr,
thread,
time::Duration,
};
// External // External
use bincode; use bincode;
@ -12,9 +17,15 @@ use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token};
use mio_extras::channel::{channel, Receiver, Sender}; use mio_extras::channel::{channel, Receiver, Sender};
// Crate // Crate
use super::data::ControlMsg; use super::{
use super::error::PostError; data::ControlMsg,
use super::{PostRecv, PostSend}; error::{
PostError,
PostErrorInternal,
},
PostRecv,
PostSend,
};
// Constants // Constants
const CTRL_TOKEN: Token = Token(0); // Token for thread control messages const CTRL_TOKEN: Token = Token(0); // Token for thread control messages
@ -31,9 +42,10 @@ where
{ {
handle: Option<thread::JoinHandle<()>>, handle: Option<thread::JoinHandle<()>>,
ctrl: Sender<ControlMsg>, ctrl: Sender<ControlMsg>,
recv: Receiver<Result<R, PostError>>, recv: Receiver<Result<R, PostErrorInternal>>,
send: Sender<S>, send: Sender<S>,
poll: Poll, poll: Poll,
err: Option<PostErrorInternal>,
} }
impl<S, R> PostBox<S, R> impl<S, R> PostBox<S, R>
@ -42,16 +54,16 @@ where
R: PostRecv, R: PostRecv,
{ {
/// Creates a new [`PostBox`] connected to specified address, meant to be used by the client /// Creates a new [`PostBox`] connected to specified address, meant to be used by the client
pub fn to_server(addr: &SocketAddr) -> Result<PostBox<S, R>, PostError> { pub fn to_server<A: Into<SocketAddr>>(addr: A) -> Result<PostBox<S, R>, PostError> {
let connection = TcpStream::connect(addr)?; let connection = TcpStream::connect(&addr.into())?;
Self::from_tcpstream(connection) Self::from_tcpstream(connection)
} }
/// Creates a new [`PostBox`] from an existing connection, meant to be used by [`PostOffice`](super::PostOffice) on the server /// Creates a new [`PostBox`] from an existing connection, meant to be used by [`PostOffice`](super::PostOffice) on the server
pub fn from_tcpstream(connection: TcpStream) -> Result<PostBox<S, R>, PostError> { pub fn from_tcpstream(connection: TcpStream) -> Result<PostBox<S, R>, PostError> {
let (ctrl_tx, ctrl_rx) = channel::<ControlMsg>(); // Control messages let (ctrl_tx, ctrl_rx) = channel(); // Control messages
let (send_tx, send_rx) = channel::<S>(); // main thread -[data to be serialized and sent]> worker thread let (send_tx, send_rx) = channel(); // main thread -[data to be serialized and sent]> worker thread
let (recv_tx, recv_rx) = channel::<Result<R, PostError>>(); // main thread <[received and deserialized data]- worker thread let (recv_tx, recv_rx) = channel(); // main thread <[received and deserialized data]- worker thread
let thread_poll = Poll::new().unwrap(); let thread_poll = Poll::new().unwrap();
let postbox_poll = Poll::new().unwrap(); let postbox_poll = Poll::new().unwrap();
thread_poll thread_poll
@ -75,31 +87,54 @@ where
recv: recv_rx, recv: recv_rx,
send: send_tx, send: send_tx,
poll: postbox_poll, poll: postbox_poll,
err: None,
}) })
} }
/// Return an `Option<PostError>` indicating the current status of the `PostBox`.
pub fn status(&self) -> Option<PostError> {
self.err.as_ref().map(|err| err.into())
}
/// Non-blocking sender method /// Non-blocking sender method
pub fn send(&self, data: S) { pub fn send(&mut self, data: S) -> Result<(), PostError> {
self.send.send(data).unwrap_or(()); match &mut self.err {
err @ None => if let Err(_) = self.send.send(data) {
*err = Some(PostErrorInternal::MioError);
Err(err.as_ref().unwrap().into())
} else {
Ok(())
},
err => Err(err.as_ref().unwrap().into()),
}
} }
/// Non-blocking receiver method returning an iterator over already received and deserialized objects /// Non-blocking receiver method returning an iterator over already received and deserialized objects
/// # Errors /// # Errors
/// If the other side disconnects PostBox won't realize that until you try to send something /// If the other side disconnects PostBox won't realize that until you try to send something
pub fn recv_iter(&self) -> Result<impl Iterator<Item = Result<R, PostError>>, PostError> { pub fn recv_iter(&mut self) -> impl Iterator<Item = R> {
let mut events = Events::with_capacity(4096); let mut events = Events::with_capacity(4096);
self.poll let mut items = VecDeque::new();
.poll(&mut events, Some(core::time::Duration::new(0, 0)))?;
let mut data: VecDeque<Result<R, PostError>> = VecDeque::new(); // If an error occured, or previously occured, just give up
if let Some(_) = self.err {
return items.into_iter();
} else if let Err(err) = self.poll.poll(&mut events, Some(Duration::new(0, 0))) {
self.err = Some(err.into());
return items.into_iter();
}
for event in events { for event in events {
match event.token() { match event.token() {
DATA_TOKEN => { DATA_TOKEN => match self.recv.try_recv() {
data.push_back(self.recv.try_recv()?); Ok(Ok(item)) => items.push_back(item),
} Err(err) => self.err = Some(err.into()),
Ok(Err(err)) => self.err = Some(err.into()),
},
_ => (), _ => (),
} }
} }
Ok(data.into_iter()) items.into_iter()
} }
} }
@ -107,7 +142,7 @@ fn postbox_thread<S, R>(
mut connection: TcpStream, mut connection: TcpStream,
ctrl_rx: Receiver<ControlMsg>, ctrl_rx: Receiver<ControlMsg>,
send_rx: Receiver<S>, send_rx: Receiver<S>,
recv_tx: Sender<Result<R, PostError>>, recv_tx: Sender<Result<R, PostErrorInternal>>,
poll: Poll, poll: Poll,
) where ) where
S: PostSend, S: PostSend,
@ -154,7 +189,7 @@ fn postbox_thread<S, R>(
.unwrap(), .unwrap(),
); );
if recv_nextlen > MESSAGE_SIZE_CAP { if recv_nextlen > MESSAGE_SIZE_CAP {
recv_tx.send(Err(PostError::MsgSizeLimitExceeded)).unwrap(); recv_tx.send(Err(PostErrorInternal::MsgSizeLimitExceeded)).unwrap();
connection.shutdown(std::net::Shutdown::Both).unwrap(); connection.shutdown(std::net::Shutdown::Both).unwrap();
recv_buff.drain(..); recv_buff.drain(..);
recv_nextlen = 0; recv_nextlen = 0;

View File

@ -1,18 +1,26 @@
// Standard // Standard
use core::time::Duration; use core::time::Duration;
use std::collections::VecDeque; use std::{
use std::net::SocketAddr; collections::VecDeque,
use std::thread; net::SocketAddr,
thread,
};
// External // External
use mio::{net::TcpListener, Events, Poll, PollOpt, Ready, Token}; use mio::{net::TcpListener, Events, Poll, PollOpt, Ready, Token};
use mio_extras::channel::{channel, Receiver, Sender}; use mio_extras::channel::{channel, Receiver, Sender};
// Crate // Crate
use super::data::ControlMsg; use super::{
use super::error::PostError; data::ControlMsg,
use super::postbox::PostBox; error::{
use super::{PostRecv, PostSend}; PostError,
PostErrorInternal,
},
postbox::PostBox,
PostRecv,
PostSend,
};
// Constants // Constants
const CTRL_TOKEN: Token = Token(0); // Token for thread control messages const CTRL_TOKEN: Token = Token(0); // Token for thread control messages
@ -28,8 +36,9 @@ where
{ {
handle: Option<thread::JoinHandle<()>>, handle: Option<thread::JoinHandle<()>>,
ctrl: Sender<ControlMsg>, ctrl: Sender<ControlMsg>,
recv: Receiver<Result<PostBox<S, R>, PostError>>, recv: Receiver<Result<PostBox<S, R>, PostErrorInternal>>,
poll: Poll, poll: Poll,
err: Option<PostErrorInternal>,
} }
impl<S, R> PostOffice<S, R> impl<S, R> PostOffice<S, R>
@ -38,49 +47,69 @@ where
R: PostRecv, R: PostRecv,
{ {
/// Creates a new [`PostOffice`] listening on specified address /// Creates a new [`PostOffice`] listening on specified address
pub fn new(addr: &SocketAddr) -> Result<Self, PostError> { pub fn new<A: Into<SocketAddr>>(addr: A) -> Result<Self, PostError> {
let listener = TcpListener::bind(addr)?; let listener = TcpListener::bind(&addr.into())?;
let (ctrl_tx, ctrl_rx) = channel(); let (ctrl_tx, ctrl_rx) = channel();
let (recv_tx, recv_rx) = channel(); let (recv_tx, recv_rx) = channel();
let thread_poll = Poll::new()?; let thread_poll = Poll::new()?;
let postbox_poll = Poll::new()?; let postbox_poll = Poll::new()?;
thread_poll.register(&listener, CONN_TOKEN, Ready::readable(), PollOpt::edge())?; thread_poll.register(&listener, CONN_TOKEN, Ready::readable(), PollOpt::edge())?;
thread_poll.register(&ctrl_rx, CTRL_TOKEN, Ready::readable(), PollOpt::edge())?; thread_poll.register(&ctrl_rx, CTRL_TOKEN, Ready::readable(), PollOpt::edge())?;
postbox_poll.register(&recv_rx, DATA_TOKEN, Ready::readable(), PollOpt::edge())?; postbox_poll.register(&recv_rx, DATA_TOKEN, Ready::readable(), PollOpt::edge())?;
let handle = thread::Builder::new() let handle = thread::Builder::new()
.name("postoffice_worker".into()) .name("postoffice_worker".into())
.spawn(move || postoffice_thread(listener, ctrl_rx, recv_tx, thread_poll))?; .spawn(move || postoffice_thread(listener, ctrl_rx, recv_tx, thread_poll))?;
Ok(PostOffice { Ok(PostOffice {
handle: Some(handle), handle: Some(handle),
ctrl: ctrl_tx, ctrl: ctrl_tx,
recv: recv_rx, recv: recv_rx,
poll: postbox_poll, poll: postbox_poll,
err: None,
}) })
} }
/// Return an `Option<PostError>` indicating the current status of the `PostOffice`.
pub fn status(&self) -> Option<PostError> {
self.err.as_ref().map(|err| err.into())
}
/// Non-blocking method returning an iterator over new connections wrapped in [`PostBox`]es /// Non-blocking method returning an iterator over new connections wrapped in [`PostBox`]es
pub fn get_iter( pub fn new_connections(
&self, &mut self,
) -> Result<impl Iterator<Item = Result<PostBox<S, R>, PostError>>, PostError> { ) -> impl Iterator<Item = PostBox<S, R>> {
let mut events = Events::with_capacity(256); let mut events = Events::with_capacity(256);
self.poll.poll(&mut events, Some(Duration::new(0, 0)))?; let mut conns = VecDeque::new();
let mut conns: VecDeque<Result<PostBox<S, R>, PostError>> = VecDeque::new();
// If an error occured, or previously occured, just give up
if let Some(_) = self.err {
return conns.into_iter();
} else if let Err(err) = self.poll.poll(&mut events, Some(Duration::new(0, 0))) {
self.err = Some(err.into());
return conns.into_iter();
}
for event in events { for event in events {
match event.token() { match event.token() {
DATA_TOKEN => { // Ignore recv error
conns.push_back(self.recv.try_recv()?); DATA_TOKEN => match self.recv.try_recv() {
} Ok(Ok(conn)) => conns.push_back(conn),
Err(err) => self.err = Some(err.into()),
Ok(Err(err)) => self.err = Some(err.into()),
},
_ => (), _ => (),
} }
} }
Ok(conns.into_iter()) conns.into_iter()
} }
} }
fn postoffice_thread<S, R>( fn postoffice_thread<S, R>(
listener: TcpListener, listener: TcpListener,
ctrl_rx: Receiver<ControlMsg>, ctrl_rx: Receiver<ControlMsg>,
recv_tx: Sender<Result<PostBox<S, R>, PostError>>, recv_tx: Sender<Result<PostBox<S, R>, PostErrorInternal>>,
poll: Poll, poll: Poll,
) where ) where
S: PostSend, S: PostSend,
@ -96,7 +125,9 @@ fn postoffice_thread<S, R>(
}, },
CONN_TOKEN => { CONN_TOKEN => {
let (conn, _addr) = listener.accept().unwrap(); let (conn, _addr) = listener.accept().unwrap();
recv_tx.send(PostBox::from_tcpstream(conn)).unwrap(); recv_tx.send(PostBox::from_tcpstream(conn)
// TODO: Is it okay to count a failure to create a postbox here as an 'internal error'?
.map_err(|_| PostErrorInternal::MioError)).unwrap();
} }
_ => (), _ => (),
} }

View File

@ -1,69 +1,78 @@
use std::io::Write; use std::{
use std::net::SocketAddr; io::Write,
str::FromStr,
net::SocketAddr,
thread,
time::Duration,
};
use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token}; use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token};
use super::{error::PostError, PostBox, PostOffice}; use super::{error::PostError, PostBox, PostOffice};
fn new_local_addr(n: u16) -> SocketAddr {
SocketAddr::from(([127, 0, 0, 1], 12345 + n))
}
#[test] #[test]
fn basic_run() { fn basic_run() {
let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12345u16)); let srv_addr = new_local_addr(0);
let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12345u16)); let mut server: PostOffice<String, String> = PostOffice::new(srv_addr).unwrap();
let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap(); let mut client: PostBox<String, String> = PostBox::to_server(srv_addr).unwrap();
let client: PostBox<String, String> = PostBox::to_server(&conn_addr).unwrap();
std::thread::sleep(std::time::Duration::from_millis(10)); std::thread::sleep(std::time::Duration::from_millis(10));
let scon = server.get_iter().unwrap().next().unwrap().unwrap(); let mut scon = server.new_connections().next().unwrap();
std::thread::sleep(std::time::Duration::from_millis(10)); std::thread::sleep(std::time::Duration::from_millis(10));
scon.send(String::from("foo")); scon.send(String::from("foo")).unwrap();
client.send(String::from("bar")); client.send(String::from("bar")).unwrap();
std::thread::sleep(std::time::Duration::from_millis(10)); std::thread::sleep(std::time::Duration::from_millis(10));
assert_eq!("foo", client.recv_iter().unwrap().next().unwrap().unwrap()); assert_eq!("foo", client.recv_iter().next().unwrap());
assert_eq!("bar", scon.recv_iter().unwrap().next().unwrap().unwrap()); assert_eq!("bar", scon.recv_iter().next().unwrap());
} }
#[test] #[test]
fn huge_size_header() { fn huge_size_header() {
let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12346u16)); let srv_addr = new_local_addr(1);
let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12346u16));
let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap(); let mut server: PostOffice<String, String> = PostOffice::new(srv_addr).unwrap();
let mut client = TcpStream::connect(&conn_addr).unwrap(); let mut client = TcpStream::connect(&srv_addr).unwrap();
std::thread::sleep(std::time::Duration::from_millis(10)); std::thread::sleep(std::time::Duration::from_millis(10));
let scon = server.get_iter().unwrap().next().unwrap().unwrap(); let mut scon = server.new_connections().next().unwrap();
std::thread::sleep(std::time::Duration::from_millis(10)); std::thread::sleep(std::time::Duration::from_millis(10));
client.write(&[0xffu8; 64]).unwrap(); client.write(&[0xffu8; 64]).unwrap();
std::thread::sleep(std::time::Duration::from_millis(10)); std::thread::sleep(std::time::Duration::from_millis(10));
assert!(match scon.recv_iter().unwrap().next().unwrap() { assert_eq!(scon.recv_iter().next(), None);
Err(PostError::MsgSizeLimitExceeded) => true,
_ => false,
});
} }
#[test] #[test]
fn disconnect() { fn disconnect() {
let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12347u16)); let srv_addr = new_local_addr(2);
let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12347u16));
let server: PostOffice<String, String> = PostOffice::new(&listen_addr).unwrap(); let mut server = PostOffice::<_, String>::new(srv_addr)
.unwrap();
// Create then close client
{ {
#[allow(unused_variables)] PostBox::<String, String>::to_server(srv_addr).unwrap();
let client: PostBox<String, String> = PostBox::to_server(&conn_addr).unwrap();
} }
std::thread::sleep(std::time::Duration::from_millis(10));
let scon = server.get_iter().unwrap().next().unwrap().unwrap();
scon.send(String::from("foo"));
std::thread::sleep(std::time::Duration::from_millis(10)); std::thread::sleep(std::time::Duration::from_millis(10));
match scon.recv_iter().unwrap().next().unwrap() { let mut to_client = server
Ok(_) => panic!("Didn't expect to receive anything"), .new_connections()
Err(err) => { .next()
if !(match err { .unwrap();
PostError::Io(e) => e,
_ => panic!("PostError different than expected"), to_client.send(String::from("foo")).unwrap();
}
.kind() thread::sleep(Duration::from_millis(10));
== std::io::ErrorKind::BrokenPipe)
{ match to_client.recv_iter().next() {
panic!("Error different than disconnection") None => {},
} _ => panic!("Unexpected message!"),
} }
match to_client.status() {
Some(PostError::Disconnected) => {},
s => panic!("Did not expect {:?}", s),
} }
} }