extract protocol specific listen code from scheduler and move it to channel.rs

This commit is contained in:
Marcel Märtens 2021-04-27 17:59:36 +02:00
parent 4afadf57dc
commit 653fb065e0
8 changed files with 300 additions and 318 deletions

View File

@ -451,7 +451,10 @@ where
m.data.extend_from_slice(&data); m.data.extend_from_slice(&data);
if m.data.len() == m.length as usize { if m.data.len() == m.length as usize {
// finished, yay // finished, yay
let m = self.incoming.remove(&mid).unwrap(); let m = self
.incoming
.remove(&mid)
.ok_or(ProtocolError::Violated)?;
self.metrics.rmsg_ob( self.metrics.rmsg_ob(
m.sid, m.sid,
RemoveReason::Finished, RemoveReason::Finished,

View File

@ -145,8 +145,8 @@ pub struct StreamParams {
/// [`Arc`](std::sync::Arc) as all commands have internal mutability. /// [`Arc`](std::sync::Arc) as all commands have internal mutability.
/// ///
/// The `Network` has methods to [`connect`] to other [`Participants`] actively /// The `Network` has methods to [`connect`] to other [`Participants`] actively
/// via their [`ProtocolConnectAddr`], or [`listen`] passively for [`connected`] /// via their [`ConnectAddr`], or [`listen`] passively for [`connected`]
/// [`Participants`] via [`ProtocolListenAddr`]. /// [`Participants`] via [`ListenAddr`].
/// ///
/// Too guarantee a clean shutdown, the [`Runtime`] MUST NOT be droped before /// Too guarantee a clean shutdown, the [`Runtime`] MUST NOT be droped before
/// the Network. /// the Network.
@ -178,6 +178,8 @@ pub struct StreamParams {
/// [`connect`]: Network::connect /// [`connect`]: Network::connect
/// [`listen`]: Network::listen /// [`listen`]: Network::listen
/// [`connected`]: Network::connected /// [`connected`]: Network::connected
/// [`ConnectAddr`]: crate::api::ConnectAddr
/// [`ListenAddr`]: crate::api::ListenAddr
pub struct Network { pub struct Network {
local_pid: Pid, local_pid: Pid,
participant_disconnect_sender: Arc<Mutex<HashMap<Pid, A2sDisconnect>>>, participant_disconnect_sender: Arc<Mutex<HashMap<Pid, A2sDisconnect>>>,
@ -293,7 +295,7 @@ impl Network {
} }
} }
/// starts listening on an [`ProtocolListenAddr`]. /// starts listening on an [`ListenAddr`].
/// When the method returns the `Network` is ready to listen for incoming /// When the method returns the `Network` is ready to listen for incoming
/// connections OR has returned a [`NetworkError`] (e.g. port already used). /// connections OR has returned a [`NetworkError`] (e.g. port already used).
/// You can call [`connected`] to asynchrony wait for a [`Participant`] to /// You can call [`connected`] to asynchrony wait for a [`Participant`] to
@ -303,7 +305,7 @@ impl Network {
/// # Examples /// # Examples
/// ```ignore /// ```ignore
/// use tokio::runtime::Runtime; /// use tokio::runtime::Runtime;
/// use veloren_network::{Network, Pid, ProtocolListenAddr}; /// use veloren_network::{Network, Pid, ListenAddr};
/// ///
/// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
/// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally /// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally
@ -311,10 +313,10 @@ impl Network {
/// let network = Network::new(Pid::new(), &runtime); /// let network = Network::new(Pid::new(), &runtime);
/// runtime.block_on(async { /// runtime.block_on(async {
/// network /// network
/// .listen(ProtocolListenAddr::Tcp("127.0.0.1:2000".parse().unwrap())) /// .listen(ListenAddr::Tcp("127.0.0.1:2000".parse().unwrap()))
/// .await?; /// .await?;
/// network /// network
/// .listen(ProtocolListenAddr::Udp("127.0.0.1:2001".parse().unwrap())) /// .listen(ListenAddr::Udp("127.0.0.1:2001".parse().unwrap()))
/// .await?; /// .await?;
/// drop(network); /// drop(network);
/// # Ok(()) /// # Ok(())
@ -323,6 +325,7 @@ impl Network {
/// ``` /// ```
/// ///
/// [`connected`]: Network::connected /// [`connected`]: Network::connected
/// [`ListenAddr`]: crate::api::ListenAddr
#[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))]
pub async fn listen(&self, address: ListenAddr) -> Result<(), NetworkError> { pub async fn listen(&self, address: ListenAddr) -> Result<(), NetworkError> {
let (s2a_result_s, s2a_result_r) = oneshot::channel::<tokio::io::Result<()>>(); let (s2a_result_s, s2a_result_r) = oneshot::channel::<tokio::io::Result<()>>();
@ -339,13 +342,13 @@ impl Network {
} }
} }
/// starts connection to an [`ProtocolConnectAddr`]. /// starts connection to an [`ConnectAddr`].
/// When the method returns the Network either returns a [`Participant`] /// When the method returns the Network either returns a [`Participant`]
/// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g. /// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g.
/// can't connect, or invalid Handshake) # Examples /// can't connect, or invalid Handshake) # Examples
/// ```ignore /// ```ignore
/// use tokio::runtime::Runtime; /// use tokio::runtime::Runtime;
/// use veloren_network::{Network, Pid, ProtocolListenAddr, ProtocolConnectAddr}; /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr};
/// ///
/// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
/// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above /// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above
@ -353,16 +356,16 @@ impl Network {
/// let network = Network::new(Pid::new(), &runtime); /// let network = Network::new(Pid::new(), &runtime);
/// # let remote = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime);
/// runtime.block_on(async { /// runtime.block_on(async {
/// # remote.listen(ProtocolListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?;
/// # remote.listen(ProtocolListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; /// # remote.listen(ListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?;
/// let p1 = network /// let p1 = network
/// .connect(ProtocolConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap())) /// .connect(ConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap()))
/// .await?; /// .await?;
/// # //this doesn't work yet, so skip the test /// # //this doesn't work yet, so skip the test
/// # //TODO fixme! /// # //TODO fixme!
/// # return Ok(()); /// # return Ok(());
/// let p2 = network /// let p2 = network
/// .connect(ProtocolConnectAddr::Udp("127.0.0.1:2011".parse().unwrap())) /// .connect(ConnectAddr::Udp("127.0.0.1:2011".parse().unwrap()))
/// .await?; /// .await?;
/// assert_eq!(&p1, &p2); /// assert_eq!(&p1, &p2);
/// # Ok(()) /// # Ok(())
@ -374,13 +377,13 @@ impl Network {
/// ``` /// ```
/// Usually the `Network` guarantees that a operation on a [`Participant`] /// Usually the `Network` guarantees that a operation on a [`Participant`]
/// succeeds, e.g. by automatic retrying unless it fails completely e.g. by /// succeeds, e.g. by automatic retrying unless it fails completely e.g. by
/// disconnecting from the remote. If 2 [`ProtocolConnectAddres`] you /// disconnecting from the remote. If 2 [`ConnectAddr] you
/// `connect` to belongs to the same [`Participant`], you get the same /// `connect` to belongs to the same [`Participant`], you get the same
/// [`Participant`] as a result. This is useful e.g. by connecting to /// [`Participant`] as a result. This is useful e.g. by connecting to
/// the same [`Participant`] via multiple Protocols. /// the same [`Participant`] via multiple Protocols.
/// ///
/// [`Streams`]: crate::api::Stream /// [`Streams`]: crate::api::Stream
/// [`ProtocolConnectAddres`]: crate::api::ProtocolConnectAddr /// [`ConnectAddr`]: crate::api::ConnectAddr
#[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))]
pub async fn connect(&self, address: ConnectAddr) -> Result<Participant, NetworkError> { pub async fn connect(&self, address: ConnectAddr) -> Result<Participant, NetworkError> {
let (pid_sender, pid_receiver) = let (pid_sender, pid_receiver) =
@ -403,7 +406,7 @@ impl Network {
Ok(participant) Ok(participant)
} }
/// returns a [`Participant`] created from a [`ProtocolListenAddr`] you /// returns a [`Participant`] created from a [`ListenAddr`] you
/// called [`listen`] on before. This function will either return a /// called [`listen`] on before. This function will either return a
/// working [`Participant`] ready to open [`Streams`] on OR has returned /// working [`Participant`] ready to open [`Streams`] on OR has returned
/// a [`NetworkError`] (e.g. Network got closed) /// a [`NetworkError`] (e.g. Network got closed)
@ -437,6 +440,7 @@ impl Network {
/// ///
/// [`Streams`]: crate::api::Stream /// [`Streams`]: crate::api::Stream
/// [`listen`]: crate::api::Network::listen /// [`listen`]: crate::api::Network::listen
/// [`ListenAddr`]: crate::api::ListenAddr
#[instrument(name="network", skip(self), fields(p = %self.local_pid))] #[instrument(name="network", skip(self), fields(p = %self.local_pid))]
pub async fn connected(&self) -> Result<Participant, NetworkError> { pub async fn connected(&self) -> Result<Participant, NetworkError> {
let participant = self.connected_receiver.lock().await.recv().await?; let participant = self.connected_receiver.lock().await.recv().await?;

View File

@ -1,21 +1,34 @@
use crate::api::NetworkConnectError; use crate::api::NetworkConnectError;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::BytesMut; use bytes::BytesMut;
use futures_util::FutureExt;
#[cfg(feature = "quic")]
use futures_util::StreamExt;
use network_protocol::{ use network_protocol::{
Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid,
ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol, Sid, TcpRecvProtocol,
TcpSendProtocol, UnreliableDrain, UnreliableSink, TcpSendProtocol, UnreliableDrain, UnreliableSink,
}; };
use std::{sync::Arc, time::Duration}; #[cfg(feature = "quic")]
use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol};
use std::{
collections::HashMap,
io,
net::SocketAddr,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
net, net,
net::tcp::{OwnedReadHalf, OwnedWriteHalf}, net::tcp::{OwnedReadHalf, OwnedWriteHalf},
sync::{mpsc, oneshot}, select,
sync::{mpsc, oneshot, Mutex},
}; };
use tokio_stream::StreamExt; use tracing::{error, info, trace, warn};
use tracing::{info, trace};
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Debug)] #[derive(Debug)]
@ -42,32 +55,67 @@ pub(crate) enum RecvProtocols {
Quic(QuicRecvProtocol<QuicSink>), Quic(QuicRecvProtocol<QuicSink>),
} }
lazy_static::lazy_static! {
pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<C2cMpscConnect>>> = {
Mutex::new(HashMap::new())
};
}
pub(crate) type C2cMpscConnect = (
mpsc::Sender<MpscMsg>,
oneshot::Sender<mpsc::Sender<MpscMsg>>,
);
impl Protocols { impl Protocols {
const MPSC_CHANNEL_BOUND: usize = 1000; const MPSC_CHANNEL_BOUND: usize = 1000;
pub(crate) async fn with_tcp_connect( pub(crate) async fn with_tcp_connect(
addr: std::net::SocketAddr, addr: SocketAddr,
cid: Cid, metrics: ProtocolMetricCache,
metrics: Arc<ProtocolMetrics>,
) -> Result<Self, NetworkConnectError> { ) -> Result<Self, NetworkConnectError> {
let stream = match net::TcpStream::connect(addr).await { let stream = net::TcpStream::connect(addr)
Ok(stream) => stream, .await
Err(e) => { .map_err(NetworkConnectError::Io)?;
return Err(crate::api::NetworkConnectError::Io(e)); info!(
}, "Connecting Tcp to: {}",
}; stream.peer_addr().map_err(NetworkConnectError::Io)?
info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); );
Ok(Protocols::new_tcp(stream, cid, metrics)) Ok(Self::new_tcp(stream, metrics))
} }
pub(crate) fn new_tcp( pub(crate) async fn with_tcp_listen(
stream: tokio::net::TcpStream, addr: SocketAddr,
cid: Cid, cids: Arc<AtomicU64>,
metrics: Arc<ProtocolMetrics>, metrics: Arc<ProtocolMetrics>,
) -> Self { s2s_stop_listening_r: oneshot::Receiver<()>,
let (r, w) = stream.into_split(); c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); ) -> std::io::Result<()> {
let listener = net::TcpListener::bind(addr).await?;
trace!(?addr, "Tcp Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
tokio::spawn(async move {
while let Some(data) = select! {
next = listener.accept().fuse() => Some(next),
_ = &mut end_receiver => None,
} {
let (stream, remote_addr) = match data {
Ok((s, p)) => (s, p),
Err(e) => {
trace!(?e, "TcpStream Error, ignoring connection attempt");
continue;
},
};
let cid = cids.fetch_add(1, Ordering::Relaxed);
info!(?remote_addr, ?cid, "Accepting Tcp from");
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
let _ = c2s_protocol_s.send((Self::new_tcp(stream, metrics.clone()), cid));
}
});
Ok(())
}
pub(crate) fn new_tcp(stream: tokio::net::TcpStream, metrics: ProtocolMetricCache) -> Self {
let (r, w) = stream.into_split();
let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone()); let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone());
let rp = TcpRecvProtocol::new( let rp = TcpRecvProtocol::new(
TcpSink { TcpSink {
@ -81,70 +129,104 @@ impl Protocols {
pub(crate) async fn with_mpsc_connect( pub(crate) async fn with_mpsc_connect(
addr: u64, addr: u64,
cid: Cid, metrics: ProtocolMetricCache,
metrics: Arc<ProtocolMetrics>,
) -> Result<Self, NetworkConnectError> { ) -> Result<Self, NetworkConnectError> {
let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) { let mpsc_s = MPSC_POOL
Some(s) => s.clone(), .lock()
None => { .await
return Err(NetworkConnectError::Io(std::io::Error::new( .get(&addr)
std::io::ErrorKind::NotConnected, .ok_or_else(|| {
NetworkConnectError::Io(io::Error::new(
io::ErrorKind::NotConnected,
"no mpsc listen on this addr", "no mpsc listen on this addr",
))); ))
}, })?
}; .clone();
let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
if mpsc_s mpsc_s
.send((remote_to_local_s, local_to_remote_oneshot_s)) .send((remote_to_local_s, local_to_remote_oneshot_s))
.is_err() .map_err(|_| {
{ NetworkConnectError::Io(io::Error::new(
return Err(NetworkConnectError::Io(std::io::Error::new( io::ErrorKind::BrokenPipe,
std::io::ErrorKind::BrokenPipe,
"mpsc pipe broke during connect", "mpsc pipe broke during connect",
))); ))
} })?;
let local_to_remote_s = match local_to_remote_oneshot_r.await { let local_to_remote_s = local_to_remote_oneshot_r
Ok(s) => s, .await
Err(e) => { .map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
return Err(NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
e,
)));
},
};
info!(?addr, "Connecting Mpsc"); info!(?addr, "Connecting Mpsc");
Ok(Self::new_mpsc( Ok(Self::new_mpsc(
local_to_remote_s, local_to_remote_s,
remote_to_local_r, remote_to_local_r,
cid,
metrics, metrics,
)) ))
} }
pub(crate) async fn with_mpsc_listen(
addr: u64,
cids: Arc<AtomicU64>,
metrics: Arc<ProtocolMetrics>,
s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> {
let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
MPSC_POOL.lock().await.insert(addr, mpsc_s);
trace!(?addr, "Mpsc Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
tokio::spawn(async move {
while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
next = mpsc_r.recv().fuse() => next,
_ = &mut end_receiver => None,
} {
let (remote_to_local_s, remote_to_local_r) =
mpsc::channel(Self::MPSC_CHANNEL_BOUND);
if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) {
error!(?e, "mpsc listen aborted");
}
let cid = cids.fetch_add(1, Ordering::Relaxed);
info!(?addr, ?cid, "Accepting Mpsc from");
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
let _ = c2s_protocol_s.send((
Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()),
cid,
));
}
warn!("MpscStream Failed, stopping");
});
Ok(())
}
pub(crate) fn new_mpsc( pub(crate) fn new_mpsc(
sender: mpsc::Sender<MpscMsg>, sender: mpsc::Sender<MpscMsg>,
receiver: mpsc::Receiver<MpscMsg>, receiver: mpsc::Receiver<MpscMsg>,
cid: Cid, metrics: ProtocolMetricCache,
metrics: Arc<ProtocolMetrics>,
) -> Self { ) -> Self {
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone()); let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone());
let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics); let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics);
Protocols::Mpsc((sp, rp)) Protocols::Mpsc((sp, rp))
} }
#[cfg(feature = "quic")]
pub(crate) async fn with_quic_connect( pub(crate) async fn with_quic_connect(
addr: std::net::SocketAddr, addr: SocketAddr,
config: quinn::ClientConfig, config: quinn::ClientConfig,
name: String, name: String,
cid: Cid, metrics: ProtocolMetricCache,
metrics: Arc<ProtocolMetrics>,
) -> Result<Self, NetworkConnectError> { ) -> Result<Self, NetworkConnectError> {
let config = config.clone(); let config = config.clone();
let endpoint = quinn::Endpoint::builder(); let endpoint = quinn::Endpoint::builder();
let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
let bindsock = match addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
SocketAddr::V6(_) => {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
},
};
let (endpoint, _) = match endpoint.bind(&bindsock) {
Ok(e) => e, Ok(e) => e,
Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)), Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)),
}; };
@ -164,7 +246,7 @@ impl Protocols {
e, e,
)) ))
})?; })?;
Protocols::new_quic(connection, false, cid, metrics) Self::new_quic(connection, false, metrics)
.await .await
.map_err(|e| { .map_err(|e| {
trace!(?e, "error with quic"); trace!(?e, "error with quic");
@ -175,15 +257,60 @@ impl Protocols {
}) })
} }
#[cfg(feature = "quic")]
pub(crate) async fn with_quic_listen(
addr: SocketAddr,
server_config: quinn::ServerConfig,
cids: Arc<AtomicU64>,
metrics: Arc<ProtocolMetrics>,
s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> {
let mut endpoint = quinn::Endpoint::builder();
endpoint.listen(server_config);
let (_endpoint, mut listener) = match endpoint.bind(&addr) {
Ok(v) => v,
Err(quinn::EndpointError::Socket(e)) => return Err(e),
};
trace!(?addr, "Quic Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
tokio::spawn(async move {
while let Some(Some(connecting)) = select! {
next = listener.next().fuse() => Some(next),
_ = &mut end_receiver => None,
} {
let remote_addr = connecting.remote_address();
let connection = match connecting.await {
Ok(c) => c,
Err(e) => {
tracing::debug!(?e, ?remote_addr, "skipping connection attempt");
continue;
},
};
let cid = cids.fetch_add(1, Ordering::Relaxed);
info!(?remote_addr, ?cid, "Accepting Quic from");
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
match Protocols::new_quic(connection, true, metrics).await {
Ok(quic) => {
let _ = c2s_protocol_s.send((quic, cid));
},
Err(e) => {
trace!(?e, "failed to start quic");
continue;
},
}
}
});
Ok(())
}
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
pub(crate) async fn new_quic( pub(crate) async fn new_quic(
mut connection: quinn::NewConnection, mut connection: quinn::NewConnection,
listen: bool, listen: bool,
cid: Cid, metrics: ProtocolMetricCache,
metrics: Arc<ProtocolMetrics>,
) -> Result<Self, quinn::ConnectionError> { ) -> Result<Self, quinn::ConnectionError> {
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
let (sendstream, recvstream) = if listen { let (sendstream, recvstream) = if listen {
connection.connection.open_bi().await? connection.connection.open_bi().await?
} else { } else {
@ -191,7 +318,7 @@ impl Protocols {
.bi_streams .bi_streams
.next() .next()
.await .await
.ok_or_else(|| quinn::ConnectionError::LocallyClosed)?? .ok_or(quinn::ConnectionError::LocallyClosed)??
}; };
let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel(); let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
let streams_s_clone = recvstreams_s.clone(); let streams_s_clone = recvstreams_s.clone();
@ -521,7 +648,8 @@ impl UnreliableSink for QuicSink {
mod tests { mod tests {
use super::*; use super::*;
use bytes::Bytes; use bytes::Bytes;
use network_protocol::{Promises, RecvProtocol, SendProtocol}; use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
#[tokio::test] #[tokio::test]
@ -533,9 +661,9 @@ mod tests {
}); });
let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); let client = TcpStream::connect("127.0.0.1:5000").await.unwrap();
let (_listener, server) = r1.await.unwrap(); let (_listener, server) = r1.await.unwrap();
let metrics = Arc::new(ProtocolMetrics::new().unwrap()); let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); let client = Protocols::new_tcp(client, metrics.clone());
let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); let server = Protocols::new_tcp(server, metrics);
let (mut s, _) = client.split(); let (mut s, _) = client.split();
let (_, mut r) = server.split(); let (_, mut r) = server.split();
let event = ProtocolEvent::OpenStream { let event = ProtocolEvent::OpenStream {
@ -582,9 +710,9 @@ mod tests {
}); });
let client = TcpStream::connect("127.0.0.1:5001").await.unwrap(); let client = TcpStream::connect("127.0.0.1:5001").await.unwrap();
let (_listener, server) = r1.await.unwrap(); let (_listener, server) = r1.await.unwrap();
let metrics = Arc::new(ProtocolMetrics::new().unwrap()); let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); let client = Protocols::new_tcp(client, metrics.clone());
let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); let server = Protocols::new_tcp(server, metrics);
let (s, _) = client.split(); let (s, _) = client.split();
let (_, mut r) = server.split(); let (_, mut r) = server.split();
let e = tokio::spawn(async move { r.recv().await }); let e = tokio::spawn(async move { r.recv().await });

