extract protocol specific listen code from scheduler and move it to channel.rs

This commit is contained in:
Marcel Märtens 2021-04-27 17:59:36 +02:00
parent 4afadf57dc
commit 653fb065e0
8 changed files with 300 additions and 318 deletions

View File

@ -451,7 +451,10 @@ where
m.data.extend_from_slice(&data);
if m.data.len() == m.length as usize {
// finished, yay
let m = self.incoming.remove(&mid).unwrap();
let m = self
.incoming
.remove(&mid)
.ok_or(ProtocolError::Violated)?;
self.metrics.rmsg_ob(
m.sid,
RemoveReason::Finished,

View File

@ -145,8 +145,8 @@ pub struct StreamParams {
/// [`Arc`](std::sync::Arc) as all commands have internal mutability.
///
/// The `Network` has methods to [`connect`] to other [`Participants`] actively
/// via their [`ProtocolConnectAddr`], or [`listen`] passively for [`connected`]
/// [`Participants`] via [`ProtocolListenAddr`].
/// via their [`ConnectAddr`], or [`listen`] passively for [`connected`]
/// [`Participants`] via [`ListenAddr`].
///
/// Too guarantee a clean shutdown, the [`Runtime`] MUST NOT be droped before
/// the Network.
@ -178,6 +178,8 @@ pub struct StreamParams {
/// [`connect`]: Network::connect
/// [`listen`]: Network::listen
/// [`connected`]: Network::connected
/// [`ConnectAddr`]: crate::api::ConnectAddr
/// [`ListenAddr`]: crate::api::ListenAddr
pub struct Network {
local_pid: Pid,
participant_disconnect_sender: Arc<Mutex<HashMap<Pid, A2sDisconnect>>>,
@ -293,7 +295,7 @@ impl Network {
}
}
/// starts listening on an [`ProtocolListenAddr`].
/// starts listening on an [`ListenAddr`].
/// When the method returns the `Network` is ready to listen for incoming
/// connections OR has returned a [`NetworkError`] (e.g. port already used).
/// You can call [`connected`] to asynchrony wait for a [`Participant`] to
@ -303,7 +305,7 @@ impl Network {
/// # Examples
/// ```ignore
/// use tokio::runtime::Runtime;
/// use veloren_network::{Network, Pid, ProtocolListenAddr};
/// use veloren_network::{Network, Pid, ListenAddr};
///
/// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
/// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally
@ -311,10 +313,10 @@ impl Network {
/// let network = Network::new(Pid::new(), &runtime);
/// runtime.block_on(async {
/// network
/// .listen(ProtocolListenAddr::Tcp("127.0.0.1:2000".parse().unwrap()))
/// .listen(ListenAddr::Tcp("127.0.0.1:2000".parse().unwrap()))
/// .await?;
/// network
/// .listen(ProtocolListenAddr::Udp("127.0.0.1:2001".parse().unwrap()))
/// .listen(ListenAddr::Udp("127.0.0.1:2001".parse().unwrap()))
/// .await?;
/// drop(network);
/// # Ok(())
@ -323,6 +325,7 @@ impl Network {
/// ```
///
/// [`connected`]: Network::connected
/// [`ListenAddr`]: crate::api::ListenAddr
#[instrument(name="network", skip(self, address), fields(p = %self.local_pid))]
pub async fn listen(&self, address: ListenAddr) -> Result<(), NetworkError> {
let (s2a_result_s, s2a_result_r) = oneshot::channel::<tokio::io::Result<()>>();
@ -339,13 +342,13 @@ impl Network {
}
}
/// starts connection to an [`ProtocolConnectAddr`].
/// starts connection to an [`ConnectAddr`].
/// When the method returns the Network either returns a [`Participant`]
/// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g.
/// can't connect, or invalid Handshake) # Examples
/// ```ignore
/// use tokio::runtime::Runtime;
/// use veloren_network::{Network, Pid, ProtocolListenAddr, ProtocolConnectAddr};
/// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr};
///
/// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
/// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above
@ -353,16 +356,16 @@ impl Network {
/// let network = Network::new(Pid::new(), &runtime);
/// # let remote = Network::new(Pid::new(), &runtime);
/// runtime.block_on(async {
/// # remote.listen(ProtocolListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?;
/// # remote.listen(ProtocolListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?;
/// # remote.listen(ListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?;
/// # remote.listen(ListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?;
/// let p1 = network
/// .connect(ProtocolConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap()))
/// .connect(ConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap()))
/// .await?;
/// # //this doesn't work yet, so skip the test
/// # //TODO fixme!
/// # return Ok(());
/// let p2 = network
/// .connect(ProtocolConnectAddr::Udp("127.0.0.1:2011".parse().unwrap()))
/// .connect(ConnectAddr::Udp("127.0.0.1:2011".parse().unwrap()))
/// .await?;
/// assert_eq!(&p1, &p2);
/// # Ok(())
@ -374,13 +377,13 @@ impl Network {
/// ```
/// Usually the `Network` guarantees that a operation on a [`Participant`]
/// succeeds, e.g. by automatic retrying unless it fails completely e.g. by
/// disconnecting from the remote. If 2 [`ProtocolConnectAddres`] you
/// disconnecting from the remote. If 2 [`ConnectAddr] you
/// `connect` to belongs to the same [`Participant`], you get the same
/// [`Participant`] as a result. This is useful e.g. by connecting to
/// the same [`Participant`] via multiple Protocols.
///
/// [`Streams`]: crate::api::Stream
/// [`ProtocolConnectAddres`]: crate::api::ProtocolConnectAddr
/// [`ConnectAddr`]: crate::api::ConnectAddr
#[instrument(name="network", skip(self, address), fields(p = %self.local_pid))]
pub async fn connect(&self, address: ConnectAddr) -> Result<Participant, NetworkError> {
let (pid_sender, pid_receiver) =
@ -403,7 +406,7 @@ impl Network {
Ok(participant)
}
/// returns a [`Participant`] created from a [`ProtocolListenAddr`] you
/// returns a [`Participant`] created from a [`ListenAddr`] you
/// called [`listen`] on before. This function will either return a
/// working [`Participant`] ready to open [`Streams`] on OR has returned
/// a [`NetworkError`] (e.g. Network got closed)
@ -437,6 +440,7 @@ impl Network {
///
/// [`Streams`]: crate::api::Stream
/// [`listen`]: crate::api::Network::listen
/// [`ListenAddr`]: crate::api::ListenAddr
#[instrument(name="network", skip(self), fields(p = %self.local_pid))]
pub async fn connected(&self) -> Result<Participant, NetworkError> {
let participant = self.connected_receiver.lock().await.recv().await?;

View File

@ -1,21 +1,34 @@
use crate::api::NetworkConnectError;
use async_trait::async_trait;
use bytes::BytesMut;
use futures_util::FutureExt;
#[cfg(feature = "quic")]
use futures_util::StreamExt;
use network_protocol::{
Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid,
ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat,
QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol, Sid, TcpRecvProtocol,
ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
TcpSendProtocol, UnreliableDrain, UnreliableSink,
};
use std::{sync::Arc, time::Duration};
#[cfg(feature = "quic")]
use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol};
use std::{
collections::HashMap,
io,
net::SocketAddr,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net,
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
sync::{mpsc, oneshot},
select,
sync::{mpsc, oneshot, Mutex},
};
use tokio_stream::StreamExt;
use tracing::{info, trace};
use tracing::{error, info, trace, warn};
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
@ -42,32 +55,67 @@ pub(crate) enum RecvProtocols {
Quic(QuicRecvProtocol<QuicSink>),
}
lazy_static::lazy_static! {
pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<C2cMpscConnect>>> = {
Mutex::new(HashMap::new())
};
}
pub(crate) type C2cMpscConnect = (
mpsc::Sender<MpscMsg>,
oneshot::Sender<mpsc::Sender<MpscMsg>>,
);
impl Protocols {
const MPSC_CHANNEL_BOUND: usize = 1000;
pub(crate) async fn with_tcp_connect(
addr: std::net::SocketAddr,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
addr: SocketAddr,
metrics: ProtocolMetricCache,
) -> Result<Self, NetworkConnectError> {
let stream = match net::TcpStream::connect(addr).await {
Ok(stream) => stream,
Err(e) => {
return Err(crate::api::NetworkConnectError::Io(e));
},
};
info!("Connecting Tcp to: {}", stream.peer_addr().unwrap());
Ok(Protocols::new_tcp(stream, cid, metrics))
let stream = net::TcpStream::connect(addr)
.await
.map_err(NetworkConnectError::Io)?;
info!(
"Connecting Tcp to: {}",
stream.peer_addr().map_err(NetworkConnectError::Io)?
);
Ok(Self::new_tcp(stream, metrics))
}
pub(crate) fn new_tcp(
stream: tokio::net::TcpStream,
cid: Cid,
pub(crate) async fn with_tcp_listen(
addr: SocketAddr,
cids: Arc<AtomicU64>,
metrics: Arc<ProtocolMetrics>,
) -> Self {
let (r, w) = stream.into_split();
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> {
let listener = net::TcpListener::bind(addr).await?;
trace!(?addr, "Tcp Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
tokio::spawn(async move {
while let Some(data) = select! {
next = listener.accept().fuse() => Some(next),
_ = &mut end_receiver => None,
} {
let (stream, remote_addr) = match data {
Ok((s, p)) => (s, p),
Err(e) => {
trace!(?e, "TcpStream Error, ignoring connection attempt");
continue;
},
};
let cid = cids.fetch_add(1, Ordering::Relaxed);
info!(?remote_addr, ?cid, "Accepting Tcp from");
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
let _ = c2s_protocol_s.send((Self::new_tcp(stream, metrics.clone()), cid));
}
});
Ok(())
}
pub(crate) fn new_tcp(stream: tokio::net::TcpStream, metrics: ProtocolMetricCache) -> Self {
let (r, w) = stream.into_split();
let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone());
let rp = TcpRecvProtocol::new(
TcpSink {
@ -81,70 +129,104 @@ impl Protocols {
pub(crate) async fn with_mpsc_connect(
addr: u64,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
metrics: ProtocolMetricCache,
) -> Result<Self, NetworkConnectError> {
let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) {
Some(s) => s.clone(),
None => {
return Err(NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::NotConnected,
let mpsc_s = MPSC_POOL
.lock()
.await
.get(&addr)
.ok_or_else(|| {
NetworkConnectError::Io(io::Error::new(
io::ErrorKind::NotConnected,
"no mpsc listen on this addr",
)));
},
};
))
})?
.clone();
let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
if mpsc_s
mpsc_s
.send((remote_to_local_s, local_to_remote_oneshot_s))
.is_err()
{
return Err(NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"mpsc pipe broke during connect",
)));
}
let local_to_remote_s = match local_to_remote_oneshot_r.await {
Ok(s) => s,
Err(e) => {
return Err(NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
e,
)));
},
};
.map_err(|_| {
NetworkConnectError::Io(io::Error::new(
io::ErrorKind::BrokenPipe,
"mpsc pipe broke during connect",
))
})?;
let local_to_remote_s = local_to_remote_oneshot_r
.await
.map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
info!(?addr, "Connecting Mpsc");
Ok(Self::new_mpsc(
local_to_remote_s,
remote_to_local_r,
cid,
metrics,
))
}
pub(crate) async fn with_mpsc_listen(
addr: u64,
cids: Arc<AtomicU64>,
metrics: Arc<ProtocolMetrics>,
s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> {
let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
MPSC_POOL.lock().await.insert(addr, mpsc_s);
trace!(?addr, "Mpsc Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
tokio::spawn(async move {
while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
next = mpsc_r.recv().fuse() => next,
_ = &mut end_receiver => None,
} {
let (remote_to_local_s, remote_to_local_r) =
mpsc::channel(Self::MPSC_CHANNEL_BOUND);
if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) {
error!(?e, "mpsc listen aborted");
}
let cid = cids.fetch_add(1, Ordering::Relaxed);
info!(?addr, ?cid, "Accepting Mpsc from");
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
let _ = c2s_protocol_s.send((
Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()),
cid,
));
}
warn!("MpscStream Failed, stopping");
});
Ok(())
}
pub(crate) fn new_mpsc(
sender: mpsc::Sender<MpscMsg>,
receiver: mpsc::Receiver<MpscMsg>,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
metrics: ProtocolMetricCache,
) -> Self {
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone());
let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics);
Protocols::Mpsc((sp, rp))
}
#[cfg(feature = "quic")]
pub(crate) async fn with_quic_connect(
addr: std::net::SocketAddr,
addr: SocketAddr,
config: quinn::ClientConfig,
name: String,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
metrics: ProtocolMetricCache,
) -> Result<Self, NetworkConnectError> {
let config = config.clone();
let endpoint = quinn::Endpoint::builder();
let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
let bindsock = match addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
SocketAddr::V6(_) => {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
},
};
let (endpoint, _) = match endpoint.bind(&bindsock) {
Ok(e) => e,
Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)),
};
@ -164,7 +246,7 @@ impl Protocols {
e,
))
})?;
Protocols::new_quic(connection, false, cid, metrics)
Self::new_quic(connection, false, metrics)
.await
.map_err(|e| {
trace!(?e, "error with quic");
@ -175,15 +257,60 @@ impl Protocols {
})
}
#[cfg(feature = "quic")]
pub(crate) async fn with_quic_listen(
addr: SocketAddr,
server_config: quinn::ServerConfig,
cids: Arc<AtomicU64>,
metrics: Arc<ProtocolMetrics>,
s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> {
let mut endpoint = quinn::Endpoint::builder();
endpoint.listen(server_config);
let (_endpoint, mut listener) = match endpoint.bind(&addr) {
Ok(v) => v,
Err(quinn::EndpointError::Socket(e)) => return Err(e),
};
trace!(?addr, "Quic Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
tokio::spawn(async move {
while let Some(Some(connecting)) = select! {
next = listener.next().fuse() => Some(next),
_ = &mut end_receiver => None,
} {
let remote_addr = connecting.remote_address();
let connection = match connecting.await {
Ok(c) => c,
Err(e) => {
tracing::debug!(?e, ?remote_addr, "skipping connection attempt");
continue;
},
};
let cid = cids.fetch_add(1, Ordering::Relaxed);
info!(?remote_addr, ?cid, "Accepting Quic from");
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
match Protocols::new_quic(connection, true, metrics).await {
Ok(quic) => {
let _ = c2s_protocol_s.send((quic, cid));
},
Err(e) => {
trace!(?e, "failed to start quic");
continue;
},
}
}
});
Ok(())
}
#[cfg(feature = "quic")]
pub(crate) async fn new_quic(
mut connection: quinn::NewConnection,
listen: bool,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
metrics: ProtocolMetricCache,
) -> Result<Self, quinn::ConnectionError> {
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
let (sendstream, recvstream) = if listen {
connection.connection.open_bi().await?
} else {
@ -191,7 +318,7 @@ impl Protocols {
.bi_streams
.next()
.await
.ok_or_else(|| quinn::ConnectionError::LocallyClosed)??
.ok_or(quinn::ConnectionError::LocallyClosed)??
};
let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
let streams_s_clone = recvstreams_s.clone();
@ -521,7 +648,8 @@ impl UnreliableSink for QuicSink {
mod tests {
use super::*;
use bytes::Bytes;
use network_protocol::{Promises, RecvProtocol, SendProtocol};
use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
#[tokio::test]
@ -533,9 +661,9 @@ mod tests {
});
let client = TcpStream::connect("127.0.0.1:5000").await.unwrap();
let (_listener, server) = r1.await.unwrap();
let metrics = Arc::new(ProtocolMetrics::new().unwrap());
let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics));
let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics));
let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
let client = Protocols::new_tcp(client, metrics.clone());
let server = Protocols::new_tcp(server, metrics);
let (mut s, _) = client.split();
let (_, mut r) = server.split();
let event = ProtocolEvent::OpenStream {
@ -582,9 +710,9 @@ mod tests {
});
let client = TcpStream::connect("127.0.0.1:5001").await.unwrap();
let (_listener, server) = r1.await.unwrap();
let metrics = Arc::new(ProtocolMetrics::new().unwrap());
let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics));
let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics));
let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
let client = Protocols::new_tcp(client, metrics.clone());
let server = Protocols::new_tcp(server, metrics);
let (s, _) = client.split();
let (_, mut r) = server.split();
let e = tokio::spawn(async move { r.recv().await });

