From 653fb065e02de3253a2f26d8cb2f610a906e6e97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Tue, 27 Apr 2021 17:59:36 +0200 Subject: [PATCH] extract protocol specific listen code from scheduler and move it to channel.rs --- network/protocol/src/quic.rs | 5 +- network/src/api.rs | 34 ++-- network/src/channel.rs | 272 +++++++++++++++++++++++--------- network/src/message.rs | 2 +- network/src/metrics.rs | 2 + network/src/participant.rs | 10 +- network/src/scheduler.rs | 291 ++++++++--------------------------- network/tests/integration.rs | 2 +- 8 files changed, 300 insertions(+), 318 deletions(-) diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs index a10764491b..a4dfa328d1 100644 --- a/network/protocol/src/quic.rs +++ b/network/protocol/src/quic.rs @@ -451,7 +451,10 @@ where m.data.extend_from_slice(&data); if m.data.len() == m.length as usize { // finished, yay - let m = self.incoming.remove(&mid).unwrap(); + let m = self + .incoming + .remove(&mid) + .ok_or(ProtocolError::Violated)?; self.metrics.rmsg_ob( m.sid, RemoveReason::Finished, diff --git a/network/src/api.rs b/network/src/api.rs index ad95dd3419..0da58aa6d5 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -145,8 +145,8 @@ pub struct StreamParams { /// [`Arc`](std::sync::Arc) as all commands have internal mutability. /// /// The `Network` has methods to [`connect`] to other [`Participants`] actively -/// via their [`ProtocolConnectAddr`], or [`listen`] passively for [`connected`] -/// [`Participants`] via [`ProtocolListenAddr`]. +/// via their [`ConnectAddr`], or [`listen`] passively for [`connected`] +/// [`Participants`] via [`ListenAddr`]. /// /// Too guarantee a clean shutdown, the [`Runtime`] MUST NOT be droped before /// the Network. @@ -178,6 +178,8 @@ pub struct StreamParams { /// [`connect`]: Network::connect /// [`listen`]: Network::listen /// [`connected`]: Network::connected +/// [`ConnectAddr`]: crate::api::ConnectAddr +/// [`ListenAddr`]: crate::api::ListenAddr pub struct Network { local_pid: Pid, participant_disconnect_sender: Arc>>, @@ -293,7 +295,7 @@ impl Network { } } - /// starts listening on an [`ProtocolListenAddr`]. + /// starts listening on an [`ListenAddr`]. /// When the method returns the `Network` is ready to listen for incoming /// connections OR has returned a [`NetworkError`] (e.g. port already used). /// You can call [`connected`] to asynchrony wait for a [`Participant`] to @@ -303,7 +305,7 @@ impl Network { /// # Examples /// ```ignore /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolListenAddr}; + /// use veloren_network::{Network, Pid, ListenAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally @@ -311,10 +313,10 @@ impl Network { /// let network = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { /// network - /// .listen(ProtocolListenAddr::Tcp("127.0.0.1:2000".parse().unwrap())) + /// .listen(ListenAddr::Tcp("127.0.0.1:2000".parse().unwrap())) /// .await?; /// network - /// .listen(ProtocolListenAddr::Udp("127.0.0.1:2001".parse().unwrap())) + /// .listen(ListenAddr::Udp("127.0.0.1:2001".parse().unwrap())) /// .await?; /// drop(network); /// # Ok(()) @@ -323,6 +325,7 @@ impl Network { /// ``` /// /// [`connected`]: Network::connected + /// [`ListenAddr`]: crate::api::ListenAddr #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] pub async fn listen(&self, address: ListenAddr) -> Result<(), NetworkError> { let (s2a_result_s, s2a_result_r) = oneshot::channel::>(); @@ -339,13 +342,13 @@ impl Network { } } - /// starts connection to an [`ProtocolConnectAddr`]. + /// starts connection to an [`ConnectAddr`]. /// When the method returns the Network either returns a [`Participant`] /// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g. /// can't connect, or invalid Handshake) # Examples /// ```ignore /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolListenAddr, ProtocolConnectAddr}; + /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above @@ -353,16 +356,16 @@ impl Network { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// # remote.listen(ProtocolListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; - /// # remote.listen(ProtocolListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; /// let p1 = network - /// .connect(ProtocolConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap())) + /// .connect(ConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap())) /// .await?; /// # //this doesn't work yet, so skip the test /// # //TODO fixme! /// # return Ok(()); /// let p2 = network - /// .connect(ProtocolConnectAddr::Udp("127.0.0.1:2011".parse().unwrap())) + /// .connect(ConnectAddr::Udp("127.0.0.1:2011".parse().unwrap())) /// .await?; /// assert_eq!(&p1, &p2); /// # Ok(()) @@ -374,13 +377,13 @@ impl Network { /// ``` /// Usually the `Network` guarantees that a operation on a [`Participant`] /// succeeds, e.g. by automatic retrying unless it fails completely e.g. by - /// disconnecting from the remote. If 2 [`ProtocolConnectAddres`] you + /// disconnecting from the remote. If 2 [`ConnectAddr] you /// `connect` to belongs to the same [`Participant`], you get the same /// [`Participant`] as a result. This is useful e.g. by connecting to /// the same [`Participant`] via multiple Protocols. /// /// [`Streams`]: crate::api::Stream - /// [`ProtocolConnectAddres`]: crate::api::ProtocolConnectAddr + /// [`ConnectAddr`]: crate::api::ConnectAddr #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] pub async fn connect(&self, address: ConnectAddr) -> Result { let (pid_sender, pid_receiver) = @@ -403,7 +406,7 @@ impl Network { Ok(participant) } - /// returns a [`Participant`] created from a [`ProtocolListenAddr`] you + /// returns a [`Participant`] created from a [`ListenAddr`] you /// called [`listen`] on before. This function will either return a /// working [`Participant`] ready to open [`Streams`] on OR has returned /// a [`NetworkError`] (e.g. Network got closed) @@ -437,6 +440,7 @@ impl Network { /// /// [`Streams`]: crate::api::Stream /// [`listen`]: crate::api::Network::listen + /// [`ListenAddr`]: crate::api::ListenAddr #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn connected(&self) -> Result { let participant = self.connected_receiver.lock().await.recv().await?; diff --git a/network/src/channel.rs b/network/src/channel.rs index fe3bff971e..03930c03ec 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,21 +1,34 @@ use crate::api::NetworkConnectError; use async_trait::async_trait; use bytes::BytesMut; +use futures_util::FutureExt; +#[cfg(feature = "quic")] +use futures_util::StreamExt; use network_protocol::{ Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, - ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, - QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol, Sid, TcpRecvProtocol, + ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, }; -use std::{sync::Arc, time::Duration}; +#[cfg(feature = "quic")] +use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol}; +use std::{ + collections::HashMap, + io, + net::SocketAddr, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, - sync::{mpsc, oneshot}, + select, + sync::{mpsc, oneshot, Mutex}, }; -use tokio_stream::StreamExt; -use tracing::{info, trace}; +use tracing::{error, info, trace, warn}; #[allow(clippy::large_enum_variant)] #[derive(Debug)] @@ -42,32 +55,67 @@ pub(crate) enum RecvProtocols { Quic(QuicRecvProtocol), } +lazy_static::lazy_static! { + pub(crate) static ref MPSC_POOL: Mutex>> = { + Mutex::new(HashMap::new()) + }; +} + +pub(crate) type C2cMpscConnect = ( + mpsc::Sender, + oneshot::Sender>, +); + impl Protocols { const MPSC_CHANNEL_BOUND: usize = 1000; pub(crate) async fn with_tcp_connect( - addr: std::net::SocketAddr, - cid: Cid, - metrics: Arc, + addr: SocketAddr, + metrics: ProtocolMetricCache, ) -> Result { - let stream = match net::TcpStream::connect(addr).await { - Ok(stream) => stream, - Err(e) => { - return Err(crate::api::NetworkConnectError::Io(e)); - }, - }; - info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - Ok(Protocols::new_tcp(stream, cid, metrics)) + let stream = net::TcpStream::connect(addr) + .await + .map_err(NetworkConnectError::Io)?; + info!( + "Connecting Tcp to: {}", + stream.peer_addr().map_err(NetworkConnectError::Io)? + ); + Ok(Self::new_tcp(stream, metrics)) } - pub(crate) fn new_tcp( - stream: tokio::net::TcpStream, - cid: Cid, + pub(crate) async fn with_tcp_listen( + addr: SocketAddr, + cids: Arc, metrics: Arc, - ) -> Self { - let (r, w) = stream.into_split(); - let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let listener = net::TcpListener::bind(addr).await?; + trace!(?addr, "Tcp Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some(data) = select! { + next = listener.accept().fuse() => Some(next), + _ = &mut end_receiver => None, + } { + let (stream, remote_addr) = match data { + Ok((s, p)) => (s, p), + Err(e) => { + trace!(?e, "TcpStream Error, ignoring connection attempt"); + continue; + }, + }; + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?remote_addr, ?cid, "Accepting Tcp from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + let _ = c2s_protocol_s.send((Self::new_tcp(stream, metrics.clone()), cid)); + } + }); + Ok(()) + } + pub(crate) fn new_tcp(stream: tokio::net::TcpStream, metrics: ProtocolMetricCache) -> Self { + let (r, w) = stream.into_split(); let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone()); let rp = TcpRecvProtocol::new( TcpSink { @@ -81,70 +129,104 @@ impl Protocols { pub(crate) async fn with_mpsc_connect( addr: u64, - cid: Cid, - metrics: Arc, + metrics: ProtocolMetricCache, ) -> Result { - let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) { - Some(s) => s.clone(), - None => { - return Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::NotConnected, + let mpsc_s = MPSC_POOL + .lock() + .await + .get(&addr) + .ok_or_else(|| { + NetworkConnectError::Io(io::Error::new( + io::ErrorKind::NotConnected, "no mpsc listen on this addr", - ))); - }, - }; + )) + })? + .clone(); let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); - if mpsc_s + mpsc_s .send((remote_to_local_s, local_to_remote_oneshot_s)) - .is_err() - { - return Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "mpsc pipe broke during connect", - ))); - } - let local_to_remote_s = match local_to_remote_oneshot_r.await { - Ok(s) => s, - Err(e) => { - return Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - e, - ))); - }, - }; + .map_err(|_| { + NetworkConnectError::Io(io::Error::new( + io::ErrorKind::BrokenPipe, + "mpsc pipe broke during connect", + )) + })?; + let local_to_remote_s = local_to_remote_oneshot_r + .await + .map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?; info!(?addr, "Connecting Mpsc"); Ok(Self::new_mpsc( local_to_remote_s, remote_to_local_r, - cid, metrics, )) } + pub(crate) async fn with_mpsc_listen( + addr: u64, + cids: Arc, + metrics: Arc, + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); + MPSC_POOL.lock().await.insert(addr, mpsc_s); + trace!(?addr, "Mpsc Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some((local_to_remote_s, local_remote_to_local_s)) = select! { + next = mpsc_r.recv().fuse() => next, + _ = &mut end_receiver => None, + } { + let (remote_to_local_s, remote_to_local_r) = + mpsc::channel(Self::MPSC_CHANNEL_BOUND); + if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) { + error!(?e, "mpsc listen aborted"); + } + + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?addr, ?cid, "Accepting Mpsc from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + let _ = c2s_protocol_s.send(( + Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()), + cid, + )); + } + warn!("MpscStream Failed, stopping"); + }); + Ok(()) + } + pub(crate) fn new_mpsc( sender: mpsc::Sender, receiver: mpsc::Receiver, - cid: Cid, - metrics: Arc, + metrics: ProtocolMetricCache, ) -> Self { - let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone()); let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics); Protocols::Mpsc((sp, rp)) } + #[cfg(feature = "quic")] pub(crate) async fn with_quic_connect( - addr: std::net::SocketAddr, + addr: SocketAddr, config: quinn::ClientConfig, name: String, - cid: Cid, - metrics: Arc, + metrics: ProtocolMetricCache, ) -> Result { let config = config.clone(); let endpoint = quinn::Endpoint::builder(); - let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) { + + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + let bindsock = match addr { + SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0), + SocketAddr::V6(_) => { + SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0) + }, + }; + let (endpoint, _) = match endpoint.bind(&bindsock) { Ok(e) => e, Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)), }; @@ -164,7 +246,7 @@ impl Protocols { e, )) })?; - Protocols::new_quic(connection, false, cid, metrics) + Self::new_quic(connection, false, metrics) .await .map_err(|e| { trace!(?e, "error with quic"); @@ -175,15 +257,60 @@ impl Protocols { }) } + #[cfg(feature = "quic")] + pub(crate) async fn with_quic_listen( + addr: SocketAddr, + server_config: quinn::ServerConfig, + cids: Arc, + metrics: Arc, + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let mut endpoint = quinn::Endpoint::builder(); + endpoint.listen(server_config); + let (_endpoint, mut listener) = match endpoint.bind(&addr) { + Ok(v) => v, + Err(quinn::EndpointError::Socket(e)) => return Err(e), + }; + trace!(?addr, "Quic Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some(Some(connecting)) = select! { + next = listener.next().fuse() => Some(next), + _ = &mut end_receiver => None, + } { + let remote_addr = connecting.remote_address(); + let connection = match connecting.await { + Ok(c) => c, + Err(e) => { + tracing::debug!(?e, ?remote_addr, "skipping connection attempt"); + continue; + }, + }; + + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?remote_addr, ?cid, "Accepting Quic from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + match Protocols::new_quic(connection, true, metrics).await { + Ok(quic) => { + let _ = c2s_protocol_s.send((quic, cid)); + }, + Err(e) => { + trace!(?e, "failed to start quic"); + continue; + }, + } + } + }); + Ok(()) + } + #[cfg(feature = "quic")] pub(crate) async fn new_quic( mut connection: quinn::NewConnection, listen: bool, - cid: Cid, - metrics: Arc, + metrics: ProtocolMetricCache, ) -> Result { - let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let (sendstream, recvstream) = if listen { connection.connection.open_bi().await? } else { @@ -191,7 +318,7 @@ impl Protocols { .bi_streams .next() .await - .ok_or_else(|| quinn::ConnectionError::LocallyClosed)?? + .ok_or(quinn::ConnectionError::LocallyClosed)?? }; let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel(); let streams_s_clone = recvstreams_s.clone(); @@ -521,7 +648,8 @@ impl UnreliableSink for QuicSink { mod tests { use super::*; use bytes::Bytes; - use network_protocol::{Promises, RecvProtocol, SendProtocol}; + use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol}; + use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; #[tokio::test] @@ -533,9 +661,9 @@ mod tests { }); let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); let (_listener, server) = r1.await.unwrap(); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); - let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap())); + let client = Protocols::new_tcp(client, metrics.clone()); + let server = Protocols::new_tcp(server, metrics); let (mut s, _) = client.split(); let (_, mut r) = server.split(); let event = ProtocolEvent::OpenStream { @@ -582,9 +710,9 @@ mod tests { }); let client = TcpStream::connect("127.0.0.1:5001").await.unwrap(); let (_listener, server) = r1.await.unwrap(); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); - let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap())); + let client = Protocols::new_tcp(client, metrics.clone()); + let server = Protocols::new_tcp(server, metrics); let (s, _) = client.split(); let (_, mut r) = server.split(); let e = tokio::spawn(async move { r.recv().await }); diff --git a/network/src/message.rs b/network/src/message.rs index 5c0029cf16..f821511450 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -30,7 +30,7 @@ impl Message { /// # Example /// for example coding, see [`send_raw`] /// - /// [`send_raw`]: Stream::send_raw + /// [`send_raw`]: crate::api::Stream::send_raw /// [`Participants`]: crate::api::Participant /// [`compress`]: lz_fear::raw::compress2 /// [`Message::serialize`]: crate::message::Message::serialize diff --git a/network/src/metrics.rs b/network/src/metrics.rs index d532347140..f3341e392b 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -251,6 +251,7 @@ fn protocolconnect_name(protocol: &ConnectAddr) -> &str { ConnectAddr::Tcp(_) => "tcp", ConnectAddr::Udp(_) => "udp", ConnectAddr::Mpsc(_) => "mpsc", + #[cfg(feature = "quic")] ConnectAddr::Quic(_, _, _) => "quic", } } @@ -261,6 +262,7 @@ fn protocollisten_name(protocol: &ListenAddr) -> &str { ListenAddr::Tcp(_) => "tcp", ListenAddr::Udp(_) => "udp", ListenAddr::Mpsc(_) => "mpsc", + #[cfg(feature = "quic")] ListenAddr::Quic(_, _) => "quic", } } diff --git a/network/src/participant.rs b/network/src/participant.rs index a06321201c..2735fd5bdd 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -756,7 +756,7 @@ impl BParticipant { #[cfg(test)] mod tests { use super::*; - use network_protocol::ProtocolMetrics; + use network_protocol::{ProtocolMetricCache, ProtocolMetrics}; use tokio::{ runtime::Runtime, sync::{mpsc, oneshot}, @@ -816,14 +816,16 @@ mod tests { ) -> Protocols { let (s1, r1) = mpsc::channel(100); let (s2, r2) = mpsc::channel(100); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let p1 = Protocols::new_mpsc(s1, r2, cid, Arc::clone(&metrics)); + let met = Arc::new(ProtocolMetrics::new().unwrap()); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met)); + let p1 = Protocols::new_mpsc(s1, r2, metrics); let (complete_s, complete_r) = oneshot::channel(); create_channel .send((cid, Sid::new(0), p1, complete_s)) .unwrap(); complete_r.await.unwrap(); - Protocols::new_mpsc(s2, r1, cid, Arc::clone(&metrics)) + let metrics = ProtocolMetricCache::new(&cid.to_string(), met); + Protocols::new_mpsc(s2, r1, metrics) } #[test] diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 1e8d5c69a8..a232be440b 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -4,8 +4,8 @@ use crate::{ metrics::{NetworkMetrics, ProtocolInfo}, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, }; -use futures_util::{FutureExt, StreamExt}; -use network_protocol::{Cid, MpscMsg, Pid, ProtocolMetrics}; +use futures_util::StreamExt; +use network_protocol::{Cid, Pid, ProtocolMetricCache, ProtocolMetrics}; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -18,7 +18,7 @@ use std::{ time::Duration, }; use tokio::{ - io, net, select, + io, sync::{mpsc, oneshot, Mutex}, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -33,12 +33,6 @@ use tracing::*; // - w: wire // - c: channel/handshake -lazy_static::lazy_static! { - pub(crate) static ref MPSC_POOL: Mutex>> = { - Mutex::new(HashMap::new()) - }; -} - #[derive(Debug)] struct ParticipantInfo { secret: u128, @@ -52,10 +46,6 @@ pub(crate) type A2sConnect = ( oneshot::Sender>, ); type A2sDisconnect = (Pid, S2bShutdownBparticipant); -type S2sMpscConnect = ( - mpsc::Sender, - oneshot::Sender>, -); #[derive(Debug)] struct ControlChannels { @@ -88,8 +78,6 @@ pub struct Scheduler { } impl Scheduler { - const MPSC_CHANNEL_BOUND: usize = 1000; - pub fn new( local_pid: Pid, #[cfg(feature = "metrics")] registry: Option<&Registry>, @@ -157,7 +145,10 @@ impl Scheduler { } pub async fn run(mut self) { - let run_channels = self.run_channels.take().unwrap(); + let run_channels = self + .run_channels + .take() + .expect("run() can only be called once"); tokio::join!( self.listen_mgr(run_channels.a2s_listen_r), @@ -174,17 +165,66 @@ impl Scheduler { a2s_listen_r .for_each_concurrent(None, |(address, s2a_listen_result_s)| { let address = address; + let cids = Arc::clone(&self.channel_ids); + + #[cfg(feature = "metrics")] + let mcache = self.metrics.connect_requests_cache(&address); + + debug!(?address, "Got request to open a channel_creator"); + self.metrics.listen_request(&address); + let (s2s_stop_listening_s, s2s_stop_listening_r) = oneshot::channel::<()>(); + let (c2s_protocol_s, mut c2s_protocol_r) = mpsc::unbounded_channel(); + let metrics = Arc::clone(&self.protocol_metrics); async move { - debug!(?address, "Got request to open a channel_creator"); - self.metrics.listen_request(&address); - let (end_sender, end_receiver) = oneshot::channel::<()>(); self.channel_listener .lock() .await - .insert(address.clone().into(), end_sender); - self.channel_creator(address, end_receiver, s2a_listen_result_s) - .await; + .insert(address.clone().into(), s2s_stop_listening_s); + + #[cfg(feature = "metrics")] + mcache.inc(); + + let res = match address { + ListenAddr::Tcp(addr) => { + Protocols::with_tcp_listen( + addr, + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + #[cfg(feature = "quic")] + ListenAddr::Quic(addr, ref server_config) => { + Protocols::with_quic_listen( + addr, + server_config.clone(), + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + ListenAddr::Mpsc(addr) => { + Protocols::with_mpsc_listen( + addr, + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + _ => unimplemented!(), + }; + let _ = s2a_listen_result_s.send(res); + + while let Some((prot, cid)) = c2s_protocol_r.recv().await { + self.init_protocol(prot, cid, None, true).await; + } } }) .await; @@ -195,15 +235,16 @@ impl Scheduler { trace!("Start connect_mgr"); while let Some((addr, pid_sender)) = a2s_connect_r.recv().await { let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - let metrics = Arc::clone(&self.protocol_metrics); + let metrics = + ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics)); self.metrics.connect_request(&addr); let protocol = match addr { - ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await, + ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await, #[cfg(feature = "quic")] ConnectAddr::Quic(addr, ref config, name) => { - Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await + Protocols::with_quic_connect(addr, config.clone(), name, metrics).await }, - ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await, + ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, metrics).await, _ => unimplemented!(), }; let protocol = match protocol { @@ -327,204 +368,6 @@ impl Scheduler { trace!("Stop scheduler_shutdown_mgr"); } - async fn channel_creator( - &self, - addr: ListenAddr, - s2s_stop_listening_r: oneshot::Receiver<()>, - s2a_listen_result_s: oneshot::Sender>, - ) { - trace!(?addr, "Start up channel creator"); - #[cfg(feature = "metrics")] - let mcache = self.metrics.connect_requests_cache(&addr); - match addr { - ListenAddr::Tcp(addr) => { - let listener = match net::TcpListener::bind(addr).await { - Ok(listener) => { - s2a_listen_result_s.send(Ok(())).unwrap(); - listener - }, - Err(e) => { - info!( - ?addr, - ?e, - "Tcp bind error during listener startup" - ); - s2a_listen_result_s.send(Err(e)).unwrap(); - return; - }, - }; - trace!(?addr, "Listener bound"); - let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some(data) = select! { - next = listener.accept().fuse() => Some(next), - _ = &mut end_receiver => None, - } { - let (stream, remote_addr) = match data { - Ok((s, p)) => (s, p), - Err(e) => { - warn!(?e, "TcpStream Error, ignoring connection attempt"); - continue; - }, - }; - #[cfg(feature = "metrics")] - mcache.inc(); - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - info!(?remote_addr, ?cid, "Accepting Tcp from"); - self.init_protocol(Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) - .await; - } - }, - #[cfg(feature = "quic")] - ListenAddr::Quic(addr, ref server_config) => { - let mut endpoint = quinn::Endpoint::builder(); - endpoint.listen(server_config.clone()); - let (_endpoint, mut listener) = match endpoint.bind(&addr) { - Ok((endpoint, listener)) => { - s2a_listen_result_s.send(Ok(())).unwrap(); - (endpoint, listener) - }, - Err(quinn::EndpointError::Socket(e)) => { - info!( - ?addr, - ?e, - "Quic bind error during listener startup" - ); - s2a_listen_result_s.send(Err(e)).unwrap(); - return; - } - }; - trace!(?addr, "Listener bound"); - let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some(Some(connecting)) = select! { - next = listener.next().fuse() => Some(next), - _ = &mut end_receiver => None, - } { - let remote_addr = connecting.remote_address(); - let connection = match connecting.await { - Ok(c) => c, - Err(e) => { - debug!(?e, ?remote_addr, "skipping connection attempt"); - continue; - }, - }; - #[cfg(feature = "metrics")] - mcache.inc(); - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - info!(?remote_addr, ?cid, "Accepting Quic from"); - let quic = match Protocols::new_quic(connection, true, cid, Arc::clone(&self.protocol_metrics)).await { - Ok(quic) => quic, - Err(e) => { - trace!(?e, "failed to start quic"); - continue; - } - }; - self.init_protocol(quic, cid, None, true) - .await; - } - }, - ListenAddr::Mpsc(addr) => { - let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); - MPSC_POOL.lock().await.insert(addr, mpsc_s); - s2a_listen_result_s.send(Ok(())).unwrap(); - trace!(?addr, "Listener bound"); - - let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some((local_to_remote_s, local_remote_to_local_s)) = select! { - next = mpsc_r.recv().fuse() => next, - _ = &mut end_receiver => None, - } { - let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); - local_remote_to_local_s.send(remote_to_local_s).unwrap(); - #[cfg(feature = "metrics")] - mcache.inc(); - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - info!(?addr, ?cid, "Accepting Mpsc from"); - self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) - .await; - } - warn!("MpscStream Failed, stopping"); - },/* - ProtocolListenAddr::Udp(addr) => { - let socket = match net::UdpSocket::bind(addr).await { - Ok(socket) => { - s2a_listen_result_s.send(Ok(())).unwrap(); - Arc::new(socket) - }, - Err(e) => { - info!( - ?addr, - ?e, - "Listener couldn't be started due to error on udp bind" - ); - s2a_listen_result_s.send(Err(e)).unwrap(); - return; - }, - }; - trace!(?addr, "Listener bound"); - // receiving is done from here and will be piped to protocol as UDP does not - // have any state - let mut listeners = HashMap::new(); - let mut end_receiver = s2s_stop_listening_r.fuse(); - const UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER: usize = 9216; - let mut data = [0u8; UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER]; - while let Ok((size, remote_addr)) = select! { - next = socket.recv_from(&mut data).fuse() => next, - _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), - } { - let mut datavec = Vec::with_capacity(size); - datavec.extend_from_slice(&data[0..size]); - //Due to the async nature i cannot make of .entry() as it would lead to a still - // borrowed in another branch situation - #[allow(clippy::map_entry)] - if !listeners.contains_key(&remote_addr) { - info!("Accepting Udp from: {}", &remote_addr); - let (udp_data_sender, udp_data_receiver) = - mpsc::unbounded_channel::>(); - listeners.insert(remote_addr, udp_data_sender); - let protocol = UdpProtocol::new( - Arc::clone(&socket), - remote_addr, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - udp_data_receiver, - ); - self.init_protocol(Protocols::Udp(protocol), None, false) - .await; - } - let udp_data_sender = listeners.get_mut(&remote_addr).unwrap(); - udp_data_sender.send(datavec).unwrap(); - } - },*/ - _ => unimplemented!(), - } - trace!(?addr, "Ending channel creator"); - } - - #[allow(dead_code)] - async fn udp_single_channel_connect( - socket: Arc, - w2p_udp_package_s: mpsc::UnboundedSender>, - ) { - let addr = socket.local_addr(); - trace!(?addr, "Start udp_single_channel_connect"); - //TODO: implement real closing - let (_end_sender, end_receiver) = oneshot::channel::<()>(); - - // receiving is done from here and will be piped to protocol as UDP does not - // have any state - let mut end_receiver = end_receiver.fuse(); - let mut data = [0u8; 9216]; - while let Ok(size) = select! { - next = socket.recv(&mut data).fuse() => next, - _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), - } { - let mut datavec = Vec::with_capacity(size); - datavec.extend_from_slice(&data[0..size]); - w2p_udp_package_s.send(datavec).unwrap(); - } - trace!(?addr, "Stop udp_single_channel_connect"); - } - async fn init_protocol( &self, mut protocol: Protocols, diff --git a/network/tests/integration.rs b/network/tests/integration.rs index e81530b4f0..9d2e57bf77 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -85,7 +85,7 @@ fn stream_simple_quic() { #[test] fn stream_simple_quic_3msg() { - let (_, _) = helper::setup(true, 0); + let (_, _) = helper::setup(false, 0); let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic()); s1_a.send("Hello World").unwrap();