diff --git a/common/Cargo.toml b/common/Cargo.toml index 0975cf220f..0bd937c298 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "veloren-common" version = "0.1.0" -authors = ["Joshua Barretto "] +authors = ["Joshua Barretto ", "Maciej Ćwięka "] edition = "2018" [dependencies] @@ -10,3 +10,8 @@ shred = "0.7" vek = "0.9" dot_vox = "1.0" threadpool = "1.7" +mio = "0.6" +mio-extras = "2.0" +serde = "1.0" +serde_derive = "1.0" +bincode = "1.0" diff --git a/common/src/lib.rs b/common/src/lib.rs index 4d33fd7988..4e61fe8c56 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,4 +1,7 @@ -#![feature(euclidean_division, duration_float)] +#![feature(euclidean_division, duration_float, try_from, trait_alias)] + +#[macro_use] +extern crate serde_derive; pub mod clock; pub mod comp; @@ -8,3 +11,28 @@ pub mod terrain; pub mod util; pub mod volumes; pub mod vol; +// TODO: unignore the code here, for some reason it refuses to compile here while has no problems copy-pasted elsewhere +/// The networking module containing high-level wrappers of `TcpListener` and `TcpStream` (`PostOffice` and `PostBox` respectively) and data types used by both the server and client +/// # Examples +/// ```ignore +/// use std::net::SocketAddr; +/// use veloren_common::net::{PostOffice, PostBox}; +/// +/// let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12345u16)); +/// let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12345u16)); +/// +/// let server: PostOffice = PostOffice::new(&listen_addr).unwrap(); +/// let client: PostBox = PostBox::to_server(&conn_addr).unwrap(); +/// std::thread::sleep(std::time::Duration::from_millis(100)); +/// +/// let scon = server.get_iter().unwrap().next().unwrap().unwrap(); +/// std::thread::sleep(std::time::Duration::from_millis(100)); +/// +/// scon.send(String::from("foo")); +/// client.send(String::from("bar")); +/// std::thread::sleep(std::time::Duration::from_millis(100)); +/// +/// assert_eq!("foo", client.recv_iter().unwrap().next().unwrap().unwrap()); +/// assert_eq!("bar", scon.recv_iter().unwrap().next().unwrap().unwrap()); +/// ``` +pub mod net; diff --git a/common/src/net/data.rs b/common/src/net/data.rs new file mode 100644 index 0000000000..ee0b19090a --- /dev/null +++ b/common/src/net/data.rs @@ -0,0 +1,18 @@ +/// Messages server sends to client +#[derive(Deserialize, Serialize, Debug)] +pub enum ServerMsg { + // VersionInfo MUST always stay first in this struct + VersionInfo {}, +} + +/// Messages client sends to server +#[derive(Deserialize, Serialize, Debug)] +pub enum ClientMsg { + // VersionInfo MUST always stay first in this struct + VersionInfo {}, +} + +/// Control message type, used in [PostBox](super::PostBox) and [PostOffice](super::PostOffice) to control threads +pub enum ControlMsg { + Shutdown, +} diff --git a/common/src/net/error.rs b/common/src/net/error.rs new file mode 100644 index 0000000000..cd86e809bf --- /dev/null +++ b/common/src/net/error.rs @@ -0,0 +1,26 @@ +#[derive(Debug)] +pub enum PostError { + Io(std::io::Error), + Serde(bincode::Error), + ChannelRecv(std::sync::mpsc::TryRecvError), + ChannelSend, // Empty because I couldn't figure out how to handle generic type in mpsc::TrySendError properly + MsgSizeLimitExceeded, +} + +impl From for PostError { + fn from(err: std::io::Error) -> Self { + PostError::Io(err) + } +} + +impl From for PostError { + fn from(err: bincode::Error) -> Self { + PostError::Serde(err) + } +} + +impl From for PostError { + fn from(err: std::sync::mpsc::TryRecvError) -> Self { + PostError::ChannelRecv(err) + } +} diff --git a/common/src/net/mod.rs b/common/src/net/mod.rs new file mode 100644 index 0000000000..d9fe3929ce --- /dev/null +++ b/common/src/net/mod.rs @@ -0,0 +1,16 @@ +pub mod data; +pub mod error; +pub mod postbox; +pub mod postoffice; +mod test; + +// Reexports +pub use self::{ + data::{ClientMsg, ServerMsg}, + error::PostError, + postbox::PostBox, + postoffice::PostOffice, +}; + +pub trait PostSend = 'static + serde::Serialize + std::marker::Send + std::fmt::Debug; +pub trait PostRecv = 'static + serde::de::DeserializeOwned + std::marker::Send + std::fmt::Debug; diff --git a/common/src/net/postbox.rs b/common/src/net/postbox.rs new file mode 100644 index 0000000000..08d6eb254e --- /dev/null +++ b/common/src/net/postbox.rs @@ -0,0 +1,226 @@ +// Standard +use std::collections::VecDeque; +use std::convert::TryFrom; +use std::io::ErrorKind; +use std::io::Read; +use std::net::SocketAddr; +use std::thread; + +// External +use bincode; +use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token}; +use mio_extras::channel::{channel, Receiver, Sender}; + +// Crate +use super::data::ControlMsg; +use super::error::PostError; +use super::{PostRecv, PostSend}; + +// Constants +const CTRL_TOKEN: Token = Token(0); // Token for thread control messages +const DATA_TOKEN: Token = Token(1); // Token for thread data exchange +const CONN_TOKEN: Token = Token(2); // Token for TcpStream for the PostBox child thread +const MESSAGE_SIZE_CAP: u64 = 1 << 20; // Maximum accepted length of a packet + +/// A high-level wrapper of [`TcpStream`](mio::net::TcpStream). +/// [`PostBox`] takes care of serializing sent packets and deserializing received packets in the background, providing a simple API for sending and receiving objects over network. +pub struct PostBox +where + S: PostSend, + R: PostRecv, +{ + handle: Option>, + ctrl: Sender, + recv: Receiver>, + send: Sender, + poll: Poll, +} + +impl PostBox +where + S: PostSend, + R: PostRecv, +{ + /// Creates a new [`PostBox`] connected to specified address, meant to be used by the client + pub fn to_server(addr: &SocketAddr) -> Result, PostError> { + let connection = TcpStream::connect(addr)?; + Self::from_tcpstream(connection) + } + + /// Creates a new [`PostBox`] from an existing connection, meant to be used by [`PostOffice`](super::PostOffice) on the server + pub fn from_tcpstream(connection: TcpStream) -> Result, PostError> { + let (ctrl_tx, ctrl_rx) = channel::(); // Control messages + let (send_tx, send_rx) = channel::(); // main thread -[data to be serialized and sent]> worker thread + let (recv_tx, recv_rx) = channel::>(); // main thread <[received and deserialized data]- worker thread + let thread_poll = Poll::new().unwrap(); + let postbox_poll = Poll::new().unwrap(); + thread_poll + .register(&connection, CONN_TOKEN, Ready::readable(), PollOpt::edge()) + .unwrap(); + thread_poll + .register(&ctrl_rx, CTRL_TOKEN, Ready::readable(), PollOpt::edge()) + .unwrap(); + thread_poll + .register(&send_rx, DATA_TOKEN, Ready::readable(), PollOpt::edge()) + .unwrap(); + postbox_poll + .register(&recv_rx, DATA_TOKEN, Ready::readable(), PollOpt::edge()) + .unwrap(); + let handle = thread::Builder::new() + .name("postbox_worker".into()) + .spawn(move || postbox_thread(connection, ctrl_rx, send_rx, recv_tx, thread_poll))?; + Ok(PostBox { + handle: Some(handle), + ctrl: ctrl_tx, + recv: recv_rx, + send: send_tx, + poll: postbox_poll, + }) + } + + /// Non-blocking sender method + pub fn send(&self, data: S) { + self.send.send(data).unwrap_or(()); + } + + /// Non-blocking receiver method returning an iterator over already received and deserialized objects + /// # Errors + /// If the other side disconnects PostBox won't realize that until you try to send something + pub fn recv_iter(&self) -> Result>, PostError> { + let mut events = Events::with_capacity(4096); + self.poll + .poll(&mut events, Some(core::time::Duration::new(0, 0)))?; + let mut data: VecDeque> = VecDeque::new(); + for event in events { + match event.token() { + DATA_TOKEN => { + data.push_back(self.recv.try_recv()?); + } + _ => (), + } + } + Ok(data.into_iter()) + } +} + +fn postbox_thread( + mut connection: TcpStream, + ctrl_rx: Receiver, + send_rx: Receiver, + recv_tx: Sender>, + poll: Poll, +) where + S: PostSend, + R: PostRecv, +{ + let mut events = Events::with_capacity(64); + // Receiving related variables + let mut recv_buff = Vec::new(); + let mut recv_nextlen: u64 = 0; + loop { + let mut disconnected = false; + poll.poll(&mut events, None) + .expect("Failed to execute poll(), most likely fault of the OS"); + for event in events.iter() { + match event.token() { + CTRL_TOKEN => match ctrl_rx.try_recv().unwrap() { + ControlMsg::Shutdown => return, + }, + CONN_TOKEN => match connection.read_to_end(&mut recv_buff) { + Ok(_) => {} + // Returned when all the data has been read + Err(ref e) if e.kind() == ErrorKind::WouldBlock => {} + Err(e) => { + recv_tx.send(Err(e.into())).unwrap(); + } + }, + DATA_TOKEN => { + let mut packet = bincode::serialize(&send_rx.try_recv().unwrap()).unwrap(); + packet.splice(0..0, (packet.len() as u64).to_be_bytes().iter().cloned()); + match connection.write_bufs(&[packet.as_slice().into()]) { + Ok(_) => {} + Err(e) => { + recv_tx.send(Err(e.into())).unwrap(); + } + }; + } + _ => {} + } + } + loop { + if recv_nextlen == 0 && recv_buff.len() >= 8 { + recv_nextlen = u64::from_be_bytes( + <[u8; 8]>::try_from(recv_buff.drain(0..8).collect::>().as_slice()) + .unwrap(), + ); + if recv_nextlen > MESSAGE_SIZE_CAP { + recv_tx.send(Err(PostError::MsgSizeLimitExceeded)).unwrap(); + connection.shutdown(std::net::Shutdown::Both).unwrap(); + recv_buff.drain(..); + recv_nextlen = 0; + break; + } + } + if recv_buff.len() as u64 >= recv_nextlen && recv_nextlen != 0 { + match bincode::deserialize(recv_buff + .drain( + 0..usize::try_from(recv_nextlen) + .expect("Message size was larger than usize (insane message size and 32 bit OS)"), + ) + .collect::>() + .as_slice()) { + Ok(ok) => { + recv_tx + .send(Ok(ok)) + .unwrap(); + recv_nextlen = 0; + } + Err(e) => { + recv_tx.send(Err(e.into())).unwrap(); + recv_nextlen = 0; + continue + } + } + } else { + break; + } + } + match connection.take_error().unwrap() { + Some(e) => { + if e.kind() == ErrorKind::BrokenPipe { + disconnected = true; + } + recv_tx.send(Err(e.into())).unwrap(); + } + None => {} + } + if disconnected == true { + break; + } + } + + // Loop after disconnected + loop { + poll.poll(&mut events, None) + .expect("Failed to execute poll(), most likely fault of the OS"); + for event in events.iter() { + match event.token() { + CTRL_TOKEN => match ctrl_rx.try_recv().unwrap() { + ControlMsg::Shutdown => return, + }, + _ => {} + } + } + } +} + +impl Drop for PostBox +where + S: PostSend, + R: PostRecv, +{ + fn drop(&mut self) { + self.ctrl.send(ControlMsg::Shutdown).unwrap_or(()); + self.handle.take().map(|handle| handle.join()); + } +} diff --git a/common/src/net/postoffice.rs b/common/src/net/postoffice.rs new file mode 100644 index 0000000000..c1e718ac30 --- /dev/null +++ b/common/src/net/postoffice.rs @@ -0,0 +1,116 @@ +// Standard +use core::time::Duration; +use std::collections::VecDeque; +use std::net::SocketAddr; +use std::thread; + +// External +use mio::{net::TcpListener, Events, Poll, PollOpt, Ready, Token}; +use mio_extras::channel::{channel, Receiver, Sender}; + +// Crate +use super::data::ControlMsg; +use super::error::PostError; +use super::postbox::PostBox; +use super::{PostRecv, PostSend}; + +// Constants +const CTRL_TOKEN: Token = Token(0); // Token for thread control messages +const DATA_TOKEN: Token = Token(1); // Token for thread data exchange +const CONN_TOKEN: Token = Token(2); // Token for TcpStream for the PostBox child thread + +/// A high-level wrapper of [`TcpListener`](mio::net::TcpListener). +/// [`PostOffice`] listens for incoming connections in the background and wraps them into [`PostBox`]es, providing a simple non-blocking API for receiving them. +pub struct PostOffice +where + S: PostSend, + R: PostRecv, +{ + handle: Option>, + ctrl: Sender, + recv: Receiver, PostError>>, + poll: Poll, +} + +impl PostOffice +where + S: PostSend, + R: PostRecv, +{ + /// Creates a new [`PostOffice`] listening on specified address + pub fn new(addr: &SocketAddr) -> Result { + let listener = TcpListener::bind(addr)?; + let (ctrl_tx, ctrl_rx) = channel(); + let (recv_tx, recv_rx) = channel(); + let thread_poll = Poll::new()?; + let postbox_poll = Poll::new()?; + thread_poll.register(&listener, CONN_TOKEN, Ready::readable(), PollOpt::edge())?; + thread_poll.register(&ctrl_rx, CTRL_TOKEN, Ready::readable(), PollOpt::edge())?; + postbox_poll.register(&recv_rx, DATA_TOKEN, Ready::readable(), PollOpt::edge())?; + let handle = thread::Builder::new() + .name("postoffice_worker".into()) + .spawn(move || postoffice_thread(listener, ctrl_rx, recv_tx, thread_poll))?; + Ok(PostOffice { + handle: Some(handle), + ctrl: ctrl_tx, + recv: recv_rx, + poll: postbox_poll, + }) + } + + /// Non-blocking method returning an iterator over new connections wrapped in [`PostBox`]es + pub fn get_iter( + &self, + ) -> Result, PostError>>, PostError> { + let mut events = Events::with_capacity(256); + self.poll.poll(&mut events, Some(Duration::new(0, 0)))?; + let mut conns: VecDeque, PostError>> = VecDeque::new(); + for event in events { + match event.token() { + DATA_TOKEN => { + conns.push_back(self.recv.try_recv()?); + } + _ => (), + } + } + Ok(conns.into_iter()) + } +} + +fn postoffice_thread( + listener: TcpListener, + ctrl_rx: Receiver, + recv_tx: Sender, PostError>>, + poll: Poll, +) where + S: PostSend, + R: PostRecv, +{ + let mut events = Events::with_capacity(256); + loop { + poll.poll(&mut events, None).expect("Failed to execute recv_poll.poll() in PostOffce receiver thread, most likely fault of the OS."); + for event in events.iter() { + match event.token() { + CTRL_TOKEN => match ctrl_rx.try_recv().unwrap() { + ControlMsg::Shutdown => return, + }, + CONN_TOKEN => { + let (conn, _addr) = listener.accept().unwrap(); + recv_tx.send(PostBox::from_tcpstream(conn)).unwrap(); + } + _ => (), + } + } + } +} + +impl Drop for PostOffice +where + S: PostSend, + R: PostRecv, +{ + fn drop(&mut self) { + self.ctrl.send(ControlMsg::Shutdown).unwrap_or(()); // If this fails the thread is dead already + self.handle.take().map(|handle| handle.join()); + } +} diff --git a/common/src/net/test.rs b/common/src/net/test.rs new file mode 100644 index 0000000000..1bcc7d2a9b --- /dev/null +++ b/common/src/net/test.rs @@ -0,0 +1,69 @@ +use std::io::Write; +use std::net::SocketAddr; + +use mio::{net::TcpStream, Events, Poll, PollOpt, Ready, Token}; + +use super::{error::PostError, PostBox, PostOffice}; + +#[test] +fn basic_run() { + let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12345u16)); + let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12345u16)); + let server: PostOffice = PostOffice::new(&listen_addr).unwrap(); + let client: PostBox = PostBox::to_server(&conn_addr).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let scon = server.get_iter().unwrap().next().unwrap().unwrap(); + std::thread::sleep(std::time::Duration::from_millis(10)); + scon.send(String::from("foo")); + client.send(String::from("bar")); + std::thread::sleep(std::time::Duration::from_millis(10)); + assert_eq!("foo", client.recv_iter().unwrap().next().unwrap().unwrap()); + assert_eq!("bar", scon.recv_iter().unwrap().next().unwrap().unwrap()); +} + +#[test] +fn huge_size_header() { + let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12346u16)); + let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12346u16)); + let server: PostOffice = PostOffice::new(&listen_addr).unwrap(); + let mut client = TcpStream::connect(&conn_addr).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let scon = server.get_iter().unwrap().next().unwrap().unwrap(); + std::thread::sleep(std::time::Duration::from_millis(10)); + client.write(&[0xffu8; 64]).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(10)); + assert!(match scon.recv_iter().unwrap().next().unwrap() { + Err(PostError::MsgSizeLimitExceeded) => true, + _ => false, + }); +} + +#[test] +fn disconnect() { + let listen_addr = SocketAddr::from(([0, 0, 0, 0], 12347u16)); + let conn_addr = SocketAddr::from(([127, 0, 0, 1], 12347u16)); + let server: PostOffice = PostOffice::new(&listen_addr).unwrap(); + { + #[allow(unused_variables)] + let client: PostBox = PostBox::to_server(&conn_addr).unwrap(); + } + std::thread::sleep(std::time::Duration::from_millis(10)); + let scon = server.get_iter().unwrap().next().unwrap().unwrap(); + scon.send(String::from("foo")); + std::thread::sleep(std::time::Duration::from_millis(10)); + + match scon.recv_iter().unwrap().next().unwrap() { + Ok(_) => panic!("Didn't expect to receive anything"), + Err(err) => { + if !(match err { + PostError::Io(e) => e, + _ => panic!("PostError different than expected"), + } + .kind() + == std::io::ErrorKind::BrokenPipe) + { + panic!("Error different than disconnection") + } + } + } +}