diff --git a/client/src/lib.rs b/client/src/lib.rs index 13c191ecdd..eb6da6b25a 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -2364,6 +2364,11 @@ impl Client { return Err(Error::ServerTimeout); } + // ignore network events + while let Some(Ok(Some(event))) = self.participant.as_ref().map(|p| p.try_fetch_event()) { + trace!(?event, "received network event"); + } + Ok(frontend_events) } diff --git a/network/src/api.rs b/network/src/api.rs index dc028dc797..d5322502a7 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -49,6 +49,14 @@ pub enum ListenAddr { Mpsc(u64), } +/// a Participant can throw different events, you are obligated to carefully +/// empty the queue from time to time +#[derive(Clone, Debug)] +pub enum ParticipantEvent { + ChannelCreated(ConnectAddr), + ChannelDeleted(ConnectAddr), +} + /// `Participants` are generated by the [`Network`] and represent a connection /// to a remote Participant. Look at the [`connect`] and [`connected`] method of /// [`Networks`] on how to generate `Participants` @@ -61,6 +69,7 @@ pub struct Participant { remote_pid: Pid, a2b_open_stream_s: Mutex>, b2a_stream_opened_r: Mutex>, + b2a_event_r: Mutex>, b2a_bandwidth_stats_r: watch::Receiver, a2s_disconnect_s: A2sDisconnect, } @@ -520,6 +529,7 @@ impl Participant { remote_pid: Pid, a2b_open_stream_s: mpsc::UnboundedSender, b2a_stream_opened_r: mpsc::UnboundedReceiver, + b2a_event_r: mpsc::UnboundedReceiver, b2a_bandwidth_stats_r: watch::Receiver, a2s_disconnect_s: mpsc::UnboundedSender<(Pid, S2bShutdownBparticipant)>, ) -> Self { @@ -528,6 +538,7 @@ impl Participant { remote_pid, a2b_open_stream_s: Mutex::new(a2b_open_stream_s), b2a_stream_opened_r: Mutex::new(b2a_stream_opened_r), + b2a_event_r: Mutex::new(b2a_event_r), b2a_bandwidth_stats_r, a2s_disconnect_s: Arc::new(Mutex::new(Some(a2s_disconnect_s))), } @@ -752,6 +763,67 @@ impl Participant { } } + /// Use this method to query [`ParticipantEvent`]. Those are internal events + /// from the network crate that will get reported to the frontend. + /// E.g. Creation and Deletion of Channels. + /// + /// Make sure to call this function from time to time to not let events + /// stack up endlessly and create a memory leak. + /// + /// # Examples + /// ```rust + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr, Promises, ParticipantEvent}; + /// + /// # fn main() -> std::result::Result<(), Box> { + /// // Create a Network, connect on port 2040 and wait for the other side to open a stream + /// // Note: It's quite unusual to actively connect, but then wait on a stream to be connected, usually the Application taking initiative want's to also create the first Stream. + /// let runtime = Runtime::new().unwrap(); + /// let network = Network::new(Pid::new(), &runtime); + /// # let remote = Network::new(Pid::new(), &runtime); + /// runtime.block_on(async { + /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2040".parse().unwrap())).await?; + /// let p1 = network.connect(ConnectAddr::Tcp("127.0.0.1:2040".parse().unwrap())).await?; + /// # let p2 = remote.connected().await?; + /// let event = p1.fetch_event().await?; + /// drop(network); + /// # drop(remote); + /// # Ok(()) + /// }) + /// # } + /// ``` + /// + /// [`ParticipantEvent`]: crate::api::ParticipantEvent + pub async fn fetch_event(&self) -> Result { + match self.b2a_event_r.lock().await.recv().await { + Some(event) => Ok(event), + None => { + debug!("event_receiver failed, closing participant"); + Err(ParticipantError::ParticipantDisconnected) + }, + } + } + + /// use `try_fetch_event` to check for a [`ParticipantEvent`] . This + /// function does not block and returns immediately. It's intended for + /// use in non-async context only. Other then that, the same rules apply + /// than for [`fetch_event`]. + /// + /// [`ParticipantEvent`]: crate::api::ParticipantEvent + /// [`fetch_event`]: Participant::fetch_event + pub fn try_fetch_event(&self) -> Result, ParticipantError> { + match &mut self.b2a_event_r.try_lock() { + Ok(b2a_event_r) => match b2a_event_r.try_recv() { + Ok(event) => Ok(Some(event)), + Err(mpsc::error::TryRecvError::Empty) => Ok(None), + Err(mpsc::error::TryRecvError::Disconnected) => { + Err(ParticipantError::ParticipantDisconnected) + }, + }, + Err(_) => Ok(None), + } + } + /// Returns the current approximation on the maximum bandwidth available. /// This WILL fluctuate based on the amount/size of send messages. pub fn bandwidth(&self) -> f32 { *self.b2a_bandwidth_stats_r.borrow() } diff --git a/network/src/lib.rs b/network/src/lib.rs index 742ab23eb7..a7da1cf5fc 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -2,6 +2,7 @@ #![cfg_attr(test, deny(rust_2018_idioms))] #![cfg_attr(test, deny(warnings))] #![deny(clippy::clone_on_ref_ptr)] +#![feature(assert_matches)] //! Crate to handle high level networking of messages with different //! requirements and priorities over a number of protocols @@ -108,7 +109,7 @@ mod util; pub use api::{ ConnectAddr, ListenAddr, Network, NetworkConnectError, NetworkError, Participant, - ParticipantError, Stream, StreamError, StreamParams, + ParticipantError, ParticipantEvent, Stream, StreamError, StreamParams, }; pub use message::Message; pub use network_protocol::{InitProtocolError, Pid, Promises}; diff --git a/network/src/participant.rs b/network/src/participant.rs index 9564f55a81..2836b81e65 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -1,5 +1,5 @@ use crate::{ - api::{ConnectAddr, ParticipantError, Stream}, + api::{ConnectAddr, ParticipantError, ParticipantEvent, Stream}, channel::{Protocols, ProtocolsError, RecvProtocols, SendProtocols}, metrics::NetworkMetrics, util::DeferredTracer, @@ -53,6 +53,7 @@ struct StreamInfo { struct ControlChannels { a2b_open_stream_r: mpsc::UnboundedReceiver, b2a_stream_opened_s: mpsc::UnboundedSender, + b2a_event_s: mpsc::UnboundedSender, s2b_create_channel_r: mpsc::UnboundedReceiver, b2a_bandwidth_stats_s: watch::Sender, s2b_shutdown_bparticipant_r: oneshot::Receiver, /* own */ @@ -95,12 +96,14 @@ impl BParticipant { Self, mpsc::UnboundedSender, mpsc::UnboundedReceiver, + mpsc::UnboundedReceiver, mpsc::UnboundedSender, oneshot::Sender, watch::Receiver, ) { let (a2b_open_stream_s, a2b_open_stream_r) = mpsc::unbounded_channel::(); let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::(); + let (b2a_event_s, b2a_event_r) = mpsc::unbounded_channel::(); let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel(); let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel(); let (b2a_bandwidth_stats_s, b2a_bandwidth_stats_r) = watch::channel::(0.0); @@ -108,6 +111,7 @@ impl BParticipant { let run_channels = Some(ControlChannels { a2b_open_stream_r, b2a_stream_opened_s, + b2a_event_s, s2b_create_channel_r, b2a_bandwidth_stats_s, s2b_shutdown_bparticipant_r, @@ -130,6 +134,7 @@ impl BParticipant { }, a2b_open_stream_s, b2a_stream_opened_r, + b2a_event_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, b2a_bandwidth_stats_r, @@ -168,6 +173,7 @@ impl BParticipant { b2b_close_send_protocol_r, b2b_notify_send_of_recv_open_r, b2b_notify_send_of_recv_close_r, + run_channels.b2a_event_s.clone(), b2s_prio_statistic_s, run_channels.b2a_bandwidth_stats_s, ) @@ -185,6 +191,7 @@ impl BParticipant { run_channels.s2b_create_channel_r, b2b_add_send_protocol_s, b2b_add_recv_protocol_s, + run_channels.b2a_event_s, ), self.participant_shutdown_mgr( run_channels.s2b_shutdown_bparticipant_r, @@ -238,6 +245,7 @@ impl BParticipant { Bandwidth, )>, b2b_notify_send_of_recv_close_r: crossbeam_channel::Receiver<(Cid, Sid)>, + b2a_event_s: mpsc::UnboundedSender, _b2s_prio_statistic_s: mpsc::UnboundedSender, b2a_bandwidth_stats_s: watch::Sender, ) { @@ -382,6 +390,13 @@ impl BParticipant { // recv trace!("TODO: for now decide to FAIL this participant and not wait for a failover"); sorted_send_protocols.delete(&cid).unwrap(); + if let Some(info) = self.channels.write().await.get(&cid) { + if let Err(e) = b2a_event_s.send(ParticipantEvent::ChannelDeleted( + info.lock().await.remote_con_addr.clone(), + )) { + debug!(?e, "Participant was dropped during channel disconnect"); + }; + } self.metrics.channels_disconnected(&self.remote_pid_string); if sorted_send_protocols.data.is_empty() { break; @@ -392,6 +407,13 @@ impl BParticipant { debug!(?cid, "remove protocol"); match sorted_send_protocols.delete(&cid) { Some(mut prot) => { + if let Some(info) = self.channels.write().await.get(&cid) { + if let Err(e) = b2a_event_s.send(ParticipantEvent::ChannelDeleted( + info.lock().await.remote_con_addr.clone(), + )) { + debug!(?e, "Participant was dropped during channel disconnect"); + }; + } self.metrics.channels_disconnected(&self.remote_pid_string); trace!("blocking flush"); let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await; @@ -558,6 +580,7 @@ impl BParticipant { s2b_create_channel_r: mpsc::UnboundedReceiver, b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>, b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>, + b2a_event_s: mpsc::UnboundedSender, ) { let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r); s2b_create_channel_r @@ -569,6 +592,7 @@ impl BParticipant { let channels = Arc::clone(&self.channels); let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone(); let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone(); + let b2a_event_s = b2a_event_s.clone(); async move { let mut lock = channels.write().await; let mut channel_no = lock.len(); @@ -577,13 +601,18 @@ impl BParticipant { Mutex::new(ChannelInfo { cid, cid_string: cid.to_string(), - remote_con_addr, + remote_con_addr: remote_con_addr.clone(), }), ); drop(lock); let (send, recv) = protocol.split(); b2b_add_send_protocol_s.send((cid, send)).unwrap(); b2b_add_recv_protocol_s.send((cid, recv)).unwrap(); + if let Err(e) = + b2a_event_s.send(ParticipantEvent::ChannelCreated(remote_con_addr)) + { + debug!(?e, "Participant was dropped during channel connect"); + }; b2s_create_channel_done_s.send(()).unwrap(); if channel_no > 5 { debug!(?channel_no, "metrics will overwrite channel #5"); @@ -777,6 +806,7 @@ impl BParticipant { #[cfg(test)] mod tests { use super::*; + use core::assert_matches::assert_matches; use network_protocol::{ProtocolMetricCache, ProtocolMetrics}; use tokio::{ runtime::Runtime, @@ -788,6 +818,7 @@ mod tests { Arc, mpsc::UnboundedSender, mpsc::UnboundedReceiver, + mpsc::UnboundedReceiver, mpsc::UnboundedSender, oneshot::Sender, mpsc::UnboundedReceiver, @@ -804,6 +835,7 @@ mod tests { bparticipant, a2b_open_stream_s, b2a_stream_opened_r, + b2a_event_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, b2a_bandwidth_stats_r, @@ -821,6 +853,7 @@ mod tests { runtime_clone, a2b_open_stream_s, b2a_stream_opened_r, + b2a_event_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, b2s_prio_statistic_r, @@ -854,6 +887,7 @@ mod tests { runtime, a2b_open_stream_s, b2a_stream_opened_r, + mut b2a_event_r, mut s2b_create_channel_s, s2b_shutdown_bparticipant_s, b2s_prio_statistic_r, @@ -877,6 +911,15 @@ mod tests { before.elapsed() > Duration::from_millis(900), "timeout wasn't triggered" ); + assert_matches!( + b2a_event_r.try_recv().unwrap(), + ParticipantEvent::ChannelCreated(_) + ); + assert_matches!( + b2a_event_r.try_recv().unwrap(), + ParticipantEvent::ChannelDeleted(_) + ); + assert_matches!(b2a_event_r.try_recv(), Err(_)); runtime.block_on(handle).unwrap(); @@ -890,6 +933,7 @@ mod tests { runtime, a2b_open_stream_s, b2a_stream_opened_r, + mut b2a_event_r, mut s2b_create_channel_s, s2b_shutdown_bparticipant_s, b2s_prio_statistic_r, @@ -914,6 +958,15 @@ mod tests { before.elapsed() < Duration::from_millis(1900), "timeout was triggered" ); + assert_matches!( + b2a_event_r.try_recv().unwrap(), + ParticipantEvent::ChannelCreated(_) + ); + assert_matches!( + b2a_event_r.try_recv().unwrap(), + ParticipantEvent::ChannelDeleted(_) + ); + assert_matches!(b2a_event_r.try_recv(), Err(_)); runtime.block_on(handle).unwrap(); @@ -927,6 +980,7 @@ mod tests { runtime, a2b_open_stream_s, b2a_stream_opened_r, + _b2a_event_r, mut s2b_create_channel_s, s2b_shutdown_bparticipant_s, b2s_prio_statistic_r, @@ -982,6 +1036,7 @@ mod tests { runtime, a2b_open_stream_s, mut b2a_stream_opened_r, + _b2a_event_r, mut s2b_create_channel_s, s2b_shutdown_bparticipant_s, b2s_prio_statistic_r, diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 63a37356c5..89bb14a580 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -418,6 +418,7 @@ impl Scheduler { bparticipant, a2b_open_stream_s, b2a_stream_opened_r, + b2a_event_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, b2a_bandwidth_stats_r, @@ -428,6 +429,7 @@ impl Scheduler { pid, a2b_open_stream_s, b2a_stream_opened_r, + b2a_event_r, b2a_bandwidth_stats_r, participant_channels.a2s_disconnect_s, ); diff --git a/network/tests/closing.rs b/network/tests/closing.rs index 04c42c8d05..88eedc9439 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -1,3 +1,4 @@ +#![feature(assert_matches)] //! How to read those tests: //! - in the first line we call the helper, this is only debug code. in case //! you want to have tracing for a special test you set set the bool = true @@ -18,9 +19,9 @@ //! - You sometimes see sleep(1000ms) this is used when we rely on the //! underlying TCP functionality, as this simulates client and server -use std::sync::Arc; +use std::{assert_matches::assert_matches, sync::Arc}; use tokio::runtime::Runtime; -use veloren_network::{Network, ParticipantError, Pid, Promises, StreamError}; +use veloren_network::{Network, ParticipantError, ParticipantEvent, Pid, Promises, StreamError}; mod helper; use helper::{network_participant_stream, tcp, SLEEP_EXTERNAL, SLEEP_INTERNAL}; @@ -389,15 +390,35 @@ fn close_network_scheduler_completely() { let addr = tcp(); r.block_on(n_a.listen(addr.0)).unwrap(); let p_b = r.block_on(n_b.connect(addr.1)).unwrap(); + assert_matches!( + r.block_on(p_b.fetch_event()), + Ok(ParticipantEvent::ChannelCreated(_)) + ); let mut s1_b = r.block_on(p_b.open(4, Promises::empty(), 0)).unwrap(); s1_b.send("HelloWorld").unwrap(); let p_a = r.block_on(n_a.connected()).unwrap(); + assert_matches!( + r.block_on(p_a.fetch_event()), + Ok(ParticipantEvent::ChannelCreated(_)) + ); + assert_matches!(p_a.try_fetch_event(), Ok(None)); + assert_matches!(p_b.try_fetch_event(), Ok(None)); let mut s1_a = r.block_on(p_a.opened()).unwrap(); assert_eq!(r.block_on(s1_a.recv()), Ok("HelloWorld".to_string())); drop(n_a); drop(n_b); std::thread::sleep(SLEEP_EXTERNAL); //p_b is INTERNAL, but p_a is EXTERNAL + assert_matches!( + p_a.try_fetch_event(), + Ok(Some(ParticipantEvent::ChannelDeleted(_))) + ); + assert_matches!( + r.block_on(p_b.fetch_event()), + Ok(ParticipantEvent::ChannelDeleted(_)) + ); + assert_matches!(p_a.try_fetch_event(), Err(_)); + assert_matches!(p_b.try_fetch_event(), Err(_)); drop(p_b); drop(p_a); diff --git a/network/tests/integration.rs b/network/tests/integration.rs index 84d54d6211..473844b31a 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -1,10 +1,11 @@ +#![feature(assert_matches)] use std::sync::Arc; use tokio::runtime::Runtime; use veloren_network::{NetworkError, StreamError}; mod helper; use helper::{mpsc, network_participant_stream, quic, tcp, udp, SLEEP_EXTERNAL, SLEEP_INTERNAL}; use std::io::ErrorKind; -use veloren_network::{ConnectAddr, ListenAddr, Network, Pid, Promises}; +use veloren_network::{ConnectAddr, ListenAddr, Network, ParticipantEvent, Pid, Promises}; #[test] fn stream_simple() { @@ -307,3 +308,55 @@ fn listen_on_ipv6_doesnt_block_ipv4() { drop((s1_a, s1_b, _n_a, _n_b, _p_a, _p_b)); drop((s1_a2, s1_b2, _n_a2, _n_b2, _p_a2, _p_b2)); //clean teardown } + +#[test] +fn check_correct_channel_events() { + let (_, _) = helper::setup(false, 0); + let con_addr = tcp(); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(con_addr.clone()); + + let event_a = r.block_on(p_a.fetch_event()).unwrap(); + let event_b = r.block_on(p_b.fetch_event()).unwrap(); + if let ConnectAddr::Tcp(listen_addr) = con_addr.1 { + match event_a { + ParticipantEvent::ChannelCreated(ConnectAddr::Tcp(socket_addr)) => { + assert_ne!(socket_addr, listen_addr); + assert_eq!(socket_addr.ip(), std::net::Ipv4Addr::LOCALHOST); + }, + e => panic!("wrong event {:?}", e), + } + match event_b { + ParticipantEvent::ChannelCreated(ConnectAddr::Tcp(socket_addr)) => { + assert_eq!(socket_addr, listen_addr); + }, + e => panic!("wrong event {:?}", e), + } + } else { + unreachable!(); + } + + std::thread::sleep(SLEEP_EXTERNAL); + drop((_n_a, _n_b)); //drop network + + let event_a = r.block_on(p_a.fetch_event()).unwrap(); + let event_b = r.block_on(p_b.fetch_event()).unwrap(); + if let ConnectAddr::Tcp(listen_addr) = con_addr.1 { + match event_a { + ParticipantEvent::ChannelDeleted(ConnectAddr::Tcp(socket_addr)) => { + assert_ne!(socket_addr, listen_addr); + assert_eq!(socket_addr.ip(), std::net::Ipv4Addr::LOCALHOST); + }, + e => panic!("wrong event {:?}", e), + } + match event_b { + ParticipantEvent::ChannelDeleted(ConnectAddr::Tcp(socket_addr)) => { + assert_eq!(socket_addr, listen_addr); + }, + e => panic!("wrong event {:?}", e), + } + } else { + unreachable!(); + } + + drop((p_a, p_b)); //clean teardown +} diff --git a/server/src/sys/msg/ping.rs b/server/src/sys/msg/ping.rs index 82e42612a3..8224133d3e 100644 --- a/server/src/sys/msg/ping.rs +++ b/server/src/sys/msg/ping.rs @@ -41,6 +41,10 @@ impl<'a> System<'a> for Sys { let mut server_emitter = server_event_bus.emitter(); for (entity, client) in (&entities, &clients).join() { + // ignore network events + while let Some(Ok(Some(_))) = client.participant.as_ref().map(|p| p.try_fetch_event()) { + } + let res = super::try_recv_all(client, 4, Self::handle_ping_msg); match res {