diff --git a/Cargo.lock b/Cargo.lock index 69b549cd2b..bd0dcb62ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -259,6 +259,17 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-trait" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d3a45e77e34375a7923b1e8febb049bb011f064714a8e17a1a616fef01da13d" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.60", +] + [[package]] name = "atom" version = "0.3.6" @@ -1057,6 +1068,7 @@ dependencies = [ "clap", "criterion-plot", "csv", + "futures", "itertools 0.10.0", "lazy_static", "num-traits", @@ -1069,6 +1081,7 @@ dependencies = [ "serde_derive", "serde_json", "tinytemplate", + "tokio 1.2.0", "walkdir 2.3.1", ] @@ -5580,7 +5593,7 @@ dependencies = [ "veloren-common", "veloren-common-net", "veloren-common-sys", - "veloren_network", + "veloren-network", ] [[package]] @@ -5661,6 +5674,47 @@ dependencies = [ "wasmer", ] +[[package]] +name = "veloren-network" +version = "0.3.0" +dependencies = [ + "async-channel", + "async-trait", + "bincode", + "bitflags", + "clap", + "crossbeam-channel 0.5.0", + "futures-core", + "futures-util", + "lazy_static", + "lz-fear", + "prometheus", + "rand 0.8.3", + "serde", + "shellexpand", + "tiny_http", + "tokio 1.2.0", + "tokio-stream", + "tracing", + "tracing-futures", + "tracing-subscriber", + "veloren-network-protocol", +] + +[[package]] +name = "veloren-network-protocol" +version = "0.5.0" +dependencies = [ + "async-channel", + "async-trait", + "bitflags", + "criterion", + "prometheus", + "rand 0.8.3", + "tokio 1.2.0", + "tracing", +] + [[package]] name = "veloren-plugin-api" version = "0.1.0" @@ -5725,9 +5779,9 @@ dependencies = [ "veloren-common", "veloren-common-net", "veloren-common-sys", + "veloren-network", "veloren-plugin-api", "veloren-world", - "veloren_network", ] [[package]] @@ -5864,31 +5918,6 @@ dependencies = [ "veloren-common-net", ] -[[package]] -name = "veloren_network" -version = "0.3.0" -dependencies = [ - "async-channel", - "bincode", - "bitflags", - "clap", - "crossbeam-channel 0.5.0", - "futures-core", - "futures-util", - "lazy_static", - "lz-fear", - "prometheus", - "rand 0.8.3", - "serde", - "shellexpand", - "tiny_http", - "tokio 1.2.0", - "tokio-stream", - "tracing", - "tracing-futures", - "tracing-subscriber", -] - [[package]] name = "version-compare" version = "0.0.10" diff --git a/Cargo.toml b/Cargo.toml index 6bd656493d..e8104a7195 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "voxygen/anim", "world", "network", + "network/protocol" ] # default profile for devs, fast to compile, okay enough to run, no debug information @@ -30,8 +31,10 @@ incremental = true # All dependencies (but not this crate itself) [profile.dev.package."*"] opt-level = 3 -[profile.dev.package."veloren_network"] +[profile.dev.package."veloren-network"] opt-level = 2 +[profile.dev.package."veloren-network-protocol"] +opt-level = 3 [profile.dev.package."veloren-common"] opt-level = 2 [profile.dev.package."veloren-client"] diff --git a/client/Cargo.toml b/client/Cargo.toml index b2ebcfead1..fcda01cb65 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -14,7 +14,7 @@ default = ["simd"] common = { package = "veloren-common", path = "../common", features = ["no-assets"] } common-sys = { package = "veloren-common-sys", path = "../common/sys", default-features = false } common-net = { package = "veloren-common-net", path = "../common/net" } -network = { package = "veloren_network", path = "../network", features = ["compression"], default-features = false } +network = { package = "veloren-network", path = "../network", features = ["compression"], default-features = false } byteorder = "1.3.2" uvth = "3.1.1" diff --git a/network/Cargo.toml b/network/Cargo.toml index 0a540ca6dc..f548278896 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "veloren_network" +name = "veloren-network" version = "0.3.0" authors = ["Marcel Märtens "] edition = "2018" @@ -7,13 +7,15 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -metrics = ["prometheus"] +metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] default = ["metrics","compression"] [dependencies] +network-protocol = { package = "veloren-network-protocol", path = "protocol", default-features = false } + #serialisation bincode = "1.3.1" serde = { version = "1.0" } @@ -35,10 +37,12 @@ rand = { version = "0.8" } #stream flags bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } +# async traits +async-trait = "0.1.42" [dev-dependencies] tracing-subscriber = { version = "0.2.3", default-features = false, features = ["env-filter", "fmt", "chrono", "ansi", "smallvec"] } -tokio = { version = "1.0.1", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } +tokio = { version = "1.1.0", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } clap = { version = "2.33", default-features = false } shellexpand = "2.0.0" diff --git a/network/protocol/Cargo.toml b/network/protocol/Cargo.toml new file mode 100644 index 0000000000..e097314b6b --- /dev/null +++ b/network/protocol/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "veloren-network-protocol" +description = "pure Protocol without any I/O itself" +version = "0.5.0" +authors = ["Marcel Märtens "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +metrics = ["prometheus"] + +default = ["metrics"] + +[dependencies] + +#tracing and metrics +tracing = { version = "0.1", default-features = false } +prometheus = { version = "0.11", default-features = false, optional = true } +#stream flags +bitflags = "1.2.1" +rand = { version = "0.8" } +# async traits +async-trait = "0.1.42" + +[dev-dependencies] +async-channel = "1.5.1" +tokio = { version = "1.2", default-features = false, features = ["rt", "macros"] } +criterion = { version = "0.3.4", features = ["default", "async_tokio"] } + +[[bench]] +name = "protocols" +harness = false \ No newline at end of file diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs new file mode 100644 index 0000000000..5151083b98 --- /dev/null +++ b/network/protocol/benches/protocols.rs @@ -0,0 +1,243 @@ +use async_channel::*; +use async_trait::async_trait; +use criterion::{criterion_group, criterion_main, Criterion}; +use std::{sync::Arc, time::Duration}; +use veloren_network_protocol::{ + InitProtocol, MessageBuffer, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, Promises, + ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, + Sid, TcpRecvProtcol, TcpSendProtcol, UnreliableDrain, UnreliableSink, _internal::Frame, +}; + +fn frame_serialize(frame: Frame, buffer: &mut [u8]) -> usize { frame.to_bytes(buffer).0 } + +async fn mpsc_msg(buffer: Arc) { + // Arrrg, need to include constructor here + let [p1, p2] = utils::ac_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + s.send(ProtocolEvent::Message { + sid: Sid::new(12), + mid: 0, + buffer, + }) + .await + .unwrap(); + r.recv().await.unwrap(); +} + +async fn mpsc_handshake() { + let [mut p1, mut p2] = utils::ac_bound(10, None); + let r1 = tokio::spawn(async move { + p1.initialize(true, Pid::fake(2), 1337).await.unwrap(); + p1 + }); + let r2 = tokio::spawn(async move { + p2.initialize(false, Pid::fake(3), 42).await.unwrap(); + p2 + }); + let (r1, r2) = tokio::join!(r1, r2); + r1.unwrap(); + r2.unwrap(); +} + +async fn tcp_msg(buffer: Arc, cnt: usize) { + let [p1, p2] = utils::tcp_bound(10000, None); /*10kbit*/ + let (mut s, mut r) = (p1.0, p2.1); + + let buffer = Arc::clone(&buffer); + let bandwidth = buffer.data.len() as u64 + 1000; + + let r1 = tokio::spawn(async move { + s.send(ProtocolEvent::OpenStream { + sid: Sid::new(12), + prio: 0, + promises: Promises::ORDERED, + guaranteed_bandwidth: 100_000, + }) + .await + .unwrap(); + + for i in 0..cnt { + s.send(ProtocolEvent::Message { + sid: Sid::new(12), + mid: i as u64, + buffer: Arc::clone(&buffer), + }) + .await + .unwrap(); + s.flush(bandwidth, Duration::from_secs(1)).await.unwrap(); + } + }); + let r2 = tokio::spawn(async move { + r.recv().await.unwrap(); + + for _ in 0..cnt { + r.recv().await.unwrap(); + } + }); + let (r1, r2) = tokio::join!(r1, r2); + r1.unwrap(); + r2.unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let rt = || { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() + }; + + c.bench_function("mpsc_short_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: b"hello_world".to_vec(), + }); + b.to_async(rt()).iter(|| mpsc_msg(Arc::clone(&buffer))) + }); + c.bench_function("mpsc_long_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: vec![150u8; 500_000], + }); + b.to_async(rt()).iter(|| mpsc_msg(Arc::clone(&buffer))) + }); + c.bench_function("mpsc_handshake", |b| { + b.to_async(rt()).iter(|| mpsc_handshake()) + }); + + let mut buffer = [0u8; 1500]; + + c.bench_function("frame_serialize_short", |b| { + let frame = Frame::Data { + mid: 65, + start: 89u64, + data: b"hello_world".to_vec(), + }; + b.iter(move || frame_serialize(frame.clone(), &mut buffer)) + }); + + c.bench_function("tcp_short_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: b"hello_world".to_vec(), + }); + b.to_async(rt()).iter(|| tcp_msg(Arc::clone(&buffer), 1)) + }); + c.bench_function("tcp_1GB_in_10000_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: vec![155u8; 100_000], + }); + b.to_async(rt()) + .iter(|| tcp_msg(Arc::clone(&buffer), 10_000)) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + +mod utils { + use super::*; + + pub struct ACDrain { + sender: Sender, + } + + pub struct ACSink { + receiver: Receiver, + } + + pub fn ac_bound( + cap: usize, + metrics: Option, + ) -> [(MpscSendProtcol, MpscRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + MpscSendProtcol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r2 }, m.clone()), + ), + ( + MpscSendProtcol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r1 }, m.clone()), + ), + ] + } + + pub struct TcpDrain { + sender: Sender>, + } + + pub struct TcpSink { + receiver: Receiver>, + } + + /// emulate Tcp protocol on Channels + pub fn tcp_bound( + cap: usize, + metrics: Option, + ) -> [(TcpSendProtcol, TcpRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + TcpSendProtcol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r2 }, m.clone()), + ), + ( + TcpSendProtcol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for ACDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for ACSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} diff --git a/network/protocol/src/event.rs b/network/protocol/src/event.rs new file mode 100644 index 0000000000..14b74de558 --- /dev/null +++ b/network/protocol/src/event.rs @@ -0,0 +1,74 @@ +use crate::{ + frame::Frame, + message::MessageBuffer, + types::{Bandwidth, Mid, Prio, Promises, Sid}, +}; +use std::sync::Arc; + +/* used for communication with Protocols */ +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ProtocolEvent { + Shutdown, + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + guaranteed_bandwidth: Bandwidth, + }, + CloseStream { + sid: Sid, + }, + Message { + buffer: Arc, + mid: Mid, + sid: Sid, + }, +} + +impl ProtocolEvent { + pub(crate) fn to_frame(&self) -> Frame { + match self { + ProtocolEvent::Shutdown => Frame::Shutdown, + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: _, + } => Frame::OpenStream { + sid: *sid, + prio: *prio, + promises: *promises, + }, + ProtocolEvent::CloseStream { sid } => Frame::CloseStream { sid: *sid }, + ProtocolEvent::Message { .. } => { + unimplemented!("Event::Message to Frame IS NOT supported") + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_to_frame() { + assert_eq!(ProtocolEvent::Shutdown.to_frame(), Frame::Shutdown); + assert_eq!( + ProtocolEvent::CloseStream { sid: Sid::new(42) }.to_frame(), + Frame::CloseStream { sid: Sid::new(42) } + ); + } + + #[test] + #[should_panic] + fn test_sixlet_to_str() { + let _ = ProtocolEvent::Message { + buffer: Arc::new(MessageBuffer { data: vec![] }), + mid: 0, + sid: Sid::new(23), + } + .to_frame(); + } +} diff --git a/network/protocol/src/frame.rs b/network/protocol/src/frame.rs new file mode 100644 index 0000000000..4824498940 --- /dev/null +++ b/network/protocol/src/frame.rs @@ -0,0 +1,634 @@ +use crate::types::{Mid, Pid, Prio, Promises, Sid}; +use std::{collections::VecDeque, convert::TryFrom}; + +// const FRAME_RESERVED_1: u8 = 0; +const FRAME_HANDSHAKE: u8 = 1; +const FRAME_INIT: u8 = 2; +const FRAME_SHUTDOWN: u8 = 3; +const FRAME_OPEN_STREAM: u8 = 4; +const FRAME_CLOSE_STREAM: u8 = 5; +const FRAME_DATA_HEADER: u8 = 6; +const FRAME_DATA: u8 = 7; +const FRAME_RAW: u8 = 8; +//const FRAME_RESERVED_2: u8 = 10; +//const FRAME_RESERVED_3: u8 = 13; + +/// Used for Communication between Channel <----(TCP/UDP)----> Channel +#[derive(Debug, PartialEq, Clone)] +pub /* should be crate only */ enum InitFrame { + Handshake { + magic_number: [u8; 7], + version: [u32; 3], + }, + Init { + pid: Pid, + secret: u128, + }, + /* WARNING: Sending RAW is only used for debug purposes in case someone write a new API + * against veloren Server! */ + Raw(Vec), +} + +/// Used for Communication between Channel <----(TCP/UDP)----> Channel +#[derive(Debug, PartialEq, Clone)] +pub enum Frame { + Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown (gracefully), + * Participant is deleted */ + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + }, + CloseStream { + sid: Sid, + }, + DataHeader { + mid: Mid, + sid: Sid, + length: u64, + }, + Data { + mid: Mid, + start: u64, + data: Vec, + }, +} + +impl InitFrame { + // Size WITHOUT the 1rst indicating byte + pub(crate) const HANDSHAKE_CNS: usize = 19; + pub(crate) const INIT_CNS: usize = 32; + /// const part of the RAW frame, actual size is variable + pub(crate) const RAW_CNS: usize = 2; + + //provide an appropriate buffer size. > 1500 + pub(crate) fn to_bytes(self, bytes: &mut [u8]) -> usize { + match self { + InitFrame::Handshake { + magic_number, + version, + } => { + let x = FRAME_HANDSHAKE.to_be_bytes(); + bytes[0] = x[0]; + bytes[1..8].copy_from_slice(&magic_number); + bytes[8..12].copy_from_slice(&version[0].to_le_bytes()); + bytes[12..16].copy_from_slice(&version[1].to_le_bytes()); + bytes[16..Self::HANDSHAKE_CNS + 1].copy_from_slice(&version[2].to_le_bytes()); + Self::HANDSHAKE_CNS + 1 + }, + InitFrame::Init { pid, secret } => { + bytes[0] = FRAME_INIT.to_be_bytes()[0]; + bytes[1..17].copy_from_slice(&pid.to_le_bytes()); + bytes[17..Self::INIT_CNS + 1].copy_from_slice(&secret.to_le_bytes()); + Self::INIT_CNS + 1 + }, + InitFrame::Raw(data) => { + bytes[0] = FRAME_RAW.to_be_bytes()[0]; + bytes[1..3].copy_from_slice(&(data.len() as u16).to_le_bytes()); + bytes[Self::RAW_CNS + 1..(data.len() + Self::RAW_CNS + 1)] + .clone_from_slice(&data[..]); + Self::RAW_CNS + 1 + data.len() + }, + } + } + + pub(crate) fn to_frame(bytes: Vec) -> Option { + let frame_no = match bytes.get(0) { + Some(&f) => f, + None => return None, + }; + let frame = match frame_no { + FRAME_HANDSHAKE => { + if bytes.len() < Self::HANDSHAKE_CNS + 1 { + return None; + } + InitFrame::gen_handshake( + *<&[u8; Self::HANDSHAKE_CNS]>::try_from(&bytes[1..Self::HANDSHAKE_CNS + 1]) + .unwrap(), + ) + }, + FRAME_INIT => { + if bytes.len() < Self::INIT_CNS + 1 { + return None; + } + InitFrame::gen_init( + *<&[u8; Self::INIT_CNS]>::try_from(&bytes[1..Self::INIT_CNS + 1]).unwrap(), + ) + }, + FRAME_RAW => { + if bytes.len() < Self::RAW_CNS + 1 { + return None; + } + let length = InitFrame::gen_raw( + *<&[u8; Self::RAW_CNS]>::try_from(&bytes[1..Self::RAW_CNS + 1]).unwrap(), + ); + let mut data = vec![0; length as usize]; + let slice = &bytes[Self::RAW_CNS + 1..]; + if slice.len() != length as usize { + return None; + } + data.copy_from_slice(&bytes[Self::RAW_CNS + 1..]); + InitFrame::Raw(data) + }, + _ => InitFrame::Raw(bytes), + }; + Some(frame) + } + + fn gen_handshake(buf: [u8; Self::HANDSHAKE_CNS]) -> Self { + let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); + InitFrame::Handshake { + magic_number, + version: [ + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..Self::HANDSHAKE_CNS]).unwrap()), + ], + } + } + + fn gen_init(buf: [u8; Self::INIT_CNS]) -> Self { + InitFrame::Init { + pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), + secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..Self::INIT_CNS]).unwrap()), + } + } + + fn gen_raw(buf: [u8; Self::RAW_CNS]) -> u16 { + u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..Self::RAW_CNS]).unwrap()) + } +} + +impl Frame { + pub(crate) const CLOSE_STREAM_CNS: usize = 8; + /// const part of the DATA frame, actual size is variable + pub(crate) const DATA_CNS: usize = 18; + pub(crate) const DATA_HEADER_CNS: usize = 24; + #[cfg(feature = "metrics")] + pub const FRAMES_LEN: u8 = 5; + pub(crate) const OPEN_STREAM_CNS: usize = 10; + // Size WITHOUT the 1rst indicating byte + pub(crate) const SHUTDOWN_CNS: usize = 0; + + #[cfg(feature = "metrics")] + pub const fn int_to_string(i: u8) -> &'static str { + match i { + 0 => "Shutdown", + 1 => "OpenStream", + 2 => "CloseStream", + 3 => "DataHeader", + 4 => "Data", + _ => "", + } + } + + #[cfg(feature = "metrics")] + pub fn get_int(&self) -> u8 { + match self { + Frame::Shutdown => 0, + Frame::OpenStream { .. } => 1, + Frame::CloseStream { .. } => 2, + Frame::DataHeader { .. } => 3, + Frame::Data { .. } => 4, + } + } + + #[cfg(feature = "metrics")] + pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } + + //provide an appropriate buffer size. > 1500 + pub fn to_bytes(self, bytes: &mut [u8]) -> (/* buf */ usize, /* actual data */ u64) { + match self { + Frame::Shutdown => { + bytes[Self::SHUTDOWN_CNS] = FRAME_SHUTDOWN.to_be_bytes()[0]; + (Self::SHUTDOWN_CNS + 1, 0) + }, + Frame::OpenStream { + sid, + prio, + promises, + } => { + bytes[0] = FRAME_OPEN_STREAM.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&sid.to_le_bytes()); + bytes[9] = prio.to_le_bytes()[0]; + bytes[Self::OPEN_STREAM_CNS] = promises.to_le_bytes()[0]; + (Self::OPEN_STREAM_CNS + 1, 0) + }, + Frame::CloseStream { sid } => { + bytes[0] = FRAME_CLOSE_STREAM.to_be_bytes()[0]; + bytes[1..Self::CLOSE_STREAM_CNS + 1].copy_from_slice(&sid.to_le_bytes()); + (Self::CLOSE_STREAM_CNS + 1, 0) + }, + Frame::DataHeader { mid, sid, length } => { + bytes[0] = FRAME_DATA_HEADER.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&mid.to_le_bytes()); + bytes[9..17].copy_from_slice(&sid.to_le_bytes()); + bytes[17..Self::DATA_HEADER_CNS + 1].copy_from_slice(&length.to_le_bytes()); + (Self::DATA_HEADER_CNS + 1, 0) + }, + Frame::Data { mid, start, data } => { + bytes[0] = FRAME_DATA.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&mid.to_le_bytes()); + bytes[9..17].copy_from_slice(&start.to_le_bytes()); + bytes[17..Self::DATA_CNS + 1].copy_from_slice(&(data.len() as u16).to_le_bytes()); + bytes[Self::DATA_CNS + 1..(data.len() + Self::DATA_CNS + 1)] + .clone_from_slice(&data[..]); + (Self::DATA_CNS + 1 + data.len(), data.len() as u64) + }, + } + } + + pub(crate) fn to_frame(bytes: &mut VecDeque) -> Option { + let frame_no = match bytes.get(0) { + Some(&f) => f, + None => return None, + }; + let size = match frame_no { + FRAME_SHUTDOWN => Self::SHUTDOWN_CNS, + FRAME_OPEN_STREAM => Self::OPEN_STREAM_CNS, + FRAME_CLOSE_STREAM => Self::CLOSE_STREAM_CNS, + FRAME_DATA_HEADER => Self::DATA_HEADER_CNS, + FRAME_DATA => { + u16::from_le_bytes([bytes[16 + 1], bytes[17 + 1]]) as usize + Self::DATA_CNS + }, + _ => return None, + }; + + if bytes.len() < size + 1 { + return None; + } + + let frame = match frame_no { + FRAME_SHUTDOWN => { + let _ = bytes.drain(..size + 1); + Frame::Shutdown + }, + FRAME_OPEN_STREAM => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_open_stream(<[u8; 10]>::try_from(bytes).unwrap()) + }, + FRAME_CLOSE_STREAM => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_close_stream(<[u8; 8]>::try_from(bytes).unwrap()) + }, + FRAME_DATA_HEADER => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_data_header(<[u8; 24]>::try_from(bytes).unwrap()) + }, + FRAME_DATA => { + let info = bytes + .drain(..Self::DATA_CNS + 1) + .skip(1) + .collect::>(); + let (mid, start, length) = Frame::gen_data(<[u8; 18]>::try_from(info).unwrap()); + debug_assert_eq!(length as usize, size - Self::DATA_CNS); + let data = bytes.drain(..length as usize).collect::>(); + Frame::Data { mid, start, data } + }, + _ => unreachable!("Frame::to_frame should be handled before!"), + }; + Some(frame) + } + + fn gen_open_stream(buf: [u8; Self::OPEN_STREAM_CNS]) -> Self { + Frame::OpenStream { + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + prio: buf[8], + promises: Promises::from_bits_truncate(buf[Self::OPEN_STREAM_CNS - 1]), + } + } + + fn gen_close_stream(buf: [u8; Self::CLOSE_STREAM_CNS]) -> Self { + Frame::CloseStream { + sid: Sid::from_le_bytes( + *<&[u8; 8]>::try_from(&buf[0..Self::CLOSE_STREAM_CNS]).unwrap(), + ), + } + } + + fn gen_data_header(buf: [u8; Self::DATA_HEADER_CNS]) -> Self { + Frame::DataHeader { + mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), + length: u64::from_le_bytes( + *<&[u8; 8]>::try_from(&buf[16..Self::DATA_HEADER_CNS]).unwrap(), + ), + } + } + + fn gen_data(buf: [u8; Self::DATA_CNS]) -> (Mid, u64, u16) { + let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); + let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); + let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..Self::DATA_CNS]).unwrap()); + (mid, start, length) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{VELOREN_MAGIC_NUMBER, VELOREN_NETWORK_VERSION}; + + fn get_initframes() -> Vec { + vec![ + InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }, + InitFrame::Init { + pid: Pid::fake(0), + secret: 0u128, + }, + InitFrame::Raw(vec![1, 2, 3]), + ] + } + + fn get_frames() -> Vec { + vec![ + Frame::OpenStream { + sid: Sid::new(1337), + prio: 14, + promises: Promises::GUARANTEED_DELIVERY, + }, + Frame::DataHeader { + sid: Sid::new(1337), + mid: 0, + length: 36, + }, + Frame::Data { + mid: 0, + start: 0, + data: vec![77u8; 20], + }, + Frame::Data { + mid: 0, + start: 20, + data: vec![42u8; 16], + }, + Frame::CloseStream { + sid: Sid::new(1337), + }, + Frame::Shutdown, + ] + } + + #[test] + fn initframe_individual() { + let dupl = |frame: InitFrame| { + let mut buffer = vec![0u8; 1500]; + let size = InitFrame::to_bytes(frame.clone(), &mut buffer); + buffer.truncate(size); + InitFrame::to_frame(buffer) + }; + + for frame in get_initframes() { + println!("initframe: {:?}", &frame); + assert_eq!(Some(frame.clone()), dupl(frame)); + } + } + + #[test] + fn initframe_multiple() { + let mut buffer = vec![0u8; 3000]; + + let mut frames = get_initframes(); + let mut last = 0; + // to string + let sizes = frames + .iter() + .map(|f| { + let s = InitFrame::to_bytes(f.clone(), &mut buffer[last..]); + last += s; + s + }) + .collect::>(); + + // from string + let mut last = 0; + let mut framesd = sizes + .iter() + .map(|&s| { + let f = InitFrame::to_frame(buffer[last..last + s].to_vec()); + last += s; + f + }) + .collect::>(); + + // compare + for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { + println!("initframe: {:?}", &f); + assert_eq!(Some(f), fd); + } + } + + #[test] + fn frame_individual() { + let dupl = |frame: Frame| { + let mut buffer = vec![0u8; 1500]; + let (size, _) = Frame::to_bytes(frame.clone(), &mut buffer); + let mut deque = buffer[..size].iter().map(|b| *b).collect(); + Frame::to_frame(&mut deque) + }; + + for frame in get_frames() { + println!("frame: {:?}", &frame); + assert_eq!(Some(frame.clone()), dupl(frame)); + } + } + + #[test] + fn frame_multiple() { + let mut buffer = vec![0u8; 3000]; + + let mut frames = get_frames(); + let mut last = 0; + // to string + let sizes = frames + .iter() + .map(|f| { + let s = Frame::to_bytes(f.clone(), &mut buffer[last..]).0; + last += s; + s + }) + .collect::>(); + + assert_eq!(sizes[0], 1 + Frame::OPEN_STREAM_CNS); + assert_eq!(sizes[1], 1 + Frame::DATA_HEADER_CNS); + assert_eq!(sizes[2], 1 + Frame::DATA_CNS + 20); + assert_eq!(sizes[3], 1 + Frame::DATA_CNS + 16); + assert_eq!(sizes[4], 1 + Frame::CLOSE_STREAM_CNS); + assert_eq!(sizes[5], 1 + Frame::SHUTDOWN_CNS); + + let mut buffer = buffer.drain(..).collect::>(); + + // from string + let mut framesd = sizes + .iter() + .map(|&_| Frame::to_frame(&mut buffer)) + .collect::>(); + + // compare + for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { + println!("frame: {:?}", &f); + assert_eq!(Some(f), fd); + } + } + + #[test] + fn frame_exact_size() { + let mut buffer = vec![0u8; Frame::CLOSE_STREAM_CNS+1/*first byte*/]; + + let frame1 = Frame::CloseStream { + sid: Sid::new(1337), + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + let mut deque = buffer.iter().map(|b| *b).collect(); + let frame2 = Frame::to_frame(&mut deque); + assert_eq!(Some(frame1), frame2); + } + + #[test] + #[should_panic] + fn initframe_too_short_buffer() { + let mut buffer = vec![0u8; 10]; + + let frame1 = InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }; + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + } + + #[test] + fn initframe_too_less_data() { + let mut buffer = vec![0u8; 20]; + + let frame1 = InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }; + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer.truncate(6); // simulate partial retrieve + let frame1d = InitFrame::to_frame(buffer[..6].to_vec()); + assert_eq!(frame1d, None); + } + + #[test] + fn initframe_rubish() { + let buffer = b"dtrgwcser".to_vec(); + assert_eq!( + InitFrame::to_frame(buffer), + Some(InitFrame::Raw(b"dtrgwcser".to_vec())) + ); + } + + #[test] + fn initframe_attack_too_much_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = InitFrame::Raw(b"foobar".to_vec()); + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer[2] = 255; + let framed = InitFrame::to_frame(buffer); + assert_eq!(framed, None); + } + + #[test] + fn initframe_attack_too_low_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = InitFrame::Raw(b"foobar".to_vec()); + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer[2] = 3; + let framed = InitFrame::to_frame(buffer); + assert_eq!(framed, None); + } + + #[test] + #[should_panic] + fn frame_too_short_buffer() { + let mut buffer = vec![0u8; 10]; + + let frame1 = Frame::OpenStream { + sid: Sid::new(88), + promises: Promises::ENCRYPTED, + prio: 88, + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + } + + #[test] + fn frame_too_less_data() { + let mut buffer = vec![0u8; 20]; + + let frame1 = Frame::OpenStream { + sid: Sid::new(88), + promises: Promises::ENCRYPTED, + prio: 88, + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer.truncate(6); // simulate partial retrieve + let mut buffer = buffer.drain(..6).collect::>(); + let frame1d = Frame::to_frame(&mut buffer); + assert_eq!(frame1d, None); + } + + #[test] + fn frame_rubish() { + let mut buffer = b"dtrgwcser".iter().map(|u| *u).collect::>(); + assert_eq!(Frame::to_frame(&mut buffer), None); + } + + #[test] + fn frame_attack_too_much_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foobar".to_vec(), + }; + + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer[17] = 255; + let mut buffer = buffer.drain(..).collect::>(); + let framed = Frame::to_frame(&mut buffer); + assert_eq!(framed, None); + } + + #[test] + fn frame_attack_too_low_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foobar".to_vec(), + }; + + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer[17] = 3; + let mut buffer = buffer.drain(..).collect::>(); + let framed = Frame::to_frame(&mut buffer); + assert_eq!( + framed, + Some(Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foo".to_vec(), + }) + ); + //next = Invalid => Empty + let framed = Frame::to_frame(&mut buffer); + assert_eq!(framed, None); + } + + #[test] + fn frame_int2str() { + assert_eq!(Frame::int_to_string(0), "Shutdown"); + assert_eq!(Frame::int_to_string(1), "OpenStream"); + assert_eq!(Frame::int_to_string(2), "CloseStream"); + assert_eq!(Frame::int_to_string(3), "DataHeader"); + assert_eq!(Frame::int_to_string(4), "Data"); + } +} diff --git a/network/protocol/src/handshake.rs b/network/protocol/src/handshake.rs new file mode 100644 index 0000000000..cc46791fc6 --- /dev/null +++ b/network/protocol/src/handshake.rs @@ -0,0 +1,227 @@ +use crate::{ + frame::InitFrame, + types::{ + Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER, + VELOREN_NETWORK_VERSION, + }, + InitProtocol, InitProtocolError, ProtocolError, +}; +use async_trait::async_trait; +use tracing::{debug, error, info, trace}; + +// Protocols might define a Reliable Variant for auto Handshake discovery +// this doesn't need to be effective +#[async_trait] +pub trait ReliableDrain { + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait ReliableSink { + async fn recv(&mut self) -> Result; +} + +#[async_trait] +impl InitProtocol for (D, S) +where + D: ReliableDrain + Send, + S: ReliableSink + Send, +{ + async fn initialize( + &mut self, + initializer: bool, + local_pid: Pid, + local_secret: u128, + ) -> Result<(Pid, Sid, u128), InitProtocolError> { + #[cfg(debug_assertions)] + const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required \ + by veloren server.\nWe are not sure if you are a \ + valid veloren client.\nClosing the connection" + .as_bytes(); + #[cfg(debug_assertions)] + const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \ + invalid version.\nWe don't know how to communicate \ + with you.\nClosing the connection"; + const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ + something went wrong on network layer and connection will be closed"; + + let drain = &mut self.0; + let sink = &mut self.1; + + if initializer { + drain + .send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + } + + match sink.recv().await? { + InitFrame::Handshake { + magic_number, + version, + } => { + trace!(?magic_number, ?version, "Recv handshake"); + if magic_number != VELOREN_MAGIC_NUMBER { + error!(?magic_number, "Connection with invalid magic_number"); + #[cfg(debug_assertions)] + drain.send(InitFrame::Raw(WRONG_NUMBER.to_vec())).await?; + Err(InitProtocolError::WrongMagicNumber(magic_number)) + } else if version != VELOREN_NETWORK_VERSION { + error!(?version, "Connection with wrong network version"); + #[cfg(debug_assertions)] + drain + .send(InitFrame::Raw( + format!( + "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection", + WRONG_VERSION, VELOREN_NETWORK_VERSION, version, + ) + .as_bytes() + .to_vec(), + )) + .await?; + Err(InitProtocolError::WrongVersion(version)) + } else { + trace!("Handshake Frame completed"); + if initializer { + drain + .send(InitFrame::Init { + pid: local_pid, + secret: local_secret, + }) + .await?; + } else { + drain + .send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + } + Ok(()) + } + }, + InitFrame::Raw(bytes) => { + match std::str::from_utf8(bytes.as_slice()) { + Ok(string) => error!(?string, ERR_S), + _ => error!(?bytes, ERR_S), + } + Err(InitProtocolError::Closed) + }, + _ => { + info!("Handshake failed"); + Err(InitProtocolError::Closed) + }, + }?; + + match sink.recv().await? { + InitFrame::Init { pid, secret } => { + debug!(?pid, "Participant send their ID"); + let stream_id_offset = if initializer { + STREAM_ID_OFFSET1 + } else { + drain + .send(InitFrame::Init { + pid: local_pid, + secret: local_secret, + }) + .await?; + STREAM_ID_OFFSET2 + }; + info!(?pid, "This Handshake is now configured!"); + Ok((pid, stream_id_offset, secret)) + }, + InitFrame::Raw(bytes) => { + match std::str::from_utf8(bytes.as_slice()) { + Ok(string) => error!(?string, ERR_S), + _ => error!(?bytes, ERR_S), + } + Err(InitProtocolError::Closed) + }, + _ => { + info!("Handshake failed"); + Err(InitProtocolError::Closed) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{mpsc::test_utils::*, InitProtocolError}; + + #[tokio::test] + async fn handshake_drop_start() { + let [mut p1, p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2; + }); + let (r1, _) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); + } + + #[tokio::test] + async fn handshake_wrong_magic_number() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: *b"woopsie", + version: VELOREN_NETWORK_VERSION, + }) + .await?; + let _ = p2.1.recv().await?; + Result::<(), InitProtocolError>::Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!( + r1.unwrap(), + Err(InitProtocolError::WrongMagicNumber(*b"woopsie")) + ); + assert_eq!(r2.unwrap(), Ok(())); + } + + #[tokio::test] + async fn handshake_wrong_version() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: [0, 1, 2], + }) + .await?; + let _ = p2.1.recv().await?; + let _ = p2.1.recv().await?; //this should be closed now + Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2]))); + assert_eq!(r2.unwrap(), Err(InitProtocolError::Closed)); + } + + #[tokio::test] + async fn handshake_unexpected_raw() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?; + Result::<(), InitProtocolError>::Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); + assert_eq!(r2.unwrap(), Ok(())); + } +} diff --git a/network/protocol/src/io.rs b/network/protocol/src/io.rs new file mode 100644 index 0000000000..c4e3eba43e --- /dev/null +++ b/network/protocol/src/io.rs @@ -0,0 +1,62 @@ +use crate::ProtocolError; +use async_trait::async_trait; +use std::collections::VecDeque; +///! I/O-Free (Sans-I/O) protocol https://sans-io.readthedocs.io/how-to-sans-io.html + +// Protocols should base on the Unrealiable variants to get something effective! +#[async_trait] +pub trait UnreliableDrain: Send { + type DataFormat; + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait UnreliableSink: Send { + type DataFormat; + async fn recv(&mut self) -> Result; +} + +pub struct BaseDrain { + data: VecDeque>, +} + +pub struct BaseSink { + data: VecDeque>, +} + +impl BaseDrain { + pub fn new() -> Self { + Self { + data: VecDeque::new(), + } + } +} + +impl BaseSink { + pub fn new() -> Self { + Self { + data: VecDeque::new(), + } + } +} + +//TODO: Test Sinks that drop 20% by random and log that + +#[async_trait] +impl UnreliableDrain for BaseDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.data.push_back(data); + Ok(()) + } +} + +#[async_trait] +impl UnreliableSink for BaseSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.data.pop_front().ok_or(ProtocolError::Closed) + } +} diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs new file mode 100644 index 0000000000..8d49ed58c9 --- /dev/null +++ b/network/protocol/src/lib.rs @@ -0,0 +1,75 @@ +mod event; +mod frame; +mod handshake; +mod io; +mod message; +mod metrics; +mod mpsc; +mod prio; +mod tcp; +mod types; + +pub use event::ProtocolEvent; +pub use io::{BaseDrain, BaseSink, UnreliableDrain, UnreliableSink}; +pub use message::MessageBuffer; +pub use metrics::ProtocolMetricCache; +#[cfg(feature = "metrics")] +pub use metrics::ProtocolMetrics; +pub use mpsc::{MpscMsg, MpscRecvProtcol, MpscSendProtcol}; +pub use tcp::{TcpRecvProtcol, TcpSendProtcol}; +pub use types::{Bandwidth, Cid, Mid, Pid, Prio, Promises, Sid, VELOREN_NETWORK_VERSION}; + +///use at own risk, might change any time, for internal benchmarks +pub mod _internal { + pub use crate::frame::Frame; +} + +use async_trait::async_trait; + +#[async_trait] +pub trait InitProtocol { + async fn initialize( + &mut self, + initializer: bool, + local_pid: Pid, + secret: u128, + ) -> Result<(Pid, Sid, u128), InitProtocolError>; +} + +#[async_trait] +pub trait SendProtocol { + //a stream MUST be bound to a specific Protocol, there will be a failover + // feature comming for the case where a Protocol fails completly + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError>; + async fn flush( + &mut self, + bandwidth: Bandwidth, + dt: std::time::Duration, + ) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait RecvProtocol { + async fn recv(&mut self) -> Result; +} + +#[derive(Debug, PartialEq)] +pub enum InitProtocolError { + Closed, + WrongMagicNumber([u8; 7]), + WrongVersion([u32; 3]), +} + +#[derive(Debug, PartialEq)] +/// When you return closed you must stay closed! +pub enum ProtocolError { + Closed, +} + +impl From for InitProtocolError { + fn from(err: ProtocolError) -> Self { + match err { + ProtocolError::Closed => InitProtocolError::Closed, + } + } +} diff --git a/network/protocol/src/message.rs b/network/protocol/src/message.rs new file mode 100644 index 0000000000..1bda1325ad --- /dev/null +++ b/network/protocol/src/message.rs @@ -0,0 +1,127 @@ +use crate::{ + frame::Frame, + types::{Mid, Sid}, +}; +use std::{collections::VecDeque, sync::Arc}; + +//Todo: Evaluate switching to VecDeque for quickly adding and removing data +// from front, back. +// - It would prob require custom bincode code but thats possible. +#[cfg_attr(test, derive(PartialEq))] +pub struct MessageBuffer { + pub data: Vec, +} + +impl std::fmt::Debug for MessageBuffer { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + //TODO: small messages! + let len = self.data.len(); + if len > 20 { + write!( + f, + "MessageBuffer(len: {}, {}, {}, {}, {:X?}..{:X?})", + len, + u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]), + u32::from_le_bytes([self.data[4], self.data[5], self.data[6], self.data[7]]), + u32::from_le_bytes([self.data[8], self.data[9], self.data[10], self.data[11]]), + &self.data[13..16], + &self.data[len - 8..len] + ) + } else { + write!(f, "MessageBuffer(len: {}, {:?})", len, &self.data[..]) + } + } +} + +/// Contains a outgoing message and store what was *send* and *confirmed* +/// All Chunks have the same size, except for the last chunk which can end +/// earlier. E.g. +/// ```ignore +/// msg = OutgoingMessage::new(); +/// msg.next(); +/// msg.next(); +/// msg.confirm(1); +/// msg.confirm(2); +/// ``` +#[derive(Debug)] +pub(crate) struct OutgoingMessage { + buffer: Arc, + send_index: u64, // 3 => 4200 (3*FRAME_DATA_SIZE) + send_header: bool, + mid: Mid, + sid: Sid, + max_index: u64, //speedup + missing_header: bool, + missing_indices: VecDeque, +} + +impl OutgoingMessage { + pub(crate) const FRAME_DATA_SIZE: u64 = 1400; + + pub(crate) fn new(buffer: Arc, mid: Mid, sid: Sid) -> Self { + let max_index = + (buffer.data.len() as u64 + Self::FRAME_DATA_SIZE - 1) / Self::FRAME_DATA_SIZE; + Self { + buffer, + send_index: 0, + send_header: false, + mid, + sid, + max_index, + missing_header: false, + missing_indices: VecDeque::new(), + } + } + + /// all has been send once, but might been resend due to failures. + #[allow(dead_code)] + pub(crate) fn initial_sent(&self) -> bool { self.send_index == self.max_index } + + pub fn get_header(&self) -> Frame { + Frame::DataHeader { + mid: self.mid, + sid: self.sid, + length: self.buffer.data.len() as u64, + } + } + + pub fn get_data(&self, index: u64) -> Frame { + let start = index * Self::FRAME_DATA_SIZE; + let to_send = std::cmp::min( + self.buffer.data[start as usize..].len() as u64, + Self::FRAME_DATA_SIZE, + ); + Frame::Data { + mid: self.mid, + start, + data: self.buffer.data[start as usize..][..to_send as usize].to_vec(), + } + } + + #[allow(dead_code)] + pub(crate) fn set_missing(&mut self, missing_header: bool, missing_indicies: VecDeque) { + self.missing_header = missing_header; + self.missing_indices = missing_indicies; + } + + /// returns if something was added + pub(crate) fn next(&mut self) -> Option { + if !self.send_header { + self.send_header = true; + Some(self.get_header()) + } else if self.send_index < self.max_index { + self.send_index += 1; + Some(self.get_data(self.send_index - 1)) + } else if self.missing_header { + self.missing_header = false; + Some(self.get_header()) + } else if let Some(index) = self.missing_indices.pop_front() { + Some(self.get_data(index)) + } else { + None + } + } + + pub(crate) fn get_sid_len(&self) -> (Sid, u64) { (self.sid, self.buffer.data.len() as u64) } +} diff --git a/network/protocol/src/metrics.rs b/network/protocol/src/metrics.rs new file mode 100644 index 0000000000..715a06fc9d --- /dev/null +++ b/network/protocol/src/metrics.rs @@ -0,0 +1,414 @@ +use crate::types::Sid; +#[cfg(feature = "metrics")] +use prometheus::{IntCounterVec, IntGaugeVec, Opts, Registry}; +#[cfg(feature = "metrics")] +use std::{error::Error, sync::Arc}; + +#[allow(dead_code)] +pub enum RemoveReason { + Finished, + Dropped, +} + +#[cfg(feature = "metrics")] +pub struct ProtocolMetrics { + // smsg=send_msg rdata=receive_data + // i=in o=out + // t=total b=byte throughput + //e.g smsg_it = sending messages, in (responsibility of protocol) total + + // based on CHANNEL/STREAM + /// messages added to be send total, by STREAM, + smsg_it: IntCounterVec, + /// messages bytes added to be send throughput, by STREAM, + smsg_ib: IntCounterVec, + /// messages removed from to be send, because they where finished total, by + /// STREAM AND REASON(finished/canceled), + smsg_ot: IntCounterVec, + /// messages bytes removed from to be send throughput, because they where + /// finished total, by STREAM AND REASON(finished/dropped), + smsg_ob: IntCounterVec, + /// data frames send by prio by CHANNEL, + sdata_frames_t: IntCounterVec, + /// data frames bytes send by prio by CHANNEL, + sdata_frames_b: IntCounterVec, + + // based on CHANNEL/STREAM + /// messages added to be received total, by STREAM, + rmsg_it: IntCounterVec, + /// messages bytes added to be received throughput, by STREAM, + rmsg_ib: IntCounterVec, + /// messages removed from to be received, because they where finished total, + /// by STREAM AND REASON(finished/canceled), + rmsg_ot: IntCounterVec, + /// messages bytes removed from to be received throughput, because they + /// where finished total, by STREAM AND REASON(finished/dropped), + rmsg_ob: IntCounterVec, + /// data frames send by prio by CHANNEL, + rdata_frames_t: IntCounterVec, + /// data frames bytes send by prio by CHANNEL, + rdata_frames_b: IntCounterVec, + /// ping per CHANNEL //TODO: implement + ping: IntGaugeVec, +} + +#[cfg(feature = "metrics")] +#[derive(Debug, Clone)] +pub struct ProtocolMetricCache { + cid: String, + m: Arc, +} + +#[cfg(not(feature = "metrics"))] +#[derive(Debug, Clone)] +pub struct ProtocolMetricCache {} + +#[cfg(feature = "metrics")] +impl ProtocolMetrics { + pub fn new() -> Result> { + let smsg_it = IntCounterVec::new( + Opts::new( + "send_messages_in_total", + "All Messages that are added to this Protocol to be send at stream level", + ), + &["channel", "stream"], + )?; + let smsg_ib = IntCounterVec::new( + Opts::new( + "send_messages_in_throughput", + "All Message bytes that are added to this Protocol to be send at stream level", + ), + &["channel", "stream"], + )?; + let smsg_ot = IntCounterVec::new( + Opts::new( + "send_messages_out_total", + "All Messages that are removed from this Protocol to be send at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let smsg_ob = IntCounterVec::new( + Opts::new( + "send_messages_out_throughput", + "All Message bytes that are removed from this Protocol to be send at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let sdata_frames_t = IntCounterVec::new( + Opts::new( + "send_data_frames_total", + "Number of data frames send per channel", + ), + &["channel"], + )?; + let sdata_frames_b = IntCounterVec::new( + Opts::new( + "send_data_frames_throughput", + "Number of data frames bytes send per channel", + ), + &["channel"], + )?; + + let rmsg_it = IntCounterVec::new( + Opts::new( + "recv_messages_in_total", + "All Messages that are added to this Protocol to be received at stream level", + ), + &["channel", "stream"], + )?; + let rmsg_ib = IntCounterVec::new( + Opts::new( + "recv_messages_in_throughput", + "All Message bytes that are added to this Protocol to be received at stream level", + ), + &["channel", "stream"], + )?; + let rmsg_ot = IntCounterVec::new( + Opts::new( + "recv_messages_out_total", + "All Messages that are removed from this Protocol to be received at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let rmsg_ob = IntCounterVec::new( + Opts::new( + "recv_messages_out_throughput", + "All Message bytes that are removed from this Protocol to be received at stream \ + and reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let rdata_frames_t = IntCounterVec::new( + Opts::new( + "recv_data_frames_total", + "Number of data frames received per channel", + ), + &["channel"], + )?; + let rdata_frames_b = IntCounterVec::new( + Opts::new( + "recv_data_frames_throughput", + "Number of data frames bytes received per channel", + ), + &["channel"], + )?; + let ping = IntGaugeVec::new(Opts::new("ping", "Ping per channel"), &["channel"])?; + + Ok(Self { + smsg_it, + smsg_ib, + smsg_ot, + smsg_ob, + sdata_frames_t, + sdata_frames_b, + rmsg_it, + rmsg_ib, + rmsg_ot, + rmsg_ob, + rdata_frames_t, + rdata_frames_b, + ping, + }) + } + + pub fn register(&self, registry: &Registry) -> Result<(), Box> { + registry.register(Box::new(self.smsg_it.clone()))?; + registry.register(Box::new(self.smsg_ib.clone()))?; + registry.register(Box::new(self.smsg_ot.clone()))?; + registry.register(Box::new(self.smsg_ob.clone()))?; + registry.register(Box::new(self.sdata_frames_t.clone()))?; + registry.register(Box::new(self.sdata_frames_b.clone()))?; + registry.register(Box::new(self.rmsg_it.clone()))?; + registry.register(Box::new(self.rmsg_ib.clone()))?; + registry.register(Box::new(self.rmsg_ot.clone()))?; + registry.register(Box::new(self.rmsg_ob.clone()))?; + registry.register(Box::new(self.rdata_frames_t.clone()))?; + registry.register(Box::new(self.rdata_frames_b.clone()))?; + registry.register(Box::new(self.ping.clone()))?; + Ok(()) + } +} + +#[cfg(feature = "metrics")] +impl ProtocolMetricCache { + pub fn new(channel_key: &str, metrics: Arc) -> Self { + Self { + cid: channel_key.to_string(), + m: metrics, + } + } + + pub(crate) fn smsg_it(&self, sid: Sid) { + self.m + .smsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc(); + } + + pub(crate) fn smsg_ib(&self, sid: Sid, bytes: u64) { + self.m + .smsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc_by(bytes); + } + + pub(crate) fn smsg_ot(&self, sid: Sid, reason: RemoveReason) { + self.m + .smsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc(); + } + + pub(crate) fn smsg_ob(&self, sid: Sid, reason: RemoveReason, bytes: u64) { + self.m + .smsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc_by(bytes); + } + + pub(crate) fn sdata_frames_t(&self) { + self.m.sdata_frames_t.with_label_values(&[&self.cid]).inc(); + } + + pub(crate) fn sdata_frames_b(&self, bytes: u64) { + self.m + .sdata_frames_b + .with_label_values(&[&self.cid]) + .inc_by(bytes); + } + + pub(crate) fn rmsg_it(&self, sid: Sid) { + self.m + .rmsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc(); + } + + pub(crate) fn rmsg_ib(&self, sid: Sid, bytes: u64) { + self.m + .rmsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc_by(bytes); + } + + pub(crate) fn rmsg_ot(&self, sid: Sid, reason: RemoveReason) { + self.m + .rmsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc(); + } + + pub(crate) fn rmsg_ob(&self, sid: Sid, reason: RemoveReason, bytes: u64) { + self.m + .rmsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc_by(bytes); + } + + pub(crate) fn rdata_frames_t(&self) { + self.m.rdata_frames_t.with_label_values(&[&self.cid]).inc(); + } + + pub(crate) fn rdata_frames_b(&self, bytes: u64) { + self.m + .rdata_frames_b + .with_label_values(&[&self.cid]) + .inc_by(bytes); + } + + #[cfg(test)] + pub(crate) fn assert_msg(&self, sid: Sid, cnt: u64, reason: RemoveReason) { + assert_eq!( + self.m + .smsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + cnt + ); + assert_eq!( + self.m + .smsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + cnt + ); + assert_eq!( + self.m + .rmsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + cnt + ); + assert_eq!( + self.m + .rmsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + cnt + ); + } + + #[cfg(test)] + pub(crate) fn assert_msg_bytes(&self, sid: Sid, bytes: u64, reason: RemoveReason) { + assert_eq!( + self.m + .smsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + bytes + ); + assert_eq!( + self.m + .smsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + bytes + ); + assert_eq!( + self.m + .rmsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + bytes + ); + assert_eq!( + self.m + .rmsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + bytes + ); + } + + #[cfg(test)] + pub(crate) fn assert_data_frames(&self, cnt: u64) { + assert_eq!( + self.m.sdata_frames_t.with_label_values(&[&self.cid]).get(), + cnt + ); + assert_eq!( + self.m.rdata_frames_t.with_label_values(&[&self.cid]).get(), + cnt + ); + } + + #[cfg(test)] + pub(crate) fn assert_data_frames_bytes(&self, bytes: u64) { + assert_eq!( + self.m.sdata_frames_b.with_label_values(&[&self.cid]).get(), + bytes + ); + assert_eq!( + self.m.rdata_frames_b.with_label_values(&[&self.cid]).get(), + bytes + ); + } +} + +#[cfg(feature = "metrics")] +impl std::fmt::Debug for ProtocolMetrics { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ProtocolMetrics()") + } +} + +#[cfg(not(feature = "metrics"))] +impl ProtocolMetricCache { + pub(crate) fn smsg_it(&self, _sid: Sid) {} + + pub(crate) fn smsg_ib(&self, _sid: Sid, _b: u64) {} + + pub(crate) fn smsg_ot(&self, _sid: Sid, _reason: RemoveReason) {} + + pub(crate) fn smsg_ob(&self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + + pub(crate) fn sdata_frames_t(&self) {} + + pub(crate) fn sdata_frames_b(&self, _b: u64) {} + + pub(crate) fn rmsg_it(&self, _sid: Sid) {} + + pub(crate) fn rmsg_ib(&self, _sid: Sid, _b: u64) {} + + pub(crate) fn rmsg_ot(&self, _sid: Sid, _reason: RemoveReason) {} + + pub(crate) fn rmsg_ob(&self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + + pub(crate) fn rdata_frames_t(&self) {} + + pub(crate) fn rdata_frames_b(&self, _b: u64) {} +} + +impl RemoveReason { + #[cfg(feature = "metrics")] + fn to_str(&self) -> &str { + match self { + RemoveReason::Dropped => "Dropped", + RemoveReason::Finished => "Finished", + } + } +} diff --git a/network/protocol/src/mpsc.rs b/network/protocol/src/mpsc.rs new file mode 100644 index 0000000000..3e9e5d55fe --- /dev/null +++ b/network/protocol/src/mpsc.rs @@ -0,0 +1,217 @@ +use crate::{ + event::ProtocolEvent, + frame::InitFrame, + handshake::{ReliableDrain, ReliableSink}, + io::{UnreliableDrain, UnreliableSink}, + metrics::{ProtocolMetricCache, RemoveReason}, + types::Bandwidth, + ProtocolError, RecvProtocol, SendProtocol, +}; +use async_trait::async_trait; +use std::time::{Duration, Instant}; + +pub /* should be private */ enum MpscMsg { + Event(ProtocolEvent), + InitFrame(InitFrame), +} + +#[derive(Debug)] +pub struct MpscSendProtcol +where + D: UnreliableDrain, +{ + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +#[derive(Debug)] +pub struct MpscRecvProtcol +where + S: UnreliableSink, +{ + sink: S, + metrics: ProtocolMetricCache, +} + +impl MpscSendProtcol +where + D: UnreliableDrain, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + drain, + last: Instant::now(), + metrics, + } + } +} + +impl MpscRecvProtcol +where + S: UnreliableSink, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { Self { sink, metrics } } +} + +#[async_trait] +impl SendProtocol for MpscSendProtcol +where + D: UnreliableDrain, +{ + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match &event { + ProtocolEvent::Message { + buffer, + mid: _, + sid, + } => { + let sid = *sid; + let bytes = buffer.data.len() as u64; + self.metrics.smsg_it(sid); + self.metrics.smsg_ib(sid, bytes); + let r = self.drain.send(MpscMsg::Event(event)).await; + self.metrics.smsg_ot(sid, RemoveReason::Finished); + self.metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + r + }, + _ => self.drain.send(MpscMsg::Event(event)).await, + } + } + + async fn flush(&mut self, _: Bandwidth, _: Duration) -> Result<(), ProtocolError> { Ok(()) } +} + +#[async_trait] +impl RecvProtocol for MpscRecvProtcol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + match self.sink.recv().await? { + MpscMsg::Event(e) => { + if let ProtocolEvent::Message { + buffer, + mid: _, + sid, + } = &e + { + let sid = *sid; + let bytes = buffer.data.len() as u64; + self.metrics.rmsg_it(sid); + self.metrics.rmsg_ib(sid, bytes); + self.metrics.rmsg_ot(sid, RemoveReason::Finished); + self.metrics.rmsg_ob(sid, RemoveReason::Finished, bytes); + } + Ok(e) + }, + MpscMsg::InitFrame(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl ReliableDrain for MpscSendProtcol +where + D: UnreliableDrain, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + self.drain.send(MpscMsg::InitFrame(frame)).await + } +} + +#[async_trait] +impl ReliableSink for MpscRecvProtcol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + match self.sink.recv().await? { + MpscMsg::Event(_) => Err(ProtocolError::Closed), + MpscMsg::InitFrame(f) => Ok(f), + } + } +} + +#[cfg(test)] +pub mod test_utils { + use super::*; + use crate::{ + io::*, + metrics::{ProtocolMetricCache, ProtocolMetrics}, + }; + use async_channel::*; + use std::sync::Arc; + + pub struct ACDrain { + sender: Sender, + } + + pub struct ACSink { + receiver: Receiver, + } + + pub fn ac_bound( + cap: usize, + metrics: Option, + ) -> [(MpscSendProtcol, MpscRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + MpscSendProtcol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r2 }, m.clone()), + ), + ( + MpscSendProtcol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for ACDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for ACSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + mpsc::test_utils::*, + types::{Pid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, + }; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } +} diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs new file mode 100644 index 0000000000..35b7067352 --- /dev/null +++ b/network/protocol/src/prio.rs @@ -0,0 +1,139 @@ +use crate::{ + frame::Frame, + message::{MessageBuffer, OutgoingMessage}, + metrics::{ProtocolMetricCache, RemoveReason}, + types::{Bandwidth, Mid, Prio, Promises, Sid}, +}; +use std::{collections::HashMap, sync::Arc, time::Duration}; + +#[derive(Debug)] +struct StreamInfo { + pub(crate) guaranteed_bandwidth: Bandwidth, + pub(crate) prio: Prio, + pub(crate) promises: Promises, + pub(crate) messages: Vec, +} + +/// Responsible for queueing messages. +/// every stream has a guaranteed bandwidth and a prio 0-7. +/// when `n` Bytes are available in the buffer, first the guaranteed bandwidth +/// is used. Then remaining bandwidth is used to fill up the prios. +#[derive(Debug)] +pub(crate) struct PrioManager { + streams: HashMap, + metrics: ProtocolMetricCache, +} + +// Send everything ONCE, then keep it till it's confirmed + +impl PrioManager { + const HIGHEST_PRIO: u8 = 7; + + pub fn new(metrics: ProtocolMetricCache) -> Self { + Self { + streams: HashMap::new(), + metrics, + } + } + + pub fn open_stream( + &mut self, + sid: Sid, + prio: Prio, + promises: Promises, + guaranteed_bandwidth: Bandwidth, + ) { + self.streams.insert(sid, StreamInfo { + guaranteed_bandwidth, + prio, + promises, + messages: vec![], + }); + } + + pub fn try_close_stream(&mut self, sid: Sid) -> bool { + if let Some(si) = self.streams.get(&sid) { + if si.messages.is_empty() { + self.streams.remove(&sid); + return true; + } + } + false + } + + pub fn is_empty(&self) -> bool { self.streams.is_empty() } + + pub fn add(&mut self, buffer: Arc, mid: Mid, sid: Sid) { + self.streams + .get_mut(&sid) + .unwrap() + .messages + .push(OutgoingMessage::new(buffer, mid, sid)); + } + + /// bandwidth might be extended, as for technical reasons + /// guaranteed_bandwidth is used and frames are always 1400 bytes. + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> Vec { + let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; + let mut cur_bytes = 0u64; + let mut frames = vec![]; + + let mut prios = [0u64; (Self::HIGHEST_PRIO + 1) as usize]; + let metrics = &self.metrics; + + let mut process_stream = + |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { + let mut finished = vec![]; + 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { + while let Some(frame) = msg.next() { + let b = if matches!(frame, Frame::DataHeader { .. }) { + 25 + } else { + 19 + OutgoingMessage::FRAME_DATA_SIZE + }; + bandwidth -= b as i64; + *cur_bytes += b; + frames.push(frame); + if bandwidth <= 0 { + break 'outer; + } + } + finished.push(i); + } + + //cleanup + for i in finished.iter().rev() { + let msg = stream.messages.remove(*i); + let (sid, bytes) = msg.get_sid_len(); + metrics.smsg_ot(sid, RemoveReason::Finished); + metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + } + }; + + // Add guaranteed bandwidth + for (_, stream) in &mut self.streams { + prios[stream.prio.min(Self::HIGHEST_PRIO) as usize] += 1; + let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; + process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); + } + + if cur_bytes < total_bytes { + // Add optional bandwidth + for prio in 0..=Self::HIGHEST_PRIO { + if prios[prio as usize] == 0 { + continue; + } + let per_stream_bytes = (total_bytes - cur_bytes) / prios[prio as usize]; + + for (_, stream) in &mut self.streams { + if stream.prio != prio { + continue; + } + process_stream(stream, per_stream_bytes as i64, &mut cur_bytes); + } + } + } + + frames + } +} diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs new file mode 100644 index 0000000000..e1c8e10e84 --- /dev/null +++ b/network/protocol/src/tcp.rs @@ -0,0 +1,584 @@ +use crate::{ + event::ProtocolEvent, + frame::{Frame, InitFrame}, + handshake::{ReliableDrain, ReliableSink}, + io::{UnreliableDrain, UnreliableSink}, + metrics::{ProtocolMetricCache, RemoveReason}, + prio::PrioManager, + types::Bandwidth, + ProtocolError, RecvProtocol, SendProtocol, +}; +use async_trait::async_trait; +use std::{ + collections::{HashMap, VecDeque}, + sync::Arc, + time::{Duration, Instant}, +}; +use tracing::info; + +#[derive(Debug)] +pub struct TcpSendProtcol +where + D: UnreliableDrain>, +{ + buffer: Vec, + store: PrioManager, + closing_streams: Vec, + pending_shutdown: bool, + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +#[derive(Debug)] +pub struct TcpRecvProtcol +where + S: UnreliableSink>, +{ + buffer: VecDeque, + incoming: HashMap, + sink: S, + metrics: ProtocolMetricCache, +} + +impl TcpSendProtcol +where + D: UnreliableDrain>, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + buffer: vec![0u8; 1500], + store: PrioManager::new(metrics.clone()), + closing_streams: vec![], + pending_shutdown: false, + drain, + last: Instant::now(), + metrics, + } + } +} + +impl TcpRecvProtcol +where + S: UnreliableSink>, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { + Self { + buffer: VecDeque::new(), + incoming: HashMap::new(), + sink, + metrics, + } + } +} + +#[async_trait] +impl SendProtocol for TcpSendProtcol +where + D: UnreliableDrain>, +{ + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + }, + ProtocolEvent::CloseStream { sid } => { + if self.store.try_close_stream(sid) { + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + } else { + self.closing_streams.push(sid); + } + }, + ProtocolEvent::Shutdown => { + if self.store.is_empty() { + tracing::error!(?event, "send frame"); + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + } else { + self.pending_shutdown = true; + } + }, + ProtocolEvent::Message { buffer, mid, sid } => { + self.metrics.smsg_it(sid); + self.metrics.smsg_ib(sid, buffer.data.len() as u64); + self.store.add(buffer, mid, sid); + }, + } + Ok(()) + } + + async fn flush(&mut self, bandwidth: Bandwidth, dt: Duration) -> Result<(), ProtocolError> { + let frames = self.store.grab(bandwidth, dt); + for frame in frames { + if let Frame::Data { + mid: _, + start: _, + data, + } = &frame + { + self.metrics.sdata_frames_t(); + self.metrics.sdata_frames_b(data.len() as u64); + } + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + tracing::warn!("send data frame, woop"); + } + let mut finished_streams = vec![]; + for (i, sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + let frame = ProtocolEvent::CloseStream { sid: *sid }.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.closing_streams.remove(*i); + } + if self.pending_shutdown && self.store.is_empty() { + tracing::error!("send shutdown frame"); + let frame = ProtocolEvent::Shutdown {}.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + self.pending_shutdown = false; + } + Ok(()) + } +} + +use crate::{ + message::MessageBuffer, + types::{Mid, Sid}, +}; + +#[derive(Debug)] +struct IncomingMsg { + sid: Sid, + length: u64, + data: MessageBuffer, +} + +#[async_trait] +impl RecvProtocol for TcpRecvProtcol +where + S: UnreliableSink>, +{ + async fn recv(&mut self) -> Result { + tracing::error!(?self.buffer, "enter loop"); + 'outer: loop { + tracing::error!(?self.buffer, "continue loop"); + while let Some(frame) = Frame::to_frame(&mut self.buffer) { + tracing::error!(?frame, "recv frame"); + match frame { + Frame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + Frame::OpenStream { + sid, + prio, + promises, + } => { + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: 1_000_000, + }); + }, + Frame::CloseStream { sid } => { + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + Frame::DataHeader { sid, mid, length } => { + let m = IncomingMsg { + sid, + length, + data: MessageBuffer { data: vec![] }, + }; + self.metrics.rmsg_it(sid); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + Frame::Data { + mid, + start: _, + mut data, + } => { + self.metrics.rdata_frames_t(); + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + info!("protocol violation by remote side: send Data before Header"); + break 'outer Err(ProtocolError::Closed); + }, + }; + m.data.data.append(&mut data); + if m.data.data.len() == m.length as usize { + // finished, yay + drop(m); + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ot(m.sid, RemoveReason::Finished); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + mid, + buffer: Arc::new(m.data), + }); + } + }, + }; + } + tracing::error!(?self.buffer, "receiving on tcp sink"); + let chunk = self.sink.recv().await?; + self.buffer.reserve(chunk.len()); + for b in chunk { + self.buffer.push_back(b); + } + tracing::error!(?self.buffer,"receiving on tcp sink done"); + } + } +} + +#[async_trait] +impl ReliableDrain for TcpSendProtcol +where + D: UnreliableDrain>, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + let mut buffer = vec![0u8; 1500]; + let s = frame.to_bytes(&mut buffer); + buffer.truncate(s); + self.drain.send(buffer).await + } +} + +#[async_trait] +impl ReliableSink for TcpRecvProtcol +where + S: UnreliableSink>, +{ + async fn recv(&mut self) -> Result { + while self.buffer.len() < 100 { + let chunk = self.sink.recv().await?; + self.buffer.reserve(chunk.len()); + for b in chunk { + self.buffer.push_back(b); + } + let todo_use_bytes_instead = self.buffer.iter().map(|b| *b).collect(); + if let Some(frame) = InitFrame::to_frame(todo_use_bytes_instead) { + match frame { + InitFrame::Handshake { .. } => self.buffer.drain(.. InitFrame::HANDSHAKE_CNS + 1), + InitFrame::Init { .. } => self.buffer.drain(.. InitFrame::INIT_CNS + 1), + InitFrame::Raw { .. } => self.buffer.drain(.. InitFrame::RAW_CNS + 1), + }; + return Ok(frame); + } + } + Err(ProtocolError::Closed) + } +} + +#[cfg(test)] +mod test_utils { + //TCP protocol based on Channel + use super::*; + use crate::{ + io::*, + metrics::{ProtocolMetricCache, ProtocolMetrics}, + }; + use async_channel::*; + + pub struct TcpDrain { + pub sender: Sender>, + } + + pub struct TcpSink { + pub receiver: Receiver>, + } + + /// emulate Tcp protocol on Channels + pub fn tcp_bound( + cap: usize, + metrics: Option, + ) -> [(TcpSendProtcol, TcpRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + TcpSendProtcol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r2 }, m.clone()), + ), + ( + TcpSendProtcol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, + tcp::test_utils::*, + types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, MessageBuffer, ProtocolEvent, RecvProtocol, SendProtocol, + }; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = tcp_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } + + #[tokio::test] + async fn open_stream() { + let [p1, p2] = tcp_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 9u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event.clone()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_short_msg() { + let [p1, p2] = tcp_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 0, + buffer: Arc::new(MessageBuffer { + data: vec![188u8; 600], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + // 2nd short message + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 1, + buffer: Arc::new(MessageBuffer { + data: vec![7u8; 30], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e) + } + + #[tokio::test] + async fn send_long_msg() { + let metrics = + ProtocolMetricCache::new("long_tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, Some(metrics.clone())); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + metrics.assert_msg(sid, 1, RemoveReason::Finished); + metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished); + metrics.assert_data_frames(358); + metrics.assert_data_frames_bytes(500_000); + } + + #[tokio::test] + async fn msg_finishes_after_close() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_shutdown() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Shutdown {}; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Shutdown { .. })); + } + + #[tokio::test] + async fn header_and_data_in_seperate_msg() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::TcpRecvProtcol::new(super::test_utils::TcpSink { receiver: r }, m.clone()); + + const DATA1: &[u8; 69] = + b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; + const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; + let mut buf = vec![0u8; 1500]; + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + let (i, _) = event.to_frame().to_bytes(&mut buf); + let (i2, _) = crate::frame::Frame::DataHeader { + mid: 99, + sid, + length: (DATA1.len() + DATA2.len()) as u64, + } + .to_bytes(&mut buf[i..]); + buf.truncate(i + i2); + s.send(buf).await.unwrap(); + + let mut buf = vec![0u8; 1500]; + let (i, _) = crate::frame::Frame::Data { + mid: 99, + start: 0, + data: DATA1.to_vec(), + } + .to_bytes(&mut buf); + let (i2, _) = crate::frame::Frame::Data { + mid: 99, + start: DATA1.len() as u64, + data: DATA2.to_vec(), + } + .to_bytes(&mut buf[i..]); + let (i3, _) = crate::frame::Frame::CloseStream { sid }.to_bytes(&mut buf[i + i2..]); + buf.truncate(i + i2 + i3); + s.send(buf).await.unwrap(); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } +} diff --git a/network/src/types.rs b/network/protocol/src/types.rs similarity index 52% rename from network/src/types.rs rename to network/protocol/src/types.rs index d257ed808f..b6f63ca208 100644 --- a/network/src/types.rs +++ b/network/protocol/src/types.rs @@ -1,10 +1,10 @@ use bitflags::bitflags; use rand::Rng; -use std::convert::TryFrom; pub type Mid = u64; pub type Cid = u64; pub type Prio = u8; +pub type Bandwidth = u64; bitflags! { /// use promises to modify the behavior of [`Streams`]. @@ -21,9 +21,8 @@ bitflags! { /// this will guarantee that the other side will receive every message exactly /// once no messages are dropped const GUARANTEED_DELIVERY = 0b00000100; - /// this will enable the internal compression on this + /// this will enable the internal compression on this, only useable with #[cfg(feature = "compression")] /// [`Stream`](crate::api::Stream) - #[cfg(feature = "compression")] const COMPRESSED = 0b00001000; /// this will enable the internal encryption on this /// [`Stream`](crate::api::Stream) @@ -35,7 +34,7 @@ impl Promises { pub const fn to_le_bytes(self) -> [u8; 1] { self.bits.to_le_bytes() } } -pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = [86, 69, 76, 79, 82, 69, 78]; //VELOREN +pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = *b"VELOREN"; pub const VELOREN_NETWORK_VERSION: [u32; 3] = [0, 5, 0]; pub(crate) const STREAM_ID_OFFSET1: Sid = Sid::new(0); pub(crate) const STREAM_ID_OFFSET2: Sid = Sid::new(u64::MAX / 2); @@ -51,144 +50,18 @@ pub struct Pid { } #[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub(crate) struct Sid { +pub struct Sid { internal: u64, } -// Used for Communication between Channel <----(TCP/UDP)----> Channel -#[derive(Debug)] -pub(crate) enum Frame { - Handshake { - magic_number: [u8; 7], - version: [u32; 3], - }, - Init { - pid: Pid, - secret: u128, - }, - Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown, Participant - * is deleted */ - OpenStream { - sid: Sid, - prio: Prio, - promises: Promises, - }, - CloseStream { - sid: Sid, - }, - DataHeader { - mid: Mid, - sid: Sid, - length: u64, - }, - Data { - mid: Mid, - start: u64, - data: Vec, - }, - /* WARNING: Sending RAW is only used for debug purposes in case someone write a new API - * against veloren Server! */ - Raw(Vec), -} - -impl Frame { - #[cfg(feature = "metrics")] - pub const FRAMES_LEN: u8 = 8; - - #[cfg(feature = "metrics")] - pub const fn int_to_string(i: u8) -> &'static str { - match i { - 0 => "Handshake", - 1 => "Init", - 2 => "Shutdown", - 3 => "OpenStream", - 4 => "CloseStream", - 5 => "DataHeader", - 6 => "Data", - 7 => "Raw", - _ => "", - } - } - - #[cfg(feature = "metrics")] - pub fn get_int(&self) -> u8 { - match self { - Frame::Handshake { .. } => 0, - Frame::Init { .. } => 1, - Frame::Shutdown => 2, - Frame::OpenStream { .. } => 3, - Frame::CloseStream { .. } => 4, - Frame::DataHeader { .. } => 5, - Frame::Data { .. } => 6, - Frame::Raw(_) => 7, - } - } - - #[cfg(feature = "metrics")] - pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } - - pub fn gen_handshake(buf: [u8; 19]) -> Self { - let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); - Frame::Handshake { - magic_number, - version: [ - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..19]).unwrap()), - ], - } - } - - pub fn gen_init(buf: [u8; 32]) -> Self { - Frame::Init { - pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), - secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..32]).unwrap()), - } - } - - pub fn gen_open_stream(buf: [u8; 10]) -> Self { - Frame::OpenStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - prio: buf[8], - promises: Promises::from_bits_truncate(buf[9]), - } - } - - pub fn gen_close_stream(buf: [u8; 8]) -> Self { - Frame::CloseStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - } - } - - pub fn gen_data_header(buf: [u8; 24]) -> Self { - Frame::DataHeader { - mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), - length: u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[16..24]).unwrap()), - } - } - - pub fn gen_data(buf: [u8; 18]) -> (Mid, u64, u16) { - let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); - let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); - let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..18]).unwrap()); - (mid, start, length) - } - - pub fn gen_raw(buf: [u8; 2]) -> u16 { - u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..2]).unwrap()) - } -} - impl Pid { /// create a new Pid with a random interior value /// /// # Example /// ```rust - /// use veloren_network::{Network, Pid}; + /// use veloren_network_protocol::Pid; /// /// let pid = Pid::new(); - /// let _ = Network::new(pid); /// ``` pub fn new() -> Self { Self { @@ -295,20 +168,7 @@ fn sixlet_to_str(sixlet: u128) -> char { #[cfg(test)] mod tests { - use crate::types::*; - - #[test] - fn frame_int2str() { - assert_eq!(Frame::int_to_string(3), "OpenStream"); - assert_eq!(Frame::int_to_string(7), "Raw"); - assert_eq!(Frame::int_to_string(8), ""); - } - - #[test] - fn frame_get_int() { - assert_eq!(Frame::get_int(&Frame::Raw(b"Foo".to_vec())), 7); - assert_eq!(Frame::get_int(&Frame::Shutdown), 2); - } + use super::*; #[test] fn frame_creation() { diff --git a/network/protocol/src/udp.rs b/network/protocol/src/udp.rs new file mode 100644 index 0000000000..ad5c31a126 --- /dev/null +++ b/network/protocol/src/udp.rs @@ -0,0 +1,37 @@ +// TODO: quick and dirty which activly waits for an ack! +/* +UDP protocol + +All Good Case: +S --HEADER--> R +S --DATA--> R +S --DATA--> R +S <--FINISHED-- R + + +Delayed HEADER: +S --HEADER--> +S --DATA--> R // STORE IT + --HEADER--> R // apply left data and continue +S --DATA--> R +S <--FINISHED-- R + + +NO HEADER: +S --HEADER--> ! +S --DATA--> R // STORE IT +S --DATA--> R // STORE IT +S <--MISSING_HEADER-- R // SEND AFTER 10 ms after DATA1 +S --HEADER--> R +S <--FINISHED-- R + + +NO DATA: +S --HEADER--> R +S --DATA--> R +S --DATA--> ! +S --STATUS--> R +S <--MISSING_DATA -- R +S --DATA--> R +S <--FINISHED-- R +*/ diff --git a/network/src/api.rs b/network/src/api.rs index ef6fb113db..08274c90be 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -3,13 +3,13 @@ //! //! (cd network/examples/async_recv && RUST_BACKTRACE=1 cargo run) use crate::{ - message::{partial_eq_bincode, IncomingMessage, Message, OutgoingMessage}, + message::{partial_eq_bincode, Message}, participant::{A2bStreamOpen, S2bShutdownBparticipant}, scheduler::Scheduler, - types::{Mid, Pid, Prio, Promises, Sid}, }; #[cfg(feature = "compression")] use lz_fear::raw::DecodeError; +use network_protocol::{Bandwidth, MessageBuffer, Mid, Pid, Prio, Promises, Sid}; #[cfg(feature = "metrics")] use prometheus::Registry; use serde::{de::DeserializeOwned, Serialize}; @@ -20,6 +20,7 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, + time::Duration, }; use tokio::{ io, @@ -49,8 +50,7 @@ pub enum ProtocolAddr { pub struct Participant { local_pid: Pid, remote_pid: Pid, - runtime: Arc, - a2b_stream_open_s: Mutex>, + a2b_open_stream_s: Mutex>, b2a_stream_opened_r: Mutex>, a2s_disconnect_s: A2sDisconnect, } @@ -75,9 +75,10 @@ pub struct Stream { mid: Mid, prio: Prio, promises: Promises, + guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: Option>, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + b2a_msg_recv_r: Option>, a2b_close_stream_s: Option>, } @@ -419,16 +420,14 @@ impl Participant { pub(crate) fn new( local_pid: Pid, remote_pid: Pid, - runtime: Arc, - a2b_stream_open_s: mpsc::UnboundedSender, + a2b_open_stream_s: mpsc::UnboundedSender, b2a_stream_opened_r: mpsc::UnboundedReceiver, a2s_disconnect_s: mpsc::UnboundedSender<(Pid, S2bShutdownBparticipant)>, ) -> Self { Self { local_pid, remote_pid, - runtime, - a2b_stream_open_s: Mutex::new(a2b_stream_open_s), + a2b_open_stream_s: Mutex::new(a2b_open_stream_s), b2a_stream_opened_r: Mutex::new(b2a_stream_opened_r), a2s_disconnect_s: Arc::new(Mutex::new(Some(a2s_disconnect_s))), } @@ -477,13 +476,13 @@ impl Participant { /// /// [`Streams`]: crate::api::Stream pub async fn open(&self, prio: u8, promises: Promises) -> Result { - let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel(); - if let Err(e) = - self.a2b_stream_open_s - .lock() - .await - .send((prio, promises, p2a_return_stream_s)) - { + let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel::(); + if let Err(e) = self.a2b_open_stream_s.lock().await.send(( + prio, + promises, + 100000u64, + p2a_return_stream_s, + )) { debug!(?e, "bParticipant is already closed, notifying"); return Err(ParticipantError::ParticipantDisconnected); } @@ -602,7 +601,7 @@ impl Participant { // Participant is connecting to Scheduler here, not as usual // Participant<->BParticipant a2s_disconnect_s - .send((pid, finished_sender)) + .send((pid, (Duration::from_secs(120), finished_sender))) .expect("Something is wrong in internal scheduler coding"); match finished_receiver.await { Ok(res) => { @@ -647,9 +646,10 @@ impl Stream { sid: Sid, prio: Prio, promises: Promises, + guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: async_channel::Receiver, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + b2a_msg_recv_r: async_channel::Receiver, a2b_close_stream_s: mpsc::UnboundedSender, ) -> Self { Self { @@ -658,6 +658,7 @@ impl Stream { mid: 0, prio, promises, + guaranteed_bandwidth, send_closed, a2b_msg_s, b2a_msg_recv_r: Some(b2a_msg_recv_r), @@ -776,12 +777,8 @@ impl Stream { } #[cfg(debug_assertions)] message.verify(&self); - self.a2b_msg_s.send((self.prio, self.sid, OutgoingMessage { - buffer: Arc::clone(&message.buffer), - cursor: 0, - mid: self.mid, - sid: self.sid, - }))?; + self.a2b_msg_s + .send((self.sid, Arc::clone(&message.buffer)))?; self.mid += 1; Ok(()) } @@ -864,7 +861,7 @@ impl Stream { Some(b2a_msg_recv_r) => { match b2a_msg_recv_r.recv().await { Ok(msg) => Ok(Message { - buffer: Arc::new(msg.buffer), + buffer: Arc::new(msg), #[cfg(feature = "compression")] compressed: self.promises.contains(Promises::COMPRESSED), }), @@ -917,7 +914,7 @@ impl Stream { Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_recv() { Ok(msg) => Ok(Some( Message { - buffer: Arc::new(msg.buffer), + buffer: Arc::new(msg), #[cfg(feature = "compression")] compressed: self.promises().contains(Promises::COMPRESSED), } @@ -953,47 +950,62 @@ impl Drop for Network { "Shutting down Participants of Network, while we still have metrics" ); let mut finished_receiver_list = vec![]; - self.runtime.block_on(async { - // we MUST avoid nested block_on, good that Network::Drop no longer triggers - // Participant::Drop directly but just the BParticipant - for (remote_pid, a2s_disconnect_s) in - self.participant_disconnect_sender.lock().await.drain() - { - match a2s_disconnect_s.lock().await.take() { - Some(a2s_disconnect_s) => { - trace!(?remote_pid, "Participants will be closed"); - let (finished_sender, finished_receiver) = oneshot::channel(); - finished_receiver_list.push((remote_pid, finished_receiver)); - a2s_disconnect_s.send((remote_pid, finished_sender)).expect( - "Scheduler is closed, but nobody other should be able to close it", - ); - }, - None => trace!(?remote_pid, "Participant already disconnected gracefully"), + + if tokio::runtime::Handle::try_current().is_ok() { + error!("we have a runtime but we mustn't, DROP NETWORK from async runtime is illegal") + } + + tokio::task::block_in_place(|| { + /* This context prevents panic if Dropped in a async fn */ + self.runtime.block_on(async { + for (remote_pid, a2s_disconnect_s) in + self.participant_disconnect_sender.lock().await.drain() + { + match a2s_disconnect_s.lock().await.take() { + Some(a2s_disconnect_s) => { + trace!(?remote_pid, "Participants will be closed"); + let (finished_sender, finished_receiver) = oneshot::channel(); + finished_receiver_list.push((remote_pid, finished_receiver)); + a2s_disconnect_s + .send((remote_pid, (Duration::from_secs(120), finished_sender))) + .expect( + "Scheduler is closed, but nobody other should be able to \ + close it", + ); + }, + None => trace!(?remote_pid, "Participant already disconnected gracefully"), + } } - } - //wait after close is requested for all - for (remote_pid, finished_receiver) in finished_receiver_list.drain(..) { - match finished_receiver.await { - Ok(Ok(())) => trace!(?remote_pid, "disconnect successful"), - Ok(Err(e)) => info!(?remote_pid, ?e, "unclean disconnect"), - Err(e) => warn!( - ?remote_pid, - ?e, - "Failed to get a message back from the scheduler, seems like the network \ - is already closed" - ), + //wait after close is requested for all + for (remote_pid, finished_receiver) in finished_receiver_list.drain(..) { + match finished_receiver.await { + Ok(Ok(())) => trace!(?remote_pid, "disconnect successful"), + Ok(Err(e)) => info!(?remote_pid, ?e, "unclean disconnect"), + Err(e) => warn!( + ?remote_pid, + ?e, + "Failed to get a message back from the scheduler, seems like the \ + network is already closed" + ), + } } - } + }); }); trace!(?pid, "Participants have shut down!"); trace!(?pid, "Shutting down Scheduler"); - self.shutdown_sender.take().unwrap().send(()).expect("Scheduler is closed, but nobody other should be able to close it"); + self.shutdown_sender + .take() + .unwrap() + .send(()) + .expect("Scheduler is closed, but nobody other should be able to close it"); debug!(?pid, "Network has shut down"); } } impl Drop for Participant { fn drop(&mut self) { + use tokio::sync::oneshot::error::TryRecvError; + // ignore closed, as we need to send it even though we disconnected the // participant from network let pid = self.remote_pid; @@ -1011,23 +1023,28 @@ impl Drop for Participant { ), Some(a2s_disconnect_s) => { debug!(?pid, "Disconnect from Scheduler"); - self.runtime.block_on(async { - let (finished_sender, finished_receiver) = oneshot::channel(); - a2s_disconnect_s - .send((self.remote_pid, finished_sender)) - .expect("Something is wrong in internal scheduler coding"); - if let Err(e) = finished_receiver - .await - .expect("Something is wrong in internal scheduler/participant coding") - { - error!( + let (finished_sender, mut finished_receiver) = oneshot::channel(); + a2s_disconnect_s + .send((self.remote_pid, (Duration::from_secs(120), finished_sender))) + .expect("Something is wrong in internal scheduler coding"); + loop { + match finished_receiver.try_recv() { + Ok(Ok(())) => break, + Ok(Err(e)) => error!( ?pid, ?e, "Error while dropping the participant, couldn't send all outgoing \ messages, dropping remaining" - ); - }; - }); + ), + Err(TryRecvError::Closed) => { + panic!("Something is wrong in internal scheduler/participant coding") + }, + Err(TryRecvError::Empty) => { + trace!("activly sleeping"); + std::thread::sleep(Duration::from_millis(20)); + }, + } + } }, } debug!(?pid, "Participant dropped"); @@ -1041,11 +1058,12 @@ impl Drop for Stream { let sid = self.sid; let pid = self.pid; debug!(?pid, ?sid, "Shutting down Stream"); - self.a2b_close_stream_s - .take() - .unwrap() - .send(self.sid) - .expect("bparticipant part of a gracefully shutdown must have crashed"); + if let Err(e) = self.a2b_close_stream_s.take().unwrap().send(self.sid) { + debug!( + ?e, + "bparticipant part of a gracefully shutdown was already closed" + ); + } } else { let sid = self.sid; let pid = self.pid; diff --git a/network/src/channel.rs b/network/src/channel.rs index 7928337bd1..654175fb1d 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,361 +1,231 @@ -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; -use crate::{ - participant::C2pFrame, - protocols::Protocols, - types::{ - Cid, Frame, Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER, - VELOREN_NETWORK_VERSION, - }, -}; -use futures_core::task::Poll; -use futures_util::{ - task::{noop_waker, Context}, - FutureExt, +use async_trait::async_trait; +use network_protocol::{ + InitProtocolError, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, ProtocolError, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtcol, TcpSendProtcol, + UnreliableDrain, UnreliableSink, }; #[cfg(feature = "metrics")] use std::sync::Arc; +use std::time::Duration; use tokio::{ - join, - sync::{mpsc, oneshot}, + io::{AsyncReadExt, AsyncWriteExt}, + net::tcp::{OwnedReadHalf, OwnedWriteHalf}, + sync::mpsc, }; -use tracing::*; -pub(crate) struct Channel { - cid: Cid, - c2w_frame_r: Option>, - read_stop_receiver: Option>, -} - -impl Channel { - pub fn new(cid: u64) -> (Self, mpsc::UnboundedSender, oneshot::Sender<()>) { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - ( - Self { - cid, - c2w_frame_r: Some(c2w_frame_r), - read_stop_receiver: Some(read_stop_receiver), - }, - c2w_frame_s, - read_stop_sender, - ) - } - - pub async fn run( - mut self, - protocol: Protocols, - mut w2c_cid_frame_s: mpsc::UnboundedSender, - mut leftover_cid_frame: Vec, - ) { - let c2w_frame_r = self.c2w_frame_r.take().unwrap(); - let read_stop_receiver = self.read_stop_receiver.take().unwrap(); - - //reapply leftovers from handshake - let cnt = leftover_cid_frame.len(); - trace!(?cnt, "Reapplying leftovers"); - for cid_frame in leftover_cid_frame.drain(..) { - w2c_cid_frame_s.send(cid_frame).unwrap(); - } - trace!(?cnt, "All leftovers reapplied"); - - trace!("Start up channel"); - match protocol { - Protocols::Tcp(tcp) => { - join!( - tcp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - tcp.write_to_wire(self.cid, c2w_frame_r), - ); - }, - Protocols::Udp(udp) => { - join!( - udp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - udp.write_to_wire(self.cid, c2w_frame_r), - ); - }, - } - - trace!("Shut down channel"); - } +#[derive(Debug)] +pub(crate) enum Protocols { + Tcp((TcpSendProtcol, TcpRecvProtcol)), + Mpsc((MpscSendProtcol, MpscRecvProtcol)), } #[derive(Debug)] -pub(crate) struct Handshake { - cid: Cid, - local_pid: Pid, - secret: u128, - init_handshake: bool, - #[cfg(feature = "metrics")] - metrics: Arc, +pub(crate) enum SendProtocols { + Tcp(TcpSendProtcol), + Mpsc(MpscSendProtcol), } -impl Handshake { - #[cfg(debug_assertions)] - const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required by \ - veloren server.\nWe are not sure if you are a valid \ - veloren client.\nClosing the connection" - .as_bytes(); - #[cfg(debug_assertions)] - const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \ - invalid version.\nWe don't know how to communicate with \ - you.\nClosing the connection"; +#[derive(Debug)] +pub(crate) enum RecvProtocols { + Tcp(TcpRecvProtcol), + Mpsc(MpscRecvProtcol), +} - pub fn new( - cid: u64, +impl Protocols { + pub(crate) fn new_tcp(stream: tokio::net::TcpStream) -> Self { + let (r, w) = stream.into_split(); + #[cfg(feature = "metrics")] + let metrics = ProtocolMetricCache::new( + "foooobaaaarrrrrrrr", + Arc::new(ProtocolMetrics::new().unwrap()), + ); + #[cfg(not(feature = "metrics"))] + let metrics = ProtocolMetricCache {}; + + let sp = TcpSendProtcol::new(TcpDrain { half: w }, metrics.clone()); + let rp = TcpRecvProtcol::new(TcpSink { half: r }, metrics.clone()); + Protocols::Tcp((sp, rp)) + } + + pub(crate) fn new_mpsc( + sender: mpsc::Sender, + receiver: mpsc::Receiver, + ) -> Self { + #[cfg(feature = "metrics")] + let metrics = + ProtocolMetricCache::new("mppppsssscccc", Arc::new(ProtocolMetrics::new().unwrap())); + #[cfg(not(feature = "metrics"))] + let metrics = ProtocolMetricCache {}; + + let sp = MpscSendProtcol::new(MpscDrain { sender }, metrics.clone()); + let rp = MpscRecvProtcol::new(MpscSink { receiver }, metrics.clone()); + Protocols::Mpsc((sp, rp)) + } + + pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) { + match self { + Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), + Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)), + } + } +} + +#[async_trait] +impl network_protocol::InitProtocol for Protocols { + async fn initialize( + &mut self, + initializer: bool, local_pid: Pid, secret: u128, - #[cfg(feature = "metrics")] metrics: Arc, - init_handshake: bool, - ) -> Self { - Self { - cid, - local_pid, - secret, - #[cfg(feature = "metrics")] - metrics, - init_handshake, + ) -> Result<(Pid, Sid, u128), InitProtocolError> { + match self { + Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, + Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, + } + } +} + +#[async_trait] +impl network_protocol::SendProtocol for SendProtocols { + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match self { + SendProtocols::Tcp(s) => s.send(event).await, + SendProtocols::Mpsc(s) => s.send(event).await, + } + } + + async fn flush(&mut self, bandwidth: u64, dt: Duration) -> Result<(), ProtocolError> { + match self { + SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, + SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, + } + } +} + +#[async_trait] +impl network_protocol::RecvProtocol for RecvProtocols { + async fn recv(&mut self) -> Result { + match self { + RecvProtocols::Tcp(r) => r.recv().await, + RecvProtocols::Mpsc(r) => r.recv().await, + } + } +} + +/////////////////////////////////////// +//// TCP +#[derive(Debug)] +pub struct TcpDrain { + half: OwnedWriteHalf, +} + +#[derive(Debug)] +pub struct TcpSink { + half: OwnedReadHalf, +} + +#[async_trait] +impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + //self.half.recv + match self.half.write_all(&data).await { + Ok(()) => Ok(()), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + let mut data = vec![0u8; 1500]; + match self.half.read(&mut data).await { + Ok(n) => { + data.truncate(n); + Ok(data) + }, + Err(_) => Err(ProtocolError::Closed), + } + } +} + +/////////////////////////////////////// +//// MPSC +#[derive(Debug)] +pub struct MpscDrain { + sender: tokio::sync::mpsc::Sender, +} + +#[derive(Debug)] +pub struct MpscSink { + receiver: tokio::sync::mpsc::Receiver, +} + +#[async_trait] +impl UnreliableDrain for MpscDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } +} + +#[async_trait] +impl UnreliableSink for MpscSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver.recv().await.ok_or(ProtocolError::Closed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use network_protocol::{Promises, RecvProtocol, SendProtocol}; + use tokio::net::{TcpListener, TcpStream}; + + #[tokio::test] + async fn tokio_sinks() { + let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap(); + let r1 = tokio::spawn(async move { + let (server, _) = listener.accept().await.unwrap(); + (listener, server) + }); + let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); + let (_listener, server) = r1.await.unwrap(); + let client = Protocols::new_tcp(client); + let server = Protocols::new_tcp(server); + let (mut s, _) = client.split(); + let (_, mut r) = server.split(); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(1), + prio: 4u8, + promises: Promises::GUARANTEED_DELIVERY, + guaranteed_bandwidth: 1_000, + }; + s.send(event.clone()).await.unwrap(); + let r = r.recv().await; + match r { + Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: _, + }) => { + assert_eq!(sid, Sid::new(1)); + assert_eq!(prio, 4u8); + assert_eq!(promises, Promises::GUARANTEED_DELIVERY); + }, + _ => { + panic!("wrong type {:?}", r); + }, } } - - pub async fn setup(self, protocol: &Protocols) -> Result<(Pid, Sid, u128, Vec), ()> { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded_channel::(); - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let handler_future = - self.frame_handler(&mut w2c_cid_frame_r, c2w_frame_s, read_stop_sender); - let res = match protocol { - Protocols::Tcp(tcp) => { - (join! { - tcp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - tcp.write_to_wire(self.cid, c2w_frame_r).fuse(), - handler_future, - }) - .2 - }, - Protocols::Udp(udp) => { - (join! { - udp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - udp.write_to_wire(self.cid, c2w_frame_r), - handler_future, - }) - .2 - }, - }; - - match res { - Ok(res) => { - let fake_waker = noop_waker(); - let mut ctx = Context::from_waker(&fake_waker); - let mut leftover_frames = vec![]; - while let Poll::Ready(Some(cid_frame)) = w2c_cid_frame_r.poll_recv(&mut ctx) { - leftover_frames.push(cid_frame); - } - let cnt = leftover_frames.len(); - if cnt > 0 { - debug!( - ?cnt, - "Some additional frames got already transferred, piping them to the \ - bparticipant as leftover_frames" - ); - } - Ok((res.0, res.1, res.2, leftover_frames)) - }, - Err(()) => Err(()), - } - } - - async fn frame_handler( - &self, - w2c_cid_frame_r: &mut mpsc::UnboundedReceiver, - mut c2w_frame_s: mpsc::UnboundedSender, - read_stop_sender: oneshot::Sender<()>, - ) -> Result<(Pid, Sid, u128), ()> { - const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ - something went wrong on network layer and connection will be closed"; - #[cfg(feature = "metrics")] - let cid_string = self.cid.to_string(); - - if self.init_handshake { - self.send_handshake(&mut c2w_frame_s).await; - } - - let frame = w2c_cid_frame_r.recv().await.map(|(_cid, frame)| frame); - #[cfg(feature = "metrics")] - { - if let Some(Ok(ref frame)) = frame { - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, &frame.get_string()]) - .inc(); - } - } - let r = match frame { - Some(Ok(Frame::Handshake { - magic_number, - version, - })) => { - trace!(?magic_number, ?version, "Recv handshake"); - if magic_number != VELOREN_MAGIC_NUMBER { - error!(?magic_number, "Connection with invalid magic_number"); - #[cfg(debug_assertions)] - self.send_raw_and_shutdown(&mut c2w_frame_s, Self::WRONG_NUMBER.to_vec()) - .await; - Err(()) - } else if version != VELOREN_NETWORK_VERSION { - error!(?version, "Connection with wrong network version"); - #[cfg(debug_assertions)] - self.send_raw_and_shutdown( - &mut c2w_frame_s, - format!( - "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection", - Self::WRONG_VERSION, - VELOREN_NETWORK_VERSION, - version, - ) - .as_bytes() - .to_vec(), - ) - .await; - Err(()) - } else { - debug!("Handshake completed"); - if self.init_handshake { - self.send_init(&mut c2w_frame_s).await; - } else { - self.send_handshake(&mut c2w_frame_s).await; - } - Ok(()) - } - }, - Some(Ok(frame)) => { - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - if let Frame::Raw(bytes) = frame { - match std::str::from_utf8(bytes.as_slice()) { - Ok(string) => error!(?string, ERR_S), - _ => error!(?bytes, ERR_S), - } - } - Err(()) - }, - Some(Err(())) => { - info!("Protocol got interrupted"); - Err(()) - }, - None => Err(()), - }; - if let Err(()) = r { - if let Err(e) = read_stop_sender.send(()) { - trace!( - ?e, - "couldn't stop protocol, probably it encountered a Protocol Stop and closed \ - itself already, which is fine" - ); - } - return Err(()); - } - - let frame = w2c_cid_frame_r.recv().await.map(|(_cid, frame)| frame); - let r = match frame { - Some(Ok(Frame::Init { pid, secret })) => { - debug!(?pid, "Participant send their ID"); - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, "ParticipantId"]) - .inc(); - let stream_id_offset = if self.init_handshake { - STREAM_ID_OFFSET1 - } else { - self.send_init(&mut c2w_frame_s).await; - STREAM_ID_OFFSET2 - }; - info!(?pid, "This Handshake is now configured!"); - Ok((pid, stream_id_offset, secret)) - }, - Some(Ok(frame)) => { - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - if let Frame::Raw(bytes) = frame { - match std::str::from_utf8(bytes.as_slice()) { - Ok(string) => error!(?string, ERR_S), - _ => error!(?bytes, ERR_S), - } - } - Err(()) - }, - Some(Err(())) => { - info!("Protocol got interrupted"); - Err(()) - }, - None => Err(()), - }; - if r.is_err() { - if let Err(e) = read_stop_sender.send(()) { - trace!( - ?e, - "couldn't stop protocol, probably it encountered a Protocol Stop and closed \ - itself already, which is fine" - ); - } - } - r - } - - async fn send_handshake(&self, c2w_frame_s: &mut mpsc::UnboundedSender) { - #[cfg(feature = "metrics")] - self.metrics - .frames_out_total - .with_label_values(&[&self.cid.to_string(), "Handshake"]) - .inc(); - c2w_frame_s - .send(Frame::Handshake { - magic_number: VELOREN_MAGIC_NUMBER, - version: VELOREN_NETWORK_VERSION, - }) - .unwrap(); - } - - async fn send_init(&self, c2w_frame_s: &mut mpsc::UnboundedSender) { - #[cfg(feature = "metrics")] - self.metrics - .frames_out_total - .with_label_values(&[&self.cid.to_string(), "ParticipantId"]) - .inc(); - c2w_frame_s - .send(Frame::Init { - pid: self.local_pid, - secret: self.secret, - }) - .unwrap(); - } - - #[cfg(debug_assertions)] - async fn send_raw_and_shutdown( - &self, - c2w_frame_s: &mut mpsc::UnboundedSender, - data: Vec, - ) { - debug!("Sending client instructions before killing"); - #[cfg(feature = "metrics")] - { - let cid_string = self.cid.to_string(); - self.metrics - .frames_out_total - .with_label_values(&[&cid_string, "Raw"]) - .inc(); - self.metrics - .frames_out_total - .with_label_values(&[&cid_string, "Shutdown"]) - .inc(); - } - c2w_frame_s.send(Frame::Raw(data)).unwrap(); - c2w_frame_s.send(Frame::Shutdown).unwrap(); - } } diff --git a/network/src/lib.rs b/network/src/lib.rs index ffba192643..7593f1edd8 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -104,14 +104,10 @@ mod channel; mod message; #[cfg(feature = "metrics")] mod metrics; mod participant; -mod prios; -mod protocols; mod scheduler; -#[macro_use] -mod types; pub use api::{ Network, NetworkError, Participant, ParticipantError, ProtocolAddr, Stream, StreamError, }; pub use message::Message; -pub use types::{Pid, Promises}; +pub use network_protocol::{Pid, Promises}; diff --git a/network/src/message.rs b/network/src/message.rs index 9ab9941599..0ad24c63ad 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -1,11 +1,9 @@ use serde::{de::DeserializeOwned, Serialize}; //use std::collections::VecDeque; +use crate::api::{Stream, StreamError}; +use network_protocol::MessageBuffer; #[cfg(feature = "compression")] -use crate::types::Promises; -use crate::{ - api::{Stream, StreamError}, - types::{Frame, Mid, Sid}, -}; +use network_protocol::Promises; use std::{io, sync::Arc}; #[cfg(all(feature = "compression", debug_assertions))] use tracing::warn; @@ -23,29 +21,6 @@ pub struct Message { pub(crate) compressed: bool, } -//Todo: Evaluate switching to VecDeque for quickly adding and removing data -// from front, back. -// - It would prob require custom bincode code but thats possible. -pub(crate) struct MessageBuffer { - pub data: Vec, -} - -#[derive(Debug)] -pub(crate) struct OutgoingMessage { - pub buffer: Arc, - pub cursor: u64, - pub mid: Mid, - pub sid: Sid, -} - -#[derive(Debug)] -pub(crate) struct IncomingMessage { - pub buffer: MessageBuffer, - pub length: u64, - pub mid: Mid, - pub sid: Sid, -} - impl Message { /// This serializes any message, according to the [`Streams`] [`Promises`]. /// You can reuse this `Message` and send it via other [`Streams`], if the @@ -170,38 +145,6 @@ impl Message { } } -impl OutgoingMessage { - pub(crate) const FRAME_DATA_SIZE: u64 = 1400; - - /// returns if msg is empty - pub(crate) fn fill_next>( - &mut self, - msg_sid: Sid, - frames: &mut E, - ) -> bool { - let to_send = std::cmp::min( - self.buffer.data[self.cursor as usize..].len() as u64, - Self::FRAME_DATA_SIZE, - ); - if to_send > 0 { - if self.cursor == 0 { - frames.extend(std::iter::once((msg_sid, Frame::DataHeader { - mid: self.mid, - sid: self.sid, - length: self.buffer.data.len() as u64, - }))); - } - frames.extend(std::iter::once((msg_sid, Frame::Data { - mid: self.mid, - start: self.cursor, - data: self.buffer.data[self.cursor as usize..][..to_send as usize].to_vec(), - }))); - }; - self.cursor += to_send; - self.cursor >= self.buffer.data.len() as u64 - } -} - ///wouldn't trust this aaaassss much, fine for tests pub(crate) fn partial_eq_io_error(first: &io::Error, second: &io::Error) -> bool { if let Some(f) = first.raw_os_error() { @@ -231,28 +174,6 @@ pub(crate) fn partial_eq_bincode(first: &bincode::ErrorKind, second: &bincode::E } } -impl std::fmt::Debug for MessageBuffer { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - //TODO: small messages! - let len = self.data.len(); - if len > 20 { - write!( - f, - "MessageBuffer(len: {}, {}, {}, {}, {:X?}..{:X?})", - len, - u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]), - u32::from_le_bytes([self.data[4], self.data[5], self.data[6], self.data[7]]), - u32::from_le_bytes([self.data[8], self.data[9], self.data[10], self.data[11]]), - &self.data[13..16], - &self.data[len - 8..len] - ) - } else { - write!(f, "MessageBuffer(len: {}, {:?})", len, &self.data[..]) - } - } -} - #[cfg(test)] mod tests { use crate::{api::Stream, message::*}; @@ -260,7 +181,8 @@ mod tests { use tokio::sync::mpsc; fn stub_stream(compressed: bool) -> Stream { - use crate::{api::*, types::*}; + use crate::api::*; + use network_protocol::*; #[cfg(feature = "compression")] let promises = if compressed { @@ -281,6 +203,7 @@ mod tests { Sid::new(0), 0u8, promises, + 1_000_000, Arc::new(AtomicBool::new(true)), a2b_msg_s, b2a_msg_recv_r, diff --git a/network/src/metrics.rs b/network/src/metrics.rs index d43aeaae3a..60650f68fe 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -1,10 +1,6 @@ -use crate::types::{Cid, Frame, Pid}; -use prometheus::{ - core::{AtomicU64, GenericCounter}, - IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry, -}; +use network_protocol::Pid; +use prometheus::{IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry}; use std::error::Error; -use tracing::*; /// 1:1 relation between NetworkMetrics and Network /// use 2NF here and avoid redundant data like CHANNEL AND PARTICIPANT encoding. @@ -25,29 +21,6 @@ pub struct NetworkMetrics { pub streams_opened_total: IntCounterVec, pub streams_closed_total: IntCounterVec, pub network_info: IntGauge, - // Frames counted a channel level, seperated by CHANNEL (and PARTICIPANT) AND FRAME TYPE, - pub frames_out_total: IntCounterVec, - pub frames_in_total: IntCounterVec, - // Frames counted at protocol level, seperated by CHANNEL (and PARTICIPANT) AND FRAME TYPE, - pub frames_wire_out_total: IntCounterVec, - pub frames_wire_in_total: IntCounterVec, - // throughput at protocol level, seperated by CHANNEL (and PARTICIPANT), - pub wire_out_throughput: IntCounterVec, - pub wire_in_throughput: IntCounterVec, - // send(prio) Messages count, seperated by STREAM AND PARTICIPANT, - pub message_out_total: IntCounterVec, - // send(prio) Messages throughput, seperated by STREAM AND PARTICIPANT, - pub message_out_throughput: IntCounterVec, - // flushed(prio) stream count, seperated by PARTICIPANT, - pub streams_flushed: IntCounterVec, - // TODO: queued Messages, seperated by STREAM (add PART, CHANNEL), - // queued Messages, seperated by PARTICIPANT - pub queued_count: IntGaugeVec, - // TODO: queued Messages bytes, seperated by STREAM (add PART, CHANNEL), - // queued Messages bytes, seperated by PARTICIPANT - pub queued_bytes: IntGaugeVec, - // ping calculated based on last msg seperated by PARTICIPANT - pub participants_ping: IntGaugeVec, } impl NetworkMetrics { @@ -115,99 +88,13 @@ impl NetworkMetrics { "version", &format!( "{}.{}.{}", - &crate::types::VELOREN_NETWORK_VERSION[0], - &crate::types::VELOREN_NETWORK_VERSION[1], - &crate::types::VELOREN_NETWORK_VERSION[2] + &network_protocol::VELOREN_NETWORK_VERSION[0], + &network_protocol::VELOREN_NETWORK_VERSION[1], + &network_protocol::VELOREN_NETWORK_VERSION[2] ), ) .const_label("local_pid", &format!("{}", &local_pid)); let network_info = IntGauge::with_opts(opts)?; - let frames_out_total = IntCounterVec::new( - Opts::new( - "frames_out_total", - "Number of all frames send per channel, at the channel level", - ), - &["channel", "frametype"], - )?; - let frames_in_total = IntCounterVec::new( - Opts::new( - "frames_in_total", - "Number of all frames received per channel, at the channel level", - ), - &["channel", "frametype"], - )?; - let frames_wire_out_total = IntCounterVec::new( - Opts::new( - "frames_wire_out_total", - "Number of all frames send per channel, at the protocol level", - ), - &["channel", "frametype"], - )?; - let frames_wire_in_total = IntCounterVec::new( - Opts::new( - "frames_wire_in_total", - "Number of all frames received per channel, at the protocol level", - ), - &["channel", "frametype"], - )?; - let wire_out_throughput = IntCounterVec::new( - Opts::new( - "wire_out_throughput", - "Throupgput of all data frames send per channel, at the protocol level", - ), - &["channel"], - )?; - let wire_in_throughput = IntCounterVec::new( - Opts::new( - "wire_in_throughput", - "Throupgput of all data frames send per channel, at the protocol level", - ), - &["channel"], - )?; - //TODO IN - let message_out_total = IntCounterVec::new( - Opts::new( - "message_out_total", - "Number of messages send by streams on the network", - ), - &["participant", "stream"], - )?; - //TODO IN - let message_out_throughput = IntCounterVec::new( - Opts::new( - "message_out_throughput", - "Throughput of messages send by streams on the network", - ), - &["participant", "stream"], - )?; - let streams_flushed = IntCounterVec::new( - Opts::new( - "stream_flushed", - "Number of flushed streams requested to PrioManager at participant level", - ), - &["participant"], - )?; - let queued_count = IntGaugeVec::new( - Opts::new( - "queued_count", - "Queued number of messages by participant on the network", - ), - &["channel"], - )?; - let queued_bytes = IntGaugeVec::new( - Opts::new( - "queued_bytes", - "Queued bytes of messages by participant on the network", - ), - &["channel"], - )?; - let participants_ping = IntGaugeVec::new( - Opts::new( - "participants_ping", - "Ping time to participants on the network", - ), - &["channel"], - )?; Ok(Self { listen_requests_total, @@ -220,18 +107,6 @@ impl NetworkMetrics { streams_opened_total, streams_closed_total, network_info, - frames_out_total, - frames_in_total, - frames_wire_out_total, - frames_wire_in_total, - wire_out_throughput, - wire_in_throughput, - message_out_total, - message_out_throughput, - streams_flushed, - queued_count, - queued_bytes, - participants_ping, }) } @@ -246,22 +121,8 @@ impl NetworkMetrics { registry.register(Box::new(self.streams_opened_total.clone()))?; registry.register(Box::new(self.streams_closed_total.clone()))?; registry.register(Box::new(self.network_info.clone()))?; - registry.register(Box::new(self.frames_out_total.clone()))?; - registry.register(Box::new(self.frames_in_total.clone()))?; - registry.register(Box::new(self.frames_wire_out_total.clone()))?; - registry.register(Box::new(self.frames_wire_in_total.clone()))?; - registry.register(Box::new(self.wire_out_throughput.clone()))?; - registry.register(Box::new(self.wire_in_throughput.clone()))?; - registry.register(Box::new(self.message_out_total.clone()))?; - registry.register(Box::new(self.message_out_throughput.clone()))?; - registry.register(Box::new(self.queued_count.clone()))?; - registry.register(Box::new(self.queued_bytes.clone()))?; - registry.register(Box::new(self.participants_ping.clone()))?; Ok(()) } - - //pub fn _is_100th_tick(&self) -> bool { - // self.tick.load(Ordering::Relaxed).rem_euclid(100) == 0 } } impl std::fmt::Debug for NetworkMetrics { @@ -270,138 +131,3 @@ impl std::fmt::Debug for NetworkMetrics { write!(f, "NetworkMetrics()") } } - -/* -pub(crate) struct PidCidFrameCache { - metric: MetricVec, - pid: String, - cache: Vec<[T::M; 8]>, -} -*/ - -pub(crate) struct MultiCidFrameCache { - metric: IntCounterVec, - cache: Vec<[Option>; Frame::FRAMES_LEN as usize]>, -} - -impl MultiCidFrameCache { - const CACHE_SIZE: usize = 2048; - - pub fn new(metric: IntCounterVec) -> Self { - Self { - metric, - cache: vec![], - } - } - - fn populate(&mut self, cid: Cid) { - let start_cid = self.cache.len(); - if cid >= start_cid as u64 && cid > (Self::CACHE_SIZE as Cid) { - warn!( - ?cid, - "cid, getting quite high, is this a attack on the cache?" - ); - } - self.cache.resize((cid + 1) as usize, [ - None, None, None, None, None, None, None, None, - ]); - } - - pub fn with_label_values(&mut self, cid: Cid, frame: &Frame) -> &GenericCounter { - self.populate(cid); - let frame_int = frame.get_int() as usize; - let r = &mut self.cache[cid as usize][frame_int]; - if r.is_none() { - *r = Some( - self.metric - .with_label_values(&[&cid.to_string(), &frame_int.to_string()]), - ); - } - r.as_ref().unwrap() - } -} - -pub(crate) struct CidFrameCache { - cache: [GenericCounter; Frame::FRAMES_LEN as usize], -} - -impl CidFrameCache { - pub fn new(metric: IntCounterVec, cid: Cid) -> Self { - let cid = cid.to_string(); - let cache = [ - metric.with_label_values(&[&cid, Frame::int_to_string(0)]), - metric.with_label_values(&[&cid, Frame::int_to_string(1)]), - metric.with_label_values(&[&cid, Frame::int_to_string(2)]), - metric.with_label_values(&[&cid, Frame::int_to_string(3)]), - metric.with_label_values(&[&cid, Frame::int_to_string(4)]), - metric.with_label_values(&[&cid, Frame::int_to_string(5)]), - metric.with_label_values(&[&cid, Frame::int_to_string(6)]), - metric.with_label_values(&[&cid, Frame::int_to_string(7)]), - ]; - Self { cache } - } - - pub fn with_label_values(&mut self, frame: &Frame) -> &GenericCounter { - &self.cache[frame.get_int() as usize] - } -} - -#[cfg(test)] -mod tests { - use crate::{ - metrics::*, - types::{Frame, Pid}, - }; - - #[test] - fn register_metrics() { - let registry = Registry::new(); - let metrics = NetworkMetrics::new(&Pid::fake(1)).unwrap(); - metrics.register(®istry).unwrap(); - } - - #[test] - fn multi_cid_frame_cache() { - let pid = Pid::fake(1); - let frame1 = Frame::Raw(b"Foo".to_vec()); - let frame2 = Frame::Raw(b"Bar".to_vec()); - let metrics = NetworkMetrics::new(&pid).unwrap(); - let mut cache = MultiCidFrameCache::new(metrics.frames_in_total); - let v1 = cache.with_label_values(1, &frame1); - v1.inc(); - assert_eq!(v1.get(), 1); - let v2 = cache.with_label_values(1, &frame1); - v2.inc(); - assert_eq!(v2.get(), 2); - let v3 = cache.with_label_values(1, &frame2); - v3.inc(); - assert_eq!(v3.get(), 3); - let v4 = cache.with_label_values(3, &frame1); - v4.inc(); - assert_eq!(v4.get(), 1); - let v5 = cache.with_label_values(3, &Frame::Shutdown); - v5.inc(); - assert_eq!(v5.get(), 1); - } - - #[test] - fn cid_frame_cache() { - let pid = Pid::fake(1); - let frame1 = Frame::Raw(b"Foo".to_vec()); - let frame2 = Frame::Raw(b"Bar".to_vec()); - let metrics = NetworkMetrics::new(&pid).unwrap(); - let mut cache = CidFrameCache::new(metrics.frames_wire_out_total, 1); - let v1 = cache.with_label_values(&frame1); - v1.inc(); - assert_eq!(v1.get(), 1); - let v2 = cache.with_label_values(&frame1); - v2.inc(); - assert_eq!(v2.get(), 2); - let v3 = cache.with_label_values(&frame2); - v3.inc(); - assert_eq!(v3.get(), 3); - let v4 = cache.with_label_values(&Frame::Shutdown); - v4.inc(); - assert_eq!(v4.get(), 1); - } -} diff --git a/network/src/participant.rs b/network/src/participant.rs index 6986a70e8f..a942632f7b 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -1,43 +1,39 @@ #[cfg(feature = "metrics")] -use crate::metrics::{MultiCidFrameCache, NetworkMetrics}; +use crate::metrics::NetworkMetrics; use crate::{ api::{ParticipantError, Stream}, - channel::Channel, - message::{IncomingMessage, MessageBuffer, OutgoingMessage}, - prios::PrioManager, - protocols::Protocols, - types::{Cid, Frame, Pid, Prio, Promises, Sid}, + channel::{Protocols, RecvProtocols, SendProtocols}, }; use futures_util::{FutureExt, StreamExt}; +use network_protocol::{ + Bandwidth, Cid, MessageBuffer, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, + Sid, +}; use std::{ - collections::{HashMap, VecDeque}, + collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicI32, Ordering}, Arc, }, time::{Duration, Instant}, }; use tokio::{ - runtime::Runtime, select, sync::{mpsc, oneshot, Mutex, RwLock}, + task::JoinHandle, }; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; -use tracing_futures::Instrument; -pub(crate) type A2bStreamOpen = (Prio, Promises, oneshot::Sender); -pub(crate) type C2pFrame = (Cid, Result); -pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, Vec, oneshot::Sender<()>); -pub(crate) type S2bShutdownBparticipant = oneshot::Sender>; +pub(crate) type A2bStreamOpen = (Prio, Promises, Bandwidth, oneshot::Sender); +pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, oneshot::Sender<()>); +pub(crate) type S2bShutdownBparticipant = (Duration, oneshot::Sender>); pub(crate) type B2sPrioStatistic = (Pid, u64, u64); #[derive(Debug)] struct ChannelInfo { cid: Cid, cid_string: String, //optimisationmetrics - b2w_frame_s: mpsc::UnboundedSender, - b2r_read_shutdown: oneshot::Sender<()>, } #[derive(Debug)] @@ -45,23 +41,19 @@ struct StreamInfo { prio: Prio, promises: Promises, send_closed: Arc, - b2a_msg_recv_s: Mutex>, + b2a_msg_recv_s: Mutex>, } #[derive(Debug)] struct ControlChannels { - a2b_stream_open_r: mpsc::UnboundedReceiver, + a2b_open_stream_r: mpsc::UnboundedReceiver, b2a_stream_opened_s: mpsc::UnboundedSender, - b2b_close_stream_opened_sender_r: oneshot::Receiver<()>, s2b_create_channel_r: mpsc::UnboundedReceiver, - a2b_close_stream_r: mpsc::UnboundedReceiver, - a2b_close_stream_s: mpsc::UnboundedSender, s2b_shutdown_bparticipant_r: oneshot::Receiver, /* own */ } #[derive(Debug)] struct ShutdownInfo { - //a2b_stream_open_r: mpsc::UnboundedReceiver, b2b_close_stream_opened_sender_s: Option>, error: Option, } @@ -71,29 +63,27 @@ pub struct BParticipant { remote_pid: Pid, remote_pid_string: String, //optimisation offset_sid: Sid, - runtime: Arc, channels: Arc>>>, streams: RwLock>, - running_mgr: AtomicUsize, run_channels: Option, + shutdown_barrier: AtomicI32, #[cfg(feature = "metrics")] metrics: Arc, no_channel_error_info: RwLock<(Instant, u64)>, - shutdown_info: RwLock, } impl BParticipant { - const BANDWIDTH: u64 = 25_000_000; - const FRAMES_PER_TICK: u64 = Self::BANDWIDTH * Self::TICK_TIME_MS / 1000 / 1400 /*TCP FRAME*/; + // We use integer instead of Barrier to not block mgr from freeing at the end + const BARR_CHANNEL: i32 = 1; + const BARR_RECV: i32 = 4; + const BARR_SEND: i32 = 2; const TICK_TIME: Duration = Duration::from_millis(Self::TICK_TIME_MS); - //in bit/s const TICK_TIME_MS: u64 = 10; #[allow(clippy::type_complexity)] pub(crate) fn new( remote_pid: Pid, offset_sid: Sid, - runtime: Arc, #[cfg(feature = "metrics")] metrics: Arc, ) -> ( Self, @@ -102,27 +92,15 @@ impl BParticipant { mpsc::UnboundedSender, oneshot::Sender, ) { - let (a2b_steam_open_s, a2b_stream_open_r) = mpsc::unbounded_channel::(); + let (a2b_open_stream_s, a2b_open_stream_r) = mpsc::unbounded_channel::(); let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::(); - let (b2b_close_stream_opened_sender_s, b2b_close_stream_opened_sender_r) = - oneshot::channel(); - let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel(); let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel(); let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel(); - let shutdown_info = RwLock::new(ShutdownInfo { - //a2b_stream_open_r: a2b_stream_open_r.clone(), - b2b_close_stream_opened_sender_s: Some(b2b_close_stream_opened_sender_s), - error: None, - }); - let run_channels = Some(ControlChannels { - a2b_stream_open_r, + a2b_open_stream_r, b2a_stream_opened_s, - b2b_close_stream_opened_sender_r, s2b_create_channel_r, - a2b_close_stream_r, - a2b_close_stream_s, s2b_shutdown_bparticipant_r, }); @@ -131,17 +109,17 @@ impl BParticipant { remote_pid, remote_pid_string: remote_pid.to_string(), offset_sid, - runtime, channels: Arc::new(RwLock::new(HashMap::new())), streams: RwLock::new(HashMap::new()), - running_mgr: AtomicUsize::new(0), + shutdown_barrier: AtomicI32::new( + Self::BARR_CHANNEL + Self::BARR_SEND + Self::BARR_RECV, + ), run_channels, #[cfg(feature = "metrics")] metrics, no_channel_error_info: RwLock::new((Instant::now(), 0)), - shutdown_info, }, - a2b_steam_open_s, + a2b_open_stream_s, b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, @@ -149,693 +127,486 @@ impl BParticipant { } pub async fn run(mut self, b2s_prio_statistic_s: mpsc::UnboundedSender) { - //those managers that listen on api::Participant need an additional oneshot for - // shutdown scenario, those handled by scheduler will be closed by it. - let (shutdown_send_mgr_sender, shutdown_send_mgr_receiver) = oneshot::channel(); - let (shutdown_stream_close_mgr_sender, shutdown_stream_close_mgr_receiver) = - oneshot::channel(); - let (shutdown_open_mgr_sender, shutdown_open_mgr_receiver) = oneshot::channel(); - let (w2b_frames_s, w2b_frames_r) = mpsc::unbounded_channel::(); - let (prios, a2p_msg_s, b2p_notify_empty_stream_s) = PrioManager::new( - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - self.remote_pid_string.clone(), - ); + let (b2b_add_send_protocol_s, b2b_add_send_protocol_r) = + mpsc::unbounded_channel::<(Cid, SendProtocols)>(); + let (b2b_add_recv_protocol_s, b2b_add_recv_protocol_r) = + mpsc::unbounded_channel::<(Cid, RecvProtocols)>(); + let (b2b_close_send_protocol_s, b2b_close_send_protocol_r) = + async_channel::unbounded::(); + let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) = + async_channel::unbounded::(); + + let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::(); + const STREAM_BOUND: usize = 10_000; + let (a2b_msg_s, a2b_msg_r) = + crossbeam_channel::bounded::<(Sid, Arc)>(STREAM_BOUND); let run_channels = self.run_channels.take().unwrap(); tokio::join!( - self.open_mgr( - run_channels.a2b_stream_open_r, - run_channels.a2b_close_stream_s.clone(), - a2p_msg_s.clone(), - shutdown_open_mgr_receiver, + self.send_mgr( + run_channels.a2b_open_stream_r, + a2b_close_stream_r, + a2b_msg_r, + b2b_add_send_protocol_r, + b2b_close_send_protocol_r, + b2s_prio_statistic_s, + a2b_msg_s.clone(), //self + a2b_close_stream_s.clone(), //self ), - self.handle_frames_mgr( - w2b_frames_r, + self.recv_mgr( run_channels.b2a_stream_opened_s, - run_channels.b2b_close_stream_opened_sender_r, - run_channels.a2b_close_stream_s, - a2p_msg_s.clone(), + b2b_add_recv_protocol_r, + b2b_force_close_recv_protocol_r, + b2b_close_send_protocol_s.clone(), + a2b_msg_s.clone(), //self + a2b_close_stream_s.clone(), //self ), - self.create_channel_mgr(run_channels.s2b_create_channel_r, w2b_frames_s), - self.send_mgr(prios, shutdown_send_mgr_receiver, b2s_prio_statistic_s), - self.stream_close_mgr( - run_channels.a2b_close_stream_r, - shutdown_stream_close_mgr_receiver, - b2p_notify_empty_stream_s, + self.create_channel_mgr( + run_channels.s2b_create_channel_r, + b2b_add_send_protocol_s, + b2b_add_recv_protocol_s, ), self.participant_shutdown_mgr( run_channels.s2b_shutdown_bparticipant_r, - shutdown_open_mgr_sender, - shutdown_stream_close_mgr_sender, - shutdown_send_mgr_sender, + b2b_close_send_protocol_s.clone(), + b2b_force_close_recv_protocol_s, ), ); } + //TODO: local stream_cid: HashMap to know the respective protocol async fn send_mgr( &self, - mut prios: PrioManager, - mut shutdown_send_mgr_receiver: oneshot::Receiver>, - b2s_prio_statistic_s: mpsc::UnboundedSender, - ) { - //This time equals the MINIMUM Latency in average, so keep it down and //Todo: - // make it configurable or switch to await E.g. Prio 0 = await, prio 50 - // wait for more messages - self.running_mgr.fetch_add(1, Ordering::Relaxed); - let mut b2b_prios_flushed_s = None; //closing up - let mut interval = tokio::time::interval(Self::TICK_TIME); - trace!("Start send_mgr"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut i: u64 = 0; - loop { - let mut frames = VecDeque::new(); - prios - .fill_frames(Self::FRAMES_PER_TICK as usize, &mut frames) - .await; - let len = frames.len(); - for (_, frame) in frames { - self.send_frame( - frame, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - b2s_prio_statistic_s - .send((self.remote_pid, len as u64, /* */ 0)) - .unwrap(); - interval.tick().await; - i += 1; - if i.rem_euclid(1000) == 0 { - trace!("Did 1000 ticks"); - } - //shutdown after all msg are send! - // Make sure this is called after the API is closed, and all streams are known - // to be droped to the priomgr - if b2b_prios_flushed_s.is_some() && (len == 0) { - break; - } - if b2b_prios_flushed_s.is_none() { - if let Ok(prios_flushed_s) = shutdown_send_mgr_receiver.try_recv() { - b2b_prios_flushed_s = Some(prios_flushed_s); - } - } - } - trace!("Stop send_mgr"); - b2b_prios_flushed_s - .expect("b2b_prios_flushed_s not set") - .send(()) - .unwrap(); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - //returns false if sending isn't possible. In that case we have to render the - // Participant `closed` - #[must_use = "You need to check if the send was successful and report to client!"] - async fn send_frame( - &self, - frame: Frame, - #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, - ) -> bool { - let mut drop_cid = None; - // TODO: find out ideal channel here - - let res = if let Some(ci) = self.channels.read().await.values().next() { - let ci = ci.lock().await; - //we are increasing metrics without checking the result to please - // borrow_checker. otherwise we would need to close `frame` what we - // dont want! - #[cfg(feature = "metrics")] - frames_out_total_cache - .with_label_values(ci.cid, &frame) - .inc(); - if let Err(e) = ci.b2w_frame_s.send(frame) { - let cid = ci.cid; - info!(?e, ?cid, "channel no longer available"); - drop_cid = Some(cid); - false - } else { - true - } - } else { - let mut guard = self.no_channel_error_info.write().await; - let now = Instant::now(); - if now.duration_since(guard.0) > Duration::from_secs(1) { - guard.0 = now; - let occurrences = guard.1 + 1; - guard.1 = 0; - let lastframe = frame; - error!( - ?occurrences, - ?lastframe, - "Participant has no channel to communicate on" - ); - } else { - guard.1 += 1; - } - false - }; - if let Some(cid) = drop_cid { - if let Some(ci) = self.channels.write().await.remove(&cid) { - let ci = ci.into_inner(); - trace!(?cid, "stopping read protocol"); - if let Err(e) = ci.b2r_read_shutdown.send(()) { - trace!(?cid, ?e, "seems like was already shut down"); - } - } - //TODO FIXME tags: takeover channel multiple - info!( - "FIXME: the frame is actually drop. which is fine for now as the participant will \ - be closed, but not if we do channel-takeover" - ); - //TEMP FIX: as we dont have channel takeover yet drop the whole bParticipant - self.close_write_api(Some(ParticipantError::ProtocolFailedUnrecoverable)) - .await; - }; - res - } - - async fn handle_frames_mgr( - &self, - mut w2b_frames_r: mpsc::UnboundedReceiver, - b2a_stream_opened_s: mpsc::UnboundedSender, - b2b_close_stream_opened_sender_r: oneshot::Receiver<()>, + mut a2b_open_stream_r: mpsc::UnboundedReceiver, + mut a2b_close_stream_r: mpsc::UnboundedReceiver, + a2b_msg_r: crossbeam_channel::Receiver<(Sid, Arc)>, + mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>, + b2b_close_send_protocol_r: async_channel::Receiver, + _b2s_prio_statistic_s: mpsc::UnboundedSender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, a2b_close_stream_s: mpsc::UnboundedSender, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start handle_frames_mgr"); - let mut messages = HashMap::new(); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut dropped_instant = Instant::now(); - let mut dropped_cnt = 0u64; - let mut dropped_sid = Sid::new(0); - let mut b2a_stream_opened_s = Some(b2a_stream_opened_s); - let mut b2b_close_stream_opened_sender_r = b2b_close_stream_opened_sender_r.fuse(); + let mut send_protocols: HashMap = HashMap::new(); + let mut interval = tokio::time::interval(Self::TICK_TIME); + let mut stream_ids = self.offset_sid; + trace!("workaround, activly wait for first protocol"); + b2b_add_protocol_r + .recv() + .await + .map(|(c, p)| send_protocols.insert(c, p)); + trace!("Start send_mgr"); + loop { + let (open, close, _, addp, remp) = select!( + next = a2b_open_stream_r.recv().fuse() => (Some(next), None, None, None, None), + next = a2b_close_stream_r.recv().fuse() => (None, Some(next), None, None, None), + _ = interval.tick() => (None, None, Some(()), None, None), + next = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(next), None), + next = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(next)), + ); - while let Some((cid, result_frame)) = select!( - next = w2b_frames_r.recv().fuse() => next, - _ = &mut b2b_close_stream_opened_sender_r => { - b2a_stream_opened_s = None; - None - }, - ) { - //trace!(?result_frame, "handling frame"); - let frame = match result_frame { - Ok(frame) => frame, - Err(()) => { - // The read protocol stopped, i need to make sure that write gets stopped, can - // drop channel as it's dead anyway - debug!("read protocol was closed. Stopping channel"); - self.channels.write().await.remove(&cid); + trace!(?open, ?close, ?addp, ?remp, "foobar"); + + addp.flatten().map(|(c, p)| send_protocols.insert(c, p)); + match remp { + Some(Ok(cid)) => { + trace!(?cid, "remove send protocol"); + match send_protocols.remove(&cid) { + Some(mut prot) => { + trace!("blocking flush"); + let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await; + trace!("shutdown prot"); + let _ = prot.send(ProtocolEvent::Shutdown).await; + }, + None => trace!("tried to remove protocol twice"), + }; + if send_protocols.is_empty() { + break; + } + }, + _ => (), + }; + + let cid = 0; + let active = match send_protocols.get_mut(&cid) { + Some(a) => a, + None => { + warn!("no channel arrg"); continue; }, }; - #[cfg(feature = "metrics")] - { - let cid_string = cid.to_string(); - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - } - match frame { - Frame::OpenStream { - sid, - prio, - promises, - } => { - trace!(?sid, ?prio, ?promises, "Opened frame from remote"); - let a2p_msg_s = a2p_msg_s.clone(); + + let active_err = async { + if let Some(Some((prio, promises, guaranteed_bandwidth, return_s))) = open { + trace!(?stream_ids, "openuing some new stream"); + let sid = stream_ids; + stream_ids += Sid::from(1); let stream = self - .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) + .create_stream( + sid, + prio, + promises, + guaranteed_bandwidth, + &a2b_msg_s, + &a2b_close_stream_s, + ) .await; - match &b2a_stream_opened_s { - None => debug!("dropping openStream as Channel is already closing"), - Some(s) => { - if let Err(e) = s.send(stream) { - warn!( - ?e, - ?sid, - "couldn't notify api::Participant that a stream got opened. \ - Is the participant already dropped?" - ); - } - }, - } - }, - Frame::CloseStream { sid } => { - // no need to keep flushing as the remote no longer knows about this stream - // anyway - self.delete_stream( - sid, - None, - true, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - }, - Frame::DataHeader { mid, sid, length } => { - let imsg = IncomingMessage { - buffer: MessageBuffer { data: Vec::new() }, - length, - mid, + + let event = ProtocolEvent::OpenStream { sid, + prio, + promises, + guaranteed_bandwidth, }; - messages.insert(mid, imsg); - }, - Frame::Data { - mid, - start: _, - mut data, - } => { - let finished = if let Some(imsg) = messages.get_mut(&mid) { - imsg.buffer.data.append(&mut data); - imsg.buffer.data.len() as u64 == imsg.length - } else { - false - }; - if finished { - //trace!(?mid, "finished receiving message"); - let imsg = messages.remove(&mid).unwrap(); - if let Some(si) = self.streams.read().await.get(&imsg.sid) { - if let Err(e) = si.b2a_msg_recv_s.lock().await.send(imsg).await { - warn!( - ?e, - ?mid, - "Dropping message, as streams seem to be in act of being \ - dropped right now" - ); - } - } else { - //aggregate errors - let n = Instant::now(); - if dropped_cnt > 0 - && (dropped_sid != imsg.sid - || n.duration_since(dropped_instant) > Duration::from_secs(1)) - { - warn!( - ?dropped_cnt, - "Dropping multiple messages as stream no longer seems to \ - exist because it was dropped probably." - ); - dropped_cnt = 0; - dropped_instant = n; - dropped_sid = imsg.sid; - } else { - dropped_cnt += 1; - } - } - } - }, - Frame::Shutdown => { - debug!("Shutdown received from remote side"); - self.close_api(Some(ParticipantError::ParticipantDisconnected)) - .await; - }, - f => { - unreachable!( - "Frame should never reach participant!: {:?}, cid: {}", - f, cid - ); - }, + + return_s.send(stream).unwrap(); + active.send(event).await?; + } + + // get all messages and assign it to a channel + for (sid, buffer) in a2b_msg_r.try_iter() { + warn!(?sid, "sending!"); + active + .send(ProtocolEvent::Message { + buffer, + mid: 0u64, + sid, + }) + .await? + } + + if let Some(Some(sid)) = close { + warn!(?sid, "delete_stream!"); + self.delete_stream(sid).await; + // Fire&Forget the protocol will take care to verify that this Frame is delayed + // till the last msg was received! + active.send(ProtocolEvent::CloseStream { sid }).await?; + } + + warn!("flush!"); + active + .flush(1_000_000, Duration::from_secs(1) /* TODO */) + .await?; //this actually blocks, so we cant set streams whilte it. + let r: Result<(), network_protocol::ProtocolError> = Ok(()); + r + } + .await; + if let Err(e) = active_err { + info!(?cid, ?e, "send protocol failed, shutting down channel"); + // remote recv will now fail, which will trigger remote send which will trigger + // recv + send_protocols.remove(&cid).unwrap(); } } - if dropped_cnt > 0 { - warn!( - ?dropped_cnt, - "Dropping multiple messages as stream no longer seems to exist because it was \ - dropped probably." + trace!("Stop send_mgr"); + self.shutdown_barrier + .fetch_sub(Self::BARR_SEND, Ordering::Relaxed); + } + + async fn recv_mgr( + &self, + b2a_stream_opened_s: mpsc::UnboundedSender, + mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>, + b2b_force_close_recv_protocol_r: async_channel::Receiver, + b2b_close_send_protocol_s: async_channel::Sender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + a2b_close_stream_s: mpsc::UnboundedSender, + ) { + let mut recv_protocols: HashMap> = HashMap::new(); + // we should be able to directly await futures imo + let (hacky_recv_s, mut hacky_recv_r) = mpsc::unbounded_channel(); + + let retrigger = |cid: Cid, mut p: RecvProtocols, map: &mut HashMap<_, _>| { + let hacky_recv_s = hacky_recv_s.clone(); + let handle = tokio::spawn(async move { + let cid = cid; + let r = p.recv().await; + let _ = hacky_recv_s.send((cid, r, p)); // ignoring failed + }); + map.insert(cid, handle); + }; + + let remove_c = |recv_protocols: &mut HashMap>, cid: &Cid| { + match recv_protocols.remove(&cid) { + Some(h) => h.abort(), + None => trace!("tried to remove protocol twice"), + }; + recv_protocols.is_empty() + }; + + trace!("Start recv_mgr"); + loop { + let (event, addp, remp) = select!( + next = hacky_recv_r.recv().fuse() => (Some(next), None, None), + Some(next) = b2b_add_protocol_r.recv().fuse() => (None, Some(next), None), + next = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(next)), ); + + addp.map(|(cid, p)| { + retrigger(cid, p, &mut recv_protocols); + }); + if let Some(Ok(cid)) = remp { + // no need to stop the send_mgr here as it has been canceled before + if remove_c(&mut recv_protocols, &cid) { + break; + } + }; + + warn!(?event, "recv event!"); + if let Some(Some((cid, r, p))) = event { + match r { + Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + }) => { + trace!(?sid, "open stream"); + let stream = self + .create_stream( + sid, + prio, + promises, + guaranteed_bandwidth, + &a2b_msg_s, + &a2b_close_stream_s, + ) + .await; + b2a_stream_opened_s.send(stream).unwrap(); + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::CloseStream { sid }) => { + trace!(?sid, "close stream"); + self.delete_stream(sid).await; + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::Message { + buffer, + mid: _, + sid, + }) => { + let buffer = Arc::try_unwrap(buffer).unwrap(); + let lock = self.streams.read().await; + match lock.get(&sid) { + Some(stream) => { + stream + .b2a_msg_recv_s + .lock() + .await + .send(buffer) + .await + .unwrap(); + }, + None => warn!("recv a msg with orphan stream"), + }; + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::Shutdown) => { + info!(?cid, "shutdown protocol"); + if let Err(e) = b2b_close_send_protocol_s.send(cid).await { + debug!(?e, ?cid, "send_mgr was already closed simultaneously"); + } + if remove_c(&mut recv_protocols, &cid) { + break; + } + }, + Err(e) => { + info!(?cid, ?e, "recv protocol failed, shutting down channel"); + if let Err(e) = b2b_close_send_protocol_s.send(cid).await { + debug!(?e, ?cid, "send_mgr was already closed simultaneously"); + } + if remove_c(&mut recv_protocols, &cid) { + break; + } + }, + } + } } - trace!("Stop handle_frames_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + + trace!("Stop recv_mgr"); + self.shutdown_barrier + .fetch_sub(Self::BARR_RECV, Ordering::Relaxed); } async fn create_channel_mgr( &self, s2b_create_channel_r: mpsc::UnboundedReceiver, - w2b_frames_s: mpsc::UnboundedSender, + b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>, + b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); trace!("Start create_channel_mgr"); let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r); s2b_create_channel_r - .for_each_concurrent( - None, - |(cid, _, protocol, leftover_cid_frame, b2s_create_channel_done_s)| { - // This channel is now configured, and we are running it in scope of the - // participant. - let w2b_frames_s = w2b_frames_s.clone(); - let channels = Arc::clone(&self.channels); - async move { - let (channel, b2w_frame_s, b2r_read_shutdown) = Channel::new(cid); - let mut lock = channels.write().await; - #[cfg(feature = "metrics")] - let mut channel_no = lock.len(); - #[cfg(not(feature = "metrics"))] - let channel_no = lock.len(); - lock.insert( + .for_each_concurrent(None, |(cid, _, protocol, b2s_create_channel_done_s)| { + // This channel is now configured, and we are running it in scope of the + // participant. + //let w2b_frames_s = w2b_frames_s.clone(); + let channels = Arc::clone(&self.channels); + let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone(); + let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone(); + async move { + let mut lock = channels.write().await; + #[cfg(feature = "metrics")] + let mut channel_no = lock.len(); + lock.insert( + cid, + Mutex::new(ChannelInfo { cid, - Mutex::new(ChannelInfo { - cid, - cid_string: cid.to_string(), - b2w_frame_s, - b2r_read_shutdown, - }), - ); - drop(lock); - b2s_create_channel_done_s.send(()).unwrap(); - #[cfg(feature = "metrics")] - { - self.metrics - .channels_connected_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - if channel_no > 5 { - debug!(?channel_no, "metrics will overwrite channel #5"); - channel_no = 5; - } - self.metrics - .participants_channel_ids - .with_label_values(&[ - &self.remote_pid_string, - &channel_no.to_string(), - ]) - .set(cid as i64); - } - trace!(?cid, ?channel_no, "Running channel in participant"); - channel - .run(protocol, w2b_frames_s, leftover_cid_frame) - .instrument(tracing::info_span!("", ?cid)) - .await; - #[cfg(feature = "metrics")] + cid_string: cid.to_string(), + }), + ); + drop(lock); + let (send, recv) = protocol.split(); + b2b_add_send_protocol_s.send((cid, send)).unwrap(); + b2b_add_recv_protocol_s.send((cid, recv)).unwrap(); + b2s_create_channel_done_s.send(()).unwrap(); + #[cfg(feature = "metrics")] + { self.metrics - .channels_disconnected_total + .channels_connected_total .with_label_values(&[&self.remote_pid_string]) .inc(); - info!(?cid, "Channel got closed"); - //maybe channel got already dropped, we don't know. - channels.write().await.remove(&cid); - trace!(?cid, "Channel cleanup completed"); - //TEMP FIX: as we dont have channel takeover yet drop the whole - // bParticipant - self.close_write_api(None).await; + if channel_no > 5 { + debug!(?channel_no, "metrics will overwrite channel #5"); + channel_no = 5; + } + self.metrics + .participants_channel_ids + .with_label_values(&[&self.remote_pid_string, &channel_no.to_string()]) + .set(cid as i64); } - }, - ) + } + }) .await; trace!("Stop create_channel_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + self.shutdown_barrier + .fetch_sub(Self::BARR_CHANNEL, Ordering::Relaxed); } - async fn open_mgr( - &self, - mut a2b_stream_open_r: mpsc::UnboundedReceiver, - a2b_close_stream_s: mpsc::UnboundedSender, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - shutdown_open_mgr_receiver: oneshot::Receiver<()>, - ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start open_mgr"); - let mut stream_ids = self.offset_sid; - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut shutdown_open_mgr_receiver = shutdown_open_mgr_receiver.fuse(); - //from api or shutdown signal - while let Some((prio, promises, p2a_return_stream)) = select! { - next = a2b_stream_open_r.recv().fuse() => next, - _ = &mut shutdown_open_mgr_receiver => None, - } { - debug!(?prio, ?promises, "Got request to open a new steam"); - //TODO: a2b_stream_open_r isn't closed on api_close yet. This needs to change. - //till then just check here if we are closed and in that case do nothing (not - // even answer) - if self.shutdown_info.read().await.error.is_some() { - continue; - } - - let a2p_msg_s = a2p_msg_s.clone(); - let sid = stream_ids; - let stream = self - .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) - .await; - if self - .send_frame( - Frame::OpenStream { - sid, - prio, - promises, - }, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await - { - //On error, we drop this, so it gets closed and client will handle this as an - // Err any way (: - p2a_return_stream.send(stream).unwrap(); - stream_ids += Sid::from(1); - } - } - trace!("Stop open_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - /// when activated this function will drop the participant completely and - /// wait for everything to go right! Then return 1. Shutting down - /// Streams for API and End user! 2. Wait for all "prio queued" Messages - /// to be send. 3. Send Stream + /// sink shutdown: + /// Situation AS, AR, BS, BR. A wants to close. + /// AS shutdown. + /// BR notices shutdown and tries to stops BS. (success) + /// BS shutdown + /// AR notices shutdown and tries to stop AS. (fails) + /// For the case where BS didn't get shutdowned, e.g. by a handing situation + /// on the remote, we have a timeout to also force close AR. + /// + /// This fn will: + /// - 1. stop api to interact with bparticipant by closing sendmsg and + /// openstream + /// - 2. stop the send_mgr (it will take care of clearing the + /// queue and finish with a Shutdown) + /// - (3). force stop recv after 60 + /// seconds + /// - (4). this fn finishes last and afterwards BParticipant + /// drops + /// + /// before calling this fn, make sure `s2b_create_channel` is closed! /// If BParticipant kills itself managers stay active till this function is /// called by api to get the result status async fn participant_shutdown_mgr( &self, s2b_shutdown_bparticipant_r: oneshot::Receiver, - shutdown_open_mgr_sender: oneshot::Sender<()>, - shutdown_stream_close_mgr_sender: oneshot::Sender>, - shutdown_send_mgr_sender: oneshot::Sender>, + b2b_close_send_protocol_s: async_channel::Sender, + b2b_force_close_recv_protocol_s: async_channel::Sender, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); + let wait_for_manager = || async { + let mut sleep = 0.01f64; + loop { + let bytes = self.shutdown_barrier.load(Ordering::Relaxed); + if bytes == 0 { + break; + } + sleep *= 1.4; + tokio::time::sleep(Duration::from_secs_f64(sleep)).await; + if sleep > 0.2 { + trace!(?bytes, "wait for mgr to close"); + } + } + }; + trace!("Start participant_shutdown_mgr"); - let sender = s2b_shutdown_bparticipant_r.await.unwrap(); + let (timeout_time, sender) = s2b_shutdown_bparticipant_r.await.unwrap(); + debug!("participant_shutdown_mgr triggered"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - - self.close_api(None).await; - - debug!("Closing all managers"); - shutdown_open_mgr_sender - .send(()) - .expect("open_mgr must have crashed before"); - let (b2b_stream_close_shutdown_confirmed_s, b2b_stream_close_shutdown_confirmed_r) = - oneshot::channel(); - shutdown_stream_close_mgr_sender - .send(b2b_stream_close_shutdown_confirmed_s) - .expect("stream_close_mgr must have crashed before"); - // We need to wait for the stream_close_mgr BEFORE send_mgr, as the - // stream_close_mgr needs to wait on the API to drop `Stream` and be triggered - // It will then sleep for streams to be flushed in PRIO, and send_mgr is - // responsible for ticking PRIO WHILE this happens, so we cant close it before! - b2b_stream_close_shutdown_confirmed_r.await.unwrap(); - - //closing send_mgr now: - let (b2b_prios_flushed_s, b2b_prios_flushed_r) = oneshot::channel(); - shutdown_send_mgr_sender - .send(b2b_prios_flushed_s) - .expect("stream_close_mgr must have crashed before"); - b2b_prios_flushed_r.await.unwrap(); - - if Some(ParticipantError::ParticipantDisconnected) != self.shutdown_info.read().await.error + debug!("Closing all streams for send"); { - debug!("Sending shutdown frame after flushed all prios"); - if !self - .send_frame( - Frame::Shutdown, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await - { - warn!("couldn't send shutdown frame, are channels already closed?"); + let lock = self.streams.read().await; + for si in lock.values() { + si.send_closed.store(true, Ordering::Relaxed); } } - debug!("Closing all channels, after flushed prios"); - for (cid, ci) in self.channels.write().await.drain() { - let ci = ci.into_inner(); - if let Err(e) = ci.b2r_read_shutdown.send(()) { + let lock = self.channels.read().await; + assert!( + !lock.is_empty(), + "no channel existed remote_pid={}", + self.remote_pid + ); + for cid in lock.keys() { + if let Err(e) = b2b_close_send_protocol_s.send(*cid).await { debug!( ?e, ?cid, - "Seems like this read protocol got already dropped by closing the Stream \ - itself, ignoring" - ); - }; - } - - //Wait for other bparticipants mgr to close via AtomicUsize - const SLEEP_TIME: Duration = Duration::from_millis(5); - const ALLOWED_MANAGER: usize = 1; - tokio::time::sleep(SLEEP_TIME).await; - let mut i: u32 = 1; - while self.running_mgr.load(Ordering::Relaxed) > ALLOWED_MANAGER { - i += 1; - if i.rem_euclid(10) == 1 { - trace!( - ?ALLOWED_MANAGER, - "Waiting for bparticipant mgr to shut down, remaining {}", - self.running_mgr.load(Ordering::Relaxed) - ALLOWED_MANAGER + "closing send_mgr may fail if we got a recv error simultaneously" ); } - tokio::time::sleep(SLEEP_TIME * i).await; } - trace!("All BParticipant mgr (except me) are shut down now"); + drop(lock); + + trace!("wait for other managers"); + let timeout = tokio::time::sleep(timeout_time); + let timeout = tokio::select! { + _ = wait_for_manager() => false, + _ = timeout => true, + }; + if timeout { + warn!("timeout triggered: for killing recv"); + let lock = self.channels.read().await; + for cid in lock.keys() { + if let Err(e) = b2b_force_close_recv_protocol_s.send(*cid).await { + debug!( + ?e, + ?cid, + "closing recv_mgr may fail if we got a recv error simultaneously" + ); + } + } + } + + trace!("wait again"); + wait_for_manager().await; + + sender.send(Ok(())).unwrap(); #[cfg(feature = "metrics")] self.metrics.participants_disconnected_total.inc(); - debug!("BParticipant close done"); - - let mut lock = self.shutdown_info.write().await; - sender - .send(match lock.error.take() { - None => Ok(()), - Some(ParticipantError::ProtocolFailedUnrecoverable) => { - Err(ParticipantError::ProtocolFailedUnrecoverable) - }, - Some(ParticipantError::ParticipantDisconnected) => Ok(()), - }) - .unwrap(); - trace!("Stop participant_shutdown_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - async fn stream_close_mgr( - &self, - mut a2b_close_stream_r: mpsc::UnboundedReceiver, - shutdown_stream_close_mgr_receiver: oneshot::Receiver>, - b2p_notify_empty_stream_s: crossbeam_channel::Sender<(Sid, oneshot::Sender<()>)>, - ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start stream_close_mgr"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut shutdown_stream_close_mgr_receiver = shutdown_stream_close_mgr_receiver.fuse(); - let mut b2b_stream_close_shutdown_confirmed_s = None; - - //from api or shutdown signal - while let Some(sid) = select! { - next = a2b_close_stream_r.recv().fuse() => next, - sender = &mut shutdown_stream_close_mgr_receiver => { - b2b_stream_close_shutdown_confirmed_s = Some(sender.unwrap()); - None - } - } { - //TODO: make this concurrent! - //TODO: Performance, closing is slow! - self.delete_stream( - sid, - Some(b2p_notify_empty_stream_s.clone()), - false, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - trace!("deleting all leftover streams"); - let sids = self - .streams - .read() - .await - .keys() - .cloned() - .collect::>(); - for sid in sids { - //flushing is still important, e.g. when Participant::drop is called (but - // Stream:drop isn't)! - self.delete_stream( - sid, - Some(b2p_notify_empty_stream_s.clone()), - false, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - if b2b_stream_close_shutdown_confirmed_s.is_none() { - b2b_stream_close_shutdown_confirmed_s = - Some(shutdown_stream_close_mgr_receiver.await.unwrap()); - } - b2b_stream_close_shutdown_confirmed_s - .unwrap() - .send(()) - .unwrap(); - trace!("Stop stream_close_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); } + /// Stopping API and participant usage + /// Protocol will take care of the order of the frame async fn delete_stream( &self, sid: Sid, - b2p_notify_empty_stream_s: Option)>>, - from_remote: bool, - #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, + /* #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, */ ) { - //This needs to first stop clients from sending any more. - //Then it will wait for all pending messages (in prio) to be send to the - // protocol After this happened the stream is closed - //Only after all messages are send to the protocol, we can send the CloseStream - // frame! If we would send it before, all followup messages couldn't - // be handled at the remote side. - async { - trace!("Stopping api to use this stream"); - match self.streams.read().await.get(&sid) { - Some(si) => { - si.send_closed.store(true, Ordering::Relaxed); - si.b2a_msg_recv_s.lock().await.close(); - }, - None => trace!( - "Couldn't find the stream, might be simultaneous close from local/remote" - ), - } - - if !from_remote { - trace!("Wait for stream to be flushed"); - let (s2b_stream_finished_closed_s, s2b_stream_finished_closed_r) = - oneshot::channel(); - b2p_notify_empty_stream_s - .expect("needs to be set when from_remote is false") - .send((sid, s2b_stream_finished_closed_s)) - .unwrap(); - s2b_stream_finished_closed_r.await.unwrap(); - - trace!("Stream was successfully flushed"); - } - - #[cfg(feature = "metrics")] - self.metrics - .streams_closed_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - //only now remove the Stream, that means we can still recv on it. - self.streams.write().await.remove(&sid); - - if !from_remote { - self.send_frame( - Frame::CloseStream { sid }, - #[cfg(feature = "metrics")] - frames_out_total_cache, - ) - .await; - } + let stream = { self.streams.write().await.remove(&sid) }; + match stream { + Some(si) => { + si.send_closed.store(true, Ordering::Relaxed); + si.b2a_msg_recv_s.lock().await.close(); + }, + None => { + trace!("Couldn't find the stream, might be simultaneous close from local/remote") + }, } - .instrument(tracing::info_span!("close", ?sid, ?from_remote)) - .await; + /* + #[cfg(feature = "metrics")] + self.metrics + .streams_closed_total + .with_label_values(&[&self.remote_pid_string]) + .inc();*/ } async fn create_stream( @@ -843,10 +614,11 @@ impl BParticipant { sid: Sid, prio: Prio, promises: Promises, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, + guaranteed_bandwidth: Bandwidth, + a2b_msg_s: &crossbeam_channel::Sender<(Sid, Arc)>, a2b_close_stream_s: &mpsc::UnboundedSender, ) -> Stream { - let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); + let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); let send_closed = Arc::new(AtomicBool::new(false)); self.streams.write().await.insert(sid, StreamInfo { prio, @@ -864,38 +636,248 @@ impl BParticipant { sid, prio, promises, + guaranteed_bandwidth, send_closed, - a2p_msg_s, + a2b_msg_s.clone(), b2a_msg_recv_r, a2b_close_stream_s.clone(), ) } +} - async fn close_write_api(&self, reason: Option) { - trace!(?reason, "close_api"); - let mut lock = self.shutdown_info.write().await; - if let Some(r) = reason { - lock.error = Some(r); - } - lock.b2b_close_stream_opened_sender_s - .take() - .map(|s| s.send(())); +#[cfg(test)] +mod tests { + use super::*; + use tokio::{ + runtime::Runtime, + sync::{mpsc, oneshot}, + task::JoinHandle, + }; - debug!("Closing all streams for write"); - for (sid, si) in self.streams.read().await.iter() { - trace!(?sid, "Shutting down Stream for write"); - si.send_closed.store(true, Ordering::Relaxed); - } + fn mock_bparticipant() -> ( + Arc, + mpsc::UnboundedSender, + mpsc::UnboundedReceiver, + mpsc::UnboundedSender, + oneshot::Sender, + mpsc::UnboundedReceiver, + JoinHandle<()>, + ) { + let runtime = Arc::new(tokio::runtime::Runtime::new().unwrap()); + let runtime_clone = Arc::clone(&runtime); + + let (b2s_prio_statistic_s, b2s_prio_statistic_r) = + mpsc::unbounded_channel::(); + + let ( + bparticipant, + a2b_open_stream_s, + b2a_stream_opened_r, + s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + ) = runtime_clone.block_on(async move { + let pid = Pid::fake(1); + let sid = Sid::new(1000); + let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); + + BParticipant::new(pid, sid, Arc::clone(&metrics)) + }); + + let handle = runtime_clone.spawn(bparticipant.run(b2s_prio_statistic_s)); + ( + runtime_clone, + a2b_open_stream_s, + b2a_stream_opened_r, + s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) } - ///closing api::Participant is done by closing all channels, expect for the - /// shutdown channel at this point! - async fn close_api(&self, reason: Option) { - self.close_write_api(reason).await; - debug!("Closing all streams"); - for (sid, si) in self.streams.read().await.iter() { - trace!(?sid, "Shutting down Stream"); - si.b2a_msg_recv_s.lock().await.close(); - } + async fn mock_mpsc( + cid: Cid, + _runtime: &Arc, + create_channel: &mut mpsc::UnboundedSender, + ) -> Protocols { + let (s1, r1) = mpsc::channel(100); + let (s2, r2) = mpsc::channel(100); + let p1 = Protocols::new_mpsc(s1, r2); + let (complete_s, complete_r) = oneshot::channel(); + create_channel + .send((cid, Sid::new(0), p1, complete_s)) + .unwrap(); + complete_r.await.unwrap(); + Protocols::new_mpsc(s2, r1) + } + + #[test] + fn close_bparticipant_by_timeout_during_close() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let _remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + let (s, r) = oneshot::channel(); + let before = Instant::now(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + r.await.unwrap().unwrap(); + }); + assert!( + before.elapsed() > Duration::from_millis(900), + "timeout wasn't triggered" + ); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn close_bparticipant_cleanly() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + let (s, r) = oneshot::channel(); + let before = Instant::now(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(2), s)) + .unwrap(); + drop(remote); // remote needs to be dropped as soon as local.sender is closed + r.await.unwrap().unwrap(); + }); + assert!( + before.elapsed() < Duration::from_millis(1900), + "timeout was triggered" + ); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn create_stream() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + // created stream + let (rs, mut rr) = remote.split(); + let (stream_sender, _stream_receiver) = oneshot::channel(); + a2b_open_stream_s + .send((7u8, Promises::ENCRYPTED, 1_000_000, stream_sender)) + .unwrap(); + + let stream_event = runtime.block_on(rr.recv()).unwrap(); + match stream_event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + assert_eq!(sid, Sid::new(1000)); + assert_eq!(prio, 7u8); + assert_eq!(promises, Promises::ENCRYPTED); + assert_eq!(guaranteed_bandwidth, 1_000_000); + }, + _ => panic!("wrong event"), + }; + + let (s, r) = oneshot::channel(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + drop((rs, rr)); + r.await.unwrap().unwrap(); + }); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn created_stream() { + let ( + runtime, + a2b_open_stream_s, + mut b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + // create stream + let (mut rs, rr) = remote.split(); + runtime + .block_on(rs.send(ProtocolEvent::OpenStream { + sid: Sid::new(1000), + prio: 9u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + })) + .unwrap(); + + let stream = runtime.block_on(b2a_stream_opened_r.recv()).unwrap(); + assert_eq!(stream.promises(), Promises::ORDERED); + + let (s, r) = oneshot::channel(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + drop((rs, rr)); + r.await.unwrap().unwrap(); + }); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); } } diff --git a/network/src/prios.rs b/network/src/prios.rs deleted file mode 100644 index a544a31241..0000000000 --- a/network/src/prios.rs +++ /dev/null @@ -1,697 +0,0 @@ -//!Priorities are handled the following way. -//!Prios from 0-63 are allowed. -//!all 5 numbers the throughput is halved. -//!E.g. in the same time 100 prio0 messages are send, only 50 prio5, 25 prio10, -//! 12 prio15 or 6 prio20 messages are send. Note: TODO: prio0 will be send -//! immediately when found! -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; -use crate::{ - message::OutgoingMessage, - types::{Frame, Prio, Sid}, -}; -use crossbeam_channel::{unbounded, Receiver, Sender}; -use std::collections::{HashMap, HashSet, VecDeque}; -#[cfg(feature = "metrics")] use std::sync::Arc; -use tokio::sync::oneshot; -use tracing::trace; - -const PRIO_MAX: usize = 64; - -#[derive(Default)] -struct PidSidInfo { - len: u64, - empty_notify: Option>, -} - -pub(crate) struct PrioManager { - points: [u32; PRIO_MAX], - messages: [VecDeque<(Sid, OutgoingMessage)>; PRIO_MAX], - messages_rx: Receiver<(Prio, Sid, OutgoingMessage)>, - sid_owned: HashMap, - //you can register to be notified if a pid_sid combination is flushed completely here - sid_flushed_rx: Receiver<(Sid, oneshot::Sender<()>)>, - queued: HashSet, - #[cfg(feature = "metrics")] - metrics: Arc, - #[cfg(feature = "metrics")] - pid: String, -} - -impl PrioManager { - const PRIOS: [u32; PRIO_MAX] = [ - 100, 115, 132, 152, 174, 200, 230, 264, 303, 348, 400, 459, 528, 606, 696, 800, 919, 1056, - 1213, 1393, 1600, 1838, 2111, 2425, 2786, 3200, 3676, 4222, 4850, 5572, 6400, 7352, 8445, - 9701, 11143, 12800, 14703, 16890, 19401, 22286, 25600, 29407, 33779, 38802, 44572, 51200, - 58813, 67559, 77605, 89144, 102400, 117627, 135118, 155209, 178289, 204800, 235253, 270235, - 310419, 356578, 409600, 470507, 540470, 620838, - ]; - - #[allow(clippy::type_complexity)] - pub fn new( - #[cfg(feature = "metrics")] metrics: Arc, - pid: String, - ) -> ( - Self, - Sender<(Prio, Sid, OutgoingMessage)>, - Sender<(Sid, oneshot::Sender<()>)>, - ) { - #[cfg(not(feature = "metrics"))] - let _pid = pid; - // (a2p_msg_s, a2p_msg_r) - let (messages_tx, messages_rx) = unbounded(); - let (sid_flushed_tx, sid_flushed_rx) = unbounded(); - ( - Self { - points: [0; PRIO_MAX], - messages: [ - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - ], - messages_rx, - queued: HashSet::new(), //TODO: optimize with u64 and 64 bits - sid_flushed_rx, - sid_owned: HashMap::new(), - #[cfg(feature = "metrics")] - metrics, - #[cfg(feature = "metrics")] - pid, - }, - messages_tx, - sid_flushed_tx, - ) - } - - async fn tick(&mut self) { - // Check Range - for (prio, sid, msg) in self.messages_rx.try_iter() { - debug_assert!(prio as usize <= PRIO_MAX); - #[cfg(feature = "metrics")] - { - let sid_string = sid.to_string(); - self.metrics - .message_out_total - .with_label_values(&[&self.pid, &sid_string]) - .inc(); - self.metrics - .message_out_throughput - .with_label_values(&[&self.pid, &sid_string]) - .inc_by(msg.buffer.data.len() as u64); - } - - //trace!(?prio, ?sid_string, "tick"); - self.queued.insert(prio); - self.messages[prio as usize].push_back((sid, msg)); - self.sid_owned.entry(sid).or_default().len += 1; - } - //this must be AFTER messages - for (sid, return_sender) in self.sid_flushed_rx.try_iter() { - #[cfg(feature = "metrics")] - self.metrics - .streams_flushed - .with_label_values(&[&self.pid]) - .inc(); - if let Some(cnt) = self.sid_owned.get_mut(&sid) { - // register sender - cnt.empty_notify = Some(return_sender); - trace!(?sid, "register empty notify"); - } else { - // return immediately - return_sender.send(()).unwrap(); - trace!(?sid, "return immediately that stream is empty"); - } - } - } - - //if None returned, we are empty! - fn calc_next_prio(&self) -> Option { - // compare all queued prios, max 64 operations - let mut lowest = std::u32::MAX; - let mut lowest_id = None; - for &n in &self.queued { - let n_points = self.points[n as usize]; - if n_points < lowest { - lowest = n_points; - lowest_id = Some(n) - } else if n_points == lowest && lowest_id.is_some() && n < lowest_id.unwrap() { - //on equal points lowest first! - lowest_id = Some(n) - } - } - lowest_id - /* - self.queued - .iter() - .min_by_key(|&n| self.points[*n as usize]).cloned()*/ - } - - /// no_of_frames = frames.len() - /// Your goal is to try to find a realistic no_of_frames! - /// no_of_frames should be choosen so, that all Frames can be send out till - /// the next tick! - /// - if no_of_frames is too high you will fill either the Socket buffer, - /// or your internal buffer. In that case you will increase latency for - /// high prio messages! - /// - if no_of_frames is too low you wont saturate your Socket fully, thus - /// have a lower bandwidth as possible - pub async fn fill_frames>( - &mut self, - no_of_frames: usize, - frames: &mut E, - ) { - for v in self.messages.iter_mut() { - v.reserve_exact(no_of_frames) - } - self.tick().await; - for _ in 0..no_of_frames { - match self.calc_next_prio() { - Some(prio) => { - //let prio2 = self.calc_next_prio().unwrap(); - //trace!(?prio, "handle next prio"); - self.points[prio as usize] += Self::PRIOS[prio as usize]; - //pop message from front of VecDeque, handle it and push it back, so that all - // => messages with same prio get a fair chance :) - //TODO: evaluate not popping every time - let (sid, mut msg) = self.messages[prio as usize].pop_front().unwrap(); - if msg.fill_next(sid, frames) { - //trace!(?m.mid, "finish message"); - //check if prio is empty - if self.messages[prio as usize].is_empty() { - self.queued.remove(&prio); - } - //decrease pid_sid counter by 1 again - let cnt = self.sid_owned.get_mut(&sid).expect( - "The pid_sid_owned counter works wrong, more pid,sid removed than \ - inserted", - ); - cnt.len -= 1; - if cnt.len == 0 { - let cnt = self.sid_owned.remove(&sid).unwrap(); - if let Some(empty_notify) = cnt.empty_notify { - empty_notify.send(()).unwrap(); - trace!(?sid, "returned that stream is empty"); - } - } - } else { - self.messages[prio as usize].push_front((sid, msg)); - } - }, - None => { - //QUEUE is empty, we are clearing the POINTS to not build up huge pipes of - // POINTS on a prio from the past - self.points = [0; PRIO_MAX]; - break; - }, - } - } - } -} - -impl std::fmt::Debug for PrioManager { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut cnt = 0; - for m in self.messages.iter() { - cnt += m.len(); - } - write!(f, "PrioManager(len: {}, queued: {:?})", cnt, &self.queued,) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - message::{MessageBuffer, OutgoingMessage}, - metrics::NetworkMetrics, - prios::*, - types::{Frame, Pid, Prio, Sid}, - }; - use crossbeam_channel::Sender; - use std::{collections::VecDeque, sync::Arc}; - use tokio::{runtime::Runtime, sync::oneshot}; - - const SIZE: u64 = OutgoingMessage::FRAME_DATA_SIZE; - const USIZE: usize = OutgoingMessage::FRAME_DATA_SIZE as usize; - - #[allow(clippy::type_complexity)] - fn mock_new() -> ( - PrioManager, - Sender<(Prio, Sid, OutgoingMessage)>, - Sender<(Sid, oneshot::Sender<()>)>, - ) { - let pid = Pid::fake(1); - PrioManager::new( - Arc::new(NetworkMetrics::new(&pid).unwrap()), - pid.to_string(), - ) - } - - fn mock_out(prio: Prio, sid: u64) -> (Prio, Sid, OutgoingMessage) { - let sid = Sid::new(sid); - (prio, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { - data: vec![48, 49, 50], - }), - cursor: 0, - mid: 1, - sid, - }) - } - - fn mock_out_large(prio: Prio, sid: u64) -> (Prio, Sid, OutgoingMessage) { - let sid = Sid::new(sid); - let mut data = vec![48; USIZE]; - data.append(&mut vec![49; USIZE]); - data.append(&mut vec![50; 20]); - (prio, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - }) - } - - fn assert_header(frames: &mut VecDeque<(Sid, Frame)>, f_sid: u64, f_length: u64) { - let frame = frames - .pop_front() - .expect("Frames vecdeque doesn't contain enough frames!") - .1; - if let Frame::DataHeader { mid, sid, length } = frame { - assert_eq!(mid, 1); - assert_eq!(sid, Sid::new(f_sid)); - assert_eq!(length, f_length); - } else { - panic!("Wrong frame type!, expected DataHeader"); - } - } - - fn assert_data(frames: &mut VecDeque<(Sid, Frame)>, f_start: u64, f_data: Vec) { - let frame = frames - .pop_front() - .expect("Frames vecdeque doesn't contain enough frames!") - .1; - if let Frame::Data { mid, start, data } = frame { - assert_eq!(mid, 1); - assert_eq!(start, f_start); - assert_eq!(data, f_data); - } else { - panic!("Wrong frame type!, expected Data"); - } - } - - #[test] - fn single_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - msg_tx.send(mock_out(20, 42)).unwrap(); - let mut frames = VecDeque::new(); - - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 42, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p20_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 42)).unwrap(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 42, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 2)).unwrap(); - msg_tx.send(mock_out(16, 1)).unwrap(); - msg_tx.send(mock_out(16, 3)).unwrap(); - msg_tx.send(mock_out(16, 5)).unwrap(); - msg_tx.send(mock_out(20, 4)).unwrap(); - msg_tx.send(mock_out(20, 7)).unwrap(); - msg_tx.send(mock_out(16, 6)).unwrap(); - msg_tx.send(mock_out(20, 10)).unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - msg_tx.send(mock_out(20, 12)).unwrap(); - msg_tx.send(mock_out(16, 9)).unwrap(); - msg_tx.send(mock_out(16, 11)).unwrap(); - msg_tx.send(mock_out(20, 13)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - for i in 1..14 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - } - - #[test] - fn multiple_fill_frames_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 2)).unwrap(); - msg_tx.send(mock_out(16, 1)).unwrap(); - msg_tx.send(mock_out(16, 3)).unwrap(); - msg_tx.send(mock_out(16, 5)).unwrap(); - msg_tx.send(mock_out(20, 4)).unwrap(); - msg_tx.send(mock_out(20, 7)).unwrap(); - msg_tx.send(mock_out(16, 6)).unwrap(); - msg_tx.send(mock_out(20, 10)).unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - msg_tx.send(mock_out(20, 12)).unwrap(); - msg_tx.send(mock_out(16, 9)).unwrap(); - msg_tx.send(mock_out(16, 11)).unwrap(); - msg_tx.send(mock_out(20, 13)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(3, &mut frames)); - for i in 1..4 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(11, &mut frames)); - for i in 4..14 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - } - - #[test] - fn single_large_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_large_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - msg_tx.send(mock_out_large(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert_header(&mut frames, 2, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_large_p16_sudden_p0() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - msg_tx.send(mock_out_large(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - - msg_tx.send(mock_out(0, 3)).unwrap(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 3, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert_header(&mut frames, 2, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p20_thousand_p16_at_once() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - for _ in 0..998 { - msg_tx.send(mock_out(16, 2)).unwrap(); - } - msg_tx.send(mock_out(20, 1)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 1, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - //unimportant - } - - #[test] - fn single_p20_thousand_p16_later() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - for _ in 0..998 { - msg_tx.send(mock_out(16, 2)).unwrap(); - } - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - //^unimportant frames, gonna be dropped - msg_tx.send(mock_out(20, 1)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - //important in that test is, that after the first frames got cleared i reset - // the Points even though 998 prio 16 messages have been send at this - // point and 0 prio20 messages the next message is a prio16 message - // again, and only then prio20! we dont want to build dept over a idling - // connection - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 1, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - //unimportant - } - - #[test] - fn gigantic_message() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - } - - #[test] - fn gigantic_message_order() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - assert_header(&mut frames, 8, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - - #[test] - fn gigantic_message_order_other_prio() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - msg_tx.send(mock_out(20, 8)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_header(&mut frames, 8, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - } -} diff --git a/network/src/protocols.rs b/network/src/protocols.rs deleted file mode 100644 index a18c1e1cbd..0000000000 --- a/network/src/protocols.rs +++ /dev/null @@ -1,591 +0,0 @@ -#[cfg(feature = "metrics")] -use crate::metrics::{CidFrameCache, NetworkMetrics}; -use crate::{ - participant::C2pFrame, - types::{Cid, Frame}, -}; -use futures_util::{future::Fuse, FutureExt}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpStream, UdpSocket}, - select, - sync::{mpsc, oneshot, Mutex}, -}; - -use std::{convert::TryFrom, net::SocketAddr, sync::Arc}; -use tracing::*; - -// Reserving bytes 0, 10, 13 as i have enough space and want to make it easy to -// detect a invalid client, e.g. sending an empty line would make 10 first char -// const FRAME_RESERVED_1: u8 = 0; -const FRAME_HANDSHAKE: u8 = 1; -const FRAME_INIT: u8 = 2; -const FRAME_SHUTDOWN: u8 = 3; -const FRAME_OPEN_STREAM: u8 = 4; -const FRAME_CLOSE_STREAM: u8 = 5; -const FRAME_DATA_HEADER: u8 = 6; -const FRAME_DATA: u8 = 7; -const FRAME_RAW: u8 = 8; -//const FRAME_RESERVED_2: u8 = 10; -//const FRAME_RESERVED_3: u8 = 13; - -#[derive(Debug)] -pub(crate) enum Protocols { - Tcp(TcpProtocol), - Udp(UdpProtocol), - //Mpsc(MpscChannel), -} - -#[derive(Debug)] -pub(crate) struct TcpProtocol { - read_stream: tokio::sync::Mutex, - write_stream: tokio::sync::Mutex, - #[cfg(feature = "metrics")] - metrics: Arc, -} - -#[derive(Debug)] -pub(crate) struct UdpProtocol { - socket: Arc, - remote_addr: SocketAddr, - #[cfg(feature = "metrics")] - metrics: Arc, - data_in: Mutex>>, -} - -//TODO: PERFORMACE: Use BufWriter and BufReader from std::io! -impl TcpProtocol { - pub(crate) fn new( - stream: TcpStream, - #[cfg(feature = "metrics")] metrics: Arc, - ) -> Self { - let (read_stream, write_stream) = stream.into_split(); - Self { - read_stream: tokio::sync::Mutex::new(read_stream), - write_stream: tokio::sync::Mutex::new(write_stream), - #[cfg(feature = "metrics")] - metrics, - } - } - - async fn read_frame( - r: &mut R, - end_receiver: &mut Fuse>, - ) -> Result> { - let handle = |read_result| match read_result { - Ok(_) => Ok(()), - Err(e) => Err(Some(e)), - }; - - let mut frame_no = [0u8; 1]; - match select! { - r = r.read_exact(&mut frame_no).fuse() => Some(r), - _ = end_receiver => None, - } { - Some(read_result) => handle(read_result)?, - None => { - trace!("shutdown requested"); - return Err(None); - }, - }; - - match frame_no[0] { - FRAME_HANDSHAKE => { - let mut bytes = [0u8; 19]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_handshake(bytes)) - }, - FRAME_INIT => { - let mut bytes = [0u8; 32]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_init(bytes)) - }, - FRAME_SHUTDOWN => Ok(Frame::Shutdown), - FRAME_OPEN_STREAM => { - let mut bytes = [0u8; 10]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_open_stream(bytes)) - }, - FRAME_CLOSE_STREAM => { - let mut bytes = [0u8; 8]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_close_stream(bytes)) - }, - FRAME_DATA_HEADER => { - let mut bytes = [0u8; 24]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_data_header(bytes)) - }, - FRAME_DATA => { - let mut bytes = [0u8; 18]; - handle(r.read_exact(&mut bytes).await)?; - let (mid, start, length) = Frame::gen_data(bytes); - let mut data = vec![0; length as usize]; - handle(r.read_exact(&mut data).await)?; - Ok(Frame::Data { mid, start, data }) - }, - FRAME_RAW => { - let mut bytes = [0u8; 2]; - handle(r.read_exact(&mut bytes).await)?; - let length = Frame::gen_raw(bytes); - let mut data = vec![0; length as usize]; - handle(r.read_exact(&mut data).await)?; - Ok(Frame::Raw(data)) - }, - other => { - // report a RAW frame, but cannot rely on the next 2 bytes to be a size. - // guessing 32 bytes, which might help to sort down issues - let mut data = vec![0; 32]; - //keep the first byte! - match r.read(&mut data[1..]).await { - Ok(n) => { - data.truncate(n + 1); - Ok(()) - }, - Err(e) => Err(Some(e)), - }?; - data[0] = other; - warn!(?data, "got a unexpected RAW msg"); - Ok(Frame::Raw(data)) - }, - } - } - - pub async fn read_from_wire( - &self, - cid: Cid, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - end_r: oneshot::Receiver<()>, - ) { - trace!("Starting up tcp read()"); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_in_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_in_throughput - .with_label_values(&[&cid.to_string()]); - let mut read_stream = self.read_stream.lock().await; - let mut end_r = end_r.fuse(); - - loop { - match Self::read_frame(&mut *read_stream, &mut end_r).await { - Ok(frame) => { - #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame - { - throughput_cache.inc_by(data.len() as u64); - } - } - if let Err(e) = w2c_cid_frame_s.send((cid, Ok(frame))) { - warn!(?e, "Channel or Participant seems no longer to exist"); - } - }, - Err(e_option) => { - if let Some(e) = e_option { - info!(?e, "Closing tcp protocol due to read error"); - //w2c_cid_frame_s is shared, dropping it wouldn't notify the receiver as - // every channel is holding a sender! thats why Ne - // need a explicit STOP here - if let Err(e) = w2c_cid_frame_s.send((cid, Err(()))) { - warn!(?e, "Channel or Participant seems no longer to exist"); - } - } - //None is clean shutdown - break; - }, - } - } - trace!("Shutting down tcp read()"); - } - - pub async fn write_frame( - w: &mut W, - frame: Frame, - ) -> Result<(), std::io::Error> { - match frame { - Frame::Handshake { - magic_number, - version, - } => { - w.write_all(&FRAME_HANDSHAKE.to_be_bytes()).await?; - w.write_all(&magic_number).await?; - w.write_all(&version[0].to_le_bytes()).await?; - w.write_all(&version[1].to_le_bytes()).await?; - w.write_all(&version[2].to_le_bytes()).await?; - }, - Frame::Init { pid, secret } => { - w.write_all(&FRAME_INIT.to_be_bytes()).await?; - w.write_all(&pid.to_le_bytes()).await?; - w.write_all(&secret.to_le_bytes()).await?; - }, - Frame::Shutdown => { - w.write_all(&FRAME_SHUTDOWN.to_be_bytes()).await?; - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - w.write_all(&FRAME_OPEN_STREAM.to_be_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - w.write_all(&prio.to_le_bytes()).await?; - w.write_all(&promises.to_le_bytes()).await?; - }, - Frame::CloseStream { sid } => { - w.write_all(&FRAME_CLOSE_STREAM.to_be_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - }, - Frame::DataHeader { mid, sid, length } => { - w.write_all(&FRAME_DATA_HEADER.to_be_bytes()).await?; - w.write_all(&mid.to_le_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - w.write_all(&length.to_le_bytes()).await?; - }, - Frame::Data { mid, start, data } => { - w.write_all(&FRAME_DATA.to_be_bytes()).await?; - w.write_all(&mid.to_le_bytes()).await?; - w.write_all(&start.to_le_bytes()).await?; - w.write_all(&(data.len() as u16).to_le_bytes()).await?; - w.write_all(&data).await?; - }, - Frame::Raw(data) => { - w.write_all(&FRAME_RAW.to_be_bytes()).await?; - w.write_all(&(data.len() as u16).to_le_bytes()).await?; - w.write_all(&data).await?; - }, - }; - Ok(()) - } - - pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { - trace!("Starting up tcp write()"); - let mut write_stream = self.write_stream.lock().await; - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_out_throughput - .with_label_values(&[&cid.to_string()]); - #[cfg(not(feature = "metrics"))] - let _cid = cid; - - while let Some(frame) = c2w_frame_r.recv().await { - #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame - { - throughput_cache.inc_by(data.len() as u64); - } - } - if let Err(e) = Self::write_frame(&mut *write_stream, frame).await { - info!( - ?e, - "Got an error writing to tcp, going to close this channel" - ); - c2w_frame_r.close(); - break; - }; - } - trace!("shutting down tcp write()"); - } -} - -impl UdpProtocol { - pub(crate) fn new( - socket: Arc, - remote_addr: SocketAddr, - #[cfg(feature = "metrics")] metrics: Arc, - data_in: mpsc::UnboundedReceiver>, - ) -> Self { - Self { - socket, - remote_addr, - #[cfg(feature = "metrics")] - metrics, - data_in: Mutex::new(data_in), - } - } - - pub async fn read_from_wire( - &self, - cid: Cid, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - end_r: oneshot::Receiver<()>, - ) { - trace!("Starting up udp read()"); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_in_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_in_throughput - .with_label_values(&[&cid.to_string()]); - let mut data_in = self.data_in.lock().await; - let mut end_r = end_r.fuse(); - while let Some(bytes) = select! { - r = data_in.recv().fuse() => match r { - Some(r) => Some(r), - None => { - info!("Udp read ended"); - w2c_cid_frame_s.send((cid, Err(()))).expect("Channel or Participant seems no longer to exist"); - None - } - }, - _ = &mut end_r => None, - } { - trace!("Got raw UDP message with len: {}", bytes.len()); - let frame_no = bytes[0]; - let frame = match frame_no { - FRAME_HANDSHAKE => { - Frame::gen_handshake(*<&[u8; 19]>::try_from(&bytes[1..20]).unwrap()) - }, - FRAME_INIT => Frame::gen_init(*<&[u8; 32]>::try_from(&bytes[1..33]).unwrap()), - FRAME_SHUTDOWN => Frame::Shutdown, - FRAME_OPEN_STREAM => { - Frame::gen_open_stream(*<&[u8; 10]>::try_from(&bytes[1..11]).unwrap()) - }, - FRAME_CLOSE_STREAM => { - Frame::gen_close_stream(*<&[u8; 8]>::try_from(&bytes[1..9]).unwrap()) - }, - FRAME_DATA_HEADER => { - Frame::gen_data_header(*<&[u8; 24]>::try_from(&bytes[1..25]).unwrap()) - }, - FRAME_DATA => { - let (mid, start, length) = - Frame::gen_data(*<&[u8; 18]>::try_from(&bytes[1..19]).unwrap()); - let mut data = vec![0; length as usize]; - #[cfg(feature = "metrics")] - throughput_cache.inc_by(length as u64); - data.copy_from_slice(&bytes[19..]); - Frame::Data { mid, start, data } - }, - FRAME_RAW => { - let length = Frame::gen_raw(*<&[u8; 2]>::try_from(&bytes[1..3]).unwrap()); - let mut data = vec![0; length as usize]; - data.copy_from_slice(&bytes[3..]); - Frame::Raw(data) - }, - _ => Frame::Raw(bytes), - }; - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - w2c_cid_frame_s.send((cid, Ok(frame))).unwrap(); - } - trace!("Shutting down udp read()"); - } - - pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { - trace!("Starting up udp write()"); - let mut buffer = [0u8; 2000]; - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_out_throughput - .with_label_values(&[&cid.to_string()]); - #[cfg(not(feature = "metrics"))] - let _cid = cid; - while let Some(frame) = c2w_frame_r.recv().await { - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - let len = match frame { - Frame::Handshake { - magic_number, - version, - } => { - let x = FRAME_HANDSHAKE.to_be_bytes(); - buffer[0] = x[0]; - buffer[1..8].copy_from_slice(&magic_number); - buffer[8..12].copy_from_slice(&version[0].to_le_bytes()); - buffer[12..16].copy_from_slice(&version[1].to_le_bytes()); - buffer[16..20].copy_from_slice(&version[2].to_le_bytes()); - 20 - }, - Frame::Init { pid, secret } => { - buffer[0] = FRAME_INIT.to_be_bytes()[0]; - buffer[1..17].copy_from_slice(&pid.to_le_bytes()); - buffer[17..33].copy_from_slice(&secret.to_le_bytes()); - 33 - }, - Frame::Shutdown => { - buffer[0] = FRAME_SHUTDOWN.to_be_bytes()[0]; - 1 - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - buffer[0] = FRAME_OPEN_STREAM.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&sid.to_le_bytes()); - buffer[9] = prio.to_le_bytes()[0]; - buffer[10] = promises.to_le_bytes()[0]; - 11 - }, - Frame::CloseStream { sid } => { - buffer[0] = FRAME_CLOSE_STREAM.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&sid.to_le_bytes()); - 9 - }, - Frame::DataHeader { mid, sid, length } => { - buffer[0] = FRAME_DATA_HEADER.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&mid.to_le_bytes()); - buffer[9..17].copy_from_slice(&sid.to_le_bytes()); - buffer[17..25].copy_from_slice(&length.to_le_bytes()); - 25 - }, - Frame::Data { mid, start, data } => { - buffer[0] = FRAME_DATA.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&mid.to_le_bytes()); - buffer[9..17].copy_from_slice(&start.to_le_bytes()); - buffer[17..19].copy_from_slice(&(data.len() as u16).to_le_bytes()); - buffer[19..(data.len() + 19)].clone_from_slice(&data[..]); - #[cfg(feature = "metrics")] - throughput_cache.inc_by(data.len() as u64); - 19 + data.len() - }, - Frame::Raw(data) => { - buffer[0] = FRAME_RAW.to_be_bytes()[0]; - buffer[1..3].copy_from_slice(&(data.len() as u16).to_le_bytes()); - buffer[3..(data.len() + 3)].clone_from_slice(&data[..]); - 3 + data.len() - }, - }; - let mut start = 0; - while start < len { - trace!(?start, ?len, "Splitting up udp frame in multiple packages"); - match self - .socket - .send_to(&buffer[start..len], self.remote_addr) - .await - { - Ok(n) => { - start += n; - if n != len { - error!( - "THIS DOESN'T WORK, as RECEIVER CURRENTLY ONLY HANDLES 1 FRAME \ - per UDP message. splitting up will fail!" - ); - } - }, - Err(e) => error!(?e, "Need to handle that error!"), - } - } - } - trace!("Shutting down udp write()"); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{metrics::NetworkMetrics, types::Pid}; - use std::sync::Arc; - use tokio::{net, runtime::Runtime, sync::mpsc}; - - #[test] - fn tcp_read_handshake() { - let pid = Pid::new(); - let cid = 80085; - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); - let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50500); - Runtime::new().unwrap().block_on(async { - let server = net::TcpListener::bind(addr).await.unwrap(); - let mut client = net::TcpStream::connect(addr).await.unwrap(); - - let (s_stream, _) = server.accept().await.unwrap(); - let prot = TcpProtocol::new(s_stream, metrics); - - //Send Handshake - client.write_all(&[FRAME_HANDSHAKE]).await.unwrap(); - client.write_all(b"HELLOWO").await.unwrap(); - client.write_all(&1337u32.to_le_bytes()).await.unwrap(); - client.write_all(&0u32.to_le_bytes()).await.unwrap(); - client.write_all(&42u32.to_le_bytes()).await.unwrap(); - client.flush().await.unwrap(); - - //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let cid2 = cid; - let t = std::thread::spawn(move || { - Runtime::new().unwrap().block_on(async { - prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) - .await; - }) - }); - // Assert than we get some value back! Its a Handshake! - //tokio::task::sleep(std::time::Duration::from_millis(1000)); - let (cid_r, frame) = w2c_cid_frame_r.recv().await.unwrap(); - assert_eq!(cid, cid_r); - if let Ok(Frame::Handshake { - magic_number, - version, - }) = frame - { - assert_eq!(&magic_number, b"HELLOWO"); - assert_eq!(version, [1337, 0, 42]); - } else { - panic!("wrong handshake"); - } - read_stop_sender.send(()).unwrap(); - t.join().unwrap(); - }); - } - - #[test] - fn tcp_read_garbage() { - let pid = Pid::new(); - let cid = 80085; - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); - let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50501); - Runtime::new().unwrap().block_on(async { - let server = net::TcpListener::bind(addr).await.unwrap(); - let mut client = net::TcpStream::connect(addr).await.unwrap(); - - let (s_stream, _) = server.accept().await.unwrap(); - let prot = TcpProtocol::new(s_stream, metrics); - - //Send Handshake - client - .write_all("x4hrtzsektfhxugzdtz5r78gzrtzfhxfdthfthuzhfzzufasgasdfg".as_bytes()) - .await - .unwrap(); - client.flush().await.unwrap(); - //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let cid2 = cid; - let t = std::thread::spawn(move || { - Runtime::new().unwrap().block_on(async { - prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) - .await; - }) - }); - // Assert than we get some value back! Its a Raw! - let (cid_r, frame) = w2c_cid_frame_r.recv().await.unwrap(); - assert_eq!(cid, cid_r); - if let Ok(Frame::Raw(data)) = frame { - assert_eq!(&data.as_slice(), b"x4hrtzsektfhxugzdtz5r78gzrtzfhxf"); - } else { - panic!("wrong frame type"); - } - read_stop_sender.send(()).unwrap(); - t.join().unwrap(); - }); - } -} diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index f648d48a15..eb6d21bd7e 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -2,12 +2,11 @@ use crate::metrics::NetworkMetrics; use crate::{ api::{Participant, ProtocolAddr}, - channel::Handshake, + channel::Protocols, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, - protocols::{Protocols, TcpProtocol, UdpProtocol}, - types::Pid, }; use futures_util::{FutureExt, StreamExt}; +use network_protocol::Pid; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -17,6 +16,7 @@ use std::{ atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }, + time::Duration, }; use tokio::{ io, net, @@ -214,47 +214,40 @@ impl Scheduler { }, }; info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - ( - Protocols::Tcp(TcpProtocol::new( - stream, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - )), - false, - ) - }, - ProtocolAddr::Udp(addr) => { - #[cfg(feature = "metrics")] - self.metrics - .connect_requests_total - .with_label_values(&["udp"]) - .inc(); - let socket = match net::UdpSocket::bind("0.0.0.0:0").await { - Ok(socket) => Arc::new(socket), - Err(e) => { - pid_sender.send(Err(e)).unwrap(); - continue; - }, - }; - if let Err(e) = socket.connect(addr).await { - pid_sender.send(Err(e)).unwrap(); - continue; - }; - info!("Connecting Udp to: {}", addr); - let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); - let protocol = UdpProtocol::new( - Arc::clone(&socket), - addr, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - udp_data_receiver, - ); - self.runtime.spawn( - Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) - .instrument(tracing::info_span!("udp", ?addr)), - ); - (Protocols::Udp(protocol), true) - }, + (Protocols::new_tcp(stream), false) + }, /* */ + //ProtocolAddr::Udp(addr) => { + //#[cfg(feature = "metrics")] + //self.metrics + //.connect_requests_total + //.with_label_values(&["udp"]) + //.inc(); + //let socket = match net::UdpSocket::bind("0.0.0.0:0").await { + //Ok(socket) => Arc::new(socket), + //Err(e) => { + //pid_sender.send(Err(e)).unwrap(); + //continue; + //}, + //}; + //if let Err(e) = socket.connect(addr).await { + //pid_sender.send(Err(e)).unwrap(); + //continue; + //}; + //info!("Connecting Udp to: {}", addr); + //let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); + //let protocol = UdpProtocol::new( + //Arc::clone(&socket), + //addr, + //#[cfg(feature = "metrics")] + //Arc::clone(&self.metrics), + //udp_data_receiver, + //); + //self.runtime.spawn( + //Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) + //.instrument(tracing::info_span!("udp", ?addr)), + //); + //(Protocols::Udp(protocol), true) + //}, _ => unimplemented!(), }; self.init_protocol(protocol, Some(pid_sender), handshake) @@ -265,7 +258,9 @@ impl Scheduler { async fn disconnect_mgr(&self, mut a2s_disconnect_r: mpsc::UnboundedReceiver) { trace!("Start disconnect_mgr"); - while let Some((pid, return_once_successful_shutdown)) = a2s_disconnect_r.recv().await { + while let Some((pid, (timeout_time, return_once_successful_shutdown))) = + a2s_disconnect_r.recv().await + { //Closing Participants is done the following way: // 1. We drop our senders and receivers // 2. we need to close BParticipant, this will drop its senderns and receivers @@ -279,7 +274,7 @@ impl Scheduler { pi.s2b_shutdown_bparticipant_s .take() .unwrap() - .send(finished_sender) + .send((timeout_time, finished_sender)) .unwrap(); drop(pi); trace!(?pid, "dropped bparticipant, waiting for finish"); @@ -322,7 +317,7 @@ impl Scheduler { pi.s2b_shutdown_bparticipant_s .take() .unwrap() - .send(finished_sender) + .send((Duration::from_secs(120), finished_sender)) .unwrap(); (pid, finished_receiver) }) @@ -392,15 +387,10 @@ impl Scheduler { }, }; info!("Accepting Tcp from: {}", remote_addr); - let protocol = TcpProtocol::new( - stream, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - ); - self.init_protocol(Protocols::Tcp(protocol), None, true) + self.init_protocol(Protocols::new_tcp(stream), None, true) .await; } - }, + },/* ProtocolAddr::Udp(addr) => { let socket = match net::UdpSocket::bind(addr).await { Ok(socket) => { @@ -451,12 +441,13 @@ impl Scheduler { let udp_data_sender = listeners.get_mut(&remote_addr).unwrap(); udp_data_sender.send(datavec).unwrap(); } - }, + },*/ _ => unimplemented!(), } trace!(?addr, "Ending channel creator"); } + #[allow(dead_code)] async fn udp_single_channel_connect( socket: Arc, w2p_udp_package_s: mpsc::UnboundedSender>, @@ -483,7 +474,7 @@ impl Scheduler { async fn init_protocol( &self, - protocol: Protocols, + mut protocol: Protocols, s2a_return_pid_s: Option>>, send_handshake: bool, ) { @@ -509,20 +500,13 @@ impl Scheduler { self.runtime.spawn( async move { trace!(?cid, "Open channel and be ready for Handshake"); - let handshake = Handshake::new( - cid, - local_pid, - local_secret, - #[cfg(feature = "metrics")] - Arc::clone(&metrics), - send_handshake, - ); - match handshake - .setup(&protocol) + use network_protocol::InitProtocol; + let init_result = protocol + .initialize(send_handshake, local_pid, local_secret) .instrument(tracing::info_span!("handshake", ?cid)) - .await - { - Ok((pid, sid, secret, leftover_cid_frame)) => { + .await; + match init_result { + Ok((pid, sid, secret)) => { trace!( ?cid, ?pid, @@ -533,14 +517,13 @@ impl Scheduler { debug!(?cid, "New participant connected via a channel"); let ( bparticipant, - a2b_stream_open_s, + a2b_open_stream_s, b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, ) = BParticipant::new( pid, sid, - Arc::clone(&runtime), #[cfg(feature = "metrics")] Arc::clone(&metrics), ); @@ -548,8 +531,7 @@ impl Scheduler { let participant = Participant::new( local_pid, pid, - Arc::clone(&runtime), - a2b_stream_open_s, + a2b_open_stream_s, b2a_stream_opened_r, participant_channels.a2s_disconnect_s, ); @@ -573,13 +555,7 @@ impl Scheduler { oneshot::channel(); //From now on wire connects directly with bparticipant! s2b_create_channel_s - .send(( - cid, - sid, - protocol, - leftover_cid_frame, - b2s_create_channel_done_s, - )) + .send((cid, sid, protocol, b2s_create_channel_done_s)) .unwrap(); b2s_create_channel_done_r.await.unwrap(); if let Some(pid_oneshot) = s2a_return_pid_s { @@ -627,8 +603,8 @@ impl Scheduler { //From now on this CHANNEL can receiver other frames! // move directly to participant! }, - Err(()) => { - debug!(?cid, "Handshake from a new connection failed"); + Err(e) => { + debug!(?cid, ?e, "Handshake from a new connection failed"); if let Some(pid_oneshot) = s2a_return_pid_s { // someone is waiting with `connect`, so give them their Error trace!(?cid, "returning the Err to api who requested the connect"); diff --git a/server/Cargo.toml b/server/Cargo.toml index f997749565..8726a037d8 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -16,7 +16,7 @@ common = { package = "veloren-common", path = "../common" } common-sys = { package = "veloren-common-sys", path = "../common/sys" } common-net = { package = "veloren-common-net", path = "../common/net" } world = { package = "veloren-world", path = "../world" } -network = { package = "veloren_network", path = "../network", features = ["metrics", "compression"], default-features = false } +network = { package = "veloren-network", path = "../network", features = ["metrics", "compression"], default-features = false } specs = { git = "https://github.com/amethyst/specs.git", features = ["shred-derive"], rev = "d4435bdf496cf322c74886ca09dd8795984919b4" } specs-idvs = { git = "https://gitlab.com/veloren/specs-idvs.git", rev = "9fab7b396acd6454585486e50ae4bfe2069858a9" } diff --git a/voxygen/src/hud/chat.rs b/voxygen/src/hud/chat.rs index b2dc69c780..1ecc5fe621 100644 --- a/voxygen/src/hud/chat.rs +++ b/voxygen/src/hud/chat.rs @@ -373,7 +373,7 @@ impl<'a> Widget for Chat<'a> { let ChatMsg { chat_type, .. } = &message; // For each ChatType needing localization get/set matching pre-formatted // localized string. This string will be formatted with the data - // provided in ChatType in the client/src/lib.rs + // provided in ChatType in the client/src/mod.rs // fn format_message called below message.message = match chat_type { ChatType::Online(_) => self