veloren/network/src/channel.rs
Marcel Märtens aea52d8b54 implement Upload Bandwidth prediction.
Its available to `api` and `metrics` and can be used to slow down msg send in veloren.
It uses a tokio::watch for now, as i plan to have a watch job in the scheduler that recalculates prio on change.
Also cleaning up participant metrics after a disconnect
2021-03-26 08:58:03 +01:00

276 lines
8.2 KiB
Rust

use async_trait::async_trait;
use bytes::BytesMut;
use network_protocol::{
Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid,
ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
TcpSendProtocol, UnreliableDrain, UnreliableSink,
};
use std::{sync::Arc, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
sync::mpsc,
};
#[derive(Debug)]
pub(crate) enum Protocols {
Tcp((TcpSendProtocol<TcpDrain>, TcpRecvProtocol<TcpSink>)),
Mpsc((MpscSendProtocol<MpscDrain>, MpscRecvProtocol<MpscSink>)),
}
#[derive(Debug)]
pub(crate) enum SendProtocols {
Tcp(TcpSendProtocol<TcpDrain>),
Mpsc(MpscSendProtocol<MpscDrain>),
}
#[derive(Debug)]
pub(crate) enum RecvProtocols {
Tcp(TcpRecvProtocol<TcpSink>),
Mpsc(MpscRecvProtocol<MpscSink>),
}
impl Protocols {
pub(crate) fn new_tcp(
stream: tokio::net::TcpStream,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
) -> Self {
let (r, w) = stream.into_split();
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone());
let rp = TcpRecvProtocol::new(
TcpSink {
half: r,
buffer: BytesMut::new(),
},
metrics,
);
Protocols::Tcp((sp, rp))
}
pub(crate) fn new_mpsc(
sender: mpsc::Sender<MpscMsg>,
receiver: mpsc::Receiver<MpscMsg>,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
) -> Self {
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone());
let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics);
Protocols::Mpsc((sp, rp))
}
pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) {
match self {
Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)),
Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)),
}
}
}
#[async_trait]
impl network_protocol::InitProtocol for Protocols {
async fn initialize(
&mut self,
initializer: bool,
local_pid: Pid,
secret: u128,
) -> Result<(Pid, Sid, u128), InitProtocolError> {
match self {
Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await,
Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await,
}
}
}
#[async_trait]
impl network_protocol::SendProtocol for SendProtocols {
fn notify_from_recv(&mut self, event: ProtocolEvent) {
match self {
SendProtocols::Tcp(s) => s.notify_from_recv(event),
SendProtocols::Mpsc(s) => s.notify_from_recv(event),
}
}
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> {
match self {
SendProtocols::Tcp(s) => s.send(event).await,
SendProtocols::Mpsc(s) => s.send(event).await,
}
}
async fn flush(
&mut self,
bandwidth: Bandwidth,
dt: Duration,
) -> Result<Bandwidth, ProtocolError> {
match self {
SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await,
SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await,
}
}
}
#[async_trait]
impl network_protocol::RecvProtocol for RecvProtocols {
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> {
match self {
RecvProtocols::Tcp(r) => r.recv().await,
RecvProtocols::Mpsc(r) => r.recv().await,
}
}
}
///////////////////////////////////////
//// TCP
#[derive(Debug)]
pub struct TcpDrain {
half: OwnedWriteHalf,
}
#[derive(Debug)]
pub struct TcpSink {
half: OwnedReadHalf,
buffer: BytesMut,
}
#[async_trait]
impl UnreliableDrain for TcpDrain {
type DataFormat = BytesMut;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
match self.half.write_all(&data).await {
Ok(()) => Ok(()),
Err(_) => Err(ProtocolError::Closed),
}
}
}
#[async_trait]
impl UnreliableSink for TcpSink {
type DataFormat = BytesMut;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
self.buffer.resize(1500, 0u8);
match self.half.read(&mut self.buffer).await {
Ok(0) => Err(ProtocolError::Closed),
Ok(n) => Ok(self.buffer.split_to(n)),
Err(_) => Err(ProtocolError::Closed),
}
}
}
///////////////////////////////////////
//// MPSC
#[derive(Debug)]
pub struct MpscDrain {
sender: tokio::sync::mpsc::Sender<MpscMsg>,
}
#[derive(Debug)]
pub struct MpscSink {
receiver: tokio::sync::mpsc::Receiver<MpscMsg>,
}
#[async_trait]
impl UnreliableDrain for MpscDrain {
type DataFormat = MpscMsg;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
}
}
#[async_trait]
impl UnreliableSink for MpscSink {
type DataFormat = MpscMsg;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
self.receiver.recv().await.ok_or(ProtocolError::Closed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use network_protocol::{Promises, RecvProtocol, SendProtocol};
use tokio::net::{TcpListener, TcpStream};
#[tokio::test]
async fn tokio_sinks() {
let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap();
let r1 = tokio::spawn(async move {
let (server, _) = listener.accept().await.unwrap();
(listener, server)
});
let client = TcpStream::connect("127.0.0.1:5000").await.unwrap();
let (_listener, server) = r1.await.unwrap();
let metrics = Arc::new(ProtocolMetrics::new().unwrap());
let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics));
let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics));
let (mut s, _) = client.split();
let (_, mut r) = server.split();
let event = ProtocolEvent::OpenStream {
sid: Sid::new(1),
prio: 4u8,
promises: Promises::GUARANTEED_DELIVERY,
guaranteed_bandwidth: 1_000,
};
s.send(event.clone()).await.unwrap();
s.send(ProtocolEvent::Message {
sid: Sid::new(1),
data: Bytes::from(&[8u8; 8][..]),
})
.await
.unwrap();
s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
drop(s); // recv must work even after shutdown of send!
tokio::time::sleep(Duration::from_secs(1)).await;
let res = r.recv().await;
match res {
Ok(ProtocolEvent::OpenStream {
sid,
prio,
promises,
guaranteed_bandwidth: _,
}) => {
assert_eq!(sid, Sid::new(1));
assert_eq!(prio, 4u8);
assert_eq!(promises, Promises::GUARANTEED_DELIVERY);
},
_ => {
panic!("wrong type {:?}", res);
},
}
r.recv().await.unwrap();
}
#[tokio::test]
async fn tokio_sink_stop_after_drop() {
let listener = TcpListener::bind("127.0.0.1:5001").await.unwrap();
let r1 = tokio::spawn(async move {
let (server, _) = listener.accept().await.unwrap();
(listener, server)
});
let client = TcpStream::connect("127.0.0.1:5001").await.unwrap();
let (_listener, server) = r1.await.unwrap();
let metrics = Arc::new(ProtocolMetrics::new().unwrap());
let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics));
let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics));
let (s, _) = client.split();
let (_, mut r) = server.split();
let e = tokio::spawn(async move { r.recv().await });
drop(s);
let e = e.await.unwrap();
assert!(e.is_err());
assert_eq!(e.unwrap_err(), ProtocolError::Closed);
}
}