diff --git a/network/examples/network-speed/main.rs b/network/examples/network-speed/main.rs index a54f4121c3..f1ad37efea 100644 --- a/network/examples/network-speed/main.rs +++ b/network/examples/network-speed/main.rs @@ -9,13 +9,12 @@ use clap::{App, Arg}; use futures::executor::block_on; use serde::{Deserialize, Serialize}; use std::{ - sync::Arc, thread, time::{Duration, Instant}, }; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::{MessageBuffer, Network, Pid, Promises, ProtocolAddr}; +use veloren_network::{Message, Network, Pid, Promises, ProtocolAddr}; #[derive(Serialize, Deserialize, Debug)] enum Msg { @@ -156,15 +155,15 @@ fn client(address: ProtocolAddr) { let mut s1 = block_on(p1.open(16, Promises::ORDERED | Promises::CONSISTENCY)).unwrap(); //remote representation of s1 let mut last = Instant::now(); let mut id = 0u64; - let raw_msg = Arc::new(MessageBuffer { - data: bincode::serialize(&Msg::Ping { + let raw_msg = Message::serialize( + &Msg::Ping { id, data: vec![0; 1000], - }) - .unwrap(), - }); + }, + &s1, + ); loop { - s1.send_raw(raw_msg.clone()).unwrap(); + s1.send_raw(&raw_msg).unwrap(); id += 1; if id.rem_euclid(1000000) == 0 { let new = Instant::now(); diff --git a/network/src/api.rs b/network/src/api.rs index 972c71b400..ae5a21852f 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -3,7 +3,7 @@ //! //! (cd network/examples/async_recv && RUST_BACKTRACE=1 cargo run) use crate::{ - message::{self, partial_eq_bincode, IncomingMessage, MessageBuffer, OutgoingMessage}, + message::{partial_eq_bincode, IncomingMessage, Message, OutgoingMessage}, participant::{A2bStreamOpen, S2bShutdownBparticipant}, scheduler::Scheduler, types::{Mid, Pid, Prio, Promises, Sid}, @@ -77,7 +77,7 @@ pub struct Stream { promises: Promises, send_closed: Arc, a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: mpsc::UnboundedReceiver, + b2a_msg_recv_r: Option>, a2b_close_stream_s: Option>, } @@ -667,7 +667,7 @@ impl Stream { promises, send_closed, a2b_msg_s, - b2a_msg_recv_r, + b2a_msg_recv_r: Some(b2a_msg_recv_r), a2b_close_stream_s: Some(a2b_close_stream_s), } } @@ -727,21 +727,18 @@ impl Stream { /// [`Serialized`]: Serialize #[inline] pub fn send(&mut self, msg: M) -> Result<(), StreamError> { - self.send_raw(Arc::new(message::serialize( - &msg, - #[cfg(feature = "compression")] - self.promises.contains(Promises::COMPRESSED), - ))) + self.send_raw(&Message::serialize(&msg, &self)) } /// This methods give the option to skip multiple calls of [`bincode`] and /// [`compress`], e.g. in case the same Message needs to send on /// multiple `Streams` to multiple [`Participants`]. Other then that, - /// the same rules apply than for [`send`] + /// the same rules apply than for [`send`]. + /// You need to create a Message via [`Message::serialize`]. /// /// # Example /// ```rust - /// use veloren_network::{Network, ProtocolAddr, Pid, MessageBuffer}; + /// use veloren_network::{Network, ProtocolAddr, Pid, Message}; /// # use veloren_network::Promises; /// use futures::executor::block_on; /// use bincode; @@ -767,13 +764,10 @@ impl Stream { /// let mut stream_b = participant_b.opened().await?; /// /// //Prepare Message and decode it - /// let msg = "Hello World"; - /// let raw_msg = Arc::new(MessageBuffer{ - /// data: bincode::serialize(&msg).unwrap(), - /// }); + /// let msg = Message::serialize("Hello World", &stream_a); /// //Send same Message to multiple Streams - /// stream_a.send_raw(raw_msg.clone()); - /// stream_b.send_raw(raw_msg.clone()); + /// stream_a.send_raw(&msg); + /// stream_b.send_raw(&msg); /// # Ok(()) /// }) /// # } @@ -782,12 +776,15 @@ impl Stream { /// [`send`]: Stream::send /// [`Participants`]: crate::api::Participant /// [`compress`]: lz_fear::raw::compress2 - pub fn send_raw(&mut self, messagebuffer: Arc) -> Result<(), StreamError> { + /// [`Message::serialize`]: crate::message::Message::serialize + pub fn send_raw(&mut self, message: &Message) -> Result<(), StreamError> { if self.send_closed.load(Ordering::Relaxed) { return Err(StreamError::StreamClosed); } + #[cfg(debug_assertions)] + message.verify(&self); self.a2b_msg_s.send((self.prio, self.sid, OutgoingMessage { - buffer: messagebuffer, + buffer: Arc::clone(&message.buffer), cursor: 0, mid: self.mid, sid: self.sid, @@ -824,7 +821,7 @@ impl Stream { /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; - /// //Send Message + /// //Recv Message /// println!("{}", stream_a.recv::().await?); /// # Ok(()) /// }) @@ -832,26 +829,120 @@ impl Stream { /// ``` #[inline] pub async fn recv(&mut self) -> Result { - message::deserialize( - self.recv_raw().await?, - #[cfg(feature = "compression")] - self.promises.contains(Promises::COMPRESSED), - ) + self.recv_raw().await?.deserialize() } /// the equivalent like [`send_raw`] but for [`recv`], no [`bincode`] or /// [`decompress`] is executed for performance reasons. /// + /// # Example + /// ``` + /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// # use veloren_network::Promises; + /// use futures::executor::block_on; + /// + /// # fn main() -> std::result::Result<(), Box> { + /// // Create a Network, listen on Port `2230` and wait for a Stream to be opened, then listen on it + /// let (network, f) = Network::new(Pid::new()); + /// std::thread::spawn(f); + /// # let (remote, fr) = Network::new(Pid::new()); + /// # std::thread::spawn(fr); + /// block_on(async { + /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; + /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # stream_p.send("Hello World"); + /// let participant_a = network.connected().await?; + /// let mut stream_a = participant_a.opened().await?; + /// //Recv Message + /// let msg = stream_a.recv_raw().await?; + /// //Resend Message, without deserializing + /// stream_a.send_raw(&msg)?; + /// # Ok(()) + /// }) + /// # } + /// ``` + /// /// [`send_raw`]: Stream::send_raw /// [`recv`]: Stream::recv /// [`decompress`]: lz_fear::raw::decompress_raw - pub async fn recv_raw(&mut self) -> Result { - let msg = self.b2a_msg_recv_r.next().await?; - Ok(msg.buffer) + pub async fn recv_raw(&mut self) -> Result { + match &mut self.b2a_msg_recv_r { + Some(b2a_msg_recv_r) => { + match b2a_msg_recv_r.next().await { + Some(msg) => Ok(Message { + buffer: Arc::new(msg.buffer), + #[cfg(feature = "compression")] + compressed: self.promises.contains(Promises::COMPRESSED), + }), + None => { + self.b2a_msg_recv_r = None; //prevent panic + Err(StreamError::StreamClosed) + }, + } + }, + None => Err(StreamError::StreamClosed), + } } + + /// use `try_recv` to check for a Message send from the remote side by their + /// `Stream`. This function does not block and returns immediately. It's + /// intended for use in non-async context only. Other then that, the + /// same rules apply than for [`recv`]. + /// + /// # Example + /// ``` + /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// # use veloren_network::Promises; + /// use futures::executor::block_on; + /// + /// # fn main() -> std::result::Result<(), Box> { + /// // Create a Network, listen on Port `2240` and wait for a Stream to be opened, then listen on it + /// let (network, f) = Network::new(Pid::new()); + /// std::thread::spawn(f); + /// # let (remote, fr) = Network::new(Pid::new()); + /// # std::thread::spawn(fr); + /// block_on(async { + /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; + /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # stream_p.send("Hello World"); + /// # std::thread::sleep(std::time::Duration::from_secs(1)); + /// let participant_a = network.connected().await?; + /// let mut stream_a = participant_a.opened().await?; + /// //Try Recv Message + /// println!("{:?}", stream_a.try_recv::()?); + /// # Ok(()) + /// }) + /// # } + /// ``` + /// + /// [`recv`]: Stream::recv + #[inline] + pub fn try_recv(&mut self) -> Result, StreamError> { + match &mut self.b2a_msg_recv_r { + Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_next() { + Err(_) => Ok(None), + Ok(None) => { + self.b2a_msg_recv_r = None; //prevent panic + Err(StreamError::StreamClosed) + }, + Ok(Some(msg)) => Ok(Some( + Message { + buffer: Arc::new(msg.buffer), + #[cfg(feature = "compression")] + compressed: self.promises().contains(Promises::COMPRESSED), + } + .deserialize()?, + )), + }, + None => Err(StreamError::StreamClosed), + } + } + + pub fn promises(&self) -> Promises { self.promises } } -/// impl core::cmp::PartialEq for Participant { fn eq(&self, other: &Self) -> bool { //don't check local_pid, 2 Participant from different network should match if diff --git a/network/src/lib.rs b/network/src/lib.rs index 611c4e6d54..bb14782a69 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -113,5 +113,5 @@ mod types; pub use api::{ Network, NetworkError, Participant, ParticipantError, ProtocolAddr, Stream, StreamError, }; -pub use message::MessageBuffer; +pub use message::Message; pub use types::{Pid, Promises}; diff --git a/network/src/message.rs b/network/src/message.rs index f711168481..bb5f49a36e 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -1,14 +1,15 @@ use serde::{de::DeserializeOwned, Serialize}; //use std::collections::VecDeque; +#[cfg(feature = "compression")] +use crate::types::Promises; use crate::{ - api::StreamError, + api::{Stream, StreamError}, types::{Frame, Mid, Sid}, }; use std::{io, sync::Arc}; +#[cfg(all(feature = "compression", debug_assertions))] +use tracing::warn; -//Todo: Evaluate switching to VecDeque for quickly adding and removing data -// from front, back. -// - It would prob require custom bincode code but thats possible. /// Support struct used for optimising sending the same Message to multiple /// [`Stream`] /// @@ -16,7 +17,16 @@ use std::{io, sync::Arc}; /// /// [`Stream`]: crate::api::Stream /// [`send_raw`]: crate::api::Stream::send_raw -pub struct MessageBuffer { +pub struct Message { + pub(crate) buffer: Arc, + #[cfg(feature = "compression")] + pub(crate) compressed: bool, +} + +//Todo: Evaluate switching to VecDeque for quickly adding and removing data +// from front, back. +// - It would prob require custom bincode code but thats possible. +pub(crate) struct MessageBuffer { pub data: Vec, } @@ -36,63 +46,127 @@ pub(crate) struct IncomingMessage { pub sid: Sid, } -pub(crate) fn serialize( - message: &M, - #[cfg(feature = "compression")] compress: bool, -) -> MessageBuffer { - //this will never fail: https://docs.rs/bincode/0.8.0/bincode/fn.serialize.html - let serialized_data = bincode::serialize(message).unwrap(); +impl Message { + /// This serializes any message, according to the [`Streams`] [`Promises`]. + /// You can reuse this `Message` and send it via other [`Streams`], if the + /// [`Promises`] match. E.g. Sending a `Message` via a compressed and + /// uncompressed Stream is dangerous, unless the remote site knows about + /// this. + /// + /// # Example + /// for example coding, see [`send_raw`] + /// + /// [`send_raw`]: Stream::send_raw + /// [`Participants`]: crate::api::Participant + /// [`compress`]: lz_fear::raw::compress2 + /// [`Message::serialize`]: crate::message::Message::serialize + /// + /// [`Streams`]: crate::api::Stream + pub fn serialize(message: &M, stream: &Stream) -> Self { + //this will never fail: https://docs.rs/bincode/0.8.0/bincode/fn.serialize.html + let serialized_data = bincode::serialize(message).unwrap(); - #[cfg(not(feature = "compression"))] - let compress = false; - - MessageBuffer { - data: if compress { - #[cfg(feature = "compression")] - { - let mut compressed_data = Vec::with_capacity(serialized_data.len() / 4 + 10); - let mut table = lz_fear::raw::U32Table::default(); - lz_fear::raw::compress2(&serialized_data, 0, &mut table, &mut compressed_data) - .unwrap(); - compressed_data - } - #[cfg(not(feature = "compression"))] - unreachable!("compression isn't enabled as a feature"); + #[cfg(feature = "compression")] + let compressed = stream.promises().contains(Promises::COMPRESSED); + #[cfg(feature = "compression")] + let data = if compressed { + let mut compressed_data = Vec::with_capacity(serialized_data.len() / 4 + 10); + let mut table = lz_fear::raw::U32Table::default(); + lz_fear::raw::compress2(&serialized_data, 0, &mut table, &mut compressed_data).unwrap(); + compressed_data } else { serialized_data - }, - } -} - -pub(crate) fn deserialize( - buffer: MessageBuffer, - #[cfg(feature = "compression")] compress: bool, -) -> Result { - #[cfg(not(feature = "compression"))] - let compress = false; - - let uncompressed_data = if compress { - #[cfg(feature = "compression")] - { - let mut uncompressed_data = Vec::with_capacity(buffer.data.len() * 2); - if let Err(e) = lz_fear::raw::decompress_raw( - &buffer.data, - &[0; 0], - &mut uncompressed_data, - usize::MAX, - ) { - return Err(StreamError::Compression(e)); - } - uncompressed_data - } + }; #[cfg(not(feature = "compression"))] - unreachable!("compression isn't enabled as a feature"); - } else { - buffer.data - }; - match bincode::deserialize(uncompressed_data.as_slice()) { - Ok(m) => Ok(m), - Err(e) => Err(StreamError::Deserialize(e)), + let data = serialized_data; + #[cfg(not(feature = "compression"))] + let _stream = stream; + + Self { + buffer: Arc::new(MessageBuffer { data }), + #[cfg(feature = "compression")] + compressed, + } + } + + /// deserialize this `Message`. This consumes the struct, as deserialization + /// is only expected once. Use this when deserialize a [`recv_raw`] + /// `Message`. If you are resending this message, deserialization might need + /// to copy memory + /// + /// # Example + /// ``` + /// # use veloren_network::{Network, ProtocolAddr, Pid}; + /// # use veloren_network::Promises; + /// # use futures::executor::block_on; + /// + /// # fn main() -> std::result::Result<(), Box> { + /// // Create a Network, listen on Port `2300` and wait for a Stream to be opened, then listen on it + /// # let (network, f) = Network::new(Pid::new()); + /// # std::thread::spawn(f); + /// # let (remote, fr) = Network::new(Pid::new()); + /// # std::thread::spawn(fr); + /// # block_on(async { + /// # network.listen(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; + /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # stream_p.send("Hello World"); + /// # let participant_a = network.connected().await?; + /// let mut stream_a = participant_a.opened().await?; + /// //Recv Message + /// let msg = stream_a.recv_raw().await?; + /// println!("Msg is {}", msg.deserialize::()?); + /// # Ok(()) + /// # }) + /// # } + /// ``` + /// + /// [`recv_raw`]: crate::api::Stream::recv_raw + pub fn deserialize(self) -> Result { + #[cfg(not(feature = "compression"))] + let uncompressed_data = match Arc::try_unwrap(self.buffer) { + Ok(d) => d.data, + Err(b) => b.data.clone(), + }; + + #[cfg(feature = "compression")] + let uncompressed_data = if self.compressed { + { + let mut uncompressed_data = Vec::with_capacity(self.buffer.data.len() * 2); + if let Err(e) = lz_fear::raw::decompress_raw( + &self.buffer.data, + &[0; 0], + &mut uncompressed_data, + usize::MAX, + ) { + return Err(StreamError::Compression(e)); + } + uncompressed_data + } + } else { + match Arc::try_unwrap(self.buffer) { + Ok(d) => d.data, + Err(b) => b.data.clone(), + } + }; + + match bincode::deserialize(uncompressed_data.as_slice()) { + Ok(m) => Ok(m), + Err(e) => Err(StreamError::Deserialize(e)), + } + } + + #[cfg(debug_assertions)] + pub(crate) fn verify(&self, stream: &Stream) { + #[cfg(not(feature = "compression"))] + let _stream = stream; + #[cfg(feature = "compression")] + if self.compressed != stream.promises().contains(Promises::COMPRESSED) { + warn!( + ?stream, + "verify failed, msg is {} and it doesn't match with stream", self.compressed + ); + } } } @@ -181,36 +255,61 @@ impl std::fmt::Debug for MessageBuffer { #[cfg(test)] mod tests { - use crate::message::*; + use crate::{api::Stream, message::*}; + use futures::channel::mpsc; + use std::sync::{atomic::AtomicBool, Arc}; + + fn stub_stream(compressed: bool) -> Stream { + use crate::{api::*, types::*}; + + #[cfg(feature = "compression")] + let promises = if compressed { + Promises::COMPRESSED + } else { + Promises::empty() + }; + + #[cfg(not(feature = "compression"))] + let promises = Promises::empty(); + + let (a2b_msg_s, _a2b_msg_r) = crossbeam_channel::unbounded(); + let (_b2a_msg_recv_s, b2a_msg_recv_r) = mpsc::unbounded(); + let (a2b_close_stream_s, _a2b_close_stream_r) = mpsc::unbounded(); + + Stream::new( + Pid::fake(0), + Sid::new(0), + 0u8, + promises, + Arc::new(AtomicBool::new(true)), + a2b_msg_s, + b2a_msg_recv_r, + a2b_close_stream_s, + ) + } #[test] fn serialize_test() { - let msg = "abc"; - let mb = serialize( - &msg, - #[cfg(feature = "compression")] - false, - ); - assert_eq!(mb.data.len(), 11); - assert_eq!(mb.data[0], 3); - assert_eq!(mb.data[1..7], [0, 0, 0, 0, 0, 0]); - assert_eq!(mb.data[8], b'a'); - assert_eq!(mb.data[9], b'b'); - assert_eq!(mb.data[10], b'c'); + let msg = Message::serialize("abc", &stub_stream(false)); + assert_eq!(msg.buffer.data.len(), 11); + assert_eq!(msg.buffer.data[0], 3); + assert_eq!(msg.buffer.data[1..7], [0, 0, 0, 0, 0, 0]); + assert_eq!(msg.buffer.data[8], b'a'); + assert_eq!(msg.buffer.data[9], b'b'); + assert_eq!(msg.buffer.data[10], b'c'); } #[cfg(feature = "compression")] #[test] fn serialize_compress_small() { - let msg = "abc"; - let mb = serialize(&msg, true); - assert_eq!(mb.data.len(), 12); - assert_eq!(mb.data[0], 176); - assert_eq!(mb.data[1], 3); - assert_eq!(mb.data[2..8], [0, 0, 0, 0, 0, 0]); - assert_eq!(mb.data[9], b'a'); - assert_eq!(mb.data[10], b'b'); - assert_eq!(mb.data[11], b'c'); + let msg = Message::serialize("abc", &stub_stream(true)); + assert_eq!(msg.buffer.data.len(), 12); + assert_eq!(msg.buffer.data[0], 176); + assert_eq!(msg.buffer.data[1], 3); + assert_eq!(msg.buffer.data[2..8], [0, 0, 0, 0, 0, 0]); + assert_eq!(msg.buffer.data[9], b'a'); + assert_eq!(msg.buffer.data[10], b'b'); + assert_eq!(msg.buffer.data[11], b'c'); } #[cfg(feature = "compression")] @@ -227,15 +326,15 @@ mod tests { 0, "assets/data/plants/flowers/greenrose.ron", ); - let mb = serialize(&msg, true); - assert_eq!(mb.data.len(), 79); - assert_eq!(mb.data[0], 34); - assert_eq!(mb.data[1], 5); - assert_eq!(mb.data[2], 0); - assert_eq!(mb.data[3], 1); - assert_eq!(mb.data[20], 20); - assert_eq!(mb.data[40], 115); - assert_eq!(mb.data[60], 111); + let msg = Message::serialize(&msg, &stub_stream(true)); + assert_eq!(msg.buffer.data.len(), 79); + assert_eq!(msg.buffer.data[0], 34); + assert_eq!(msg.buffer.data[1], 5); + assert_eq!(msg.buffer.data[2], 0); + assert_eq!(msg.buffer.data[3], 1); + assert_eq!(msg.buffer.data[20], 20); + assert_eq!(msg.buffer.data[40], 115); + assert_eq!(msg.buffer.data[60], 111); } #[cfg(feature = "compression")] @@ -257,7 +356,7 @@ mod tests { _ => {}, } } - let mb = serialize(&msg, true); - assert_eq!(mb.data.len(), 1296); + let msg = Message::serialize(&msg, &stub_stream(true)); + assert_eq!(msg.buffer.data.len(), 1296); } } diff --git a/network/tests/closing.rs b/network/tests/closing.rs index 2c714277d3..fac118ff5a 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -343,3 +343,46 @@ fn close_network_scheduler_completely() { ha.join().unwrap(); hb.join().unwrap(); } + +#[test] +fn dont_panic_on_multiply_recv_after_close() { + let (_, _) = helper::setup(false, 0); + let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + + s1_a.send(11u32).unwrap(); + drop(s1_a); + std::thread::sleep(std::time::Duration::from_secs(1)); + assert_eq!(s1_b.try_recv::(), Ok(Some(11u32))); + assert_eq!(s1_b.try_recv::(), Err(StreamError::StreamClosed)); + // There was a "Feature" in futures::channels that they panic when you call recv + // a second time after it showed end of stream + assert_eq!(s1_b.try_recv::(), Err(StreamError::StreamClosed)); +} + +#[test] +fn dont_panic_on_recv_send_after_close() { + let (_, _) = helper::setup(false, 0); + let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + + s1_a.send(11u32).unwrap(); + drop(s1_a); + std::thread::sleep(std::time::Duration::from_secs(1)); + assert_eq!(s1_b.try_recv::(), Ok(Some(11u32))); + assert_eq!(s1_b.try_recv::(), Err(StreamError::StreamClosed)); + assert_eq!(s1_b.send("foobar"), Err(StreamError::StreamClosed)); +} + +#[test] +fn dont_panic_on_multiple_send_after_close() { + let (_, _) = helper::setup(false, 0); + let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + + s1_a.send(11u32).unwrap(); + drop(s1_a); + drop(_p_a); + std::thread::sleep(std::time::Duration::from_secs(1)); + assert_eq!(s1_b.try_recv::(), Ok(Some(11u32))); + assert_eq!(s1_b.try_recv::(), Err(StreamError::StreamClosed)); + assert_eq!(s1_b.send("foobar"), Err(StreamError::StreamClosed)); + assert_eq!(s1_b.send("foobar"), Err(StreamError::StreamClosed)); +} diff --git a/network/tests/integration.rs b/network/tests/integration.rs index 30e1f79fd6..696691d6d4 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -23,6 +23,16 @@ fn stream_simple() { assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); } +#[test] +fn stream_try_recv() { + let (_, _) = helper::setup(false, 0); + let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + + s1_a.send(4242u32).unwrap(); + std::thread::sleep(std::time::Duration::from_secs(1)); + assert_eq!(s1_b.try_recv(), Ok(Some(4242u32))); +} + #[test] fn stream_simple_3msg() { let (_, _) = helper::setup(false, 0); @@ -182,3 +192,20 @@ fn wrong_parse() { _ => panic!("this should fail, but it doesnt!"), } } + +#[test] +fn multiple_try_recv() { + let (_, _) = helper::setup(false, 0); + let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + + s1_a.send("asd").unwrap(); + s1_a.send(11u32).unwrap(); + std::thread::sleep(std::time::Duration::from_secs(1)); + assert_eq!(s1_b.try_recv(), Ok(Some("asd".to_string()))); + assert_eq!(s1_b.try_recv::(), Ok(Some(11u32))); + assert_eq!(s1_b.try_recv::(), Ok(None)); + + drop(s1_a); + std::thread::sleep(std::time::Duration::from_secs(1)); + assert_eq!(s1_b.try_recv::(), Err(StreamError::StreamClosed)); +}