mirror of
https://gitlab.com/veloren/veloren.git
synced 2024-08-30 18:12:32 +00:00
extract protocol specific listen code from scheduler and move it to channel.rs
This commit is contained in:
parent
4afadf57dc
commit
c06d4d0156
@ -451,7 +451,10 @@ where
|
||||
m.data.extend_from_slice(&data);
|
||||
if m.data.len() == m.length as usize {
|
||||
// finished, yay
|
||||
let m = self.incoming.remove(&mid).unwrap();
|
||||
let m = self
|
||||
.incoming
|
||||
.remove(&mid)
|
||||
.ok_or(ProtocolError::Violated)?;
|
||||
self.metrics.rmsg_ob(
|
||||
m.sid,
|
||||
RemoveReason::Finished,
|
||||
|
@ -145,8 +145,8 @@ pub struct StreamParams {
|
||||
/// [`Arc`](std::sync::Arc) as all commands have internal mutability.
|
||||
///
|
||||
/// The `Network` has methods to [`connect`] to other [`Participants`] actively
|
||||
/// via their [`ProtocolConnectAddr`], or [`listen`] passively for [`connected`]
|
||||
/// [`Participants`] via [`ProtocolListenAddr`].
|
||||
/// via their [`ConnectAddr`], or [`listen`] passively for [`connected`]
|
||||
/// [`Participants`] via [`ListenAddr`].
|
||||
///
|
||||
/// Too guarantee a clean shutdown, the [`Runtime`] MUST NOT be droped before
|
||||
/// the Network.
|
||||
@ -178,6 +178,8 @@ pub struct StreamParams {
|
||||
/// [`connect`]: Network::connect
|
||||
/// [`listen`]: Network::listen
|
||||
/// [`connected`]: Network::connected
|
||||
/// [`ConnectAddr`]: crate::api::ConnectAddr
|
||||
/// [`ListenAddr`]: crate::api::ListenAddr
|
||||
pub struct Network {
|
||||
local_pid: Pid,
|
||||
participant_disconnect_sender: Arc<Mutex<HashMap<Pid, A2sDisconnect>>>,
|
||||
@ -293,7 +295,7 @@ impl Network {
|
||||
}
|
||||
}
|
||||
|
||||
/// starts listening on an [`ProtocolListenAddr`].
|
||||
/// starts listening on an [`ListenAddr`].
|
||||
/// When the method returns the `Network` is ready to listen for incoming
|
||||
/// connections OR has returned a [`NetworkError`] (e.g. port already used).
|
||||
/// You can call [`connected`] to asynchrony wait for a [`Participant`] to
|
||||
@ -303,7 +305,7 @@ impl Network {
|
||||
/// # Examples
|
||||
/// ```ignore
|
||||
/// use tokio::runtime::Runtime;
|
||||
/// use veloren_network::{Network, Pid, ProtocolListenAddr};
|
||||
/// use veloren_network::{Network, Pid, ListenAddr};
|
||||
///
|
||||
/// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
/// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally
|
||||
@ -311,10 +313,10 @@ impl Network {
|
||||
/// let network = Network::new(Pid::new(), &runtime);
|
||||
/// runtime.block_on(async {
|
||||
/// network
|
||||
/// .listen(ProtocolListenAddr::Tcp("127.0.0.1:2000".parse().unwrap()))
|
||||
/// .listen(ListenAddr::Tcp("127.0.0.1:2000".parse().unwrap()))
|
||||
/// .await?;
|
||||
/// network
|
||||
/// .listen(ProtocolListenAddr::Udp("127.0.0.1:2001".parse().unwrap()))
|
||||
/// .listen(ListenAddr::Udp("127.0.0.1:2001".parse().unwrap()))
|
||||
/// .await?;
|
||||
/// drop(network);
|
||||
/// # Ok(())
|
||||
@ -323,6 +325,7 @@ impl Network {
|
||||
/// ```
|
||||
///
|
||||
/// [`connected`]: Network::connected
|
||||
/// [`ListenAddr`]: crate::api::ListenAddr
|
||||
#[instrument(name="network", skip(self, address), fields(p = %self.local_pid))]
|
||||
pub async fn listen(&self, address: ListenAddr) -> Result<(), NetworkError> {
|
||||
let (s2a_result_s, s2a_result_r) = oneshot::channel::<tokio::io::Result<()>>();
|
||||
@ -339,13 +342,13 @@ impl Network {
|
||||
}
|
||||
}
|
||||
|
||||
/// starts connection to an [`ProtocolConnectAddr`].
|
||||
/// starts connection to an [`ConnectAddr`].
|
||||
/// When the method returns the Network either returns a [`Participant`]
|
||||
/// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g.
|
||||
/// can't connect, or invalid Handshake) # Examples
|
||||
/// ```ignore
|
||||
/// use tokio::runtime::Runtime;
|
||||
/// use veloren_network::{Network, Pid, ProtocolListenAddr, ProtocolConnectAddr};
|
||||
/// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr};
|
||||
///
|
||||
/// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
/// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above
|
||||
@ -353,16 +356,16 @@ impl Network {
|
||||
/// let network = Network::new(Pid::new(), &runtime);
|
||||
/// # let remote = Network::new(Pid::new(), &runtime);
|
||||
/// runtime.block_on(async {
|
||||
/// # remote.listen(ProtocolListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?;
|
||||
/// # remote.listen(ProtocolListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?;
|
||||
/// # remote.listen(ListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?;
|
||||
/// # remote.listen(ListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?;
|
||||
/// let p1 = network
|
||||
/// .connect(ProtocolConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap()))
|
||||
/// .connect(ConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap()))
|
||||
/// .await?;
|
||||
/// # //this doesn't work yet, so skip the test
|
||||
/// # //TODO fixme!
|
||||
/// # return Ok(());
|
||||
/// let p2 = network
|
||||
/// .connect(ProtocolConnectAddr::Udp("127.0.0.1:2011".parse().unwrap()))
|
||||
/// .connect(ConnectAddr::Udp("127.0.0.1:2011".parse().unwrap()))
|
||||
/// .await?;
|
||||
/// assert_eq!(&p1, &p2);
|
||||
/// # Ok(())
|
||||
@ -374,13 +377,13 @@ impl Network {
|
||||
/// ```
|
||||
/// Usually the `Network` guarantees that a operation on a [`Participant`]
|
||||
/// succeeds, e.g. by automatic retrying unless it fails completely e.g. by
|
||||
/// disconnecting from the remote. If 2 [`ProtocolConnectAddres`] you
|
||||
/// disconnecting from the remote. If 2 [`ConnectAddr] you
|
||||
/// `connect` to belongs to the same [`Participant`], you get the same
|
||||
/// [`Participant`] as a result. This is useful e.g. by connecting to
|
||||
/// the same [`Participant`] via multiple Protocols.
|
||||
///
|
||||
/// [`Streams`]: crate::api::Stream
|
||||
/// [`ProtocolConnectAddres`]: crate::api::ProtocolConnectAddr
|
||||
/// [`ConnectAddr`]: crate::api::ConnectAddr
|
||||
#[instrument(name="network", skip(self, address), fields(p = %self.local_pid))]
|
||||
pub async fn connect(&self, address: ConnectAddr) -> Result<Participant, NetworkError> {
|
||||
let (pid_sender, pid_receiver) =
|
||||
@ -403,7 +406,7 @@ impl Network {
|
||||
Ok(participant)
|
||||
}
|
||||
|
||||
/// returns a [`Participant`] created from a [`ProtocolListenAddr`] you
|
||||
/// returns a [`Participant`] created from a [`ListenAddr`] you
|
||||
/// called [`listen`] on before. This function will either return a
|
||||
/// working [`Participant`] ready to open [`Streams`] on OR has returned
|
||||
/// a [`NetworkError`] (e.g. Network got closed)
|
||||
@ -437,6 +440,7 @@ impl Network {
|
||||
///
|
||||
/// [`Streams`]: crate::api::Stream
|
||||
/// [`listen`]: crate::api::Network::listen
|
||||
/// [`ListenAddr`]: crate::api::ListenAddr
|
||||
#[instrument(name="network", skip(self), fields(p = %self.local_pid))]
|
||||
pub async fn connected(&self) -> Result<Participant, NetworkError> {
|
||||
let participant = self.connected_receiver.lock().await.recv().await?;
|
||||
|
@ -1,21 +1,34 @@
|
||||
use crate::api::NetworkConnectError;
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::FutureExt;
|
||||
#[cfg(feature = "quic")]
|
||||
use futures_util::StreamExt;
|
||||
use network_protocol::{
|
||||
Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid,
|
||||
ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat,
|
||||
QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol, Sid, TcpRecvProtocol,
|
||||
ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
|
||||
TcpSendProtocol, UnreliableDrain, UnreliableSink,
|
||||
};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
#[cfg(feature = "quic")]
|
||||
use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net,
|
||||
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
sync::{mpsc, oneshot},
|
||||
select,
|
||||
sync::{mpsc, oneshot, Mutex},
|
||||
};
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, trace};
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Debug)]
|
||||
@ -42,32 +55,67 @@ pub(crate) enum RecvProtocols {
|
||||
Quic(QuicRecvProtocol<QuicSink>),
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<C2cMpscConnect>>> = {
|
||||
Mutex::new(HashMap::new())
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) type C2cMpscConnect = (
|
||||
mpsc::Sender<MpscMsg>,
|
||||
oneshot::Sender<mpsc::Sender<MpscMsg>>,
|
||||
);
|
||||
|
||||
impl Protocols {
|
||||
const MPSC_CHANNEL_BOUND: usize = 1000;
|
||||
|
||||
pub(crate) async fn with_tcp_connect(
|
||||
addr: std::net::SocketAddr,
|
||||
cid: Cid,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
addr: SocketAddr,
|
||||
metrics: ProtocolMetricCache,
|
||||
) -> Result<Self, NetworkConnectError> {
|
||||
let stream = match net::TcpStream::connect(addr).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
return Err(crate::api::NetworkConnectError::Io(e));
|
||||
},
|
||||
};
|
||||
info!("Connecting Tcp to: {}", stream.peer_addr().unwrap());
|
||||
Ok(Protocols::new_tcp(stream, cid, metrics))
|
||||
let stream = net::TcpStream::connect(addr)
|
||||
.await
|
||||
.map_err(NetworkConnectError::Io)?;
|
||||
info!(
|
||||
"Connecting Tcp to: {}",
|
||||
stream.peer_addr().map_err(NetworkConnectError::Io)?
|
||||
);
|
||||
Ok(Self::new_tcp(stream, metrics))
|
||||
}
|
||||
|
||||
pub(crate) fn new_tcp(
|
||||
stream: tokio::net::TcpStream,
|
||||
cid: Cid,
|
||||
pub(crate) async fn with_tcp_listen(
|
||||
addr: SocketAddr,
|
||||
cids: Arc<AtomicU64>,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
) -> Self {
|
||||
let (r, w) = stream.into_split();
|
||||
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
|
||||
s2s_stop_listening_r: oneshot::Receiver<()>,
|
||||
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
|
||||
) -> std::io::Result<()> {
|
||||
let listener = net::TcpListener::bind(addr).await?;
|
||||
trace!(?addr, "Tcp Listener bound");
|
||||
let mut end_receiver = s2s_stop_listening_r.fuse();
|
||||
tokio::spawn(async move {
|
||||
while let Some(data) = select! {
|
||||
next = listener.accept().fuse() => Some(next),
|
||||
_ = &mut end_receiver => None,
|
||||
} {
|
||||
let (stream, remote_addr) = match data {
|
||||
Ok((s, p)) => (s, p),
|
||||
Err(e) => {
|
||||
trace!(?e, "TcpStream Error, ignoring connection attempt");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
let cid = cids.fetch_add(1, Ordering::Relaxed);
|
||||
info!(?remote_addr, ?cid, "Accepting Tcp from");
|
||||
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
|
||||
let _ = c2s_protocol_s.send((Self::new_tcp(stream, metrics.clone()), cid));
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn new_tcp(stream: tokio::net::TcpStream, metrics: ProtocolMetricCache) -> Self {
|
||||
let (r, w) = stream.into_split();
|
||||
let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone());
|
||||
let rp = TcpRecvProtocol::new(
|
||||
TcpSink {
|
||||
@ -81,70 +129,104 @@ impl Protocols {
|
||||
|
||||
pub(crate) async fn with_mpsc_connect(
|
||||
addr: u64,
|
||||
cid: Cid,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
metrics: ProtocolMetricCache,
|
||||
) -> Result<Self, NetworkConnectError> {
|
||||
let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) {
|
||||
Some(s) => s.clone(),
|
||||
None => {
|
||||
return Err(NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::NotConnected,
|
||||
let mpsc_s = MPSC_POOL
|
||||
.lock()
|
||||
.await
|
||||
.get(&addr)
|
||||
.ok_or_else(|| {
|
||||
NetworkConnectError::Io(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"no mpsc listen on this addr",
|
||||
)));
|
||||
},
|
||||
};
|
||||
))
|
||||
})?
|
||||
.clone();
|
||||
let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
|
||||
let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
|
||||
if mpsc_s
|
||||
mpsc_s
|
||||
.send((remote_to_local_s, local_to_remote_oneshot_s))
|
||||
.is_err()
|
||||
{
|
||||
return Err(NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
"mpsc pipe broke during connect",
|
||||
)));
|
||||
}
|
||||
let local_to_remote_s = match local_to_remote_oneshot_r.await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
e,
|
||||
)));
|
||||
},
|
||||
};
|
||||
.map_err(|_| {
|
||||
NetworkConnectError::Io(io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"mpsc pipe broke during connect",
|
||||
))
|
||||
})?;
|
||||
let local_to_remote_s = local_to_remote_oneshot_r
|
||||
.await
|
||||
.map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
|
||||
info!(?addr, "Connecting Mpsc");
|
||||
Ok(Self::new_mpsc(
|
||||
local_to_remote_s,
|
||||
remote_to_local_r,
|
||||
cid,
|
||||
metrics,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) async fn with_mpsc_listen(
|
||||
addr: u64,
|
||||
cids: Arc<AtomicU64>,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
s2s_stop_listening_r: oneshot::Receiver<()>,
|
||||
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
|
||||
) -> std::io::Result<()> {
|
||||
let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
|
||||
MPSC_POOL.lock().await.insert(addr, mpsc_s);
|
||||
trace!(?addr, "Mpsc Listener bound");
|
||||
let mut end_receiver = s2s_stop_listening_r.fuse();
|
||||
tokio::spawn(async move {
|
||||
while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
|
||||
next = mpsc_r.recv().fuse() => next,
|
||||
_ = &mut end_receiver => None,
|
||||
} {
|
||||
let (remote_to_local_s, remote_to_local_r) =
|
||||
mpsc::channel(Self::MPSC_CHANNEL_BOUND);
|
||||
if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) {
|
||||
error!(?e, "mpsc listen aborted");
|
||||
}
|
||||
|
||||
let cid = cids.fetch_add(1, Ordering::Relaxed);
|
||||
info!(?addr, ?cid, "Accepting Mpsc from");
|
||||
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
|
||||
let _ = c2s_protocol_s.send((
|
||||
Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()),
|
||||
cid,
|
||||
));
|
||||
}
|
||||
warn!("MpscStream Failed, stopping");
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn new_mpsc(
|
||||
sender: mpsc::Sender<MpscMsg>,
|
||||
receiver: mpsc::Receiver<MpscMsg>,
|
||||
cid: Cid,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
metrics: ProtocolMetricCache,
|
||||
) -> Self {
|
||||
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
|
||||
|
||||
let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone());
|
||||
let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics);
|
||||
Protocols::Mpsc((sp, rp))
|
||||
}
|
||||
|
||||
#[cfg(feature = "quic")]
|
||||
pub(crate) async fn with_quic_connect(
|
||||
addr: std::net::SocketAddr,
|
||||
addr: SocketAddr,
|
||||
config: quinn::ClientConfig,
|
||||
name: String,
|
||||
cid: Cid,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
metrics: ProtocolMetricCache,
|
||||
) -> Result<Self, NetworkConnectError> {
|
||||
let config = config.clone();
|
||||
let endpoint = quinn::Endpoint::builder();
|
||||
let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) {
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
let bindsock = match addr {
|
||||
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
|
||||
SocketAddr::V6(_) => {
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
|
||||
},
|
||||
};
|
||||
let (endpoint, _) = match endpoint.bind(&bindsock) {
|
||||
Ok(e) => e,
|
||||
Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)),
|
||||
};
|
||||
@ -164,7 +246,7 @@ impl Protocols {
|
||||
e,
|
||||
))
|
||||
})?;
|
||||
Protocols::new_quic(connection, false, cid, metrics)
|
||||
Self::new_quic(connection, false, metrics)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
trace!(?e, "error with quic");
|
||||
@ -175,15 +257,60 @@ impl Protocols {
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "quic")]
|
||||
pub(crate) async fn with_quic_listen(
|
||||
addr: SocketAddr,
|
||||
server_config: quinn::ServerConfig,
|
||||
cids: Arc<AtomicU64>,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
s2s_stop_listening_r: oneshot::Receiver<()>,
|
||||
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
|
||||
) -> std::io::Result<()> {
|
||||
let mut endpoint = quinn::Endpoint::builder();
|
||||
endpoint.listen(server_config);
|
||||
let (_endpoint, mut listener) = match endpoint.bind(&addr) {
|
||||
Ok(v) => v,
|
||||
Err(quinn::EndpointError::Socket(e)) => return Err(e),
|
||||
};
|
||||
trace!(?addr, "Quic Listener bound");
|
||||
let mut end_receiver = s2s_stop_listening_r.fuse();
|
||||
tokio::spawn(async move {
|
||||
while let Some(Some(connecting)) = select! {
|
||||
next = listener.next().fuse() => Some(next),
|
||||
_ = &mut end_receiver => None,
|
||||
} {
|
||||
let remote_addr = connecting.remote_address();
|
||||
let connection = match connecting.await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::debug!(?e, ?remote_addr, "skipping connection attempt");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let cid = cids.fetch_add(1, Ordering::Relaxed);
|
||||
info!(?remote_addr, ?cid, "Accepting Quic from");
|
||||
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
|
||||
match Protocols::new_quic(connection, true, metrics).await {
|
||||
Ok(quic) => {
|
||||
let _ = c2s_protocol_s.send((quic, cid));
|
||||
},
|
||||
Err(e) => {
|
||||
trace!(?e, "failed to start quic");
|
||||
continue;
|
||||
},
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "quic")]
|
||||
pub(crate) async fn new_quic(
|
||||
mut connection: quinn::NewConnection,
|
||||
listen: bool,
|
||||
cid: Cid,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
metrics: ProtocolMetricCache,
|
||||
) -> Result<Self, quinn::ConnectionError> {
|
||||
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
|
||||
|
||||
let (sendstream, recvstream) = if listen {
|
||||
connection.connection.open_bi().await?
|
||||
} else {
|
||||
@ -191,7 +318,7 @@ impl Protocols {
|
||||
.bi_streams
|
||||
.next()
|
||||
.await
|
||||
.ok_or_else(|| quinn::ConnectionError::LocallyClosed)??
|
||||
.ok_or(quinn::ConnectionError::LocallyClosed)??
|
||||
};
|
||||
let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
|
||||
let streams_s_clone = recvstreams_s.clone();
|
||||
@ -521,7 +648,8 @@ impl UnreliableSink for QuicSink {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use bytes::Bytes;
|
||||
use network_protocol::{Promises, RecvProtocol, SendProtocol};
|
||||
use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
#[tokio::test]
|
||||
@ -533,9 +661,9 @@ mod tests {
|
||||
});
|
||||
let client = TcpStream::connect("127.0.0.1:5000").await.unwrap();
|
||||
let (_listener, server) = r1.await.unwrap();
|
||||
let metrics = Arc::new(ProtocolMetrics::new().unwrap());
|
||||
let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics));
|
||||
let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics));
|
||||
let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
|
||||
let client = Protocols::new_tcp(client, metrics.clone());
|
||||
let server = Protocols::new_tcp(server, metrics);
|
||||
let (mut s, _) = client.split();
|
||||
let (_, mut r) = server.split();
|
||||
let event = ProtocolEvent::OpenStream {
|
||||
@ -582,9 +710,9 @@ mod tests {
|
||||
});
|
||||
let client = TcpStream::connect("127.0.0.1:5001").await.unwrap();
|
||||
let (_listener, server) = r1.await.unwrap();
|
||||
let metrics = Arc::new(ProtocolMetrics::new().unwrap());
|
||||
let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics));
|
||||
let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics));
|
||||
let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
|
||||
let client = Protocols::new_tcp(client, metrics.clone());
|
||||
let server = Protocols::new_tcp(server, metrics);
|
||||
let (s, _) = client.split();
|
||||
let (_, mut r) = server.split();
|
||||
let e = tokio::spawn(async move { r.recv().await });
|
||||
|
@ -30,7 +30,7 @@ impl Message {
|
||||
/// # Example
|
||||
/// for example coding, see [`send_raw`]
|
||||
///
|
||||
/// [`send_raw`]: Stream::send_raw
|
||||
/// [`send_raw`]: crate::api::Stream::send_raw
|
||||
/// [`Participants`]: crate::api::Participant
|
||||
/// [`compress`]: lz_fear::raw::compress2
|
||||
/// [`Message::serialize`]: crate::message::Message::serialize
|
||||
|
@ -756,7 +756,7 @@ impl BParticipant {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use network_protocol::ProtocolMetrics;
|
||||
use network_protocol::{ProtocolMetricCache, ProtocolMetrics};
|
||||
use tokio::{
|
||||
runtime::Runtime,
|
||||
sync::{mpsc, oneshot},
|
||||
@ -816,14 +816,16 @@ mod tests {
|
||||
) -> Protocols {
|
||||
let (s1, r1) = mpsc::channel(100);
|
||||
let (s2, r2) = mpsc::channel(100);
|
||||
let metrics = Arc::new(ProtocolMetrics::new().unwrap());
|
||||
let p1 = Protocols::new_mpsc(s1, r2, cid, Arc::clone(&metrics));
|
||||
let met = Arc::new(ProtocolMetrics::new().unwrap());
|
||||
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met));
|
||||
let p1 = Protocols::new_mpsc(s1, r2, metrics);
|
||||
let (complete_s, complete_r) = oneshot::channel();
|
||||
create_channel
|
||||
.send((cid, Sid::new(0), p1, complete_s))
|
||||
.unwrap();
|
||||
complete_r.await.unwrap();
|
||||
Protocols::new_mpsc(s2, r1, cid, Arc::clone(&metrics))
|
||||
let metrics = ProtocolMetricCache::new(&cid.to_string(), met);
|
||||
Protocols::new_mpsc(s2, r1, metrics)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -4,8 +4,8 @@ use crate::{
|
||||
metrics::{NetworkMetrics, ProtocolInfo},
|
||||
participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant},
|
||||
};
|
||||
use futures_util::{FutureExt, StreamExt};
|
||||
use network_protocol::{Cid, MpscMsg, Pid, ProtocolMetrics};
|
||||
use futures_util::StreamExt;
|
||||
use network_protocol::{Cid, Pid, ProtocolMetricCache, ProtocolMetrics};
|
||||
#[cfg(feature = "metrics")]
|
||||
use prometheus::Registry;
|
||||
use rand::Rng;
|
||||
@ -18,7 +18,7 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
io, net, select,
|
||||
io,
|
||||
sync::{mpsc, oneshot, Mutex},
|
||||
};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
@ -33,12 +33,6 @@ use tracing::*;
|
||||
// - w: wire
|
||||
// - c: channel/handshake
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = {
|
||||
Mutex::new(HashMap::new())
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ParticipantInfo {
|
||||
secret: u128,
|
||||
@ -52,10 +46,6 @@ pub(crate) type A2sConnect = (
|
||||
oneshot::Sender<Result<Participant, NetworkConnectError>>,
|
||||
);
|
||||
type A2sDisconnect = (Pid, S2bShutdownBparticipant);
|
||||
type S2sMpscConnect = (
|
||||
mpsc::Sender<MpscMsg>,
|
||||
oneshot::Sender<mpsc::Sender<MpscMsg>>,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ControlChannels {
|
||||
@ -88,8 +78,6 @@ pub struct Scheduler {
|
||||
}
|
||||
|
||||
impl Scheduler {
|
||||
const MPSC_CHANNEL_BOUND: usize = 1000;
|
||||
|
||||
pub fn new(
|
||||
local_pid: Pid,
|
||||
#[cfg(feature = "metrics")] registry: Option<&Registry>,
|
||||
@ -157,7 +145,10 @@ impl Scheduler {
|
||||
}
|
||||
|
||||
pub async fn run(mut self) {
|
||||
let run_channels = self.run_channels.take().unwrap();
|
||||
let run_channels = self
|
||||
.run_channels
|
||||
.take()
|
||||
.expect("run() can only be called once");
|
||||
|
||||
tokio::join!(
|
||||
self.listen_mgr(run_channels.a2s_listen_r),
|
||||
@ -174,17 +165,66 @@ impl Scheduler {
|
||||
a2s_listen_r
|
||||
.for_each_concurrent(None, |(address, s2a_listen_result_s)| {
|
||||
let address = address;
|
||||
let cids = Arc::clone(&self.channel_ids);
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
let mcache = self.metrics.connect_requests_cache(&address);
|
||||
|
||||
debug!(?address, "Got request to open a channel_creator");
|
||||
self.metrics.listen_request(&address);
|
||||
let (s2s_stop_listening_s, s2s_stop_listening_r) = oneshot::channel::<()>();
|
||||
let (c2s_protocol_s, mut c2s_protocol_r) = mpsc::unbounded_channel();
|
||||
let metrics = Arc::clone(&self.protocol_metrics);
|
||||
|
||||
async move {
|
||||
debug!(?address, "Got request to open a channel_creator");
|
||||
self.metrics.listen_request(&address);
|
||||
let (end_sender, end_receiver) = oneshot::channel::<()>();
|
||||
self.channel_listener
|
||||
.lock()
|
||||
.await
|
||||
.insert(address.clone().into(), end_sender);
|
||||
self.channel_creator(address, end_receiver, s2a_listen_result_s)
|
||||
.await;
|
||||
.insert(address.clone().into(), s2s_stop_listening_s);
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
mcache.inc();
|
||||
|
||||
let res = match address {
|
||||
ListenAddr::Tcp(addr) => {
|
||||
Protocols::with_tcp_listen(
|
||||
addr,
|
||||
cids,
|
||||
metrics,
|
||||
s2s_stop_listening_r,
|
||||
c2s_protocol_s,
|
||||
)
|
||||
.await
|
||||
},
|
||||
#[cfg(feature = "quic")]
|
||||
ListenAddr::Quic(addr, ref server_config) => {
|
||||
Protocols::with_quic_listen(
|
||||
addr,
|
||||
server_config.clone(),
|
||||
cids,
|
||||
metrics,
|
||||
s2s_stop_listening_r,
|
||||
c2s_protocol_s,
|
||||
)
|
||||
.await
|
||||
},
|
||||
ListenAddr::Mpsc(addr) => {
|
||||
Protocols::with_mpsc_listen(
|
||||
addr,
|
||||
cids,
|
||||
metrics,
|
||||
s2s_stop_listening_r,
|
||||
c2s_protocol_s,
|
||||
)
|
||||
.await
|
||||
},
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
let _ = s2a_listen_result_s.send(res);
|
||||
|
||||
while let Some((prot, cid)) = c2s_protocol_r.recv().await {
|
||||
self.init_protocol(prot, cid, None, true).await;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
@ -195,15 +235,16 @@ impl Scheduler {
|
||||
trace!("Start connect_mgr");
|
||||
while let Some((addr, pid_sender)) = a2s_connect_r.recv().await {
|
||||
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
|
||||
let metrics = Arc::clone(&self.protocol_metrics);
|
||||
let metrics =
|
||||
ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics));
|
||||
self.metrics.connect_request(&addr);
|
||||
let protocol = match addr {
|
||||
ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await,
|
||||
ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await,
|
||||
#[cfg(feature = "quic")]
|
||||
ConnectAddr::Quic(addr, ref config, name) => {
|
||||
Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await
|
||||
Protocols::with_quic_connect(addr, config.clone(), name, metrics).await
|
||||
},
|
||||
ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await,
|
||||
ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, metrics).await,
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
let protocol = match protocol {
|
||||
@ -327,204 +368,6 @@ impl Scheduler {
|
||||
trace!("Stop scheduler_shutdown_mgr");
|
||||
}
|
||||
|
||||
async fn channel_creator(
|
||||
&self,
|
||||
addr: ListenAddr,
|
||||
s2s_stop_listening_r: oneshot::Receiver<()>,
|
||||
s2a_listen_result_s: oneshot::Sender<io::Result<()>>,
|
||||
) {
|
||||
trace!(?addr, "Start up channel creator");
|
||||
#[cfg(feature = "metrics")]
|
||||
let mcache = self.metrics.connect_requests_cache(&addr);
|
||||
match addr {
|
||||
ListenAddr::Tcp(addr) => {
|
||||
let listener = match net::TcpListener::bind(addr).await {
|
||||
Ok(listener) => {
|
||||
s2a_listen_result_s.send(Ok(())).unwrap();
|
||||
listener
|
||||
},
|
||||
Err(e) => {
|
||||
info!(
|
||||
?addr,
|
||||
?e,
|
||||
"Tcp bind error during listener startup"
|
||||
);
|
||||
s2a_listen_result_s.send(Err(e)).unwrap();
|
||||
return;
|
||||
},
|
||||
};
|
||||
trace!(?addr, "Listener bound");
|
||||
let mut end_receiver = s2s_stop_listening_r.fuse();
|
||||
while let Some(data) = select! {
|
||||
next = listener.accept().fuse() => Some(next),
|
||||
_ = &mut end_receiver => None,
|
||||
} {
|
||||
let (stream, remote_addr) = match data {
|
||||
Ok((s, p)) => (s, p),
|
||||
Err(e) => {
|
||||
warn!(?e, "TcpStream Error, ignoring connection attempt");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
#[cfg(feature = "metrics")]
|
||||
mcache.inc();
|
||||
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
|
||||
info!(?remote_addr, ?cid, "Accepting Tcp from");
|
||||
self.init_protocol(Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), cid, None, true)
|
||||
.await;
|
||||
}
|
||||
},
|
||||
#[cfg(feature = "quic")]
|
||||
ListenAddr::Quic(addr, ref server_config) => {
|
||||
let mut endpoint = quinn::Endpoint::builder();
|
||||
endpoint.listen(server_config.clone());
|
||||
let (_endpoint, mut listener) = match endpoint.bind(&addr) {
|
||||
Ok((endpoint, listener)) => {
|
||||
s2a_listen_result_s.send(Ok(())).unwrap();
|
||||
(endpoint, listener)
|
||||
},
|
||||
Err(quinn::EndpointError::Socket(e)) => {
|
||||
info!(
|
||||
?addr,
|
||||
?e,
|
||||
"Quic bind error during listener startup"
|
||||
);
|
||||
s2a_listen_result_s.send(Err(e)).unwrap();
|
||||
return;
|
||||
}
|
||||
};
|
||||
trace!(?addr, "Listener bound");
|
||||
let mut end_receiver = s2s_stop_listening_r.fuse();
|
||||
while let Some(Some(connecting)) = select! {
|
||||
next = listener.next().fuse() => Some(next),
|
||||
_ = &mut end_receiver => None,
|
||||
} {
|
||||
let remote_addr = connecting.remote_address();
|
||||
let connection = match connecting.await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
debug!(?e, ?remote_addr, "skipping connection attempt");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
#[cfg(feature = "metrics")]
|
||||
mcache.inc();
|
||||
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
|
||||
info!(?remote_addr, ?cid, "Accepting Quic from");
|
||||
let quic = match Protocols::new_quic(connection, true, cid, Arc::clone(&self.protocol_metrics)).await {
|
||||
Ok(quic) => quic,
|
||||
Err(e) => {
|
||||
trace!(?e, "failed to start quic");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
self.init_protocol(quic, cid, None, true)
|
||||
.await;
|
||||
}
|
||||
},
|
||||
ListenAddr::Mpsc(addr) => {
|
||||
let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
|
||||
MPSC_POOL.lock().await.insert(addr, mpsc_s);
|
||||
s2a_listen_result_s.send(Ok(())).unwrap();
|
||||
trace!(?addr, "Listener bound");
|
||||
|
||||
let mut end_receiver = s2s_stop_listening_r.fuse();
|
||||
while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
|
||||
next = mpsc_r.recv().fuse() => next,
|
||||
_ = &mut end_receiver => None,
|
||||
} {
|
||||
let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
|
||||
local_remote_to_local_s.send(remote_to_local_s).unwrap();
|
||||
#[cfg(feature = "metrics")]
|
||||
mcache.inc();
|
||||
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
|
||||
info!(?addr, ?cid, "Accepting Mpsc from");
|
||||
self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, Arc::clone(&self.protocol_metrics)), cid, None, true)
|
||||
.await;
|
||||
}
|
||||
warn!("MpscStream Failed, stopping");
|
||||
},/*
|
||||
ProtocolListenAddr::Udp(addr) => {
|
||||
let socket = match net::UdpSocket::bind(addr).await {
|
||||
Ok(socket) => {
|
||||
s2a_listen_result_s.send(Ok(())).unwrap();
|
||||
Arc::new(socket)
|
||||
},
|
||||
Err(e) => {
|
||||
info!(
|
||||
?addr,
|
||||
?e,
|
||||
"Listener couldn't be started due to error on udp bind"
|
||||
);
|
||||
s2a_listen_result_s.send(Err(e)).unwrap();
|
||||
return;
|
||||
},
|
||||
};
|
||||
trace!(?addr, "Listener bound");
|
||||
// receiving is done from here and will be piped to protocol as UDP does not
|
||||
// have any state
|
||||
let mut listeners = HashMap::new();
|
||||
let mut end_receiver = s2s_stop_listening_r.fuse();
|
||||
const UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER: usize = 9216;
|
||||
let mut data = [0u8; UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER];
|
||||
while let Ok((size, remote_addr)) = select! {
|
||||
next = socket.recv_from(&mut data).fuse() => next,
|
||||
_ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")),
|
||||
} {
|
||||
let mut datavec = Vec::with_capacity(size);
|
||||
datavec.extend_from_slice(&data[0..size]);
|
||||
//Due to the async nature i cannot make of .entry() as it would lead to a still
|
||||
// borrowed in another branch situation
|
||||
#[allow(clippy::map_entry)]
|
||||
if !listeners.contains_key(&remote_addr) {
|
||||
info!("Accepting Udp from: {}", &remote_addr);
|
||||
let (udp_data_sender, udp_data_receiver) =
|
||||
mpsc::unbounded_channel::<Vec<u8>>();
|
||||
listeners.insert(remote_addr, udp_data_sender);
|
||||
let protocol = UdpProtocol::new(
|
||||
Arc::clone(&socket),
|
||||
remote_addr,
|
||||
#[cfg(feature = "metrics")]
|
||||
Arc::clone(&self.metrics),
|
||||
udp_data_receiver,
|
||||
);
|
||||
self.init_protocol(Protocols::Udp(protocol), None, false)
|
||||
.await;
|
||||
}
|
||||
let udp_data_sender = listeners.get_mut(&remote_addr).unwrap();
|
||||
udp_data_sender.send(datavec).unwrap();
|
||||
}
|
||||
},*/
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
trace!(?addr, "Ending channel creator");
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn udp_single_channel_connect(
|
||||
socket: Arc<net::UdpSocket>,
|
||||
w2p_udp_package_s: mpsc::UnboundedSender<Vec<u8>>,
|
||||
) {
|
||||
let addr = socket.local_addr();
|
||||
trace!(?addr, "Start udp_single_channel_connect");
|
||||
//TODO: implement real closing
|
||||
let (_end_sender, end_receiver) = oneshot::channel::<()>();
|
||||
|
||||
// receiving is done from here and will be piped to protocol as UDP does not
|
||||
// have any state
|
||||
let mut end_receiver = end_receiver.fuse();
|
||||
let mut data = [0u8; 9216];
|
||||
while let Ok(size) = select! {
|
||||
next = socket.recv(&mut data).fuse() => next,
|
||||
_ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")),
|
||||
} {
|
||||
let mut datavec = Vec::with_capacity(size);
|
||||
datavec.extend_from_slice(&data[0..size]);
|
||||
w2p_udp_package_s.send(datavec).unwrap();
|
||||
}
|
||||
trace!(?addr, "Stop udp_single_channel_connect");
|
||||
}
|
||||
|
||||
async fn init_protocol(
|
||||
&self,
|
||||
mut protocol: Protocols,
|
||||
|
@ -85,7 +85,7 @@ fn stream_simple_quic() {
|
||||
|
||||
#[test]
|
||||
fn stream_simple_quic_3msg() {
|
||||
let (_, _) = helper::setup(true, 0);
|
||||
let (_, _) = helper::setup(false, 0);
|
||||
let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic());
|
||||
|
||||
s1_a.send("Hello World").unwrap();
|
||||
|
Loading…
Reference in New Issue
Block a user