View File

@ -30,7 +30,7 @@ impl Message {
/// # Example /// # Example
/// for example coding, see [`send_raw`] /// for example coding, see [`send_raw`]
/// ///
/// [`send_raw`]: Stream::send_raw /// [`send_raw`]: crate::api::Stream::send_raw
/// [`Participants`]: crate::api::Participant /// [`Participants`]: crate::api::Participant
/// [`compress`]: lz_fear::raw::compress2 /// [`compress`]: lz_fear::raw::compress2
/// [`Message::serialize`]: crate::message::Message::serialize /// [`Message::serialize`]: crate::message::Message::serialize

View File

@ -251,6 +251,7 @@ fn protocolconnect_name(protocol: &ConnectAddr) -> &str {
ConnectAddr::Tcp(_) => "tcp", ConnectAddr::Tcp(_) => "tcp",
ConnectAddr::Udp(_) => "udp", ConnectAddr::Udp(_) => "udp",
ConnectAddr::Mpsc(_) => "mpsc", ConnectAddr::Mpsc(_) => "mpsc",
#[cfg(feature = "quic")]
ConnectAddr::Quic(_, _, _) => "quic", ConnectAddr::Quic(_, _, _) => "quic",
} }
} }
@ -261,6 +262,7 @@ fn protocollisten_name(protocol: &ListenAddr) -> &str {
ListenAddr::Tcp(_) => "tcp", ListenAddr::Tcp(_) => "tcp",
ListenAddr::Udp(_) => "udp", ListenAddr::Udp(_) => "udp",
ListenAddr::Mpsc(_) => "mpsc", ListenAddr::Mpsc(_) => "mpsc",
#[cfg(feature = "quic")]
ListenAddr::Quic(_, _) => "quic", ListenAddr::Quic(_, _) => "quic",
} }
} }

