diff --git a/network/src/api.rs b/network/src/api.rs index 33ce2de6ad..54de28cf35 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -54,7 +54,6 @@ pub struct Participant { remote_pid: Pid, a2b_stream_open_s: RwLock>, b2a_stream_opened_r: RwLock>, - closed: Arc>>, a2s_disconnect_s: A2sDisconnect, } @@ -78,9 +77,9 @@ pub struct Stream { mid: Mid, prio: Prio, promises: Promises, + send_closed: Arc, a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, b2a_msg_recv_r: mpsc::UnboundedReceiver, - closed: Arc, a2b_close_stream_s: Option>, } @@ -427,14 +426,12 @@ impl Participant { a2b_stream_open_s: mpsc::UnboundedSender, b2a_stream_opened_r: mpsc::UnboundedReceiver, a2s_disconnect_s: mpsc::UnboundedSender<(Pid, S2bShutdownBparticipant)>, - closed: Arc>>, ) -> Self { Self { local_pid, remote_pid, a2b_stream_open_s: RwLock::new(a2b_stream_open_s), b2a_stream_opened_r: RwLock::new(b2a_stream_opened_r), - closed, a2s_disconnect_s: Arc::new(Mutex::new(Some(a2s_disconnect_s))), } } @@ -483,12 +480,14 @@ impl Participant { //use this lock for now to make sure that only one open at a time is made, // TODO: not sure if we can paralise that, check in future let mut a2b_stream_open_s = self.a2b_stream_open_s.write().await; - self.closed.read().await.clone()?; let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel(); - a2b_stream_open_s + if let Err(e) = a2b_stream_open_s .send((prio, promises, p2a_return_stream_s)) .await - .unwrap(); + { + debug!(?e, "bParticipant is already closed, notifying"); + return Err(ParticipantError::ParticipantDisconnected); + } match p2a_return_stream_r.await { Ok(stream) => { let sid = stream.sid; @@ -497,8 +496,7 @@ impl Participant { }, Err(_) => { debug!(?self.remote_pid, "p2a_return_stream_r failed, closing participant"); - *self.closed.write().await = Err(ParticipantError::ProtocolFailedUnrecoverable); - Err(ParticipantError::ProtocolFailedUnrecoverable) + Err(ParticipantError::ParticipantDisconnected) }, } } @@ -540,7 +538,6 @@ impl Participant { //use this lock for now to make sure that only one open at a time is made, // TODO: not sure if we can paralise that, check in future let mut stream_opened_receiver = self.b2a_stream_opened_r.write().await; - self.closed.read().await.clone()?; match stream_opened_receiver.next().await { Some(stream) => { let sid = stream.sid; @@ -549,8 +546,7 @@ impl Participant { }, None => { debug!(?self.remote_pid, "stream_opened_receiver failed, closing participant"); - *self.closed.write().await = Err(ParticipantError::ProtocolFailedUnrecoverable); - Err(ParticipantError::ProtocolFailedUnrecoverable) + Err(ParticipantError::ParticipantDisconnected) }, } } @@ -602,11 +598,6 @@ impl Participant { // Remove, Close and try_unwrap error when unwrap fails! let pid = self.remote_pid; debug!(?pid, "Closing participant from network"); - { - let mut lock = self.closed.write().await; - lock.clone()?; - *lock = Err(ParticipantError::ParticipantDisconnected); - } //Streams will be closed by BParticipant match self.a2s_disconnect_s.lock().await.take() { @@ -619,17 +610,14 @@ impl Participant { .await .expect("Something is wrong in internal scheduler coding"); match finished_receiver.await { - Ok(Ok(())) => { - trace!(?pid, "Participant is now closed"); - Ok(()) - }, - Ok(Err(e)) => { - trace!( - ?e, - "Error occured during shutdown of participant and is propagated to \ - User" - ); - Err(ParticipantError::ProtocolFailedUnrecoverable) + Ok(res) => { + match res { + Ok(()) => trace!(?pid, "Participant is now closed"), + Err(ref e) => { + trace!(?pid, ?e, "Error occured during shutdown of participant") + }, + }; + res }, Err(e) => { //this is a bug. but as i am Participant i can't destroy the network @@ -664,9 +652,9 @@ impl Stream { sid: Sid, prio: Prio, promises: Promises, + send_closed: Arc, a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, b2a_msg_recv_r: mpsc::UnboundedReceiver, - closed: Arc, a2b_close_stream_s: mpsc::UnboundedSender, ) -> Self { Self { @@ -675,9 +663,9 @@ impl Stream { mid: 0, prio, promises, + send_closed, a2b_msg_s, b2a_msg_recv_r, - closed, a2b_close_stream_s: Some(a2b_close_stream_s), } } @@ -788,10 +776,9 @@ impl Stream { /// [`send`]: Stream::send /// [`Participants`]: crate::api::Participant pub fn send_raw(&mut self, messagebuffer: Arc) -> Result<(), StreamError> { - if self.closed.load(Ordering::Relaxed) { + if self.send_closed.load(Ordering::Relaxed) { return Err(StreamError::StreamClosed); } - //debug!(?messagebuffer, "sending a message"); self.a2b_msg_s.send((self.prio, self.sid, OutgoingMessage { buffer: messagebuffer, cursor: 0, @@ -847,10 +834,7 @@ impl Stream { /// [`send_raw`]: Stream::send_raw /// [`recv`]: Stream::recv pub async fn recv_raw(&mut self) -> Result { - //no need to access self.closed here, as when this stream is closed the Channel - // is closed which will trigger a None let msg = self.b2a_msg_recv_r.next().await?; - //info!(?msg, "delivering a message"); Ok(msg.buffer) } } @@ -959,8 +943,8 @@ impl Drop for Participant { impl Drop for Stream { fn drop(&mut self) { - // a send if closed is unecessary but doesnt hurt, we must not crash here - if !self.closed.load(Ordering::Relaxed) { + // send if closed is unecessary but doesnt hurt, we must not crash + if !self.send_closed.load(Ordering::Relaxed) { let sid = self.sid; let pid = self.pid; debug!(?pid, ?sid, "Shutting down Stream"); diff --git a/network/src/participant.rs b/network/src/participant.rs index 8d6fa66aab..d2035b9989 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -28,7 +28,7 @@ use tracing::*; pub(crate) type A2bStreamOpen = (Prio, Promises, oneshot::Sender); pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, Vec<(Cid, Frame)>, oneshot::Sender<()>); -pub(crate) type S2bShutdownBparticipant = oneshot::Sender>; +pub(crate) type S2bShutdownBparticipant = oneshot::Sender>; pub(crate) type B2sPrioStatistic = (Pid, u64, u64); #[derive(Debug)] @@ -43,8 +43,8 @@ struct ChannelInfo { struct StreamInfo { prio: Prio, promises: Promises, + send_closed: Arc, b2a_msg_recv_s: mpsc::UnboundedSender, - closed: Arc, } #[derive(Debug)] @@ -57,6 +57,13 @@ struct ControlChannels { s2b_shutdown_bparticipant_r: oneshot::Receiver, /* own */ } +#[derive(Debug)] +struct ShutdownInfo { + //a2b_stream_open_r: mpsc::UnboundedReceiver, + b2a_stream_opened_s: mpsc::UnboundedSender, + error: Option, +} + #[derive(Debug)] pub struct BParticipant { remote_pid: Pid, @@ -64,12 +71,12 @@ pub struct BParticipant { offset_sid: Sid, channels: Arc>>, streams: RwLock>, - api_participant_closed: Arc>>, running_mgr: AtomicUsize, run_channels: Option, #[cfg(feature = "metrics")] metrics: Arc, no_channel_error_info: RwLock<(Instant, u64)>, + shutdown_info: RwLock, } impl BParticipant { @@ -84,7 +91,6 @@ impl BParticipant { mpsc::UnboundedReceiver, mpsc::UnboundedSender, oneshot::Sender, - Arc>>, ) { let (a2b_steam_open_s, a2b_stream_open_r) = mpsc::unbounded::(); let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded::(); @@ -92,6 +98,12 @@ impl BParticipant { let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel(); let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded(); + let shutdown_info = RwLock::new(ShutdownInfo { + //a2b_stream_open_r: a2b_stream_open_r.clone(), + b2a_stream_opened_s: b2a_stream_opened_s.clone(), + error: None, + }); + let run_channels = Some(ControlChannels { a2b_stream_open_r, b2a_stream_opened_s, @@ -101,8 +113,6 @@ impl BParticipant { s2b_shutdown_bparticipant_r, }); - let api_participant_closed = Arc::new(RwLock::new(Ok(()))); - ( Self { remote_pid, @@ -110,18 +120,17 @@ impl BParticipant { offset_sid, channels: Arc::new(RwLock::new(vec![])), streams: RwLock::new(HashMap::new()), - api_participant_closed: api_participant_closed.clone(), running_mgr: AtomicUsize::new(0), run_channels, #[cfg(feature = "metrics")] metrics, no_channel_error_info: RwLock::new((Instant::now(), 0)), + shutdown_info, }, a2b_steam_open_s, b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, - api_participant_closed, ) } @@ -269,7 +278,7 @@ impl BParticipant { will be closed, but not if we do channel-takeover" ); //TEMP FIX: as we dont have channel takeover yet drop the whole bParticipant - self.close_api(ParticipantError::ProtocolFailedUnrecoverable) + self.close_api(Some(ParticipantError::ProtocolFailedUnrecoverable)) .await; false } else { @@ -347,7 +356,8 @@ impl BParticipant { .streams_closed_total .with_label_values(&[&self.remote_pid_string]) .inc(); - si.closed.store(true, Ordering::Relaxed); + si.send_closed.store(true, Ordering::Relaxed); + si.b2a_msg_recv_s.close_channel(); trace!(?sid, "Closed stream from remote"); } else { warn!( @@ -411,7 +421,7 @@ impl BParticipant { }, Frame::Shutdown => { debug!("Shutdown received from remote side"); - self.close_api(ParticipantError::ParticipantDisconnected) + self.close_api(Some(ParticipantError::ParticipantDisconnected)) .await; }, f => unreachable!("Frame should never reache participant!: {:?}", f), @@ -495,6 +505,13 @@ impl BParticipant { _ = shutdown_open_mgr_receiver => None, } { debug!(?prio, ?promises, "Got request to open a new steam"); + //TODO: a2b_stream_open_r isn't closed on api_close yet. This needs to change. + //till then just check here if we are closed and in that case do nothing (not + // even answer) + if self.shutdown_info.read().await.error.is_some() { + continue; + } + let a2p_msg_s = a2p_msg_s.clone(); let sid = stream_ids; let stream = self @@ -538,10 +555,7 @@ impl BParticipant { trace!("Start participant_shutdown_mgr"); let sender = s2b_shutdown_bparticipant_r.await.unwrap(); - //Todo: isn't ParticipantDisconnected useless, as api is waiting rn for a - // callback? - self.close_api(ParticipantError::ParticipantDisconnected) - .await; + self.close_api(None).await; debug!("Closing all managers"); for sender in mgr_to_shutdown.drain(..) { @@ -580,7 +594,14 @@ impl BParticipant { self.metrics.participants_disconnected_total.inc(); debug!("BParticipant close done"); - sender.send(Ok(())).unwrap(); + let mut lock = self.shutdown_info.write().await; + sender + .send(match lock.error.take() { + None => Ok(()), + Some(e) => Err(e), + }) + .unwrap(); + trace!("Stop participant_shutdown_mgr"); self.running_mgr.fetch_sub(1, Ordering::Relaxed); } @@ -616,7 +637,8 @@ impl BParticipant { trace!(?sid, "Stopping api to use this stream"); match self.streams.read().await.get(&sid) { Some(si) => { - si.closed.store(true, Ordering::Relaxed); + si.send_closed.store(true, Ordering::Relaxed); + si.b2a_msg_recv_s.close_channel(); }, None => warn!("Couldn't find the stream, might be simulanious close from remote"), } @@ -658,12 +680,12 @@ impl BParticipant { a2b_close_stream_s: &mpsc::UnboundedSender, ) -> Stream { let (b2a_msg_recv_s, b2a_msg_recv_r) = mpsc::unbounded::(); - let closed = Arc::new(AtomicBool::new(false)); + let send_closed = Arc::new(AtomicBool::new(false)); self.streams.write().await.insert(sid, StreamInfo { prio, promises, + send_closed: send_closed.clone(), b2a_msg_recv_s, - closed: closed.clone(), }); #[cfg(feature = "metrics")] self.metrics @@ -675,20 +697,28 @@ impl BParticipant { sid, prio, promises, + send_closed, a2p_msg_s, b2a_msg_recv_r, - closed.clone(), a2b_close_stream_s.clone(), ) } /// close streams and set err - async fn close_api(&self, err: ParticipantError) { - *self.api_participant_closed.write().await = Err(err); + async fn close_api(&self, reason: Option) { + //closing api::Participant is done by closing all channels, exepct for the + // shutdown channel at this point! + let mut lock = self.shutdown_info.write().await; + if let Some(r) = reason { + lock.error = Some(r); + } + lock.b2a_stream_opened_s.close_channel(); + debug!("Closing all streams"); for (sid, si) in self.streams.write().await.drain() { trace!(?sid, "Shutting down Stream"); - si.closed.store(true, Ordering::Relaxed); + si.b2a_msg_recv_s.close_channel(); + si.send_closed.store(true, Ordering::Relaxed); } } } diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index b29818d1b1..000ccd92e7 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -3,7 +3,7 @@ use crate::metrics::NetworkMetrics; use crate::{ api::{Participant, ProtocolAddr}, channel::Handshake, - participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel}, + participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, protocols::{Protocols, TcpProtocol, UdpProtocol}, types::Pid, }; @@ -45,13 +45,12 @@ use tracing_futures::Instrument; struct ParticipantInfo { secret: u128, s2b_create_channel_s: mpsc::UnboundedSender, - s2b_shutdown_bparticipant_s: - Option>>>, + s2b_shutdown_bparticipant_s: Option>, } type A2sListen = (ProtocolAddr, oneshot::Sender>); type A2sConnect = (ProtocolAddr, oneshot::Sender>); -type A2sDisconnect = (Pid, oneshot::Sender>); +type A2sDisconnect = (Pid, S2bShutdownBparticipant); #[derive(Debug)] struct ControlChannels { @@ -529,7 +528,6 @@ impl Scheduler { b2a_stream_opened_r, mut s2b_create_channel_s, s2b_shutdown_bparticipant_s, - api_participant_closed, ) = BParticipant::new( pid, sid, @@ -543,7 +541,6 @@ impl Scheduler { a2b_stream_open_s, b2a_stream_opened_r, participant_channels.a2s_disconnect_s, - api_participant_closed, ); #[cfg(feature = "metrics")] diff --git a/network/tests/closing.rs b/network/tests/closing.rs index 9b79f9c88a..1eb81b1990 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -20,7 +20,7 @@ use async_std::task; use task::block_on; -use veloren_network::StreamError; +use veloren_network::{Network, ParticipantError, Pid, StreamError, PROMISES_NONE}; mod helper; use helper::{network_participant_stream, tcp}; @@ -42,14 +42,10 @@ fn close_participant() { let (_n_a, p1_a, mut s1_a, _n_b, p1_b, mut s1_b) = block_on(network_participant_stream(tcp())); block_on(p1_a.disconnect()).unwrap(); - // The following will `Err`, but we don't know the exact error message. - // Why? because of the TCP layer we have no guarantee if the TCP messages send - // one line above already reached `p1_b`. If they reached them it would fail - // with a `ParticipantDisconnected` as a clean disconnect was performed. - // If they haven't reached them yet but will reach them during the execution it - // will return a unclean shutdown was detected. Nevertheless, if it returns - // Ok(()) then something is wrong! - assert!(block_on(p1_b.disconnect()).is_err()); + assert_eq!( + block_on(p1_b.disconnect()), + Err(ParticipantError::ParticipantDisconnected) + ); assert_eq!(s1_a.send("Hello World"), Err(StreamError::StreamClosed)); assert_eq!( @@ -229,3 +225,99 @@ fn close_network_then_disconnect_part() { assert!(block_on(p_a.disconnect()).is_err()); std::thread::sleep(std::time::Duration::from_millis(1000)); } + +#[test] +fn opened_stream_before_remote_part_is_closed() { + let (_, _) = helper::setup(false, 0); + let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); + let mut s2_a = block_on(p_a.open(10, PROMISES_NONE)).unwrap(); + s2_a.send("HelloWorld").unwrap(); + let mut s2_b = block_on(p_b.opened()).unwrap(); + drop(p_a); + std::thread::sleep(std::time::Duration::from_millis(1000)); + assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); +} + +#[test] +fn opened_stream_after_remote_part_is_closed() { + let (_, _) = helper::setup(false, 0); + let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); + let mut s2_a = block_on(p_a.open(10, PROMISES_NONE)).unwrap(); + s2_a.send("HelloWorld").unwrap(); + drop(p_a); + std::thread::sleep(std::time::Duration::from_millis(1000)); + let mut s2_b = block_on(p_b.opened()).unwrap(); + assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + assert_eq!( + block_on(p_b.opened()).unwrap_err(), + ParticipantError::ParticipantDisconnected + ); +} + +#[test] +fn open_stream_after_remote_part_is_closed() { + let (_, _) = helper::setup(false, 0); + let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); + let mut s2_a = block_on(p_a.open(10, PROMISES_NONE)).unwrap(); + s2_a.send("HelloWorld").unwrap(); + drop(p_a); + std::thread::sleep(std::time::Duration::from_millis(1000)); + let mut s2_b = block_on(p_b.opened()).unwrap(); + assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + assert_eq!( + block_on(p_b.open(20, PROMISES_NONE)).unwrap_err(), + ParticipantError::ParticipantDisconnected + ); +} + +#[test] +fn failed_stream_open_after_remote_part_is_closed() { + let (_, _) = helper::setup(false, 0); + let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); + drop(p_a); + std::thread::sleep(std::time::Duration::from_millis(1000)); + assert_eq!( + block_on(p_b.opened()).unwrap_err(), + ParticipantError::ParticipantDisconnected + ); +} + +#[test] +fn open_participant_before_remote_part_is_closed() { + let (n_a, f) = Network::new(Pid::fake(1)); + std::thread::spawn(f); + let (n_b, f) = Network::new(Pid::fake(2)); + std::thread::spawn(f); + let addr = tcp(); + block_on(n_a.listen(addr.clone())).unwrap(); + let p_b = block_on(n_b.connect(addr)).unwrap(); + let mut s1_b = block_on(p_b.open(10, PROMISES_NONE)).unwrap(); + s1_b.send("HelloWorld").unwrap(); + let p_a = block_on(n_a.connected()).unwrap(); + drop(s1_b); + drop(p_b); + drop(n_b); + std::thread::sleep(std::time::Duration::from_millis(1000)); + let mut s1_a = block_on(p_a.opened()).unwrap(); + assert_eq!(block_on(s1_a.recv()), Ok("HelloWorld".to_string())); +} + +#[test] +fn open_participant_after_remote_part_is_closed() { + let (n_a, f) = Network::new(Pid::fake(1)); + std::thread::spawn(f); + let (n_b, f) = Network::new(Pid::fake(2)); + std::thread::spawn(f); + let addr = tcp(); + block_on(n_a.listen(addr.clone())).unwrap(); + let p_b = block_on(n_b.connect(addr)).unwrap(); + let mut s1_b = block_on(p_b.open(10, PROMISES_NONE)).unwrap(); + s1_b.send("HelloWorld").unwrap(); + drop(s1_b); + drop(p_b); + drop(n_b); + std::thread::sleep(std::time::Duration::from_millis(1000)); + let p_a = block_on(n_a.connected()).unwrap(); + let mut s1_a = block_on(p_a.opened()).unwrap(); + assert_eq!(block_on(s1_a.recv()), Ok("HelloWorld".to_string())); +} diff --git a/server/src/lib.rs b/server/src/lib.rs index de784f3025..c37fdb9708 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -376,13 +376,7 @@ impl Server { let before_new_connections = Instant::now(); // 3) Handle inputs from clients - block_on(async { - //TIMEOUT 0.1 ms for msg handling - select!( - _ = Delay::new(std::time::Duration::from_micros(100)).fuse() => Ok(()), - err = self.handle_new_connections(&mut frontend_events).fuse() => err, - ) - })?; + block_on(self.handle_new_connections(&mut frontend_events))?; let before_message_system = Instant::now(); @@ -629,16 +623,29 @@ impl Server { &mut self, frontend_events: &mut Vec, ) -> Result<(), Error> { + //TIMEOUT 0.1 ms for msg handling + const TIMEOUT: Duration = Duration::from_micros(100); loop { - let participant = self.network.connected().await?; + let participant = match select!( + _ = Delay::new(TIMEOUT).fuse() => None, + pr = self.network.connected().fuse() => Some(pr), + ) { + None => return Ok(()), + Some(pr) => pr?, + }; debug!("New Participant connected to the server"); - let singleton_stream = match participant.opened().await { - Ok(s) => s, - Err(e) => { - warn!( - ?e, - "Failed to open a Stream from remote client. Dropping it" - ); + + let singleton_stream = match select!( + _ = Delay::new(TIMEOUT*100).fuse() => None, + sr = participant.opened().fuse() => Some(sr), + ) { + None => { + warn!("Either Slowloris attack or very slow client, dropping"); + return Ok(()); //return rather then continue to give removes a tick more to send data. + }, + Some(Ok(s)) => s, + Some(Err(e)) => { + warn!(?e, "Failed to open a Stream from remote client. dropping"); continue; }, };