View File

@ -30,7 +30,7 @@ impl Message {
/// # Example
/// for example coding, see [`send_raw`]
///
/// [`send_raw`]: Stream::send_raw
/// [`send_raw`]: crate::api::Stream::send_raw
/// [`Participants`]: crate::api::Participant
/// [`compress`]: lz_fear::raw::compress2
/// [`Message::serialize`]: crate::message::Message::serialize

View File

@ -251,6 +251,7 @@ fn protocolconnect_name(protocol: &ConnectAddr) -> &str {
ConnectAddr::Tcp(_) => "tcp",
ConnectAddr::Udp(_) => "udp",
ConnectAddr::Mpsc(_) => "mpsc",
#[cfg(feature = "quic")]
ConnectAddr::Quic(_, _, _) => "quic",
}
}
@ -261,6 +262,7 @@ fn protocollisten_name(protocol: &ListenAddr) -> &str {
ListenAddr::Tcp(_) => "tcp",
ListenAddr::Udp(_) => "udp",
ListenAddr::Mpsc(_) => "mpsc",
#[cfg(feature = "quic")]
ListenAddr::Quic(_, _) => "quic",
}
}

View File

@ -756,7 +756,7 @@ impl BParticipant {
#[cfg(test)]
mod tests {
use super::*;
use network_protocol::ProtocolMetrics;
use network_protocol::{ProtocolMetricCache, ProtocolMetrics};
use tokio::{
runtime::Runtime,
sync::{mpsc, oneshot},
@ -816,14 +816,16 @@ mod tests {
) -> Protocols {
let (s1, r1) = mpsc::channel(100);
let (s2, r2) = mpsc::channel(100);
let metrics = Arc::new(ProtocolMetrics::new().unwrap());
let p1 = Protocols::new_mpsc(s1, r2, cid, Arc::clone(&metrics));
let met = Arc::new(ProtocolMetrics::new().unwrap());
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met));
let p1 = Protocols::new_mpsc(s1, r2, metrics);
let (complete_s, complete_r) = oneshot::channel();
create_channel
.send((cid, Sid::new(0), p1, complete_s))
.unwrap();
complete_r.await.unwrap();
Protocols::new_mpsc(s2, r1, cid, Arc::clone(&metrics))
let metrics = ProtocolMetricCache::new(&cid.to_string(), met);
Protocols::new_mpsc(s2, r1, metrics)
}
#[test]

View File

@ -4,8 +4,8 @@ use crate::{
metrics::{NetworkMetrics, ProtocolInfo},
participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant},
};
use futures_util::{FutureExt, StreamExt};
use network_protocol::{Cid, MpscMsg, Pid, ProtocolMetrics};
use futures_util::StreamExt;
use network_protocol::{Cid, Pid, ProtocolMetricCache, ProtocolMetrics};
#[cfg(feature = "metrics")]
use prometheus::Registry;
use rand::Rng;
@ -18,7 +18,7 @@ use std::{
time::Duration,
};
use tokio::{
io, net, select,
io,
sync::{mpsc, oneshot, Mutex},
};
use tokio_stream::wrappers::UnboundedReceiverStream;
@ -33,12 +33,6 @@ use tracing::*;
// - w: wire
// - c: channel/handshake
lazy_static::lazy_static! {
pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = {
Mutex::new(HashMap::new())
};
}
#[derive(Debug)]
struct ParticipantInfo {
secret: u128,
@ -52,10 +46,6 @@ pub(crate) type A2sConnect = (
oneshot::Sender<Result<Participant, NetworkConnectError>>,
);
type A2sDisconnect = (Pid, S2bShutdownBparticipant);
type S2sMpscConnect = (
mpsc::Sender<MpscMsg>,
oneshot::Sender<mpsc::Sender<MpscMsg>>,
);
#[derive(Debug)]
struct ControlChannels {
@ -88,8 +78,6 @@ pub struct Scheduler {
}
impl Scheduler {
const MPSC_CHANNEL_BOUND: usize = 1000;
pub fn new(
local_pid: Pid,
#[cfg(feature = "metrics")] registry: Option<&Registry>,
@ -157,7 +145,10 @@ impl Scheduler {
}
pub async fn run(mut self) {
let run_channels = self.run_channels.take().unwrap();
let run_channels = self
.run_channels
.take()
.expect("run() can only be called once");
tokio::join!(
self.listen_mgr(run_channels.a2s_listen_r),
@ -174,17 +165,66 @@ impl Scheduler {
a2s_listen_r
.for_each_concurrent(None, |(address, s2a_listen_result_s)| {
let address = address;
let cids = Arc::clone(&self.channel_ids);
#[cfg(feature = "metrics")]
let mcache = self.metrics.connect_requests_cache(&address);
debug!(?address, "Got request to open a channel_creator");
self.metrics.listen_request(&address);
let (s2s_stop_listening_s, s2s_stop_listening_r) = oneshot::channel::<()>();
let (c2s_protocol_s, mut c2s_protocol_r) = mpsc::unbounded_channel();
let metrics = Arc::clone(&self.protocol_metrics);
async move {
debug!(?address, "Got request to open a channel_creator");
self.metrics.listen_request(&address);
let (end_sender, end_receiver) = oneshot::channel::<()>();
self.channel_listener
.lock()
.await
.insert(address.clone().into(), end_sender);
self.channel_creator(address, end_receiver, s2a_listen_result_s)
.await;
.insert(address.clone().into(), s2s_stop_listening_s);
#[cfg(feature = "metrics")]
mcache.inc();
let res = match address {
ListenAddr::Tcp(addr) => {
Protocols::with_tcp_listen(
addr,
cids,
metrics,
s2s_stop_listening_r,
c2s_protocol_s,
)
.await
},
#[cfg(feature = "quic")]
ListenAddr::Quic(addr, ref server_config) => {
Protocols::with_quic_listen(
addr,
server_config.clone(),
cids,
metrics,
s2s_stop_listening_r,
c2s_protocol_s,
)
.await
},
ListenAddr::Mpsc(addr) => {
Protocols::with_mpsc_listen(
addr,
cids,
metrics,
s2s_stop_listening_r,
c2s_protocol_s,
)
.await
},
_ => unimplemented!(),
};
let _ = s2a_listen_result_s.send(res);
while let Some((prot, cid)) = c2s_protocol_r.recv().await {
self.init_protocol(prot, cid, None, true).await;
}
}
})
.await;
@ -195,15 +235,16 @@ impl Scheduler {
trace!("Start connect_mgr");
while let Some((addr, pid_sender)) = a2s_connect_r.recv().await {
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
let metrics = Arc::clone(&self.protocol_metrics);
let metrics =
ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics));
self.metrics.connect_request(&addr);
let protocol = match addr {
ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await,
ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await,
#[cfg(feature = "quic")]
ConnectAddr::Quic(addr, ref config, name) => {
Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await
Protocols::with_quic_connect(addr, config.clone(), name, metrics).await
},
ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await,
ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, metrics).await,
_ => unimplemented!(),
};
let protocol = match protocol {
@ -327,204 +368,6 @@ impl Scheduler {
trace!("Stop scheduler_shutdown_mgr");
}
async fn channel_creator(
&self,
addr: ListenAddr,
s2s_stop_listening_r: oneshot::Receiver<()>,
s2a_listen_result_s: oneshot::Sender<io::Result<()>>,
) {
trace!(?addr, "Start up channel creator");
#[cfg(feature = "metrics")]
let mcache = self.metrics.connect_requests_cache(&addr);
match addr {
ListenAddr::Tcp(addr) => {
let listener = match net::TcpListener::bind(addr).await {
Ok(listener) => {
s2a_listen_result_s.send(Ok(())).unwrap();
listener
},
Err(e) => {
info!(
?addr,
?e,
"Tcp bind error during listener startup"
);
s2a_listen_result_s.send(Err(e)).unwrap();
return;
},
};
trace!(?addr, "Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
while let Some(data) = select! {
next = listener.accept().fuse() => Some(next),
_ = &mut end_receiver => None,
} {
let (stream, remote_addr) = match data {
Ok((s, p)) => (s, p),
Err(e) => {
warn!(?e, "TcpStream Error, ignoring connection attempt");
continue;
},
};
#[cfg(feature = "metrics")]
mcache.inc();
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
info!(?remote_addr, ?cid, "Accepting Tcp from");
self.init_protocol(Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), cid, None, true)
.await;
}
},
#[cfg(feature = "quic")]
ListenAddr::Quic(addr, ref server_config) => {
let mut endpoint = quinn::Endpoint::builder();
endpoint.listen(server_config.clone());
let (_endpoint, mut listener) = match endpoint.bind(&addr) {
Ok((endpoint, listener)) => {
s2a_listen_result_s.send(Ok(())).unwrap();
(endpoint, listener)
},
Err(quinn::EndpointError::Socket(e)) => {
info!(
?addr,
?e,
"Quic bind error during listener startup"
);
s2a_listen_result_s.send(Err(e)).unwrap();
return;
}
};
trace!(?addr, "Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
while let Some(Some(connecting)) = select! {
next = listener.next().fuse() => Some(next),
_ = &mut end_receiver => None,
} {
let remote_addr = connecting.remote_address();
let connection = match connecting.await {
Ok(c) => c,
Err(e) => {
debug!(?e, ?remote_addr, "skipping connection attempt");
continue;
},
};
#[cfg(feature = "metrics")]
mcache.inc();
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
info!(?remote_addr, ?cid, "Accepting Quic from");
let quic = match Protocols::new_quic(connection, true, cid, Arc::clone(&self.protocol_metrics)).await {
Ok(quic) => quic,
Err(e) => {
trace!(?e, "failed to start quic");
continue;
}
};
self.init_protocol(quic, cid, None, true)
.await;
}
},
ListenAddr::Mpsc(addr) => {
let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
MPSC_POOL.lock().await.insert(addr, mpsc_s);
s2a_listen_result_s.send(Ok(())).unwrap();
trace!(?addr, "Listener bound");
let mut end_receiver = s2s_stop_listening_r.fuse();
while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
next = mpsc_r.recv().fuse() => next,
_ = &mut end_receiver => None,
} {
let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
local_remote_to_local_s.send(remote_to_local_s).unwrap();
#[cfg(feature = "metrics")]
mcache.inc();
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
info!(?addr, ?cid, "Accepting Mpsc from");
self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, Arc::clone(&self.protocol_metrics)), cid, None, true)
.await;
}
warn!("MpscStream Failed, stopping");
},/*
ProtocolListenAddr::Udp(addr) => {
let socket = match net::UdpSocket::bind(addr).await {
Ok(socket) => {
s2a_listen_result_s.send(Ok(())).unwrap();
Arc::new(socket)
},
Err(e) => {
info!(
?addr,
?e,
"Listener couldn't be started due to error on udp bind"
);
s2a_listen_result_s.send(Err(e)).unwrap();
return;
},
};
trace!(?addr, "Listener bound");
// receiving is done from here and will be piped to protocol as UDP does not
// have any state
let mut listeners = HashMap::new();
let mut end_receiver = s2s_stop_listening_r.fuse();
const UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER: usize = 9216;
let mut data = [0u8; UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER];
while let Ok((size, remote_addr)) = select! {
next = socket.recv_from(&mut data).fuse() => next,
_ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")),
} {
let mut datavec = Vec::with_capacity(size);
datavec.extend_from_slice(&data[0..size]);
//Due to the async nature i cannot make of .entry() as it would lead to a still
// borrowed in another branch situation
#[allow(clippy::map_entry)]
if !listeners.contains_key(&remote_addr) {
info!("Accepting Udp from: {}", &remote_addr);
let (udp_data_sender, udp_data_receiver) =
mpsc::unbounded_channel::<Vec<u8>>();
listeners.insert(remote_addr, udp_data_sender);
let protocol = UdpProtocol::new(
Arc::clone(&socket),
remote_addr,
#[cfg(feature = "metrics")]
Arc::clone(&self.metrics),
udp_data_receiver,
);
self.init_protocol(Protocols::Udp(protocol), None, false)
.await;
}
let udp_data_sender = listeners.get_mut(&remote_addr).unwrap();
udp_data_sender.send(datavec).unwrap();
}
},*/
_ => unimplemented!(),
}
trace!(?addr, "Ending channel creator");
}
#[allow(dead_code)]
async fn udp_single_channel_connect(
socket: Arc<net::UdpSocket>,
w2p_udp_package_s: mpsc::UnboundedSender<Vec<u8>>,
) {
let addr = socket.local_addr();
trace!(?addr, "Start udp_single_channel_connect");
//TODO: implement real closing
let (_end_sender, end_receiver) = oneshot::channel::<()>();
// receiving is done from here and will be piped to protocol as UDP does not
// have any state
let mut end_receiver = end_receiver.fuse();
let mut data = [0u8; 9216];
while let Ok(size) = select! {
next = socket.recv(&mut data).fuse() => next,
_ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")),
} {
let mut datavec = Vec::with_capacity(size);
datavec.extend_from_slice(&data[0..size]);
w2p_udp_package_s.send(datavec).unwrap();
}
trace!(?addr, "Stop udp_single_channel_connect");
}
async fn init_protocol(
&self,
mut protocol: Protocols,

View File

@ -85,7 +85,7 @@ fn stream_simple_quic() {
#[test]
fn stream_simple_quic_3msg() {
let (_, _) = helper::setup(true, 0);
let (_, _) = helper::setup(false, 0);
let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic());
s1_a.send("Hello World").unwrap();