View File

@ -756,7 +756,7 @@ impl BParticipant {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use network_protocol::ProtocolMetrics; use network_protocol::{ProtocolMetricCache, ProtocolMetrics};
use tokio::{ use tokio::{
runtime::Runtime, runtime::Runtime,
sync::{mpsc, oneshot}, sync::{mpsc, oneshot},
@ -816,14 +816,16 @@ mod tests {
) -> Protocols { ) -> Protocols {
let (s1, r1) = mpsc::channel(100); let (s1, r1) = mpsc::channel(100);
let (s2, r2) = mpsc::channel(100); let (s2, r2) = mpsc::channel(100);
let metrics = Arc::new(ProtocolMetrics::new().unwrap()); let met = Arc::new(ProtocolMetrics::new().unwrap());
let p1 = Protocols::new_mpsc(s1, r2, cid, Arc::clone(&metrics)); let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met));
let p1 = Protocols::new_mpsc(s1, r2, metrics);
let (complete_s, complete_r) = oneshot::channel(); let (complete_s, complete_r) = oneshot::channel();
create_channel create_channel
.send((cid, Sid::new(0), p1, complete_s)) .send((cid, Sid::new(0), p1, complete_s))
.unwrap(); .unwrap();
complete_r.await.unwrap(); complete_r.await.unwrap();
Protocols::new_mpsc(s2, r1, cid, Arc::clone(&metrics)) let metrics = ProtocolMetricCache::new(&cid.to_string(), met);
Protocols::new_mpsc(s2, r1, metrics)
} }
#[test] #[test]

View File

@ -4,8 +4,8 @@ use crate::{
metrics::{NetworkMetrics, ProtocolInfo}, metrics::{NetworkMetrics, ProtocolInfo},
participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant},
}; };
use futures_util::{FutureExt, StreamExt}; use futures_util::StreamExt;
use network_protocol::{Cid, MpscMsg, Pid, ProtocolMetrics}; use network_protocol::{Cid, Pid, ProtocolMetricCache, ProtocolMetrics};
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
use prometheus::Registry; use prometheus::Registry;
use rand::Rng; use rand::Rng;
@ -18,7 +18,7 @@ use std::{
time::Duration, time::Duration,
}; };
use tokio::{ use tokio::{
io, net, select, io,
sync::{mpsc, oneshot, Mutex}, sync::{mpsc, oneshot, Mutex},
}; };
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
@ -33,12 +33,6 @@ use tracing::*;
// - w: wire // - w: wire
// - c: channel/handshake // - c: channel/handshake
lazy_static::lazy_static! {
pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = {
Mutex::new(HashMap::new())
};
}
#[derive(Debug)] #[derive(Debug)]
struct ParticipantInfo { struct ParticipantInfo {
secret: u128, secret: u128,
@ -52,10 +46,6 @@ pub(crate) type A2sConnect = (
oneshot::Sender<Result<Participant, NetworkConnectError>>, oneshot::Sender<Result<Participant, NetworkConnectError>>,
); );
type A2sDisconnect = (Pid, S2bShutdownBparticipant); type A2sDisconnect = (Pid, S2bShutdownBparticipant);
type S2sMpscConnect = (
mpsc::Sender<MpscMsg>,
oneshot::Sender<mpsc::Sender<MpscMsg>>,
);
#[derive(Debug)] #[derive(Debug)]
struct ControlChannels { struct ControlChannels {
@ -88,8 +78,6 @@ pub struct Scheduler {
} }
impl Scheduler { impl Scheduler {
const MPSC_CHANNEL_BOUND: usize = 1000;
pub fn new( pub fn new(
local_pid: Pid, local_pid: Pid,
#[cfg(feature = "metrics")] registry: Option<&Registry>, #[cfg(feature = "metrics")] registry: Option<&Registry>,
@ -157,7 +145,10 @@ impl Scheduler {
} }
pub async fn run(mut self) { pub async fn run(mut self) {
let run_channels = self.run_channels.take().unwrap(); let run_channels = self
.run_channels
.take()
.expect("run() can only be called once");
tokio::join!( tokio::join!(
self.listen_mgr(run_channels.a2s_listen_r), self.listen_mgr(run_channels.a2s_listen_r),
@ -174,17 +165,66 @@ impl Scheduler {
a2s_listen_r a2s_listen_r
.for_each_concurrent(None, |(address, s2a_listen_result_s)| { .for_each_concurrent(None, |(address, s2a_listen_result_s)| {
let address = address; let address = address;
let cids = Arc::clone(&self.channel_ids);
#[cfg(feature = "metrics")]
let mcache = self.metrics.connect_requests_cache(&address);
async move {
debug!(?address, "Got request to open a channel_creator"); debug!(?address, "Got request to open a channel_creator");
self.metrics.listen_request(&address); self.metrics.listen_request(&address);
let (end_sender, end_receiver) = oneshot::channel::<()>(); let (s2s_stop_listening_s, s2s_stop_listening_r) = oneshot::channel::<()>();
let (c2s_protocol_s, mut c2s_protocol_r) = mpsc::unbounded_channel();
let metrics = Arc::clone(&self.protocol_metrics);
async move {
self.channel_listener self.channel_listener
.lock() .lock()
.await .await
.insert(address.clone().into(), end_sender); .insert(address.clone().into(), s2s_stop_listening_s);
self.channel_creator(address, end_receiver, s2a_listen_result_s)
.await; #[cfg(feature = "metrics")]
mcache.inc();
let res = match address {
ListenAddr::Tcp(addr) => {
Protocols::with_tcp_listen(
addr,
cids,
metrics,
s2s_stop_listening_r,
c2s_protocol_s,
)
.await
},
#[cfg(feature = "quic")]
ListenAddr::Quic(addr, ref server_config) => {
Protocols::with_quic_listen(
addr,
server_config.clone(),
cids,
metrics,
s2s_stop_listening_r,
c2s_protocol_s,
)
.await
},
ListenAddr::Mpsc(addr) => {
Protocols::with_mpsc_listen(
addr,
cids,
metrics,
s2s_stop_listening_r,
c2s_protocol_s,
)
.await
},
_ => unimplemented!(),
};
let _ = s2a_listen_result_s.send(res);
while let Some((prot, cid)) = c2s_protocol_r.recv().await {
self.init_protocol(prot, cid, None, true).await;
}
} }
}) })
.await; .await;
@ -195,15 +235,16 @@ impl Scheduler {
trace!("Start connect_mgr"); trace!("Start connect_mgr");
while let Some((addr, pid_sender)) = a2s_connect_r.recv().await { while let Some((addr, pid_sender)) = a2s_connect_r.recv().await {
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
let metrics = Arc::clone(&self.protocol_metrics); let metrics =
ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics));
self.metrics.connect_request(&addr); self.metrics.connect_request(&addr);
let protocol = match addr { let protocol = match addr {
ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await, ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await,
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
ConnectAddr::Quic(addr, ref config, name) => { ConnectAddr::Quic(addr, ref config, name) => {
Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await Protocols::with_quic_connect(addr, config.clone(), name, metrics).await
}, },
ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await, ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, metrics).await,
_ => unimplemented!(), _ => unimplemented!(),
}; };
let protocol = match protocol { let protocol = match protocol {
@ -327,204 +368,6 @@ impl Scheduler {
trace!("Stop scheduler_shutdown_mgr"); trace!("Stop scheduler_shutdown_mgr");
} }
async fn channel_creator(
&self,
addr: ListenAddr,
s2s_stop_listening_r: oneshot::Receiver<()>,
s2a_listen_result_s: oneshot::Sender<io::Result<()>>,
) {
trace!(?addr, "Start up channel creator");
#[cfg(feature = "metrics")]
let mcache = self.metrics.connect_requests_cache(&addr);
match addr {
ListenAddr::Tcp(addr) => {
let listener = match net::TcpListener::bind(addr).await {
Ok(listener) => {
s2a_listen_result_s.send(Ok(())).unwrap();
listener
},
Err(e) => {
info!(
?addr,
?e,
"Tcp bind error during listener startup"
);
s2a_listen_result_s.send(Err(e)).unwrap();
return;
},
};
trace!(?addr, "Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
while let Some(data) = select! {
next = listener.accept().fuse() => Some(next),
_ = &mut end_receiver => None,
} {
let (stream, remote_addr) = match data {
Ok((s, p)) => (s, p),
Err(e) => {
warn!(?e, "TcpStream Error, ignoring connection attempt");
continue;
},
};
#[cfg(feature = "metrics")]
mcache.inc();
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
info!(?remote_addr, ?cid, "Accepting Tcp from");
self.init_protocol(Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), cid, None, true)
.await;
}
},
#[cfg(feature = "quic")]
ListenAddr::Quic(addr, ref server_config) => {
let mut endpoint = quinn::Endpoint::builder();
endpoint.listen(server_config.clone());
let (_endpoint, mut listener) = match endpoint.bind(&addr) {
Ok((endpoint, listener)) => {
s2a_listen_result_s.send(Ok(())).unwrap();
(endpoint, listener)
},
Err(quinn::EndpointError::Socket(e)) => {
info!(
?addr,
?e,
"Quic bind error during listener startup"
);
s2a_listen_result_s.send(Err(e)).unwrap();
return;
}
};
trace!(?addr, "Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
while let Some(Some(connecting)) = select! {
next = listener.next().fuse() => Some(next),
_ = &mut end_receiver => None,
} {
let remote_addr = connecting.remote_address();
let connection = match connecting.await {
Ok(c) => c,
Err(e) => {
debug!(?e, ?remote_addr, "skipping connection attempt");
continue;
},
};
#[cfg(feature = "metrics")]
mcache.inc();
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
info!(?remote_addr, ?cid, "Accepting Quic from");
let quic = match Protocols::new_quic(connection, true, cid, Arc::clone(&self.protocol_metrics)).await {
Ok(quic) => quic,
Err(e) => {
trace!(?e, "failed to start quic");
continue;
}
};
self.init_protocol(quic, cid, None, true)
.await;
}
},
ListenAddr::Mpsc(addr) => {
let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
MPSC_POOL.lock().await.insert(addr, mpsc_s);
s2a_listen_result_s.send(Ok(())).unwrap();
trace!(?addr, "Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
next = mpsc_r.recv().fuse() => next,
_ = &mut end_receiver => None,
} {
let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
local_remote_to_local_s.send(remote_to_local_s).unwrap();
#[cfg(feature = "metrics")]
mcache.inc();
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
info!(?addr, ?cid, "Accepting Mpsc from");
self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, Arc::clone(&self.protocol_metrics)), cid, None, true)
.await;
}
warn!("MpscStream Failed, stopping");
},/*
ProtocolListenAddr::Udp(addr) => {
let socket = match net::UdpSocket::bind(addr).await {
Ok(socket) => {
s2a_listen_result_s.send(Ok(())).unwrap();
Arc::new(socket)
},
Err(e) => {
info!(
?addr,
?e,
"Listener couldn't be started due to error on udp bind"
);
s2a_listen_result_s.send(Err(e)).unwrap();
return;
},
};
trace!(?addr, "Listener bound");
// receiving is done from here and will be piped to protocol as UDP does not
// have any state
let mut listeners = HashMap::new();
let mut end_receiver = s2s_stop_listening_r.fuse();
const UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER: usize = 9216;
let mut data = [0u8; UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER];
while let Ok((size, remote_addr)) = select! {
next = socket.recv_from(&mut data).fuse() => next,
_ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")),
} {
let mut datavec = Vec::with_capacity(size);
datavec.extend_from_slice(&data[0..size]);
//Due to the async nature i cannot make of .entry() as it would lead to a still
// borrowed in another branch situation
#[allow(clippy::map_entry)]
if !listeners.contains_key(&remote_addr) {
info!("Accepting Udp from: {}", &remote_addr);
let (udp_data_sender, udp_data_receiver) =
mpsc::unbounded_channel::<Vec<u8>>();
listeners.insert(remote_addr, udp_data_sender);
let protocol = UdpProtocol::new(
Arc::clone(&socket),
remote_addr,
#[cfg(feature = "metrics")]
Arc::clone(&self.metrics),
udp_data_receiver,
);
self.init_protocol(Protocols::Udp(protocol), None, false)
.await;
}
let udp_data_sender = listeners.get_mut(&remote_addr).unwrap();
udp_data_sender.send(datavec).unwrap();
}
},*/
_ => unimplemented!(),
}
trace!(?addr, "Ending channel creator");
}
#[allow(dead_code)]
async fn udp_single_channel_connect(
socket: Arc<net::UdpSocket>,
w2p_udp_package_s: mpsc::UnboundedSender<Vec<u8>>,
) {
let addr = socket.local_addr();
trace!(?addr, "Start udp_single_channel_connect");
//TODO: implement real closing
let (_end_sender, end_receiver) = oneshot::channel::<()>();
// receiving is done from here and will be piped to protocol as UDP does not
// have any state
let mut end_receiver = end_receiver.fuse();
let mut data = [0u8; 9216];
while let Ok(size) = select! {
next = socket.recv(&mut data).fuse() => next,
_ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")),
} {
let mut datavec = Vec::with_capacity(size);
datavec.extend_from_slice(&data[0..size]);
w2p_udp_package_s.send(datavec).unwrap();
}
trace!(?addr, "Stop udp_single_channel_connect");
}
async fn init_protocol( async fn init_protocol(
&self, &self,
mut protocol: Protocols, mut protocol: Protocols,

View File

@ -85,7 +85,7 @@ fn stream_simple_quic() {
#[test] #[test]
fn stream_simple_quic_3msg() { fn stream_simple_quic_3msg() {
let (_, _) = helper::setup(true, 0); let (_, _) = helper::setup(false, 0);
let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic()); let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic());
s1_a.send("Hello World").unwrap(); s1_a.send("Hello World").unwrap();